mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-23 12:40:08 +00:00
fix progress
This commit is contained in:
@@ -10,7 +10,7 @@ Usage:
|
||||
python src/lerobot/policies/rlearn/eval_script.py --model MODEL_NAME --dataset DATASET_REPO --episodes N
|
||||
|
||||
Example:
|
||||
python src/lerobot/policies/rlearn/eval_script.py --model pepijn223/rlearn_mse5 --dataset pepijn223/phone_pipeline_pickup1 --episodes 2
|
||||
python src/lerobot/policies/rlearn/eval_script.py --model pepijn223/rlearn_18 --dataset pepijn223/phone_pipeline_pickup1 --episodes 2
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
@@ -418,6 +418,7 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
frames,
|
||||
rewind_prob=self.config.rewind_prob,
|
||||
last3_prob=self.config.rewind_last3_prob,
|
||||
anchor_stats=anchor_stats,
|
||||
)
|
||||
|
||||
# Apply stride and frame dropout
|
||||
@@ -484,18 +485,22 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
# IMPORTANT: Progress should be 0-1 across the ENTIRE EPISODE, not just the temporal window
|
||||
loss_dict: dict[str, float] = {}
|
||||
|
||||
# Generate progress targets that span full 0-1 range
|
||||
# Generate progress targets based on episode-relative positions
|
||||
if self.training and augmented_target is not None:
|
||||
# Always create targets that span 0-1 across T_eff frames for better distribution
|
||||
target = torch.linspace(0, 1, T_eff, device=device).unsqueeze(0).expand(B, -1)
|
||||
# For rewind augmentation, the augmented_target already contains proper progress values
|
||||
# But we need to handle potential stride/dropout
|
||||
target = augmented_target[:, :T_eff] if augmented_target.shape[1] > T_eff else augmented_target
|
||||
if target.shape[1] < T_eff:
|
||||
# This shouldn't happen but handle it gracefully
|
||||
target = torch.linspace(0, 1, T_eff, device=device).unsqueeze(0).expand(B, -1)
|
||||
else:
|
||||
# Use anchor-based window-relative progress
|
||||
# Use anchor-based episode-relative progress
|
||||
if anchor_stats.get("fallback_used", False):
|
||||
raise ValueError(
|
||||
"Anchor-based sampling failed. Ensure 'episode_index', 'frame_index' are in batch "
|
||||
"and 'episode_data_index' is loaded from episodes.jsonl"
|
||||
)
|
||||
target = self._calculate_anchor_based_progress(T_eff)
|
||||
target = self._calculate_anchor_based_progress(T_eff, anchor_stats)
|
||||
|
||||
# During inference, we might not want to compute loss
|
||||
if not self.training and target is None:
|
||||
@@ -844,7 +849,7 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
return ep, fr
|
||||
|
||||
def _sample_random_anchor_windows(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
||||
"""Sample random anchor windows for training."""
|
||||
"""Sample random anchor windows for training and compute episode-relative progress."""
|
||||
# Extract episode and frame indices - required for proper anchor sampling
|
||||
episode_indices, frame_indices = self._extract_episode_and_frame_indices(batch)
|
||||
|
||||
@@ -865,6 +870,8 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
# Sample random anchors and build windows
|
||||
sampled_frames = []
|
||||
anchor_positions = []
|
||||
window_frame_indices = [] # Store actual frame indices for progress calculation
|
||||
episode_lengths = [] # Store episode lengths for progress calculation
|
||||
oob_count = 0
|
||||
|
||||
for b_idx in range(B):
|
||||
@@ -874,6 +881,7 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
ep_start = self.episode_data_index["from"][ep_idx].item()
|
||||
ep_end = self.episode_data_index["to"][ep_idx].item()
|
||||
ep_length = ep_end - ep_start
|
||||
episode_lengths.append(ep_length)
|
||||
|
||||
# Choose random anchor - need at least T-1 frames before for [-15..0] window
|
||||
min_anchor = T - 1
|
||||
@@ -883,9 +891,11 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
|
||||
# Build window indices with reflection padding
|
||||
window_indices = []
|
||||
frame_indices_for_progress = [] # Track actual frame positions for progress
|
||||
had_oob = False
|
||||
for delta in range(-(T-1), 1): # [-15, -14, ..., 0] for T=16
|
||||
idx = anchor + delta
|
||||
actual_frame_idx = idx # Store the actual frame index before reflection
|
||||
if idx < 0:
|
||||
idx = -idx # Reflect at start
|
||||
had_oob = True
|
||||
@@ -893,6 +903,8 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
idx = 2 * (ep_length - 1) - idx # Reflect at end
|
||||
had_oob = True
|
||||
window_indices.append(min(idx, available_T - 1))
|
||||
# For reflected indices, use the reflected position for progress
|
||||
frame_indices_for_progress.append(idx)
|
||||
|
||||
if had_oob:
|
||||
oob_count += 1
|
||||
@@ -900,6 +912,7 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
# Extract frames
|
||||
frame_tensors = [raw_frames[b_idx, idx] for idx in window_indices]
|
||||
sampled_frames.append(torch.stack(frame_tensors))
|
||||
window_frame_indices.append(frame_indices_for_progress)
|
||||
|
||||
frames = torch.stack(sampled_frames, dim=0)
|
||||
|
||||
@@ -908,21 +921,54 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
"anchor_std": float(torch.tensor(anchor_positions).float().std()),
|
||||
"oob_fraction": float(oob_count) / B,
|
||||
"padded_fraction": 0.0, # No padding with reflection approach
|
||||
"fallback_used": False
|
||||
"fallback_used": False,
|
||||
"window_frame_indices": window_frame_indices, # Pass frame indices for progress calculation
|
||||
"episode_lengths": episode_lengths # Pass episode lengths for progress calculation
|
||||
}
|
||||
|
||||
return frames, anchor_stats
|
||||
|
||||
def _calculate_anchor_based_progress(self, T_eff: int) -> Tensor:
|
||||
"""Generate window-relative progress (0 to 1 across actual frames used)."""
|
||||
def _calculate_anchor_based_progress(self, T_eff: int, anchor_stats: dict) -> Tensor:
|
||||
"""Generate episode-relative progress based on actual frame positions within episodes."""
|
||||
device = next(self.parameters()).device
|
||||
# Create progress that spans 0 to 1 across the T_eff frames we actually use
|
||||
# This ensures we get samples at all progress levels including near 1.0
|
||||
if T_eff == 1:
|
||||
progress = torch.tensor([0.5], device=device) # Single frame gets middle progress
|
||||
else:
|
||||
progress = torch.linspace(0, 1, T_eff, device=device) # Full 0-1 range
|
||||
return progress.unsqueeze(0) # (1, T_eff) - will broadcast to (B, T_eff)
|
||||
|
||||
# Extract frame indices and episode lengths from anchor_stats
|
||||
window_frame_indices = anchor_stats.get("window_frame_indices")
|
||||
episode_lengths = anchor_stats.get("episode_lengths")
|
||||
|
||||
if window_frame_indices is None or episode_lengths is None:
|
||||
# Fallback to window-relative progress if episode info not available
|
||||
# This should not happen in normal training
|
||||
if T_eff == 1:
|
||||
progress = torch.tensor([0.5], device=device)
|
||||
else:
|
||||
progress = torch.linspace(0, 1, T_eff, device=device)
|
||||
return progress.unsqueeze(0)
|
||||
|
||||
B = len(window_frame_indices)
|
||||
T = len(window_frame_indices[0]) # Original window size (16)
|
||||
|
||||
# Calculate episode-relative progress for each sample
|
||||
all_progress = []
|
||||
for b_idx in range(B):
|
||||
frame_indices = window_frame_indices[b_idx]
|
||||
ep_length = episode_lengths[b_idx]
|
||||
|
||||
# Calculate progress as frame_index / (episode_length - 1)
|
||||
# This gives us progress from 0.0 to 1.0 across the episode
|
||||
progress = torch.tensor([
|
||||
frame_idx / max(ep_length - 1, 1) for frame_idx in frame_indices
|
||||
], device=device, dtype=torch.float32)
|
||||
|
||||
# If we have stride/dropout (T_eff < T), subsample the progress values
|
||||
if T_eff < T:
|
||||
# Subsample evenly from the progress values
|
||||
indices = torch.linspace(0, T - 1, T_eff, dtype=torch.long)
|
||||
progress = progress[indices]
|
||||
|
||||
all_progress.append(progress)
|
||||
|
||||
return torch.stack(all_progress) # (B, T_eff)
|
||||
|
||||
|
||||
|
||||
@@ -1033,26 +1079,43 @@ def extract_visual_sequence(batch: dict[str, Tensor], target_seq_len: int = None
|
||||
return frames
|
||||
|
||||
|
||||
def apply_video_rewind(frames: Tensor, rewind_prob: float = 0.5, last3_prob: float | None = None) -> tuple[Tensor, Tensor]:
|
||||
"""Apply video rewinding augmentation without constant-value padding.
|
||||
def apply_video_rewind(frames: Tensor, rewind_prob: float = 0.5, last3_prob: float | None = None, anchor_stats: dict | None = None) -> tuple[Tensor, Tensor]:
|
||||
"""Apply video rewinding augmentation with episode-relative progress.
|
||||
|
||||
This version ensures the rewound sequence is exactly T frames without flat plateaus
|
||||
that drag down the target mean.
|
||||
This version ensures the rewound sequence is exactly T frames and generates
|
||||
episode-relative progress labels based on actual frame positions.
|
||||
|
||||
Args:
|
||||
frames: Tensor of shape (B, T, C, H, W)
|
||||
rewind_prob: Probability of applying rewind augmentation to each video
|
||||
last3_prob: Probability of limiting rewind to last 3 frames
|
||||
anchor_stats: Dictionary containing window_frame_indices and episode_lengths for episode-relative progress
|
||||
|
||||
Returns:
|
||||
Augmented frames and corresponding progress labels
|
||||
Augmented frames and corresponding episode-relative progress labels
|
||||
"""
|
||||
B, T, C, H, W = frames.shape
|
||||
device = frames.device
|
||||
|
||||
# Create default progress labels - will be properly scaled after stride/dropout
|
||||
# Use frame indices that will give 0-1 range after subsampling
|
||||
default_progress = torch.linspace(0, 1, T, device=device).unsqueeze(0).expand(B, -1)
|
||||
# Extract episode information if available
|
||||
window_frame_indices = anchor_stats.get("window_frame_indices") if anchor_stats else None
|
||||
episode_lengths = anchor_stats.get("episode_lengths") if anchor_stats else None
|
||||
|
||||
# Create default progress labels based on episode-relative positions
|
||||
if window_frame_indices and episode_lengths:
|
||||
# Use actual episode-relative progress
|
||||
default_progress = []
|
||||
for b_idx in range(B):
|
||||
frame_indices = window_frame_indices[b_idx]
|
||||
ep_length = episode_lengths[b_idx]
|
||||
progress = torch.tensor([
|
||||
frame_idx / max(ep_length - 1, 1) for frame_idx in frame_indices
|
||||
], device=device, dtype=torch.float32)
|
||||
default_progress.append(progress)
|
||||
default_progress = torch.stack(default_progress)
|
||||
else:
|
||||
# Fallback to window-relative progress
|
||||
default_progress = torch.linspace(0, 1, T, device=device).unsqueeze(0).expand(B, -1)
|
||||
|
||||
# Apply rewind augmentation to each sample in batch independently
|
||||
augmented_frames = []
|
||||
@@ -1095,11 +1158,27 @@ def apply_video_rewind(frames: Tensor, rewind_prob: float = 0.5, last3_prob: flo
|
||||
reverse_frames = frames[b, max(0, i - k):i].flip(dims=[0])
|
||||
rewound_seq = torch.cat([forward_frames, reverse_frames], dim=0)
|
||||
|
||||
# Create corresponding progress labels without constant padding
|
||||
denom = max(T - 1, 1)
|
||||
forward_progress = torch.linspace(0, (i - 1) / denom, i, device=device)
|
||||
reverse_progress = torch.linspace((i - 1) / denom, max(0.0, (i - k) / denom), k, device=device)
|
||||
rewound_progress = torch.cat([forward_progress, reverse_progress])
|
||||
# Create corresponding progress labels based on episode-relative positions
|
||||
if window_frame_indices and episode_lengths:
|
||||
# Use episode-relative progress for rewind
|
||||
frame_indices = window_frame_indices[b]
|
||||
ep_length = episode_lengths[b]
|
||||
# Forward part: use actual frame indices
|
||||
forward_progress = torch.tensor([
|
||||
frame_indices[idx] / max(ep_length - 1, 1) for idx in range(i)
|
||||
], device=device, dtype=torch.float32)
|
||||
# Reverse part: use reversed frame indices
|
||||
reverse_indices = list(range(max(0, i - k), i))[::-1]
|
||||
reverse_progress = torch.tensor([
|
||||
frame_indices[idx] / max(ep_length - 1, 1) for idx in reverse_indices
|
||||
], device=device, dtype=torch.float32)
|
||||
rewound_progress = torch.cat([forward_progress, reverse_progress])
|
||||
else:
|
||||
# Fallback to window-relative progress
|
||||
denom = max(T - 1, 1)
|
||||
forward_progress = torch.linspace(0, (i - 1) / denom, i, device=device)
|
||||
reverse_progress = torch.linspace((i - 1) / denom, max(0.0, (i - k) / denom), k, device=device)
|
||||
rewound_progress = torch.cat([forward_progress, reverse_progress])
|
||||
|
||||
success = True
|
||||
break
|
||||
@@ -1114,11 +1193,26 @@ def apply_video_rewind(frames: Tensor, rewind_prob: float = 0.5, last3_prob: flo
|
||||
rewound_seq = torch.cat([forward_frames, reverse_frames], dim=0)
|
||||
|
||||
if rewound_seq.shape[0] == T:
|
||||
# Create progress labels
|
||||
denom = max(T - 1, 1)
|
||||
forward_progress = torch.linspace(0, (i - 1) / denom, i, device=device)
|
||||
reverse_progress = torch.linspace((i - 1) / denom, max(0.0, (i - k_extended) / denom), k_extended, device=device)
|
||||
rewound_progress = torch.cat([forward_progress, reverse_progress])
|
||||
# Create progress labels based on episode-relative positions
|
||||
if window_frame_indices and episode_lengths:
|
||||
frame_indices = window_frame_indices[b]
|
||||
ep_length = episode_lengths[b]
|
||||
# Forward part
|
||||
forward_progress = torch.tensor([
|
||||
frame_indices[idx] / max(ep_length - 1, 1) for idx in range(i)
|
||||
], device=device, dtype=torch.float32)
|
||||
# Extended reverse part
|
||||
reverse_indices = list(range(max(0, i - k_extended), i))[::-1]
|
||||
reverse_progress = torch.tensor([
|
||||
frame_indices[idx] / max(ep_length - 1, 1) for idx in reverse_indices
|
||||
], device=device, dtype=torch.float32)
|
||||
rewound_progress = torch.cat([forward_progress, reverse_progress])
|
||||
else:
|
||||
# Fallback to window-relative progress
|
||||
denom = max(T - 1, 1)
|
||||
forward_progress = torch.linspace(0, (i - 1) / denom, i, device=device)
|
||||
reverse_progress = torch.linspace((i - 1) / denom, max(0.0, (i - k_extended) / denom), k_extended, device=device)
|
||||
rewound_progress = torch.cat([forward_progress, reverse_progress])
|
||||
|
||||
success = True
|
||||
break
|
||||
|
||||
@@ -0,0 +1,64 @@
|
||||
#!/usr/bin/env python
|
||||
"""Test script to verify episode-relative progress is working correctly."""
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
|
||||
# Simulate what the dataset would provide
|
||||
def create_test_batch(batch_size=2, episode_lengths=[100, 150]):
|
||||
"""Create a test batch with episode information."""
|
||||
batch = {}
|
||||
|
||||
# Simulate episode indices and frame indices
|
||||
batch["episode_index"] = torch.tensor([0, 1]) # Two different episodes
|
||||
batch["frame_index"] = torch.tensor([50, 75]) # Middle of each episode
|
||||
|
||||
# Simulate images (not important for this test)
|
||||
batch["observation.images"] = torch.randn(batch_size, 16, 3, 224, 224)
|
||||
|
||||
# Simulate language
|
||||
batch["observation.language"] = ["Pick up the blue block", "Pick up the red block"]
|
||||
|
||||
return batch
|
||||
|
||||
def test_progress_calculation():
|
||||
"""Test that progress is calculated correctly."""
|
||||
print("Testing Episode-Relative Progress Calculation")
|
||||
print("=" * 60)
|
||||
|
||||
# Simulate episode_data_index
|
||||
episode_data_index = {
|
||||
"from": torch.tensor([0, 100, 250]), # Episode boundaries
|
||||
"to": torch.tensor([100, 250, 400]) # Episode ends
|
||||
}
|
||||
|
||||
# Test case 1: Sample from middle of episode
|
||||
print("\nTest Case 1: Window from middle of 100-frame episode")
|
||||
print("Anchor at frame 50, window frames [35-50]")
|
||||
|
||||
# Expected progress for frames 35-50 in a 100-frame episode
|
||||
expected_progress = [35/99, 36/99, 37/99, 38/99, 39/99, 40/99, 41/99, 42/99,
|
||||
43/99, 44/99, 45/99, 46/99, 47/99, 48/99, 49/99, 50/99]
|
||||
|
||||
print(f"Expected progress range: [{expected_progress[0]:.3f} to {expected_progress[-1]:.3f}]")
|
||||
print(f"This is ~[0.354 to 0.505] - NOT [0.0 to 1.0]!")
|
||||
|
||||
# Test case 2: Sample from end of episode
|
||||
print("\nTest Case 2: Window from end of 150-frame episode")
|
||||
print("Anchor at frame 140, window frames [125-140]")
|
||||
|
||||
# Expected progress for frames 125-140 in a 150-frame episode
|
||||
expected_progress_2 = [125/149, 126/149, 127/149, 128/149, 129/149, 130/149, 131/149, 132/149,
|
||||
133/149, 134/149, 135/149, 136/149, 137/149, 138/149, 139/149, 140/149]
|
||||
|
||||
print(f"Expected progress range: [{expected_progress_2[0]:.3f} to {expected_progress_2[-1]:.3f}]")
|
||||
print(f"This is ~[0.839 to 0.940] - NOT [0.0 to 1.0]!")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("✅ Key Insight: Each 16-frame window should have progress values")
|
||||
print(" that reflect its actual position within the episode,")
|
||||
print(" NOT always [0.0 to 1.0]!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_progress_calculation()
|
||||
Reference in New Issue
Block a user