delete policies

This commit is contained in:
Eugene Mironov
2025-11-06 17:56:10 +07:00
parent e09a6a90e1
commit 83f1de035e
+191 -77
View File
@@ -45,6 +45,7 @@ Usage:
--torch_compile_mode=max-autotune --torch_compile_mode=max-autotune
""" """
import gc
import logging import logging
import os import os
import random import random
@@ -174,28 +175,7 @@ class RTCEvaluator:
self.cfg = cfg self.cfg = cfg
self.device = cfg.device self.device = cfg.device
# Load policy # Load dataset first (needed for preprocessor)
logging.info(f"Loading policy from {cfg.policy.pretrained_path}")
policy_class = get_policy_class(cfg.policy.type)
self.policy = policy_class.from_pretrained(cfg.policy.pretrained_path)
self.policy = self.policy.to(self.device)
self.policy.eval()
# Configure RTC
cfg.rtc.enabled = True
cfg.rtc.debug = True # Enable debug tracking for visualization
self.policy.config.rtc_config = cfg.rtc
self.policy.init_rtc_processor()
# Apply torch.compile if enabled
if cfg.use_torch_compile:
self._apply_torch_compile()
logging.info(f"Policy loaded: {self.policy.name}")
logging.info(f"RTC enabled: {cfg.rtc.enabled}")
logging.info(f"Execution horizon: {cfg.rtc.execution_horizon}")
# Load dataset
logging.info(f"Loading dataset: {cfg.dataset.repo_id}") logging.info(f"Loading dataset: {cfg.dataset.repo_id}")
self.dataset = LeRobotDataset(cfg.dataset.repo_id, delta_timestamps={"action": np.arange(50) / 30}) self.dataset = LeRobotDataset(cfg.dataset.repo_id, delta_timestamps={"action": np.arange(50) / 30})
logging.info(f"Dataset loaded: {len(self.dataset)} samples, {self.dataset.num_episodes} episodes") logging.info(f"Dataset loaded: {len(self.dataset)} samples, {self.dataset.num_episodes} episodes")
@@ -209,56 +189,177 @@ class RTCEvaluator:
}, },
) )
def _apply_torch_compile(self): # Initialize three separate policy instances
"""Apply torch.compile to the policy model for faster inference.""" # Note: These policies are initialized here but will be freed sequentially during
# evaluation to manage memory. Large models (e.g., VLAs with billions of parameters)
# cannot fit three instances in memory simultaneously. Each policy is deleted and
# memory is freed (via torch.cuda.empty_cache()) immediately after its use.
logging.info("=" * 80)
logging.info("Initializing three policy instances:")
logging.info(" 1. policy_prev_chunk (for generating previous chunk)")
logging.info(" 2. policy_no_rtc (for non-RTC inference)")
logging.info(" 3. policy_rtc (for RTC inference)")
logging.info(" Note: Policies will be freed sequentially during evaluation to manage memory")
logging.info("=" * 80)
# Policy 1: For generating previous chunk (RTC disabled, no debug)
self.policy_prev_chunk = self._init_policy(
name="policy_prev_chunk",
rtc_enabled=False,
rtc_debug=False,
)
# Policy 2: For non-RTC inference (RTC disabled, debug enabled)
self.policy_no_rtc = self._init_policy(
name="policy_no_rtc",
rtc_enabled=False,
rtc_debug=True,
)
# Policy 3: For RTC inference (RTC enabled, debug enabled)
self.policy_rtc = self._init_policy(
name="policy_rtc",
rtc_enabled=True,
rtc_debug=True,
)
logging.info("=" * 80)
logging.info("All policies initialized successfully")
logging.info("=" * 80)
def _init_policy(self, name: str, rtc_enabled: bool, rtc_debug: bool):
"""Initialize a single policy instance with specified RTC configuration.
Args:
name: Name identifier for logging purposes
rtc_enabled: Whether to enable RTC for this policy
rtc_debug: Whether to enable debug tracking for this policy
Returns:
Configured policy instance with optional torch.compile applied
"""
logging.info(f"Initializing {name}...")
# Load policy from pretrained
policy_class = get_policy_class(self.cfg.policy.type)
policy = policy_class.from_pretrained(self.cfg.policy.pretrained_path)
policy = policy.to(self.device)
policy.eval()
# Configure RTC
rtc_config = RTCConfig(
enabled=rtc_enabled,
execution_horizon=self.cfg.rtc.execution_horizon,
max_guidance_weight=self.cfg.rtc.max_guidance_weight,
prefix_attention_schedule=self.cfg.rtc.prefix_attention_schedule,
debug=rtc_debug,
debug_maxlen=self.cfg.rtc.debug_maxlen,
)
policy.config.rtc_config = rtc_config
policy.init_rtc_processor()
logging.info(f" RTC enabled: {rtc_enabled}")
logging.info(f" RTC debug: {rtc_debug}")
# Apply torch.compile to predict_action_chunk method if enabled
if self.cfg.use_torch_compile:
policy = self._apply_torch_compile(policy, name)
logging.info(f"{name} initialized successfully")
return policy
def _apply_torch_compile(self, policy, policy_name: str):
"""Apply torch.compile to the policy's predict_action_chunk method.
Args:
policy: Policy instance to compile
policy_name: Name for logging purposes
Returns:
Policy with compiled predict_action_chunk method
"""
try: try:
# Check if torch.compile is available (PyTorch 2.0+) # Check if torch.compile is available (PyTorch 2.0+)
if not hasattr(torch, "compile"): if not hasattr(torch, "compile"):
logging.warning( logging.warning(
"torch.compile is not available. Requires PyTorch 2.0+. " f" [{policy_name}] torch.compile is not available. Requires PyTorch 2.0+. "
f"Current version: {torch.__version__}. Skipping compilation." f"Current version: {torch.__version__}. Skipping compilation."
) )
return return policy
logging.info("Applying torch.compile to policy model...") logging.info(f" [{policy_name}] Applying torch.compile to predict_action_chunk...")
logging.info(f" Backend: {self.cfg.torch_compile_backend}") logging.info(f" Backend: {self.cfg.torch_compile_backend}")
logging.info(f" Mode: {self.cfg.torch_compile_mode}") logging.info(f" Mode: {self.cfg.torch_compile_mode}")
# Compile the policy's model (not the policy itself to preserve methods) # Compile the predict_action_chunk method
if hasattr(self.policy, "model"): original_method = policy.predict_action_chunk
original_model = self.policy.model compiled_method = torch.compile(
compiled_model = torch.compile( original_method,
original_model, backend=self.cfg.torch_compile_backend,
backend=self.cfg.torch_compile_backend, mode=self.cfg.torch_compile_mode,
mode=self.cfg.torch_compile_mode, )
) policy.predict_action_chunk = compiled_method
self.policy.model = compiled_model logging.info(f" ✓ [{policy_name}] Successfully compiled predict_action_chunk")
logging.info("✓ Successfully compiled policy.model")
else:
logging.warning(
"Policy does not have a 'model' attribute. "
"Attempting to compile entire policy (may not work for all policy types)."
)
self.policy = torch.compile(
self.policy,
backend=self.cfg.torch_compile_backend,
mode=self.cfg.torch_compile_mode,
)
logging.info("✓ Successfully compiled policy")
except Exception as e: except Exception as e:
logging.error(f"Failed to apply torch.compile: {e}") logging.error(f" [{policy_name}] Failed to apply torch.compile: {e}")
logging.warning("Continuing without torch.compile") logging.warning(f" [{policy_name}] Continuing without torch.compile")
return policy
def _destroy_policy(self, policy, policy_name: str):
"""Explicitly destroy a policy and free all associated memory.
This method performs aggressive cleanup to ensure maximum memory is freed,
which is critical for large models (e.g., VLAs with billions of parameters).
Args:
policy: Policy instance to destroy
policy_name: Name for logging purposes
"""
logging.info(f" Destroying {policy_name} and freeing memory...")
try:
# Step 1: Move policy to CPU to free GPU/MPS memory
policy.cpu()
# Step 2: Delete the policy object
del policy
# Step 3: Force garbage collection to reclaim memory immediately
gc.collect()
# Step 4: Clear device-specific caches
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize() # Ensure all operations complete
if torch.backends.mps.is_available():
torch.mps.empty_cache()
logging.info(f"{policy_name} destroyed and memory freed")
except Exception as e:
logging.warning(f" Warning: Error during {policy_name} cleanup: {e}")
def run_evaluation(self): def run_evaluation(self):
"""Run evaluation on two random dataset samples.""" """Run evaluation on two random dataset samples using three separate policies.
Note: Policies are deinitalized after each step to free memory. Large models
(e.g., VLA models with billions of parameters) cannot fit three instances in
memory simultaneously. By deleting and garbage collecting after each step,
we ensure only one policy is loaded at a time.
"""
# Create output directory # Create output directory
os.makedirs(self.cfg.output_dir, exist_ok=True) os.makedirs(self.cfg.output_dir, exist_ok=True)
logging.info(f"Output directory: {self.cfg.output_dir}") logging.info(f"Output directory: {self.cfg.output_dir}")
logging.info("=" * 80)
logging.info("Starting RTC evaluation") logging.info("Starting RTC evaluation")
logging.info(f"Inference delay: {self.cfg.inference_delay}") logging.info(f"Inference delay: {self.cfg.inference_delay}")
logging.info("=" * 80)
# Load two random samples from dataset
data_loader = torch.utils.data.DataLoader(self.dataset, batch_size=1, shuffle=True) data_loader = torch.utils.data.DataLoader(self.dataset, batch_size=1, shuffle=True)
loader_iter = iter(data_loader) loader_iter = iter(data_loader)
first_sample = next(loader_iter) first_sample = next(loader_iter)
@@ -267,50 +368,65 @@ class RTCEvaluator:
preprocessed_first_sample = self.preprocessor(first_sample) preprocessed_first_sample = self.preprocessor(first_sample)
preprocessed_second_sample = self.preprocessor(second_sample) preprocessed_second_sample = self.preprocessor(second_sample)
# Don't postprocess the previous chunk # Step 1: Generate previous chunk using policy_prev_chunk
prev_chunk_left_over = self.policy.predict_action_chunk( # This policy is only used to generate the reference chunk and then freed
preprocessed_first_sample, logging.info("Step 1: Generating previous chunk with policy_prev_chunk")
)[:, :25, :].squeeze(0) with torch.no_grad():
prev_chunk_left_over = self.policy_prev_chunk.predict_action_chunk(
preprocessed_first_sample,
)[:, :25, :].squeeze(0)
logging.info(f" Generated prev_chunk shape: {prev_chunk_left_over.shape}")
self.policy.rtc_processor.reset_tracker() # Destroy policy_prev_chunk to free memory for large models
self._destroy_policy(self.policy_prev_chunk, "policy_prev_chunk")
logging.info("Resetting tracker")
# Sample noise (use same noise for both RTC and non-RTC for fair comparison) # Sample noise (use same noise for both RTC and non-RTC for fair comparison)
noise_size = (1, self.policy.config.chunk_size, self.policy.config.max_action_dim) noise_size = (1, self.policy_no_rtc.config.chunk_size, self.policy_no_rtc.config.max_action_dim)
noise = self.policy.model.sample_noise(noise_size, self.device) noise = self.policy_no_rtc.model.sample_noise(noise_size, self.device)
noise_clone = noise.clone() noise_clone = noise.clone()
# Generate actions WITHOUT RTC # Step 2: Generate actions WITHOUT RTC using policy_no_rtc
logging.info("Generating actions WITHOUT RTC") logging.info("Step 2: Generating actions WITHOUT RTC with policy_no_rtc")
self.policy.config.rtc_config.enabled = False self.policy_no_rtc.rtc_processor.reset_tracker()
with torch.no_grad(): with torch.no_grad():
_ = self.policy.predict_action_chunk( _ = self.policy_no_rtc.predict_action_chunk(
preprocessed_second_sample, preprocessed_second_sample,
noise=noise, noise=noise,
) )
no_rtc_tracked_steps = self.policy_no_rtc.rtc_processor.tracker.get_all_steps()
logging.info(f" Tracked {len(no_rtc_tracked_steps)} steps without RTC")
no_rtc_tracked_steps = self.policy.rtc_processor.tracker.get_all_steps() # Destroy policy_no_rtc to free memory before loading policy_rtc
self.policy.rtc_processor.reset_tracker() self._destroy_policy(self.policy_no_rtc, "policy_no_rtc")
# Generate actions WITH RTC # Step 3: Generate actions WITH RTC using policy_rtc
logging.info("Generating actions WITH RTC") logging.info("Step 3: Generating actions WITH RTC with policy_rtc")
self.policy.config.rtc_config.enabled = True self.policy_rtc.rtc_processor.reset_tracker()
with torch.no_grad(): with torch.no_grad():
_ = self.policy.predict_action_chunk( _ = self.policy_rtc.predict_action_chunk(
preprocessed_second_sample, preprocessed_second_sample,
noise=noise_clone, noise=noise_clone,
inference_delay=self.cfg.inference_delay, inference_delay=self.cfg.inference_delay,
prev_chunk_left_over=prev_chunk_left_over, prev_chunk_left_over=prev_chunk_left_over,
execution_horizon=self.cfg.rtc.execution_horizon, execution_horizon=self.cfg.rtc.execution_horizon,
) )
rtc_tracked_steps = self.policy_rtc.rtc_processor.get_all_debug_steps()
logging.info(f" Tracked {len(rtc_tracked_steps)} steps with RTC")
rtc_tracked_steps = self.policy.rtc_processor.get_all_debug_steps() # Save num_steps before destroying policy (needed for plotting)
num_steps = self.policy_rtc.config.num_steps
self.plot_tracked_data(rtc_tracked_steps, no_rtc_tracked_steps, prev_chunk_left_over) # Destroy policy_rtc after final use
self._destroy_policy(self.policy_rtc, "policy_rtc")
# Plot and save results
logging.info("=" * 80)
logging.info("Plotting results...")
self.plot_tracked_data(rtc_tracked_steps, no_rtc_tracked_steps, prev_chunk_left_over, num_steps)
logging.info("=" * 80)
logging.info("Evaluation completed successfully") logging.info("Evaluation completed successfully")
def plot_tracked_data(self, rtc_tracked_steps, no_rtc_tracked_steps, prev_chunk_left_over): def plot_tracked_data(self, rtc_tracked_steps, no_rtc_tracked_steps, prev_chunk_left_over, num_steps):
# Create side-by-side figures for denoising visualization # Create side-by-side figures for denoising visualization
fig_xt, axs_xt = self._create_figure("x_t Denoising: No RTC (left) vs RTC (right)") fig_xt, axs_xt = self._create_figure("x_t Denoising: No RTC (left) vs RTC (right)")
fig_vt, axs_vt = self._create_figure("v_t Denoising: No RTC (left) vs RTC (right)") fig_vt, axs_vt = self._create_figure("v_t Denoising: No RTC (left) vs RTC (right)")
@@ -318,8 +434,6 @@ class RTCEvaluator:
fig_x1t, axs_x1t = self._create_figure( fig_x1t, axs_x1t = self._create_figure(
"x1_t Predicted State & Error: No RTC (left - empty) vs RTC (right)" "x1_t Predicted State & Error: No RTC (left - empty) vs RTC (right)"
) )
num_steps = self.policy.config.num_steps
self._plot_denoising_steps_from_tracker( self._plot_denoising_steps_from_tracker(
rtc_tracked_steps, rtc_tracked_steps,
axs_xt[:, 1], # Right column for x_t axs_xt[:, 1], # Right column for x_t