delete policies

This commit is contained in:
Eugene Mironov
2025-11-06 17:56:10 +07:00
parent 6fdee95923
commit 6aa940346d
+191 -77
View File
@@ -45,6 +45,7 @@ Usage:
--torch_compile_mode=max-autotune
"""
import gc
import logging
import os
import random
@@ -174,28 +175,7 @@ class RTCEvaluator:
self.cfg = cfg
self.device = cfg.device
# Load policy
logging.info(f"Loading policy from {cfg.policy.pretrained_path}")
policy_class = get_policy_class(cfg.policy.type)
self.policy = policy_class.from_pretrained(cfg.policy.pretrained_path)
self.policy = self.policy.to(self.device)
self.policy.eval()
# Configure RTC
cfg.rtc.enabled = True
cfg.rtc.debug = True # Enable debug tracking for visualization
self.policy.config.rtc_config = cfg.rtc
self.policy.init_rtc_processor()
# Apply torch.compile if enabled
if cfg.use_torch_compile:
self._apply_torch_compile()
logging.info(f"Policy loaded: {self.policy.name}")
logging.info(f"RTC enabled: {cfg.rtc.enabled}")
logging.info(f"Execution horizon: {cfg.rtc.execution_horizon}")
# Load dataset
# Load dataset first (needed for preprocessor)
logging.info(f"Loading dataset: {cfg.dataset.repo_id}")
self.dataset = LeRobotDataset(cfg.dataset.repo_id, delta_timestamps={"action": np.arange(50) / 30})
logging.info(f"Dataset loaded: {len(self.dataset)} samples, {self.dataset.num_episodes} episodes")
@@ -209,56 +189,177 @@ class RTCEvaluator:
},
)
def _apply_torch_compile(self):
"""Apply torch.compile to the policy model for faster inference."""
# Initialize three separate policy instances
# Note: These policies are initialized here but will be freed sequentially during
# evaluation to manage memory. Large models (e.g., VLAs with billions of parameters)
# cannot fit three instances in memory simultaneously. Each policy is deleted and
# memory is freed (via torch.cuda.empty_cache()) immediately after its use.
logging.info("=" * 80)
logging.info("Initializing three policy instances:")
logging.info(" 1. policy_prev_chunk (for generating previous chunk)")
logging.info(" 2. policy_no_rtc (for non-RTC inference)")
logging.info(" 3. policy_rtc (for RTC inference)")
logging.info(" Note: Policies will be freed sequentially during evaluation to manage memory")
logging.info("=" * 80)
# Policy 1: For generating previous chunk (RTC disabled, no debug)
self.policy_prev_chunk = self._init_policy(
name="policy_prev_chunk",
rtc_enabled=False,
rtc_debug=False,
)
# Policy 2: For non-RTC inference (RTC disabled, debug enabled)
self.policy_no_rtc = self._init_policy(
name="policy_no_rtc",
rtc_enabled=False,
rtc_debug=True,
)
# Policy 3: For RTC inference (RTC enabled, debug enabled)
self.policy_rtc = self._init_policy(
name="policy_rtc",
rtc_enabled=True,
rtc_debug=True,
)
logging.info("=" * 80)
logging.info("All policies initialized successfully")
logging.info("=" * 80)
def _init_policy(self, name: str, rtc_enabled: bool, rtc_debug: bool):
"""Initialize a single policy instance with specified RTC configuration.
Args:
name: Name identifier for logging purposes
rtc_enabled: Whether to enable RTC for this policy
rtc_debug: Whether to enable debug tracking for this policy
Returns:
Configured policy instance with optional torch.compile applied
"""
logging.info(f"Initializing {name}...")
# Load policy from pretrained
policy_class = get_policy_class(self.cfg.policy.type)
policy = policy_class.from_pretrained(self.cfg.policy.pretrained_path)
policy = policy.to(self.device)
policy.eval()
# Configure RTC
rtc_config = RTCConfig(
enabled=rtc_enabled,
execution_horizon=self.cfg.rtc.execution_horizon,
max_guidance_weight=self.cfg.rtc.max_guidance_weight,
prefix_attention_schedule=self.cfg.rtc.prefix_attention_schedule,
debug=rtc_debug,
debug_maxlen=self.cfg.rtc.debug_maxlen,
)
policy.config.rtc_config = rtc_config
policy.init_rtc_processor()
logging.info(f" RTC enabled: {rtc_enabled}")
logging.info(f" RTC debug: {rtc_debug}")
# Apply torch.compile to predict_action_chunk method if enabled
if self.cfg.use_torch_compile:
policy = self._apply_torch_compile(policy, name)
logging.info(f"{name} initialized successfully")
return policy
def _apply_torch_compile(self, policy, policy_name: str):
"""Apply torch.compile to the policy's predict_action_chunk method.
Args:
policy: Policy instance to compile
policy_name: Name for logging purposes
Returns:
Policy with compiled predict_action_chunk method
"""
try:
# Check if torch.compile is available (PyTorch 2.0+)
if not hasattr(torch, "compile"):
logging.warning(
"torch.compile is not available. Requires PyTorch 2.0+. "
f" [{policy_name}] torch.compile is not available. Requires PyTorch 2.0+. "
f"Current version: {torch.__version__}. Skipping compilation."
)
return
return policy
logging.info("Applying torch.compile to policy model...")
logging.info(f" Backend: {self.cfg.torch_compile_backend}")
logging.info(f" Mode: {self.cfg.torch_compile_mode}")
logging.info(f" [{policy_name}] Applying torch.compile to predict_action_chunk...")
logging.info(f" Backend: {self.cfg.torch_compile_backend}")
logging.info(f" Mode: {self.cfg.torch_compile_mode}")
# Compile the policy's model (not the policy itself to preserve methods)
if hasattr(self.policy, "model"):
original_model = self.policy.model
compiled_model = torch.compile(
original_model,
backend=self.cfg.torch_compile_backend,
mode=self.cfg.torch_compile_mode,
)
self.policy.model = compiled_model
logging.info("✓ Successfully compiled policy.model")
else:
logging.warning(
"Policy does not have a 'model' attribute. "
"Attempting to compile entire policy (may not work for all policy types)."
)
self.policy = torch.compile(
self.policy,
backend=self.cfg.torch_compile_backend,
mode=self.cfg.torch_compile_mode,
)
logging.info("✓ Successfully compiled policy")
# Compile the predict_action_chunk method
original_method = policy.predict_action_chunk
compiled_method = torch.compile(
original_method,
backend=self.cfg.torch_compile_backend,
mode=self.cfg.torch_compile_mode,
)
policy.predict_action_chunk = compiled_method
logging.info(f" ✓ [{policy_name}] Successfully compiled predict_action_chunk")
except Exception as e:
logging.error(f"Failed to apply torch.compile: {e}")
logging.warning("Continuing without torch.compile")
logging.error(f" [{policy_name}] Failed to apply torch.compile: {e}")
logging.warning(f" [{policy_name}] Continuing without torch.compile")
return policy
def _destroy_policy(self, policy, policy_name: str):
"""Explicitly destroy a policy and free all associated memory.
This method performs aggressive cleanup to ensure maximum memory is freed,
which is critical for large models (e.g., VLAs with billions of parameters).
Args:
policy: Policy instance to destroy
policy_name: Name for logging purposes
"""
logging.info(f" Destroying {policy_name} and freeing memory...")
try:
# Step 1: Move policy to CPU to free GPU/MPS memory
policy.cpu()
# Step 2: Delete the policy object
del policy
# Step 3: Force garbage collection to reclaim memory immediately
gc.collect()
# Step 4: Clear device-specific caches
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize() # Ensure all operations complete
if torch.backends.mps.is_available():
torch.mps.empty_cache()
logging.info(f"{policy_name} destroyed and memory freed")
except Exception as e:
logging.warning(f" Warning: Error during {policy_name} cleanup: {e}")
def run_evaluation(self):
"""Run evaluation on two random dataset samples."""
"""Run evaluation on two random dataset samples using three separate policies.
Note: Policies are deinitalized after each step to free memory. Large models
(e.g., VLA models with billions of parameters) cannot fit three instances in
memory simultaneously. By deleting and garbage collecting after each step,
we ensure only one policy is loaded at a time.
"""
# Create output directory
os.makedirs(self.cfg.output_dir, exist_ok=True)
logging.info(f"Output directory: {self.cfg.output_dir}")
logging.info("=" * 80)
logging.info("Starting RTC evaluation")
logging.info(f"Inference delay: {self.cfg.inference_delay}")
logging.info("=" * 80)
# Load two random samples from dataset
data_loader = torch.utils.data.DataLoader(self.dataset, batch_size=1, shuffle=True)
loader_iter = iter(data_loader)
first_sample = next(loader_iter)
@@ -267,50 +368,65 @@ class RTCEvaluator:
preprocessed_first_sample = self.preprocessor(first_sample)
preprocessed_second_sample = self.preprocessor(second_sample)
# Don't postprocess the previous chunk
prev_chunk_left_over = self.policy.predict_action_chunk(
preprocessed_first_sample,
)[:, :25, :].squeeze(0)
# Step 1: Generate previous chunk using policy_prev_chunk
# This policy is only used to generate the reference chunk and then freed
logging.info("Step 1: Generating previous chunk with policy_prev_chunk")
with torch.no_grad():
prev_chunk_left_over = self.policy_prev_chunk.predict_action_chunk(
preprocessed_first_sample,
)[:, :25, :].squeeze(0)
logging.info(f" Generated prev_chunk shape: {prev_chunk_left_over.shape}")
self.policy.rtc_processor.reset_tracker()
logging.info("Resetting tracker")
# Destroy policy_prev_chunk to free memory for large models
self._destroy_policy(self.policy_prev_chunk, "policy_prev_chunk")
# Sample noise (use same noise for both RTC and non-RTC for fair comparison)
noise_size = (1, self.policy.config.chunk_size, self.policy.config.max_action_dim)
noise = self.policy.model.sample_noise(noise_size, self.device)
noise_size = (1, self.policy_no_rtc.config.chunk_size, self.policy_no_rtc.config.max_action_dim)
noise = self.policy_no_rtc.model.sample_noise(noise_size, self.device)
noise_clone = noise.clone()
# Generate actions WITHOUT RTC
logging.info("Generating actions WITHOUT RTC")
self.policy.config.rtc_config.enabled = False
# Step 2: Generate actions WITHOUT RTC using policy_no_rtc
logging.info("Step 2: Generating actions WITHOUT RTC with policy_no_rtc")
self.policy_no_rtc.rtc_processor.reset_tracker()
with torch.no_grad():
_ = self.policy.predict_action_chunk(
_ = self.policy_no_rtc.predict_action_chunk(
preprocessed_second_sample,
noise=noise,
)
no_rtc_tracked_steps = self.policy_no_rtc.rtc_processor.tracker.get_all_steps()
logging.info(f" Tracked {len(no_rtc_tracked_steps)} steps without RTC")
no_rtc_tracked_steps = self.policy.rtc_processor.tracker.get_all_steps()
self.policy.rtc_processor.reset_tracker()
# Destroy policy_no_rtc to free memory before loading policy_rtc
self._destroy_policy(self.policy_no_rtc, "policy_no_rtc")
# Generate actions WITH RTC
logging.info("Generating actions WITH RTC")
self.policy.config.rtc_config.enabled = True
# Step 3: Generate actions WITH RTC using policy_rtc
logging.info("Step 3: Generating actions WITH RTC with policy_rtc")
self.policy_rtc.rtc_processor.reset_tracker()
with torch.no_grad():
_ = self.policy.predict_action_chunk(
_ = self.policy_rtc.predict_action_chunk(
preprocessed_second_sample,
noise=noise_clone,
inference_delay=self.cfg.inference_delay,
prev_chunk_left_over=prev_chunk_left_over,
execution_horizon=self.cfg.rtc.execution_horizon,
)
rtc_tracked_steps = self.policy_rtc.rtc_processor.get_all_debug_steps()
logging.info(f" Tracked {len(rtc_tracked_steps)} steps with RTC")
rtc_tracked_steps = self.policy.rtc_processor.get_all_debug_steps()
# Save num_steps before destroying policy (needed for plotting)
num_steps = self.policy_rtc.config.num_steps
self.plot_tracked_data(rtc_tracked_steps, no_rtc_tracked_steps, prev_chunk_left_over)
# Destroy policy_rtc after final use
self._destroy_policy(self.policy_rtc, "policy_rtc")
# Plot and save results
logging.info("=" * 80)
logging.info("Plotting results...")
self.plot_tracked_data(rtc_tracked_steps, no_rtc_tracked_steps, prev_chunk_left_over, num_steps)
logging.info("=" * 80)
logging.info("Evaluation completed successfully")
def plot_tracked_data(self, rtc_tracked_steps, no_rtc_tracked_steps, prev_chunk_left_over):
def plot_tracked_data(self, rtc_tracked_steps, no_rtc_tracked_steps, prev_chunk_left_over, num_steps):
# Create side-by-side figures for denoising visualization
fig_xt, axs_xt = self._create_figure("x_t Denoising: No RTC (left) vs RTC (right)")
fig_vt, axs_vt = self._create_figure("v_t Denoising: No RTC (left) vs RTC (right)")
@@ -318,8 +434,6 @@ class RTCEvaluator:
fig_x1t, axs_x1t = self._create_figure(
"x1_t Predicted State & Error: No RTC (left - empty) vs RTC (right)"
)
num_steps = self.policy.config.num_steps
self._plot_denoising_steps_from_tracker(
rtc_tracked_steps,
axs_xt[:, 1], # Right column for x_t