From 086815edb7a1d21b3c823072ecf2377c5fc0d16a Mon Sep 17 00:00:00 2001 From: Pepijn Date: Sun, 31 Aug 2025 17:13:49 +0200 Subject: [PATCH] fix progress --- src/lerobot/policies/rlearn/eval_script.py | 2 +- .../policies/rlearn/modeling_rlearn.py | 162 ++++++++++++++---- test_episode_progress.py | 64 +++++++ 3 files changed, 193 insertions(+), 35 deletions(-) create mode 100644 test_episode_progress.py diff --git a/src/lerobot/policies/rlearn/eval_script.py b/src/lerobot/policies/rlearn/eval_script.py index 69cfc1b1f..a3bd0a8cd 100644 --- a/src/lerobot/policies/rlearn/eval_script.py +++ b/src/lerobot/policies/rlearn/eval_script.py @@ -10,7 +10,7 @@ Usage: python src/lerobot/policies/rlearn/eval_script.py --model MODEL_NAME --dataset DATASET_REPO --episodes N Example: - python src/lerobot/policies/rlearn/eval_script.py --model pepijn223/rlearn_mse5 --dataset pepijn223/phone_pipeline_pickup1 --episodes 2 + python src/lerobot/policies/rlearn/eval_script.py --model pepijn223/rlearn_18 --dataset pepijn223/phone_pipeline_pickup1 --episodes 2 """ import argparse diff --git a/src/lerobot/policies/rlearn/modeling_rlearn.py b/src/lerobot/policies/rlearn/modeling_rlearn.py index ca3053dae..315df1251 100644 --- a/src/lerobot/policies/rlearn/modeling_rlearn.py +++ b/src/lerobot/policies/rlearn/modeling_rlearn.py @@ -418,6 +418,7 @@ class RLearNPolicy(PreTrainedPolicy): frames, rewind_prob=self.config.rewind_prob, last3_prob=self.config.rewind_last3_prob, + anchor_stats=anchor_stats, ) # Apply stride and frame dropout @@ -484,18 +485,22 @@ class RLearNPolicy(PreTrainedPolicy): # IMPORTANT: Progress should be 0-1 across the ENTIRE EPISODE, not just the temporal window loss_dict: dict[str, float] = {} - # Generate progress targets that span full 0-1 range + # Generate progress targets based on episode-relative positions if self.training and augmented_target is not None: - # Always create targets that span 0-1 across T_eff frames for better distribution - target = torch.linspace(0, 1, T_eff, device=device).unsqueeze(0).expand(B, -1) + # For rewind augmentation, the augmented_target already contains proper progress values + # But we need to handle potential stride/dropout + target = augmented_target[:, :T_eff] if augmented_target.shape[1] > T_eff else augmented_target + if target.shape[1] < T_eff: + # This shouldn't happen but handle it gracefully + target = torch.linspace(0, 1, T_eff, device=device).unsqueeze(0).expand(B, -1) else: - # Use anchor-based window-relative progress + # Use anchor-based episode-relative progress if anchor_stats.get("fallback_used", False): raise ValueError( "Anchor-based sampling failed. Ensure 'episode_index', 'frame_index' are in batch " "and 'episode_data_index' is loaded from episodes.jsonl" ) - target = self._calculate_anchor_based_progress(T_eff) + target = self._calculate_anchor_based_progress(T_eff, anchor_stats) # During inference, we might not want to compute loss if not self.training and target is None: @@ -844,7 +849,7 @@ class RLearNPolicy(PreTrainedPolicy): return ep, fr def _sample_random_anchor_windows(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]: - """Sample random anchor windows for training.""" + """Sample random anchor windows for training and compute episode-relative progress.""" # Extract episode and frame indices - required for proper anchor sampling episode_indices, frame_indices = self._extract_episode_and_frame_indices(batch) @@ -865,6 +870,8 @@ class RLearNPolicy(PreTrainedPolicy): # Sample random anchors and build windows sampled_frames = [] anchor_positions = [] + window_frame_indices = [] # Store actual frame indices for progress calculation + episode_lengths = [] # Store episode lengths for progress calculation oob_count = 0 for b_idx in range(B): @@ -874,6 +881,7 @@ class RLearNPolicy(PreTrainedPolicy): ep_start = self.episode_data_index["from"][ep_idx].item() ep_end = self.episode_data_index["to"][ep_idx].item() ep_length = ep_end - ep_start + episode_lengths.append(ep_length) # Choose random anchor - need at least T-1 frames before for [-15..0] window min_anchor = T - 1 @@ -883,9 +891,11 @@ class RLearNPolicy(PreTrainedPolicy): # Build window indices with reflection padding window_indices = [] + frame_indices_for_progress = [] # Track actual frame positions for progress had_oob = False for delta in range(-(T-1), 1): # [-15, -14, ..., 0] for T=16 idx = anchor + delta + actual_frame_idx = idx # Store the actual frame index before reflection if idx < 0: idx = -idx # Reflect at start had_oob = True @@ -893,6 +903,8 @@ class RLearNPolicy(PreTrainedPolicy): idx = 2 * (ep_length - 1) - idx # Reflect at end had_oob = True window_indices.append(min(idx, available_T - 1)) + # For reflected indices, use the reflected position for progress + frame_indices_for_progress.append(idx) if had_oob: oob_count += 1 @@ -900,6 +912,7 @@ class RLearNPolicy(PreTrainedPolicy): # Extract frames frame_tensors = [raw_frames[b_idx, idx] for idx in window_indices] sampled_frames.append(torch.stack(frame_tensors)) + window_frame_indices.append(frame_indices_for_progress) frames = torch.stack(sampled_frames, dim=0) @@ -908,21 +921,54 @@ class RLearNPolicy(PreTrainedPolicy): "anchor_std": float(torch.tensor(anchor_positions).float().std()), "oob_fraction": float(oob_count) / B, "padded_fraction": 0.0, # No padding with reflection approach - "fallback_used": False + "fallback_used": False, + "window_frame_indices": window_frame_indices, # Pass frame indices for progress calculation + "episode_lengths": episode_lengths # Pass episode lengths for progress calculation } return frames, anchor_stats - def _calculate_anchor_based_progress(self, T_eff: int) -> Tensor: - """Generate window-relative progress (0 to 1 across actual frames used).""" + def _calculate_anchor_based_progress(self, T_eff: int, anchor_stats: dict) -> Tensor: + """Generate episode-relative progress based on actual frame positions within episodes.""" device = next(self.parameters()).device - # Create progress that spans 0 to 1 across the T_eff frames we actually use - # This ensures we get samples at all progress levels including near 1.0 - if T_eff == 1: - progress = torch.tensor([0.5], device=device) # Single frame gets middle progress - else: - progress = torch.linspace(0, 1, T_eff, device=device) # Full 0-1 range - return progress.unsqueeze(0) # (1, T_eff) - will broadcast to (B, T_eff) + + # Extract frame indices and episode lengths from anchor_stats + window_frame_indices = anchor_stats.get("window_frame_indices") + episode_lengths = anchor_stats.get("episode_lengths") + + if window_frame_indices is None or episode_lengths is None: + # Fallback to window-relative progress if episode info not available + # This should not happen in normal training + if T_eff == 1: + progress = torch.tensor([0.5], device=device) + else: + progress = torch.linspace(0, 1, T_eff, device=device) + return progress.unsqueeze(0) + + B = len(window_frame_indices) + T = len(window_frame_indices[0]) # Original window size (16) + + # Calculate episode-relative progress for each sample + all_progress = [] + for b_idx in range(B): + frame_indices = window_frame_indices[b_idx] + ep_length = episode_lengths[b_idx] + + # Calculate progress as frame_index / (episode_length - 1) + # This gives us progress from 0.0 to 1.0 across the episode + progress = torch.tensor([ + frame_idx / max(ep_length - 1, 1) for frame_idx in frame_indices + ], device=device, dtype=torch.float32) + + # If we have stride/dropout (T_eff < T), subsample the progress values + if T_eff < T: + # Subsample evenly from the progress values + indices = torch.linspace(0, T - 1, T_eff, dtype=torch.long) + progress = progress[indices] + + all_progress.append(progress) + + return torch.stack(all_progress) # (B, T_eff) @@ -1033,26 +1079,43 @@ def extract_visual_sequence(batch: dict[str, Tensor], target_seq_len: int = None return frames -def apply_video_rewind(frames: Tensor, rewind_prob: float = 0.5, last3_prob: float | None = None) -> tuple[Tensor, Tensor]: - """Apply video rewinding augmentation without constant-value padding. +def apply_video_rewind(frames: Tensor, rewind_prob: float = 0.5, last3_prob: float | None = None, anchor_stats: dict | None = None) -> tuple[Tensor, Tensor]: + """Apply video rewinding augmentation with episode-relative progress. - This version ensures the rewound sequence is exactly T frames without flat plateaus - that drag down the target mean. + This version ensures the rewound sequence is exactly T frames and generates + episode-relative progress labels based on actual frame positions. Args: frames: Tensor of shape (B, T, C, H, W) rewind_prob: Probability of applying rewind augmentation to each video last3_prob: Probability of limiting rewind to last 3 frames + anchor_stats: Dictionary containing window_frame_indices and episode_lengths for episode-relative progress Returns: - Augmented frames and corresponding progress labels + Augmented frames and corresponding episode-relative progress labels """ B, T, C, H, W = frames.shape device = frames.device - # Create default progress labels - will be properly scaled after stride/dropout - # Use frame indices that will give 0-1 range after subsampling - default_progress = torch.linspace(0, 1, T, device=device).unsqueeze(0).expand(B, -1) + # Extract episode information if available + window_frame_indices = anchor_stats.get("window_frame_indices") if anchor_stats else None + episode_lengths = anchor_stats.get("episode_lengths") if anchor_stats else None + + # Create default progress labels based on episode-relative positions + if window_frame_indices and episode_lengths: + # Use actual episode-relative progress + default_progress = [] + for b_idx in range(B): + frame_indices = window_frame_indices[b_idx] + ep_length = episode_lengths[b_idx] + progress = torch.tensor([ + frame_idx / max(ep_length - 1, 1) for frame_idx in frame_indices + ], device=device, dtype=torch.float32) + default_progress.append(progress) + default_progress = torch.stack(default_progress) + else: + # Fallback to window-relative progress + default_progress = torch.linspace(0, 1, T, device=device).unsqueeze(0).expand(B, -1) # Apply rewind augmentation to each sample in batch independently augmented_frames = [] @@ -1095,11 +1158,27 @@ def apply_video_rewind(frames: Tensor, rewind_prob: float = 0.5, last3_prob: flo reverse_frames = frames[b, max(0, i - k):i].flip(dims=[0]) rewound_seq = torch.cat([forward_frames, reverse_frames], dim=0) - # Create corresponding progress labels without constant padding - denom = max(T - 1, 1) - forward_progress = torch.linspace(0, (i - 1) / denom, i, device=device) - reverse_progress = torch.linspace((i - 1) / denom, max(0.0, (i - k) / denom), k, device=device) - rewound_progress = torch.cat([forward_progress, reverse_progress]) + # Create corresponding progress labels based on episode-relative positions + if window_frame_indices and episode_lengths: + # Use episode-relative progress for rewind + frame_indices = window_frame_indices[b] + ep_length = episode_lengths[b] + # Forward part: use actual frame indices + forward_progress = torch.tensor([ + frame_indices[idx] / max(ep_length - 1, 1) for idx in range(i) + ], device=device, dtype=torch.float32) + # Reverse part: use reversed frame indices + reverse_indices = list(range(max(0, i - k), i))[::-1] + reverse_progress = torch.tensor([ + frame_indices[idx] / max(ep_length - 1, 1) for idx in reverse_indices + ], device=device, dtype=torch.float32) + rewound_progress = torch.cat([forward_progress, reverse_progress]) + else: + # Fallback to window-relative progress + denom = max(T - 1, 1) + forward_progress = torch.linspace(0, (i - 1) / denom, i, device=device) + reverse_progress = torch.linspace((i - 1) / denom, max(0.0, (i - k) / denom), k, device=device) + rewound_progress = torch.cat([forward_progress, reverse_progress]) success = True break @@ -1114,11 +1193,26 @@ def apply_video_rewind(frames: Tensor, rewind_prob: float = 0.5, last3_prob: flo rewound_seq = torch.cat([forward_frames, reverse_frames], dim=0) if rewound_seq.shape[0] == T: - # Create progress labels - denom = max(T - 1, 1) - forward_progress = torch.linspace(0, (i - 1) / denom, i, device=device) - reverse_progress = torch.linspace((i - 1) / denom, max(0.0, (i - k_extended) / denom), k_extended, device=device) - rewound_progress = torch.cat([forward_progress, reverse_progress]) + # Create progress labels based on episode-relative positions + if window_frame_indices and episode_lengths: + frame_indices = window_frame_indices[b] + ep_length = episode_lengths[b] + # Forward part + forward_progress = torch.tensor([ + frame_indices[idx] / max(ep_length - 1, 1) for idx in range(i) + ], device=device, dtype=torch.float32) + # Extended reverse part + reverse_indices = list(range(max(0, i - k_extended), i))[::-1] + reverse_progress = torch.tensor([ + frame_indices[idx] / max(ep_length - 1, 1) for idx in reverse_indices + ], device=device, dtype=torch.float32) + rewound_progress = torch.cat([forward_progress, reverse_progress]) + else: + # Fallback to window-relative progress + denom = max(T - 1, 1) + forward_progress = torch.linspace(0, (i - 1) / denom, i, device=device) + reverse_progress = torch.linspace((i - 1) / denom, max(0.0, (i - k_extended) / denom), k_extended, device=device) + rewound_progress = torch.cat([forward_progress, reverse_progress]) success = True break diff --git a/test_episode_progress.py b/test_episode_progress.py new file mode 100644 index 000000000..40ddf25d7 --- /dev/null +++ b/test_episode_progress.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python +"""Test script to verify episode-relative progress is working correctly.""" + +import torch +import numpy as np +from pathlib import Path + +# Simulate what the dataset would provide +def create_test_batch(batch_size=2, episode_lengths=[100, 150]): + """Create a test batch with episode information.""" + batch = {} + + # Simulate episode indices and frame indices + batch["episode_index"] = torch.tensor([0, 1]) # Two different episodes + batch["frame_index"] = torch.tensor([50, 75]) # Middle of each episode + + # Simulate images (not important for this test) + batch["observation.images"] = torch.randn(batch_size, 16, 3, 224, 224) + + # Simulate language + batch["observation.language"] = ["Pick up the blue block", "Pick up the red block"] + + return batch + +def test_progress_calculation(): + """Test that progress is calculated correctly.""" + print("Testing Episode-Relative Progress Calculation") + print("=" * 60) + + # Simulate episode_data_index + episode_data_index = { + "from": torch.tensor([0, 100, 250]), # Episode boundaries + "to": torch.tensor([100, 250, 400]) # Episode ends + } + + # Test case 1: Sample from middle of episode + print("\nTest Case 1: Window from middle of 100-frame episode") + print("Anchor at frame 50, window frames [35-50]") + + # Expected progress for frames 35-50 in a 100-frame episode + expected_progress = [35/99, 36/99, 37/99, 38/99, 39/99, 40/99, 41/99, 42/99, + 43/99, 44/99, 45/99, 46/99, 47/99, 48/99, 49/99, 50/99] + + print(f"Expected progress range: [{expected_progress[0]:.3f} to {expected_progress[-1]:.3f}]") + print(f"This is ~[0.354 to 0.505] - NOT [0.0 to 1.0]!") + + # Test case 2: Sample from end of episode + print("\nTest Case 2: Window from end of 150-frame episode") + print("Anchor at frame 140, window frames [125-140]") + + # Expected progress for frames 125-140 in a 150-frame episode + expected_progress_2 = [125/149, 126/149, 127/149, 128/149, 129/149, 130/149, 131/149, 132/149, + 133/149, 134/149, 135/149, 136/149, 137/149, 138/149, 139/149, 140/149] + + print(f"Expected progress range: [{expected_progress_2[0]:.3f} to {expected_progress_2[-1]:.3f}]") + print(f"This is ~[0.839 to 0.940] - NOT [0.0 to 1.0]!") + + print("\n" + "=" * 60) + print("✅ Key Insight: Each 16-frame window should have progress values") + print(" that reflect its actual position within the episode,") + print(" NOT always [0.0 to 1.0]!") + +if __name__ == "__main__": + test_progress_calculation()