mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-24 04:59:47 +00:00
fix progress
This commit is contained in:
@@ -141,9 +141,12 @@ def extract_episode_frames_and_gt(dataset, episode_idx):
|
|||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def predict_rewards_sliding(model, frames, language, max_seq_len=16, batch_size=64, device="cuda"):
|
def predict_rewards_sliding(model, frames, language, max_seq_len=16, batch_size=64, device="cuda"):
|
||||||
"""
|
"""
|
||||||
Sliding-window prediction: for each frame i, create a window [max(0, i-L+1) .. i],
|
Sliding-window prediction for episode-relative progress model.
|
||||||
left-pad by repeating the first frame to length L (<= 16), and take the prediction
|
For each frame i, creates a window and extracts the prediction for that specific frame.
|
||||||
corresponding to the current frame's position in the window.
|
|
||||||
|
NOTE: This assumes we don't have episode context (episode_index, frame_index, episode_length).
|
||||||
|
The model will use its fallback logic for window-relative progress.
|
||||||
|
|
||||||
Returns np.ndarray of shape (T,).
|
Returns np.ndarray of shape (T,).
|
||||||
"""
|
"""
|
||||||
T = frames.shape[0]
|
T = frames.shape[0]
|
||||||
@@ -153,49 +156,45 @@ def predict_rewards_sliding(model, frames, language, max_seq_len=16, batch_size=
|
|||||||
# Preprocessed tensor on device
|
# Preprocessed tensor on device
|
||||||
frames = frames.to(device)
|
frames = frames.to(device)
|
||||||
|
|
||||||
windows = []
|
# Simple approach: predict each 16-frame window and take the last prediction
|
||||||
frame_positions = [] # Track which temporal position each frame should use
|
# This assumes the model can handle the lack of episode context gracefully
|
||||||
|
|
||||||
for i in range(T):
|
|
||||||
start = max(0, i - L + 1)
|
|
||||||
window = frames[start : i + 1] # (len<=L, C, H, W)
|
|
||||||
|
|
||||||
if window.shape[0] < L:
|
|
||||||
pad_needed = L - window.shape[0]
|
|
||||||
pad = window[:1].expand(pad_needed, -1, -1, -1) # repeat first frame
|
|
||||||
window = torch.cat([pad, window], dim=0)
|
|
||||||
|
|
||||||
# IMPROVED FIX: Cycle through MLPs to get varied predictions throughout the episode
|
|
||||||
# This ensures we use all 16 frame-specific MLPs and get varied outputs
|
|
||||||
# Frames 0-15 use MLPs 0-15, frames 16-31 use MLPs 0-15 again, etc.
|
|
||||||
frame_pos = i % L # Cycle through [0, 1, 2, ..., 15, 0, 1, 2, ..., 15, ...]
|
|
||||||
|
|
||||||
windows.append(window)
|
|
||||||
frame_positions.append(frame_pos)
|
|
||||||
|
|
||||||
preds = np.zeros(T, dtype=float)
|
preds = np.zeros(T, dtype=float)
|
||||||
|
|
||||||
for s in range(0, T, batch_size):
|
# Process non-overlapping windows for efficiency
|
||||||
e = min(s + batch_size, T)
|
for start_idx in range(0, T, L):
|
||||||
batch_windows = torch.stack(windows[s:e]) # (B, L, C, H, W)
|
end_idx = min(start_idx + L, T)
|
||||||
batch_positions = frame_positions[s:e]
|
window_frames = frames[start_idx:end_idx]
|
||||||
|
|
||||||
batch = {OBS_IMAGES: batch_windows, OBS_LANGUAGE: [language] * (e - s)} # expects (B, L, C, H, W)
|
# Pad if needed
|
||||||
|
if window_frames.shape[0] < L:
|
||||||
# Model returns (B, L) predictions for each temporal position
|
pad_needed = L - window_frames.shape[0]
|
||||||
values = model.predict_rewards(batch) # torch.Tensor (B, L)
|
if start_idx == 0:
|
||||||
|
# Pad with first frame at beginning
|
||||||
# Debug output removed - issue was identified and fixed
|
pad = window_frames[:1].expand(pad_needed, -1, -1, -1)
|
||||||
|
window_frames = torch.cat([pad, window_frames], dim=0)
|
||||||
if values.dim() == 2:
|
else:
|
||||||
# Extract the prediction corresponding to each frame's position in its window
|
# Pad with last frame at end
|
||||||
batch_preds = []
|
pad = window_frames[-1:].expand(pad_needed, -1, -1, -1)
|
||||||
for b_idx, pos in enumerate(batch_positions):
|
window_frames = torch.cat([window_frames, pad], dim=0)
|
||||||
batch_preds.append(values[b_idx, pos].item())
|
|
||||||
preds[s:e] = np.array(batch_preds)
|
# Create batch (batch size = 1)
|
||||||
|
batch = {
|
||||||
|
OBS_IMAGES: window_frames.unsqueeze(0), # (1, L, C, H, W)
|
||||||
|
OBS_LANGUAGE: [language]
|
||||||
|
}
|
||||||
|
|
||||||
|
# Get predictions for this window
|
||||||
|
window_preds = model.predict_rewards(batch) # (1, L)
|
||||||
|
window_preds = window_preds.squeeze(0).cpu().numpy() # (L,)
|
||||||
|
|
||||||
|
# Extract the relevant predictions for the actual frames
|
||||||
|
actual_frames = min(L, end_idx - start_idx)
|
||||||
|
if start_idx == 0 and window_frames.shape[0] > actual_frames:
|
||||||
|
# Skip padding at beginning
|
||||||
|
preds[start_idx:end_idx] = window_preds[-actual_frames:]
|
||||||
else:
|
else:
|
||||||
# Fallback: if model returns (B,), use as is
|
# Take the first predictions (no beginning padding)
|
||||||
preds[s:e] = values.detach().float().cpu().numpy()
|
preds[start_idx:end_idx] = window_preds[:actual_frames]
|
||||||
|
|
||||||
return preds
|
return preds
|
||||||
|
|
||||||
|
|||||||
@@ -1,64 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
"""Test script to verify episode-relative progress is working correctly."""
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
# Simulate what the dataset would provide
|
|
||||||
def create_test_batch(batch_size=2, episode_lengths=[100, 150]):
|
|
||||||
"""Create a test batch with episode information."""
|
|
||||||
batch = {}
|
|
||||||
|
|
||||||
# Simulate episode indices and frame indices
|
|
||||||
batch["episode_index"] = torch.tensor([0, 1]) # Two different episodes
|
|
||||||
batch["frame_index"] = torch.tensor([50, 75]) # Middle of each episode
|
|
||||||
|
|
||||||
# Simulate images (not important for this test)
|
|
||||||
batch["observation.images"] = torch.randn(batch_size, 16, 3, 224, 224)
|
|
||||||
|
|
||||||
# Simulate language
|
|
||||||
batch["observation.language"] = ["Pick up the blue block", "Pick up the red block"]
|
|
||||||
|
|
||||||
return batch
|
|
||||||
|
|
||||||
def test_progress_calculation():
|
|
||||||
"""Test that progress is calculated correctly."""
|
|
||||||
print("Testing Episode-Relative Progress Calculation")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
# Simulate episode_data_index
|
|
||||||
episode_data_index = {
|
|
||||||
"from": torch.tensor([0, 100, 250]), # Episode boundaries
|
|
||||||
"to": torch.tensor([100, 250, 400]) # Episode ends
|
|
||||||
}
|
|
||||||
|
|
||||||
# Test case 1: Sample from middle of episode
|
|
||||||
print("\nTest Case 1: Window from middle of 100-frame episode")
|
|
||||||
print("Anchor at frame 50, window frames [35-50]")
|
|
||||||
|
|
||||||
# Expected progress for frames 35-50 in a 100-frame episode
|
|
||||||
expected_progress = [35/99, 36/99, 37/99, 38/99, 39/99, 40/99, 41/99, 42/99,
|
|
||||||
43/99, 44/99, 45/99, 46/99, 47/99, 48/99, 49/99, 50/99]
|
|
||||||
|
|
||||||
print(f"Expected progress range: [{expected_progress[0]:.3f} to {expected_progress[-1]:.3f}]")
|
|
||||||
print(f"This is ~[0.354 to 0.505] - NOT [0.0 to 1.0]!")
|
|
||||||
|
|
||||||
# Test case 2: Sample from end of episode
|
|
||||||
print("\nTest Case 2: Window from end of 150-frame episode")
|
|
||||||
print("Anchor at frame 140, window frames [125-140]")
|
|
||||||
|
|
||||||
# Expected progress for frames 125-140 in a 150-frame episode
|
|
||||||
expected_progress_2 = [125/149, 126/149, 127/149, 128/149, 129/149, 130/149, 131/149, 132/149,
|
|
||||||
133/149, 134/149, 135/149, 136/149, 137/149, 138/149, 139/149, 140/149]
|
|
||||||
|
|
||||||
print(f"Expected progress range: [{expected_progress_2[0]:.3f} to {expected_progress_2[-1]:.3f}]")
|
|
||||||
print(f"This is ~[0.839 to 0.940] - NOT [0.0 to 1.0]!")
|
|
||||||
|
|
||||||
print("\n" + "=" * 60)
|
|
||||||
print("✅ Key Insight: Each 16-frame window should have progress values")
|
|
||||||
print(" that reflect its actual position within the episode,")
|
|
||||||
print(" NOT always [0.0 to 1.0]!")
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
test_progress_calculation()
|
|
||||||
Reference in New Issue
Block a user