diff --git a/examples/rtc/eval_dataset.py b/examples/rtc/eval_dataset.py index f5ad118c2..de949cb2e 100644 --- a/examples/rtc/eval_dataset.py +++ b/examples/rtc/eval_dataset.py @@ -45,6 +45,7 @@ Usage: --torch_compile_mode=max-autotune """ +import gc import logging import os import random @@ -174,28 +175,7 @@ class RTCEvaluator: self.cfg = cfg self.device = cfg.device - # Load policy - logging.info(f"Loading policy from {cfg.policy.pretrained_path}") - policy_class = get_policy_class(cfg.policy.type) - self.policy = policy_class.from_pretrained(cfg.policy.pretrained_path) - self.policy = self.policy.to(self.device) - self.policy.eval() - - # Configure RTC - cfg.rtc.enabled = True - cfg.rtc.debug = True # Enable debug tracking for visualization - self.policy.config.rtc_config = cfg.rtc - self.policy.init_rtc_processor() - - # Apply torch.compile if enabled - if cfg.use_torch_compile: - self._apply_torch_compile() - - logging.info(f"Policy loaded: {self.policy.name}") - logging.info(f"RTC enabled: {cfg.rtc.enabled}") - logging.info(f"Execution horizon: {cfg.rtc.execution_horizon}") - - # Load dataset + # Load dataset first (needed for preprocessor) logging.info(f"Loading dataset: {cfg.dataset.repo_id}") self.dataset = LeRobotDataset(cfg.dataset.repo_id, delta_timestamps={"action": np.arange(50) / 30}) logging.info(f"Dataset loaded: {len(self.dataset)} samples, {self.dataset.num_episodes} episodes") @@ -209,56 +189,177 @@ class RTCEvaluator: }, ) - def _apply_torch_compile(self): - """Apply torch.compile to the policy model for faster inference.""" + # Initialize three separate policy instances + # Note: These policies are initialized here but will be freed sequentially during + # evaluation to manage memory. Large models (e.g., VLAs with billions of parameters) + # cannot fit three instances in memory simultaneously. Each policy is deleted and + # memory is freed (via torch.cuda.empty_cache()) immediately after its use. + logging.info("=" * 80) + logging.info("Initializing three policy instances:") + logging.info(" 1. policy_prev_chunk (for generating previous chunk)") + logging.info(" 2. policy_no_rtc (for non-RTC inference)") + logging.info(" 3. policy_rtc (for RTC inference)") + logging.info(" Note: Policies will be freed sequentially during evaluation to manage memory") + logging.info("=" * 80) + + # Policy 1: For generating previous chunk (RTC disabled, no debug) + self.policy_prev_chunk = self._init_policy( + name="policy_prev_chunk", + rtc_enabled=False, + rtc_debug=False, + ) + + # Policy 2: For non-RTC inference (RTC disabled, debug enabled) + self.policy_no_rtc = self._init_policy( + name="policy_no_rtc", + rtc_enabled=False, + rtc_debug=True, + ) + + # Policy 3: For RTC inference (RTC enabled, debug enabled) + self.policy_rtc = self._init_policy( + name="policy_rtc", + rtc_enabled=True, + rtc_debug=True, + ) + + logging.info("=" * 80) + logging.info("All policies initialized successfully") + logging.info("=" * 80) + + def _init_policy(self, name: str, rtc_enabled: bool, rtc_debug: bool): + """Initialize a single policy instance with specified RTC configuration. + + Args: + name: Name identifier for logging purposes + rtc_enabled: Whether to enable RTC for this policy + rtc_debug: Whether to enable debug tracking for this policy + + Returns: + Configured policy instance with optional torch.compile applied + """ + logging.info(f"Initializing {name}...") + + # Load policy from pretrained + policy_class = get_policy_class(self.cfg.policy.type) + policy = policy_class.from_pretrained(self.cfg.policy.pretrained_path) + policy = policy.to(self.device) + policy.eval() + + # Configure RTC + rtc_config = RTCConfig( + enabled=rtc_enabled, + execution_horizon=self.cfg.rtc.execution_horizon, + max_guidance_weight=self.cfg.rtc.max_guidance_weight, + prefix_attention_schedule=self.cfg.rtc.prefix_attention_schedule, + debug=rtc_debug, + debug_maxlen=self.cfg.rtc.debug_maxlen, + ) + policy.config.rtc_config = rtc_config + policy.init_rtc_processor() + + logging.info(f" RTC enabled: {rtc_enabled}") + logging.info(f" RTC debug: {rtc_debug}") + + # Apply torch.compile to predict_action_chunk method if enabled + if self.cfg.use_torch_compile: + policy = self._apply_torch_compile(policy, name) + + logging.info(f"✓ {name} initialized successfully") + return policy + + def _apply_torch_compile(self, policy, policy_name: str): + """Apply torch.compile to the policy's predict_action_chunk method. + + Args: + policy: Policy instance to compile + policy_name: Name for logging purposes + + Returns: + Policy with compiled predict_action_chunk method + """ try: # Check if torch.compile is available (PyTorch 2.0+) if not hasattr(torch, "compile"): logging.warning( - "torch.compile is not available. Requires PyTorch 2.0+. " + f" [{policy_name}] torch.compile is not available. Requires PyTorch 2.0+. " f"Current version: {torch.__version__}. Skipping compilation." ) - return + return policy - logging.info("Applying torch.compile to policy model...") - logging.info(f" Backend: {self.cfg.torch_compile_backend}") - logging.info(f" Mode: {self.cfg.torch_compile_mode}") + logging.info(f" [{policy_name}] Applying torch.compile to predict_action_chunk...") + logging.info(f" Backend: {self.cfg.torch_compile_backend}") + logging.info(f" Mode: {self.cfg.torch_compile_mode}") - # Compile the policy's model (not the policy itself to preserve methods) - if hasattr(self.policy, "model"): - original_model = self.policy.model - compiled_model = torch.compile( - original_model, - backend=self.cfg.torch_compile_backend, - mode=self.cfg.torch_compile_mode, - ) - self.policy.model = compiled_model - logging.info("✓ Successfully compiled policy.model") - else: - logging.warning( - "Policy does not have a 'model' attribute. " - "Attempting to compile entire policy (may not work for all policy types)." - ) - self.policy = torch.compile( - self.policy, - backend=self.cfg.torch_compile_backend, - mode=self.cfg.torch_compile_mode, - ) - logging.info("✓ Successfully compiled policy") + # Compile the predict_action_chunk method + original_method = policy.predict_action_chunk + compiled_method = torch.compile( + original_method, + backend=self.cfg.torch_compile_backend, + mode=self.cfg.torch_compile_mode, + ) + policy.predict_action_chunk = compiled_method + logging.info(f" ✓ [{policy_name}] Successfully compiled predict_action_chunk") except Exception as e: - logging.error(f"Failed to apply torch.compile: {e}") - logging.warning("Continuing without torch.compile") + logging.error(f" [{policy_name}] Failed to apply torch.compile: {e}") + logging.warning(f" [{policy_name}] Continuing without torch.compile") + + return policy + + def _destroy_policy(self, policy, policy_name: str): + """Explicitly destroy a policy and free all associated memory. + + This method performs aggressive cleanup to ensure maximum memory is freed, + which is critical for large models (e.g., VLAs with billions of parameters). + + Args: + policy: Policy instance to destroy + policy_name: Name for logging purposes + """ + logging.info(f" Destroying {policy_name} and freeing memory...") + + try: + # Step 1: Move policy to CPU to free GPU/MPS memory + policy.cpu() + + # Step 2: Delete the policy object + del policy + + # Step 3: Force garbage collection to reclaim memory immediately + gc.collect() + + # Step 4: Clear device-specific caches + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize() # Ensure all operations complete + + if torch.backends.mps.is_available(): + torch.mps.empty_cache() + + logging.info(f" ✓ {policy_name} destroyed and memory freed") + + except Exception as e: + logging.warning(f" Warning: Error during {policy_name} cleanup: {e}") def run_evaluation(self): - """Run evaluation on two random dataset samples.""" + """Run evaluation on two random dataset samples using three separate policies. + + Note: Policies are deinitalized after each step to free memory. Large models + (e.g., VLA models with billions of parameters) cannot fit three instances in + memory simultaneously. By deleting and garbage collecting after each step, + we ensure only one policy is loaded at a time. + """ # Create output directory os.makedirs(self.cfg.output_dir, exist_ok=True) logging.info(f"Output directory: {self.cfg.output_dir}") + logging.info("=" * 80) logging.info("Starting RTC evaluation") logging.info(f"Inference delay: {self.cfg.inference_delay}") + logging.info("=" * 80) + # Load two random samples from dataset data_loader = torch.utils.data.DataLoader(self.dataset, batch_size=1, shuffle=True) loader_iter = iter(data_loader) first_sample = next(loader_iter) @@ -267,50 +368,65 @@ class RTCEvaluator: preprocessed_first_sample = self.preprocessor(first_sample) preprocessed_second_sample = self.preprocessor(second_sample) - # Don't postprocess the previous chunk - prev_chunk_left_over = self.policy.predict_action_chunk( - preprocessed_first_sample, - )[:, :25, :].squeeze(0) + # Step 1: Generate previous chunk using policy_prev_chunk + # This policy is only used to generate the reference chunk and then freed + logging.info("Step 1: Generating previous chunk with policy_prev_chunk") + with torch.no_grad(): + prev_chunk_left_over = self.policy_prev_chunk.predict_action_chunk( + preprocessed_first_sample, + )[:, :25, :].squeeze(0) + logging.info(f" Generated prev_chunk shape: {prev_chunk_left_over.shape}") - self.policy.rtc_processor.reset_tracker() - - logging.info("Resetting tracker") + # Destroy policy_prev_chunk to free memory for large models + self._destroy_policy(self.policy_prev_chunk, "policy_prev_chunk") # Sample noise (use same noise for both RTC and non-RTC for fair comparison) - noise_size = (1, self.policy.config.chunk_size, self.policy.config.max_action_dim) - noise = self.policy.model.sample_noise(noise_size, self.device) + noise_size = (1, self.policy_no_rtc.config.chunk_size, self.policy_no_rtc.config.max_action_dim) + noise = self.policy_no_rtc.model.sample_noise(noise_size, self.device) noise_clone = noise.clone() - # Generate actions WITHOUT RTC - logging.info("Generating actions WITHOUT RTC") - self.policy.config.rtc_config.enabled = False + # Step 2: Generate actions WITHOUT RTC using policy_no_rtc + logging.info("Step 2: Generating actions WITHOUT RTC with policy_no_rtc") + self.policy_no_rtc.rtc_processor.reset_tracker() with torch.no_grad(): - _ = self.policy.predict_action_chunk( + _ = self.policy_no_rtc.predict_action_chunk( preprocessed_second_sample, noise=noise, ) + no_rtc_tracked_steps = self.policy_no_rtc.rtc_processor.tracker.get_all_steps() + logging.info(f" Tracked {len(no_rtc_tracked_steps)} steps without RTC") - no_rtc_tracked_steps = self.policy.rtc_processor.tracker.get_all_steps() - self.policy.rtc_processor.reset_tracker() + # Destroy policy_no_rtc to free memory before loading policy_rtc + self._destroy_policy(self.policy_no_rtc, "policy_no_rtc") - # Generate actions WITH RTC - logging.info("Generating actions WITH RTC") - self.policy.config.rtc_config.enabled = True + # Step 3: Generate actions WITH RTC using policy_rtc + logging.info("Step 3: Generating actions WITH RTC with policy_rtc") + self.policy_rtc.rtc_processor.reset_tracker() with torch.no_grad(): - _ = self.policy.predict_action_chunk( + _ = self.policy_rtc.predict_action_chunk( preprocessed_second_sample, noise=noise_clone, inference_delay=self.cfg.inference_delay, prev_chunk_left_over=prev_chunk_left_over, execution_horizon=self.cfg.rtc.execution_horizon, ) + rtc_tracked_steps = self.policy_rtc.rtc_processor.get_all_debug_steps() + logging.info(f" Tracked {len(rtc_tracked_steps)} steps with RTC") - rtc_tracked_steps = self.policy.rtc_processor.get_all_debug_steps() + # Save num_steps before destroying policy (needed for plotting) + num_steps = self.policy_rtc.config.num_steps - self.plot_tracked_data(rtc_tracked_steps, no_rtc_tracked_steps, prev_chunk_left_over) + # Destroy policy_rtc after final use + self._destroy_policy(self.policy_rtc, "policy_rtc") + + # Plot and save results + logging.info("=" * 80) + logging.info("Plotting results...") + self.plot_tracked_data(rtc_tracked_steps, no_rtc_tracked_steps, prev_chunk_left_over, num_steps) + logging.info("=" * 80) logging.info("Evaluation completed successfully") - def plot_tracked_data(self, rtc_tracked_steps, no_rtc_tracked_steps, prev_chunk_left_over): + def plot_tracked_data(self, rtc_tracked_steps, no_rtc_tracked_steps, prev_chunk_left_over, num_steps): # Create side-by-side figures for denoising visualization fig_xt, axs_xt = self._create_figure("x_t Denoising: No RTC (left) vs RTC (right)") fig_vt, axs_vt = self._create_figure("v_t Denoising: No RTC (left) vs RTC (right)") @@ -318,8 +434,6 @@ class RTCEvaluator: fig_x1t, axs_x1t = self._create_figure( "x1_t Predicted State & Error: No RTC (left - empty) vs RTC (right)" ) - - num_steps = self.policy.config.num_steps self._plot_denoising_steps_from_tracker( rtc_tracked_steps, axs_xt[:, 1], # Right column for x_t