diff --git a/src/lerobot/scripts/train.py b/src/lerobot/scripts/train.py index 260d48cc8..2203bae2b 100644 --- a/src/lerobot/scripts/train.py +++ b/src/lerobot/scripts/train.py @@ -24,6 +24,68 @@ import torch # Fix tokenizer parallelism conflicts with multiprocessing os.environ["TOKENIZERS_PARALLELISM"] = "false" + + +def _add_video_decoding_timing(dataset): + """Add timing instrumentation to video decoding for debugging.""" + from lerobot.datasets.lerobot_dataset import LeRobotDataset, MultiLeRobotDataset + + def instrument_dataset(ds): + if not hasattr(ds, '_query_videos'): + return + + # Store original method + original_query_videos = ds._query_videos + + # Initialize timing stats + if not hasattr(ds, '_video_decode_timing'): + ds._video_decode_timing = { + 'decode_times': [], + 'last_print_time': time.perf_counter() + } + + def timed_query_videos(self, query_timestamps, ep_idx): + decode_start = time.perf_counter() + result = original_query_videos(query_timestamps, ep_idx) + decode_time = time.perf_counter() - decode_start + + # Accumulate timing + timing_stats = self._video_decode_timing + timing_stats['decode_times'].append(decode_time * 1000) # Convert to ms + + # Print averaged stats every minute + current_time = time.perf_counter() + if current_time - timing_stats['last_print_time'] >= 60.0: + n_samples = len(timing_stats['decode_times']) + if n_samples > 0: + avg_decode_time = sum(timing_stats['decode_times']) / n_samples + total_frames = sum(len(query_timestamps[key]) for key in query_timestamps) + avg_frames_per_call = total_frames / n_samples if n_samples > 0 else 0 + + print(f"\nVideo Decoding Timing (last {n_samples} calls):") + print(f" Avg decode time: {avg_decode_time:.2f} ms") + print(f" Avg frames/call: {avg_frames_per_call:.1f}") + print(f" Time per frame: {avg_decode_time/max(avg_frames_per_call, 1):.2f} ms/frame") + print("-" * 50) + + # Reset stats + timing_stats['decode_times'] = [] + timing_stats['last_print_time'] = current_time + + return result + + # Bind the method to the instance + import types + ds._query_videos = types.MethodType(timed_query_videos, ds) + + # Handle both single and multi datasets + if isinstance(dataset, MultiLeRobotDataset): + for ds in dataset._datasets: + instrument_dataset(ds) + elif isinstance(dataset, LeRobotDataset): + instrument_dataset(dataset) + else: + print(f"Warning: Unknown dataset type {type(dataset)}, skipping video timing instrumentation") from termcolor import colored from torch.amp import GradScaler from torch.optim import Optimizer @@ -182,6 +244,10 @@ def train(cfg: TrainPipelineConfig): logging.info("Creating dataset") dataset = make_dataset(cfg) + + # Add video decoding timing for RLearN debugging + if getattr(cfg.policy, "type", None) == "rlearn": + _add_video_decoding_timing(dataset) # Create environment used for evaluating checkpoints during training on simulation data. # On real-world data, no need to create an environment as evaluations are done outside train.py, @@ -357,6 +423,18 @@ def train(cfg: TrainPipelineConfig): print(f" Data loading: {avg_data_loading:.2f} ms") print(f" Preprocessing: {avg_preprocessing:.2f} ms") print(f" Total data pipeline: {avg_data_loading + avg_preprocessing:.2f} ms") + + # Show video decoding breakdown if available + try: + ds = dataset._datasets[0] if hasattr(dataset, '_datasets') else dataset + if hasattr(ds, '_video_decode_timing'): + recent_decodes = ds._video_decode_timing.get('decode_times', []) + if recent_decodes: + avg_video_decode = sum(recent_decodes) / len(recent_decodes) + print(f" └─ Video decoding: ~{avg_video_decode:.2f} ms/call (included in data loading)") + except Exception: + pass + print("-" * 50) # Reset stats for next minute