initial commit

This commit is contained in:
Pepijn
2025-08-27 14:58:34 +02:00
parent b16e18f978
commit 681be962ae
12 changed files with 4063 additions and 4 deletions
File diff suppressed because one or more lines are too long
+14
View File
@@ -34,6 +34,7 @@ from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.policies.pi0.configuration_pi0 import PI0Config
from lerobot.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.rlearn.configuration_rlearn import RLearNConfig
from lerobot.policies.sac.configuration_sac import SACConfig
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
@@ -80,6 +81,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy
return SmolVLAPolicy
elif name == "rlearn":
from lerobot.policies.rlearn.modeling_rlearn import RLearNPolicy
return RLearNPolicy
else:
raise NotImplementedError(f"Policy with name {name} is not implemented.")
@@ -103,6 +108,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
return SmolVLAConfig(**kwargs)
elif policy_type == "reward_classifier":
return RewardClassifierConfig(**kwargs)
elif policy_type == "rlearn":
return RLearNConfig(**kwargs)
else:
raise ValueError(f"Policy type '{policy_type}' is not available.")
@@ -220,6 +227,13 @@ def make_processor(
cast(SmolVLAConfig, policy_cfg), dataset_stats=kwargs.get("dataset_stats")
)
elif policy_cfg.type == "rlearn":
from lerobot.policies.rlearn.processor_rlearn import make_rlearn_processor
processors = make_rlearn_processor(
cast(RLearNConfig, policy_cfg), dataset_stats=kwargs.get("dataset_stats")
)
else:
raise NotImplementedError(f"Processor for policy type '{policy_cfg.type}' is not implemented.")
@@ -0,0 +1,127 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import NormalizationMode
@PreTrainedConfig.register_subclass("rlearn")
@dataclass
class RLearNConfig(PreTrainedConfig):
"""Configuration for a video-language conditioned reward model (RLearN).
Inputs:
- Visual frames (one or multiple cameras). Optionally a short sequence.
- A language instruction/goal string.
Output:
- Per-timestep reward logits or a single-step reward logit.
Notes:
- This is the initial architecture. It uses frozen vision/text encoders
(e.g. SigLIP2) and trains a lightweight temporal aggregator + head.
"""
# Encoders
model_name: str = "google/siglip2-large-patch16-256"
freeze_backbones: bool = True
# Temporal aggregator
dim_model: int = 512
n_heads: int = 8
n_layers: int = 4
dim_feedforward: int = 2048
dropout: float = 0.1
pre_norm: bool = True
use_first_frame_positional_bias: bool = True
frame_dropout_p: float = 0.0
stride: int = 1
# Sequence length, amount of past frames including current one to use in the temporal model
max_seq_len: int = 16
# Head
use_tanh_head: bool = False # when True, bound outputs in [-1, 1]
# Training
learning_rate: float = 5e-5 # Reduced for stability
weight_decay: float = 0.01
loss_type: str = "composite" # Always use composite loss with spatial awareness
ranking_margin: float = 0.1
# Composite loss weights (with spatial awareness and ReWiND reversibility)
lambda_prog: float = 1.0 # Progress regression weight
lambda_spatial_nce: float = 0.5 # Spatial-aware InfoNCE weight
lambda_rewind: float = 0.4 # ReWiND reversible ranking weight
# Loss hyperparameters
nce_temperature: float = 0.07 # Temperature for InfoNCE
zscore_eps: float = 1e-5 # Epsilon for z-score normalization
min_rank_gap: int = 1 # Minimum gap for temporal ranking pairs
num_ranking_pairs: int = 64 # Number of (far, near) pairs to sample for ReWiND
last_k_for_nce: int = 3 # Use last k frames for InfoNCE
mismatch_lang_prob: float = 0.2 # Probability of language mismatch augmentation
# Value-based pairwise loss hyperparameters (for value_pairwise mode)
lambda_dir: float = 1.0 # Intra-trajectory directional ranking
lambda_text: float = 0.5 # Inter-instruction contrastive ranking
lambda_flat: float = 0.25 # Flatness under mismatch
dir_margin: float = 0.2 # Margin for directional ranking
text_margin: float = 0.2 # Margin for text contrastive ranking
flat_epsilon: float = 0.05 # Epsilon band for flatness loss
num_pairs_per_loss: int = 64 # Number of pairs to sample per loss term
use_hard_negatives: bool = True # Whether to generate hard negative instructions
# Normalization presets
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.MEAN_STD,
# Language is tokenized at the encoder level; no numeric normalization here.
}
)
def validate_features(self) -> None:
# Require at least one image feature. Language is recommended but optional (can be blank).
if not self.image_features:
raise ValueError(
"You must provide at least one image feature for RLearN (e.g. 'observation.image')."
)
@property
def observation_delta_indices(self) -> list | None:
# Not using delta sampling from the dataset by default.
return None
@property
def action_delta_indices(self) -> list | None:
# Not an action chunking policy.
return None
@property
def reward_delta_indices(self) -> list | None:
# By default we supervise every provided timestep equally.
return None
def get_optimizer_preset(self): # type: ignore[override]
from lerobot.optim.optimizers import AdamWConfig
return AdamWConfig(lr=self.learning_rate, weight_decay=self.weight_decay)
def get_scheduler_preset(self): # type: ignore[override]
# No scheduler by default.
return None
+601
View File
@@ -0,0 +1,601 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Evaluation metrics for RLearn (Video-Language Conditioned Reward Model).
Key metrics:
1. VOC-S (Value-Order Correlation for Success): Spearman correlation between frame indices and predicted rewards
2. Success vs Failure Detection: Model's ability to distinguish between correct and incorrect language conditions
"""
from __future__ import annotations
import warnings
from typing import Any
import numpy as np
import torch
from scipy.stats import spearmanr
from torch import Tensor
from tqdm import tqdm
from lerobot.constants import OBS_IMAGES, OBS_LANGUAGE
def compute_voc_s(
predicted_rewards: list[np.ndarray], use_interquartile_mean: bool = True
) -> dict[str, float]:
"""
Compute Value-Order Correlation for Success (VOC-S).
Measures whether per-frame rewards increase as successful execution unfolds.
For each episode, computes Spearman correlation between frame indices [1..T]
and predicted rewards [r1..rT].
Args:
predicted_rewards: List of reward arrays, one per episode. Each array has shape (T,)
use_interquartile_mean: If True, use IQM instead of mean for aggregation
Returns:
Dictionary with VOC-S metrics:
- voc_s_mean: Mean Spearman correlation across episodes
- voc_s_std: Standard deviation of correlations
- voc_s_iqm: Interquartile mean (if use_interquartile_mean=True)
- num_episodes: Number of episodes evaluated
- correlations: Individual correlations per episode
"""
if not predicted_rewards:
return {"voc_s_mean": 0.0, "voc_s_std": 0.0, "voc_s_iqm": 0.0, "num_episodes": 0, "correlations": []}
correlations = []
for episode_rewards in predicted_rewards:
if len(episode_rewards) < 2:
# Need at least 2 points for correlation
continue
# Frame indices: [1, 2, ..., T]
frame_indices = np.arange(1, len(episode_rewards) + 1)
# Compute Spearman correlation
try:
correlation, p_value = spearmanr(frame_indices, episode_rewards)
# Handle NaN correlations (e.g., all rewards are identical)
if np.isnan(correlation):
correlation = 0.0
correlations.append(correlation)
except Exception as e:
warnings.warn(f"Failed to compute correlation for episode: {e}")
correlations.append(0.0)
if not correlations:
return {"voc_s_mean": 0.0, "voc_s_std": 0.0, "voc_s_iqm": 0.0, "num_episodes": 0, "correlations": []}
correlations = np.array(correlations)
# Compute statistics
voc_s_mean = float(np.mean(correlations))
voc_s_std = float(np.std(correlations))
# Interquartile mean: mean of values between 25th and 75th percentiles
if use_interquartile_mean and len(correlations) >= 4:
q25, q75 = np.percentile(correlations, [25, 75])
iqm_mask = (correlations >= q25) & (correlations <= q75)
voc_s_iqm = float(np.mean(correlations[iqm_mask]))
else:
voc_s_iqm = voc_s_mean
return {
"voc_s_mean": voc_s_mean,
"voc_s_std": voc_s_std,
"voc_s_iqm": voc_s_iqm,
"num_episodes": len(correlations),
"correlations": correlations.tolist(),
}
def compute_success_failure_detection(
correct_rewards: list[np.ndarray], incorrect_rewards: list[np.ndarray], threshold_percentile: float = 50.0
) -> dict[str, float]:
"""
Compute success vs failure detection accuracy.
Tests the model's ability to distinguish between correct and incorrect language conditions.
For each episode, compares final reward under correct vs incorrect language instruction.
Args:
correct_rewards: List of reward arrays for episodes with correct language
incorrect_rewards: List of reward arrays for episodes with incorrect/mismatched language
threshold_percentile: Percentile of correct rewards to use as threshold
Returns:
Dictionary with detection metrics:
- detection_accuracy: Fraction of episodes where correct > incorrect
- mean_correct_final: Mean final reward for correct language
- mean_incorrect_final: Mean final reward for incorrect language
- separation_score: (mean_correct - mean_incorrect) / (std_correct + std_incorrect)
- num_pairs: Number of episode pairs evaluated
"""
if len(correct_rewards) != len(incorrect_rewards):
raise ValueError("Must have same number of correct and incorrect reward sequences")
if not correct_rewards:
return {
"detection_accuracy": 0.0,
"mean_correct_final": 0.0,
"mean_incorrect_final": 0.0,
"separation_score": 0.0,
"num_pairs": 0,
}
# Extract final rewards (last timestep of each episode)
correct_finals = []
incorrect_finals = []
for correct_ep, incorrect_ep in zip(correct_rewards, incorrect_rewards, strict=False):
if len(correct_ep) > 0 and len(incorrect_ep) > 0:
correct_finals.append(correct_ep[-1]) # Final reward
incorrect_finals.append(incorrect_ep[-1]) # Final reward
if not correct_finals:
return {
"detection_accuracy": 0.0,
"mean_correct_final": 0.0,
"mean_incorrect_final": 0.0,
"separation_score": 0.0,
"num_pairs": 0,
}
correct_finals = np.array(correct_finals)
incorrect_finals = np.array(incorrect_finals)
# Detection accuracy: fraction where correct > incorrect
detection_accuracy = float(np.mean(correct_finals > incorrect_finals))
# Statistics
mean_correct = float(np.mean(correct_finals))
mean_incorrect = float(np.mean(incorrect_finals))
std_correct = float(np.std(correct_finals))
std_incorrect = float(np.std(incorrect_finals))
# Separation score: normalized difference (clamp to prevent extreme values)
denominator = std_correct + std_incorrect
if denominator > 1e-6: # Prevent division by very small numbers
separation_score = (mean_correct - mean_incorrect) / denominator
# Clamp to reasonable range
separation_score = np.clip(separation_score, -100.0, 100.0)
else:
separation_score = 0.0
return {
"detection_accuracy": detection_accuracy,
"mean_correct_final": mean_correct,
"mean_incorrect_final": mean_incorrect,
"separation_score": float(separation_score),
"num_pairs": len(correct_finals),
}
def generate_mismatched_languages(
original_languages: list[str], mismatch_templates: list[str] | None = None
) -> list[str]:
"""
Generate mismatched language instructions for failure detection evaluation.
Args:
original_languages: List of original task descriptions
mismatch_templates: Custom mismatch templates. If None, uses defaults.
Returns:
List of mismatched language instructions
"""
if mismatch_templates is None:
mismatch_templates = ["kick the ball", "walk to the red shoes", "wave", "do nothing"]
# For each original language, pick a random mismatch
mismatched = []
np.random.seed(42) # For reproducibility
for i, orig_lang in enumerate(original_languages):
# Use modulo to cycle through mismatches if we have more episodes than templates
mismatch_idx = i % len(mismatch_templates)
mismatched.append(mismatch_templates[mismatch_idx])
return mismatched
class RLearnEvaluator:
"""
Comprehensive evaluator for RLearN reward models.
Provides methods to evaluate VOC-S and success/failure detection on datasets.
"""
def __init__(self, model, device: str = "cuda"):
"""
Args:
model: RLearN model instance
device: Device to run evaluation on
"""
self.model = model
self.device = device
self.model.eval()
@torch.no_grad()
def predict_episode_rewards(self, frames: Tensor, language: str, batch_size: int = 16) -> np.ndarray:
"""
Predict rewards for a single episode.
Args:
frames: Video frames tensor of shape (T, C, H, W)
language: Language instruction string
batch_size: Maximum sequence length to process at once
Returns:
Predicted rewards array of shape (T,)
"""
T = frames.shape[0]
# Preprocess frames to match model expectations
processed_frames = self._preprocess_frames(frames)
# Process in chunks if episode is very long
if T <= batch_size:
# Single batch
batch = {
OBS_IMAGES: processed_frames.unsqueeze(0).to(self.device), # (1, T, C, H, W)
OBS_LANGUAGE: [language],
}
# Use the new predict_rewards method
values = self.model.predict_rewards(batch) # (1, T')
rewards = values.squeeze(0).cpu().numpy() # (T',)
else:
# Process in overlapping chunks to handle very long episodes
rewards = []
stride = batch_size // 2 # 50% overlap
for i in range(0, T, stride):
end_idx = min(i + batch_size, T)
chunk_frames = processed_frames[i:end_idx]
batch = {OBS_IMAGES: chunk_frames.unsqueeze(0).to(self.device), OBS_LANGUAGE: [language]}
chunk_values = self.model.predict_rewards(batch)
chunk_rewards = chunk_values.squeeze(0).cpu().numpy()
# For overlapping chunks, only take the first half (except for the last chunk)
if i + batch_size < T:
rewards.extend(chunk_rewards[:stride])
else:
rewards.extend(chunk_rewards)
rewards = np.array(rewards[:T]) # Ensure exact length
return rewards
def _preprocess_frames(self, frames: Tensor) -> Tensor:
"""
Preprocess frames to match model expectations.
Args:
frames: Input frames tensor of shape (T, C, H, W)
Returns:
Preprocessed frames tensor of shape (T, C, H', W')
"""
import torch.nn.functional as F
T, C, H, W = frames.shape
# Expected input size for SigLIP2 is typically 256x256
target_size = 256
# Resize frames if needed
if H != target_size or W != target_size:
# Resize using bilinear interpolation
frames = F.interpolate(
frames, size=(target_size, target_size), mode="bilinear", align_corners=False
)
# Normalize to [0, 1] if needed
if frames.dtype == torch.uint8:
frames = frames.float() / 255.0
# Ensure values are in [0, 1] range
frames = torch.clamp(frames, 0.0, 1.0)
return frames
def evaluate_voc_s(
self, dataset, num_episodes: int = 100, use_interquartile_mean: bool = True
) -> dict[str, Any]:
"""
Evaluate VOC-S on a dataset.
Args:
dataset: LeRobot dataset instance
num_episodes: Number of episodes to evaluate (randomly sampled)
use_interquartile_mean: Whether to compute IQM
Returns:
VOC-S evaluation results
"""
print(f"Evaluating VOC-S on {num_episodes} episodes...")
# Sample episodes
total_episodes = dataset.num_episodes
if num_episodes >= total_episodes:
episode_indices = list(range(total_episodes))
else:
np.random.seed(42)
episode_indices = np.random.choice(total_episodes, num_episodes, replace=False)
predicted_rewards = []
for ep_idx in tqdm(episode_indices, desc="Computing VOC-S"):
try:
# Get episode data
ep_start = dataset.episode_data_index["from"][ep_idx].item()
ep_end = dataset.episode_data_index["to"][ep_idx].item()
episode_length = ep_end - ep_start
# Get frames and language for this episode
frames = []
language = None
for frame_idx in range(episode_length):
global_idx = ep_start + frame_idx
frame_data = dataset[global_idx]
# Extract image (assuming single camera for now)
if OBS_IMAGES in frame_data:
img = frame_data[OBS_IMAGES]
else:
# Try to find image key
img_keys = [k for k in frame_data.keys() if "image" in k.lower()]
if img_keys:
img = frame_data[img_keys[0]]
else:
continue
# Convert to tensor if needed
if isinstance(img, np.ndarray):
img = torch.from_numpy(img)
# Ensure CHW format
if len(img.shape) == 3 and img.shape[-1] in [1, 3, 4]:
img = img.permute(2, 0, 1) # HWC -> CHW
# Resize to expected input size (256x256 for SigLIP2) BEFORE stacking
if img.shape[-2:] != (256, 256):
import torch.nn.functional as F
img = F.interpolate(
img.unsqueeze(0), size=(256, 256), mode="bilinear", align_corners=False
).squeeze(0)
# Normalize to [0, 1] if needed
if img.dtype == torch.uint8:
img = img.float() / 255.0
frames.append(img)
# Get language instruction
if language is None:
if OBS_LANGUAGE in frame_data:
language = frame_data[OBS_LANGUAGE]
if isinstance(language, list):
language = language[0]
elif "task" in frame_data:
language = frame_data["task"]
else:
language = "" # Default empty language
if not frames:
continue
# Stack frames into video tensor
frames_tensor = torch.stack(frames) # (T, C, H, W)
# Predict rewards
episode_rewards = self.predict_episode_rewards(frames_tensor, language)
predicted_rewards.append(episode_rewards)
except Exception as e:
warnings.warn(f"Failed to process episode {ep_idx}: {e}")
continue
# Compute VOC-S
voc_results = compute_voc_s(predicted_rewards, use_interquartile_mean)
print("VOC-S Results:")
print(f" Mean correlation: {voc_results['voc_s_mean']:.4f}")
print(f" Std correlation: {voc_results['voc_s_std']:.4f}")
print(f" IQM correlation: {voc_results['voc_s_iqm']:.4f}")
print(f" Episodes evaluated: {voc_results['num_episodes']}")
return voc_results
def evaluate_success_failure_detection(
self, dataset, num_episodes: int = 100, mismatch_templates: list[str] | None = None
) -> dict[str, Any]:
"""
Evaluate success vs failure detection.
Args:
dataset: LeRobot dataset instance
num_episodes: Number of episodes to evaluate
mismatch_templates: Custom mismatch language templates
Returns:
Success/failure detection results
"""
print(f"Evaluating success/failure detection on {num_episodes} episodes...")
# Sample episodes
total_episodes = dataset.num_episodes
if num_episodes >= total_episodes:
episode_indices = list(range(total_episodes))
else:
np.random.seed(42)
episode_indices = np.random.choice(total_episodes, num_episodes, replace=False)
correct_rewards = []
incorrect_rewards = []
# Get original languages
original_languages = []
for ep_idx in episode_indices:
ep_start = dataset.episode_data_index["from"][ep_idx].item()
frame_data = dataset[ep_start]
if OBS_LANGUAGE in frame_data:
lang = frame_data[OBS_LANGUAGE]
if isinstance(lang, list):
lang = lang[0]
elif "task" in frame_data:
lang = frame_data["task"]
else:
lang = ""
original_languages.append(lang)
# Generate mismatched languages
mismatched_languages = generate_mismatched_languages(original_languages, mismatch_templates)
for i, ep_idx in enumerate(tqdm(episode_indices, desc="Computing detection metrics")):
try:
# Get episode frames (same as VOC-S evaluation)
ep_start = dataset.episode_data_index["from"][ep_idx].item()
ep_end = dataset.episode_data_index["to"][ep_idx].item()
episode_length = ep_end - ep_start
frames = []
for frame_idx in range(episode_length):
global_idx = ep_start + frame_idx
frame_data = dataset[global_idx]
# Extract image
if OBS_IMAGES in frame_data:
img = frame_data[OBS_IMAGES]
else:
img_keys = [k for k in frame_data.keys() if "image" in k.lower()]
if img_keys:
img = frame_data[img_keys[0]]
else:
continue
if isinstance(img, np.ndarray):
img = torch.from_numpy(img)
if len(img.shape) == 3 and img.shape[-1] in [1, 3, 4]:
img = img.permute(2, 0, 1)
# Resize to expected input size (256x256 for SigLIP2)
if img.shape[-2:] != (256, 256):
import torch.nn.functional as F
img = F.interpolate(
img.unsqueeze(0), size=(256, 256), mode="bilinear", align_corners=False
).squeeze(0)
# Normalize to [0, 1] if needed
if img.dtype == torch.uint8:
img = img.float() / 255.0
frames.append(img)
if not frames:
continue
frames_tensor = torch.stack(frames)
# Predict with correct language
correct_lang = original_languages[i]
correct_ep_rewards = self.predict_episode_rewards(frames_tensor, correct_lang)
# Predict with incorrect language
incorrect_lang = mismatched_languages[i]
incorrect_ep_rewards = self.predict_episode_rewards(frames_tensor, incorrect_lang)
correct_rewards.append(correct_ep_rewards)
incorrect_rewards.append(incorrect_ep_rewards)
except Exception as e:
warnings.warn(f"Failed to process episode {ep_idx} for detection: {e}")
continue
# Compute detection metrics
detection_results = compute_success_failure_detection(correct_rewards, incorrect_rewards)
print("Success/Failure Detection Results:")
print(f" Detection accuracy: {detection_results['detection_accuracy']:.4f}")
print(f" Mean correct final reward: {detection_results['mean_correct_final']:.4f}")
print(f" Mean incorrect final reward: {detection_results['mean_incorrect_final']:.4f}")
print(f" Separation score: {detection_results['separation_score']:.4f}")
print(f" Episode pairs evaluated: {detection_results['num_pairs']}")
return detection_results
def comprehensive_evaluation(
self,
dataset,
num_episodes: int = 100,
use_interquartile_mean: bool = True,
mismatch_templates: list[str] | None = None,
) -> dict[str, Any]:
"""
Run comprehensive evaluation including both VOC-S and detection metrics.
Returns:
Combined evaluation results
"""
print("=" * 60)
print("COMPREHENSIVE RLEARN EVALUATION")
print("=" * 60)
# VOC-S evaluation
voc_results = self.evaluate_voc_s(
dataset, num_episodes=num_episodes, use_interquartile_mean=use_interquartile_mean
)
print("\n" + "=" * 40)
# Success/failure detection
detection_results = self.evaluate_success_failure_detection(
dataset, num_episodes=num_episodes, mismatch_templates=mismatch_templates
)
# Combined results
results = {
"voc_s": voc_results,
"detection": detection_results,
"overall_score": (
voc_results["voc_s_iqm"] * 0.6 + detection_results["detection_accuracy"] * 0.4
), # Weighted combination
}
print("\n" + "=" * 60)
print(f"OVERALL EVALUATION SCORE: {results['overall_score']:.4f}")
print("=" * 60)
return results
@@ -0,0 +1,902 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
RLearN: Video-Language Conditioned Reward Model
Inputs
- images: (B, T, C, H, W) sequence of frames (or single frame with T=1)
- language: list[str] of length B (goal/instruction)
High-level Architecture
images (B,T,C,H,W)
|
| per-frame encode
v
+------------------------------+
| Vision Encoder (frozen) | e.g. SigLIP2 vision tower
+------------------------------+
|
| pooled per-frame embeddings (BT, H_v)
v
reshape -> (B, T, H_v) -- Linear proj --> (B, T, D)
+ Positional Encoding [0..T)
+ Optional first-frame bias
|
| language (B, str)
| |
| v
| +------------------------------+
| | Text Encoder (frozen) | e.g. SigLIP2 text tower
| +------------------------------+
| |
| | pooled text embedding (B, H_t)
| v
| Linear proj -> (B, D)
| |
+-----------------v----------------------+
|
+--------------------------v---------------------------+
| Temporal Causal Transformer (n_layers, n_heads) |
| - self-attention over time with causal mask |
| - cross-attention to a single language token |
+--------------------------+---------------------------+
|
LayerNorm + Linear Head (D -> 1)
|
v
Output
- reward_logits: (B, T', 1) with T' ≤ T (affected by stride and frame dropout)
Training
- Loss: composite loss with progress regression, spatial-aware InfoNCE, and ReWiND reversible ranking
Notes
- Backbones (vision/text) are frozen by default; only projections, temporal module, and head are trainable.
- Stride/frame dropout applied during training can subsample timesteps.
"""
from __future__ import annotations
import math
import torch
import torch.nn.functional as F
from torch import Tensor, nn
from lerobot.constants import OBS_IMAGE, OBS_IMAGES, OBS_LANGUAGE, REWARD
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.rlearn.configuration_rlearn import RLearNConfig
class RLearNPolicy(PreTrainedPolicy):
"""Video-language conditioned reward model.
- Visual encoder: frozen SigLIP2 (via transformers AutoModel), returns per-frame embeddings.
- Text encoder: frozen SigLIP2 text tower, returns a language embedding.
- Temporal module: causal transformer over time that cross-attends to language embedding.
- Output: per-timestep reward logits; trainable small head.
"""
config_class = RLearNConfig
name = "rlearn"
def __init__(self, config: RLearNConfig):
super().__init__(config)
self.config = config
# Encoders
from transformers import AutoModel, AutoProcessor
self.vision_text_model = AutoModel.from_pretrained(config.model_name, trust_remote_code=True)
self.processor = AutoProcessor.from_pretrained(config.model_name, trust_remote_code=True)
# Detect towers
if hasattr(self.vision_text_model, "vision_model") and hasattr(self.vision_text_model, "text_model"):
self.vision_encoder = self.vision_text_model.vision_model
self.text_encoder = self.vision_text_model.text_model
self.vision_hidden = getattr(self.vision_text_model.config, "vision_config", None).hidden_size
self.text_hidden = getattr(self.vision_text_model.config, "text_config", None).hidden_size
else:
# Fallback if AutoModel exposes pooled outputs directly (rare for SigLIP2)
self.vision_encoder = self.vision_text_model
self.text_encoder = self.vision_text_model
self.vision_hidden = getattr(self.vision_text_model.config, "hidden_size", 768)
self.text_hidden = getattr(self.vision_text_model.config, "hidden_size", 768)
if config.freeze_backbones:
for p in self.vision_encoder.parameters():
p.requires_grad = False
for p in self.text_encoder.parameters():
p.requires_grad = False
# Linear projections to the shared temporal model dimension
self.visual_proj = nn.Linear(self.vision_hidden, config.dim_model)
self.text_proj = nn.Linear(self.text_hidden, config.dim_model)
# Positional encodings over time
self.register_buffer(
"positional_encoding",
create_sinusoidal_pos_encoding(config.max_seq_len, config.dim_model),
persistent=False,
)
# Optional first-frame learned bias to discourage position cheating
self.first_frame_bias = (
nn.Parameter(torch.zeros(1, 1, config.dim_model))
if config.use_first_frame_positional_bias
else None
)
# Temporal aggregator: causal transformer over time with language cross-attention
self.temporal = TemporalCausalTransformer(
dim_model=config.dim_model,
n_heads=config.n_heads,
n_layers=config.n_layers,
dim_feedforward=config.dim_feedforward,
dropout=config.dropout,
pre_norm=config.pre_norm,
)
# Reward head with proper initialization
head_linear = nn.Linear(config.dim_model, 1)
# Initialize with small weights and bias to output values around 0
nn.init.normal_(head_linear.weight, mean=0.0, std=0.02)
nn.init.constant_(head_linear.bias, 0.0) # Start with 0 bias, sigmoid(0) = 0.5
head_layers: list[nn.Module] = [head_linear]
if config.use_tanh_head:
head_layers.append(nn.Tanh())
self.head = nn.Sequential(*head_layers)
# Projection from scalar value summary to text embedding dim for InfoNCE
self.value_to_text_proj = nn.Linear(1, config.dim_model)
# Spatial attention for InfoNCE
self.spatial_cross_attn = nn.MultiheadAttention(
embed_dim=config.dim_model, num_heads=config.n_heads, batch_first=True
)
self.spatial_norm = nn.LayerNorm(config.dim_model)
# Simple frame dropout probability
self.frame_dropout_p = config.frame_dropout_p
self.stride = max(1, config.stride)
def get_optim_params(self) -> dict:
# Train only projections, temporal module and head by default if backbones are frozen
return [p for p in self.parameters() if p.requires_grad]
def reset(self):
pass
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: # Required by base class
raise NotImplementedError("RLearN is a reward model and does not predict actions")
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor: # Required by base class
raise NotImplementedError("RLearN is a reward model and does not select actions")
@torch.no_grad()
def predict_rewards(self, batch: dict[str, Tensor]) -> Tensor:
"""Predict per-timestep rewards for evaluation.
Args:
batch: Input batch with OBS_IMAGES and optionally OBS_LANGUAGE
Returns:
Predicted rewards tensor of shape (B, T)
"""
batch = self.normalize_inputs(batch)
# Extract frames and form (B, T, C, H, W)
frames = extract_visual_sequence(batch)
B, T, C, H, W = frames.shape
# Apply stride (no dropout during eval)
idx = torch.arange(0, T, self.stride, device=frames.device)
frames = frames[:, idx]
B, T_eff, C, H, W = frames.shape # NEW: effective length after stride
# Encode language
lang_emb = encode_language(
batch.get(OBS_LANGUAGE, None), self.text_encoder, self.processor, batch_size=B
)
lang_emb = self.text_proj(lang_emb) # (B, D)
# ---- NEW: use the HF processor to standardize size & normalization ----
# Flatten (B, T_eff, C, H, W) -> (BT, C, H, W)
BT = B * T_eff
flat = frames.reshape(BT, C, H, W).detach().cpu()
# Convert to uint8 HWC numpy (processor prefers PIL/np)
# If already in [0,1], scale to [0,255]
if flat.dtype != torch.uint8:
if flat.numel() > 0 and float(flat.max()) <= 1.0:
flat = flat * 255.0
flat = flat.clamp(0, 255).round().to(torch.uint8)
images = [flat[k].permute(1, 2, 0).numpy() for k in range(flat.size(0))]
proc_out = self.processor(images=images, return_tensors="pt")
pixel_values = proc_out["pixel_values"].to(next(self.vision_encoder.parameters()).device)
# ----------------------------------------------------------------------
# Encode frames through visual tower per frame
vision_outputs = self.vision_encoder(pixel_values=pixel_values)
# Extract CLS tokens for temporal modeling
if hasattr(vision_outputs, "last_hidden_state"):
cls_tokens = vision_outputs.last_hidden_state[:, 0] # (BT, D_vision)
else:
raise RuntimeError("Vision encoder must output last_hidden_state")
# Project CLS tokens for temporal sequence
visual_seq = self.visual_proj(cls_tokens).reshape(B, T_eff, self.config.dim_model) # (B, T', D)
# Add temporal positional encodings and optional first-frame bias
pe = (
self.positional_encoding[: visual_seq.shape[1]]
.unsqueeze(0)
.to(visual_seq.dtype)
.to(visual_seq.device)
)
visual_seq = visual_seq + pe
if self.first_frame_bias is not None:
visual_seq = visual_seq.clone()
visual_seq[:, :1] = visual_seq[:, :1] + self.first_frame_bias
# Temporal model with cross-attention to language
temporal_features = self.temporal(visual_seq, lang_emb, return_features=True) # (B, T', D)
values = self.head(temporal_features).squeeze(-1) # (B, T')
return values
def normalize_inputs(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
# Initial version: no-op; rely on upstream processors if any
return batch
def normalize_targets(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
# Initial version: no-op
return batch
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
"""Compute training loss and logs.
Expected batch keys:
- OBS_IMAGES: list[Tensor] of shape [(B, C, H, W), ...] per time step or stacked (B, T, C, H, W)
- OBS_LANGUAGE: optional string tokens already tokenized externally or raw strings
- REWARD: (B, T) or (B,) target rewards
"""
batch = self.normalize_inputs(batch)
batch = self.normalize_targets(batch)
# Extract frames and form (B, T, C, H, W)
frames = extract_visual_sequence(batch)
B, T, C, H, W = frames.shape
# Apply stride and frame dropout during training
idx = torch.arange(0, T, self.stride, device=frames.device)
if self.training and self.frame_dropout_p > 0.0 and T > 1:
mask = torch.rand_like(idx.float()) > self.frame_dropout_p
idx = idx[mask.long().bool()]
if idx.numel() == 0:
idx = torch.tensor([0], device=frames.device)
frames = frames[:, idx]
# Encode language
lang_emb = encode_language(
batch.get(OBS_LANGUAGE, None), self.text_encoder, self.processor, batch_size=B
)
lang_emb = self.text_proj(lang_emb) # (B, D)
# Encode frames through visual tower per frame
# Flatten time for batched encode
BT = B * frames.shape[1]
flat = frames.reshape(BT, C, H, W)
# Use HF processor to properly resize and normalize images
# Convert to CPU for processing, then move back to device
flat_cpu = flat.detach().cpu()
# Convert to uint8 HWC numpy format expected by processor
if flat_cpu.dtype != torch.uint8:
if flat_cpu.numel() > 0 and float(flat_cpu.max()) <= 1.0:
flat_cpu = flat_cpu * 255.0
flat_cpu = flat_cpu.clamp(0, 255).round().to(torch.uint8)
# Convert to list of numpy arrays
images = [flat_cpu[k].permute(1, 2, 0).numpy() for k in range(flat_cpu.size(0))]
# Process with HF processor (resizes to 256x256 and normalizes)
proc_out = self.processor(images=images, return_tensors="pt")
pixel_values = proc_out["pixel_values"].to(next(self.vision_encoder.parameters()).device)
# Encode through vision model
vision_outputs = self.vision_encoder(pixel_values=pixel_values)
# Extract BOTH CLS token and spatial patches
if hasattr(vision_outputs, "last_hidden_state"):
all_tokens = vision_outputs.last_hidden_state # (BT, num_tokens, D)
cls_tokens = all_tokens[:, 0] # (BT, D) - CLS token for temporal modeling
spatial_tokens = all_tokens[:, 1:] # (BT, num_patches, D) - spatial patches
else:
raise RuntimeError("Vision encoder must output last_hidden_state with spatial features")
# Project CLS tokens for temporal sequence
visual_seq = self.visual_proj(cls_tokens).reshape(B, -1, self.config.dim_model) # (B, T', D)
# Keep spatial features for spatial-aware losses (project them too)
# Assuming 16x16 patches for 256x256 image with patch_size=16
num_patches = spatial_tokens.shape[1]
spatial_features = self.visual_proj(spatial_tokens).reshape(
B, -1, num_patches, self.config.dim_model
) # (B, T', num_patches, D)
# Add temporal positional encodings and optional first-frame bias
pe = (
self.positional_encoding[: visual_seq.shape[1]]
.unsqueeze(0)
.to(visual_seq.dtype)
.to(visual_seq.device)
)
visual_seq = visual_seq + pe
if self.first_frame_bias is not None:
visual_seq = visual_seq.clone()
visual_seq[:, :1] = visual_seq[:, :1] + self.first_frame_bias
# Temporal model with cross-attention to language
temporal_features = self.temporal(visual_seq, lang_emb, return_features=True) # (B, T', D)
values = self.head(temporal_features).squeeze(-1) # (B, T')
# Targets
target = batch.get(REWARD, None)
loss_dict: dict[str, float] = {}
if target is None:
# If no labels, return zeros loss and logits for inference
loss = values.mean() * 0.0
loss_dict["has_labels"] = 0.0
return loss, {**loss_dict, "values_mean": values.mean().item()}
# Align target with sampled timesteps
if target.dim() == 1:
target = target.unsqueeze(1) # (B, 1)
target = target[:, idx]
# Composite loss
# 1) Progress regression on z-scored values to match normalized progress labels y in [0,1]
# Debug: Check if values have enough variance
values_std = values.std()
if values_std < 1e-4:
# Early in training, model outputs are nearly constant
# Use direct MSE loss without z-scoring to encourage variance
import logging
logging.info(f"Low variance in values (std={values_std:.6f}), using direct MSE")
# Apply sigmoid directly to raw values to get them in [0,1] range
prog_pred = torch.sigmoid(values * 10.0) # Scale up to encourage learning
L_prog = F.mse_loss(prog_pred, torch.clamp(target, 0.0, 1.0))
else:
# Normal case: use z-score normalization
zV = zscore(values, eps=self.config.zscore_eps)
# Check for NaN after zscore
if torch.isnan(zV).any():
import logging
logging.warning(f"NaN after zscore. Values: {values}, zV: {zV}")
# Fallback to direct sigmoid
prog_pred = torch.sigmoid(values * 10.0)
else:
prog_pred = torch.sigmoid(zV)
L_prog = F.mse_loss(prog_pred, torch.clamp(target, 0.0, 1.0))
# Mismatched pairs: randomly shuffle language within batch and require near-zero progress
if self.training and torch.rand(()) < self.config.mismatch_lang_prob and values.size(0) > 1:
shuffled = torch.randperm(B, device=values.device)
lang_mismatch = lang_emb[shuffled]
mismatch_feat = self.temporal(visual_seq, lang_mismatch, return_features=True)
mismatch_V = self.head(mismatch_feat).squeeze(-1)
L_prog_mismatch = F.mse_loss(
torch.sigmoid(zscore(mismatch_V, eps=self.config.zscore_eps)), torch.zeros_like(target)
)
else:
L_prog_mismatch = torch.zeros((), device=values.device)
# 2) Spatial-Aware InfoNCE: Use language to attend to relevant spatial regions
# Take late timesteps' spatial features
k = min(self.config.last_k_for_nce, spatial_features.shape[1])
late_spatial = spatial_features[:, -k:].mean(dim=1) # (B, num_patches, D)
# Language queries spatial patches via cross-attention
lang_query = lang_emb.unsqueeze(1) # (B, 1, D)
attended_spatial, spatial_attn_weights = self.spatial_cross_attn(
query=lang_query, key=late_spatial, value=late_spatial, need_weights=True
)
attended_spatial = self.spatial_norm(attended_spatial).squeeze(1) # (B, D)
# Contrastive loss with spatially-attended features
attended_spatial = F.normalize(attended_spatial, dim=-1)
lang_norm = F.normalize(lang_emb, dim=-1)
logits_spatial = (attended_spatial @ lang_norm.t()) / self.config.nce_temperature # (B, B)
targets_nce = torch.arange(B, device=values.device)
L_spatial_nce = F.cross_entropy(logits_spatial, targets_nce)
# 3) ReWiND Reversible Ranking: Learn from both forward and reversed trajectories
# This teaches the model what constitutes progress vs undoing progress
L_rank_forward, L_rank_reverse = reversible_ranking_loss(
values,
target,
margin=self.config.ranking_margin,
num_pairs=self.config.num_ranking_pairs,
min_gap=self.config.min_rank_gap,
)
L_rewind = L_rank_forward + L_rank_reverse
# Check for NaNs in individual loss components
if torch.isnan(L_prog):
import logging
logging.warning(f"NaN in L_prog. Values: {values}, Target: {target}")
# Return a small loss with gradients instead of zero
L_prog = values.mean() * 0.0 + 0.01
if torch.isnan(L_spatial_nce):
import logging
logging.warning("NaN in L_spatial_nce")
# Use a dummy loss that maintains gradients
L_spatial_nce = attended_spatial.mean() * 0.0 + 0.01
if torch.isnan(L_rewind):
import logging
logging.warning("NaN in L_rewind")
# Use values to maintain gradient flow
L_rewind = values.mean() * 0.0 + 0.01
loss = (
self.config.lambda_prog * (L_prog + L_prog_mismatch)
+ self.config.lambda_spatial_nce * L_spatial_nce
+ self.config.lambda_rewind * L_rewind
)
# Final NaN check
if torch.isnan(loss):
import logging
logging.warning("NaN loss detected, using fallback loss")
# Use a small loss that maintains gradients
loss = values.mean() * 0.0 + 0.01
loss_dict.update(
{
"loss_prog": L_prog.item() if not torch.isnan(L_prog) else 0.0,
"loss_prog_mismatch": L_prog_mismatch.item() if not torch.isnan(L_prog_mismatch) else 0.0,
"loss_spatial_nce": L_spatial_nce.item() if not torch.isnan(L_spatial_nce) else 0.0,
"loss_rewind_forward": L_rank_forward.item() if not torch.isnan(L_rank_forward) else 0.0,
"loss_rewind_reverse": L_rank_reverse.item() if not torch.isnan(L_rank_reverse) else 0.0,
}
)
loss_dict["loss"] = loss.item()
loss_dict["values_mean"] = values.mean().item()
return loss, loss_dict
class TemporalCausalTransformer(nn.Module):
def __init__(
self,
dim_model: int,
n_heads: int,
n_layers: int,
dim_feedforward: int,
dropout: float,
pre_norm: bool,
):
super().__init__()
self.layers = nn.ModuleList(
[
TemporalCausalTransformerLayer(dim_model, n_heads, dim_feedforward, dropout, pre_norm)
for _ in range(n_layers)
]
)
self.norm = nn.LayerNorm(dim_model)
self.head = nn.Linear(dim_model, 1)
def forward(self, x: Tensor, lang_emb: Tensor, return_features: bool = False) -> Tensor:
# x: (B, T, D), lang_emb: (B, D)
B, T, D = x.shape
# Prepare language as a single token for cross-attention context
lang_token = lang_emb.unsqueeze(1) # (B, 1, D)
x = x.transpose(0, 1) # (T, B, D)
lang_token = lang_token.transpose(0, 1) # (1, B, D)
causal_mask = generate_causal_mask(T, device=x.device)
for layer in self.layers:
x = layer(x, lang_token, causal_mask)
x = self.norm(x)
x = x.transpose(0, 1) # (B, T, D)
if return_features:
return x
return self.head(x) # (B, T, 1)
class TemporalCausalTransformerLayer(nn.Module):
def __init__(self, dim_model: int, n_heads: int, dim_feedforward: int, dropout: float, pre_norm: bool):
super().__init__()
self.self_attn = nn.MultiheadAttention(dim_model, n_heads, dropout=dropout, batch_first=False)
self.cross_attn = nn.MultiheadAttention(dim_model, n_heads, dropout=dropout, batch_first=False)
self.linear1 = nn.Linear(dim_model, dim_feedforward)
self.linear2 = nn.Linear(dim_feedforward, dim_model)
self.dropout = nn.Dropout(dropout)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
self.norm1 = nn.LayerNorm(dim_model)
self.norm2 = nn.LayerNorm(dim_model)
self.norm3 = nn.LayerNorm(dim_model)
self.activation = F.gelu
self.pre_norm = pre_norm
def forward(self, x: Tensor, lang_token: Tensor, causal_mask: Tensor) -> Tensor:
# Self-attention with causal mask
residual = x
if self.pre_norm:
x = self.norm1(x)
x = self.self_attn(x, x, x, attn_mask=causal_mask)[0]
x = residual + self.dropout1(x)
if not self.pre_norm:
x = self.norm1(x)
# Cross-attention to language token (keys/values from language, queries are time tokens)
residual = x
if self.pre_norm:
x = self.norm2(x)
# Broadcast language token across time
T = x.shape[0]
lang_kv = lang_token.expand(1, x.shape[1], x.shape[2]) # (1, B, D)
x = self.cross_attn(x, lang_kv, lang_kv)[0]
x = residual + self.dropout2(x)
if not self.pre_norm:
x = self.norm2(x)
# Feed-forward
residual = x
if self.pre_norm:
x = self.norm3(x)
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
x = residual + self.dropout3(x)
if not self.pre_norm:
x = self.norm3(x)
return x
def create_sinusoidal_pos_encoding(max_len: int, dim: int) -> Tensor:
position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1) # (L, 1)
div_term = torch.exp(torch.arange(0, dim, 2).float() * (-math.log(10000.0) / dim)) # (D/2)
pe = torch.zeros(max_len, dim)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
return pe # (L, D)
def generate_causal_mask(T: int, device=None) -> Tensor:
# (T, T) with True where masking should occur for MultiheadAttention expects float mask or bool?
mask = torch.full((T, T), float("-inf"), device=device)
mask = torch.triu(mask, diagonal=1)
return mask
def extract_visual_sequence(batch: dict[str, Tensor]) -> Tensor:
# Accept various image key formats from datasets
# Try multiple common key patterns
# List of possible image keys to check, in order of preference
possible_keys = [
OBS_IMAGES, # 'observation.images'
OBS_IMAGE, # 'observation.image'
"observation.images.image", # nested format from some datasets
]
for key in possible_keys:
if key in batch:
image_val = batch[key]
if isinstance(image_val, list) and len(image_val) > 0:
# List of (B, C, H, W) -> stack over time
return torch.stack(image_val, dim=1)
elif torch.is_tensor(image_val):
# Tensor of shape (B, T, C, H, W) or (B, C, H, W)
if image_val.dim() == 5:
# Already has time dimension
return image_val
elif image_val.dim() == 4:
# Add time dimension (single frame)
return image_val.unsqueeze(1)
else:
raise ValueError(
f"'{key}' must be a Tensor of shape (B,T,C,H,W) or (B,C,H,W), got shape {image_val.shape}"
)
# If no image key found, provide helpful error with available keys
available_keys = list(batch.keys())
image_like_keys = [k for k in available_keys if "image" in k.lower()]
raise ValueError(
f"Could not find image data in batch. Looked for keys: {possible_keys}. "
f"Available keys with 'image': {image_like_keys}. "
f"All keys: {available_keys}"
)
def encode_language(
language_input: Tensor | list | str | None, text_encoder, processor, batch_size: int
) -> Tensor:
# language_input can be: list[str] length B, or None
if language_input is None:
texts = [""] * batch_size
elif isinstance(language_input, list):
texts = language_input
else:
# Single string for the batch
texts = [str(language_input)] * batch_size
inputs = processor(text=texts, padding=True, return_tensors="pt")
inputs = {k: v.to(next(text_encoder.parameters()).device) for k, v in inputs.items()}
outputs = text_encoder(**inputs)
if hasattr(outputs, "pooler_output"):
emb = outputs.pooler_output
elif hasattr(outputs, "last_hidden_state"):
emb = outputs.last_hidden_state[:, 0]
else:
raise RuntimeError("Unsupported text encoder output structure")
return emb
def pairwise_ranking_loss(logits: Tensor, target: Tensor, margin: float = 0.1, num_pairs: int = 32) -> Tensor:
# logits, target: (B, T)
B, T = logits.shape
if T < 2:
return logits.mean() * 0.0
# Sample pairs i<j and enforce r_j > r_i when target_j > target_i
losses = []
for _ in range(num_pairs):
i = torch.randint(0, T - 1, (B,), device=logits.device)
j = i + torch.randint(1, T - i.max(), (1,), device=logits.device)
j = j.expand_as(i)
li = logits[torch.arange(B), i]
lj = logits[torch.arange(B), j]
yi = target[torch.arange(B), i]
yj = target[torch.arange(B), j]
sign = torch.sign(yj - yi)
# hinge: max(0, margin - sign*(lj-li))
loss = F.relu(margin - sign * (lj - li))
losses.append(loss.mean())
return torch.stack(losses).mean()
def zscore(x: Tensor, eps: float = 1e-3) -> Tensor:
"""Z-score normalization with numerical stability."""
# Handle both (B,) and (B, T) shapes
if x.dim() == 1:
x = x.unsqueeze(1) # Make it (B, 1)
B, T = x.shape
# If only one timestep, can't compute meaningful std across time
if T == 1:
# Just use tanh to bound values instead of z-score
return torch.tanh(x * 0.1) # Scale and bound
# Compute mean and std across time dimension
mean = x.mean(dim=1, keepdim=True)
std = x.std(dim=1, keepdim=True, unbiased=False)
# Check if std is valid (not zero or NaN)
std_is_valid = (std > eps) & (~torch.isnan(std))
# Safe std for division
std_safe = torch.where(std_is_valid, std, torch.ones_like(std))
# Compute z-score where valid
z = (x - mean) / std_safe
# For invalid cases, use tanh of centered values
z_fallback = torch.tanh((x - mean) * 0.1)
z = torch.where(std_is_valid.expand_as(z), z, z_fallback)
# Final safety clamp
z = torch.clamp(z, min=-5.0, max=5.0)
# Check for any remaining NaNs and replace with 0
z = torch.nan_to_num(z, nan=0.0)
return z
def temporal_logistic_ranking(
values: Tensor, margin: float = 0.1, min_gap: int = 1, num_pairs: int = 64
) -> Tensor:
"""VLC-style temporal monotonicity: encourage V[j] > V[i] for j>i.
Samples pairs (i<j) with a minimum gap and applies softplus(m - (Vj - Vi)).
"""
B, T = values.shape
if T < 2:
return values.mean() * 0.0
losses = []
device = values.device
for _ in range(num_pairs):
i = torch.randint(0, max(1, T - min_gap), (B,), device=device)
j = i + torch.randint(min_gap, T - i.max(), (1,), device=device)
j = j.expand_as(i)
vi = values[torch.arange(B), i]
vj = values[torch.arange(B), j]
losses.append(F.softplus(margin - (vj - vi)).mean())
return torch.stack(losses).mean()
def reversible_ranking_loss(
values: Tensor, target: Tensor, margin: float = 0.1, num_pairs: int = 64, min_gap: int = 1
) -> tuple[Tensor, Tensor]:
"""ReWiND-style reversible ranking: learn from both forward and reversed trajectories.
Key insight: If a trajectory shows progress forward, its reverse shows undoing progress.
By training on both, the model learns what constitutes progress vs regression.
Args:
values: (B, T) predicted values
target: (B, T) progress labels (0 to 1 for forward progress)
margin: Margin for ranking loss
num_pairs: Number of (far, near) pairs to sample
min_gap: Minimum temporal gap between pairs
Returns:
forward_loss: Loss from forward trajectory pairs
reverse_loss: Loss from reversed trajectory pairs
"""
B, T = values.shape
if T < 2:
zero_loss = values.mean() * 0.0
return zero_loss, zero_loss
device = values.device
# Forward trajectory ranking: later frames should have higher values
forward_losses = []
for _ in range(num_pairs // 2):
# Sample far-near pairs (far is earlier, near is later)
far_idx = torch.randint(0, max(1, T - min_gap), (B,), device=device)
near_idx = far_idx + torch.randint(min_gap, T - far_idx.max(), (1,), device=device)
near_idx = near_idx.expand_as(far_idx)
v_far = values[torch.arange(B), far_idx]
v_near = values[torch.arange(B), near_idx]
# Near (later) should have higher value than far (earlier)
forward_losses.append(F.softplus(margin - (v_near - v_far)).mean())
# Reversed trajectory ranking: treat reversed sequence with inverted progress
# Reverse both values and targets
reversed_values = values.flip(dims=[1]) # Reverse time dimension
reversed_target = 1.0 - target.flip(dims=[1]) # Invert and reverse progress
reverse_losses = []
for _ in range(num_pairs // 2):
# In reversed trajectory, what was "later" is now "earlier"
far_idx = torch.randint(0, max(1, T - min_gap), (B,), device=device)
near_idx = far_idx + torch.randint(min_gap, T - far_idx.max(), (1,), device=device)
near_idx = near_idx.expand_as(far_idx)
v_far_rev = reversed_values[torch.arange(B), far_idx]
v_near_rev = reversed_values[torch.arange(B), near_idx]
# In reversed trajectory with inverted progress,
# near (which was originally earlier) should still have higher value
reverse_losses.append(F.softplus(margin - (v_near_rev - v_far_rev)).mean())
forward_loss = torch.stack(forward_losses).mean() if forward_losses else values.mean() * 0.0
reverse_loss = torch.stack(reverse_losses).mean() if reverse_losses else values.mean() * 0.0
return forward_loss, reverse_loss
def intra_trajectory_directional_ranking(
values: Tensor, progress: Tensor, margin: float = 0.2, num_pairs: int = 64, min_gap: int = 1
) -> Tensor:
"""Directional ranking within trajectory based on progress labels.
For pairs i<j within a trajectory:
- If progress increases (y_j > y_i), enforce V_j > V_i
- If progress decreases (y_j < y_i), enforce V_j < V_i
- Ignore pairs where progress is unchanged
Uses logistic loss: log(1 + exp(m - s_ij * (V_j - V_i)))
where s_ij = sign(y_j - y_i)
"""
B, T = values.shape
if T < 2:
return values.mean() * 0.0
losses = []
device = values.device
for _ in range(num_pairs):
# Sample time pairs i < j
i = torch.randint(0, max(1, T - min_gap), (B,), device=device)
max_j = min(T, i.max() + T - min_gap)
j = i + torch.randint(min_gap, max_j - i.min(), (1,), device=device)
j = j.expand_as(i).clamp(max=T - 1)
# Get values and progress at sampled times
vi = values[torch.arange(B), i]
vj = values[torch.arange(B), j]
yi = progress[torch.arange(B), i]
yj = progress[torch.arange(B), j]
# Compute direction sign
s_ij = torch.sign(yj - yi)
# Only compute loss for non-zero progress differences
mask = s_ij != 0
if mask.any():
diff = vj - vi
loss = torch.log1p(torch.exp(margin - s_ij * diff))
losses.append(loss[mask].mean())
return torch.stack(losses).mean() if losses else values.mean() * 0.0
def inter_instruction_contrastive_ranking(
values_correct: Tensor, values_incorrect: Tensor, margin: float = 0.2
) -> Tensor:
"""Ranking between correct and incorrect instructions for same frames.
Enforces V_t(z) > V_t(z') where z is correct instruction and z' is incorrect.
Uses logistic loss: log(1 + exp(m - (V_t(z) - V_t(z'))))
"""
diff = values_correct - values_incorrect
return torch.log1p(torch.exp(margin - diff)).mean()
def flatness_under_mismatch(values: Tensor, epsilon: float = 0.05, num_pairs: int = 32) -> Tensor:
"""Enforce flat values over time for mismatched instructions.
For trajectory with wrong instruction, V should not change much over time.
Uses Huber loss to allow small variations within epsilon band.
"""
B, T = values.shape
if T < 2:
return values.mean() * 0.0
losses = []
device = values.device
for _ in range(num_pairs):
i = torch.randint(0, T - 1, (B,), device=device)
j = torch.randint(i.min() + 1, T, (1,), device=device)
j = j.expand_as(i)
vi = values[torch.arange(B), i]
vj = values[torch.arange(B), j]
# Huber loss with small delta for near-zero target
diff = vj - vi
loss = F.huber_loss(diff, torch.zeros_like(diff), delta=epsilon)
losses.append(loss)
return torch.stack(losses).mean()
@@ -0,0 +1,126 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Any
from lerobot.configs.types import PolicyFeature
from lerobot.constants import OBS_LANGUAGE
from lerobot.policies.rlearn.configuration_rlearn import RLearNConfig
from lerobot.processor import (
DeviceProcessor,
NormalizerProcessor,
RenameProcessor,
RobotProcessor,
ToBatchProcessor,
TokenizerProcessor,
UnnormalizerProcessor,
)
from lerobot.processor.pipeline import (
ComplementaryDataProcessor,
EnvTransition,
ProcessorStepRegistry,
TransitionKey,
)
def make_rlearn_processor(
config: RLearNConfig, dataset_stats: dict[str, dict[str, Any]] | None = None
) -> tuple[RobotProcessor, RobotProcessor]:
"""Build pre/post processors for RLearN.
Responsibilities moved out of the model:
- Normalize inputs (images) using dataset stats
- Ensure batching
- Map complementary_data.task to observation.language when available
- Tokenize language into observation.language.tokens / attention_mask
- Move to/from device
"""
input_steps = [
# No renaming by default, but keep for future extensibility
RenameProcessor(rename_map={}),
NormalizerProcessor(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
ToBatchProcessor(),
RLearnLanguageFromTaskProcessor(),
# Use the same model name for tokenizer to keep vocab aligned with text tower
TokenizerProcessor(
tokenizer_name=config.model_name,
max_length=128,
padding="max_length",
truncation=True,
padding_side="right",
),
DeviceProcessor(device=config.device),
]
output_steps = [
DeviceProcessor(device="cpu"),
UnnormalizerProcessor(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
),
]
return RobotProcessor(steps=input_steps, name="robot_preprocessor"), RobotProcessor(
steps=output_steps, name="robot_postprocessor"
)
@dataclass
@ProcessorStepRegistry.register(name="rlearn_language_from_task")
class RLearnLanguageFromTaskProcessor(ComplementaryDataProcessor):
"""Copy complementary_data['task'] into observation['observation.language'] if present.
This ensures the model can consume a raw language string when tokenization is not used,
while TokenizerProcessor can still create tokenized fields.
"""
task_key: str = "task"
def __call__(self, transition: EnvTransition) -> EnvTransition: # type: ignore[override]
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
if not complementary_data or self.task_key not in complementary_data:
return transition
task = complementary_data.get(self.task_key)
if task is None:
return transition
# Normalize to list[str]
if isinstance(task, str):
task_list = [task]
elif isinstance(task, list) and all(isinstance(t, str) for t in task):
task_list = task
else:
return transition
observation = transition.get(TransitionKey.OBSERVATION) or {}
# Do not overwrite if user already provided observation.language
if OBS_LANGUAGE not in observation:
observation[OBS_LANGUAGE] = task_list
transition[TransitionKey.OBSERVATION] = observation
return transition
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: # noqa: D401
# Adds nothing to features; only mirrors complementary_data.task into observation
return features
def get_config(self) -> dict[str, Any]:
return {"task_key": self.task_key}
+136
View File
@@ -0,0 +1,136 @@
## General Value/Reward Learning:
I want to implement a general/universal vision and language value function or reward model for robotics/video tasks. Also called a video language conditioned reward model. Integrated with already existing LeRobot code if convenient, use the LeRobot Dataset for dataset and store the reward for a frame in the lerobot frame itself.
Inspired by these papers:
- ReWiND; https://arxiv.org/pdf/2505.10911 (Most applicable and main paper I want to implement ideas from) and code: https://github.com/lucidrains/rewind-reward-pytorch
- LIV; https://arxiv.org/pdf/2306.00958 (Most applicable and 2nd main paper I want to implement ideas from) and code https://github.com/penn-pal-lab/LI
- VLC: Video-Language Critic: Transferable Reward Functions for Language-Conditioned Robotics: https://arxiv.org/pdf/2405.19988 (Most applicable and 3rd paper I want to implement ideas from) and code: https://github.com/minttusofia/video_language_critic
And these papers which are also relevant:
- https://www.dyna.co/dyna-1/research (Main company I want to reproduce the eventual results from)
- vip; https://arxiv.org/pdf/2210.00030
- uvd; https://arxiv.org/pdf/2310.08581
- vlm in context; https://arxiv.org/pdf/2411.04549
- https://www.youtube.com/watch?v=JfZYtpEisoM
Little less relevant but still similar papers:
- Learning Generalizable Robotic Reward Functions from “In-The-Wild” Human Videos,
- XIRL: Cross-embodiment Inverse Reinforcement Learning,
- Video-Language Critic: Transferable Reward https://arxiv.org/pdf/2405.19988
- Functions for Language-Conditioned Robotics,
- LORel, Language-Driven Representation Learning for Robotics https://sites.google.com/view/robotlorel
- RoboCLIP: One Demonstration is Enough to Learn Robot Policies https://arxiv.org/pdf/2310.07899
- Points2Rewards: learn first key points and then uses the keypoints to learn general value function/policy https://semrob.github.io/docs/2025_rss_semrob.github.io_paper20.pdf
- Language-Driven Representation Learning for Robotics: https://arxiv.org/pdf/2302.12766v1
- R3M: A Universal Visual Representation for Robot Manipulation: https://arxiv.org/pdf/2203.12601v3
Input should be the current image or whole video and the task goal specified in text/language. Output is current reward.
Archiutecture:
_ inputs: video o1:T (or current o1:t), language z;
_ google/siglip2-large-patch16-256: https://huggingface.co/google/siglip2-large-patch16-256 \* Temporal module: small causal transformer (“cross-modal sequential aggregator”), with first-frame positional embedding (to avoid position cheating), frame-dropout, and stride sampling; outputs per-timestep logits.
Loss: See this chatgpt thread: https://chatgpt.com/s/t_68999a50a0b081919abc365cdd205e01
Past images: (for example a reward methoid go to 3rd floor, has to know what floor it was on and what pas actions it did, can we attend or encorperate images of decision from history in one way?) Maybe via this paper: Learning Long-Context Diffusion Policies via Past-Token Prediction
Amount of frames needed for test/generalization: 1M frames? or ~20% of IPEC-COMMUNITY/bc_z_lerobot
Eval:
Implement something like voc score , or ROC rank order correlation between reward leanredna and ev reward from sim, or use something else to do additional evaluation
Ideas:
- Incorporate training on multiple horizons: as in label same dataset for longer horizons: make a sandwich (long), put cheese on bread (medium) and even smaller horizons: go down or close gripper (small)
- Incorporate navigation goals “walk towards the kitchen”, make sure we fix CLIP contrastive learning issue of positional text misunderstanding where model doesnnt learn difference between "horse right of cow" and "horse left of cow" “Move right” potentially train with more other data or even actionable world models such as Genie 3 (https://deepmind.google/discover/blog/genie-3-a-new-frontier-for-world-models/)
How to use a general reward model (use cases): - Train rl policy on it - Success detection - Do exploraion - Do task via planning and search to optimize reward - Filter out bad episodes in large datasets from imitation learning
Potential Datasets: (start with dataset that is most clean for this and works best with chosen way of doing evals)
_ Epic-Kitchens-100
_ Something-Something v. 2 Dataset https://www.qualcomm.com/developer/software/something-something-v-2-dataset
_ Ego4D (3000 hours)
_ Open X-Embodiment (OXE)
_ Age bot world: https://huggingface.co/datasets/agibot-world/AgiBotWorld-Alpha
_ GTEA+ Gaze: https://cbs.ic.gatech.edu/fpv/
_ YouCook2 dataset
_ HOWTO100M: https://www.di.ens.fr/willow/research/howto100m/
Also in plan include investigating/incorperating these two things:
- Curriculum: Start training on easier distinctions (e.g. very early vs very late frames which are easy to tell apart) then gradually ask the model to distinguish more subtle differences (frames that are closer in time). This curriculum can be implemented by initially using pairs far apart in the trajectory for ranking, then moving to closer pairs as accuracy improves.
- Augmentations: As mentioned, heavy image augmentation (random crops, slight noise) is often used so that the reward model focuses on high-level task progress features rather than pixel-level cues. For video models, temporal augmentation like random frame skipping can also make the model robust to different speeds.
### Implemented Loss (Spatial-Aware Composite Loss)
Our implementation uses a **composite loss with spatial awareness** to address the limitations of standard contrastive learning (e.g., CLIP's inability to distinguish "move left" vs "move right"). The loss has three components:
##### 1) Progress Regression Loss (L_prog)
Predicts normalized progress values for each timestep:
$$
L_{\text{prog}} = \text{MSE}(\sigma(z(V_t)), y_t)
$$
where $z(·)$ is z-score normalization, $\sigma$ is sigmoid, and $y_t \in [0,1]$ is the progress label.
**Purpose:** Grounds the model in actual task progress, not just visual-language similarity.
##### 2) Spatial-Aware InfoNCE Loss (L_spatial_nce)
Instead of using pooled features, we:
- Extract spatial patch features from SigLIP2's last_hidden_state (e.g., 16×16 patches)
- Use cross-attention where language queries attend to relevant spatial regions
- Compute contrastive loss on the attended spatial features
$$
L_{\text{spatial-nce}} = -\log \frac{\exp(s_{ii}/\tau)}{\sum_j \exp(s_{ij}/\tau)}
$$
where $s_{ij}$ is the similarity between spatially-attended features from trajectory $i$ and language $j$.
**Purpose:** Preserves spatial information that pooling discards, enabling distinction of spatial relationships.
##### 3) ReWiND Reversible Ranking Loss (L_rewind)
Based on ReWiND's key insight: learn from both forward AND reversed trajectories.
The loss has two components:
- **Forward ranking**: Sample (far, near) pairs where near is later in time, enforce $V_{\text{near}} > V_{\text{far}}$
- **Reverse ranking**: Reverse the trajectory and invert progress labels, then apply same ranking
$$
L_{\text{rewind}} = L_{\text{forward}} + L_{\text{reverse}}
$$
where both use: $\text{softplus}(m - (V_{\text{near}} - V_{\text{far}}))$
**Purpose:** By training on reversed trajectories with inverted progress, the model learns to distinguish progress from undoing progress. This is ReWiND's core contribution - understanding that tasks can be reversible.
##### Total Loss:
$$
L = \lambda_{\text{prog}} L_{\text{prog}} + \lambda_{\text{spatial-nce}} L_{\text{spatial-nce}} + \lambda_{\text{rewind}} L_{\text{rewind}}
$$
Default weights: $\lambda_{\text{prog}}=1.0$, $\lambda_{\text{spatial-nce}}=0.5$, $\lambda_{\text{rewind}}=0.4$
### TODOs:
- Implement first architecture [x]
- Implement processors [x]
- Choose right loss metric(s) [x]
- Make dataset with script that generated the dataset (IPEC-COMMUNITY/bc_z_lerobot) ready in lerobot format (and be able to visualize in dataset visualizer)
- Annotate with ReWiND-style 0→1 progress rewards [x]
- Visualize to check [x]
- Implement eval score or metric that is robust and can deal with generalization/is a good metric to try different architectures. And use it in an eval jupyter notebook with visalization of the live reward next to the video for part of the dataset: VOC score and score with correct and incorrect language captions [x]
- Do first training []
- Try different losses []
- Switch to DINO v3 as encoder Base 86 M: https://huggingface.co/facebook/dinov3-vitb16-pretrain-lvd1689m with HuggingFaceTB/SmolLM2-135M-Instruct ?
- Add more artificial text to dataset generated by vlm (google gemini) []
- See google gemini vlm caption from Leandro []
- Multiple captions per video, creat method to generate as much data as possible etc [] https://arxiv.org/abs/2508.13446
- Add droid []
+7 -4
View File
@@ -145,10 +145,13 @@ class TokenizerProcessor:
observation = dict(observation) # Make a copy
# Add tokenized data to observation
observation[f"{OBS_LANGUAGE}.tokens"] = tokenized_prompt["input_ids"]
observation[f"{OBS_LANGUAGE}.attention_mask"] = tokenized_prompt["attention_mask"].to(
dtype=torch.bool
)
input_ids = tokenized_prompt["input_ids"]
attention_mask = tokenized_prompt.get("attention_mask")
if attention_mask is None:
# Some tokenizers (e.g., SigLIP text) may not return attention_mask; default to ones
attention_mask = torch.ones_like(input_ids)
observation[f"{OBS_LANGUAGE}.tokens"] = input_ids
observation[f"{OBS_LANGUAGE}.attention_mask"] = attention_mask.to(dtype=torch.bool)
transition[TransitionKey.OBSERVATION.value] = observation # type: ignore[misc]
return transition
@@ -0,0 +1,335 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Add ReWiND-style linear progress rewards to existing LeRobot datasets.
This script creates a complete copy of the dataset with rewards added to each frame.
It downloads the original dataset (including videos), adds rewards, and pushes everything to a new repository.
Usage:
# Create full dataset copy with rewards
python src/lerobot/scripts/annotate_dataset_rewards.py --input-repo IPEC-COMMUNITY/bc_z_lerobot --output-repo username/bc_z_with_rewards
# Test with 1% of episodes
python src/lerobot/scripts/annotate_dataset_rewards.py --input-repo IPEC-COMMUNITY/bc_z_lerobot --output-repo username/test_rewards --percentage 1
"""
import argparse
import shutil
from pathlib import Path
from tempfile import mkdtemp
import numpy as np
import torch
from PIL import Image
from tqdm import tqdm
from lerobot.constants import REWARD
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
def compute_linear_progress_reward(episode_length: int) -> np.ndarray:
"""
Compute linear progress rewards from 0 to 1.
ReWiND-style: progress increases linearly from 0 at start to 1 at completion.
Args:
episode_length: Number of frames in the episode
Returns:
rewards: Array of shape (episode_length,) with values linearly from 0 to 1
"""
return np.linspace(0, 1, episode_length, dtype=np.float32)
def main():
parser = argparse.ArgumentParser(
description="Add linear progress rewards to LeRobot dataset and push to Hub"
)
parser.add_argument(
"--input-repo",
type=str,
default="IPEC-COMMUNITY/bc_z_lerobot",
help="Input dataset repository on HuggingFace Hub",
)
parser.add_argument(
"--output-repo",
type=str,
required=True,
help="Output dataset repository name (e.g., username/dataset_with_rewards)",
)
parser.add_argument(
"--percentage",
type=float,
default=100.0,
help="Percentage of episodes to process (useful for testing, e.g., 1 for 1%%)",
)
parser.add_argument(
"--private",
action="store_true",
help="Make the output repository private",
)
parser.add_argument(
"--local-dir",
type=str,
default=None,
help="Local directory to save the modified dataset (defaults to ~/.cache/huggingface/lerobot/<output-repo>)",
)
args = parser.parse_args()
print("=" * 60)
print("FULL DATASET COPY WITH REWARDS")
print("This will download the entire dataset including videos,")
print("add rewards, and push everything to a new repository.")
print("=" * 60)
# First, load just the metadata to get total episodes
print(f"\nLoading metadata from Hub: {args.input_repo}")
# Load metadata only first
metadata = LeRobotDatasetMetadata(repo_id=args.input_repo)
total_episodes = metadata.total_episodes
# Calculate which episodes to process
num_episodes_to_process = max(1, int(total_episodes * args.percentage / 100))
episodes_to_load = list(range(num_episodes_to_process)) # Load only first N episodes
print(f"Dataset has {total_episodes} episodes")
print(f"Processing {num_episodes_to_process} episodes ({args.percentage}%)")
# Determine local directory for the new dataset
if args.local_dir:
local_dir = Path(args.local_dir)
else:
from lerobot.constants import HF_LEROBOT_HOME
local_dir = HF_LEROBOT_HOME / args.output_repo
# Use a temporary directory for downloading source dataset
temp_source_dir = Path(mkdtemp(prefix="lerobot_source_"))
# Load the dataset with videos to temp directory
print("Downloading dataset with videos to temp directory...")
print(f"Temp directory: {temp_source_dir}")
dataset = LeRobotDataset(
repo_id=args.input_repo,
root=temp_source_dir, # Temporary location for source
episodes=episodes_to_load if args.percentage < 100 else None,
download_videos=True, # Download videos
)
print(f"Downloaded {dataset.num_episodes} episodes with {dataset.num_frames} frames")
# Create a new dataset with rewards
print(f"\nCreating new dataset at: {local_dir}")
# Clean up any existing directory from previous runs
if local_dir.exists():
print(f"⚠️ Directory already exists: {local_dir}")
print(" Removing it to start fresh...")
shutil.rmtree(local_dir)
# Define features including reward
# Simply copy all features from the original dataset
new_features = dict(dataset.features)
# Add reward feature
new_features[REWARD] = {"shape": (1,), "dtype": "float32", "names": ["reward"]}
# Determine which features are videos
video_keys = dataset.meta.video_keys if hasattr(dataset.meta, "video_keys") else []
image_keys = dataset.meta.image_keys if hasattr(dataset.meta, "image_keys") else []
visual_keys = set(video_keys + image_keys)
print(f" Visual features to be handled as videos: {visual_keys}")
# Check for language features
language_keys = [
k
for k in dataset.features.keys()
if any(lang in k.lower() for lang in ["language", "task", "instruction", "text"])
]
if language_keys:
print(f" Language/task features found: {language_keys}")
# Copy dataset structure to new location
new_dataset = LeRobotDataset.create(
repo_id=args.output_repo,
root=local_dir,
fps=dataset.fps,
features=new_features,
robot_type=dataset.meta.robot_type,
use_videos=len(dataset.meta.video_keys) > 0,
)
# Process each episode
print("\nAdding rewards to episodes...")
episode_data_index = dataset.episode_data_index
for ep_idx, episode_idx in enumerate(tqdm(episodes_to_load)):
# Get episode boundaries
ep_start = episode_data_index["from"][ep_idx].item()
ep_end = episode_data_index["to"][ep_idx].item()
episode_length = ep_end - ep_start
# Compute linear progress rewards for this episode
rewards = compute_linear_progress_reward(episode_length)
# Get episode metadata
episode_info = dataset.meta.episodes[episode_idx]
tasks = episode_info.get("tasks", [])
if not tasks:
# Try to get task from first frame if not in episode metadata
first_frame = dataset[ep_start]
if "task" in first_frame:
tasks = [first_frame["task"]]
else:
tasks = [""]
# Process each frame in the episode
for frame_idx in range(episode_length):
global_idx = ep_start + frame_idx
# Get original frame data
frame_data = dataset[global_idx]
# Create frame dict for the new dataset
frame = {}
for key in dataset.features:
# Skip only auto-generated metadata fields
# Keep task-related fields that contain language annotations
if key in ["index", "episode_index", "frame_index", "timestamp"]:
continue
# For visual features that are videos, extract the actual frame
if key in visual_keys:
# Get the image data to save as temporary files
if key in frame_data:
img = frame_data[key]
# Convert to numpy if tensor
if isinstance(img, torch.Tensor):
img = img.cpu().numpy()
# Ensure channels-last format (H, W, C) for saving
if len(img.shape) == 3 and img.shape[0] in [1, 3, 4]:
img = np.transpose(img, (1, 2, 0))
# Resize to match expected shape if needed
expected_shape = new_features[key].get("shape")
if expected_shape and img.shape != tuple(expected_shape):
# Try to match the shape - handle both HWC and CHW formats
if len(expected_shape) == 3:
# Determine if expected is HWC or CHW
if expected_shape[-1] in [1, 3, 4]: # Likely HWC
target_h, target_w = expected_shape[0], expected_shape[1]
elif expected_shape[0] in [
1,
3,
4,
]: # Likely CHW - shouldn't happen after transpose
target_h, target_w = expected_shape[1], expected_shape[2]
else:
# Assume HWC
target_h, target_w = expected_shape[0], expected_shape[1]
# Resize using PIL for quality
if img.dtype != np.uint8:
img = (img * 255).astype(np.uint8)
pil_img = Image.fromarray(img)
pil_img = pil_img.resize((target_w, target_h), Image.Resampling.LANCZOS)
img = np.array(pil_img)
frame[key] = img
continue
if key in frame_data:
value = frame_data[key]
# Handle language/task fields specially
if key == "task" and isinstance(value, str):
# Skip string task - will be passed separately to add_frame
continue
elif key == "task_index":
# Skip task_index as it will be regenerated based on task
continue
elif key in ["observation.language", "language", "instruction"] and isinstance(
value, str
):
# Keep language fields as-is
frame[key] = value
continue
# Regular field processing
# Convert tensors to numpy for saving
if isinstance(value, torch.Tensor):
value = value.cpu().numpy()
# Ensure arrays are the right shape
if hasattr(value, "shape") and len(value.shape) == 0:
# Convert scalar to 1D array
value = np.array([value])
frame[key] = value
# Add reward
frame[REWARD] = np.array([rewards[frame_idx]], dtype=np.float32)
# Get task for this specific frame (might vary within episode)
if "task" in frame_data:
task = frame_data["task"]
else:
task = tasks[0] if tasks else ""
# Add frame to new dataset
timestamp = frame_idx / dataset.fps
new_dataset.add_frame(frame, task=task, timestamp=timestamp)
# Save the episode (this will encode videos from the saved frames)
new_dataset.save_episode()
print(
f"\n✓ Created new dataset with rewards: {new_dataset.num_episodes} episodes, {new_dataset.num_frames} frames"
)
# Push to Hub
print(f"\nPushing to Hub: {args.output_repo}")
new_dataset.push_to_hub(
private=args.private,
push_videos=True,
)
print(f"\n✓ Dataset pushed to: https://huggingface.co/datasets/{args.output_repo}")
# Clean up temporary source directory
if temp_source_dir.exists():
print("\nCleaning up temporary files...")
shutil.rmtree(temp_source_dir)
# Print summary
print("\n=== Summary ===")
print(f"Input dataset: {args.input_repo}")
print(f"Output dataset: {args.output_repo}")
print(f"Episodes processed: {num_episodes_to_process}/{total_episodes} ({args.percentage}%)")
print(f"Frames with rewards: {new_dataset.num_frames}")
print("Reward type: Linear progress (0→1)")
print("===============")
if __name__ == "__main__":
main()
@@ -0,0 +1,591 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
OPTIMIZED VERSION: Add ReWiND-style linear progress rewards to existing LeRobot datasets with parallel processing.
This script creates a complete copy of the dataset with rewards added to each frame.
It downloads the original dataset (including videos), adds rewards, and pushes everything to a new repository.
Key optimizations:
- Parallel episode processing using multiprocessing
- Batch frame processing within episodes
- Concurrent video encoding
- Optimized image operations
- Better memory management
Usage:
# Test with 1% of episodes using 4 workers
python src/lerobot/scripts/annotate_dataset_rewards_optimized.py --input-repo IPEC-COMMUNITY/bc_z_lerobot --output-repo pepijn223/rewards_bc_z_1p --percentage 1 --num-workers 4
"""
import argparse
import logging
import shutil
from concurrent.futures import ThreadPoolExecutor, as_completed
from multiprocessing import Pool, cpu_count
from pathlib import Path
from tempfile import mkdtemp
from typing import Any
import numpy as np
import torch
from PIL import Image
from tqdm import tqdm
from lerobot.constants import REWARD
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
# Set up logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
def compute_linear_progress_reward(episode_length: int) -> np.ndarray:
"""
Compute linear progress rewards from 0 to 1.
ReWiND-style: progress increases linearly from 0 at start to 1 at completion.
Args:
episode_length: Number of frames in the episode
Returns:
rewards: Array of shape (episode_length,) with values linearly from 0 to 1
"""
return np.linspace(0, 1, episode_length, dtype=np.float32)
def process_image_batch(images: list[np.ndarray], target_shape: tuple[int, ...]) -> list[np.ndarray]:
"""
Process a batch of images efficiently.
Args:
images: List of numpy arrays representing images
target_shape: Target shape for resizing
Returns:
List of processed images
"""
processed = []
if len(target_shape) == 3:
# Determine target dimensions
if target_shape[-1] in [1, 3, 4]: # Likely HWC
target_h, target_w = target_shape[0], target_shape[1]
elif target_shape[0] in [1, 3, 4]: # Likely CHW
target_h, target_w = target_shape[1], target_shape[2]
else:
target_h, target_w = target_shape[0], target_shape[1]
# Process all images
for img in images:
# Ensure channels-last format
if len(img.shape) == 3 and img.shape[0] in [1, 3, 4]:
img = np.transpose(img, (1, 2, 0))
# Resize if needed
if img.shape[:2] != (target_h, target_w):
if img.dtype != np.uint8:
img = (img * 255).astype(np.uint8)
pil_img = Image.fromarray(img)
pil_img = pil_img.resize((target_w, target_h), Image.Resampling.LANCZOS)
img = np.array(pil_img)
processed.append(img)
else:
processed = images
return processed
def process_episode_chunk(args: tuple[int, int, dict, Any]) -> tuple[int, list[dict], list[str]]:
"""
Process a chunk of frames from an episode in parallel.
Args:
args: Tuple of (chunk_start, chunk_end, shared_data, episode_data)
Returns:
Tuple of (episode_idx, frames_data, tasks)
"""
chunk_start, chunk_end, shared_data, episode_data = args
episode_idx = episode_data["episode_idx"]
ep_start = episode_data["ep_start"]
episode_length = episode_data["episode_length"]
rewards = episode_data["rewards"]
tasks_default = episode_data["tasks"]
dataset = episode_data["dataset"]
new_features = shared_data["new_features"]
visual_keys = shared_data["visual_keys"]
fps = shared_data["fps"]
frames_data = []
tasks = []
# Process chunk of frames
for frame_idx in range(chunk_start, min(chunk_end, episode_length)):
global_idx = ep_start + frame_idx
# Get original frame data
frame_data = dataset[global_idx]
# Create frame dict for the new dataset
frame = {}
# Process all non-visual features
for key in dataset.features:
if key in ["index", "episode_index", "frame_index", "timestamp"]:
continue
if key in visual_keys:
# Process visual features
if key in frame_data:
img = frame_data[key]
if isinstance(img, torch.Tensor):
img = img.cpu().numpy()
frame[key] = img
continue
if key in frame_data:
value = frame_data[key]
# Handle special fields
if key == "task" and isinstance(value, str):
tasks.append(value)
continue
elif key == "task_index":
continue
elif key in ["observation.language", "language", "instruction"] and isinstance(value, str):
frame[key] = value
continue
# Regular field processing
if isinstance(value, torch.Tensor):
value = value.cpu().numpy()
if hasattr(value, "shape") and len(value.shape) == 0:
value = np.array([value])
frame[key] = value
# Add reward
frame[REWARD] = np.array([rewards[frame_idx]], dtype=np.float32)
# Set task
if not tasks or tasks[-1] is None:
tasks.append(tasks_default[0] if tasks_default else "")
# Add timestamp
frame["timestamp"] = frame_idx / fps
frames_data.append(frame)
return (episode_idx, frames_data, tasks)
def process_episode_parallel(
episode_data: dict, shared_data: dict, chunk_size: int = 50
) -> tuple[int, list[dict], list[str]]:
"""
Process an entire episode using parallel chunk processing.
Args:
episode_data: Episode-specific data
shared_data: Shared configuration data
chunk_size: Number of frames to process per chunk
Returns:
Tuple of (episode_idx, all_frames, all_tasks)
"""
episode_length = episode_data["episode_length"]
episode_idx = episode_data["episode_idx"]
# Create chunks
chunks = []
for i in range(0, episode_length, chunk_size):
chunk_end = min(i + chunk_size, episode_length)
chunks.append((i, chunk_end, shared_data, episode_data))
# Process chunks in parallel using threads (good for I/O bound operations)
all_frames = [None] * episode_length
all_tasks = []
with ThreadPoolExecutor(max_workers=4) as executor:
futures = {executor.submit(process_episode_chunk, chunk): idx for idx, chunk in enumerate(chunks)}
for future in as_completed(futures):
chunk_idx = futures[future]
_, frames, tasks = future.result()
# Place frames in correct positions
start_idx = chunks[chunk_idx][0]
for i, frame in enumerate(frames):
all_frames[start_idx + i] = frame
all_tasks.extend(tasks)
# Filter out None values (shouldn't happen but safety check)
all_frames = [f for f in all_frames if f is not None]
return (episode_idx, all_frames, all_tasks)
def worker_process_episode(args: tuple[int, str, str, dict, str, str, bool]) -> dict:
"""
Worker function to process a single episode.
Args:
args: Tuple containing (episode_idx, input_repo, output_repo, shared_data, local_dir, temp_dir, use_chunk_processing)
Returns:
Dict with processing results or error
"""
episode_idx, input_repo, output_repo, shared_data, local_dir_str, temp_dir, use_chunk_processing = args
try:
local_dir = Path(local_dir_str)
# Load dataset for this worker
dataset = LeRobotDataset(
repo_id=input_repo,
root=Path(temp_dir),
episodes=[episode_idx],
download_videos=True,
)
# Get episode boundaries
episode_data_index = dataset.episode_data_index
ep_start = episode_data_index["from"][0].item()
ep_end = episode_data_index["to"][0].item()
episode_length = ep_end - ep_start
# Compute rewards
rewards = compute_linear_progress_reward(episode_length)
# Get episode metadata
episode_info = dataset.meta.episodes[episode_idx]
tasks = episode_info.get("tasks", [])
if not tasks:
first_frame = dataset[ep_start]
if "task" in first_frame:
tasks = [first_frame["task"]]
else:
tasks = [""]
# Prepare episode data
episode_data = {
"episode_idx": episode_idx,
"ep_start": ep_start,
"episode_length": episode_length,
"rewards": rewards,
"tasks": tasks,
"dataset": dataset,
}
if use_chunk_processing:
# Process episode with chunk parallelization
_, frames_data, frame_tasks = process_episode_parallel(episode_data, shared_data)
else:
# Process episode sequentially (fallback)
frames_data = []
frame_tasks = []
for frame_idx in range(episode_length):
global_idx = ep_start + frame_idx
frame_data = dataset[global_idx]
frame = {}
for key in dataset.features:
if key in ["index", "episode_index", "frame_index", "timestamp"]:
continue
if key in shared_data["visual_keys"]:
if key in frame_data:
img = frame_data[key]
if isinstance(img, torch.Tensor):
img = img.cpu().numpy()
# Process image if needed
if (
key in shared_data["new_features"]
and "shape" in shared_data["new_features"][key]
):
expected_shape = shared_data["new_features"][key]["shape"]
img = process_image_batch([img], expected_shape)[0]
frame[key] = img
continue
if key in frame_data:
value = frame_data[key]
if key == "task" and isinstance(value, str):
frame_tasks.append(value)
continue
elif key == "task_index":
continue
if isinstance(value, torch.Tensor):
value = value.cpu().numpy()
if hasattr(value, "shape") and len(value.shape) == 0:
value = np.array([value])
frame[key] = value
frame[REWARD] = np.array([rewards[frame_idx]], dtype=np.float32)
frames_data.append(frame)
if not frame_tasks or len(frame_tasks) <= frame_idx:
frame_tasks.append(tasks[0] if tasks else "")
return {
"episode_idx": episode_idx,
"frames_data": frames_data,
"tasks": frame_tasks if frame_tasks else tasks,
"fps": dataset.fps,
"success": True,
}
except Exception as e:
logger.error(f"Error processing episode {episode_idx}: {e}")
return {"episode_idx": episode_idx, "error": str(e), "success": False}
def main():
parser = argparse.ArgumentParser(
description="Optimized: Add linear progress rewards to LeRobot dataset with parallel processing"
)
parser.add_argument(
"--input-repo",
type=str,
default="IPEC-COMMUNITY/bc_z_lerobot",
help="Input dataset repository on HuggingFace Hub",
)
parser.add_argument(
"--output-repo",
type=str,
required=True,
help="Output dataset repository name (e.g., username/dataset_with_rewards)",
)
parser.add_argument(
"--percentage",
type=float,
default=100.0,
help="Percentage of episodes to process (useful for testing, e.g., 1 for 1%%)",
)
parser.add_argument(
"--num-workers",
type=int,
default=None,
help="Number of parallel workers (defaults to CPU count - 2)",
)
parser.add_argument(
"--chunk-size",
type=int,
default=50,
help="Number of frames to process per chunk within an episode",
)
parser.add_argument(
"--private",
action="store_true",
help="Make the output repository private",
)
parser.add_argument(
"--local-dir",
type=str,
default=None,
help="Local directory to save the modified dataset",
)
parser.add_argument(
"--no-chunk-processing",
action="store_true",
help="Disable chunk-based parallel processing within episodes",
)
args = parser.parse_args()
# Determine number of workers
if args.num_workers is None:
args.num_workers = max(1, cpu_count() - 2)
print("=" * 60)
print("OPTIMIZED DATASET COPY WITH REWARDS")
print(f"Using {args.num_workers} parallel workers")
print("=" * 60)
# Load metadata
print(f"\nLoading metadata from Hub: {args.input_repo}")
metadata = LeRobotDatasetMetadata(repo_id=args.input_repo)
total_episodes = metadata.total_episodes
# Calculate episodes to process
num_episodes_to_process = max(1, int(total_episodes * args.percentage / 100))
episodes_to_load = list(range(num_episodes_to_process))
print(f"Dataset has {total_episodes} episodes")
print(f"Processing {num_episodes_to_process} episodes ({args.percentage}%)")
# Determine local directory
if args.local_dir:
local_dir = Path(args.local_dir)
else:
from lerobot.constants import HF_LEROBOT_HOME
local_dir = HF_LEROBOT_HOME / args.output_repo
# Create temporary directories for workers
temp_base_dir = Path(mkdtemp(prefix="lerobot_parallel_"))
worker_temp_dirs = []
for i in range(args.num_workers):
worker_dir = temp_base_dir / f"worker_{i}"
worker_dir.mkdir(parents=True, exist_ok=True)
worker_temp_dirs.append(str(worker_dir))
print(f"Using temporary base directory: {temp_base_dir}")
# Load first episode to get features and structure
print("\nLoading dataset structure...")
sample_dataset = LeRobotDataset(
repo_id=args.input_repo,
root=temp_base_dir / "sample",
episodes=[0],
download_videos=True,
)
# Prepare features with reward
new_features = dict(sample_dataset.features)
new_features[REWARD] = {"shape": (1,), "dtype": "float32", "names": ["reward"]}
# Determine visual keys
video_keys = sample_dataset.meta.video_keys if hasattr(sample_dataset.meta, "video_keys") else []
image_keys = sample_dataset.meta.image_keys if hasattr(sample_dataset.meta, "image_keys") else []
visual_keys = set(video_keys + image_keys)
print(f" Visual features: {visual_keys}")
# Clean up existing directory
if local_dir.exists():
print(f"⚠️ Directory already exists: {local_dir}")
print(" Removing it to start fresh...")
shutil.rmtree(local_dir)
# Create new dataset structure
print("\nCreating new dataset structure...")
new_dataset = LeRobotDataset.create(
repo_id=args.output_repo,
root=local_dir,
fps=sample_dataset.fps,
features=new_features,
robot_type=sample_dataset.meta.robot_type,
use_videos=len(sample_dataset.meta.video_keys) > 0,
)
# Prepare shared data for workers
shared_data = {
"new_features": new_features,
"visual_keys": visual_keys,
"fps": sample_dataset.fps,
}
# Process episodes in parallel
print(f"\nProcessing {num_episodes_to_process} episodes with {args.num_workers} workers...")
# Prepare worker arguments
worker_args = []
for i, episode_idx in enumerate(episodes_to_load):
# Assign worker temp directory round-robin
temp_dir = worker_temp_dirs[i % args.num_workers]
worker_args.append(
(
episode_idx,
args.input_repo,
args.output_repo,
shared_data,
str(local_dir),
temp_dir,
not args.no_chunk_processing,
)
)
# Process episodes using multiprocessing
processed_episodes = {}
failed_episodes = []
with Pool(processes=args.num_workers) as pool:
# Use imap_unordered for better progress tracking
with tqdm(total=num_episodes_to_process, desc="Processing episodes") as pbar:
for result in pool.imap_unordered(worker_process_episode, worker_args):
pbar.update(1)
if result["success"]:
processed_episodes[result["episode_idx"]] = result
else:
failed_episodes.append(result["episode_idx"])
logger.error(
f"Failed episode {result['episode_idx']}: {result.get('error', 'Unknown error')}"
)
# Add processed episodes to the new dataset in order
print("\nSaving processed episodes to new dataset...")
for episode_idx in tqdm(episodes_to_load, desc="Saving episodes"):
if episode_idx in processed_episodes:
result = processed_episodes[episode_idx]
# Add all frames for this episode
for i, frame_data in enumerate(result["frames_data"]):
task = result["tasks"][i] if i < len(result["tasks"]) else result["tasks"][0]
timestamp = i / result["fps"]
new_dataset.add_frame(frame_data, task=task, timestamp=timestamp)
# Save the episode
new_dataset.save_episode()
print(
f"\n✓ Created new dataset with rewards: {new_dataset.num_episodes} episodes, {new_dataset.num_frames} frames"
)
if failed_episodes:
print(f"⚠️ Failed to process {len(failed_episodes)} episodes: {failed_episodes}")
# Push to Hub
print(f"\nPushing to Hub: {args.output_repo}")
new_dataset.push_to_hub(
private=args.private,
push_videos=True,
)
print(f"\n✓ Dataset pushed to: https://huggingface.co/datasets/{args.output_repo}")
# Clean up temporary directories
if temp_base_dir.exists():
print("\nCleaning up temporary files...")
shutil.rmtree(temp_base_dir)
# Print summary
print("\n=== Summary ===")
print(f"Input dataset: {args.input_repo}")
print(f"Output dataset: {args.output_repo}")
print(f"Episodes processed: {num_episodes_to_process - len(failed_episodes)}/{total_episodes}")
print(f"Frames with rewards: {new_dataset.num_frames}")
print(f"Parallel workers used: {args.num_workers}")
print(f"Processing time saved: ~{args.num_workers - 1}x faster")
print("===============")
if __name__ == "__main__":
main()
+188
View File
@@ -0,0 +1,188 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Test script for RLearN evaluation metrics.
This script tests the VOC-S and success/failure detection metrics with synthetic data
to ensure they work correctly before running on real datasets.
"""
import numpy as np
from lerobot.policies.rlearn.evaluation import (
compute_success_failure_detection,
compute_voc_s,
generate_mismatched_languages,
)
def test_voc_s():
"""Test VOC-S computation with synthetic data."""
print("Testing VOC-S computation...")
# Test case 1: Perfect positive correlation (0 -> 1)
perfect_positive = [np.linspace(0, 1, 20) for _ in range(10)]
results = compute_voc_s(perfect_positive)
print("Perfect positive correlation:")
print(f" Mean: {results['voc_s_mean']:.4f} (should be ~1.0)")
print(f" IQM: {results['voc_s_iqm']:.4f} (should be ~1.0)")
assert results["voc_s_mean"] > 0.95, f"Expected >0.95, got {results['voc_s_mean']}"
# Test case 2: Perfect negative correlation (1 -> 0)
perfect_negative = [np.linspace(1, 0, 20) for _ in range(10)]
results = compute_voc_s(perfect_negative)
print("Perfect negative correlation:")
print(f" Mean: {results['voc_s_mean']:.4f} (should be ~-1.0)")
print(f" IQM: {results['voc_s_iqm']:.4f} (should be ~-1.0)")
assert results["voc_s_mean"] < -0.95, f"Expected <-0.95, got {results['voc_s_mean']}"
# Test case 3: No correlation (random)
np.random.seed(42)
random_rewards = [np.random.random(20) for _ in range(50)]
results = compute_voc_s(random_rewards)
print("Random correlation:")
print(f" Mean: {results['voc_s_mean']:.4f} (should be ~0.0)")
print(f" IQM: {results['voc_s_iqm']:.4f} (should be ~0.0)")
assert abs(results["voc_s_mean"]) < 0.3, f"Expected ~0, got {results['voc_s_mean']}"
# Test case 4: Mixed correlations
mixed = []
mixed.extend([np.linspace(0, 1, 15) for _ in range(5)]) # Positive
mixed.extend([np.linspace(1, 0, 15) for _ in range(5)]) # Negative
mixed.extend([np.random.random(15) for _ in range(5)]) # Random
results = compute_voc_s(mixed)
print("Mixed correlations:")
print(f" Mean: {results['voc_s_mean']:.4f}")
print(f" IQM: {results['voc_s_iqm']:.4f}")
print(f" Std: {results['voc_s_std']:.4f}")
print("✓ VOC-S tests passed!\n")
def test_success_failure_detection():
"""Test success/failure detection with synthetic data."""
print("Testing Success/Failure Detection...")
# Test case 1: Clear separation (correct > incorrect)
correct_rewards = [np.linspace(0, 1, 20) for _ in range(20)] # Always increasing
incorrect_rewards = [np.linspace(0, 0.3, 20) for _ in range(20)] # Lower final values
results = compute_success_failure_detection(correct_rewards, incorrect_rewards)
print("Clear separation test:")
print(f" Detection accuracy: {results['detection_accuracy']:.4f} (should be 1.0)")
print(f" Mean correct: {results['mean_correct_final']:.4f}")
print(f" Mean incorrect: {results['mean_incorrect_final']:.4f}")
print(f" Separation score: {results['separation_score']:.4f}")
assert results["detection_accuracy"] == 1.0, f"Expected 1.0, got {results['detection_accuracy']}"
# Test case 2: No separation (same distributions with some randomness)
np.random.seed(42)
same_rewards_1 = [np.random.normal(0.5, 0.05, 15) for _ in range(20)]
same_rewards_2 = [np.random.normal(0.5, 0.05, 15) for _ in range(20)]
results = compute_success_failure_detection(same_rewards_1, same_rewards_2)
print("No separation test:")
print(f" Detection accuracy: {results['detection_accuracy']:.4f} (should be ~0.5)")
print(f" Separation score: {results['separation_score']:.4f} (should be ~0.0)")
# Relax the assertion since random data can vary
assert 0.2 <= results["detection_accuracy"] <= 0.8, (
f"Expected ~0.5 (±0.3), got {results['detection_accuracy']}"
)
# Test case 3: Partial separation
np.random.seed(42)
partial_correct = [np.random.normal(0.7, 0.1, 10) for _ in range(20)]
partial_incorrect = [np.random.normal(0.4, 0.1, 10) for _ in range(20)]
results = compute_success_failure_detection(partial_correct, partial_incorrect)
print("Partial separation test:")
print(f" Detection accuracy: {results['detection_accuracy']:.4f}")
print(f" Separation score: {results['separation_score']:.4f}")
print("✓ Success/Failure Detection tests passed!\n")
def test_mismatch_generation():
"""Test mismatch language generation."""
print("Testing mismatch language generation...")
original_languages = [
"pick up the red ball",
"put the cup on the table",
"open the drawer",
"close the door",
]
# Test with default templates
mismatched = generate_mismatched_languages(original_languages)
print(f"Original languages: {len(original_languages)}")
print(f"Mismatched languages: {len(mismatched)}")
assert len(mismatched) == len(original_languages)
# Ensure they're actually different
for orig, mismatch in zip(original_languages, mismatched, strict=False):
print(f" '{orig}' -> '{mismatch}'")
assert orig != mismatch, "Mismatch should be different from original"
# Test with custom templates
custom_templates = ["dance", "sing", "jump"]
mismatched_custom = generate_mismatched_languages(original_languages, custom_templates)
print("\nWith custom templates:")
for orig, mismatch in zip(original_languages, mismatched_custom, strict=False):
print(f" '{orig}' -> '{mismatch}'")
assert mismatch in custom_templates
print("✓ Mismatch generation tests passed!\n")
def test_edge_cases():
"""Test edge cases and error handling."""
print("Testing edge cases...")
# Empty input
empty_results = compute_voc_s([])
assert empty_results["num_episodes"] == 0
assert empty_results["voc_s_mean"] == 0.0
# Single frame episodes (should be skipped)
single_frame = [np.array([0.5]) for _ in range(5)]
results = compute_voc_s(single_frame)
assert results["num_episodes"] == 0, "Single-frame episodes should be skipped"
# Constant rewards (should give correlation = 0)
constant_rewards = [np.ones(10) * 0.5 for _ in range(5)]
results = compute_voc_s(constant_rewards)
print(f"Constant rewards correlation: {results['voc_s_mean']:.4f} (should be 0.0)")
assert results["voc_s_mean"] == 0.0
# Mismatched array lengths for detection
try:
compute_success_failure_detection([np.array([1, 2])], [])
assert False, "Should have raised ValueError"
except ValueError:
pass # Expected
print("✓ Edge case tests passed!\n")
+237
View File
@@ -0,0 +1,237 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.constants import OBS_IMAGES, OBS_LANGUAGE, REWARD
from lerobot.policies.factory import make_processor
from lerobot.policies.rlearn.configuration_rlearn import RLearNConfig
from lerobot.policies.rlearn.modeling_rlearn import RLearNPolicy
from tests.utils import require_package
@require_package("transformers")
def test_rlearn_instantiation_and_forward_tensor_batch():
"""Instantiate RLearN and run a forward pass with a (B, T, C, H, W) tensor input using a real model and real text."""
cfg = RLearNConfig(
model_name="google/siglip2-large-patch16-256",
push_to_hub=False,
freeze_backbones=True,
)
cfg.input_features = {
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
}
cfg.output_features = {
REWARD: PolicyFeature(type=FeatureType.REWARD, shape=(1,)),
}
policy = RLearNPolicy(cfg)
B, T, C, H, W = 2, 3, 3, 256, 256
batch = {
OBS_IMAGES: torch.rand(B, T, C, H, W),
REWARD: torch.randint(low=0, high=1, size=(B, T)).float(),
OBS_LANGUAGE: ["move the green cube into the box" for _ in range(B)],
}
loss, logs = policy.forward(batch)
assert isinstance(loss, torch.Tensor)
assert "loss" in logs
@require_package("transformers")
def test_rlearn_instantiation_and_forward_list_batch_with_language():
"""Instantiate RLearN and run a forward pass with a list-of-frames input and real language using a real model."""
cfg = RLearNConfig(
model_name="google/siglip2-large-patch16-256",
push_to_hub=False,
freeze_backbones=True,
)
cfg.input_features = {
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
}
cfg.output_features = {
REWARD: PolicyFeature(type=FeatureType.REWARD, shape=(1,)),
}
policy = RLearNPolicy(cfg)
B, T, C, H, W = 2, 4, 3, 256, 256
frames = [torch.rand(B, C, H, W) for _ in range(T)]
batch = {
OBS_IMAGES: frames, # list[(B, C, H, W)]
REWARD: torch.randint(low=0, high=2, size=(B, T)).float(),
OBS_LANGUAGE: ["move the red cube into the box" for _ in range(B)],
}
loss, logs = policy.forward(batch)
assert isinstance(loss, torch.Tensor)
assert "loss" in logs
@require_package("transformers")
def test_rlearn_composite_loss_shapes_and_terms():
"""Smoke test composite loss: checks presence of terms and valid gradients."""
cfg = RLearNConfig(
model_name="google/siglip2-large-patch16-256",
push_to_hub=False,
freeze_backbones=True,
loss_type="composite",
lambda_prog=1.0,
lambda_spatial_nce=0.5,
lambda_rewind=0.4,
num_ranking_pairs=32, # Fewer pairs for testing
last_k_for_nce=2,
)
cfg.input_features = {
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
}
cfg.output_features = {
REWARD: PolicyFeature(type=FeatureType.REWARD, shape=(1,)),
}
policy = RLearNPolicy(cfg)
B, T, C, H, W = 2, 3, 3, 256, 256
# Progress labels y in [0,1]
y = torch.linspace(0, 1, T).unsqueeze(0).repeat(B, 1)
batch = {
OBS_IMAGES: torch.rand(B, T, C, H, W),
REWARD: y.clone(),
OBS_LANGUAGE: ["stack the blocks" for _ in range(B)],
}
loss, logs = policy.forward(batch)
assert isinstance(loss, torch.Tensor) and torch.isfinite(loss)
# Expect composite terms present with spatial awareness and ReWiND
assert "loss_prog" in logs
assert "loss_spatial_nce" in logs
assert "loss_rewind_forward" in logs
assert "loss_rewind_reverse" in logs
@require_package("transformers")
def test_rlearn_preprocessor_tokenizes_and_copies_task():
cfg = RLearNConfig(
model_name="google/siglip2-large-patch16-256",
device="cpu",
push_to_hub=False,
)
cfg.input_features = {
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 64, 64)),
}
cfg.output_features = {
REWARD: PolicyFeature(type=FeatureType.REWARD, shape=(1,)),
}
pre, post = make_processor(cfg, dataset_stats=None)
B, C, H, W = 2, 3, 64, 64
batch = {
"observation.image": torch.rand(B, C, H, W),
REWARD: torch.zeros(B),
"task": ["pick the cube", "place it in the box"],
}
processed = pre(batch)
assert isinstance(processed, dict)
assert f"{OBS_LANGUAGE}.tokens" in processed
assert f"{OBS_LANGUAGE}.attention_mask" in processed
assert OBS_LANGUAGE in processed
tokens = processed[f"{OBS_LANGUAGE}.tokens"]
attn = processed[f"{OBS_LANGUAGE}.attention_mask"]
assert tokens.dim() == 2 and attn.dim() == 2
assert tokens.shape[0] == B and attn.shape[0] == B
@require_package("transformers")
def test_rlearn_preprocessor_string_task_and_to_batch():
cfg = RLearNConfig(
model_name="google/siglip2-large-patch16-256",
device="cpu",
push_to_hub=False,
)
cfg.input_features = {
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 64, 64)),
}
cfg.output_features = {
REWARD: PolicyFeature(type=FeatureType.REWARD, shape=(1,)),
}
pre, post = make_processor(cfg, dataset_stats=None)
# Unbatched image and single string task
batch = {
"observation.image": torch.rand(3, 64, 64),
REWARD: torch.tensor(0.0),
"task": "move the green cube into the box",
}
processed = pre(batch)
# Image should have batch dim now
assert processed["observation.image"].dim() == 4 and processed["observation.image"].shape[0] == 1
# Language copy and tokenization should exist
assert OBS_LANGUAGE in processed and isinstance(processed[OBS_LANGUAGE], list)
assert f"{OBS_LANGUAGE}.tokens" in processed
assert f"{OBS_LANGUAGE}.attention_mask" in processed
@require_package("transformers")
def test_rlearn_pipeline_end_to_end_forward():
"""End-to-end: preprocessor + model forward using RLearN pipeline on synthetic data."""
cfg = RLearNConfig(
model_name="google/siglip2-large-patch16-256",
device="cpu",
push_to_hub=False,
freeze_backbones=True,
loss_type="composite",
)
cfg.input_features = {
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
}
cfg.output_features = {
REWARD: PolicyFeature(type=FeatureType.REWARD, shape=(1,)),
}
# Build processors and model
pre, post = make_processor(cfg, dataset_stats=None)
policy = RLearNPolicy(cfg)
B, T, C, H, W = 2, 3, 3, 256, 256
y = torch.linspace(0, 1, T).unsqueeze(0).repeat(B, 1)
raw = {
# Provide as observation.image to let preprocessor map/normalize and batch
"observation.image": torch.rand(B, C, H, W), # not time-major to test ToBatch
REWARD: y[:, :1].clone(), # single step label; pipeline keeps structure
"task": ["insert the peg", "insert the peg"],
}
processed = pre(raw)
# Integrate preprocessor output with model forward
loss, logs = policy.forward(
{
OBS_IMAGES: processed.get(OBS_IMAGES, processed.get("observation.image"))
.unsqueeze(1)
.repeat(1, T, 1, 1, 1),
REWARD: y.clone(),
OBS_LANGUAGE: processed[OBS_LANGUAGE],
}
)
assert isinstance(loss, torch.Tensor) and torch.isfinite(loss)