mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 00:59:46 +00:00
delete policies
This commit is contained in:
+191
-77
@@ -45,6 +45,7 @@ Usage:
|
||||
--torch_compile_mode=max-autotune
|
||||
"""
|
||||
|
||||
import gc
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
@@ -174,28 +175,7 @@ class RTCEvaluator:
|
||||
self.cfg = cfg
|
||||
self.device = cfg.device
|
||||
|
||||
# Load policy
|
||||
logging.info(f"Loading policy from {cfg.policy.pretrained_path}")
|
||||
policy_class = get_policy_class(cfg.policy.type)
|
||||
self.policy = policy_class.from_pretrained(cfg.policy.pretrained_path)
|
||||
self.policy = self.policy.to(self.device)
|
||||
self.policy.eval()
|
||||
|
||||
# Configure RTC
|
||||
cfg.rtc.enabled = True
|
||||
cfg.rtc.debug = True # Enable debug tracking for visualization
|
||||
self.policy.config.rtc_config = cfg.rtc
|
||||
self.policy.init_rtc_processor()
|
||||
|
||||
# Apply torch.compile if enabled
|
||||
if cfg.use_torch_compile:
|
||||
self._apply_torch_compile()
|
||||
|
||||
logging.info(f"Policy loaded: {self.policy.name}")
|
||||
logging.info(f"RTC enabled: {cfg.rtc.enabled}")
|
||||
logging.info(f"Execution horizon: {cfg.rtc.execution_horizon}")
|
||||
|
||||
# Load dataset
|
||||
# Load dataset first (needed for preprocessor)
|
||||
logging.info(f"Loading dataset: {cfg.dataset.repo_id}")
|
||||
self.dataset = LeRobotDataset(cfg.dataset.repo_id, delta_timestamps={"action": np.arange(50) / 30})
|
||||
logging.info(f"Dataset loaded: {len(self.dataset)} samples, {self.dataset.num_episodes} episodes")
|
||||
@@ -209,56 +189,177 @@ class RTCEvaluator:
|
||||
},
|
||||
)
|
||||
|
||||
def _apply_torch_compile(self):
|
||||
"""Apply torch.compile to the policy model for faster inference."""
|
||||
# Initialize three separate policy instances
|
||||
# Note: These policies are initialized here but will be freed sequentially during
|
||||
# evaluation to manage memory. Large models (e.g., VLAs with billions of parameters)
|
||||
# cannot fit three instances in memory simultaneously. Each policy is deleted and
|
||||
# memory is freed (via torch.cuda.empty_cache()) immediately after its use.
|
||||
logging.info("=" * 80)
|
||||
logging.info("Initializing three policy instances:")
|
||||
logging.info(" 1. policy_prev_chunk (for generating previous chunk)")
|
||||
logging.info(" 2. policy_no_rtc (for non-RTC inference)")
|
||||
logging.info(" 3. policy_rtc (for RTC inference)")
|
||||
logging.info(" Note: Policies will be freed sequentially during evaluation to manage memory")
|
||||
logging.info("=" * 80)
|
||||
|
||||
# Policy 1: For generating previous chunk (RTC disabled, no debug)
|
||||
self.policy_prev_chunk = self._init_policy(
|
||||
name="policy_prev_chunk",
|
||||
rtc_enabled=False,
|
||||
rtc_debug=False,
|
||||
)
|
||||
|
||||
# Policy 2: For non-RTC inference (RTC disabled, debug enabled)
|
||||
self.policy_no_rtc = self._init_policy(
|
||||
name="policy_no_rtc",
|
||||
rtc_enabled=False,
|
||||
rtc_debug=True,
|
||||
)
|
||||
|
||||
# Policy 3: For RTC inference (RTC enabled, debug enabled)
|
||||
self.policy_rtc = self._init_policy(
|
||||
name="policy_rtc",
|
||||
rtc_enabled=True,
|
||||
rtc_debug=True,
|
||||
)
|
||||
|
||||
logging.info("=" * 80)
|
||||
logging.info("All policies initialized successfully")
|
||||
logging.info("=" * 80)
|
||||
|
||||
def _init_policy(self, name: str, rtc_enabled: bool, rtc_debug: bool):
|
||||
"""Initialize a single policy instance with specified RTC configuration.
|
||||
|
||||
Args:
|
||||
name: Name identifier for logging purposes
|
||||
rtc_enabled: Whether to enable RTC for this policy
|
||||
rtc_debug: Whether to enable debug tracking for this policy
|
||||
|
||||
Returns:
|
||||
Configured policy instance with optional torch.compile applied
|
||||
"""
|
||||
logging.info(f"Initializing {name}...")
|
||||
|
||||
# Load policy from pretrained
|
||||
policy_class = get_policy_class(self.cfg.policy.type)
|
||||
policy = policy_class.from_pretrained(self.cfg.policy.pretrained_path)
|
||||
policy = policy.to(self.device)
|
||||
policy.eval()
|
||||
|
||||
# Configure RTC
|
||||
rtc_config = RTCConfig(
|
||||
enabled=rtc_enabled,
|
||||
execution_horizon=self.cfg.rtc.execution_horizon,
|
||||
max_guidance_weight=self.cfg.rtc.max_guidance_weight,
|
||||
prefix_attention_schedule=self.cfg.rtc.prefix_attention_schedule,
|
||||
debug=rtc_debug,
|
||||
debug_maxlen=self.cfg.rtc.debug_maxlen,
|
||||
)
|
||||
policy.config.rtc_config = rtc_config
|
||||
policy.init_rtc_processor()
|
||||
|
||||
logging.info(f" RTC enabled: {rtc_enabled}")
|
||||
logging.info(f" RTC debug: {rtc_debug}")
|
||||
|
||||
# Apply torch.compile to predict_action_chunk method if enabled
|
||||
if self.cfg.use_torch_compile:
|
||||
policy = self._apply_torch_compile(policy, name)
|
||||
|
||||
logging.info(f"✓ {name} initialized successfully")
|
||||
return policy
|
||||
|
||||
def _apply_torch_compile(self, policy, policy_name: str):
|
||||
"""Apply torch.compile to the policy's predict_action_chunk method.
|
||||
|
||||
Args:
|
||||
policy: Policy instance to compile
|
||||
policy_name: Name for logging purposes
|
||||
|
||||
Returns:
|
||||
Policy with compiled predict_action_chunk method
|
||||
"""
|
||||
try:
|
||||
# Check if torch.compile is available (PyTorch 2.0+)
|
||||
if not hasattr(torch, "compile"):
|
||||
logging.warning(
|
||||
"torch.compile is not available. Requires PyTorch 2.0+. "
|
||||
f" [{policy_name}] torch.compile is not available. Requires PyTorch 2.0+. "
|
||||
f"Current version: {torch.__version__}. Skipping compilation."
|
||||
)
|
||||
return
|
||||
return policy
|
||||
|
||||
logging.info("Applying torch.compile to policy model...")
|
||||
logging.info(f" Backend: {self.cfg.torch_compile_backend}")
|
||||
logging.info(f" Mode: {self.cfg.torch_compile_mode}")
|
||||
logging.info(f" [{policy_name}] Applying torch.compile to predict_action_chunk...")
|
||||
logging.info(f" Backend: {self.cfg.torch_compile_backend}")
|
||||
logging.info(f" Mode: {self.cfg.torch_compile_mode}")
|
||||
|
||||
# Compile the policy's model (not the policy itself to preserve methods)
|
||||
if hasattr(self.policy, "model"):
|
||||
original_model = self.policy.model
|
||||
compiled_model = torch.compile(
|
||||
original_model,
|
||||
backend=self.cfg.torch_compile_backend,
|
||||
mode=self.cfg.torch_compile_mode,
|
||||
)
|
||||
self.policy.model = compiled_model
|
||||
logging.info("✓ Successfully compiled policy.model")
|
||||
else:
|
||||
logging.warning(
|
||||
"Policy does not have a 'model' attribute. "
|
||||
"Attempting to compile entire policy (may not work for all policy types)."
|
||||
)
|
||||
self.policy = torch.compile(
|
||||
self.policy,
|
||||
backend=self.cfg.torch_compile_backend,
|
||||
mode=self.cfg.torch_compile_mode,
|
||||
)
|
||||
logging.info("✓ Successfully compiled policy")
|
||||
# Compile the predict_action_chunk method
|
||||
original_method = policy.predict_action_chunk
|
||||
compiled_method = torch.compile(
|
||||
original_method,
|
||||
backend=self.cfg.torch_compile_backend,
|
||||
mode=self.cfg.torch_compile_mode,
|
||||
)
|
||||
policy.predict_action_chunk = compiled_method
|
||||
logging.info(f" ✓ [{policy_name}] Successfully compiled predict_action_chunk")
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to apply torch.compile: {e}")
|
||||
logging.warning("Continuing without torch.compile")
|
||||
logging.error(f" [{policy_name}] Failed to apply torch.compile: {e}")
|
||||
logging.warning(f" [{policy_name}] Continuing without torch.compile")
|
||||
|
||||
return policy
|
||||
|
||||
def _destroy_policy(self, policy, policy_name: str):
|
||||
"""Explicitly destroy a policy and free all associated memory.
|
||||
|
||||
This method performs aggressive cleanup to ensure maximum memory is freed,
|
||||
which is critical for large models (e.g., VLAs with billions of parameters).
|
||||
|
||||
Args:
|
||||
policy: Policy instance to destroy
|
||||
policy_name: Name for logging purposes
|
||||
"""
|
||||
logging.info(f" Destroying {policy_name} and freeing memory...")
|
||||
|
||||
try:
|
||||
# Step 1: Move policy to CPU to free GPU/MPS memory
|
||||
policy.cpu()
|
||||
|
||||
# Step 2: Delete the policy object
|
||||
del policy
|
||||
|
||||
# Step 3: Force garbage collection to reclaim memory immediately
|
||||
gc.collect()
|
||||
|
||||
# Step 4: Clear device-specific caches
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize() # Ensure all operations complete
|
||||
|
||||
if torch.backends.mps.is_available():
|
||||
torch.mps.empty_cache()
|
||||
|
||||
logging.info(f" ✓ {policy_name} destroyed and memory freed")
|
||||
|
||||
except Exception as e:
|
||||
logging.warning(f" Warning: Error during {policy_name} cleanup: {e}")
|
||||
|
||||
def run_evaluation(self):
|
||||
"""Run evaluation on two random dataset samples."""
|
||||
"""Run evaluation on two random dataset samples using three separate policies.
|
||||
|
||||
Note: Policies are deinitalized after each step to free memory. Large models
|
||||
(e.g., VLA models with billions of parameters) cannot fit three instances in
|
||||
memory simultaneously. By deleting and garbage collecting after each step,
|
||||
we ensure only one policy is loaded at a time.
|
||||
"""
|
||||
# Create output directory
|
||||
os.makedirs(self.cfg.output_dir, exist_ok=True)
|
||||
logging.info(f"Output directory: {self.cfg.output_dir}")
|
||||
|
||||
logging.info("=" * 80)
|
||||
logging.info("Starting RTC evaluation")
|
||||
logging.info(f"Inference delay: {self.cfg.inference_delay}")
|
||||
logging.info("=" * 80)
|
||||
|
||||
# Load two random samples from dataset
|
||||
data_loader = torch.utils.data.DataLoader(self.dataset, batch_size=1, shuffle=True)
|
||||
loader_iter = iter(data_loader)
|
||||
first_sample = next(loader_iter)
|
||||
@@ -267,50 +368,65 @@ class RTCEvaluator:
|
||||
preprocessed_first_sample = self.preprocessor(first_sample)
|
||||
preprocessed_second_sample = self.preprocessor(second_sample)
|
||||
|
||||
# Don't postprocess the previous chunk
|
||||
prev_chunk_left_over = self.policy.predict_action_chunk(
|
||||
preprocessed_first_sample,
|
||||
)[:, :25, :].squeeze(0)
|
||||
# Step 1: Generate previous chunk using policy_prev_chunk
|
||||
# This policy is only used to generate the reference chunk and then freed
|
||||
logging.info("Step 1: Generating previous chunk with policy_prev_chunk")
|
||||
with torch.no_grad():
|
||||
prev_chunk_left_over = self.policy_prev_chunk.predict_action_chunk(
|
||||
preprocessed_first_sample,
|
||||
)[:, :25, :].squeeze(0)
|
||||
logging.info(f" Generated prev_chunk shape: {prev_chunk_left_over.shape}")
|
||||
|
||||
self.policy.rtc_processor.reset_tracker()
|
||||
|
||||
logging.info("Resetting tracker")
|
||||
# Destroy policy_prev_chunk to free memory for large models
|
||||
self._destroy_policy(self.policy_prev_chunk, "policy_prev_chunk")
|
||||
|
||||
# Sample noise (use same noise for both RTC and non-RTC for fair comparison)
|
||||
noise_size = (1, self.policy.config.chunk_size, self.policy.config.max_action_dim)
|
||||
noise = self.policy.model.sample_noise(noise_size, self.device)
|
||||
noise_size = (1, self.policy_no_rtc.config.chunk_size, self.policy_no_rtc.config.max_action_dim)
|
||||
noise = self.policy_no_rtc.model.sample_noise(noise_size, self.device)
|
||||
noise_clone = noise.clone()
|
||||
|
||||
# Generate actions WITHOUT RTC
|
||||
logging.info("Generating actions WITHOUT RTC")
|
||||
self.policy.config.rtc_config.enabled = False
|
||||
# Step 2: Generate actions WITHOUT RTC using policy_no_rtc
|
||||
logging.info("Step 2: Generating actions WITHOUT RTC with policy_no_rtc")
|
||||
self.policy_no_rtc.rtc_processor.reset_tracker()
|
||||
with torch.no_grad():
|
||||
_ = self.policy.predict_action_chunk(
|
||||
_ = self.policy_no_rtc.predict_action_chunk(
|
||||
preprocessed_second_sample,
|
||||
noise=noise,
|
||||
)
|
||||
no_rtc_tracked_steps = self.policy_no_rtc.rtc_processor.tracker.get_all_steps()
|
||||
logging.info(f" Tracked {len(no_rtc_tracked_steps)} steps without RTC")
|
||||
|
||||
no_rtc_tracked_steps = self.policy.rtc_processor.tracker.get_all_steps()
|
||||
self.policy.rtc_processor.reset_tracker()
|
||||
# Destroy policy_no_rtc to free memory before loading policy_rtc
|
||||
self._destroy_policy(self.policy_no_rtc, "policy_no_rtc")
|
||||
|
||||
# Generate actions WITH RTC
|
||||
logging.info("Generating actions WITH RTC")
|
||||
self.policy.config.rtc_config.enabled = True
|
||||
# Step 3: Generate actions WITH RTC using policy_rtc
|
||||
logging.info("Step 3: Generating actions WITH RTC with policy_rtc")
|
||||
self.policy_rtc.rtc_processor.reset_tracker()
|
||||
with torch.no_grad():
|
||||
_ = self.policy.predict_action_chunk(
|
||||
_ = self.policy_rtc.predict_action_chunk(
|
||||
preprocessed_second_sample,
|
||||
noise=noise_clone,
|
||||
inference_delay=self.cfg.inference_delay,
|
||||
prev_chunk_left_over=prev_chunk_left_over,
|
||||
execution_horizon=self.cfg.rtc.execution_horizon,
|
||||
)
|
||||
rtc_tracked_steps = self.policy_rtc.rtc_processor.get_all_debug_steps()
|
||||
logging.info(f" Tracked {len(rtc_tracked_steps)} steps with RTC")
|
||||
|
||||
rtc_tracked_steps = self.policy.rtc_processor.get_all_debug_steps()
|
||||
# Save num_steps before destroying policy (needed for plotting)
|
||||
num_steps = self.policy_rtc.config.num_steps
|
||||
|
||||
self.plot_tracked_data(rtc_tracked_steps, no_rtc_tracked_steps, prev_chunk_left_over)
|
||||
# Destroy policy_rtc after final use
|
||||
self._destroy_policy(self.policy_rtc, "policy_rtc")
|
||||
|
||||
# Plot and save results
|
||||
logging.info("=" * 80)
|
||||
logging.info("Plotting results...")
|
||||
self.plot_tracked_data(rtc_tracked_steps, no_rtc_tracked_steps, prev_chunk_left_over, num_steps)
|
||||
logging.info("=" * 80)
|
||||
logging.info("Evaluation completed successfully")
|
||||
|
||||
def plot_tracked_data(self, rtc_tracked_steps, no_rtc_tracked_steps, prev_chunk_left_over):
|
||||
def plot_tracked_data(self, rtc_tracked_steps, no_rtc_tracked_steps, prev_chunk_left_over, num_steps):
|
||||
# Create side-by-side figures for denoising visualization
|
||||
fig_xt, axs_xt = self._create_figure("x_t Denoising: No RTC (left) vs RTC (right)")
|
||||
fig_vt, axs_vt = self._create_figure("v_t Denoising: No RTC (left) vs RTC (right)")
|
||||
@@ -318,8 +434,6 @@ class RTCEvaluator:
|
||||
fig_x1t, axs_x1t = self._create_figure(
|
||||
"x1_t Predicted State & Error: No RTC (left - empty) vs RTC (right)"
|
||||
)
|
||||
|
||||
num_steps = self.policy.config.num_steps
|
||||
self._plot_denoising_steps_from_tracker(
|
||||
rtc_tracked_steps,
|
||||
axs_xt[:, 1], # Right column for x_t
|
||||
|
||||
Reference in New Issue
Block a user