diff --git a/src/lerobot/datasets/temporal_sampler.py b/src/lerobot/datasets/temporal_sampler.py index de07942b2..895e96295 100644 --- a/src/lerobot/datasets/temporal_sampler.py +++ b/src/lerobot/datasets/temporal_sampler.py @@ -64,9 +64,7 @@ class SARMTemporalSampler(Sampler): self.shuffle = shuffle self.samples_per_epoch = samples_per_epoch - # Minimum frames needed for SARM pattern: - # 8 consecutive frames with frame_gap spacing = 7 * frame_gap + 1 - # (Plus the initial frame which is always available) + # Minimum frames needed for SARM pattern: 8 consecutive frames with frame_gap spacing = 7 * frame_gap + 1 self.min_frames_needed = 7 * frame_gap + 1 if seed is not None: @@ -138,7 +136,3 @@ class SARMTemporalSampler(Sampler): for i in range(self.samples_per_epoch): idx = i % len(self.all_valid_positions) yield int(self.all_valid_positions[idx]) - - -# Backwards compatibility alias -TemporalSequenceSampler = SARMTemporalSampler diff --git a/src/lerobot/policies/sarm/__init__.py b/src/lerobot/policies/sarm/__init__.py index c936e1632..36d896cd5 100644 --- a/src/lerobot/policies/sarm/__init__.py +++ b/src/lerobot/policies/sarm/__init__.py @@ -18,7 +18,6 @@ from lerobot.policies.sarm.configuration_sarm import SARMConfig from lerobot.policies.sarm.modeling_sarm import ( SARMRewardModel, SARMTransformer, - compute_stage_loss, ) from lerobot.policies.sarm.processor_sarm import ( SARMEncodingProcessorStep, @@ -29,7 +28,6 @@ __all__ = [ "SARMConfig", "SARMRewardModel", "SARMTransformer", - "compute_stage_loss", "SARMEncodingProcessorStep", "make_sarm_pre_post_processors", ] diff --git a/src/lerobot/policies/sarm/configuration_sarm.py b/src/lerobot/policies/sarm/configuration_sarm.py index 930cb4010..16ba5d6da 100644 --- a/src/lerobot/policies/sarm/configuration_sarm.py +++ b/src/lerobot/policies/sarm/configuration_sarm.py @@ -17,7 +17,7 @@ from dataclasses import dataclass, field from lerobot.configs.policies import PreTrainedConfig -from lerobot.configs.types import PolicyFeature, FeatureType +from lerobot.configs.types import PolicyFeature, FeatureType, NormalizationMode from lerobot.optim.optimizers import AdamWConfig from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig @@ -27,63 +27,83 @@ from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig class SARMConfig(PreTrainedConfig): """Configuration class for SARM (Stage-Aware Reward Modeling)""" - # Visual encoding parameters - image_dim: int = 512 # CLIP embedding dimension + # CLIP encoding parameters + image_dim: int = 512 + text_dim: int = 512 num_frames: int = 9 # 1 initial + 8 consecutive frames frame_gap: int = 30 # Frame gap between consecutive frames (at 30 fps = 1 second) - # Text encoding parameters (CLIP text encoder output dimension) - text_dim: int = 512 - - # Joint state parameters - state_dim: int | None = None # Auto-detected from dataset if None - # Architecture parameters hidden_dim: int = 768 num_heads: int = 12 num_layers: int = 8 - num_stages: int = 5 # Number of task stages for classification (auto-updated from annotations if available) + max_state_dim: int = 32 + num_stages: int = 5 # Number of task stages (auto-updated from annotations if available) subtask_names: list | None = None # List of subtask names (auto-populated from annotations) temporal_proportions: list | None = None # Temporal proportions for each stage (auto-computed from annotations) - - # Temporal parameters max_length: int = num_frames # Maximum video sequence length (matches num_frames) use_temporal_sampler: bool = True # Always enable temporal sequence loading # Training parameters batch_size: int = 64 clip_batch_size: int = 64 # Batch size for CLIP encoding - gradient_checkpointing: bool = False # Enable gradient checkpointing - dropout: float = 0.1 # Dropout rate + dropout: float = 0.1 stage_loss_weight: float = 1.0 # Weight for stage classification loss when using subtask annotations pretrained_model_path: str | None = None - device: str | None = None # Processor settings image_key: str = "observation.images.top" # Key for image used from the dataset task_description: str = "perform the task" # Default task description - # Video_features and text_features are generated by the processor from raw images/text, we don't declare them as VISUAL/LANGUAGE here to avoid validation errors - input_features: dict = field(default_factory=lambda: { - "state_features": PolicyFeature(shape=(9, 14), type=FeatureType.STATE) # Example: 7 DOF × 2 arms - }) + # State key in the dataset (for normalization) + state_key: str = "observation.state" + + # Populated by the processor (video_features, state_features, text_features) + input_features: dict = field(default_factory=lambda: {}) + + # Output features output_features: dict = field(default_factory=lambda: { - "stage": PolicyFeature(shape=(1,), type=FeatureType.REWARD), - "progress": PolicyFeature(shape=(1,), type=FeatureType.REWARD) + "stage": PolicyFeature(shape=(9, 5), type=FeatureType.REWARD), + "progress": PolicyFeature(shape=(9, 1), type=FeatureType.REWARD), }) + + normalization_mapping: dict[str, NormalizationMode] = field( + default_factory=lambda: { + "VISUAL": NormalizationMode.IDENTITY, + "STATE": NormalizationMode.MEAN_STD, + "LANGUAGE": NormalizationMode.IDENTITY, + "REWARD": NormalizationMode.IDENTITY, + } + ) def __post_init__(self): super().__post_init__() - - # Add the image_key, the processor will transform this into video_features - if self.image_key and self.image_key not in self.input_features: + + # Add the image_key as VISUAL (this is the raw image from dataset) + if self.image_key: self.input_features[self.image_key] = PolicyFeature( shape=(480, 640, 3), type=FeatureType.VISUAL ) + # Add state_key as STATE (raw state from dataset, will be padded to max_state_dim) + self.input_features[self.state_key] = PolicyFeature( + shape=(self.max_state_dim,), # Single frame state, temporal sampling handles sequence + type=FeatureType.STATE + ) + + # Update output features with actual dimensions + self.output_features["stage"] = PolicyFeature( + shape=(self.num_frames, self.num_stages), + type=FeatureType.REWARD + ) + self.output_features["progress"] = PolicyFeature( + shape=(self.num_frames, 1), + type=FeatureType.REWARD + ) + # Validate configuration if self.hidden_dim % self.num_heads != 0: raise ValueError( @@ -95,9 +115,6 @@ class SARMConfig(PreTrainedConfig): f"max_length ({self.max_length}) must equal num_frames ({self.num_frames})" ) - if self.dropout < 0 or self.dropout >= 1: - raise ValueError(f"dropout must be in [0, 1), got {self.dropout}") - if self.num_stages < 2: raise ValueError(f"num_stages must be at least 2, got {self.num_stages}") @@ -139,11 +156,10 @@ class SARMConfig(PreTrainedConfig): Returns: 9 delta indices: [-1_000_000, -(7*gap), -(6*gap), ..., -gap, 0] """ - # First delta: large negative to always clamp to episode start (frame 0) initial_frame_delta = -1_000_000 - # Remaining 8 deltas: consecutive frames with frame_gap spacing - num_consecutive = self.num_frames - 1 # 8 frames + # Remaining consecutive frames with frame_gap spacing + num_consecutive = self.num_frames - 1 consecutive_deltas = list(range(-self.frame_gap * (num_consecutive - 1), 1, self.frame_gap)) return [initial_frame_delta] + consecutive_deltas diff --git a/src/lerobot/policies/sarm/modeling_sarm.py b/src/lerobot/policies/sarm/modeling_sarm.py index 47e63ec09..d00e2aab1 100644 --- a/src/lerobot/policies/sarm/modeling_sarm.py +++ b/src/lerobot/policies/sarm/modeling_sarm.py @@ -15,7 +15,7 @@ # limitations under the License. import logging -from typing import List, Union, Dict, Optional +from typing import List, Union, Optional import random import numpy as np @@ -28,9 +28,12 @@ from transformers import CLIPModel, CLIPProcessor from torch import Tensor from lerobot.policies.sarm.configuration_sarm import SARMConfig +from lerobot.policies.sarm.sarm_utils import compute_priors, compute_cumulative_progress_batch, pad_state_to_max_dim from lerobot.policies.pretrained import PreTrainedPolicy + + class SARMTransformer(nn.Module): """ SARM Transformer model for stage-aware reward prediction. @@ -45,8 +48,8 @@ class SARMTransformer(nn.Module): def __init__( self, video_dim: int = 512, - text_dim: int = 512, # CLIP text encoder output dimension (per SARM paper A.4) - state_dim: int = 14, + text_dim: int = 512, + max_state_dim: int = 32, hidden_dim: int = 768, num_heads: int = 12, num_layers: int = 8, @@ -59,9 +62,8 @@ class SARMTransformer(nn.Module): self.hidden_dim = hidden_dim self.max_length = max_length self.num_stages = num_stages + self.max_state_dim = max_state_dim - # Store temporal proportions for progress conversion (Paper Eq. 4) - # ŷ = P_{k-1} + ᾱ_k × τ̂ if temporal_proportions is None: raise ValueError( "temporal_proportions is required for SARM. " @@ -77,15 +79,13 @@ class SARMTransformer(nn.Module): self.register_buffer('alpha', alpha) self.register_buffer('cumulative_prior', cumulative) - # Project video, text, and state to same dimension self.video_proj = nn.Linear(video_dim, hidden_dim) self.text_proj = nn.Linear(text_dim, hidden_dim) - self.state_proj = nn.Linear(state_dim, hidden_dim) + self.state_proj = nn.Linear(max_state_dim, hidden_dim) # Position embedding only for the first frame self.first_pos_embed = nn.Parameter(torch.randn(1, hidden_dim)) - # Transformer encoder encoder_layer = nn.TransformerEncoderLayer( d_model=hidden_dim, nhead=num_heads, @@ -96,7 +96,6 @@ class SARMTransformer(nn.Module): self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) # Stage estimator head (classification) - # Paper A.4: "2 layers with hidden dimension of 512" self.stage_head = nn.Sequential( nn.Linear(hidden_dim, 512), nn.LayerNorm(512), @@ -106,8 +105,6 @@ class SARMTransformer(nn.Module): ) # Subtask estimator head (regression, conditioned on stage) - # Takes concatenated [features, stage_embedding] - # Paper A.4: "2 layers with hidden dimension of 512" self.stage_embedding = nn.Embedding(num_stages, hidden_dim // 4) subtask_input_dim = hidden_dim + hidden_dim // 4 self.subtask_head = nn.Sequential( @@ -119,13 +116,13 @@ class SARMTransformer(nn.Module): nn.Sigmoid() ) - # Attention mask for causal self-attention + # Attention mask self.register_buffer("attention_mask", None, persistent=False) def _get_attention_mask(self, seq_length: int, device: torch.device) -> torch.Tensor: """Generate or retrieve cached causal attention mask.""" if self.attention_mask is None or self.attention_mask.shape[0] != seq_length: - # Create causal mask (upper triangular with -inf) + # Create causal mask mask = nn.Transformer.generate_square_subsequent_mask(seq_length, device=device) self.attention_mask = mask return self.attention_mask @@ -149,15 +146,17 @@ class SARMTransformer(nn.Module): - Stage logits for each frame (batch_size, seq_len, num_stages) - Stage probabilities (batch_size, seq_len, num_stages) - Progress predictions for each frame (batch_size, seq_len, 1) - """ - batch_size = video_frames.shape[0] - + """ # Project inputs to common dimension video_embed = self.video_proj(video_frames) # [batch_size, seq_len, hidden_dim] text_embed = self.text_proj(text_embed).unsqueeze(1) # [batch_size, 1, hidden_dim] - state_embed = self.state_proj(state_features) # [batch_size, seq_len, hidden_dim] - # Fuse video and state features (simple addition) + # Pad state features to max_state_dim before projection + state_features_padded = pad_state_to_max_dim(state_features, self.max_state_dim) + + state_embed = self.state_proj(state_features_padded) # [batch_size, seq_len, hidden_dim] + + # Fuse video and state features video_embed = video_embed + state_embed # Add positional embedding to first video frame @@ -173,7 +172,7 @@ class SARMTransformer(nn.Module): # Pass through transformer with causal masking transformed = self.transformer(sequence, mask=attention_mask, is_causal=True) - # Get frame features (exclude text token) + # Get frame features frame_features = transformed[:, 1:] # [batch_size, seq_len, hidden_dim] # Stage estimation @@ -193,14 +192,11 @@ class SARMTransformer(nn.Module): # τ̂ = within-subtask progress (0-1) tau_preds = self.subtask_head(conditioned_features) # [batch_size, seq_len, 1] - # Convert τ̂ to cumulative progress ŷ using Paper Eq. 4: + # Convert τ̂ to cumulative progress ŷ using Paper Formula (2): # ŷ = P_{k-1} + ᾱ_k × τ̂ - # P_{k-1} = cumulative prior up to stage k-1 - # ᾱ_k = temporal proportion of stage k - P_k_minus_1 = self.cumulative_prior[stage_indices] # [batch_size, seq_len] - alpha_k = self.alpha[stage_indices] # [batch_size, seq_len] - - progress_preds = P_k_minus_1.unsqueeze(-1) + alpha_k.unsqueeze(-1) * tau_preds + progress_preds = compute_cumulative_progress_batch( + tau_preds, stage_indices, self.alpha, self.cumulative_prior + ) return stage_logits, stage_probs, progress_preds @@ -227,65 +223,37 @@ class SARMRewardModel(PreTrainedPolicy): self.dataset_stats = dataset_stats self.device = torch.device(config.device if config.device else "cuda" if torch.cuda.is_available() else "cpu") - # Auto-detect num_stages from dataset annotations before building the model + # Detect num_stages from dataset annotations before building the model if dataset_meta is not None: self._update_num_stages_from_dataset(dataset_meta) - # Initialize CLIP encoder for images AND text (per SARM paper A.4) - logging.info("Loading CLIP encoder for images and text...") + logging.info("Loading CLIP encoder") self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", use_fast=True) self.clip_model.to(self.device) self.clip_model.eval() - # Auto-detect state_dim from dataset_stats - if config.state_dim is None: - logging.info(f"Attempting to auto-detect state_dim. dataset_stats is None: {dataset_stats is None}") - - if dataset_stats is not None: - if "observation.state" in dataset_stats: - config.state_dim = dataset_stats["observation.state"]["mean"].shape[0] - logging.info(f"Auto-detected state_dim={config.state_dim} from dataset_stats['observation.state']") - elif "state" in dataset_stats: - config.state_dim = dataset_stats["state"]["mean"].shape[0] - logging.info(f"Auto-detected state_dim={config.state_dim} from dataset_stats['state']") - else: - logging.warning(f"State keys not found in dataset_stats. Available keys: {list(dataset_stats.keys())}") - else: - logging.warning("dataset_stats is None, cannot auto-detect state_dim") - - # Raise explicit error if still None - if config.state_dim is None: - raise ValueError( - "Could not determine state_dim! " - f"dataset_stats={'None' if dataset_stats is None else f'available with keys: {list(dataset_stats.keys())}'}, " - "config.state_dim=None. " - "Please either:\n" - "1. Provide --policy.state_dim= explicitly, or\n" - "2. Ensure dataset_stats contains 'observation.state' or 'state' key" - ) - - # Initialize SARM transformer with temporal proportions for progress conversion - temporal_proportions = getattr(config, 'temporal_proportions', None) self.sarm_transformer = SARMTransformer( video_dim=config.image_dim, text_dim=config.text_dim, - state_dim=config.state_dim, + max_state_dim=config.max_state_dim, hidden_dim=config.hidden_dim, num_heads=config.num_heads, num_layers=config.num_layers, num_stages=config.num_stages, max_length=config.max_length, dropout=config.dropout, - temporal_proportions=temporal_proportions + temporal_proportions=config.temporal_proportions ) self.sarm_transformer.to(self.device) - - logging.info(f"SARM Reward Model initialized on {self.device}") def _update_num_stages_from_dataset(self, dataset_meta) -> None: - """Update num_stages and temporal_proportions from dataset subtask annotations.""" + """Update num_stages and temporal_proportions from dataset subtask annotations. + + Implements SARM Paper Formula (1): + ᾱ_k = (1/M) × Σ_i (L_{i,k} / T_i) + """ episodes = dataset_meta.episodes if episodes is None or len(episodes) == 0: raise ValueError("No episodes found, using default num_stages") @@ -295,27 +263,38 @@ class SARMRewardModel(PreTrainedPolicy): episodes_df = episodes.to_pandas() - # Collect all unique subtask names and compute durations + # Collect subtask durations and trajectory lengths for compute_priors all_subtask_names = set() - subtask_durations = {} + subtask_durations_per_trajectory = {} + trajectory_lengths = {} for ep_idx in episodes_df.index: - subtask_names = episodes_df.loc[ep_idx, 'subtask_names'] - if subtask_names is None or (isinstance(subtask_names, float) and pd.isna(subtask_names)): + subtask_names_ep = episodes_df.loc[ep_idx, 'subtask_names'] + if subtask_names_ep is None or (isinstance(subtask_names_ep, float) and pd.isna(subtask_names_ep)): continue - all_subtask_names.update(subtask_names) + all_subtask_names.update(subtask_names_ep) # Compute durations if available if 'subtask_start_frames' in episodes_df.columns and 'subtask_end_frames' in episodes_df.columns: start_frames = episodes_df.loc[ep_idx, 'subtask_start_frames'] end_frames = episodes_df.loc[ep_idx, 'subtask_end_frames'] - for i, name in enumerate(subtask_names): + # Compute total trajectory length T_i + total_traj_length = sum(end_frames[i] - start_frames[i] for i in range(len(subtask_names_ep))) + + if total_traj_length <= 0: + continue + + for i, name in enumerate(subtask_names_ep): duration = end_frames[i] - start_frames[i] - if name not in subtask_durations: - subtask_durations[name] = [] - subtask_durations[name].append(duration) + + if name not in subtask_durations_per_trajectory: + subtask_durations_per_trajectory[name] = [] + trajectory_lengths[name] = [] + + subtask_durations_per_trajectory[name].append(duration) + trajectory_lengths[name].append(total_traj_length) if not all_subtask_names: raise ValueError("No valid subtask names found, using default num_stages") @@ -324,26 +303,20 @@ class SARMRewardModel(PreTrainedPolicy): subtask_names = sorted(list(all_subtask_names)) num_stages = len(subtask_names) - # Compute temporal proportions (Paper Eq. 1: ᾱ_k) - avg_durations = {} - for name in subtask_names: - if name in subtask_durations and subtask_durations[name]: - avg_durations[name] = np.mean(subtask_durations[name]) - else: - avg_durations[name] = 1.0 # Default - - total_duration = sum(avg_durations.values()) - if total_duration > 0: - temporal_proportions = [avg_durations[name] / total_duration for name in subtask_names] - else: - temporal_proportions = [1.0 / num_stages] * num_stages + # Compute temporal proportions using Paper Formula (1) + temporal_proportions_dict = compute_priors( + subtask_durations_per_trajectory, + trajectory_lengths, + subtask_names + ) + temporal_proportions = [temporal_proportions_dict[name] for name in subtask_names] self.config.num_stages = num_stages self.config.subtask_names = subtask_names self.config.temporal_proportions = temporal_proportions logging.info(f"Auto-detected {num_stages} subtasks: {subtask_names}") - logging.info(f"Temporal proportions: {dict(zip(subtask_names, temporal_proportions))}") + logging.info(f"Temporal proportions: {temporal_proportions_dict}") def to(self, device): """Override to method to ensure all components move together.""" @@ -475,7 +448,6 @@ class SARMRewardModel(PreTrainedPolicy): If return_stages=True: Tuple of (rewards, stage_probs) """ - # Convert to tensors if needed if isinstance(text_embeddings, np.ndarray): text_embeddings = torch.tensor(text_embeddings, dtype=torch.float32) if isinstance(video_embeddings, np.ndarray): @@ -535,16 +507,13 @@ class SARMRewardModel(PreTrainedPolicy): def load_pretrained_checkpoint(self, checkpoint_path: str, strict: bool = False): """Load pretrained model weights from a checkpoint file.""" logging.info(f"Loading pretrained checkpoint from {checkpoint_path}") - checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=False) - # Handle different checkpoint formats if "model_state_dict" in checkpoint: state_dict = checkpoint["model_state_dict"] else: state_dict = checkpoint - - # Load only the SARMTransformer weights + missing_keys, unexpected_keys = self.sarm_transformer.load_state_dict(state_dict, strict=strict) if missing_keys: @@ -557,9 +526,7 @@ class SARMRewardModel(PreTrainedPolicy): def train(self, mode: bool = True): """Set training mode. Note: CLIP encoder always stays in eval mode (frozen).""" super().train(mode) - # Keep CLIP encoder in eval mode (frozen per SARM paper) self.clip_model.eval() - # Only transformer can be trained self.sarm_transformer.train(mode) return self @@ -686,8 +653,7 @@ class SARMRewardModel(PreTrainedPolicy): state = state_features[i] if state_features is not None else None progress = progress_from_annotations[i].squeeze(-1) # (T,) - # Apply temporal augmentation with 50% probability (SARM paper A.4) - # Appends up to 4 reversed frames to simulate failures/recoveries + # Apply temporal augmentation with 50% probability: appends up to 4 reversed frames to simulate failures/recoveries if random.random() < 0.5: video, progress, state = self._apply_temporal_augmentation(video, progress, state, max_length) @@ -729,7 +695,7 @@ class SARMRewardModel(PreTrainedPolicy): total_loss = total_loss + self.config.stage_loss_weight * stage_loss output_dict['stage_loss'] = stage_loss.item() - # Misaligned loss: 20% probability (SARM paper - improve video-language alignment) + # Misaligned loss: 20% probability if random.random() < 0.2: shuffle_idx = torch.randperm(batch_size, device=self.device) _, _, misaligned_preds = self.sarm_transformer( diff --git a/src/lerobot/policies/sarm/processor_sarm.py b/src/lerobot/policies/sarm/processor_sarm.py index 3530b0d1c..b820c4966 100644 --- a/src/lerobot/policies/sarm/processor_sarm.py +++ b/src/lerobot/policies/sarm/processor_sarm.py @@ -23,16 +23,19 @@ import pandas as pd from transformers import CLIPModel, CLIPProcessor from lerobot.policies.sarm.configuration_sarm import SARMConfig +from lerobot.policies.sarm.sarm_utils import compute_priors, compute_tau, compute_cumulative_progress_batch, pad_state_to_max_dim from lerobot.processor import ( ProcessorStep, PolicyProcessorPipeline, PolicyAction, DeviceProcessorStep, AddBatchDimensionProcessorStep, + NormalizerProcessorStep, ) from lerobot.processor.converters import ( policy_action_to_transition, transition_to_policy_action, + from_tensor_to_numpy, ) from lerobot.processor.pipeline import PipelineFeatureType from lerobot.processor.core import EnvTransition, TransitionKey @@ -41,20 +44,7 @@ from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PR class SARMEncodingProcessorStep(ProcessorStep): - """ - ProcessorStep that encodes images and text for SARM training. - - Per SARM paper (Appendix A.4): "We employ a frozen clip-vit-base-patch32 encoder - to process both RGB image sequences and task descriptions." - - This step handles: - - CLIP image encoding (512-dim) - - CLIP text encoding (512-dim) - - Joint state normalization - - Supports temporal sequences: (B, T, C, H, W) → (B, T, 512) video features - """ - + """ProcessorStep that encodes images and text with CLIP.""" def __init__( self, config: SARMConfig, @@ -69,8 +59,6 @@ class SARMEncodingProcessorStep(ProcessorStep): self.task_description = task_description or config.task_description self.dataset_meta = dataset_meta self.dataset_stats = dataset_stats - - # Compute temporal proportions from subtask annotations if available self.temporal_proportions = None self.subtask_names = None if dataset_meta is not None: @@ -94,7 +82,14 @@ class SARMEncodingProcessorStep(ProcessorStep): self.device = device def _compute_temporal_proportions(self): - """Compute temporal proportions for each subtask from dataset annotations.""" + """Compute temporal proportions for each subtask from dataset annotations. + + Implements SARM Paper Formula (1): + ᾱ_k = (1/M) × Σ_i (L_{i,k} / T_i) + + This averages the proportion of time spent on each subtask within each trajectory, + giving equal weight to all trajectories regardless of absolute length. + """ if self.dataset_meta is None or not hasattr(self.dataset_meta, 'episodes'): return @@ -108,32 +103,42 @@ class SARMEncodingProcessorStep(ProcessorStep): logging.info("No subtask annotations found in dataset") return - # Convert to pandas episodes_df = episodes.to_pandas() - # Collect all subtask names and compute average durations - subtask_durations = {} + # Collect subtask durations and trajectory lengths for compute_priors + subtask_durations_per_trajectory = {} + trajectory_lengths = {} all_subtask_names = set() for ep_idx in episodes_df.index: - subtask_names = episodes_df.loc[ep_idx, 'subtask_names'] + subtask_names_ep = episodes_df.loc[ep_idx, 'subtask_names'] # Skip episodes without annotations - if subtask_names is None or (isinstance(subtask_names, float) and pd.isna(subtask_names)): + if subtask_names_ep is None or (isinstance(subtask_names_ep, float) and pd.isna(subtask_names_ep)): continue start_times = episodes_df.loc[ep_idx, 'subtask_start_times'] end_times = episodes_df.loc[ep_idx, 'subtask_end_times'] # Track unique subtask names - all_subtask_names.update(subtask_names) + all_subtask_names.update(subtask_names_ep) - # Compute durations - for i, name in enumerate(subtask_names): + # Compute total trajectory length T_i (sum of all subtask durations) + total_traj_length = sum(end_times[i] - start_times[i] for i in range(len(subtask_names_ep))) + + if total_traj_length <= 0: + continue + + # Store duration and trajectory length for each subtask occurrence + for i, name in enumerate(subtask_names_ep): duration = end_times[i] - start_times[i] - if name not in subtask_durations: - subtask_durations[name] = [] - subtask_durations[name].append(duration) + + if name not in subtask_durations_per_trajectory: + subtask_durations_per_trajectory[name] = [] + trajectory_lengths[name] = [] + + subtask_durations_per_trajectory[name].append(duration) + trajectory_lengths[name].append(total_traj_length) if not all_subtask_names: logging.info("No valid subtask annotations found") @@ -142,44 +147,17 @@ class SARMEncodingProcessorStep(ProcessorStep): # Sort subtask names for consistent ordering self.subtask_names = sorted(list(all_subtask_names)) self.config.num_stages = len(self.subtask_names) - self.config.subtask_names = self.subtask_names # Store in config for reference + self.config.subtask_names = self.subtask_names - # Compute average duration for each subtask - avg_durations = {} - for name in self.subtask_names: - if name in subtask_durations: - avg_durations[name] = np.mean(subtask_durations[name]) - else: - avg_durations[name] = 0.0 - - # Normalize to get proportions - total_duration = sum(avg_durations.values()) - if total_duration > 0: - self.temporal_proportions = { - name: avg_durations[name] / total_duration - for name in self.subtask_names - } - else: - raise ValueError( - "Cannot compute temporal proportions: all subtask durations are zero. " - "Check that your dataset has valid subtask annotations with start/end times." - ) - - # Store in config for the model to use in progress output conversion (SARM paper Eq. 4) + # Compute temporal proportions using Paper Formula (1) + self.temporal_proportions = compute_priors( + subtask_durations_per_trajectory, + trajectory_lengths, + self.subtask_names + ) self.config.temporal_proportions = [self.temporal_proportions[name] for name in self.subtask_names] - logging.info(f"Computed temporal proportions for {len(self.subtask_names)} subtasks: {self.temporal_proportions}") - def _to_numpy_array(self, x) -> np.ndarray: - """Convert input to a 1D numpy array.""" - if isinstance(x, torch.Tensor): - arr = x.cpu().numpy() - else: - arr = np.array(x) - if arr.ndim == 0: - arr = np.array([arr.item()]) - return arr - def _find_episode_for_frame(self, frame_idx: int) -> int: """Find the episode index for a given frame index.""" for ep_idx in range(len(self.dataset_meta.episodes)): @@ -187,14 +165,14 @@ class SARMEncodingProcessorStep(ProcessorStep): ep_end = self.dataset_meta.episodes[ep_idx]["dataset_to_index"] if ep_start <= frame_idx < ep_end: return ep_idx - return 0 # Fallback + return 0 def _get_episode_indices(self, frame_indices: np.ndarray, episode_index) -> np.ndarray: """Get episode indices for each frame index.""" if episode_index is None: return np.array([self._find_episode_for_frame(int(f)) for f in frame_indices]) - episode_indices = self._to_numpy_array(episode_index) + episode_indices = np.atleast_1d(np.asarray(from_tensor_to_numpy(episode_index))) # If single episode but multiple frames, compute episode for each frame if len(episode_indices) == 1 and len(frame_indices) > 1: @@ -211,22 +189,16 @@ class SARMEncodingProcessorStep(ProcessorStep): Pattern: [ep_start, t-(7*gap), t-(6*gap), ..., t-gap, t] """ - frame_gap = getattr(self.config, 'frame_gap', 1) - indices = [] - - - # First frame is the episode's initial frame - indices.append(ep_start) + indices.append(ep_start) # First frame is the episode's initial frame # Remaining frames are consecutive with frame_gap spacing num_consecutive = num_frames - 1 for i in range(num_consecutive): - offset = -(num_consecutive - 1 - i) * frame_gap + offset = -(num_consecutive - 1 - i) * self.config.frame_gap idx = max(ep_start, frame_idx + offset) indices.append(idx) - return torch.tensor(indices) def _compute_episode_metadata( @@ -269,6 +241,14 @@ class SARMEncodingProcessorStep(ProcessorStep): ) -> tuple[int, float]: """Compute stage index and cumulative progress for a single frame. + Implements SARM Paper Formula (2): + y_t = P_{k-1} + ᾱ_k × τ_t + + where: + - τ_t = (t - s_k) / (e_k - s_k) is within-subtask progress + - P_{k-1} is cumulative prior (sum of previous subtask proportions) + - ᾱ_k is the temporal proportion for subtask k + Args: current_frame: Frame index relative to episode start subtask_names: List of subtask names for this episode @@ -278,53 +258,46 @@ class SARMEncodingProcessorStep(ProcessorStep): Returns: Tuple of (stage_idx, cumulative_progress) """ - stage_idx = -1 - cumulative_progress = 0.0 + # Get temporal proportions as list for compute_cumulative_progress + temporal_proportions_list = [ + self.temporal_proportions.get(name, 0.0) for name in self.subtask_names + ] # Find which subtask this frame belongs to for j, (name, start_frame, end_frame) in enumerate(zip(subtask_names, subtask_start_frames, subtask_end_frames)): if current_frame >= start_frame and current_frame <= end_frame: - # Found the subtask + # Found the subtask, get its global index stage_idx = self.subtask_names.index(name) if name in self.subtask_names else 0 - # Calculate within-subtask progress - subtask_duration = end_frame - start_frame - if subtask_duration > 0: - within_subtask_progress = (current_frame - start_frame) / subtask_duration - else: - within_subtask_progress = 1.0 + # Compute τ_t using utility function (Paper Formula 2) + tau = compute_tau(current_frame, start_frame, end_frame) - # Calculate cumulative progress from completed subtasks - for k in range(j): - prev_name = subtask_names[k] - if prev_name in self.temporal_proportions: - cumulative_progress += self.temporal_proportions[prev_name] - - # Add current subtask's partial progress - if name in self.temporal_proportions: - cumulative_progress += self.temporal_proportions[name] * within_subtask_progress + # Compute cumulative progress using utility function (Paper Formula 2) + cumulative_progress = compute_cumulative_progress_batch( + tau, stage_idx, temporal_proportions_list + ) return stage_idx, cumulative_progress - # No matching subtask found - estimate based on position + # No matching subtask found if current_frame < subtask_start_frames[0]: return 0, 0.0 elif current_frame > subtask_end_frames[-1]: return len(self.subtask_names) - 1, 1.0 else: - # Between subtasks - use previous subtask's end state + # Between subtasks - use previous subtask's end state (tau = 1.0) for j in range(len(subtask_names) - 1): if current_frame > subtask_end_frames[j] and current_frame < subtask_start_frames[j + 1]: name = subtask_names[j] stage_idx = self.subtask_names.index(name) if name in self.subtask_names else j - # Sum up all completed subtasks - for k in range(j + 1): - prev_name = subtask_names[k] - if prev_name in self.temporal_proportions: - cumulative_progress += self.temporal_proportions[prev_name] + + # Completed subtask, so tau = 1.0 + cumulative_progress = compute_cumulative_progress_batch( + 1.0, stage_idx, temporal_proportions_list + ) return stage_idx, cumulative_progress - return 0, 0.0 # Fallback + return 0, 0.0 def _compute_labels_for_sample( self, @@ -359,13 +332,8 @@ class SARMEncodingProcessorStep(ProcessorStep): subtask_start_frames = episodes_df.loc[ep_idx, 'subtask_start_frames'] subtask_end_frames = episodes_df.loc[ep_idx, 'subtask_end_frames'] - - # Get episode boundaries ep_start = self.dataset_meta.episodes[ep_idx]["dataset_from_index"] - # Get config values - frame_gap = self.config.frame_gap if hasattr(self.config, 'frame_gap') else 1 - # Generate labels for each frame in the sequence stage_labels = [] progress_targets = [] @@ -377,7 +345,7 @@ class SARMEncodingProcessorStep(ProcessorStep): else: # Positions 1-8: consecutive frames with frame_gap spacing num_consecutive = seq_len - 1 - offset = -(num_consecutive - i) * frame_gap + offset = -(num_consecutive - i) * self.config.frame_gap current_frame = max(0, frame_idx + offset - ep_start) @@ -388,7 +356,6 @@ class SARMEncodingProcessorStep(ProcessorStep): stage_labels.append(stage_idx) progress_targets.append(cumulative_progress) - # Convert to tensors stage_labels = torch.tensor(stage_labels, dtype=torch.long) progress_targets = torch.tensor(progress_targets, dtype=torch.float32).unsqueeze(-1) @@ -411,7 +378,7 @@ class SARMEncodingProcessorStep(ProcessorStep): is_batch = isinstance(frame_index, torch.Tensor) and frame_index.numel() > 1 # Normalize inputs to numpy arrays - frame_indices = self._to_numpy_array(frame_index) + frame_indices = np.atleast_1d(np.asarray(from_tensor_to_numpy(frame_index))) episode_indices = self._get_episode_indices(frame_indices, episode_index) # Determine sequence length @@ -422,7 +389,6 @@ class SARMEncodingProcessorStep(ProcessorStep): episodes_df = self.dataset_meta.episodes.to_pandas() - # Process all samples all_stage_labels = [] all_progress_targets = [] @@ -450,7 +416,7 @@ class SARMEncodingProcessorStep(ProcessorStep): if not isinstance(observation, dict): raise ValueError("Observation must be a dictionary") - # 1. Encode images with CLIP + # Encode images with CLIP image = observation.get(self.image_key) if image is None: raise ValueError(f"Image not found in observation for key: {self.image_key}") @@ -460,27 +426,27 @@ class SARMEncodingProcessorStep(ProcessorStep): video_features = self._encode_images_batch(image) observation['video_features'] = video_features - # 2. Extract and normalize joint states - state_data = observation.get("state") or observation.get("observation.state") + # Extract state and pad to max_state_dim (already normalized by NormalizerProcessorStep) + state_key = self.config.state_key + state_data = observation.get(state_key) if state_data is None: - raise ValueError("State data not found in observation (expected 'state' or 'observation.state')") + state_data = observation.get("state") or observation.get("observation.state") + if state_data is None: + raise ValueError(f"State data not found in observation (expected '{state_key}', 'state', or 'observation.state')") if isinstance(state_data, torch.Tensor): - state_data = state_data.cpu().numpy() + state_tensor = state_data.float() + else: + state_tensor = torch.tensor(state_data, dtype=torch.float32) - state_key = "state" if "state" in observation else "observation.state" - if self.dataset_stats and state_key in self.dataset_stats: - mean = self.dataset_stats[state_key]['mean'] - std = self.dataset_stats[state_key]['std'] - state_data = (state_data - mean) / (std + 1e-8) + # Pad state + observation['state_features'] = pad_state_to_max_dim(state_tensor, self.config.max_state_dim) - observation['state_features'] = torch.tensor(state_data, dtype=torch.float32) - - # 3. Encode text with CLIP (per SARM paper A.4) + # Encode text with CLIP batch_size = video_features.shape[0] observation['text_features'] = self._encode_text_clip(self.task_description, batch_size) - # 4. Extract frame/episode indices from complementary data + # Extract frame/episode indices from complementary data comp_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) if not isinstance(comp_data, dict): raise ValueError("COMPLEMENTARY_DATA must be a dictionary") @@ -493,10 +459,10 @@ class SARMEncodingProcessorStep(ProcessorStep): if episode_index is None: raise ValueError("Episode index ('episode_index') not found in COMPLEMENTARY_DATA") - # 5. Compute episode metadata if dataset_meta is available + # Compute episode metadata if dataset_meta is available if self.dataset_meta is not None: is_batch = isinstance(frame_index, torch.Tensor) and frame_index.numel() > 1 - frame_indices = self._to_numpy_array(frame_index) + frame_indices = np.atleast_1d(np.asarray(from_tensor_to_numpy(frame_index))) episode_indices = self._get_episode_indices(frame_indices, episode_index) # Determine number of frames from video features @@ -512,7 +478,7 @@ class SARMEncodingProcessorStep(ProcessorStep): observation['remaining_length'] = remaining observation['episode_length'] = ep_lengths - # 6. Generate stage labels and progress targets from subtask annotations + # Generate stage labels and progress targets from subtask annotations if self.temporal_proportions is not None and self.dataset_meta is not None: stage_labels, progress_targets = self._generate_stage_and_progress_labels( frame_index, episode_index, video_features @@ -539,14 +505,10 @@ class SARMEncodingProcessorStep(ProcessorStep): # Check if we have temporal dimension has_temporal = len(images.shape) == 5 - if has_temporal: - # Shape: (B, T, C, H, W) + if has_temporal: # Shape: (B, T, C, H, W) batch_size, seq_length = images.shape[0], images.shape[1] - - # Reshape to (B*T, C, H, W) to process all frames at once - images = images.reshape(batch_size * seq_length, *images.shape[2:]) - elif len(images.shape) == 4: - # Shape: (B, C, H, W) + images = images.reshape(batch_size * seq_length, *images.shape[2:]) + elif len(images.shape) == 4: # Shape: (B, C, H, W) batch_size = images.shape[0] seq_length = 1 else: @@ -608,7 +570,7 @@ class SARMEncodingProcessorStep(ProcessorStep): Returns: Encoded text features with shape (B, 512) """ - # Use CLIP's tokenizer directly for text (avoids image processor validation issues) + # Use CLIP's tokenizer directly for text tokenizer = self.clip_processor.tokenizer inputs = tokenizer([text], return_tensors="pt", padding=True, truncation=True) inputs = {k: v.to(self.device) for k, v in inputs.items()} @@ -629,7 +591,7 @@ class SARMEncodingProcessorStep(ProcessorStep): self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: """Add encoded features to the observation features.""" - # Add the encoded features + # Add the encoded features (state uses max_state_dim, padded with zeros) features[PipelineFeatureType.OBSERVATION]['video_features'] = PolicyFeature( type=FeatureType.VISUAL, shape=(self.config.num_frames, self.config.image_dim) @@ -640,7 +602,7 @@ class SARMEncodingProcessorStep(ProcessorStep): ) features[PipelineFeatureType.OBSERVATION]['state_features'] = PolicyFeature( type=FeatureType.STATE, - shape=(self.config.num_frames, self.config.state_dim) + shape=(self.config.num_frames, self.config.max_state_dim) ) return features @@ -660,11 +622,16 @@ def make_sarm_pre_post_processors( to process both RGB image sequences and task descriptions." The pre-processing pipeline: - 1. Encodes images with CLIP (512-dim) - 2. Encodes text with CLIP (512-dim) - 3. Normalizes joint states - 4. Adds batch dimension - 5. Moves data to device + 1. Adds batch dimension + 2. Normalizes observation.state using NormalizerProcessorStep (MEAN_STD) + 3. SARMEncodingProcessorStep: + - Encodes images with CLIP (512-dim) + - Pads states to max_state_dim + - Encodes text with CLIP (512-dim) + 4. Moves data to device + + The post-processing pipeline: + 1. Moves data to CPU (no unnormalization - outputs are rewards) Args: config: SARM configuration @@ -676,6 +643,11 @@ def make_sarm_pre_post_processors( """ input_steps = [ AddBatchDimensionProcessorStep(), + NormalizerProcessorStep( + features={**config.input_features, **config.output_features}, + norm_map=config.normalization_mapping, + stats=dataset_stats, + ), SARMEncodingProcessorStep( config=config, dataset_meta=dataset_meta, diff --git a/src/lerobot/policies/sarm/sarm_utils.py b/src/lerobot/policies/sarm/sarm_utils.py new file mode 100644 index 000000000..ccfa10361 --- /dev/null +++ b/src/lerobot/policies/sarm/sarm_utils.py @@ -0,0 +1,249 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Utility functions for SARM progress label computation. + +Implements formulas from the SARM paper: +- Formula (1): Compute dataset-level temporal proportions (priors) ᾱ_k +- Formula (2): Compute normalized progress targets y_t = P_{k-1} + ᾱ_k × τ_t +""" + +import numpy as np +import torch +import torch.nn.functional as F +from typing import Sequence + + +def compute_priors( + subtask_durations_per_trajectory: dict[str, list[float]], + trajectory_lengths: dict[str, list[float]], + subtask_names: list[str], +) -> dict[str, float]: + """ + Compute dataset-level temporal proportions (priors) for each subtask. + + Implements SARM Paper Formula (1): + ᾱ_k = (1/M) × Σ_i (L_{i,k} / T_i) + + where: + - M is the number of trajectories + - L_{i,k} is the length of subtask k in trajectory i + - T_i is the total length of trajectory i + + This averages the PROPORTION of each subtask within each trajectory, + giving equal weight to all trajectories regardless of their absolute length. + + Args: + subtask_durations_per_trajectory: Dict mapping subtask name to list of + (duration, trajectory_length) tuples for each occurrence + trajectory_lengths: Dict mapping subtask name to list of trajectory lengths + for each occurrence of that subtask + subtask_names: Ordered list of subtask names + + Returns: + Dict mapping subtask name to its temporal proportion (ᾱ_k) + """ + if not subtask_names: + raise ValueError("subtask_names cannot be empty") + + # Compute proportion per occurrence: L_{i,k} / T_i + subtask_proportions = {} + for name in subtask_names: + if name in subtask_durations_per_trajectory and name in trajectory_lengths: + durations = subtask_durations_per_trajectory[name] + traj_lengths = trajectory_lengths[name] + + if len(durations) != len(traj_lengths): + raise ValueError( + f"Mismatch in lengths for subtask '{name}': " + f"{len(durations)} durations vs {len(traj_lengths)} trajectory lengths" + ) + + # Compute L_{i,k} / T_i for each occurrence + proportions = [] + for duration, traj_len in zip(durations, traj_lengths): + if traj_len > 0: + proportions.append(duration / traj_len) + + # Average across all occurrences: (1/M) × Σ_i (L_{i,k} / T_i) + subtask_proportions[name] = np.mean(proportions) if proportions else 0.0 + else: + subtask_proportions[name] = 0.0 + + # Normalize to ensure sum = 1 (handles floating point errors and missing subtasks) + total = sum(subtask_proportions.values()) + if total > 0: + subtask_proportions = { + name: prop / total for name, prop in subtask_proportions.items() + } + else: + raise ValueError("Cannot compute temporal proportions: all proportions are zero. " + "Check that your dataset has valid subtask annotations with start/end times.") + + return subtask_proportions + + +def compute_tau( + current_frame: int | float, + subtask_start: int | float, + subtask_end: int | float, +) -> float: + """ + Compute within-subtask normalized time τ_t. + + Implements part of SARM Paper Formula (2): + τ_t = (t - s_k) / (e_k - s_k) ∈ [0, 1] + + where: + - t is the current frame + - s_k is the start frame of subtask k + - e_k is the end frame of subtask k + + Args: + current_frame: Current frame index (t) + subtask_start: Start frame of the subtask (s_k) + subtask_end: End frame of the subtask (e_k) + + Returns: + Within-subtask progress τ_t ∈ [0, 1] + """ + subtask_duration = subtask_end - subtask_start + + if subtask_duration <= 0: + return 1.0 + + tau = (current_frame - subtask_start) / subtask_duration + + return float(np.clip(tau, 0.0, 1.0)) + + +def compute_cumulative_progress_batch( + tau: torch.Tensor | float, + stage_indices: torch.Tensor | int, + alpha: torch.Tensor | Sequence[float], + cumulative_prior: torch.Tensor | None = None, +) -> torch.Tensor | float: + """ + Compute cumulative normalized progress from within-subtask progress. + + This function implements the core formula used in SARM for both: + + **Formula 2 (Training labels):** + y_t = P_{k-1} + ᾱ_k × τ_t ∈ [0, 1] + + Used to compute ground-truth progress labels from subtask annotations. + - τ_t comes from annotated frame position: τ_t = (t - s_k) / (e_k - s_k) + - k is the known subtask from annotations + + **Formula 4 (Inference predictions):** + ŷ_{1:N} = P̂_{k-1, 1:N} + ᾱ_{k, 1:N} × τ̂_{1:N} ∈ [0, 1] + + Used to convert model outputs to cumulative progress during inference. + - τ̂ comes from the subtask MLP head (conditioned on predicted stage) + - k = Ŝ is the predicted stage from Formula 3: Ŝ = argmax(softmax(Ψ)) + + The formulas are mathematically identical; only the source of inputs differs: + - Training: τ and k from annotations → ground-truth labels + - Inference: τ̂ and Ŝ from model → predicted progress + + where: + - P_{k-1} = Σ_{j=1}^{k-1} ᾱ_j is the cumulative prior (sum of previous proportions) + - ᾱ_k is the temporal proportion for subtask k (from Formula 1) + - τ is within-subtask progress ∈ [0, 1] + + This ensures: + - y at start of subtask k = P_{k-1} + - y at end of subtask k = P_k + + Supports both scalar and batched tensor inputs: + - Scalar: tau (float), stage_indices (int), alpha (list/sequence) + - Batch: tau (Tensor), stage_indices (Tensor), alpha (Tensor), cumulative_prior (Tensor) + + Args: + tau: Within-subtask progress τ ∈ [0, 1]. + For training: computed from frame position in annotated subtask. + For inference: predicted by subtask MLP head. + Scalar float or Tensor with shape (..., 1) + stage_indices: Index of current subtask k (0-indexed). + For training: known from annotations. + For inference: predicted via argmax(stage_probs) (Formula 3). + Scalar int or Tensor with shape (...) + alpha: Temporal proportions ᾱ with shape (num_stages,) or Sequence[float]. + Computed from dataset annotations using Formula 1. + cumulative_prior: Optional. Cumulative priors P with shape (num_stages + 1,) + where cumulative_prior[k] = P_k = Σ_{j=1}^{k} ᾱ_j. + If None, will be computed from alpha. + + Returns: + Cumulative progress y ∈ [0, 1]. + Scalar float if inputs are scalar, otherwise Tensor with shape (..., 1) + """ + if not isinstance(tau, torch.Tensor): + if not alpha: + raise ValueError("alpha (temporal_proportions) cannot be empty") + + if isinstance(alpha, torch.Tensor): + alpha_list = alpha.tolist() + else: + alpha_list = list(alpha) + + if stage_indices < 0 or stage_indices >= len(alpha_list): + raise ValueError( + f"stage_indices {stage_indices} out of range " + f"for {len(alpha_list)} subtasks" + ) + + # P_{k-1} = sum of proportions for subtasks 0 to k-1 + P_k_minus_1 = sum(alpha_list[:stage_indices]) + + # ᾱ_k = proportion for current subtask + alpha_k = alpha_list[stage_indices] + + # y_t = P_{k-1} + ᾱ_k × τ_t + y_t = P_k_minus_1 + alpha_k * tau + + return float(np.clip(y_t, 0.0, 1.0)) + + if not isinstance(alpha, torch.Tensor): + alpha = torch.tensor(alpha, dtype=torch.float32) + + # Compute cumulative_prior if not provided + if cumulative_prior is None: + cumulative_prior = torch.zeros(len(alpha) + 1, dtype=alpha.dtype, device=alpha.device) + cumulative_prior[1:] = torch.cumsum(alpha, dim=0) + + # P_{k-1} for each predicted stage + P_k_minus_1 = cumulative_prior[stage_indices] + + # ᾱ_k for each predicted stage + alpha_k = alpha[stage_indices] + + # ŷ = P_{k-1} + ᾱ_k × τ̂ + progress = P_k_minus_1.unsqueeze(-1) + alpha_k.unsqueeze(-1) * tau + + return progress + +def pad_state_to_max_dim(state: torch.Tensor, max_state_dim: int) -> torch.Tensor: + """Pad the state tensor's last dimension to max_state_dim with zeros.""" + current_dim = state.shape[-1] + if current_dim >= max_state_dim: + return state[..., :max_state_dim] # Truncate if larger + + # Pad with zeros on the right + padding = (0, max_state_dim - current_dim) # (left, right) for last dim + return F.pad(state, padding, mode='constant', value=0) + diff --git a/tests/policies/test_sarm_utils.py b/tests/policies/test_sarm_utils.py new file mode 100644 index 000000000..2ae2b6468 --- /dev/null +++ b/tests/policies/test_sarm_utils.py @@ -0,0 +1,392 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Tests for SARM utility functions. + +Tests the implementation of SARM paper formulas: +- Formula (1): compute_priors - dataset-level temporal proportions +- Formula (2): compute_tau, compute_cumulative_progress - progress labels +""" + +import pytest +import numpy as np +import torch + +from lerobot.policies.sarm.sarm_utils import ( + compute_priors, + compute_tau, + compute_cumulative_progress_batch, +) + + +class TestComputePriors: + """Tests for compute_priors (SARM Paper Formula 1). + + Formula: ᾱ_k = (1/M) × Σ_i (L_{i,k} / T_i) + + Key insight: This averages the PROPORTION of each subtask within each trajectory, + giving equal weight to all trajectories regardless of absolute length. + """ + + def test_basic_two_trajectories_equal_proportions(self): + """Test with two trajectories that have equal proportions.""" + # Both trajectories: subtask1 = 50%, subtask2 = 50% + subtask_durations = { + 'subtask1': [50, 100], # durations + 'subtask2': [50, 100], + } + trajectory_lengths = { + 'subtask1': [100, 200], + 'subtask2': [100, 200], + } + subtask_names = ['subtask1', 'subtask2'] + + result = compute_priors(subtask_durations, trajectory_lengths, subtask_names) + + # Both should be 0.5 + assert abs(result['subtask1'] - 0.5) < 1e-6 + assert abs(result['subtask2'] - 0.5) < 1e-6 + + def test_paper_example_different_from_avg_durations(self): + """Test that compute_priors differs from naive average duration approach. + + This is the key test showing the difference between: + - Paper formula: average of (L_i,k / T_i) + - Naive approach: mean(L_i,k) / sum(mean(L_i,j)) + """ + # Episode 1: T=100, subtask1=80, subtask2=20 (proportions: 0.8, 0.2) + # Episode 2: T=200, subtask1=40, subtask2=160 (proportions: 0.2, 0.8) + subtask_durations = { + 'subtask1': [80, 40], + 'subtask2': [20, 160], + } + trajectory_lengths = { + 'subtask1': [100, 200], + 'subtask2': [100, 200], + } + subtask_names = ['subtask1', 'subtask2'] + + result = compute_priors(subtask_durations, trajectory_lengths, subtask_names) + + # Paper formula: + # ᾱ_1 = (1/2) × (80/100 + 40/200) = (1/2) × (0.8 + 0.2) = 0.5 + # ᾱ_2 = (1/2) × (20/100 + 160/200) = (1/2) × (0.2 + 0.8) = 0.5 + assert abs(result['subtask1'] - 0.5) < 1e-6 + assert abs(result['subtask2'] - 0.5) < 1e-6 + + + def test_single_trajectory(self): + """Test with a single trajectory.""" + subtask_durations = { + 'reach': [30], + 'grasp': [20], + 'lift': [50], + } + trajectory_lengths = { + 'reach': [100], + 'grasp': [100], + 'lift': [100], + } + subtask_names = ['grasp', 'lift', 'reach'] # sorted order + + result = compute_priors(subtask_durations, trajectory_lengths, subtask_names) + + assert abs(result['reach'] - 0.3) < 1e-6 + assert abs(result['grasp'] - 0.2) < 1e-6 + assert abs(result['lift'] - 0.5) < 1e-6 + + def test_sum_to_one(self): + """Test that proportions always sum to 1.""" + subtask_durations = { + 'a': [10, 20, 30], + 'b': [40, 50, 60], + 'c': [50, 30, 10], + } + trajectory_lengths = { + 'a': [100, 100, 100], + 'b': [100, 100, 100], + 'c': [100, 100, 100], + } + subtask_names = ['a', 'b', 'c'] + + result = compute_priors(subtask_durations, trajectory_lengths, subtask_names) + + total = sum(result.values()) + assert abs(total - 1.0) < 1e-6 + + def test_empty_subtask_names_raises(self): + """Test that empty subtask_names raises an error.""" + with pytest.raises(ValueError, match="subtask_names cannot be empty"): + compute_priors({}, {}, []) + + def test_missing_subtask_gets_zero_before_normalization(self): + """Test handling of subtasks that appear in some but not all trajectories.""" + # subtask1 appears in both, subtask2 only in first + subtask_durations = { + 'subtask1': [50, 100], + 'subtask2': [50], # only in first trajectory + } + trajectory_lengths = { + 'subtask1': [100, 200], + 'subtask2': [100], + } + subtask_names = ['subtask1', 'subtask2'] + + result = compute_priors(subtask_durations, trajectory_lengths, subtask_names) + + # subtask1: (50/100 + 100/200) / 2 = (0.5 + 0.5) / 2 = 0.5 + # subtask2: 50/100 = 0.5 (only one occurrence) + # After normalization: both should be 0.5 + assert result['subtask1'] > 0 + assert result['subtask2'] > 0 + assert abs(sum(result.values()) - 1.0) < 1e-6 + + +class TestComputeTau: + """Tests for compute_tau (within-subtask progress). + + Formula: τ_t = (t - s_k) / (e_k - s_k) ∈ [0, 1] + """ + + def test_at_start(self): + """τ should be 0 at subtask start.""" + tau = compute_tau(current_frame=10, subtask_start=10, subtask_end=50) + assert tau == 0.0 + + def test_at_end(self): + """τ should be 1 at subtask end.""" + tau = compute_tau(current_frame=50, subtask_start=10, subtask_end=50) + assert tau == 1.0 + + def test_at_middle(self): + """τ should be 0.5 at subtask midpoint.""" + tau = compute_tau(current_frame=30, subtask_start=10, subtask_end=50) + assert abs(tau - 0.5) < 1e-6 + + def test_quarter_progress(self): + """Test τ at 25% through subtask.""" + tau = compute_tau(current_frame=20, subtask_start=0, subtask_end=80) + assert abs(tau - 0.25) < 1e-6 + + def test_zero_duration_subtask(self): + """τ should be 1.0 for zero-duration subtask.""" + tau = compute_tau(current_frame=10, subtask_start=10, subtask_end=10) + assert tau == 1.0 + + def test_clamps_below_zero(self): + """τ should be clamped to 0 if frame is before subtask.""" + tau = compute_tau(current_frame=5, subtask_start=10, subtask_end=50) + assert tau == 0.0 + + def test_clamps_above_one(self): + """τ should be clamped to 1 if frame is after subtask.""" + tau = compute_tau(current_frame=60, subtask_start=10, subtask_end=50) + assert tau == 1.0 + + def test_float_inputs(self): + """Test with float frame indices (from interpolation).""" + tau = compute_tau(current_frame=25.5, subtask_start=10.0, subtask_end=50.0) + expected = (25.5 - 10.0) / (50.0 - 10.0) + assert abs(tau - expected) < 1e-6 + + +class TestComputeCumulativeProgressBatchScalar: + """Tests for compute_cumulative_progress_batch with scalar inputs (normalized progress y_t). + + Formula: y_t = P_{k-1} + ᾱ_k × τ_t ∈ [0, 1] + """ + + def test_first_subtask_start(self): + """y should be 0 at start of first subtask.""" + proportions = [0.3, 0.5, 0.2] + y = compute_cumulative_progress_batch(tau=0.0, stage_indices=0, alpha=proportions) + assert y == 0.0 + + def test_first_subtask_end(self): + """y should equal ᾱ_1 at end of first subtask.""" + proportions = [0.3, 0.5, 0.2] + y = compute_cumulative_progress_batch(tau=1.0, stage_indices=0, alpha=proportions) + assert abs(y - 0.3) < 1e-6 + + def test_second_subtask_start(self): + """y should equal P_1 at start of second subtask.""" + proportions = [0.3, 0.5, 0.2] + y = compute_cumulative_progress_batch(tau=0.0, stage_indices=1, alpha=proportions) + assert abs(y - 0.3) < 1e-6 + + def test_second_subtask_end(self): + """y should equal P_2 at end of second subtask.""" + proportions = [0.3, 0.5, 0.2] + y = compute_cumulative_progress_batch(tau=1.0, stage_indices=1, alpha=proportions) + assert abs(y - 0.8) < 1e-6 # 0.3 + 0.5 + + def test_third_subtask_end(self): + """y should be 1.0 at end of last subtask.""" + proportions = [0.3, 0.5, 0.2] + y = compute_cumulative_progress_batch(tau=1.0, stage_indices=2, alpha=proportions) + assert abs(y - 1.0) < 1e-6 + + def test_midpoint_of_subtask(self): + """Test progress at midpoint of a subtask.""" + proportions = [0.4, 0.6] + # At τ=0.5 in subtask 1: y = P_0 + ᾱ_1 × 0.5 = 0 + 0.4 × 0.5 = 0.2 + y = compute_cumulative_progress_batch(tau=0.5, stage_indices=0, alpha=proportions) + assert abs(y - 0.2) < 1e-6 + + # At τ=0.5 in subtask 2: y = P_1 + ᾱ_2 × 0.5 = 0.4 + 0.6 × 0.5 = 0.7 + y = compute_cumulative_progress_batch(tau=0.5, stage_indices=1, alpha=proportions) + assert abs(y - 0.7) < 1e-6 + + def test_uniform_proportions(self): + """Test with uniform proportions.""" + proportions = [0.25, 0.25, 0.25, 0.25] + + # At end of each subtask, progress should be 0.25, 0.5, 0.75, 1.0 + for i in range(4): + y = compute_cumulative_progress_batch(tau=1.0, stage_indices=i, alpha=proportions) + expected = (i + 1) * 0.25 + assert abs(y - expected) < 1e-6 + + +class TestComputeCumulativeProgressBatchTensor: + """Tests for compute_cumulative_progress_batch with tensor inputs (GPU batch version).""" + + def test_tensor_matches_scalar_version(self): + """Test that tensor version matches scalar version.""" + proportions = [0.3, 0.5, 0.2] + alpha = torch.tensor(proportions, dtype=torch.float32) + cumulative = torch.zeros(len(proportions) + 1, dtype=torch.float32) + cumulative[1:] = torch.cumsum(alpha, dim=0) + + test_cases = [ + (0.0, 0), # start of subtask 0 + (1.0, 0), # end of subtask 0 + (0.0, 1), # start of subtask 1 + (0.5, 1), # middle of subtask 1 + (1.0, 2), # end of subtask 2 + ] + + for tau_val, stage_idx in test_cases: + # Scalar version + expected = compute_cumulative_progress_batch(tau_val, stage_idx, proportions) + + # Tensor version (single element) + tau = torch.tensor([[[tau_val]]]) # (1, 1, 1) + stages = torch.tensor([[stage_idx]]) # (1, 1) + result = compute_cumulative_progress_batch(tau, stages, alpha, cumulative) + + assert abs(result[0, 0, 0].item() - expected) < 1e-6 + + def test_batch_processing(self): + """Test batch processing with multiple samples.""" + proportions = [0.4, 0.6] + alpha = torch.tensor(proportions, dtype=torch.float32) + cumulative = torch.zeros(3, dtype=torch.float32) + cumulative[1:] = torch.cumsum(alpha, dim=0) + + # Batch of 2 samples, sequence length 3 + tau = torch.tensor([ + [[0.0], [0.5], [1.0]], # sample 1 + [[0.0], [0.5], [1.0]], # sample 2 + ]) + stages = torch.tensor([ + [0, 0, 0], # sample 1: all in subtask 0 + [1, 1, 1], # sample 2: all in subtask 1 + ]) + + result = compute_cumulative_progress_batch(tau, stages, alpha, cumulative) + + # Sample 1: subtask 0 with tau 0, 0.5, 1.0 -> y = 0, 0.2, 0.4 + assert abs(result[0, 0, 0].item() - 0.0) < 1e-6 + assert abs(result[0, 1, 0].item() - 0.2) < 1e-6 + assert abs(result[0, 2, 0].item() - 0.4) < 1e-6 + + # Sample 2: subtask 1 with tau 0, 0.5, 1.0 -> y = 0.4, 0.7, 1.0 + assert abs(result[1, 0, 0].item() - 0.4) < 1e-6 + assert abs(result[1, 1, 0].item() - 0.7) < 1e-6 + assert abs(result[1, 2, 0].item() - 1.0) < 1e-6 + + def test_auto_compute_cumulative_prior(self): + """Test that cumulative_prior is auto-computed when not provided.""" + proportions = [0.3, 0.5, 0.2] + alpha = torch.tensor(proportions, dtype=torch.float32) + + tau = torch.tensor([[[0.5]]]) + stages = torch.tensor([[1]]) + + # Without cumulative_prior (should auto-compute) + result = compute_cumulative_progress_batch(tau, stages, alpha) + + # Expected: P_0 + alpha_1 * 0.5 = 0.3 + 0.5 * 0.5 = 0.55 + assert abs(result[0, 0, 0].item() - 0.55) < 1e-6 + + +class TestEndToEndProgressLabeling: + """End-to-end tests for progress label computation.""" + + def test_consistent_semantic_meaning(self): + """Test that same subtask completion maps to same progress across trajectories. + + This is the key semantic property: "end of subtask 1" should always + mean the same progress value regardless of trajectory speed. + """ + proportions = [0.3, 0.5, 0.2] + + # Fast trajectory: subtask 1 ends at frame 30 (of 100) + tau_fast = compute_tau(30, 0, 30) # = 1.0 + y_fast = compute_cumulative_progress_batch(tau_fast, 0, proportions) + + # Slow trajectory: subtask 1 ends at frame 90 (of 300) + tau_slow = compute_tau(90, 0, 90) # = 1.0 + y_slow = compute_cumulative_progress_batch(tau_slow, 0, proportions) + + # Both should map to same progress (0.3 = end of subtask 1) + assert abs(y_fast - y_slow) < 1e-6 + assert abs(y_fast - 0.3) < 1e-6 + + def test_monotonic_within_subtask(self): + """Test that progress is monotonically increasing within a subtask.""" + proportions = [0.4, 0.6] + + prev_y = -1 + for tau in np.linspace(0, 1, 11): + y = compute_cumulative_progress_batch(tau, 0, proportions) + assert y > prev_y or (tau == 0 and y == 0) + prev_y = y + + def test_continuous_across_subtasks(self): + """Test that progress is continuous at subtask boundaries.""" + proportions = [0.3, 0.5, 0.2] + + # End of subtask 0 (tau=1.0) + y_end_0 = compute_cumulative_progress_batch(1.0, 0, proportions) + + # Start of subtask 1 (tau=0.0) + y_start_1 = compute_cumulative_progress_batch(0.0, 1, proportions) + + # Should be equal (P_1 = 0.3) + assert abs(y_end_0 - y_start_1) < 1e-6 + + # End of subtask 1 + y_end_1 = compute_cumulative_progress_batch(1.0, 1, proportions) + + # Start of subtask 2 + y_start_2 = compute_cumulative_progress_batch(0.0, 2, proportions) + + # Should be equal (P_2 = 0.8) + assert abs(y_end_1 - y_start_2) < 1e-6 +