add tests, implement formula 1,2 correctly and cleanup

This commit is contained in:
Pepijn
2025-11-27 14:04:01 +01:00
parent 3ed0425d2c
commit f2ad86831d
7 changed files with 861 additions and 274 deletions
+1 -7
View File
@@ -64,9 +64,7 @@ class SARMTemporalSampler(Sampler):
self.shuffle = shuffle self.shuffle = shuffle
self.samples_per_epoch = samples_per_epoch self.samples_per_epoch = samples_per_epoch
# Minimum frames needed for SARM pattern: # Minimum frames needed for SARM pattern: 8 consecutive frames with frame_gap spacing = 7 * frame_gap + 1
# 8 consecutive frames with frame_gap spacing = 7 * frame_gap + 1
# (Plus the initial frame which is always available)
self.min_frames_needed = 7 * frame_gap + 1 self.min_frames_needed = 7 * frame_gap + 1
if seed is not None: if seed is not None:
@@ -138,7 +136,3 @@ class SARMTemporalSampler(Sampler):
for i in range(self.samples_per_epoch): for i in range(self.samples_per_epoch):
idx = i % len(self.all_valid_positions) idx = i % len(self.all_valid_positions)
yield int(self.all_valid_positions[idx]) yield int(self.all_valid_positions[idx])
# Backwards compatibility alias
TemporalSequenceSampler = SARMTemporalSampler
-2
View File
@@ -18,7 +18,6 @@ from lerobot.policies.sarm.configuration_sarm import SARMConfig
from lerobot.policies.sarm.modeling_sarm import ( from lerobot.policies.sarm.modeling_sarm import (
SARMRewardModel, SARMRewardModel,
SARMTransformer, SARMTransformer,
compute_stage_loss,
) )
from lerobot.policies.sarm.processor_sarm import ( from lerobot.policies.sarm.processor_sarm import (
SARMEncodingProcessorStep, SARMEncodingProcessorStep,
@@ -29,7 +28,6 @@ __all__ = [
"SARMConfig", "SARMConfig",
"SARMRewardModel", "SARMRewardModel",
"SARMTransformer", "SARMTransformer",
"compute_stage_loss",
"SARMEncodingProcessorStep", "SARMEncodingProcessorStep",
"make_sarm_pre_post_processors", "make_sarm_pre_post_processors",
] ]
+45 -29
View File
@@ -17,7 +17,7 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import PolicyFeature, FeatureType from lerobot.configs.types import PolicyFeature, FeatureType, NormalizationMode
from lerobot.optim.optimizers import AdamWConfig from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
@@ -27,63 +27,83 @@ from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
class SARMConfig(PreTrainedConfig): class SARMConfig(PreTrainedConfig):
"""Configuration class for SARM (Stage-Aware Reward Modeling)""" """Configuration class for SARM (Stage-Aware Reward Modeling)"""
# Visual encoding parameters # CLIP encoding parameters
image_dim: int = 512 # CLIP embedding dimension image_dim: int = 512
text_dim: int = 512
num_frames: int = 9 # 1 initial + 8 consecutive frames num_frames: int = 9 # 1 initial + 8 consecutive frames
frame_gap: int = 30 # Frame gap between consecutive frames (at 30 fps = 1 second) frame_gap: int = 30 # Frame gap between consecutive frames (at 30 fps = 1 second)
# Text encoding parameters (CLIP text encoder output dimension)
text_dim: int = 512
# Joint state parameters
state_dim: int | None = None # Auto-detected from dataset if None
# Architecture parameters # Architecture parameters
hidden_dim: int = 768 hidden_dim: int = 768
num_heads: int = 12 num_heads: int = 12
num_layers: int = 8 num_layers: int = 8
num_stages: int = 5 # Number of task stages for classification (auto-updated from annotations if available) max_state_dim: int = 32
num_stages: int = 5 # Number of task stages (auto-updated from annotations if available)
subtask_names: list | None = None # List of subtask names (auto-populated from annotations) subtask_names: list | None = None # List of subtask names (auto-populated from annotations)
temporal_proportions: list | None = None # Temporal proportions for each stage (auto-computed from annotations) temporal_proportions: list | None = None # Temporal proportions for each stage (auto-computed from annotations)
# Temporal parameters
max_length: int = num_frames # Maximum video sequence length (matches num_frames) max_length: int = num_frames # Maximum video sequence length (matches num_frames)
use_temporal_sampler: bool = True # Always enable temporal sequence loading use_temporal_sampler: bool = True # Always enable temporal sequence loading
# Training parameters # Training parameters
batch_size: int = 64 batch_size: int = 64
clip_batch_size: int = 64 # Batch size for CLIP encoding clip_batch_size: int = 64 # Batch size for CLIP encoding
gradient_checkpointing: bool = False # Enable gradient checkpointing dropout: float = 0.1
dropout: float = 0.1 # Dropout rate
stage_loss_weight: float = 1.0 # Weight for stage classification loss when using subtask annotations stage_loss_weight: float = 1.0 # Weight for stage classification loss when using subtask annotations
pretrained_model_path: str | None = None pretrained_model_path: str | None = None
device: str | None = None device: str | None = None
# Processor settings # Processor settings
image_key: str = "observation.images.top" # Key for image used from the dataset image_key: str = "observation.images.top" # Key for image used from the dataset
task_description: str = "perform the task" # Default task description task_description: str = "perform the task" # Default task description
# Video_features and text_features are generated by the processor from raw images/text, we don't declare them as VISUAL/LANGUAGE here to avoid validation errors # State key in the dataset (for normalization)
input_features: dict = field(default_factory=lambda: { state_key: str = "observation.state"
"state_features": PolicyFeature(shape=(9, 14), type=FeatureType.STATE) # Example: 7 DOF × 2 arms
}) # Populated by the processor (video_features, state_features, text_features)
input_features: dict = field(default_factory=lambda: {})
# Output features
output_features: dict = field(default_factory=lambda: { output_features: dict = field(default_factory=lambda: {
"stage": PolicyFeature(shape=(1,), type=FeatureType.REWARD), "stage": PolicyFeature(shape=(9, 5), type=FeatureType.REWARD),
"progress": PolicyFeature(shape=(1,), type=FeatureType.REWARD) "progress": PolicyFeature(shape=(9, 1), type=FeatureType.REWARD),
}) })
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.MEAN_STD,
"LANGUAGE": NormalizationMode.IDENTITY,
"REWARD": NormalizationMode.IDENTITY,
}
)
def __post_init__(self): def __post_init__(self):
super().__post_init__() super().__post_init__()
# Add the image_key, the processor will transform this into video_features # Add the image_key as VISUAL (this is the raw image from dataset)
if self.image_key and self.image_key not in self.input_features: if self.image_key:
self.input_features[self.image_key] = PolicyFeature( self.input_features[self.image_key] = PolicyFeature(
shape=(480, 640, 3), shape=(480, 640, 3),
type=FeatureType.VISUAL type=FeatureType.VISUAL
) )
# Add state_key as STATE (raw state from dataset, will be padded to max_state_dim)
self.input_features[self.state_key] = PolicyFeature(
shape=(self.max_state_dim,), # Single frame state, temporal sampling handles sequence
type=FeatureType.STATE
)
# Update output features with actual dimensions
self.output_features["stage"] = PolicyFeature(
shape=(self.num_frames, self.num_stages),
type=FeatureType.REWARD
)
self.output_features["progress"] = PolicyFeature(
shape=(self.num_frames, 1),
type=FeatureType.REWARD
)
# Validate configuration # Validate configuration
if self.hidden_dim % self.num_heads != 0: if self.hidden_dim % self.num_heads != 0:
raise ValueError( raise ValueError(
@@ -95,9 +115,6 @@ class SARMConfig(PreTrainedConfig):
f"max_length ({self.max_length}) must equal num_frames ({self.num_frames})" f"max_length ({self.max_length}) must equal num_frames ({self.num_frames})"
) )
if self.dropout < 0 or self.dropout >= 1:
raise ValueError(f"dropout must be in [0, 1), got {self.dropout}")
if self.num_stages < 2: if self.num_stages < 2:
raise ValueError(f"num_stages must be at least 2, got {self.num_stages}") raise ValueError(f"num_stages must be at least 2, got {self.num_stages}")
@@ -139,11 +156,10 @@ class SARMConfig(PreTrainedConfig):
Returns: Returns:
9 delta indices: [-1_000_000, -(7*gap), -(6*gap), ..., -gap, 0] 9 delta indices: [-1_000_000, -(7*gap), -(6*gap), ..., -gap, 0]
""" """
# First delta: large negative to always clamp to episode start (frame 0)
initial_frame_delta = -1_000_000 initial_frame_delta = -1_000_000
# Remaining 8 deltas: consecutive frames with frame_gap spacing # Remaining consecutive frames with frame_gap spacing
num_consecutive = self.num_frames - 1 # 8 frames num_consecutive = self.num_frames - 1
consecutive_deltas = list(range(-self.frame_gap * (num_consecutive - 1), 1, self.frame_gap)) consecutive_deltas = list(range(-self.frame_gap * (num_consecutive - 1), 1, self.frame_gap))
return [initial_frame_delta] + consecutive_deltas return [initial_frame_delta] + consecutive_deltas
+60 -94
View File
@@ -15,7 +15,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import List, Union, Dict, Optional from typing import List, Union, Optional
import random import random
import numpy as np import numpy as np
@@ -28,9 +28,12 @@ from transformers import CLIPModel, CLIPProcessor
from torch import Tensor from torch import Tensor
from lerobot.policies.sarm.configuration_sarm import SARMConfig from lerobot.policies.sarm.configuration_sarm import SARMConfig
from lerobot.policies.sarm.sarm_utils import compute_priors, compute_cumulative_progress_batch, pad_state_to_max_dim
from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.pretrained import PreTrainedPolicy
class SARMTransformer(nn.Module): class SARMTransformer(nn.Module):
""" """
SARM Transformer model for stage-aware reward prediction. SARM Transformer model for stage-aware reward prediction.
@@ -45,8 +48,8 @@ class SARMTransformer(nn.Module):
def __init__( def __init__(
self, self,
video_dim: int = 512, video_dim: int = 512,
text_dim: int = 512, # CLIP text encoder output dimension (per SARM paper A.4) text_dim: int = 512,
state_dim: int = 14, max_state_dim: int = 32,
hidden_dim: int = 768, hidden_dim: int = 768,
num_heads: int = 12, num_heads: int = 12,
num_layers: int = 8, num_layers: int = 8,
@@ -59,9 +62,8 @@ class SARMTransformer(nn.Module):
self.hidden_dim = hidden_dim self.hidden_dim = hidden_dim
self.max_length = max_length self.max_length = max_length
self.num_stages = num_stages self.num_stages = num_stages
self.max_state_dim = max_state_dim
# Store temporal proportions for progress conversion (Paper Eq. 4)
# ŷ = P_{k-1} + ᾱ_k × τ̂
if temporal_proportions is None: if temporal_proportions is None:
raise ValueError( raise ValueError(
"temporal_proportions is required for SARM. " "temporal_proportions is required for SARM. "
@@ -77,15 +79,13 @@ class SARMTransformer(nn.Module):
self.register_buffer('alpha', alpha) self.register_buffer('alpha', alpha)
self.register_buffer('cumulative_prior', cumulative) self.register_buffer('cumulative_prior', cumulative)
# Project video, text, and state to same dimension
self.video_proj = nn.Linear(video_dim, hidden_dim) self.video_proj = nn.Linear(video_dim, hidden_dim)
self.text_proj = nn.Linear(text_dim, hidden_dim) self.text_proj = nn.Linear(text_dim, hidden_dim)
self.state_proj = nn.Linear(state_dim, hidden_dim) self.state_proj = nn.Linear(max_state_dim, hidden_dim)
# Position embedding only for the first frame # Position embedding only for the first frame
self.first_pos_embed = nn.Parameter(torch.randn(1, hidden_dim)) self.first_pos_embed = nn.Parameter(torch.randn(1, hidden_dim))
# Transformer encoder
encoder_layer = nn.TransformerEncoderLayer( encoder_layer = nn.TransformerEncoderLayer(
d_model=hidden_dim, d_model=hidden_dim,
nhead=num_heads, nhead=num_heads,
@@ -96,7 +96,6 @@ class SARMTransformer(nn.Module):
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
# Stage estimator head (classification) # Stage estimator head (classification)
# Paper A.4: "2 layers with hidden dimension of 512"
self.stage_head = nn.Sequential( self.stage_head = nn.Sequential(
nn.Linear(hidden_dim, 512), nn.Linear(hidden_dim, 512),
nn.LayerNorm(512), nn.LayerNorm(512),
@@ -106,8 +105,6 @@ class SARMTransformer(nn.Module):
) )
# Subtask estimator head (regression, conditioned on stage) # Subtask estimator head (regression, conditioned on stage)
# Takes concatenated [features, stage_embedding]
# Paper A.4: "2 layers with hidden dimension of 512"
self.stage_embedding = nn.Embedding(num_stages, hidden_dim // 4) self.stage_embedding = nn.Embedding(num_stages, hidden_dim // 4)
subtask_input_dim = hidden_dim + hidden_dim // 4 subtask_input_dim = hidden_dim + hidden_dim // 4
self.subtask_head = nn.Sequential( self.subtask_head = nn.Sequential(
@@ -119,13 +116,13 @@ class SARMTransformer(nn.Module):
nn.Sigmoid() nn.Sigmoid()
) )
# Attention mask for causal self-attention # Attention mask
self.register_buffer("attention_mask", None, persistent=False) self.register_buffer("attention_mask", None, persistent=False)
def _get_attention_mask(self, seq_length: int, device: torch.device) -> torch.Tensor: def _get_attention_mask(self, seq_length: int, device: torch.device) -> torch.Tensor:
"""Generate or retrieve cached causal attention mask.""" """Generate or retrieve cached causal attention mask."""
if self.attention_mask is None or self.attention_mask.shape[0] != seq_length: if self.attention_mask is None or self.attention_mask.shape[0] != seq_length:
# Create causal mask (upper triangular with -inf) # Create causal mask
mask = nn.Transformer.generate_square_subsequent_mask(seq_length, device=device) mask = nn.Transformer.generate_square_subsequent_mask(seq_length, device=device)
self.attention_mask = mask self.attention_mask = mask
return self.attention_mask return self.attention_mask
@@ -150,14 +147,16 @@ class SARMTransformer(nn.Module):
- Stage probabilities (batch_size, seq_len, num_stages) - Stage probabilities (batch_size, seq_len, num_stages)
- Progress predictions for each frame (batch_size, seq_len, 1) - Progress predictions for each frame (batch_size, seq_len, 1)
""" """
batch_size = video_frames.shape[0]
# Project inputs to common dimension # Project inputs to common dimension
video_embed = self.video_proj(video_frames) # [batch_size, seq_len, hidden_dim] video_embed = self.video_proj(video_frames) # [batch_size, seq_len, hidden_dim]
text_embed = self.text_proj(text_embed).unsqueeze(1) # [batch_size, 1, hidden_dim] text_embed = self.text_proj(text_embed).unsqueeze(1) # [batch_size, 1, hidden_dim]
state_embed = self.state_proj(state_features) # [batch_size, seq_len, hidden_dim] # Pad state features to max_state_dim before projection
# Fuse video and state features (simple addition) state_features_padded = pad_state_to_max_dim(state_features, self.max_state_dim)
state_embed = self.state_proj(state_features_padded) # [batch_size, seq_len, hidden_dim]
# Fuse video and state features
video_embed = video_embed + state_embed video_embed = video_embed + state_embed
# Add positional embedding to first video frame # Add positional embedding to first video frame
@@ -173,7 +172,7 @@ class SARMTransformer(nn.Module):
# Pass through transformer with causal masking # Pass through transformer with causal masking
transformed = self.transformer(sequence, mask=attention_mask, is_causal=True) transformed = self.transformer(sequence, mask=attention_mask, is_causal=True)
# Get frame features (exclude text token) # Get frame features
frame_features = transformed[:, 1:] # [batch_size, seq_len, hidden_dim] frame_features = transformed[:, 1:] # [batch_size, seq_len, hidden_dim]
# Stage estimation # Stage estimation
@@ -193,14 +192,11 @@ class SARMTransformer(nn.Module):
# τ̂ = within-subtask progress (0-1) # τ̂ = within-subtask progress (0-1)
tau_preds = self.subtask_head(conditioned_features) # [batch_size, seq_len, 1] tau_preds = self.subtask_head(conditioned_features) # [batch_size, seq_len, 1]
# Convert τ̂ to cumulative progress ŷ using Paper Eq. 4: # Convert τ̂ to cumulative progress ŷ using Paper Formula (2):
# ŷ = P_{k-1} + ᾱ_k × τ̂ # ŷ = P_{k-1} + ᾱ_k × τ̂
# P_{k-1} = cumulative prior up to stage k-1 progress_preds = compute_cumulative_progress_batch(
# ᾱ_k = temporal proportion of stage k tau_preds, stage_indices, self.alpha, self.cumulative_prior
P_k_minus_1 = self.cumulative_prior[stage_indices] # [batch_size, seq_len] )
alpha_k = self.alpha[stage_indices] # [batch_size, seq_len]
progress_preds = P_k_minus_1.unsqueeze(-1) + alpha_k.unsqueeze(-1) * tau_preds
return stage_logits, stage_probs, progress_preds return stage_logits, stage_probs, progress_preds
@@ -227,65 +223,37 @@ class SARMRewardModel(PreTrainedPolicy):
self.dataset_stats = dataset_stats self.dataset_stats = dataset_stats
self.device = torch.device(config.device if config.device else "cuda" if torch.cuda.is_available() else "cpu") self.device = torch.device(config.device if config.device else "cuda" if torch.cuda.is_available() else "cpu")
# Auto-detect num_stages from dataset annotations before building the model # Detect num_stages from dataset annotations before building the model
if dataset_meta is not None: if dataset_meta is not None:
self._update_num_stages_from_dataset(dataset_meta) self._update_num_stages_from_dataset(dataset_meta)
# Initialize CLIP encoder for images AND text (per SARM paper A.4) logging.info("Loading CLIP encoder")
logging.info("Loading CLIP encoder for images and text...")
self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", use_fast=True) self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", use_fast=True)
self.clip_model.to(self.device) self.clip_model.to(self.device)
self.clip_model.eval() self.clip_model.eval()
# Auto-detect state_dim from dataset_stats
if config.state_dim is None:
logging.info(f"Attempting to auto-detect state_dim. dataset_stats is None: {dataset_stats is None}")
if dataset_stats is not None:
if "observation.state" in dataset_stats:
config.state_dim = dataset_stats["observation.state"]["mean"].shape[0]
logging.info(f"Auto-detected state_dim={config.state_dim} from dataset_stats['observation.state']")
elif "state" in dataset_stats:
config.state_dim = dataset_stats["state"]["mean"].shape[0]
logging.info(f"Auto-detected state_dim={config.state_dim} from dataset_stats['state']")
else:
logging.warning(f"State keys not found in dataset_stats. Available keys: {list(dataset_stats.keys())}")
else:
logging.warning("dataset_stats is None, cannot auto-detect state_dim")
# Raise explicit error if still None
if config.state_dim is None:
raise ValueError(
"Could not determine state_dim! "
f"dataset_stats={'None' if dataset_stats is None else f'available with keys: {list(dataset_stats.keys())}'}, "
"config.state_dim=None. "
"Please either:\n"
"1. Provide --policy.state_dim=<your_state_dimension> explicitly, or\n"
"2. Ensure dataset_stats contains 'observation.state' or 'state' key"
)
# Initialize SARM transformer with temporal proportions for progress conversion
temporal_proportions = getattr(config, 'temporal_proportions', None)
self.sarm_transformer = SARMTransformer( self.sarm_transformer = SARMTransformer(
video_dim=config.image_dim, video_dim=config.image_dim,
text_dim=config.text_dim, text_dim=config.text_dim,
state_dim=config.state_dim, max_state_dim=config.max_state_dim,
hidden_dim=config.hidden_dim, hidden_dim=config.hidden_dim,
num_heads=config.num_heads, num_heads=config.num_heads,
num_layers=config.num_layers, num_layers=config.num_layers,
num_stages=config.num_stages, num_stages=config.num_stages,
max_length=config.max_length, max_length=config.max_length,
dropout=config.dropout, dropout=config.dropout,
temporal_proportions=temporal_proportions temporal_proportions=config.temporal_proportions
) )
self.sarm_transformer.to(self.device) self.sarm_transformer.to(self.device)
logging.info(f"SARM Reward Model initialized on {self.device}") logging.info(f"SARM Reward Model initialized on {self.device}")
def _update_num_stages_from_dataset(self, dataset_meta) -> None: def _update_num_stages_from_dataset(self, dataset_meta) -> None:
"""Update num_stages and temporal_proportions from dataset subtask annotations.""" """Update num_stages and temporal_proportions from dataset subtask annotations.
Implements SARM Paper Formula (1):
ᾱ_k = (1/M) × Σ_i (L_{i,k} / T_i)
"""
episodes = dataset_meta.episodes episodes = dataset_meta.episodes
if episodes is None or len(episodes) == 0: if episodes is None or len(episodes) == 0:
raise ValueError("No episodes found, using default num_stages") raise ValueError("No episodes found, using default num_stages")
@@ -295,27 +263,38 @@ class SARMRewardModel(PreTrainedPolicy):
episodes_df = episodes.to_pandas() episodes_df = episodes.to_pandas()
# Collect all unique subtask names and compute durations # Collect subtask durations and trajectory lengths for compute_priors
all_subtask_names = set() all_subtask_names = set()
subtask_durations = {} subtask_durations_per_trajectory = {}
trajectory_lengths = {}
for ep_idx in episodes_df.index: for ep_idx in episodes_df.index:
subtask_names = episodes_df.loc[ep_idx, 'subtask_names'] subtask_names_ep = episodes_df.loc[ep_idx, 'subtask_names']
if subtask_names is None or (isinstance(subtask_names, float) and pd.isna(subtask_names)): if subtask_names_ep is None or (isinstance(subtask_names_ep, float) and pd.isna(subtask_names_ep)):
continue continue
all_subtask_names.update(subtask_names) all_subtask_names.update(subtask_names_ep)
# Compute durations if available # Compute durations if available
if 'subtask_start_frames' in episodes_df.columns and 'subtask_end_frames' in episodes_df.columns: if 'subtask_start_frames' in episodes_df.columns and 'subtask_end_frames' in episodes_df.columns:
start_frames = episodes_df.loc[ep_idx, 'subtask_start_frames'] start_frames = episodes_df.loc[ep_idx, 'subtask_start_frames']
end_frames = episodes_df.loc[ep_idx, 'subtask_end_frames'] end_frames = episodes_df.loc[ep_idx, 'subtask_end_frames']
for i, name in enumerate(subtask_names): # Compute total trajectory length T_i
total_traj_length = sum(end_frames[i] - start_frames[i] for i in range(len(subtask_names_ep)))
if total_traj_length <= 0:
continue
for i, name in enumerate(subtask_names_ep):
duration = end_frames[i] - start_frames[i] duration = end_frames[i] - start_frames[i]
if name not in subtask_durations:
subtask_durations[name] = [] if name not in subtask_durations_per_trajectory:
subtask_durations[name].append(duration) subtask_durations_per_trajectory[name] = []
trajectory_lengths[name] = []
subtask_durations_per_trajectory[name].append(duration)
trajectory_lengths[name].append(total_traj_length)
if not all_subtask_names: if not all_subtask_names:
raise ValueError("No valid subtask names found, using default num_stages") raise ValueError("No valid subtask names found, using default num_stages")
@@ -324,26 +303,20 @@ class SARMRewardModel(PreTrainedPolicy):
subtask_names = sorted(list(all_subtask_names)) subtask_names = sorted(list(all_subtask_names))
num_stages = len(subtask_names) num_stages = len(subtask_names)
# Compute temporal proportions (Paper Eq. 1: ᾱ_k) # Compute temporal proportions using Paper Formula (1)
avg_durations = {} temporal_proportions_dict = compute_priors(
for name in subtask_names: subtask_durations_per_trajectory,
if name in subtask_durations and subtask_durations[name]: trajectory_lengths,
avg_durations[name] = np.mean(subtask_durations[name]) subtask_names
else: )
avg_durations[name] = 1.0 # Default temporal_proportions = [temporal_proportions_dict[name] for name in subtask_names]
total_duration = sum(avg_durations.values())
if total_duration > 0:
temporal_proportions = [avg_durations[name] / total_duration for name in subtask_names]
else:
temporal_proportions = [1.0 / num_stages] * num_stages
self.config.num_stages = num_stages self.config.num_stages = num_stages
self.config.subtask_names = subtask_names self.config.subtask_names = subtask_names
self.config.temporal_proportions = temporal_proportions self.config.temporal_proportions = temporal_proportions
logging.info(f"Auto-detected {num_stages} subtasks: {subtask_names}") logging.info(f"Auto-detected {num_stages} subtasks: {subtask_names}")
logging.info(f"Temporal proportions: {dict(zip(subtask_names, temporal_proportions))}") logging.info(f"Temporal proportions: {temporal_proportions_dict}")
def to(self, device): def to(self, device):
"""Override to method to ensure all components move together.""" """Override to method to ensure all components move together."""
@@ -475,7 +448,6 @@ class SARMRewardModel(PreTrainedPolicy):
If return_stages=True: If return_stages=True:
Tuple of (rewards, stage_probs) Tuple of (rewards, stage_probs)
""" """
# Convert to tensors if needed
if isinstance(text_embeddings, np.ndarray): if isinstance(text_embeddings, np.ndarray):
text_embeddings = torch.tensor(text_embeddings, dtype=torch.float32) text_embeddings = torch.tensor(text_embeddings, dtype=torch.float32)
if isinstance(video_embeddings, np.ndarray): if isinstance(video_embeddings, np.ndarray):
@@ -535,16 +507,13 @@ class SARMRewardModel(PreTrainedPolicy):
def load_pretrained_checkpoint(self, checkpoint_path: str, strict: bool = False): def load_pretrained_checkpoint(self, checkpoint_path: str, strict: bool = False):
"""Load pretrained model weights from a checkpoint file.""" """Load pretrained model weights from a checkpoint file."""
logging.info(f"Loading pretrained checkpoint from {checkpoint_path}") logging.info(f"Loading pretrained checkpoint from {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=False) checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=False)
# Handle different checkpoint formats
if "model_state_dict" in checkpoint: if "model_state_dict" in checkpoint:
state_dict = checkpoint["model_state_dict"] state_dict = checkpoint["model_state_dict"]
else: else:
state_dict = checkpoint state_dict = checkpoint
# Load only the SARMTransformer weights
missing_keys, unexpected_keys = self.sarm_transformer.load_state_dict(state_dict, strict=strict) missing_keys, unexpected_keys = self.sarm_transformer.load_state_dict(state_dict, strict=strict)
if missing_keys: if missing_keys:
@@ -557,9 +526,7 @@ class SARMRewardModel(PreTrainedPolicy):
def train(self, mode: bool = True): def train(self, mode: bool = True):
"""Set training mode. Note: CLIP encoder always stays in eval mode (frozen).""" """Set training mode. Note: CLIP encoder always stays in eval mode (frozen)."""
super().train(mode) super().train(mode)
# Keep CLIP encoder in eval mode (frozen per SARM paper)
self.clip_model.eval() self.clip_model.eval()
# Only transformer can be trained
self.sarm_transformer.train(mode) self.sarm_transformer.train(mode)
return self return self
@@ -686,8 +653,7 @@ class SARMRewardModel(PreTrainedPolicy):
state = state_features[i] if state_features is not None else None state = state_features[i] if state_features is not None else None
progress = progress_from_annotations[i].squeeze(-1) # (T,) progress = progress_from_annotations[i].squeeze(-1) # (T,)
# Apply temporal augmentation with 50% probability (SARM paper A.4) # Apply temporal augmentation with 50% probability: appends up to 4 reversed frames to simulate failures/recoveries
# Appends up to 4 reversed frames to simulate failures/recoveries
if random.random() < 0.5: if random.random() < 0.5:
video, progress, state = self._apply_temporal_augmentation(video, progress, state, max_length) video, progress, state = self._apply_temporal_augmentation(video, progress, state, max_length)
@@ -729,7 +695,7 @@ class SARMRewardModel(PreTrainedPolicy):
total_loss = total_loss + self.config.stage_loss_weight * stage_loss total_loss = total_loss + self.config.stage_loss_weight * stage_loss
output_dict['stage_loss'] = stage_loss.item() output_dict['stage_loss'] = stage_loss.item()
# Misaligned loss: 20% probability (SARM paper - improve video-language alignment) # Misaligned loss: 20% probability
if random.random() < 0.2: if random.random() < 0.2:
shuffle_idx = torch.randperm(batch_size, device=self.device) shuffle_idx = torch.randperm(batch_size, device=self.device)
_, _, misaligned_preds = self.sarm_transformer( _, _, misaligned_preds = self.sarm_transformer(
+108 -136
View File
@@ -23,16 +23,19 @@ import pandas as pd
from transformers import CLIPModel, CLIPProcessor from transformers import CLIPModel, CLIPProcessor
from lerobot.policies.sarm.configuration_sarm import SARMConfig from lerobot.policies.sarm.configuration_sarm import SARMConfig
from lerobot.policies.sarm.sarm_utils import compute_priors, compute_tau, compute_cumulative_progress_batch, pad_state_to_max_dim
from lerobot.processor import ( from lerobot.processor import (
ProcessorStep, ProcessorStep,
PolicyProcessorPipeline, PolicyProcessorPipeline,
PolicyAction, PolicyAction,
DeviceProcessorStep, DeviceProcessorStep,
AddBatchDimensionProcessorStep, AddBatchDimensionProcessorStep,
NormalizerProcessorStep,
) )
from lerobot.processor.converters import ( from lerobot.processor.converters import (
policy_action_to_transition, policy_action_to_transition,
transition_to_policy_action, transition_to_policy_action,
from_tensor_to_numpy,
) )
from lerobot.processor.pipeline import PipelineFeatureType from lerobot.processor.pipeline import PipelineFeatureType
from lerobot.processor.core import EnvTransition, TransitionKey from lerobot.processor.core import EnvTransition, TransitionKey
@@ -41,20 +44,7 @@ from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PR
class SARMEncodingProcessorStep(ProcessorStep): class SARMEncodingProcessorStep(ProcessorStep):
""" """ProcessorStep that encodes images and text with CLIP."""
ProcessorStep that encodes images and text for SARM training.
Per SARM paper (Appendix A.4): "We employ a frozen clip-vit-base-patch32 encoder
to process both RGB image sequences and task descriptions."
This step handles:
- CLIP image encoding (512-dim)
- CLIP text encoding (512-dim)
- Joint state normalization
Supports temporal sequences: (B, T, C, H, W) (B, T, 512) video features
"""
def __init__( def __init__(
self, self,
config: SARMConfig, config: SARMConfig,
@@ -69,8 +59,6 @@ class SARMEncodingProcessorStep(ProcessorStep):
self.task_description = task_description or config.task_description self.task_description = task_description or config.task_description
self.dataset_meta = dataset_meta self.dataset_meta = dataset_meta
self.dataset_stats = dataset_stats self.dataset_stats = dataset_stats
# Compute temporal proportions from subtask annotations if available
self.temporal_proportions = None self.temporal_proportions = None
self.subtask_names = None self.subtask_names = None
if dataset_meta is not None: if dataset_meta is not None:
@@ -94,7 +82,14 @@ class SARMEncodingProcessorStep(ProcessorStep):
self.device = device self.device = device
def _compute_temporal_proportions(self): def _compute_temporal_proportions(self):
"""Compute temporal proportions for each subtask from dataset annotations.""" """Compute temporal proportions for each subtask from dataset annotations.
Implements SARM Paper Formula (1):
ᾱ_k = (1/M) × Σ_i (L_{i,k} / T_i)
This averages the proportion of time spent on each subtask within each trajectory,
giving equal weight to all trajectories regardless of absolute length.
"""
if self.dataset_meta is None or not hasattr(self.dataset_meta, 'episodes'): if self.dataset_meta is None or not hasattr(self.dataset_meta, 'episodes'):
return return
@@ -108,32 +103,42 @@ class SARMEncodingProcessorStep(ProcessorStep):
logging.info("No subtask annotations found in dataset") logging.info("No subtask annotations found in dataset")
return return
# Convert to pandas
episodes_df = episodes.to_pandas() episodes_df = episodes.to_pandas()
# Collect all subtask names and compute average durations # Collect subtask durations and trajectory lengths for compute_priors
subtask_durations = {} subtask_durations_per_trajectory = {}
trajectory_lengths = {}
all_subtask_names = set() all_subtask_names = set()
for ep_idx in episodes_df.index: for ep_idx in episodes_df.index:
subtask_names = episodes_df.loc[ep_idx, 'subtask_names'] subtask_names_ep = episodes_df.loc[ep_idx, 'subtask_names']
# Skip episodes without annotations # Skip episodes without annotations
if subtask_names is None or (isinstance(subtask_names, float) and pd.isna(subtask_names)): if subtask_names_ep is None or (isinstance(subtask_names_ep, float) and pd.isna(subtask_names_ep)):
continue continue
start_times = episodes_df.loc[ep_idx, 'subtask_start_times'] start_times = episodes_df.loc[ep_idx, 'subtask_start_times']
end_times = episodes_df.loc[ep_idx, 'subtask_end_times'] end_times = episodes_df.loc[ep_idx, 'subtask_end_times']
# Track unique subtask names # Track unique subtask names
all_subtask_names.update(subtask_names) all_subtask_names.update(subtask_names_ep)
# Compute durations # Compute total trajectory length T_i (sum of all subtask durations)
for i, name in enumerate(subtask_names): total_traj_length = sum(end_times[i] - start_times[i] for i in range(len(subtask_names_ep)))
if total_traj_length <= 0:
continue
# Store duration and trajectory length for each subtask occurrence
for i, name in enumerate(subtask_names_ep):
duration = end_times[i] - start_times[i] duration = end_times[i] - start_times[i]
if name not in subtask_durations:
subtask_durations[name] = [] if name not in subtask_durations_per_trajectory:
subtask_durations[name].append(duration) subtask_durations_per_trajectory[name] = []
trajectory_lengths[name] = []
subtask_durations_per_trajectory[name].append(duration)
trajectory_lengths[name].append(total_traj_length)
if not all_subtask_names: if not all_subtask_names:
logging.info("No valid subtask annotations found") logging.info("No valid subtask annotations found")
@@ -142,44 +147,17 @@ class SARMEncodingProcessorStep(ProcessorStep):
# Sort subtask names for consistent ordering # Sort subtask names for consistent ordering
self.subtask_names = sorted(list(all_subtask_names)) self.subtask_names = sorted(list(all_subtask_names))
self.config.num_stages = len(self.subtask_names) self.config.num_stages = len(self.subtask_names)
self.config.subtask_names = self.subtask_names # Store in config for reference self.config.subtask_names = self.subtask_names
# Compute average duration for each subtask # Compute temporal proportions using Paper Formula (1)
avg_durations = {} self.temporal_proportions = compute_priors(
for name in self.subtask_names: subtask_durations_per_trajectory,
if name in subtask_durations: trajectory_lengths,
avg_durations[name] = np.mean(subtask_durations[name]) self.subtask_names
else:
avg_durations[name] = 0.0
# Normalize to get proportions
total_duration = sum(avg_durations.values())
if total_duration > 0:
self.temporal_proportions = {
name: avg_durations[name] / total_duration
for name in self.subtask_names
}
else:
raise ValueError(
"Cannot compute temporal proportions: all subtask durations are zero. "
"Check that your dataset has valid subtask annotations with start/end times."
) )
# Store in config for the model to use in progress output conversion (SARM paper Eq. 4)
self.config.temporal_proportions = [self.temporal_proportions[name] for name in self.subtask_names] self.config.temporal_proportions = [self.temporal_proportions[name] for name in self.subtask_names]
logging.info(f"Computed temporal proportions for {len(self.subtask_names)} subtasks: {self.temporal_proportions}") logging.info(f"Computed temporal proportions for {len(self.subtask_names)} subtasks: {self.temporal_proportions}")
def _to_numpy_array(self, x) -> np.ndarray:
"""Convert input to a 1D numpy array."""
if isinstance(x, torch.Tensor):
arr = x.cpu().numpy()
else:
arr = np.array(x)
if arr.ndim == 0:
arr = np.array([arr.item()])
return arr
def _find_episode_for_frame(self, frame_idx: int) -> int: def _find_episode_for_frame(self, frame_idx: int) -> int:
"""Find the episode index for a given frame index.""" """Find the episode index for a given frame index."""
for ep_idx in range(len(self.dataset_meta.episodes)): for ep_idx in range(len(self.dataset_meta.episodes)):
@@ -187,14 +165,14 @@ class SARMEncodingProcessorStep(ProcessorStep):
ep_end = self.dataset_meta.episodes[ep_idx]["dataset_to_index"] ep_end = self.dataset_meta.episodes[ep_idx]["dataset_to_index"]
if ep_start <= frame_idx < ep_end: if ep_start <= frame_idx < ep_end:
return ep_idx return ep_idx
return 0 # Fallback return 0
def _get_episode_indices(self, frame_indices: np.ndarray, episode_index) -> np.ndarray: def _get_episode_indices(self, frame_indices: np.ndarray, episode_index) -> np.ndarray:
"""Get episode indices for each frame index.""" """Get episode indices for each frame index."""
if episode_index is None: if episode_index is None:
return np.array([self._find_episode_for_frame(int(f)) for f in frame_indices]) return np.array([self._find_episode_for_frame(int(f)) for f in frame_indices])
episode_indices = self._to_numpy_array(episode_index) episode_indices = np.atleast_1d(np.asarray(from_tensor_to_numpy(episode_index)))
# If single episode but multiple frames, compute episode for each frame # If single episode but multiple frames, compute episode for each frame
if len(episode_indices) == 1 and len(frame_indices) > 1: if len(episode_indices) == 1 and len(frame_indices) > 1:
@@ -211,22 +189,16 @@ class SARMEncodingProcessorStep(ProcessorStep):
Pattern: [ep_start, t-(7*gap), t-(6*gap), ..., t-gap, t] Pattern: [ep_start, t-(7*gap), t-(6*gap), ..., t-gap, t]
""" """
frame_gap = getattr(self.config, 'frame_gap', 1)
indices = [] indices = []
indices.append(ep_start) # First frame is the episode's initial frame
# First frame is the episode's initial frame
indices.append(ep_start)
# Remaining frames are consecutive with frame_gap spacing # Remaining frames are consecutive with frame_gap spacing
num_consecutive = num_frames - 1 num_consecutive = num_frames - 1
for i in range(num_consecutive): for i in range(num_consecutive):
offset = -(num_consecutive - 1 - i) * frame_gap offset = -(num_consecutive - 1 - i) * self.config.frame_gap
idx = max(ep_start, frame_idx + offset) idx = max(ep_start, frame_idx + offset)
indices.append(idx) indices.append(idx)
return torch.tensor(indices) return torch.tensor(indices)
def _compute_episode_metadata( def _compute_episode_metadata(
@@ -269,6 +241,14 @@ class SARMEncodingProcessorStep(ProcessorStep):
) -> tuple[int, float]: ) -> tuple[int, float]:
"""Compute stage index and cumulative progress for a single frame. """Compute stage index and cumulative progress for a single frame.
Implements SARM Paper Formula (2):
y_t = P_{k-1} + ᾱ_k × τ_t
where:
- τ_t = (t - s_k) / (e_k - s_k) is within-subtask progress
- P_{k-1} is cumulative prior (sum of previous subtask proportions)
- ᾱ_k is the temporal proportion for subtask k
Args: Args:
current_frame: Frame index relative to episode start current_frame: Frame index relative to episode start
subtask_names: List of subtask names for this episode subtask_names: List of subtask names for this episode
@@ -278,53 +258,46 @@ class SARMEncodingProcessorStep(ProcessorStep):
Returns: Returns:
Tuple of (stage_idx, cumulative_progress) Tuple of (stage_idx, cumulative_progress)
""" """
stage_idx = -1 # Get temporal proportions as list for compute_cumulative_progress
cumulative_progress = 0.0 temporal_proportions_list = [
self.temporal_proportions.get(name, 0.0) for name in self.subtask_names
]
# Find which subtask this frame belongs to # Find which subtask this frame belongs to
for j, (name, start_frame, end_frame) in enumerate(zip(subtask_names, subtask_start_frames, subtask_end_frames)): for j, (name, start_frame, end_frame) in enumerate(zip(subtask_names, subtask_start_frames, subtask_end_frames)):
if current_frame >= start_frame and current_frame <= end_frame: if current_frame >= start_frame and current_frame <= end_frame:
# Found the subtask # Found the subtask, get its global index
stage_idx = self.subtask_names.index(name) if name in self.subtask_names else 0 stage_idx = self.subtask_names.index(name) if name in self.subtask_names else 0
# Calculate within-subtask progress # Compute τ_t using utility function (Paper Formula 2)
subtask_duration = end_frame - start_frame tau = compute_tau(current_frame, start_frame, end_frame)
if subtask_duration > 0:
within_subtask_progress = (current_frame - start_frame) / subtask_duration
else:
within_subtask_progress = 1.0
# Calculate cumulative progress from completed subtasks # Compute cumulative progress using utility function (Paper Formula 2)
for k in range(j): cumulative_progress = compute_cumulative_progress_batch(
prev_name = subtask_names[k] tau, stage_idx, temporal_proportions_list
if prev_name in self.temporal_proportions: )
cumulative_progress += self.temporal_proportions[prev_name]
# Add current subtask's partial progress
if name in self.temporal_proportions:
cumulative_progress += self.temporal_proportions[name] * within_subtask_progress
return stage_idx, cumulative_progress return stage_idx, cumulative_progress
# No matching subtask found - estimate based on position # No matching subtask found
if current_frame < subtask_start_frames[0]: if current_frame < subtask_start_frames[0]:
return 0, 0.0 return 0, 0.0
elif current_frame > subtask_end_frames[-1]: elif current_frame > subtask_end_frames[-1]:
return len(self.subtask_names) - 1, 1.0 return len(self.subtask_names) - 1, 1.0
else: else:
# Between subtasks - use previous subtask's end state # Between subtasks - use previous subtask's end state (tau = 1.0)
for j in range(len(subtask_names) - 1): for j in range(len(subtask_names) - 1):
if current_frame > subtask_end_frames[j] and current_frame < subtask_start_frames[j + 1]: if current_frame > subtask_end_frames[j] and current_frame < subtask_start_frames[j + 1]:
name = subtask_names[j] name = subtask_names[j]
stage_idx = self.subtask_names.index(name) if name in self.subtask_names else j stage_idx = self.subtask_names.index(name) if name in self.subtask_names else j
# Sum up all completed subtasks
for k in range(j + 1): # Completed subtask, so tau = 1.0
prev_name = subtask_names[k] cumulative_progress = compute_cumulative_progress_batch(
if prev_name in self.temporal_proportions: 1.0, stage_idx, temporal_proportions_list
cumulative_progress += self.temporal_proportions[prev_name] )
return stage_idx, cumulative_progress return stage_idx, cumulative_progress
return 0, 0.0 # Fallback return 0, 0.0
def _compute_labels_for_sample( def _compute_labels_for_sample(
self, self,
@@ -359,13 +332,8 @@ class SARMEncodingProcessorStep(ProcessorStep):
subtask_start_frames = episodes_df.loc[ep_idx, 'subtask_start_frames'] subtask_start_frames = episodes_df.loc[ep_idx, 'subtask_start_frames']
subtask_end_frames = episodes_df.loc[ep_idx, 'subtask_end_frames'] subtask_end_frames = episodes_df.loc[ep_idx, 'subtask_end_frames']
# Get episode boundaries
ep_start = self.dataset_meta.episodes[ep_idx]["dataset_from_index"] ep_start = self.dataset_meta.episodes[ep_idx]["dataset_from_index"]
# Get config values
frame_gap = self.config.frame_gap if hasattr(self.config, 'frame_gap') else 1
# Generate labels for each frame in the sequence # Generate labels for each frame in the sequence
stage_labels = [] stage_labels = []
progress_targets = [] progress_targets = []
@@ -377,7 +345,7 @@ class SARMEncodingProcessorStep(ProcessorStep):
else: else:
# Positions 1-8: consecutive frames with frame_gap spacing # Positions 1-8: consecutive frames with frame_gap spacing
num_consecutive = seq_len - 1 num_consecutive = seq_len - 1
offset = -(num_consecutive - i) * frame_gap offset = -(num_consecutive - i) * self.config.frame_gap
current_frame = max(0, frame_idx + offset - ep_start) current_frame = max(0, frame_idx + offset - ep_start)
@@ -388,7 +356,6 @@ class SARMEncodingProcessorStep(ProcessorStep):
stage_labels.append(stage_idx) stage_labels.append(stage_idx)
progress_targets.append(cumulative_progress) progress_targets.append(cumulative_progress)
# Convert to tensors
stage_labels = torch.tensor(stage_labels, dtype=torch.long) stage_labels = torch.tensor(stage_labels, dtype=torch.long)
progress_targets = torch.tensor(progress_targets, dtype=torch.float32).unsqueeze(-1) progress_targets = torch.tensor(progress_targets, dtype=torch.float32).unsqueeze(-1)
@@ -411,7 +378,7 @@ class SARMEncodingProcessorStep(ProcessorStep):
is_batch = isinstance(frame_index, torch.Tensor) and frame_index.numel() > 1 is_batch = isinstance(frame_index, torch.Tensor) and frame_index.numel() > 1
# Normalize inputs to numpy arrays # Normalize inputs to numpy arrays
frame_indices = self._to_numpy_array(frame_index) frame_indices = np.atleast_1d(np.asarray(from_tensor_to_numpy(frame_index)))
episode_indices = self._get_episode_indices(frame_indices, episode_index) episode_indices = self._get_episode_indices(frame_indices, episode_index)
# Determine sequence length # Determine sequence length
@@ -422,7 +389,6 @@ class SARMEncodingProcessorStep(ProcessorStep):
episodes_df = self.dataset_meta.episodes.to_pandas() episodes_df = self.dataset_meta.episodes.to_pandas()
# Process all samples
all_stage_labels = [] all_stage_labels = []
all_progress_targets = [] all_progress_targets = []
@@ -450,7 +416,7 @@ class SARMEncodingProcessorStep(ProcessorStep):
if not isinstance(observation, dict): if not isinstance(observation, dict):
raise ValueError("Observation must be a dictionary") raise ValueError("Observation must be a dictionary")
# 1. Encode images with CLIP # Encode images with CLIP
image = observation.get(self.image_key) image = observation.get(self.image_key)
if image is None: if image is None:
raise ValueError(f"Image not found in observation for key: {self.image_key}") raise ValueError(f"Image not found in observation for key: {self.image_key}")
@@ -460,27 +426,27 @@ class SARMEncodingProcessorStep(ProcessorStep):
video_features = self._encode_images_batch(image) video_features = self._encode_images_batch(image)
observation['video_features'] = video_features observation['video_features'] = video_features
# 2. Extract and normalize joint states # Extract state and pad to max_state_dim (already normalized by NormalizerProcessorStep)
state_key = self.config.state_key
state_data = observation.get(state_key)
if state_data is None:
state_data = observation.get("state") or observation.get("observation.state") state_data = observation.get("state") or observation.get("observation.state")
if state_data is None: if state_data is None:
raise ValueError("State data not found in observation (expected 'state' or 'observation.state')") raise ValueError(f"State data not found in observation (expected '{state_key}', 'state', or 'observation.state')")
if isinstance(state_data, torch.Tensor): if isinstance(state_data, torch.Tensor):
state_data = state_data.cpu().numpy() state_tensor = state_data.float()
else:
state_tensor = torch.tensor(state_data, dtype=torch.float32)
state_key = "state" if "state" in observation else "observation.state" # Pad state
if self.dataset_stats and state_key in self.dataset_stats: observation['state_features'] = pad_state_to_max_dim(state_tensor, self.config.max_state_dim)
mean = self.dataset_stats[state_key]['mean']
std = self.dataset_stats[state_key]['std']
state_data = (state_data - mean) / (std + 1e-8)
observation['state_features'] = torch.tensor(state_data, dtype=torch.float32) # Encode text with CLIP
# 3. Encode text with CLIP (per SARM paper A.4)
batch_size = video_features.shape[0] batch_size = video_features.shape[0]
observation['text_features'] = self._encode_text_clip(self.task_description, batch_size) observation['text_features'] = self._encode_text_clip(self.task_description, batch_size)
# 4. Extract frame/episode indices from complementary data # Extract frame/episode indices from complementary data
comp_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) comp_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
if not isinstance(comp_data, dict): if not isinstance(comp_data, dict):
raise ValueError("COMPLEMENTARY_DATA must be a dictionary") raise ValueError("COMPLEMENTARY_DATA must be a dictionary")
@@ -493,10 +459,10 @@ class SARMEncodingProcessorStep(ProcessorStep):
if episode_index is None: if episode_index is None:
raise ValueError("Episode index ('episode_index') not found in COMPLEMENTARY_DATA") raise ValueError("Episode index ('episode_index') not found in COMPLEMENTARY_DATA")
# 5. Compute episode metadata if dataset_meta is available # Compute episode metadata if dataset_meta is available
if self.dataset_meta is not None: if self.dataset_meta is not None:
is_batch = isinstance(frame_index, torch.Tensor) and frame_index.numel() > 1 is_batch = isinstance(frame_index, torch.Tensor) and frame_index.numel() > 1
frame_indices = self._to_numpy_array(frame_index) frame_indices = np.atleast_1d(np.asarray(from_tensor_to_numpy(frame_index)))
episode_indices = self._get_episode_indices(frame_indices, episode_index) episode_indices = self._get_episode_indices(frame_indices, episode_index)
# Determine number of frames from video features # Determine number of frames from video features
@@ -512,7 +478,7 @@ class SARMEncodingProcessorStep(ProcessorStep):
observation['remaining_length'] = remaining observation['remaining_length'] = remaining
observation['episode_length'] = ep_lengths observation['episode_length'] = ep_lengths
# 6. Generate stage labels and progress targets from subtask annotations # Generate stage labels and progress targets from subtask annotations
if self.temporal_proportions is not None and self.dataset_meta is not None: if self.temporal_proportions is not None and self.dataset_meta is not None:
stage_labels, progress_targets = self._generate_stage_and_progress_labels( stage_labels, progress_targets = self._generate_stage_and_progress_labels(
frame_index, episode_index, video_features frame_index, episode_index, video_features
@@ -539,14 +505,10 @@ class SARMEncodingProcessorStep(ProcessorStep):
# Check if we have temporal dimension # Check if we have temporal dimension
has_temporal = len(images.shape) == 5 has_temporal = len(images.shape) == 5
if has_temporal: if has_temporal: # Shape: (B, T, C, H, W)
# Shape: (B, T, C, H, W)
batch_size, seq_length = images.shape[0], images.shape[1] batch_size, seq_length = images.shape[0], images.shape[1]
# Reshape to (B*T, C, H, W) to process all frames at once
images = images.reshape(batch_size * seq_length, *images.shape[2:]) images = images.reshape(batch_size * seq_length, *images.shape[2:])
elif len(images.shape) == 4: elif len(images.shape) == 4: # Shape: (B, C, H, W)
# Shape: (B, C, H, W)
batch_size = images.shape[0] batch_size = images.shape[0]
seq_length = 1 seq_length = 1
else: else:
@@ -608,7 +570,7 @@ class SARMEncodingProcessorStep(ProcessorStep):
Returns: Returns:
Encoded text features with shape (B, 512) Encoded text features with shape (B, 512)
""" """
# Use CLIP's tokenizer directly for text (avoids image processor validation issues) # Use CLIP's tokenizer directly for text
tokenizer = self.clip_processor.tokenizer tokenizer = self.clip_processor.tokenizer
inputs = tokenizer([text], return_tensors="pt", padding=True, truncation=True) inputs = tokenizer([text], return_tensors="pt", padding=True, truncation=True)
inputs = {k: v.to(self.device) for k, v in inputs.items()} inputs = {k: v.to(self.device) for k, v in inputs.items()}
@@ -629,7 +591,7 @@ class SARMEncodingProcessorStep(ProcessorStep):
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
"""Add encoded features to the observation features.""" """Add encoded features to the observation features."""
# Add the encoded features # Add the encoded features (state uses max_state_dim, padded with zeros)
features[PipelineFeatureType.OBSERVATION]['video_features'] = PolicyFeature( features[PipelineFeatureType.OBSERVATION]['video_features'] = PolicyFeature(
type=FeatureType.VISUAL, type=FeatureType.VISUAL,
shape=(self.config.num_frames, self.config.image_dim) shape=(self.config.num_frames, self.config.image_dim)
@@ -640,7 +602,7 @@ class SARMEncodingProcessorStep(ProcessorStep):
) )
features[PipelineFeatureType.OBSERVATION]['state_features'] = PolicyFeature( features[PipelineFeatureType.OBSERVATION]['state_features'] = PolicyFeature(
type=FeatureType.STATE, type=FeatureType.STATE,
shape=(self.config.num_frames, self.config.state_dim) shape=(self.config.num_frames, self.config.max_state_dim)
) )
return features return features
@@ -660,11 +622,16 @@ def make_sarm_pre_post_processors(
to process both RGB image sequences and task descriptions." to process both RGB image sequences and task descriptions."
The pre-processing pipeline: The pre-processing pipeline:
1. Encodes images with CLIP (512-dim) 1. Adds batch dimension
2. Encodes text with CLIP (512-dim) 2. Normalizes observation.state using NormalizerProcessorStep (MEAN_STD)
3. Normalizes joint states 3. SARMEncodingProcessorStep:
4. Adds batch dimension - Encodes images with CLIP (512-dim)
5. Moves data to device - Pads states to max_state_dim
- Encodes text with CLIP (512-dim)
4. Moves data to device
The post-processing pipeline:
1. Moves data to CPU (no unnormalization - outputs are rewards)
Args: Args:
config: SARM configuration config: SARM configuration
@@ -676,6 +643,11 @@ def make_sarm_pre_post_processors(
""" """
input_steps = [ input_steps = [
AddBatchDimensionProcessorStep(), AddBatchDimensionProcessorStep(),
NormalizerProcessorStep(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
SARMEncodingProcessorStep( SARMEncodingProcessorStep(
config=config, config=config,
dataset_meta=dataset_meta, dataset_meta=dataset_meta,
+249
View File
@@ -0,0 +1,249 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Utility functions for SARM progress label computation.
Implements formulas from the SARM paper:
- Formula (1): Compute dataset-level temporal proportions (priors) ᾱ_k
- Formula (2): Compute normalized progress targets y_t = P_{k-1} + ᾱ_k × τ_t
"""
import numpy as np
import torch
import torch.nn.functional as F
from typing import Sequence
def compute_priors(
subtask_durations_per_trajectory: dict[str, list[float]],
trajectory_lengths: dict[str, list[float]],
subtask_names: list[str],
) -> dict[str, float]:
"""
Compute dataset-level temporal proportions (priors) for each subtask.
Implements SARM Paper Formula (1):
ᾱ_k = (1/M) × Σ_i (L_{i,k} / T_i)
where:
- M is the number of trajectories
- L_{i,k} is the length of subtask k in trajectory i
- T_i is the total length of trajectory i
This averages the PROPORTION of each subtask within each trajectory,
giving equal weight to all trajectories regardless of their absolute length.
Args:
subtask_durations_per_trajectory: Dict mapping subtask name to list of
(duration, trajectory_length) tuples for each occurrence
trajectory_lengths: Dict mapping subtask name to list of trajectory lengths
for each occurrence of that subtask
subtask_names: Ordered list of subtask names
Returns:
Dict mapping subtask name to its temporal proportion (ᾱ_k)
"""
if not subtask_names:
raise ValueError("subtask_names cannot be empty")
# Compute proportion per occurrence: L_{i,k} / T_i
subtask_proportions = {}
for name in subtask_names:
if name in subtask_durations_per_trajectory and name in trajectory_lengths:
durations = subtask_durations_per_trajectory[name]
traj_lengths = trajectory_lengths[name]
if len(durations) != len(traj_lengths):
raise ValueError(
f"Mismatch in lengths for subtask '{name}': "
f"{len(durations)} durations vs {len(traj_lengths)} trajectory lengths"
)
# Compute L_{i,k} / T_i for each occurrence
proportions = []
for duration, traj_len in zip(durations, traj_lengths):
if traj_len > 0:
proportions.append(duration / traj_len)
# Average across all occurrences: (1/M) × Σ_i (L_{i,k} / T_i)
subtask_proportions[name] = np.mean(proportions) if proportions else 0.0
else:
subtask_proportions[name] = 0.0
# Normalize to ensure sum = 1 (handles floating point errors and missing subtasks)
total = sum(subtask_proportions.values())
if total > 0:
subtask_proportions = {
name: prop / total for name, prop in subtask_proportions.items()
}
else:
raise ValueError("Cannot compute temporal proportions: all proportions are zero. "
"Check that your dataset has valid subtask annotations with start/end times.")
return subtask_proportions
def compute_tau(
current_frame: int | float,
subtask_start: int | float,
subtask_end: int | float,
) -> float:
"""
Compute within-subtask normalized time τ_t.
Implements part of SARM Paper Formula (2):
τ_t = (t - s_k) / (e_k - s_k) [0, 1]
where:
- t is the current frame
- s_k is the start frame of subtask k
- e_k is the end frame of subtask k
Args:
current_frame: Current frame index (t)
subtask_start: Start frame of the subtask (s_k)
subtask_end: End frame of the subtask (e_k)
Returns:
Within-subtask progress τ_t [0, 1]
"""
subtask_duration = subtask_end - subtask_start
if subtask_duration <= 0:
return 1.0
tau = (current_frame - subtask_start) / subtask_duration
return float(np.clip(tau, 0.0, 1.0))
def compute_cumulative_progress_batch(
tau: torch.Tensor | float,
stage_indices: torch.Tensor | int,
alpha: torch.Tensor | Sequence[float],
cumulative_prior: torch.Tensor | None = None,
) -> torch.Tensor | float:
"""
Compute cumulative normalized progress from within-subtask progress.
This function implements the core formula used in SARM for both:
**Formula 2 (Training labels):**
y_t = P_{k-1} + ᾱ_k × τ_t [0, 1]
Used to compute ground-truth progress labels from subtask annotations.
- τ_t comes from annotated frame position: τ_t = (t - s_k) / (e_k - s_k)
- k is the known subtask from annotations
**Formula 4 (Inference predictions):**
ŷ_{1:N} = P̂_{k-1, 1:N} + ᾱ_{k, 1:N} × τ̂_{1:N} [0, 1]
Used to convert model outputs to cumulative progress during inference.
- τ̂ comes from the subtask MLP head (conditioned on predicted stage)
- k = Ŝ is the predicted stage from Formula 3: Ŝ = argmax(softmax(Ψ))
The formulas are mathematically identical; only the source of inputs differs:
- Training: τ and k from annotations ground-truth labels
- Inference: τ̂ and Ŝ from model predicted progress
where:
- P_{k-1} = Σ_{j=1}^{k-1} ᾱ_j is the cumulative prior (sum of previous proportions)
- ᾱ_k is the temporal proportion for subtask k (from Formula 1)
- τ is within-subtask progress [0, 1]
This ensures:
- y at start of subtask k = P_{k-1}
- y at end of subtask k = P_k
Supports both scalar and batched tensor inputs:
- Scalar: tau (float), stage_indices (int), alpha (list/sequence)
- Batch: tau (Tensor), stage_indices (Tensor), alpha (Tensor), cumulative_prior (Tensor)
Args:
tau: Within-subtask progress τ [0, 1].
For training: computed from frame position in annotated subtask.
For inference: predicted by subtask MLP head.
Scalar float or Tensor with shape (..., 1)
stage_indices: Index of current subtask k (0-indexed).
For training: known from annotations.
For inference: predicted via argmax(stage_probs) (Formula 3).
Scalar int or Tensor with shape (...)
alpha: Temporal proportions with shape (num_stages,) or Sequence[float].
Computed from dataset annotations using Formula 1.
cumulative_prior: Optional. Cumulative priors P with shape (num_stages + 1,)
where cumulative_prior[k] = P_k = Σ_{j=1}^{k} ᾱ_j.
If None, will be computed from alpha.
Returns:
Cumulative progress y [0, 1].
Scalar float if inputs are scalar, otherwise Tensor with shape (..., 1)
"""
if not isinstance(tau, torch.Tensor):
if not alpha:
raise ValueError("alpha (temporal_proportions) cannot be empty")
if isinstance(alpha, torch.Tensor):
alpha_list = alpha.tolist()
else:
alpha_list = list(alpha)
if stage_indices < 0 or stage_indices >= len(alpha_list):
raise ValueError(
f"stage_indices {stage_indices} out of range "
f"for {len(alpha_list)} subtasks"
)
# P_{k-1} = sum of proportions for subtasks 0 to k-1
P_k_minus_1 = sum(alpha_list[:stage_indices])
# ᾱ_k = proportion for current subtask
alpha_k = alpha_list[stage_indices]
# y_t = P_{k-1} + ᾱ_k × τ_t
y_t = P_k_minus_1 + alpha_k * tau
return float(np.clip(y_t, 0.0, 1.0))
if not isinstance(alpha, torch.Tensor):
alpha = torch.tensor(alpha, dtype=torch.float32)
# Compute cumulative_prior if not provided
if cumulative_prior is None:
cumulative_prior = torch.zeros(len(alpha) + 1, dtype=alpha.dtype, device=alpha.device)
cumulative_prior[1:] = torch.cumsum(alpha, dim=0)
# P_{k-1} for each predicted stage
P_k_minus_1 = cumulative_prior[stage_indices]
# ᾱ_k for each predicted stage
alpha_k = alpha[stage_indices]
# ŷ = P_{k-1} + ᾱ_k × τ̂
progress = P_k_minus_1.unsqueeze(-1) + alpha_k.unsqueeze(-1) * tau
return progress
def pad_state_to_max_dim(state: torch.Tensor, max_state_dim: int) -> torch.Tensor:
"""Pad the state tensor's last dimension to max_state_dim with zeros."""
current_dim = state.shape[-1]
if current_dim >= max_state_dim:
return state[..., :max_state_dim] # Truncate if larger
# Pad with zeros on the right
padding = (0, max_state_dim - current_dim) # (left, right) for last dim
return F.pad(state, padding, mode='constant', value=0)
+392
View File
@@ -0,0 +1,392 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Tests for SARM utility functions.
Tests the implementation of SARM paper formulas:
- Formula (1): compute_priors - dataset-level temporal proportions
- Formula (2): compute_tau, compute_cumulative_progress - progress labels
"""
import pytest
import numpy as np
import torch
from lerobot.policies.sarm.sarm_utils import (
compute_priors,
compute_tau,
compute_cumulative_progress_batch,
)
class TestComputePriors:
"""Tests for compute_priors (SARM Paper Formula 1).
Formula: ᾱ_k = (1/M) × Σ_i (L_{i,k} / T_i)
Key insight: This averages the PROPORTION of each subtask within each trajectory,
giving equal weight to all trajectories regardless of absolute length.
"""
def test_basic_two_trajectories_equal_proportions(self):
"""Test with two trajectories that have equal proportions."""
# Both trajectories: subtask1 = 50%, subtask2 = 50%
subtask_durations = {
'subtask1': [50, 100], # durations
'subtask2': [50, 100],
}
trajectory_lengths = {
'subtask1': [100, 200],
'subtask2': [100, 200],
}
subtask_names = ['subtask1', 'subtask2']
result = compute_priors(subtask_durations, trajectory_lengths, subtask_names)
# Both should be 0.5
assert abs(result['subtask1'] - 0.5) < 1e-6
assert abs(result['subtask2'] - 0.5) < 1e-6
def test_paper_example_different_from_avg_durations(self):
"""Test that compute_priors differs from naive average duration approach.
This is the key test showing the difference between:
- Paper formula: average of (L_i,k / T_i)
- Naive approach: mean(L_i,k) / sum(mean(L_i,j))
"""
# Episode 1: T=100, subtask1=80, subtask2=20 (proportions: 0.8, 0.2)
# Episode 2: T=200, subtask1=40, subtask2=160 (proportions: 0.2, 0.8)
subtask_durations = {
'subtask1': [80, 40],
'subtask2': [20, 160],
}
trajectory_lengths = {
'subtask1': [100, 200],
'subtask2': [100, 200],
}
subtask_names = ['subtask1', 'subtask2']
result = compute_priors(subtask_durations, trajectory_lengths, subtask_names)
# Paper formula:
# ᾱ_1 = (1/2) × (80/100 + 40/200) = (1/2) × (0.8 + 0.2) = 0.5
# ᾱ_2 = (1/2) × (20/100 + 160/200) = (1/2) × (0.2 + 0.8) = 0.5
assert abs(result['subtask1'] - 0.5) < 1e-6
assert abs(result['subtask2'] - 0.5) < 1e-6
def test_single_trajectory(self):
"""Test with a single trajectory."""
subtask_durations = {
'reach': [30],
'grasp': [20],
'lift': [50],
}
trajectory_lengths = {
'reach': [100],
'grasp': [100],
'lift': [100],
}
subtask_names = ['grasp', 'lift', 'reach'] # sorted order
result = compute_priors(subtask_durations, trajectory_lengths, subtask_names)
assert abs(result['reach'] - 0.3) < 1e-6
assert abs(result['grasp'] - 0.2) < 1e-6
assert abs(result['lift'] - 0.5) < 1e-6
def test_sum_to_one(self):
"""Test that proportions always sum to 1."""
subtask_durations = {
'a': [10, 20, 30],
'b': [40, 50, 60],
'c': [50, 30, 10],
}
trajectory_lengths = {
'a': [100, 100, 100],
'b': [100, 100, 100],
'c': [100, 100, 100],
}
subtask_names = ['a', 'b', 'c']
result = compute_priors(subtask_durations, trajectory_lengths, subtask_names)
total = sum(result.values())
assert abs(total - 1.0) < 1e-6
def test_empty_subtask_names_raises(self):
"""Test that empty subtask_names raises an error."""
with pytest.raises(ValueError, match="subtask_names cannot be empty"):
compute_priors({}, {}, [])
def test_missing_subtask_gets_zero_before_normalization(self):
"""Test handling of subtasks that appear in some but not all trajectories."""
# subtask1 appears in both, subtask2 only in first
subtask_durations = {
'subtask1': [50, 100],
'subtask2': [50], # only in first trajectory
}
trajectory_lengths = {
'subtask1': [100, 200],
'subtask2': [100],
}
subtask_names = ['subtask1', 'subtask2']
result = compute_priors(subtask_durations, trajectory_lengths, subtask_names)
# subtask1: (50/100 + 100/200) / 2 = (0.5 + 0.5) / 2 = 0.5
# subtask2: 50/100 = 0.5 (only one occurrence)
# After normalization: both should be 0.5
assert result['subtask1'] > 0
assert result['subtask2'] > 0
assert abs(sum(result.values()) - 1.0) < 1e-6
class TestComputeTau:
"""Tests for compute_tau (within-subtask progress).
Formula: τ_t = (t - s_k) / (e_k - s_k) [0, 1]
"""
def test_at_start(self):
"""τ should be 0 at subtask start."""
tau = compute_tau(current_frame=10, subtask_start=10, subtask_end=50)
assert tau == 0.0
def test_at_end(self):
"""τ should be 1 at subtask end."""
tau = compute_tau(current_frame=50, subtask_start=10, subtask_end=50)
assert tau == 1.0
def test_at_middle(self):
"""τ should be 0.5 at subtask midpoint."""
tau = compute_tau(current_frame=30, subtask_start=10, subtask_end=50)
assert abs(tau - 0.5) < 1e-6
def test_quarter_progress(self):
"""Test τ at 25% through subtask."""
tau = compute_tau(current_frame=20, subtask_start=0, subtask_end=80)
assert abs(tau - 0.25) < 1e-6
def test_zero_duration_subtask(self):
"""τ should be 1.0 for zero-duration subtask."""
tau = compute_tau(current_frame=10, subtask_start=10, subtask_end=10)
assert tau == 1.0
def test_clamps_below_zero(self):
"""τ should be clamped to 0 if frame is before subtask."""
tau = compute_tau(current_frame=5, subtask_start=10, subtask_end=50)
assert tau == 0.0
def test_clamps_above_one(self):
"""τ should be clamped to 1 if frame is after subtask."""
tau = compute_tau(current_frame=60, subtask_start=10, subtask_end=50)
assert tau == 1.0
def test_float_inputs(self):
"""Test with float frame indices (from interpolation)."""
tau = compute_tau(current_frame=25.5, subtask_start=10.0, subtask_end=50.0)
expected = (25.5 - 10.0) / (50.0 - 10.0)
assert abs(tau - expected) < 1e-6
class TestComputeCumulativeProgressBatchScalar:
"""Tests for compute_cumulative_progress_batch with scalar inputs (normalized progress y_t).
Formula: y_t = P_{k-1} + ᾱ_k × τ_t [0, 1]
"""
def test_first_subtask_start(self):
"""y should be 0 at start of first subtask."""
proportions = [0.3, 0.5, 0.2]
y = compute_cumulative_progress_batch(tau=0.0, stage_indices=0, alpha=proportions)
assert y == 0.0
def test_first_subtask_end(self):
"""y should equal ᾱ_1 at end of first subtask."""
proportions = [0.3, 0.5, 0.2]
y = compute_cumulative_progress_batch(tau=1.0, stage_indices=0, alpha=proportions)
assert abs(y - 0.3) < 1e-6
def test_second_subtask_start(self):
"""y should equal P_1 at start of second subtask."""
proportions = [0.3, 0.5, 0.2]
y = compute_cumulative_progress_batch(tau=0.0, stage_indices=1, alpha=proportions)
assert abs(y - 0.3) < 1e-6
def test_second_subtask_end(self):
"""y should equal P_2 at end of second subtask."""
proportions = [0.3, 0.5, 0.2]
y = compute_cumulative_progress_batch(tau=1.0, stage_indices=1, alpha=proportions)
assert abs(y - 0.8) < 1e-6 # 0.3 + 0.5
def test_third_subtask_end(self):
"""y should be 1.0 at end of last subtask."""
proportions = [0.3, 0.5, 0.2]
y = compute_cumulative_progress_batch(tau=1.0, stage_indices=2, alpha=proportions)
assert abs(y - 1.0) < 1e-6
def test_midpoint_of_subtask(self):
"""Test progress at midpoint of a subtask."""
proportions = [0.4, 0.6]
# At τ=0.5 in subtask 1: y = P_0 + ᾱ_1 × 0.5 = 0 + 0.4 × 0.5 = 0.2
y = compute_cumulative_progress_batch(tau=0.5, stage_indices=0, alpha=proportions)
assert abs(y - 0.2) < 1e-6
# At τ=0.5 in subtask 2: y = P_1 + ᾱ_2 × 0.5 = 0.4 + 0.6 × 0.5 = 0.7
y = compute_cumulative_progress_batch(tau=0.5, stage_indices=1, alpha=proportions)
assert abs(y - 0.7) < 1e-6
def test_uniform_proportions(self):
"""Test with uniform proportions."""
proportions = [0.25, 0.25, 0.25, 0.25]
# At end of each subtask, progress should be 0.25, 0.5, 0.75, 1.0
for i in range(4):
y = compute_cumulative_progress_batch(tau=1.0, stage_indices=i, alpha=proportions)
expected = (i + 1) * 0.25
assert abs(y - expected) < 1e-6
class TestComputeCumulativeProgressBatchTensor:
"""Tests for compute_cumulative_progress_batch with tensor inputs (GPU batch version)."""
def test_tensor_matches_scalar_version(self):
"""Test that tensor version matches scalar version."""
proportions = [0.3, 0.5, 0.2]
alpha = torch.tensor(proportions, dtype=torch.float32)
cumulative = torch.zeros(len(proportions) + 1, dtype=torch.float32)
cumulative[1:] = torch.cumsum(alpha, dim=0)
test_cases = [
(0.0, 0), # start of subtask 0
(1.0, 0), # end of subtask 0
(0.0, 1), # start of subtask 1
(0.5, 1), # middle of subtask 1
(1.0, 2), # end of subtask 2
]
for tau_val, stage_idx in test_cases:
# Scalar version
expected = compute_cumulative_progress_batch(tau_val, stage_idx, proportions)
# Tensor version (single element)
tau = torch.tensor([[[tau_val]]]) # (1, 1, 1)
stages = torch.tensor([[stage_idx]]) # (1, 1)
result = compute_cumulative_progress_batch(tau, stages, alpha, cumulative)
assert abs(result[0, 0, 0].item() - expected) < 1e-6
def test_batch_processing(self):
"""Test batch processing with multiple samples."""
proportions = [0.4, 0.6]
alpha = torch.tensor(proportions, dtype=torch.float32)
cumulative = torch.zeros(3, dtype=torch.float32)
cumulative[1:] = torch.cumsum(alpha, dim=0)
# Batch of 2 samples, sequence length 3
tau = torch.tensor([
[[0.0], [0.5], [1.0]], # sample 1
[[0.0], [0.5], [1.0]], # sample 2
])
stages = torch.tensor([
[0, 0, 0], # sample 1: all in subtask 0
[1, 1, 1], # sample 2: all in subtask 1
])
result = compute_cumulative_progress_batch(tau, stages, alpha, cumulative)
# Sample 1: subtask 0 with tau 0, 0.5, 1.0 -> y = 0, 0.2, 0.4
assert abs(result[0, 0, 0].item() - 0.0) < 1e-6
assert abs(result[0, 1, 0].item() - 0.2) < 1e-6
assert abs(result[0, 2, 0].item() - 0.4) < 1e-6
# Sample 2: subtask 1 with tau 0, 0.5, 1.0 -> y = 0.4, 0.7, 1.0
assert abs(result[1, 0, 0].item() - 0.4) < 1e-6
assert abs(result[1, 1, 0].item() - 0.7) < 1e-6
assert abs(result[1, 2, 0].item() - 1.0) < 1e-6
def test_auto_compute_cumulative_prior(self):
"""Test that cumulative_prior is auto-computed when not provided."""
proportions = [0.3, 0.5, 0.2]
alpha = torch.tensor(proportions, dtype=torch.float32)
tau = torch.tensor([[[0.5]]])
stages = torch.tensor([[1]])
# Without cumulative_prior (should auto-compute)
result = compute_cumulative_progress_batch(tau, stages, alpha)
# Expected: P_0 + alpha_1 * 0.5 = 0.3 + 0.5 * 0.5 = 0.55
assert abs(result[0, 0, 0].item() - 0.55) < 1e-6
class TestEndToEndProgressLabeling:
"""End-to-end tests for progress label computation."""
def test_consistent_semantic_meaning(self):
"""Test that same subtask completion maps to same progress across trajectories.
This is the key semantic property: "end of subtask 1" should always
mean the same progress value regardless of trajectory speed.
"""
proportions = [0.3, 0.5, 0.2]
# Fast trajectory: subtask 1 ends at frame 30 (of 100)
tau_fast = compute_tau(30, 0, 30) # = 1.0
y_fast = compute_cumulative_progress_batch(tau_fast, 0, proportions)
# Slow trajectory: subtask 1 ends at frame 90 (of 300)
tau_slow = compute_tau(90, 0, 90) # = 1.0
y_slow = compute_cumulative_progress_batch(tau_slow, 0, proportions)
# Both should map to same progress (0.3 = end of subtask 1)
assert abs(y_fast - y_slow) < 1e-6
assert abs(y_fast - 0.3) < 1e-6
def test_monotonic_within_subtask(self):
"""Test that progress is monotonically increasing within a subtask."""
proportions = [0.4, 0.6]
prev_y = -1
for tau in np.linspace(0, 1, 11):
y = compute_cumulative_progress_batch(tau, 0, proportions)
assert y > prev_y or (tau == 0 and y == 0)
prev_y = y
def test_continuous_across_subtasks(self):
"""Test that progress is continuous at subtask boundaries."""
proportions = [0.3, 0.5, 0.2]
# End of subtask 0 (tau=1.0)
y_end_0 = compute_cumulative_progress_batch(1.0, 0, proportions)
# Start of subtask 1 (tau=0.0)
y_start_1 = compute_cumulative_progress_batch(0.0, 1, proportions)
# Should be equal (P_1 = 0.3)
assert abs(y_end_0 - y_start_1) < 1e-6
# End of subtask 1
y_end_1 = compute_cumulative_progress_batch(1.0, 1, proportions)
# Start of subtask 2
y_start_2 = compute_cumulative_progress_batch(0.0, 2, proportions)
# Should be equal (P_2 = 0.8)
assert abs(y_end_1 - y_start_2) < 1e-6