mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-21 11:39:50 +00:00
fix progress
This commit is contained in:
@@ -10,7 +10,7 @@ Usage:
|
|||||||
python src/lerobot/policies/rlearn/eval_script.py --model MODEL_NAME --dataset DATASET_REPO --episodes N
|
python src/lerobot/policies/rlearn/eval_script.py --model MODEL_NAME --dataset DATASET_REPO --episodes N
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
python src/lerobot/policies/rlearn/eval_script.py --model pepijn223/rlearn_mse5 --dataset pepijn223/phone_pipeline_pickup1 --episodes 2
|
python src/lerobot/policies/rlearn/eval_script.py --model pepijn223/rlearn_18 --dataset pepijn223/phone_pipeline_pickup1 --episodes 2
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
|||||||
@@ -418,6 +418,7 @@ class RLearNPolicy(PreTrainedPolicy):
|
|||||||
frames,
|
frames,
|
||||||
rewind_prob=self.config.rewind_prob,
|
rewind_prob=self.config.rewind_prob,
|
||||||
last3_prob=self.config.rewind_last3_prob,
|
last3_prob=self.config.rewind_last3_prob,
|
||||||
|
anchor_stats=anchor_stats,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Apply stride and frame dropout
|
# Apply stride and frame dropout
|
||||||
@@ -484,18 +485,22 @@ class RLearNPolicy(PreTrainedPolicy):
|
|||||||
# IMPORTANT: Progress should be 0-1 across the ENTIRE EPISODE, not just the temporal window
|
# IMPORTANT: Progress should be 0-1 across the ENTIRE EPISODE, not just the temporal window
|
||||||
loss_dict: dict[str, float] = {}
|
loss_dict: dict[str, float] = {}
|
||||||
|
|
||||||
# Generate progress targets that span full 0-1 range
|
# Generate progress targets based on episode-relative positions
|
||||||
if self.training and augmented_target is not None:
|
if self.training and augmented_target is not None:
|
||||||
# Always create targets that span 0-1 across T_eff frames for better distribution
|
# For rewind augmentation, the augmented_target already contains proper progress values
|
||||||
target = torch.linspace(0, 1, T_eff, device=device).unsqueeze(0).expand(B, -1)
|
# But we need to handle potential stride/dropout
|
||||||
|
target = augmented_target[:, :T_eff] if augmented_target.shape[1] > T_eff else augmented_target
|
||||||
|
if target.shape[1] < T_eff:
|
||||||
|
# This shouldn't happen but handle it gracefully
|
||||||
|
target = torch.linspace(0, 1, T_eff, device=device).unsqueeze(0).expand(B, -1)
|
||||||
else:
|
else:
|
||||||
# Use anchor-based window-relative progress
|
# Use anchor-based episode-relative progress
|
||||||
if anchor_stats.get("fallback_used", False):
|
if anchor_stats.get("fallback_used", False):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Anchor-based sampling failed. Ensure 'episode_index', 'frame_index' are in batch "
|
"Anchor-based sampling failed. Ensure 'episode_index', 'frame_index' are in batch "
|
||||||
"and 'episode_data_index' is loaded from episodes.jsonl"
|
"and 'episode_data_index' is loaded from episodes.jsonl"
|
||||||
)
|
)
|
||||||
target = self._calculate_anchor_based_progress(T_eff)
|
target = self._calculate_anchor_based_progress(T_eff, anchor_stats)
|
||||||
|
|
||||||
# During inference, we might not want to compute loss
|
# During inference, we might not want to compute loss
|
||||||
if not self.training and target is None:
|
if not self.training and target is None:
|
||||||
@@ -844,7 +849,7 @@ class RLearNPolicy(PreTrainedPolicy):
|
|||||||
return ep, fr
|
return ep, fr
|
||||||
|
|
||||||
def _sample_random_anchor_windows(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
def _sample_random_anchor_windows(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
||||||
"""Sample random anchor windows for training."""
|
"""Sample random anchor windows for training and compute episode-relative progress."""
|
||||||
# Extract episode and frame indices - required for proper anchor sampling
|
# Extract episode and frame indices - required for proper anchor sampling
|
||||||
episode_indices, frame_indices = self._extract_episode_and_frame_indices(batch)
|
episode_indices, frame_indices = self._extract_episode_and_frame_indices(batch)
|
||||||
|
|
||||||
@@ -865,6 +870,8 @@ class RLearNPolicy(PreTrainedPolicy):
|
|||||||
# Sample random anchors and build windows
|
# Sample random anchors and build windows
|
||||||
sampled_frames = []
|
sampled_frames = []
|
||||||
anchor_positions = []
|
anchor_positions = []
|
||||||
|
window_frame_indices = [] # Store actual frame indices for progress calculation
|
||||||
|
episode_lengths = [] # Store episode lengths for progress calculation
|
||||||
oob_count = 0
|
oob_count = 0
|
||||||
|
|
||||||
for b_idx in range(B):
|
for b_idx in range(B):
|
||||||
@@ -874,6 +881,7 @@ class RLearNPolicy(PreTrainedPolicy):
|
|||||||
ep_start = self.episode_data_index["from"][ep_idx].item()
|
ep_start = self.episode_data_index["from"][ep_idx].item()
|
||||||
ep_end = self.episode_data_index["to"][ep_idx].item()
|
ep_end = self.episode_data_index["to"][ep_idx].item()
|
||||||
ep_length = ep_end - ep_start
|
ep_length = ep_end - ep_start
|
||||||
|
episode_lengths.append(ep_length)
|
||||||
|
|
||||||
# Choose random anchor - need at least T-1 frames before for [-15..0] window
|
# Choose random anchor - need at least T-1 frames before for [-15..0] window
|
||||||
min_anchor = T - 1
|
min_anchor = T - 1
|
||||||
@@ -883,9 +891,11 @@ class RLearNPolicy(PreTrainedPolicy):
|
|||||||
|
|
||||||
# Build window indices with reflection padding
|
# Build window indices with reflection padding
|
||||||
window_indices = []
|
window_indices = []
|
||||||
|
frame_indices_for_progress = [] # Track actual frame positions for progress
|
||||||
had_oob = False
|
had_oob = False
|
||||||
for delta in range(-(T-1), 1): # [-15, -14, ..., 0] for T=16
|
for delta in range(-(T-1), 1): # [-15, -14, ..., 0] for T=16
|
||||||
idx = anchor + delta
|
idx = anchor + delta
|
||||||
|
actual_frame_idx = idx # Store the actual frame index before reflection
|
||||||
if idx < 0:
|
if idx < 0:
|
||||||
idx = -idx # Reflect at start
|
idx = -idx # Reflect at start
|
||||||
had_oob = True
|
had_oob = True
|
||||||
@@ -893,6 +903,8 @@ class RLearNPolicy(PreTrainedPolicy):
|
|||||||
idx = 2 * (ep_length - 1) - idx # Reflect at end
|
idx = 2 * (ep_length - 1) - idx # Reflect at end
|
||||||
had_oob = True
|
had_oob = True
|
||||||
window_indices.append(min(idx, available_T - 1))
|
window_indices.append(min(idx, available_T - 1))
|
||||||
|
# For reflected indices, use the reflected position for progress
|
||||||
|
frame_indices_for_progress.append(idx)
|
||||||
|
|
||||||
if had_oob:
|
if had_oob:
|
||||||
oob_count += 1
|
oob_count += 1
|
||||||
@@ -900,6 +912,7 @@ class RLearNPolicy(PreTrainedPolicy):
|
|||||||
# Extract frames
|
# Extract frames
|
||||||
frame_tensors = [raw_frames[b_idx, idx] for idx in window_indices]
|
frame_tensors = [raw_frames[b_idx, idx] for idx in window_indices]
|
||||||
sampled_frames.append(torch.stack(frame_tensors))
|
sampled_frames.append(torch.stack(frame_tensors))
|
||||||
|
window_frame_indices.append(frame_indices_for_progress)
|
||||||
|
|
||||||
frames = torch.stack(sampled_frames, dim=0)
|
frames = torch.stack(sampled_frames, dim=0)
|
||||||
|
|
||||||
@@ -908,21 +921,54 @@ class RLearNPolicy(PreTrainedPolicy):
|
|||||||
"anchor_std": float(torch.tensor(anchor_positions).float().std()),
|
"anchor_std": float(torch.tensor(anchor_positions).float().std()),
|
||||||
"oob_fraction": float(oob_count) / B,
|
"oob_fraction": float(oob_count) / B,
|
||||||
"padded_fraction": 0.0, # No padding with reflection approach
|
"padded_fraction": 0.0, # No padding with reflection approach
|
||||||
"fallback_used": False
|
"fallback_used": False,
|
||||||
|
"window_frame_indices": window_frame_indices, # Pass frame indices for progress calculation
|
||||||
|
"episode_lengths": episode_lengths # Pass episode lengths for progress calculation
|
||||||
}
|
}
|
||||||
|
|
||||||
return frames, anchor_stats
|
return frames, anchor_stats
|
||||||
|
|
||||||
def _calculate_anchor_based_progress(self, T_eff: int) -> Tensor:
|
def _calculate_anchor_based_progress(self, T_eff: int, anchor_stats: dict) -> Tensor:
|
||||||
"""Generate window-relative progress (0 to 1 across actual frames used)."""
|
"""Generate episode-relative progress based on actual frame positions within episodes."""
|
||||||
device = next(self.parameters()).device
|
device = next(self.parameters()).device
|
||||||
# Create progress that spans 0 to 1 across the T_eff frames we actually use
|
|
||||||
# This ensures we get samples at all progress levels including near 1.0
|
# Extract frame indices and episode lengths from anchor_stats
|
||||||
if T_eff == 1:
|
window_frame_indices = anchor_stats.get("window_frame_indices")
|
||||||
progress = torch.tensor([0.5], device=device) # Single frame gets middle progress
|
episode_lengths = anchor_stats.get("episode_lengths")
|
||||||
else:
|
|
||||||
progress = torch.linspace(0, 1, T_eff, device=device) # Full 0-1 range
|
if window_frame_indices is None or episode_lengths is None:
|
||||||
return progress.unsqueeze(0) # (1, T_eff) - will broadcast to (B, T_eff)
|
# Fallback to window-relative progress if episode info not available
|
||||||
|
# This should not happen in normal training
|
||||||
|
if T_eff == 1:
|
||||||
|
progress = torch.tensor([0.5], device=device)
|
||||||
|
else:
|
||||||
|
progress = torch.linspace(0, 1, T_eff, device=device)
|
||||||
|
return progress.unsqueeze(0)
|
||||||
|
|
||||||
|
B = len(window_frame_indices)
|
||||||
|
T = len(window_frame_indices[0]) # Original window size (16)
|
||||||
|
|
||||||
|
# Calculate episode-relative progress for each sample
|
||||||
|
all_progress = []
|
||||||
|
for b_idx in range(B):
|
||||||
|
frame_indices = window_frame_indices[b_idx]
|
||||||
|
ep_length = episode_lengths[b_idx]
|
||||||
|
|
||||||
|
# Calculate progress as frame_index / (episode_length - 1)
|
||||||
|
# This gives us progress from 0.0 to 1.0 across the episode
|
||||||
|
progress = torch.tensor([
|
||||||
|
frame_idx / max(ep_length - 1, 1) for frame_idx in frame_indices
|
||||||
|
], device=device, dtype=torch.float32)
|
||||||
|
|
||||||
|
# If we have stride/dropout (T_eff < T), subsample the progress values
|
||||||
|
if T_eff < T:
|
||||||
|
# Subsample evenly from the progress values
|
||||||
|
indices = torch.linspace(0, T - 1, T_eff, dtype=torch.long)
|
||||||
|
progress = progress[indices]
|
||||||
|
|
||||||
|
all_progress.append(progress)
|
||||||
|
|
||||||
|
return torch.stack(all_progress) # (B, T_eff)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -1033,26 +1079,43 @@ def extract_visual_sequence(batch: dict[str, Tensor], target_seq_len: int = None
|
|||||||
return frames
|
return frames
|
||||||
|
|
||||||
|
|
||||||
def apply_video_rewind(frames: Tensor, rewind_prob: float = 0.5, last3_prob: float | None = None) -> tuple[Tensor, Tensor]:
|
def apply_video_rewind(frames: Tensor, rewind_prob: float = 0.5, last3_prob: float | None = None, anchor_stats: dict | None = None) -> tuple[Tensor, Tensor]:
|
||||||
"""Apply video rewinding augmentation without constant-value padding.
|
"""Apply video rewinding augmentation with episode-relative progress.
|
||||||
|
|
||||||
This version ensures the rewound sequence is exactly T frames without flat plateaus
|
This version ensures the rewound sequence is exactly T frames and generates
|
||||||
that drag down the target mean.
|
episode-relative progress labels based on actual frame positions.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
frames: Tensor of shape (B, T, C, H, W)
|
frames: Tensor of shape (B, T, C, H, W)
|
||||||
rewind_prob: Probability of applying rewind augmentation to each video
|
rewind_prob: Probability of applying rewind augmentation to each video
|
||||||
last3_prob: Probability of limiting rewind to last 3 frames
|
last3_prob: Probability of limiting rewind to last 3 frames
|
||||||
|
anchor_stats: Dictionary containing window_frame_indices and episode_lengths for episode-relative progress
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Augmented frames and corresponding progress labels
|
Augmented frames and corresponding episode-relative progress labels
|
||||||
"""
|
"""
|
||||||
B, T, C, H, W = frames.shape
|
B, T, C, H, W = frames.shape
|
||||||
device = frames.device
|
device = frames.device
|
||||||
|
|
||||||
# Create default progress labels - will be properly scaled after stride/dropout
|
# Extract episode information if available
|
||||||
# Use frame indices that will give 0-1 range after subsampling
|
window_frame_indices = anchor_stats.get("window_frame_indices") if anchor_stats else None
|
||||||
default_progress = torch.linspace(0, 1, T, device=device).unsqueeze(0).expand(B, -1)
|
episode_lengths = anchor_stats.get("episode_lengths") if anchor_stats else None
|
||||||
|
|
||||||
|
# Create default progress labels based on episode-relative positions
|
||||||
|
if window_frame_indices and episode_lengths:
|
||||||
|
# Use actual episode-relative progress
|
||||||
|
default_progress = []
|
||||||
|
for b_idx in range(B):
|
||||||
|
frame_indices = window_frame_indices[b_idx]
|
||||||
|
ep_length = episode_lengths[b_idx]
|
||||||
|
progress = torch.tensor([
|
||||||
|
frame_idx / max(ep_length - 1, 1) for frame_idx in frame_indices
|
||||||
|
], device=device, dtype=torch.float32)
|
||||||
|
default_progress.append(progress)
|
||||||
|
default_progress = torch.stack(default_progress)
|
||||||
|
else:
|
||||||
|
# Fallback to window-relative progress
|
||||||
|
default_progress = torch.linspace(0, 1, T, device=device).unsqueeze(0).expand(B, -1)
|
||||||
|
|
||||||
# Apply rewind augmentation to each sample in batch independently
|
# Apply rewind augmentation to each sample in batch independently
|
||||||
augmented_frames = []
|
augmented_frames = []
|
||||||
@@ -1095,11 +1158,27 @@ def apply_video_rewind(frames: Tensor, rewind_prob: float = 0.5, last3_prob: flo
|
|||||||
reverse_frames = frames[b, max(0, i - k):i].flip(dims=[0])
|
reverse_frames = frames[b, max(0, i - k):i].flip(dims=[0])
|
||||||
rewound_seq = torch.cat([forward_frames, reverse_frames], dim=0)
|
rewound_seq = torch.cat([forward_frames, reverse_frames], dim=0)
|
||||||
|
|
||||||
# Create corresponding progress labels without constant padding
|
# Create corresponding progress labels based on episode-relative positions
|
||||||
denom = max(T - 1, 1)
|
if window_frame_indices and episode_lengths:
|
||||||
forward_progress = torch.linspace(0, (i - 1) / denom, i, device=device)
|
# Use episode-relative progress for rewind
|
||||||
reverse_progress = torch.linspace((i - 1) / denom, max(0.0, (i - k) / denom), k, device=device)
|
frame_indices = window_frame_indices[b]
|
||||||
rewound_progress = torch.cat([forward_progress, reverse_progress])
|
ep_length = episode_lengths[b]
|
||||||
|
# Forward part: use actual frame indices
|
||||||
|
forward_progress = torch.tensor([
|
||||||
|
frame_indices[idx] / max(ep_length - 1, 1) for idx in range(i)
|
||||||
|
], device=device, dtype=torch.float32)
|
||||||
|
# Reverse part: use reversed frame indices
|
||||||
|
reverse_indices = list(range(max(0, i - k), i))[::-1]
|
||||||
|
reverse_progress = torch.tensor([
|
||||||
|
frame_indices[idx] / max(ep_length - 1, 1) for idx in reverse_indices
|
||||||
|
], device=device, dtype=torch.float32)
|
||||||
|
rewound_progress = torch.cat([forward_progress, reverse_progress])
|
||||||
|
else:
|
||||||
|
# Fallback to window-relative progress
|
||||||
|
denom = max(T - 1, 1)
|
||||||
|
forward_progress = torch.linspace(0, (i - 1) / denom, i, device=device)
|
||||||
|
reverse_progress = torch.linspace((i - 1) / denom, max(0.0, (i - k) / denom), k, device=device)
|
||||||
|
rewound_progress = torch.cat([forward_progress, reverse_progress])
|
||||||
|
|
||||||
success = True
|
success = True
|
||||||
break
|
break
|
||||||
@@ -1114,11 +1193,26 @@ def apply_video_rewind(frames: Tensor, rewind_prob: float = 0.5, last3_prob: flo
|
|||||||
rewound_seq = torch.cat([forward_frames, reverse_frames], dim=0)
|
rewound_seq = torch.cat([forward_frames, reverse_frames], dim=0)
|
||||||
|
|
||||||
if rewound_seq.shape[0] == T:
|
if rewound_seq.shape[0] == T:
|
||||||
# Create progress labels
|
# Create progress labels based on episode-relative positions
|
||||||
denom = max(T - 1, 1)
|
if window_frame_indices and episode_lengths:
|
||||||
forward_progress = torch.linspace(0, (i - 1) / denom, i, device=device)
|
frame_indices = window_frame_indices[b]
|
||||||
reverse_progress = torch.linspace((i - 1) / denom, max(0.0, (i - k_extended) / denom), k_extended, device=device)
|
ep_length = episode_lengths[b]
|
||||||
rewound_progress = torch.cat([forward_progress, reverse_progress])
|
# Forward part
|
||||||
|
forward_progress = torch.tensor([
|
||||||
|
frame_indices[idx] / max(ep_length - 1, 1) for idx in range(i)
|
||||||
|
], device=device, dtype=torch.float32)
|
||||||
|
# Extended reverse part
|
||||||
|
reverse_indices = list(range(max(0, i - k_extended), i))[::-1]
|
||||||
|
reverse_progress = torch.tensor([
|
||||||
|
frame_indices[idx] / max(ep_length - 1, 1) for idx in reverse_indices
|
||||||
|
], device=device, dtype=torch.float32)
|
||||||
|
rewound_progress = torch.cat([forward_progress, reverse_progress])
|
||||||
|
else:
|
||||||
|
# Fallback to window-relative progress
|
||||||
|
denom = max(T - 1, 1)
|
||||||
|
forward_progress = torch.linspace(0, (i - 1) / denom, i, device=device)
|
||||||
|
reverse_progress = torch.linspace((i - 1) / denom, max(0.0, (i - k_extended) / denom), k_extended, device=device)
|
||||||
|
rewound_progress = torch.cat([forward_progress, reverse_progress])
|
||||||
|
|
||||||
success = True
|
success = True
|
||||||
break
|
break
|
||||||
|
|||||||
@@ -0,0 +1,64 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
"""Test script to verify episode-relative progress is working correctly."""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Simulate what the dataset would provide
|
||||||
|
def create_test_batch(batch_size=2, episode_lengths=[100, 150]):
|
||||||
|
"""Create a test batch with episode information."""
|
||||||
|
batch = {}
|
||||||
|
|
||||||
|
# Simulate episode indices and frame indices
|
||||||
|
batch["episode_index"] = torch.tensor([0, 1]) # Two different episodes
|
||||||
|
batch["frame_index"] = torch.tensor([50, 75]) # Middle of each episode
|
||||||
|
|
||||||
|
# Simulate images (not important for this test)
|
||||||
|
batch["observation.images"] = torch.randn(batch_size, 16, 3, 224, 224)
|
||||||
|
|
||||||
|
# Simulate language
|
||||||
|
batch["observation.language"] = ["Pick up the blue block", "Pick up the red block"]
|
||||||
|
|
||||||
|
return batch
|
||||||
|
|
||||||
|
def test_progress_calculation():
|
||||||
|
"""Test that progress is calculated correctly."""
|
||||||
|
print("Testing Episode-Relative Progress Calculation")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# Simulate episode_data_index
|
||||||
|
episode_data_index = {
|
||||||
|
"from": torch.tensor([0, 100, 250]), # Episode boundaries
|
||||||
|
"to": torch.tensor([100, 250, 400]) # Episode ends
|
||||||
|
}
|
||||||
|
|
||||||
|
# Test case 1: Sample from middle of episode
|
||||||
|
print("\nTest Case 1: Window from middle of 100-frame episode")
|
||||||
|
print("Anchor at frame 50, window frames [35-50]")
|
||||||
|
|
||||||
|
# Expected progress for frames 35-50 in a 100-frame episode
|
||||||
|
expected_progress = [35/99, 36/99, 37/99, 38/99, 39/99, 40/99, 41/99, 42/99,
|
||||||
|
43/99, 44/99, 45/99, 46/99, 47/99, 48/99, 49/99, 50/99]
|
||||||
|
|
||||||
|
print(f"Expected progress range: [{expected_progress[0]:.3f} to {expected_progress[-1]:.3f}]")
|
||||||
|
print(f"This is ~[0.354 to 0.505] - NOT [0.0 to 1.0]!")
|
||||||
|
|
||||||
|
# Test case 2: Sample from end of episode
|
||||||
|
print("\nTest Case 2: Window from end of 150-frame episode")
|
||||||
|
print("Anchor at frame 140, window frames [125-140]")
|
||||||
|
|
||||||
|
# Expected progress for frames 125-140 in a 150-frame episode
|
||||||
|
expected_progress_2 = [125/149, 126/149, 127/149, 128/149, 129/149, 130/149, 131/149, 132/149,
|
||||||
|
133/149, 134/149, 135/149, 136/149, 137/149, 138/149, 139/149, 140/149]
|
||||||
|
|
||||||
|
print(f"Expected progress range: [{expected_progress_2[0]:.3f} to {expected_progress_2[-1]:.3f}]")
|
||||||
|
print(f"This is ~[0.839 to 0.940] - NOT [0.0 to 1.0]!")
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("✅ Key Insight: Each 16-frame window should have progress values")
|
||||||
|
print(" that reflect its actual position within the episode,")
|
||||||
|
print(" NOT always [0.0 to 1.0]!")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_progress_calculation()
|
||||||
Reference in New Issue
Block a user