mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-24 04:59:47 +00:00
initial commit
This commit is contained in:
File diff suppressed because one or more lines are too long
@@ -34,6 +34,7 @@ from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.rlearn.configuration_rlearn import RLearNConfig
|
||||
from lerobot.policies.sac.configuration_sac import SACConfig
|
||||
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
|
||||
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
|
||||
@@ -80,6 +81,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy
|
||||
|
||||
return SmolVLAPolicy
|
||||
elif name == "rlearn":
|
||||
from lerobot.policies.rlearn.modeling_rlearn import RLearNPolicy
|
||||
|
||||
return RLearNPolicy
|
||||
else:
|
||||
raise NotImplementedError(f"Policy with name {name} is not implemented.")
|
||||
|
||||
@@ -103,6 +108,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
return SmolVLAConfig(**kwargs)
|
||||
elif policy_type == "reward_classifier":
|
||||
return RewardClassifierConfig(**kwargs)
|
||||
elif policy_type == "rlearn":
|
||||
return RLearNConfig(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Policy type '{policy_type}' is not available.")
|
||||
|
||||
@@ -220,6 +227,13 @@ def make_processor(
|
||||
cast(SmolVLAConfig, policy_cfg), dataset_stats=kwargs.get("dataset_stats")
|
||||
)
|
||||
|
||||
elif policy_cfg.type == "rlearn":
|
||||
from lerobot.policies.rlearn.processor_rlearn import make_rlearn_processor
|
||||
|
||||
processors = make_rlearn_processor(
|
||||
cast(RLearNConfig, policy_cfg), dataset_stats=kwargs.get("dataset_stats")
|
||||
)
|
||||
|
||||
else:
|
||||
raise NotImplementedError(f"Processor for policy type '{policy_cfg.type}' is not implemented.")
|
||||
|
||||
|
||||
@@ -0,0 +1,127 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import NormalizationMode
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("rlearn")
|
||||
@dataclass
|
||||
class RLearNConfig(PreTrainedConfig):
|
||||
"""Configuration for a video-language conditioned reward model (RLearN).
|
||||
|
||||
Inputs:
|
||||
- Visual frames (one or multiple cameras). Optionally a short sequence.
|
||||
- A language instruction/goal string.
|
||||
|
||||
Output:
|
||||
- Per-timestep reward logits or a single-step reward logit.
|
||||
|
||||
Notes:
|
||||
- This is the initial architecture. It uses frozen vision/text encoders
|
||||
(e.g. SigLIP2) and trains a lightweight temporal aggregator + head.
|
||||
"""
|
||||
|
||||
# Encoders
|
||||
model_name: str = "google/siglip2-large-patch16-256"
|
||||
freeze_backbones: bool = True
|
||||
|
||||
# Temporal aggregator
|
||||
dim_model: int = 512
|
||||
n_heads: int = 8
|
||||
n_layers: int = 4
|
||||
dim_feedforward: int = 2048
|
||||
dropout: float = 0.1
|
||||
pre_norm: bool = True
|
||||
use_first_frame_positional_bias: bool = True
|
||||
frame_dropout_p: float = 0.0
|
||||
stride: int = 1
|
||||
|
||||
# Sequence length, amount of past frames including current one to use in the temporal model
|
||||
max_seq_len: int = 16
|
||||
|
||||
# Head
|
||||
use_tanh_head: bool = False # when True, bound outputs in [-1, 1]
|
||||
|
||||
# Training
|
||||
learning_rate: float = 5e-5 # Reduced for stability
|
||||
weight_decay: float = 0.01
|
||||
loss_type: str = "composite" # Always use composite loss with spatial awareness
|
||||
ranking_margin: float = 0.1
|
||||
|
||||
# Composite loss weights (with spatial awareness and ReWiND reversibility)
|
||||
lambda_prog: float = 1.0 # Progress regression weight
|
||||
lambda_spatial_nce: float = 0.5 # Spatial-aware InfoNCE weight
|
||||
lambda_rewind: float = 0.4 # ReWiND reversible ranking weight
|
||||
|
||||
# Loss hyperparameters
|
||||
nce_temperature: float = 0.07 # Temperature for InfoNCE
|
||||
zscore_eps: float = 1e-5 # Epsilon for z-score normalization
|
||||
min_rank_gap: int = 1 # Minimum gap for temporal ranking pairs
|
||||
num_ranking_pairs: int = 64 # Number of (far, near) pairs to sample for ReWiND
|
||||
last_k_for_nce: int = 3 # Use last k frames for InfoNCE
|
||||
mismatch_lang_prob: float = 0.2 # Probability of language mismatch augmentation
|
||||
|
||||
# Value-based pairwise loss hyperparameters (for value_pairwise mode)
|
||||
lambda_dir: float = 1.0 # Intra-trajectory directional ranking
|
||||
lambda_text: float = 0.5 # Inter-instruction contrastive ranking
|
||||
lambda_flat: float = 0.25 # Flatness under mismatch
|
||||
dir_margin: float = 0.2 # Margin for directional ranking
|
||||
text_margin: float = 0.2 # Margin for text contrastive ranking
|
||||
flat_epsilon: float = 0.05 # Epsilon band for flatness loss
|
||||
num_pairs_per_loss: int = 64 # Number of pairs to sample per loss term
|
||||
use_hard_negatives: bool = True # Whether to generate hard negative instructions
|
||||
|
||||
# Normalization presets
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.MEAN_STD,
|
||||
# Language is tokenized at the encoder level; no numeric normalization here.
|
||||
}
|
||||
)
|
||||
|
||||
def validate_features(self) -> None:
|
||||
# Require at least one image feature. Language is recommended but optional (can be blank).
|
||||
if not self.image_features:
|
||||
raise ValueError(
|
||||
"You must provide at least one image feature for RLearN (e.g. 'observation.image')."
|
||||
)
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> list | None:
|
||||
# Not using delta sampling from the dataset by default.
|
||||
return None
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list | None:
|
||||
# Not an action chunking policy.
|
||||
return None
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> list | None:
|
||||
# By default we supervise every provided timestep equally.
|
||||
return None
|
||||
|
||||
def get_optimizer_preset(self): # type: ignore[override]
|
||||
from lerobot.optim.optimizers import AdamWConfig
|
||||
|
||||
return AdamWConfig(lr=self.learning_rate, weight_decay=self.weight_decay)
|
||||
|
||||
def get_scheduler_preset(self): # type: ignore[override]
|
||||
# No scheduler by default.
|
||||
return None
|
||||
@@ -0,0 +1,601 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Evaluation metrics for RLearn (Video-Language Conditioned Reward Model).
|
||||
|
||||
Key metrics:
|
||||
1. VOC-S (Value-Order Correlation for Success): Spearman correlation between frame indices and predicted rewards
|
||||
2. Success vs Failure Detection: Model's ability to distinguish between correct and incorrect language conditions
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from scipy.stats import spearmanr
|
||||
from torch import Tensor
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.constants import OBS_IMAGES, OBS_LANGUAGE
|
||||
|
||||
|
||||
def compute_voc_s(
|
||||
predicted_rewards: list[np.ndarray], use_interquartile_mean: bool = True
|
||||
) -> dict[str, float]:
|
||||
"""
|
||||
Compute Value-Order Correlation for Success (VOC-S).
|
||||
|
||||
Measures whether per-frame rewards increase as successful execution unfolds.
|
||||
For each episode, computes Spearman correlation between frame indices [1..T]
|
||||
and predicted rewards [r1..rT].
|
||||
|
||||
Args:
|
||||
predicted_rewards: List of reward arrays, one per episode. Each array has shape (T,)
|
||||
use_interquartile_mean: If True, use IQM instead of mean for aggregation
|
||||
|
||||
Returns:
|
||||
Dictionary with VOC-S metrics:
|
||||
- voc_s_mean: Mean Spearman correlation across episodes
|
||||
- voc_s_std: Standard deviation of correlations
|
||||
- voc_s_iqm: Interquartile mean (if use_interquartile_mean=True)
|
||||
- num_episodes: Number of episodes evaluated
|
||||
- correlations: Individual correlations per episode
|
||||
"""
|
||||
if not predicted_rewards:
|
||||
return {"voc_s_mean": 0.0, "voc_s_std": 0.0, "voc_s_iqm": 0.0, "num_episodes": 0, "correlations": []}
|
||||
|
||||
correlations = []
|
||||
|
||||
for episode_rewards in predicted_rewards:
|
||||
if len(episode_rewards) < 2:
|
||||
# Need at least 2 points for correlation
|
||||
continue
|
||||
|
||||
# Frame indices: [1, 2, ..., T]
|
||||
frame_indices = np.arange(1, len(episode_rewards) + 1)
|
||||
|
||||
# Compute Spearman correlation
|
||||
try:
|
||||
correlation, p_value = spearmanr(frame_indices, episode_rewards)
|
||||
|
||||
# Handle NaN correlations (e.g., all rewards are identical)
|
||||
if np.isnan(correlation):
|
||||
correlation = 0.0
|
||||
|
||||
correlations.append(correlation)
|
||||
|
||||
except Exception as e:
|
||||
warnings.warn(f"Failed to compute correlation for episode: {e}")
|
||||
correlations.append(0.0)
|
||||
|
||||
if not correlations:
|
||||
return {"voc_s_mean": 0.0, "voc_s_std": 0.0, "voc_s_iqm": 0.0, "num_episodes": 0, "correlations": []}
|
||||
|
||||
correlations = np.array(correlations)
|
||||
|
||||
# Compute statistics
|
||||
voc_s_mean = float(np.mean(correlations))
|
||||
voc_s_std = float(np.std(correlations))
|
||||
|
||||
# Interquartile mean: mean of values between 25th and 75th percentiles
|
||||
if use_interquartile_mean and len(correlations) >= 4:
|
||||
q25, q75 = np.percentile(correlations, [25, 75])
|
||||
iqm_mask = (correlations >= q25) & (correlations <= q75)
|
||||
voc_s_iqm = float(np.mean(correlations[iqm_mask]))
|
||||
else:
|
||||
voc_s_iqm = voc_s_mean
|
||||
|
||||
return {
|
||||
"voc_s_mean": voc_s_mean,
|
||||
"voc_s_std": voc_s_std,
|
||||
"voc_s_iqm": voc_s_iqm,
|
||||
"num_episodes": len(correlations),
|
||||
"correlations": correlations.tolist(),
|
||||
}
|
||||
|
||||
|
||||
def compute_success_failure_detection(
|
||||
correct_rewards: list[np.ndarray], incorrect_rewards: list[np.ndarray], threshold_percentile: float = 50.0
|
||||
) -> dict[str, float]:
|
||||
"""
|
||||
Compute success vs failure detection accuracy.
|
||||
|
||||
Tests the model's ability to distinguish between correct and incorrect language conditions.
|
||||
For each episode, compares final reward under correct vs incorrect language instruction.
|
||||
|
||||
Args:
|
||||
correct_rewards: List of reward arrays for episodes with correct language
|
||||
incorrect_rewards: List of reward arrays for episodes with incorrect/mismatched language
|
||||
threshold_percentile: Percentile of correct rewards to use as threshold
|
||||
|
||||
Returns:
|
||||
Dictionary with detection metrics:
|
||||
- detection_accuracy: Fraction of episodes where correct > incorrect
|
||||
- mean_correct_final: Mean final reward for correct language
|
||||
- mean_incorrect_final: Mean final reward for incorrect language
|
||||
- separation_score: (mean_correct - mean_incorrect) / (std_correct + std_incorrect)
|
||||
- num_pairs: Number of episode pairs evaluated
|
||||
"""
|
||||
if len(correct_rewards) != len(incorrect_rewards):
|
||||
raise ValueError("Must have same number of correct and incorrect reward sequences")
|
||||
|
||||
if not correct_rewards:
|
||||
return {
|
||||
"detection_accuracy": 0.0,
|
||||
"mean_correct_final": 0.0,
|
||||
"mean_incorrect_final": 0.0,
|
||||
"separation_score": 0.0,
|
||||
"num_pairs": 0,
|
||||
}
|
||||
|
||||
# Extract final rewards (last timestep of each episode)
|
||||
correct_finals = []
|
||||
incorrect_finals = []
|
||||
|
||||
for correct_ep, incorrect_ep in zip(correct_rewards, incorrect_rewards, strict=False):
|
||||
if len(correct_ep) > 0 and len(incorrect_ep) > 0:
|
||||
correct_finals.append(correct_ep[-1]) # Final reward
|
||||
incorrect_finals.append(incorrect_ep[-1]) # Final reward
|
||||
|
||||
if not correct_finals:
|
||||
return {
|
||||
"detection_accuracy": 0.0,
|
||||
"mean_correct_final": 0.0,
|
||||
"mean_incorrect_final": 0.0,
|
||||
"separation_score": 0.0,
|
||||
"num_pairs": 0,
|
||||
}
|
||||
|
||||
correct_finals = np.array(correct_finals)
|
||||
incorrect_finals = np.array(incorrect_finals)
|
||||
|
||||
# Detection accuracy: fraction where correct > incorrect
|
||||
detection_accuracy = float(np.mean(correct_finals > incorrect_finals))
|
||||
|
||||
# Statistics
|
||||
mean_correct = float(np.mean(correct_finals))
|
||||
mean_incorrect = float(np.mean(incorrect_finals))
|
||||
std_correct = float(np.std(correct_finals))
|
||||
std_incorrect = float(np.std(incorrect_finals))
|
||||
|
||||
# Separation score: normalized difference (clamp to prevent extreme values)
|
||||
denominator = std_correct + std_incorrect
|
||||
if denominator > 1e-6: # Prevent division by very small numbers
|
||||
separation_score = (mean_correct - mean_incorrect) / denominator
|
||||
# Clamp to reasonable range
|
||||
separation_score = np.clip(separation_score, -100.0, 100.0)
|
||||
else:
|
||||
separation_score = 0.0
|
||||
|
||||
return {
|
||||
"detection_accuracy": detection_accuracy,
|
||||
"mean_correct_final": mean_correct,
|
||||
"mean_incorrect_final": mean_incorrect,
|
||||
"separation_score": float(separation_score),
|
||||
"num_pairs": len(correct_finals),
|
||||
}
|
||||
|
||||
|
||||
def generate_mismatched_languages(
|
||||
original_languages: list[str], mismatch_templates: list[str] | None = None
|
||||
) -> list[str]:
|
||||
"""
|
||||
Generate mismatched language instructions for failure detection evaluation.
|
||||
|
||||
Args:
|
||||
original_languages: List of original task descriptions
|
||||
mismatch_templates: Custom mismatch templates. If None, uses defaults.
|
||||
|
||||
Returns:
|
||||
List of mismatched language instructions
|
||||
"""
|
||||
if mismatch_templates is None:
|
||||
mismatch_templates = ["kick the ball", "walk to the red shoes", "wave", "do nothing"]
|
||||
|
||||
# For each original language, pick a random mismatch
|
||||
mismatched = []
|
||||
np.random.seed(42) # For reproducibility
|
||||
|
||||
for i, orig_lang in enumerate(original_languages):
|
||||
# Use modulo to cycle through mismatches if we have more episodes than templates
|
||||
mismatch_idx = i % len(mismatch_templates)
|
||||
mismatched.append(mismatch_templates[mismatch_idx])
|
||||
|
||||
return mismatched
|
||||
|
||||
|
||||
class RLearnEvaluator:
|
||||
"""
|
||||
Comprehensive evaluator for RLearN reward models.
|
||||
|
||||
Provides methods to evaluate VOC-S and success/failure detection on datasets.
|
||||
"""
|
||||
|
||||
def __init__(self, model, device: str = "cuda"):
|
||||
"""
|
||||
Args:
|
||||
model: RLearN model instance
|
||||
device: Device to run evaluation on
|
||||
"""
|
||||
self.model = model
|
||||
self.device = device
|
||||
self.model.eval()
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_episode_rewards(self, frames: Tensor, language: str, batch_size: int = 16) -> np.ndarray:
|
||||
"""
|
||||
Predict rewards for a single episode.
|
||||
|
||||
Args:
|
||||
frames: Video frames tensor of shape (T, C, H, W)
|
||||
language: Language instruction string
|
||||
batch_size: Maximum sequence length to process at once
|
||||
|
||||
Returns:
|
||||
Predicted rewards array of shape (T,)
|
||||
"""
|
||||
T = frames.shape[0]
|
||||
|
||||
# Preprocess frames to match model expectations
|
||||
processed_frames = self._preprocess_frames(frames)
|
||||
|
||||
# Process in chunks if episode is very long
|
||||
if T <= batch_size:
|
||||
# Single batch
|
||||
batch = {
|
||||
OBS_IMAGES: processed_frames.unsqueeze(0).to(self.device), # (1, T, C, H, W)
|
||||
OBS_LANGUAGE: [language],
|
||||
}
|
||||
|
||||
# Use the new predict_rewards method
|
||||
values = self.model.predict_rewards(batch) # (1, T')
|
||||
rewards = values.squeeze(0).cpu().numpy() # (T',)
|
||||
|
||||
else:
|
||||
# Process in overlapping chunks to handle very long episodes
|
||||
rewards = []
|
||||
stride = batch_size // 2 # 50% overlap
|
||||
|
||||
for i in range(0, T, stride):
|
||||
end_idx = min(i + batch_size, T)
|
||||
chunk_frames = processed_frames[i:end_idx]
|
||||
|
||||
batch = {OBS_IMAGES: chunk_frames.unsqueeze(0).to(self.device), OBS_LANGUAGE: [language]}
|
||||
|
||||
chunk_values = self.model.predict_rewards(batch)
|
||||
chunk_rewards = chunk_values.squeeze(0).cpu().numpy()
|
||||
|
||||
# For overlapping chunks, only take the first half (except for the last chunk)
|
||||
if i + batch_size < T:
|
||||
rewards.extend(chunk_rewards[:stride])
|
||||
else:
|
||||
rewards.extend(chunk_rewards)
|
||||
|
||||
rewards = np.array(rewards[:T]) # Ensure exact length
|
||||
|
||||
return rewards
|
||||
|
||||
def _preprocess_frames(self, frames: Tensor) -> Tensor:
|
||||
"""
|
||||
Preprocess frames to match model expectations.
|
||||
|
||||
Args:
|
||||
frames: Input frames tensor of shape (T, C, H, W)
|
||||
|
||||
Returns:
|
||||
Preprocessed frames tensor of shape (T, C, H', W')
|
||||
"""
|
||||
import torch.nn.functional as F
|
||||
|
||||
T, C, H, W = frames.shape
|
||||
|
||||
# Expected input size for SigLIP2 is typically 256x256
|
||||
target_size = 256
|
||||
|
||||
# Resize frames if needed
|
||||
if H != target_size or W != target_size:
|
||||
# Resize using bilinear interpolation
|
||||
frames = F.interpolate(
|
||||
frames, size=(target_size, target_size), mode="bilinear", align_corners=False
|
||||
)
|
||||
|
||||
# Normalize to [0, 1] if needed
|
||||
if frames.dtype == torch.uint8:
|
||||
frames = frames.float() / 255.0
|
||||
|
||||
# Ensure values are in [0, 1] range
|
||||
frames = torch.clamp(frames, 0.0, 1.0)
|
||||
|
||||
return frames
|
||||
|
||||
def evaluate_voc_s(
|
||||
self, dataset, num_episodes: int = 100, use_interquartile_mean: bool = True
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Evaluate VOC-S on a dataset.
|
||||
|
||||
Args:
|
||||
dataset: LeRobot dataset instance
|
||||
num_episodes: Number of episodes to evaluate (randomly sampled)
|
||||
use_interquartile_mean: Whether to compute IQM
|
||||
|
||||
Returns:
|
||||
VOC-S evaluation results
|
||||
"""
|
||||
print(f"Evaluating VOC-S on {num_episodes} episodes...")
|
||||
|
||||
# Sample episodes
|
||||
total_episodes = dataset.num_episodes
|
||||
if num_episodes >= total_episodes:
|
||||
episode_indices = list(range(total_episodes))
|
||||
else:
|
||||
np.random.seed(42)
|
||||
episode_indices = np.random.choice(total_episodes, num_episodes, replace=False)
|
||||
|
||||
predicted_rewards = []
|
||||
|
||||
for ep_idx in tqdm(episode_indices, desc="Computing VOC-S"):
|
||||
try:
|
||||
# Get episode data
|
||||
ep_start = dataset.episode_data_index["from"][ep_idx].item()
|
||||
ep_end = dataset.episode_data_index["to"][ep_idx].item()
|
||||
episode_length = ep_end - ep_start
|
||||
|
||||
# Get frames and language for this episode
|
||||
frames = []
|
||||
language = None
|
||||
|
||||
for frame_idx in range(episode_length):
|
||||
global_idx = ep_start + frame_idx
|
||||
frame_data = dataset[global_idx]
|
||||
|
||||
# Extract image (assuming single camera for now)
|
||||
if OBS_IMAGES in frame_data:
|
||||
img = frame_data[OBS_IMAGES]
|
||||
else:
|
||||
# Try to find image key
|
||||
img_keys = [k for k in frame_data.keys() if "image" in k.lower()]
|
||||
if img_keys:
|
||||
img = frame_data[img_keys[0]]
|
||||
else:
|
||||
continue
|
||||
|
||||
# Convert to tensor if needed
|
||||
if isinstance(img, np.ndarray):
|
||||
img = torch.from_numpy(img)
|
||||
|
||||
# Ensure CHW format
|
||||
if len(img.shape) == 3 and img.shape[-1] in [1, 3, 4]:
|
||||
img = img.permute(2, 0, 1) # HWC -> CHW
|
||||
|
||||
# Resize to expected input size (256x256 for SigLIP2) BEFORE stacking
|
||||
if img.shape[-2:] != (256, 256):
|
||||
import torch.nn.functional as F
|
||||
|
||||
img = F.interpolate(
|
||||
img.unsqueeze(0), size=(256, 256), mode="bilinear", align_corners=False
|
||||
).squeeze(0)
|
||||
|
||||
# Normalize to [0, 1] if needed
|
||||
if img.dtype == torch.uint8:
|
||||
img = img.float() / 255.0
|
||||
|
||||
frames.append(img)
|
||||
|
||||
# Get language instruction
|
||||
if language is None:
|
||||
if OBS_LANGUAGE in frame_data:
|
||||
language = frame_data[OBS_LANGUAGE]
|
||||
if isinstance(language, list):
|
||||
language = language[0]
|
||||
elif "task" in frame_data:
|
||||
language = frame_data["task"]
|
||||
else:
|
||||
language = "" # Default empty language
|
||||
|
||||
if not frames:
|
||||
continue
|
||||
|
||||
# Stack frames into video tensor
|
||||
frames_tensor = torch.stack(frames) # (T, C, H, W)
|
||||
|
||||
# Predict rewards
|
||||
episode_rewards = self.predict_episode_rewards(frames_tensor, language)
|
||||
predicted_rewards.append(episode_rewards)
|
||||
|
||||
except Exception as e:
|
||||
warnings.warn(f"Failed to process episode {ep_idx}: {e}")
|
||||
continue
|
||||
|
||||
# Compute VOC-S
|
||||
voc_results = compute_voc_s(predicted_rewards, use_interquartile_mean)
|
||||
|
||||
print("VOC-S Results:")
|
||||
print(f" Mean correlation: {voc_results['voc_s_mean']:.4f}")
|
||||
print(f" Std correlation: {voc_results['voc_s_std']:.4f}")
|
||||
print(f" IQM correlation: {voc_results['voc_s_iqm']:.4f}")
|
||||
print(f" Episodes evaluated: {voc_results['num_episodes']}")
|
||||
|
||||
return voc_results
|
||||
|
||||
def evaluate_success_failure_detection(
|
||||
self, dataset, num_episodes: int = 100, mismatch_templates: list[str] | None = None
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Evaluate success vs failure detection.
|
||||
|
||||
Args:
|
||||
dataset: LeRobot dataset instance
|
||||
num_episodes: Number of episodes to evaluate
|
||||
mismatch_templates: Custom mismatch language templates
|
||||
|
||||
Returns:
|
||||
Success/failure detection results
|
||||
"""
|
||||
print(f"Evaluating success/failure detection on {num_episodes} episodes...")
|
||||
|
||||
# Sample episodes
|
||||
total_episodes = dataset.num_episodes
|
||||
if num_episodes >= total_episodes:
|
||||
episode_indices = list(range(total_episodes))
|
||||
else:
|
||||
np.random.seed(42)
|
||||
episode_indices = np.random.choice(total_episodes, num_episodes, replace=False)
|
||||
|
||||
correct_rewards = []
|
||||
incorrect_rewards = []
|
||||
|
||||
# Get original languages
|
||||
original_languages = []
|
||||
for ep_idx in episode_indices:
|
||||
ep_start = dataset.episode_data_index["from"][ep_idx].item()
|
||||
frame_data = dataset[ep_start]
|
||||
|
||||
if OBS_LANGUAGE in frame_data:
|
||||
lang = frame_data[OBS_LANGUAGE]
|
||||
if isinstance(lang, list):
|
||||
lang = lang[0]
|
||||
elif "task" in frame_data:
|
||||
lang = frame_data["task"]
|
||||
else:
|
||||
lang = ""
|
||||
|
||||
original_languages.append(lang)
|
||||
|
||||
# Generate mismatched languages
|
||||
mismatched_languages = generate_mismatched_languages(original_languages, mismatch_templates)
|
||||
|
||||
for i, ep_idx in enumerate(tqdm(episode_indices, desc="Computing detection metrics")):
|
||||
try:
|
||||
# Get episode frames (same as VOC-S evaluation)
|
||||
ep_start = dataset.episode_data_index["from"][ep_idx].item()
|
||||
ep_end = dataset.episode_data_index["to"][ep_idx].item()
|
||||
episode_length = ep_end - ep_start
|
||||
|
||||
frames = []
|
||||
for frame_idx in range(episode_length):
|
||||
global_idx = ep_start + frame_idx
|
||||
frame_data = dataset[global_idx]
|
||||
|
||||
# Extract image
|
||||
if OBS_IMAGES in frame_data:
|
||||
img = frame_data[OBS_IMAGES]
|
||||
else:
|
||||
img_keys = [k for k in frame_data.keys() if "image" in k.lower()]
|
||||
if img_keys:
|
||||
img = frame_data[img_keys[0]]
|
||||
else:
|
||||
continue
|
||||
|
||||
if isinstance(img, np.ndarray):
|
||||
img = torch.from_numpy(img)
|
||||
|
||||
if len(img.shape) == 3 and img.shape[-1] in [1, 3, 4]:
|
||||
img = img.permute(2, 0, 1)
|
||||
|
||||
# Resize to expected input size (256x256 for SigLIP2)
|
||||
if img.shape[-2:] != (256, 256):
|
||||
import torch.nn.functional as F
|
||||
|
||||
img = F.interpolate(
|
||||
img.unsqueeze(0), size=(256, 256), mode="bilinear", align_corners=False
|
||||
).squeeze(0)
|
||||
|
||||
# Normalize to [0, 1] if needed
|
||||
if img.dtype == torch.uint8:
|
||||
img = img.float() / 255.0
|
||||
|
||||
frames.append(img)
|
||||
|
||||
if not frames:
|
||||
continue
|
||||
|
||||
frames_tensor = torch.stack(frames)
|
||||
|
||||
# Predict with correct language
|
||||
correct_lang = original_languages[i]
|
||||
correct_ep_rewards = self.predict_episode_rewards(frames_tensor, correct_lang)
|
||||
|
||||
# Predict with incorrect language
|
||||
incorrect_lang = mismatched_languages[i]
|
||||
incorrect_ep_rewards = self.predict_episode_rewards(frames_tensor, incorrect_lang)
|
||||
|
||||
correct_rewards.append(correct_ep_rewards)
|
||||
incorrect_rewards.append(incorrect_ep_rewards)
|
||||
|
||||
except Exception as e:
|
||||
warnings.warn(f"Failed to process episode {ep_idx} for detection: {e}")
|
||||
continue
|
||||
|
||||
# Compute detection metrics
|
||||
detection_results = compute_success_failure_detection(correct_rewards, incorrect_rewards)
|
||||
|
||||
print("Success/Failure Detection Results:")
|
||||
print(f" Detection accuracy: {detection_results['detection_accuracy']:.4f}")
|
||||
print(f" Mean correct final reward: {detection_results['mean_correct_final']:.4f}")
|
||||
print(f" Mean incorrect final reward: {detection_results['mean_incorrect_final']:.4f}")
|
||||
print(f" Separation score: {detection_results['separation_score']:.4f}")
|
||||
print(f" Episode pairs evaluated: {detection_results['num_pairs']}")
|
||||
|
||||
return detection_results
|
||||
|
||||
def comprehensive_evaluation(
|
||||
self,
|
||||
dataset,
|
||||
num_episodes: int = 100,
|
||||
use_interquartile_mean: bool = True,
|
||||
mismatch_templates: list[str] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Run comprehensive evaluation including both VOC-S and detection metrics.
|
||||
|
||||
Returns:
|
||||
Combined evaluation results
|
||||
"""
|
||||
print("=" * 60)
|
||||
print("COMPREHENSIVE RLEARN EVALUATION")
|
||||
print("=" * 60)
|
||||
|
||||
# VOC-S evaluation
|
||||
voc_results = self.evaluate_voc_s(
|
||||
dataset, num_episodes=num_episodes, use_interquartile_mean=use_interquartile_mean
|
||||
)
|
||||
|
||||
print("\n" + "=" * 40)
|
||||
|
||||
# Success/failure detection
|
||||
detection_results = self.evaluate_success_failure_detection(
|
||||
dataset, num_episodes=num_episodes, mismatch_templates=mismatch_templates
|
||||
)
|
||||
|
||||
# Combined results
|
||||
results = {
|
||||
"voc_s": voc_results,
|
||||
"detection": detection_results,
|
||||
"overall_score": (
|
||||
voc_results["voc_s_iqm"] * 0.6 + detection_results["detection_accuracy"] * 0.4
|
||||
), # Weighted combination
|
||||
}
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print(f"OVERALL EVALUATION SCORE: {results['overall_score']:.4f}")
|
||||
print("=" * 60)
|
||||
|
||||
return results
|
||||
@@ -0,0 +1,902 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
RLearN: Video-Language Conditioned Reward Model
|
||||
|
||||
Inputs
|
||||
- images: (B, T, C, H, W) sequence of frames (or single frame with T=1)
|
||||
- language: list[str] of length B (goal/instruction)
|
||||
|
||||
High-level Architecture
|
||||
|
||||
images (B,T,C,H,W)
|
||||
|
|
||||
| per-frame encode
|
||||
v
|
||||
+------------------------------+
|
||||
| Vision Encoder (frozen) | e.g. SigLIP2 vision tower
|
||||
+------------------------------+
|
||||
|
|
||||
| pooled per-frame embeddings (BT, H_v)
|
||||
v
|
||||
reshape -> (B, T, H_v) -- Linear proj --> (B, T, D)
|
||||
+ Positional Encoding [0..T)
|
||||
+ Optional first-frame bias
|
||||
|
|
||||
| language (B, str)
|
||||
| |
|
||||
| v
|
||||
| +------------------------------+
|
||||
| | Text Encoder (frozen) | e.g. SigLIP2 text tower
|
||||
| +------------------------------+
|
||||
| |
|
||||
| | pooled text embedding (B, H_t)
|
||||
| v
|
||||
| Linear proj -> (B, D)
|
||||
| |
|
||||
+-----------------v----------------------+
|
||||
|
|
||||
+--------------------------v---------------------------+
|
||||
| Temporal Causal Transformer (n_layers, n_heads) |
|
||||
| - self-attention over time with causal mask |
|
||||
| - cross-attention to a single language token |
|
||||
+--------------------------+---------------------------+
|
||||
|
|
||||
LayerNorm + Linear Head (D -> 1)
|
||||
|
|
||||
v
|
||||
Output
|
||||
- reward_logits: (B, T', 1) with T' ≤ T (affected by stride and frame dropout)
|
||||
|
||||
Training
|
||||
- Loss: composite loss with progress regression, spatial-aware InfoNCE, and ReWiND reversible ranking
|
||||
|
||||
Notes
|
||||
- Backbones (vision/text) are frozen by default; only projections, temporal module, and head are trainable.
|
||||
- Stride/frame dropout applied during training can subsample timesteps.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.constants import OBS_IMAGE, OBS_IMAGES, OBS_LANGUAGE, REWARD
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.rlearn.configuration_rlearn import RLearNConfig
|
||||
|
||||
|
||||
class RLearNPolicy(PreTrainedPolicy):
|
||||
"""Video-language conditioned reward model.
|
||||
|
||||
- Visual encoder: frozen SigLIP2 (via transformers AutoModel), returns per-frame embeddings.
|
||||
- Text encoder: frozen SigLIP2 text tower, returns a language embedding.
|
||||
- Temporal module: causal transformer over time that cross-attends to language embedding.
|
||||
- Output: per-timestep reward logits; trainable small head.
|
||||
"""
|
||||
|
||||
config_class = RLearNConfig
|
||||
name = "rlearn"
|
||||
|
||||
def __init__(self, config: RLearNConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
# Encoders
|
||||
from transformers import AutoModel, AutoProcessor
|
||||
|
||||
self.vision_text_model = AutoModel.from_pretrained(config.model_name, trust_remote_code=True)
|
||||
self.processor = AutoProcessor.from_pretrained(config.model_name, trust_remote_code=True)
|
||||
|
||||
# Detect towers
|
||||
if hasattr(self.vision_text_model, "vision_model") and hasattr(self.vision_text_model, "text_model"):
|
||||
self.vision_encoder = self.vision_text_model.vision_model
|
||||
self.text_encoder = self.vision_text_model.text_model
|
||||
self.vision_hidden = getattr(self.vision_text_model.config, "vision_config", None).hidden_size
|
||||
self.text_hidden = getattr(self.vision_text_model.config, "text_config", None).hidden_size
|
||||
else:
|
||||
# Fallback if AutoModel exposes pooled outputs directly (rare for SigLIP2)
|
||||
self.vision_encoder = self.vision_text_model
|
||||
self.text_encoder = self.vision_text_model
|
||||
self.vision_hidden = getattr(self.vision_text_model.config, "hidden_size", 768)
|
||||
self.text_hidden = getattr(self.vision_text_model.config, "hidden_size", 768)
|
||||
|
||||
if config.freeze_backbones:
|
||||
for p in self.vision_encoder.parameters():
|
||||
p.requires_grad = False
|
||||
for p in self.text_encoder.parameters():
|
||||
p.requires_grad = False
|
||||
|
||||
# Linear projections to the shared temporal model dimension
|
||||
self.visual_proj = nn.Linear(self.vision_hidden, config.dim_model)
|
||||
self.text_proj = nn.Linear(self.text_hidden, config.dim_model)
|
||||
|
||||
# Positional encodings over time
|
||||
self.register_buffer(
|
||||
"positional_encoding",
|
||||
create_sinusoidal_pos_encoding(config.max_seq_len, config.dim_model),
|
||||
persistent=False,
|
||||
)
|
||||
# Optional first-frame learned bias to discourage position cheating
|
||||
self.first_frame_bias = (
|
||||
nn.Parameter(torch.zeros(1, 1, config.dim_model))
|
||||
if config.use_first_frame_positional_bias
|
||||
else None
|
||||
)
|
||||
|
||||
# Temporal aggregator: causal transformer over time with language cross-attention
|
||||
self.temporal = TemporalCausalTransformer(
|
||||
dim_model=config.dim_model,
|
||||
n_heads=config.n_heads,
|
||||
n_layers=config.n_layers,
|
||||
dim_feedforward=config.dim_feedforward,
|
||||
dropout=config.dropout,
|
||||
pre_norm=config.pre_norm,
|
||||
)
|
||||
|
||||
# Reward head with proper initialization
|
||||
head_linear = nn.Linear(config.dim_model, 1)
|
||||
# Initialize with small weights and bias to output values around 0
|
||||
nn.init.normal_(head_linear.weight, mean=0.0, std=0.02)
|
||||
nn.init.constant_(head_linear.bias, 0.0) # Start with 0 bias, sigmoid(0) = 0.5
|
||||
|
||||
head_layers: list[nn.Module] = [head_linear]
|
||||
if config.use_tanh_head:
|
||||
head_layers.append(nn.Tanh())
|
||||
self.head = nn.Sequential(*head_layers)
|
||||
# Projection from scalar value summary to text embedding dim for InfoNCE
|
||||
self.value_to_text_proj = nn.Linear(1, config.dim_model)
|
||||
|
||||
# Spatial attention for InfoNCE
|
||||
self.spatial_cross_attn = nn.MultiheadAttention(
|
||||
embed_dim=config.dim_model, num_heads=config.n_heads, batch_first=True
|
||||
)
|
||||
self.spatial_norm = nn.LayerNorm(config.dim_model)
|
||||
|
||||
# Simple frame dropout probability
|
||||
self.frame_dropout_p = config.frame_dropout_p
|
||||
self.stride = max(1, config.stride)
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
# Train only projections, temporal module and head by default if backbones are frozen
|
||||
return [p for p in self.parameters() if p.requires_grad]
|
||||
|
||||
def reset(self):
|
||||
pass
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: # Required by base class
|
||||
raise NotImplementedError("RLearN is a reward model and does not predict actions")
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor: # Required by base class
|
||||
raise NotImplementedError("RLearN is a reward model and does not select actions")
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_rewards(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Predict per-timestep rewards for evaluation.
|
||||
|
||||
Args:
|
||||
batch: Input batch with OBS_IMAGES and optionally OBS_LANGUAGE
|
||||
|
||||
Returns:
|
||||
Predicted rewards tensor of shape (B, T)
|
||||
"""
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
|
||||
# Extract frames and form (B, T, C, H, W)
|
||||
frames = extract_visual_sequence(batch)
|
||||
B, T, C, H, W = frames.shape
|
||||
|
||||
# Apply stride (no dropout during eval)
|
||||
idx = torch.arange(0, T, self.stride, device=frames.device)
|
||||
frames = frames[:, idx]
|
||||
B, T_eff, C, H, W = frames.shape # NEW: effective length after stride
|
||||
|
||||
# Encode language
|
||||
lang_emb = encode_language(
|
||||
batch.get(OBS_LANGUAGE, None), self.text_encoder, self.processor, batch_size=B
|
||||
)
|
||||
lang_emb = self.text_proj(lang_emb) # (B, D)
|
||||
|
||||
# ---- NEW: use the HF processor to standardize size & normalization ----
|
||||
# Flatten (B, T_eff, C, H, W) -> (BT, C, H, W)
|
||||
BT = B * T_eff
|
||||
flat = frames.reshape(BT, C, H, W).detach().cpu()
|
||||
|
||||
# Convert to uint8 HWC numpy (processor prefers PIL/np)
|
||||
# If already in [0,1], scale to [0,255]
|
||||
if flat.dtype != torch.uint8:
|
||||
if flat.numel() > 0 and float(flat.max()) <= 1.0:
|
||||
flat = flat * 255.0
|
||||
flat = flat.clamp(0, 255).round().to(torch.uint8)
|
||||
|
||||
images = [flat[k].permute(1, 2, 0).numpy() for k in range(flat.size(0))]
|
||||
|
||||
proc_out = self.processor(images=images, return_tensors="pt")
|
||||
pixel_values = proc_out["pixel_values"].to(next(self.vision_encoder.parameters()).device)
|
||||
# ----------------------------------------------------------------------
|
||||
|
||||
# Encode frames through visual tower per frame
|
||||
vision_outputs = self.vision_encoder(pixel_values=pixel_values)
|
||||
|
||||
# Extract CLS tokens for temporal modeling
|
||||
if hasattr(vision_outputs, "last_hidden_state"):
|
||||
cls_tokens = vision_outputs.last_hidden_state[:, 0] # (BT, D_vision)
|
||||
else:
|
||||
raise RuntimeError("Vision encoder must output last_hidden_state")
|
||||
|
||||
# Project CLS tokens for temporal sequence
|
||||
visual_seq = self.visual_proj(cls_tokens).reshape(B, T_eff, self.config.dim_model) # (B, T', D)
|
||||
|
||||
# Add temporal positional encodings and optional first-frame bias
|
||||
pe = (
|
||||
self.positional_encoding[: visual_seq.shape[1]]
|
||||
.unsqueeze(0)
|
||||
.to(visual_seq.dtype)
|
||||
.to(visual_seq.device)
|
||||
)
|
||||
visual_seq = visual_seq + pe
|
||||
if self.first_frame_bias is not None:
|
||||
visual_seq = visual_seq.clone()
|
||||
visual_seq[:, :1] = visual_seq[:, :1] + self.first_frame_bias
|
||||
|
||||
# Temporal model with cross-attention to language
|
||||
temporal_features = self.temporal(visual_seq, lang_emb, return_features=True) # (B, T', D)
|
||||
values = self.head(temporal_features).squeeze(-1) # (B, T')
|
||||
|
||||
return values
|
||||
|
||||
def normalize_inputs(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
# Initial version: no-op; rely on upstream processors if any
|
||||
return batch
|
||||
|
||||
def normalize_targets(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
# Initial version: no-op
|
||||
return batch
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
||||
"""Compute training loss and logs.
|
||||
|
||||
Expected batch keys:
|
||||
- OBS_IMAGES: list[Tensor] of shape [(B, C, H, W), ...] per time step or stacked (B, T, C, H, W)
|
||||
- OBS_LANGUAGE: optional string tokens already tokenized externally or raw strings
|
||||
- REWARD: (B, T) or (B,) target rewards
|
||||
"""
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = self.normalize_targets(batch)
|
||||
|
||||
# Extract frames and form (B, T, C, H, W)
|
||||
frames = extract_visual_sequence(batch)
|
||||
B, T, C, H, W = frames.shape
|
||||
|
||||
# Apply stride and frame dropout during training
|
||||
idx = torch.arange(0, T, self.stride, device=frames.device)
|
||||
if self.training and self.frame_dropout_p > 0.0 and T > 1:
|
||||
mask = torch.rand_like(idx.float()) > self.frame_dropout_p
|
||||
idx = idx[mask.long().bool()]
|
||||
if idx.numel() == 0:
|
||||
idx = torch.tensor([0], device=frames.device)
|
||||
frames = frames[:, idx]
|
||||
|
||||
# Encode language
|
||||
lang_emb = encode_language(
|
||||
batch.get(OBS_LANGUAGE, None), self.text_encoder, self.processor, batch_size=B
|
||||
)
|
||||
lang_emb = self.text_proj(lang_emb) # (B, D)
|
||||
|
||||
# Encode frames through visual tower per frame
|
||||
# Flatten time for batched encode
|
||||
BT = B * frames.shape[1]
|
||||
flat = frames.reshape(BT, C, H, W)
|
||||
|
||||
# Use HF processor to properly resize and normalize images
|
||||
# Convert to CPU for processing, then move back to device
|
||||
flat_cpu = flat.detach().cpu()
|
||||
|
||||
# Convert to uint8 HWC numpy format expected by processor
|
||||
if flat_cpu.dtype != torch.uint8:
|
||||
if flat_cpu.numel() > 0 and float(flat_cpu.max()) <= 1.0:
|
||||
flat_cpu = flat_cpu * 255.0
|
||||
flat_cpu = flat_cpu.clamp(0, 255).round().to(torch.uint8)
|
||||
|
||||
# Convert to list of numpy arrays
|
||||
images = [flat_cpu[k].permute(1, 2, 0).numpy() for k in range(flat_cpu.size(0))]
|
||||
|
||||
# Process with HF processor (resizes to 256x256 and normalizes)
|
||||
proc_out = self.processor(images=images, return_tensors="pt")
|
||||
pixel_values = proc_out["pixel_values"].to(next(self.vision_encoder.parameters()).device)
|
||||
|
||||
# Encode through vision model
|
||||
vision_outputs = self.vision_encoder(pixel_values=pixel_values)
|
||||
|
||||
# Extract BOTH CLS token and spatial patches
|
||||
if hasattr(vision_outputs, "last_hidden_state"):
|
||||
all_tokens = vision_outputs.last_hidden_state # (BT, num_tokens, D)
|
||||
cls_tokens = all_tokens[:, 0] # (BT, D) - CLS token for temporal modeling
|
||||
spatial_tokens = all_tokens[:, 1:] # (BT, num_patches, D) - spatial patches
|
||||
else:
|
||||
raise RuntimeError("Vision encoder must output last_hidden_state with spatial features")
|
||||
|
||||
# Project CLS tokens for temporal sequence
|
||||
visual_seq = self.visual_proj(cls_tokens).reshape(B, -1, self.config.dim_model) # (B, T', D)
|
||||
|
||||
# Keep spatial features for spatial-aware losses (project them too)
|
||||
# Assuming 16x16 patches for 256x256 image with patch_size=16
|
||||
num_patches = spatial_tokens.shape[1]
|
||||
spatial_features = self.visual_proj(spatial_tokens).reshape(
|
||||
B, -1, num_patches, self.config.dim_model
|
||||
) # (B, T', num_patches, D)
|
||||
|
||||
# Add temporal positional encodings and optional first-frame bias
|
||||
pe = (
|
||||
self.positional_encoding[: visual_seq.shape[1]]
|
||||
.unsqueeze(0)
|
||||
.to(visual_seq.dtype)
|
||||
.to(visual_seq.device)
|
||||
)
|
||||
visual_seq = visual_seq + pe
|
||||
if self.first_frame_bias is not None:
|
||||
visual_seq = visual_seq.clone()
|
||||
visual_seq[:, :1] = visual_seq[:, :1] + self.first_frame_bias
|
||||
|
||||
# Temporal model with cross-attention to language
|
||||
temporal_features = self.temporal(visual_seq, lang_emb, return_features=True) # (B, T', D)
|
||||
values = self.head(temporal_features).squeeze(-1) # (B, T')
|
||||
|
||||
# Targets
|
||||
target = batch.get(REWARD, None)
|
||||
loss_dict: dict[str, float] = {}
|
||||
if target is None:
|
||||
# If no labels, return zeros loss and logits for inference
|
||||
loss = values.mean() * 0.0
|
||||
loss_dict["has_labels"] = 0.0
|
||||
return loss, {**loss_dict, "values_mean": values.mean().item()}
|
||||
|
||||
# Align target with sampled timesteps
|
||||
if target.dim() == 1:
|
||||
target = target.unsqueeze(1) # (B, 1)
|
||||
target = target[:, idx]
|
||||
|
||||
# Composite loss
|
||||
# 1) Progress regression on z-scored values to match normalized progress labels y in [0,1]
|
||||
|
||||
# Debug: Check if values have enough variance
|
||||
values_std = values.std()
|
||||
if values_std < 1e-4:
|
||||
# Early in training, model outputs are nearly constant
|
||||
# Use direct MSE loss without z-scoring to encourage variance
|
||||
import logging
|
||||
|
||||
logging.info(f"Low variance in values (std={values_std:.6f}), using direct MSE")
|
||||
# Apply sigmoid directly to raw values to get them in [0,1] range
|
||||
prog_pred = torch.sigmoid(values * 10.0) # Scale up to encourage learning
|
||||
L_prog = F.mse_loss(prog_pred, torch.clamp(target, 0.0, 1.0))
|
||||
else:
|
||||
# Normal case: use z-score normalization
|
||||
zV = zscore(values, eps=self.config.zscore_eps)
|
||||
# Check for NaN after zscore
|
||||
if torch.isnan(zV).any():
|
||||
import logging
|
||||
|
||||
logging.warning(f"NaN after zscore. Values: {values}, zV: {zV}")
|
||||
# Fallback to direct sigmoid
|
||||
prog_pred = torch.sigmoid(values * 10.0)
|
||||
else:
|
||||
prog_pred = torch.sigmoid(zV)
|
||||
|
||||
L_prog = F.mse_loss(prog_pred, torch.clamp(target, 0.0, 1.0))
|
||||
|
||||
# Mismatched pairs: randomly shuffle language within batch and require near-zero progress
|
||||
if self.training and torch.rand(()) < self.config.mismatch_lang_prob and values.size(0) > 1:
|
||||
shuffled = torch.randperm(B, device=values.device)
|
||||
lang_mismatch = lang_emb[shuffled]
|
||||
mismatch_feat = self.temporal(visual_seq, lang_mismatch, return_features=True)
|
||||
mismatch_V = self.head(mismatch_feat).squeeze(-1)
|
||||
L_prog_mismatch = F.mse_loss(
|
||||
torch.sigmoid(zscore(mismatch_V, eps=self.config.zscore_eps)), torch.zeros_like(target)
|
||||
)
|
||||
else:
|
||||
L_prog_mismatch = torch.zeros((), device=values.device)
|
||||
|
||||
# 2) Spatial-Aware InfoNCE: Use language to attend to relevant spatial regions
|
||||
# Take late timesteps' spatial features
|
||||
k = min(self.config.last_k_for_nce, spatial_features.shape[1])
|
||||
late_spatial = spatial_features[:, -k:].mean(dim=1) # (B, num_patches, D)
|
||||
|
||||
# Language queries spatial patches via cross-attention
|
||||
lang_query = lang_emb.unsqueeze(1) # (B, 1, D)
|
||||
attended_spatial, spatial_attn_weights = self.spatial_cross_attn(
|
||||
query=lang_query, key=late_spatial, value=late_spatial, need_weights=True
|
||||
)
|
||||
attended_spatial = self.spatial_norm(attended_spatial).squeeze(1) # (B, D)
|
||||
|
||||
# Contrastive loss with spatially-attended features
|
||||
attended_spatial = F.normalize(attended_spatial, dim=-1)
|
||||
lang_norm = F.normalize(lang_emb, dim=-1)
|
||||
logits_spatial = (attended_spatial @ lang_norm.t()) / self.config.nce_temperature # (B, B)
|
||||
targets_nce = torch.arange(B, device=values.device)
|
||||
L_spatial_nce = F.cross_entropy(logits_spatial, targets_nce)
|
||||
|
||||
# 3) ReWiND Reversible Ranking: Learn from both forward and reversed trajectories
|
||||
# This teaches the model what constitutes progress vs undoing progress
|
||||
L_rank_forward, L_rank_reverse = reversible_ranking_loss(
|
||||
values,
|
||||
target,
|
||||
margin=self.config.ranking_margin,
|
||||
num_pairs=self.config.num_ranking_pairs,
|
||||
min_gap=self.config.min_rank_gap,
|
||||
)
|
||||
L_rewind = L_rank_forward + L_rank_reverse
|
||||
|
||||
# Check for NaNs in individual loss components
|
||||
if torch.isnan(L_prog):
|
||||
import logging
|
||||
|
||||
logging.warning(f"NaN in L_prog. Values: {values}, Target: {target}")
|
||||
# Return a small loss with gradients instead of zero
|
||||
L_prog = values.mean() * 0.0 + 0.01
|
||||
|
||||
if torch.isnan(L_spatial_nce):
|
||||
import logging
|
||||
|
||||
logging.warning("NaN in L_spatial_nce")
|
||||
# Use a dummy loss that maintains gradients
|
||||
L_spatial_nce = attended_spatial.mean() * 0.0 + 0.01
|
||||
|
||||
if torch.isnan(L_rewind):
|
||||
import logging
|
||||
|
||||
logging.warning("NaN in L_rewind")
|
||||
# Use values to maintain gradient flow
|
||||
L_rewind = values.mean() * 0.0 + 0.01
|
||||
|
||||
loss = (
|
||||
self.config.lambda_prog * (L_prog + L_prog_mismatch)
|
||||
+ self.config.lambda_spatial_nce * L_spatial_nce
|
||||
+ self.config.lambda_rewind * L_rewind
|
||||
)
|
||||
|
||||
# Final NaN check
|
||||
if torch.isnan(loss):
|
||||
import logging
|
||||
|
||||
logging.warning("NaN loss detected, using fallback loss")
|
||||
# Use a small loss that maintains gradients
|
||||
loss = values.mean() * 0.0 + 0.01
|
||||
|
||||
loss_dict.update(
|
||||
{
|
||||
"loss_prog": L_prog.item() if not torch.isnan(L_prog) else 0.0,
|
||||
"loss_prog_mismatch": L_prog_mismatch.item() if not torch.isnan(L_prog_mismatch) else 0.0,
|
||||
"loss_spatial_nce": L_spatial_nce.item() if not torch.isnan(L_spatial_nce) else 0.0,
|
||||
"loss_rewind_forward": L_rank_forward.item() if not torch.isnan(L_rank_forward) else 0.0,
|
||||
"loss_rewind_reverse": L_rank_reverse.item() if not torch.isnan(L_rank_reverse) else 0.0,
|
||||
}
|
||||
)
|
||||
|
||||
loss_dict["loss"] = loss.item()
|
||||
loss_dict["values_mean"] = values.mean().item()
|
||||
return loss, loss_dict
|
||||
|
||||
|
||||
class TemporalCausalTransformer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim_model: int,
|
||||
n_heads: int,
|
||||
n_layers: int,
|
||||
dim_feedforward: int,
|
||||
dropout: float,
|
||||
pre_norm: bool,
|
||||
):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
TemporalCausalTransformerLayer(dim_model, n_heads, dim_feedforward, dropout, pre_norm)
|
||||
for _ in range(n_layers)
|
||||
]
|
||||
)
|
||||
self.norm = nn.LayerNorm(dim_model)
|
||||
self.head = nn.Linear(dim_model, 1)
|
||||
|
||||
def forward(self, x: Tensor, lang_emb: Tensor, return_features: bool = False) -> Tensor:
|
||||
# x: (B, T, D), lang_emb: (B, D)
|
||||
B, T, D = x.shape
|
||||
# Prepare language as a single token for cross-attention context
|
||||
lang_token = lang_emb.unsqueeze(1) # (B, 1, D)
|
||||
|
||||
x = x.transpose(0, 1) # (T, B, D)
|
||||
lang_token = lang_token.transpose(0, 1) # (1, B, D)
|
||||
causal_mask = generate_causal_mask(T, device=x.device)
|
||||
for layer in self.layers:
|
||||
x = layer(x, lang_token, causal_mask)
|
||||
x = self.norm(x)
|
||||
x = x.transpose(0, 1) # (B, T, D)
|
||||
if return_features:
|
||||
return x
|
||||
return self.head(x) # (B, T, 1)
|
||||
|
||||
|
||||
class TemporalCausalTransformerLayer(nn.Module):
|
||||
def __init__(self, dim_model: int, n_heads: int, dim_feedforward: int, dropout: float, pre_norm: bool):
|
||||
super().__init__()
|
||||
self.self_attn = nn.MultiheadAttention(dim_model, n_heads, dropout=dropout, batch_first=False)
|
||||
self.cross_attn = nn.MultiheadAttention(dim_model, n_heads, dropout=dropout, batch_first=False)
|
||||
self.linear1 = nn.Linear(dim_model, dim_feedforward)
|
||||
self.linear2 = nn.Linear(dim_feedforward, dim_model)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.dropout1 = nn.Dropout(dropout)
|
||||
self.dropout2 = nn.Dropout(dropout)
|
||||
self.dropout3 = nn.Dropout(dropout)
|
||||
self.norm1 = nn.LayerNorm(dim_model)
|
||||
self.norm2 = nn.LayerNorm(dim_model)
|
||||
self.norm3 = nn.LayerNorm(dim_model)
|
||||
self.activation = F.gelu
|
||||
self.pre_norm = pre_norm
|
||||
|
||||
def forward(self, x: Tensor, lang_token: Tensor, causal_mask: Tensor) -> Tensor:
|
||||
# Self-attention with causal mask
|
||||
residual = x
|
||||
if self.pre_norm:
|
||||
x = self.norm1(x)
|
||||
x = self.self_attn(x, x, x, attn_mask=causal_mask)[0]
|
||||
x = residual + self.dropout1(x)
|
||||
if not self.pre_norm:
|
||||
x = self.norm1(x)
|
||||
|
||||
# Cross-attention to language token (keys/values from language, queries are time tokens)
|
||||
residual = x
|
||||
if self.pre_norm:
|
||||
x = self.norm2(x)
|
||||
# Broadcast language token across time
|
||||
T = x.shape[0]
|
||||
lang_kv = lang_token.expand(1, x.shape[1], x.shape[2]) # (1, B, D)
|
||||
x = self.cross_attn(x, lang_kv, lang_kv)[0]
|
||||
x = residual + self.dropout2(x)
|
||||
if not self.pre_norm:
|
||||
x = self.norm2(x)
|
||||
|
||||
# Feed-forward
|
||||
residual = x
|
||||
if self.pre_norm:
|
||||
x = self.norm3(x)
|
||||
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
|
||||
x = residual + self.dropout3(x)
|
||||
if not self.pre_norm:
|
||||
x = self.norm3(x)
|
||||
return x
|
||||
|
||||
|
||||
def create_sinusoidal_pos_encoding(max_len: int, dim: int) -> Tensor:
|
||||
position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1) # (L, 1)
|
||||
div_term = torch.exp(torch.arange(0, dim, 2).float() * (-math.log(10000.0) / dim)) # (D/2)
|
||||
pe = torch.zeros(max_len, dim)
|
||||
pe[:, 0::2] = torch.sin(position * div_term)
|
||||
pe[:, 1::2] = torch.cos(position * div_term)
|
||||
return pe # (L, D)
|
||||
|
||||
|
||||
def generate_causal_mask(T: int, device=None) -> Tensor:
|
||||
# (T, T) with True where masking should occur for MultiheadAttention expects float mask or bool?
|
||||
mask = torch.full((T, T), float("-inf"), device=device)
|
||||
mask = torch.triu(mask, diagonal=1)
|
||||
return mask
|
||||
|
||||
|
||||
def extract_visual_sequence(batch: dict[str, Tensor]) -> Tensor:
|
||||
# Accept various image key formats from datasets
|
||||
# Try multiple common key patterns
|
||||
|
||||
# List of possible image keys to check, in order of preference
|
||||
possible_keys = [
|
||||
OBS_IMAGES, # 'observation.images'
|
||||
OBS_IMAGE, # 'observation.image'
|
||||
"observation.images.image", # nested format from some datasets
|
||||
]
|
||||
|
||||
for key in possible_keys:
|
||||
if key in batch:
|
||||
image_val = batch[key]
|
||||
|
||||
if isinstance(image_val, list) and len(image_val) > 0:
|
||||
# List of (B, C, H, W) -> stack over time
|
||||
return torch.stack(image_val, dim=1)
|
||||
elif torch.is_tensor(image_val):
|
||||
# Tensor of shape (B, T, C, H, W) or (B, C, H, W)
|
||||
if image_val.dim() == 5:
|
||||
# Already has time dimension
|
||||
return image_val
|
||||
elif image_val.dim() == 4:
|
||||
# Add time dimension (single frame)
|
||||
return image_val.unsqueeze(1)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"'{key}' must be a Tensor of shape (B,T,C,H,W) or (B,C,H,W), got shape {image_val.shape}"
|
||||
)
|
||||
|
||||
# If no image key found, provide helpful error with available keys
|
||||
available_keys = list(batch.keys())
|
||||
image_like_keys = [k for k in available_keys if "image" in k.lower()]
|
||||
raise ValueError(
|
||||
f"Could not find image data in batch. Looked for keys: {possible_keys}. "
|
||||
f"Available keys with 'image': {image_like_keys}. "
|
||||
f"All keys: {available_keys}"
|
||||
)
|
||||
|
||||
|
||||
def encode_language(
|
||||
language_input: Tensor | list | str | None, text_encoder, processor, batch_size: int
|
||||
) -> Tensor:
|
||||
# language_input can be: list[str] length B, or None
|
||||
if language_input is None:
|
||||
texts = [""] * batch_size
|
||||
elif isinstance(language_input, list):
|
||||
texts = language_input
|
||||
else:
|
||||
# Single string for the batch
|
||||
texts = [str(language_input)] * batch_size
|
||||
|
||||
inputs = processor(text=texts, padding=True, return_tensors="pt")
|
||||
inputs = {k: v.to(next(text_encoder.parameters()).device) for k, v in inputs.items()}
|
||||
outputs = text_encoder(**inputs)
|
||||
if hasattr(outputs, "pooler_output"):
|
||||
emb = outputs.pooler_output
|
||||
elif hasattr(outputs, "last_hidden_state"):
|
||||
emb = outputs.last_hidden_state[:, 0]
|
||||
else:
|
||||
raise RuntimeError("Unsupported text encoder output structure")
|
||||
return emb
|
||||
|
||||
|
||||
def pairwise_ranking_loss(logits: Tensor, target: Tensor, margin: float = 0.1, num_pairs: int = 32) -> Tensor:
|
||||
# logits, target: (B, T)
|
||||
B, T = logits.shape
|
||||
if T < 2:
|
||||
return logits.mean() * 0.0
|
||||
# Sample pairs i<j and enforce r_j > r_i when target_j > target_i
|
||||
losses = []
|
||||
for _ in range(num_pairs):
|
||||
i = torch.randint(0, T - 1, (B,), device=logits.device)
|
||||
j = i + torch.randint(1, T - i.max(), (1,), device=logits.device)
|
||||
j = j.expand_as(i)
|
||||
li = logits[torch.arange(B), i]
|
||||
lj = logits[torch.arange(B), j]
|
||||
yi = target[torch.arange(B), i]
|
||||
yj = target[torch.arange(B), j]
|
||||
sign = torch.sign(yj - yi)
|
||||
# hinge: max(0, margin - sign*(lj-li))
|
||||
loss = F.relu(margin - sign * (lj - li))
|
||||
losses.append(loss.mean())
|
||||
return torch.stack(losses).mean()
|
||||
|
||||
|
||||
def zscore(x: Tensor, eps: float = 1e-3) -> Tensor:
|
||||
"""Z-score normalization with numerical stability."""
|
||||
# Handle both (B,) and (B, T) shapes
|
||||
if x.dim() == 1:
|
||||
x = x.unsqueeze(1) # Make it (B, 1)
|
||||
|
||||
B, T = x.shape
|
||||
|
||||
# If only one timestep, can't compute meaningful std across time
|
||||
if T == 1:
|
||||
# Just use tanh to bound values instead of z-score
|
||||
return torch.tanh(x * 0.1) # Scale and bound
|
||||
|
||||
# Compute mean and std across time dimension
|
||||
mean = x.mean(dim=1, keepdim=True)
|
||||
std = x.std(dim=1, keepdim=True, unbiased=False)
|
||||
|
||||
# Check if std is valid (not zero or NaN)
|
||||
std_is_valid = (std > eps) & (~torch.isnan(std))
|
||||
|
||||
# Safe std for division
|
||||
std_safe = torch.where(std_is_valid, std, torch.ones_like(std))
|
||||
|
||||
# Compute z-score where valid
|
||||
z = (x - mean) / std_safe
|
||||
|
||||
# For invalid cases, use tanh of centered values
|
||||
z_fallback = torch.tanh((x - mean) * 0.1)
|
||||
z = torch.where(std_is_valid.expand_as(z), z, z_fallback)
|
||||
|
||||
# Final safety clamp
|
||||
z = torch.clamp(z, min=-5.0, max=5.0)
|
||||
|
||||
# Check for any remaining NaNs and replace with 0
|
||||
z = torch.nan_to_num(z, nan=0.0)
|
||||
|
||||
return z
|
||||
|
||||
|
||||
def temporal_logistic_ranking(
|
||||
values: Tensor, margin: float = 0.1, min_gap: int = 1, num_pairs: int = 64
|
||||
) -> Tensor:
|
||||
"""VLC-style temporal monotonicity: encourage V[j] > V[i] for j>i.
|
||||
|
||||
Samples pairs (i<j) with a minimum gap and applies softplus(m - (Vj - Vi)).
|
||||
"""
|
||||
B, T = values.shape
|
||||
if T < 2:
|
||||
return values.mean() * 0.0
|
||||
losses = []
|
||||
device = values.device
|
||||
for _ in range(num_pairs):
|
||||
i = torch.randint(0, max(1, T - min_gap), (B,), device=device)
|
||||
j = i + torch.randint(min_gap, T - i.max(), (1,), device=device)
|
||||
j = j.expand_as(i)
|
||||
vi = values[torch.arange(B), i]
|
||||
vj = values[torch.arange(B), j]
|
||||
losses.append(F.softplus(margin - (vj - vi)).mean())
|
||||
return torch.stack(losses).mean()
|
||||
|
||||
|
||||
def reversible_ranking_loss(
|
||||
values: Tensor, target: Tensor, margin: float = 0.1, num_pairs: int = 64, min_gap: int = 1
|
||||
) -> tuple[Tensor, Tensor]:
|
||||
"""ReWiND-style reversible ranking: learn from both forward and reversed trajectories.
|
||||
|
||||
Key insight: If a trajectory shows progress forward, its reverse shows undoing progress.
|
||||
By training on both, the model learns what constitutes progress vs regression.
|
||||
|
||||
Args:
|
||||
values: (B, T) predicted values
|
||||
target: (B, T) progress labels (0 to 1 for forward progress)
|
||||
margin: Margin for ranking loss
|
||||
num_pairs: Number of (far, near) pairs to sample
|
||||
min_gap: Minimum temporal gap between pairs
|
||||
|
||||
Returns:
|
||||
forward_loss: Loss from forward trajectory pairs
|
||||
reverse_loss: Loss from reversed trajectory pairs
|
||||
"""
|
||||
B, T = values.shape
|
||||
if T < 2:
|
||||
zero_loss = values.mean() * 0.0
|
||||
return zero_loss, zero_loss
|
||||
|
||||
device = values.device
|
||||
|
||||
# Forward trajectory ranking: later frames should have higher values
|
||||
forward_losses = []
|
||||
for _ in range(num_pairs // 2):
|
||||
# Sample far-near pairs (far is earlier, near is later)
|
||||
far_idx = torch.randint(0, max(1, T - min_gap), (B,), device=device)
|
||||
near_idx = far_idx + torch.randint(min_gap, T - far_idx.max(), (1,), device=device)
|
||||
near_idx = near_idx.expand_as(far_idx)
|
||||
|
||||
v_far = values[torch.arange(B), far_idx]
|
||||
v_near = values[torch.arange(B), near_idx]
|
||||
|
||||
# Near (later) should have higher value than far (earlier)
|
||||
forward_losses.append(F.softplus(margin - (v_near - v_far)).mean())
|
||||
|
||||
# Reversed trajectory ranking: treat reversed sequence with inverted progress
|
||||
# Reverse both values and targets
|
||||
reversed_values = values.flip(dims=[1]) # Reverse time dimension
|
||||
reversed_target = 1.0 - target.flip(dims=[1]) # Invert and reverse progress
|
||||
|
||||
reverse_losses = []
|
||||
for _ in range(num_pairs // 2):
|
||||
# In reversed trajectory, what was "later" is now "earlier"
|
||||
far_idx = torch.randint(0, max(1, T - min_gap), (B,), device=device)
|
||||
near_idx = far_idx + torch.randint(min_gap, T - far_idx.max(), (1,), device=device)
|
||||
near_idx = near_idx.expand_as(far_idx)
|
||||
|
||||
v_far_rev = reversed_values[torch.arange(B), far_idx]
|
||||
v_near_rev = reversed_values[torch.arange(B), near_idx]
|
||||
|
||||
# In reversed trajectory with inverted progress,
|
||||
# near (which was originally earlier) should still have higher value
|
||||
reverse_losses.append(F.softplus(margin - (v_near_rev - v_far_rev)).mean())
|
||||
|
||||
forward_loss = torch.stack(forward_losses).mean() if forward_losses else values.mean() * 0.0
|
||||
reverse_loss = torch.stack(reverse_losses).mean() if reverse_losses else values.mean() * 0.0
|
||||
|
||||
return forward_loss, reverse_loss
|
||||
|
||||
|
||||
def intra_trajectory_directional_ranking(
|
||||
values: Tensor, progress: Tensor, margin: float = 0.2, num_pairs: int = 64, min_gap: int = 1
|
||||
) -> Tensor:
|
||||
"""Directional ranking within trajectory based on progress labels.
|
||||
|
||||
For pairs i<j within a trajectory:
|
||||
- If progress increases (y_j > y_i), enforce V_j > V_i
|
||||
- If progress decreases (y_j < y_i), enforce V_j < V_i
|
||||
- Ignore pairs where progress is unchanged
|
||||
|
||||
Uses logistic loss: log(1 + exp(m - s_ij * (V_j - V_i)))
|
||||
where s_ij = sign(y_j - y_i)
|
||||
"""
|
||||
B, T = values.shape
|
||||
if T < 2:
|
||||
return values.mean() * 0.0
|
||||
|
||||
losses = []
|
||||
device = values.device
|
||||
|
||||
for _ in range(num_pairs):
|
||||
# Sample time pairs i < j
|
||||
i = torch.randint(0, max(1, T - min_gap), (B,), device=device)
|
||||
max_j = min(T, i.max() + T - min_gap)
|
||||
j = i + torch.randint(min_gap, max_j - i.min(), (1,), device=device)
|
||||
j = j.expand_as(i).clamp(max=T - 1)
|
||||
|
||||
# Get values and progress at sampled times
|
||||
vi = values[torch.arange(B), i]
|
||||
vj = values[torch.arange(B), j]
|
||||
yi = progress[torch.arange(B), i]
|
||||
yj = progress[torch.arange(B), j]
|
||||
|
||||
# Compute direction sign
|
||||
s_ij = torch.sign(yj - yi)
|
||||
|
||||
# Only compute loss for non-zero progress differences
|
||||
mask = s_ij != 0
|
||||
if mask.any():
|
||||
diff = vj - vi
|
||||
loss = torch.log1p(torch.exp(margin - s_ij * diff))
|
||||
losses.append(loss[mask].mean())
|
||||
|
||||
return torch.stack(losses).mean() if losses else values.mean() * 0.0
|
||||
|
||||
|
||||
def inter_instruction_contrastive_ranking(
|
||||
values_correct: Tensor, values_incorrect: Tensor, margin: float = 0.2
|
||||
) -> Tensor:
|
||||
"""Ranking between correct and incorrect instructions for same frames.
|
||||
|
||||
Enforces V_t(z) > V_t(z') where z is correct instruction and z' is incorrect.
|
||||
Uses logistic loss: log(1 + exp(m - (V_t(z) - V_t(z'))))
|
||||
"""
|
||||
diff = values_correct - values_incorrect
|
||||
return torch.log1p(torch.exp(margin - diff)).mean()
|
||||
|
||||
|
||||
def flatness_under_mismatch(values: Tensor, epsilon: float = 0.05, num_pairs: int = 32) -> Tensor:
|
||||
"""Enforce flat values over time for mismatched instructions.
|
||||
|
||||
For trajectory with wrong instruction, V should not change much over time.
|
||||
Uses Huber loss to allow small variations within epsilon band.
|
||||
"""
|
||||
B, T = values.shape
|
||||
if T < 2:
|
||||
return values.mean() * 0.0
|
||||
|
||||
losses = []
|
||||
device = values.device
|
||||
|
||||
for _ in range(num_pairs):
|
||||
i = torch.randint(0, T - 1, (B,), device=device)
|
||||
j = torch.randint(i.min() + 1, T, (1,), device=device)
|
||||
j = j.expand_as(i)
|
||||
|
||||
vi = values[torch.arange(B), i]
|
||||
vj = values[torch.arange(B), j]
|
||||
|
||||
# Huber loss with small delta for near-zero target
|
||||
diff = vj - vi
|
||||
loss = F.huber_loss(diff, torch.zeros_like(diff), delta=epsilon)
|
||||
losses.append(loss)
|
||||
|
||||
return torch.stack(losses).mean()
|
||||
@@ -0,0 +1,126 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.constants import OBS_LANGUAGE
|
||||
from lerobot.policies.rlearn.configuration_rlearn import RLearNConfig
|
||||
from lerobot.processor import (
|
||||
DeviceProcessor,
|
||||
NormalizerProcessor,
|
||||
RenameProcessor,
|
||||
RobotProcessor,
|
||||
ToBatchProcessor,
|
||||
TokenizerProcessor,
|
||||
UnnormalizerProcessor,
|
||||
)
|
||||
from lerobot.processor.pipeline import (
|
||||
ComplementaryDataProcessor,
|
||||
EnvTransition,
|
||||
ProcessorStepRegistry,
|
||||
TransitionKey,
|
||||
)
|
||||
|
||||
|
||||
def make_rlearn_processor(
|
||||
config: RLearNConfig, dataset_stats: dict[str, dict[str, Any]] | None = None
|
||||
) -> tuple[RobotProcessor, RobotProcessor]:
|
||||
"""Build pre/post processors for RLearN.
|
||||
|
||||
Responsibilities moved out of the model:
|
||||
- Normalize inputs (images) using dataset stats
|
||||
- Ensure batching
|
||||
- Map complementary_data.task to observation.language when available
|
||||
- Tokenize language into observation.language.tokens / attention_mask
|
||||
- Move to/from device
|
||||
"""
|
||||
|
||||
input_steps = [
|
||||
# No renaming by default, but keep for future extensibility
|
||||
RenameProcessor(rename_map={}),
|
||||
NormalizerProcessor(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
ToBatchProcessor(),
|
||||
RLearnLanguageFromTaskProcessor(),
|
||||
# Use the same model name for tokenizer to keep vocab aligned with text tower
|
||||
TokenizerProcessor(
|
||||
tokenizer_name=config.model_name,
|
||||
max_length=128,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
padding_side="right",
|
||||
),
|
||||
DeviceProcessor(device=config.device),
|
||||
]
|
||||
|
||||
output_steps = [
|
||||
DeviceProcessor(device="cpu"),
|
||||
UnnormalizerProcessor(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
]
|
||||
|
||||
return RobotProcessor(steps=input_steps, name="robot_preprocessor"), RobotProcessor(
|
||||
steps=output_steps, name="robot_postprocessor"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="rlearn_language_from_task")
|
||||
class RLearnLanguageFromTaskProcessor(ComplementaryDataProcessor):
|
||||
"""Copy complementary_data['task'] into observation['observation.language'] if present.
|
||||
|
||||
This ensures the model can consume a raw language string when tokenization is not used,
|
||||
while TokenizerProcessor can still create tokenized fields.
|
||||
"""
|
||||
|
||||
task_key: str = "task"
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition: # type: ignore[override]
|
||||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
|
||||
if not complementary_data or self.task_key not in complementary_data:
|
||||
return transition
|
||||
|
||||
task = complementary_data.get(self.task_key)
|
||||
if task is None:
|
||||
return transition
|
||||
|
||||
# Normalize to list[str]
|
||||
if isinstance(task, str):
|
||||
task_list = [task]
|
||||
elif isinstance(task, list) and all(isinstance(t, str) for t in task):
|
||||
task_list = task
|
||||
else:
|
||||
return transition
|
||||
|
||||
observation = transition.get(TransitionKey.OBSERVATION) or {}
|
||||
# Do not overwrite if user already provided observation.language
|
||||
if OBS_LANGUAGE not in observation:
|
||||
observation[OBS_LANGUAGE] = task_list
|
||||
transition[TransitionKey.OBSERVATION] = observation
|
||||
return transition
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: # noqa: D401
|
||||
# Adds nothing to features; only mirrors complementary_data.task into observation
|
||||
return features
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {"task_key": self.task_key}
|
||||
@@ -0,0 +1,136 @@
|
||||
## General Value/Reward Learning:
|
||||
|
||||
I want to implement a general/universal vision and language value function or reward model for robotics/video tasks. Also called a video language conditioned reward model. Integrated with already existing LeRobot code if convenient, use the LeRobot Dataset for dataset and store the reward for a frame in the lerobot frame itself.
|
||||
|
||||
Inspired by these papers:
|
||||
|
||||
- ReWiND; https://arxiv.org/pdf/2505.10911 (Most applicable and main paper I want to implement ideas from) and code: https://github.com/lucidrains/rewind-reward-pytorch
|
||||
- LIV; https://arxiv.org/pdf/2306.00958 (Most applicable and 2nd main paper I want to implement ideas from) and code https://github.com/penn-pal-lab/LI
|
||||
- VLC: Video-Language Critic: Transferable Reward Functions for Language-Conditioned Robotics: https://arxiv.org/pdf/2405.19988 (Most applicable and 3rd paper I want to implement ideas from) and code: https://github.com/minttusofia/video_language_critic
|
||||
|
||||
And these papers which are also relevant:
|
||||
|
||||
- https://www.dyna.co/dyna-1/research (Main company I want to reproduce the eventual results from)
|
||||
- vip; https://arxiv.org/pdf/2210.00030
|
||||
- uvd; https://arxiv.org/pdf/2310.08581
|
||||
- vlm in context; https://arxiv.org/pdf/2411.04549
|
||||
- https://www.youtube.com/watch?v=JfZYtpEisoM
|
||||
|
||||
Little less relevant but still similar papers:
|
||||
|
||||
- Learning Generalizable Robotic Reward Functions from “In-The-Wild” Human Videos,
|
||||
- XIRL: Cross-embodiment Inverse Reinforcement Learning,
|
||||
- Video-Language Critic: Transferable Reward https://arxiv.org/pdf/2405.19988
|
||||
- Functions for Language-Conditioned Robotics,
|
||||
- LORel, Language-Driven Representation Learning for Robotics https://sites.google.com/view/robotlorel
|
||||
- RoboCLIP: One Demonstration is Enough to Learn Robot Policies https://arxiv.org/pdf/2310.07899
|
||||
- Points2Rewards: learn first key points and then uses the keypoints to learn general value function/policy https://semrob.github.io/docs/2025_rss_semrob.github.io_paper20.pdf
|
||||
- Language-Driven Representation Learning for Robotics: https://arxiv.org/pdf/2302.12766v1
|
||||
- R3M: A Universal Visual Representation for Robot Manipulation: https://arxiv.org/pdf/2203.12601v3
|
||||
|
||||
Input should be the current image or whole video and the task goal specified in text/language. Output is current reward.
|
||||
Archiutecture:
|
||||
_ inputs: video o1:T (or current o1:t), language z;
|
||||
_ google/siglip2-large-patch16-256: https://huggingface.co/google/siglip2-large-patch16-256 \* Temporal module: small causal transformer (“cross-modal sequential aggregator”), with first-frame positional embedding (to avoid position cheating), frame-dropout, and stride sampling; outputs per-timestep logits.
|
||||
|
||||
Loss: See this chatgpt thread: https://chatgpt.com/s/t_68999a50a0b081919abc365cdd205e01
|
||||
|
||||
Past images: (for example a reward methoid go to 3rd floor, has to know what floor it was on and what pas actions it did, can we attend or encorperate images of decision from history in one way?) Maybe via this paper: Learning Long-Context Diffusion Policies via Past-Token Prediction
|
||||
|
||||
Amount of frames needed for test/generalization: 1M frames? or ~20% of IPEC-COMMUNITY/bc_z_lerobot
|
||||
|
||||
Eval:
|
||||
Implement something like voc score , or ROC rank order correlation between reward leanredna and ev reward from sim, or use something else to do additional evaluation
|
||||
|
||||
Ideas:
|
||||
|
||||
- Incorporate training on multiple horizons: as in label same dataset for longer horizons: make a sandwich (long), put cheese on bread (medium) and even smaller horizons: go down or close gripper (small)
|
||||
- Incorporate navigation goals “walk towards the kitchen”, make sure we fix CLIP contrastive learning issue of positional text misunderstanding where model doesnnt learn difference between "horse right of cow" and "horse left of cow" “Move right” potentially train with more other data or even actionable world models such as Genie 3 (https://deepmind.google/discover/blog/genie-3-a-new-frontier-for-world-models/)
|
||||
|
||||
How to use a general reward model (use cases): - Train rl policy on it - Success detection - Do exploraion - Do task via planning and search to optimize reward - Filter out bad episodes in large datasets from imitation learning
|
||||
|
||||
Potential Datasets: (start with dataset that is most clean for this and works best with chosen way of doing evals)
|
||||
_ Epic-Kitchens-100
|
||||
_ Something-Something v. 2 Dataset https://www.qualcomm.com/developer/software/something-something-v-2-dataset
|
||||
_ Ego4D (3000 hours)
|
||||
_ Open X-Embodiment (OXE)
|
||||
_ Age bot world: https://huggingface.co/datasets/agibot-world/AgiBotWorld-Alpha
|
||||
_ GTEA+ Gaze: https://cbs.ic.gatech.edu/fpv/
|
||||
_ YouCook2 dataset
|
||||
_ HOWTO100M: https://www.di.ens.fr/willow/research/howto100m/
|
||||
|
||||
Also in plan include investigating/incorperating these two things:
|
||||
|
||||
- Curriculum: Start training on easier distinctions (e.g. very early vs very late frames which are easy to tell apart) then gradually ask the model to distinguish more subtle differences (frames that are closer in time). This curriculum can be implemented by initially using pairs far apart in the trajectory for ranking, then moving to closer pairs as accuracy improves.
|
||||
- Augmentations: As mentioned, heavy image augmentation (random crops, slight noise) is often used so that the reward model focuses on high-level task progress features rather than pixel-level cues. For video models, temporal augmentation like random frame skipping can also make the model robust to different speeds.
|
||||
|
||||
### Implemented Loss (Spatial-Aware Composite Loss)
|
||||
|
||||
Our implementation uses a **composite loss with spatial awareness** to address the limitations of standard contrastive learning (e.g., CLIP's inability to distinguish "move left" vs "move right"). The loss has three components:
|
||||
|
||||
##### 1) Progress Regression Loss (L_prog)
|
||||
|
||||
Predicts normalized progress values for each timestep:
|
||||
|
||||
$$
|
||||
L_{\text{prog}} = \text{MSE}(\sigma(z(V_t)), y_t)
|
||||
$$
|
||||
|
||||
where $z(·)$ is z-score normalization, $\sigma$ is sigmoid, and $y_t \in [0,1]$ is the progress label.
|
||||
**Purpose:** Grounds the model in actual task progress, not just visual-language similarity.
|
||||
|
||||
##### 2) Spatial-Aware InfoNCE Loss (L_spatial_nce)
|
||||
|
||||
Instead of using pooled features, we:
|
||||
|
||||
- Extract spatial patch features from SigLIP2's last_hidden_state (e.g., 16×16 patches)
|
||||
- Use cross-attention where language queries attend to relevant spatial regions
|
||||
- Compute contrastive loss on the attended spatial features
|
||||
|
||||
$$
|
||||
L_{\text{spatial-nce}} = -\log \frac{\exp(s_{ii}/\tau)}{\sum_j \exp(s_{ij}/\tau)}
|
||||
$$
|
||||
|
||||
where $s_{ij}$ is the similarity between spatially-attended features from trajectory $i$ and language $j$.
|
||||
**Purpose:** Preserves spatial information that pooling discards, enabling distinction of spatial relationships.
|
||||
|
||||
##### 3) ReWiND Reversible Ranking Loss (L_rewind)
|
||||
|
||||
Based on ReWiND's key insight: learn from both forward AND reversed trajectories.
|
||||
The loss has two components:
|
||||
|
||||
- **Forward ranking**: Sample (far, near) pairs where near is later in time, enforce $V_{\text{near}} > V_{\text{far}}$
|
||||
- **Reverse ranking**: Reverse the trajectory and invert progress labels, then apply same ranking
|
||||
|
||||
$$
|
||||
L_{\text{rewind}} = L_{\text{forward}} + L_{\text{reverse}}
|
||||
$$
|
||||
|
||||
where both use: $\text{softplus}(m - (V_{\text{near}} - V_{\text{far}}))$
|
||||
|
||||
**Purpose:** By training on reversed trajectories with inverted progress, the model learns to distinguish progress from undoing progress. This is ReWiND's core contribution - understanding that tasks can be reversible.
|
||||
|
||||
##### Total Loss:
|
||||
|
||||
$$
|
||||
L = \lambda_{\text{prog}} L_{\text{prog}} + \lambda_{\text{spatial-nce}} L_{\text{spatial-nce}} + \lambda_{\text{rewind}} L_{\text{rewind}}
|
||||
$$
|
||||
|
||||
Default weights: $\lambda_{\text{prog}}=1.0$, $\lambda_{\text{spatial-nce}}=0.5$, $\lambda_{\text{rewind}}=0.4$
|
||||
|
||||
### TODOs:
|
||||
|
||||
- Implement first architecture [x]
|
||||
- Implement processors [x]
|
||||
- Choose right loss metric(s) [x]
|
||||
- Make dataset with script that generated the dataset (IPEC-COMMUNITY/bc_z_lerobot) ready in lerobot format (and be able to visualize in dataset visualizer)
|
||||
- Annotate with ReWiND-style 0→1 progress rewards [x]
|
||||
- Visualize to check [x]
|
||||
- Implement eval score or metric that is robust and can deal with generalization/is a good metric to try different architectures. And use it in an eval jupyter notebook with visalization of the live reward next to the video for part of the dataset: VOC score and score with correct and incorrect language captions [x]
|
||||
- Do first training []
|
||||
- Try different losses []
|
||||
- Switch to DINO v3 as encoder Base 86 M: https://huggingface.co/facebook/dinov3-vitb16-pretrain-lvd1689m with HuggingFaceTB/SmolLM2-135M-Instruct ?
|
||||
- Add more artificial text to dataset generated by vlm (google gemini) []
|
||||
- See google gemini vlm caption from Leandro []
|
||||
- Multiple captions per video, creat method to generate as much data as possible etc [] https://arxiv.org/abs/2508.13446
|
||||
- Add droid []
|
||||
@@ -145,10 +145,13 @@ class TokenizerProcessor:
|
||||
observation = dict(observation) # Make a copy
|
||||
|
||||
# Add tokenized data to observation
|
||||
observation[f"{OBS_LANGUAGE}.tokens"] = tokenized_prompt["input_ids"]
|
||||
observation[f"{OBS_LANGUAGE}.attention_mask"] = tokenized_prompt["attention_mask"].to(
|
||||
dtype=torch.bool
|
||||
)
|
||||
input_ids = tokenized_prompt["input_ids"]
|
||||
attention_mask = tokenized_prompt.get("attention_mask")
|
||||
if attention_mask is None:
|
||||
# Some tokenizers (e.g., SigLIP text) may not return attention_mask; default to ones
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
observation[f"{OBS_LANGUAGE}.tokens"] = input_ids
|
||||
observation[f"{OBS_LANGUAGE}.attention_mask"] = attention_mask.to(dtype=torch.bool)
|
||||
|
||||
transition[TransitionKey.OBSERVATION.value] = observation # type: ignore[misc]
|
||||
return transition
|
||||
|
||||
@@ -0,0 +1,335 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Add ReWiND-style linear progress rewards to existing LeRobot datasets.
|
||||
|
||||
This script creates a complete copy of the dataset with rewards added to each frame.
|
||||
It downloads the original dataset (including videos), adds rewards, and pushes everything to a new repository.
|
||||
|
||||
Usage:
|
||||
# Create full dataset copy with rewards
|
||||
python src/lerobot/scripts/annotate_dataset_rewards.py --input-repo IPEC-COMMUNITY/bc_z_lerobot --output-repo username/bc_z_with_rewards
|
||||
|
||||
# Test with 1% of episodes
|
||||
python src/lerobot/scripts/annotate_dataset_rewards.py --input-repo IPEC-COMMUNITY/bc_z_lerobot --output-repo username/test_rewards --percentage 1
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from tempfile import mkdtemp
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.constants import REWARD
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
|
||||
|
||||
def compute_linear_progress_reward(episode_length: int) -> np.ndarray:
|
||||
"""
|
||||
Compute linear progress rewards from 0 to 1.
|
||||
|
||||
ReWiND-style: progress increases linearly from 0 at start to 1 at completion.
|
||||
|
||||
Args:
|
||||
episode_length: Number of frames in the episode
|
||||
|
||||
Returns:
|
||||
rewards: Array of shape (episode_length,) with values linearly from 0 to 1
|
||||
"""
|
||||
return np.linspace(0, 1, episode_length, dtype=np.float32)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Add linear progress rewards to LeRobot dataset and push to Hub"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input-repo",
|
||||
type=str,
|
||||
default="IPEC-COMMUNITY/bc_z_lerobot",
|
||||
help="Input dataset repository on HuggingFace Hub",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-repo",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Output dataset repository name (e.g., username/dataset_with_rewards)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--percentage",
|
||||
type=float,
|
||||
default=100.0,
|
||||
help="Percentage of episodes to process (useful for testing, e.g., 1 for 1%%)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--private",
|
||||
action="store_true",
|
||||
help="Make the output repository private",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--local-dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Local directory to save the modified dataset (defaults to ~/.cache/huggingface/lerobot/<output-repo>)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
print("=" * 60)
|
||||
print("FULL DATASET COPY WITH REWARDS")
|
||||
print("This will download the entire dataset including videos,")
|
||||
print("add rewards, and push everything to a new repository.")
|
||||
print("=" * 60)
|
||||
|
||||
# First, load just the metadata to get total episodes
|
||||
print(f"\nLoading metadata from Hub: {args.input_repo}")
|
||||
|
||||
# Load metadata only first
|
||||
metadata = LeRobotDatasetMetadata(repo_id=args.input_repo)
|
||||
total_episodes = metadata.total_episodes
|
||||
|
||||
# Calculate which episodes to process
|
||||
num_episodes_to_process = max(1, int(total_episodes * args.percentage / 100))
|
||||
episodes_to_load = list(range(num_episodes_to_process)) # Load only first N episodes
|
||||
|
||||
print(f"Dataset has {total_episodes} episodes")
|
||||
print(f"Processing {num_episodes_to_process} episodes ({args.percentage}%)")
|
||||
|
||||
# Determine local directory for the new dataset
|
||||
if args.local_dir:
|
||||
local_dir = Path(args.local_dir)
|
||||
else:
|
||||
from lerobot.constants import HF_LEROBOT_HOME
|
||||
|
||||
local_dir = HF_LEROBOT_HOME / args.output_repo
|
||||
|
||||
# Use a temporary directory for downloading source dataset
|
||||
temp_source_dir = Path(mkdtemp(prefix="lerobot_source_"))
|
||||
|
||||
# Load the dataset with videos to temp directory
|
||||
print("Downloading dataset with videos to temp directory...")
|
||||
print(f"Temp directory: {temp_source_dir}")
|
||||
dataset = LeRobotDataset(
|
||||
repo_id=args.input_repo,
|
||||
root=temp_source_dir, # Temporary location for source
|
||||
episodes=episodes_to_load if args.percentage < 100 else None,
|
||||
download_videos=True, # Download videos
|
||||
)
|
||||
|
||||
print(f"Downloaded {dataset.num_episodes} episodes with {dataset.num_frames} frames")
|
||||
|
||||
# Create a new dataset with rewards
|
||||
print(f"\nCreating new dataset at: {local_dir}")
|
||||
|
||||
# Clean up any existing directory from previous runs
|
||||
if local_dir.exists():
|
||||
print(f"⚠️ Directory already exists: {local_dir}")
|
||||
print(" Removing it to start fresh...")
|
||||
shutil.rmtree(local_dir)
|
||||
|
||||
# Define features including reward
|
||||
# Simply copy all features from the original dataset
|
||||
new_features = dict(dataset.features)
|
||||
|
||||
# Add reward feature
|
||||
new_features[REWARD] = {"shape": (1,), "dtype": "float32", "names": ["reward"]}
|
||||
|
||||
# Determine which features are videos
|
||||
video_keys = dataset.meta.video_keys if hasattr(dataset.meta, "video_keys") else []
|
||||
image_keys = dataset.meta.image_keys if hasattr(dataset.meta, "image_keys") else []
|
||||
visual_keys = set(video_keys + image_keys)
|
||||
|
||||
print(f" Visual features to be handled as videos: {visual_keys}")
|
||||
|
||||
# Check for language features
|
||||
language_keys = [
|
||||
k
|
||||
for k in dataset.features.keys()
|
||||
if any(lang in k.lower() for lang in ["language", "task", "instruction", "text"])
|
||||
]
|
||||
if language_keys:
|
||||
print(f" Language/task features found: {language_keys}")
|
||||
|
||||
# Copy dataset structure to new location
|
||||
new_dataset = LeRobotDataset.create(
|
||||
repo_id=args.output_repo,
|
||||
root=local_dir,
|
||||
fps=dataset.fps,
|
||||
features=new_features,
|
||||
robot_type=dataset.meta.robot_type,
|
||||
use_videos=len(dataset.meta.video_keys) > 0,
|
||||
)
|
||||
|
||||
# Process each episode
|
||||
print("\nAdding rewards to episodes...")
|
||||
|
||||
episode_data_index = dataset.episode_data_index
|
||||
|
||||
for ep_idx, episode_idx in enumerate(tqdm(episodes_to_load)):
|
||||
# Get episode boundaries
|
||||
ep_start = episode_data_index["from"][ep_idx].item()
|
||||
ep_end = episode_data_index["to"][ep_idx].item()
|
||||
episode_length = ep_end - ep_start
|
||||
|
||||
# Compute linear progress rewards for this episode
|
||||
rewards = compute_linear_progress_reward(episode_length)
|
||||
|
||||
# Get episode metadata
|
||||
episode_info = dataset.meta.episodes[episode_idx]
|
||||
tasks = episode_info.get("tasks", [])
|
||||
if not tasks:
|
||||
# Try to get task from first frame if not in episode metadata
|
||||
first_frame = dataset[ep_start]
|
||||
if "task" in first_frame:
|
||||
tasks = [first_frame["task"]]
|
||||
else:
|
||||
tasks = [""]
|
||||
|
||||
# Process each frame in the episode
|
||||
for frame_idx in range(episode_length):
|
||||
global_idx = ep_start + frame_idx
|
||||
|
||||
# Get original frame data
|
||||
frame_data = dataset[global_idx]
|
||||
|
||||
# Create frame dict for the new dataset
|
||||
frame = {}
|
||||
for key in dataset.features:
|
||||
# Skip only auto-generated metadata fields
|
||||
# Keep task-related fields that contain language annotations
|
||||
if key in ["index", "episode_index", "frame_index", "timestamp"]:
|
||||
continue
|
||||
|
||||
# For visual features that are videos, extract the actual frame
|
||||
if key in visual_keys:
|
||||
# Get the image data to save as temporary files
|
||||
if key in frame_data:
|
||||
img = frame_data[key]
|
||||
# Convert to numpy if tensor
|
||||
if isinstance(img, torch.Tensor):
|
||||
img = img.cpu().numpy()
|
||||
# Ensure channels-last format (H, W, C) for saving
|
||||
if len(img.shape) == 3 and img.shape[0] in [1, 3, 4]:
|
||||
img = np.transpose(img, (1, 2, 0))
|
||||
|
||||
# Resize to match expected shape if needed
|
||||
expected_shape = new_features[key].get("shape")
|
||||
if expected_shape and img.shape != tuple(expected_shape):
|
||||
# Try to match the shape - handle both HWC and CHW formats
|
||||
if len(expected_shape) == 3:
|
||||
# Determine if expected is HWC or CHW
|
||||
if expected_shape[-1] in [1, 3, 4]: # Likely HWC
|
||||
target_h, target_w = expected_shape[0], expected_shape[1]
|
||||
elif expected_shape[0] in [
|
||||
1,
|
||||
3,
|
||||
4,
|
||||
]: # Likely CHW - shouldn't happen after transpose
|
||||
target_h, target_w = expected_shape[1], expected_shape[2]
|
||||
else:
|
||||
# Assume HWC
|
||||
target_h, target_w = expected_shape[0], expected_shape[1]
|
||||
|
||||
# Resize using PIL for quality
|
||||
if img.dtype != np.uint8:
|
||||
img = (img * 255).astype(np.uint8)
|
||||
pil_img = Image.fromarray(img)
|
||||
pil_img = pil_img.resize((target_w, target_h), Image.Resampling.LANCZOS)
|
||||
img = np.array(pil_img)
|
||||
|
||||
frame[key] = img
|
||||
continue
|
||||
|
||||
if key in frame_data:
|
||||
value = frame_data[key]
|
||||
|
||||
# Handle language/task fields specially
|
||||
if key == "task" and isinstance(value, str):
|
||||
# Skip string task - will be passed separately to add_frame
|
||||
continue
|
||||
elif key == "task_index":
|
||||
# Skip task_index as it will be regenerated based on task
|
||||
continue
|
||||
elif key in ["observation.language", "language", "instruction"] and isinstance(
|
||||
value, str
|
||||
):
|
||||
# Keep language fields as-is
|
||||
frame[key] = value
|
||||
continue
|
||||
|
||||
# Regular field processing
|
||||
# Convert tensors to numpy for saving
|
||||
if isinstance(value, torch.Tensor):
|
||||
value = value.cpu().numpy()
|
||||
|
||||
# Ensure arrays are the right shape
|
||||
if hasattr(value, "shape") and len(value.shape) == 0:
|
||||
# Convert scalar to 1D array
|
||||
value = np.array([value])
|
||||
|
||||
frame[key] = value
|
||||
|
||||
# Add reward
|
||||
frame[REWARD] = np.array([rewards[frame_idx]], dtype=np.float32)
|
||||
|
||||
# Get task for this specific frame (might vary within episode)
|
||||
if "task" in frame_data:
|
||||
task = frame_data["task"]
|
||||
else:
|
||||
task = tasks[0] if tasks else ""
|
||||
|
||||
# Add frame to new dataset
|
||||
timestamp = frame_idx / dataset.fps
|
||||
new_dataset.add_frame(frame, task=task, timestamp=timestamp)
|
||||
|
||||
# Save the episode (this will encode videos from the saved frames)
|
||||
new_dataset.save_episode()
|
||||
|
||||
print(
|
||||
f"\n✓ Created new dataset with rewards: {new_dataset.num_episodes} episodes, {new_dataset.num_frames} frames"
|
||||
)
|
||||
|
||||
# Push to Hub
|
||||
print(f"\nPushing to Hub: {args.output_repo}")
|
||||
new_dataset.push_to_hub(
|
||||
private=args.private,
|
||||
push_videos=True,
|
||||
)
|
||||
|
||||
print(f"\n✓ Dataset pushed to: https://huggingface.co/datasets/{args.output_repo}")
|
||||
|
||||
# Clean up temporary source directory
|
||||
if temp_source_dir.exists():
|
||||
print("\nCleaning up temporary files...")
|
||||
shutil.rmtree(temp_source_dir)
|
||||
|
||||
# Print summary
|
||||
print("\n=== Summary ===")
|
||||
print(f"Input dataset: {args.input_repo}")
|
||||
print(f"Output dataset: {args.output_repo}")
|
||||
print(f"Episodes processed: {num_episodes_to_process}/{total_episodes} ({args.percentage}%)")
|
||||
print(f"Frames with rewards: {new_dataset.num_frames}")
|
||||
print("Reward type: Linear progress (0→1)")
|
||||
print("===============")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,591 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
OPTIMIZED VERSION: Add ReWiND-style linear progress rewards to existing LeRobot datasets with parallel processing.
|
||||
|
||||
This script creates a complete copy of the dataset with rewards added to each frame.
|
||||
It downloads the original dataset (including videos), adds rewards, and pushes everything to a new repository.
|
||||
|
||||
Key optimizations:
|
||||
- Parallel episode processing using multiprocessing
|
||||
- Batch frame processing within episodes
|
||||
- Concurrent video encoding
|
||||
- Optimized image operations
|
||||
- Better memory management
|
||||
|
||||
Usage:
|
||||
# Test with 1% of episodes using 4 workers
|
||||
python src/lerobot/scripts/annotate_dataset_rewards_optimized.py --input-repo IPEC-COMMUNITY/bc_z_lerobot --output-repo pepijn223/rewards_bc_z_1p --percentage 1 --num-workers 4
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import shutil
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from multiprocessing import Pool, cpu_count
|
||||
from pathlib import Path
|
||||
from tempfile import mkdtemp
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.constants import REWARD
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
|
||||
# Set up logging
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def compute_linear_progress_reward(episode_length: int) -> np.ndarray:
|
||||
"""
|
||||
Compute linear progress rewards from 0 to 1.
|
||||
|
||||
ReWiND-style: progress increases linearly from 0 at start to 1 at completion.
|
||||
|
||||
Args:
|
||||
episode_length: Number of frames in the episode
|
||||
|
||||
Returns:
|
||||
rewards: Array of shape (episode_length,) with values linearly from 0 to 1
|
||||
"""
|
||||
return np.linspace(0, 1, episode_length, dtype=np.float32)
|
||||
|
||||
|
||||
def process_image_batch(images: list[np.ndarray], target_shape: tuple[int, ...]) -> list[np.ndarray]:
|
||||
"""
|
||||
Process a batch of images efficiently.
|
||||
|
||||
Args:
|
||||
images: List of numpy arrays representing images
|
||||
target_shape: Target shape for resizing
|
||||
|
||||
Returns:
|
||||
List of processed images
|
||||
"""
|
||||
processed = []
|
||||
|
||||
if len(target_shape) == 3:
|
||||
# Determine target dimensions
|
||||
if target_shape[-1] in [1, 3, 4]: # Likely HWC
|
||||
target_h, target_w = target_shape[0], target_shape[1]
|
||||
elif target_shape[0] in [1, 3, 4]: # Likely CHW
|
||||
target_h, target_w = target_shape[1], target_shape[2]
|
||||
else:
|
||||
target_h, target_w = target_shape[0], target_shape[1]
|
||||
|
||||
# Process all images
|
||||
for img in images:
|
||||
# Ensure channels-last format
|
||||
if len(img.shape) == 3 and img.shape[0] in [1, 3, 4]:
|
||||
img = np.transpose(img, (1, 2, 0))
|
||||
|
||||
# Resize if needed
|
||||
if img.shape[:2] != (target_h, target_w):
|
||||
if img.dtype != np.uint8:
|
||||
img = (img * 255).astype(np.uint8)
|
||||
pil_img = Image.fromarray(img)
|
||||
pil_img = pil_img.resize((target_w, target_h), Image.Resampling.LANCZOS)
|
||||
img = np.array(pil_img)
|
||||
|
||||
processed.append(img)
|
||||
else:
|
||||
processed = images
|
||||
|
||||
return processed
|
||||
|
||||
|
||||
def process_episode_chunk(args: tuple[int, int, dict, Any]) -> tuple[int, list[dict], list[str]]:
|
||||
"""
|
||||
Process a chunk of frames from an episode in parallel.
|
||||
|
||||
Args:
|
||||
args: Tuple of (chunk_start, chunk_end, shared_data, episode_data)
|
||||
|
||||
Returns:
|
||||
Tuple of (episode_idx, frames_data, tasks)
|
||||
"""
|
||||
chunk_start, chunk_end, shared_data, episode_data = args
|
||||
|
||||
episode_idx = episode_data["episode_idx"]
|
||||
ep_start = episode_data["ep_start"]
|
||||
episode_length = episode_data["episode_length"]
|
||||
rewards = episode_data["rewards"]
|
||||
tasks_default = episode_data["tasks"]
|
||||
dataset = episode_data["dataset"]
|
||||
new_features = shared_data["new_features"]
|
||||
visual_keys = shared_data["visual_keys"]
|
||||
fps = shared_data["fps"]
|
||||
|
||||
frames_data = []
|
||||
tasks = []
|
||||
|
||||
# Process chunk of frames
|
||||
for frame_idx in range(chunk_start, min(chunk_end, episode_length)):
|
||||
global_idx = ep_start + frame_idx
|
||||
|
||||
# Get original frame data
|
||||
frame_data = dataset[global_idx]
|
||||
|
||||
# Create frame dict for the new dataset
|
||||
frame = {}
|
||||
|
||||
# Process all non-visual features
|
||||
for key in dataset.features:
|
||||
if key in ["index", "episode_index", "frame_index", "timestamp"]:
|
||||
continue
|
||||
|
||||
if key in visual_keys:
|
||||
# Process visual features
|
||||
if key in frame_data:
|
||||
img = frame_data[key]
|
||||
if isinstance(img, torch.Tensor):
|
||||
img = img.cpu().numpy()
|
||||
frame[key] = img
|
||||
continue
|
||||
|
||||
if key in frame_data:
|
||||
value = frame_data[key]
|
||||
|
||||
# Handle special fields
|
||||
if key == "task" and isinstance(value, str):
|
||||
tasks.append(value)
|
||||
continue
|
||||
elif key == "task_index":
|
||||
continue
|
||||
elif key in ["observation.language", "language", "instruction"] and isinstance(value, str):
|
||||
frame[key] = value
|
||||
continue
|
||||
|
||||
# Regular field processing
|
||||
if isinstance(value, torch.Tensor):
|
||||
value = value.cpu().numpy()
|
||||
|
||||
if hasattr(value, "shape") and len(value.shape) == 0:
|
||||
value = np.array([value])
|
||||
|
||||
frame[key] = value
|
||||
|
||||
# Add reward
|
||||
frame[REWARD] = np.array([rewards[frame_idx]], dtype=np.float32)
|
||||
|
||||
# Set task
|
||||
if not tasks or tasks[-1] is None:
|
||||
tasks.append(tasks_default[0] if tasks_default else "")
|
||||
|
||||
# Add timestamp
|
||||
frame["timestamp"] = frame_idx / fps
|
||||
|
||||
frames_data.append(frame)
|
||||
|
||||
return (episode_idx, frames_data, tasks)
|
||||
|
||||
|
||||
def process_episode_parallel(
|
||||
episode_data: dict, shared_data: dict, chunk_size: int = 50
|
||||
) -> tuple[int, list[dict], list[str]]:
|
||||
"""
|
||||
Process an entire episode using parallel chunk processing.
|
||||
|
||||
Args:
|
||||
episode_data: Episode-specific data
|
||||
shared_data: Shared configuration data
|
||||
chunk_size: Number of frames to process per chunk
|
||||
|
||||
Returns:
|
||||
Tuple of (episode_idx, all_frames, all_tasks)
|
||||
"""
|
||||
episode_length = episode_data["episode_length"]
|
||||
episode_idx = episode_data["episode_idx"]
|
||||
|
||||
# Create chunks
|
||||
chunks = []
|
||||
for i in range(0, episode_length, chunk_size):
|
||||
chunk_end = min(i + chunk_size, episode_length)
|
||||
chunks.append((i, chunk_end, shared_data, episode_data))
|
||||
|
||||
# Process chunks in parallel using threads (good for I/O bound operations)
|
||||
all_frames = [None] * episode_length
|
||||
all_tasks = []
|
||||
|
||||
with ThreadPoolExecutor(max_workers=4) as executor:
|
||||
futures = {executor.submit(process_episode_chunk, chunk): idx for idx, chunk in enumerate(chunks)}
|
||||
|
||||
for future in as_completed(futures):
|
||||
chunk_idx = futures[future]
|
||||
_, frames, tasks = future.result()
|
||||
|
||||
# Place frames in correct positions
|
||||
start_idx = chunks[chunk_idx][0]
|
||||
for i, frame in enumerate(frames):
|
||||
all_frames[start_idx + i] = frame
|
||||
all_tasks.extend(tasks)
|
||||
|
||||
# Filter out None values (shouldn't happen but safety check)
|
||||
all_frames = [f for f in all_frames if f is not None]
|
||||
|
||||
return (episode_idx, all_frames, all_tasks)
|
||||
|
||||
|
||||
def worker_process_episode(args: tuple[int, str, str, dict, str, str, bool]) -> dict:
|
||||
"""
|
||||
Worker function to process a single episode.
|
||||
|
||||
Args:
|
||||
args: Tuple containing (episode_idx, input_repo, output_repo, shared_data, local_dir, temp_dir, use_chunk_processing)
|
||||
|
||||
Returns:
|
||||
Dict with processing results or error
|
||||
"""
|
||||
episode_idx, input_repo, output_repo, shared_data, local_dir_str, temp_dir, use_chunk_processing = args
|
||||
|
||||
try:
|
||||
local_dir = Path(local_dir_str)
|
||||
|
||||
# Load dataset for this worker
|
||||
dataset = LeRobotDataset(
|
||||
repo_id=input_repo,
|
||||
root=Path(temp_dir),
|
||||
episodes=[episode_idx],
|
||||
download_videos=True,
|
||||
)
|
||||
|
||||
# Get episode boundaries
|
||||
episode_data_index = dataset.episode_data_index
|
||||
ep_start = episode_data_index["from"][0].item()
|
||||
ep_end = episode_data_index["to"][0].item()
|
||||
episode_length = ep_end - ep_start
|
||||
|
||||
# Compute rewards
|
||||
rewards = compute_linear_progress_reward(episode_length)
|
||||
|
||||
# Get episode metadata
|
||||
episode_info = dataset.meta.episodes[episode_idx]
|
||||
tasks = episode_info.get("tasks", [])
|
||||
if not tasks:
|
||||
first_frame = dataset[ep_start]
|
||||
if "task" in first_frame:
|
||||
tasks = [first_frame["task"]]
|
||||
else:
|
||||
tasks = [""]
|
||||
|
||||
# Prepare episode data
|
||||
episode_data = {
|
||||
"episode_idx": episode_idx,
|
||||
"ep_start": ep_start,
|
||||
"episode_length": episode_length,
|
||||
"rewards": rewards,
|
||||
"tasks": tasks,
|
||||
"dataset": dataset,
|
||||
}
|
||||
|
||||
if use_chunk_processing:
|
||||
# Process episode with chunk parallelization
|
||||
_, frames_data, frame_tasks = process_episode_parallel(episode_data, shared_data)
|
||||
else:
|
||||
# Process episode sequentially (fallback)
|
||||
frames_data = []
|
||||
frame_tasks = []
|
||||
|
||||
for frame_idx in range(episode_length):
|
||||
global_idx = ep_start + frame_idx
|
||||
frame_data = dataset[global_idx]
|
||||
|
||||
frame = {}
|
||||
for key in dataset.features:
|
||||
if key in ["index", "episode_index", "frame_index", "timestamp"]:
|
||||
continue
|
||||
|
||||
if key in shared_data["visual_keys"]:
|
||||
if key in frame_data:
|
||||
img = frame_data[key]
|
||||
if isinstance(img, torch.Tensor):
|
||||
img = img.cpu().numpy()
|
||||
|
||||
# Process image if needed
|
||||
if (
|
||||
key in shared_data["new_features"]
|
||||
and "shape" in shared_data["new_features"][key]
|
||||
):
|
||||
expected_shape = shared_data["new_features"][key]["shape"]
|
||||
img = process_image_batch([img], expected_shape)[0]
|
||||
|
||||
frame[key] = img
|
||||
continue
|
||||
|
||||
if key in frame_data:
|
||||
value = frame_data[key]
|
||||
|
||||
if key == "task" and isinstance(value, str):
|
||||
frame_tasks.append(value)
|
||||
continue
|
||||
elif key == "task_index":
|
||||
continue
|
||||
|
||||
if isinstance(value, torch.Tensor):
|
||||
value = value.cpu().numpy()
|
||||
|
||||
if hasattr(value, "shape") and len(value.shape) == 0:
|
||||
value = np.array([value])
|
||||
|
||||
frame[key] = value
|
||||
|
||||
frame[REWARD] = np.array([rewards[frame_idx]], dtype=np.float32)
|
||||
frames_data.append(frame)
|
||||
|
||||
if not frame_tasks or len(frame_tasks) <= frame_idx:
|
||||
frame_tasks.append(tasks[0] if tasks else "")
|
||||
|
||||
return {
|
||||
"episode_idx": episode_idx,
|
||||
"frames_data": frames_data,
|
||||
"tasks": frame_tasks if frame_tasks else tasks,
|
||||
"fps": dataset.fps,
|
||||
"success": True,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing episode {episode_idx}: {e}")
|
||||
return {"episode_idx": episode_idx, "error": str(e), "success": False}
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Optimized: Add linear progress rewards to LeRobot dataset with parallel processing"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input-repo",
|
||||
type=str,
|
||||
default="IPEC-COMMUNITY/bc_z_lerobot",
|
||||
help="Input dataset repository on HuggingFace Hub",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-repo",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Output dataset repository name (e.g., username/dataset_with_rewards)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--percentage",
|
||||
type=float,
|
||||
default=100.0,
|
||||
help="Percentage of episodes to process (useful for testing, e.g., 1 for 1%%)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Number of parallel workers (defaults to CPU count - 2)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--chunk-size",
|
||||
type=int,
|
||||
default=50,
|
||||
help="Number of frames to process per chunk within an episode",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--private",
|
||||
action="store_true",
|
||||
help="Make the output repository private",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--local-dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Local directory to save the modified dataset",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-chunk-processing",
|
||||
action="store_true",
|
||||
help="Disable chunk-based parallel processing within episodes",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Determine number of workers
|
||||
if args.num_workers is None:
|
||||
args.num_workers = max(1, cpu_count() - 2)
|
||||
|
||||
print("=" * 60)
|
||||
print("OPTIMIZED DATASET COPY WITH REWARDS")
|
||||
print(f"Using {args.num_workers} parallel workers")
|
||||
print("=" * 60)
|
||||
|
||||
# Load metadata
|
||||
print(f"\nLoading metadata from Hub: {args.input_repo}")
|
||||
metadata = LeRobotDatasetMetadata(repo_id=args.input_repo)
|
||||
total_episodes = metadata.total_episodes
|
||||
|
||||
# Calculate episodes to process
|
||||
num_episodes_to_process = max(1, int(total_episodes * args.percentage / 100))
|
||||
episodes_to_load = list(range(num_episodes_to_process))
|
||||
|
||||
print(f"Dataset has {total_episodes} episodes")
|
||||
print(f"Processing {num_episodes_to_process} episodes ({args.percentage}%)")
|
||||
|
||||
# Determine local directory
|
||||
if args.local_dir:
|
||||
local_dir = Path(args.local_dir)
|
||||
else:
|
||||
from lerobot.constants import HF_LEROBOT_HOME
|
||||
|
||||
local_dir = HF_LEROBOT_HOME / args.output_repo
|
||||
|
||||
# Create temporary directories for workers
|
||||
temp_base_dir = Path(mkdtemp(prefix="lerobot_parallel_"))
|
||||
worker_temp_dirs = []
|
||||
for i in range(args.num_workers):
|
||||
worker_dir = temp_base_dir / f"worker_{i}"
|
||||
worker_dir.mkdir(parents=True, exist_ok=True)
|
||||
worker_temp_dirs.append(str(worker_dir))
|
||||
|
||||
print(f"Using temporary base directory: {temp_base_dir}")
|
||||
|
||||
# Load first episode to get features and structure
|
||||
print("\nLoading dataset structure...")
|
||||
sample_dataset = LeRobotDataset(
|
||||
repo_id=args.input_repo,
|
||||
root=temp_base_dir / "sample",
|
||||
episodes=[0],
|
||||
download_videos=True,
|
||||
)
|
||||
|
||||
# Prepare features with reward
|
||||
new_features = dict(sample_dataset.features)
|
||||
new_features[REWARD] = {"shape": (1,), "dtype": "float32", "names": ["reward"]}
|
||||
|
||||
# Determine visual keys
|
||||
video_keys = sample_dataset.meta.video_keys if hasattr(sample_dataset.meta, "video_keys") else []
|
||||
image_keys = sample_dataset.meta.image_keys if hasattr(sample_dataset.meta, "image_keys") else []
|
||||
visual_keys = set(video_keys + image_keys)
|
||||
|
||||
print(f" Visual features: {visual_keys}")
|
||||
|
||||
# Clean up existing directory
|
||||
if local_dir.exists():
|
||||
print(f"⚠️ Directory already exists: {local_dir}")
|
||||
print(" Removing it to start fresh...")
|
||||
shutil.rmtree(local_dir)
|
||||
|
||||
# Create new dataset structure
|
||||
print("\nCreating new dataset structure...")
|
||||
new_dataset = LeRobotDataset.create(
|
||||
repo_id=args.output_repo,
|
||||
root=local_dir,
|
||||
fps=sample_dataset.fps,
|
||||
features=new_features,
|
||||
robot_type=sample_dataset.meta.robot_type,
|
||||
use_videos=len(sample_dataset.meta.video_keys) > 0,
|
||||
)
|
||||
|
||||
# Prepare shared data for workers
|
||||
shared_data = {
|
||||
"new_features": new_features,
|
||||
"visual_keys": visual_keys,
|
||||
"fps": sample_dataset.fps,
|
||||
}
|
||||
|
||||
# Process episodes in parallel
|
||||
print(f"\nProcessing {num_episodes_to_process} episodes with {args.num_workers} workers...")
|
||||
|
||||
# Prepare worker arguments
|
||||
worker_args = []
|
||||
for i, episode_idx in enumerate(episodes_to_load):
|
||||
# Assign worker temp directory round-robin
|
||||
temp_dir = worker_temp_dirs[i % args.num_workers]
|
||||
worker_args.append(
|
||||
(
|
||||
episode_idx,
|
||||
args.input_repo,
|
||||
args.output_repo,
|
||||
shared_data,
|
||||
str(local_dir),
|
||||
temp_dir,
|
||||
not args.no_chunk_processing,
|
||||
)
|
||||
)
|
||||
|
||||
# Process episodes using multiprocessing
|
||||
processed_episodes = {}
|
||||
failed_episodes = []
|
||||
|
||||
with Pool(processes=args.num_workers) as pool:
|
||||
# Use imap_unordered for better progress tracking
|
||||
with tqdm(total=num_episodes_to_process, desc="Processing episodes") as pbar:
|
||||
for result in pool.imap_unordered(worker_process_episode, worker_args):
|
||||
pbar.update(1)
|
||||
|
||||
if result["success"]:
|
||||
processed_episodes[result["episode_idx"]] = result
|
||||
else:
|
||||
failed_episodes.append(result["episode_idx"])
|
||||
logger.error(
|
||||
f"Failed episode {result['episode_idx']}: {result.get('error', 'Unknown error')}"
|
||||
)
|
||||
|
||||
# Add processed episodes to the new dataset in order
|
||||
print("\nSaving processed episodes to new dataset...")
|
||||
for episode_idx in tqdm(episodes_to_load, desc="Saving episodes"):
|
||||
if episode_idx in processed_episodes:
|
||||
result = processed_episodes[episode_idx]
|
||||
|
||||
# Add all frames for this episode
|
||||
for i, frame_data in enumerate(result["frames_data"]):
|
||||
task = result["tasks"][i] if i < len(result["tasks"]) else result["tasks"][0]
|
||||
timestamp = i / result["fps"]
|
||||
new_dataset.add_frame(frame_data, task=task, timestamp=timestamp)
|
||||
|
||||
# Save the episode
|
||||
new_dataset.save_episode()
|
||||
|
||||
print(
|
||||
f"\n✓ Created new dataset with rewards: {new_dataset.num_episodes} episodes, {new_dataset.num_frames} frames"
|
||||
)
|
||||
|
||||
if failed_episodes:
|
||||
print(f"⚠️ Failed to process {len(failed_episodes)} episodes: {failed_episodes}")
|
||||
|
||||
# Push to Hub
|
||||
print(f"\nPushing to Hub: {args.output_repo}")
|
||||
new_dataset.push_to_hub(
|
||||
private=args.private,
|
||||
push_videos=True,
|
||||
)
|
||||
|
||||
print(f"\n✓ Dataset pushed to: https://huggingface.co/datasets/{args.output_repo}")
|
||||
|
||||
# Clean up temporary directories
|
||||
if temp_base_dir.exists():
|
||||
print("\nCleaning up temporary files...")
|
||||
shutil.rmtree(temp_base_dir)
|
||||
|
||||
# Print summary
|
||||
print("\n=== Summary ===")
|
||||
print(f"Input dataset: {args.input_repo}")
|
||||
print(f"Output dataset: {args.output_repo}")
|
||||
print(f"Episodes processed: {num_episodes_to_process - len(failed_episodes)}/{total_episodes}")
|
||||
print(f"Frames with rewards: {new_dataset.num_frames}")
|
||||
print(f"Parallel workers used: {args.num_workers}")
|
||||
print(f"Processing time saved: ~{args.num_workers - 1}x faster")
|
||||
print("===============")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,188 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Test script for RLearN evaluation metrics.
|
||||
|
||||
This script tests the VOC-S and success/failure detection metrics with synthetic data
|
||||
to ensure they work correctly before running on real datasets.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lerobot.policies.rlearn.evaluation import (
|
||||
compute_success_failure_detection,
|
||||
compute_voc_s,
|
||||
generate_mismatched_languages,
|
||||
)
|
||||
|
||||
|
||||
def test_voc_s():
|
||||
"""Test VOC-S computation with synthetic data."""
|
||||
print("Testing VOC-S computation...")
|
||||
|
||||
# Test case 1: Perfect positive correlation (0 -> 1)
|
||||
perfect_positive = [np.linspace(0, 1, 20) for _ in range(10)]
|
||||
results = compute_voc_s(perfect_positive)
|
||||
|
||||
print("Perfect positive correlation:")
|
||||
print(f" Mean: {results['voc_s_mean']:.4f} (should be ~1.0)")
|
||||
print(f" IQM: {results['voc_s_iqm']:.4f} (should be ~1.0)")
|
||||
assert results["voc_s_mean"] > 0.95, f"Expected >0.95, got {results['voc_s_mean']}"
|
||||
|
||||
# Test case 2: Perfect negative correlation (1 -> 0)
|
||||
perfect_negative = [np.linspace(1, 0, 20) for _ in range(10)]
|
||||
results = compute_voc_s(perfect_negative)
|
||||
|
||||
print("Perfect negative correlation:")
|
||||
print(f" Mean: {results['voc_s_mean']:.4f} (should be ~-1.0)")
|
||||
print(f" IQM: {results['voc_s_iqm']:.4f} (should be ~-1.0)")
|
||||
assert results["voc_s_mean"] < -0.95, f"Expected <-0.95, got {results['voc_s_mean']}"
|
||||
|
||||
# Test case 3: No correlation (random)
|
||||
np.random.seed(42)
|
||||
random_rewards = [np.random.random(20) for _ in range(50)]
|
||||
results = compute_voc_s(random_rewards)
|
||||
|
||||
print("Random correlation:")
|
||||
print(f" Mean: {results['voc_s_mean']:.4f} (should be ~0.0)")
|
||||
print(f" IQM: {results['voc_s_iqm']:.4f} (should be ~0.0)")
|
||||
assert abs(results["voc_s_mean"]) < 0.3, f"Expected ~0, got {results['voc_s_mean']}"
|
||||
|
||||
# Test case 4: Mixed correlations
|
||||
mixed = []
|
||||
mixed.extend([np.linspace(0, 1, 15) for _ in range(5)]) # Positive
|
||||
mixed.extend([np.linspace(1, 0, 15) for _ in range(5)]) # Negative
|
||||
mixed.extend([np.random.random(15) for _ in range(5)]) # Random
|
||||
|
||||
results = compute_voc_s(mixed)
|
||||
print("Mixed correlations:")
|
||||
print(f" Mean: {results['voc_s_mean']:.4f}")
|
||||
print(f" IQM: {results['voc_s_iqm']:.4f}")
|
||||
print(f" Std: {results['voc_s_std']:.4f}")
|
||||
|
||||
print("✓ VOC-S tests passed!\n")
|
||||
|
||||
|
||||
def test_success_failure_detection():
|
||||
"""Test success/failure detection with synthetic data."""
|
||||
print("Testing Success/Failure Detection...")
|
||||
|
||||
# Test case 1: Clear separation (correct > incorrect)
|
||||
correct_rewards = [np.linspace(0, 1, 20) for _ in range(20)] # Always increasing
|
||||
incorrect_rewards = [np.linspace(0, 0.3, 20) for _ in range(20)] # Lower final values
|
||||
|
||||
results = compute_success_failure_detection(correct_rewards, incorrect_rewards)
|
||||
|
||||
print("Clear separation test:")
|
||||
print(f" Detection accuracy: {results['detection_accuracy']:.4f} (should be 1.0)")
|
||||
print(f" Mean correct: {results['mean_correct_final']:.4f}")
|
||||
print(f" Mean incorrect: {results['mean_incorrect_final']:.4f}")
|
||||
print(f" Separation score: {results['separation_score']:.4f}")
|
||||
assert results["detection_accuracy"] == 1.0, f"Expected 1.0, got {results['detection_accuracy']}"
|
||||
|
||||
# Test case 2: No separation (same distributions with some randomness)
|
||||
np.random.seed(42)
|
||||
same_rewards_1 = [np.random.normal(0.5, 0.05, 15) for _ in range(20)]
|
||||
same_rewards_2 = [np.random.normal(0.5, 0.05, 15) for _ in range(20)]
|
||||
|
||||
results = compute_success_failure_detection(same_rewards_1, same_rewards_2)
|
||||
|
||||
print("No separation test:")
|
||||
print(f" Detection accuracy: {results['detection_accuracy']:.4f} (should be ~0.5)")
|
||||
print(f" Separation score: {results['separation_score']:.4f} (should be ~0.0)")
|
||||
# Relax the assertion since random data can vary
|
||||
assert 0.2 <= results["detection_accuracy"] <= 0.8, (
|
||||
f"Expected ~0.5 (±0.3), got {results['detection_accuracy']}"
|
||||
)
|
||||
|
||||
# Test case 3: Partial separation
|
||||
np.random.seed(42)
|
||||
partial_correct = [np.random.normal(0.7, 0.1, 10) for _ in range(20)]
|
||||
partial_incorrect = [np.random.normal(0.4, 0.1, 10) for _ in range(20)]
|
||||
|
||||
results = compute_success_failure_detection(partial_correct, partial_incorrect)
|
||||
|
||||
print("Partial separation test:")
|
||||
print(f" Detection accuracy: {results['detection_accuracy']:.4f}")
|
||||
print(f" Separation score: {results['separation_score']:.4f}")
|
||||
|
||||
print("✓ Success/Failure Detection tests passed!\n")
|
||||
|
||||
|
||||
def test_mismatch_generation():
|
||||
"""Test mismatch language generation."""
|
||||
print("Testing mismatch language generation...")
|
||||
|
||||
original_languages = [
|
||||
"pick up the red ball",
|
||||
"put the cup on the table",
|
||||
"open the drawer",
|
||||
"close the door",
|
||||
]
|
||||
|
||||
# Test with default templates
|
||||
mismatched = generate_mismatched_languages(original_languages)
|
||||
|
||||
print(f"Original languages: {len(original_languages)}")
|
||||
print(f"Mismatched languages: {len(mismatched)}")
|
||||
assert len(mismatched) == len(original_languages)
|
||||
|
||||
# Ensure they're actually different
|
||||
for orig, mismatch in zip(original_languages, mismatched, strict=False):
|
||||
print(f" '{orig}' -> '{mismatch}'")
|
||||
assert orig != mismatch, "Mismatch should be different from original"
|
||||
|
||||
# Test with custom templates
|
||||
custom_templates = ["dance", "sing", "jump"]
|
||||
mismatched_custom = generate_mismatched_languages(original_languages, custom_templates)
|
||||
|
||||
print("\nWith custom templates:")
|
||||
for orig, mismatch in zip(original_languages, mismatched_custom, strict=False):
|
||||
print(f" '{orig}' -> '{mismatch}'")
|
||||
assert mismatch in custom_templates
|
||||
|
||||
print("✓ Mismatch generation tests passed!\n")
|
||||
|
||||
|
||||
def test_edge_cases():
|
||||
"""Test edge cases and error handling."""
|
||||
print("Testing edge cases...")
|
||||
|
||||
# Empty input
|
||||
empty_results = compute_voc_s([])
|
||||
assert empty_results["num_episodes"] == 0
|
||||
assert empty_results["voc_s_mean"] == 0.0
|
||||
|
||||
# Single frame episodes (should be skipped)
|
||||
single_frame = [np.array([0.5]) for _ in range(5)]
|
||||
results = compute_voc_s(single_frame)
|
||||
assert results["num_episodes"] == 0, "Single-frame episodes should be skipped"
|
||||
|
||||
# Constant rewards (should give correlation = 0)
|
||||
constant_rewards = [np.ones(10) * 0.5 for _ in range(5)]
|
||||
results = compute_voc_s(constant_rewards)
|
||||
print(f"Constant rewards correlation: {results['voc_s_mean']:.4f} (should be 0.0)")
|
||||
assert results["voc_s_mean"] == 0.0
|
||||
|
||||
# Mismatched array lengths for detection
|
||||
try:
|
||||
compute_success_failure_detection([np.array([1, 2])], [])
|
||||
assert False, "Should have raised ValueError"
|
||||
except ValueError:
|
||||
pass # Expected
|
||||
|
||||
print("✓ Edge case tests passed!\n")
|
||||
@@ -0,0 +1,237 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.constants import OBS_IMAGES, OBS_LANGUAGE, REWARD
|
||||
from lerobot.policies.factory import make_processor
|
||||
from lerobot.policies.rlearn.configuration_rlearn import RLearNConfig
|
||||
from lerobot.policies.rlearn.modeling_rlearn import RLearNPolicy
|
||||
from tests.utils import require_package
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
def test_rlearn_instantiation_and_forward_tensor_batch():
|
||||
"""Instantiate RLearN and run a forward pass with a (B, T, C, H, W) tensor input using a real model and real text."""
|
||||
cfg = RLearNConfig(
|
||||
model_name="google/siglip2-large-patch16-256",
|
||||
push_to_hub=False,
|
||||
freeze_backbones=True,
|
||||
)
|
||||
cfg.input_features = {
|
||||
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||
}
|
||||
cfg.output_features = {
|
||||
REWARD: PolicyFeature(type=FeatureType.REWARD, shape=(1,)),
|
||||
}
|
||||
|
||||
policy = RLearNPolicy(cfg)
|
||||
|
||||
B, T, C, H, W = 2, 3, 3, 256, 256
|
||||
batch = {
|
||||
OBS_IMAGES: torch.rand(B, T, C, H, W),
|
||||
REWARD: torch.randint(low=0, high=1, size=(B, T)).float(),
|
||||
OBS_LANGUAGE: ["move the green cube into the box" for _ in range(B)],
|
||||
}
|
||||
|
||||
loss, logs = policy.forward(batch)
|
||||
assert isinstance(loss, torch.Tensor)
|
||||
assert "loss" in logs
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
def test_rlearn_instantiation_and_forward_list_batch_with_language():
|
||||
"""Instantiate RLearN and run a forward pass with a list-of-frames input and real language using a real model."""
|
||||
cfg = RLearNConfig(
|
||||
model_name="google/siglip2-large-patch16-256",
|
||||
push_to_hub=False,
|
||||
freeze_backbones=True,
|
||||
)
|
||||
cfg.input_features = {
|
||||
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||
}
|
||||
cfg.output_features = {
|
||||
REWARD: PolicyFeature(type=FeatureType.REWARD, shape=(1,)),
|
||||
}
|
||||
|
||||
policy = RLearNPolicy(cfg)
|
||||
|
||||
B, T, C, H, W = 2, 4, 3, 256, 256
|
||||
frames = [torch.rand(B, C, H, W) for _ in range(T)]
|
||||
batch = {
|
||||
OBS_IMAGES: frames, # list[(B, C, H, W)]
|
||||
REWARD: torch.randint(low=0, high=2, size=(B, T)).float(),
|
||||
OBS_LANGUAGE: ["move the red cube into the box" for _ in range(B)],
|
||||
}
|
||||
|
||||
loss, logs = policy.forward(batch)
|
||||
assert isinstance(loss, torch.Tensor)
|
||||
assert "loss" in logs
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
def test_rlearn_composite_loss_shapes_and_terms():
|
||||
"""Smoke test composite loss: checks presence of terms and valid gradients."""
|
||||
cfg = RLearNConfig(
|
||||
model_name="google/siglip2-large-patch16-256",
|
||||
push_to_hub=False,
|
||||
freeze_backbones=True,
|
||||
loss_type="composite",
|
||||
lambda_prog=1.0,
|
||||
lambda_spatial_nce=0.5,
|
||||
lambda_rewind=0.4,
|
||||
num_ranking_pairs=32, # Fewer pairs for testing
|
||||
last_k_for_nce=2,
|
||||
)
|
||||
cfg.input_features = {
|
||||
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||
}
|
||||
cfg.output_features = {
|
||||
REWARD: PolicyFeature(type=FeatureType.REWARD, shape=(1,)),
|
||||
}
|
||||
|
||||
policy = RLearNPolicy(cfg)
|
||||
|
||||
B, T, C, H, W = 2, 3, 3, 256, 256
|
||||
# Progress labels y in [0,1]
|
||||
y = torch.linspace(0, 1, T).unsqueeze(0).repeat(B, 1)
|
||||
batch = {
|
||||
OBS_IMAGES: torch.rand(B, T, C, H, W),
|
||||
REWARD: y.clone(),
|
||||
OBS_LANGUAGE: ["stack the blocks" for _ in range(B)],
|
||||
}
|
||||
|
||||
loss, logs = policy.forward(batch)
|
||||
assert isinstance(loss, torch.Tensor) and torch.isfinite(loss)
|
||||
# Expect composite terms present with spatial awareness and ReWiND
|
||||
assert "loss_prog" in logs
|
||||
assert "loss_spatial_nce" in logs
|
||||
assert "loss_rewind_forward" in logs
|
||||
assert "loss_rewind_reverse" in logs
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
def test_rlearn_preprocessor_tokenizes_and_copies_task():
|
||||
cfg = RLearNConfig(
|
||||
model_name="google/siglip2-large-patch16-256",
|
||||
device="cpu",
|
||||
push_to_hub=False,
|
||||
)
|
||||
cfg.input_features = {
|
||||
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 64, 64)),
|
||||
}
|
||||
cfg.output_features = {
|
||||
REWARD: PolicyFeature(type=FeatureType.REWARD, shape=(1,)),
|
||||
}
|
||||
|
||||
pre, post = make_processor(cfg, dataset_stats=None)
|
||||
|
||||
B, C, H, W = 2, 3, 64, 64
|
||||
batch = {
|
||||
"observation.image": torch.rand(B, C, H, W),
|
||||
REWARD: torch.zeros(B),
|
||||
"task": ["pick the cube", "place it in the box"],
|
||||
}
|
||||
|
||||
processed = pre(batch)
|
||||
|
||||
assert isinstance(processed, dict)
|
||||
assert f"{OBS_LANGUAGE}.tokens" in processed
|
||||
assert f"{OBS_LANGUAGE}.attention_mask" in processed
|
||||
assert OBS_LANGUAGE in processed
|
||||
|
||||
tokens = processed[f"{OBS_LANGUAGE}.tokens"]
|
||||
attn = processed[f"{OBS_LANGUAGE}.attention_mask"]
|
||||
assert tokens.dim() == 2 and attn.dim() == 2
|
||||
assert tokens.shape[0] == B and attn.shape[0] == B
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
def test_rlearn_preprocessor_string_task_and_to_batch():
|
||||
cfg = RLearNConfig(
|
||||
model_name="google/siglip2-large-patch16-256",
|
||||
device="cpu",
|
||||
push_to_hub=False,
|
||||
)
|
||||
cfg.input_features = {
|
||||
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 64, 64)),
|
||||
}
|
||||
cfg.output_features = {
|
||||
REWARD: PolicyFeature(type=FeatureType.REWARD, shape=(1,)),
|
||||
}
|
||||
|
||||
pre, post = make_processor(cfg, dataset_stats=None)
|
||||
|
||||
# Unbatched image and single string task
|
||||
batch = {
|
||||
"observation.image": torch.rand(3, 64, 64),
|
||||
REWARD: torch.tensor(0.0),
|
||||
"task": "move the green cube into the box",
|
||||
}
|
||||
|
||||
processed = pre(batch)
|
||||
|
||||
# Image should have batch dim now
|
||||
assert processed["observation.image"].dim() == 4 and processed["observation.image"].shape[0] == 1
|
||||
# Language copy and tokenization should exist
|
||||
assert OBS_LANGUAGE in processed and isinstance(processed[OBS_LANGUAGE], list)
|
||||
assert f"{OBS_LANGUAGE}.tokens" in processed
|
||||
assert f"{OBS_LANGUAGE}.attention_mask" in processed
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
def test_rlearn_pipeline_end_to_end_forward():
|
||||
"""End-to-end: preprocessor + model forward using RLearN pipeline on synthetic data."""
|
||||
cfg = RLearNConfig(
|
||||
model_name="google/siglip2-large-patch16-256",
|
||||
device="cpu",
|
||||
push_to_hub=False,
|
||||
freeze_backbones=True,
|
||||
loss_type="composite",
|
||||
)
|
||||
cfg.input_features = {
|
||||
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||
}
|
||||
cfg.output_features = {
|
||||
REWARD: PolicyFeature(type=FeatureType.REWARD, shape=(1,)),
|
||||
}
|
||||
|
||||
# Build processors and model
|
||||
pre, post = make_processor(cfg, dataset_stats=None)
|
||||
policy = RLearNPolicy(cfg)
|
||||
|
||||
B, T, C, H, W = 2, 3, 3, 256, 256
|
||||
y = torch.linspace(0, 1, T).unsqueeze(0).repeat(B, 1)
|
||||
raw = {
|
||||
# Provide as observation.image to let preprocessor map/normalize and batch
|
||||
"observation.image": torch.rand(B, C, H, W), # not time-major to test ToBatch
|
||||
REWARD: y[:, :1].clone(), # single step label; pipeline keeps structure
|
||||
"task": ["insert the peg", "insert the peg"],
|
||||
}
|
||||
|
||||
processed = pre(raw)
|
||||
# Integrate preprocessor output with model forward
|
||||
loss, logs = policy.forward(
|
||||
{
|
||||
OBS_IMAGES: processed.get(OBS_IMAGES, processed.get("observation.image"))
|
||||
.unsqueeze(1)
|
||||
.repeat(1, T, 1, 1, 1),
|
||||
REWARD: y.clone(),
|
||||
OBS_LANGUAGE: processed[OBS_LANGUAGE],
|
||||
}
|
||||
)
|
||||
assert isinstance(loss, torch.Tensor) and torch.isfinite(loss)
|
||||
Reference in New Issue
Block a user