fix progress

This commit is contained in:
Pepijn
2025-08-31 17:13:49 +02:00
parent c9243c29b0
commit 086815edb7
3 changed files with 193 additions and 35 deletions
+1 -1
View File
@@ -10,7 +10,7 @@ Usage:
python src/lerobot/policies/rlearn/eval_script.py --model MODEL_NAME --dataset DATASET_REPO --episodes N python src/lerobot/policies/rlearn/eval_script.py --model MODEL_NAME --dataset DATASET_REPO --episodes N
Example: Example:
python src/lerobot/policies/rlearn/eval_script.py --model pepijn223/rlearn_mse5 --dataset pepijn223/phone_pipeline_pickup1 --episodes 2 python src/lerobot/policies/rlearn/eval_script.py --model pepijn223/rlearn_18 --dataset pepijn223/phone_pipeline_pickup1 --episodes 2
""" """
import argparse import argparse
+128 -34
View File
@@ -418,6 +418,7 @@ class RLearNPolicy(PreTrainedPolicy):
frames, frames,
rewind_prob=self.config.rewind_prob, rewind_prob=self.config.rewind_prob,
last3_prob=self.config.rewind_last3_prob, last3_prob=self.config.rewind_last3_prob,
anchor_stats=anchor_stats,
) )
# Apply stride and frame dropout # Apply stride and frame dropout
@@ -484,18 +485,22 @@ class RLearNPolicy(PreTrainedPolicy):
# IMPORTANT: Progress should be 0-1 across the ENTIRE EPISODE, not just the temporal window # IMPORTANT: Progress should be 0-1 across the ENTIRE EPISODE, not just the temporal window
loss_dict: dict[str, float] = {} loss_dict: dict[str, float] = {}
# Generate progress targets that span full 0-1 range # Generate progress targets based on episode-relative positions
if self.training and augmented_target is not None: if self.training and augmented_target is not None:
# Always create targets that span 0-1 across T_eff frames for better distribution # For rewind augmentation, the augmented_target already contains proper progress values
target = torch.linspace(0, 1, T_eff, device=device).unsqueeze(0).expand(B, -1) # But we need to handle potential stride/dropout
target = augmented_target[:, :T_eff] if augmented_target.shape[1] > T_eff else augmented_target
if target.shape[1] < T_eff:
# This shouldn't happen but handle it gracefully
target = torch.linspace(0, 1, T_eff, device=device).unsqueeze(0).expand(B, -1)
else: else:
# Use anchor-based window-relative progress # Use anchor-based episode-relative progress
if anchor_stats.get("fallback_used", False): if anchor_stats.get("fallback_used", False):
raise ValueError( raise ValueError(
"Anchor-based sampling failed. Ensure 'episode_index', 'frame_index' are in batch " "Anchor-based sampling failed. Ensure 'episode_index', 'frame_index' are in batch "
"and 'episode_data_index' is loaded from episodes.jsonl" "and 'episode_data_index' is loaded from episodes.jsonl"
) )
target = self._calculate_anchor_based_progress(T_eff) target = self._calculate_anchor_based_progress(T_eff, anchor_stats)
# During inference, we might not want to compute loss # During inference, we might not want to compute loss
if not self.training and target is None: if not self.training and target is None:
@@ -844,7 +849,7 @@ class RLearNPolicy(PreTrainedPolicy):
return ep, fr return ep, fr
def _sample_random_anchor_windows(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]: def _sample_random_anchor_windows(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
"""Sample random anchor windows for training.""" """Sample random anchor windows for training and compute episode-relative progress."""
# Extract episode and frame indices - required for proper anchor sampling # Extract episode and frame indices - required for proper anchor sampling
episode_indices, frame_indices = self._extract_episode_and_frame_indices(batch) episode_indices, frame_indices = self._extract_episode_and_frame_indices(batch)
@@ -865,6 +870,8 @@ class RLearNPolicy(PreTrainedPolicy):
# Sample random anchors and build windows # Sample random anchors and build windows
sampled_frames = [] sampled_frames = []
anchor_positions = [] anchor_positions = []
window_frame_indices = [] # Store actual frame indices for progress calculation
episode_lengths = [] # Store episode lengths for progress calculation
oob_count = 0 oob_count = 0
for b_idx in range(B): for b_idx in range(B):
@@ -874,6 +881,7 @@ class RLearNPolicy(PreTrainedPolicy):
ep_start = self.episode_data_index["from"][ep_idx].item() ep_start = self.episode_data_index["from"][ep_idx].item()
ep_end = self.episode_data_index["to"][ep_idx].item() ep_end = self.episode_data_index["to"][ep_idx].item()
ep_length = ep_end - ep_start ep_length = ep_end - ep_start
episode_lengths.append(ep_length)
# Choose random anchor - need at least T-1 frames before for [-15..0] window # Choose random anchor - need at least T-1 frames before for [-15..0] window
min_anchor = T - 1 min_anchor = T - 1
@@ -883,9 +891,11 @@ class RLearNPolicy(PreTrainedPolicy):
# Build window indices with reflection padding # Build window indices with reflection padding
window_indices = [] window_indices = []
frame_indices_for_progress = [] # Track actual frame positions for progress
had_oob = False had_oob = False
for delta in range(-(T-1), 1): # [-15, -14, ..., 0] for T=16 for delta in range(-(T-1), 1): # [-15, -14, ..., 0] for T=16
idx = anchor + delta idx = anchor + delta
actual_frame_idx = idx # Store the actual frame index before reflection
if idx < 0: if idx < 0:
idx = -idx # Reflect at start idx = -idx # Reflect at start
had_oob = True had_oob = True
@@ -893,6 +903,8 @@ class RLearNPolicy(PreTrainedPolicy):
idx = 2 * (ep_length - 1) - idx # Reflect at end idx = 2 * (ep_length - 1) - idx # Reflect at end
had_oob = True had_oob = True
window_indices.append(min(idx, available_T - 1)) window_indices.append(min(idx, available_T - 1))
# For reflected indices, use the reflected position for progress
frame_indices_for_progress.append(idx)
if had_oob: if had_oob:
oob_count += 1 oob_count += 1
@@ -900,6 +912,7 @@ class RLearNPolicy(PreTrainedPolicy):
# Extract frames # Extract frames
frame_tensors = [raw_frames[b_idx, idx] for idx in window_indices] frame_tensors = [raw_frames[b_idx, idx] for idx in window_indices]
sampled_frames.append(torch.stack(frame_tensors)) sampled_frames.append(torch.stack(frame_tensors))
window_frame_indices.append(frame_indices_for_progress)
frames = torch.stack(sampled_frames, dim=0) frames = torch.stack(sampled_frames, dim=0)
@@ -908,21 +921,54 @@ class RLearNPolicy(PreTrainedPolicy):
"anchor_std": float(torch.tensor(anchor_positions).float().std()), "anchor_std": float(torch.tensor(anchor_positions).float().std()),
"oob_fraction": float(oob_count) / B, "oob_fraction": float(oob_count) / B,
"padded_fraction": 0.0, # No padding with reflection approach "padded_fraction": 0.0, # No padding with reflection approach
"fallback_used": False "fallback_used": False,
"window_frame_indices": window_frame_indices, # Pass frame indices for progress calculation
"episode_lengths": episode_lengths # Pass episode lengths for progress calculation
} }
return frames, anchor_stats return frames, anchor_stats
def _calculate_anchor_based_progress(self, T_eff: int) -> Tensor: def _calculate_anchor_based_progress(self, T_eff: int, anchor_stats: dict) -> Tensor:
"""Generate window-relative progress (0 to 1 across actual frames used).""" """Generate episode-relative progress based on actual frame positions within episodes."""
device = next(self.parameters()).device device = next(self.parameters()).device
# Create progress that spans 0 to 1 across the T_eff frames we actually use
# This ensures we get samples at all progress levels including near 1.0 # Extract frame indices and episode lengths from anchor_stats
if T_eff == 1: window_frame_indices = anchor_stats.get("window_frame_indices")
progress = torch.tensor([0.5], device=device) # Single frame gets middle progress episode_lengths = anchor_stats.get("episode_lengths")
else:
progress = torch.linspace(0, 1, T_eff, device=device) # Full 0-1 range if window_frame_indices is None or episode_lengths is None:
return progress.unsqueeze(0) # (1, T_eff) - will broadcast to (B, T_eff) # Fallback to window-relative progress if episode info not available
# This should not happen in normal training
if T_eff == 1:
progress = torch.tensor([0.5], device=device)
else:
progress = torch.linspace(0, 1, T_eff, device=device)
return progress.unsqueeze(0)
B = len(window_frame_indices)
T = len(window_frame_indices[0]) # Original window size (16)
# Calculate episode-relative progress for each sample
all_progress = []
for b_idx in range(B):
frame_indices = window_frame_indices[b_idx]
ep_length = episode_lengths[b_idx]
# Calculate progress as frame_index / (episode_length - 1)
# This gives us progress from 0.0 to 1.0 across the episode
progress = torch.tensor([
frame_idx / max(ep_length - 1, 1) for frame_idx in frame_indices
], device=device, dtype=torch.float32)
# If we have stride/dropout (T_eff < T), subsample the progress values
if T_eff < T:
# Subsample evenly from the progress values
indices = torch.linspace(0, T - 1, T_eff, dtype=torch.long)
progress = progress[indices]
all_progress.append(progress)
return torch.stack(all_progress) # (B, T_eff)
@@ -1033,26 +1079,43 @@ def extract_visual_sequence(batch: dict[str, Tensor], target_seq_len: int = None
return frames return frames
def apply_video_rewind(frames: Tensor, rewind_prob: float = 0.5, last3_prob: float | None = None) -> tuple[Tensor, Tensor]: def apply_video_rewind(frames: Tensor, rewind_prob: float = 0.5, last3_prob: float | None = None, anchor_stats: dict | None = None) -> tuple[Tensor, Tensor]:
"""Apply video rewinding augmentation without constant-value padding. """Apply video rewinding augmentation with episode-relative progress.
This version ensures the rewound sequence is exactly T frames without flat plateaus This version ensures the rewound sequence is exactly T frames and generates
that drag down the target mean. episode-relative progress labels based on actual frame positions.
Args: Args:
frames: Tensor of shape (B, T, C, H, W) frames: Tensor of shape (B, T, C, H, W)
rewind_prob: Probability of applying rewind augmentation to each video rewind_prob: Probability of applying rewind augmentation to each video
last3_prob: Probability of limiting rewind to last 3 frames last3_prob: Probability of limiting rewind to last 3 frames
anchor_stats: Dictionary containing window_frame_indices and episode_lengths for episode-relative progress
Returns: Returns:
Augmented frames and corresponding progress labels Augmented frames and corresponding episode-relative progress labels
""" """
B, T, C, H, W = frames.shape B, T, C, H, W = frames.shape
device = frames.device device = frames.device
# Create default progress labels - will be properly scaled after stride/dropout # Extract episode information if available
# Use frame indices that will give 0-1 range after subsampling window_frame_indices = anchor_stats.get("window_frame_indices") if anchor_stats else None
default_progress = torch.linspace(0, 1, T, device=device).unsqueeze(0).expand(B, -1) episode_lengths = anchor_stats.get("episode_lengths") if anchor_stats else None
# Create default progress labels based on episode-relative positions
if window_frame_indices and episode_lengths:
# Use actual episode-relative progress
default_progress = []
for b_idx in range(B):
frame_indices = window_frame_indices[b_idx]
ep_length = episode_lengths[b_idx]
progress = torch.tensor([
frame_idx / max(ep_length - 1, 1) for frame_idx in frame_indices
], device=device, dtype=torch.float32)
default_progress.append(progress)
default_progress = torch.stack(default_progress)
else:
# Fallback to window-relative progress
default_progress = torch.linspace(0, 1, T, device=device).unsqueeze(0).expand(B, -1)
# Apply rewind augmentation to each sample in batch independently # Apply rewind augmentation to each sample in batch independently
augmented_frames = [] augmented_frames = []
@@ -1095,11 +1158,27 @@ def apply_video_rewind(frames: Tensor, rewind_prob: float = 0.5, last3_prob: flo
reverse_frames = frames[b, max(0, i - k):i].flip(dims=[0]) reverse_frames = frames[b, max(0, i - k):i].flip(dims=[0])
rewound_seq = torch.cat([forward_frames, reverse_frames], dim=0) rewound_seq = torch.cat([forward_frames, reverse_frames], dim=0)
# Create corresponding progress labels without constant padding # Create corresponding progress labels based on episode-relative positions
denom = max(T - 1, 1) if window_frame_indices and episode_lengths:
forward_progress = torch.linspace(0, (i - 1) / denom, i, device=device) # Use episode-relative progress for rewind
reverse_progress = torch.linspace((i - 1) / denom, max(0.0, (i - k) / denom), k, device=device) frame_indices = window_frame_indices[b]
rewound_progress = torch.cat([forward_progress, reverse_progress]) ep_length = episode_lengths[b]
# Forward part: use actual frame indices
forward_progress = torch.tensor([
frame_indices[idx] / max(ep_length - 1, 1) for idx in range(i)
], device=device, dtype=torch.float32)
# Reverse part: use reversed frame indices
reverse_indices = list(range(max(0, i - k), i))[::-1]
reverse_progress = torch.tensor([
frame_indices[idx] / max(ep_length - 1, 1) for idx in reverse_indices
], device=device, dtype=torch.float32)
rewound_progress = torch.cat([forward_progress, reverse_progress])
else:
# Fallback to window-relative progress
denom = max(T - 1, 1)
forward_progress = torch.linspace(0, (i - 1) / denom, i, device=device)
reverse_progress = torch.linspace((i - 1) / denom, max(0.0, (i - k) / denom), k, device=device)
rewound_progress = torch.cat([forward_progress, reverse_progress])
success = True success = True
break break
@@ -1114,11 +1193,26 @@ def apply_video_rewind(frames: Tensor, rewind_prob: float = 0.5, last3_prob: flo
rewound_seq = torch.cat([forward_frames, reverse_frames], dim=0) rewound_seq = torch.cat([forward_frames, reverse_frames], dim=0)
if rewound_seq.shape[0] == T: if rewound_seq.shape[0] == T:
# Create progress labels # Create progress labels based on episode-relative positions
denom = max(T - 1, 1) if window_frame_indices and episode_lengths:
forward_progress = torch.linspace(0, (i - 1) / denom, i, device=device) frame_indices = window_frame_indices[b]
reverse_progress = torch.linspace((i - 1) / denom, max(0.0, (i - k_extended) / denom), k_extended, device=device) ep_length = episode_lengths[b]
rewound_progress = torch.cat([forward_progress, reverse_progress]) # Forward part
forward_progress = torch.tensor([
frame_indices[idx] / max(ep_length - 1, 1) for idx in range(i)
], device=device, dtype=torch.float32)
# Extended reverse part
reverse_indices = list(range(max(0, i - k_extended), i))[::-1]
reverse_progress = torch.tensor([
frame_indices[idx] / max(ep_length - 1, 1) for idx in reverse_indices
], device=device, dtype=torch.float32)
rewound_progress = torch.cat([forward_progress, reverse_progress])
else:
# Fallback to window-relative progress
denom = max(T - 1, 1)
forward_progress = torch.linspace(0, (i - 1) / denom, i, device=device)
reverse_progress = torch.linspace((i - 1) / denom, max(0.0, (i - k_extended) / denom), k_extended, device=device)
rewound_progress = torch.cat([forward_progress, reverse_progress])
success = True success = True
break break
+64
View File
@@ -0,0 +1,64 @@
#!/usr/bin/env python
"""Test script to verify episode-relative progress is working correctly."""
import torch
import numpy as np
from pathlib import Path
# Simulate what the dataset would provide
def create_test_batch(batch_size=2, episode_lengths=[100, 150]):
"""Create a test batch with episode information."""
batch = {}
# Simulate episode indices and frame indices
batch["episode_index"] = torch.tensor([0, 1]) # Two different episodes
batch["frame_index"] = torch.tensor([50, 75]) # Middle of each episode
# Simulate images (not important for this test)
batch["observation.images"] = torch.randn(batch_size, 16, 3, 224, 224)
# Simulate language
batch["observation.language"] = ["Pick up the blue block", "Pick up the red block"]
return batch
def test_progress_calculation():
"""Test that progress is calculated correctly."""
print("Testing Episode-Relative Progress Calculation")
print("=" * 60)
# Simulate episode_data_index
episode_data_index = {
"from": torch.tensor([0, 100, 250]), # Episode boundaries
"to": torch.tensor([100, 250, 400]) # Episode ends
}
# Test case 1: Sample from middle of episode
print("\nTest Case 1: Window from middle of 100-frame episode")
print("Anchor at frame 50, window frames [35-50]")
# Expected progress for frames 35-50 in a 100-frame episode
expected_progress = [35/99, 36/99, 37/99, 38/99, 39/99, 40/99, 41/99, 42/99,
43/99, 44/99, 45/99, 46/99, 47/99, 48/99, 49/99, 50/99]
print(f"Expected progress range: [{expected_progress[0]:.3f} to {expected_progress[-1]:.3f}]")
print(f"This is ~[0.354 to 0.505] - NOT [0.0 to 1.0]!")
# Test case 2: Sample from end of episode
print("\nTest Case 2: Window from end of 150-frame episode")
print("Anchor at frame 140, window frames [125-140]")
# Expected progress for frames 125-140 in a 150-frame episode
expected_progress_2 = [125/149, 126/149, 127/149, 128/149, 129/149, 130/149, 131/149, 132/149,
133/149, 134/149, 135/149, 136/149, 137/149, 138/149, 139/149, 140/149]
print(f"Expected progress range: [{expected_progress_2[0]:.3f} to {expected_progress_2[-1]:.3f}]")
print(f"This is ~[0.839 to 0.940] - NOT [0.0 to 1.0]!")
print("\n" + "=" * 60)
print("✅ Key Insight: Each 16-frame window should have progress values")
print(" that reflect its actual position within the episode,")
print(" NOT always [0.0 to 1.0]!")
if __name__ == "__main__":
test_progress_calculation()