diff --git a/scripts/convert_videos_to_images.py b/scripts/convert_videos_to_images.py new file mode 100644 index 000000000..d8e82deef --- /dev/null +++ b/scripts/convert_videos_to_images.py @@ -0,0 +1,74 @@ +#!/usr/bin/env python + +""" +Convert video dataset to image dataset for faster training. +This pre-extracts all frames from MP4 files to PNG images. +""" + +import argparse +from pathlib import Path +import logging +import shutil + +def convert_dataset_videos_to_images(repo_id: str, root: str | None = None): + """Convert all videos in a LeRobot dataset to individual image files.""" + from lerobot.datasets.lerobot_dataset import LeRobotDataset + from lerobot.datasets.video_utils import decode_video_frames + import torch + + # Load dataset + dataset = LeRobotDataset(repo_id, root=root, download_videos=True) + + total_frames_processed = 0 + + for ep_idx in range(dataset.meta.total_episodes): + logging.info(f"Processing episode {ep_idx}/{dataset.meta.total_episodes}") + + for vid_key in dataset.meta.video_keys: + video_path = dataset.root / dataset.meta.get_video_file_path(ep_idx, vid_key) + + if not video_path.exists(): + logging.warning(f"Video not found: {video_path}") + continue + + # Create image directory + img_dir = dataset.root / f"images/chunk-{dataset.meta.get_episode_chunk(ep_idx)}/{vid_key}" + img_dir.mkdir(parents=True, exist_ok=True) + + # Decode all frames from video + # Get episode length to decode all frames + ep_length = dataset.meta.episodes[ep_idx]["length"] + timestamps = [i / dataset.fps for i in range(ep_length)] + + try: + frames = decode_video_frames(video_path, timestamps, dataset.tolerance_s, dataset.video_backend) + + # Save each frame as PNG + for i, frame in enumerate(frames.squeeze(0)): + img_path = img_dir / f"episode_{ep_idx:06d}_{i:06d}.png" + # Convert tensor to PIL and save + import torchvision.transforms as T + to_pil = T.ToPILImage() + pil_frame = to_pil(frame) + pil_frame.save(img_path) + + total_frames_processed += len(frames.squeeze(0)) + logging.info(f" Extracted {len(frames.squeeze(0))} frames to {img_dir}") + + except Exception as e: + logging.error(f"Failed to process {video_path}: {e}") + continue + + logging.info(f"Conversion complete! Processed {total_frames_processed} total frames") + logging.info(f"You can now use download_videos=False to use the extracted images") + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Convert LeRobot video dataset to images") + parser.add_argument("repo_id", help="Dataset repo ID (e.g., 'kenmacken/record-test-2')") + parser.add_argument("--root", help="Local root directory", default=None) + + args = parser.parse_args() + + logging.basicConfig(level=logging.INFO) + + convert_dataset_videos_to_images(args.repo_id, args.root) diff --git a/src/lerobot/policies/rlearn/configuration_rlearn.py b/src/lerobot/policies/rlearn/configuration_rlearn.py index 9a1d41f17..2967a1e64 100644 --- a/src/lerobot/policies/rlearn/configuration_rlearn.py +++ b/src/lerobot/policies/rlearn/configuration_rlearn.py @@ -67,6 +67,7 @@ class RLearNConfig(PreTrainedConfig): # Performance optimizations use_amp: bool = True # Mixed precision training for speed boost compile_model: bool = True # torch.compile for additional speedup + video_backend: str = "pyav" # Use PyAV for faster video decoding (vs torchcodec) # ReWiND-specific parameters use_video_rewind: bool = True # Enable video rewinding augmentation diff --git a/src/lerobot/scripts/train.py b/src/lerobot/scripts/train.py index 2203bae2b..a034e21dd 100644 --- a/src/lerobot/scripts/train.py +++ b/src/lerobot/scripts/train.py @@ -86,6 +86,58 @@ def _add_video_decoding_timing(dataset): instrument_dataset(dataset) else: print(f"Warning: Unknown dataset type {type(dataset)}, skipping video timing instrumentation") + + +def _add_video_frame_caching(dataset, cache_size=1000): + """Add LRU caching to video decoding to avoid re-decoding the same frames.""" + from functools import lru_cache + from lerobot.datasets.lerobot_dataset import LeRobotDataset, MultiLeRobotDataset + + def instrument_dataset_caching(ds): + if not hasattr(ds, '_query_videos'): + return + + # Store original method + original_query_videos = ds._query_videos + + # Create cache key from timestamps and episode + def make_cache_key(query_timestamps, ep_idx): + # Convert to hashable tuple + key_parts = [ep_idx] + for vid_key in sorted(query_timestamps.keys()): + ts_tuple = tuple(round(ts, 6) for ts in query_timestamps[vid_key]) # Round to microsecond precision + key_parts.append((vid_key, ts_tuple)) + return tuple(key_parts) + + # Create LRU cached version + @lru_cache(maxsize=cache_size) + def cached_decode_frames(cache_key, ep_idx): + # Reconstruct query_timestamps from cache_key + query_timestamps = {} + for item in cache_key[1:]: # Skip ep_idx + vid_key, ts_tuple = item + query_timestamps[vid_key] = list(ts_tuple) + return original_query_videos(query_timestamps, ep_idx) + + def cached_query_videos(self, query_timestamps, ep_idx): + cache_key = make_cache_key(query_timestamps, ep_idx) + return cached_decode_frames(cache_key, ep_idx) + + # Bind the cached method to the instance + import types + ds._query_videos = types.MethodType(cached_query_videos, ds) + ds._cached_decode_frames = cached_decode_frames # Keep reference for cache info + + print(f"Added video frame caching with size {cache_size}") + + # Handle both single and multi datasets + if isinstance(dataset, MultiLeRobotDataset): + for ds in dataset._datasets: + instrument_dataset_caching(ds) + elif isinstance(dataset, LeRobotDataset): + instrument_dataset_caching(dataset) + else: + print(f"Warning: Unknown dataset type {type(dataset)}, skipping video caching") from termcolor import colored from torch.amp import GradScaler from torch.optim import Optimizer @@ -243,7 +295,13 @@ def train(cfg: TrainPipelineConfig): torch.backends.cuda.matmul.allow_tf32 = True logging.info("Creating dataset") - dataset = make_dataset(cfg) + # Pass video backend to dataset for RLearN optimization + dataset_kwargs = {} + if getattr(cfg.policy, "type", None) == "rlearn" and hasattr(cfg.policy, "video_backend"): + dataset_kwargs["video_backend"] = cfg.policy.video_backend + logging.info(f"Using video backend: {cfg.policy.video_backend}") + + dataset = make_dataset(cfg, **dataset_kwargs) # Add video decoding timing for RLearN debugging if getattr(cfg.policy, "type", None) == "rlearn": @@ -432,6 +490,12 @@ def train(cfg: TrainPipelineConfig): if recent_decodes: avg_video_decode = sum(recent_decodes) / len(recent_decodes) print(f" └─ Video decoding: ~{avg_video_decode:.2f} ms/call (included in data loading)") + + # Show cache hit rate if available + if hasattr(ds, '_cached_decode_frames'): + cache_info = ds._cached_decode_frames.cache_info() + hit_rate = cache_info.hits / max(cache_info.hits + cache_info.misses, 1) * 100 + print(f" └─ Cache hit rate: {hit_rate:.1f}% ({cache_info.hits}H/{cache_info.misses}M, size={cache_info.currsize})") except Exception: pass diff --git a/test_video_backends.py b/test_video_backends.py new file mode 100644 index 000000000..57b092151 --- /dev/null +++ b/test_video_backends.py @@ -0,0 +1,73 @@ +#!/usr/bin/env python + +""" +Quick benchmark to test video decoding speed across different backends. +""" + +import time +from pathlib import Path +import torch + +def test_video_backend(video_path, backend_name, num_frames=10): + """Test video decoding speed for a specific backend.""" + try: + from lerobot.datasets.video_utils import decode_video_frames + + # Create timestamps for first N frames + fps = 30 # Assume 30fps, adjust if needed + timestamps = [i / fps for i in range(num_frames)] + + # Time the decoding + start_time = time.perf_counter() + frames = decode_video_frames(video_path, timestamps, tolerance_s=1e-4, backend=backend_name) + decode_time = time.perf_counter() - start_time + + frames_decoded = frames.shape[1] if frames.dim() > 1 else frames.shape[0] + ms_per_frame = (decode_time * 1000) / max(frames_decoded, 1) + + print(f"✅ {backend_name:12} | {decode_time*1000:6.1f}ms total | {ms_per_frame:6.1f}ms/frame | {frames_decoded} frames") + return decode_time, frames_decoded + + except Exception as e: + print(f"❌ {backend_name:12} | ERROR: {str(e)[:50]}...") + return float('inf'), 0 + +def main(): + # Find your video files + video_dir = Path.home() / ".cache/huggingface/lerobot/kenmacken/record-test-2/videos" + video_files = list(video_dir.rglob("*.mp4")) + + if not video_files: + print("❌ No video files found! Check the path.") + return + + test_video = video_files[0] + print(f"Testing video: {test_video.name}") + print(f"File size: {test_video.stat().st_size / 1024 / 1024:.1f} MB") + print("-" * 60) + + backends = ["torchcodec", "pyav", "video_reader"] + results = {} + + for backend in backends: + decode_time, frames = test_video_backend(test_video, backend) + results[backend] = (decode_time, frames) + + print("-" * 60) + print("RECOMMENDATION:") + + # Find fastest backend + valid_results = {k: v for k, v in results.items() if v[0] != float('inf')} + if valid_results: + fastest = min(valid_results.items(), key=lambda x: x[1][0]) + print(f"🚀 Use '{fastest[0]}' - fastest backend!") + print(f" Add to your config: video_backend: \"{fastest[0]}\"") + + slowest_time = max(valid_results.values())[0] + speedup = slowest_time / fastest[1][0] + print(f" Speedup vs slowest: {speedup:.1f}x faster") + else: + print("❌ No backends worked!") + +if __name__ == "__main__": + main()