Compare commits

...

41 Commits

Author SHA1 Message Date
Pepijn 112eb70a65 Add uniform sampling and transition smoothing 2025-11-28 17:15:57 +01:00
Pepijn 6e3b972534 add visualize subtask annotations 2025-11-28 16:59:29 +01:00
Pepijn fa5004bd8c fix formatting 2025-11-28 13:27:20 +01:00
Pepijn b98c70376b Fix visualization and change prompt 2025-11-28 12:16:16 +01:00
Pepijn 2fa045eedc fix normalization in visualization 2025-11-28 10:52:24 +01:00
Pepijn adc476d8af simplify and cleanup code and move compute_temporal_proportions to utils 2025-11-27 19:38:32 +01:00
Pepijn 73dd4f10f7 simplify 2025-11-27 17:36:00 +01:00
Pepijn 2889c0650a use task from dataset, cleanup visualizer 2025-11-27 14:14:52 +01:00
Pepijn f2ad86831d add tests, implement formula 1,2 correctly and cleanup 2025-11-27 14:04:01 +01:00
Pepijn 3ed0425d2c Remove rewind, use clip tokenizer 2025-11-26 21:06:20 +01:00
Pepijn 425eced2de use large offset for initial frame (ugly) 2025-11-26 11:53:12 +01:00
Pepijn cc2e91febe fix progress conversion and adding initial frame 2025-11-26 11:02:42 +01:00
Pepijn c66aef878c add small logging 2025-11-25 22:54:35 +01:00
Pepijn 599c2477c5 change loadig subtasks 2025-11-25 22:48:46 +01:00
Pepijn 456d9fe3ff pass dataset metadata to policy 2025-11-25 22:13:23 +01:00
Pepijn 006185ff4a revert lerobot_train changes 2025-11-25 22:09:27 +01:00
Pepijn 2dc2a3ae55 add subtask init and detection 2025-11-25 22:06:20 +01:00
Pepijn 0c99b768f4 add episode inddex to complementary data 2025-11-25 18:34:56 +01:00
Pepijn c774818eda cleanup and refactor 2025-11-25 17:47:36 +01:00
Pepijn 3b31c2d9d3 pass stats 2025-11-25 16:25:58 +01:00
Pepijn 6b6a82bbdf raise if no state key is found 2025-11-25 16:21:29 +01:00
Pepijn 7beb20819e get state input from dataset stats 2025-11-25 16:17:28 +01:00
Pepijn 9a5a0ad575 change validation 2025-11-25 16:03:13 +01:00
Pepijn 2af40615b8 add image validation 2025-11-25 14:48:52 +01:00
Pepijn 8d2fb5d298 change expected features 2025-11-25 13:51:01 +01:00
Pepijn d286ea30d4 add reward output 2025-11-25 13:44:04 +01:00
Pepijn ca67231892 update sarm processor 2025-11-25 13:40:04 +01:00
Pepijn 5245332e36 print batch size 2025-11-25 13:26:30 +01:00
Pepijn 4367348327 change order train log 2025-11-25 13:05:56 +01:00
Pepijn c2c0dbf52e cleanup 2025-11-25 11:49:27 +01:00
Pepijn 3d28dc3681 Merge branch 'main' into feat/add_rewind 2025-11-24 19:23:05 +01:00
Pepijn 9bd69bb236 Add script to generate embedding for dataset (#2138)
* Add generate and validate script

* fix precommit

* Improve generate embeddings function by using dataset tools (#2206)

---------

Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
2025-11-18 17:13:55 +01:00
Pepijn 52b080fd8c fix rewind discrepancies 2025-11-18 16:09:16 +01:00
Pepijn 0d84f4724d fix spawn 2025-11-18 15:44:24 +01:00
Pepijn 1ffdc6f49e subtasks 2025-11-18 15:28:40 +01:00
Pepijn Kooijmans f688eb160b Merge branch 'feat/add_rewind' of https://github.com/huggingface/lerobot into feat/add_rewind 2025-11-18 15:00:30 +01:00
Pepijn Kooijmans 69868360c7 add sarm 2025-11-18 15:00:05 +01:00
Pepijn 3c9149e909 small fix 2025-11-18 13:47:05 +01:00
Pepijn cf0f878dbb add annotation 2025-11-18 13:34:21 +01:00
Pepijn 1da9eee095 make rewind pretrained policy 2025-10-28 10:29:35 +01:00
Pepijn d9f0c8c3ae add initial modeling 2025-10-15 12:52:33 +02:00
21 changed files with 5842 additions and 10 deletions
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,525 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Visualize SARM Subtask Annotations
This script creates visualizations of the subtask annotations generated by subtask_annotation.py.
For each episode, it shows:
- A timeline with dashed vertical lines at subtask boundaries
- Sample frames from the episode at key points (start, middle, end of each subtask)
- Color-coded subtask segments
Usage:
python visualize_subtask_annotations.py --repo-id pepijn223/mydataset --video-key observation.images.top --num-episodes 5
"""
import argparse
import random
from pathlib import Path
import cv2
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import numpy as np
import pandas as pd
from matplotlib.lines import Line2D
from rich.console import Console
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.utils import load_episodes
from lerobot.policies.sarm.sarm_utils import SubtaskAnnotation, Subtask, Timestamp
def timestamp_to_seconds(timestamp: str) -> float:
"""Convert MM:SS or SS timestamp to seconds"""
parts = timestamp.split(":")
if len(parts) == 2:
return int(parts[0]) * 60 + int(parts[1])
else:
return int(parts[0])
def load_annotations_from_dataset(dataset_path: Path) -> dict[int, SubtaskAnnotation]:
"""
Load annotations from LeRobot dataset parquet files.
Reads subtask annotations from the episodes metadata parquet files.
"""
episodes_dataset = load_episodes(dataset_path)
if episodes_dataset is None or len(episodes_dataset) == 0:
return {}
# Check if subtask columns exist
if "subtask_names" not in episodes_dataset.column_names:
return {}
# Convert to pandas DataFrame for easier access
episodes_df = episodes_dataset.to_pandas()
annotations = {}
for ep_idx in episodes_df.index:
subtask_names = episodes_df.loc[ep_idx, "subtask_names"]
# Skip episodes without annotations
if subtask_names is None or (isinstance(subtask_names, float) and pd.isna(subtask_names)):
continue
start_times = episodes_df.loc[ep_idx, "subtask_start_times"]
end_times = episodes_df.loc[ep_idx, "subtask_end_times"]
# Reconstruct SubtaskAnnotation from stored data
subtasks = []
for i, name in enumerate(subtask_names):
# Convert seconds back to MM:SS format
start_sec = int(start_times[i])
end_sec = int(end_times[i])
start_str = f"{start_sec // 60:02d}:{start_sec % 60:02d}"
end_str = f"{end_sec // 60:02d}:{end_sec % 60:02d}"
subtasks.append(
Subtask(
name=name,
timestamps=Timestamp(start=start_str, end=end_str)
)
)
annotations[int(ep_idx)] = SubtaskAnnotation(subtasks=subtasks)
return annotations
# Color palette for subtasks (colorblind-friendly)
SUBTASK_COLORS = [
"#E69F00", # Orange
"#56B4E9", # Sky blue
"#009E73", # Bluish green
"#F0E442", # Yellow
"#0072B2", # Blue
"#D55E00", # Vermillion
"#CC79A7", # Reddish purple
"#999999", # Gray
]
def extract_frame_from_video(video_path: Path, timestamp: float) -> np.ndarray | None:
"""Extract a single frame from video at given timestamp."""
cap = cv2.VideoCapture(str(video_path))
if not cap.isOpened():
return None
# Set position to timestamp
cap.set(cv2.CAP_PROP_POS_MSEC, timestamp * 1000)
ret, frame = cap.read()
cap.release()
if ret:
# Convert BGR to RGB
return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
return None
def visualize_episode(
episode_idx: int,
annotation,
video_path: Path,
video_start_timestamp: float,
video_end_timestamp: float,
fps: int,
output_path: Path,
video_key: str,
):
"""
Create visualization for a single episode.
Shows:
- Top row: Sample frames from the episode (one per subtask)
- Bottom: Timeline with subtask segments and boundary lines
"""
subtasks = annotation.subtasks
num_subtasks = len(subtasks)
if num_subtasks == 0:
print(f"No subtasks found for episode {episode_idx}")
return
# Calculate episode duration
episode_duration = video_end_timestamp - video_start_timestamp
# Extract sample frames - get frame from middle of each subtask
sample_frames = []
frame_timestamps = []
for subtask in subtasks:
start_sec = timestamp_to_seconds(subtask.timestamps.start)
end_sec = timestamp_to_seconds(subtask.timestamps.end)
mid_sec = (start_sec + end_sec) / 2
# Convert to video timestamp (add video_start_timestamp offset)
video_timestamp = video_start_timestamp + mid_sec
frame_timestamps.append(mid_sec)
frame = extract_frame_from_video(video_path, video_timestamp)
sample_frames.append(frame)
# Create figure
fig = plt.figure(figsize=(16, 10))
# Use a dark background for better contrast
fig.patch.set_facecolor('#1a1a2e')
# Calculate grid layout
# Top section: frames (variable number of columns based on subtasks)
# Bottom section: timeline
# Create gridspec
gs = fig.add_gridspec(
2, max(num_subtasks, 1),
height_ratios=[2, 1],
hspace=0.3,
wspace=0.1,
left=0.05, right=0.95,
top=0.88, bottom=0.1
)
# Add title
fig.suptitle(
f"Episode {episode_idx} - Subtask Annotations",
fontsize=18,
fontweight='bold',
color='white',
y=0.96
)
# Add subtitle with video info
fig.text(
0.5, 0.91,
f"Camera: {video_key} | Duration: {episode_duration:.1f}s | {num_subtasks} subtasks",
ha='center',
fontsize=11,
color='#888888'
)
# Plot sample frames
for i, (frame, subtask) in enumerate(zip(sample_frames, subtasks)):
ax = fig.add_subplot(gs[0, i])
ax.set_facecolor('#16213e')
if frame is not None:
ax.imshow(frame)
else:
ax.text(0.5, 0.5, "Frame\nN/A", ha='center', va='center',
fontsize=12, color='white', transform=ax.transAxes)
ax.set_title(
f"{subtask.name}",
fontsize=10,
fontweight='bold',
color=SUBTASK_COLORS[i % len(SUBTASK_COLORS)],
pad=8
)
ax.axis('off')
# Add frame timestamp below
ax.text(
0.5, -0.08,
f"t={frame_timestamps[i]:.1f}s",
ha='center',
fontsize=9,
color='#888888',
transform=ax.transAxes
)
# Create timeline subplot spanning all columns
ax_timeline = fig.add_subplot(gs[1, :])
ax_timeline.set_facecolor('#16213e')
# Get total duration from last subtask end time
total_duration = timestamp_to_seconds(subtasks[-1].timestamps.end)
# Draw subtask segments as colored bars
bar_height = 0.6
bar_y = 0.5
for i, subtask in enumerate(subtasks):
start_sec = timestamp_to_seconds(subtask.timestamps.start)
end_sec = timestamp_to_seconds(subtask.timestamps.end)
color = SUBTASK_COLORS[i % len(SUBTASK_COLORS)]
# Draw segment bar
rect = mpatches.FancyBboxPatch(
(start_sec, bar_y - bar_height/2),
end_sec - start_sec,
bar_height,
boxstyle="round,pad=0.02,rounding_size=0.1",
facecolor=color,
edgecolor='white',
linewidth=1.5,
alpha=0.85
)
ax_timeline.add_patch(rect)
# Add subtask label inside bar
mid_x = (start_sec + end_sec) / 2
duration = end_sec - start_sec
# Only add text if segment is wide enough
if duration > total_duration * 0.08:
ax_timeline.text(
mid_x, bar_y,
subtask.name,
ha='center', va='center',
fontsize=9,
fontweight='bold',
color='black' if i in [3] else 'white', # Yellow needs dark text
rotation=0 if duration > total_duration * 0.15 else 45
)
# Draw boundary lines (dashed vertical lines between subtasks)
boundary_times = []
for i, subtask in enumerate(subtasks):
start_sec = timestamp_to_seconds(subtask.timestamps.start)
end_sec = timestamp_to_seconds(subtask.timestamps.end)
# Add start boundary (except for first subtask at t=0)
if i == 0 and start_sec > 0:
boundary_times.append(start_sec)
elif i > 0:
boundary_times.append(start_sec)
# Add end boundary for last subtask
if i == len(subtasks) - 1:
boundary_times.append(end_sec)
# Draw dashed lines at boundaries
for t in boundary_times:
ax_timeline.axvline(
x=t,
ymin=0.1, ymax=0.9,
color='white',
linestyle='--',
linewidth=2,
alpha=0.9
)
# Add time label below line
ax_timeline.text(
t, 0.0,
f"{int(t//60):02d}:{int(t%60):02d}",
ha='center', va='top',
fontsize=8,
color='#cccccc'
)
# Add start line at t=0
ax_timeline.axvline(x=0, ymin=0.1, ymax=0.9, color='#00ff00', linestyle='-', linewidth=2.5, alpha=0.9)
ax_timeline.text(0, 0.0, "00:00", ha='center', va='top', fontsize=8, color='#00ff00', fontweight='bold')
# Configure timeline axes
ax_timeline.set_xlim(-total_duration * 0.02, total_duration * 1.02)
ax_timeline.set_ylim(-0.3, 1.2)
ax_timeline.set_xlabel("Time (seconds)", fontsize=11, color='white', labelpad=10)
ax_timeline.set_ylabel("")
# Style the axes
ax_timeline.spines['top'].set_visible(False)
ax_timeline.spines['right'].set_visible(False)
ax_timeline.spines['left'].set_visible(False)
ax_timeline.spines['bottom'].set_color('#444444')
ax_timeline.tick_params(axis='x', colors='#888888', labelsize=9)
ax_timeline.tick_params(axis='y', left=False, labelleft=False)
# Add x-axis ticks at regular intervals
tick_interval = max(1, int(total_duration / 10))
ax_timeline.set_xticks(np.arange(0, total_duration + tick_interval, tick_interval))
# Add legend explaining line styles
legend_elements = [
Line2D([0], [0], color='#00ff00', linewidth=2.5, linestyle='-', label='Start'),
Line2D([0], [0], color='white', linewidth=2, linestyle='--', label='Subtask boundary'),
]
ax_timeline.legend(
handles=legend_elements,
loc='upper right',
framealpha=0.3,
facecolor='#16213e',
edgecolor='#444444',
fontsize=9,
labelcolor='white'
)
# Save figure
plt.savefig(output_path, dpi=150, facecolor=fig.get_facecolor(), edgecolor='none', bbox_inches='tight')
plt.close()
return output_path
def main():
parser = argparse.ArgumentParser(
description="Visualize SARM subtask annotations",
formatter_class=argparse.RawDescriptionHelpFormatter,
)
parser.add_argument(
"--repo-id",
type=str,
required=True,
help="HuggingFace dataset repository ID",
)
parser.add_argument(
"--num-episodes",
type=int,
default=5,
help="Number of random episodes to visualize (default: 5)",
)
parser.add_argument(
"--episodes",
type=int,
nargs="+",
default=None,
help="Specific episode indices to visualize (overrides --num-episodes)",
)
parser.add_argument(
"--video-key",
type=str,
default=None,
help="Camera/video key to use. If not specified, uses first available.",
)
parser.add_argument(
"--output-dir",
type=str,
default="./subtask_viz",
help="Output directory for visualizations (default: ./subtask_viz)",
)
parser.add_argument(
"--seed",
type=int,
default=None,
help="Random seed for reproducibility",
)
args = parser.parse_args()
console = Console()
# Set random seed if specified
if args.seed is not None:
random.seed(args.seed)
console.print(f"\n[cyan]Loading dataset: {args.repo_id}[/cyan]")
dataset = LeRobotDataset(args.repo_id, download_videos=True)
fps = dataset.fps
# Get video key
if args.video_key:
if args.video_key not in dataset.meta.video_keys:
console.print(f"[red]Error: Video key '{args.video_key}' not found[/red]")
console.print(f"[yellow]Available: {', '.join(dataset.meta.video_keys)}[/yellow]")
return
video_key = args.video_key
else:
video_key = dataset.meta.video_keys[0]
console.print(f"[cyan]Using camera: {video_key}[/cyan]")
console.print(f"[cyan]FPS: {fps}[/cyan]")
# Load annotations
console.print(f"\n[cyan]Loading annotations...[/cyan]")
annotations = load_annotations_from_dataset(dataset.root)
if not annotations:
console.print("[red]Error: No annotations found in dataset[/red]")
console.print("[yellow]Run subtask_annotation.py first to generate annotations[/yellow]")
return
console.print(f"[green]Found {len(annotations)} annotated episodes[/green]")
# Determine which episodes to visualize
if args.episodes:
episode_indices = args.episodes
# Validate episodes exist
for ep in episode_indices:
if ep not in annotations:
console.print(f"[yellow]Warning: Episode {ep} has no annotation, skipping[/yellow]")
episode_indices = [ep for ep in episode_indices if ep in annotations]
else:
# Random selection
available_episodes = list(annotations.keys())
num_to_select = min(args.num_episodes, len(available_episodes))
episode_indices = random.sample(available_episodes, num_to_select)
episode_indices.sort()
if not episode_indices:
console.print("[red]Error: No valid episodes to visualize[/red]")
return
console.print(f"[cyan]Visualizing episodes: {episode_indices}[/cyan]")
# Create output directory
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
# Generate visualizations
for ep_idx in episode_indices:
console.print(f"\n[cyan]Processing episode {ep_idx}...[/cyan]")
annotation = annotations[ep_idx]
# Get video path and timestamps
video_path = dataset.root / dataset.meta.get_video_file_path(ep_idx, video_key)
if not video_path.exists():
console.print(f"[red]Video not found: {video_path}[/red]")
continue
# Get episode-specific timestamps within the video file
video_path_key = f"videos/{video_key}/from_timestamp"
video_path_key_to = f"videos/{video_key}/to_timestamp"
video_start_timestamp = float(dataset.meta.episodes[video_path_key][ep_idx])
video_end_timestamp = float(dataset.meta.episodes[video_path_key_to][ep_idx])
# Create visualization
output_path = output_dir / f"episode_{ep_idx:04d}_subtasks.png"
try:
visualize_episode(
episode_idx=ep_idx,
annotation=annotation,
video_path=video_path,
video_start_timestamp=video_start_timestamp,
video_end_timestamp=video_end_timestamp,
fps=fps,
output_path=output_path,
video_key=video_key,
)
console.print(f"[green]✓ Saved: {output_path}[/green]")
except Exception as e:
console.print(f"[red]✗ Failed to visualize episode {ep_idx}: {e}[/red]")
# Print summary
console.print(f"\n[bold green]{'=' * 50}[/bold green]")
console.print(f"[bold green]Visualization Complete![/bold green]")
console.print(f"[bold green]{'=' * 50}[/bold green]")
console.print(f"Output directory: {output_dir.absolute()}")
console.print(f"Episodes visualized: {len(episode_indices)}")
if __name__ == "__main__":
main()
+761
View File
@@ -0,0 +1,761 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Inference script for SARM (Stage-Aware Reward Model).
This script loads a trained SARM model and runs inference on a dataset episode,
generating visualizations of the predicted task stages and progress over time.
Example usage:
python scripts/visualize_sarm_predictions.py \
--model-id username/sarm-model \
--dataset-repo lerobot/aloha_sim_insertion_human \
--episode-index 0 \
--output-dir outputs/sarm_viz \
--task-description "insert the peg into the socket"
"""
import argparse
import json
import logging
from pathlib import Path
from typing import Optional
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import matplotlib.patches as mpatches
import numpy as np
import pandas as pd
import torch
from tqdm import tqdm
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.policies.sarm.modeling_sarm import SARMRewardModel
from lerobot.policies.sarm.sarm_utils import (
pad_state_to_max_dim,
compute_tau,
compute_cumulative_progress_batch,
)
from lerobot.datasets.utils import load_stats
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def parse_args():
parser = argparse.ArgumentParser(description="Run SARM inference and visualize predictions")
# Model arguments
parser.add_argument(
"--model-id",
type=str,
required=True,
help="HuggingFace model ID or local path to trained SARM model"
)
# Dataset arguments
parser.add_argument(
"--dataset-repo",
type=str,
required=True,
help="HuggingFace dataset repository ID (e.g., lerobot/aloha_sim_insertion_human)"
)
parser.add_argument(
"--episode-index",
type=int,
default=0,
help="Index of the episode to visualize (default: 0)"
)
parser.add_argument(
"--task-description",
type=str,
default="perform the task",
help="Task description for the reward model (default: 'perform the task')"
)
# Output arguments
parser.add_argument(
"--output-dir",
type=str,
default="outputs/sarm_inference",
help="Directory to save visualization outputs (default: outputs/sarm_inference)"
)
parser.add_argument(
"--image-key",
type=str,
default=None,
help="Key for images in dataset (e.g., observation.images.image). If not specified, uses model config's image_key"
)
parser.add_argument(
"--state-key",
type=str,
default=None,
help="Key for joint states in dataset. If None, auto-detects from dataset"
)
# Visualization options
parser.add_argument(
"--show-frames",
action="store_true",
help="Include sample frames in the visualization"
)
parser.add_argument(
"--num-sample-frames",
type=int,
default=8,
help="Number of sample frames to show (default: 8)"
)
parser.add_argument(
"--figsize",
type=int,
nargs=2,
default=[14, 8],
help="Figure size as width height (default: 14 8)"
)
# Device
parser.add_argument(
"--device",
type=str,
default=None,
help="Device to run inference on (cuda/cpu, default: auto-detect)"
)
return parser.parse_args()
def load_episode_data(
dataset: LeRobotDataset,
episode_index: int,
image_key: str,
state_key: str | None = None
) -> tuple[np.ndarray, np.ndarray, int, int, str]:
"""
Load all frames and states from a specific episode.
Args:
dataset: LeRobotDataset instance
episode_index: Index of the episode to load
image_key: Key for accessing images in the dataset
state_key: Key for accessing joint states (auto-detected if None)
Returns:
Tuple of (frames, states, start_index, end_index, task_description)
"""
# Get episode boundaries
episode_data = dataset.meta.episodes
start_idx = episode_data["dataset_from_index"][episode_index]
end_idx = episode_data["dataset_to_index"][episode_index]
logger.info(f"Loading episode {episode_index}: frames {start_idx} to {end_idx} ({end_idx - start_idx} frames)")
# Auto-detect state key if not provided
if state_key is None:
first_item = dataset[start_idx]
state_keys = [k for k in first_item.keys() if 'state' in k.lower() or 'qpos' in k.lower()]
if state_keys:
state_key = state_keys[0]
logger.info(f"Auto-detected state key: {state_key}")
# Get task description from the dataset if available
task_description = None
first_item = dataset[start_idx]
if "task" in first_item:
task_description = first_item["task"]
logger.info(f"✓ Extracted task from episode {episode_index}: '{task_description}'")
# Load all frames and states from the episode
frames = []
states = []
for idx in tqdm(range(start_idx, end_idx), desc="Loading frames"):
item = dataset[idx]
# Get image
img = item[image_key]
# Convert to numpy if needed
if isinstance(img, torch.Tensor):
img = img.cpu().numpy()
# Handle different image formats (C, H, W) or (H, W, C)
if img.shape[0] in [1, 3]: # Channel first
img = np.transpose(img, (1, 2, 0))
# Convert to uint8 if needed
if img.dtype != np.uint8:
if img.max() <= 1.0:
img = (img * 255).astype(np.uint8)
else:
img = img.astype(np.uint8)
frames.append(img)
# Get state if available
if state_key and state_key in item:
state = item[state_key]
if isinstance(state, torch.Tensor):
state = state.cpu().numpy()
states.append(state)
frames = np.array(frames)
states = np.array(states) if states else None
logger.info(f"Loaded {len(frames)} frames with shape {frames[0].shape}")
if states is not None:
logger.info(f"Loaded states with shape {states.shape}")
return frames, states, start_idx, end_idx, task_description
@torch.no_grad()
def run_inference(
model: SARMRewardModel,
frames: np.ndarray,
states: Optional[np.ndarray],
task_description: str,
dataset_stats: dict | None = None,
state_key: str = "observation.state",
batch_size: int = 32
) -> tuple[np.ndarray, np.ndarray]:
"""
Run SARM inference on video frames and joint states.
(per SARM paper Section A.4):
- Frame 0: Initial frame of the episode (frame 0)
- Frames 1-8: 8 consecutive frames with frame_gap spacing ending at current frame t
Pattern: [frame_0, t-(7*gap), t-(6*gap), ..., t-gap, t]
Args:
model: SARM model
frames: Video frames (num_frames, H, W, C) - all frames from ONE episode
states: Joint states (num_frames, state_dim)
task_description: Task description text
dataset_stats: Dataset statistics for state normalization (same as training)
state_key: Key for state in dataset_stats
batch_size: Batch size for processing slices
Returns:
Tuple of (progress_predictions, stage_predictions)
- progress_predictions: (num_frames,)
- stage_predictions: (num_frames, num_stages)
"""
logger.info("Encoding video frames with CLIP...")
video_embeddings = model.encode_images(frames)
logger.info("Encoding task description with CLIP...")
text_embedding = model.encode_text(task_description)
# Get config values
num_frames_model = model.config.num_frames # 9
frame_gap = model.config.frame_gap # 30
logger.info("Creating video slices (SARM paper: initial frame + 8 consecutive)...")
# Convert to tensors
video_embeddings = torch.tensor(video_embeddings, dtype=torch.float32)
text_embedding = torch.tensor(text_embedding, dtype=torch.float32)
if states is not None:
state_embeddings = torch.tensor(states, dtype=torch.float32)
# Normalize states using dataset stats (same as training processor)
if dataset_stats is not None and state_key in dataset_stats:
mean = torch.tensor(dataset_stats[state_key]["mean"], dtype=torch.float32)
std = torch.tensor(dataset_stats[state_key]["std"], dtype=torch.float32)
state_embeddings = (state_embeddings - mean) / (std + 1e-8)
logger.info(f"✓ Applied MEAN_STD normalization to states using {state_key}")
else:
logger.warning("⚠ No dataset_stats provided - states not normalized (may differ from training)")
else:
state_embeddings = None
video_slices = []
state_slices = []
for current_frame in tqdm(range(len(video_embeddings)), desc="Creating slices"):
# Compute frame indices using symmetric bidirectional pattern:
# [initial (0), t-4*gap, t-3*gap, t-2*gap, t-gap, t, t+gap, t+2*gap, t+3*gap]
# Boundary handling: clamp to [0, last_valid]
deltas = model.config.observation_delta_indices
last_valid = len(video_embeddings) - 1
frame_indices = []
for delta in deltas:
idx = current_frame + delta
idx = max(0, min(idx, last_valid)) # Clamp to valid range
frame_indices.append(idx)
video_slice = video_embeddings[frame_indices]
video_slices.append(video_slice)
if state_embeddings is not None:
state_slice = state_embeddings[frame_indices]
state_slices.append(state_slice)
video_slices = torch.stack(video_slices) # (num_frames, num_frames_model, 512)
if state_embeddings is not None:
state_slices = torch.stack(state_slices) # (num_frames, num_frames_model, state_dim)
# Pad states to max_state_dim (same as training processor)
state_slices = pad_state_to_max_dim(state_slices, model.config.max_state_dim)
else:
state_slices = None
logger.info("Running SARM inference on all slices...")
# Process in batches
all_progress = []
all_stages = []
for i in tqdm(range(0, len(video_slices), batch_size), desc="Inference"):
batch_video = video_slices[i:i + batch_size].to(model.device)
batch_states = state_slices[i:i + batch_size].to(model.device) if state_slices is not None else None
batch_size_actual = batch_video.shape[0]
# Replicate text embedding for batch
batch_text = text_embedding.unsqueeze(0).repeat(batch_size_actual, 1).to(model.device)
# Get predictions
stage_logits, stage_probs, progress_preds = model.sarm_transformer(
batch_video, batch_text, batch_states
)
# Extract predictions at the "current frame" position
# With symmetric pattern [initial, t-4g, t-3g, t-2g, t-g, t, t+g, t+2g, t+3g],
# the current frame is at position 5 (0-indexed)
current_frame_idx = 5
batch_progress = progress_preds[:, current_frame_idx, 0].cpu().numpy()
batch_stages = stage_probs[:, current_frame_idx, :].cpu().numpy()
all_progress.extend(batch_progress)
all_stages.extend(batch_stages)
return np.array(all_progress), np.array(all_stages)
def compute_ground_truth_progress(
dataset: LeRobotDataset,
episode_index: int,
temporal_proportions: dict[str, float],
subtask_names_ordered: list[str],
) -> tuple[np.ndarray, np.ndarray] | tuple[None, None]:
"""
Compute ground truth progress and stage labels for an episode using annotations.
Uses SARM Paper Formula (2):
y_t = P_{k-1} + ᾱ_k × τ_t
where:
- τ_t = (t - s_k) / (e_k - s_k) is within-subtask progress
- P_{k-1} is cumulative prior (sum of previous subtask proportions)
- ᾱ_k is the temporal proportion for subtask k
Args:
dataset: LeRobotDataset instance
episode_index: Index of the episode
temporal_proportions: Dict mapping subtask name to proportion
subtask_names_ordered: Ordered list of subtask names (for consistent stage indexing)
Returns:
Tuple of (ground_truth_progress, ground_truth_stages) arrays, or (None, None) if no annotations
"""
# Load episode metadata
episodes_df = dataset.meta.episodes.to_pandas()
# Check if annotations exist
if "subtask_names" not in episodes_df.columns:
logger.warning("No subtask_names column found in episodes metadata")
return None, None
ep_subtask_names = episodes_df.loc[episode_index, "subtask_names"]
if ep_subtask_names is None or (isinstance(ep_subtask_names, float) and pd.isna(ep_subtask_names)):
logger.warning(f"No annotations found for episode {episode_index}")
return None, None
subtask_start_frames = episodes_df.loc[episode_index, "subtask_start_frames"]
subtask_end_frames = episodes_df.loc[episode_index, "subtask_end_frames"]
# Get episode boundaries
ep_start = dataset.meta.episodes["dataset_from_index"][episode_index]
ep_end = dataset.meta.episodes["dataset_to_index"][episode_index]
num_frames = ep_end - ep_start
# Get temporal proportions as ordered list
temporal_proportions_list = [
temporal_proportions.get(name, 0.0) for name in subtask_names_ordered
]
logger.info(f"Computing ground truth for {num_frames} frames using {len(ep_subtask_names)} annotated subtasks")
logger.info(f"Subtask names in episode: {ep_subtask_names}")
logger.info(f"Subtask start frames: {subtask_start_frames}")
logger.info(f"Subtask end frames: {subtask_end_frames}")
logger.info(f"Temporal proportions (ordered): {dict(zip(subtask_names_ordered, temporal_proportions_list))}")
# Compute ground truth for each frame
gt_progress = np.zeros(num_frames)
gt_stages = np.zeros(num_frames, dtype=np.int32)
for frame_rel in range(num_frames):
# Find which subtask this frame belongs to
found = False
for j, (name, start_frame, end_frame) in enumerate(zip(ep_subtask_names, subtask_start_frames, subtask_end_frames)):
if frame_rel >= start_frame and frame_rel <= end_frame:
# Found the subtask - get its global index
stage_idx = subtask_names_ordered.index(name) if name in subtask_names_ordered else 0
# Compute τ_t using utility function
tau = compute_tau(frame_rel, start_frame, end_frame)
# Compute cumulative progress using utility function
progress = compute_cumulative_progress_batch(tau, stage_idx, temporal_proportions_list)
gt_progress[frame_rel] = progress
gt_stages[frame_rel] = stage_idx
found = True
break
if not found:
# Handle frames outside annotated subtasks
if frame_rel < subtask_start_frames[0]:
gt_progress[frame_rel] = 0.0
gt_stages[frame_rel] = 0
elif frame_rel > subtask_end_frames[-1]:
gt_progress[frame_rel] = 1.0
gt_stages[frame_rel] = len(subtask_names_ordered) - 1
else:
# Between subtasks - find previous subtask
for j in range(len(ep_subtask_names) - 1):
if frame_rel > subtask_end_frames[j] and frame_rel < subtask_start_frames[j + 1]:
name = ep_subtask_names[j]
stage_idx = subtask_names_ordered.index(name) if name in subtask_names_ordered else j
progress = compute_cumulative_progress_batch(1.0, stage_idx, temporal_proportions_list)
gt_progress[frame_rel] = progress
gt_stages[frame_rel] = stage_idx
break
logger.info(f"✓ Ground truth computed: final={gt_progress[-1]:.3f}, max={gt_progress.max():.3f}")
return gt_progress, gt_stages
def visualize_predictions(
frames: np.ndarray,
progress_predictions: np.ndarray,
stage_predictions: np.ndarray,
task_description: str,
output_path: Path,
num_sample_frames: int = 8,
figsize: tuple = (14, 8),
subtask_names: list[str] | None = None,
temporal_proportions: dict[str, float] | None = None,
ground_truth_progress: np.ndarray | None = None,
ground_truth_stages: np.ndarray | None = None,
):
"""
Create visualization of SARM predictions with optional ground truth comparison.
Args:
frames: Video frames (num_frames, H, W, C)
progress_predictions: Progress predictions (num_frames,)
stage_predictions: Stage probabilities (num_frames, num_stages)
task_description: Task description
output_path: Path to save the figure
num_sample_frames: Number of frames to show
figsize: Figure size (width, height)
subtask_names: Optional list of subtask names for labeling
temporal_proportions: Optional dict of temporal proportions for each subtask
ground_truth_progress: Optional ground truth progress array (num_frames,)
ground_truth_stages: Optional ground truth stage indices array (num_frames,)
"""
num_stages = stage_predictions.shape[1]
stage_colors = plt.cm.tab10(np.linspace(0, 1, num_stages))
# Use subtask names if available, otherwise use generic labels
if subtask_names is not None and len(subtask_names) == num_stages:
stage_labels = subtask_names
else:
stage_labels = [f'Stage {i+1}' for i in range(num_stages)]
# Create figure with progress plot, stage plot, and sample frames
fig = plt.figure(figsize=(figsize[0], figsize[1] + 4))
gs = gridspec.GridSpec(3, 1, height_ratios=[2, 1, 1], hspace=0.3)
ax_progress = fig.add_subplot(gs[0])
ax_stages = fig.add_subplot(gs[1], sharex=ax_progress)
ax_frames = fig.add_subplot(gs[2])
frame_indices = np.arange(len(progress_predictions))
# Plot 1: Progress over time
ax_progress.plot(frame_indices, progress_predictions, linewidth=2, color='#2E86AB', label='Predicted Progress')
ax_progress.fill_between(frame_indices, 0, progress_predictions, alpha=0.3, color='#2E86AB')
# Plot ground truth if available
if ground_truth_progress is not None:
ax_progress.plot(frame_indices, ground_truth_progress, linewidth=2, color='#28A745',
linestyle='--', label='Ground Truth Progress')
ax_progress.fill_between(frame_indices, 0, ground_truth_progress, alpha=0.15, color='#28A745')
ax_progress.axhline(y=1.0, color='gray', linestyle='--', alpha=0.5, linewidth=1)
ax_progress.set_ylabel('Task Progress', fontsize=12)
ax_progress.set_title(f'Task: "{task_description}"', fontsize=14, fontweight='bold')
ax_progress.grid(True, alpha=0.3)
ax_progress.set_ylim(-0.05, 1.1)
ax_progress.legend(loc='upper left')
# Add statistics box
stats_text = (
f'Frames: {len(progress_predictions)}\n'
f'Final Progress: {progress_predictions[-1]:.3f}\n'
f'Max Progress: {progress_predictions.max():.3f}\n'
f'Mean Progress: {progress_predictions.mean():.3f}'
)
if ground_truth_progress is not None:
mse = np.mean((progress_predictions - ground_truth_progress) ** 2)
stats_text += f'\nMSE vs GT: {mse:.4f}'
stats_text += f'\nGT Final: {ground_truth_progress[-1]:.3f}'
ax_progress.text(0.98, 0.02, stats_text, transform=ax_progress.transAxes,
fontsize=10, verticalalignment='bottom', horizontalalignment='right',
bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
# Plot 2: Stage predictions (stacked area plot)
ax_stages.stackplot(frame_indices, *[stage_predictions[:, i] for i in range(num_stages)],
colors=stage_colors, alpha=0.8, labels=stage_labels)
# Plot ground truth stage as vertical bands or markers
if ground_truth_stages is not None:
# Find stage transition points in ground truth
stage_changes = np.where(np.diff(ground_truth_stages) != 0)[0] + 1
for change_idx in stage_changes:
ax_stages.axvline(x=change_idx, color='black', linestyle='-', alpha=0.7, linewidth=1.5)
ax_progress.axvline(x=change_idx, color='black', linestyle='-', alpha=0.3, linewidth=1)
# Add small markers at bottom showing GT stage
gt_stage_normalized = ground_truth_stages / max(num_stages - 1, 1)
ax_stages.scatter(frame_indices[::30], np.zeros(len(frame_indices[::30])) + 0.02,
c=[stage_colors[s] for s in ground_truth_stages[::30]],
s=20, marker='|', alpha=0.8, label='GT Stage Markers')
ax_stages.set_xlabel('Frame Index', fontsize=12)
ax_stages.set_ylabel('Stage Probability', fontsize=12)
ax_stages.set_ylim(0, 1)
ax_stages.grid(True, alpha=0.3)
# Adjust legend based on number of stages and label lengths
if num_stages <= 5:
ax_stages.legend(loc='upper left', ncol=num_stages, fontsize=8)
else:
ax_stages.legend(loc='upper left', ncol=3, fontsize=7)
# Add vertical lines and labels for expected stage transitions (if temporal proportions available)
if temporal_proportions is not None and subtask_names is not None:
cumulative_progress = 0.0
for i, name in enumerate(stage_labels):
if name in temporal_proportions:
# Find approximate frame where this stage should end
stage_end_progress = cumulative_progress + temporal_proportions[name]
# Find frame index closest to this progress
progress_diffs = np.abs(progress_predictions - stage_end_progress)
stage_end_frame = np.argmin(progress_diffs)
# Draw vertical line
ax_progress.axvline(x=stage_end_frame, color='gray', linestyle=':', alpha=0.5, linewidth=1)
ax_stages.axvline(x=stage_end_frame, color='gray', linestyle=':', alpha=0.5, linewidth=1)
cumulative_progress = stage_end_progress
# Plot 3: Sample frames (if requested)
frame_indices_to_show = np.linspace(0, len(frames) - 1, num_sample_frames, dtype=int)
ax_frames.axis('off')
# Create grid for frames
frame_height = frames[0].shape[0]
frame_width = frames[0].shape[1]
combined_width = frame_width * num_sample_frames
combined_image = np.zeros((frame_height, combined_width, 3), dtype=np.uint8)
for i, frame_idx in enumerate(frame_indices_to_show):
frame = frames[frame_idx]
if frame.shape[-1] == 1:
frame = np.repeat(frame, 3, axis=-1)
# Add frame to combined image
x_start = i * frame_width
x_end = (i + 1) * frame_width
combined_image[:, x_start:x_end] = frame
# Add frame number, progress, and stage
progress_val = progress_predictions[frame_idx]
stage_idx = np.argmax(stage_predictions[frame_idx])
stage_name = stage_labels[stage_idx] if stage_idx < len(stage_labels) else f'{stage_idx+1}'
# Truncate long stage names for display
if len(stage_name) > 15:
stage_name = stage_name[:12] + '...'
label = f'Frame {frame_idx}\nProg: {progress_val:.2f}\n{stage_name}'
# Draw label on image
ax_frames.text(x_start + frame_width / 2, -10, label,
ha='center', va='top', fontsize=7,
bbox=dict(boxstyle='round', facecolor='white', alpha=0.7))
ax_frames.imshow(combined_image)
ax_frames.set_title('Sample Frames', fontsize=12, pad=20)
plt.tight_layout()
output_path.parent.mkdir(parents=True, exist_ok=True)
plt.savefig(output_path, dpi=150, bbox_inches='tight')
logger.info(f"Saved visualization to {output_path}")
plt.close()
def main():
args = parse_args()
# Setup device
if args.device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
else:
device = args.device
logger.info(f"Using device: {device}")
# Load model
logger.info(f"Loading SARM model from {args.model_id}...")
model = SARMRewardModel.from_pretrained(args.model_id)
model.to(device)
model.eval()
logger.info("Model loaded successfully")
# Load dataset
logger.info(f"Loading dataset {args.dataset_repo}...")
dataset = LeRobotDataset(args.dataset_repo)
logger.info(f"Dataset loaded: {len(dataset.meta.episodes)} episodes, {len(dataset)} frames")
# Validate episode index
if args.episode_index >= len(dataset.meta.episodes):
raise ValueError(
f"Episode index {args.episode_index} out of range. "
f"Dataset has {len(dataset.meta.episodes)} episodes."
)
image_key = args.image_key if args.image_key is not None else model.config.image_key
state_key = args.state_key if args.state_key is not None else model.config.state_key
logger.info(f"Using image key: {image_key}")
logger.info(f"Using state key: {state_key}")
# Load dataset stats for state normalization (same as training)
dataset_stats = load_stats(dataset.root)
if dataset_stats:
logger.info(f"✓ Loaded dataset stats from {dataset.root}")
else:
logger.warning("⚠ Could not load dataset stats - states will not be normalized")
# Load episode data
frames, states, start_idx, end_idx, dataset_task = load_episode_data(
dataset, args.episode_index, image_key, state_key
)
# Use task description from dataset if available, otherwise use command-line argument
task_description = dataset_task if dataset_task is not None else args.task_description
logger.info(f"Using task description: '{task_description}'")
# Run inference
progress_predictions, stage_predictions = run_inference(
model, frames, states, task_description,
dataset_stats=dataset_stats, state_key=state_key
)
# Extract subtask names and temporal proportions from model config if available
subtask_names = None
temporal_proportions = None
if hasattr(model.config, 'subtask_names') and model.config.subtask_names is not None:
subtask_names = model.config.subtask_names
logger.info(f"✓ Found {len(subtask_names)} subtask names in model config: {subtask_names}")
# Try to load temporal proportions from model config
if hasattr(model.config, 'temporal_proportions') and model.config.temporal_proportions is not None:
temporal_proportions = {
name: prop for name, prop in zip(model.config.subtask_names, model.config.temporal_proportions)
}
logger.info(f"✓ Loaded temporal proportions from model config: {temporal_proportions}")
# Fallback: try to load from dataset meta
if temporal_proportions is None:
proportions_path = dataset.root / "meta" / "temporal_proportions.json"
if proportions_path.exists():
with open(proportions_path, 'r') as f:
temporal_proportions = json.load(f)
logger.info(f"✓ Loaded temporal proportions from dataset: {temporal_proportions}")
# Also extract subtask names from proportions if not already set
if subtask_names is None:
subtask_names = sorted(temporal_proportions.keys())
logger.info(f"✓ Extracted subtask names from proportions: {subtask_names}")
# Compute ground truth progress if annotations are available
ground_truth_progress = None
ground_truth_stages = None
if temporal_proportions is not None and subtask_names is not None:
logger.info("Attempting to compute ground truth progress from annotations...")
ground_truth_progress, ground_truth_stages = compute_ground_truth_progress(
dataset,
args.episode_index,
temporal_proportions,
subtask_names
)
if ground_truth_progress is None:
logger.warning("⚠ Ground truth not available - annotations may be missing for this episode")
else:
logger.warning("⚠ Cannot compute ground truth - temporal_proportions or subtask_names not available")
output_dir = Path(args.output_dir)
output_path = output_dir / f"sarm_prediction_ep{args.episode_index}.png"
visualize_predictions(
frames,
progress_predictions,
stage_predictions,
task_description,
output_path,
num_sample_frames=args.num_sample_frames,
figsize=tuple(args.figsize),
subtask_names=subtask_names,
temporal_proportions=temporal_proportions,
ground_truth_progress=ground_truth_progress,
ground_truth_stages=ground_truth_stages,
)
predictions_path = output_dir / f"predictions_ep{args.episode_index}.npz"
save_dict = {
'progress': progress_predictions,
'stages': stage_predictions
}
if ground_truth_progress is not None:
save_dict['gt_progress'] = ground_truth_progress
save_dict['gt_stages'] = ground_truth_stages
np.savez(predictions_path, **save_dict)
logger.info(f"Saved predictions to {predictions_path}")
logger.info(f"\nVisualization: {output_path}")
if __name__ == "__main__":
main()
+19 -2
View File
@@ -64,9 +64,26 @@ class TrainPipelineConfig(HubMixin):
scheduler: LRSchedulerConfig | None = None
eval: EvalConfig = field(default_factory=EvalConfig)
wandb: WandBConfig = field(default_factory=WandBConfig)
checkpoint_path: Path | None = field(init=False, default=None)
# RA-BC (Reward-Aligned Behavior Cloning) parameters
use_rabc: bool = False # Enable reward-weighted training
reward_model_path: str | None = None # Path to pre-trained reward model (e.g., SARM)
rabc_kappa: float = 0.01 # Hard threshold for high-quality samples
rabc_epsilon: float = 1e-6 # Small constant for numerical stability
rabc_update_freq: int = 1 # Compute rewards every N batches (1 = every batch)
# Rename map for the observation to override the image and state keys
rename_map: dict[str, str] = field(default_factory=dict)
rename_map: dict[str, str] = field(default_factory=dict)
checkpoint_path: Path | None = field(init=False, default=None)
def validate(self):
# Validate RA-BC configuration
if self.use_rabc and not self.reward_model_path:
raise ValueError(
"RA-BC is enabled (use_rabc=True) but no reward_model_path provided. "
"Please specify a pre-trained reward model (e.g., SARM) path."
)
def validate(self) -> None:
# HACK: We parse again the cli args here to get the pretrained paths if there was some.
+11 -3
View File
@@ -999,10 +999,18 @@ def _copy_data_with_feature_changes(
df[feature_name] = feature_values
else:
feature_slice = values[frame_idx:end_idx]
if len(feature_slice.shape) > 1 and feature_slice.shape[1] == 1:
df[feature_name] = feature_slice.flatten()
else:
if len(feature_slice.shape) == 1:
# 1D array - can assign directly
df[feature_name] = feature_slice
elif len(feature_slice.shape) == 2 and feature_slice.shape[1] == 1:
# 2D array with single column - flatten it
df[feature_name] = feature_slice.flatten()
elif len(feature_slice.shape) == 2:
# 2D array with multiple columns (e.g., embeddings) - convert to list of lists
df[feature_name] = feature_slice.tolist()
else:
# Higher dimensional - convert to list
df[feature_name] = [row.tolist() for row in feature_slice]
frame_idx = end_idx
# Write using the same chunk/file structure as source
@@ -0,0 +1,146 @@
# LeRobot Embedding Generation Script
Generate embeddings for LeRobot datasets to make them more lightweight and efficient for training.
## Overview
This script processes v3.0 LeRobot datasets and adds pre-computed embeddings for:
- **Task embeddings**: Language command embeddings using MiniLM
- **Image embeddings**: Frame embeddings using DinoV2
The resulting dataset can be used more efficiently during training by loading pre-computed embeddings instead of running encoders on-the-fly.
## Supported Encoders
### Image Encoders (DinoV2)
DinoV2 is a self-supervised vision transformer that produces high-quality image embeddings:
- **`dinov2_vits14`**: ViT-S/14 (384-dim) - Fastest, smaller model
- **`dinov2_vitb14`**: ViT-B/14 (768-dim) - **Recommended** - Good balance
- **`dinov2_vitl14`**: ViT-L/14 (1024-dim) - Best quality, slower
### Language Encoders (MiniLM)
MiniLM is a lightweight sentence transformer model:
- **`minilm-l6`**: MiniLM-L6-v2 (384-dim) - Faster
- **`minilm-l12`**: MiniLM-L12-v2 (384-dim) - **Recommended** - Better quality
## Usage
### Basic Command
```bash
python src/lerobot/datasets/generating_embeddings/generate_embeddings.py \
--repo-id lerobot/utokyo_xarm_bimanual \
--output-repo-id your-username/utokyo_xarm_bimanual_embeddings \
--image-encoder dinov2_vitb14 \
--language-encoder minilm-l12 \
--push-to-hub
```
### Lightweight Version (No Videos)
Removes video files to significantly reduce storage:
```bash
python src/lerobot/datasets/generating_embeddings/generate_embeddings.py \
--repo-id lerobot/utokyo_xarm_bimanual \
--output-repo-id your-username/utokyo_xarm_bimanual_lightweight \
--image-encoder dinov2_vitb14 \
--language-encoder minilm-l12 \
--remove-videos \
--push-to-hub
```
## Output
The script adds new features to your dataset:
### New Features
1. **`task_embedding`**: Language embedding for each frame
- Shape: `[384]` (MiniLM)
- One embedding per frame based on its task
2. **`{camera_key}_embedding`**: Image embedding for each camera view
- Shape: `[384]`, `[768]`, or `[1024]` depending on DinoV2 model
- Examples: `observation.images.top_embedding`, `observation.images.wrist_embedding`
### Using Embeddings in Training
```python
from lerobot.datasets.lerobot_dataset import LeRobotDataset
# Load dataset with embeddings
dataset = LeRobotDataset("your-username/utokyo_xarm_bimanual_embeddings")
# Access embeddings
item = dataset[0]
task_emb = item["task_embedding"] # Shape: [384]
img_emb = item["observation.images.top_embedding"] # Shape: [768]
# Use in your policy
# Instead of running encoders during training, use pre-computed embeddings
```
## Extending with New Encoders
The script is designed to be easily extensible. To add a new encoder:
### 1. Create Encoder Class
```python
class MyCustomImageEncoder(ImageEncoder):
"""Your custom image encoder."""
def __init__(self, device: str = "cuda"):
super().__init__(device)
# Load your model
self.model = load_my_model()
self.model = self.model.to(self.device)
self.model.eval()
def encode(self, images: list[np.ndarray]) -> np.ndarray:
"""Encode a batch of images."""
# Your encoding logic here
embeddings = []
for img in images:
emb = self.model(img)
embeddings.append(emb)
return np.array(embeddings)
@property
def embedding_dim(self) -> int:
"""Return embedding dimension."""
return 512 # Your embedding dimension
```
### 2. Add to Factory Function
```python
def get_image_encoder(encoder_name: str, device: str = "cuda") -> ImageEncoder:
encoders = {
"dinov2_vits14": lambda: DinoV2Encoder(model_name="dinov2_vits14", device=device),
"dinov2_vitb14": lambda: DinoV2Encoder(model_name="dinov2_vitb14", device=device),
"dinov2_vitl14": lambda: DinoV2Encoder(model_name="dinov2_vitl14", device=device),
# Add your encoder
"my_custom": lambda: MyCustomImageEncoder(device=device),
}
# ... rest of function
```
## Validating Embeddings
After generating embeddings, you can validate them using `validate_embeddings.py`:
```bash
python src/lerobot/datasets/generating_embeddings/validate_embeddings.py \
--original-repo-id lerobot/utokyo_xarm_bimanual \
--embeddings-repo-id pepijn223/utokyo_xarm_bimanual_embeddings \
--image-encoder dinov2_vitb14 \
--language-encoder minilm-l12 \
--num-samples 20
```
@@ -0,0 +1,147 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import numpy as np
import torch
from PIL import Image
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class ImageEncoder:
"""Base class for image encoders."""
def __init__(self, device: str = "cuda"):
self.device = torch.device(device if torch.cuda.is_available() else "cpu")
def encode(self, images: list[np.ndarray]) -> np.ndarray:
"""Encode a batch of images."""
raise NotImplementedError
class DinoV2Encoder(ImageEncoder):
"""DinoV2 image encoder.
DinoV2 is a self-supervised vision transformer that produces high-quality image embeddings.
Supports multiple model sizes (ViT-S/14, ViT-B/14, ViT-L/14).
"""
def __init__(self, model_name: str = "dinov2_vitb14", device: str = "cuda", batch_size: int = 32):
super().__init__(device)
self.batch_size = batch_size
self.model_name = model_name
logger.info(f"Loading DinoV2 model: {model_name}")
self.model = torch.hub.load("facebookresearch/dinov2", model_name) # nosec B614
self.model = self.model.to(self.device)
self.model.eval()
# DinoV2 preprocessing
from torchvision import transforms
self.transform = transforms.Compose(
[
transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
def encode(self, images: list[np.ndarray]) -> np.ndarray:
"""Encode a batch of images."""
embeddings = []
with torch.inference_mode():
for i in range(0, len(images), self.batch_size):
batch_images = images[i : i + self.batch_size]
# Convert numpy arrays to PIL Images and apply transforms
pil_images = [Image.fromarray(img.astype(np.uint8)) for img in batch_images]
tensors = torch.stack([self.transform(img) for img in pil_images]).to(self.device)
# Get embeddings
batch_embeddings = self.model(tensors).cpu().numpy()
embeddings.append(batch_embeddings)
return np.concatenate(embeddings, axis=0)
@property
def embedding_dim(self) -> int:
"""Return the embedding dimension based on model size."""
if "vits14" in self.model_name:
return 384 # DinoV2 ViT-S/14
elif "vitb14" in self.model_name:
return 768 # DinoV2 ViT-B/14
elif "vitl14" in self.model_name:
return 1024 # DinoV2 ViT-L/14
else:
return 768 # Default to ViT-B/14
class LanguageEncoder:
"""Base class for language encoders."""
def __init__(self, device: str = "cuda"):
self.device = torch.device(device if torch.cuda.is_available() else "cpu")
def encode(self, texts: list[str]) -> np.ndarray:
"""Encode a batch of texts."""
raise NotImplementedError
class MiniLMEncoder(LanguageEncoder):
"""MiniLM language encoder.
MiniLM is a lightweight sentence transformer model that produces high-quality text embeddings.
Supports L6 and L12 model sizes.
"""
def __init__(self, model_name: str = "sentence-transformers/all-MiniLM-L12-v2", device: str = "cuda"):
super().__init__(device)
self.model_name = model_name
logger.info(f"Loading MiniLM model: {model_name}")
from transformers import AutoModel, AutoTokenizer
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModel.from_pretrained(model_name).to(self.device)
self.model.eval()
def _mean_pooling(self, model_output, attention_mask):
"""Mean pooling to get sentence embeddings."""
token_embeddings = model_output[0]
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
input_mask_expanded.sum(1), min=1e-9
)
def encode(self, texts: list[str]) -> np.ndarray:
"""Encode a batch of texts."""
with torch.inference_mode():
encoded_input = self.tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
encoded_input = {k: v.to(self.device) for k, v in encoded_input.items()}
model_output = self.model(**encoded_input)
embeddings = self._mean_pooling(model_output, encoded_input["attention_mask"])
return embeddings.cpu().numpy()
@property
def embedding_dim(self) -> int:
"""Return the embedding dimension."""
return 384 # Both MiniLM-L6 and L12 output 384-dim embeddings
@@ -0,0 +1,329 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Generate embeddings for LeRobot datasets to make them more lightweight and efficient.
This script:
1. Loads a v3.0 LeRobot dataset from the hub
2. Computes embeddings for tasks (language commands) and frames (images)
3. Stores embeddings as new features in the dataset
4. Optionally removes video files to reduce size
5. Pushes the converted dataset to the hub
Current supported encoders:
- Image: DinoV2 (dinov2_vits14, dinov2_vitb14, dinov2_vitl14)
- Language: MiniLM (minilm-l6, minilm-l12)
The architecture is extensible - you can add more encoders by:
1. Creating a new encoder class inheriting from ImageEncoder or LanguageEncoder
2. Implementing the encode() method and embedding_dim property
3. Adding it to the get_image_encoder() or get_language_encoder() factory function
Usage example:
python src/lerobot/datasets/generating_embeddings/generate_embeddings.py \
--repo-id lerobot/utokyo_xarm_bimanual \
--output-repo-id lerobot/utokyo_xarm_bimanual_embeddings \
--image-encoder dinov2_vitb14 \
--language-encoder minilm-l12 \
--remove-videos \
--push-to-hub
"""
import argparse
import shutil
from pathlib import Path
import numpy as np
import torch
from tqdm import tqdm
from lerobot.datasets.generating_embeddings.encoders import (
DinoV2Encoder,
ImageEncoder,
LanguageEncoder,
MiniLMEncoder,
)
from lerobot.datasets.lerobot_dataset import LeRobotDataset
def get_image_encoder(encoder_name: str, device: str = "cuda") -> ImageEncoder:
"""Factory function to get image encoder.
To add a new encoder:
1. Create a new class inheriting from ImageEncoder
2. Implement encode() and embedding_dim property
3. Add it to the encoders dictionary below
"""
encoders = {
"dinov2_vits14": lambda: DinoV2Encoder(model_name="dinov2_vits14", device=device),
"dinov2_vitb14": lambda: DinoV2Encoder(model_name="dinov2_vitb14", device=device),
"dinov2_vitl14": lambda: DinoV2Encoder(model_name="dinov2_vitl14", device=device),
}
if encoder_name not in encoders:
raise ValueError(f"Unknown image encoder: {encoder_name}. Available options: {list(encoders.keys())}")
return encoders[encoder_name]()
def get_language_encoder(encoder_name: str, device: str = "cuda") -> LanguageEncoder:
"""Factory function to get language encoder.
To add a new encoder:
1. Create a new class inheriting from LanguageEncoder
2. Implement encode() and embedding_dim property
3. Add it to the encoders dictionary below
"""
encoders = {
"minilm-l6": lambda: MiniLMEncoder(
model_name="sentence-transformers/all-MiniLM-L6-v2", device=device
),
"minilm-l12": lambda: MiniLMEncoder(
model_name="sentence-transformers/all-MiniLM-L12-v2", device=device
),
}
if encoder_name not in encoders:
raise ValueError(
f"Unknown language encoder: {encoder_name}. Available options: {list(encoders.keys())}"
)
return encoders[encoder_name]()
def generate_embeddings_for_dataset(
repo_id: str,
output_repo_id: str,
image_encoder: ImageEncoder,
language_encoder: LanguageEncoder,
remove_videos: bool = False,
local_dir: Path | None = None,
output_local_dir: Path | None = None,
push_to_hub: bool = False,
):
"""Generate embeddings for a LeRobot dataset.
Args:
repo_id: Source dataset repository ID
output_repo_id: Output dataset repository ID
image_encoder: Image encoder instance
language_encoder: Language encoder instance
remove_videos: Whether to remove video files
local_dir: Local directory for source dataset
output_local_dir: Local directory for output dataset
push_to_hub: Whether to push to hub after conversion
"""
from lerobot.datasets.dataset_tools import modify_features
print(f"Loading dataset: {repo_id}")
dataset = LeRobotDataset(repo_id, root=local_dir, download_videos=True)
print(f"Dataset: {dataset.num_episodes} episodes, {dataset.num_frames} frames")
print("Computing task embeddings...")
unique_tasks = dataset.meta.tasks.index.tolist()
task_embeddings = {}
for task in tqdm(unique_tasks, desc="Encoding tasks"):
# Clean up task text
task_clean = task.strip().capitalize().strip(" .,!?-_")
embedding = language_encoder.encode([task_clean])[0]
task_embeddings[task] = embedding
print(f"Computed {len(task_embeddings)} task embeddings")
print("Processing frames and computing embeddings...")
all_task_embeddings = []
all_image_embeddings_dict = {cam_key: [] for cam_key in dataset.meta.camera_keys}
for frame_idx in tqdm(range(dataset.num_frames), desc="Processing frames"):
item = dataset.hf_dataset[frame_idx]
ep_idx = item["episode_index"].item()
task = dataset.meta.tasks.iloc[item["task_index"].item()].name
task_emb = task_embeddings[task]
all_task_embeddings.append(task_emb)
for cam_key in dataset.meta.camera_keys:
if cam_key in dataset.meta.video_keys:
current_ts = item["timestamp"].item()
video_frames = dataset._query_videos({cam_key: [current_ts]}, ep_idx)
img = video_frames[cam_key]
if isinstance(img, torch.Tensor):
if img.ndim == 4:
img = img[0] # (T, C, H, W) -> (C, H, W)
elif img.ndim != 3:
raise ValueError(f"Unexpected video frame shape {img.shape} for camera {cam_key}")
img_np = (img.permute(1, 2, 0).numpy() * 255).astype(np.uint8)
else:
img_np = np.array(img)
else:
img = item[cam_key]
if isinstance(img, torch.Tensor):
if img.ndim == 3:
img_np = (img.permute(1, 2, 0).numpy() * 255).astype(np.uint8)
else:
raise ValueError(f"Unexpected image shape {img.shape} for camera {cam_key}")
else:
img_np = np.array(img)
all_image_embeddings_dict[cam_key].append(img_np)
print("Computing image embeddings...")
image_embeddings_dict = {}
for cam_key, images in all_image_embeddings_dict.items():
print(f" {cam_key}: {len(images)} images")
embeddings = image_encoder.encode(images)
image_embeddings_dict[cam_key] = embeddings
all_task_embeddings = np.array(all_task_embeddings)
for cam_key in dataset.meta.camera_keys:
image_embeddings_dict[cam_key] = np.array(image_embeddings_dict[cam_key])
img_emb_dim = image_encoder.embedding_dim
lang_emb_dim = language_encoder.embedding_dim
add_features_dict = {
"task_embedding": (
all_task_embeddings,
{"dtype": "float32", "shape": [lang_emb_dim], "names": None},
),
}
for cam_key in dataset.meta.camera_keys:
add_features_dict[f"{cam_key}_embedding"] = (
image_embeddings_dict[cam_key],
{"dtype": "float32", "shape": [img_emb_dim], "names": None},
)
print("Adding embeddings to dataset...")
remove_features_list = None
if remove_videos:
remove_features_list = dataset.meta.video_keys
output_dataset = modify_features(
dataset=dataset,
add_features=add_features_dict,
remove_features=remove_features_list,
output_dir=output_local_dir,
repo_id=output_repo_id,
)
if remove_videos:
print("Removing video files...")
videos_dir = output_dataset.root / "videos"
if videos_dir.exists():
shutil.rmtree(videos_dir)
print(f"Saved to: {output_dataset.root}")
if push_to_hub:
print(f"Pushing to hub: {output_repo_id}")
output_dataset.push_to_hub(push_videos=not remove_videos)
print("Done!")
def main():
parser = argparse.ArgumentParser(
description="Generate embeddings for LeRobot datasets",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Basic usage with default encoders (DinoV2 ViT-B/14 + MiniLM-L12)
python src/lerobot/datasets/generating_embeddings/generate_embeddings.py \\
--repo-id lerobot/utokyo_xarm_bimanual \\
--output-repo-id your-username/utokyo_xarm_bimanual_embeddings \\
--image-encoder dinov2_vitb14 \\
--language-encoder minilm-l12 \\
--push-to-hub
# Generate embeddings and remove videos
python src/lerobot/datasets/generating_embeddings/generate_embeddings.py \\
--repo-id lerobot/utokyo_xarm_bimanual \\
--output-repo-id your-username/utokyo_xarm_bimanual_lightweight \\
--image-encoder dinov2_vitb14 \\
--language-encoder minilm-l12 \\
--remove-videos \\
--push-to-hub
Available image encoders:
- dinov2_vits14: DinoV2 ViT-S/14 (384-dim, faster)
- dinov2_vitb14: DinoV2 ViT-B/14 (768-dim, recommended)
- dinov2_vitl14: DinoV2 ViT-L/14 (1024-dim, best quality)
Available language encoders:
- minilm-l6: MiniLM-L6-v2 (384-dim, faster)
- minilm-l12: MiniLM-L12-v2 (384-dim, recommended)
""",
)
parser.add_argument("--repo-id", type=str, required=True, help="Source dataset repository ID")
parser.add_argument("--output-repo-id", type=str, required=True, help="Output dataset repository ID")
parser.add_argument(
"--image-encoder",
type=str,
default="dinov2_vitb14",
help="Image encoder to use (default: dinov2_vitb14)",
)
parser.add_argument(
"--language-encoder",
type=str,
default="minilm-l12",
help="Language encoder to use (default: minilm-l12)",
)
parser.add_argument(
"--remove-videos",
action="store_true",
help="Remove video files after generating embeddings",
)
parser.add_argument("--local-dir", type=str, default=None, help="Local directory for source dataset")
parser.add_argument(
"--output-local-dir", type=str, default=None, help="Local directory for output dataset"
)
parser.add_argument(
"--push-to-hub",
action="store_true",
help="Push the converted dataset to the hub",
)
parser.add_argument(
"--device",
type=str,
default="cuda",
help="Device to use for encoding (default: cuda)",
)
args = parser.parse_args()
# Load encoders
image_encoder = get_image_encoder(args.image_encoder, device=args.device)
language_encoder = get_language_encoder(args.language_encoder, device=args.device)
# Generate embeddings
generate_embeddings_for_dataset(
repo_id=args.repo_id,
output_repo_id=args.output_repo_id,
image_encoder=image_encoder,
language_encoder=language_encoder,
remove_videos=args.remove_videos,
local_dir=Path(args.local_dir) if args.local_dir else None,
output_local_dir=Path(args.output_local_dir) if args.output_local_dir else None,
push_to_hub=args.push_to_hub,
)
if __name__ == "__main__":
main()
@@ -0,0 +1,222 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Validate pre-computed embeddings against on-the-fly computed embeddings.
Usage:
python src/lerobot/datasets/generating_embeddings/validate_embeddings.py \
--original-repo-id lerobot/utokyo_xarm_bimanual \
--embeddings-repo-id <your_username>/utokyo_xarm_bimanual_embeddings \
--image-encoder dinov2_vitb14 \
--language-encoder minilm-l12 \
--num-samples 10
"""
import argparse
import numpy as np
import torch
from tqdm import tqdm
from lerobot.datasets.generating_embeddings.encoders import ImageEncoder, LanguageEncoder
from lerobot.datasets.generating_embeddings.generate_embeddings import (
get_image_encoder,
get_language_encoder,
)
from lerobot.datasets.lerobot_dataset import LeRobotDataset
def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
"""Compute cosine similarity between two vectors."""
return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
def validate_embeddings(
original_repo_id: str,
embeddings_repo_id: str,
image_encoder: ImageEncoder,
language_encoder: LanguageEncoder,
num_samples: int = 10,
device: str = "cuda",
):
"""Validate pre-computed embeddings against on-the-fly embeddings.
Args:
original_repo_id: Original dataset repository ID
embeddings_repo_id: Dataset with pre-computed embeddings repository ID
image_encoder: Image encoder instance
language_encoder: Language encoder instance
num_samples: Number of samples to validate
device: Device to use for encoding
"""
# Load both datasets
print("Loading datasets...")
original_dataset = LeRobotDataset(original_repo_id, download_videos=True)
embeddings_dataset = LeRobotDataset(embeddings_repo_id, download_videos=False)
# Verify both datasets have the same number of frames
assert original_dataset.num_frames == embeddings_dataset.num_frames, (
f"Frame count mismatch: original={original_dataset.num_frames}, "
f"embeddings={embeddings_dataset.num_frames}"
)
camera_keys = original_dataset.meta.camera_keys
# Check embedding features exist
expected_features = ["task_embedding"] + [f"{cam}_embedding" for cam in camera_keys]
for feat in expected_features:
if feat not in embeddings_dataset.features:
raise ValueError(f"Embedding feature not found: {feat}")
# Select random sample indices
sample_indices = np.random.choice(
original_dataset.num_frames, size=min(num_samples, original_dataset.num_frames), replace=False
)
print(f"Validating {len(sample_indices)} samples...")
# Track statistics
task_similarities = []
image_similarities = {cam: [] for cam in camera_keys}
for idx in tqdm(sample_indices, desc="Validating"):
idx = int(idx)
embeddings_item = embeddings_dataset[idx]
precomputed_task_emb = embeddings_item["task_embedding"].numpy()
precomputed_image_embs = {cam: embeddings_item[f"{cam}_embedding"].numpy() for cam in camera_keys}
original_item = original_dataset[idx]
# Get task and compute embedding
task = original_item["task"]
# Clean up task text (same as in generate_embeddings.py)
task_clean = task.strip().capitalize().strip(" .,!?-_")
onthefly_task_emb = language_encoder.encode([task_clean])[0]
# Get images and compute embeddings
onthefly_image_embs = {}
for cam in camera_keys:
img = original_item[cam]
# Convert to numpy if needed
if isinstance(img, torch.Tensor):
if img.ndim == 3: # (C, H, W)
img_np = (img.permute(1, 2, 0).numpy() * 255).astype(np.uint8)
else:
raise ValueError(f"Unexpected image shape: {img.shape}")
else:
img_np = np.array(img)
onthefly_image_embs[cam] = image_encoder.encode([img_np])[0]
# Task embedding comparison
task_sim = cosine_similarity(precomputed_task_emb, onthefly_task_emb)
task_similarities.append(task_sim)
# Image embedding comparison
for cam in camera_keys:
img_sim = cosine_similarity(precomputed_image_embs[cam], onthefly_image_embs[cam])
image_similarities[cam].append(img_sim)
# Results
print("\nResults:")
task_sim_threshold = 0.99
img_sim_threshold = 0.99
task_mean_sim = np.mean(task_similarities)
task_pass = task_mean_sim >= task_sim_threshold
print(f" Task: {task_mean_sim:.4f} {'' if task_pass else ''}")
for cam in camera_keys:
cam_mean_sim = np.mean(image_similarities[cam])
cam_pass = cam_mean_sim >= img_sim_threshold
print(f" {cam}: {cam_mean_sim:.4f} {'' if cam_pass else ''}")
image_pass = all(np.mean(image_similarities[cam]) >= img_sim_threshold for cam in camera_keys)
print()
if task_pass and image_pass:
print("✓ PASSED")
else:
print("✗ FAILED")
def main():
parser = argparse.ArgumentParser(
description="Validate and compare pre-computed embeddings with on-the-fly embeddings",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Example:
python src/lerobot/datasets/generating_embeddings/validate_embeddings.py \\
--original-repo-id lerobot/utokyo_xarm_bimanual \\
--embeddings-repo-id lerobot/utokyo_xarm_bimanual_embeddings \\
--image-encoder dinov2_vitb14 \\
--language-encoder minilm-l12 \\
--num-samples 20
""",
)
parser.add_argument("--original-repo-id", type=str, required=True, help="Original dataset repository ID")
parser.add_argument(
"--embeddings-repo-id",
type=str,
required=True,
help="Dataset with pre-computed embeddings repository ID",
)
parser.add_argument(
"--image-encoder",
type=str,
default="dinov2_vitb14",
help="Image encoder to use (default: dinov2_vitb14)",
)
parser.add_argument(
"--language-encoder",
type=str,
default="minilm-l12",
help="Language encoder to use (default: minilm-l12)",
)
parser.add_argument(
"--num-samples",
type=int,
default=10,
help="Number of samples to validate (default: 10)",
)
parser.add_argument(
"--device",
type=str,
default="cuda",
help="Device to use for encoding (default: cuda)",
)
args = parser.parse_args()
# Load encoders
image_encoder = get_image_encoder(args.image_encoder, device=args.device)
language_encoder = get_language_encoder(args.language_encoder, device=args.device)
# Validate embeddings
validate_embeddings(
original_repo_id=args.original_repo_id,
embeddings_repo_id=args.embeddings_repo_id,
image_encoder=image_encoder,
language_encoder=language_encoder,
num_samples=args.num_samples,
device=args.device,
)
if __name__ == "__main__":
main()
+151
View File
@@ -0,0 +1,151 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
SARM Temporal Sampler for reward model training.
Samples frames uniformly from episodes for SARM's 9-frame symmetric pattern:
- 1 initial frame + 4 frames before + current + 3 frames after
Boundary handling: clamp to first/last frame when indices go out of bounds.
This enables truly uniform sampling across entire episodes.
"""
import logging
from typing import Iterator, Optional
import numpy as np
import torch
from torch.utils.data import Sampler
import random
class SARMTemporalSampler(Sampler):
"""
Temporal sampler for SARM reward model training with symmetric/bidirectional sampling.
SARM uses 9 frames per sample:
- Frame 0: Initial frame of the episode (always frame 0)
- Frames 1-8: Symmetric context around current frame
Pattern: [t-4*gap, t-3*gap, t-2*gap, t-gap, t, t+gap, t+2*gap, t+3*gap]
Boundary handling:
- Early frames: backward indices clamp to 0 (e.g., [0,0,0,5,35,65,95,125])
- Late frames: forward indices clamp to last frame (e.g., [850,880,910,940,970,1000,1000,1000])
This enables truly uniform sampling across entire episodes.
Args:
dataset_from_index: Start indices of episodes (global dataset indices)
dataset_to_index: End indices of episodes (global dataset indices)
frame_gap: Gap between consecutive frames (default: 30 = 1 second at 30fps)
shuffle: Whether to shuffle sampling order
seed: Random seed for reproducibility
samples_per_epoch: Number of samples per epoch (default: 6400)
min_episode_length: Minimum episode length to include (default: 1)
"""
def __init__(
self,
dataset_from_index: np.ndarray,
dataset_to_index: np.ndarray,
frame_gap: int = 30,
shuffle: bool = True,
seed: Optional[int] = None,
samples_per_epoch: int = 6400,
min_episode_length: int = 1,
):
self.dataset_from_index = np.array(dataset_from_index)
self.dataset_to_index = np.array(dataset_to_index)
self.frame_gap = frame_gap
self.shuffle = shuffle
self.samples_per_epoch = samples_per_epoch
self.min_episode_length = min_episode_length
if seed is not None:
self.seed = seed
random.seed(seed)
np.random.seed(seed)
self.generator = torch.Generator().manual_seed(seed)
else:
self.generator = torch.Generator()
# Compute valid episodes and sampling positions (ALL frames for uniform sampling)
self._compute_valid_positions()
logging.info(
f"SARMTemporalSampler: {len(self.valid_episodes)} valid episodes, "
f"{len(self.all_valid_positions)} positions (uniform sampling), "
f"{self.samples_per_epoch} samples per epoch, "
f"frame_gap={frame_gap}, symmetric bidirectional pattern"
)
def _compute_valid_positions(self):
"""Compute valid episodes and ALL sampling positions for uniform sampling.
With symmetric bidirectional sampling, we can sample from ANY frame:
- Early frames: backward indices clamp to first frame
- Late frames: forward indices clamp to last frame
"""
self.valid_episodes = []
self.all_valid_positions = []
for ep_idx in range(len(self.dataset_from_index)):
ep_start = self.dataset_from_index[ep_idx]
ep_end = self.dataset_to_index[ep_idx]
episode_length = ep_end - ep_start
# Include all episodes with at least min_episode_length frames
if episode_length >= self.min_episode_length:
self.valid_episodes.append((ep_idx, ep_start, ep_end))
# Include ALL positions in the episode (truly uniform sampling)
for pos in range(ep_start, ep_end):
self.all_valid_positions.append(pos)
self.valid_episodes = np.array(self.valid_episodes)
self.all_valid_positions = np.array(self.all_valid_positions)
if len(self.all_valid_positions) == 0:
raise ValueError(
f"No valid sampling positions found! "
f"Check that episodes have at least {self.min_episode_length} frames."
)
def __len__(self) -> int:
return self.samples_per_epoch
def __iter__(self) -> Iterator[int]:
"""
Yields global dataset indices for uniform sampling across episodes.
Each yielded index represents the "current frame" position.
The dataset's observation_delta_indices then handles loading:
- Frame 0: Episode initial frame (via large negative delta clamping)
- Frames 1-8: Symmetric context around current frame (with boundary clamping)
For early frames: backward indices clamp to first frame (progress ~0%)
For late frames: forward indices clamp to last frame (progress ~100%)
"""
if self.shuffle:
# Randomly sample from all valid positions
for _ in range(self.samples_per_epoch):
idx = np.random.randint(0, len(self.all_valid_positions))
yield int(self.all_valid_positions[idx])
else:
# Sequential sampling with wrap-around
for i in range(self.samples_per_epoch):
idx = i % len(self.all_valid_positions)
yield int(self.all_valid_positions[idx])
+20
View File
@@ -35,6 +35,7 @@ from lerobot.policies.pi0.configuration_pi0 import PI0Config
from lerobot.policies.pi05.configuration_pi05 import PI05Config
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.sac.configuration_sac import SACConfig
from lerobot.policies.sarm.configuration_sarm import SARMConfig
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
@@ -103,6 +104,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy
return SmolVLAPolicy
elif name == "sarm":
from lerobot.policies.sarm.modeling_sarm import SARMRewardModel
return SARMRewardModel
elif name == "groot":
from lerobot.policies.groot.modeling_groot import GrootPolicy
@@ -322,6 +327,14 @@ def make_pre_post_processors(
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, SARMConfig):
from lerobot.policies.sarm.processor_sarm import make_sarm_pre_post_processors
processors = make_sarm_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
dataset_meta=kwargs.get("dataset_meta"),
)
elif isinstance(policy_cfg, GrootConfig):
from lerobot.policies.groot.processor_groot import make_groot_pre_post_processors
@@ -405,6 +418,13 @@ def make_policy(
if not cfg.input_features:
cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features}
kwargs["config"] = cfg
# Pass dataset_stats to the policy if available (needed for some policies like SARM)
if ds_meta is not None and hasattr(ds_meta, 'stats'):
kwargs["dataset_stats"] = ds_meta.stats
if ds_meta is not None:
kwargs["dataset_meta"] = ds_meta
if cfg.pretrained_path:
# Load a pretrained policy and override the config if needed (for example, if there are inference-time
+34
View File
@@ -0,0 +1,34 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from lerobot.policies.sarm.configuration_sarm import SARMConfig
from lerobot.policies.sarm.modeling_sarm import (
SARMRewardModel,
SARMTransformer,
)
from lerobot.policies.sarm.processor_sarm import (
SARMEncodingProcessorStep,
make_sarm_pre_post_processors,
)
__all__ = [
"SARMConfig",
"SARMRewardModel",
"SARMTransformer",
"SARMEncodingProcessorStep",
"make_sarm_pre_post_processors",
]
@@ -0,0 +1,186 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import PolicyFeature, FeatureType, NormalizationMode
from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
@PreTrainedConfig.register_subclass("sarm")
@dataclass
class SARMConfig(PreTrainedConfig):
"""Configuration class for SARM (Stage-Aware Reward Modeling)"""
# CLIP params
image_dim: int = 512
text_dim: int = 512
num_frames: int = 9 # 1 initial + 8 consecutive frames
frame_gap: int = 30 # Frame gap between frames (at 30 fps = 1 second)
# Architecture params
hidden_dim: int = 768
num_heads: int = 12
num_layers: int = 8
max_state_dim: int = 32
num_stages: int = 5 # Number of task stages (auto-updated from annotations if available)
subtask_names: list | None = None # List of subtask names (auto-populated from annotations)
temporal_proportions: list | None = None # Temporal proportions for each stage (auto-computed from annotations)
max_length: int = num_frames # Maximum video sequence length (matches num_frames)
use_temporal_sampler: bool = True # Always enable temporal sequence loading
# Training params
batch_size: int = 64
clip_batch_size: int = 64 # Batch size for CLIP encoding
dropout: float = 0.1
stage_loss_weight: float = 1.0 # Weight for stage classification loss when using subtask annotations
pretrained_model_path: str | None = None
device: str | None = None
# Processor settings
image_key: str = "observation.images.top" # Key for image used from the dataset
# State key in the dataset (for normalization)
state_key: str = "observation.state"
# Populated by the processor (video_features, state_features, text_features)
input_features: dict = field(default_factory=lambda: {})
# Output features
output_features: dict = field(default_factory=lambda: {
"stage": PolicyFeature(shape=(9, 5), type=FeatureType.REWARD),
"progress": PolicyFeature(shape=(9, 1), type=FeatureType.REWARD),
})
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.MEAN_STD,
"LANGUAGE": NormalizationMode.IDENTITY,
"REWARD": NormalizationMode.IDENTITY,
}
)
def __post_init__(self):
super().__post_init__()
# Add the image_key as VISUAL
if self.image_key:
self.input_features[self.image_key] = PolicyFeature(
shape=(480, 640, 3),
type=FeatureType.VISUAL
)
# Add state_key as STATE
self.input_features[self.state_key] = PolicyFeature(
shape=(self.max_state_dim,), # Single frame state, temporal sampling handles sequence
type=FeatureType.STATE
)
# Update output features with actual dimensions
self.output_features["stage"] = PolicyFeature(
shape=(self.num_frames, self.num_stages),
type=FeatureType.REWARD
)
self.output_features["progress"] = PolicyFeature(
shape=(self.num_frames, 1),
type=FeatureType.REWARD
)
# Validate configuration
if self.hidden_dim % self.num_heads != 0:
raise ValueError(
f"hidden_dim ({self.hidden_dim}) must be divisible by num_heads ({self.num_heads})"
)
if self.max_length != self.num_frames:
raise ValueError(
f"max_length ({self.max_length}) must equal num_frames ({self.num_frames})"
)
if self.num_stages < 2:
raise ValueError(f"num_stages must be at least 2, got {self.num_stages}")
def get_optimizer_preset(self) -> AdamWConfig:
"""Get default optimizer configuration for SARM training."""
return AdamWConfig(
lr=5e-5,
weight_decay=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
)
def get_scheduler_preset(self) -> CosineDecayWithWarmupSchedulerConfig:
"""Get default learning rate scheduler configuration."""
return CosineDecayWithWarmupSchedulerConfig(
peak_lr=5e-5,
decay_lr=5e-6,
num_warmup_steps=500,
num_decay_steps=50000,
)
def validate_features(self) -> None:
"""Validate input and output features."""
pass
@property
def observation_delta_indices(self) -> list[int]:
"""Load frames for SARM temporal sampling with SYMMETRIC/BIDIRECTIONAL pattern.
The model uses 9 frames with symmetric context around current frame:
- Frame 0: Initial frame of the episode (clamped via large negative delta)
- Frames 1-8: Symmetric context: 4 before + current + 3 after
Pattern: [initial, t-4*gap, t-3*gap, t-2*gap, t-gap, t, t+gap, t+2*gap, t+3*gap]
Boundary handling (done by dataset loader):
- Early frames: backward indices clamp to 0 (first frame)
- Late frames: forward indices clamp to episode end (last frame)
This enables truly uniform sampling across entire episodes.
Returns:
9 delta indices: [-1_000_000, -4*gap, -3*gap, -2*gap, -gap, 0, gap, 2*gap, 3*gap]
"""
initial_frame_delta = -1_000_000
# Symmetric pattern: 4 frames before, current (0), 3 frames after = 8 context frames
symmetric_deltas = [
-4 * self.frame_gap,
-3 * self.frame_gap,
-2 * self.frame_gap,
-1 * self.frame_gap,
0, # current frame
1 * self.frame_gap,
2 * self.frame_gap,
3 * self.frame_gap,
]
return [initial_frame_delta] + symmetric_deltas
@property
def action_delta_indices(self) -> None:
"""SARM is a reward model, not an action policy."""
return None
@property
def reward_delta_indices(self) -> None:
"""SARM doesn't use delta rewards."""
return None
+650
View File
@@ -0,0 +1,650 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import List, Union, Optional
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from transformers import CLIPModel, CLIPProcessor
from torch import Tensor
from lerobot.policies.sarm.configuration_sarm import SARMConfig
from lerobot.policies.sarm.sarm_utils import compute_cumulative_progress_batch, pad_state_to_max_dim
from lerobot.policies.pretrained import PreTrainedPolicy
class SARMTransformer(nn.Module):
"""
SARM Transformer model for stage-aware reward prediction.
This model has a dual-head architecture:
1. Stage estimator: Predicts the high-level task stage (classification)
2. Subtask estimator: Predicts fine-grained progress within the stage (regression)
"""
def __init__(
self,
video_dim: int = 512,
text_dim: int = 512,
max_state_dim: int = 32,
hidden_dim: int = 768,
num_heads: int = 12,
num_layers: int = 8,
num_stages: int = 5,
max_length: int = 9,
dropout: float = 0.1,
temporal_proportions: list[float] | None = None
):
super().__init__()
self.hidden_dim = hidden_dim
self.max_length = max_length
self.num_stages = num_stages
self.max_state_dim = max_state_dim
if temporal_proportions is None:
raise ValueError(
"temporal_proportions is required for SARM. "
"Provide subtask annotations in your dataset or set temporal_proportions in config."
)
# ᾱ_k: proportion for each stage
alpha = torch.tensor(temporal_proportions, dtype=torch.float32)
# P_k: cumulative proportion up to stage k (P_0 = 0)
cumulative = torch.zeros(num_stages + 1, dtype=torch.float32)
cumulative[1:] = torch.cumsum(alpha, dim=0)
self.register_buffer('alpha', alpha)
self.register_buffer('cumulative_prior', cumulative)
self.video_proj = nn.Linear(video_dim, hidden_dim)
self.text_proj = nn.Linear(text_dim, hidden_dim)
self.state_proj = nn.Linear(max_state_dim, hidden_dim)
# Position embedding only for the first frame
self.first_pos_embed = nn.Parameter(torch.randn(1, hidden_dim))
encoder_layer = nn.TransformerEncoderLayer(
d_model=hidden_dim,
nhead=num_heads,
dim_feedforward=hidden_dim * 4,
dropout=dropout,
batch_first=True
)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
# Stage estimator head (classification)
self.stage_head = nn.Sequential(
nn.Linear(hidden_dim, 512),
nn.LayerNorm(512),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(512, num_stages)
)
# Subtask estimator head (regression)
self.stage_embedding = nn.Embedding(num_stages, hidden_dim // 4)
subtask_input_dim = hidden_dim + hidden_dim // 4
self.subtask_head = nn.Sequential(
nn.Linear(subtask_input_dim, 512),
nn.LayerNorm(512),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(512, 1),
nn.Sigmoid()
)
# Attention mask
self.register_buffer("attention_mask", None, persistent=False)
def _get_attention_mask(self, seq_length: int, device: torch.device) -> torch.Tensor:
"""Generate or retrieve cached causal attention mask."""
if self.attention_mask is None or self.attention_mask.shape[0] != seq_length:
# Create causal mask
mask = nn.Transformer.generate_square_subsequent_mask(seq_length, device=device)
self.attention_mask = mask
return self.attention_mask
def forward(
self,
video_frames: torch.Tensor,
text_embed: torch.Tensor,
state_features: Optional[torch.Tensor] = None
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Forward pass through the SARM transformer.
Args:
video_frames: Video frame embeddings (batch_size, seq_len, video_dim)
text_embed: Text embeddings (batch_size, text_dim)
state_features: Joint state features (batch_size, seq_len, state_dim)
Returns:
Tuple of:
- Stage logits for each frame (batch_size, seq_len, num_stages)
- Stage probabilities (batch_size, seq_len, num_stages)
- Progress predictions for each frame (batch_size, seq_len, 1)
"""
# Project inputs to common dimension
video_embed = self.video_proj(video_frames) # [batch_size, seq_len, hidden_dim]
text_embed = self.text_proj(text_embed).unsqueeze(1) # [batch_size, 1, hidden_dim]
# Pad state features to max_state_dim before projection
state_features_padded = pad_state_to_max_dim(state_features, self.max_state_dim)
state_embed = self.state_proj(state_features_padded) # [batch_size, seq_len, hidden_dim]
# Fuse video and state features
video_embed = video_embed + state_embed
# Add positional embedding to first video frame
video_embed[:, 0] += self.first_pos_embed
# Combine sequence: [text, video_frames]
sequence = torch.cat([text_embed, video_embed], dim=1)
# Get causal attention mask
seq_length = sequence.shape[1]
attention_mask = self._get_attention_mask(seq_length, sequence.device)
# Pass through transformer with causal masking
transformed = self.transformer(sequence, mask=attention_mask, is_causal=True)
# Get frame features
frame_features = transformed[:, 1:] # [batch_size, seq_len, hidden_dim]
# Stage estimation
stage_logits = self.stage_head(frame_features) # [batch_size, seq_len, num_stages]
stage_probs = F.softmax(stage_logits, dim=-1) # [batch_size, seq_len, num_stages]
# Get predicted stage indices
stage_indices = torch.argmax(stage_probs, dim=-1) # [batch_size, seq_len]
# Get stage embeddings for conditioning
stage_embeds = self.stage_embedding(stage_indices)
# Concatenate frame features with stage embeddings
conditioned_features = torch.cat([frame_features, stage_embeds], dim=-1)
# Subtask progress estimation (conditioned on stage)
# τ̂ = within-subtask progress (0-1)
tau_preds = self.subtask_head(conditioned_features) # [batch_size, seq_len, 1]
# Convert τ̂ to cumulative progress ŷ using Paper Formula (2):
# ŷ = P_{k-1} + ᾱ_k × τ̂
progress_preds = compute_cumulative_progress_batch(
tau_preds, stage_indices, self.alpha, self.cumulative_prior
)
return stage_logits, stage_probs, progress_preds
class SARMRewardModel(PreTrainedPolicy):
"""
SARM Reward Model for stage-aware task completion rewards.
Per SARM paper (Appendix A.4): "We employ a frozen clip-vit-base-patch32 encoder
to process both RGB image sequences and task descriptions."
This model combines:
- CLIP for encoding video frames AND text descriptions
- SARMTransformer for predicting task stage and progress
- Optional RA-BC (Reward-Aligned Behavior Cloning) for weighted training
"""
name = "sarm"
config_class = SARMConfig
def __init__(self, config: SARMConfig, dataset_stats: dict | None = None, dataset_meta=None):
super().__init__(config, dataset_stats)
config.validate_features()
self.config = config
self.dataset_stats = dataset_stats
self.device = torch.device(config.device if config.device else "cuda" if torch.cuda.is_available() else "cpu")
# Load temporal proportions from dataset
if config.temporal_proportions is None and dataset_meta is not None:
self._load_temporal_proportions(dataset_meta)
logging.info("Loading CLIP encoder")
self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", use_fast=True)
self.clip_model.to(self.device)
self.clip_model.eval()
self.sarm_transformer = SARMTransformer(
video_dim=config.image_dim,
text_dim=config.text_dim,
max_state_dim=config.max_state_dim,
hidden_dim=config.hidden_dim,
num_heads=config.num_heads,
num_layers=config.num_layers,
num_stages=config.num_stages,
max_length=config.max_length,
dropout=config.dropout,
temporal_proportions=config.temporal_proportions
)
self.sarm_transformer.to(self.device)
logging.info(f"SARM initialized on {self.device}")
def _load_temporal_proportions(self, dataset_meta) -> None:
"""
Load pre-computed temporal proportions from dataset metadata JSON file.
The temporal proportions are computed during dataset annotation using SARM Paper Formula (1):
ᾱ_k = (1/M) × Σ_i (L_{i,k} / T_i)
"""
import json
proportions_path = dataset_meta.root / "meta" / "temporal_proportions.json"
if not proportions_path.exists():
raise ValueError(
f"Temporal proportions not found at {proportions_path}. "
"Run the subtask annotation tool first to compute and save temporal proportions."
)
with open(proportions_path, "r") as f:
temporal_proportions_dict = json.load(f)
# Sort subtask names for consistent ordering
subtask_names = sorted(temporal_proportions_dict.keys())
self.config.num_stages = len(subtask_names)
self.config.subtask_names = subtask_names
self.config.temporal_proportions = [temporal_proportions_dict[name] for name in subtask_names]
logging.info(f"Loaded {len(subtask_names)} subtasks: {subtask_names}")
logging.info(f"Temporal proportions: {temporal_proportions_dict}")
def to(self, device):
"""Override to method to ensure all components move together."""
super().to(device)
self.device = device if isinstance(device, torch.device) else torch.device(device)
self.clip_model.to(device)
self.sarm_transformer.to(device)
return self
@torch.no_grad()
def encode_images(self, images: np.ndarray) -> np.ndarray:
"""
Encode video frames using CLIP.
Args:
images: Video frames with shape (num_videos, num_frames, H, W, C) in uint8.
Can also be (num_frames, H, W, C) for a single video.
Returns:
Encoded image features (num_videos, num_frames, 512) or (num_frames, 512).
"""
# Handle single video case
single_video = False
if len(images.shape) == 4:
images = images[np.newaxis, ...]
single_video = True
assert len(images.shape) == 5, f"Expected 5D input (num_videos, num_frames, H, W, C), got {images.shape}"
all_embeddings = []
for video in images:
video_embeddings = []
# Convert frames to PIL images for CLIP processor
frames = []
for frame in video:
if frame.shape[0] == 3: # Channel first
frame = frame.transpose(1, 2, 0)
if frame.dtype != np.uint8:
frame = (frame * 255).astype(np.uint8) if frame.max() <= 1.0 else frame.astype(np.uint8)
frames.append(Image.fromarray(frame))
# Batch process frames with CLIP
for i in range(0, len(frames), self.config.clip_batch_size):
batch = frames[i:i + self.config.clip_batch_size]
inputs = self.clip_processor(images=batch, return_tensors="pt")
inputs = {k: v.to(self.device) for k, v in inputs.items()}
# Get image embeddings from CLIP
embeddings = self.clip_model.get_image_features(**inputs).detach().cpu()
# Handle single frame case
if embeddings.dim() == 1:
embeddings = embeddings.unsqueeze(0)
video_embeddings.append(embeddings)
video_embeddings = torch.cat(video_embeddings)
all_embeddings.append(video_embeddings)
result = torch.stack(all_embeddings).numpy()
if single_video:
result = result[0]
return result
@torch.no_grad()
def encode_text(self, text: Union[str, List[str]]) -> np.ndarray:
"""
Encode text using CLIP text encoder (per SARM paper A.4).
Args:
text: Text string or list of text strings.
Returns:
Encoded text features (batch_size, 512) or (512,) for single text.
"""
if isinstance(text, str):
text = [text]
single_text = True
else:
single_text = False
# Use CLIP's tokenizer directly (avoids image processor validation issues)
tokenizer = self.clip_processor.tokenizer
# Process in batches
all_embeddings = []
for i in range(0, len(text), self.config.batch_size):
batch_text = text[i:i + self.config.batch_size]
inputs = tokenizer(batch_text, return_tensors="pt", padding=True, truncation=True)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
text_embeddings = self.clip_model.get_text_features(**inputs)
all_embeddings.append(text_embeddings.cpu())
result = torch.cat(all_embeddings).numpy()
if single_text:
result = result[0]
return result
@torch.no_grad()
def calculate_rewards(
self,
text_embeddings: Union[np.ndarray, torch.Tensor],
video_embeddings: Union[np.ndarray, torch.Tensor],
state_features: Optional[Union[np.ndarray, torch.Tensor]] = None,
return_all_frames: bool = False,
return_stages: bool = False
) -> Union[np.ndarray, tuple]:
"""
Calculate rewards for given text, video, and state representations.
Args:
text_embeddings: Encoded text representations (batch_size, 512)
video_embeddings: Encoded video representations (batch_size, num_frames, 512)
state_features: Joint state features (batch_size, num_frames, state_dim)
return_all_frames: If True, return rewards for all frames
return_stages: If True, also return stage predictions
Returns:
If return_stages=False:
Reward values (batch_size,) or (batch_size, num_frames)
If return_stages=True:
Tuple of (rewards, stage_probs)
"""
if isinstance(text_embeddings, np.ndarray):
text_embeddings = torch.tensor(text_embeddings, dtype=torch.float32)
if isinstance(video_embeddings, np.ndarray):
video_embeddings = torch.tensor(video_embeddings, dtype=torch.float32)
if state_features is not None and isinstance(state_features, np.ndarray):
state_features = torch.tensor(state_features, dtype=torch.float32)
# Handle single sample case
if text_embeddings.dim() == 1:
text_embeddings = text_embeddings.unsqueeze(0)
video_embeddings = video_embeddings.unsqueeze(0)
if state_features is not None:
state_features = state_features.unsqueeze(0)
single_sample = True
else:
single_sample = False
# Process in batches
all_rewards = []
all_stage_probs = []
for i in range(0, len(video_embeddings), self.config.batch_size):
batch_texts = text_embeddings[i:i + self.config.batch_size].to(self.device)
batch_videos = video_embeddings[i:i + self.config.batch_size].to(self.device)
batch_states = None
if state_features is not None:
batch_states = state_features[i:i + self.config.batch_size].to(self.device)
# Get predictions
stage_logits, stage_probs, progress_preds = self.sarm_transformer(
batch_videos.float(), batch_texts.float(), batch_states.float() if batch_states is not None else None
)
if return_all_frames:
all_rewards.append(progress_preds.squeeze(-1).cpu())
else:
# Return only last frame reward
all_rewards.append(progress_preds[:, -1, 0].cpu())
if return_stages:
all_stage_probs.append(stage_probs.cpu())
rewards = torch.cat(all_rewards).numpy()
if single_sample:
rewards = rewards[0] if not return_all_frames else rewards[0]
if return_stages:
stage_probs = torch.cat(all_stage_probs).numpy()
if single_sample:
stage_probs = stage_probs[0]
return rewards, stage_probs
return rewards
def train(self, mode: bool = True):
"""Overwrite train method to ensure CLIP encoder stays frozen during training"""
super().train(mode)
self.clip_model.eval()
self.sarm_transformer.train(mode)
return self
def eval(self):
"""Overwrite eval method to ensure CLIP encoder stays frozen during evaluation"""
return self.train(False)
def parameters(self):
"""Override to return trainable parameters (only SARM transformer, not CLIP encoder)."""
return self.sarm_transformer.parameters()
def get_optim_params(self):
"""Override to return optimizer parameters (only SARM transformer, not CLIP encoder)."""
return self.parameters()
def reset(self):
"""Required by PreTrainedPolicy but not used for reward models."""
pass
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
"""Required by PreTrainedPolicy but not used for reward models."""
raise NotImplementedError("SARM model does not predict action chunks")
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Required by PreTrainedPolicy but not used for SARM."""
raise NotImplementedError("SARM model does not select actions")
def _apply_temporal_augmentation(
self,
video: torch.Tensor,
progress: torch.Tensor,
state: torch.Tensor | None,
max_length: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
"""Apply temporal augmentation by appending reversed frames (SARM paper A.4).
This helps the model learn to handle non-monotonic progress (failures, recoveries).
Appends 1-4 reversed frames to simulate going backwards in task progress.
"""
num_reverse = random.randint(1, min(4, max_length - 1))
# Reverse and take frames (skip first which is last of original)
reversed_video = video.flip(0)[1:num_reverse + 1]
reversed_progress = progress.flip(0)[1:num_reverse + 1]
# Concatenate and trim
video = torch.cat([video, reversed_video], dim=0)[:max_length]
progress = torch.cat([progress, reversed_progress], dim=0)[:max_length]
if state is not None:
reversed_state = state.flip(0)[1:num_reverse + 1]
state = torch.cat([state, reversed_state], dim=0)[:max_length]
return video, progress, state
def _ensure_sequence_length(self, tensor: torch.Tensor, target_len: int) -> torch.Tensor:
"""Pad or trim tensor to target length."""
current_len = tensor.shape[0]
if current_len == target_len:
return tensor
if current_len < target_len:
padding = target_len - current_len
return torch.cat([tensor, tensor[-1:].expand(padding, *tensor.shape[1:])])
return tensor[:target_len]
def forward(self, batch):
"""
Forward pass for SARM reward model training.
Uses annotation-based progress targets following SARM paper Eq. 2:
yt = Pk-1 + α̅k × τt
where:
- τt = (t - sk) / (ek - sk) is within-subtask normalized time
- Pk-1 is cumulative prior (sum of previous subtask proportions)
- α̅k is the temporal proportion for subtask k
Args:
batch: Dictionary with 'observation' containing:
- 'video_features': (B, T, 512) pre-encoded video features
- 'text_features': (B, 512) pre-encoded text features (CLIP)
- 'state_features': (B, T, state_dim) joint state features
- 'stage_labels': (B, T) stage labels from annotations
- 'progress_targets': (B, T, 1) progress targets from annotations
Returns:
Tuple of (total_loss, output_dict with loss components)
"""
observation = batch.get('observation', batch)
# Extract required features
video_features = observation['video_features'].to(self.device)
text_features = observation['text_features'].to(self.device)
state_features = observation.get('state_features').to(self.device)
batch_size = video_features.shape[0]
max_length = self.config.num_frames
# Ensure 3D video features (B, T, D)
if video_features.dim() == 2:
video_features = video_features.unsqueeze(1).expand(-1, max_length, -1)
if state_features is not None and state_features.dim() == 2:
state_features = state_features.unsqueeze(1).expand(-1, max_length, -1)
# Get annotation-based progress targets (required for SARM paper formula)
progress_from_annotations = observation.get('progress_targets')
if progress_from_annotations is None:
raise ValueError("progress_targets from annotations is required for SARM training")
progress_from_annotations = progress_from_annotations.to(self.device)
if progress_from_annotations.dim() == 2:
progress_from_annotations = progress_from_annotations.unsqueeze(-1)
if progress_from_annotations.dim() == 3 and progress_from_annotations.shape[0] == 1:
progress_from_annotations = progress_from_annotations.expand(batch_size, -1, -1)
# Process each sample: apply temporal REWIND augmentation
processed_videos = []
processed_states = []
progress_targets = []
for i in range(batch_size):
video = video_features[i]
state = state_features[i] if state_features is not None else None
progress = progress_from_annotations[i].squeeze(-1) # (T,)
# Apply temporal REWIND augmentation with 50% probability: appends up to 4 reversed frames to simulate failures/recoveries
if random.random() < 0.5:
video, progress, state = self._apply_temporal_augmentation(video, progress, state, max_length)
# Ensure correct sequence length
video = self._ensure_sequence_length(video, max_length)
progress = self._ensure_sequence_length(progress.unsqueeze(-1), max_length).squeeze(-1)
if state is not None:
state = self._ensure_sequence_length(state, max_length)
processed_videos.append(video)
progress_targets.append(progress)
if state is not None:
processed_states.append(state)
# Stack into batches
processed_videos = torch.stack(processed_videos)
progress_targets = torch.stack(progress_targets).unsqueeze(-1) # (B, T, 1)
processed_states = torch.stack(processed_states) if processed_states else None
# Get model predictions
stage_logits, stage_probs, progress_preds = self.sarm_transformer(
processed_videos, text_features, processed_states
)
# Compute progress loss (MSE)
progress_loss = F.mse_loss(progress_preds, progress_targets)
output_dict = {'progress_loss': progress_loss.item()}
total_loss = progress_loss
# Compute stage loss (cross-entropy)
stage_labels = observation.get('stage_labels')
if stage_labels is None:
raise ValueError("stage_labels from annotations is required for SARM training")
stage_labels = stage_labels.to(self.device)
if stage_labels.dim() == 1:
stage_labels = stage_labels.unsqueeze(0).expand(batch_size, -1)
stage_loss = compute_stage_loss(stage_logits, stage_labels)
total_loss = total_loss + self.config.stage_loss_weight * stage_loss
output_dict['stage_loss'] = stage_loss.item()
# Misaligned loss: 20% probability
if random.random() < 0.2:
shuffle_idx = torch.randperm(batch_size, device=self.device)
_, _, misaligned_preds = self.sarm_transformer(
processed_videos, text_features[shuffle_idx], processed_states
)
misaligned_loss = F.mse_loss(misaligned_preds, torch.zeros_like(misaligned_preds))
total_loss = total_loss + misaligned_loss
output_dict['misaligned_loss'] = misaligned_loss.item()
output_dict['total_loss'] = total_loss.item()
return total_loss, output_dict
def compute_stage_loss(stage_logits: torch.Tensor, target_stages: torch.Tensor) -> torch.Tensor:
_, _, num_stages = stage_logits.shape
stage_logits_flat = stage_logits.reshape(-1, num_stages)
target_stages_flat = target_stages.reshape(-1)
loss = F.cross_entropy(stage_logits_flat, target_stages_flat)
return loss
+644
View File
@@ -0,0 +1,644 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
import numpy as np
import torch
from PIL import Image
import pandas as pd
from transformers import CLIPModel, CLIPProcessor
from lerobot.processor.core import TransitionKey
from lerobot.policies.sarm.configuration_sarm import SARMConfig
from lerobot.policies.sarm.sarm_utils import compute_tau, compute_cumulative_progress_batch, pad_state_to_max_dim
from lerobot.processor import (
ProcessorStep,
PolicyProcessorPipeline,
PolicyAction,
DeviceProcessorStep,
AddBatchDimensionProcessorStep,
NormalizerProcessorStep,
)
from lerobot.processor.converters import (
policy_action_to_transition,
transition_to_policy_action,
from_tensor_to_numpy,
)
from lerobot.processor.pipeline import PipelineFeatureType
from lerobot.processor.core import EnvTransition, TransitionKey
from lerobot.configs.types import PolicyFeature, FeatureType
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
class SARMEncodingProcessorStep(ProcessorStep):
"""ProcessorStep that encodes images and text with CLIP."""
def __init__(
self,
config: SARMConfig,
image_key: str | None = None,
dataset_meta = None,
dataset_stats: dict | None = None,
):
super().__init__()
self.config = config
self.image_key = image_key or config.image_key
self.dataset_meta = dataset_meta
self.dataset_stats = dataset_stats
self.temporal_proportions = {name: prop for name, prop in zip(self.config.subtask_names, self.config.temporal_proportions)}
self.subtask_names = self.config.subtask_names
self.device = torch.device(
self.config.device if self.config.device
else "cuda" if torch.cuda.is_available() else "cpu"
)
self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", use_fast=True)
self.clip_model.to(self.device)
self.clip_model.eval()
def _find_episode_for_frame(self, frame_idx: int) -> int:
"""Find the episode index for a given frame index."""
for ep_idx in range(len(self.dataset_meta.episodes)):
ep_start = self.dataset_meta.episodes[ep_idx]["dataset_from_index"]
ep_end = self.dataset_meta.episodes[ep_idx]["dataset_to_index"]
if ep_start <= frame_idx < ep_end:
return ep_idx
return 0
def _get_episode_indices(self, frame_indices: np.ndarray, episode_index) -> np.ndarray:
"""Get episode indices for each frame index."""
if episode_index is None:
return np.array([self._find_episode_for_frame(int(f)) for f in frame_indices])
episode_indices = np.atleast_1d(np.asarray(from_tensor_to_numpy(episode_index)))
# If single episode but multiple frames, compute episode for each frame
if len(episode_indices) == 1 and len(frame_indices) > 1:
return np.array([self._find_episode_for_frame(int(f)) for f in frame_indices])
return episode_indices
def _compute_absolute_indices(self, frame_idx: int, ep_start: int, ep_end: int, num_frames: int) -> torch.Tensor:
"""Compute absolute frame indices for symmetric bidirectional pattern.
Pattern: [ep_start, t-4*gap, t-3*gap, t-2*gap, t-gap, t, t+gap, t+2*gap, t+3*gap]
Boundary handling:
- Backward indices clamp to ep_start (first frame)
- Forward indices clamp to ep_end - 1 (last frame)
"""
indices = []
indices.append(ep_start) # Initial frame is always episode start
# Symmetric context: 4 before, current, 3 after
num_before = 4
num_after = 3
last_valid_frame = ep_end - 1
# Frames before current (clamp to first frame)
for i in range(num_before, 0, -1):
idx = max(ep_start, frame_idx - i * self.config.frame_gap)
indices.append(idx)
# Current frame
indices.append(frame_idx)
# Frames after current (clamp to last frame)
for i in range(1, num_after + 1):
idx = min(last_valid_frame, frame_idx + i * self.config.frame_gap)
indices.append(idx)
return torch.tensor(indices)
def _compute_episode_metadata(
self,
frame_indices: np.ndarray,
episode_indices: np.ndarray,
num_frames: int,
) -> tuple[list | torch.Tensor, torch.Tensor, torch.Tensor]:
"""Compute episode metadata for all samples.
Returns:
Tuple of (absolute_frame_indices, remaining_lengths, episode_lengths)
"""
absolute_indices_list = []
remaining_lengths = []
episode_lengths = []
for ep_idx, frame_idx in zip(episode_indices.tolist(), frame_indices.tolist()):
ep_idx, frame_idx = int(ep_idx), int(frame_idx)
ep_start = self.dataset_meta.episodes[ep_idx]["dataset_from_index"]
ep_end = self.dataset_meta.episodes[ep_idx]["dataset_to_index"]
episode_lengths.append(ep_end - ep_start)
abs_indices = self._compute_absolute_indices(frame_idx, ep_start, ep_end, num_frames)
absolute_indices_list.append(abs_indices)
remaining_lengths.append(ep_end - abs_indices[0].item())
return absolute_indices_list, torch.tensor(remaining_lengths), torch.tensor(episode_lengths)
def _compute_stage_and_progress_for_frame(
self,
current_frame: int,
subtask_names: list,
subtask_start_frames: list,
subtask_end_frames: list,
transition_smoothing_frames: int = 15,
) -> tuple[int, float, dict[int, float] | None]:
"""Compute stage index, cumulative progress, and soft stage labels for a single frame.
Implements SARM Paper Formula (2):
y_t = P_{k-1} + ᾱ_k × τ_t
where:
- τ_t = (t - s_k) / (e_k - s_k) is within-subtask progress
- P_{k-1} is cumulative prior (sum of previous subtask proportions)
- ᾱ_k is the temporal proportion for subtask k
Additionally computes soft stage labels near transitions to mitigate discrete jumps
in the stage classifier. Near stage boundaries, labels are blended between adjacent
stages to encourage smoother predictions.
Args:
current_frame: Frame index relative to episode start
subtask_names: List of subtask names for this episode
subtask_start_frames: List of subtask start frames
subtask_end_frames: List of subtask end frames
transition_smoothing_frames: Number of frames over which to smooth labels near transitions
Returns:
Tuple of (stage_idx, cumulative_progress, soft_stage_labels)
- stage_idx: Hard stage index (for compatibility)
- cumulative_progress: Progress value in [0, 1]
- soft_stage_labels: Dict mapping stage_idx -> probability, or None if not near transition
"""
# Get temporal proportions as list for compute_cumulative_progress
temporal_proportions_list = [
self.temporal_proportions.get(name, 0.0) for name in self.subtask_names
]
num_stages = len(self.subtask_names)
# Find which subtask this frame belongs to
for j, (name, start_frame, end_frame) in enumerate(zip(subtask_names, subtask_start_frames, subtask_end_frames)):
if current_frame >= start_frame and current_frame <= end_frame:
# Found the subtask, get its global index
stage_idx = self.subtask_names.index(name) if name in self.subtask_names else 0
# Compute τ_t using utility function (Paper Formula 2)
tau = compute_tau(current_frame, start_frame, end_frame)
# Compute cumulative progress using utility function (Paper Formula 2)
cumulative_progress = compute_cumulative_progress_batch(
tau, stage_idx, temporal_proportions_list
)
# Compute soft stage labels near transitions
soft_stage_labels = None
frames_from_start = current_frame - start_frame
frames_to_end = end_frame - current_frame
if frames_from_start < transition_smoothing_frames and j > 0:
# Near start of stage - blend with previous stage
blend = frames_from_start / transition_smoothing_frames
prev_name = subtask_names[j - 1]
prev_stage_idx = self.subtask_names.index(prev_name) if prev_name in self.subtask_names else max(0, stage_idx - 1)
soft_stage_labels = {prev_stage_idx: 1.0 - blend, stage_idx: blend}
elif frames_to_end < transition_smoothing_frames and j < len(subtask_names) - 1:
# Near end of stage - blend with next stage
blend = frames_to_end / transition_smoothing_frames
next_name = subtask_names[j + 1]
next_stage_idx = self.subtask_names.index(next_name) if next_name in self.subtask_names else min(num_stages - 1, stage_idx + 1)
soft_stage_labels = {stage_idx: blend, next_stage_idx: 1.0 - blend}
return stage_idx, cumulative_progress, soft_stage_labels
# No matching subtask found
if current_frame < subtask_start_frames[0]:
return 0, 0.0, None
elif current_frame > subtask_end_frames[-1]:
return len(self.subtask_names) - 1, 1.0, None
else:
# Between subtasks - use previous subtask's end state (tau = 1.0)
for j in range(len(subtask_names) - 1):
if current_frame > subtask_end_frames[j] and current_frame < subtask_start_frames[j + 1]:
name = subtask_names[j]
stage_idx = self.subtask_names.index(name) if name in self.subtask_names else j
# Completed subtask, so tau = 1.0
cumulative_progress = compute_cumulative_progress_batch(
1.0, stage_idx, temporal_proportions_list
)
return stage_idx, cumulative_progress, None
return 0, 0.0, None
def _compute_labels_for_sample(
self,
frame_idx: int,
ep_idx: int,
seq_len: int,
episodes_df: pd.DataFrame,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None] | tuple[None, None, None]:
"""Compute stage labels, progress targets, and soft stage labels for symmetric bidirectional pattern.
Pattern: [initial, t-4*gap, t-3*gap, t-2*gap, t-gap, t, t+gap, t+2*gap, t+3*gap]
Boundary handling:
- Before episode start: clamp to frame 0 (progress ~0%)
- After episode end: clamp to last frame (progress ~100%)
Soft stage labels are computed near stage transitions to mitigate discrete jumps.
Args:
frame_idx: The frame index for this sample
ep_idx: The episode index
seq_len: Number of frames in the sequence
episodes_df: DataFrame with episode metadata
Returns:
Tuple of (stage_labels, progress_targets, soft_stage_labels):
- stage_labels: (T,) hard stage indices
- progress_targets: (T, 1) progress values
- soft_stage_labels: (T, num_stages) soft probability labels, or None if no transitions nearby
"""
# Check if episode has valid annotations
if ep_idx >= len(episodes_df):
return None, None, None
subtask_names = episodes_df.loc[ep_idx, 'subtask_names']
if subtask_names is None or (isinstance(subtask_names, float) and pd.isna(subtask_names)):
return None, None, None
subtask_start_frames = episodes_df.loc[ep_idx, 'subtask_start_frames']
subtask_end_frames = episodes_df.loc[ep_idx, 'subtask_end_frames']
ep_start = self.dataset_meta.episodes[ep_idx]["dataset_from_index"]
ep_end = self.dataset_meta.episodes[ep_idx]["dataset_to_index"]
ep_length = ep_end - ep_start
last_valid_frame = ep_length - 1
num_stages = len(self.subtask_names)
# Generate labels for each frame in the sequence
stage_labels = []
progress_targets = []
soft_labels_list = [] # List of soft label dicts (or None)
has_any_soft_labels = False
# Symmetric pattern: initial + 4 before + current + 3 after = 9 frames
num_before = 4
num_after = 3
for i in range(seq_len):
if i == 0:
# Position 0: Initial frame of the episode
current_frame = 0 # Relative to episode start
elif i <= num_before:
# Positions 1-4: frames before current (with clamping to first frame)
offset = -(num_before - i + 1) * self.config.frame_gap
current_frame = max(0, frame_idx + offset - ep_start)
elif i == num_before + 1:
# Position 5: current frame
current_frame = frame_idx - ep_start
else:
# Positions 6-8: frames after current (with clamping to last frame)
offset = (i - num_before - 1) * self.config.frame_gap
current_frame = min(last_valid_frame, frame_idx + offset - ep_start)
stage_idx, cumulative_progress, soft_stage_labels = self._compute_stage_and_progress_for_frame(
current_frame, subtask_names, subtask_start_frames, subtask_end_frames
)
stage_labels.append(stage_idx)
progress_targets.append(cumulative_progress)
soft_labels_list.append(soft_stage_labels)
if soft_stage_labels is not None:
has_any_soft_labels = True
stage_labels = torch.tensor(stage_labels, dtype=torch.long)
progress_targets = torch.tensor(progress_targets, dtype=torch.float32).unsqueeze(-1)
# Convert soft labels to tensor if any exist
soft_stage_labels_tensor = None
if has_any_soft_labels:
soft_stage_labels_tensor = torch.zeros(seq_len, num_stages, dtype=torch.float32)
for i, soft_dict in enumerate(soft_labels_list):
if soft_dict is not None:
for stage_idx, prob in soft_dict.items():
soft_stage_labels_tensor[i, stage_idx] = prob
else:
# Use hard one-hot label
soft_stage_labels_tensor[i, stage_labels[i]] = 1.0
return stage_labels, progress_targets, soft_stage_labels_tensor
def _generate_stage_and_progress_labels(self, frame_index, episode_index, video_features):
"""Generate stage labels, progress targets, and soft stage labels from subtask annotations.
Args:
frame_index: Current frame index or tensor of indices
episode_index: Episode index or tensor of indices
video_features: Video features tensor to determine sequence length
Returns:
Tuple of (stage_labels, progress_targets, soft_stage_labels) or (None, None, None) if no annotations.
- stage_labels: (B, T) hard stage indices
- progress_targets: (B, T, 1) progress values
- soft_stage_labels: (B, T, num_stages) soft probability labels, or None
"""
if self.temporal_proportions is None or episode_index is None:
return None, None, None
# Normalize inputs to numpy arrays
frame_indices = np.atleast_1d(np.asarray(from_tensor_to_numpy(frame_index)))
episode_indices = self._get_episode_indices(frame_indices, episode_index)
# Determine sequence length
if video_features is not None and video_features.dim() >= 2:
seq_len = video_features.shape[1]
else:
seq_len = 1
episodes_df = self.dataset_meta.episodes.to_pandas()
num_stages = len(self.subtask_names)
all_stage_labels = []
all_progress_targets = []
all_soft_stage_labels = []
has_any_soft_labels = False
for ep_idx, frame_idx in zip(episode_indices.tolist(), frame_indices.tolist()):
stage_labels, progress_targets, soft_labels = self._compute_labels_for_sample(
int(frame_idx), int(ep_idx), seq_len, episodes_df
)
if stage_labels is None:
all_stage_labels.append(torch.zeros(seq_len, dtype=torch.long))
all_progress_targets.append(torch.zeros(seq_len, 1, dtype=torch.float32))
all_soft_stage_labels.append(None)
else:
all_stage_labels.append(stage_labels)
all_progress_targets.append(progress_targets)
all_soft_stage_labels.append(soft_labels)
if soft_labels is not None:
has_any_soft_labels = True
stacked_stage_labels = torch.stack(all_stage_labels, dim=0)
stacked_progress_targets = torch.stack(all_progress_targets, dim=0)
# Stack soft labels if any exist
stacked_soft_labels = None
if has_any_soft_labels:
soft_labels_tensors = []
for i, soft_labels in enumerate(all_soft_stage_labels):
if soft_labels is not None:
soft_labels_tensors.append(soft_labels)
else:
# Create one-hot from hard labels
one_hot = torch.zeros(seq_len, num_stages, dtype=torch.float32)
for t in range(seq_len):
one_hot[t, all_stage_labels[i][t]] = 1.0
soft_labels_tensors.append(one_hot)
stacked_soft_labels = torch.stack(soft_labels_tensors, dim=0)
return stacked_stage_labels, stacked_progress_targets, stacked_soft_labels
def __call__(self, transition: EnvTransition) -> EnvTransition:
"""Encode images, text, and normalize states in the transition."""
new_transition = transition.copy() if hasattr(transition, 'copy') else dict(transition)
observation = new_transition.get(TransitionKey.OBSERVATION)
image = observation.get(self.image_key)
if isinstance(image, torch.Tensor):
image = image.cpu().numpy()
video_features = self._encode_images_batch(image)
observation['video_features'] = video_features
# Extract state and pad to max_state_dim (already normalized by NormalizerProcessorStep)
state_key = self.config.state_key
state_data = observation.get(state_key)
if isinstance(state_data, torch.Tensor):
state_tensor = state_data.float()
else:
state_tensor = torch.tensor(state_data, dtype=torch.float32)
observation['state_features'] = pad_state_to_max_dim(state_tensor, self.config.max_state_dim)
comp_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
# Get task description from dataset (complementary_data["task"])
task = comp_data.get('task')
if isinstance(task, list):
# If batch, take first task (assuming same task for all items in batch)
task = task[0] if task else ""
# Encode text with CLIP
batch_size = video_features.shape[0]
observation['text_features'] = self._encode_text_clip(task, batch_size)
frame_index = comp_data.get('index')
episode_index = comp_data.get('episode_index')
if frame_index is None:
raise ValueError("Frame index ('index') not found in COMPLEMENTARY_DATA")
if episode_index is None:
raise ValueError("Episode index ('episode_index') not found in COMPLEMENTARY_DATA")
# Compute episode metadata if dataset_meta is available
if self.dataset_meta is not None:
frame_indices = np.atleast_1d(np.asarray(from_tensor_to_numpy(frame_index)))
episode_indices = self._get_episode_indices(frame_indices, episode_index)
# Determine number of frames from video features
if video_features.dim() >= 2:
num_frames = video_features.shape[1]
else:
num_frames = 1
abs_indices, remaining, ep_lengths = self._compute_episode_metadata(
frame_indices, episode_indices, num_frames
)
observation['absolute_frame_indices'] = abs_indices
observation['remaining_length'] = remaining
observation['episode_length'] = ep_lengths
# Generate stage labels, progress targets, and soft stage labels from subtask annotations
if self.temporal_proportions is not None and self.dataset_meta is not None:
stage_labels, progress_targets, soft_stage_labels = self._generate_stage_and_progress_labels(
frame_index, episode_index, video_features
)
if stage_labels is not None:
observation['stage_labels'] = stage_labels
observation['progress_targets'] = progress_targets
if soft_stage_labels is not None:
observation['soft_stage_labels'] = soft_stage_labels
new_transition[TransitionKey.OBSERVATION] = observation
return new_transition
@torch.no_grad()
def _encode_images_batch(self, images: np.ndarray) -> torch.Tensor:
"""Encode a batch of images using CLIP.
Args:
images: Batched images with shape: (B, T, C, H, W)
Returns:
Encoded feature vectors with shape (B, T, 512)
"""
batch_size, seq_length = images.shape[0], images.shape[1]
images = images.reshape(batch_size * seq_length, *images.shape[2:])
# Convert to list of PIL images
num_frames = images.shape[0]
images_list = []
for i in range(num_frames):
img = images[i]
if img.shape[0] in [1, 3]: # Channel first (C, H, W)
img = img.transpose(1, 2, 0)
# Handle single channel
if img.shape[-1] == 1:
img = np.repeat(img, 3, axis=-1)
# Convert to uint8
if img.dtype != np.uint8:
img = (img * 255).astype(np.uint8) if img.max() <= 1.0 else img.astype(np.uint8)
images_list.append(Image.fromarray(img))
# Encode each batch
all_embeddings = []
for i in range(0, num_frames, self.config.clip_batch_size):
batch_imgs = images_list[i:i + self.config.clip_batch_size]
# Process with CLIP
inputs = self.clip_processor(images=batch_imgs, return_tensors="pt")
inputs = {k: v.to(self.device) for k, v in inputs.items()}
# Get image embeddings
embeddings = self.clip_model.get_image_features(**inputs).detach().cpu()
# Handle single frame case
if embeddings.dim() == 1:
embeddings = embeddings.unsqueeze(0)
all_embeddings.append(embeddings)
# Concatenate all embeddings
all_embeddings = torch.cat(all_embeddings) # (B*T, 512)
# Reshape back
all_embeddings = all_embeddings.reshape(batch_size, seq_length, -1) # (B, T, 512)
return all_embeddings
@torch.no_grad()
def _encode_text_clip(self, text: str, batch_size: int) -> torch.Tensor:
"""Encode text using CLIP text encoder (per SARM paper A.4).
Args:
text: Task description text to encode
batch_size: Batch size to replicate for
Returns:
Encoded text features with shape (B, 512)
"""
# Use CLIP's tokenizer directly for text
tokenizer = self.clip_processor.tokenizer
inputs = tokenizer([text], return_tensors="pt", padding=True, truncation=True)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
# Get text features from CLIP
text_embedding = self.clip_model.get_text_features(**inputs).detach().cpu()
# Replicate for batch (B, 512)
text_embedding = text_embedding.expand(batch_size, -1)
return text_embedding
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
"""Add encoded features to the observation features."""
features[PipelineFeatureType.OBSERVATION]['video_features'] = PolicyFeature(
type=FeatureType.VISUAL,
shape=(self.config.num_frames, self.config.image_dim)
)
features[PipelineFeatureType.OBSERVATION]['text_features'] = PolicyFeature(
type=FeatureType.LANGUAGE,
shape=(self.config.text_dim,)
)
features[PipelineFeatureType.OBSERVATION]['state_features'] = PolicyFeature(
type=FeatureType.STATE,
shape=(self.config.num_frames, self.config.max_state_dim)
)
return features
def make_sarm_pre_post_processors(
config: SARMConfig,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
dataset_meta = None,
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
"""
Create pre-processor and post-processor pipelines for SARM.
The pre-processing pipeline:
1. Adds batch dimension
2. Normalizes observation.state using NormalizerProcessorStep (MEAN_STD)
3. SARMEncodingProcessorStep:
- Encodes images with CLIP
- Pads states to max_state_dim
- Encodes text with CLIP
4. Moves data to device
The post-processing pipeline:
1. Moves data to CPU
"""
return (
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
steps=[
AddBatchDimensionProcessorStep(),
NormalizerProcessorStep(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
SARMEncodingProcessorStep(
config=config,
dataset_meta=dataset_meta,
dataset_stats=dataset_stats
),
DeviceProcessorStep(device=config.device),
],
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
),
PolicyProcessorPipeline[PolicyAction, PolicyAction](
steps=[DeviceProcessorStep(device="cpu")],
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
to_transition=policy_action_to_transition,
to_output=transition_to_policy_action,
),
)
+257
View File
@@ -0,0 +1,257 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import torch
import torch.nn.functional as F
from typing import Sequence, Any
from pydantic import BaseModel, Field
# Pydantic Models for SARM-style Annotation
class Timestamp(BaseModel):
"""Timestamp in MM:SS or SS format"""
start: str = Field(description="Start timestamp (MM:SS or just seconds)")
end: str = Field(description="End timestamp (MM:SS or just seconds)")
class Subtask(BaseModel):
"""Individual subtask/stage - must use EXACT names from provided list"""
name: str = Field(description="Subtask name - MUST match one from the predefined list exactly")
timestamps: Timestamp
class SubtaskAnnotation(BaseModel):
"""Complete annotation for a robot manipulation episode"""
subtasks: list[Subtask] = Field(description="List of all subtasks in temporal order")
def compute_temporal_proportions(annotations: dict[int, Any], fps: int = 30) -> dict[str, float]:
"""
Compute dataset-level temporal proportions (priors) for each subtask.
Implements SARM Paper Formula (1):
ᾱ_k = (1/M) × Σ_i (L_{i,k} / T_i)
where:
- M is the number of trajectories (episodes)
- L_{i,k} is the duration of subtask k in trajectory i
- T_i is the total duration of trajectory i
This averages the PROPORTION of each subtask within each trajectory,
giving equal weight to all trajectories regardless of their absolute length.
Args:
annotations: Dict mapping episode index to SubtaskAnnotation object.
Each annotation has a .subtasks list where each subtask has:
- .name: subtask name
- .timestamps.start: start time as "MM:SS" string
- .timestamps.end: end time as "MM:SS" string
fps: Frames per second (unused, kept for API compatibility)
Returns:
Dict mapping subtask name to its temporal proportion (ᾱ_k).
Proportions are normalized to sum to 1.0.
"""
subtask_proportions: dict[str, list[float]] = {}
for annotation in annotations.values():
total_duration = 0
durations: dict[str, int] = {}
for subtask in annotation.subtasks:
start_parts = subtask.timestamps.start.split(":")
end_parts = subtask.timestamps.end.split(":")
start_seconds = int(start_parts[0]) * 60 + int(start_parts[1]) if len(start_parts) == 2 else int(start_parts[0])
end_seconds = int(end_parts[0]) * 60 + int(end_parts[1]) if len(end_parts) == 2 else int(end_parts[0])
duration = end_seconds - start_seconds
durations[subtask.name] = duration
total_duration += duration
# Calculate L_{i,k} / T_i for each subtask in this trajectory
if total_duration > 0:
for name, duration in durations.items():
if name not in subtask_proportions:
subtask_proportions[name] = []
subtask_proportions[name].append(duration / total_duration)
if not subtask_proportions:
return {}
# Average across trajectories: (1/M) × Σ_i (L_{i,k} / T_i)
avg_proportions = {
name: sum(props) / len(props)
for name, props in subtask_proportions.items()
}
# Normalize to ensure sum = 1
total = sum(avg_proportions.values())
if total > 0:
avg_proportions = {name: prop / total for name, prop in avg_proportions.items()}
return avg_proportions
def compute_tau(
current_frame: int | float,
subtask_start: int | float,
subtask_end: int | float,
) -> float:
"""
Compute within-subtask normalized time τ_t.
Implements part of SARM Paper Formula (2):
τ_t = (t - s_k) / (e_k - s_k) ∈ [0, 1]
where:
- t is the current frame
- s_k is the start frame of subtask k
- e_k is the end frame of subtask k
Args:
current_frame: Current frame index (t)
subtask_start: Start frame of the subtask (s_k)
subtask_end: End frame of the subtask (e_k)
Returns:
Within-subtask progress τ_t ∈ [0, 1]
"""
subtask_duration = subtask_end - subtask_start
if subtask_duration <= 0:
return 1.0
tau = (current_frame - subtask_start) / subtask_duration
return float(np.clip(tau, 0.0, 1.0))
def compute_cumulative_progress_batch(
tau: torch.Tensor | float,
stage_indices: torch.Tensor | int,
alpha: torch.Tensor | Sequence[float],
cumulative_prior: torch.Tensor | None = None,
) -> torch.Tensor | float:
"""
Compute cumulative normalized progress from within-subtask progress.
This function implements the core formula used in SARM for both:
**Formula 2 (Training labels):**
y_t = P_{k-1} + ᾱ_k × τ_t ∈ [0, 1]
Used to compute ground-truth progress labels from subtask annotations.
- τ_t comes from annotated frame position: τ_t = (t - s_k) / (e_k - s_k)
- k is the known subtask from annotations
**Formula 4 (Inference predictions):**
ŷ_{1:N} = P̂_{k-1, 1:N} + ᾱ_{k, 1:N} × τ̂_{1:N} ∈ [0, 1]
Used to convert model outputs to cumulative progress during inference.
- τ̂ comes from the subtask MLP head (conditioned on predicted stage)
- k = Ŝ is the predicted stage from Formula 3: Ŝ = argmax(softmax(Ψ))
The formulas are mathematically identical; only the source of inputs differs:
- Training: τ and k from annotations → ground-truth labels
- Inference: τ̂ and Ŝ from model → predicted progress
where:
- P_{k-1} = Σ_{j=1}^{k-1} ᾱ_j is the cumulative prior (sum of previous proportions)
- ᾱ_k is the temporal proportion for subtask k (from Formula 1)
- τ is within-subtask progress ∈ [0, 1]
This ensures:
- y at start of subtask k = P_{k-1}
- y at end of subtask k = P_k
Supports both scalar and batched tensor inputs:
- Scalar: tau (float), stage_indices (int), alpha (list/sequence)
- Batch: tau (Tensor), stage_indices (Tensor), alpha (Tensor), cumulative_prior (Tensor)
Args:
tau: Within-subtask progress τ ∈ [0, 1].
For training: computed from frame position in annotated subtask.
For inference: predicted by subtask MLP head.
Scalar float or Tensor with shape (..., 1)
stage_indices: Index of current subtask k (0-indexed).
For training: known from annotations.
For inference: predicted via argmax(stage_probs) (Formula 3).
Scalar int or Tensor with shape (...)
alpha: Temporal proportions ᾱ with shape (num_stages,) or Sequence[float].
Computed from dataset annotations using Formula 1.
cumulative_prior: Optional. Cumulative priors P with shape (num_stages + 1,)
where cumulative_prior[k] = P_k = Σ_{j=1}^{k} ᾱ_j.
If None, will be computed from alpha.
Returns:
Cumulative progress y ∈ [0, 1].
Scalar float if inputs are scalar, otherwise Tensor with shape (..., 1)
"""
if not isinstance(tau, torch.Tensor):
if not alpha:
raise ValueError("alpha (temporal_proportions) cannot be empty")
if isinstance(alpha, torch.Tensor):
alpha_list = alpha.tolist()
else:
alpha_list = list(alpha)
if stage_indices < 0 or stage_indices >= len(alpha_list):
raise ValueError(
f"stage_indices {stage_indices} out of range "
f"for {len(alpha_list)} subtasks"
)
# P_{k-1} = sum of proportions for subtasks 0 to k-1
P_k_minus_1 = sum(alpha_list[:stage_indices])
# ᾱ_k = proportion for current subtask
alpha_k = alpha_list[stage_indices]
# y_t = P_{k-1} + ᾱ_k × τ_t
y_t = P_k_minus_1 + alpha_k * tau
return float(np.clip(y_t, 0.0, 1.0))
if not isinstance(alpha, torch.Tensor):
alpha = torch.tensor(alpha, dtype=torch.float32)
# Compute cumulative_prior if not provided
if cumulative_prior is None:
cumulative_prior = torch.zeros(len(alpha) + 1, dtype=alpha.dtype, device=alpha.device)
cumulative_prior[1:] = torch.cumsum(alpha, dim=0)
# P_{k-1} for each predicted stage
P_k_minus_1 = cumulative_prior[stage_indices]
# ᾱ_k for each predicted stage
alpha_k = alpha[stage_indices]
# ŷ = P_{k-1} + ᾱ_k × τ̂
progress = P_k_minus_1.unsqueeze(-1) + alpha_k.unsqueeze(-1) * tau
return progress
def pad_state_to_max_dim(state: torch.Tensor, max_state_dim: int) -> torch.Tensor:
"""Pad the state tensor's last dimension to max_state_dim with zeros."""
current_dim = state.shape[-1]
if current_dim >= max_state_dim:
return state[..., :max_state_dim] # Truncate if larger
# Pad with zeros on the right
padding = (0, max_state_dim - current_dim) # (left, right) for last dim
return F.pad(state, padding, mode='constant', value=0)
+11 -1
View File
@@ -230,6 +230,10 @@ def validate_visual_features_consistency(
) -> None:
"""
Validates visual feature consistency between a policy config and provided dataset/environment features.
Validation passes if EITHER:
- Policy's expected visuals are a subset of dataset (policy uses some cameras, dataset has more)
- Dataset's provided visuals are a subset of policy (policy declares extras for flexibility)
Args:
cfg (PreTrainedConfig): The model or policy configuration containing input_features and type.
@@ -237,5 +241,11 @@ def validate_visual_features_consistency(
"""
expected_visuals = {k for k, v in cfg.input_features.items() if v.type == FeatureType.VISUAL}
provided_visuals = {k for k, v in features.items() if v.type == FeatureType.VISUAL}
if not provided_visuals.issubset(expected_visuals):
# Accept if either direction is a subset
policy_subset_of_dataset = expected_visuals.issubset(provided_visuals)
dataset_subset_of_policy = provided_visuals.issubset(expected_visuals)
if not (policy_subset_of_dataset or dataset_subset_of_policy):
raise_feature_mismatch_error(provided_visuals, expected_visuals)
+2 -1
View File
@@ -170,8 +170,9 @@ def _extract_complementary_data(batch: dict[str, Any]) -> dict[str, Any]:
task_key = {"task": batch["task"]} if "task" in batch else {}
index_key = {"index": batch["index"]} if "index" in batch else {}
task_index_key = {"task_index": batch["task_index"]} if "task_index" in batch else {}
episode_index_key = {"episode_index": batch["episode_index"]} if "episode_index" in batch else {}
return {**pad_keys, **task_key, **index_key, **task_index_key}
return {**pad_keys, **task_key, **index_key, **task_index_key, **episode_index_key}
def create_transition(
+63 -3
View File
@@ -61,6 +61,7 @@ def update_policy(
accelerator: Accelerator,
lr_scheduler=None,
lock=None,
rabc_weight_computer=None,
) -> tuple[MetricsTracker, dict]:
"""
Performs a single training step to update the policy's weights.
@@ -85,10 +86,22 @@ def update_policy(
"""
start_time = time.perf_counter()
policy.train()
# Compute RA-BC weights if enabled
rabc_weights = None
if rabc_weight_computer is not None:
rabc_weights = rabc_weight_computer.compute_batch_weights(batch)
# Let accelerator handle mixed precision
with accelerator.autocast():
loss, output_dict = policy.forward(batch)
# Apply RA-BC weights if enabled
if rabc_weights is not None:
# Weight the loss
loss = loss * rabc_weights.mean()
output_dict['rabc_mean_weight'] = rabc_weights.mean().item()
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
# Use accelerator's backward method
@@ -140,8 +153,6 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
cfg: A `TrainPipelineConfig` object containing all training configurations.
accelerator: Optional Accelerator instance. If None, one will be created automatically.
"""
cfg.validate()
# Create Accelerator if not provided
# It will automatically detect if running in distributed mode or single-process mode
# We set step_scheduler_with_optimizer=False to prevent accelerate from adjusting the lr_scheduler steps based on the num_processes
@@ -158,6 +169,8 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
# When using accelerate, only the main process should log to avoid duplicate outputs
is_main_process = accelerator.is_main_process
cfg.validate()
# Only log on main process
if is_main_process:
logging.info(pformat(cfg.to_dict()))
@@ -215,6 +228,10 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
if (cfg.policy.pretrained_path and not cfg.resume) or not cfg.policy.pretrained_path:
# Only provide dataset_stats when not resuming from saved processor state
processor_kwargs["dataset_stats"] = dataset.meta.stats
# For SARM, always provide dataset_meta for progress normalization
if cfg.policy.type == "sarm":
processor_kwargs["dataset_meta"] = dataset.meta
if cfg.policy.pretrained_path is not None:
processor_kwargs["preprocessor_overrides"] = {
@@ -246,6 +263,28 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
if is_main_process:
logging.info("Creating optimizer and scheduler")
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
# Load reward model for RA-BC if enabled
rabc_weight_computer = None
if cfg.use_rabc:
logging.info(f"Loading reward model for RA-BC from {cfg.reward_model_path}")
from lerobot.policies.factory import get_policy_class
from lerobot.utils.rabc import RABCWeightComputer
# Detect reward model type from path
# For now, assume SARM if not specified
reward_model_class = get_policy_class("sarm")
reward_model = reward_model_class.from_pretrained(cfg.reward_model_path)
reward_model.to(device)
reward_model.eval()
rabc_weight_computer = RABCWeightComputer(
reward_model=reward_model,
kappa=cfg.rabc_kappa,
epsilon=cfg.rabc_epsilon,
device=device,
)
logging.info("RA-BC weight computer initialized")
step = 0 # number of policy updates (forward + backward + optim)
@@ -280,6 +319,18 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
drop_n_last_frames=cfg.policy.drop_n_last_frames,
shuffle=True,
)
elif cfg.policy.type == "sarm" and getattr(cfg.policy, "use_temporal_sampler", False):
# Use SARM temporal sampler for reward model training
from lerobot.datasets.temporal_sampler import SARMTemporalSampler
shuffle = False
sampler = SARMTemporalSampler(
dataset_from_index=dataset.meta.episodes["dataset_from_index"],
dataset_to_index=dataset.meta.episodes["dataset_to_index"],
frame_gap=getattr(cfg.policy, "frame_gap", 30),
shuffle=True,
seed=cfg.seed,
)
else:
shuffle = True
sampler = None
@@ -324,7 +375,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
)
if is_main_process:
logging.info("Start offline training on a fixed dataset")
logging.info(f"Start offline training on a fixed dataset, with effective batch size: {effective_batch_size}")
for _ in range(step, cfg.steps):
start_time = time.perf_counter()
@@ -340,6 +391,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
cfg.optimizer.grad_clip_norm,
accelerator=accelerator,
lr_scheduler=lr_scheduler,
rabc_weight_computer=rabc_weight_computer,
)
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
@@ -356,6 +408,14 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
wandb_log_dict = train_tracker.to_dict()
if output_dict:
wandb_log_dict.update(output_dict)
# Log RA-BC statistics if enabled
if rabc_weight_computer is not None:
rabc_stats = rabc_weight_computer.get_stats()
wandb_log_dict.update({
'rabc_progress_mean': rabc_stats['mean'],
'rabc_progress_std': rabc_stats['std'],
'rabc_samples_seen': rabc_stats['count'],
})
wandb_logger.log_dict(wandb_log_dict, step)
train_tracker.reset_averages()
+183
View File
@@ -0,0 +1,183 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Reward-Aligned Behavior Cloning (RA-BC) utilities.
RA-BC uses a pre-trained reward model (e.g., SARM) to compute progress-based weights
for training samples, emphasizing high-quality demonstrations and down-weighting
suboptimal ones.
"""
import logging
import torch
import torch.nn as nn
class RABCWeightComputer:
"""
Computes RA-BC weights for training batches using a pre-trained reward model.
Uses Welford's online algorithm for numerically stable running statistics
and applies soft weighting based on progress deltas.
Args:
reward_model: Pre-trained reward model (e.g., SARM)
kappa: Hard threshold for high-quality samples (default: 0.01)
epsilon: Small constant for numerical stability (default: 1e-6)
device: Device to run reward model on
"""
def __init__(
self,
reward_model: nn.Module,
kappa: float = 0.01,
epsilon: float = 1e-6,
device: torch.device = None,
):
self.reward_model = reward_model
self.reward_model.eval() # Always in eval mode
self.kappa = kappa
self.epsilon = epsilon
self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Running statistics (Welford's algorithm)
self.mean = 0.0
self.m2 = 0.0
self.count = 0
logging.info(f"RA-BC WeightComputer initialized with kappa={kappa}, epsilon={epsilon}")
def _update_stats(self, deltas: torch.Tensor):
"""Update running statistics using Welford's online algorithm."""
for delta in deltas:
self.count += 1
delta_val = delta.item()
delta_mean = delta_val - self.mean
self.mean += delta_mean / self.count
delta_m2 = delta_val - self.mean
self.m2 += delta_mean * delta_m2
def _compute_weights(self, deltas: torch.Tensor) -> torch.Tensor:
"""Compute RA-BC weights from progress deltas."""
if self.count < 2:
# Not enough data, use uniform weights
return torch.ones_like(deltas)
# Get running statistics
mean = max(self.mean, 0.0) # Clamp mean to non-negative
variance = self.m2 / (self.count - 1)
std = torch.tensor(variance).sqrt().item()
# Compute soft weights
lower_bound = mean - 2 * std
upper_bound = mean + 2 * std
weights = (deltas - lower_bound) / (4 * std + self.epsilon)
weights = torch.clamp(weights, 0.0, 1.0)
# Apply hard threshold
high_quality_mask = deltas > self.kappa
weights = torch.where(high_quality_mask, torch.ones_like(weights), weights)
return weights
@torch.no_grad()
def compute_batch_weights(self, batch: dict, chunk_size: int = 1) -> torch.Tensor:
"""
Compute RA-BC weights for a training batch.
This function:
1. Extracts current and next observations from the batch
2. Computes rewards using the reward model
3. Calculates progress deltas
4. Updates running statistics
5. Returns normalized weights
Args:
batch: Training batch containing observations
chunk_size: Size of action chunks for computing deltas (default: 1)
Returns:
Weights tensor (batch_size,) normalized to sum to batch_size
"""
observation = batch.get('observation', batch)
batch_size = next(iter(observation.values())).shape[0]
# Extract features needed for reward computation
# These should already be encoded by the preprocessor
if 'video_features' not in observation or 'text_features' not in observation:
logging.warning("RA-BC: Missing video/text features, using uniform weights")
return torch.ones(batch_size, device=self.device)
video_features = observation['video_features'].to(self.device)
text_features = observation['text_features'].to(self.device)
state_features = observation.get('state_features', None)
if state_features is not None:
state_features = state_features.to(self.device)
# Compute rewards for current observations
# Handle both single-frame and multi-frame features
if video_features.dim() == 3: # (B, T, D)
# Multi-frame: use last frame reward
if hasattr(self.reward_model, 'calculate_rewards'):
current_rewards = self.reward_model.calculate_rewards(
text_features, video_features, state_features,
return_all_frames=False
)
else:
# Fallback for models without calculate_rewards
current_rewards = torch.zeros(batch_size, device=self.device)
else: # (B, D)
# Single frame
if hasattr(self.reward_model, 'calculate_rewards'):
current_rewards = self.reward_model.calculate_rewards(
text_features, video_features.unsqueeze(1), state_features,
return_all_frames=False
)
else:
current_rewards = torch.zeros(batch_size, device=self.device)
if isinstance(current_rewards, tuple):
current_rewards = current_rewards[0]
current_rewards = torch.tensor(current_rewards, device=self.device) if isinstance(current_rewards, (list, tuple)) else current_rewards
# For simplicity, assume progress delta is proportional to reward
# In practice, you'd want to compute next_frame rewards and take differences
# For now, use current reward as a proxy for progress delta
progress_deltas = current_rewards
# Update running statistics
self._update_stats(progress_deltas)
# Compute weights
weights = self._compute_weights(progress_deltas)
# Normalize weights to sum to batch_size (maintains effective batch size)
weight_sum = weights.sum() + self.epsilon
weights = weights * batch_size / weight_sum
return weights
def get_stats(self) -> dict:
"""Get current running statistics."""
std = torch.tensor(self.m2 / (self.count - 1)).sqrt().item() if self.count > 1 else 0.0
return {
'mean': self.mean,
'std': std,
'count': self.count,
}
+378
View File
@@ -0,0 +1,378 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Tests for SARM utility functions.
Tests the implementation of SARM paper formulas:
- Formula (1): compute_temporal_proportions - dataset-level temporal proportions
- Formula (2): compute_tau, compute_cumulative_progress - progress labels
"""
import pytest
import numpy as np
import torch
from lerobot.policies.sarm.sarm_utils import SubtaskAnnotation, Subtask, Timestamp
from lerobot.policies.sarm.sarm_utils import (
compute_temporal_proportions,
compute_tau,
compute_cumulative_progress_batch,
)
def make_annotation(subtasks: list[tuple[str, int, int]]) -> SubtaskAnnotation:
"""Helper to create SubtaskAnnotation from list of (name, start_sec, end_sec)."""
return SubtaskAnnotation(
subtasks=[
Subtask(
name=name,
timestamps=Timestamp(
start=f"{start // 60:02d}:{start % 60:02d}",
end=f"{end // 60:02d}:{end % 60:02d}"
)
)
for name, start, end in subtasks
]
)
class TestComputeTemporalProportions:
"""Tests for compute_temporal_proportions (SARM Paper Formula 1).
Formula: ᾱ_k = (1/M) × Σ_i (L_{i,k} / T_i)
Key insight: This averages the PROPORTION of each subtask within each trajectory,
giving equal weight to all trajectories regardless of absolute length.
"""
def test_basic_two_trajectories_equal_proportions(self):
"""Test with two trajectories that have equal proportions."""
# Both trajectories: subtask1 = 50%, subtask2 = 50%
# Traj 1: T=100s, subtask1=50s, subtask2=50s
# Traj 2: T=200s, subtask1=100s, subtask2=100s
annotations = {
0: make_annotation([('subtask1', 0, 50), ('subtask2', 50, 100)]),
1: make_annotation([('subtask1', 0, 100), ('subtask2', 100, 200)]),
}
result = compute_temporal_proportions(annotations)
# Both should be 0.5
assert abs(result['subtask1'] - 0.5) < 1e-6
assert abs(result['subtask2'] - 0.5) < 1e-6
def test_paper_example_different_from_avg_durations(self):
"""Test that compute_temporal_proportions differs from naive average duration approach.
This is the key test showing the difference between:
- Paper formula: average of (L_i,k / T_i)
- Naive approach: mean(L_i,k) / sum(mean(L_i,j))
"""
# Episode 1: T=100s, subtask1=80s, subtask2=20s (proportions: 0.8, 0.2)
# Episode 2: T=200s, subtask1=40s, subtask2=160s (proportions: 0.2, 0.8)
annotations = {
0: make_annotation([('subtask1', 0, 80), ('subtask2', 80, 100)]),
1: make_annotation([('subtask1', 0, 40), ('subtask2', 40, 200)]),
}
result = compute_temporal_proportions(annotations)
# Paper formula:
# ᾱ_1 = (1/2) × (80/100 + 40/200) = (1/2) × (0.8 + 0.2) = 0.5
# ᾱ_2 = (1/2) × (20/100 + 160/200) = (1/2) × (0.2 + 0.8) = 0.5
assert abs(result['subtask1'] - 0.5) < 1e-6
assert abs(result['subtask2'] - 0.5) < 1e-6
def test_single_trajectory(self):
"""Test with a single trajectory."""
# T=100s, reach=30s, grasp=20s, lift=50s
annotations = {
0: make_annotation([('reach', 0, 30), ('grasp', 30, 50), ('lift', 50, 100)]),
}
result = compute_temporal_proportions(annotations)
assert abs(result['reach'] - 0.3) < 1e-6
assert abs(result['grasp'] - 0.2) < 1e-6
assert abs(result['lift'] - 0.5) < 1e-6
def test_sum_to_one(self):
"""Test that proportions always sum to 1."""
# Three episodes with varying proportions
annotations = {
0: make_annotation([('a', 0, 10), ('b', 10, 50), ('c', 50, 100)]), # 0.1, 0.4, 0.5
1: make_annotation([('a', 0, 20), ('b', 20, 70), ('c', 70, 100)]), # 0.2, 0.5, 0.3
2: make_annotation([('a', 0, 30), ('b', 30, 90), ('c', 90, 100)]), # 0.3, 0.6, 0.1
}
result = compute_temporal_proportions(annotations)
total = sum(result.values())
assert abs(total - 1.0) < 1e-6
def test_empty_annotations_returns_empty(self):
"""Test that empty annotations returns empty dict."""
result = compute_temporal_proportions({})
assert result == {}
def test_uniform_proportions(self):
"""Test with uniform proportions across subtasks."""
# Each subtask takes 25% of each episode
annotations = {
0: make_annotation([('a', 0, 25), ('b', 25, 50), ('c', 50, 75), ('d', 75, 100)]),
1: make_annotation([('a', 0, 50), ('b', 50, 100), ('c', 100, 150), ('d', 150, 200)]),
}
result = compute_temporal_proportions(annotations)
for name in ['a', 'b', 'c', 'd']:
assert abs(result[name] - 0.25) < 1e-6
class TestComputeTau:
"""Tests for compute_tau (within-subtask progress).
Formula: τ_t = (t - s_k) / (e_k - s_k) ∈ [0, 1]
"""
def test_at_start(self):
"""τ should be 0 at subtask start."""
tau = compute_tau(current_frame=10, subtask_start=10, subtask_end=50)
assert tau == 0.0
def test_at_end(self):
"""τ should be 1 at subtask end."""
tau = compute_tau(current_frame=50, subtask_start=10, subtask_end=50)
assert tau == 1.0
def test_at_middle(self):
"""τ should be 0.5 at subtask midpoint."""
tau = compute_tau(current_frame=30, subtask_start=10, subtask_end=50)
assert abs(tau - 0.5) < 1e-6
def test_quarter_progress(self):
"""Test τ at 25% through subtask."""
tau = compute_tau(current_frame=20, subtask_start=0, subtask_end=80)
assert abs(tau - 0.25) < 1e-6
def test_zero_duration_subtask(self):
"""τ should be 1.0 for zero-duration subtask."""
tau = compute_tau(current_frame=10, subtask_start=10, subtask_end=10)
assert tau == 1.0
def test_clamps_below_zero(self):
"""τ should be clamped to 0 if frame is before subtask."""
tau = compute_tau(current_frame=5, subtask_start=10, subtask_end=50)
assert tau == 0.0
def test_clamps_above_one(self):
"""τ should be clamped to 1 if frame is after subtask."""
tau = compute_tau(current_frame=60, subtask_start=10, subtask_end=50)
assert tau == 1.0
def test_float_inputs(self):
"""Test with float frame indices (from interpolation)."""
tau = compute_tau(current_frame=25.5, subtask_start=10.0, subtask_end=50.0)
expected = (25.5 - 10.0) / (50.0 - 10.0)
assert abs(tau - expected) < 1e-6
class TestComputeCumulativeProgressBatchScalar:
"""Tests for compute_cumulative_progress_batch with scalar inputs (normalized progress y_t).
Formula: y_t = P_{k-1} + ᾱ_k × τ_t ∈ [0, 1]
"""
def test_first_subtask_start(self):
"""y should be 0 at start of first subtask."""
proportions = [0.3, 0.5, 0.2]
y = compute_cumulative_progress_batch(tau=0.0, stage_indices=0, alpha=proportions)
assert y == 0.0
def test_first_subtask_end(self):
"""y should equal ᾱ_1 at end of first subtask."""
proportions = [0.3, 0.5, 0.2]
y = compute_cumulative_progress_batch(tau=1.0, stage_indices=0, alpha=proportions)
assert abs(y - 0.3) < 1e-6
def test_second_subtask_start(self):
"""y should equal P_1 at start of second subtask."""
proportions = [0.3, 0.5, 0.2]
y = compute_cumulative_progress_batch(tau=0.0, stage_indices=1, alpha=proportions)
assert abs(y - 0.3) < 1e-6
def test_second_subtask_end(self):
"""y should equal P_2 at end of second subtask."""
proportions = [0.3, 0.5, 0.2]
y = compute_cumulative_progress_batch(tau=1.0, stage_indices=1, alpha=proportions)
assert abs(y - 0.8) < 1e-6 # 0.3 + 0.5
def test_third_subtask_end(self):
"""y should be 1.0 at end of last subtask."""
proportions = [0.3, 0.5, 0.2]
y = compute_cumulative_progress_batch(tau=1.0, stage_indices=2, alpha=proportions)
assert abs(y - 1.0) < 1e-6
def test_midpoint_of_subtask(self):
"""Test progress at midpoint of a subtask."""
proportions = [0.4, 0.6]
# At τ=0.5 in subtask 1: y = P_0 + ᾱ_1 × 0.5 = 0 + 0.4 × 0.5 = 0.2
y = compute_cumulative_progress_batch(tau=0.5, stage_indices=0, alpha=proportions)
assert abs(y - 0.2) < 1e-6
# At τ=0.5 in subtask 2: y = P_1 + ᾱ_2 × 0.5 = 0.4 + 0.6 × 0.5 = 0.7
y = compute_cumulative_progress_batch(tau=0.5, stage_indices=1, alpha=proportions)
assert abs(y - 0.7) < 1e-6
def test_uniform_proportions(self):
"""Test with uniform proportions."""
proportions = [0.25, 0.25, 0.25, 0.25]
# At end of each subtask, progress should be 0.25, 0.5, 0.75, 1.0
for i in range(4):
y = compute_cumulative_progress_batch(tau=1.0, stage_indices=i, alpha=proportions)
expected = (i + 1) * 0.25
assert abs(y - expected) < 1e-6
class TestComputeCumulativeProgressBatchTensor:
"""Tests for compute_cumulative_progress_batch with tensor inputs (GPU batch version)."""
def test_tensor_matches_scalar_version(self):
"""Test that tensor version matches scalar version."""
proportions = [0.3, 0.5, 0.2]
alpha = torch.tensor(proportions, dtype=torch.float32)
cumulative = torch.zeros(len(proportions) + 1, dtype=torch.float32)
cumulative[1:] = torch.cumsum(alpha, dim=0)
test_cases = [
(0.0, 0), # start of subtask 0
(1.0, 0), # end of subtask 0
(0.0, 1), # start of subtask 1
(0.5, 1), # middle of subtask 1
(1.0, 2), # end of subtask 2
]
for tau_val, stage_idx in test_cases:
# Scalar version
expected = compute_cumulative_progress_batch(tau_val, stage_idx, proportions)
# Tensor version (single element)
tau = torch.tensor([[[tau_val]]]) # (1, 1, 1)
stages = torch.tensor([[stage_idx]]) # (1, 1)
result = compute_cumulative_progress_batch(tau, stages, alpha, cumulative)
assert abs(result[0, 0, 0].item() - expected) < 1e-6
def test_batch_processing(self):
"""Test batch processing with multiple samples."""
proportions = [0.4, 0.6]
alpha = torch.tensor(proportions, dtype=torch.float32)
cumulative = torch.zeros(3, dtype=torch.float32)
cumulative[1:] = torch.cumsum(alpha, dim=0)
# Batch of 2 samples, sequence length 3
tau = torch.tensor([
[[0.0], [0.5], [1.0]], # sample 1
[[0.0], [0.5], [1.0]], # sample 2
])
stages = torch.tensor([
[0, 0, 0], # sample 1: all in subtask 0
[1, 1, 1], # sample 2: all in subtask 1
])
result = compute_cumulative_progress_batch(tau, stages, alpha, cumulative)
# Sample 1: subtask 0 with tau 0, 0.5, 1.0 -> y = 0, 0.2, 0.4
assert abs(result[0, 0, 0].item() - 0.0) < 1e-6
assert abs(result[0, 1, 0].item() - 0.2) < 1e-6
assert abs(result[0, 2, 0].item() - 0.4) < 1e-6
# Sample 2: subtask 1 with tau 0, 0.5, 1.0 -> y = 0.4, 0.7, 1.0
assert abs(result[1, 0, 0].item() - 0.4) < 1e-6
assert abs(result[1, 1, 0].item() - 0.7) < 1e-6
assert abs(result[1, 2, 0].item() - 1.0) < 1e-6
def test_auto_compute_cumulative_prior(self):
"""Test that cumulative_prior is auto-computed when not provided."""
proportions = [0.3, 0.5, 0.2]
alpha = torch.tensor(proportions, dtype=torch.float32)
tau = torch.tensor([[[0.5]]])
stages = torch.tensor([[1]])
# Without cumulative_prior (should auto-compute)
result = compute_cumulative_progress_batch(tau, stages, alpha)
# Expected: P_0 + alpha_1 * 0.5 = 0.3 + 0.5 * 0.5 = 0.55
assert abs(result[0, 0, 0].item() - 0.55) < 1e-6
class TestEndToEndProgressLabeling:
"""End-to-end tests for progress label computation."""
def test_consistent_semantic_meaning(self):
"""Test that same subtask completion maps to same progress across trajectories.
This is the key semantic property: "end of subtask 1" should always
mean the same progress value regardless of trajectory speed.
"""
proportions = [0.3, 0.5, 0.2]
# Fast trajectory: subtask 1 ends at frame 30 (of 100)
tau_fast = compute_tau(30, 0, 30) # = 1.0
y_fast = compute_cumulative_progress_batch(tau_fast, 0, proportions)
# Slow trajectory: subtask 1 ends at frame 90 (of 300)
tau_slow = compute_tau(90, 0, 90) # = 1.0
y_slow = compute_cumulative_progress_batch(tau_slow, 0, proportions)
# Both should map to same progress (0.3 = end of subtask 1)
assert abs(y_fast - y_slow) < 1e-6
assert abs(y_fast - 0.3) < 1e-6
def test_monotonic_within_subtask(self):
"""Test that progress is monotonically increasing within a subtask."""
proportions = [0.4, 0.6]
prev_y = -1
for tau in np.linspace(0, 1, 11):
y = compute_cumulative_progress_batch(tau, 0, proportions)
assert y > prev_y or (tau == 0 and y == 0)
prev_y = y
def test_continuous_across_subtasks(self):
"""Test that progress is continuous at subtask boundaries."""
proportions = [0.3, 0.5, 0.2]
# End of subtask 0 (tau=1.0)
y_end_0 = compute_cumulative_progress_batch(1.0, 0, proportions)
# Start of subtask 1 (tau=0.0)
y_start_1 = compute_cumulative_progress_batch(0.0, 1, proportions)
# Should be equal (P_1 = 0.3)
assert abs(y_end_0 - y_start_1) < 1e-6
# End of subtask 1
y_end_1 = compute_cumulative_progress_batch(1.0, 1, proportions)
# Start of subtask 2
y_start_2 = compute_cumulative_progress_batch(0.0, 2, proportions)
# Should be equal (P_2 = 0.8)
assert abs(y_end_1 - y_start_2) < 1e-6