diff --git a/examples/rtc/eval_dataset.py b/examples/rtc/eval_dataset.py index de949cb2e..f5e9a424a 100644 --- a/examples/rtc/eval_dataset.py +++ b/examples/rtc/eval_dataset.py @@ -27,7 +27,7 @@ Usage: --use_torch_compile=true \ --torch_compile_mode=max-autotune - # With torch.compile for faster inference (PyTorch 2.0+) + # With torch.compile on CUDA uv run python examples/rtc/eval_dataset.py \ --policy.path=helper2424/smolvla_check_rtc_last3 \ --dataset.repo_id=helper2424/check_rtc \ @@ -202,31 +202,6 @@ class RTCEvaluator: logging.info(" Note: Policies will be freed sequentially during evaluation to manage memory") logging.info("=" * 80) - # Policy 1: For generating previous chunk (RTC disabled, no debug) - self.policy_prev_chunk = self._init_policy( - name="policy_prev_chunk", - rtc_enabled=False, - rtc_debug=False, - ) - - # Policy 2: For non-RTC inference (RTC disabled, debug enabled) - self.policy_no_rtc = self._init_policy( - name="policy_no_rtc", - rtc_enabled=False, - rtc_debug=True, - ) - - # Policy 3: For RTC inference (RTC enabled, debug enabled) - self.policy_rtc = self._init_policy( - name="policy_rtc", - rtc_enabled=True, - rtc_debug=True, - ) - - logging.info("=" * 80) - logging.info("All policies initialized successfully") - logging.info("=" * 80) - def _init_policy(self, name: str, rtc_enabled: bool, rtc_debug: bool): """Initialize a single policy instance with specified RTC configuration. @@ -290,8 +265,11 @@ class RTCEvaluator: logging.info(f" [{policy_name}] Applying torch.compile to predict_action_chunk...") logging.info(f" Backend: {self.cfg.torch_compile_backend}") logging.info(f" Mode: {self.cfg.torch_compile_mode}") + logging.info(" Note: Debug tracker excluded from compilation via @torch._dynamo.disable") # Compile the predict_action_chunk method + # The debug tracker is excluded from compilation via @torch._dynamo.disable decorator + # on the Tracker.track() method, so it won't cause graph breaks original_method = policy.predict_action_chunk compiled_method = torch.compile( original_method, @@ -368,56 +346,77 @@ class RTCEvaluator: preprocessed_first_sample = self.preprocessor(first_sample) preprocessed_second_sample = self.preprocessor(second_sample) + # Policy 1: For generating previous chunk (RTC disabled, no debug) + policy_prev_chunk_policy = self._init_policy( + name="policy_prev_chunk", + rtc_enabled=False, + rtc_debug=False, + ) + + # Policy 2: For non-RTC inference (RTC disabled, debug enabled) + policy_no_rtc_policy = self._init_policy( + name="policy_no_rtc", + rtc_enabled=False, + rtc_debug=True, + ) + + # Policy 3: For RTC inference (RTC enabled, debug enabled) + policy_rtc_policy = self._init_policy( + name="policy_rtc", + rtc_enabled=True, + rtc_debug=True, + ) + # Step 1: Generate previous chunk using policy_prev_chunk # This policy is only used to generate the reference chunk and then freed logging.info("Step 1: Generating previous chunk with policy_prev_chunk") with torch.no_grad(): - prev_chunk_left_over = self.policy_prev_chunk.predict_action_chunk( + prev_chunk_left_over = policy_prev_chunk_policy.predict_action_chunk( preprocessed_first_sample, )[:, :25, :].squeeze(0) logging.info(f" Generated prev_chunk shape: {prev_chunk_left_over.shape}") # Destroy policy_prev_chunk to free memory for large models - self._destroy_policy(self.policy_prev_chunk, "policy_prev_chunk") + self._destroy_policy(policy_prev_chunk_policy, "policy_prev_chunk") # Sample noise (use same noise for both RTC and non-RTC for fair comparison) - noise_size = (1, self.policy_no_rtc.config.chunk_size, self.policy_no_rtc.config.max_action_dim) - noise = self.policy_no_rtc.model.sample_noise(noise_size, self.device) + noise_size = (1, policy_no_rtc_policy.config.chunk_size, policy_no_rtc_policy.config.max_action_dim) + noise = policy_no_rtc_policy.model.sample_noise(noise_size, self.device) noise_clone = noise.clone() # Step 2: Generate actions WITHOUT RTC using policy_no_rtc logging.info("Step 2: Generating actions WITHOUT RTC with policy_no_rtc") - self.policy_no_rtc.rtc_processor.reset_tracker() + policy_no_rtc_policy.rtc_processor.reset_tracker() with torch.no_grad(): - _ = self.policy_no_rtc.predict_action_chunk( + _ = policy_no_rtc_policy.predict_action_chunk( preprocessed_second_sample, noise=noise, ) - no_rtc_tracked_steps = self.policy_no_rtc.rtc_processor.tracker.get_all_steps() + no_rtc_tracked_steps = policy_no_rtc_policy.rtc_processor.tracker.get_all_steps() logging.info(f" Tracked {len(no_rtc_tracked_steps)} steps without RTC") # Destroy policy_no_rtc to free memory before loading policy_rtc - self._destroy_policy(self.policy_no_rtc, "policy_no_rtc") + self._destroy_policy(policy_no_rtc_policy, "policy_no_rtc") # Step 3: Generate actions WITH RTC using policy_rtc logging.info("Step 3: Generating actions WITH RTC with policy_rtc") - self.policy_rtc.rtc_processor.reset_tracker() + policy_rtc_policy.rtc_processor.reset_tracker() with torch.no_grad(): - _ = self.policy_rtc.predict_action_chunk( + _ = policy_rtc_policy.predict_action_chunk( preprocessed_second_sample, noise=noise_clone, inference_delay=self.cfg.inference_delay, prev_chunk_left_over=prev_chunk_left_over, execution_horizon=self.cfg.rtc.execution_horizon, ) - rtc_tracked_steps = self.policy_rtc.rtc_processor.get_all_debug_steps() + rtc_tracked_steps = policy_rtc_policy.rtc_processor.get_all_debug_steps() logging.info(f" Tracked {len(rtc_tracked_steps)} steps with RTC") # Save num_steps before destroying policy (needed for plotting) - num_steps = self.policy_rtc.config.num_steps + num_steps = policy_rtc_policy.config.num_steps # Destroy policy_rtc after final use - self._destroy_policy(self.policy_rtc, "policy_rtc") + self._destroy_policy(policy_rtc_policy, "policy_rtc") # Plot and save results logging.info("=" * 80) diff --git a/src/lerobot/policies/rtc/debug_tracker.py b/src/lerobot/policies/rtc/debug_tracker.py index 135281f9e..f143c223b 100644 --- a/src/lerobot/policies/rtc/debug_tracker.py +++ b/src/lerobot/policies/rtc/debug_tracker.py @@ -19,6 +19,7 @@ from dataclasses import dataclass, field from typing import Any +import torch from torch import Tensor @@ -120,6 +121,7 @@ class Tracker: self._steps.clear() self._step_counter = 0 + @torch._dynamo.disable def track( self, time: float | Tensor, @@ -139,6 +141,9 @@ class Tracker: If a step with the given time already exists, it will be updated with the new data. Otherwise, a new step will be created. Only non-None fields are updated/set. + Note: This method is excluded from torch.compile to avoid graph breaks from + operations like .item() which are incompatible with compiled graphs. + Args: time (float | Tensor): Time parameter - used as the key to identify the step. x_t (Tensor | None): Current latent/state tensor.