diff --git a/examples/rtc/eval_dataset.py b/examples/rtc/eval_dataset.py index 737e9f890..43163154d 100644 --- a/examples/rtc/eval_dataset.py +++ b/examples/rtc/eval_dataset.py @@ -82,7 +82,8 @@ from lerobot.configs import parser from lerobot.configs.default import DatasetConfig from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.types import RTCAttentionSchedule -from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.datasets.factory import resolve_delta_timestamps +from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata from lerobot.policies.factory import get_policy_class, make_pre_post_processors from lerobot.policies.rtc.configuration_rtc import RTCConfig from lerobot.policies.rtc.debug_visualizer import RTCDebugVisualizer @@ -206,9 +207,21 @@ class RTCEvaluator: self.cfg = cfg self.device = cfg.device - # Load dataset first (needed for preprocessor) + # Load dataset with proper delta_timestamps based on policy configuration + # Calculate delta_timestamps using the same logic as make_dataset factory logging.info(f"Loading dataset: {cfg.dataset.repo_id}") - self.dataset = LeRobotDataset(cfg.dataset.repo_id, delta_timestamps={"action": np.arange(50) / 30}) + + # Get dataset metadata to extract FPS + ds_meta = LeRobotDatasetMetadata(cfg.dataset.repo_id) + + # Calculate delta_timestamps from policy's delta_indices + delta_timestamps = resolve_delta_timestamps(cfg.policy, ds_meta) + + # Create dataset with calculated delta_timestamps + self.dataset = LeRobotDataset( + cfg.dataset.repo_id, + delta_timestamps=delta_timestamps, + ) logging.info(f"Dataset loaded: {len(self.dataset)} samples, {self.dataset.num_episodes} episodes") # Create preprocessor/postprocessor @@ -220,17 +233,12 @@ class RTCEvaluator: }, ) - # Initialize three separate policy instances - # Note: These policies are initialized here but will be freed sequentially during - # evaluation to manage memory. Large models (e.g., VLAs with billions of parameters) - # cannot fit three instances in memory simultaneously. Each policy is deleted and - # memory is freed (via torch.cuda.empty_cache()) immediately after its use. logging.info("=" * 80) - logging.info("Initializing three policy instances:") - logging.info(" 1. policy_prev_chunk (for generating previous chunk)") - logging.info(" 2. policy_no_rtc (for non-RTC inference)") - logging.info(" 3. policy_rtc (for RTC inference)") - logging.info(" Note: Policies will be freed sequentially during evaluation to manage memory") + logging.info("Ready to run evaluation with sequential policy loading:") + logging.info(" 1. policy_prev_chunk - Generate reference chunk, then destroy") + logging.info(" 2. policy_no_rtc - Generate without RTC, then destroy") + logging.info(" 3. policy_rtc - Generate with RTC, then destroy") + logging.info(" Note: Only one policy in memory at a time for efficient memory usage") logging.info("=" * 80) def _init_policy(self, name: str, rtc_enabled: bool, rtc_debug: bool): @@ -383,30 +391,20 @@ class RTCEvaluator: preprocessed_first_sample = self.preprocessor(first_sample) preprocessed_second_sample = self.preprocessor(second_sample) - # Policy 1: For generating previous chunk (RTC disabled, no debug) + # ============================================================================ + # Step 1: Generate previous chunk using policy_prev_chunk + # ============================================================================ + # This policy is only used to generate the reference chunk and then freed + logging.info("=" * 80) + logging.info("Step 1: Generating previous chunk with policy_prev_chunk") + logging.info("=" * 80) + + # Initialize policy 1 policy_prev_chunk_policy = self._init_policy( name="policy_prev_chunk", rtc_enabled=False, rtc_debug=False, ) - - # Policy 2: For non-RTC inference (RTC disabled, debug enabled) - policy_no_rtc_policy = self._init_policy( - name="policy_no_rtc", - rtc_enabled=False, - rtc_debug=True, - ) - - # Policy 3: For RTC inference (RTC enabled, debug enabled) - policy_rtc_policy = self._init_policy( - name="policy_rtc", - rtc_enabled=True, - rtc_debug=True, - ) - - # Step 1: Generate previous chunk using policy_prev_chunk - # This policy is only used to generate the reference chunk and then freed - logging.info("Step 1: Generating previous chunk with policy_prev_chunk") with torch.no_grad(): prev_chunk_left_over = policy_prev_chunk_policy.predict_action_chunk( preprocessed_first_sample, @@ -416,13 +414,24 @@ class RTCEvaluator: # Destroy policy_prev_chunk to free memory for large models self._destroy_policy(policy_prev_chunk_policy, "policy_prev_chunk") + # ============================================================================ + # Step 2: Generate actions WITHOUT RTC using policy_no_rtc + # ============================================================================ + logging.info("=" * 80) + logging.info("Step 2: Generating actions WITHOUT RTC with policy_no_rtc") + logging.info("=" * 80) + + # Initialize policy 2 + policy_no_rtc_policy = self._init_policy( + name="policy_no_rtc", + rtc_enabled=False, + rtc_debug=True, + ) + # Sample noise (use same noise for both RTC and non-RTC for fair comparison) noise_size = (1, policy_no_rtc_policy.config.chunk_size, policy_no_rtc_policy.config.max_action_dim) noise = policy_no_rtc_policy.model.sample_noise(noise_size, self.device) noise_clone = noise.clone() - - # Step 2: Generate actions WITHOUT RTC using policy_no_rtc - logging.info("Step 2: Generating actions WITHOUT RTC with policy_no_rtc") policy_no_rtc_policy.rtc_processor.reset_tracker() with torch.no_grad(): _ = policy_no_rtc_policy.predict_action_chunk( @@ -435,8 +444,19 @@ class RTCEvaluator: # Destroy policy_no_rtc to free memory before loading policy_rtc self._destroy_policy(policy_no_rtc_policy, "policy_no_rtc") + # ============================================================================ # Step 3: Generate actions WITH RTC using policy_rtc + # ============================================================================ + logging.info("=" * 80) logging.info("Step 3: Generating actions WITH RTC with policy_rtc") + logging.info("=" * 80) + + # Initialize policy 3 + policy_rtc_policy = self._init_policy( + name="policy_rtc", + rtc_enabled=True, + rtc_debug=True, + ) policy_rtc_policy.rtc_processor.reset_tracker() with torch.no_grad(): _ = policy_rtc_policy.predict_action_chunk(