add decode logging

This commit is contained in:
Pepijn
2025-08-30 15:52:24 +02:00
parent 0b5da92a58
commit aed90c8042
+78
View File
@@ -24,6 +24,68 @@ import torch
# Fix tokenizer parallelism conflicts with multiprocessing
os.environ["TOKENIZERS_PARALLELISM"] = "false"
def _add_video_decoding_timing(dataset):
"""Add timing instrumentation to video decoding for debugging."""
from lerobot.datasets.lerobot_dataset import LeRobotDataset, MultiLeRobotDataset
def instrument_dataset(ds):
if not hasattr(ds, '_query_videos'):
return
# Store original method
original_query_videos = ds._query_videos
# Initialize timing stats
if not hasattr(ds, '_video_decode_timing'):
ds._video_decode_timing = {
'decode_times': [],
'last_print_time': time.perf_counter()
}
def timed_query_videos(self, query_timestamps, ep_idx):
decode_start = time.perf_counter()
result = original_query_videos(query_timestamps, ep_idx)
decode_time = time.perf_counter() - decode_start
# Accumulate timing
timing_stats = self._video_decode_timing
timing_stats['decode_times'].append(decode_time * 1000) # Convert to ms
# Print averaged stats every minute
current_time = time.perf_counter()
if current_time - timing_stats['last_print_time'] >= 60.0:
n_samples = len(timing_stats['decode_times'])
if n_samples > 0:
avg_decode_time = sum(timing_stats['decode_times']) / n_samples
total_frames = sum(len(query_timestamps[key]) for key in query_timestamps)
avg_frames_per_call = total_frames / n_samples if n_samples > 0 else 0
print(f"\nVideo Decoding Timing (last {n_samples} calls):")
print(f" Avg decode time: {avg_decode_time:.2f} ms")
print(f" Avg frames/call: {avg_frames_per_call:.1f}")
print(f" Time per frame: {avg_decode_time/max(avg_frames_per_call, 1):.2f} ms/frame")
print("-" * 50)
# Reset stats
timing_stats['decode_times'] = []
timing_stats['last_print_time'] = current_time
return result
# Bind the method to the instance
import types
ds._query_videos = types.MethodType(timed_query_videos, ds)
# Handle both single and multi datasets
if isinstance(dataset, MultiLeRobotDataset):
for ds in dataset._datasets:
instrument_dataset(ds)
elif isinstance(dataset, LeRobotDataset):
instrument_dataset(dataset)
else:
print(f"Warning: Unknown dataset type {type(dataset)}, skipping video timing instrumentation")
from termcolor import colored
from torch.amp import GradScaler
from torch.optim import Optimizer
@@ -182,6 +244,10 @@ def train(cfg: TrainPipelineConfig):
logging.info("Creating dataset")
dataset = make_dataset(cfg)
# Add video decoding timing for RLearN debugging
if getattr(cfg.policy, "type", None) == "rlearn":
_add_video_decoding_timing(dataset)
# Create environment used for evaluating checkpoints during training on simulation data.
# On real-world data, no need to create an environment as evaluations are done outside train.py,
@@ -357,6 +423,18 @@ def train(cfg: TrainPipelineConfig):
print(f" Data loading: {avg_data_loading:.2f} ms")
print(f" Preprocessing: {avg_preprocessing:.2f} ms")
print(f" Total data pipeline: {avg_data_loading + avg_preprocessing:.2f} ms")
# Show video decoding breakdown if available
try:
ds = dataset._datasets[0] if hasattr(dataset, '_datasets') else dataset
if hasattr(ds, '_video_decode_timing'):
recent_decodes = ds._video_decode_timing.get('decode_times', [])
if recent_decodes:
avg_video_decode = sum(recent_decodes) / len(recent_decodes)
print(f" └─ Video decoding: ~{avg_video_decode:.2f} ms/call (included in data loading)")
except Exception:
pass
print("-" * 50)
# Reset stats for next minute