From 3ed0425d2cf7c826af40733e364300a0873ad954 Mon Sep 17 00:00:00 2001 From: Pepijn Date: Wed, 26 Nov 2025 21:06:20 +0100 Subject: [PATCH] Remove rewind, use clip tokenizer --- scripts/visualize_rewind_predictions.py | 528 ------------- scripts/visualize_sarm_predictions.py | 2 +- src/lerobot/datasets/rewind_sampler.py | 128 ---- src/lerobot/datasets/temporal_sampler.py | 161 ++-- src/lerobot/datasets/video_sampler.py | 383 ---------- src/lerobot/optim/schedulers.py | 35 - src/lerobot/policies/factory.py | 14 - src/lerobot/policies/rewind/__init__.py | 34 - .../policies/rewind/configuration_rewind.py | 138 ---- .../policies/rewind/modeling_rewind.py | 711 ------------------ .../policies/rewind/processor_rewind.py | 405 ---------- src/lerobot/policies/sarm/__init__.py | 2 - .../policies/sarm/configuration_sarm.py | 10 +- src/lerobot/policies/sarm/modeling_sarm.py | 169 ++--- src/lerobot/policies/sarm/processor_sarm.py | 101 +-- src/lerobot/scripts/lerobot_train.py | 17 +- src/lerobot/utils/rabc.py | 2 +- 17 files changed, 172 insertions(+), 2668 deletions(-) delete mode 100644 scripts/visualize_rewind_predictions.py delete mode 100644 src/lerobot/datasets/rewind_sampler.py delete mode 100644 src/lerobot/datasets/video_sampler.py delete mode 100644 src/lerobot/policies/rewind/__init__.py delete mode 100644 src/lerobot/policies/rewind/configuration_rewind.py delete mode 100644 src/lerobot/policies/rewind/modeling_rewind.py delete mode 100644 src/lerobot/policies/rewind/processor_rewind.py diff --git a/scripts/visualize_rewind_predictions.py b/scripts/visualize_rewind_predictions.py deleted file mode 100644 index b3df946b3..000000000 --- a/scripts/visualize_rewind_predictions.py +++ /dev/null @@ -1,528 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2025 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Inference script for ReWiND Reward Model. - -This script loads a trained ReWiND model and runs inference on a dataset episode, -generating visualizations of the predicted task progression over time. - -Example usage: - python scripts/visualize_rewind_predictions.py \ - --model-id username/rewind-model \ - --dataset-repo lerobot/aloha_sim_insertion_human \ - --episode-index 0 \ - --output-dir outputs/rewind_viz \ - --task-description "insert the peg into the socket" -""" - -import argparse -import logging -from pathlib import Path -from typing import Optional - -import matplotlib.pyplot as plt -import matplotlib.gridspec as gridspec -import numpy as np -import torch -from tqdm import tqdm - -from lerobot.datasets.lerobot_dataset import LeRobotDataset -from lerobot.policies.rewind.modeling_rewind import ReWiNDRewardModel - - -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - - -def parse_args(): - parser = argparse.ArgumentParser(description="Run ReWiND inference and visualize predictions") - - # Model arguments - parser.add_argument( - "--model-id", - type=str, - required=True, - help="HuggingFace model ID or local path to trained ReWiND model" - ) - - # Dataset arguments - parser.add_argument( - "--dataset-repo", - type=str, - required=True, - help="HuggingFace dataset repository ID (e.g., lerobot/aloha_sim_insertion_human)" - ) - parser.add_argument( - "--episode-index", - type=int, - default=0, - help="Index of the episode to visualize (default: 0)" - ) - parser.add_argument( - "--task-description", - type=str, - default="perform the task", - help="Task description for the reward model (default: 'perform the task')" - ) - - # Output arguments - parser.add_argument( - "--output-dir", - type=str, - default="outputs/rewind_inference", - help="Directory to save visualization outputs (default: outputs/rewind_inference)" - ) - parser.add_argument( - "--image-key", - type=str, - default=None, - help="Key for images in dataset (e.g., observation.images.image for jaco_play). If not specified, uses model config's image_key" - ) - - # Visualization options - parser.add_argument( - "--show-frames", - action="store_true", - help="Include sample frames in the visualization" - ) - parser.add_argument( - "--num-sample-frames", - type=int, - default=8, - help="Number of sample frames to show (default: 8)" - ) - parser.add_argument( - "--figsize", - type=int, - nargs=2, - default=[12, 6], - help="Figure size as width height (default: 12 6)" - ) - - # Device - parser.add_argument( - "--device", - type=str, - default=None, - help="Device to run inference on (cuda/cpu, default: auto-detect)" - ) - - return parser.parse_args() - - -def load_episode_data( - dataset: LeRobotDataset, - episode_index: int, - image_key: str -) -> tuple[np.ndarray, int, int, str]: - """ - Load all frames from a specific episode. - - Args: - dataset: LeRobotDataset instance - episode_index: Index of the episode to load - image_key: Key for accessing images in the dataset - - Returns: - Tuple of (frames, start_index, end_index, task_description) - """ - # Get episode boundaries - episode_data = dataset.meta.episodes - start_idx = episode_data["dataset_from_index"][episode_index] - end_idx = episode_data["dataset_to_index"][episode_index] - - logger.info(f"Loading episode {episode_index}: frames {start_idx} to {end_idx} ({end_idx - start_idx} frames)") - - # Get task description from the dataset if available - task_description = None - first_item = dataset[start_idx] - if "task" in first_item: - task_description = first_item["task"] - print(f"✓ Extracted task from episode {episode_index}: '{task_description}'") - - # Load all frames from the episode - frames = [] - for idx in tqdm(range(start_idx, end_idx), desc="Loading frames"): - item = dataset[idx] - # Get image from the item - img = item[image_key] - - # Convert to numpy if needed - if isinstance(img, torch.Tensor): - img = img.cpu().numpy() - - # Handle different image formats (C, H, W) or (H, W, C) - if img.shape[0] in [1, 3]: # Channel first - img = np.transpose(img, (1, 2, 0)) - - # Convert to uint8 if needed - if img.dtype != np.uint8: - if img.max() <= 1.0: - img = (img * 255).astype(np.uint8) - else: - img = img.astype(np.uint8) - - frames.append(img) - - frames = np.array(frames) - logger.info(f"Loaded {len(frames)} frames with shape {frames[0].shape}") - - return frames, start_idx, end_idx, task_description - - -@torch.no_grad() -def run_inference( - model: ReWiNDRewardModel, - frames: np.ndarray, - task_description: str, - batch_size: int = 32 -) -> tuple[np.ndarray, np.ndarray]: - """ - Run ReWiND inference on video frames using the original ReWiND approach. - - This function creates video slices for all frames at once (similar to the original - metaworld_label_reward.py), where each slice contains frames from start up to that point. - - Progress Normalization (from original ReWiND dataset.py): - - Training: progress = [1, 2, ..., N] / remaining_length - where remaining_length = episode_end - sequence_start - - Inference: Starting from frame 0, remaining_length = total_episode_length - So expected progress for frame i = (i + 1) / total_episode_length - - This function computes both: - 1. Model predictions (what the model actually predicts) - 2. Expected progress (ground truth based on frame position) - - Args: - model: ReWiND model - frames: Video frames (num_frames, H, W, C) - task_description: Task description text - batch_size: Batch size for processing slices - - Returns: - Tuple of: - - Model predictions for each frame (num_frames,) - - Expected progress for each frame (num_frames,) - """ - total_frames = len(frames) - - logger.info("Encoding video frames with DINO...") - video_embeddings = model.encode_images(frames) - - logger.info("Encoding task description with MiniLM...") - text_embedding = model.encode_text(task_description) - - logger.info("Creating video slices (original ReWiND approach)...") - # Convert to tensors - video_embeddings = torch.tensor(video_embeddings, dtype=torch.float32) - text_embedding = torch.tensor(text_embedding, dtype=torch.float32) - - # Create video slices: for each frame i, create a sequence of frames [0:i+1] - # This matches the original ReWiND inference approach - video_slices = [] - for i in tqdm(range(len(video_embeddings)), desc="Creating slices"): - # Slice from start to current frame (inclusive) - video_slice = video_embeddings[:i + 1] - - # Pad or subsample to max_length - if model.config.subsample_video: - video_slice = model.padding_video(video_slice, model.config.max_length) - - video_slices.append(video_slice) - - video_slices = torch.stack(video_slices) # (num_frames, max_length, 768) - - # Create last_index_mask to extract the relevant prediction for each slice - # For slice i, the last valid frame is at position min(i, max_length-1) - max_length = model.config.max_length - last_index_mask = torch.zeros((len(video_slices), max_length), dtype=torch.bool) - - for i in range(len(video_slices)): - last_frame_idx = min(i, max_length - 1) - last_index_mask[i, last_frame_idx] = 1 - - logger.info("Running ReWiND inference on all slices...") - # Process in batches - all_progress = [] - for i in tqdm(range(0, len(video_slices), batch_size), desc="Inference"): - batch_video = video_slices[i:i + batch_size].to(model.device) - batch_mask = last_index_mask[i:i + batch_size].to(model.device) - batch_size_actual = batch_video.shape[0] - - # Replicate text embedding for batch - batch_text = text_embedding.unsqueeze(0).repeat(batch_size_actual, 1).to(model.device) - - # Get predictions for all frames in batch - progress_preds = model.rewind_transformer(batch_video, batch_text) # (batch, max_length, 1) - progress_preds = progress_preds.squeeze(-1) # (batch, max_length) - - # Extract predictions using the last_index_mask - # This gets the prediction for the last valid frame in each slice - batch_progress = progress_preds[batch_mask].cpu().numpy() - all_progress.extend(batch_progress) - - predictions = np.array(all_progress) - - # Compute expected progress based on original ReWiND normalization - # When starting from frame 0, remaining_length = total_episode_length - # Expected progress for frame i = (i + 1) / total_frames - expected_progress = np.arange(1, total_frames + 1, dtype=np.float32) / total_frames - - logger.info(f"Inference complete. Predicted progress range: [{predictions.min():.3f}, {predictions.max():.3f}]") - logger.info(f"Expected progress range: [{expected_progress.min():.3f}, {expected_progress.max():.3f}]") - - return predictions, expected_progress - - -def visualize_predictions( - frames: np.ndarray, - predictions: np.ndarray, - expected_progress: np.ndarray, - task_description: str, - output_path: Path, - show_frames: bool = False, - num_sample_frames: int = 8, - figsize: tuple = (12, 6) -): - """ - Create visualization of ReWiND predictions with expected progress comparison. - - Args: - frames: Video frames (num_frames, H, W, C) - predictions: Model progress predictions (num_frames,) - expected_progress: Expected progress based on frame position (num_frames,) - task_description: Task description - output_path: Path to save the figure - show_frames: Whether to include sample frames - num_sample_frames: Number of frames to show - figsize: Figure size (width, height) - """ - if show_frames: - # Create figure with progress plot and sample frames - fig = plt.figure(figsize=(figsize[0], figsize[1] + 4)) - gs = gridspec.GridSpec(2, 1, height_ratios=[2, 1], hspace=0.3) - - # Progress plot - ax_progress = fig.add_subplot(gs[0]) - else: - # Just progress plot - fig, ax_progress = plt.subplots(1, 1, figsize=figsize) - - # Plot progress over time - frame_indices = np.arange(len(predictions)) - - # Plot expected progress (ground truth) - ax_progress.plot(frame_indices, expected_progress, linewidth=2, color='#A8DADC', - linestyle='--', label='Expected Progress (Linear)', alpha=0.7) - - # Plot model predictions - ax_progress.plot(frame_indices, predictions, linewidth=2.5, color='#2E86AB', - label='Model Predictions') - ax_progress.fill_between(frame_indices, 0, predictions, alpha=0.2, color='#2E86AB') - - # Add reference line at 1.0 - ax_progress.axhline(y=1.0, color='gray', linestyle='--', alpha=0.5, linewidth=1) - - # Styling - ax_progress.set_xlabel('Frame Index', fontsize=12) - ax_progress.set_ylabel('Task Progress', fontsize=12) - ax_progress.set_title(f'ReWiND Task Progress Prediction\nTask: "{task_description}"', - fontsize=14, fontweight='bold') - ax_progress.grid(True, alpha=0.3) - ax_progress.set_ylim(-0.05, 1.1) - ax_progress.legend(loc='upper left') - - # Compute alignment metrics - mae = np.mean(np.abs(predictions - expected_progress)) - rmse = np.sqrt(np.mean((predictions - expected_progress) ** 2)) - - # Add statistics box - stats_text = ( - f'Frames: {len(predictions)}\n' - f'Model Final: {predictions[-1]:.3f}\n' - f'Model Max: {predictions.max():.3f}\n' - f'Model Mean: {predictions.mean():.3f}\n' - f'MAE: {mae:.3f}\n' - f'RMSE: {rmse:.3f}' - ) - ax_progress.text(0.98, 0.02, stats_text, transform=ax_progress.transAxes, - fontsize=10, verticalalignment='bottom', horizontalalignment='right', - bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5)) - - # Show sample frames if requested - if show_frames: - # Select evenly spaced frames - frame_indices_to_show = np.linspace(0, len(frames) - 1, num_sample_frames, dtype=int) - - # Create subplot for frames - ax_frames = fig.add_subplot(gs[1]) - ax_frames.axis('off') - - # Create grid for frames - frame_height = frames[0].shape[0] - frame_width = frames[0].shape[1] - - combined_width = frame_width * num_sample_frames - combined_image = np.zeros((frame_height, combined_width, 3), dtype=np.uint8) - - for i, frame_idx in enumerate(frame_indices_to_show): - frame = frames[frame_idx] - if frame.shape[-1] == 1: - frame = np.repeat(frame, 3, axis=-1) - - # Add frame to combined image - x_start = i * frame_width - x_end = (i + 1) * frame_width - combined_image[:, x_start:x_end] = frame - - # Add frame number and progress value - progress_val = predictions[frame_idx] - label = f'Frame {frame_idx}\nProgress: {progress_val:.3f}' - - # Draw label on image - ax_frames.text(x_start + frame_width / 2, -10, label, - ha='center', va='top', fontsize=8, - bbox=dict(boxstyle='round', facecolor='white', alpha=0.7)) - - ax_frames.imshow(combined_image) - ax_frames.set_title('Sample Frames', fontsize=12, pad=20) - - # Save figure - plt.tight_layout() - output_path.parent.mkdir(parents=True, exist_ok=True) - plt.savefig(output_path, dpi=150, bbox_inches='tight') - logger.info(f"Saved visualization to {output_path}") - - plt.close() - - -def main(): - args = parse_args() - - # Setup device - if args.device is None: - device = "cuda" if torch.cuda.is_available() else "cpu" - else: - device = args.device - logger.info(f"Using device: {device}") - - # Load model - logger.info(f"Loading ReWiND model from {args.model_id}...") - model = ReWiNDRewardModel.from_pretrained(args.model_id) - model.to(device) - model.eval() - logger.info("Model loaded successfully") - - # Load dataset - logger.info(f"Loading dataset {args.dataset_repo}...") - dataset = LeRobotDataset(args.dataset_repo) - logger.info(f"Dataset loaded: {len(dataset.meta.episodes)} episodes, {len(dataset)} frames") - - # Validate episode index - if args.episode_index >= len(dataset.meta.episodes): - raise ValueError( - f"Episode index {args.episode_index} out of range. " - f"Dataset has {len(dataset.meta.episodes)} episodes." - ) - - # Determine which image key to use - image_key = args.image_key if args.image_key is not None else model.config.image_key - logger.info(f"Using image key: {image_key}") - - # Load episode data (this also extracts the task description from the episode) - frames, start_idx, end_idx, dataset_task = load_episode_data(dataset, args.episode_index, image_key) - - # Use task description from dataset if available, otherwise use command-line argument - task_description = dataset_task if dataset_task is not None else args.task_description - logger.info(f"Using task description: '{task_description}'") - - # Run inference - predictions, expected_progress = run_inference(model, frames, task_description) - - # Create visualization - output_dir = Path(args.output_dir) - output_path = output_dir / f"rewind_prediction_ep{args.episode_index}.png" - - visualize_predictions( - frames, - predictions, - expected_progress, - task_description, - output_path, - show_frames=args.show_frames, - num_sample_frames=args.num_sample_frames, - figsize=tuple(args.figsize) - ) - - # Save predictions and expected progress as numpy arrays - predictions_path = output_dir / f"predictions_ep{args.episode_index}.npy" - expected_path = output_dir / f"expected_progress_ep{args.episode_index}.npy" - np.save(predictions_path, predictions) - np.save(expected_path, expected_progress) - logger.info(f"Saved predictions array to {predictions_path}") - logger.info(f"Saved expected progress to {expected_path}") - - # Compute alignment metrics - mae = np.mean(np.abs(predictions - expected_progress)) - rmse = np.sqrt(np.mean((predictions - expected_progress) ** 2)) - correlation = np.corrcoef(predictions, expected_progress)[0, 1] - - # Print summary - logger.info("\n" + "="*60) - logger.info("INFERENCE SUMMARY") - logger.info("="*60) - logger.info(f"Model: {args.model_id}") - logger.info(f"Dataset: {args.dataset_repo}") - logger.info(f"Episode: {args.episode_index}") - logger.info(f"Task: {task_description}") - logger.info(f"Frames: {len(frames)}") - logger.info(f"\nModel Predictions:") - logger.info(f" Final: {predictions[-1]:.3f}") - logger.info(f" Max: {predictions.max():.3f}") - logger.info(f" Mean: {predictions.mean():.3f}") - logger.info(f" Std: {predictions.std():.3f}") - logger.info(f"\nExpected Progress (Linear):") - logger.info(f" Final: {expected_progress[-1]:.3f}") - logger.info(f" Mean: {expected_progress.mean():.3f}") - logger.info(f"\nAlignment Metrics:") - logger.info(f" MAE: {mae:.3f}") - logger.info(f" RMSE: {rmse:.3f}") - logger.info(f" Correlation: {correlation:.3f}") - logger.info(f"\nOutput:") - logger.info(f" Visualization: {output_path}") - logger.info("="*60) - - # Diagnostic warnings - if predictions.std() < 0.05: - logger.warning("\n⚠ WARNING: Mode collapse detected (std < 0.05)") - logger.warning(" Model predictions show very low variance.") - logger.warning(" This indicates the model was likely trained with incorrect") - logger.warning(" progress normalization (absolute indices instead of remaining length).") - elif mae > 0.3: - logger.warning("\n⚠ WARNING: High prediction error (MAE > 0.3)") - logger.warning(" Model predictions deviate significantly from expected linear progress.") - logger.warning(" Consider retraining with correct progress normalization.") - elif correlation < 0.5: - logger.warning("\n⚠ WARNING: Low correlation with expected progress (< 0.5)") - logger.warning(" Model predictions don't align well with linear task progression.") - else: - logger.info("\n✓ Model predictions show healthy progression!") - - -if __name__ == "__main__": - main() - diff --git a/scripts/visualize_sarm_predictions.py b/scripts/visualize_sarm_predictions.py index 1841af63c..eeb81d07e 100644 --- a/scripts/visualize_sarm_predictions.py +++ b/scripts/visualize_sarm_predictions.py @@ -244,7 +244,7 @@ def run_inference( logger.info("Encoding video frames with CLIP...") video_embeddings = model.encode_images(frames) - logger.info("Encoding task description with MiniLM...") + logger.info("Encoding task description with CLIP...") text_embedding = model.encode_text(task_description) # Get config values diff --git a/src/lerobot/datasets/rewind_sampler.py b/src/lerobot/datasets/rewind_sampler.py deleted file mode 100644 index 1b1d3f392..000000000 --- a/src/lerobot/datasets/rewind_sampler.py +++ /dev/null @@ -1,128 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2025 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -ReWiND Sampler for temporal sequence loading. -""" - -import logging -from typing import Iterator, Optional -import numpy as np -import torch -from torch.utils.data import Sampler -import random - - -class ReWiNDTemporalSampler(Sampler): - """ - Sampler for ReWiND that samples random temporal windows from episodes. - - Matches original ReWiND sampling: - - Samples random start and end points within episodes - - Minimum window size of 3 frames - - Can sample from beginning, middle, or end of episodes - - Args: - dataset_from_index: Start indices of episodes - dataset_to_index: End indices of episodes - sequence_length: Maximum sequence length (for padding/subsampling) - stride: Not used (kept for API compatibility) - shuffle: Whether to shuffle sampling order - seed: Random seed - """ - - def __init__( - self, - dataset_from_index: np.ndarray, - dataset_to_index: np.ndarray, - sequence_length: int = 32, - stride: int = 1, - shuffle: bool = True, - seed: Optional[int] = None, - ): - self.dataset_from_index = np.array(dataset_from_index) - self.dataset_to_index = np.array(dataset_to_index) - self.sequence_length = sequence_length - self.shuffle = shuffle - - if seed is not None: - self.seed = seed - random.seed(seed) - np.random.seed(seed) - self.generator = torch.Generator().manual_seed(seed) - else: - self.generator = torch.Generator() - - # Compute valid episodes (those with at least 3 frames) - self._compute_valid_episodes() - - # Number of samples per epoch (matching original ReWiND) - self.samples_per_epoch = 100 * 64 # 100 batches of 64 - - logging.info( - f"ReWiNDTemporalSampler: {len(self.valid_episodes)} valid episodes, " - f"{self.samples_per_epoch} samples per epoch" - ) - - def _compute_valid_episodes(self): - """Compute valid episodes (those with at least 3 frames).""" - self.valid_episodes = [] - - for ep_idx in range(len(self.dataset_from_index)): - ep_start = self.dataset_from_index[ep_idx] - ep_end = self.dataset_to_index[ep_idx] - episode_length = ep_end - ep_start - - if episode_length >= 3: # Minimum 3 frames - self.valid_episodes.append((ep_idx, ep_start, ep_end)) - - self.valid_episodes = np.array(self.valid_episodes) - - def __len__(self) -> int: - return self.samples_per_epoch - - def __iter__(self) -> Iterator[int]: - """ - Yields ONE index per sample (the end of a random window). - - Matches original ReWiND behavior: - 1. Pick random episode - 2. Pick random end frame (at least 3 frames from start) - 3. Yield that end frame index - 4. Dataset/processor loads from episode start to this end frame - 5. Model pads/subsamples to sequence_length (32) - - This allows sampling from anywhere in episodes: - - Early frames → short sequences (mostly padding) → low progress - - Middle frames → medium sequences (some subsampling) → medium progress - - End frames → long sequences (full subsampling) → high progress approaching 1.0 - """ - for _ in range(self.samples_per_epoch): - # Randomly select an episode - ep_idx, ep_start, ep_end = self.valid_episodes[ - np.random.randint(0, len(self.valid_episodes)) - ] - - episode_length = ep_end - ep_start - - # Sample a random end point (must be at least 3 frames from start) - # This matches original: random.randint(start_idx+3, len(progress_dataset)) - end_offset = np.random.randint(3, episode_length + 1) - end_idx = ep_start + end_offset - - # Yield ONLY the end index - # The dataset will load all frames from ep_start to end_idx - yield int(end_idx - 1) # -1 because end_idx is exclusive diff --git a/src/lerobot/datasets/temporal_sampler.py b/src/lerobot/datasets/temporal_sampler.py index 3e210717e..de07942b2 100644 --- a/src/lerobot/datasets/temporal_sampler.py +++ b/src/lerobot/datasets/temporal_sampler.py @@ -15,12 +15,10 @@ # limitations under the License. """ -Temporal Sequence Sampler for reward models and temporal policies. +SARM Temporal Sampler for reward model training. -Supports multiple sampling modes: -- "rewind": ReWiND-style sampling (random windows from episode start) -- "sarm": SARM-style sampling (9-frame sequences with specific pattern) -- "custom": Custom temporal sampling +Samples frames from episodes ensuring sufficient temporal history for SARM's +9-frame pattern (1 initial + 8 consecutive with frame_gap spacing). """ import logging @@ -31,24 +29,23 @@ from torch.utils.data import Sampler import random -class TemporalSequenceSampler(Sampler): +class SARMTemporalSampler(Sampler): """ - Generalized temporal sampler for reward models. + Temporal sampler for SARM reward model training. - Supports multiple sampling modes: - - "rewind": Consecutive frames from episode start to random end point (ReWiND: 32 consecutive frames) - - "sarm": 9-frame sequences with 1 initial + 8 consecutive (SARM) - - "custom": Custom temporal sampling + SARM uses 9 frames per sample: + - Frame 0: Initial frame of the episode (always frame 0) + - Frames 1-8: 8 consecutive frames with frame_gap spacing ending at current frame + + This sampler ensures we only sample from positions that have enough + temporal history (at least 7 * frame_gap frames from episode start). Args: - dataset_from_index: Start indices of episodes - dataset_to_index: End indices of episodes - sequence_length: Maximum sequence length (for padding/subsampling) - stride: Frame stride for consecutive sampling (SARM mode) + dataset_from_index: Start indices of episodes (global dataset indices) + dataset_to_index: End indices of episodes (global dataset indices) + frame_gap: Gap between consecutive frames (default: 30 = 1 second at 30fps) shuffle: Whether to shuffle sampling order - seed: Random seed - sampling_mode: Sampling mode ("rewind", "sarm", or "custom") - min_frames: Minimum frames per episode (default: 3) + seed: Random seed for reproducibility samples_per_epoch: Number of samples per epoch (default: 6400) """ @@ -56,25 +53,21 @@ class TemporalSequenceSampler(Sampler): self, dataset_from_index: np.ndarray, dataset_to_index: np.ndarray, - sequence_length: int = 32, - stride: int = 1, + frame_gap: int = 30, shuffle: bool = True, seed: Optional[int] = None, - sampling_mode: str = "rewind", - min_frames: int = 3, samples_per_epoch: int = 6400, ): self.dataset_from_index = np.array(dataset_from_index) self.dataset_to_index = np.array(dataset_to_index) - self.sequence_length = sequence_length - self.stride = stride + self.frame_gap = frame_gap self.shuffle = shuffle - self.sampling_mode = sampling_mode - self.min_frames = min_frames self.samples_per_epoch = samples_per_epoch - if sampling_mode not in ["rewind", "sarm", "custom"]: - raise ValueError(f"sampling_mode must be 'rewind', 'sarm', or 'custom', got {sampling_mode}") + # Minimum frames needed for SARM pattern: + # 8 consecutive frames with frame_gap spacing = 7 * frame_gap + 1 + # (Plus the initial frame which is always available) + self.min_frames_needed = 7 * frame_gap + 1 if seed is not None: self.seed = seed @@ -84,98 +77,68 @@ class TemporalSequenceSampler(Sampler): else: self.generator = torch.Generator() - # Compute valid episodes - self._compute_valid_episodes() + # Compute valid episodes and sampling positions + self._compute_valid_positions() logging.info( - f"TemporalSequenceSampler ({sampling_mode} mode): " - f"{len(self.valid_episodes)} valid episodes, " - f"{self.samples_per_epoch} samples per epoch" + f"SARMTemporalSampler: {len(self.valid_episodes)} valid episodes, " + f"{len(self.all_valid_positions)} valid positions, " + f"{self.samples_per_epoch} samples per epoch, " + f"frame_gap={frame_gap}" ) - def _compute_valid_episodes(self): - """Compute valid episodes based on minimum frame requirement.""" + def _compute_valid_positions(self): + """Compute valid episodes and all valid sampling positions.""" self.valid_episodes = [] + self.all_valid_positions = [] for ep_idx in range(len(self.dataset_from_index)): ep_start = self.dataset_from_index[ep_idx] ep_end = self.dataset_to_index[ep_idx] episode_length = ep_end - ep_start - # For SARM mode, need enough frames for the sequence pattern - if self.sampling_mode == "sarm": - # Need at least sequence_length * stride frames - min_required = self.sequence_length * self.stride - if episode_length >= min_required: - self.valid_episodes.append((ep_idx, ep_start, ep_end)) - else: - # For rewind mode, use min_frames - if episode_length >= self.min_frames: - self.valid_episodes.append((ep_idx, ep_start, ep_end)) + # Episode must have enough frames for SARM pattern + if episode_length >= self.min_frames_needed: + self.valid_episodes.append((ep_idx, ep_start, ep_end)) + + # Valid positions: from min_frames_needed to episode end + # These are global dataset indices + for pos in range(ep_start + self.min_frames_needed - 1, ep_end): + self.all_valid_positions.append(pos) self.valid_episodes = np.array(self.valid_episodes) + self.all_valid_positions = np.array(self.all_valid_positions) + + if len(self.all_valid_positions) == 0: + raise ValueError( + f"No valid sampling positions found! " + f"Episodes need at least {self.min_frames_needed} frames " + f"(7 * frame_gap + 1 = 7 * {self.frame_gap} + 1)." + ) def __len__(self) -> int: return self.samples_per_epoch def __iter__(self) -> Iterator[int]: """ - Yields ONE index per sample. + Yields global dataset indices for sampling. - Sampling behavior depends on mode: - - ReWiND mode: - 1. Pick random episode - 2. Pick random end frame (at least min_frames from start) - 3. Yield that end frame index - 4. Dataset loads from episode start to this end frame - - SARM mode: - 1. Pick random episode - 2. Pick random end frame (must allow sequence_length frames with stride) - 3. Yield that end frame index - 4. Dataset loads sequence_length frames with stride spacing ending at this frame + Each yielded index represents the "current frame" position. + The dataset's observation_delta_indices then handles loading: + - Frame 0: Episode initial frame (via large negative delta clamping) + - Frames 1-8: Consecutive frames ending at the yielded index """ - for _ in range(self.samples_per_epoch): - # Randomly select an episode - ep_idx, ep_start, ep_end = self.valid_episodes[ - np.random.randint(0, len(self.valid_episodes)) - ] - - episode_length = ep_end - ep_start - - if self.sampling_mode == "rewind": - # ReWiND: Sample random end point (at least min_frames from start) - end_offset = np.random.randint(self.min_frames, episode_length + 1) - end_idx = ep_start + end_offset - - # Yield the end index (dataset will load from start to this point) - yield int(end_idx - 1) # -1 because end_idx is exclusive - - elif self.sampling_mode == "sarm": - # SARM: Sample end point that allows full sequence - # We need sequence_length frames with stride spacing - min_end_offset = self.sequence_length * self.stride - - if episode_length >= min_end_offset: - # Can sample anywhere from min_end_offset to episode_length - end_offset = np.random.randint(min_end_offset, episode_length + 1) - else: - # Episode is exactly the minimum length - end_offset = episode_length - - end_idx = ep_start + end_offset - - # Yield the end index (dataset will load sequence with stride) - yield int(end_idx - 1) # -1 because end_idx is exclusive - - else: # custom mode - # Default to rewind-style sampling - end_offset = np.random.randint(self.min_frames, episode_length + 1) - end_idx = ep_start + end_offset - yield int(end_idx - 1) + if self.shuffle: + # Randomly sample from all valid positions + for _ in range(self.samples_per_epoch): + idx = np.random.randint(0, len(self.all_valid_positions)) + yield int(self.all_valid_positions[idx]) + else: + # Sequential sampling with wrap-around + for i in range(self.samples_per_epoch): + idx = i % len(self.all_valid_positions) + yield int(self.all_valid_positions[idx]) # Backwards compatibility alias -ReWiNDTemporalSampler = TemporalSequenceSampler - +TemporalSequenceSampler = SARMTemporalSampler diff --git a/src/lerobot/datasets/video_sampler.py b/src/lerobot/datasets/video_sampler.py deleted file mode 100644 index 222167cc8..000000000 --- a/src/lerobot/datasets/video_sampler.py +++ /dev/null @@ -1,383 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2025 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Video sampling utilities for temporal data augmentation and frame selection. - -This module provides utilities for sampling and augmenting video sequences, particularly -for reward model training. It includes functions for: -- Padding/sampling videos to fixed lengths -- Video rewind augmentation for learning to decrease rewards -""" - -import random -from typing import Tuple - -import numpy as np -import torch - - -def sample_video_feature( - video_feature: torch.Tensor, - max_length: int = 32, - random_sample: bool = True, - remaining_length: int = None -) -> tuple[torch.Tensor, torch.Tensor]: - """ - Sample or pad video features to a fixed length with progress targets. - - Progress normalization matches original ReWiND implementation: - - Progress = (position_in_sequence + 1) / remaining_trajectory_length - - remaining_trajectory_length = frames from first sampled frame to episode end - - Original ReWiND logic (dataset.py lines 12493-12499): - video_frames = frames[start_idx:end_idx] - full_frames = frames[start_idx:] # All frames from start to episode end - progress = [1, 2, ..., len(video_frames)] / len(full_frames) - - This ensures all sequences show increasing progress from near-zero, regardless - of where they're sampled from in the episode. - - Uses original ReWiND sampling: random start/end points with minimum 3 frames. - - Args: - video_feature: Video features tensor (num_frames, feature_dim) - max_length: Target sequence length - random_sample: If True, randomly sample frames. If False, uniformly sample consecutive frames. - remaining_length: Remaining trajectory length from first frame to episode end - - Returns: - Tuple of: - - Sampled/padded video features (max_length, feature_dim) - - Progress targets for each frame (max_length,) - """ - video_length = len(video_feature) - - # Original ReWiND sampling: random start/end with minimum 3 frames - if video_length > 3: - # Sample random start index (ensuring we can get at least 3 frames) - start_idx = random.randint(0, max(0, video_length - 3)) - # Sample random end index (at least 3 frames after start, up to video_length) - end_idx = random.randint(min(start_idx + 3, video_length), video_length) - - # Extract the sampled segment - video_feature = video_feature[start_idx:end_idx] - - # Update video_length for the sampled segment - video_length = len(video_feature) - - # Adjust remaining_length to be from start_idx to episode end - if remaining_length is not None: - # The remaining length should be from start_idx to episode end - # If we started at start_idx, we've already consumed start_idx frames - remaining_length = remaining_length - start_idx if remaining_length > start_idx else video_length - - # Generate progress targets using ORIGINAL ReWiND formula - # Progress = (position_in_sequence + 1) / remaining_trajectory_length - progress_indices = torch.arange(1, video_length + 1, dtype=torch.float32) - progress_targets = progress_indices / remaining_length - - if video_length < max_length: - # Pad with last frame - padding_length = max_length - video_length - last_frame = video_feature[-1].unsqueeze(0) - padding_frames = last_frame.repeat(padding_length, 1) - video_feature = torch.cat([video_feature, padding_frames], dim=0) - - # Pad progress with last progress value - last_progress = progress_targets[-1] - padding_progress = torch.full((padding_length,), last_progress) - progress_targets = torch.cat([progress_targets, padding_progress]) - - elif video_length > max_length: - if random_sample: - # Random sampling (maintains temporal order via sorted indices) - frame_idx = sorted(random.sample(range(video_length), max_length)) - else: - # Uniform sampling (consecutive frames with even spacing) - frame_idx = np.linspace(0, video_length - 1, max_length, dtype=int) - video_feature = video_feature[frame_idx] - progress_targets = progress_targets[frame_idx] - - return video_feature, progress_targets - - -def sample_reverse_video_feature( - video_feature: torch.Tensor, - max_length: int = 32, - random_sample: bool = True, - remaining_length: int = None -) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Sample video with reverse augmentation (video rewind) - ORIGINAL REWIND LOGIC. - - This implements the EXACT video rewind augmentation from the original ReWiND paper: - 1. Take forward sequence (sampled with random start/end, min 3 frames) - 2. Append reversed frames from the END backwards - 3. Progress increases then decreases (simulating task completion then failure) - - Progress normalization matches original ReWiND (same as sample_video_feature). - Original ReWiND logic (dataset.py lines 12526-12541): - progress = [1, 2, ..., len(video_frames)] / len(full_frames) - reverse_progress = progress[::-1][1:selected_end_point] - - Args: - video_feature: Video features tensor (num_frames, feature_dim) - max_length: Target sequence length - random_sample: If True, use random sampling for frame selection - remaining_length: Remaining trajectory length from first frame to episode end - - Returns: - Tuple of: - - Rewound video features (max_length, feature_dim) - - Progress targets for each frame (max_length,) - """ - video_length = len(video_feature) - - # Original logic: start from first half, end in second half, ensure min 3 frames - if video_length > 3: - # Sample start from first half - start_idx = random.randint(0, video_length // 2) - # Sample end from second half - end_idx = random.randint(video_length // 2, video_length) - - # Ensure minimum 3 frames difference (original uses while loop) - while end_idx - start_idx < 3: - start_idx = random.randint(0, video_length // 2) - end_idx = random.randint(video_length // 2, video_length) - - # Extract the forward segment - video_feature = video_feature[start_idx:end_idx] - video_length = len(video_feature) - - # Adjust remaining_length - if remaining_length is not None: - remaining_length = remaining_length - start_idx if remaining_length > start_idx else video_length - - # Generate forward progress targets using ORIGINAL ReWiND formula - # Progress = (position_in_sequence + 1) / remaining_trajectory_length - progress_indices = torch.arange(1, video_length + 1, dtype=torch.float32) - forward_progress = progress_indices / remaining_length - - # ORIGINAL LOGIC: Reverse from END backwards, then append to forward sequence - # Example: video=[A,B,C,D,E] -> reversed=[E,D,C,B,A] -> take some from reversed (skip first) - # Result: [A,B,C,D,E] + [D,C,B] = progress increases then decreases - - # Randomly select how many frames to reverse and append - selected_end_point = random.randint(2, min(video_length, max_length)) - - # Reverse the entire video and its progress - reversed_video = video_feature.flip(0) - reversed_progress = forward_progress.flip(0) - - # Take frames from reversed (skip the first frame which is the last frame of original) - reverse_frames = reversed_video[1:selected_end_point] - reverse_progress = reversed_progress[1:selected_end_point] - - # Concatenate forward + reversed (creates rewind effect) - rewound_video = torch.cat([video_feature, reverse_frames], dim=0) - progress_targets = torch.cat([forward_progress, reverse_progress], dim=0) - - # Pad or sample to target length - if len(rewound_video) < max_length: - # Pad with last frame - padding_length = max_length - len(rewound_video) - last_frame = rewound_video[-1].unsqueeze(0) - padding_frames = last_frame.repeat(padding_length, 1) - rewound_video = torch.cat([rewound_video, padding_frames], dim=0) - - # Pad progress with last progress value - last_progress = progress_targets[-1] - padding_progress = torch.full((padding_length,), last_progress) - progress_targets = torch.cat([progress_targets, padding_progress]) - - elif len(rewound_video) > max_length: - # Sample frames to fit max_length - if random_sample: - frame_idx = sorted(random.sample(range(len(rewound_video)), max_length)) - else: - frame_idx = np.linspace(0, len(rewound_video) - 1, max_length, dtype=int) - rewound_video = rewound_video[frame_idx] - progress_targets = progress_targets[frame_idx] - - return rewound_video, progress_targets - - -def sample_sarm_video_feature( - video_feature: torch.Tensor, - num_frames: int = 9, - frame_gap: int = 30, - random_sample: bool = True, - absolute_indices: torch.Tensor = None, - episode_length: int = None -) -> tuple[torch.Tensor, torch.Tensor]: - """ - Sample video features for SARM (Stage-Aware Reward Modeling). - - SARM uses a specific pattern: - - 1 initial frame (from episode start) - - 8 consecutive frames with frame_gap spacing - - Progress normalization matches SARM implementation: - - Progress = absolute_frame_index / total_episode_length - - Args: - video_feature: Video features tensor (num_frames_available, feature_dim) - num_frames: Target number of frames (default: 9) - frame_gap: Gap between consecutive frames (default: 30, i.e., 1 second at 30fps) - random_sample: If True, use random sampling (not used for SARM's fixed pattern) - absolute_indices: Absolute frame indices in the episode (num_frames_available,) - episode_length: Total length of the episode - - Returns: - Tuple of: - - Sampled video features (num_frames, feature_dim) - - Progress targets for each frame (num_frames,) - """ - video_length = len(video_feature) - - # Generate progress targets based on relative position within sampled sequence - # Note: SARM paper uses subtask annotations (Equation 2: yt = Pk−1 + ᾱk * τt) - # Without annotations, we use linear progress relative to sequence position - if absolute_indices is not None and episode_length is not None: - # Compute relative progress: position within sequence / remaining trajectory - # This ensures progress starts near 0 and increases, not starting at 0.8 if sampled from end - first_frame_idx = absolute_indices[0].item() if isinstance(absolute_indices[0], torch.Tensor) else absolute_indices[0] - remaining_length = episode_length - first_frame_idx - - # Progress = (position_in_sequence + 1) / remaining_trajectory_length - progress_indices = torch.arange(1, video_length + 1, dtype=torch.float32) - progress_targets = progress_indices / remaining_length - else: - # Fallback: linear progress - progress_targets = torch.linspace(1.0/video_length, 1.0, video_length) - - # SARM pattern: first frame + (num_frames-1) consecutive frames with frame_gap - # The first frame should be from the beginning of the sequence - # The remaining frames are sampled with frame_gap spacing - - if video_length < num_frames: - # Not enough frames, pad with last frame - sampled_video = video_feature - sampled_progress = progress_targets - - padding_length = num_frames - video_length - last_frame = sampled_video[-1].unsqueeze(0) - padding_frames = last_frame.repeat(padding_length, 1) - sampled_video = torch.cat([sampled_video, padding_frames], dim=0) - - last_progress = sampled_progress[-1] - padding_progress = torch.full((padding_length,), last_progress) - sampled_progress = torch.cat([sampled_progress, padding_progress]) - - else: - # Sample frames: first frame + (num_frames-1) with frame_gap - # The indices should represent: [0, gap, 2*gap, 3*gap, ..., (num_frames-1)*gap] - # But we need to ensure we don't exceed video_length - - frame_indices = [0] # First frame - for i in range(1, num_frames): - idx = i * frame_gap - if idx >= video_length: - idx = video_length - 1 - frame_indices.append(idx) - - frame_indices = torch.tensor(frame_indices, dtype=torch.long) - sampled_video = video_feature[frame_indices] - sampled_progress = progress_targets[frame_indices] - - return sampled_video, sampled_progress - - -def sample_sarm_reverse_video_feature( - video_feature: torch.Tensor, - num_frames: int = 9, - frame_gap: int = 30, - random_sample: bool = True, - absolute_indices: torch.Tensor = None, - episode_length: int = None -) -> tuple[torch.Tensor, torch.Tensor]: - """ - Sample video with reverse augmentation for SARM (rewind augmentation). - - Similar to ReWiND's rewind augmentation but adapted for SARM's frame pattern: - 1. Take forward sequence (1 initial + 8 consecutive) - 2. Append some reversed frames from the end backwards - 3. Progress increases then decreases - - Args: - video_feature: Video features tensor (num_frames_available, feature_dim) - num_frames: Target number of frames (default: 9) - frame_gap: Gap between consecutive frames (default: 30) - random_sample: If True, use random sampling for reverse section - absolute_indices: Absolute frame indices in the episode - episode_length: Total length of the episode - - Returns: - Tuple of: - - Rewound video features (num_frames, feature_dim) - - Progress targets for each frame (num_frames,) - """ - video_length = len(video_feature) - - # Generate forward progress targets (relative to sequence, not absolute) - if absolute_indices is not None and episode_length is not None: - # Use same relative progress as normal sampling - first_frame_idx = absolute_indices[0].item() if isinstance(absolute_indices[0], torch.Tensor) else absolute_indices[0] - remaining_length = episode_length - first_frame_idx - progress_indices = torch.arange(1, video_length + 1, dtype=torch.float32) - forward_progress = progress_indices / remaining_length - else: - forward_progress = torch.linspace(1.0/video_length, 1.0, video_length) - - # Sample forward sequence first - forward_video, forward_progress_sampled = sample_sarm_video_feature( - video_feature, num_frames, frame_gap, random_sample, absolute_indices, episode_length - ) - - # Randomly select how many frames to reverse and append - # For SARM, we append 2-4 reversed frames - num_reverse = random.randint(2, min(4, num_frames - 1)) - - # Reverse the video and progress - reversed_video = video_feature.flip(0) - reversed_progress = forward_progress.flip(0) - - # Take frames from reversed (skip the first frame which is the last frame of original) - reverse_frames = reversed_video[1:num_reverse+1] - reverse_progress = reversed_progress[1:num_reverse+1] - - # Concatenate forward + reversed (creates rewind effect) - rewound_video = torch.cat([forward_video, reverse_frames], dim=0) - progress_targets = torch.cat([forward_progress_sampled, reverse_progress], dim=0) - - # Trim to num_frames if necessary - if len(rewound_video) > num_frames: - # Keep the first num_frames - rewound_video = rewound_video[:num_frames] - progress_targets = progress_targets[:num_frames] - elif len(rewound_video) < num_frames: - # Pad if necessary - padding_length = num_frames - len(rewound_video) - last_frame = rewound_video[-1].unsqueeze(0) - padding_frames = last_frame.repeat(padding_length, 1) - rewound_video = torch.cat([rewound_video, padding_frames], dim=0) - - last_progress = progress_targets[-1] - padding_progress = torch.full((padding_length,), last_progress) - progress_targets = torch.cat([progress_targets, padding_progress]) - - return rewound_video, progress_targets - diff --git a/src/lerobot/optim/schedulers.py b/src/lerobot/optim/schedulers.py index 10a3ee637..b5d54b396 100644 --- a/src/lerobot/optim/schedulers.py +++ b/src/lerobot/optim/schedulers.py @@ -132,41 +132,6 @@ class CosineDecayWithWarmupSchedulerConfig(LRSchedulerConfig): return LambdaLR(optimizer, lr_lambda, -1) -@LRSchedulerConfig.register_subclass("cosine_with_min_lr") -@dataclass -class CosineWithMinLRSchedulerConfig(LRSchedulerConfig): - """Cosine learning rate scheduler with minimum learning rate floor. - - Used by ReWiND for reward model training. Includes linear warmup phase - followed by cosine annealing with a minimum learning rate. - """ - - num_warmup_steps: int - min_lr: float = 0.0 - - def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR: - def lr_lambda(current_step): - # Get base learning rate from optimizer - base_lr = optimizer.param_groups[0]['lr'] - - if current_step <= self.num_warmup_steps: - # Linear warmup - if self.num_warmup_steps == 0: - return 1.0 - return float(current_step) / float(max(1, self.num_warmup_steps)) - else: - # Cosine annealing with minimum learning rate - progress = (current_step - self.num_warmup_steps) / float( - max(1, num_training_steps - self.num_warmup_steps) - ) - cosine_factor = 0.5 * (1.0 + math.cos(math.pi * progress)) - # Scale between min_lr and base_lr - min_factor = self.min_lr / base_lr if base_lr > 0 else 0.0 - return min_factor + (1.0 - min_factor) * cosine_factor - - return LambdaLR(optimizer, lr_lambda, -1) - - def save_scheduler_state(scheduler: LRScheduler, save_dir: Path) -> None: state_dict = scheduler.state_dict() write_json(state_dict, save_dir / SCHEDULER_STATE) diff --git a/src/lerobot/policies/factory.py b/src/lerobot/policies/factory.py index 32c028119..1ec651011 100644 --- a/src/lerobot/policies/factory.py +++ b/src/lerobot/policies/factory.py @@ -34,7 +34,6 @@ from lerobot.policies.groot.configuration_groot import GrootConfig from lerobot.policies.pi0.configuration_pi0 import PI0Config from lerobot.policies.pi05.configuration_pi05 import PI05Config from lerobot.policies.pretrained import PreTrainedPolicy -from lerobot.policies.rewind.configuration_rewind import ReWiNDConfig from lerobot.policies.sac.configuration_sac import SACConfig from lerobot.policies.sarm.configuration_sarm import SARMConfig from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig @@ -105,10 +104,6 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]: from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy return SmolVLAPolicy - elif name == "rewind": - from lerobot.policies.rewind.modeling_rewind import ReWiNDRewardModel - - return ReWiNDRewardModel elif name == "sarm": from lerobot.policies.sarm.modeling_sarm import SARMRewardModel @@ -332,15 +327,6 @@ def make_pre_post_processors( dataset_stats=kwargs.get("dataset_stats"), ) - elif isinstance(policy_cfg, ReWiNDConfig): - from lerobot.policies.rewind.processor_rewind import make_rewind_pre_post_processors - - processors = make_rewind_pre_post_processors( - config=policy_cfg, - dataset_stats=kwargs.get("dataset_stats"), - dataset_meta=kwargs.get("dataset_meta"), - ) - elif isinstance(policy_cfg, SARMConfig): from lerobot.policies.sarm.processor_sarm import make_sarm_pre_post_processors diff --git a/src/lerobot/policies/rewind/__init__.py b/src/lerobot/policies/rewind/__init__.py deleted file mode 100644 index a1aeff12a..000000000 --- a/src/lerobot/policies/rewind/__init__.py +++ /dev/null @@ -1,34 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2025 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from lerobot.policies.rewind.configuration_rewind import ReWiNDConfig -from lerobot.policies.rewind.modeling_rewind import ( - ReWiNDRewardModel, - ReWiNDTransformer, -) -from lerobot.policies.rewind.processor_rewind import ( - ReWiNDEncodingProcessorStep, - make_rewind_pre_post_processors, -) - -__all__ = [ - "ReWiNDConfig", - "ReWiNDRewardModel", - "ReWiNDTransformer", - "ReWiNDEncodingProcessorStep", - "make_rewind_pre_post_processors", -] - diff --git a/src/lerobot/policies/rewind/configuration_rewind.py b/src/lerobot/policies/rewind/configuration_rewind.py deleted file mode 100644 index 9a9b61584..000000000 --- a/src/lerobot/policies/rewind/configuration_rewind.py +++ /dev/null @@ -1,138 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2025 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dataclasses import dataclass, field - -from lerobot.configs.policies import PreTrainedConfig -from lerobot.optim.optimizers import AdamWConfig -from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig - - -@PreTrainedConfig.register_subclass("rewind") -@dataclass -class ReWiNDConfig(PreTrainedConfig): - """Configuration class for ReWiND Reward Model. - - ReWiND (Reward from Video and Natural language Descriptions) is a reward model - that computes task completion/progress rewards from video observations and - language task descriptions. - """ - - # Model architecture parameters - video_dim: int = 768 # DINO embedding dimension - text_dim: int = 384 # MiniLM embedding dimension - hidden_dim: int = 512 - num_heads: int = 8 - num_layers: int = 4 - - # Temporal parameters - max_length: int = 32 # Maximum video sequence length, ORIGINAL: 16! - subsample_video: bool = True # Whether to pad/subsample videos to max_length - use_temporal_sampler: bool = True # Always enable temporal sequence loading - sequence_stride: int = 1 # Stride between frames when using temporal sampler - rewind_ratio: float = 0.8 # Probability of applying rewind augmentation (original: 0.8) - - # Training parameters - batch_size: int = 64 - dino_batch_size: int = 64 - gradient_checkpointing: bool = False # Enable gradient checkpointing for memory optimization - - # Model loading - pretrained_model_path: str | None = None - - # Device settings - device: str | None = None - - # Dropout - dropout: float = 0.1 # Dropout rate for transformer - - # Processor settings (for automatic preprocessing) - image_key: str = "observation.images.top" # Key for images in dataset - task_description: str = "perform the task" # Default task description (used if no task field in data) - encode_on_the_fly: bool = True # Encode images/text during training - use_dataset_task: bool = True # Use task descriptions from dataset (per-episode) - - # Features (required by PreTrainedPolicy) - input_features: dict = field(default_factory=lambda: { - "video_features": {"shape": [768], "dtype": "float32"}, - "text_features": {"shape": [384], "dtype": "float32"} - }) - output_features: dict = field(default_factory=lambda: { - "progress": {"shape": [1], "dtype": "float32"} - }) - - def __post_init__(self): - super().__post_init__() - - # Validate configuration - if self.hidden_dim % self.num_heads != 0: - raise ValueError( - f"hidden_dim ({self.hidden_dim}) must be divisible by num_heads ({self.num_heads})" - ) - - if self.max_length <= 0: - raise ValueError(f"max_length must be positive, got {self.max_length}") - - if self.dropout < 0 or self.dropout >= 1: - raise ValueError(f"dropout must be in [0, 1), got {self.dropout}") - - def get_optimizer_preset(self) -> AdamWConfig: - """Get default optimizer configuration for ReWiND training.""" - return AdamWConfig( - lr=1e-4, - weight_decay=1e-4, - betas=(0.9, 0.999), - eps=1e-8, - ) - - def get_scheduler_preset(self) -> CosineDecayWithWarmupSchedulerConfig: - """Get default learning rate scheduler configuration.""" - return CosineDecayWithWarmupSchedulerConfig( - peak_lr=1e-4, - decay_lr=1e-5, - num_warmup_steps=1000, - num_decay_steps=100000, - ) - - def validate_features(self) -> None: - pass - - @property - def observation_delta_indices(self) -> list[int]: - """Load all frames from episode start up to current frame. - - The sampler yields a random end point in each episode. - This property tells the dataset to load all frames from -(end_idx - start_idx) to 0. - - Since we don't know the exact window size in advance, we load up to max_length frames. - The dataset will automatically clamp to episode boundaries. - - Returns: - Indices for loading history: [-31, -30, ..., -1, 0] for max_length=32 - """ - # Load the last max_length frames (or up to episode start) - return list(range(-(self.max_length - 1), 1)) - - @property - def action_delta_indices(self) -> None: - """ReWiND is a reward model, not an action policy.""" - return None - - @property - def reward_delta_indices(self) -> None: - """ReWiND doesn't use delta rewards.""" - return None - diff --git a/src/lerobot/policies/rewind/modeling_rewind.py b/src/lerobot/policies/rewind/modeling_rewind.py deleted file mode 100644 index 81576770a..000000000 --- a/src/lerobot/policies/rewind/modeling_rewind.py +++ /dev/null @@ -1,711 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2025 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -from typing import List, Union, Dict, Optional -import random - -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F -from PIL import Image -from transformers import AutoModel, AutoTokenizer -import torchvision.transforms as T -from torch import Tensor - -from lerobot.policies.rewind.configuration_rewind import ReWiNDConfig -from lerobot.policies.pretrained import PreTrainedPolicy -from lerobot.datasets.video_sampler import sample_video_feature, sample_reverse_video_feature - - -# Helper functions for encoding -def dino_load_image(img: np.ndarray) -> torch.Tensor: - """ - Load an image and return a tensor that can be used as an input to DINOv2. - - Args: - img: Input image as numpy array (H, W, C) in uint8 format. - - Returns: - Transformed image tensor ready for DINO encoder (1, 3, 224, 224). - """ - # Define transform: center crop to 224x224, normalize to [-1, 1] - dino_transform = T.Compose([ - T.ToTensor(), - T.CenterCrop(224), - T.Normalize([0.5], [0.5]) - ]) - - img_pil = Image.fromarray(img) - transformed_img = dino_transform(img_pil)[:3].unsqueeze(0) - - return transformed_img - - -def mean_pooling(model_output, attention_mask): - """ - Mean pooling - take attention mask into account for correct averaging. - - Args: - model_output: Model output containing token embeddings. - attention_mask: Attention mask for the tokens. - - Returns: - Mean-pooled embeddings. - """ - token_embeddings = model_output[0] # First element contains all token embeddings - input_mask_expanded = ( - attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() - ) - return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp( - input_mask_expanded.sum(1), min=1e-9 - ) - - -class ReWiNDTransformer(nn.Module): - """ - ReWiND Transformer model for predicting task progress from video and text. - - This model takes video frame embeddings and text embeddings as input, - and predicts a progress score (0-1) for each frame indicating how much - of the task has been completed. - """ - - def __init__( - self, - video_dim: int = 768, - text_dim: int = 384, - hidden_dim: int = 512, - num_heads: int = 8, - num_layers: int = 4, - max_length: int = 32, - dropout: float = 0.1 - ): - super().__init__() - self.hidden_dim = hidden_dim - self.max_length = max_length - - # Project video and text to common dimension - self.video_proj = nn.Linear(video_dim, hidden_dim) - self.text_proj = nn.Linear(text_dim, hidden_dim) - - # Position embeddings for video sequence - # We only add positional embedding to the first frame as in the original - self.first_pos_embed = nn.Parameter(torch.randn(1, hidden_dim)) - - # Transformer encoder - encoder_layer = nn.TransformerEncoderLayer( - d_model=hidden_dim, - nhead=num_heads, - dim_feedforward=hidden_dim * 4, - dropout=dropout, - batch_first=True - ) - self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) - - # Progress prediction head (applied to each frame) - self.progress_head = nn.Sequential( - nn.Linear(hidden_dim, hidden_dim // 2), - nn.LayerNorm(hidden_dim // 2), - nn.GELU(), - nn.Dropout(dropout), - nn.Linear(hidden_dim // 2, 1), - nn.Sigmoid() - ) - - # Attention mask for causal self-attention - # Will be created on-demand based on sequence length - self.register_buffer("attention_mask", None, persistent=False) - - def _get_attention_mask(self, seq_length: int, device: torch.device) -> torch.Tensor: - """Generate or retrieve cached causal attention mask.""" - if self.attention_mask is None or self.attention_mask.shape[0] != seq_length: - # Create causal mask (upper triangular with -inf) - mask = nn.Transformer.generate_square_subsequent_mask(seq_length, device=device) - self.attention_mask = mask - return self.attention_mask - - def forward(self, video_frames: torch.Tensor, text_embed: torch.Tensor) -> torch.Tensor: - """ - Forward pass through the ReWiND transformer. - - Args: - video_frames: Video frame embeddings (batch_size, seq_len, video_dim) - text_embed: Text embeddings (batch_size, text_dim) - - Returns: - Progress predictions for each frame (batch_size, seq_len, 1) - """ - batch_size = video_frames.shape[0] - - # Project inputs to common dimension - video_embed = self.video_proj(video_frames) # [batch_size, seq_len, hidden_dim] - text_embed = self.text_proj(text_embed).unsqueeze(1) # [batch_size, 1, hidden_dim] - - # Add positional embedding to first video frame - video_embed[:, 0] += self.first_pos_embed - - # Combine sequence: [text, video_frames] - sequence = torch.cat([text_embed, video_embed], dim=1) - - # Get causal attention mask - seq_length = sequence.shape[1] - attention_mask = self._get_attention_mask(seq_length, sequence.device) - - # Pass through transformer with causal masking - transformed = self.transformer(sequence, mask=attention_mask, is_causal=True) - - # Get progress predictions for each frame (exclude text token) - progress_preds = self.progress_head(transformed[:, 1:]) - - return progress_preds - - -class ReWiNDRewardModel(PreTrainedPolicy): - """ - ReWiND Reward Model for computing task completion rewards from video and text. - - This model combines: - - DINO (DINOv2) for encoding video frames - - MiniLM for encoding text descriptions - - ReWiNDTransformer for predicting task progress - """ - - name = "rewind" - config_class = ReWiNDConfig - - def __init__(self, config: ReWiNDConfig, dataset_stats: dict | None = None): - super().__init__(config, dataset_stats) - self.config = config - self.dataset_stats = dataset_stats - self.device = torch.device(config.device if config.device else "cuda" if torch.cuda.is_available() else "cpu") - - # Initialize DINO encoder for images - logging.info("Loading DINO encoder...") - self.dino_encoder = torch.hub.load("facebookresearch/dinov2", "dinov2_vitb14") - self.dino_encoder.to(self.device) - self.dino_encoder.eval() - - # Initialize MiniLM encoder for text - logging.info("Loading MiniLM encoder...") - self.minilm_tokenizer = AutoTokenizer.from_pretrained( - "sentence-transformers/all-MiniLM-L12-v2" - ) - self.minilm_model = AutoModel.from_pretrained( - "sentence-transformers/all-MiniLM-L12-v2" - ) - self.minilm_model.to(self.device) - self.minilm_model.eval() - - # Initialize ReWiND transformer with explicit architecture parameters - self.rewind_transformer = ReWiNDTransformer( - video_dim=config.video_dim, - text_dim=config.text_dim, - hidden_dim=config.hidden_dim, - num_heads=config.num_heads, - num_layers=config.num_layers, - max_length=config.max_length, - dropout=config.dropout - ) - self.rewind_transformer.to(self.device) - - logging.info(f"ReWiND Reward Model initialized on {self.device}") - - def to(self, device): - """Override to method to ensure all components move together.""" - super().to(device) - self.device = device if isinstance(device, torch.device) else torch.device(device) - self.dino_encoder.to(device) - self.minilm_model.to(device) - self.rewind_transformer.to(device) - return self - - @torch.no_grad() - def encode_images(self, images: np.ndarray) -> np.ndarray: - """ - Encode video frames using DINO. - - Args: - images: Video frames with shape (num_videos, num_frames, H, W, C) in uint8. - Can also be (num_frames, H, W, C) for a single video. - - Returns: - Encoded image features (num_videos, num_frames, 768) or (num_frames, 768). - """ - # Handle single video case - single_video = False - if len(images.shape) == 4: - images = images[np.newaxis, ...] - single_video = True - - assert len(images.shape) == 5, f"Expected 5D input (num_videos, num_frames, H, W, C), got {images.shape}" - - # Ensure channels are in correct position - if images.shape[-1] == 3 and images.shape[2] != 3: - images = np.transpose(images, (0, 1, 4, 2, 3)) - - all_embeddings = [] - - for video in images: - # Process each video - video_embeddings = [] - - # Convert frames to list of numpy arrays - frames = [frame.transpose(1, 2, 0).astype(np.uint8) if frame.shape[0] == 3 else frame for frame in video] - - # Batch process frames with DINO - episode_images_dino = [dino_load_image(frame) for frame in frames] - - # Process in batches - for i in range(0, len(episode_images_dino), self.config.dino_batch_size): - batch = torch.cat(episode_images_dino[i:i + self.config.dino_batch_size]) - batch = batch.to(self.device) - embeddings = self.dino_encoder(batch).squeeze().detach().cpu() - - # Handle single frame case - if embeddings.dim() == 1: - embeddings = embeddings.unsqueeze(0) - - video_embeddings.append(embeddings) - - video_embeddings = torch.cat(video_embeddings) - all_embeddings.append(video_embeddings) - - result = torch.stack(all_embeddings).numpy() - - if single_video: - result = result[0] - - return result - - @torch.no_grad() - def encode_text(self, text: Union[str, List[str]]) -> np.ndarray: - """ - Encode text using MiniLM. - - Args: - text: Text string or list of text strings. - - Returns: - Encoded text features (batch_size, 384) or (384,) for single text. - """ - if isinstance(text, str): - text = [text] - single_text = True - else: - single_text = False - - # Process in batches - all_embeddings = [] - for i in range(0, len(text), self.config.batch_size): - batch_text = text[i:i + self.config.batch_size] - - encoded_input = self.minilm_tokenizer( - batch_text, padding=True, truncation=True, return_tensors="pt" - ).to(self.device) - - model_output = self.minilm_model(**encoded_input) - text_embeddings = mean_pooling(model_output, encoded_input["attention_mask"]) - - all_embeddings.append(text_embeddings.cpu()) - - result = torch.cat(all_embeddings).numpy() - - if single_text: - result = result[0] - - return result - - def padding_video(self, video_frames: torch.Tensor, max_length: int) -> torch.Tensor: - """ - Pad or subsample video frames to a fixed length. - - Args: - video_frames: Video frames tensor (num_frames, embedding_dim) - max_length: Target sequence length - - Returns: - Padded/subsampled video frames (max_length, embedding_dim) - """ - video_length = len(video_frames) - - if isinstance(video_frames, np.ndarray): - video_frames = torch.tensor(video_frames) - - if video_length < max_length: - # Pad with last frame - padding_length = max_length - video_length - last_frame = video_frames[-1].unsqueeze(0) - padding_frames = last_frame.repeat(padding_length, 1) - video_frames = torch.cat([video_frames, padding_frames], dim=0) - - elif video_length > max_length: - # Subsample uniformly - frame_idx = np.linspace(0, video_length - 1, max_length).astype(int) - video_frames = video_frames[frame_idx] - - return video_frames - - @torch.no_grad() - def calculate_rewards( - self, - text_embeddings: Union[np.ndarray, torch.Tensor], - video_embeddings: Union[np.ndarray, torch.Tensor], - return_all_frames: bool = False - ) -> np.ndarray: - """ - Calculate rewards for given text and video representations. - - Args: - text_embeddings: Encoded text representations (batch_size, 384) - video_embeddings: Encoded video representations (batch_size, num_frames, 768) - return_all_frames: If True, return rewards for all frames. If False, return only last frame. - - Returns: - Reward values (batch_size,) or (batch_size, num_frames) if return_all_frames=True - """ - # Convert to tensors if needed - if isinstance(text_embeddings, np.ndarray): - text_embeddings = torch.tensor(text_embeddings, dtype=torch.float32) - if isinstance(video_embeddings, np.ndarray): - video_embeddings = torch.tensor(video_embeddings, dtype=torch.float32) - - # Handle single sample case - if text_embeddings.dim() == 1: - text_embeddings = text_embeddings.unsqueeze(0) - video_embeddings = video_embeddings.unsqueeze(0) - single_sample = True - else: - single_sample = False - - # Process in batches - all_rewards = [] - for i in range(0, len(video_embeddings), self.config.batch_size): - batch_texts = text_embeddings[i:i + self.config.batch_size].to(self.device) - batch_videos = video_embeddings[i:i + self.config.batch_size].to(self.device) - - # Pad/subsample videos if needed - if self.config.subsample_video: - padded_videos = [] - for video in batch_videos: - padded_video = self.padding_video(video, self.config.max_length) - padded_videos.append(padded_video) - batch_videos = torch.stack(padded_videos).to(self.device) - - # Get progress predictions - rewards = self.rewind_transformer(batch_videos.float(), batch_texts.float()) - - if return_all_frames: - all_rewards.append(rewards.squeeze(-1).cpu()) - else: - # Return only last frame reward - all_rewards.append(rewards[:, -1, 0].cpu()) - - result = torch.cat(all_rewards).numpy() - - if single_sample: - result = result[0] if not return_all_frames else result[0] - - return result - - def load_pretrained_checkpoint(self, checkpoint_path: str, strict: bool = False): - """ - Load pretrained model weights from a checkpoint file. - - Args: - checkpoint_path: Path to the .pth checkpoint file - strict: Whether to strictly enforce that the keys match - """ - logging.info(f"Loading pretrained checkpoint from {checkpoint_path}") - - checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=False) - - # Handle different checkpoint formats - if "model_state_dict" in checkpoint: - state_dict = checkpoint["model_state_dict"] - - # Check for architecture parameters in checkpoint - if "args" in checkpoint: - args = checkpoint["args"] - logging.info(f"Checkpoint was trained with: max_length={args.max_length}") - - # Warn if max_length differs - if hasattr(args, 'max_length') and args.max_length != self.config.max_length: - logging.warning( - f"Checkpoint max_length ({args.max_length}) differs from config ({self.config.max_length}). " - "This may cause issues if sequence lengths don't match." - ) - else: - state_dict = checkpoint - - # Load only the ReWiNDTransformer weights - missing_keys, unexpected_keys = self.rewind_transformer.load_state_dict(state_dict, strict=strict) - - if missing_keys: - logging.warning(f"Missing keys when loading checkpoint: {missing_keys}") - if unexpected_keys: - logging.warning(f"Unexpected keys when loading checkpoint: {unexpected_keys}") - - logging.info("Checkpoint loaded successfully") - - def train(self, mode: bool = True): - """Set training mode. Note: DINO and MiniLM encoders always stay in eval mode.""" - super().train(mode) - # Keep encoders in eval mode - self.dino_encoder.eval() - self.minilm_model.eval() - # Only transformer can be trained - self.rewind_transformer.train(mode) - return self - - def eval(self): - """Set evaluation mode.""" - return self.train(False) - - def parameters(self): - """Return trainable parameters (only ReWiND transformer, not encoders).""" - return self.rewind_transformer.parameters() - - def get_optim_params(self): - """Return optimizer parameters for the policy.""" - return self.parameters() - - def reset(self): - """ - This method is required by PreTrainedPolicy but not used for reward models. - The reward model does not maintain state between episodes. - """ - pass - - def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: - """ - This method is required by PreTrainedPolicy but not used for reward models. - The rewind model is not an actor and does not produce action chunks. - """ - raise NotImplementedError("Rewind model does not predict action chunks") - - def select_action(self, batch: dict[str, Tensor]) -> Tensor: - """ - This method is required by PreTrainedPolicy but not used for rewind. - The rewind model is not an actor and does not select actions. - """ - raise NotImplementedError("Rewind model does not select actions") - - def forward(self, batch): - """ - Forward pass compatible with lerobot training pipeline. - - Args: - batch: Dictionary containing observation with: - - 'video_features': Pre-encoded video features (B, 768) or (B, T, 768) - - 'text_features': Pre-encoded text features (B, 384) - - Returns: - loss: Total training loss - output_dict: Dictionary of loss components for logging - """ - # Extract from observation dict - observation = batch.get('observation', batch) - video_features = observation['video_features'].to(self.device) - text_features = observation['text_features'].to(self.device) - - batch_size = video_features.shape[0] - max_length = self.config.max_length - - # Handle both single frames (B, 768) and sequences (B, T, 768) - if video_features.dim() == 2: - # Single frames: replicate to create pseudo-sequences - video_features = video_features.unsqueeze(1).repeat(1, max_length, 1) # (B, max_length, 768) - - # Now video_features is (B, T, 768) where T might be > max_length - - # Process videos (with potential rewind augmentation) - import random - from lerobot.datasets.video_sampler import sample_video_feature, sample_reverse_video_feature - - processed_videos = [] - progress_targets = [] - - # Extract episode metadata for correct progress normalization - absolute_frame_indices = observation.get('absolute_frame_indices', None) - episode_lengths = observation.get('episode_length', None) - remaining_lengths = observation.get('remaining_length', None) - - for i in range(batch_size): - # Get metadata for this sample - current_absolute_indices = None - current_episode_length = None - current_remaining_length = None - - if absolute_frame_indices is not None: - if isinstance(absolute_frame_indices, list): - current_absolute_indices = absolute_frame_indices[i] - else: - current_absolute_indices = absolute_frame_indices - - if episode_lengths is not None: - if isinstance(episode_lengths, torch.Tensor) and episode_lengths.dim() > 0: - current_episode_length = episode_lengths[i].item() - else: - current_episode_length = episode_lengths.item() if isinstance(episode_lengths, torch.Tensor) else episode_lengths - - if remaining_lengths is not None: - if isinstance(remaining_lengths, torch.Tensor) and remaining_lengths.dim() > 0: - current_remaining_length = remaining_lengths[i].item() - else: - current_remaining_length = remaining_lengths.item() if isinstance(remaining_lengths, torch.Tensor) else remaining_lengths - - if random.random() < self.config.rewind_ratio: # Use configurable rewind ratio - # Apply video rewind augmentation (now returns tuple) - rewound_video, progress = sample_reverse_video_feature( - video_features[i], - max_length=max_length, - random_sample=True, # Use random sampling (original ReWiND) - remaining_length=current_remaining_length - ) - processed_videos.append(rewound_video.to(self.device)) - progress_targets.append(progress.to(self.device)) - else: - # Normal video sampling (now returns tuple with progress targets) - sampled_video, progress = sample_video_feature( - video_features[i], - max_length=max_length, - random_sample=True, # Use random sampling (original ReWiND) - remaining_length=current_remaining_length - ) - processed_videos.append(sampled_video.to(self.device)) - progress_targets.append(progress.to(self.device)) - - processed_videos = torch.stack(processed_videos) - progress_targets = torch.stack(progress_targets) - - # Compute progress loss - progress_loss = compute_progress_loss( - self.rewind_transformer, - processed_videos, - text_features, - progress_targets - ) - - total_loss = progress_loss - output_dict = {'progress_loss': progress_loss.item()} - - # Compute misaligned loss if requested (20% probability to match original) - if random.random() < 0.2: # 20% chance of adding misalignment loss (original ReWiND uses 20%) - if 'misaligned_video_features' in batch and 'misaligned_text_features' in batch: - misaligned_videos = batch['misaligned_video_features'].to(self.device) - misaligned_texts = batch['misaligned_text_features'].to(self.device) - else: - # Create misaligned pairs by shuffling - shuffle_idx = torch.randperm(batch_size) - misaligned_videos = processed_videos[shuffle_idx] - misaligned_texts = text_features - - # Sample misaligned videos (function now returns tuple) - # For misaligned pairs, we don't need correct progress targets (will be set to 0) - misaligned_videos_sampled = [] - for i in range(batch_size): - # For misaligned videos, use video length as remaining_length - video_len = len(misaligned_videos[i]) - sampled, _ = sample_video_feature( - misaligned_videos[i], - max_length=max_length, - random_sample=True, - remaining_length=video_len # Use video length for misaligned pairs - ) - misaligned_videos_sampled.append(sampled.to(self.device)) - misaligned_videos_sampled = torch.stack(misaligned_videos_sampled) - - misaligned_loss = compute_misaligned_loss( - self.rewind_transformer, - misaligned_videos_sampled, - misaligned_texts - ) - - total_loss = total_loss + misaligned_loss - output_dict['misaligned_loss'] = misaligned_loss.item() - - output_dict['total_loss'] = total_loss.item() - - return total_loss, output_dict - - -# Loss utilities -def compute_progress_loss( - model: ReWiNDTransformer, - video_features: torch.Tensor, - text_features: torch.Tensor, - target_progress: Optional[torch.Tensor] = None -) -> torch.Tensor: - """ - Compute progress prediction loss. - - Args: - model: ReWiNDTransformer model - video_features: Batch of video features (batch_size, max_length, feature_dim) - text_features: Batch of text features (batch_size, text_dim) - target_progress: Optional target progress values (batch_size, max_length). - If None, uses linear progress from 0 to 1. - - Returns: - Mean squared error loss - """ - # Get predictions - progress_preds = model(video_features, text_features) - - # Create target progress if not provided - if target_progress is None: - batch_size, max_length = video_features.shape[:2] - target_progress = torch.linspace(0, 1, max_length, device=video_features.device) - target_progress = target_progress.unsqueeze(0).repeat(batch_size, 1) - - # Ensure target has correct shape - if target_progress.dim() == 2: - target_progress = target_progress.unsqueeze(-1) - - # Compute MSE loss - loss = F.mse_loss(progress_preds, target_progress) - - return loss - - -def compute_misaligned_loss( - model: ReWiNDTransformer, - video_features: torch.Tensor, - misaligned_text_features: torch.Tensor -) -> torch.Tensor: - """ - Compute loss for misaligned video-text pairs (should predict 0 progress). - - Args: - model: ReWiNDTransformer model - video_features: Batch of video features (batch_size, max_length, feature_dim) - misaligned_text_features: Batch of misaligned text features (batch_size, text_dim) - - Returns: - Mean squared error loss (predictions should be close to 0) - """ - # Get predictions - progress_preds = model(video_features, misaligned_text_features) - - # Target is all zeros - target_zeros = torch.zeros_like(progress_preds) - - # Compute MSE loss - loss = F.mse_loss(progress_preds, target_zeros) - - return loss diff --git a/src/lerobot/policies/rewind/processor_rewind.py b/src/lerobot/policies/rewind/processor_rewind.py deleted file mode 100644 index d8f3b331c..000000000 --- a/src/lerobot/policies/rewind/processor_rewind.py +++ /dev/null @@ -1,405 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2025 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -from typing import Dict, Any, List, Optional -import numpy as np -import torch - -from lerobot.policies.rewind.configuration_rewind import ReWiNDConfig -from lerobot.processor import ( - ProcessorStep, - PolicyProcessorPipeline, - PolicyAction, - DeviceProcessorStep, - AddBatchDimensionProcessorStep, -) -from lerobot.processor.converters import ( - policy_action_to_transition, - transition_to_policy_action, -) -from lerobot.processor.pipeline import PipelineFeatureType -from lerobot.processor.core import EnvTransition, TransitionKey -from lerobot.configs.types import PolicyFeature, FeatureType -from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME - -class ReWiNDEncodingProcessorStep(ProcessorStep): - """ - ProcessorStep that encodes images and text for ReWiND training. - - This step handles the DINO (image) and MiniLM (text) encoding that ReWiND needs. - - Supports both single-frame and temporal sequence encoding: - - Single frame: (B, C, H, W) → (B, 768) video features - - Temporal sequence: (B, T, C, H, W) → (B, T, 768) video features - - To use temporal sequences, configure the dataset with delta_timestamps for your image key. - For example, to encode sequences of 32 frames: - delta_timestamps = { - "observation.images.top": [i / fps for i in range(-15, 17)] # 32 frames centered on current - } - """ - - def __init__( - self, - config: ReWiNDConfig, - image_key: str | None = None, - task_description: str | None = None, - dataset_meta = None, - ): - super().__init__() - self.config = config - self.image_key = image_key or config.image_key - self.task_description = task_description or config.task_description - self.dataset_meta = dataset_meta # Store dataset metadata for episode info - - # Initialize encoders - self._init_encoders() - - def _init_encoders(self): - """Initialize DINO and MiniLM encoders.""" - from transformers import AutoModel, AutoTokenizer - - device = torch.device( - self.config.device if self.config.device - else "cuda" if torch.cuda.is_available() else "cpu" - ) - - logging.info("Initializing DINO encoder for ReWiND...") - self.dino_encoder = torch.hub.load("facebookresearch/dinov2", "dinov2_vitb14") - self.dino_encoder.to(device) - self.dino_encoder.eval() - - logging.info("Initializing MiniLM encoder for ReWiND...") - self.minilm_tokenizer = AutoTokenizer.from_pretrained( - "sentence-transformers/all-MiniLM-L12-v2" - ) - self.minilm_model = AutoModel.from_pretrained( - "sentence-transformers/all-MiniLM-L12-v2" - ) - self.minilm_model.to(device) - self.minilm_model.eval() - - self.device = device - - def __call__(self, transition: EnvTransition) -> EnvTransition: - """Encode images and text in the transition.""" - self._current_transition = transition.copy() if hasattr(transition, 'copy') else dict(transition) - new_transition = self._current_transition - - observation = new_transition.get(TransitionKey.OBSERVATION) - if observation is None or not isinstance(observation, dict): - # If no observation, just return the transition as-is - return new_transition - - # Extract images from observation and encode - # For ReWiND, we need to load the sequence from episode start to current frame - batch_size = 1 - if self.image_key in observation: - image = observation[self.image_key] - - # Handle different image formats - if isinstance(image, torch.Tensor): - image = image.cpu().numpy() - - # Check if we have temporal sequences or single frames - # Temporal sampling: Load from episode start to current frame - # This will be handled by the dataset if configured with delta_timestamps - # Otherwise, we just encode the single frame - video_features = self._encode_images_batch(image) - observation['video_features'] = video_features - - # Get batch size from the encoded features - batch_size = video_features.shape[0] - - # Get task descriptions - check if 'task' field exists in the transition - # This allows per-episode task descriptions (e.g., for datasets with multiple tasks) - task_descriptions = None - if 'task' in new_transition: - task_descriptions = new_transition['task'] - # Convert to list if it's a single string - if isinstance(task_descriptions, str): - task_descriptions = [task_descriptions] * batch_size - - # Encode text - if task_descriptions is not None: - # Encode per-sample task descriptions - text_features = self._encode_text_batch_list(task_descriptions) - else: - # Fall back to config task description if no task field in transition - text_features = self._encode_text_batch(self.task_description, batch_size) - - observation['text_features'] = text_features - - # Compute episode metadata for progress normalization (used by ReWiND) - # We need to pass absolute frame indices and total episode length for correct progress calculation - if self.dataset_meta is not None and 'episode_index' in new_transition and 'index' in new_transition: - episode_indices = new_transition['episode_index'] - frame_indices = new_transition['index'] - - # Handle both single samples and batches - if isinstance(episode_indices, (int, np.integer)): - ep_idx = int(episode_indices) - frame_idx = int(frame_indices) - ep_start = self.dataset_meta.episodes[ep_idx]["dataset_from_index"] - ep_end = self.dataset_meta.episodes[ep_idx]["dataset_to_index"] - episode_length = ep_end - ep_start - - # For temporal sequences with observation_delta_indices: - # If we loaded frames using delta_indices (e.g., [-31, -30, ..., 0]), - # we need to compute the absolute indices of those frames - # The current frame is at frame_idx, and we loaded max_length frames before it - if 'video_features' in observation and len(observation['video_features'].shape) > 1: - # We have a temporal sequence - num_loaded_frames = observation['video_features'].shape[0] if observation['video_features'].dim() == 2 else observation['video_features'].shape[1] - # Absolute indices: from (frame_idx - num_frames + 1) to frame_idx - start_idx = max(ep_start, frame_idx - num_loaded_frames + 1) - absolute_indices = torch.arange(start_idx, frame_idx + 1) - observation['absolute_frame_indices'] = absolute_indices - # Compute remaining length: from first loaded frame to episode end - observation['remaining_length'] = ep_end - start_idx - else: - # Single frame - observation['absolute_frame_indices'] = torch.tensor([frame_idx]) - # Remaining length from this frame to episode end - observation['remaining_length'] = ep_end - frame_idx - - observation['episode_length'] = episode_length - else: - # Batch case - absolute_indices_list = [] - episode_lengths = [] - remaining_lengths = [] - for ep_idx, frame_idx in zip(episode_indices, frame_indices): - ep_idx = int(ep_idx.item() if hasattr(ep_idx, 'item') else ep_idx) - frame_idx = int(frame_idx.item() if hasattr(frame_idx, 'item') else frame_idx) - ep_start = self.dataset_meta.episodes[ep_idx]["dataset_from_index"] - ep_end = self.dataset_meta.episodes[ep_idx]["dataset_to_index"] - episode_length = ep_end - ep_start - episode_lengths.append(episode_length) - - # Compute absolute indices for this sample - if 'video_features' in observation and len(observation['video_features'].shape) > 1: - num_loaded_frames = observation['video_features'].shape[1] - start_idx = max(ep_start, frame_idx - num_loaded_frames + 1) - absolute_indices = torch.arange(start_idx, frame_idx + 1) - absolute_indices_list.append(absolute_indices) - # Remaining length from first loaded frame to episode end - remaining_lengths.append(ep_end - start_idx) - else: - absolute_indices_list.append(torch.tensor([frame_idx])) - # Remaining length from this frame to episode end - remaining_lengths.append(ep_end - frame_idx) - - observation['absolute_frame_indices'] = absolute_indices_list - observation['episode_length'] = torch.tensor(episode_lengths) - observation['remaining_length'] = torch.tensor(remaining_lengths) - - new_transition[TransitionKey.OBSERVATION] = observation - return new_transition - - @torch.no_grad() - def _encode_images_batch(self, images: np.ndarray) -> torch.Tensor: - """Encode a batch of images (with optional temporal dimension) using DINO. - - Args: - images: Batched images with shape: - - (B, C, H, W) for single frames, or - - (B, T, C, H, W) for temporal sequences - - Returns: - Encoded feature vectors with shape (B, 768) or (B, T, 768) - """ - from lerobot.policies.rewind.modeling_rewind import dino_load_image - - # Check if we have temporal dimension - has_temporal = len(images.shape) == 5 - - if has_temporal: - # Shape: (B, T, C, H, W) - batch_size, seq_length = images.shape[0], images.shape[1] - - # Reshape to (B*T, C, H, W) to process all frames at once - images = images.reshape(batch_size * seq_length, *images.shape[2:]) - elif len(images.shape) == 4: - # Shape: (B, C, H, W) - batch_size = images.shape[0] - seq_length = 1 - else: - raise ValueError(f"Expected 4D (B, C, H, W) or 5D (B, T, C, H, W) input, got shape {images.shape}") - - # Convert to list of (H, W, C) images - num_frames = images.shape[0] - if images.shape[1] in [1, 3]: # Channel first (N, C, H, W) - images_list = [images[i].transpose(1, 2, 0) for i in range(num_frames)] - else: # Channel last (N, H, W, C) - images_list = [images[i] for i in range(num_frames)] - - # Encode each frame (can batch process with DINO for efficiency) - all_embeddings = [] - for i in range(0, num_frames, self.config.dino_batch_size): - batch_imgs = images_list[i:i + self.config.dino_batch_size] - - # Prepare images for DINO - dino_inputs = [] - for img in batch_imgs: - # Handle single channel - if img.shape[-1] == 1: - img = np.repeat(img, 3, axis=-1) - - # Convert to uint8 - if img.dtype != np.uint8: - img = (img * 255).astype(np.uint8) if img.max() <= 1.0 else img.astype(np.uint8) - - dino_inputs.append(dino_load_image(img)) - - # Batch encode - dino_batch = torch.cat(dino_inputs).to(self.device) - embeddings = self.dino_encoder(dino_batch).detach().cpu() - - # Handle single frame case - if embeddings.dim() == 1: - embeddings = embeddings.unsqueeze(0) - - all_embeddings.append(embeddings) - - # Concatenate all embeddings - all_embeddings = torch.cat(all_embeddings) # (B*T, 768) - - # Reshape back if temporal - if has_temporal: - all_embeddings = all_embeddings.reshape(batch_size, seq_length, -1) # (B, T, 768) - - return all_embeddings - - @torch.no_grad() - def _encode_text_batch(self, text: str, batch_size: int) -> torch.Tensor: - """Encode a text string using MiniLM and replicate for batch. - - Args: - text: Text string to encode - batch_size: Batch size to replicate for - - Returns: - Encoded feature vectors with shape (B, 384) - """ - from lerobot.policies.rewind.modeling_rewind import mean_pooling - - encoded_input = self.minilm_tokenizer( - text, padding=True, truncation=True, return_tensors="pt" - ).to(self.device) - - model_output = self.minilm_model(**encoded_input) - text_embedding = mean_pooling(model_output, encoded_input["attention_mask"]) - text_embedding = text_embedding.squeeze().cpu() - - # Replicate for batch (B, 384) - text_embedding = text_embedding.unsqueeze(0).repeat(batch_size, 1) - - return text_embedding - - @torch.no_grad() - def _encode_text_batch_list(self, text_list: list[str]) -> torch.Tensor: - """Encode a list of text strings using MiniLM (one per sample). - - Args: - text_list: List of text strings to encode - - Returns: - Encoded feature vectors with shape (B, 384) - """ - from lerobot.policies.rewind.modeling_rewind import mean_pooling - - # Encode all texts in the batch at once - encoded_input = self.minilm_tokenizer( - text_list, padding=True, truncation=True, return_tensors="pt" - ).to(self.device) - - model_output = self.minilm_model(**encoded_input) - text_embeddings = mean_pooling(model_output, encoded_input["attention_mask"]) - text_embeddings = text_embeddings.cpu() - - return text_embeddings - - def transform_features( - self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] - ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: - """ - Adds video_features and text_features to the observation features. - """ - # Add the encoded features - features[PipelineFeatureType.OBSERVATION]['video_features'] = PolicyFeature( - type=FeatureType.VISUAL, - shape=(768,) # DINO embedding dimension - ) - features[PipelineFeatureType.OBSERVATION]['text_features'] = PolicyFeature( - type=FeatureType.LANGUAGE, - shape=(384,) # MiniLM embedding dimension - ) - return features - - -def make_rewind_pre_post_processors( - config: ReWiNDConfig, - dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None, - dataset_meta = None, -) -> tuple[ - PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], - PolicyProcessorPipeline[PolicyAction, PolicyAction], -]: - """ - Create pre-processor and post-processor pipelines for ReWiND. - - The pre-processing pipeline: - 1. Encodes images with DINO (768-dim) - 2. Encodes text with MiniLM (384-dim) - 3. Computes remaining episode length for progress normalization - 4. Adds batch dimension - 5. Moves data to device - - The post-processing pipeline moves data back to CPU. - - Args: - config: ReWiND configuration - dataset_stats: Dataset statistics (not used for ReWiND) - dataset_meta: Dataset metadata for computing episode remaining length - - Returns: - Tuple of (preprocessor, postprocessor) pipelines - """ - input_steps = [ - AddBatchDimensionProcessorStep(), - ReWiNDEncodingProcessorStep(config=config, dataset_meta=dataset_meta), - DeviceProcessorStep(device=config.device), - ] - - output_steps = [ - DeviceProcessorStep(device="cpu"), - ] - - return ( - PolicyProcessorPipeline[dict[str, Any], dict[str, Any]]( - steps=input_steps, - name=POLICY_PREPROCESSOR_DEFAULT_NAME, - ), - PolicyProcessorPipeline[PolicyAction, PolicyAction]( - steps=output_steps, - name=POLICY_POSTPROCESSOR_DEFAULT_NAME, - to_transition=policy_action_to_transition, - to_output=transition_to_policy_action, - ), - ) - diff --git a/src/lerobot/policies/sarm/__init__.py b/src/lerobot/policies/sarm/__init__.py index 4cda62bd2..c936e1632 100644 --- a/src/lerobot/policies/sarm/__init__.py +++ b/src/lerobot/policies/sarm/__init__.py @@ -19,7 +19,6 @@ from lerobot.policies.sarm.modeling_sarm import ( SARMRewardModel, SARMTransformer, compute_stage_loss, - compute_progress_loss, ) from lerobot.policies.sarm.processor_sarm import ( SARMEncodingProcessorStep, @@ -31,7 +30,6 @@ __all__ = [ "SARMRewardModel", "SARMTransformer", "compute_stage_loss", - "compute_progress_loss", "SARMEncodingProcessorStep", "make_sarm_pre_post_processors", ] diff --git a/src/lerobot/policies/sarm/configuration_sarm.py b/src/lerobot/policies/sarm/configuration_sarm.py index a89e367e8..930cb4010 100644 --- a/src/lerobot/policies/sarm/configuration_sarm.py +++ b/src/lerobot/policies/sarm/configuration_sarm.py @@ -32,8 +32,8 @@ class SARMConfig(PreTrainedConfig): num_frames: int = 9 # 1 initial + 8 consecutive frames frame_gap: int = 30 # Frame gap between consecutive frames (at 30 fps = 1 second) - # Text encoding parameters - text_dim: int = 384 + # Text encoding parameters (CLIP text encoder output dimension) + text_dim: int = 512 # Joint state parameters state_dim: int | None = None # Auto-detected from dataset if None @@ -49,7 +49,6 @@ class SARMConfig(PreTrainedConfig): # Temporal parameters max_length: int = num_frames # Maximum video sequence length (matches num_frames) use_temporal_sampler: bool = True # Always enable temporal sequence loading - sampling_mode: str = "sarm" # Sampling mode: "sarm" or "rewind" # Training parameters batch_size: int = 64 @@ -101,11 +100,6 @@ class SARMConfig(PreTrainedConfig): if self.num_stages < 2: raise ValueError(f"num_stages must be at least 2, got {self.num_stages}") - - if self.sampling_mode not in ["sarm", "rewind", "custom"]: - raise ValueError( - f"sampling_mode must be 'sarm' or 'rewind', got {self.sampling_mode}" - ) def get_optimizer_preset(self) -> AdamWConfig: """Get default optimizer configuration for SARM training.""" diff --git a/src/lerobot/policies/sarm/modeling_sarm.py b/src/lerobot/policies/sarm/modeling_sarm.py index 296e9188e..47e63ec09 100644 --- a/src/lerobot/policies/sarm/modeling_sarm.py +++ b/src/lerobot/policies/sarm/modeling_sarm.py @@ -24,33 +24,13 @@ import torch import torch.nn as nn import torch.nn.functional as F from PIL import Image -from transformers import AutoModel, AutoTokenizer, CLIPModel, CLIPProcessor +from transformers import CLIPModel, CLIPProcessor from torch import Tensor from lerobot.policies.sarm.configuration_sarm import SARMConfig from lerobot.policies.pretrained import PreTrainedPolicy -def mean_pooling(model_output, attention_mask): - """ - Mean pooling - take attention mask into account for correct averaging. - - Args: - model_output: Model output containing token embeddings. - attention_mask: Attention mask for the tokens. - - Returns: - Mean-pooled embeddings. - """ - token_embeddings = model_output[0] # First element contains all token embeddings - input_mask_expanded = ( - attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() - ) - return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp( - input_mask_expanded.sum(1), min=1e-9 - ) - - class SARMTransformer(nn.Module): """ SARM Transformer model for stage-aware reward prediction. @@ -65,7 +45,7 @@ class SARMTransformer(nn.Module): def __init__( self, video_dim: int = 512, - text_dim: int = 384, + text_dim: int = 512, # CLIP text encoder output dimension (per SARM paper A.4) state_dim: int = 14, hidden_dim: int = 768, num_heads: int = 12, @@ -204,7 +184,7 @@ class SARMTransformer(nn.Module): stage_indices = torch.argmax(stage_probs, dim=-1) # [batch_size, seq_len] # Get stage embeddings for conditioning - stage_embeds = self.stage_embedding(stage_indices) # [batch_size, seq_len, hidden_dim//4] + stage_embeds = self.stage_embedding(stage_indices) # Concatenate frame features with stage embeddings conditioned_features = torch.cat([frame_features, stage_embeds], dim=-1) @@ -229,9 +209,11 @@ class SARMRewardModel(PreTrainedPolicy): """ SARM Reward Model for stage-aware task completion rewards. + Per SARM paper (Appendix A.4): "We employ a frozen clip-vit-base-patch32 encoder + to process both RGB image sequences and task descriptions." + This model combines: - - CLIP for encoding video frames - - MiniLM for encoding text descriptions + - CLIP for encoding video frames AND text descriptions - SARMTransformer for predicting task stage and progress - Optional RA-BC (Reward-Aligned Behavior Cloning) for weighted training """ @@ -249,24 +231,13 @@ class SARMRewardModel(PreTrainedPolicy): if dataset_meta is not None: self._update_num_stages_from_dataset(dataset_meta) - # Initialize CLIP encoder for images - logging.info("Loading CLIP encoder...") + # Initialize CLIP encoder for images AND text (per SARM paper A.4) + logging.info("Loading CLIP encoder for images and text...") self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", use_fast=True) self.clip_model.to(self.device) self.clip_model.eval() - # Initialize MiniLM encoder for text - logging.info("Loading MiniLM encoder...") - self.minilm_tokenizer = AutoTokenizer.from_pretrained( - "sentence-transformers/all-MiniLM-L12-v2" - ) - self.minilm_model = AutoModel.from_pretrained( - "sentence-transformers/all-MiniLM-L12-v2" - ) - self.minilm_model.to(self.device) - self.minilm_model.eval() - # Auto-detect state_dim from dataset_stats if config.state_dim is None: logging.info(f"Attempting to auto-detect state_dim. dataset_stats is None: {dataset_stats is None}") @@ -379,7 +350,6 @@ class SARMRewardModel(PreTrainedPolicy): super().to(device) self.device = device if isinstance(device, torch.device) else torch.device(device) self.clip_model.to(device) - self.minilm_model.to(device) self.sarm_transformer.to(device) return self @@ -445,13 +415,13 @@ class SARMRewardModel(PreTrainedPolicy): @torch.no_grad() def encode_text(self, text: Union[str, List[str]]) -> np.ndarray: """ - Encode text using MiniLM. + Encode text using CLIP text encoder (per SARM paper A.4). Args: text: Text string or list of text strings. Returns: - Encoded text features (batch_size, 384) or (384,) for single text. + Encoded text features (batch_size, 512) or (512,) for single text. """ if isinstance(text, str): text = [text] @@ -459,18 +429,18 @@ class SARMRewardModel(PreTrainedPolicy): else: single_text = False + # Use CLIP's tokenizer directly (avoids image processor validation issues) + tokenizer = self.clip_processor.tokenizer + # Process in batches all_embeddings = [] for i in range(0, len(text), self.config.batch_size): batch_text = text[i:i + self.config.batch_size] - encoded_input = self.minilm_tokenizer( - batch_text, padding=True, truncation=True, return_tensors="pt" - ).to(self.device) - - model_output = self.minilm_model(**encoded_input) - text_embeddings = mean_pooling(model_output, encoded_input["attention_mask"]) + inputs = tokenizer(batch_text, return_tensors="pt", padding=True, truncation=True) + inputs = {k: v.to(self.device) for k, v in inputs.items()} + text_embeddings = self.clip_model.get_text_features(**inputs) all_embeddings.append(text_embeddings.cpu()) result = torch.cat(all_embeddings).numpy() @@ -493,7 +463,7 @@ class SARMRewardModel(PreTrainedPolicy): Calculate rewards for given text, video, and state representations. Args: - text_embeddings: Encoded text representations (batch_size, 384) + text_embeddings: Encoded text representations (batch_size, 512) video_embeddings: Encoded video representations (batch_size, num_frames, 512) state_features: Joint state features (batch_size, num_frames, state_dim) return_all_frames: If True, return rewards for all frames @@ -585,11 +555,10 @@ class SARMRewardModel(PreTrainedPolicy): logging.info("Checkpoint loaded successfully") def train(self, mode: bool = True): - """Set training mode. Note: CLIP and MiniLM encoders always stay in eval mode.""" + """Set training mode. Note: CLIP encoder always stays in eval mode (frozen).""" super().train(mode) - # Keep encoders in eval mode + # Keep CLIP encoder in eval mode (frozen per SARM paper) self.clip_model.eval() - self.minilm_model.eval() # Only transformer can be trained self.sarm_transformer.train(mode) return self @@ -618,30 +587,18 @@ class SARMRewardModel(PreTrainedPolicy): """Required by PreTrainedPolicy but not used for SARM.""" raise NotImplementedError("SARM model does not select actions") - def _get_remaining_length(self, observation: dict, idx: int) -> float | None: - """Extract remaining length for a sample from observation metadata.""" - remaining_lengths = observation.get('remaining_length') - if remaining_lengths is None: - return None - if isinstance(remaining_lengths, torch.Tensor): - return remaining_lengths[idx].item() if remaining_lengths.dim() > 0 else remaining_lengths.item() - return remaining_lengths - - def _compute_progress_targets(self, remaining_length: float | None, seq_len: int) -> torch.Tensor: - """Compute progress targets based on remaining trajectory length.""" - if remaining_length is not None and remaining_length > 0: - return torch.arange(1, seq_len + 1, dtype=torch.float32, device=self.device) / remaining_length - else: - raise ValueError("Remaining length is None, but is required for progress targets") - - def _apply_rewind_augmentation( + def _apply_temporal_augmentation( self, video: torch.Tensor, progress: torch.Tensor, state: torch.Tensor | None, max_length: int, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]: - """Apply rewind augmentation: append up to 4 reversed frames (SARM paper A.4).""" + """Apply temporal augmentation by appending reversed frames (SARM paper A.4). + + This helps the model learn to handle non-monotonic progress (failures, recoveries). + Appends 1-4 reversed frames to simulate going backwards in task progress. + """ num_reverse = random.randint(1, min(4, max_length - 1)) # Reverse and take frames (skip first which is last of original) @@ -672,14 +629,20 @@ class SARMRewardModel(PreTrainedPolicy): """ Forward pass for SARM reward model training. + Uses annotation-based progress targets following SARM paper Eq. 2: + yt = Pk-1 + α̅k × τt + where: + - τt = (t - sk) / (ek - sk) is within-subtask normalized time + - Pk-1 is cumulative prior (sum of previous subtask proportions) + - α̅k is the temporal proportion for subtask k + Args: batch: Dictionary with 'observation' containing: - 'video_features': (B, T, 512) pre-encoded video features - - 'text_features': (B, 384) pre-encoded text features + - 'text_features': (B, 512) pre-encoded text features (CLIP) - 'state_features': (B, T, state_dim) joint state features - - 'remaining_length': (B,) remaining trajectory lengths (optional) - - 'stage_labels': (B, T) stage labels (optional, from annotations) - - 'progress_targets': (B, T, 1) progress targets (optional, from annotations) + - 'stage_labels': (B, T) stage labels from annotations + - 'progress_targets': (B, T, 1) progress targets from annotations Returns: Tuple of (total_loss, output_dict with loss components) @@ -702,21 +665,31 @@ class SARMRewardModel(PreTrainedPolicy): if state_features is not None and state_features.dim() == 2: state_features = state_features.unsqueeze(1).expand(-1, max_length, -1) - # Process each sample: compute progress targets and apply rewind augmentation + # Get annotation-based progress targets (required for SARM paper formula) + progress_from_annotations = observation.get('progress_targets') + if progress_from_annotations is None: + raise ValueError("progress_targets from annotations is required for SARM training") + + progress_from_annotations = progress_from_annotations.to(self.device) + if progress_from_annotations.dim() == 2: + progress_from_annotations = progress_from_annotations.unsqueeze(-1) + if progress_from_annotations.dim() == 3 and progress_from_annotations.shape[0] == 1: + progress_from_annotations = progress_from_annotations.expand(batch_size, -1, -1) + + # Process each sample: apply temporal augmentation (SARM paper A.4) processed_videos = [] processed_states = [] progress_targets = [] for i in range(batch_size): - remaining_length = self._get_remaining_length(observation, i) - progress = self._compute_progress_targets(remaining_length, max_length) - video = video_features[i] state = state_features[i] if state_features is not None else None + progress = progress_from_annotations[i].squeeze(-1) # (T,) - # Apply rewind augmentation with 50% probability (SARM paper) + # Apply temporal augmentation with 50% probability (SARM paper A.4) + # Appends up to 4 reversed frames to simulate failures/recoveries if random.random() < 0.5: - video, progress, state = self._apply_rewind_augmentation(video, progress, state, max_length) + video, progress, state = self._apply_temporal_augmentation(video, progress, state, max_length) # Ensure correct sequence length video = self._ensure_sequence_length(video, max_length) @@ -739,32 +712,22 @@ class SARMRewardModel(PreTrainedPolicy): processed_videos, text_features, processed_states ) - # Use annotation-based progress targets - progress_from_annotations = observation.get('progress_targets') - if progress_from_annotations is not None: - progress_from_annotations = progress_from_annotations.to(self.device) - if progress_from_annotations.dim() == 2: - progress_from_annotations = progress_from_annotations.unsqueeze(-1) - if progress_from_annotations.dim() == 3 and progress_from_annotations.shape[0] == 1: - progress_from_annotations = progress_from_annotations.expand(batch_size, -1, -1) - progress_targets = progress_from_annotations - - # Compute progress loss + # Compute progress loss (MSE) progress_loss = F.mse_loss(progress_preds, progress_targets) output_dict = {'progress_loss': progress_loss.item()} total_loss = progress_loss - # Compute stage loss if labels available + # Compute stage loss (cross-entropy) stage_labels = observation.get('stage_labels') - if stage_labels is not None: - stage_labels = stage_labels.to(self.device) - if stage_labels.dim() == 1: - stage_labels = stage_labels.unsqueeze(0).expand(batch_size, -1) - stage_loss = compute_stage_loss(stage_logits, stage_labels) - total_loss = total_loss + self.config.stage_loss_weight * stage_loss - output_dict['stage_loss'] = stage_loss.item() - else: - raise ValueError("Stage labels are None, but are required for stage loss") + if stage_labels is None: + raise ValueError("stage_labels from annotations is required for SARM training") + + stage_labels = stage_labels.to(self.device) + if stage_labels.dim() == 1: + stage_labels = stage_labels.unsqueeze(0).expand(batch_size, -1) + stage_loss = compute_stage_loss(stage_logits, stage_labels) + total_loss = total_loss + self.config.stage_loss_weight * stage_loss + output_dict['stage_loss'] = stage_loss.item() # Misaligned loss: 20% probability (SARM paper - improve video-language alignment) if random.random() < 0.2: @@ -786,9 +749,3 @@ def compute_stage_loss(stage_logits: torch.Tensor, target_stages: torch.Tensor) loss = F.cross_entropy(stage_logits_flat, target_stages_flat) return loss - - -def compute_progress_loss(progress_preds: torch.Tensor, target_progress: torch.Tensor) -> torch.Tensor: - loss = F.mse_loss(progress_preds, target_progress) - return loss - diff --git a/src/lerobot/policies/sarm/processor_sarm.py b/src/lerobot/policies/sarm/processor_sarm.py index 59ddf7292..3530b0d1c 100644 --- a/src/lerobot/policies/sarm/processor_sarm.py +++ b/src/lerobot/policies/sarm/processor_sarm.py @@ -20,7 +20,7 @@ import numpy as np import torch from PIL import Image import pandas as pd -from transformers import AutoModel, AutoTokenizer, CLIPModel, CLIPProcessor +from transformers import CLIPModel, CLIPProcessor from lerobot.policies.sarm.configuration_sarm import SARMConfig from lerobot.processor import ( @@ -44,9 +44,12 @@ class SARMEncodingProcessorStep(ProcessorStep): """ ProcessorStep that encodes images and text for SARM training. + Per SARM paper (Appendix A.4): "We employ a frozen clip-vit-base-patch32 encoder + to process both RGB image sequences and task descriptions." + This step handles: - - CLIP (image) encoding - - MiniLM (text) encoding + - CLIP image encoding (512-dim) + - CLIP text encoding (512-dim) - Joint state normalization Supports temporal sequences: (B, T, C, H, W) → (B, T, 512) video features @@ -76,28 +79,18 @@ class SARMEncodingProcessorStep(ProcessorStep): self._init_encoders() def _init_encoders(self): - """Initialize CLIP and MiniLM encoders.""" + """Initialize CLIP encoder for both images and text (per SARM paper A.4).""" device = torch.device( self.config.device if self.config.device else "cuda" if torch.cuda.is_available() else "cpu" ) - logging.info("Initializing CLIP encoder for SARM...") + logging.info("Initializing CLIP encoder for SARM (images + text)...") self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", use_fast=True) self.clip_model.to(device) self.clip_model.eval() - logging.info("Initializing MiniLM encoder for SARM...") - self.minilm_tokenizer = AutoTokenizer.from_pretrained( - "sentence-transformers/all-MiniLM-L12-v2" - ) - self.minilm_model = AutoModel.from_pretrained( - "sentence-transformers/all-MiniLM-L12-v2" - ) - self.minilm_model.to(device) - self.minilm_model.eval() - self.device = device def _compute_temporal_proportions(self): @@ -167,11 +160,13 @@ class SARMEncodingProcessorStep(ProcessorStep): for name in self.subtask_names } else: - # Equal proportions if no duration info - self.temporal_proportions = { - name: 1.0 / len(self.subtask_names) - for name in self.subtask_names - } + raise ValueError( + "Cannot compute temporal proportions: all subtask durations are zero. " + "Check that your dataset has valid subtask annotations with start/end times." + ) + + # Store in config for the model to use in progress output conversion (SARM paper Eq. 4) + self.config.temporal_proportions = [self.temporal_proportions[name] for name in self.subtask_names] logging.info(f"Computed temporal proportions for {len(self.subtask_names)} subtasks: {self.temporal_proportions}") @@ -481,15 +476,9 @@ class SARMEncodingProcessorStep(ProcessorStep): observation['state_features'] = torch.tensor(state_data, dtype=torch.float32) - # 3. Encode text with MiniLM + # 3. Encode text with CLIP (per SARM paper A.4) batch_size = video_features.shape[0] - task_descriptions = new_transition.get('task') - if task_descriptions is not None: - if isinstance(task_descriptions, str): - task_descriptions = [task_descriptions] * batch_size - observation['text_features'] = self._encode_text_batch_list(task_descriptions) - else: - observation['text_features'] = self._encode_text_batch(self.task_description, batch_size) + observation['text_features'] = self._encode_text_clip(self.task_description, batch_size) # 4. Extract frame/episode indices from complementary data comp_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) @@ -609,54 +598,33 @@ class SARMEncodingProcessorStep(ProcessorStep): return all_embeddings @torch.no_grad() - def _encode_text_batch(self, text: str, batch_size: int) -> torch.Tensor: - """Encode a text string using MiniLM and replicate for batch. + def _encode_text_clip(self, text: str, batch_size: int) -> torch.Tensor: + """Encode text using CLIP text encoder (per SARM paper A.4). Args: - text: Text string to encode + text: Task description text to encode batch_size: Batch size to replicate for Returns: - Encoded feature vectors with shape (B, 384) + Encoded text features with shape (B, 512) """ - from lerobot.policies.rewind.modeling_rewind import mean_pooling + # Use CLIP's tokenizer directly for text (avoids image processor validation issues) + tokenizer = self.clip_processor.tokenizer + inputs = tokenizer([text], return_tensors="pt", padding=True, truncation=True) + inputs = {k: v.to(self.device) for k, v in inputs.items()} - encoded_input = self.minilm_tokenizer( - text, padding=True, truncation=True, return_tensors="pt" - ).to(self.device) + # Get text features from CLIP + text_embedding = self.clip_model.get_text_features(**inputs).detach().cpu() - model_output = self.minilm_model(**encoded_input) - text_embedding = mean_pooling(model_output, encoded_input["attention_mask"]) - text_embedding = text_embedding.squeeze().cpu() + # Handle single text case + if text_embedding.dim() == 1: + text_embedding = text_embedding.unsqueeze(0) - # Replicate for batch (B, 384) - text_embedding = text_embedding.unsqueeze(0).repeat(batch_size, 1) + # Replicate for batch (B, 512) + text_embedding = text_embedding.expand(batch_size, -1) return text_embedding - @torch.no_grad() - def _encode_text_batch_list(self, text_list: list[str]) -> torch.Tensor: - """Encode a list of text strings using MiniLM. - - Args: - text_list: List of text strings to encode - - Returns: - Encoded feature vectors with shape (B, 384) - """ - from lerobot.policies.rewind.modeling_rewind import mean_pooling - - # Encode all texts in the batch at once - encoded_input = self.minilm_tokenizer( - text_list, padding=True, truncation=True, return_tensors="pt" - ).to(self.device) - - model_output = self.minilm_model(**encoded_input) - text_embeddings = mean_pooling(model_output, encoded_input["attention_mask"]) - text_embeddings = text_embeddings.cpu() - - return text_embeddings - def transform_features( self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: @@ -688,9 +656,12 @@ def make_sarm_pre_post_processors( """ Create pre-processor and post-processor pipelines for SARM. + Per SARM paper (Appendix A.4): "We employ a frozen clip-vit-base-patch32 encoder + to process both RGB image sequences and task descriptions." + The pre-processing pipeline: 1. Encodes images with CLIP (512-dim) - 2. Encodes text with MiniLM (384-dim) + 2. Encodes text with CLIP (512-dim) 3. Normalizes joint states 4. Adds batch dimension 5. Moves data to device diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index a4fede3d3..7abe8add9 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -229,8 +229,8 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): # Only provide dataset_stats when not resuming from saved processor state processor_kwargs["dataset_stats"] = dataset.meta.stats - # For ReWiND and SARM, always provide dataset_meta for progress normalization - if cfg.policy.type in ["rewind", "sarm"]: + # For SARM, always provide dataset_meta for progress normalization + if cfg.policy.type == "sarm": processor_kwargs["dataset_meta"] = dataset.meta if cfg.policy.pretrained_path is not None: @@ -319,20 +319,17 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): drop_n_last_frames=cfg.policy.drop_n_last_frames, shuffle=True, ) - elif cfg.policy.type in ["rewind", "sarm"] and getattr(cfg.policy, "use_temporal_sampler", False): - # Use temporal sequence sampler for loading sequences - from lerobot.datasets.temporal_sampler import TemporalSequenceSampler + elif cfg.policy.type == "sarm" and getattr(cfg.policy, "use_temporal_sampler", False): + # Use SARM temporal sampler for reward model training + from lerobot.datasets.temporal_sampler import SARMTemporalSampler shuffle = False - sampling_mode = getattr(cfg.policy, "sampling_mode", cfg.policy.type) - sampler = TemporalSequenceSampler( + sampler = SARMTemporalSampler( dataset_from_index=dataset.meta.episodes["dataset_from_index"], dataset_to_index=dataset.meta.episodes["dataset_to_index"], - sequence_length=cfg.policy.max_length, - stride=getattr(cfg.policy, "sequence_stride", 1) if cfg.policy.type == "rewind" else getattr(cfg.policy, "frame_gap", 30), + frame_gap=getattr(cfg.policy, "frame_gap", 30), shuffle=True, seed=cfg.seed, - sampling_mode=sampling_mode, ) else: shuffle = True diff --git a/src/lerobot/utils/rabc.py b/src/lerobot/utils/rabc.py index e9648de78..e24c18df2 100644 --- a/src/lerobot/utils/rabc.py +++ b/src/lerobot/utils/rabc.py @@ -35,7 +35,7 @@ class RABCWeightComputer: and applies soft weighting based on progress deltas. Args: - reward_model: Pre-trained reward model (e.g., SARM, ReWiND) + reward_model: Pre-trained reward model (e.g., SARM) kappa: Hard threshold for high-quality samples (default: 0.01) epsilon: Small constant for numerical stability (default: 1e-6) device: Device to run reward model on