mirror of
https://github.com/huggingface/lerobot.git
synced 2026-07-05 17:17:01 +00:00
add decode logging
This commit is contained in:
@@ -67,7 +67,6 @@ class RLearNConfig(PreTrainedConfig):
|
||||
# Performance optimizations
|
||||
use_amp: bool = True # Mixed precision training for speed boost
|
||||
compile_model: bool = True # torch.compile for additional speedup
|
||||
video_backend: str = "pyav" # Use PyAV for faster video decoding (vs torchcodec)
|
||||
|
||||
# ReWiND-specific parameters
|
||||
use_video_rewind: bool = True # Enable video rewinding augmentation
|
||||
|
||||
@@ -45,13 +45,25 @@ def _add_video_decoding_timing(dataset):
|
||||
}
|
||||
|
||||
def timed_query_videos(self, query_timestamps, ep_idx):
|
||||
# Debug: print what backend is being used
|
||||
if not hasattr(self, '_backend_logged'):
|
||||
print(f"DEBUG: Video backend in use: {getattr(self, 'video_backend', 'UNKNOWN')}")
|
||||
self._backend_logged = True
|
||||
|
||||
decode_start = time.perf_counter()
|
||||
result = original_query_videos(query_timestamps, ep_idx)
|
||||
decode_time = time.perf_counter() - decode_start
|
||||
|
||||
# Debug problematic 0.5 frames issue
|
||||
actual_frames = 0
|
||||
for key in query_timestamps:
|
||||
actual_frames += len(query_timestamps[key])
|
||||
|
||||
# Accumulate timing
|
||||
timing_stats = self._video_decode_timing
|
||||
timing_stats['decode_times'].append(decode_time * 1000) # Convert to ms
|
||||
timing_stats['actual_frame_counts'] = timing_stats.get('actual_frame_counts', [])
|
||||
timing_stats['actual_frame_counts'].append(actual_frames)
|
||||
|
||||
# Print averaged stats every minute
|
||||
current_time = time.perf_counter()
|
||||
@@ -59,8 +71,9 @@ def _add_video_decoding_timing(dataset):
|
||||
n_samples = len(timing_stats['decode_times'])
|
||||
if n_samples > 0:
|
||||
avg_decode_time = sum(timing_stats['decode_times']) / n_samples
|
||||
total_frames = sum(len(query_timestamps[key]) for key in query_timestamps)
|
||||
avg_frames_per_call = total_frames / n_samples if n_samples > 0 else 0
|
||||
# Use actual frame counts tracked per call
|
||||
actual_counts = timing_stats.get('actual_frame_counts', [])
|
||||
avg_frames_per_call = sum(actual_counts) / len(actual_counts) if actual_counts else 0
|
||||
|
||||
print(f"\nVideo Decoding Timing (last {n_samples} calls):")
|
||||
print(f" Avg decode time: {avg_decode_time:.2f} ms")
|
||||
@@ -70,6 +83,7 @@ def _add_video_decoding_timing(dataset):
|
||||
|
||||
# Reset stats
|
||||
timing_stats['decode_times'] = []
|
||||
timing_stats['actual_frame_counts'] = []
|
||||
timing_stats['last_print_time'] = current_time
|
||||
|
||||
return result
|
||||
@@ -295,11 +309,27 @@ def train(cfg: TrainPipelineConfig):
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
logging.info("Creating dataset")
|
||||
|
||||
# Force PyAV backend for RLearN (proven to be fastest)
|
||||
if getattr(cfg.policy, "type", None) == "rlearn":
|
||||
# Override video backend to use PyAV
|
||||
if hasattr(cfg.dataset, 'video_backend'):
|
||||
original_backend = cfg.dataset.video_backend
|
||||
cfg.dataset.video_backend = 'pyav'
|
||||
logging.info(f"RLearN: Forcing video_backend from '{original_backend}' to 'pyav' for better performance")
|
||||
else:
|
||||
cfg.dataset.video_backend = 'pyav'
|
||||
logging.info("RLearN: Setting video_backend to 'pyav' for better performance")
|
||||
|
||||
dataset = make_dataset(cfg)
|
||||
|
||||
# Add video decoding timing for RLearN debugging
|
||||
# Add video decoding timing and caching for RLearN debugging
|
||||
if getattr(cfg.policy, "type", None) == "rlearn":
|
||||
_add_video_decoding_timing(dataset)
|
||||
# Add frame caching for small datasets
|
||||
if hasattr(dataset, 'num_frames') and dataset.num_frames < 1000:
|
||||
_add_video_frame_caching(dataset, cache_size=500)
|
||||
logging.info(f"RLearN: Added frame caching for {dataset.num_frames} frame dataset")
|
||||
|
||||
# Create environment used for evaluating checkpoints during training on simulation data.
|
||||
# On real-world data, no need to create an environment as evaluations are done outside train.py,
|
||||
|
||||
Reference in New Issue
Block a user