add tests, implement formula 1,2 correctly and cleanup

This commit is contained in:
Pepijn
2025-11-27 14:04:01 +01:00
parent 3ed0425d2c
commit f2ad86831d
7 changed files with 861 additions and 274 deletions
+1 -7
View File
@@ -64,9 +64,7 @@ class SARMTemporalSampler(Sampler):
self.shuffle = shuffle
self.samples_per_epoch = samples_per_epoch
# Minimum frames needed for SARM pattern:
# 8 consecutive frames with frame_gap spacing = 7 * frame_gap + 1
# (Plus the initial frame which is always available)
# Minimum frames needed for SARM pattern: 8 consecutive frames with frame_gap spacing = 7 * frame_gap + 1
self.min_frames_needed = 7 * frame_gap + 1
if seed is not None:
@@ -138,7 +136,3 @@ class SARMTemporalSampler(Sampler):
for i in range(self.samples_per_epoch):
idx = i % len(self.all_valid_positions)
yield int(self.all_valid_positions[idx])
# Backwards compatibility alias
TemporalSequenceSampler = SARMTemporalSampler
-2
View File
@@ -18,7 +18,6 @@ from lerobot.policies.sarm.configuration_sarm import SARMConfig
from lerobot.policies.sarm.modeling_sarm import (
SARMRewardModel,
SARMTransformer,
compute_stage_loss,
)
from lerobot.policies.sarm.processor_sarm import (
SARMEncodingProcessorStep,
@@ -29,7 +28,6 @@ __all__ = [
"SARMConfig",
"SARMRewardModel",
"SARMTransformer",
"compute_stage_loss",
"SARMEncodingProcessorStep",
"make_sarm_pre_post_processors",
]
+45 -29
View File
@@ -17,7 +17,7 @@
from dataclasses import dataclass, field
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import PolicyFeature, FeatureType
from lerobot.configs.types import PolicyFeature, FeatureType, NormalizationMode
from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
@@ -27,63 +27,83 @@ from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
class SARMConfig(PreTrainedConfig):
"""Configuration class for SARM (Stage-Aware Reward Modeling)"""
# Visual encoding parameters
image_dim: int = 512 # CLIP embedding dimension
# CLIP encoding parameters
image_dim: int = 512
text_dim: int = 512
num_frames: int = 9 # 1 initial + 8 consecutive frames
frame_gap: int = 30 # Frame gap between consecutive frames (at 30 fps = 1 second)
# Text encoding parameters (CLIP text encoder output dimension)
text_dim: int = 512
# Joint state parameters
state_dim: int | None = None # Auto-detected from dataset if None
# Architecture parameters
hidden_dim: int = 768
num_heads: int = 12
num_layers: int = 8
num_stages: int = 5 # Number of task stages for classification (auto-updated from annotations if available)
max_state_dim: int = 32
num_stages: int = 5 # Number of task stages (auto-updated from annotations if available)
subtask_names: list | None = None # List of subtask names (auto-populated from annotations)
temporal_proportions: list | None = None # Temporal proportions for each stage (auto-computed from annotations)
# Temporal parameters
max_length: int = num_frames # Maximum video sequence length (matches num_frames)
use_temporal_sampler: bool = True # Always enable temporal sequence loading
# Training parameters
batch_size: int = 64
clip_batch_size: int = 64 # Batch size for CLIP encoding
gradient_checkpointing: bool = False # Enable gradient checkpointing
dropout: float = 0.1 # Dropout rate
dropout: float = 0.1
stage_loss_weight: float = 1.0 # Weight for stage classification loss when using subtask annotations
pretrained_model_path: str | None = None
device: str | None = None
# Processor settings
image_key: str = "observation.images.top" # Key for image used from the dataset
task_description: str = "perform the task" # Default task description
# Video_features and text_features are generated by the processor from raw images/text, we don't declare them as VISUAL/LANGUAGE here to avoid validation errors
input_features: dict = field(default_factory=lambda: {
"state_features": PolicyFeature(shape=(9, 14), type=FeatureType.STATE) # Example: 7 DOF × 2 arms
})
# State key in the dataset (for normalization)
state_key: str = "observation.state"
# Populated by the processor (video_features, state_features, text_features)
input_features: dict = field(default_factory=lambda: {})
# Output features
output_features: dict = field(default_factory=lambda: {
"stage": PolicyFeature(shape=(1,), type=FeatureType.REWARD),
"progress": PolicyFeature(shape=(1,), type=FeatureType.REWARD)
"stage": PolicyFeature(shape=(9, 5), type=FeatureType.REWARD),
"progress": PolicyFeature(shape=(9, 1), type=FeatureType.REWARD),
})
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.MEAN_STD,
"LANGUAGE": NormalizationMode.IDENTITY,
"REWARD": NormalizationMode.IDENTITY,
}
)
def __post_init__(self):
super().__post_init__()
# Add the image_key, the processor will transform this into video_features
if self.image_key and self.image_key not in self.input_features:
# Add the image_key as VISUAL (this is the raw image from dataset)
if self.image_key:
self.input_features[self.image_key] = PolicyFeature(
shape=(480, 640, 3),
type=FeatureType.VISUAL
)
# Add state_key as STATE (raw state from dataset, will be padded to max_state_dim)
self.input_features[self.state_key] = PolicyFeature(
shape=(self.max_state_dim,), # Single frame state, temporal sampling handles sequence
type=FeatureType.STATE
)
# Update output features with actual dimensions
self.output_features["stage"] = PolicyFeature(
shape=(self.num_frames, self.num_stages),
type=FeatureType.REWARD
)
self.output_features["progress"] = PolicyFeature(
shape=(self.num_frames, 1),
type=FeatureType.REWARD
)
# Validate configuration
if self.hidden_dim % self.num_heads != 0:
raise ValueError(
@@ -95,9 +115,6 @@ class SARMConfig(PreTrainedConfig):
f"max_length ({self.max_length}) must equal num_frames ({self.num_frames})"
)
if self.dropout < 0 or self.dropout >= 1:
raise ValueError(f"dropout must be in [0, 1), got {self.dropout}")
if self.num_stages < 2:
raise ValueError(f"num_stages must be at least 2, got {self.num_stages}")
@@ -139,11 +156,10 @@ class SARMConfig(PreTrainedConfig):
Returns:
9 delta indices: [-1_000_000, -(7*gap), -(6*gap), ..., -gap, 0]
"""
# First delta: large negative to always clamp to episode start (frame 0)
initial_frame_delta = -1_000_000
# Remaining 8 deltas: consecutive frames with frame_gap spacing
num_consecutive = self.num_frames - 1 # 8 frames
# Remaining consecutive frames with frame_gap spacing
num_consecutive = self.num_frames - 1
consecutive_deltas = list(range(-self.frame_gap * (num_consecutive - 1), 1, self.frame_gap))
return [initial_frame_delta] + consecutive_deltas
+60 -94
View File
@@ -15,7 +15,7 @@
# limitations under the License.
import logging
from typing import List, Union, Dict, Optional
from typing import List, Union, Optional
import random
import numpy as np
@@ -28,9 +28,12 @@ from transformers import CLIPModel, CLIPProcessor
from torch import Tensor
from lerobot.policies.sarm.configuration_sarm import SARMConfig
from lerobot.policies.sarm.sarm_utils import compute_priors, compute_cumulative_progress_batch, pad_state_to_max_dim
from lerobot.policies.pretrained import PreTrainedPolicy
class SARMTransformer(nn.Module):
"""
SARM Transformer model for stage-aware reward prediction.
@@ -45,8 +48,8 @@ class SARMTransformer(nn.Module):
def __init__(
self,
video_dim: int = 512,
text_dim: int = 512, # CLIP text encoder output dimension (per SARM paper A.4)
state_dim: int = 14,
text_dim: int = 512,
max_state_dim: int = 32,
hidden_dim: int = 768,
num_heads: int = 12,
num_layers: int = 8,
@@ -59,9 +62,8 @@ class SARMTransformer(nn.Module):
self.hidden_dim = hidden_dim
self.max_length = max_length
self.num_stages = num_stages
self.max_state_dim = max_state_dim
# Store temporal proportions for progress conversion (Paper Eq. 4)
# ŷ = P_{k-1} + ᾱ_k × τ̂
if temporal_proportions is None:
raise ValueError(
"temporal_proportions is required for SARM. "
@@ -77,15 +79,13 @@ class SARMTransformer(nn.Module):
self.register_buffer('alpha', alpha)
self.register_buffer('cumulative_prior', cumulative)
# Project video, text, and state to same dimension
self.video_proj = nn.Linear(video_dim, hidden_dim)
self.text_proj = nn.Linear(text_dim, hidden_dim)
self.state_proj = nn.Linear(state_dim, hidden_dim)
self.state_proj = nn.Linear(max_state_dim, hidden_dim)
# Position embedding only for the first frame
self.first_pos_embed = nn.Parameter(torch.randn(1, hidden_dim))
# Transformer encoder
encoder_layer = nn.TransformerEncoderLayer(
d_model=hidden_dim,
nhead=num_heads,
@@ -96,7 +96,6 @@ class SARMTransformer(nn.Module):
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
# Stage estimator head (classification)
# Paper A.4: "2 layers with hidden dimension of 512"
self.stage_head = nn.Sequential(
nn.Linear(hidden_dim, 512),
nn.LayerNorm(512),
@@ -106,8 +105,6 @@ class SARMTransformer(nn.Module):
)
# Subtask estimator head (regression, conditioned on stage)
# Takes concatenated [features, stage_embedding]
# Paper A.4: "2 layers with hidden dimension of 512"
self.stage_embedding = nn.Embedding(num_stages, hidden_dim // 4)
subtask_input_dim = hidden_dim + hidden_dim // 4
self.subtask_head = nn.Sequential(
@@ -119,13 +116,13 @@ class SARMTransformer(nn.Module):
nn.Sigmoid()
)
# Attention mask for causal self-attention
# Attention mask
self.register_buffer("attention_mask", None, persistent=False)
def _get_attention_mask(self, seq_length: int, device: torch.device) -> torch.Tensor:
"""Generate or retrieve cached causal attention mask."""
if self.attention_mask is None or self.attention_mask.shape[0] != seq_length:
# Create causal mask (upper triangular with -inf)
# Create causal mask
mask = nn.Transformer.generate_square_subsequent_mask(seq_length, device=device)
self.attention_mask = mask
return self.attention_mask
@@ -150,14 +147,16 @@ class SARMTransformer(nn.Module):
- Stage probabilities (batch_size, seq_len, num_stages)
- Progress predictions for each frame (batch_size, seq_len, 1)
"""
batch_size = video_frames.shape[0]
# Project inputs to common dimension
video_embed = self.video_proj(video_frames) # [batch_size, seq_len, hidden_dim]
text_embed = self.text_proj(text_embed).unsqueeze(1) # [batch_size, 1, hidden_dim]
state_embed = self.state_proj(state_features) # [batch_size, seq_len, hidden_dim]
# Fuse video and state features (simple addition)
# Pad state features to max_state_dim before projection
state_features_padded = pad_state_to_max_dim(state_features, self.max_state_dim)
state_embed = self.state_proj(state_features_padded) # [batch_size, seq_len, hidden_dim]
# Fuse video and state features
video_embed = video_embed + state_embed
# Add positional embedding to first video frame
@@ -173,7 +172,7 @@ class SARMTransformer(nn.Module):
# Pass through transformer with causal masking
transformed = self.transformer(sequence, mask=attention_mask, is_causal=True)
# Get frame features (exclude text token)
# Get frame features
frame_features = transformed[:, 1:] # [batch_size, seq_len, hidden_dim]
# Stage estimation
@@ -193,14 +192,11 @@ class SARMTransformer(nn.Module):
# τ̂ = within-subtask progress (0-1)
tau_preds = self.subtask_head(conditioned_features) # [batch_size, seq_len, 1]
# Convert τ̂ to cumulative progress ŷ using Paper Eq. 4:
# Convert τ̂ to cumulative progress ŷ using Paper Formula (2):
# ŷ = P_{k-1} + ᾱ_k × τ̂
# P_{k-1} = cumulative prior up to stage k-1
# ᾱ_k = temporal proportion of stage k
P_k_minus_1 = self.cumulative_prior[stage_indices] # [batch_size, seq_len]
alpha_k = self.alpha[stage_indices] # [batch_size, seq_len]
progress_preds = P_k_minus_1.unsqueeze(-1) + alpha_k.unsqueeze(-1) * tau_preds
progress_preds = compute_cumulative_progress_batch(
tau_preds, stage_indices, self.alpha, self.cumulative_prior
)
return stage_logits, stage_probs, progress_preds
@@ -227,65 +223,37 @@ class SARMRewardModel(PreTrainedPolicy):
self.dataset_stats = dataset_stats
self.device = torch.device(config.device if config.device else "cuda" if torch.cuda.is_available() else "cpu")
# Auto-detect num_stages from dataset annotations before building the model
# Detect num_stages from dataset annotations before building the model
if dataset_meta is not None:
self._update_num_stages_from_dataset(dataset_meta)
# Initialize CLIP encoder for images AND text (per SARM paper A.4)
logging.info("Loading CLIP encoder for images and text...")
logging.info("Loading CLIP encoder")
self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", use_fast=True)
self.clip_model.to(self.device)
self.clip_model.eval()
# Auto-detect state_dim from dataset_stats
if config.state_dim is None:
logging.info(f"Attempting to auto-detect state_dim. dataset_stats is None: {dataset_stats is None}")
if dataset_stats is not None:
if "observation.state" in dataset_stats:
config.state_dim = dataset_stats["observation.state"]["mean"].shape[0]
logging.info(f"Auto-detected state_dim={config.state_dim} from dataset_stats['observation.state']")
elif "state" in dataset_stats:
config.state_dim = dataset_stats["state"]["mean"].shape[0]
logging.info(f"Auto-detected state_dim={config.state_dim} from dataset_stats['state']")
else:
logging.warning(f"State keys not found in dataset_stats. Available keys: {list(dataset_stats.keys())}")
else:
logging.warning("dataset_stats is None, cannot auto-detect state_dim")
# Raise explicit error if still None
if config.state_dim is None:
raise ValueError(
"Could not determine state_dim! "
f"dataset_stats={'None' if dataset_stats is None else f'available with keys: {list(dataset_stats.keys())}'}, "
"config.state_dim=None. "
"Please either:\n"
"1. Provide --policy.state_dim=<your_state_dimension> explicitly, or\n"
"2. Ensure dataset_stats contains 'observation.state' or 'state' key"
)
# Initialize SARM transformer with temporal proportions for progress conversion
temporal_proportions = getattr(config, 'temporal_proportions', None)
self.sarm_transformer = SARMTransformer(
video_dim=config.image_dim,
text_dim=config.text_dim,
state_dim=config.state_dim,
max_state_dim=config.max_state_dim,
hidden_dim=config.hidden_dim,
num_heads=config.num_heads,
num_layers=config.num_layers,
num_stages=config.num_stages,
max_length=config.max_length,
dropout=config.dropout,
temporal_proportions=temporal_proportions
temporal_proportions=config.temporal_proportions
)
self.sarm_transformer.to(self.device)
logging.info(f"SARM Reward Model initialized on {self.device}")
def _update_num_stages_from_dataset(self, dataset_meta) -> None:
"""Update num_stages and temporal_proportions from dataset subtask annotations."""
"""Update num_stages and temporal_proportions from dataset subtask annotations.
Implements SARM Paper Formula (1):
ᾱ_k = (1/M) × Σ_i (L_{i,k} / T_i)
"""
episodes = dataset_meta.episodes
if episodes is None or len(episodes) == 0:
raise ValueError("No episodes found, using default num_stages")
@@ -295,27 +263,38 @@ class SARMRewardModel(PreTrainedPolicy):
episodes_df = episodes.to_pandas()
# Collect all unique subtask names and compute durations
# Collect subtask durations and trajectory lengths for compute_priors
all_subtask_names = set()
subtask_durations = {}
subtask_durations_per_trajectory = {}
trajectory_lengths = {}
for ep_idx in episodes_df.index:
subtask_names = episodes_df.loc[ep_idx, 'subtask_names']
if subtask_names is None or (isinstance(subtask_names, float) and pd.isna(subtask_names)):
subtask_names_ep = episodes_df.loc[ep_idx, 'subtask_names']
if subtask_names_ep is None or (isinstance(subtask_names_ep, float) and pd.isna(subtask_names_ep)):
continue
all_subtask_names.update(subtask_names)
all_subtask_names.update(subtask_names_ep)
# Compute durations if available
if 'subtask_start_frames' in episodes_df.columns and 'subtask_end_frames' in episodes_df.columns:
start_frames = episodes_df.loc[ep_idx, 'subtask_start_frames']
end_frames = episodes_df.loc[ep_idx, 'subtask_end_frames']
for i, name in enumerate(subtask_names):
# Compute total trajectory length T_i
total_traj_length = sum(end_frames[i] - start_frames[i] for i in range(len(subtask_names_ep)))
if total_traj_length <= 0:
continue
for i, name in enumerate(subtask_names_ep):
duration = end_frames[i] - start_frames[i]
if name not in subtask_durations:
subtask_durations[name] = []
subtask_durations[name].append(duration)
if name not in subtask_durations_per_trajectory:
subtask_durations_per_trajectory[name] = []
trajectory_lengths[name] = []
subtask_durations_per_trajectory[name].append(duration)
trajectory_lengths[name].append(total_traj_length)
if not all_subtask_names:
raise ValueError("No valid subtask names found, using default num_stages")
@@ -324,26 +303,20 @@ class SARMRewardModel(PreTrainedPolicy):
subtask_names = sorted(list(all_subtask_names))
num_stages = len(subtask_names)
# Compute temporal proportions (Paper Eq. 1: ᾱ_k)
avg_durations = {}
for name in subtask_names:
if name in subtask_durations and subtask_durations[name]:
avg_durations[name] = np.mean(subtask_durations[name])
else:
avg_durations[name] = 1.0 # Default
total_duration = sum(avg_durations.values())
if total_duration > 0:
temporal_proportions = [avg_durations[name] / total_duration for name in subtask_names]
else:
temporal_proportions = [1.0 / num_stages] * num_stages
# Compute temporal proportions using Paper Formula (1)
temporal_proportions_dict = compute_priors(
subtask_durations_per_trajectory,
trajectory_lengths,
subtask_names
)
temporal_proportions = [temporal_proportions_dict[name] for name in subtask_names]
self.config.num_stages = num_stages
self.config.subtask_names = subtask_names
self.config.temporal_proportions = temporal_proportions
logging.info(f"Auto-detected {num_stages} subtasks: {subtask_names}")
logging.info(f"Temporal proportions: {dict(zip(subtask_names, temporal_proportions))}")
logging.info(f"Temporal proportions: {temporal_proportions_dict}")
def to(self, device):
"""Override to method to ensure all components move together."""
@@ -475,7 +448,6 @@ class SARMRewardModel(PreTrainedPolicy):
If return_stages=True:
Tuple of (rewards, stage_probs)
"""
# Convert to tensors if needed
if isinstance(text_embeddings, np.ndarray):
text_embeddings = torch.tensor(text_embeddings, dtype=torch.float32)
if isinstance(video_embeddings, np.ndarray):
@@ -535,16 +507,13 @@ class SARMRewardModel(PreTrainedPolicy):
def load_pretrained_checkpoint(self, checkpoint_path: str, strict: bool = False):
"""Load pretrained model weights from a checkpoint file."""
logging.info(f"Loading pretrained checkpoint from {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=False)
# Handle different checkpoint formats
if "model_state_dict" in checkpoint:
state_dict = checkpoint["model_state_dict"]
else:
state_dict = checkpoint
# Load only the SARMTransformer weights
missing_keys, unexpected_keys = self.sarm_transformer.load_state_dict(state_dict, strict=strict)
if missing_keys:
@@ -557,9 +526,7 @@ class SARMRewardModel(PreTrainedPolicy):
def train(self, mode: bool = True):
"""Set training mode. Note: CLIP encoder always stays in eval mode (frozen)."""
super().train(mode)
# Keep CLIP encoder in eval mode (frozen per SARM paper)
self.clip_model.eval()
# Only transformer can be trained
self.sarm_transformer.train(mode)
return self
@@ -686,8 +653,7 @@ class SARMRewardModel(PreTrainedPolicy):
state = state_features[i] if state_features is not None else None
progress = progress_from_annotations[i].squeeze(-1) # (T,)
# Apply temporal augmentation with 50% probability (SARM paper A.4)
# Appends up to 4 reversed frames to simulate failures/recoveries
# Apply temporal augmentation with 50% probability: appends up to 4 reversed frames to simulate failures/recoveries
if random.random() < 0.5:
video, progress, state = self._apply_temporal_augmentation(video, progress, state, max_length)
@@ -729,7 +695,7 @@ class SARMRewardModel(PreTrainedPolicy):
total_loss = total_loss + self.config.stage_loss_weight * stage_loss
output_dict['stage_loss'] = stage_loss.item()
# Misaligned loss: 20% probability (SARM paper - improve video-language alignment)
# Misaligned loss: 20% probability
if random.random() < 0.2:
shuffle_idx = torch.randperm(batch_size, device=self.device)
_, _, misaligned_preds = self.sarm_transformer(
+110 -138
View File
@@ -23,16 +23,19 @@ import pandas as pd
from transformers import CLIPModel, CLIPProcessor
from lerobot.policies.sarm.configuration_sarm import SARMConfig
from lerobot.policies.sarm.sarm_utils import compute_priors, compute_tau, compute_cumulative_progress_batch, pad_state_to_max_dim
from lerobot.processor import (
ProcessorStep,
PolicyProcessorPipeline,
PolicyAction,
DeviceProcessorStep,
AddBatchDimensionProcessorStep,
NormalizerProcessorStep,
)
from lerobot.processor.converters import (
policy_action_to_transition,
transition_to_policy_action,
from_tensor_to_numpy,
)
from lerobot.processor.pipeline import PipelineFeatureType
from lerobot.processor.core import EnvTransition, TransitionKey
@@ -41,20 +44,7 @@ from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PR
class SARMEncodingProcessorStep(ProcessorStep):
"""
ProcessorStep that encodes images and text for SARM training.
Per SARM paper (Appendix A.4): "We employ a frozen clip-vit-base-patch32 encoder
to process both RGB image sequences and task descriptions."
This step handles:
- CLIP image encoding (512-dim)
- CLIP text encoding (512-dim)
- Joint state normalization
Supports temporal sequences: (B, T, C, H, W) (B, T, 512) video features
"""
"""ProcessorStep that encodes images and text with CLIP."""
def __init__(
self,
config: SARMConfig,
@@ -69,8 +59,6 @@ class SARMEncodingProcessorStep(ProcessorStep):
self.task_description = task_description or config.task_description
self.dataset_meta = dataset_meta
self.dataset_stats = dataset_stats
# Compute temporal proportions from subtask annotations if available
self.temporal_proportions = None
self.subtask_names = None
if dataset_meta is not None:
@@ -94,7 +82,14 @@ class SARMEncodingProcessorStep(ProcessorStep):
self.device = device
def _compute_temporal_proportions(self):
"""Compute temporal proportions for each subtask from dataset annotations."""
"""Compute temporal proportions for each subtask from dataset annotations.
Implements SARM Paper Formula (1):
ᾱ_k = (1/M) × Σ_i (L_{i,k} / T_i)
This averages the proportion of time spent on each subtask within each trajectory,
giving equal weight to all trajectories regardless of absolute length.
"""
if self.dataset_meta is None or not hasattr(self.dataset_meta, 'episodes'):
return
@@ -108,32 +103,42 @@ class SARMEncodingProcessorStep(ProcessorStep):
logging.info("No subtask annotations found in dataset")
return
# Convert to pandas
episodes_df = episodes.to_pandas()
# Collect all subtask names and compute average durations
subtask_durations = {}
# Collect subtask durations and trajectory lengths for compute_priors
subtask_durations_per_trajectory = {}
trajectory_lengths = {}
all_subtask_names = set()
for ep_idx in episodes_df.index:
subtask_names = episodes_df.loc[ep_idx, 'subtask_names']
subtask_names_ep = episodes_df.loc[ep_idx, 'subtask_names']
# Skip episodes without annotations
if subtask_names is None or (isinstance(subtask_names, float) and pd.isna(subtask_names)):
if subtask_names_ep is None or (isinstance(subtask_names_ep, float) and pd.isna(subtask_names_ep)):
continue
start_times = episodes_df.loc[ep_idx, 'subtask_start_times']
end_times = episodes_df.loc[ep_idx, 'subtask_end_times']
# Track unique subtask names
all_subtask_names.update(subtask_names)
all_subtask_names.update(subtask_names_ep)
# Compute durations
for i, name in enumerate(subtask_names):
# Compute total trajectory length T_i (sum of all subtask durations)
total_traj_length = sum(end_times[i] - start_times[i] for i in range(len(subtask_names_ep)))
if total_traj_length <= 0:
continue
# Store duration and trajectory length for each subtask occurrence
for i, name in enumerate(subtask_names_ep):
duration = end_times[i] - start_times[i]
if name not in subtask_durations:
subtask_durations[name] = []
subtask_durations[name].append(duration)
if name not in subtask_durations_per_trajectory:
subtask_durations_per_trajectory[name] = []
trajectory_lengths[name] = []
subtask_durations_per_trajectory[name].append(duration)
trajectory_lengths[name].append(total_traj_length)
if not all_subtask_names:
logging.info("No valid subtask annotations found")
@@ -142,44 +147,17 @@ class SARMEncodingProcessorStep(ProcessorStep):
# Sort subtask names for consistent ordering
self.subtask_names = sorted(list(all_subtask_names))
self.config.num_stages = len(self.subtask_names)
self.config.subtask_names = self.subtask_names # Store in config for reference
self.config.subtask_names = self.subtask_names
# Compute average duration for each subtask
avg_durations = {}
for name in self.subtask_names:
if name in subtask_durations:
avg_durations[name] = np.mean(subtask_durations[name])
else:
avg_durations[name] = 0.0
# Normalize to get proportions
total_duration = sum(avg_durations.values())
if total_duration > 0:
self.temporal_proportions = {
name: avg_durations[name] / total_duration
for name in self.subtask_names
}
else:
raise ValueError(
"Cannot compute temporal proportions: all subtask durations are zero. "
"Check that your dataset has valid subtask annotations with start/end times."
)
# Store in config for the model to use in progress output conversion (SARM paper Eq. 4)
# Compute temporal proportions using Paper Formula (1)
self.temporal_proportions = compute_priors(
subtask_durations_per_trajectory,
trajectory_lengths,
self.subtask_names
)
self.config.temporal_proportions = [self.temporal_proportions[name] for name in self.subtask_names]
logging.info(f"Computed temporal proportions for {len(self.subtask_names)} subtasks: {self.temporal_proportions}")
def _to_numpy_array(self, x) -> np.ndarray:
"""Convert input to a 1D numpy array."""
if isinstance(x, torch.Tensor):
arr = x.cpu().numpy()
else:
arr = np.array(x)
if arr.ndim == 0:
arr = np.array([arr.item()])
return arr
def _find_episode_for_frame(self, frame_idx: int) -> int:
"""Find the episode index for a given frame index."""
for ep_idx in range(len(self.dataset_meta.episodes)):
@@ -187,14 +165,14 @@ class SARMEncodingProcessorStep(ProcessorStep):
ep_end = self.dataset_meta.episodes[ep_idx]["dataset_to_index"]
if ep_start <= frame_idx < ep_end:
return ep_idx
return 0 # Fallback
return 0
def _get_episode_indices(self, frame_indices: np.ndarray, episode_index) -> np.ndarray:
"""Get episode indices for each frame index."""
if episode_index is None:
return np.array([self._find_episode_for_frame(int(f)) for f in frame_indices])
episode_indices = self._to_numpy_array(episode_index)
episode_indices = np.atleast_1d(np.asarray(from_tensor_to_numpy(episode_index)))
# If single episode but multiple frames, compute episode for each frame
if len(episode_indices) == 1 and len(frame_indices) > 1:
@@ -211,22 +189,16 @@ class SARMEncodingProcessorStep(ProcessorStep):
Pattern: [ep_start, t-(7*gap), t-(6*gap), ..., t-gap, t]
"""
frame_gap = getattr(self.config, 'frame_gap', 1)
indices = []
# First frame is the episode's initial frame
indices.append(ep_start)
indices.append(ep_start) # First frame is the episode's initial frame
# Remaining frames are consecutive with frame_gap spacing
num_consecutive = num_frames - 1
for i in range(num_consecutive):
offset = -(num_consecutive - 1 - i) * frame_gap
offset = -(num_consecutive - 1 - i) * self.config.frame_gap
idx = max(ep_start, frame_idx + offset)
indices.append(idx)
return torch.tensor(indices)
def _compute_episode_metadata(
@@ -269,6 +241,14 @@ class SARMEncodingProcessorStep(ProcessorStep):
) -> tuple[int, float]:
"""Compute stage index and cumulative progress for a single frame.
Implements SARM Paper Formula (2):
y_t = P_{k-1} + ᾱ_k × τ_t
where:
- τ_t = (t - s_k) / (e_k - s_k) is within-subtask progress
- P_{k-1} is cumulative prior (sum of previous subtask proportions)
- ᾱ_k is the temporal proportion for subtask k
Args:
current_frame: Frame index relative to episode start
subtask_names: List of subtask names for this episode
@@ -278,53 +258,46 @@ class SARMEncodingProcessorStep(ProcessorStep):
Returns:
Tuple of (stage_idx, cumulative_progress)
"""
stage_idx = -1
cumulative_progress = 0.0
# Get temporal proportions as list for compute_cumulative_progress
temporal_proportions_list = [
self.temporal_proportions.get(name, 0.0) for name in self.subtask_names
]
# Find which subtask this frame belongs to
for j, (name, start_frame, end_frame) in enumerate(zip(subtask_names, subtask_start_frames, subtask_end_frames)):
if current_frame >= start_frame and current_frame <= end_frame:
# Found the subtask
# Found the subtask, get its global index
stage_idx = self.subtask_names.index(name) if name in self.subtask_names else 0
# Calculate within-subtask progress
subtask_duration = end_frame - start_frame
if subtask_duration > 0:
within_subtask_progress = (current_frame - start_frame) / subtask_duration
else:
within_subtask_progress = 1.0
# Compute τ_t using utility function (Paper Formula 2)
tau = compute_tau(current_frame, start_frame, end_frame)
# Calculate cumulative progress from completed subtasks
for k in range(j):
prev_name = subtask_names[k]
if prev_name in self.temporal_proportions:
cumulative_progress += self.temporal_proportions[prev_name]
# Add current subtask's partial progress
if name in self.temporal_proportions:
cumulative_progress += self.temporal_proportions[name] * within_subtask_progress
# Compute cumulative progress using utility function (Paper Formula 2)
cumulative_progress = compute_cumulative_progress_batch(
tau, stage_idx, temporal_proportions_list
)
return stage_idx, cumulative_progress
# No matching subtask found - estimate based on position
# No matching subtask found
if current_frame < subtask_start_frames[0]:
return 0, 0.0
elif current_frame > subtask_end_frames[-1]:
return len(self.subtask_names) - 1, 1.0
else:
# Between subtasks - use previous subtask's end state
# Between subtasks - use previous subtask's end state (tau = 1.0)
for j in range(len(subtask_names) - 1):
if current_frame > subtask_end_frames[j] and current_frame < subtask_start_frames[j + 1]:
name = subtask_names[j]
stage_idx = self.subtask_names.index(name) if name in self.subtask_names else j
# Sum up all completed subtasks
for k in range(j + 1):
prev_name = subtask_names[k]
if prev_name in self.temporal_proportions:
cumulative_progress += self.temporal_proportions[prev_name]
# Completed subtask, so tau = 1.0
cumulative_progress = compute_cumulative_progress_batch(
1.0, stage_idx, temporal_proportions_list
)
return stage_idx, cumulative_progress
return 0, 0.0 # Fallback
return 0, 0.0
def _compute_labels_for_sample(
self,
@@ -359,13 +332,8 @@ class SARMEncodingProcessorStep(ProcessorStep):
subtask_start_frames = episodes_df.loc[ep_idx, 'subtask_start_frames']
subtask_end_frames = episodes_df.loc[ep_idx, 'subtask_end_frames']
# Get episode boundaries
ep_start = self.dataset_meta.episodes[ep_idx]["dataset_from_index"]
# Get config values
frame_gap = self.config.frame_gap if hasattr(self.config, 'frame_gap') else 1
# Generate labels for each frame in the sequence
stage_labels = []
progress_targets = []
@@ -377,7 +345,7 @@ class SARMEncodingProcessorStep(ProcessorStep):
else:
# Positions 1-8: consecutive frames with frame_gap spacing
num_consecutive = seq_len - 1
offset = -(num_consecutive - i) * frame_gap
offset = -(num_consecutive - i) * self.config.frame_gap
current_frame = max(0, frame_idx + offset - ep_start)
@@ -388,7 +356,6 @@ class SARMEncodingProcessorStep(ProcessorStep):
stage_labels.append(stage_idx)
progress_targets.append(cumulative_progress)
# Convert to tensors
stage_labels = torch.tensor(stage_labels, dtype=torch.long)
progress_targets = torch.tensor(progress_targets, dtype=torch.float32).unsqueeze(-1)
@@ -411,7 +378,7 @@ class SARMEncodingProcessorStep(ProcessorStep):
is_batch = isinstance(frame_index, torch.Tensor) and frame_index.numel() > 1
# Normalize inputs to numpy arrays
frame_indices = self._to_numpy_array(frame_index)
frame_indices = np.atleast_1d(np.asarray(from_tensor_to_numpy(frame_index)))
episode_indices = self._get_episode_indices(frame_indices, episode_index)
# Determine sequence length
@@ -422,7 +389,6 @@ class SARMEncodingProcessorStep(ProcessorStep):
episodes_df = self.dataset_meta.episodes.to_pandas()
# Process all samples
all_stage_labels = []
all_progress_targets = []
@@ -450,7 +416,7 @@ class SARMEncodingProcessorStep(ProcessorStep):
if not isinstance(observation, dict):
raise ValueError("Observation must be a dictionary")
# 1. Encode images with CLIP
# Encode images with CLIP
image = observation.get(self.image_key)
if image is None:
raise ValueError(f"Image not found in observation for key: {self.image_key}")
@@ -460,27 +426,27 @@ class SARMEncodingProcessorStep(ProcessorStep):
video_features = self._encode_images_batch(image)
observation['video_features'] = video_features
# 2. Extract and normalize joint states
state_data = observation.get("state") or observation.get("observation.state")
# Extract state and pad to max_state_dim (already normalized by NormalizerProcessorStep)
state_key = self.config.state_key
state_data = observation.get(state_key)
if state_data is None:
raise ValueError("State data not found in observation (expected 'state' or 'observation.state')")
state_data = observation.get("state") or observation.get("observation.state")
if state_data is None:
raise ValueError(f"State data not found in observation (expected '{state_key}', 'state', or 'observation.state')")
if isinstance(state_data, torch.Tensor):
state_data = state_data.cpu().numpy()
state_tensor = state_data.float()
else:
state_tensor = torch.tensor(state_data, dtype=torch.float32)
state_key = "state" if "state" in observation else "observation.state"
if self.dataset_stats and state_key in self.dataset_stats:
mean = self.dataset_stats[state_key]['mean']
std = self.dataset_stats[state_key]['std']
state_data = (state_data - mean) / (std + 1e-8)
# Pad state
observation['state_features'] = pad_state_to_max_dim(state_tensor, self.config.max_state_dim)
observation['state_features'] = torch.tensor(state_data, dtype=torch.float32)
# 3. Encode text with CLIP (per SARM paper A.4)
# Encode text with CLIP
batch_size = video_features.shape[0]
observation['text_features'] = self._encode_text_clip(self.task_description, batch_size)
# 4. Extract frame/episode indices from complementary data
# Extract frame/episode indices from complementary data
comp_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
if not isinstance(comp_data, dict):
raise ValueError("COMPLEMENTARY_DATA must be a dictionary")
@@ -493,10 +459,10 @@ class SARMEncodingProcessorStep(ProcessorStep):
if episode_index is None:
raise ValueError("Episode index ('episode_index') not found in COMPLEMENTARY_DATA")
# 5. Compute episode metadata if dataset_meta is available
# Compute episode metadata if dataset_meta is available
if self.dataset_meta is not None:
is_batch = isinstance(frame_index, torch.Tensor) and frame_index.numel() > 1
frame_indices = self._to_numpy_array(frame_index)
frame_indices = np.atleast_1d(np.asarray(from_tensor_to_numpy(frame_index)))
episode_indices = self._get_episode_indices(frame_indices, episode_index)
# Determine number of frames from video features
@@ -512,7 +478,7 @@ class SARMEncodingProcessorStep(ProcessorStep):
observation['remaining_length'] = remaining
observation['episode_length'] = ep_lengths
# 6. Generate stage labels and progress targets from subtask annotations
# Generate stage labels and progress targets from subtask annotations
if self.temporal_proportions is not None and self.dataset_meta is not None:
stage_labels, progress_targets = self._generate_stage_and_progress_labels(
frame_index, episode_index, video_features
@@ -539,14 +505,10 @@ class SARMEncodingProcessorStep(ProcessorStep):
# Check if we have temporal dimension
has_temporal = len(images.shape) == 5
if has_temporal:
# Shape: (B, T, C, H, W)
if has_temporal: # Shape: (B, T, C, H, W)
batch_size, seq_length = images.shape[0], images.shape[1]
# Reshape to (B*T, C, H, W) to process all frames at once
images = images.reshape(batch_size * seq_length, *images.shape[2:])
elif len(images.shape) == 4:
# Shape: (B, C, H, W)
elif len(images.shape) == 4: # Shape: (B, C, H, W)
batch_size = images.shape[0]
seq_length = 1
else:
@@ -608,7 +570,7 @@ class SARMEncodingProcessorStep(ProcessorStep):
Returns:
Encoded text features with shape (B, 512)
"""
# Use CLIP's tokenizer directly for text (avoids image processor validation issues)
# Use CLIP's tokenizer directly for text
tokenizer = self.clip_processor.tokenizer
inputs = tokenizer([text], return_tensors="pt", padding=True, truncation=True)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
@@ -629,7 +591,7 @@ class SARMEncodingProcessorStep(ProcessorStep):
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
"""Add encoded features to the observation features."""
# Add the encoded features
# Add the encoded features (state uses max_state_dim, padded with zeros)
features[PipelineFeatureType.OBSERVATION]['video_features'] = PolicyFeature(
type=FeatureType.VISUAL,
shape=(self.config.num_frames, self.config.image_dim)
@@ -640,7 +602,7 @@ class SARMEncodingProcessorStep(ProcessorStep):
)
features[PipelineFeatureType.OBSERVATION]['state_features'] = PolicyFeature(
type=FeatureType.STATE,
shape=(self.config.num_frames, self.config.state_dim)
shape=(self.config.num_frames, self.config.max_state_dim)
)
return features
@@ -660,11 +622,16 @@ def make_sarm_pre_post_processors(
to process both RGB image sequences and task descriptions."
The pre-processing pipeline:
1. Encodes images with CLIP (512-dim)
2. Encodes text with CLIP (512-dim)
3. Normalizes joint states
4. Adds batch dimension
5. Moves data to device
1. Adds batch dimension
2. Normalizes observation.state using NormalizerProcessorStep (MEAN_STD)
3. SARMEncodingProcessorStep:
- Encodes images with CLIP (512-dim)
- Pads states to max_state_dim
- Encodes text with CLIP (512-dim)
4. Moves data to device
The post-processing pipeline:
1. Moves data to CPU (no unnormalization - outputs are rewards)
Args:
config: SARM configuration
@@ -676,6 +643,11 @@ def make_sarm_pre_post_processors(
"""
input_steps = [
AddBatchDimensionProcessorStep(),
NormalizerProcessorStep(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
SARMEncodingProcessorStep(
config=config,
dataset_meta=dataset_meta,
+249
View File
@@ -0,0 +1,249 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Utility functions for SARM progress label computation.
Implements formulas from the SARM paper:
- Formula (1): Compute dataset-level temporal proportions (priors) ᾱ_k
- Formula (2): Compute normalized progress targets y_t = P_{k-1} + ᾱ_k × τ_t
"""
import numpy as np
import torch
import torch.nn.functional as F
from typing import Sequence
def compute_priors(
subtask_durations_per_trajectory: dict[str, list[float]],
trajectory_lengths: dict[str, list[float]],
subtask_names: list[str],
) -> dict[str, float]:
"""
Compute dataset-level temporal proportions (priors) for each subtask.
Implements SARM Paper Formula (1):
ᾱ_k = (1/M) × Σ_i (L_{i,k} / T_i)
where:
- M is the number of trajectories
- L_{i,k} is the length of subtask k in trajectory i
- T_i is the total length of trajectory i
This averages the PROPORTION of each subtask within each trajectory,
giving equal weight to all trajectories regardless of their absolute length.
Args:
subtask_durations_per_trajectory: Dict mapping subtask name to list of
(duration, trajectory_length) tuples for each occurrence
trajectory_lengths: Dict mapping subtask name to list of trajectory lengths
for each occurrence of that subtask
subtask_names: Ordered list of subtask names
Returns:
Dict mapping subtask name to its temporal proportion (ᾱ_k)
"""
if not subtask_names:
raise ValueError("subtask_names cannot be empty")
# Compute proportion per occurrence: L_{i,k} / T_i
subtask_proportions = {}
for name in subtask_names:
if name in subtask_durations_per_trajectory and name in trajectory_lengths:
durations = subtask_durations_per_trajectory[name]
traj_lengths = trajectory_lengths[name]
if len(durations) != len(traj_lengths):
raise ValueError(
f"Mismatch in lengths for subtask '{name}': "
f"{len(durations)} durations vs {len(traj_lengths)} trajectory lengths"
)
# Compute L_{i,k} / T_i for each occurrence
proportions = []
for duration, traj_len in zip(durations, traj_lengths):
if traj_len > 0:
proportions.append(duration / traj_len)
# Average across all occurrences: (1/M) × Σ_i (L_{i,k} / T_i)
subtask_proportions[name] = np.mean(proportions) if proportions else 0.0
else:
subtask_proportions[name] = 0.0
# Normalize to ensure sum = 1 (handles floating point errors and missing subtasks)
total = sum(subtask_proportions.values())
if total > 0:
subtask_proportions = {
name: prop / total for name, prop in subtask_proportions.items()
}
else:
raise ValueError("Cannot compute temporal proportions: all proportions are zero. "
"Check that your dataset has valid subtask annotations with start/end times.")
return subtask_proportions
def compute_tau(
current_frame: int | float,
subtask_start: int | float,
subtask_end: int | float,
) -> float:
"""
Compute within-subtask normalized time τ_t.
Implements part of SARM Paper Formula (2):
τ_t = (t - s_k) / (e_k - s_k) [0, 1]
where:
- t is the current frame
- s_k is the start frame of subtask k
- e_k is the end frame of subtask k
Args:
current_frame: Current frame index (t)
subtask_start: Start frame of the subtask (s_k)
subtask_end: End frame of the subtask (e_k)
Returns:
Within-subtask progress τ_t [0, 1]
"""
subtask_duration = subtask_end - subtask_start
if subtask_duration <= 0:
return 1.0
tau = (current_frame - subtask_start) / subtask_duration
return float(np.clip(tau, 0.0, 1.0))
def compute_cumulative_progress_batch(
tau: torch.Tensor | float,
stage_indices: torch.Tensor | int,
alpha: torch.Tensor | Sequence[float],
cumulative_prior: torch.Tensor | None = None,
) -> torch.Tensor | float:
"""
Compute cumulative normalized progress from within-subtask progress.
This function implements the core formula used in SARM for both:
**Formula 2 (Training labels):**
y_t = P_{k-1} + ᾱ_k × τ_t [0, 1]
Used to compute ground-truth progress labels from subtask annotations.
- τ_t comes from annotated frame position: τ_t = (t - s_k) / (e_k - s_k)
- k is the known subtask from annotations
**Formula 4 (Inference predictions):**
ŷ_{1:N} = P̂_{k-1, 1:N} + ᾱ_{k, 1:N} × τ̂_{1:N} [0, 1]
Used to convert model outputs to cumulative progress during inference.
- τ̂ comes from the subtask MLP head (conditioned on predicted stage)
- k = Ŝ is the predicted stage from Formula 3: Ŝ = argmax(softmax(Ψ))
The formulas are mathematically identical; only the source of inputs differs:
- Training: τ and k from annotations ground-truth labels
- Inference: τ̂ and Ŝ from model predicted progress
where:
- P_{k-1} = Σ_{j=1}^{k-1} ᾱ_j is the cumulative prior (sum of previous proportions)
- ᾱ_k is the temporal proportion for subtask k (from Formula 1)
- τ is within-subtask progress [0, 1]
This ensures:
- y at start of subtask k = P_{k-1}
- y at end of subtask k = P_k
Supports both scalar and batched tensor inputs:
- Scalar: tau (float), stage_indices (int), alpha (list/sequence)
- Batch: tau (Tensor), stage_indices (Tensor), alpha (Tensor), cumulative_prior (Tensor)
Args:
tau: Within-subtask progress τ [0, 1].
For training: computed from frame position in annotated subtask.
For inference: predicted by subtask MLP head.
Scalar float or Tensor with shape (..., 1)
stage_indices: Index of current subtask k (0-indexed).
For training: known from annotations.
For inference: predicted via argmax(stage_probs) (Formula 3).
Scalar int or Tensor with shape (...)
alpha: Temporal proportions with shape (num_stages,) or Sequence[float].
Computed from dataset annotations using Formula 1.
cumulative_prior: Optional. Cumulative priors P with shape (num_stages + 1,)
where cumulative_prior[k] = P_k = Σ_{j=1}^{k} ᾱ_j.
If None, will be computed from alpha.
Returns:
Cumulative progress y [0, 1].
Scalar float if inputs are scalar, otherwise Tensor with shape (..., 1)
"""
if not isinstance(tau, torch.Tensor):
if not alpha:
raise ValueError("alpha (temporal_proportions) cannot be empty")
if isinstance(alpha, torch.Tensor):
alpha_list = alpha.tolist()
else:
alpha_list = list(alpha)
if stage_indices < 0 or stage_indices >= len(alpha_list):
raise ValueError(
f"stage_indices {stage_indices} out of range "
f"for {len(alpha_list)} subtasks"
)
# P_{k-1} = sum of proportions for subtasks 0 to k-1
P_k_minus_1 = sum(alpha_list[:stage_indices])
# ᾱ_k = proportion for current subtask
alpha_k = alpha_list[stage_indices]
# y_t = P_{k-1} + ᾱ_k × τ_t
y_t = P_k_minus_1 + alpha_k * tau
return float(np.clip(y_t, 0.0, 1.0))
if not isinstance(alpha, torch.Tensor):
alpha = torch.tensor(alpha, dtype=torch.float32)
# Compute cumulative_prior if not provided
if cumulative_prior is None:
cumulative_prior = torch.zeros(len(alpha) + 1, dtype=alpha.dtype, device=alpha.device)
cumulative_prior[1:] = torch.cumsum(alpha, dim=0)
# P_{k-1} for each predicted stage
P_k_minus_1 = cumulative_prior[stage_indices]
# ᾱ_k for each predicted stage
alpha_k = alpha[stage_indices]
# ŷ = P_{k-1} + ᾱ_k × τ̂
progress = P_k_minus_1.unsqueeze(-1) + alpha_k.unsqueeze(-1) * tau
return progress
def pad_state_to_max_dim(state: torch.Tensor, max_state_dim: int) -> torch.Tensor:
"""Pad the state tensor's last dimension to max_state_dim with zeros."""
current_dim = state.shape[-1]
if current_dim >= max_state_dim:
return state[..., :max_state_dim] # Truncate if larger
# Pad with zeros on the right
padding = (0, max_state_dim - current_dim) # (left, right) for last dim
return F.pad(state, padding, mode='constant', value=0)
+392
View File
@@ -0,0 +1,392 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Tests for SARM utility functions.
Tests the implementation of SARM paper formulas:
- Formula (1): compute_priors - dataset-level temporal proportions
- Formula (2): compute_tau, compute_cumulative_progress - progress labels
"""
import pytest
import numpy as np
import torch
from lerobot.policies.sarm.sarm_utils import (
compute_priors,
compute_tau,
compute_cumulative_progress_batch,
)
class TestComputePriors:
"""Tests for compute_priors (SARM Paper Formula 1).
Formula: ᾱ_k = (1/M) × Σ_i (L_{i,k} / T_i)
Key insight: This averages the PROPORTION of each subtask within each trajectory,
giving equal weight to all trajectories regardless of absolute length.
"""
def test_basic_two_trajectories_equal_proportions(self):
"""Test with two trajectories that have equal proportions."""
# Both trajectories: subtask1 = 50%, subtask2 = 50%
subtask_durations = {
'subtask1': [50, 100], # durations
'subtask2': [50, 100],
}
trajectory_lengths = {
'subtask1': [100, 200],
'subtask2': [100, 200],
}
subtask_names = ['subtask1', 'subtask2']
result = compute_priors(subtask_durations, trajectory_lengths, subtask_names)
# Both should be 0.5
assert abs(result['subtask1'] - 0.5) < 1e-6
assert abs(result['subtask2'] - 0.5) < 1e-6
def test_paper_example_different_from_avg_durations(self):
"""Test that compute_priors differs from naive average duration approach.
This is the key test showing the difference between:
- Paper formula: average of (L_i,k / T_i)
- Naive approach: mean(L_i,k) / sum(mean(L_i,j))
"""
# Episode 1: T=100, subtask1=80, subtask2=20 (proportions: 0.8, 0.2)
# Episode 2: T=200, subtask1=40, subtask2=160 (proportions: 0.2, 0.8)
subtask_durations = {
'subtask1': [80, 40],
'subtask2': [20, 160],
}
trajectory_lengths = {
'subtask1': [100, 200],
'subtask2': [100, 200],
}
subtask_names = ['subtask1', 'subtask2']
result = compute_priors(subtask_durations, trajectory_lengths, subtask_names)
# Paper formula:
# ᾱ_1 = (1/2) × (80/100 + 40/200) = (1/2) × (0.8 + 0.2) = 0.5
# ᾱ_2 = (1/2) × (20/100 + 160/200) = (1/2) × (0.2 + 0.8) = 0.5
assert abs(result['subtask1'] - 0.5) < 1e-6
assert abs(result['subtask2'] - 0.5) < 1e-6
def test_single_trajectory(self):
"""Test with a single trajectory."""
subtask_durations = {
'reach': [30],
'grasp': [20],
'lift': [50],
}
trajectory_lengths = {
'reach': [100],
'grasp': [100],
'lift': [100],
}
subtask_names = ['grasp', 'lift', 'reach'] # sorted order
result = compute_priors(subtask_durations, trajectory_lengths, subtask_names)
assert abs(result['reach'] - 0.3) < 1e-6
assert abs(result['grasp'] - 0.2) < 1e-6
assert abs(result['lift'] - 0.5) < 1e-6
def test_sum_to_one(self):
"""Test that proportions always sum to 1."""
subtask_durations = {
'a': [10, 20, 30],
'b': [40, 50, 60],
'c': [50, 30, 10],
}
trajectory_lengths = {
'a': [100, 100, 100],
'b': [100, 100, 100],
'c': [100, 100, 100],
}
subtask_names = ['a', 'b', 'c']
result = compute_priors(subtask_durations, trajectory_lengths, subtask_names)
total = sum(result.values())
assert abs(total - 1.0) < 1e-6
def test_empty_subtask_names_raises(self):
"""Test that empty subtask_names raises an error."""
with pytest.raises(ValueError, match="subtask_names cannot be empty"):
compute_priors({}, {}, [])
def test_missing_subtask_gets_zero_before_normalization(self):
"""Test handling of subtasks that appear in some but not all trajectories."""
# subtask1 appears in both, subtask2 only in first
subtask_durations = {
'subtask1': [50, 100],
'subtask2': [50], # only in first trajectory
}
trajectory_lengths = {
'subtask1': [100, 200],
'subtask2': [100],
}
subtask_names = ['subtask1', 'subtask2']
result = compute_priors(subtask_durations, trajectory_lengths, subtask_names)
# subtask1: (50/100 + 100/200) / 2 = (0.5 + 0.5) / 2 = 0.5
# subtask2: 50/100 = 0.5 (only one occurrence)
# After normalization: both should be 0.5
assert result['subtask1'] > 0
assert result['subtask2'] > 0
assert abs(sum(result.values()) - 1.0) < 1e-6
class TestComputeTau:
"""Tests for compute_tau (within-subtask progress).
Formula: τ_t = (t - s_k) / (e_k - s_k) [0, 1]
"""
def test_at_start(self):
"""τ should be 0 at subtask start."""
tau = compute_tau(current_frame=10, subtask_start=10, subtask_end=50)
assert tau == 0.0
def test_at_end(self):
"""τ should be 1 at subtask end."""
tau = compute_tau(current_frame=50, subtask_start=10, subtask_end=50)
assert tau == 1.0
def test_at_middle(self):
"""τ should be 0.5 at subtask midpoint."""
tau = compute_tau(current_frame=30, subtask_start=10, subtask_end=50)
assert abs(tau - 0.5) < 1e-6
def test_quarter_progress(self):
"""Test τ at 25% through subtask."""
tau = compute_tau(current_frame=20, subtask_start=0, subtask_end=80)
assert abs(tau - 0.25) < 1e-6
def test_zero_duration_subtask(self):
"""τ should be 1.0 for zero-duration subtask."""
tau = compute_tau(current_frame=10, subtask_start=10, subtask_end=10)
assert tau == 1.0
def test_clamps_below_zero(self):
"""τ should be clamped to 0 if frame is before subtask."""
tau = compute_tau(current_frame=5, subtask_start=10, subtask_end=50)
assert tau == 0.0
def test_clamps_above_one(self):
"""τ should be clamped to 1 if frame is after subtask."""
tau = compute_tau(current_frame=60, subtask_start=10, subtask_end=50)
assert tau == 1.0
def test_float_inputs(self):
"""Test with float frame indices (from interpolation)."""
tau = compute_tau(current_frame=25.5, subtask_start=10.0, subtask_end=50.0)
expected = (25.5 - 10.0) / (50.0 - 10.0)
assert abs(tau - expected) < 1e-6
class TestComputeCumulativeProgressBatchScalar:
"""Tests for compute_cumulative_progress_batch with scalar inputs (normalized progress y_t).
Formula: y_t = P_{k-1} + ᾱ_k × τ_t [0, 1]
"""
def test_first_subtask_start(self):
"""y should be 0 at start of first subtask."""
proportions = [0.3, 0.5, 0.2]
y = compute_cumulative_progress_batch(tau=0.0, stage_indices=0, alpha=proportions)
assert y == 0.0
def test_first_subtask_end(self):
"""y should equal ᾱ_1 at end of first subtask."""
proportions = [0.3, 0.5, 0.2]
y = compute_cumulative_progress_batch(tau=1.0, stage_indices=0, alpha=proportions)
assert abs(y - 0.3) < 1e-6
def test_second_subtask_start(self):
"""y should equal P_1 at start of second subtask."""
proportions = [0.3, 0.5, 0.2]
y = compute_cumulative_progress_batch(tau=0.0, stage_indices=1, alpha=proportions)
assert abs(y - 0.3) < 1e-6
def test_second_subtask_end(self):
"""y should equal P_2 at end of second subtask."""
proportions = [0.3, 0.5, 0.2]
y = compute_cumulative_progress_batch(tau=1.0, stage_indices=1, alpha=proportions)
assert abs(y - 0.8) < 1e-6 # 0.3 + 0.5
def test_third_subtask_end(self):
"""y should be 1.0 at end of last subtask."""
proportions = [0.3, 0.5, 0.2]
y = compute_cumulative_progress_batch(tau=1.0, stage_indices=2, alpha=proportions)
assert abs(y - 1.0) < 1e-6
def test_midpoint_of_subtask(self):
"""Test progress at midpoint of a subtask."""
proportions = [0.4, 0.6]
# At τ=0.5 in subtask 1: y = P_0 + ᾱ_1 × 0.5 = 0 + 0.4 × 0.5 = 0.2
y = compute_cumulative_progress_batch(tau=0.5, stage_indices=0, alpha=proportions)
assert abs(y - 0.2) < 1e-6
# At τ=0.5 in subtask 2: y = P_1 + ᾱ_2 × 0.5 = 0.4 + 0.6 × 0.5 = 0.7
y = compute_cumulative_progress_batch(tau=0.5, stage_indices=1, alpha=proportions)
assert abs(y - 0.7) < 1e-6
def test_uniform_proportions(self):
"""Test with uniform proportions."""
proportions = [0.25, 0.25, 0.25, 0.25]
# At end of each subtask, progress should be 0.25, 0.5, 0.75, 1.0
for i in range(4):
y = compute_cumulative_progress_batch(tau=1.0, stage_indices=i, alpha=proportions)
expected = (i + 1) * 0.25
assert abs(y - expected) < 1e-6
class TestComputeCumulativeProgressBatchTensor:
"""Tests for compute_cumulative_progress_batch with tensor inputs (GPU batch version)."""
def test_tensor_matches_scalar_version(self):
"""Test that tensor version matches scalar version."""
proportions = [0.3, 0.5, 0.2]
alpha = torch.tensor(proportions, dtype=torch.float32)
cumulative = torch.zeros(len(proportions) + 1, dtype=torch.float32)
cumulative[1:] = torch.cumsum(alpha, dim=0)
test_cases = [
(0.0, 0), # start of subtask 0
(1.0, 0), # end of subtask 0
(0.0, 1), # start of subtask 1
(0.5, 1), # middle of subtask 1
(1.0, 2), # end of subtask 2
]
for tau_val, stage_idx in test_cases:
# Scalar version
expected = compute_cumulative_progress_batch(tau_val, stage_idx, proportions)
# Tensor version (single element)
tau = torch.tensor([[[tau_val]]]) # (1, 1, 1)
stages = torch.tensor([[stage_idx]]) # (1, 1)
result = compute_cumulative_progress_batch(tau, stages, alpha, cumulative)
assert abs(result[0, 0, 0].item() - expected) < 1e-6
def test_batch_processing(self):
"""Test batch processing with multiple samples."""
proportions = [0.4, 0.6]
alpha = torch.tensor(proportions, dtype=torch.float32)
cumulative = torch.zeros(3, dtype=torch.float32)
cumulative[1:] = torch.cumsum(alpha, dim=0)
# Batch of 2 samples, sequence length 3
tau = torch.tensor([
[[0.0], [0.5], [1.0]], # sample 1
[[0.0], [0.5], [1.0]], # sample 2
])
stages = torch.tensor([
[0, 0, 0], # sample 1: all in subtask 0
[1, 1, 1], # sample 2: all in subtask 1
])
result = compute_cumulative_progress_batch(tau, stages, alpha, cumulative)
# Sample 1: subtask 0 with tau 0, 0.5, 1.0 -> y = 0, 0.2, 0.4
assert abs(result[0, 0, 0].item() - 0.0) < 1e-6
assert abs(result[0, 1, 0].item() - 0.2) < 1e-6
assert abs(result[0, 2, 0].item() - 0.4) < 1e-6
# Sample 2: subtask 1 with tau 0, 0.5, 1.0 -> y = 0.4, 0.7, 1.0
assert abs(result[1, 0, 0].item() - 0.4) < 1e-6
assert abs(result[1, 1, 0].item() - 0.7) < 1e-6
assert abs(result[1, 2, 0].item() - 1.0) < 1e-6
def test_auto_compute_cumulative_prior(self):
"""Test that cumulative_prior is auto-computed when not provided."""
proportions = [0.3, 0.5, 0.2]
alpha = torch.tensor(proportions, dtype=torch.float32)
tau = torch.tensor([[[0.5]]])
stages = torch.tensor([[1]])
# Without cumulative_prior (should auto-compute)
result = compute_cumulative_progress_batch(tau, stages, alpha)
# Expected: P_0 + alpha_1 * 0.5 = 0.3 + 0.5 * 0.5 = 0.55
assert abs(result[0, 0, 0].item() - 0.55) < 1e-6
class TestEndToEndProgressLabeling:
"""End-to-end tests for progress label computation."""
def test_consistent_semantic_meaning(self):
"""Test that same subtask completion maps to same progress across trajectories.
This is the key semantic property: "end of subtask 1" should always
mean the same progress value regardless of trajectory speed.
"""
proportions = [0.3, 0.5, 0.2]
# Fast trajectory: subtask 1 ends at frame 30 (of 100)
tau_fast = compute_tau(30, 0, 30) # = 1.0
y_fast = compute_cumulative_progress_batch(tau_fast, 0, proportions)
# Slow trajectory: subtask 1 ends at frame 90 (of 300)
tau_slow = compute_tau(90, 0, 90) # = 1.0
y_slow = compute_cumulative_progress_batch(tau_slow, 0, proportions)
# Both should map to same progress (0.3 = end of subtask 1)
assert abs(y_fast - y_slow) < 1e-6
assert abs(y_fast - 0.3) < 1e-6
def test_monotonic_within_subtask(self):
"""Test that progress is monotonically increasing within a subtask."""
proportions = [0.4, 0.6]
prev_y = -1
for tau in np.linspace(0, 1, 11):
y = compute_cumulative_progress_batch(tau, 0, proportions)
assert y > prev_y or (tau == 0 and y == 0)
prev_y = y
def test_continuous_across_subtasks(self):
"""Test that progress is continuous at subtask boundaries."""
proportions = [0.3, 0.5, 0.2]
# End of subtask 0 (tau=1.0)
y_end_0 = compute_cumulative_progress_batch(1.0, 0, proportions)
# Start of subtask 1 (tau=0.0)
y_start_1 = compute_cumulative_progress_batch(0.0, 1, proportions)
# Should be equal (P_1 = 0.3)
assert abs(y_end_0 - y_start_1) < 1e-6
# End of subtask 1
y_end_1 = compute_cumulative_progress_batch(1.0, 1, proportions)
# Start of subtask 2
y_start_2 = compute_cumulative_progress_batch(0.0, 2, proportions)
# Should be equal (P_2 = 0.8)
assert abs(y_end_1 - y_start_2) < 1e-6