add decode logging

This commit is contained in:
Pepijn
2025-08-30 16:16:08 +02:00
parent b1ff7132c1
commit 1234e71cfb
2 changed files with 33 additions and 4 deletions
@@ -67,7 +67,6 @@ class RLearNConfig(PreTrainedConfig):
# Performance optimizations
use_amp: bool = True # Mixed precision training for speed boost
compile_model: bool = True # torch.compile for additional speedup
video_backend: str = "pyav" # Use PyAV for faster video decoding (vs torchcodec)
# ReWiND-specific parameters
use_video_rewind: bool = True # Enable video rewinding augmentation
+33 -3
View File
@@ -45,13 +45,25 @@ def _add_video_decoding_timing(dataset):
}
def timed_query_videos(self, query_timestamps, ep_idx):
# Debug: print what backend is being used
if not hasattr(self, '_backend_logged'):
print(f"DEBUG: Video backend in use: {getattr(self, 'video_backend', 'UNKNOWN')}")
self._backend_logged = True
decode_start = time.perf_counter()
result = original_query_videos(query_timestamps, ep_idx)
decode_time = time.perf_counter() - decode_start
# Debug problematic 0.5 frames issue
actual_frames = 0
for key in query_timestamps:
actual_frames += len(query_timestamps[key])
# Accumulate timing
timing_stats = self._video_decode_timing
timing_stats['decode_times'].append(decode_time * 1000) # Convert to ms
timing_stats['actual_frame_counts'] = timing_stats.get('actual_frame_counts', [])
timing_stats['actual_frame_counts'].append(actual_frames)
# Print averaged stats every minute
current_time = time.perf_counter()
@@ -59,8 +71,9 @@ def _add_video_decoding_timing(dataset):
n_samples = len(timing_stats['decode_times'])
if n_samples > 0:
avg_decode_time = sum(timing_stats['decode_times']) / n_samples
total_frames = sum(len(query_timestamps[key]) for key in query_timestamps)
avg_frames_per_call = total_frames / n_samples if n_samples > 0 else 0
# Use actual frame counts tracked per call
actual_counts = timing_stats.get('actual_frame_counts', [])
avg_frames_per_call = sum(actual_counts) / len(actual_counts) if actual_counts else 0
print(f"\nVideo Decoding Timing (last {n_samples} calls):")
print(f" Avg decode time: {avg_decode_time:.2f} ms")
@@ -70,6 +83,7 @@ def _add_video_decoding_timing(dataset):
# Reset stats
timing_stats['decode_times'] = []
timing_stats['actual_frame_counts'] = []
timing_stats['last_print_time'] = current_time
return result
@@ -295,11 +309,27 @@ def train(cfg: TrainPipelineConfig):
torch.backends.cuda.matmul.allow_tf32 = True
logging.info("Creating dataset")
# Force PyAV backend for RLearN (proven to be fastest)
if getattr(cfg.policy, "type", None) == "rlearn":
# Override video backend to use PyAV
if hasattr(cfg.dataset, 'video_backend'):
original_backend = cfg.dataset.video_backend
cfg.dataset.video_backend = 'pyav'
logging.info(f"RLearN: Forcing video_backend from '{original_backend}' to 'pyav' for better performance")
else:
cfg.dataset.video_backend = 'pyav'
logging.info("RLearN: Setting video_backend to 'pyav' for better performance")
dataset = make_dataset(cfg)
# Add video decoding timing for RLearN debugging
# Add video decoding timing and caching for RLearN debugging
if getattr(cfg.policy, "type", None) == "rlearn":
_add_video_decoding_timing(dataset)
# Add frame caching for small datasets
if hasattr(dataset, 'num_frames') and dataset.num_frames < 1000:
_add_video_frame_caching(dataset, cache_size=500)
logging.info(f"RLearN: Added frame caching for {dataset.num_frames} frame dataset")
# Create environment used for evaluating checkpoints during training on simulation data.
# On real-world data, no need to create an environment as evaluations are done outside train.py,