From 69868360c7eb7f283c923028b91fd0074a05b419 Mon Sep 17 00:00:00 2001 From: Pepijn Kooijmans Date: Tue, 18 Nov 2025 15:00:05 +0100 Subject: [PATCH] add sarm --- scripts/visualize_rewind_predictions.py | 528 ++++++++++++ scripts/visualize_sarm_predictions.py | 537 ++++++++++++ src/lerobot/configs/train.py | 14 + src/lerobot/datasets/rewind_sampler.py | 128 +++ src/lerobot/datasets/temporal_sampler.py | 181 ++++ src/lerobot/datasets/video_sampler.py | 323 +++++-- src/lerobot/policies/factory.py | 20 + .../policies/rewind/configuration_rewind.py | 56 +- .../policies/rewind/modeling_rewind.py | 108 ++- .../policies/rewind/processor_rewind.py | 325 +++++-- src/lerobot/policies/sarm/__init__.py | 38 + .../policies/sarm/configuration_sarm.py | 165 ++++ src/lerobot/policies/sarm/modeling_sarm.py | 808 ++++++++++++++++++ src/lerobot/policies/sarm/processor_sarm.py | 552 ++++++++++++ src/lerobot/scripts/lerobot_train.py | 64 ++ src/lerobot/utils/rabc.py | 183 ++++ 16 files changed, 3872 insertions(+), 158 deletions(-) create mode 100644 scripts/visualize_rewind_predictions.py create mode 100644 scripts/visualize_sarm_predictions.py create mode 100644 src/lerobot/datasets/rewind_sampler.py create mode 100644 src/lerobot/datasets/temporal_sampler.py create mode 100644 src/lerobot/policies/sarm/__init__.py create mode 100644 src/lerobot/policies/sarm/configuration_sarm.py create mode 100644 src/lerobot/policies/sarm/modeling_sarm.py create mode 100644 src/lerobot/policies/sarm/processor_sarm.py create mode 100644 src/lerobot/utils/rabc.py diff --git a/scripts/visualize_rewind_predictions.py b/scripts/visualize_rewind_predictions.py new file mode 100644 index 000000000..b3df946b3 --- /dev/null +++ b/scripts/visualize_rewind_predictions.py @@ -0,0 +1,528 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Inference script for ReWiND Reward Model. + +This script loads a trained ReWiND model and runs inference on a dataset episode, +generating visualizations of the predicted task progression over time. + +Example usage: + python scripts/visualize_rewind_predictions.py \ + --model-id username/rewind-model \ + --dataset-repo lerobot/aloha_sim_insertion_human \ + --episode-index 0 \ + --output-dir outputs/rewind_viz \ + --task-description "insert the peg into the socket" +""" + +import argparse +import logging +from pathlib import Path +from typing import Optional + +import matplotlib.pyplot as plt +import matplotlib.gridspec as gridspec +import numpy as np +import torch +from tqdm import tqdm + +from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.policies.rewind.modeling_rewind import ReWiNDRewardModel + + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def parse_args(): + parser = argparse.ArgumentParser(description="Run ReWiND inference and visualize predictions") + + # Model arguments + parser.add_argument( + "--model-id", + type=str, + required=True, + help="HuggingFace model ID or local path to trained ReWiND model" + ) + + # Dataset arguments + parser.add_argument( + "--dataset-repo", + type=str, + required=True, + help="HuggingFace dataset repository ID (e.g., lerobot/aloha_sim_insertion_human)" + ) + parser.add_argument( + "--episode-index", + type=int, + default=0, + help="Index of the episode to visualize (default: 0)" + ) + parser.add_argument( + "--task-description", + type=str, + default="perform the task", + help="Task description for the reward model (default: 'perform the task')" + ) + + # Output arguments + parser.add_argument( + "--output-dir", + type=str, + default="outputs/rewind_inference", + help="Directory to save visualization outputs (default: outputs/rewind_inference)" + ) + parser.add_argument( + "--image-key", + type=str, + default=None, + help="Key for images in dataset (e.g., observation.images.image for jaco_play). If not specified, uses model config's image_key" + ) + + # Visualization options + parser.add_argument( + "--show-frames", + action="store_true", + help="Include sample frames in the visualization" + ) + parser.add_argument( + "--num-sample-frames", + type=int, + default=8, + help="Number of sample frames to show (default: 8)" + ) + parser.add_argument( + "--figsize", + type=int, + nargs=2, + default=[12, 6], + help="Figure size as width height (default: 12 6)" + ) + + # Device + parser.add_argument( + "--device", + type=str, + default=None, + help="Device to run inference on (cuda/cpu, default: auto-detect)" + ) + + return parser.parse_args() + + +def load_episode_data( + dataset: LeRobotDataset, + episode_index: int, + image_key: str +) -> tuple[np.ndarray, int, int, str]: + """ + Load all frames from a specific episode. + + Args: + dataset: LeRobotDataset instance + episode_index: Index of the episode to load + image_key: Key for accessing images in the dataset + + Returns: + Tuple of (frames, start_index, end_index, task_description) + """ + # Get episode boundaries + episode_data = dataset.meta.episodes + start_idx = episode_data["dataset_from_index"][episode_index] + end_idx = episode_data["dataset_to_index"][episode_index] + + logger.info(f"Loading episode {episode_index}: frames {start_idx} to {end_idx} ({end_idx - start_idx} frames)") + + # Get task description from the dataset if available + task_description = None + first_item = dataset[start_idx] + if "task" in first_item: + task_description = first_item["task"] + print(f"✓ Extracted task from episode {episode_index}: '{task_description}'") + + # Load all frames from the episode + frames = [] + for idx in tqdm(range(start_idx, end_idx), desc="Loading frames"): + item = dataset[idx] + # Get image from the item + img = item[image_key] + + # Convert to numpy if needed + if isinstance(img, torch.Tensor): + img = img.cpu().numpy() + + # Handle different image formats (C, H, W) or (H, W, C) + if img.shape[0] in [1, 3]: # Channel first + img = np.transpose(img, (1, 2, 0)) + + # Convert to uint8 if needed + if img.dtype != np.uint8: + if img.max() <= 1.0: + img = (img * 255).astype(np.uint8) + else: + img = img.astype(np.uint8) + + frames.append(img) + + frames = np.array(frames) + logger.info(f"Loaded {len(frames)} frames with shape {frames[0].shape}") + + return frames, start_idx, end_idx, task_description + + +@torch.no_grad() +def run_inference( + model: ReWiNDRewardModel, + frames: np.ndarray, + task_description: str, + batch_size: int = 32 +) -> tuple[np.ndarray, np.ndarray]: + """ + Run ReWiND inference on video frames using the original ReWiND approach. + + This function creates video slices for all frames at once (similar to the original + metaworld_label_reward.py), where each slice contains frames from start up to that point. + + Progress Normalization (from original ReWiND dataset.py): + - Training: progress = [1, 2, ..., N] / remaining_length + where remaining_length = episode_end - sequence_start + - Inference: Starting from frame 0, remaining_length = total_episode_length + So expected progress for frame i = (i + 1) / total_episode_length + + This function computes both: + 1. Model predictions (what the model actually predicts) + 2. Expected progress (ground truth based on frame position) + + Args: + model: ReWiND model + frames: Video frames (num_frames, H, W, C) + task_description: Task description text + batch_size: Batch size for processing slices + + Returns: + Tuple of: + - Model predictions for each frame (num_frames,) + - Expected progress for each frame (num_frames,) + """ + total_frames = len(frames) + + logger.info("Encoding video frames with DINO...") + video_embeddings = model.encode_images(frames) + + logger.info("Encoding task description with MiniLM...") + text_embedding = model.encode_text(task_description) + + logger.info("Creating video slices (original ReWiND approach)...") + # Convert to tensors + video_embeddings = torch.tensor(video_embeddings, dtype=torch.float32) + text_embedding = torch.tensor(text_embedding, dtype=torch.float32) + + # Create video slices: for each frame i, create a sequence of frames [0:i+1] + # This matches the original ReWiND inference approach + video_slices = [] + for i in tqdm(range(len(video_embeddings)), desc="Creating slices"): + # Slice from start to current frame (inclusive) + video_slice = video_embeddings[:i + 1] + + # Pad or subsample to max_length + if model.config.subsample_video: + video_slice = model.padding_video(video_slice, model.config.max_length) + + video_slices.append(video_slice) + + video_slices = torch.stack(video_slices) # (num_frames, max_length, 768) + + # Create last_index_mask to extract the relevant prediction for each slice + # For slice i, the last valid frame is at position min(i, max_length-1) + max_length = model.config.max_length + last_index_mask = torch.zeros((len(video_slices), max_length), dtype=torch.bool) + + for i in range(len(video_slices)): + last_frame_idx = min(i, max_length - 1) + last_index_mask[i, last_frame_idx] = 1 + + logger.info("Running ReWiND inference on all slices...") + # Process in batches + all_progress = [] + for i in tqdm(range(0, len(video_slices), batch_size), desc="Inference"): + batch_video = video_slices[i:i + batch_size].to(model.device) + batch_mask = last_index_mask[i:i + batch_size].to(model.device) + batch_size_actual = batch_video.shape[0] + + # Replicate text embedding for batch + batch_text = text_embedding.unsqueeze(0).repeat(batch_size_actual, 1).to(model.device) + + # Get predictions for all frames in batch + progress_preds = model.rewind_transformer(batch_video, batch_text) # (batch, max_length, 1) + progress_preds = progress_preds.squeeze(-1) # (batch, max_length) + + # Extract predictions using the last_index_mask + # This gets the prediction for the last valid frame in each slice + batch_progress = progress_preds[batch_mask].cpu().numpy() + all_progress.extend(batch_progress) + + predictions = np.array(all_progress) + + # Compute expected progress based on original ReWiND normalization + # When starting from frame 0, remaining_length = total_episode_length + # Expected progress for frame i = (i + 1) / total_frames + expected_progress = np.arange(1, total_frames + 1, dtype=np.float32) / total_frames + + logger.info(f"Inference complete. Predicted progress range: [{predictions.min():.3f}, {predictions.max():.3f}]") + logger.info(f"Expected progress range: [{expected_progress.min():.3f}, {expected_progress.max():.3f}]") + + return predictions, expected_progress + + +def visualize_predictions( + frames: np.ndarray, + predictions: np.ndarray, + expected_progress: np.ndarray, + task_description: str, + output_path: Path, + show_frames: bool = False, + num_sample_frames: int = 8, + figsize: tuple = (12, 6) +): + """ + Create visualization of ReWiND predictions with expected progress comparison. + + Args: + frames: Video frames (num_frames, H, W, C) + predictions: Model progress predictions (num_frames,) + expected_progress: Expected progress based on frame position (num_frames,) + task_description: Task description + output_path: Path to save the figure + show_frames: Whether to include sample frames + num_sample_frames: Number of frames to show + figsize: Figure size (width, height) + """ + if show_frames: + # Create figure with progress plot and sample frames + fig = plt.figure(figsize=(figsize[0], figsize[1] + 4)) + gs = gridspec.GridSpec(2, 1, height_ratios=[2, 1], hspace=0.3) + + # Progress plot + ax_progress = fig.add_subplot(gs[0]) + else: + # Just progress plot + fig, ax_progress = plt.subplots(1, 1, figsize=figsize) + + # Plot progress over time + frame_indices = np.arange(len(predictions)) + + # Plot expected progress (ground truth) + ax_progress.plot(frame_indices, expected_progress, linewidth=2, color='#A8DADC', + linestyle='--', label='Expected Progress (Linear)', alpha=0.7) + + # Plot model predictions + ax_progress.plot(frame_indices, predictions, linewidth=2.5, color='#2E86AB', + label='Model Predictions') + ax_progress.fill_between(frame_indices, 0, predictions, alpha=0.2, color='#2E86AB') + + # Add reference line at 1.0 + ax_progress.axhline(y=1.0, color='gray', linestyle='--', alpha=0.5, linewidth=1) + + # Styling + ax_progress.set_xlabel('Frame Index', fontsize=12) + ax_progress.set_ylabel('Task Progress', fontsize=12) + ax_progress.set_title(f'ReWiND Task Progress Prediction\nTask: "{task_description}"', + fontsize=14, fontweight='bold') + ax_progress.grid(True, alpha=0.3) + ax_progress.set_ylim(-0.05, 1.1) + ax_progress.legend(loc='upper left') + + # Compute alignment metrics + mae = np.mean(np.abs(predictions - expected_progress)) + rmse = np.sqrt(np.mean((predictions - expected_progress) ** 2)) + + # Add statistics box + stats_text = ( + f'Frames: {len(predictions)}\n' + f'Model Final: {predictions[-1]:.3f}\n' + f'Model Max: {predictions.max():.3f}\n' + f'Model Mean: {predictions.mean():.3f}\n' + f'MAE: {mae:.3f}\n' + f'RMSE: {rmse:.3f}' + ) + ax_progress.text(0.98, 0.02, stats_text, transform=ax_progress.transAxes, + fontsize=10, verticalalignment='bottom', horizontalalignment='right', + bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5)) + + # Show sample frames if requested + if show_frames: + # Select evenly spaced frames + frame_indices_to_show = np.linspace(0, len(frames) - 1, num_sample_frames, dtype=int) + + # Create subplot for frames + ax_frames = fig.add_subplot(gs[1]) + ax_frames.axis('off') + + # Create grid for frames + frame_height = frames[0].shape[0] + frame_width = frames[0].shape[1] + + combined_width = frame_width * num_sample_frames + combined_image = np.zeros((frame_height, combined_width, 3), dtype=np.uint8) + + for i, frame_idx in enumerate(frame_indices_to_show): + frame = frames[frame_idx] + if frame.shape[-1] == 1: + frame = np.repeat(frame, 3, axis=-1) + + # Add frame to combined image + x_start = i * frame_width + x_end = (i + 1) * frame_width + combined_image[:, x_start:x_end] = frame + + # Add frame number and progress value + progress_val = predictions[frame_idx] + label = f'Frame {frame_idx}\nProgress: {progress_val:.3f}' + + # Draw label on image + ax_frames.text(x_start + frame_width / 2, -10, label, + ha='center', va='top', fontsize=8, + bbox=dict(boxstyle='round', facecolor='white', alpha=0.7)) + + ax_frames.imshow(combined_image) + ax_frames.set_title('Sample Frames', fontsize=12, pad=20) + + # Save figure + plt.tight_layout() + output_path.parent.mkdir(parents=True, exist_ok=True) + plt.savefig(output_path, dpi=150, bbox_inches='tight') + logger.info(f"Saved visualization to {output_path}") + + plt.close() + + +def main(): + args = parse_args() + + # Setup device + if args.device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + else: + device = args.device + logger.info(f"Using device: {device}") + + # Load model + logger.info(f"Loading ReWiND model from {args.model_id}...") + model = ReWiNDRewardModel.from_pretrained(args.model_id) + model.to(device) + model.eval() + logger.info("Model loaded successfully") + + # Load dataset + logger.info(f"Loading dataset {args.dataset_repo}...") + dataset = LeRobotDataset(args.dataset_repo) + logger.info(f"Dataset loaded: {len(dataset.meta.episodes)} episodes, {len(dataset)} frames") + + # Validate episode index + if args.episode_index >= len(dataset.meta.episodes): + raise ValueError( + f"Episode index {args.episode_index} out of range. " + f"Dataset has {len(dataset.meta.episodes)} episodes." + ) + + # Determine which image key to use + image_key = args.image_key if args.image_key is not None else model.config.image_key + logger.info(f"Using image key: {image_key}") + + # Load episode data (this also extracts the task description from the episode) + frames, start_idx, end_idx, dataset_task = load_episode_data(dataset, args.episode_index, image_key) + + # Use task description from dataset if available, otherwise use command-line argument + task_description = dataset_task if dataset_task is not None else args.task_description + logger.info(f"Using task description: '{task_description}'") + + # Run inference + predictions, expected_progress = run_inference(model, frames, task_description) + + # Create visualization + output_dir = Path(args.output_dir) + output_path = output_dir / f"rewind_prediction_ep{args.episode_index}.png" + + visualize_predictions( + frames, + predictions, + expected_progress, + task_description, + output_path, + show_frames=args.show_frames, + num_sample_frames=args.num_sample_frames, + figsize=tuple(args.figsize) + ) + + # Save predictions and expected progress as numpy arrays + predictions_path = output_dir / f"predictions_ep{args.episode_index}.npy" + expected_path = output_dir / f"expected_progress_ep{args.episode_index}.npy" + np.save(predictions_path, predictions) + np.save(expected_path, expected_progress) + logger.info(f"Saved predictions array to {predictions_path}") + logger.info(f"Saved expected progress to {expected_path}") + + # Compute alignment metrics + mae = np.mean(np.abs(predictions - expected_progress)) + rmse = np.sqrt(np.mean((predictions - expected_progress) ** 2)) + correlation = np.corrcoef(predictions, expected_progress)[0, 1] + + # Print summary + logger.info("\n" + "="*60) + logger.info("INFERENCE SUMMARY") + logger.info("="*60) + logger.info(f"Model: {args.model_id}") + logger.info(f"Dataset: {args.dataset_repo}") + logger.info(f"Episode: {args.episode_index}") + logger.info(f"Task: {task_description}") + logger.info(f"Frames: {len(frames)}") + logger.info(f"\nModel Predictions:") + logger.info(f" Final: {predictions[-1]:.3f}") + logger.info(f" Max: {predictions.max():.3f}") + logger.info(f" Mean: {predictions.mean():.3f}") + logger.info(f" Std: {predictions.std():.3f}") + logger.info(f"\nExpected Progress (Linear):") + logger.info(f" Final: {expected_progress[-1]:.3f}") + logger.info(f" Mean: {expected_progress.mean():.3f}") + logger.info(f"\nAlignment Metrics:") + logger.info(f" MAE: {mae:.3f}") + logger.info(f" RMSE: {rmse:.3f}") + logger.info(f" Correlation: {correlation:.3f}") + logger.info(f"\nOutput:") + logger.info(f" Visualization: {output_path}") + logger.info("="*60) + + # Diagnostic warnings + if predictions.std() < 0.05: + logger.warning("\n⚠ WARNING: Mode collapse detected (std < 0.05)") + logger.warning(" Model predictions show very low variance.") + logger.warning(" This indicates the model was likely trained with incorrect") + logger.warning(" progress normalization (absolute indices instead of remaining length).") + elif mae > 0.3: + logger.warning("\n⚠ WARNING: High prediction error (MAE > 0.3)") + logger.warning(" Model predictions deviate significantly from expected linear progress.") + logger.warning(" Consider retraining with correct progress normalization.") + elif correlation < 0.5: + logger.warning("\n⚠ WARNING: Low correlation with expected progress (< 0.5)") + logger.warning(" Model predictions don't align well with linear task progression.") + else: + logger.info("\n✓ Model predictions show healthy progression!") + + +if __name__ == "__main__": + main() + diff --git a/scripts/visualize_sarm_predictions.py b/scripts/visualize_sarm_predictions.py new file mode 100644 index 000000000..1944803a2 --- /dev/null +++ b/scripts/visualize_sarm_predictions.py @@ -0,0 +1,537 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Inference script for SARM (Stage-Aware Reward Model). + +This script loads a trained SARM model and runs inference on a dataset episode, +generating visualizations of the predicted task stages and progress over time. + +Example usage: + python scripts/visualize_sarm_predictions.py \ + --model-id username/sarm-model \ + --dataset-repo lerobot/aloha_sim_insertion_human \ + --episode-index 0 \ + --output-dir outputs/sarm_viz \ + --task-description "insert the peg into the socket" +""" + +import argparse +import logging +from pathlib import Path +from typing import Optional + +import matplotlib.pyplot as plt +import matplotlib.gridspec as gridspec +import matplotlib.patches as mpatches +import numpy as np +import torch +from tqdm import tqdm + +from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.policies.sarm.modeling_sarm import SARMRewardModel + + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def parse_args(): + parser = argparse.ArgumentParser(description="Run SARM inference and visualize predictions") + + # Model arguments + parser.add_argument( + "--model-id", + type=str, + required=True, + help="HuggingFace model ID or local path to trained SARM model" + ) + + # Dataset arguments + parser.add_argument( + "--dataset-repo", + type=str, + required=True, + help="HuggingFace dataset repository ID (e.g., lerobot/aloha_sim_insertion_human)" + ) + parser.add_argument( + "--episode-index", + type=int, + default=0, + help="Index of the episode to visualize (default: 0)" + ) + parser.add_argument( + "--task-description", + type=str, + default="perform the task", + help="Task description for the reward model (default: 'perform the task')" + ) + + # Output arguments + parser.add_argument( + "--output-dir", + type=str, + default="outputs/sarm_inference", + help="Directory to save visualization outputs (default: outputs/sarm_inference)" + ) + parser.add_argument( + "--image-key", + type=str, + default=None, + help="Key for images in dataset (e.g., observation.images.image). If not specified, uses model config's image_key" + ) + parser.add_argument( + "--state-key", + type=str, + default=None, + help="Key for joint states in dataset. If None, auto-detects from dataset" + ) + + # Visualization options + parser.add_argument( + "--show-frames", + action="store_true", + help="Include sample frames in the visualization" + ) + parser.add_argument( + "--num-sample-frames", + type=int, + default=8, + help="Number of sample frames to show (default: 8)" + ) + parser.add_argument( + "--figsize", + type=int, + nargs=2, + default=[14, 8], + help="Figure size as width height (default: 14 8)" + ) + + # Device + parser.add_argument( + "--device", + type=str, + default=None, + help="Device to run inference on (cuda/cpu, default: auto-detect)" + ) + + return parser.parse_args() + + +def load_episode_data( + dataset: LeRobotDataset, + episode_index: int, + image_key: str, + state_key: str | None = None +) -> tuple[np.ndarray, np.ndarray, int, int, str]: + """ + Load all frames and states from a specific episode. + + Args: + dataset: LeRobotDataset instance + episode_index: Index of the episode to load + image_key: Key for accessing images in the dataset + state_key: Key for accessing joint states (auto-detected if None) + + Returns: + Tuple of (frames, states, start_index, end_index, task_description) + """ + # Get episode boundaries + episode_data = dataset.meta.episodes + start_idx = episode_data["dataset_from_index"][episode_index] + end_idx = episode_data["dataset_to_index"][episode_index] + + logger.info(f"Loading episode {episode_index}: frames {start_idx} to {end_idx} ({end_idx - start_idx} frames)") + + # Auto-detect state key if not provided + if state_key is None: + first_item = dataset[start_idx] + state_keys = [k for k in first_item.keys() if 'state' in k.lower() or 'qpos' in k.lower()] + if state_keys: + state_key = state_keys[0] + logger.info(f"Auto-detected state key: {state_key}") + + # Get task description from the dataset if available + task_description = None + first_item = dataset[start_idx] + if "task" in first_item: + task_description = first_item["task"] + logger.info(f"✓ Extracted task from episode {episode_index}: '{task_description}'") + + # Load all frames and states from the episode + frames = [] + states = [] + for idx in tqdm(range(start_idx, end_idx), desc="Loading frames"): + item = dataset[idx] + + # Get image + img = item[image_key] + + # Convert to numpy if needed + if isinstance(img, torch.Tensor): + img = img.cpu().numpy() + + # Handle different image formats (C, H, W) or (H, W, C) + if img.shape[0] in [1, 3]: # Channel first + img = np.transpose(img, (1, 2, 0)) + + # Convert to uint8 if needed + if img.dtype != np.uint8: + if img.max() <= 1.0: + img = (img * 255).astype(np.uint8) + else: + img = img.astype(np.uint8) + + frames.append(img) + + # Get state if available + if state_key and state_key in item: + state = item[state_key] + if isinstance(state, torch.Tensor): + state = state.cpu().numpy() + states.append(state) + + frames = np.array(frames) + states = np.array(states) if states else None + logger.info(f"Loaded {len(frames)} frames with shape {frames[0].shape}") + if states is not None: + logger.info(f"Loaded states with shape {states.shape}") + + return frames, states, start_idx, end_idx, task_description + + +@torch.no_grad() +def run_inference( + model: SARMRewardModel, + frames: np.ndarray, + states: Optional[np.ndarray], + task_description: str, + batch_size: int = 32 +) -> tuple[np.ndarray, np.ndarray]: + """ + Run SARM inference on video frames and joint states. + + For each frame t, creates a temporal sequence of 9 frames using SARM's pattern: + [t-240, t-210, t-180, t-150, t-120, t-90, t-60, t-30, t] + This matches the training pattern where frames are loaded with 30-frame gaps + relative to the current frame. + + Args: + model: SARM model + frames: Video frames (num_frames, H, W, C) + states: Joint states (num_frames, state_dim) + task_description: Task description text + batch_size: Batch size for processing slices + + Returns: + Tuple of (progress_predictions, stage_predictions) + - progress_predictions: (num_frames,) + - stage_predictions: (num_frames, num_stages) + """ + logger.info("Encoding video frames with CLIP...") + video_embeddings = model.encode_images(frames) + + logger.info("Encoding task description with MiniLM...") + text_embedding = model.encode_text(task_description) + + logger.info("Creating video slices (SARM approach)...") + # Convert to tensors + video_embeddings = torch.tensor(video_embeddings, dtype=torch.float32) + text_embedding = torch.tensor(text_embedding, dtype=torch.float32) + if states is not None: + state_embeddings = torch.tensor(states, dtype=torch.float32) + else: + state_embeddings = None + + # Create video slices: for each frame i, create a sequence using SARM's pattern + # For SARM: 9 frames relative to current, with 30-frame gaps + # Pattern: [current-240, current-210, ..., current-30, current] + num_frames_model = model.config.num_frames + frame_gap = model.config.frame_gap + + video_slices = [] + state_slices = [] + last_frame_indices = [] + + for i in tqdm(range(len(video_embeddings)), desc="Creating slices"): + # For SARM, create sequence relative to current frame (matching training pattern) + # Pattern: [current-240, current-210, ..., current-30, current] + # This matches observation_delta_indices: range(-240, 1, 30) + + # Compute frame indices for this slice (relative to current frame i) + frame_indices = [] + for j in range(num_frames_model): + # Start from -(num_frames_model-1) * frame_gap and go to 0 + offset = -(num_frames_model - 1 - j) * frame_gap + idx = i + offset + + # Clamp to valid range [0, current_frame] + if idx < 0: + idx = 0 # Pad with first available frame + + frame_indices.append(idx) + + # Extract slice + video_slice = video_embeddings[frame_indices] + video_slices.append(video_slice) + + if state_embeddings is not None: + state_slice = state_embeddings[frame_indices] + state_slices.append(state_slice) + + # Track which frame index corresponds to the "current" frame + last_frame_indices.append(min(i, len(frame_indices) - 1)) + + video_slices = torch.stack(video_slices) # (num_frames, num_frames_model, 512) + if state_embeddings is not None: + state_slices = torch.stack(state_slices) # (num_frames, num_frames_model, state_dim) + else: + state_slices = None + + logger.info("Running SARM inference on all slices...") + # Process in batches + all_progress = [] + all_stages = [] + + for i in tqdm(range(0, len(video_slices), batch_size), desc="Inference"): + batch_video = video_slices[i:i + batch_size].to(model.device) + batch_states = state_slices[i:i + batch_size].to(model.device) if state_slices is not None else None + batch_size_actual = batch_video.shape[0] + + # Replicate text embedding for batch + batch_text = text_embedding.unsqueeze(0).repeat(batch_size_actual, 1).to(model.device) + + # Get predictions + stage_logits, stage_probs, progress_preds = model.sarm_transformer( + batch_video, batch_text, batch_states + ) + + # Extract last frame predictions (the "current" frame) + # For SARM, we take the last frame in each sequence + batch_progress = progress_preds[:, -1, 0].cpu().numpy() + batch_stages = stage_probs[:, -1, :].cpu().numpy() + + all_progress.extend(batch_progress) + all_stages.extend(batch_stages) + + return np.array(all_progress), np.array(all_stages) + + +def visualize_predictions( + frames: np.ndarray, + progress_predictions: np.ndarray, + stage_predictions: np.ndarray, + task_description: str, + output_path: Path, + show_frames: bool = False, + num_sample_frames: int = 8, + figsize: tuple = (14, 8) +): + """ + Create visualization of SARM predictions. + + Args: + frames: Video frames (num_frames, H, W, C) + progress_predictions: Progress predictions (num_frames,) + stage_predictions: Stage probabilities (num_frames, num_stages) + task_description: Task description + output_path: Path to save the figure + show_frames: Whether to include sample frames + num_sample_frames: Number of frames to show + figsize: Figure size (width, height) + """ + num_stages = stage_predictions.shape[1] + stage_colors = plt.cm.tab10(np.linspace(0, 1, num_stages)) + + if show_frames: + # Create figure with progress plot, stage plot, and sample frames + fig = plt.figure(figsize=(figsize[0], figsize[1] + 4)) + gs = gridspec.GridSpec(3, 1, height_ratios=[2, 1, 1], hspace=0.3) + + ax_progress = fig.add_subplot(gs[0]) + ax_stages = fig.add_subplot(gs[1], sharex=ax_progress) + ax_frames = fig.add_subplot(gs[2]) + else: + # Just progress and stage plots + fig = plt.figure(figsize=figsize) + gs = gridspec.GridSpec(2, 1, height_ratios=[2, 1], hspace=0.3) + + ax_progress = fig.add_subplot(gs[0]) + ax_stages = fig.add_subplot(gs[1], sharex=ax_progress) + + frame_indices = np.arange(len(progress_predictions)) + + # Plot 1: Progress over time + ax_progress.plot(frame_indices, progress_predictions, linewidth=2, color='#2E86AB', label='Predicted Progress') + ax_progress.fill_between(frame_indices, 0, progress_predictions, alpha=0.3, color='#2E86AB') + ax_progress.axhline(y=1.0, color='gray', linestyle='--', alpha=0.5, linewidth=1) + ax_progress.set_ylabel('Task Progress', fontsize=12) + ax_progress.set_title(f'SARM Task Progress & Stage Prediction\nTask: "{task_description}"', + fontsize=14, fontweight='bold') + ax_progress.grid(True, alpha=0.3) + ax_progress.set_ylim(-0.05, 1.1) + ax_progress.legend(loc='upper left') + + # Add statistics box + stats_text = ( + f'Frames: {len(progress_predictions)}\n' + f'Final Progress: {progress_predictions[-1]:.3f}\n' + f'Max Progress: {progress_predictions.max():.3f}\n' + f'Mean Progress: {progress_predictions.mean():.3f}' + ) + ax_progress.text(0.98, 0.02, stats_text, transform=ax_progress.transAxes, + fontsize=10, verticalalignment='bottom', horizontalalignment='right', + bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5)) + + # Plot 2: Stage predictions (stacked area plot) + ax_stages.stackplot(frame_indices, *[stage_predictions[:, i] for i in range(num_stages)], + colors=stage_colors, alpha=0.8, labels=[f'Stage {i+1}' for i in range(num_stages)]) + ax_stages.set_xlabel('Frame Index', fontsize=12) + ax_stages.set_ylabel('Stage Probability', fontsize=12) + ax_stages.set_ylim(0, 1) + ax_stages.grid(True, alpha=0.3) + ax_stages.legend(loc='upper left', ncol=num_stages, fontsize=8) + + # Plot 3: Sample frames (if requested) + if show_frames: + frame_indices_to_show = np.linspace(0, len(frames) - 1, num_sample_frames, dtype=int) + + ax_frames.axis('off') + + # Create grid for frames + frame_height = frames[0].shape[0] + frame_width = frames[0].shape[1] + + combined_width = frame_width * num_sample_frames + combined_image = np.zeros((frame_height, combined_width, 3), dtype=np.uint8) + + for i, frame_idx in enumerate(frame_indices_to_show): + frame = frames[frame_idx] + if frame.shape[-1] == 1: + frame = np.repeat(frame, 3, axis=-1) + + # Add frame to combined image + x_start = i * frame_width + x_end = (i + 1) * frame_width + combined_image[:, x_start:x_end] = frame + + # Add frame number, progress, and stage + progress_val = progress_predictions[frame_idx] + stage_idx = np.argmax(stage_predictions[frame_idx]) + label = f'Frame {frame_idx}\nProg: {progress_val:.2f}\nStage: {stage_idx+1}' + + # Draw label on image + ax_frames.text(x_start + frame_width / 2, -10, label, + ha='center', va='top', fontsize=7, + bbox=dict(boxstyle='round', facecolor='white', alpha=0.7)) + + ax_frames.imshow(combined_image) + ax_frames.set_title('Sample Frames', fontsize=12, pad=20) + + # Save figure + plt.tight_layout() + output_path.parent.mkdir(parents=True, exist_ok=True) + plt.savefig(output_path, dpi=150, bbox_inches='tight') + logger.info(f"Saved visualization to {output_path}") + + plt.close() + + +def main(): + args = parse_args() + + # Setup device + if args.device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + else: + device = args.device + logger.info(f"Using device: {device}") + + # Load model + logger.info(f"Loading SARM model from {args.model_id}...") + model = SARMRewardModel.from_pretrained(args.model_id) + model.to(device) + model.eval() + logger.info("Model loaded successfully") + + # Load dataset + logger.info(f"Loading dataset {args.dataset_repo}...") + dataset = LeRobotDataset(args.dataset_repo) + logger.info(f"Dataset loaded: {len(dataset.meta.episodes)} episodes, {len(dataset)} frames") + + # Validate episode index + if args.episode_index >= len(dataset.meta.episodes): + raise ValueError( + f"Episode index {args.episode_index} out of range. " + f"Dataset has {len(dataset.meta.episodes)} episodes." + ) + + # Determine which image key to use + image_key = args.image_key if args.image_key is not None else model.config.image_key + logger.info(f"Using image key: {image_key}") + + # Load episode data + frames, states, start_idx, end_idx, dataset_task = load_episode_data( + dataset, args.episode_index, image_key, args.state_key + ) + + # Use task description from dataset if available, otherwise use command-line argument + task_description = dataset_task if dataset_task is not None else args.task_description + logger.info(f"Using task description: '{task_description}'") + + # Run inference + progress_predictions, stage_predictions = run_inference(model, frames, states, task_description) + + # Create visualization + output_dir = Path(args.output_dir) + output_path = output_dir / f"sarm_prediction_ep{args.episode_index}.png" + + visualize_predictions( + frames, + progress_predictions, + stage_predictions, + task_description, + output_path, + show_frames=args.show_frames, + num_sample_frames=args.num_sample_frames, + figsize=tuple(args.figsize) + ) + + # Save predictions as numpy arrays + predictions_path = output_dir / f"predictions_ep{args.episode_index}.npz" + np.savez(predictions_path, progress=progress_predictions, stages=stage_predictions) + logger.info(f"Saved predictions to {predictions_path}") + + # Print summary + logger.info("\n" + "="*60) + logger.info("INFERENCE SUMMARY") + logger.info("="*60) + logger.info(f"Model: {args.model_id}") + logger.info(f"Dataset: {args.dataset_repo}") + logger.info(f"Episode: {args.episode_index}") + logger.info(f"Task: {task_description}") + logger.info(f"Frames: {len(frames)}") + logger.info(f"Final Progress: {progress_predictions[-1]:.3f}") + logger.info(f"Max Progress: {progress_predictions.max():.3f}") + logger.info(f"Mean Progress: {progress_predictions.mean():.3f}") + logger.info(f"Most Common Stage: {np.argmax(np.sum(stage_predictions, axis=0)) + 1}") + logger.info(f"Visualization: {output_path}") + logger.info("="*60) + + +if __name__ == "__main__": + main() + diff --git a/src/lerobot/configs/train.py b/src/lerobot/configs/train.py index 60a4d81d5..e1734ed37 100644 --- a/src/lerobot/configs/train.py +++ b/src/lerobot/configs/train.py @@ -63,11 +63,25 @@ class TrainPipelineConfig(HubMixin): scheduler: LRSchedulerConfig | None = None eval: EvalConfig = field(default_factory=EvalConfig) wandb: WandBConfig = field(default_factory=WandBConfig) + + # RA-BC (Reward-Aligned Behavior Cloning) parameters + use_rabc: bool = False # Enable reward-weighted training + reward_model_path: str | None = None # Path to pre-trained reward model (e.g., SARM) + rabc_kappa: float = 0.01 # Hard threshold for high-quality samples + rabc_epsilon: float = 1e-6 # Small constant for numerical stability + rabc_update_freq: int = 1 # Compute rewards every N batches (1 = every batch) def __post_init__(self): self.checkpoint_path = None def validate(self): + # Validate RA-BC configuration + if self.use_rabc and not self.reward_model_path: + raise ValueError( + "RA-BC is enabled (use_rabc=True) but no reward_model_path provided. " + "Please specify a pre-trained reward model (e.g., SARM) path." + ) + # HACK: We parse again the cli args here to get the pretrained paths if there was some. policy_path = parser.get_path_arg("policy") if policy_path: diff --git a/src/lerobot/datasets/rewind_sampler.py b/src/lerobot/datasets/rewind_sampler.py new file mode 100644 index 000000000..1b1d3f392 --- /dev/null +++ b/src/lerobot/datasets/rewind_sampler.py @@ -0,0 +1,128 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +ReWiND Sampler for temporal sequence loading. +""" + +import logging +from typing import Iterator, Optional +import numpy as np +import torch +from torch.utils.data import Sampler +import random + + +class ReWiNDTemporalSampler(Sampler): + """ + Sampler for ReWiND that samples random temporal windows from episodes. + + Matches original ReWiND sampling: + - Samples random start and end points within episodes + - Minimum window size of 3 frames + - Can sample from beginning, middle, or end of episodes + + Args: + dataset_from_index: Start indices of episodes + dataset_to_index: End indices of episodes + sequence_length: Maximum sequence length (for padding/subsampling) + stride: Not used (kept for API compatibility) + shuffle: Whether to shuffle sampling order + seed: Random seed + """ + + def __init__( + self, + dataset_from_index: np.ndarray, + dataset_to_index: np.ndarray, + sequence_length: int = 32, + stride: int = 1, + shuffle: bool = True, + seed: Optional[int] = None, + ): + self.dataset_from_index = np.array(dataset_from_index) + self.dataset_to_index = np.array(dataset_to_index) + self.sequence_length = sequence_length + self.shuffle = shuffle + + if seed is not None: + self.seed = seed + random.seed(seed) + np.random.seed(seed) + self.generator = torch.Generator().manual_seed(seed) + else: + self.generator = torch.Generator() + + # Compute valid episodes (those with at least 3 frames) + self._compute_valid_episodes() + + # Number of samples per epoch (matching original ReWiND) + self.samples_per_epoch = 100 * 64 # 100 batches of 64 + + logging.info( + f"ReWiNDTemporalSampler: {len(self.valid_episodes)} valid episodes, " + f"{self.samples_per_epoch} samples per epoch" + ) + + def _compute_valid_episodes(self): + """Compute valid episodes (those with at least 3 frames).""" + self.valid_episodes = [] + + for ep_idx in range(len(self.dataset_from_index)): + ep_start = self.dataset_from_index[ep_idx] + ep_end = self.dataset_to_index[ep_idx] + episode_length = ep_end - ep_start + + if episode_length >= 3: # Minimum 3 frames + self.valid_episodes.append((ep_idx, ep_start, ep_end)) + + self.valid_episodes = np.array(self.valid_episodes) + + def __len__(self) -> int: + return self.samples_per_epoch + + def __iter__(self) -> Iterator[int]: + """ + Yields ONE index per sample (the end of a random window). + + Matches original ReWiND behavior: + 1. Pick random episode + 2. Pick random end frame (at least 3 frames from start) + 3. Yield that end frame index + 4. Dataset/processor loads from episode start to this end frame + 5. Model pads/subsamples to sequence_length (32) + + This allows sampling from anywhere in episodes: + - Early frames → short sequences (mostly padding) → low progress + - Middle frames → medium sequences (some subsampling) → medium progress + - End frames → long sequences (full subsampling) → high progress approaching 1.0 + """ + for _ in range(self.samples_per_epoch): + # Randomly select an episode + ep_idx, ep_start, ep_end = self.valid_episodes[ + np.random.randint(0, len(self.valid_episodes)) + ] + + episode_length = ep_end - ep_start + + # Sample a random end point (must be at least 3 frames from start) + # This matches original: random.randint(start_idx+3, len(progress_dataset)) + end_offset = np.random.randint(3, episode_length + 1) + end_idx = ep_start + end_offset + + # Yield ONLY the end index + # The dataset will load all frames from ep_start to end_idx + yield int(end_idx - 1) # -1 because end_idx is exclusive diff --git a/src/lerobot/datasets/temporal_sampler.py b/src/lerobot/datasets/temporal_sampler.py new file mode 100644 index 000000000..3e210717e --- /dev/null +++ b/src/lerobot/datasets/temporal_sampler.py @@ -0,0 +1,181 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Temporal Sequence Sampler for reward models and temporal policies. + +Supports multiple sampling modes: +- "rewind": ReWiND-style sampling (random windows from episode start) +- "sarm": SARM-style sampling (9-frame sequences with specific pattern) +- "custom": Custom temporal sampling +""" + +import logging +from typing import Iterator, Optional +import numpy as np +import torch +from torch.utils.data import Sampler +import random + + +class TemporalSequenceSampler(Sampler): + """ + Generalized temporal sampler for reward models. + + Supports multiple sampling modes: + - "rewind": Consecutive frames from episode start to random end point (ReWiND: 32 consecutive frames) + - "sarm": 9-frame sequences with 1 initial + 8 consecutive (SARM) + - "custom": Custom temporal sampling + + Args: + dataset_from_index: Start indices of episodes + dataset_to_index: End indices of episodes + sequence_length: Maximum sequence length (for padding/subsampling) + stride: Frame stride for consecutive sampling (SARM mode) + shuffle: Whether to shuffle sampling order + seed: Random seed + sampling_mode: Sampling mode ("rewind", "sarm", or "custom") + min_frames: Minimum frames per episode (default: 3) + samples_per_epoch: Number of samples per epoch (default: 6400) + """ + + def __init__( + self, + dataset_from_index: np.ndarray, + dataset_to_index: np.ndarray, + sequence_length: int = 32, + stride: int = 1, + shuffle: bool = True, + seed: Optional[int] = None, + sampling_mode: str = "rewind", + min_frames: int = 3, + samples_per_epoch: int = 6400, + ): + self.dataset_from_index = np.array(dataset_from_index) + self.dataset_to_index = np.array(dataset_to_index) + self.sequence_length = sequence_length + self.stride = stride + self.shuffle = shuffle + self.sampling_mode = sampling_mode + self.min_frames = min_frames + self.samples_per_epoch = samples_per_epoch + + if sampling_mode not in ["rewind", "sarm", "custom"]: + raise ValueError(f"sampling_mode must be 'rewind', 'sarm', or 'custom', got {sampling_mode}") + + if seed is not None: + self.seed = seed + random.seed(seed) + np.random.seed(seed) + self.generator = torch.Generator().manual_seed(seed) + else: + self.generator = torch.Generator() + + # Compute valid episodes + self._compute_valid_episodes() + + logging.info( + f"TemporalSequenceSampler ({sampling_mode} mode): " + f"{len(self.valid_episodes)} valid episodes, " + f"{self.samples_per_epoch} samples per epoch" + ) + + def _compute_valid_episodes(self): + """Compute valid episodes based on minimum frame requirement.""" + self.valid_episodes = [] + + for ep_idx in range(len(self.dataset_from_index)): + ep_start = self.dataset_from_index[ep_idx] + ep_end = self.dataset_to_index[ep_idx] + episode_length = ep_end - ep_start + + # For SARM mode, need enough frames for the sequence pattern + if self.sampling_mode == "sarm": + # Need at least sequence_length * stride frames + min_required = self.sequence_length * self.stride + if episode_length >= min_required: + self.valid_episodes.append((ep_idx, ep_start, ep_end)) + else: + # For rewind mode, use min_frames + if episode_length >= self.min_frames: + self.valid_episodes.append((ep_idx, ep_start, ep_end)) + + self.valid_episodes = np.array(self.valid_episodes) + + def __len__(self) -> int: + return self.samples_per_epoch + + def __iter__(self) -> Iterator[int]: + """ + Yields ONE index per sample. + + Sampling behavior depends on mode: + + ReWiND mode: + 1. Pick random episode + 2. Pick random end frame (at least min_frames from start) + 3. Yield that end frame index + 4. Dataset loads from episode start to this end frame + + SARM mode: + 1. Pick random episode + 2. Pick random end frame (must allow sequence_length frames with stride) + 3. Yield that end frame index + 4. Dataset loads sequence_length frames with stride spacing ending at this frame + """ + for _ in range(self.samples_per_epoch): + # Randomly select an episode + ep_idx, ep_start, ep_end = self.valid_episodes[ + np.random.randint(0, len(self.valid_episodes)) + ] + + episode_length = ep_end - ep_start + + if self.sampling_mode == "rewind": + # ReWiND: Sample random end point (at least min_frames from start) + end_offset = np.random.randint(self.min_frames, episode_length + 1) + end_idx = ep_start + end_offset + + # Yield the end index (dataset will load from start to this point) + yield int(end_idx - 1) # -1 because end_idx is exclusive + + elif self.sampling_mode == "sarm": + # SARM: Sample end point that allows full sequence + # We need sequence_length frames with stride spacing + min_end_offset = self.sequence_length * self.stride + + if episode_length >= min_end_offset: + # Can sample anywhere from min_end_offset to episode_length + end_offset = np.random.randint(min_end_offset, episode_length + 1) + else: + # Episode is exactly the minimum length + end_offset = episode_length + + end_idx = ep_start + end_offset + + # Yield the end index (dataset will load sequence with stride) + yield int(end_idx - 1) # -1 because end_idx is exclusive + + else: # custom mode + # Default to rewind-style sampling + end_offset = np.random.randint(self.min_frames, episode_length + 1) + end_idx = ep_start + end_offset + yield int(end_idx - 1) + + +# Backwards compatibility alias +ReWiNDTemporalSampler = TemporalSequenceSampler + diff --git a/src/lerobot/datasets/video_sampler.py b/src/lerobot/datasets/video_sampler.py index 1aafa8676..f33e14b1b 100644 --- a/src/lerobot/datasets/video_sampler.py +++ b/src/lerobot/datasets/video_sampler.py @@ -32,21 +32,62 @@ import torch def sample_video_feature( video_feature: torch.Tensor, max_length: int = 32, - random_sample: bool = True -) -> torch.Tensor: + random_sample: bool = True, + remaining_length: int = None, + absolute_indices: torch.Tensor = None, + episode_length: int = None +) -> tuple[torch.Tensor, torch.Tensor]: """ - Sample or pad video features to a fixed length. + Sample or pad video features to a fixed length with progress targets. + + Progress normalization matches original ReWiND implementation: + - Progress = (position_in_sequence + 1) / remaining_trajectory_length + - remaining_trajectory_length = frames from first sampled frame to episode end + + Original ReWiND logic (dataset.py lines 12493-12499): + video_frames = frames[start_idx:end_idx] + full_frames = frames[start_idx:] # All frames from start to episode end + progress = [1, 2, ..., len(video_frames)] / len(full_frames) + + This ensures all sequences show increasing progress from near-zero, regardless + of where they're sampled from in the episode. + + Note: ReWiND uses consecutive frames loaded via observation_delta_indices. + When video_length > max_length, this function can subsample, but ReWiND + typically loads exactly max_length frames, so no subsampling occurs. Args: video_feature: Video features tensor (num_frames, feature_dim) max_length: Target sequence length - random_sample: If True, randomly sample frames. If False, uniformly sample. + random_sample: If True, randomly sample frames. If False, uniformly sample consecutive frames. + ReWiND uses False to preserve temporal order. + remaining_length: Remaining trajectory length from first frame to episode end + absolute_indices: Absolute frame indices in the episode (num_frames,) [for fallback] + episode_length: Total length of the episode [for fallback] Returns: - Sampled/padded video features (max_length, feature_dim) + Tuple of: + - Sampled/padded video features (max_length, feature_dim) + - Progress targets for each frame (max_length,) """ video_length = len(video_feature) + # Generate progress targets using ORIGINAL ReWiND formula + # Progress = (position_in_sequence + 1) / remaining_trajectory_length + if remaining_length is not None: + # CORRECT: Use remaining length from first frame to episode end + progress_indices = torch.arange(1, video_length + 1, dtype=torch.float32) + progress_targets = progress_indices / remaining_length + elif absolute_indices is not None and episode_length is not None: + # Fallback: Compute remaining length from first frame to episode end + first_frame_idx = absolute_indices[0].item() if isinstance(absolute_indices[0], torch.Tensor) else absolute_indices[0] + remaining_length_computed = episode_length - first_frame_idx + progress_indices = torch.arange(1, video_length + 1, dtype=torch.float32) + progress_targets = progress_indices / remaining_length_computed + else: + # Fallback: linear progress (for inference/testing) + progress_targets = torch.linspace(1.0/video_length, 1.0, video_length) + if video_length < max_length: # Pad with last frame padding_length = max_length - video_length @@ -54,35 +95,52 @@ def sample_video_feature( padding_frames = last_frame.repeat(padding_length, 1) video_feature = torch.cat([video_feature, padding_frames], dim=0) + # Pad progress with last progress value + last_progress = progress_targets[-1] + padding_progress = torch.full((padding_length,), last_progress) + progress_targets = torch.cat([progress_targets, padding_progress]) + elif video_length > max_length: if random_sample: - # Random sampling + # Random sampling (maintains temporal order via sorted indices) frame_idx = sorted(random.sample(range(video_length), max_length)) else: - # Uniform sampling + # Uniform sampling (consecutive frames with even spacing) frame_idx = np.linspace(0, video_length - 1, max_length, dtype=int) video_feature = video_feature[frame_idx] + progress_targets = progress_targets[frame_idx] - return video_feature + return video_feature, progress_targets def sample_reverse_video_feature( video_feature: torch.Tensor, max_length: int = 32, - random_sample: bool = True + random_sample: bool = True, + remaining_length: int = None, + absolute_indices: torch.Tensor = None, + episode_length: int = None ) -> Tuple[torch.Tensor, torch.Tensor]: """ - Sample video with reverse augmentation (video rewind). + Sample video with reverse augmentation (video rewind) - ORIGINAL REWIND LOGIC. - This function implements the video rewind augmentation described in the ReWiND paper. - It splits the video at a random point and reverses k frames from that point, creating - a trajectory that looks like it's making progress then regressing. This trains the - reward model to properly decrease rewards when the policy fails. + This implements the EXACT video rewind augmentation from the original ReWiND paper: + 1. Take forward sequence + 2. Append reversed frames from the END backwards + 3. Progress increases then decreases (simulating task completion then failure) + + Progress normalization matches original ReWiND (same as sample_video_feature). + Original ReWiND logic (dataset.py lines 12526-12541): + progress = [1, 2, ..., len(video_frames)] / len(full_frames) + reverse_progress = progress[::-1][1:selected_end_point] Args: video_feature: Video features tensor (num_frames, feature_dim) max_length: Target sequence length random_sample: If True, use random sampling for frame selection + remaining_length: Remaining trajectory length from first frame to episode end + absolute_indices: Absolute frame indices in the episode (num_frames,) [for fallback] + episode_length: Total length of the episode [for fallback] Returns: Tuple of: @@ -91,42 +149,40 @@ def sample_reverse_video_feature( """ video_length = len(video_feature) - # Sample split point (where to start reversing) - split_idx = random.randint(1, min(video_length - 1, max_length - 1)) - - # Sample how many frames to reverse (k in the paper) - max_reverse = min(split_idx, max_length - split_idx) - if max_reverse > 0: - reverse_length = random.randint(1, max_reverse) + # Generate forward progress targets using ORIGINAL ReWiND formula + # Progress = (position_in_sequence + 1) / remaining_trajectory_length + if remaining_length is not None: + # CORRECT: Use remaining length from first frame to episode end + progress_indices = torch.arange(1, video_length + 1, dtype=torch.float32) + forward_progress = progress_indices / remaining_length + elif absolute_indices is not None and episode_length is not None: + # Fallback: Compute remaining length from first frame to episode end + first_frame_idx = absolute_indices[0].item() if isinstance(absolute_indices[0], torch.Tensor) else absolute_indices[0] + remaining_length_computed = episode_length - first_frame_idx + progress_indices = torch.arange(1, video_length + 1, dtype=torch.float32) + forward_progress = progress_indices / remaining_length_computed else: - reverse_length = 0 + # Fallback: linear progress + forward_progress = torch.linspace(1.0/video_length, 1.0, video_length) - # Create rewound video - if reverse_length > 0: - # Forward part: frames 0 to split_idx - forward_frames = video_feature[:split_idx] - - # Reverse part: frames from split_idx-1 going backwards - reverse_frames = video_feature[split_idx - reverse_length:split_idx].flip(0) - - # Combine forward and reverse parts - rewound_video = torch.cat([forward_frames, reverse_frames], dim=0) - - # Create progress targets - # Forward part has increasing progress - forward_progress = torch.linspace(0, split_idx / video_length, split_idx) - # Reverse part has decreasing progress - reverse_progress = torch.linspace( - (split_idx - 1) / video_length, - (split_idx - reverse_length) / video_length, - reverse_length - ) - progress_targets = torch.cat([forward_progress, reverse_progress]) - - else: - # No reversal, just use original video - rewound_video = video_feature[:max_length] - progress_targets = torch.linspace(0, min(max_length, video_length) / video_length, len(rewound_video)) + # ORIGINAL LOGIC: Reverse from END backwards, then append to forward sequence + # Example: video=[A,B,C,D,E] -> reversed=[E,D,C,B,A] -> take some from reversed (skip first) + # Result: [A,B,C,D,E] + [D,C,B] = progress increases then decreases + + # Randomly select how many frames to reverse and append + selected_end_point = random.randint(2, min(video_length, max_length)) + + # Reverse the entire video and its progress + reversed_video = video_feature.flip(0) + reversed_progress = forward_progress.flip(0) + + # Take frames from reversed (skip the first frame which is the last frame of original) + reverse_frames = reversed_video[1:selected_end_point] + reverse_progress = reversed_progress[1:selected_end_point] + + # Concatenate forward + reversed (creates rewind effect) + rewound_video = torch.cat([video_feature, reverse_frames], dim=0) + progress_targets = torch.cat([forward_progress, reverse_progress], dim=0) # Pad or sample to target length if len(rewound_video) < max_length: @@ -136,13 +192,13 @@ def sample_reverse_video_feature( padding_frames = last_frame.repeat(padding_length, 1) rewound_video = torch.cat([rewound_video, padding_frames], dim=0) - # Extend progress targets (stay at last progress value) + # Pad progress with last progress value last_progress = progress_targets[-1] padding_progress = torch.full((padding_length,), last_progress) progress_targets = torch.cat([progress_targets, padding_progress]) elif len(rewound_video) > max_length: - # Sample frames + # Sample frames to fit max_length if random_sample: frame_idx = sorted(random.sample(range(len(rewound_video)), max_length)) else: @@ -152,3 +208,170 @@ def sample_reverse_video_feature( return rewound_video, progress_targets + +def sample_sarm_video_feature( + video_feature: torch.Tensor, + num_frames: int = 9, + frame_gap: int = 30, + random_sample: bool = True, + absolute_indices: torch.Tensor = None, + episode_length: int = None +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Sample video features for SARM (Stage-Aware Reward Modeling). + + SARM uses a specific pattern: + - 1 initial frame (from episode start) + - 8 consecutive frames with frame_gap spacing + + Progress normalization matches SARM implementation: + - Progress = absolute_frame_index / total_episode_length + + Args: + video_feature: Video features tensor (num_frames_available, feature_dim) + num_frames: Target number of frames (default: 9) + frame_gap: Gap between consecutive frames (default: 30, i.e., 1 second at 30fps) + random_sample: If True, use random sampling (not used for SARM's fixed pattern) + absolute_indices: Absolute frame indices in the episode (num_frames_available,) + episode_length: Total length of the episode + + Returns: + Tuple of: + - Sampled video features (num_frames, feature_dim) + - Progress targets for each frame (num_frames,) + """ + video_length = len(video_feature) + + # Generate progress targets based on relative position within sampled sequence + # Note: SARM paper uses subtask annotations (Equation 2: yt = Pk−1 + ᾱk * τt) + # Without annotations, we use linear progress relative to sequence position + if absolute_indices is not None and episode_length is not None: + # Compute relative progress: position within sequence / remaining trajectory + # This ensures progress starts near 0 and increases, not starting at 0.8 if sampled from end + first_frame_idx = absolute_indices[0].item() if isinstance(absolute_indices[0], torch.Tensor) else absolute_indices[0] + remaining_length = episode_length - first_frame_idx + + # Progress = (position_in_sequence + 1) / remaining_trajectory_length + progress_indices = torch.arange(1, video_length + 1, dtype=torch.float32) + progress_targets = progress_indices / remaining_length + else: + # Fallback: linear progress + progress_targets = torch.linspace(1.0/video_length, 1.0, video_length) + + # SARM pattern: first frame + (num_frames-1) consecutive frames with frame_gap + # The first frame should be from the beginning of the sequence + # The remaining frames are sampled with frame_gap spacing + + if video_length < num_frames: + # Not enough frames, pad with last frame + sampled_video = video_feature + sampled_progress = progress_targets + + padding_length = num_frames - video_length + last_frame = sampled_video[-1].unsqueeze(0) + padding_frames = last_frame.repeat(padding_length, 1) + sampled_video = torch.cat([sampled_video, padding_frames], dim=0) + + last_progress = sampled_progress[-1] + padding_progress = torch.full((padding_length,), last_progress) + sampled_progress = torch.cat([sampled_progress, padding_progress]) + + else: + # Sample frames: first frame + (num_frames-1) with frame_gap + # The indices should represent: [0, gap, 2*gap, 3*gap, ..., (num_frames-1)*gap] + # But we need to ensure we don't exceed video_length + + frame_indices = [0] # First frame + for i in range(1, num_frames): + idx = i * frame_gap + if idx >= video_length: + idx = video_length - 1 + frame_indices.append(idx) + + frame_indices = torch.tensor(frame_indices, dtype=torch.long) + sampled_video = video_feature[frame_indices] + sampled_progress = progress_targets[frame_indices] + + return sampled_video, sampled_progress + + +def sample_sarm_reverse_video_feature( + video_feature: torch.Tensor, + num_frames: int = 9, + frame_gap: int = 30, + random_sample: bool = True, + absolute_indices: torch.Tensor = None, + episode_length: int = None +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Sample video with reverse augmentation for SARM (rewind augmentation). + + Similar to ReWiND's rewind augmentation but adapted for SARM's frame pattern: + 1. Take forward sequence (1 initial + 8 consecutive) + 2. Append some reversed frames from the end backwards + 3. Progress increases then decreases + + Args: + video_feature: Video features tensor (num_frames_available, feature_dim) + num_frames: Target number of frames (default: 9) + frame_gap: Gap between consecutive frames (default: 30) + random_sample: If True, use random sampling for reverse section + absolute_indices: Absolute frame indices in the episode + episode_length: Total length of the episode + + Returns: + Tuple of: + - Rewound video features (num_frames, feature_dim) + - Progress targets for each frame (num_frames,) + """ + video_length = len(video_feature) + + # Generate forward progress targets (relative to sequence, not absolute) + if absolute_indices is not None and episode_length is not None: + # Use same relative progress as normal sampling + first_frame_idx = absolute_indices[0].item() if isinstance(absolute_indices[0], torch.Tensor) else absolute_indices[0] + remaining_length = episode_length - first_frame_idx + progress_indices = torch.arange(1, video_length + 1, dtype=torch.float32) + forward_progress = progress_indices / remaining_length + else: + forward_progress = torch.linspace(1.0/video_length, 1.0, video_length) + + # Sample forward sequence first + forward_video, forward_progress_sampled = sample_sarm_video_feature( + video_feature, num_frames, frame_gap, random_sample, absolute_indices, episode_length + ) + + # Randomly select how many frames to reverse and append + # For SARM, we append 2-4 reversed frames + num_reverse = random.randint(2, min(4, num_frames - 1)) + + # Reverse the video and progress + reversed_video = video_feature.flip(0) + reversed_progress = forward_progress.flip(0) + + # Take frames from reversed (skip the first frame which is the last frame of original) + reverse_frames = reversed_video[1:num_reverse+1] + reverse_progress = reversed_progress[1:num_reverse+1] + + # Concatenate forward + reversed (creates rewind effect) + rewound_video = torch.cat([forward_video, reverse_frames], dim=0) + progress_targets = torch.cat([forward_progress_sampled, reverse_progress], dim=0) + + # Trim to num_frames if necessary + if len(rewound_video) > num_frames: + # Keep the first num_frames + rewound_video = rewound_video[:num_frames] + progress_targets = progress_targets[:num_frames] + elif len(rewound_video) < num_frames: + # Pad if necessary + padding_length = num_frames - len(rewound_video) + last_frame = rewound_video[-1].unsqueeze(0) + padding_frames = last_frame.repeat(padding_length, 1) + rewound_video = torch.cat([rewound_video, padding_frames], dim=0) + + last_progress = progress_targets[-1] + padding_progress = torch.full((padding_length,), last_progress) + progress_targets = torch.cat([progress_targets, padding_progress]) + + return rewound_video, progress_targets + diff --git a/src/lerobot/policies/factory.py b/src/lerobot/policies/factory.py index 2f8989a11..13823742e 100644 --- a/src/lerobot/policies/factory.py +++ b/src/lerobot/policies/factory.py @@ -35,6 +35,7 @@ from lerobot.policies.pi05.configuration_pi05 import PI05Config from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.rewind.configuration_rewind import ReWiNDConfig from lerobot.policies.sac.configuration_sac import SACConfig +from lerobot.policies.sarm.configuration_sarm import SARMConfig from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig @@ -102,6 +103,14 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]: from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy return SmolVLAPolicy + elif name == "rewind": + from lerobot.policies.rewind.modeling_rewind import ReWiNDRewardModel + + return ReWiNDRewardModel + elif name == "sarm": + from lerobot.policies.sarm.modeling_sarm import SARMRewardModel + + return SARMRewardModel else: raise NotImplementedError(f"Policy with name {name} is not implemented.") @@ -300,6 +309,16 @@ def make_pre_post_processors( processors = make_rewind_pre_post_processors( config=policy_cfg, dataset_stats=kwargs.get("dataset_stats"), + dataset_meta=kwargs.get("dataset_meta"), + ) + + elif isinstance(policy_cfg, SARMConfig): + from lerobot.policies.sarm.processor_sarm import make_sarm_pre_post_processors + + processors = make_sarm_pre_post_processors( + config=policy_cfg, + dataset_stats=kwargs.get("dataset_stats"), + dataset_meta=kwargs.get("dataset_meta"), ) else: @@ -372,6 +391,7 @@ def make_policy( cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION} cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features} kwargs["config"] = cfg + if cfg.pretrained_path: # Load a pretrained policy and override the config if needed (for example, if there are inference-time diff --git a/src/lerobot/policies/rewind/configuration_rewind.py b/src/lerobot/policies/rewind/configuration_rewind.py index a3fa0a035..2280ca634 100644 --- a/src/lerobot/policies/rewind/configuration_rewind.py +++ b/src/lerobot/policies/rewind/configuration_rewind.py @@ -17,8 +17,8 @@ from dataclasses import dataclass, field from lerobot.configs.policies import PreTrainedConfig -from lerobot.optim import OptimizerConfig -from lerobot.optim.schedulers import LRSchedulerConfig +from lerobot.optim.optimizers import AdamWConfig +from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig @PreTrainedConfig.register_subclass("rewind") @@ -41,6 +41,8 @@ class ReWiNDConfig(PreTrainedConfig): # Temporal parameters max_length: int = 32 # Maximum video sequence length subsample_video: bool = True # Whether to pad/subsample videos to max_length + use_temporal_sampler: bool = True # Always enable temporal sequence loading + sequence_stride: int = 1 # Stride between frames when using temporal sampler # Training parameters batch_size: int = 64 @@ -58,8 +60,9 @@ class ReWiNDConfig(PreTrainedConfig): # Processor settings (for automatic preprocessing) image_key: str = "observation.images.top" # Key for images in dataset - task_description: str = "perform the task" # Default task description + task_description: str = "perform the task" # Default task description (used if no task field in data) encode_on_the_fly: bool = True # Encode images/text during training + use_dataset_task: bool = True # Use task descriptions from dataset (per-episode) # Features (required by PreTrainedPolicy) input_features: dict = field(default_factory=lambda: { @@ -85,23 +88,50 @@ class ReWiNDConfig(PreTrainedConfig): if self.dropout < 0 or self.dropout >= 1: raise ValueError(f"dropout must be in [0, 1), got {self.dropout}") - def get_optimizer_preset(self) -> OptimizerConfig: + def get_optimizer_preset(self) -> AdamWConfig: """Get default optimizer configuration for ReWiND training.""" - return OptimizerConfig( - name="adamw", + return AdamWConfig( lr=3e-4, weight_decay=1e-4, betas=(0.9, 0.999), eps=1e-8, - grad_clip_norm=1.0 ) - def get_scheduler_preset(self) -> LRSchedulerConfig: + def get_scheduler_preset(self) -> CosineDecayWithWarmupSchedulerConfig: """Get default learning rate scheduler configuration.""" - return LRSchedulerConfig( - name="cosine", - warmup_steps=1000, - T_max=100000, # Will be overridden by training steps - eta_min=3e-5 + return CosineDecayWithWarmupSchedulerConfig( + peak_lr=3e-4, + decay_lr=3e-5, + num_warmup_steps=1000, + num_decay_steps=100000, ) + + def validate_features(self) -> None: + pass + + @property + def observation_delta_indices(self) -> list[int]: + """Load all frames from episode start up to current frame. + + The sampler yields a random end point in each episode. + This property tells the dataset to load all frames from -(end_idx - start_idx) to 0. + + Since we don't know the exact window size in advance, we load up to max_length frames. + The dataset will automatically clamp to episode boundaries. + + Returns: + Indices for loading history: [-31, -30, ..., -1, 0] for max_length=32 + """ + # Load the last max_length frames (or up to episode start) + return list(range(-(self.max_length - 1), 1)) + + @property + def action_delta_indices(self) -> None: + """ReWiND is a reward model, not an action policy.""" + return None + + @property + def reward_delta_indices(self) -> None: + """ReWiND doesn't use delta rewards.""" + return None diff --git a/src/lerobot/policies/rewind/modeling_rewind.py b/src/lerobot/policies/rewind/modeling_rewind.py index cb0d7be9f..7adc131aa 100644 --- a/src/lerobot/policies/rewind/modeling_rewind.py +++ b/src/lerobot/policies/rewind/modeling_rewind.py @@ -25,6 +25,7 @@ import torch.nn.functional as F from PIL import Image from transformers import AutoModel, AutoTokenizer import torchvision.transforms as T +from torch import Tensor from lerobot.policies.rewind.configuration_rewind import ReWiNDConfig from lerobot.policies.pretrained import PreTrainedPolicy @@ -185,6 +186,7 @@ class ReWiNDRewardModel(PreTrainedPolicy): """ name = "rewind" + config_class = ReWiNDConfig def __init__(self, config: ReWiNDConfig, dataset_stats: dict | None = None): super().__init__(config, dataset_stats) @@ -478,6 +480,24 @@ class ReWiNDRewardModel(PreTrainedPolicy): """Return trainable parameters (only ReWiND transformer, not encoders).""" return self.rewind_transformer.parameters() + def get_optim_params(self): + """Return optimizer parameters for the policy.""" + return self.parameters() + + def reset(self): + """ + This method is required by PreTrainedPolicy but not used for reward models. + The reward model does not maintain state between episodes. + """ + pass + + def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: + """ + This method is required by PreTrainedPolicy but not used for reward models. + The rewind model is not an actor and does not produce action chunks. + """ + raise NotImplementedError("Rewind model does not predict action chunks") + def select_action(self, batch: dict[str, Tensor]) -> Tensor: """ This method is required by PreTrainedPolicy but not used for rewind. @@ -490,22 +510,29 @@ class ReWiNDRewardModel(PreTrainedPolicy): Forward pass compatible with lerobot training pipeline. Args: - batch: Dictionary containing: - - 'video_features': Pre-encoded video features (B, T, 768) + batch: Dictionary containing observation with: + - 'video_features': Pre-encoded video features (B, 768) or (B, T, 768) - 'text_features': Pre-encoded text features (B, 384) - - Optional: 'misaligned_video_features', 'misaligned_text_features' Returns: loss: Total training loss output_dict: Dictionary of loss components for logging """ - # Use train_step_fn but without optimizer step (that's handled by training pipeline) - video_features = batch['video_features'].to(self.device) - text_features = batch['text_features'].to(self.device) + # Extract from observation dict + observation = batch.get('observation', batch) + video_features = observation['video_features'].to(self.device) + text_features = observation['text_features'].to(self.device) batch_size = video_features.shape[0] max_length = self.config.max_length + # Handle both single frames (B, 768) and sequences (B, T, 768) + if video_features.dim() == 2: + # Single frames: replicate to create pseudo-sequences + video_features = video_features.unsqueeze(1).repeat(1, max_length, 1) # (B, max_length, 768) + + # Now video_features is (B, T, 768) where T might be > max_length + # Process videos (with potential rewind augmentation) import random from lerobot.datasets.video_sampler import sample_video_feature, sample_reverse_video_feature @@ -513,27 +540,59 @@ class ReWiNDRewardModel(PreTrainedPolicy): processed_videos = [] progress_targets = [] + # Extract episode metadata for correct progress normalization + absolute_frame_indices = observation.get('absolute_frame_indices', None) + episode_lengths = observation.get('episode_length', None) + remaining_lengths = observation.get('remaining_length', None) + for i in range(batch_size): + # Get metadata for this sample + current_absolute_indices = None + current_episode_length = None + current_remaining_length = None + + if absolute_frame_indices is not None: + if isinstance(absolute_frame_indices, list): + current_absolute_indices = absolute_frame_indices[i] + else: + current_absolute_indices = absolute_frame_indices + + if episode_lengths is not None: + if isinstance(episode_lengths, torch.Tensor) and episode_lengths.dim() > 0: + current_episode_length = episode_lengths[i].item() + else: + current_episode_length = episode_lengths.item() if isinstance(episode_lengths, torch.Tensor) else episode_lengths + + if remaining_lengths is not None: + if isinstance(remaining_lengths, torch.Tensor) and remaining_lengths.dim() > 0: + current_remaining_length = remaining_lengths[i].item() + else: + current_remaining_length = remaining_lengths.item() if isinstance(remaining_lengths, torch.Tensor) else remaining_lengths + if random.random() < 0.5: # 50% chance of rewind - # Apply video rewind augmentation + # Apply video rewind augmentation (now returns tuple) rewound_video, progress = sample_reverse_video_feature( video_features[i], max_length=max_length, - random_sample=True + random_sample=False, # Use consecutive frames, not random sampling + remaining_length=current_remaining_length, + absolute_indices=current_absolute_indices, + episode_length=current_episode_length ) - processed_videos.append(rewound_video) - progress_targets.append(progress) + processed_videos.append(rewound_video.to(self.device)) + progress_targets.append(progress.to(self.device)) else: - # Normal video sampling - sampled_video = sample_video_feature( + # Normal video sampling (now returns tuple with progress targets) + sampled_video, progress = sample_video_feature( video_features[i], max_length=max_length, - random_sample=True + random_sample=False, # Use consecutive frames, not random sampling + remaining_length=current_remaining_length, + absolute_indices=current_absolute_indices, + episode_length=current_episode_length ) - processed_videos.append(sampled_video) - # Linear progress from 0 to 1 - progress = torch.linspace(0, 1, max_length, device=self.device) - progress_targets.append(progress) + processed_videos.append(sampled_video.to(self.device)) + progress_targets.append(progress.to(self.device)) processed_videos = torch.stack(processed_videos) progress_targets = torch.stack(progress_targets) @@ -549,8 +608,8 @@ class ReWiNDRewardModel(PreTrainedPolicy): total_loss = progress_loss output_dict = {'progress_loss': progress_loss.item()} - # Compute misaligned loss if requested - if random.random() < 0.5: # 50% chance of adding misalignment loss + # Compute misaligned loss if requested (20% probability to match original) + if random.random() < 0.2: # 20% chance of adding misalignment loss (original ReWiND uses 20%) if 'misaligned_video_features' in batch and 'misaligned_text_features' in batch: misaligned_videos = batch['misaligned_video_features'].to(self.device) misaligned_texts = batch['misaligned_text_features'].to(self.device) @@ -560,15 +619,18 @@ class ReWiNDRewardModel(PreTrainedPolicy): misaligned_videos = processed_videos[shuffle_idx] misaligned_texts = text_features - # Sample misaligned videos + # Sample misaligned videos (function now returns tuple) + # For misaligned pairs, we don't need correct progress targets (will be set to 0) misaligned_videos_sampled = [] for i in range(batch_size): - sampled = sample_video_feature( + sampled, _ = sample_video_feature( misaligned_videos[i], max_length=max_length, - random_sample=True + random_sample=True, + absolute_indices=None, + episode_length=None ) - misaligned_videos_sampled.append(sampled) + misaligned_videos_sampled.append(sampled.to(self.device)) misaligned_videos_sampled = torch.stack(misaligned_videos_sampled) misaligned_loss = compute_misaligned_loss( diff --git a/src/lerobot/policies/rewind/processor_rewind.py b/src/lerobot/policies/rewind/processor_rewind.py index 4837e7bad..d8f3b331c 100644 --- a/src/lerobot/policies/rewind/processor_rewind.py +++ b/src/lerobot/policies/rewind/processor_rewind.py @@ -20,16 +20,20 @@ import numpy as np import torch from lerobot.policies.rewind.configuration_rewind import ReWiNDConfig -from lerobot.policies.processor import ( +from lerobot.processor import ( ProcessorStep, PolicyProcessorPipeline, PolicyAction, DeviceProcessorStep, + AddBatchDimensionProcessorStep, ) -from lerobot.policies.processor.transition import ( +from lerobot.processor.converters import ( policy_action_to_transition, transition_to_policy_action, ) +from lerobot.processor.pipeline import PipelineFeatureType +from lerobot.processor.core import EnvTransition, TransitionKey +from lerobot.configs.types import PolicyFeature, FeatureType from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME class ReWiNDEncodingProcessorStep(ProcessorStep): @@ -37,6 +41,16 @@ class ReWiNDEncodingProcessorStep(ProcessorStep): ProcessorStep that encodes images and text for ReWiND training. This step handles the DINO (image) and MiniLM (text) encoding that ReWiND needs. + + Supports both single-frame and temporal sequence encoding: + - Single frame: (B, C, H, W) → (B, 768) video features + - Temporal sequence: (B, T, C, H, W) → (B, T, 768) video features + + To use temporal sequences, configure the dataset with delta_timestamps for your image key. + For example, to encode sequences of 32 frames: + delta_timestamps = { + "observation.images.top": [i / fps for i in range(-15, 17)] # 32 frames centered on current + } """ def __init__( @@ -44,11 +58,13 @@ class ReWiNDEncodingProcessorStep(ProcessorStep): config: ReWiNDConfig, image_key: str | None = None, task_description: str | None = None, + dataset_meta = None, ): super().__init__() self.config = config self.image_key = image_key or config.image_key self.task_description = task_description or config.task_description + self.dataset_meta = dataset_meta # Store dataset metadata for episode info # Initialize encoders self._init_encoders() @@ -79,106 +95,267 @@ class ReWiNDEncodingProcessorStep(ProcessorStep): self.device = device - def __call__(self, batch: Dict[str, Any]) -> Dict[str, Any]: - """Encode images and text in the batch.""" - # Extract images - if self.image_key in batch: - images = batch[self.image_key] + def __call__(self, transition: EnvTransition) -> EnvTransition: + """Encode images and text in the transition.""" + self._current_transition = transition.copy() if hasattr(transition, 'copy') else dict(transition) + new_transition = self._current_transition + + observation = new_transition.get(TransitionKey.OBSERVATION) + if observation is None or not isinstance(observation, dict): + # If no observation, just return the transition as-is + return new_transition + + # Extract images from observation and encode + # For ReWiND, we need to load the sequence from episode start to current frame + batch_size = 1 + if self.image_key in observation: + image = observation[self.image_key] # Handle different image formats - if isinstance(images, torch.Tensor): - images = images.cpu().numpy() + if isinstance(image, torch.Tensor): + image = image.cpu().numpy() - # Encode images - video_features = self._encode_images(images) - batch['video_features'] = video_features + # Check if we have temporal sequences or single frames + # Temporal sampling: Load from episode start to current frame + # This will be handled by the dataset if configured with delta_timestamps + # Otherwise, we just encode the single frame + video_features = self._encode_images_batch(image) + observation['video_features'] = video_features + + # Get batch size from the encoded features + batch_size = video_features.shape[0] + + # Get task descriptions - check if 'task' field exists in the transition + # This allows per-episode task descriptions (e.g., for datasets with multiple tasks) + task_descriptions = None + if 'task' in new_transition: + task_descriptions = new_transition['task'] + # Convert to list if it's a single string + if isinstance(task_descriptions, str): + task_descriptions = [task_descriptions] * batch_size # Encode text - batch_size = len(batch.get('video_features', batch.get(list(batch.keys())[0]))) - task_descriptions = [self.task_description] * batch_size - text_features = self._encode_text(task_descriptions) - batch['text_features'] = text_features + if task_descriptions is not None: + # Encode per-sample task descriptions + text_features = self._encode_text_batch_list(task_descriptions) + else: + # Fall back to config task description if no task field in transition + text_features = self._encode_text_batch(self.task_description, batch_size) - return batch + observation['text_features'] = text_features + + # Compute episode metadata for progress normalization (used by ReWiND) + # We need to pass absolute frame indices and total episode length for correct progress calculation + if self.dataset_meta is not None and 'episode_index' in new_transition and 'index' in new_transition: + episode_indices = new_transition['episode_index'] + frame_indices = new_transition['index'] + + # Handle both single samples and batches + if isinstance(episode_indices, (int, np.integer)): + ep_idx = int(episode_indices) + frame_idx = int(frame_indices) + ep_start = self.dataset_meta.episodes[ep_idx]["dataset_from_index"] + ep_end = self.dataset_meta.episodes[ep_idx]["dataset_to_index"] + episode_length = ep_end - ep_start + + # For temporal sequences with observation_delta_indices: + # If we loaded frames using delta_indices (e.g., [-31, -30, ..., 0]), + # we need to compute the absolute indices of those frames + # The current frame is at frame_idx, and we loaded max_length frames before it + if 'video_features' in observation and len(observation['video_features'].shape) > 1: + # We have a temporal sequence + num_loaded_frames = observation['video_features'].shape[0] if observation['video_features'].dim() == 2 else observation['video_features'].shape[1] + # Absolute indices: from (frame_idx - num_frames + 1) to frame_idx + start_idx = max(ep_start, frame_idx - num_loaded_frames + 1) + absolute_indices = torch.arange(start_idx, frame_idx + 1) + observation['absolute_frame_indices'] = absolute_indices + # Compute remaining length: from first loaded frame to episode end + observation['remaining_length'] = ep_end - start_idx + else: + # Single frame + observation['absolute_frame_indices'] = torch.tensor([frame_idx]) + # Remaining length from this frame to episode end + observation['remaining_length'] = ep_end - frame_idx + + observation['episode_length'] = episode_length + else: + # Batch case + absolute_indices_list = [] + episode_lengths = [] + remaining_lengths = [] + for ep_idx, frame_idx in zip(episode_indices, frame_indices): + ep_idx = int(ep_idx.item() if hasattr(ep_idx, 'item') else ep_idx) + frame_idx = int(frame_idx.item() if hasattr(frame_idx, 'item') else frame_idx) + ep_start = self.dataset_meta.episodes[ep_idx]["dataset_from_index"] + ep_end = self.dataset_meta.episodes[ep_idx]["dataset_to_index"] + episode_length = ep_end - ep_start + episode_lengths.append(episode_length) + + # Compute absolute indices for this sample + if 'video_features' in observation and len(observation['video_features'].shape) > 1: + num_loaded_frames = observation['video_features'].shape[1] + start_idx = max(ep_start, frame_idx - num_loaded_frames + 1) + absolute_indices = torch.arange(start_idx, frame_idx + 1) + absolute_indices_list.append(absolute_indices) + # Remaining length from first loaded frame to episode end + remaining_lengths.append(ep_end - start_idx) + else: + absolute_indices_list.append(torch.tensor([frame_idx])) + # Remaining length from this frame to episode end + remaining_lengths.append(ep_end - frame_idx) + + observation['absolute_frame_indices'] = absolute_indices_list + observation['episode_length'] = torch.tensor(episode_lengths) + observation['remaining_length'] = torch.tensor(remaining_lengths) + + new_transition[TransitionKey.OBSERVATION] = observation + return new_transition @torch.no_grad() - def _encode_images(self, images: np.ndarray) -> torch.Tensor: - """Encode images using DINO.""" + def _encode_images_batch(self, images: np.ndarray) -> torch.Tensor: + """Encode a batch of images (with optional temporal dimension) using DINO. + + Args: + images: Batched images with shape: + - (B, C, H, W) for single frames, or + - (B, T, C, H, W) for temporal sequences + + Returns: + Encoded feature vectors with shape (B, 768) or (B, T, 768) + """ from lerobot.policies.rewind.modeling_rewind import dino_load_image - # Handle single frame case - if len(images.shape) == 4: - images = images[:, np.newaxis, ...] - single_frame = True + # Check if we have temporal dimension + has_temporal = len(images.shape) == 5 + + if has_temporal: + # Shape: (B, T, C, H, W) + batch_size, seq_length = images.shape[0], images.shape[1] + + # Reshape to (B*T, C, H, W) to process all frames at once + images = images.reshape(batch_size * seq_length, *images.shape[2:]) + elif len(images.shape) == 4: + # Shape: (B, C, H, W) + batch_size = images.shape[0] + seq_length = 1 else: - single_frame = False + raise ValueError(f"Expected 4D (B, C, H, W) or 5D (B, T, C, H, W) input, got shape {images.shape}") - batch_size, num_frames, C, H, W = images.shape - - # Convert to (B, T, H, W, C) - if C == 3: - images = images.transpose(0, 1, 3, 4, 2) + # Convert to list of (H, W, C) images + num_frames = images.shape[0] + if images.shape[1] in [1, 3]: # Channel first (N, C, H, W) + images_list = [images[i].transpose(1, 2, 0) for i in range(num_frames)] + else: # Channel last (N, H, W, C) + images_list = [images[i] for i in range(num_frames)] + # Encode each frame (can batch process with DINO for efficiency) all_embeddings = [] - - for video in images: - video_embeddings = [] + for i in range(0, num_frames, self.config.dino_batch_size): + batch_imgs = images_list[i:i + self.config.dino_batch_size] - # Convert to uint8 - if video.dtype != np.uint8: - video = (video * 255).astype(np.uint8) if video.max() <= 1.0 else video.astype(np.uint8) - - frames = [frame for frame in video] - episode_images_dino = [dino_load_image(frame) for frame in frames] - - # Batch process - for i in range(0, len(episode_images_dino), self.config.dino_batch_size): - dino_batch = torch.cat(episode_images_dino[i:i + self.config.dino_batch_size]) - dino_batch = dino_batch.to(self.device) - embeddings = self.dino_encoder(dino_batch).squeeze().detach().cpu() + # Prepare images for DINO + dino_inputs = [] + for img in batch_imgs: + # Handle single channel + if img.shape[-1] == 1: + img = np.repeat(img, 3, axis=-1) - if embeddings.dim() == 1: - embeddings = embeddings.unsqueeze(0) + # Convert to uint8 + if img.dtype != np.uint8: + img = (img * 255).astype(np.uint8) if img.max() <= 1.0 else img.astype(np.uint8) - video_embeddings.append(embeddings) + dino_inputs.append(dino_load_image(img)) - video_embeddings = torch.cat(video_embeddings) - all_embeddings.append(video_embeddings) + # Batch encode + dino_batch = torch.cat(dino_inputs).to(self.device) + embeddings = self.dino_encoder(dino_batch).detach().cpu() + + # Handle single frame case + if embeddings.dim() == 1: + embeddings = embeddings.unsqueeze(0) + + all_embeddings.append(embeddings) - result = torch.stack(all_embeddings) + # Concatenate all embeddings + all_embeddings = torch.cat(all_embeddings) # (B*T, 768) - if single_frame: - result = result.squeeze(1) + # Reshape back if temporal + if has_temporal: + all_embeddings = all_embeddings.reshape(batch_size, seq_length, -1) # (B, T, 768) - return result + return all_embeddings @torch.no_grad() - def _encode_text(self, text: List[str]) -> torch.Tensor: - """Encode text using MiniLM.""" + def _encode_text_batch(self, text: str, batch_size: int) -> torch.Tensor: + """Encode a text string using MiniLM and replicate for batch. + + Args: + text: Text string to encode + batch_size: Batch size to replicate for + + Returns: + Encoded feature vectors with shape (B, 384) + """ from lerobot.policies.rewind.modeling_rewind import mean_pooling - all_embeddings = [] + encoded_input = self.minilm_tokenizer( + text, padding=True, truncation=True, return_tensors="pt" + ).to(self.device) - for i in range(0, len(text), self.config.batch_size): - batch_text = text[i:i + self.config.batch_size] - - encoded_input = self.minilm_tokenizer( - batch_text, padding=True, truncation=True, return_tensors="pt" - ).to(self.device) - - model_output = self.minilm_model(**encoded_input) - text_embeddings = mean_pooling(model_output, encoded_input["attention_mask"]) - - all_embeddings.append(text_embeddings.cpu()) + model_output = self.minilm_model(**encoded_input) + text_embedding = mean_pooling(model_output, encoded_input["attention_mask"]) + text_embedding = text_embedding.squeeze().cpu() - result = torch.cat(all_embeddings) + # Replicate for batch (B, 384) + text_embedding = text_embedding.unsqueeze(0).repeat(batch_size, 1) - return result + return text_embedding + + @torch.no_grad() + def _encode_text_batch_list(self, text_list: list[str]) -> torch.Tensor: + """Encode a list of text strings using MiniLM (one per sample). + + Args: + text_list: List of text strings to encode + + Returns: + Encoded feature vectors with shape (B, 384) + """ + from lerobot.policies.rewind.modeling_rewind import mean_pooling + + # Encode all texts in the batch at once + encoded_input = self.minilm_tokenizer( + text_list, padding=True, truncation=True, return_tensors="pt" + ).to(self.device) + + model_output = self.minilm_model(**encoded_input) + text_embeddings = mean_pooling(model_output, encoded_input["attention_mask"]) + text_embeddings = text_embeddings.cpu() + + return text_embeddings + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + """ + Adds video_features and text_features to the observation features. + """ + # Add the encoded features + features[PipelineFeatureType.OBSERVATION]['video_features'] = PolicyFeature( + type=FeatureType.VISUAL, + shape=(768,) # DINO embedding dimension + ) + features[PipelineFeatureType.OBSERVATION]['text_features'] = PolicyFeature( + type=FeatureType.LANGUAGE, + shape=(384,) # MiniLM embedding dimension + ) + return features def make_rewind_pre_post_processors( config: ReWiNDConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None, + dataset_meta = None, ) -> tuple[ PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], PolicyProcessorPipeline[PolicyAction, PolicyAction], @@ -189,19 +366,23 @@ def make_rewind_pre_post_processors( The pre-processing pipeline: 1. Encodes images with DINO (768-dim) 2. Encodes text with MiniLM (384-dim) - 3. Moves data to device + 3. Computes remaining episode length for progress normalization + 4. Adds batch dimension + 5. Moves data to device - The post-processing pipeline is minimal (just moves to CPU). + The post-processing pipeline moves data back to CPU. Args: config: ReWiND configuration dataset_stats: Dataset statistics (not used for ReWiND) + dataset_meta: Dataset metadata for computing episode remaining length Returns: Tuple of (preprocessor, postprocessor) pipelines """ input_steps = [ - ReWiNDEncodingProcessorStep(config=config), + AddBatchDimensionProcessorStep(), + ReWiNDEncodingProcessorStep(config=config, dataset_meta=dataset_meta), DeviceProcessorStep(device=config.device), ] diff --git a/src/lerobot/policies/sarm/__init__.py b/src/lerobot/policies/sarm/__init__.py new file mode 100644 index 000000000..4cda62bd2 --- /dev/null +++ b/src/lerobot/policies/sarm/__init__.py @@ -0,0 +1,38 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from lerobot.policies.sarm.configuration_sarm import SARMConfig +from lerobot.policies.sarm.modeling_sarm import ( + SARMRewardModel, + SARMTransformer, + compute_stage_loss, + compute_progress_loss, +) +from lerobot.policies.sarm.processor_sarm import ( + SARMEncodingProcessorStep, + make_sarm_pre_post_processors, +) + +__all__ = [ + "SARMConfig", + "SARMRewardModel", + "SARMTransformer", + "compute_stage_loss", + "compute_progress_loss", + "SARMEncodingProcessorStep", + "make_sarm_pre_post_processors", +] + diff --git a/src/lerobot/policies/sarm/configuration_sarm.py b/src/lerobot/policies/sarm/configuration_sarm.py new file mode 100644 index 000000000..67a497106 --- /dev/null +++ b/src/lerobot/policies/sarm/configuration_sarm.py @@ -0,0 +1,165 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field + +from lerobot.configs.policies import PreTrainedConfig +from lerobot.optim.optimizers import AdamWConfig +from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig + + +@PreTrainedConfig.register_subclass("sarm") +@dataclass +class SARMConfig(PreTrainedConfig): + """Configuration class for SARM (Stage-Aware Reward Modeling). + + SARM is a dual-head reward model that jointly predicts: + 1. High-level task stage (classification) + 2. Fine-grained progress within each stage (regression) + + It uses CLIP for visual encoding and supports joint state input. + """ + + # Visual encoding parameters + image_dim: int = 512 # CLIP embedding dimension + num_frames: int = 9 # 1 initial + 8 consecutive frames + frame_gap: int = 30 # Frame gap between consecutive frames (at 30 fps = 1 second) + + # Text encoding parameters + text_dim: int = 384 # MiniLM embedding dimension + + # Joint state parameters + state_dim: int | None = None # Auto-detected from dataset if None + use_joint_state: bool = True # Whether to use joint state input + + # Architecture parameters + hidden_dim: int = 768 # Transformer hidden dimension + num_heads: int = 12 # Number of attention heads + num_layers: int = 8 # Number of transformer layers + num_stages: int = 5 # Number of task stages for classification + + # Temporal parameters + max_length: int = 9 # Maximum video sequence length (should match num_frames) + use_temporal_sampler: bool = True # Always enable temporal sequence loading + sampling_mode: str = "sarm" # Sampling mode: "sarm" or "rewind" + + # Training parameters + batch_size: int = 64 + clip_batch_size: int = 64 # Batch size for CLIP encoding + gradient_checkpointing: bool = False # Enable gradient checkpointing + dropout: float = 0.1 # Dropout rate + + # RA-BC (Reward-Aligned Behavior Cloning) parameters + enable_rabc: bool = False # Enable RA-BC weighted loss + rabc_kappa: float = 0.01 # Hard threshold for high-quality samples + rabc_epsilon: float = 1e-6 # Small constant to avoid division by zero + chunk_length: int = 25 # Action chunk length for computing progress deltas + + # Model loading + pretrained_model_path: str | None = None + + # Device settings + device: str | None = None + + # Processor settings + image_key: str = "observation.images.top" # Key for images in dataset + task_description: str = "perform the task" # Default task description + encode_on_the_fly: bool = True # Encode images/text during training + use_dataset_task: bool = True # Use task descriptions from dataset + + # Features (required by PreTrainedPolicy) + input_features: dict = field(default_factory=lambda: { + "video_features": {"shape": [9, 512], "dtype": "float32"}, + "text_features": {"shape": [384], "dtype": "float32"}, + "state_features": {"shape": [9, 14], "dtype": "float32"} # Example: 7 DOF × 2 arms + }) + output_features: dict = field(default_factory=lambda: { + "stage": {"shape": [1], "dtype": "int64"}, + "progress": {"shape": [1], "dtype": "float32"} + }) + + def __post_init__(self): + super().__post_init__() + + # Validate configuration + if self.hidden_dim % self.num_heads != 0: + raise ValueError( + f"hidden_dim ({self.hidden_dim}) must be divisible by num_heads ({self.num_heads})" + ) + + if self.max_length != self.num_frames: + raise ValueError( + f"max_length ({self.max_length}) must equal num_frames ({self.num_frames})" + ) + + if self.dropout < 0 or self.dropout >= 1: + raise ValueError(f"dropout must be in [0, 1), got {self.dropout}") + + if self.num_stages < 2: + raise ValueError(f"num_stages must be at least 2, got {self.num_stages}") + + if self.sampling_mode not in ["sarm", "rewind", "custom"]: + raise ValueError( + f"sampling_mode must be 'sarm', 'rewind', or 'custom', got {self.sampling_mode}" + ) + + def get_optimizer_preset(self) -> AdamWConfig: + """Get default optimizer configuration for SARM training.""" + return AdamWConfig( + lr=5e-5, + weight_decay=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + ) + + def get_scheduler_preset(self) -> CosineDecayWithWarmupSchedulerConfig: + """Get default learning rate scheduler configuration.""" + return CosineDecayWithWarmupSchedulerConfig( + peak_lr=5e-5, + decay_lr=5e-6, + num_warmup_steps=500, + num_decay_steps=50000, + ) + + def validate_features(self) -> None: + """Validate input and output features.""" + pass + + @property + def observation_delta_indices(self) -> list[int]: + """Load frames for SARM temporal sampling. + + SARM uses 9 frames: 1 initial frame + 8 consecutive frames with frame_gap spacing. + + Returns: + Indices for loading: [-(8*frame_gap), ..., -frame_gap, 0] + """ + # For SARM: we need the initial frame (from episode start) plus 8 consecutive frames + # The dataset will load relative to current frame + # We'll handle the "initial frame" logic in the processor + # For now, load the last 8*frame_gap frames + return list(range(-self.frame_gap * (self.num_frames - 1), 1, self.frame_gap)) + + @property + def action_delta_indices(self) -> None: + """SARM is a reward model, not an action policy.""" + return None + + @property + def reward_delta_indices(self) -> None: + """SARM doesn't use delta rewards.""" + return None + diff --git a/src/lerobot/policies/sarm/modeling_sarm.py b/src/lerobot/policies/sarm/modeling_sarm.py new file mode 100644 index 000000000..1dd53e07d --- /dev/null +++ b/src/lerobot/policies/sarm/modeling_sarm.py @@ -0,0 +1,808 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import List, Union, Dict, Optional +import random + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from PIL import Image +from transformers import AutoModel, AutoTokenizer, CLIPModel, CLIPProcessor +from torch import Tensor + +from lerobot.policies.sarm.configuration_sarm import SARMConfig +from lerobot.policies.pretrained import PreTrainedPolicy + + +def mean_pooling(model_output, attention_mask): + """ + Mean pooling - take attention mask into account for correct averaging. + + Args: + model_output: Model output containing token embeddings. + attention_mask: Attention mask for the tokens. + + Returns: + Mean-pooled embeddings. + """ + token_embeddings = model_output[0] # First element contains all token embeddings + input_mask_expanded = ( + attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() + ) + return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp( + input_mask_expanded.sum(1), min=1e-9 + ) + + +class SARMTransformer(nn.Module): + """ + SARM Transformer model for stage-aware reward prediction. + + This model has a dual-head architecture: + 1. Stage estimator: Predicts the high-level task stage (classification) + 2. Subtask estimator: Predicts fine-grained progress within the stage (regression) + + The subtask estimator is conditioned on the stage prediction. + """ + + def __init__( + self, + video_dim: int = 512, # CLIP dimension + text_dim: int = 384, # MiniLM dimension + state_dim: int = 14, # Joint state dimension + hidden_dim: int = 768, + num_heads: int = 12, + num_layers: int = 8, + num_stages: int = 5, + max_length: int = 9, + dropout: float = 0.1, + use_joint_state: bool = True + ): + super().__init__() + self.hidden_dim = hidden_dim + self.max_length = max_length + self.num_stages = num_stages + self.use_joint_state = use_joint_state + + # Project video, text, and state to common dimension + self.video_proj = nn.Linear(video_dim, hidden_dim) + self.text_proj = nn.Linear(text_dim, hidden_dim) + if use_joint_state: + self.state_proj = nn.Linear(state_dim, hidden_dim) + + # Position embedding only for the first frame + self.first_pos_embed = nn.Parameter(torch.randn(1, hidden_dim)) + + # Transformer encoder (shared backbone) + encoder_layer = nn.TransformerEncoderLayer( + d_model=hidden_dim, + nhead=num_heads, + dim_feedforward=hidden_dim * 4, + dropout=dropout, + batch_first=True + ) + self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) + + # Stage estimator head (classification) + self.stage_head = nn.Sequential( + nn.Linear(hidden_dim, hidden_dim // 2), + nn.LayerNorm(hidden_dim // 2), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim // 2, num_stages) + ) + + # Subtask estimator head (regression, conditioned on stage) + # Takes concatenated [features, stage_embedding] + self.stage_embedding = nn.Embedding(num_stages, hidden_dim // 4) + subtask_input_dim = hidden_dim + hidden_dim // 4 + self.subtask_head = nn.Sequential( + nn.Linear(subtask_input_dim, hidden_dim // 2), + nn.LayerNorm(hidden_dim // 2), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim // 2, 1), + nn.Sigmoid() + ) + + # Attention mask for causal self-attention + self.register_buffer("attention_mask", None, persistent=False) + + def _get_attention_mask(self, seq_length: int, device: torch.device) -> torch.Tensor: + """Generate or retrieve cached causal attention mask.""" + if self.attention_mask is None or self.attention_mask.shape[0] != seq_length: + # Create causal mask (upper triangular with -inf) + mask = nn.Transformer.generate_square_subsequent_mask(seq_length, device=device) + self.attention_mask = mask + return self.attention_mask + + def forward( + self, + video_frames: torch.Tensor, + text_embed: torch.Tensor, + state_features: Optional[torch.Tensor] = None + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Forward pass through the SARM transformer. + + Args: + video_frames: Video frame embeddings (batch_size, seq_len, video_dim) + text_embed: Text embeddings (batch_size, text_dim) + state_features: Joint state features (batch_size, seq_len, state_dim) + + Returns: + Tuple of: + - Stage logits for each frame (batch_size, seq_len, num_stages) + - Stage probabilities (batch_size, seq_len, num_stages) + - Progress predictions for each frame (batch_size, seq_len, 1) + """ + batch_size = video_frames.shape[0] + + # Project inputs to common dimension + video_embed = self.video_proj(video_frames) # [batch_size, seq_len, hidden_dim] + text_embed = self.text_proj(text_embed).unsqueeze(1) # [batch_size, 1, hidden_dim] + + # Add joint state if provided + if self.use_joint_state and state_features is not None: + state_embed = self.state_proj(state_features) # [batch_size, seq_len, hidden_dim] + # Fuse video and state features (simple addition) + video_embed = video_embed + state_embed + + # Add positional embedding to first video frame + video_embed[:, 0] += self.first_pos_embed + + # Combine sequence: [text, video_frames] + sequence = torch.cat([text_embed, video_embed], dim=1) + + # Get causal attention mask + seq_length = sequence.shape[1] + attention_mask = self._get_attention_mask(seq_length, sequence.device) + + # Pass through transformer with causal masking + transformed = self.transformer(sequence, mask=attention_mask, is_causal=True) + + # Get frame features (exclude text token) + frame_features = transformed[:, 1:] # [batch_size, seq_len, hidden_dim] + + # Stage estimation + stage_logits = self.stage_head(frame_features) # [batch_size, seq_len, num_stages] + stage_probs = F.softmax(stage_logits, dim=-1) # [batch_size, seq_len, num_stages] + + # Get predicted stage indices + stage_indices = torch.argmax(stage_probs, dim=-1) # [batch_size, seq_len] + + # Get stage embeddings for conditioning + stage_embeds = self.stage_embedding(stage_indices) # [batch_size, seq_len, hidden_dim//4] + + # Concatenate frame features with stage embeddings + conditioned_features = torch.cat([frame_features, stage_embeds], dim=-1) + + # Subtask progress estimation (conditioned on stage) + progress_preds = self.subtask_head(conditioned_features) # [batch_size, seq_len, 1] + + return stage_logits, stage_probs, progress_preds + + +class SARMRewardModel(PreTrainedPolicy): + """ + SARM Reward Model for stage-aware task completion rewards. + + This model combines: + - CLIP for encoding video frames + - MiniLM for encoding text descriptions + - SARMTransformer for predicting task stage and progress + - Optional RA-BC (Reward-Aligned Behavior Cloning) for weighted training + """ + + name = "sarm" + config_class = SARMConfig + + def __init__(self, config: SARMConfig, dataset_stats: dict | None = None): + super().__init__(config, dataset_stats) + self.config = config + self.dataset_stats = dataset_stats + self.device = torch.device(config.device if config.device else "cuda" if torch.cuda.is_available() else "cpu") + + # Initialize CLIP encoder for images + logging.info("Loading CLIP encoder...") + self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") + self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") + self.clip_model.to(self.device) + self.clip_model.eval() + + # Initialize MiniLM encoder for text + logging.info("Loading MiniLM encoder...") + self.minilm_tokenizer = AutoTokenizer.from_pretrained( + "sentence-transformers/all-MiniLM-L12-v2" + ) + self.minilm_model = AutoModel.from_pretrained( + "sentence-transformers/all-MiniLM-L12-v2" + ) + self.minilm_model.to(self.device) + self.minilm_model.eval() + + # Auto-detect state_dim from input_features if not explicitly set + if config.state_dim is None: + # Look for "observation.state" or "state" in input_features + if "observation.state" in config.input_features: + config.state_dim = config.input_features["observation.state"].shape[0] + logging.info(f"Auto-detected state_dim={config.state_dim} from input_features['observation.state']") + elif "state" in config.input_features: + config.state_dim = config.input_features["state"].shape[0] + logging.info(f"Auto-detected state_dim={config.state_dim} from input_features['state']") + else: + config.state_dim = 14 + logging.warning(f"Could not find state in input_features, using default state_dim={config.state_dim}") + + # Initialize SARM transformer + self.sarm_transformer = SARMTransformer( + video_dim=config.image_dim, + text_dim=config.text_dim, + state_dim=config.state_dim, + hidden_dim=config.hidden_dim, + num_heads=config.num_heads, + num_layers=config.num_layers, + num_stages=config.num_stages, + max_length=config.max_length, + dropout=config.dropout, + use_joint_state=config.use_joint_state + ) + self.sarm_transformer.to(self.device) + + # RA-BC running statistics (for weighted loss) + if config.enable_rabc: + self.register_buffer("rabc_mean", torch.tensor(0.0)) + self.register_buffer("rabc_m2", torch.tensor(0.0)) + self.register_buffer("rabc_count", torch.tensor(0)) + + logging.info(f"SARM Reward Model initialized on {self.device}") + + def to(self, device): + """Override to method to ensure all components move together.""" + super().to(device) + self.device = device if isinstance(device, torch.device) else torch.device(device) + self.clip_model.to(device) + self.minilm_model.to(device) + self.sarm_transformer.to(device) + return self + + @torch.no_grad() + def encode_images(self, images: np.ndarray) -> np.ndarray: + """ + Encode video frames using CLIP. + + Args: + images: Video frames with shape (num_videos, num_frames, H, W, C) in uint8. + Can also be (num_frames, H, W, C) for a single video. + + Returns: + Encoded image features (num_videos, num_frames, 512) or (num_frames, 512). + """ + # Handle single video case + single_video = False + if len(images.shape) == 4: + images = images[np.newaxis, ...] + single_video = True + + assert len(images.shape) == 5, f"Expected 5D input (num_videos, num_frames, H, W, C), got {images.shape}" + + all_embeddings = [] + + for video in images: + video_embeddings = [] + + # Convert frames to PIL images for CLIP processor + frames = [] + for frame in video: + if frame.shape[0] == 3: # Channel first + frame = frame.transpose(1, 2, 0) + if frame.dtype != np.uint8: + frame = (frame * 255).astype(np.uint8) if frame.max() <= 1.0 else frame.astype(np.uint8) + frames.append(Image.fromarray(frame)) + + # Batch process frames with CLIP + for i in range(0, len(frames), self.config.clip_batch_size): + batch = frames[i:i + self.config.clip_batch_size] + inputs = self.clip_processor(images=batch, return_tensors="pt", padding=True) + inputs = {k: v.to(self.device) for k, v in inputs.items()} + + # Get image embeddings from CLIP + embeddings = self.clip_model.get_image_features(**inputs).detach().cpu() + + # Handle single frame case + if embeddings.dim() == 1: + embeddings = embeddings.unsqueeze(0) + + video_embeddings.append(embeddings) + + video_embeddings = torch.cat(video_embeddings) + all_embeddings.append(video_embeddings) + + result = torch.stack(all_embeddings).numpy() + + if single_video: + result = result[0] + + return result + + @torch.no_grad() + def encode_text(self, text: Union[str, List[str]]) -> np.ndarray: + """ + Encode text using MiniLM. + + Args: + text: Text string or list of text strings. + + Returns: + Encoded text features (batch_size, 384) or (384,) for single text. + """ + if isinstance(text, str): + text = [text] + single_text = True + else: + single_text = False + + # Process in batches + all_embeddings = [] + for i in range(0, len(text), self.config.batch_size): + batch_text = text[i:i + self.config.batch_size] + + encoded_input = self.minilm_tokenizer( + batch_text, padding=True, truncation=True, return_tensors="pt" + ).to(self.device) + + model_output = self.minilm_model(**encoded_input) + text_embeddings = mean_pooling(model_output, encoded_input["attention_mask"]) + + all_embeddings.append(text_embeddings.cpu()) + + result = torch.cat(all_embeddings).numpy() + + if single_text: + result = result[0] + + return result + + @torch.no_grad() + def calculate_rewards( + self, + text_embeddings: Union[np.ndarray, torch.Tensor], + video_embeddings: Union[np.ndarray, torch.Tensor], + state_features: Optional[Union[np.ndarray, torch.Tensor]] = None, + return_all_frames: bool = False, + return_stages: bool = False + ) -> Union[np.ndarray, tuple]: + """ + Calculate rewards for given text, video, and state representations. + + Args: + text_embeddings: Encoded text representations (batch_size, 384) + video_embeddings: Encoded video representations (batch_size, num_frames, 512) + state_features: Joint state features (batch_size, num_frames, state_dim) + return_all_frames: If True, return rewards for all frames + return_stages: If True, also return stage predictions + + Returns: + If return_stages=False: + Reward values (batch_size,) or (batch_size, num_frames) + If return_stages=True: + Tuple of (rewards, stage_probs) + """ + # Convert to tensors if needed + if isinstance(text_embeddings, np.ndarray): + text_embeddings = torch.tensor(text_embeddings, dtype=torch.float32) + if isinstance(video_embeddings, np.ndarray): + video_embeddings = torch.tensor(video_embeddings, dtype=torch.float32) + if state_features is not None and isinstance(state_features, np.ndarray): + state_features = torch.tensor(state_features, dtype=torch.float32) + + # Handle single sample case + if text_embeddings.dim() == 1: + text_embeddings = text_embeddings.unsqueeze(0) + video_embeddings = video_embeddings.unsqueeze(0) + if state_features is not None: + state_features = state_features.unsqueeze(0) + single_sample = True + else: + single_sample = False + + # Process in batches + all_rewards = [] + all_stage_probs = [] + + for i in range(0, len(video_embeddings), self.config.batch_size): + batch_texts = text_embeddings[i:i + self.config.batch_size].to(self.device) + batch_videos = video_embeddings[i:i + self.config.batch_size].to(self.device) + batch_states = None + if state_features is not None: + batch_states = state_features[i:i + self.config.batch_size].to(self.device) + + # Get predictions + stage_logits, stage_probs, progress_preds = self.sarm_transformer( + batch_videos.float(), batch_texts.float(), batch_states.float() if batch_states is not None else None + ) + + if return_all_frames: + all_rewards.append(progress_preds.squeeze(-1).cpu()) + else: + # Return only last frame reward + all_rewards.append(progress_preds[:, -1, 0].cpu()) + + if return_stages: + all_stage_probs.append(stage_probs.cpu()) + + rewards = torch.cat(all_rewards).numpy() + + if single_sample: + rewards = rewards[0] if not return_all_frames else rewards[0] + + if return_stages: + stage_probs = torch.cat(all_stage_probs).numpy() + if single_sample: + stage_probs = stage_probs[0] + return rewards, stage_probs + + return rewards + + def _update_rabc_stats(self, progress_deltas: torch.Tensor): + """Update running statistics for RA-BC using Welford's online algorithm.""" + if not self.config.enable_rabc: + return + + for delta in progress_deltas: + self.rabc_count += 1 + delta_val = delta.item() + delta_mean = delta_val - self.rabc_mean + self.rabc_mean += delta_mean / self.rabc_count + delta_m2 = delta_val - self.rabc_mean + self.rabc_m2 += delta_mean * delta_m2 + + def _compute_rabc_weights(self, progress_deltas: torch.Tensor) -> torch.Tensor: + """Compute RA-BC weights for progress deltas.""" + if not self.config.enable_rabc or self.rabc_count < 2: + return torch.ones_like(progress_deltas) + + # Get running statistics + mean = max(self.rabc_mean.item(), 0.0) # Clamp mean to non-negative + variance = self.rabc_m2 / (self.rabc_count - 1) + std = torch.sqrt(variance).item() + + # Compute soft weights + lower_bound = mean - 2 * std + upper_bound = mean + 2 * std + weights = (progress_deltas - lower_bound) / (4 * std + self.config.rabc_epsilon) + weights = torch.clamp(weights, 0.0, 1.0) + + # Apply hard threshold + high_quality_mask = progress_deltas > self.config.rabc_kappa + weights = torch.where(high_quality_mask, torch.ones_like(weights), weights) + + return weights + + def load_pretrained_checkpoint(self, checkpoint_path: str, strict: bool = False): + """Load pretrained model weights from a checkpoint file.""" + logging.info(f"Loading pretrained checkpoint from {checkpoint_path}") + + checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=False) + + # Handle different checkpoint formats + if "model_state_dict" in checkpoint: + state_dict = checkpoint["model_state_dict"] + else: + state_dict = checkpoint + + # Load only the SARMTransformer weights + missing_keys, unexpected_keys = self.sarm_transformer.load_state_dict(state_dict, strict=strict) + + if missing_keys: + logging.warning(f"Missing keys when loading checkpoint: {missing_keys}") + if unexpected_keys: + logging.warning(f"Unexpected keys when loading checkpoint: {unexpected_keys}") + + logging.info("Checkpoint loaded successfully") + + def train(self, mode: bool = True): + """Set training mode. Note: CLIP and MiniLM encoders always stay in eval mode.""" + super().train(mode) + # Keep encoders in eval mode + self.clip_model.eval() + self.minilm_model.eval() + # Only transformer can be trained + self.sarm_transformer.train(mode) + return self + + def eval(self): + """Set evaluation mode.""" + return self.train(False) + + def parameters(self): + """Return trainable parameters (only SARM transformer, not encoders).""" + return self.sarm_transformer.parameters() + + def get_optim_params(self): + """Return optimizer parameters for the policy.""" + return self.parameters() + + def reset(self): + """Required by PreTrainedPolicy but not used for reward models.""" + pass + + def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: + """Required by PreTrainedPolicy but not used for reward models.""" + raise NotImplementedError("SARM model does not predict action chunks") + + def select_action(self, batch: dict[str, Tensor]) -> Tensor: + """Required by PreTrainedPolicy but not used for SARM.""" + raise NotImplementedError("SARM model does not select actions") + + def forward(self, batch): + """ + Forward pass compatible with lerobot training pipeline. + + Args: + batch: Dictionary containing observation with: + - 'video_features': Pre-encoded video features (B, T, 512) + - 'text_features': Pre-encoded text features (B, 384) + - 'state_features': Joint state features (B, T, state_dim) + + Returns: + loss: Total training loss + output_dict: Dictionary of loss components for logging + """ + # Extract from observation dict + observation = batch.get('observation', batch) + video_features = observation['video_features'].to(self.device) + text_features = observation['text_features'].to(self.device) + state_features = observation.get('state_features', None) + if state_features is not None: + state_features = state_features.to(self.device) + + batch_size = video_features.shape[0] + max_length = self.config.num_frames + + # Handle both single frames and sequences + if video_features.dim() == 2: + # Single frames: replicate to create pseudo-sequences + video_features = video_features.unsqueeze(1).repeat(1, max_length, 1) + + if state_features is not None and state_features.dim() == 2: + # Single state: replicate to match sequence length + state_features = state_features.unsqueeze(1).repeat(1, max_length, 1) + + # Apply rewind augmentation (following SARM paper: up to 4 reversed frames) + # Note: video_features are already sampled by dataset (9 frames with 30-frame gaps) + # We just need to compute progress targets and optionally apply rewind + + processed_videos = [] + processed_states = [] + progress_targets = [] + + # Extract episode metadata for correct progress normalization + absolute_frame_indices = observation.get('absolute_frame_indices', None) + episode_lengths = observation.get('episode_length', None) + remaining_lengths = observation.get('remaining_length', None) + + for i in range(batch_size): + # Get metadata for this sample + current_absolute_indices = None + current_episode_length = None + current_remaining_length = None + + if absolute_frame_indices is not None: + if isinstance(absolute_frame_indices, list): + current_absolute_indices = absolute_frame_indices[i] + else: + current_absolute_indices = absolute_frame_indices + + if episode_lengths is not None: + if isinstance(episode_lengths, torch.Tensor) and episode_lengths.dim() > 0: + current_episode_length = episode_lengths[i].item() + else: + current_episode_length = episode_lengths.item() if isinstance(episode_lengths, torch.Tensor) else episode_lengths + + if remaining_lengths is not None: + if isinstance(remaining_lengths, torch.Tensor) and remaining_lengths.dim() > 0: + current_remaining_length = remaining_lengths[i].item() + else: + current_remaining_length = remaining_lengths.item() if isinstance(remaining_lengths, torch.Tensor) else remaining_lengths + + # Compute progress targets directly from metadata (frames already loaded by dataset) + # Progress = (position_in_sequence + 1) / remaining_trajectory_length + if current_remaining_length is not None and current_remaining_length > 0: + # Correct: relative progress from first loaded frame to episode end + progress_indices = torch.arange(1, max_length + 1, dtype=torch.float32, device=self.device) + progress = progress_indices / current_remaining_length + else: + # Fallback: linear progress (when metadata is not available) + logging.warning(f"Sample {i}: No remaining_length metadata, using linear progress fallback") + progress = torch.linspace(1.0/max_length, 1.0, max_length, device=self.device) + + # Apply rewind augmentation with 50% probability (following SARM paper) + # Paper specifies: "appending up to four frames from earlier timestamps with reversed order" + if random.random() < 0.5: + # Rewind: append 2-4 reversed frames, trim to max_length + num_reverse = random.randint(2, min(4, max_length - 1)) + + # Reverse video and progress + reversed_video = video_features[i].flip(0) + reversed_progress = progress.flip(0) + + # Take frames from reversed (skip first which is last of original) + reverse_frames = reversed_video[1:num_reverse+1] + reverse_progress = reversed_progress[1:num_reverse+1] + + # Concatenate forward + reversed + rewound_video = torch.cat([video_features[i], reverse_frames], dim=0) + rewound_progress = torch.cat([progress, reverse_progress], dim=0) + + # Trim to max_length + rewound_video = rewound_video[:max_length] + rewound_progress = rewound_progress[:max_length] + + processed_videos.append(rewound_video) + progress_targets.append(rewound_progress) + + # Process state features if available + if state_features is not None: + reversed_state = state_features[i].flip(0) + reverse_state_frames = reversed_state[1:num_reverse+1] + rewound_state = torch.cat([state_features[i], reverse_state_frames], dim=0) + rewound_state = rewound_state[:max_length] + processed_states.append(rewound_state) + else: + # Normal: use frames as-is with forward progress + processed_videos.append(video_features[i]) + progress_targets.append(progress) + + # Process state features if available + if state_features is not None: + processed_states.append(state_features[i]) + + # Ensure all sequences have the same length before stacking + # (sampling functions should return max_length, but double-check) + validated_videos = [] + validated_progress = [] + for i, (vid, prog) in enumerate(zip(processed_videos, progress_targets)): + if len(vid) != max_length: + logging.warning(f"Sample {i}: video length {len(vid)} != {max_length}, padding/trimming") + if len(vid) < max_length: + # Pad + padding = max_length - len(vid) + vid = torch.cat([vid, vid[-1:].repeat(padding, 1)]) + prog = torch.cat([prog, torch.full((padding,), prog[-1], device=prog.device)]) + else: + # Trim + vid = vid[:max_length] + prog = prog[:max_length] + validated_videos.append(vid) + validated_progress.append(prog) + + # Stack processed features + processed_videos = torch.stack(validated_videos) + progress_targets = torch.stack(validated_progress) + + # Ensure progress_targets has the same shape as progress_preds + # progress_preds is (batch_size, num_frames, 1) + # progress_targets is (batch_size, num_frames) -> add last dimension + if progress_targets.dim() == 2: + progress_targets = progress_targets.unsqueeze(-1) # (batch_size, num_frames, 1) + + if state_features is not None and len(processed_states) > 0: + processed_states = torch.stack(processed_states) + else: + processed_states = None + + # Get predictions + stage_logits, stage_probs, progress_preds = self.sarm_transformer( + processed_videos, text_features, processed_states + ) + + # Compute progress loss using augmented targets + progress_loss = F.mse_loss(progress_preds, progress_targets) + + # For now, just use progress loss since we don't have stage annotations + # In future: can add stage loss when we have annotated stage labels + total_loss = progress_loss + + output_dict = { + 'progress_loss': progress_loss.item(), + } + + # Compute misaligned loss (following SARM paper and ReWiND) + # "To improve video-language alignment, task descriptions are occasionally perturbed" + if random.random() < 0.2: # 20% probability (matching ReWiND) + # Create misaligned pairs by shuffling text features + shuffle_idx = torch.randperm(batch_size, device=self.device) + misaligned_texts = text_features[shuffle_idx] + + # Get predictions for misaligned pairs (should predict zero progress) + _, _, misaligned_preds = self.sarm_transformer( + processed_videos, misaligned_texts, processed_states + ) + + # Target is zero progress for misaligned pairs + target_zeros = torch.zeros_like(misaligned_preds) + misaligned_loss = F.mse_loss(misaligned_preds, target_zeros) + + # Add to total loss + total_loss = total_loss + misaligned_loss + output_dict['misaligned_loss'] = misaligned_loss.item() + + # RA-BC weighted loss (if enabled) + if self.config.enable_rabc: + # Compute progress deltas (simplified: use consecutive frame differences) + progress_deltas = progress_preds[:, 1:, 0] - progress_preds[:, :-1, 0] + progress_deltas = progress_deltas.mean(dim=1) # Average over sequence + + # Update running statistics + self._update_rabc_stats(progress_deltas) + + # Compute weights + weights = self._compute_rabc_weights(progress_deltas) + + # Apply weighted loss + weighted_loss = (total_loss * weights.mean()).sum() + total_loss = weighted_loss + + # Add final total loss to output dict + output_dict['total_loss'] = total_loss.item() + + return total_loss, output_dict + + +# Loss utilities +def compute_stage_loss( + stage_logits: torch.Tensor, + target_stages: torch.Tensor +) -> torch.Tensor: + """ + Compute stage classification loss. + + Args: + stage_logits: Stage predictions (batch_size, num_frames, num_stages) + target_stages: Target stage indices (batch_size, num_frames) + + Returns: + Cross-entropy loss + """ + batch_size, num_frames, num_stages = stage_logits.shape + stage_logits_flat = stage_logits.reshape(-1, num_stages) + target_stages_flat = target_stages.reshape(-1) + + loss = F.cross_entropy(stage_logits_flat, target_stages_flat) + return loss + + +def compute_progress_loss( + progress_preds: torch.Tensor, + target_progress: torch.Tensor +) -> torch.Tensor: + """ + Compute progress regression loss. + + Args: + progress_preds: Progress predictions (batch_size, num_frames, 1) + target_progress: Target progress values (batch_size, num_frames, 1) + + Returns: + Mean squared error loss + """ + loss = F.mse_loss(progress_preds, target_progress) + return loss + diff --git a/src/lerobot/policies/sarm/processor_sarm.py b/src/lerobot/policies/sarm/processor_sarm.py new file mode 100644 index 000000000..9c4fedc53 --- /dev/null +++ b/src/lerobot/policies/sarm/processor_sarm.py @@ -0,0 +1,552 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Dict, Any, List, Optional +import numpy as np +import torch +from PIL import Image + +from lerobot.policies.sarm.configuration_sarm import SARMConfig +from lerobot.processor import ( + ProcessorStep, + PolicyProcessorPipeline, + PolicyAction, + DeviceProcessorStep, + AddBatchDimensionProcessorStep, +) +from lerobot.processor.converters import ( + policy_action_to_transition, + transition_to_policy_action, +) +from lerobot.processor.pipeline import PipelineFeatureType +from lerobot.processor.core import EnvTransition, TransitionKey +from lerobot.configs.types import PolicyFeature, FeatureType +from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME + + +class SARMEncodingProcessorStep(ProcessorStep): + """ + ProcessorStep that encodes images and text for SARM training. + + This step handles: + - CLIP (image) encoding + - MiniLM (text) encoding + - Joint state normalization + + Supports temporal sequences: (B, T, C, H, W) → (B, T, 512) video features + """ + + def __init__( + self, + config: SARMConfig, + image_key: str | None = None, + task_description: str | None = None, + dataset_meta = None, + dataset_stats: dict | None = None, + ): + super().__init__() + self.config = config + self.image_key = image_key or config.image_key + self.task_description = task_description or config.task_description + self.dataset_meta = dataset_meta + self.dataset_stats = dataset_stats + + # Initialize encoders + self._init_encoders() + + def _init_encoders(self): + """Initialize CLIP and MiniLM encoders.""" + from transformers import AutoModel, AutoTokenizer, CLIPModel, CLIPProcessor + + device = torch.device( + self.config.device if self.config.device + else "cuda" if torch.cuda.is_available() else "cpu" + ) + + logging.info("Initializing CLIP encoder for SARM...") + self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") + self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") + self.clip_model.to(device) + self.clip_model.eval() + + logging.info("Initializing MiniLM encoder for SARM...") + self.minilm_tokenizer = AutoTokenizer.from_pretrained( + "sentence-transformers/all-MiniLM-L12-v2" + ) + self.minilm_model = AutoModel.from_pretrained( + "sentence-transformers/all-MiniLM-L12-v2" + ) + self.minilm_model.to(device) + self.minilm_model.eval() + + self.device = device + + def __call__(self, transition: EnvTransition) -> EnvTransition: + """Encode images, text, and normalize states in the transition.""" + from lerobot.processor.core import TransitionKey + + self._current_transition = transition.copy() if hasattr(transition, 'copy') else dict(transition) + new_transition = self._current_transition + + observation = new_transition.get(TransitionKey.OBSERVATION) + if observation is None or not isinstance(observation, dict): + return new_transition + + # Extract and encode images + batch_size = 1 + if self.image_key in observation: + image = observation[self.image_key] + + # Handle different image formats + if isinstance(image, torch.Tensor): + image = image.cpu().numpy() + + # Encode images + video_features = self._encode_images_batch(image) + observation['video_features'] = video_features + + # Get batch size from encoded features + batch_size = video_features.shape[0] + + # Extract and normalize joint states + if self.config.use_joint_state: + # Look for "state" or "observation.state" in observation + state_key = None + state_data = None + + if "state" in observation: + state_key = "state" + state_data = observation["state"] + elif "observation.state" in observation: + state_key = "observation.state" + state_data = observation["observation.state"] + + if state_data is not None: + if isinstance(state_data, torch.Tensor): + state_data = state_data.cpu().numpy() + + # Normalize if stats available + if self.dataset_stats and state_key in self.dataset_stats: + mean = self.dataset_stats[state_key]['mean'] + std = self.dataset_stats[state_key]['std'] + state_data = (state_data - mean) / (std + 1e-8) + + observation['state_features'] = torch.tensor(state_data, dtype=torch.float32) + else: + # Create dummy state features if not found + if 'video_features' in observation: + num_frames = observation['video_features'].shape[0] if observation['video_features'].dim() == 2 else observation['video_features'].shape[1] + observation['state_features'] = torch.zeros(batch_size, num_frames, self.config.state_dim) + + # Get task descriptions + task_descriptions = None + if 'task' in new_transition: + task_descriptions = new_transition['task'] + if isinstance(task_descriptions, str): + task_descriptions = [task_descriptions] * batch_size + + # Encode text + if task_descriptions is not None: + text_features = self._encode_text_batch_list(task_descriptions) + else: + text_features = self._encode_text_batch(self.task_description, batch_size) + + observation['text_features'] = text_features + + # Compute episode metadata for progress normalization + # Note: Processor runs BEFORE batching, so we need to extract from raw dataset structure + # The dataset provides episode_index and index in the raw item + + # Extract index and episode_index from COMPLEMENTARY_DATA + episode_index = None + frame_index = None + + # Primary location: COMPLEMENTARY_DATA (confirmed from debug logs) + if TransitionKey.COMPLEMENTARY_DATA in new_transition: + comp_data = new_transition[TransitionKey.COMPLEMENTARY_DATA] + if isinstance(comp_data, dict): + frame_index = comp_data.get('index') + episode_index = comp_data.get('episode_index') + + # Fallback: check other locations + if frame_index is None and TransitionKey.OBSERVATION in new_transition: + obs = new_transition[TransitionKey.OBSERVATION] + if isinstance(obs, dict): + frame_index = obs.get('index') + if episode_index is None: + episode_index = obs.get('episode_index') + + # If we have frame_index but no episode_index, compute it from episode boundaries + if frame_index is not None and episode_index is None and self.dataset_meta is not None: + # Convert to int if needed + if isinstance(frame_index, torch.Tensor): + frame_idx = frame_index.item() if frame_index.numel() == 1 else frame_index[0].item() + else: + frame_idx = int(frame_index) + + # Search through episodes to find which one this frame belongs to + for ep_idx in range(len(self.dataset_meta.episodes)): + ep_start = self.dataset_meta.episodes[ep_idx]["dataset_from_index"] + ep_end = self.dataset_meta.episodes[ep_idx]["dataset_to_index"] + if ep_start <= frame_idx < ep_end: + episode_index = ep_idx + break + + if self.dataset_meta is not None and frame_index is not None: + # Handle batch processing + is_batch = isinstance(frame_index, torch.Tensor) and frame_index.numel() > 1 + + if is_batch: + # Batch case: process multiple samples at once + batch_size = frame_index.shape[0] + frame_indices = frame_index.cpu().numpy() if isinstance(frame_index, torch.Tensor) else np.array(frame_index) + + # Ensure at least 1D + if frame_indices.ndim == 0: + frame_indices = np.array([frame_indices.item()]) + + # Compute episode_index for each frame if not provided + if episode_index is None: + episode_indices = [] + for frame_idx in frame_indices: + frame_idx = int(frame_idx) + # Search through episodes + found = False + for ep_idx in range(len(self.dataset_meta.episodes)): + ep_start = self.dataset_meta.episodes[ep_idx]["dataset_from_index"] + ep_end = self.dataset_meta.episodes[ep_idx]["dataset_to_index"] + if ep_start <= frame_idx < ep_end: + episode_indices.append(ep_idx) + found = True + break + if not found: + episode_indices.append(0) # Fallback + episode_indices = np.array(episode_indices) + else: + episode_indices = episode_index.cpu().numpy() if isinstance(episode_index, torch.Tensor) else np.array(episode_index) + # Ensure at least 1D + if episode_indices.ndim == 0: + episode_indices = np.array([episode_indices.item()]) + + # CRITICAL FIX: If we have a single episode_index but multiple frame_indices, + # compute the correct episode for each frame (they might be from different episodes) + if len(episode_indices) == 1 and len(frame_indices) > 1: + episode_indices = [] + for frame_idx in frame_indices: + frame_idx = int(frame_idx) + # Search through episodes + found = False + for ep_idx in range(len(self.dataset_meta.episodes)): + ep_start = self.dataset_meta.episodes[ep_idx]["dataset_from_index"] + ep_end = self.dataset_meta.episodes[ep_idx]["dataset_to_index"] + if ep_start <= frame_idx < ep_end: + episode_indices.append(ep_idx) + found = True + break + if not found: + episode_indices.append(0) # Fallback + episode_indices = np.array(episode_indices) + + # Compute metadata for each sample in batch + absolute_indices_list = [] + remaining_lengths = [] + episode_lengths = [] + + # Convert to list for safe iteration + episode_indices_list = episode_indices.tolist() if hasattr(episode_indices, 'tolist') else list(episode_indices) + frame_indices_list = frame_indices.tolist() if hasattr(frame_indices, 'tolist') else list(frame_indices) + + for i, (ep_idx, frame_idx) in enumerate(zip(episode_indices_list, frame_indices_list)): + ep_idx = int(ep_idx) + frame_idx = int(frame_idx) + ep_start = self.dataset_meta.episodes[ep_idx]["dataset_from_index"] + ep_end = self.dataset_meta.episodes[ep_idx]["dataset_to_index"] + episode_length = ep_end - ep_start + episode_lengths.append(episode_length) + + # Compute absolute indices for this sample + if 'video_features' in observation and observation['video_features'].dim() > 1: + num_loaded_frames = observation['video_features'].shape[1] # (batch, seq_len, features) + frame_gap = self.config.frame_gap if hasattr(self.config, 'frame_gap') else 1 + + if frame_gap > 1: + absolute_indices = [] + for j in range(num_loaded_frames): + offset = -(num_loaded_frames - 1 - j) * frame_gap + idx = max(ep_start, frame_idx + offset) + absolute_indices.append(idx) + absolute_indices = torch.tensor(absolute_indices) + else: + start_idx = max(ep_start, frame_idx - num_loaded_frames + 1) + absolute_indices = torch.arange(start_idx, frame_idx + 1) + + absolute_indices_list.append(absolute_indices) + remaining_lengths.append(ep_end - absolute_indices[0].item()) + else: + absolute_indices_list.append(torch.tensor([frame_idx])) + remaining_lengths.append(ep_end - frame_idx) + + observation['absolute_frame_indices'] = absolute_indices_list + observation['remaining_length'] = torch.tensor(remaining_lengths) + observation['episode_length'] = torch.tensor(episode_lengths) + else: + # Single sample case + if isinstance(frame_index, torch.Tensor): + frame_idx = frame_index.item() + else: + frame_idx = int(frame_index) + + # Get episode_index + if episode_index is None: + # Search through episodes + for ep_idx in range(len(self.dataset_meta.episodes)): + ep_start = self.dataset_meta.episodes[ep_idx]["dataset_from_index"] + ep_end = self.dataset_meta.episodes[ep_idx]["dataset_to_index"] + if ep_start <= frame_idx < ep_end: + episode_index = ep_idx + break + if episode_index is None: + episode_index = 0 # Fallback + + ep_idx = int(episode_index) if not isinstance(episode_index, int) else episode_index + ep_start = self.dataset_meta.episodes[ep_idx]["dataset_from_index"] + ep_end = self.dataset_meta.episodes[ep_idx]["dataset_to_index"] + episode_length = ep_end - ep_start + + # Compute absolute indices + if 'video_features' in observation and observation['video_features'].dim() > 0: + num_loaded_frames = observation['video_features'].shape[0] + frame_gap = self.config.frame_gap if hasattr(self.config, 'frame_gap') else 1 + + if frame_gap > 1: + absolute_indices = [] + for i in range(num_loaded_frames): + offset = -(num_loaded_frames - 1 - i) * frame_gap + idx = max(ep_start, frame_idx + offset) + absolute_indices.append(idx) + absolute_indices = torch.tensor(absolute_indices) + else: + start_idx = max(ep_start, frame_idx - num_loaded_frames + 1) + absolute_indices = torch.arange(start_idx, frame_idx + 1) + + observation['absolute_frame_indices'] = absolute_indices + observation['remaining_length'] = ep_end - absolute_indices[0].item() + else: + observation['absolute_frame_indices'] = torch.tensor([frame_idx]) + observation['remaining_length'] = ep_end - frame_idx + + observation['episode_length'] = episode_length + + new_transition[TransitionKey.OBSERVATION] = observation + return new_transition + + @torch.no_grad() + def _encode_images_batch(self, images: np.ndarray) -> torch.Tensor: + """Encode a batch of images using CLIP. + + Args: + images: Batched images with shape: + - (B, C, H, W) for single frames, or + - (B, T, C, H, W) for temporal sequences + + Returns: + Encoded feature vectors with shape (B, 512) or (B, T, 512) + """ + # Check if we have temporal dimension + has_temporal = len(images.shape) == 5 + + if has_temporal: + # Shape: (B, T, C, H, W) + batch_size, seq_length = images.shape[0], images.shape[1] + + # Reshape to (B*T, C, H, W) to process all frames at once + images = images.reshape(batch_size * seq_length, *images.shape[2:]) + elif len(images.shape) == 4: + # Shape: (B, C, H, W) + batch_size = images.shape[0] + seq_length = 1 + else: + raise ValueError(f"Expected 4D (B, C, H, W) or 5D (B, T, C, H, W) input, got shape {images.shape}") + + # Convert to list of PIL images + num_frames = images.shape[0] + images_list = [] + for i in range(num_frames): + img = images[i] + if img.shape[0] in [1, 3]: # Channel first (C, H, W) + img = img.transpose(1, 2, 0) + + # Handle single channel + if img.shape[-1] == 1: + img = np.repeat(img, 3, axis=-1) + + # Convert to uint8 + if img.dtype != np.uint8: + img = (img * 255).astype(np.uint8) if img.max() <= 1.0 else img.astype(np.uint8) + + images_list.append(Image.fromarray(img)) + + # Encode each batch + all_embeddings = [] + for i in range(0, num_frames, self.config.clip_batch_size): + batch_imgs = images_list[i:i + self.config.clip_batch_size] + + # Process with CLIP + inputs = self.clip_processor(images=batch_imgs, return_tensors="pt", padding=True) + inputs = {k: v.to(self.device) for k, v in inputs.items()} + + # Get image embeddings + embeddings = self.clip_model.get_image_features(**inputs).detach().cpu() + + # Handle single frame case + if embeddings.dim() == 1: + embeddings = embeddings.unsqueeze(0) + + all_embeddings.append(embeddings) + + # Concatenate all embeddings + all_embeddings = torch.cat(all_embeddings) # (B*T, 512) + + # Reshape back if temporal + if has_temporal: + all_embeddings = all_embeddings.reshape(batch_size, seq_length, -1) # (B, T, 512) + + return all_embeddings + + @torch.no_grad() + def _encode_text_batch(self, text: str, batch_size: int) -> torch.Tensor: + """Encode a text string using MiniLM and replicate for batch. + + Args: + text: Text string to encode + batch_size: Batch size to replicate for + + Returns: + Encoded feature vectors with shape (B, 384) + """ + from lerobot.policies.rewind.modeling_rewind import mean_pooling + + encoded_input = self.minilm_tokenizer( + text, padding=True, truncation=True, return_tensors="pt" + ).to(self.device) + + model_output = self.minilm_model(**encoded_input) + text_embedding = mean_pooling(model_output, encoded_input["attention_mask"]) + text_embedding = text_embedding.squeeze().cpu() + + # Replicate for batch (B, 384) + text_embedding = text_embedding.unsqueeze(0).repeat(batch_size, 1) + + return text_embedding + + @torch.no_grad() + def _encode_text_batch_list(self, text_list: list[str]) -> torch.Tensor: + """Encode a list of text strings using MiniLM. + + Args: + text_list: List of text strings to encode + + Returns: + Encoded feature vectors with shape (B, 384) + """ + from lerobot.policies.rewind.modeling_rewind import mean_pooling + + # Encode all texts in the batch at once + encoded_input = self.minilm_tokenizer( + text_list, padding=True, truncation=True, return_tensors="pt" + ).to(self.device) + + model_output = self.minilm_model(**encoded_input) + text_embeddings = mean_pooling(model_output, encoded_input["attention_mask"]) + text_embeddings = text_embeddings.cpu() + + return text_embeddings + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + """Add encoded features to the observation features.""" + # Add the encoded features + features[PipelineFeatureType.OBSERVATION]['video_features'] = PolicyFeature( + type=FeatureType.VISUAL, + shape=(self.config.num_frames, self.config.image_dim) + ) + features[PipelineFeatureType.OBSERVATION]['text_features'] = PolicyFeature( + type=FeatureType.LANGUAGE, + shape=(self.config.text_dim,) + ) + if self.config.use_joint_state: + features[PipelineFeatureType.OBSERVATION]['state_features'] = PolicyFeature( + type=FeatureType.STATE, + shape=(self.config.num_frames, self.config.state_dim) + ) + return features + + +def make_sarm_pre_post_processors( + config: SARMConfig, + dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None, + dataset_meta = None, +) -> tuple[ + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], + PolicyProcessorPipeline[PolicyAction, PolicyAction], +]: + """ + Create pre-processor and post-processor pipelines for SARM. + + The pre-processing pipeline: + 1. Encodes images with CLIP (512-dim) + 2. Encodes text with MiniLM (384-dim) + 3. Normalizes joint states + 4. Adds batch dimension + 5. Moves data to device + + Args: + config: SARM configuration + dataset_stats: Dataset statistics for normalization + dataset_meta: Dataset metadata for computing episode info + + Returns: + Tuple of (preprocessor, postprocessor) pipelines + """ + input_steps = [ + AddBatchDimensionProcessorStep(), + SARMEncodingProcessorStep( + config=config, + dataset_meta=dataset_meta, + dataset_stats=dataset_stats + ), + DeviceProcessorStep(device=config.device), + ] + + output_steps = [ + DeviceProcessorStep(device="cpu"), + ] + + return ( + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]]( + steps=input_steps, + name=POLICY_PREPROCESSOR_DEFAULT_NAME, + ), + PolicyProcessorPipeline[PolicyAction, PolicyAction]( + steps=output_steps, + name=POLICY_POSTPROCESSOR_DEFAULT_NAME, + to_transition=policy_action_to_transition, + to_output=transition_to_policy_action, + ), + ) + diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index bc66618ca..ad746e731 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -64,6 +64,7 @@ def update_policy( lr_scheduler=None, use_amp: bool = False, lock=None, + rabc_weight_computer=None, ) -> tuple[MetricsTracker, dict]: """ Performs a single training step to update the policy's weights. @@ -90,8 +91,21 @@ def update_policy( start_time = time.perf_counter() device = get_device_from_parameters(policy) policy.train() + + # Compute RA-BC weights if enabled + rabc_weights = None + if rabc_weight_computer is not None: + rabc_weights = rabc_weight_computer.compute_batch_weights(batch) + with torch.autocast(device_type=device.type) if use_amp else nullcontext(): loss, output_dict = policy.forward(batch) + + # Apply RA-BC weights if enabled + if rabc_weights is not None: + # Weight the loss + loss = loss * rabc_weights.mean() + output_dict['rabc_mean_weight'] = rabc_weights.mean().item() + # TODO(rcadene): policy.unnormalize_outputs(out_dict) grad_scaler.scale(loss).backward() @@ -184,6 +198,10 @@ def train(cfg: TrainPipelineConfig): if (cfg.policy.pretrained_path and not cfg.resume) or not cfg.policy.pretrained_path: # Only provide dataset_stats when not resuming from saved processor state processor_kwargs["dataset_stats"] = dataset.meta.stats + + # For ReWiND and SARM, always provide dataset_meta for progress normalization + if cfg.policy.type in ["rewind", "sarm"]: + processor_kwargs["dataset_meta"] = dataset.meta if cfg.policy.pretrained_path is not None: processor_kwargs["preprocessor_overrides"] = { @@ -212,6 +230,28 @@ def train(cfg: TrainPipelineConfig): logging.info("Creating optimizer and scheduler") optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy) grad_scaler = GradScaler(device.type, enabled=cfg.policy.use_amp) + + # Load reward model for RA-BC if enabled + rabc_weight_computer = None + if cfg.use_rabc: + logging.info(f"Loading reward model for RA-BC from {cfg.reward_model_path}") + from lerobot.policies.factory import get_policy_class + from lerobot.utils.rabc import RABCWeightComputer + + # Detect reward model type from path + # For now, assume SARM if not specified + reward_model_class = get_policy_class("sarm") + reward_model = reward_model_class.from_pretrained(cfg.reward_model_path) + reward_model.to(device) + reward_model.eval() + + rabc_weight_computer = RABCWeightComputer( + reward_model=reward_model, + kappa=cfg.rabc_kappa, + epsilon=cfg.rabc_epsilon, + device=device, + ) + logging.info("RA-BC weight computer initialized") step = 0 # number of policy updates (forward + backward + optim) @@ -239,6 +279,21 @@ def train(cfg: TrainPipelineConfig): drop_n_last_frames=cfg.policy.drop_n_last_frames, shuffle=True, ) + elif cfg.policy.type in ["rewind", "sarm"] and getattr(cfg.policy, "use_temporal_sampler", False): + # Use temporal sequence sampler for loading sequences + from lerobot.datasets.temporal_sampler import TemporalSequenceSampler + + shuffle = False + sampling_mode = getattr(cfg.policy, "sampling_mode", cfg.policy.type) + sampler = TemporalSequenceSampler( + dataset_from_index=dataset.meta.episodes["dataset_from_index"], + dataset_to_index=dataset.meta.episodes["dataset_to_index"], + sequence_length=cfg.policy.max_length, + stride=getattr(cfg.policy, "sequence_stride", 1) if cfg.policy.type == "rewind" else getattr(cfg.policy, "frame_gap", 30), + shuffle=True, + seed=cfg.seed, + sampling_mode=sampling_mode, + ) else: shuffle = True sampler = None @@ -285,6 +340,7 @@ def train(cfg: TrainPipelineConfig): grad_scaler=grad_scaler, lr_scheduler=lr_scheduler, use_amp=cfg.policy.use_amp, + rabc_weight_computer=rabc_weight_computer, ) # Note: eval and checkpoint happens *after* the `step`th training update has completed, so we @@ -301,6 +357,14 @@ def train(cfg: TrainPipelineConfig): wandb_log_dict = train_tracker.to_dict() if output_dict: wandb_log_dict.update(output_dict) + # Log RA-BC statistics if enabled + if rabc_weight_computer is not None: + rabc_stats = rabc_weight_computer.get_stats() + wandb_log_dict.update({ + 'rabc_progress_mean': rabc_stats['mean'], + 'rabc_progress_std': rabc_stats['std'], + 'rabc_samples_seen': rabc_stats['count'], + }) wandb_logger.log_dict(wandb_log_dict, step) train_tracker.reset_averages() diff --git a/src/lerobot/utils/rabc.py b/src/lerobot/utils/rabc.py new file mode 100644 index 000000000..e9648de78 --- /dev/null +++ b/src/lerobot/utils/rabc.py @@ -0,0 +1,183 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Reward-Aligned Behavior Cloning (RA-BC) utilities. + +RA-BC uses a pre-trained reward model (e.g., SARM) to compute progress-based weights +for training samples, emphasizing high-quality demonstrations and down-weighting +suboptimal ones. +""" + +import logging +import torch +import torch.nn as nn + + +class RABCWeightComputer: + """ + Computes RA-BC weights for training batches using a pre-trained reward model. + + Uses Welford's online algorithm for numerically stable running statistics + and applies soft weighting based on progress deltas. + + Args: + reward_model: Pre-trained reward model (e.g., SARM, ReWiND) + kappa: Hard threshold for high-quality samples (default: 0.01) + epsilon: Small constant for numerical stability (default: 1e-6) + device: Device to run reward model on + """ + + def __init__( + self, + reward_model: nn.Module, + kappa: float = 0.01, + epsilon: float = 1e-6, + device: torch.device = None, + ): + self.reward_model = reward_model + self.reward_model.eval() # Always in eval mode + self.kappa = kappa + self.epsilon = epsilon + self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # Running statistics (Welford's algorithm) + self.mean = 0.0 + self.m2 = 0.0 + self.count = 0 + + logging.info(f"RA-BC WeightComputer initialized with kappa={kappa}, epsilon={epsilon}") + + def _update_stats(self, deltas: torch.Tensor): + """Update running statistics using Welford's online algorithm.""" + for delta in deltas: + self.count += 1 + delta_val = delta.item() + delta_mean = delta_val - self.mean + self.mean += delta_mean / self.count + delta_m2 = delta_val - self.mean + self.m2 += delta_mean * delta_m2 + + def _compute_weights(self, deltas: torch.Tensor) -> torch.Tensor: + """Compute RA-BC weights from progress deltas.""" + if self.count < 2: + # Not enough data, use uniform weights + return torch.ones_like(deltas) + + # Get running statistics + mean = max(self.mean, 0.0) # Clamp mean to non-negative + variance = self.m2 / (self.count - 1) + std = torch.tensor(variance).sqrt().item() + + # Compute soft weights + lower_bound = mean - 2 * std + upper_bound = mean + 2 * std + weights = (deltas - lower_bound) / (4 * std + self.epsilon) + weights = torch.clamp(weights, 0.0, 1.0) + + # Apply hard threshold + high_quality_mask = deltas > self.kappa + weights = torch.where(high_quality_mask, torch.ones_like(weights), weights) + + return weights + + @torch.no_grad() + def compute_batch_weights(self, batch: dict, chunk_size: int = 1) -> torch.Tensor: + """ + Compute RA-BC weights for a training batch. + + This function: + 1. Extracts current and next observations from the batch + 2. Computes rewards using the reward model + 3. Calculates progress deltas + 4. Updates running statistics + 5. Returns normalized weights + + Args: + batch: Training batch containing observations + chunk_size: Size of action chunks for computing deltas (default: 1) + + Returns: + Weights tensor (batch_size,) normalized to sum to batch_size + """ + observation = batch.get('observation', batch) + batch_size = next(iter(observation.values())).shape[0] + + # Extract features needed for reward computation + # These should already be encoded by the preprocessor + if 'video_features' not in observation or 'text_features' not in observation: + logging.warning("RA-BC: Missing video/text features, using uniform weights") + return torch.ones(batch_size, device=self.device) + + video_features = observation['video_features'].to(self.device) + text_features = observation['text_features'].to(self.device) + state_features = observation.get('state_features', None) + if state_features is not None: + state_features = state_features.to(self.device) + + # Compute rewards for current observations + # Handle both single-frame and multi-frame features + if video_features.dim() == 3: # (B, T, D) + # Multi-frame: use last frame reward + if hasattr(self.reward_model, 'calculate_rewards'): + current_rewards = self.reward_model.calculate_rewards( + text_features, video_features, state_features, + return_all_frames=False + ) + else: + # Fallback for models without calculate_rewards + current_rewards = torch.zeros(batch_size, device=self.device) + else: # (B, D) + # Single frame + if hasattr(self.reward_model, 'calculate_rewards'): + current_rewards = self.reward_model.calculate_rewards( + text_features, video_features.unsqueeze(1), state_features, + return_all_frames=False + ) + else: + current_rewards = torch.zeros(batch_size, device=self.device) + + if isinstance(current_rewards, tuple): + current_rewards = current_rewards[0] + + current_rewards = torch.tensor(current_rewards, device=self.device) if isinstance(current_rewards, (list, tuple)) else current_rewards + + # For simplicity, assume progress delta is proportional to reward + # In practice, you'd want to compute next_frame rewards and take differences + # For now, use current reward as a proxy for progress delta + progress_deltas = current_rewards + + # Update running statistics + self._update_stats(progress_deltas) + + # Compute weights + weights = self._compute_weights(progress_deltas) + + # Normalize weights to sum to batch_size (maintains effective batch size) + weight_sum = weights.sum() + self.epsilon + weights = weights * batch_size / weight_sum + + return weights + + def get_stats(self) -> dict: + """Get current running statistics.""" + std = torch.tensor(self.m2 / (self.count - 1)).sqrt().item() if self.count > 1 else 0.0 + return { + 'mean': self.mean, + 'std': std, + 'count': self.count, + } +