mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 12:09:42 +00:00
delete policies
This commit is contained in:
+191
-77
@@ -45,6 +45,7 @@ Usage:
|
|||||||
--torch_compile_mode=max-autotune
|
--torch_compile_mode=max-autotune
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import gc
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
@@ -174,28 +175,7 @@ class RTCEvaluator:
|
|||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
self.device = cfg.device
|
self.device = cfg.device
|
||||||
|
|
||||||
# Load policy
|
# Load dataset first (needed for preprocessor)
|
||||||
logging.info(f"Loading policy from {cfg.policy.pretrained_path}")
|
|
||||||
policy_class = get_policy_class(cfg.policy.type)
|
|
||||||
self.policy = policy_class.from_pretrained(cfg.policy.pretrained_path)
|
|
||||||
self.policy = self.policy.to(self.device)
|
|
||||||
self.policy.eval()
|
|
||||||
|
|
||||||
# Configure RTC
|
|
||||||
cfg.rtc.enabled = True
|
|
||||||
cfg.rtc.debug = True # Enable debug tracking for visualization
|
|
||||||
self.policy.config.rtc_config = cfg.rtc
|
|
||||||
self.policy.init_rtc_processor()
|
|
||||||
|
|
||||||
# Apply torch.compile if enabled
|
|
||||||
if cfg.use_torch_compile:
|
|
||||||
self._apply_torch_compile()
|
|
||||||
|
|
||||||
logging.info(f"Policy loaded: {self.policy.name}")
|
|
||||||
logging.info(f"RTC enabled: {cfg.rtc.enabled}")
|
|
||||||
logging.info(f"Execution horizon: {cfg.rtc.execution_horizon}")
|
|
||||||
|
|
||||||
# Load dataset
|
|
||||||
logging.info(f"Loading dataset: {cfg.dataset.repo_id}")
|
logging.info(f"Loading dataset: {cfg.dataset.repo_id}")
|
||||||
self.dataset = LeRobotDataset(cfg.dataset.repo_id, delta_timestamps={"action": np.arange(50) / 30})
|
self.dataset = LeRobotDataset(cfg.dataset.repo_id, delta_timestamps={"action": np.arange(50) / 30})
|
||||||
logging.info(f"Dataset loaded: {len(self.dataset)} samples, {self.dataset.num_episodes} episodes")
|
logging.info(f"Dataset loaded: {len(self.dataset)} samples, {self.dataset.num_episodes} episodes")
|
||||||
@@ -209,56 +189,177 @@ class RTCEvaluator:
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def _apply_torch_compile(self):
|
# Initialize three separate policy instances
|
||||||
"""Apply torch.compile to the policy model for faster inference."""
|
# Note: These policies are initialized here but will be freed sequentially during
|
||||||
|
# evaluation to manage memory. Large models (e.g., VLAs with billions of parameters)
|
||||||
|
# cannot fit three instances in memory simultaneously. Each policy is deleted and
|
||||||
|
# memory is freed (via torch.cuda.empty_cache()) immediately after its use.
|
||||||
|
logging.info("=" * 80)
|
||||||
|
logging.info("Initializing three policy instances:")
|
||||||
|
logging.info(" 1. policy_prev_chunk (for generating previous chunk)")
|
||||||
|
logging.info(" 2. policy_no_rtc (for non-RTC inference)")
|
||||||
|
logging.info(" 3. policy_rtc (for RTC inference)")
|
||||||
|
logging.info(" Note: Policies will be freed sequentially during evaluation to manage memory")
|
||||||
|
logging.info("=" * 80)
|
||||||
|
|
||||||
|
# Policy 1: For generating previous chunk (RTC disabled, no debug)
|
||||||
|
self.policy_prev_chunk = self._init_policy(
|
||||||
|
name="policy_prev_chunk",
|
||||||
|
rtc_enabled=False,
|
||||||
|
rtc_debug=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Policy 2: For non-RTC inference (RTC disabled, debug enabled)
|
||||||
|
self.policy_no_rtc = self._init_policy(
|
||||||
|
name="policy_no_rtc",
|
||||||
|
rtc_enabled=False,
|
||||||
|
rtc_debug=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Policy 3: For RTC inference (RTC enabled, debug enabled)
|
||||||
|
self.policy_rtc = self._init_policy(
|
||||||
|
name="policy_rtc",
|
||||||
|
rtc_enabled=True,
|
||||||
|
rtc_debug=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info("=" * 80)
|
||||||
|
logging.info("All policies initialized successfully")
|
||||||
|
logging.info("=" * 80)
|
||||||
|
|
||||||
|
def _init_policy(self, name: str, rtc_enabled: bool, rtc_debug: bool):
|
||||||
|
"""Initialize a single policy instance with specified RTC configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Name identifier for logging purposes
|
||||||
|
rtc_enabled: Whether to enable RTC for this policy
|
||||||
|
rtc_debug: Whether to enable debug tracking for this policy
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured policy instance with optional torch.compile applied
|
||||||
|
"""
|
||||||
|
logging.info(f"Initializing {name}...")
|
||||||
|
|
||||||
|
# Load policy from pretrained
|
||||||
|
policy_class = get_policy_class(self.cfg.policy.type)
|
||||||
|
policy = policy_class.from_pretrained(self.cfg.policy.pretrained_path)
|
||||||
|
policy = policy.to(self.device)
|
||||||
|
policy.eval()
|
||||||
|
|
||||||
|
# Configure RTC
|
||||||
|
rtc_config = RTCConfig(
|
||||||
|
enabled=rtc_enabled,
|
||||||
|
execution_horizon=self.cfg.rtc.execution_horizon,
|
||||||
|
max_guidance_weight=self.cfg.rtc.max_guidance_weight,
|
||||||
|
prefix_attention_schedule=self.cfg.rtc.prefix_attention_schedule,
|
||||||
|
debug=rtc_debug,
|
||||||
|
debug_maxlen=self.cfg.rtc.debug_maxlen,
|
||||||
|
)
|
||||||
|
policy.config.rtc_config = rtc_config
|
||||||
|
policy.init_rtc_processor()
|
||||||
|
|
||||||
|
logging.info(f" RTC enabled: {rtc_enabled}")
|
||||||
|
logging.info(f" RTC debug: {rtc_debug}")
|
||||||
|
|
||||||
|
# Apply torch.compile to predict_action_chunk method if enabled
|
||||||
|
if self.cfg.use_torch_compile:
|
||||||
|
policy = self._apply_torch_compile(policy, name)
|
||||||
|
|
||||||
|
logging.info(f"✓ {name} initialized successfully")
|
||||||
|
return policy
|
||||||
|
|
||||||
|
def _apply_torch_compile(self, policy, policy_name: str):
|
||||||
|
"""Apply torch.compile to the policy's predict_action_chunk method.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
policy: Policy instance to compile
|
||||||
|
policy_name: Name for logging purposes
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Policy with compiled predict_action_chunk method
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
# Check if torch.compile is available (PyTorch 2.0+)
|
# Check if torch.compile is available (PyTorch 2.0+)
|
||||||
if not hasattr(torch, "compile"):
|
if not hasattr(torch, "compile"):
|
||||||
logging.warning(
|
logging.warning(
|
||||||
"torch.compile is not available. Requires PyTorch 2.0+. "
|
f" [{policy_name}] torch.compile is not available. Requires PyTorch 2.0+. "
|
||||||
f"Current version: {torch.__version__}. Skipping compilation."
|
f"Current version: {torch.__version__}. Skipping compilation."
|
||||||
)
|
)
|
||||||
return
|
return policy
|
||||||
|
|
||||||
logging.info("Applying torch.compile to policy model...")
|
logging.info(f" [{policy_name}] Applying torch.compile to predict_action_chunk...")
|
||||||
logging.info(f" Backend: {self.cfg.torch_compile_backend}")
|
logging.info(f" Backend: {self.cfg.torch_compile_backend}")
|
||||||
logging.info(f" Mode: {self.cfg.torch_compile_mode}")
|
logging.info(f" Mode: {self.cfg.torch_compile_mode}")
|
||||||
|
|
||||||
# Compile the policy's model (not the policy itself to preserve methods)
|
# Compile the predict_action_chunk method
|
||||||
if hasattr(self.policy, "model"):
|
original_method = policy.predict_action_chunk
|
||||||
original_model = self.policy.model
|
compiled_method = torch.compile(
|
||||||
compiled_model = torch.compile(
|
original_method,
|
||||||
original_model,
|
backend=self.cfg.torch_compile_backend,
|
||||||
backend=self.cfg.torch_compile_backend,
|
mode=self.cfg.torch_compile_mode,
|
||||||
mode=self.cfg.torch_compile_mode,
|
)
|
||||||
)
|
policy.predict_action_chunk = compiled_method
|
||||||
self.policy.model = compiled_model
|
logging.info(f" ✓ [{policy_name}] Successfully compiled predict_action_chunk")
|
||||||
logging.info("✓ Successfully compiled policy.model")
|
|
||||||
else:
|
|
||||||
logging.warning(
|
|
||||||
"Policy does not have a 'model' attribute. "
|
|
||||||
"Attempting to compile entire policy (may not work for all policy types)."
|
|
||||||
)
|
|
||||||
self.policy = torch.compile(
|
|
||||||
self.policy,
|
|
||||||
backend=self.cfg.torch_compile_backend,
|
|
||||||
mode=self.cfg.torch_compile_mode,
|
|
||||||
)
|
|
||||||
logging.info("✓ Successfully compiled policy")
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Failed to apply torch.compile: {e}")
|
logging.error(f" [{policy_name}] Failed to apply torch.compile: {e}")
|
||||||
logging.warning("Continuing without torch.compile")
|
logging.warning(f" [{policy_name}] Continuing without torch.compile")
|
||||||
|
|
||||||
|
return policy
|
||||||
|
|
||||||
|
def _destroy_policy(self, policy, policy_name: str):
|
||||||
|
"""Explicitly destroy a policy and free all associated memory.
|
||||||
|
|
||||||
|
This method performs aggressive cleanup to ensure maximum memory is freed,
|
||||||
|
which is critical for large models (e.g., VLAs with billions of parameters).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
policy: Policy instance to destroy
|
||||||
|
policy_name: Name for logging purposes
|
||||||
|
"""
|
||||||
|
logging.info(f" Destroying {policy_name} and freeing memory...")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Step 1: Move policy to CPU to free GPU/MPS memory
|
||||||
|
policy.cpu()
|
||||||
|
|
||||||
|
# Step 2: Delete the policy object
|
||||||
|
del policy
|
||||||
|
|
||||||
|
# Step 3: Force garbage collection to reclaim memory immediately
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
# Step 4: Clear device-specific caches
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
torch.cuda.synchronize() # Ensure all operations complete
|
||||||
|
|
||||||
|
if torch.backends.mps.is_available():
|
||||||
|
torch.mps.empty_cache()
|
||||||
|
|
||||||
|
logging.info(f" ✓ {policy_name} destroyed and memory freed")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(f" Warning: Error during {policy_name} cleanup: {e}")
|
||||||
|
|
||||||
def run_evaluation(self):
|
def run_evaluation(self):
|
||||||
"""Run evaluation on two random dataset samples."""
|
"""Run evaluation on two random dataset samples using three separate policies.
|
||||||
|
|
||||||
|
Note: Policies are deinitalized after each step to free memory. Large models
|
||||||
|
(e.g., VLA models with billions of parameters) cannot fit three instances in
|
||||||
|
memory simultaneously. By deleting and garbage collecting after each step,
|
||||||
|
we ensure only one policy is loaded at a time.
|
||||||
|
"""
|
||||||
# Create output directory
|
# Create output directory
|
||||||
os.makedirs(self.cfg.output_dir, exist_ok=True)
|
os.makedirs(self.cfg.output_dir, exist_ok=True)
|
||||||
logging.info(f"Output directory: {self.cfg.output_dir}")
|
logging.info(f"Output directory: {self.cfg.output_dir}")
|
||||||
|
|
||||||
|
logging.info("=" * 80)
|
||||||
logging.info("Starting RTC evaluation")
|
logging.info("Starting RTC evaluation")
|
||||||
logging.info(f"Inference delay: {self.cfg.inference_delay}")
|
logging.info(f"Inference delay: {self.cfg.inference_delay}")
|
||||||
|
logging.info("=" * 80)
|
||||||
|
|
||||||
|
# Load two random samples from dataset
|
||||||
data_loader = torch.utils.data.DataLoader(self.dataset, batch_size=1, shuffle=True)
|
data_loader = torch.utils.data.DataLoader(self.dataset, batch_size=1, shuffle=True)
|
||||||
loader_iter = iter(data_loader)
|
loader_iter = iter(data_loader)
|
||||||
first_sample = next(loader_iter)
|
first_sample = next(loader_iter)
|
||||||
@@ -267,50 +368,65 @@ class RTCEvaluator:
|
|||||||
preprocessed_first_sample = self.preprocessor(first_sample)
|
preprocessed_first_sample = self.preprocessor(first_sample)
|
||||||
preprocessed_second_sample = self.preprocessor(second_sample)
|
preprocessed_second_sample = self.preprocessor(second_sample)
|
||||||
|
|
||||||
# Don't postprocess the previous chunk
|
# Step 1: Generate previous chunk using policy_prev_chunk
|
||||||
prev_chunk_left_over = self.policy.predict_action_chunk(
|
# This policy is only used to generate the reference chunk and then freed
|
||||||
preprocessed_first_sample,
|
logging.info("Step 1: Generating previous chunk with policy_prev_chunk")
|
||||||
)[:, :25, :].squeeze(0)
|
with torch.no_grad():
|
||||||
|
prev_chunk_left_over = self.policy_prev_chunk.predict_action_chunk(
|
||||||
|
preprocessed_first_sample,
|
||||||
|
)[:, :25, :].squeeze(0)
|
||||||
|
logging.info(f" Generated prev_chunk shape: {prev_chunk_left_over.shape}")
|
||||||
|
|
||||||
self.policy.rtc_processor.reset_tracker()
|
# Destroy policy_prev_chunk to free memory for large models
|
||||||
|
self._destroy_policy(self.policy_prev_chunk, "policy_prev_chunk")
|
||||||
logging.info("Resetting tracker")
|
|
||||||
|
|
||||||
# Sample noise (use same noise for both RTC and non-RTC for fair comparison)
|
# Sample noise (use same noise for both RTC and non-RTC for fair comparison)
|
||||||
noise_size = (1, self.policy.config.chunk_size, self.policy.config.max_action_dim)
|
noise_size = (1, self.policy_no_rtc.config.chunk_size, self.policy_no_rtc.config.max_action_dim)
|
||||||
noise = self.policy.model.sample_noise(noise_size, self.device)
|
noise = self.policy_no_rtc.model.sample_noise(noise_size, self.device)
|
||||||
noise_clone = noise.clone()
|
noise_clone = noise.clone()
|
||||||
|
|
||||||
# Generate actions WITHOUT RTC
|
# Step 2: Generate actions WITHOUT RTC using policy_no_rtc
|
||||||
logging.info("Generating actions WITHOUT RTC")
|
logging.info("Step 2: Generating actions WITHOUT RTC with policy_no_rtc")
|
||||||
self.policy.config.rtc_config.enabled = False
|
self.policy_no_rtc.rtc_processor.reset_tracker()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
_ = self.policy.predict_action_chunk(
|
_ = self.policy_no_rtc.predict_action_chunk(
|
||||||
preprocessed_second_sample,
|
preprocessed_second_sample,
|
||||||
noise=noise,
|
noise=noise,
|
||||||
)
|
)
|
||||||
|
no_rtc_tracked_steps = self.policy_no_rtc.rtc_processor.tracker.get_all_steps()
|
||||||
|
logging.info(f" Tracked {len(no_rtc_tracked_steps)} steps without RTC")
|
||||||
|
|
||||||
no_rtc_tracked_steps = self.policy.rtc_processor.tracker.get_all_steps()
|
# Destroy policy_no_rtc to free memory before loading policy_rtc
|
||||||
self.policy.rtc_processor.reset_tracker()
|
self._destroy_policy(self.policy_no_rtc, "policy_no_rtc")
|
||||||
|
|
||||||
# Generate actions WITH RTC
|
# Step 3: Generate actions WITH RTC using policy_rtc
|
||||||
logging.info("Generating actions WITH RTC")
|
logging.info("Step 3: Generating actions WITH RTC with policy_rtc")
|
||||||
self.policy.config.rtc_config.enabled = True
|
self.policy_rtc.rtc_processor.reset_tracker()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
_ = self.policy.predict_action_chunk(
|
_ = self.policy_rtc.predict_action_chunk(
|
||||||
preprocessed_second_sample,
|
preprocessed_second_sample,
|
||||||
noise=noise_clone,
|
noise=noise_clone,
|
||||||
inference_delay=self.cfg.inference_delay,
|
inference_delay=self.cfg.inference_delay,
|
||||||
prev_chunk_left_over=prev_chunk_left_over,
|
prev_chunk_left_over=prev_chunk_left_over,
|
||||||
execution_horizon=self.cfg.rtc.execution_horizon,
|
execution_horizon=self.cfg.rtc.execution_horizon,
|
||||||
)
|
)
|
||||||
|
rtc_tracked_steps = self.policy_rtc.rtc_processor.get_all_debug_steps()
|
||||||
|
logging.info(f" Tracked {len(rtc_tracked_steps)} steps with RTC")
|
||||||
|
|
||||||
rtc_tracked_steps = self.policy.rtc_processor.get_all_debug_steps()
|
# Save num_steps before destroying policy (needed for plotting)
|
||||||
|
num_steps = self.policy_rtc.config.num_steps
|
||||||
|
|
||||||
self.plot_tracked_data(rtc_tracked_steps, no_rtc_tracked_steps, prev_chunk_left_over)
|
# Destroy policy_rtc after final use
|
||||||
|
self._destroy_policy(self.policy_rtc, "policy_rtc")
|
||||||
|
|
||||||
|
# Plot and save results
|
||||||
|
logging.info("=" * 80)
|
||||||
|
logging.info("Plotting results...")
|
||||||
|
self.plot_tracked_data(rtc_tracked_steps, no_rtc_tracked_steps, prev_chunk_left_over, num_steps)
|
||||||
|
logging.info("=" * 80)
|
||||||
logging.info("Evaluation completed successfully")
|
logging.info("Evaluation completed successfully")
|
||||||
|
|
||||||
def plot_tracked_data(self, rtc_tracked_steps, no_rtc_tracked_steps, prev_chunk_left_over):
|
def plot_tracked_data(self, rtc_tracked_steps, no_rtc_tracked_steps, prev_chunk_left_over, num_steps):
|
||||||
# Create side-by-side figures for denoising visualization
|
# Create side-by-side figures for denoising visualization
|
||||||
fig_xt, axs_xt = self._create_figure("x_t Denoising: No RTC (left) vs RTC (right)")
|
fig_xt, axs_xt = self._create_figure("x_t Denoising: No RTC (left) vs RTC (right)")
|
||||||
fig_vt, axs_vt = self._create_figure("v_t Denoising: No RTC (left) vs RTC (right)")
|
fig_vt, axs_vt = self._create_figure("v_t Denoising: No RTC (left) vs RTC (right)")
|
||||||
@@ -318,8 +434,6 @@ class RTCEvaluator:
|
|||||||
fig_x1t, axs_x1t = self._create_figure(
|
fig_x1t, axs_x1t = self._create_figure(
|
||||||
"x1_t Predicted State & Error: No RTC (left - empty) vs RTC (right)"
|
"x1_t Predicted State & Error: No RTC (left - empty) vs RTC (right)"
|
||||||
)
|
)
|
||||||
|
|
||||||
num_steps = self.policy.config.num_steps
|
|
||||||
self._plot_denoising_steps_from_tracker(
|
self._plot_denoising_steps_from_tracker(
|
||||||
rtc_tracked_steps,
|
rtc_tracked_steps,
|
||||||
axs_xt[:, 1], # Right column for x_t
|
axs_xt[:, 1], # Right column for x_t
|
||||||
|
|||||||
Reference in New Issue
Block a user