mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-21 11:39:50 +00:00
add decode logging
This commit is contained in:
@@ -0,0 +1,74 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
"""
|
||||||
|
Convert video dataset to image dataset for faster training.
|
||||||
|
This pre-extracts all frames from MP4 files to PNG images.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
from pathlib import Path
|
||||||
|
import logging
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
def convert_dataset_videos_to_images(repo_id: str, root: str | None = None):
|
||||||
|
"""Convert all videos in a LeRobot dataset to individual image files."""
|
||||||
|
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||||
|
from lerobot.datasets.video_utils import decode_video_frames
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# Load dataset
|
||||||
|
dataset = LeRobotDataset(repo_id, root=root, download_videos=True)
|
||||||
|
|
||||||
|
total_frames_processed = 0
|
||||||
|
|
||||||
|
for ep_idx in range(dataset.meta.total_episodes):
|
||||||
|
logging.info(f"Processing episode {ep_idx}/{dataset.meta.total_episodes}")
|
||||||
|
|
||||||
|
for vid_key in dataset.meta.video_keys:
|
||||||
|
video_path = dataset.root / dataset.meta.get_video_file_path(ep_idx, vid_key)
|
||||||
|
|
||||||
|
if not video_path.exists():
|
||||||
|
logging.warning(f"Video not found: {video_path}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Create image directory
|
||||||
|
img_dir = dataset.root / f"images/chunk-{dataset.meta.get_episode_chunk(ep_idx)}/{vid_key}"
|
||||||
|
img_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Decode all frames from video
|
||||||
|
# Get episode length to decode all frames
|
||||||
|
ep_length = dataset.meta.episodes[ep_idx]["length"]
|
||||||
|
timestamps = [i / dataset.fps for i in range(ep_length)]
|
||||||
|
|
||||||
|
try:
|
||||||
|
frames = decode_video_frames(video_path, timestamps, dataset.tolerance_s, dataset.video_backend)
|
||||||
|
|
||||||
|
# Save each frame as PNG
|
||||||
|
for i, frame in enumerate(frames.squeeze(0)):
|
||||||
|
img_path = img_dir / f"episode_{ep_idx:06d}_{i:06d}.png"
|
||||||
|
# Convert tensor to PIL and save
|
||||||
|
import torchvision.transforms as T
|
||||||
|
to_pil = T.ToPILImage()
|
||||||
|
pil_frame = to_pil(frame)
|
||||||
|
pil_frame.save(img_path)
|
||||||
|
|
||||||
|
total_frames_processed += len(frames.squeeze(0))
|
||||||
|
logging.info(f" Extracted {len(frames.squeeze(0))} frames to {img_dir}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Failed to process {video_path}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
logging.info(f"Conversion complete! Processed {total_frames_processed} total frames")
|
||||||
|
logging.info(f"You can now use download_videos=False to use the extracted images")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="Convert LeRobot video dataset to images")
|
||||||
|
parser.add_argument("repo_id", help="Dataset repo ID (e.g., 'kenmacken/record-test-2')")
|
||||||
|
parser.add_argument("--root", help="Local root directory", default=None)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
|
convert_dataset_videos_to_images(args.repo_id, args.root)
|
||||||
@@ -67,6 +67,7 @@ class RLearNConfig(PreTrainedConfig):
|
|||||||
# Performance optimizations
|
# Performance optimizations
|
||||||
use_amp: bool = True # Mixed precision training for speed boost
|
use_amp: bool = True # Mixed precision training for speed boost
|
||||||
compile_model: bool = True # torch.compile for additional speedup
|
compile_model: bool = True # torch.compile for additional speedup
|
||||||
|
video_backend: str = "pyav" # Use PyAV for faster video decoding (vs torchcodec)
|
||||||
|
|
||||||
# ReWiND-specific parameters
|
# ReWiND-specific parameters
|
||||||
use_video_rewind: bool = True # Enable video rewinding augmentation
|
use_video_rewind: bool = True # Enable video rewinding augmentation
|
||||||
|
|||||||
@@ -86,6 +86,58 @@ def _add_video_decoding_timing(dataset):
|
|||||||
instrument_dataset(dataset)
|
instrument_dataset(dataset)
|
||||||
else:
|
else:
|
||||||
print(f"Warning: Unknown dataset type {type(dataset)}, skipping video timing instrumentation")
|
print(f"Warning: Unknown dataset type {type(dataset)}, skipping video timing instrumentation")
|
||||||
|
|
||||||
|
|
||||||
|
def _add_video_frame_caching(dataset, cache_size=1000):
|
||||||
|
"""Add LRU caching to video decoding to avoid re-decoding the same frames."""
|
||||||
|
from functools import lru_cache
|
||||||
|
from lerobot.datasets.lerobot_dataset import LeRobotDataset, MultiLeRobotDataset
|
||||||
|
|
||||||
|
def instrument_dataset_caching(ds):
|
||||||
|
if not hasattr(ds, '_query_videos'):
|
||||||
|
return
|
||||||
|
|
||||||
|
# Store original method
|
||||||
|
original_query_videos = ds._query_videos
|
||||||
|
|
||||||
|
# Create cache key from timestamps and episode
|
||||||
|
def make_cache_key(query_timestamps, ep_idx):
|
||||||
|
# Convert to hashable tuple
|
||||||
|
key_parts = [ep_idx]
|
||||||
|
for vid_key in sorted(query_timestamps.keys()):
|
||||||
|
ts_tuple = tuple(round(ts, 6) for ts in query_timestamps[vid_key]) # Round to microsecond precision
|
||||||
|
key_parts.append((vid_key, ts_tuple))
|
||||||
|
return tuple(key_parts)
|
||||||
|
|
||||||
|
# Create LRU cached version
|
||||||
|
@lru_cache(maxsize=cache_size)
|
||||||
|
def cached_decode_frames(cache_key, ep_idx):
|
||||||
|
# Reconstruct query_timestamps from cache_key
|
||||||
|
query_timestamps = {}
|
||||||
|
for item in cache_key[1:]: # Skip ep_idx
|
||||||
|
vid_key, ts_tuple = item
|
||||||
|
query_timestamps[vid_key] = list(ts_tuple)
|
||||||
|
return original_query_videos(query_timestamps, ep_idx)
|
||||||
|
|
||||||
|
def cached_query_videos(self, query_timestamps, ep_idx):
|
||||||
|
cache_key = make_cache_key(query_timestamps, ep_idx)
|
||||||
|
return cached_decode_frames(cache_key, ep_idx)
|
||||||
|
|
||||||
|
# Bind the cached method to the instance
|
||||||
|
import types
|
||||||
|
ds._query_videos = types.MethodType(cached_query_videos, ds)
|
||||||
|
ds._cached_decode_frames = cached_decode_frames # Keep reference for cache info
|
||||||
|
|
||||||
|
print(f"Added video frame caching with size {cache_size}")
|
||||||
|
|
||||||
|
# Handle both single and multi datasets
|
||||||
|
if isinstance(dataset, MultiLeRobotDataset):
|
||||||
|
for ds in dataset._datasets:
|
||||||
|
instrument_dataset_caching(ds)
|
||||||
|
elif isinstance(dataset, LeRobotDataset):
|
||||||
|
instrument_dataset_caching(dataset)
|
||||||
|
else:
|
||||||
|
print(f"Warning: Unknown dataset type {type(dataset)}, skipping video caching")
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
from torch.amp import GradScaler
|
from torch.amp import GradScaler
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
@@ -243,7 +295,13 @@ def train(cfg: TrainPipelineConfig):
|
|||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
|
||||||
logging.info("Creating dataset")
|
logging.info("Creating dataset")
|
||||||
dataset = make_dataset(cfg)
|
# Pass video backend to dataset for RLearN optimization
|
||||||
|
dataset_kwargs = {}
|
||||||
|
if getattr(cfg.policy, "type", None) == "rlearn" and hasattr(cfg.policy, "video_backend"):
|
||||||
|
dataset_kwargs["video_backend"] = cfg.policy.video_backend
|
||||||
|
logging.info(f"Using video backend: {cfg.policy.video_backend}")
|
||||||
|
|
||||||
|
dataset = make_dataset(cfg, **dataset_kwargs)
|
||||||
|
|
||||||
# Add video decoding timing for RLearN debugging
|
# Add video decoding timing for RLearN debugging
|
||||||
if getattr(cfg.policy, "type", None) == "rlearn":
|
if getattr(cfg.policy, "type", None) == "rlearn":
|
||||||
@@ -432,6 +490,12 @@ def train(cfg: TrainPipelineConfig):
|
|||||||
if recent_decodes:
|
if recent_decodes:
|
||||||
avg_video_decode = sum(recent_decodes) / len(recent_decodes)
|
avg_video_decode = sum(recent_decodes) / len(recent_decodes)
|
||||||
print(f" └─ Video decoding: ~{avg_video_decode:.2f} ms/call (included in data loading)")
|
print(f" └─ Video decoding: ~{avg_video_decode:.2f} ms/call (included in data loading)")
|
||||||
|
|
||||||
|
# Show cache hit rate if available
|
||||||
|
if hasattr(ds, '_cached_decode_frames'):
|
||||||
|
cache_info = ds._cached_decode_frames.cache_info()
|
||||||
|
hit_rate = cache_info.hits / max(cache_info.hits + cache_info.misses, 1) * 100
|
||||||
|
print(f" └─ Cache hit rate: {hit_rate:.1f}% ({cache_info.hits}H/{cache_info.misses}M, size={cache_info.currsize})")
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,73 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
"""
|
||||||
|
Quick benchmark to test video decoding speed across different backends.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
import torch
|
||||||
|
|
||||||
|
def test_video_backend(video_path, backend_name, num_frames=10):
|
||||||
|
"""Test video decoding speed for a specific backend."""
|
||||||
|
try:
|
||||||
|
from lerobot.datasets.video_utils import decode_video_frames
|
||||||
|
|
||||||
|
# Create timestamps for first N frames
|
||||||
|
fps = 30 # Assume 30fps, adjust if needed
|
||||||
|
timestamps = [i / fps for i in range(num_frames)]
|
||||||
|
|
||||||
|
# Time the decoding
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
frames = decode_video_frames(video_path, timestamps, tolerance_s=1e-4, backend=backend_name)
|
||||||
|
decode_time = time.perf_counter() - start_time
|
||||||
|
|
||||||
|
frames_decoded = frames.shape[1] if frames.dim() > 1 else frames.shape[0]
|
||||||
|
ms_per_frame = (decode_time * 1000) / max(frames_decoded, 1)
|
||||||
|
|
||||||
|
print(f"✅ {backend_name:12} | {decode_time*1000:6.1f}ms total | {ms_per_frame:6.1f}ms/frame | {frames_decoded} frames")
|
||||||
|
return decode_time, frames_decoded
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ {backend_name:12} | ERROR: {str(e)[:50]}...")
|
||||||
|
return float('inf'), 0
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Find your video files
|
||||||
|
video_dir = Path.home() / ".cache/huggingface/lerobot/kenmacken/record-test-2/videos"
|
||||||
|
video_files = list(video_dir.rglob("*.mp4"))
|
||||||
|
|
||||||
|
if not video_files:
|
||||||
|
print("❌ No video files found! Check the path.")
|
||||||
|
return
|
||||||
|
|
||||||
|
test_video = video_files[0]
|
||||||
|
print(f"Testing video: {test_video.name}")
|
||||||
|
print(f"File size: {test_video.stat().st_size / 1024 / 1024:.1f} MB")
|
||||||
|
print("-" * 60)
|
||||||
|
|
||||||
|
backends = ["torchcodec", "pyav", "video_reader"]
|
||||||
|
results = {}
|
||||||
|
|
||||||
|
for backend in backends:
|
||||||
|
decode_time, frames = test_video_backend(test_video, backend)
|
||||||
|
results[backend] = (decode_time, frames)
|
||||||
|
|
||||||
|
print("-" * 60)
|
||||||
|
print("RECOMMENDATION:")
|
||||||
|
|
||||||
|
# Find fastest backend
|
||||||
|
valid_results = {k: v for k, v in results.items() if v[0] != float('inf')}
|
||||||
|
if valid_results:
|
||||||
|
fastest = min(valid_results.items(), key=lambda x: x[1][0])
|
||||||
|
print(f"🚀 Use '{fastest[0]}' - fastest backend!")
|
||||||
|
print(f" Add to your config: video_backend: \"{fastest[0]}\"")
|
||||||
|
|
||||||
|
slowest_time = max(valid_results.values())[0]
|
||||||
|
speedup = slowest_time / fastest[1][0]
|
||||||
|
print(f" Speedup vs slowest: {speedup:.1f}x faster")
|
||||||
|
else:
|
||||||
|
print("❌ No backends worked!")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Reference in New Issue
Block a user