diff --git a/scripts/visualize_sarm_predictions.py b/scripts/visualize_sarm_predictions.py index 20ff651f6..8c9d7ea23 100644 --- a/scripts/visualize_sarm_predictions.py +++ b/scripts/visualize_sarm_predictions.py @@ -224,14 +224,14 @@ def run_inference( """ Run SARM inference on video frames and joint states. - For each frame t, creates a temporal sequence of 9 frames using SARM's pattern: - [t-240, t-210, t-180, t-150, t-120, t-90, t-60, t-30, t] - This matches the training pattern where frames are loaded with 30-frame gaps - relative to the current frame. + (per SARM paper Section A.4): + - Frame 0: Initial frame of the episode (frame 0) + - Frames 1-8: 8 consecutive frames with frame_gap spacing ending at current frame t + Pattern: [frame_0, t-(7*gap), t-(6*gap), ..., t-gap, t] Args: model: SARM model - frames: Video frames (num_frames, H, W, C) + frames: Video frames (num_frames, H, W, C) - all frames from ONE episode states: Joint states (num_frames, state_dim) task_description: Task description text batch_size: Batch size for processing slices @@ -247,7 +247,12 @@ def run_inference( logger.info("Encoding task description with MiniLM...") text_embedding = model.encode_text(task_description) - logger.info("Creating video slices (SARM approach)...") + # Get config values + num_frames_model = model.config.num_frames # 9 + frame_gap = model.config.frame_gap # 30 + + logger.info("Creating video slices (SARM paper: initial frame + 8 consecutive)...") + # Convert to tensors video_embeddings = torch.tensor(video_embeddings, dtype=torch.float32) text_embedding = torch.tensor(text_embedding, dtype=torch.float32) @@ -256,33 +261,14 @@ def run_inference( else: state_embeddings = None - # Create video slices: for each frame i, create a sequence using SARM's pattern - # For SARM: 9 frames relative to current, with 30-frame gaps - # Pattern: [current-240, current-210, ..., current-30, current] - num_frames_model = model.config.num_frames - frame_gap = model.config.frame_gap - video_slices = [] state_slices = [] - last_frame_indices = [] - for i in tqdm(range(len(video_embeddings)), desc="Creating slices"): - # For SARM, create sequence relative to current frame (matching training pattern) - # Pattern: [current-240, current-210, ..., current-30, current] - # This matches observation_delta_indices: range(-240, 1, 30) - - # Compute frame indices for this slice (relative to current frame i) - frame_indices = [] - for j in range(num_frames_model): - # Start from -(num_frames_model-1) * frame_gap and go to 0 - offset = -(num_frames_model - 1 - j) * frame_gap - idx = i + offset - - # Clamp to valid range [0, current_frame] - if idx < 0: - idx = 0 # Pad with first available frame - - frame_indices.append(idx) + for current_frame in tqdm(range(len(video_embeddings)), desc="Creating slices"): + # Compute frame indices using SARM pattern: + # [initial_frame, t-(7*gap), t-(6*gap), ..., t-gap, t] + deltas = model.config.observation_delta_indices(current_frame) + frame_indices = [max(0, current_frame + delta) for delta in deltas] # Extract slice video_slice = video_embeddings[frame_indices] @@ -291,9 +277,6 @@ def run_inference( if state_embeddings is not None: state_slice = state_embeddings[frame_indices] state_slices.append(state_slice) - - # Track which frame index corresponds to the "current" frame - last_frame_indices.append(min(i, len(frame_indices) - 1)) video_slices = torch.stack(video_slices) # (num_frames, num_frames_model, 512) if state_embeddings is not None: @@ -320,7 +303,6 @@ def run_inference( ) # Extract last frame predictions (the "current" frame) - # For SARM, we take the last frame in each sequence batch_progress = progress_preds[:, -1, 0].cpu().numpy() batch_stages = stage_probs[:, -1, :].cpu().numpy() diff --git a/src/lerobot/policies/sarm/configuration_sarm.py b/src/lerobot/policies/sarm/configuration_sarm.py index e990fa83a..bccac50b5 100644 --- a/src/lerobot/policies/sarm/configuration_sarm.py +++ b/src/lerobot/policies/sarm/configuration_sarm.py @@ -44,6 +44,7 @@ class SARMConfig(PreTrainedConfig): num_layers: int = 8 num_stages: int = 5 # Number of task stages for classification (auto-updated from annotations if available) subtask_names: list | None = None # List of subtask names (auto-populated from annotations) + temporal_proportions: list | None = None # Temporal proportions for each stage (auto-computed from annotations) # Temporal parameters max_length: int = num_frames # Maximum video sequence length (matches num_frames) @@ -128,20 +129,31 @@ class SARMConfig(PreTrainedConfig): """Validate input and output features.""" pass - @property - def observation_delta_indices(self) -> list[int]: - """Load frames for SARM temporal sampling. + def observation_delta_indices(self, episode_frame_index: int) -> list[int]: + """Compute delta indices for SARM temporal sampling. - SARM uses 9 frames: 1 initial frame + 8 consecutive frames with frame_gap spacing. + Per SARM paper (Section A.4), the model uses 9 frames: + - Frame 0: Initial frame of the episode (delta = -episode_frame_index) + - Frames 1-8: 8 consecutive frames with frame_gap spacing ending at current frame + The dataloader converts these to seconds: delta_seconds = delta / fps + This means the first delta (-episode_frame_index) becomes -current_time, + which correctly points to t=0 (the initial frame). + + Args: + episode_frame_index: Current frame index within the episode (0, 1, 2, ...) + Returns: - Indices for loading: [-(8*frame_gap), ..., -frame_gap, 0] + 9 delta indices: [-episode_frame_index, -(7*gap), -(6*gap), ..., -gap, 0] """ - # For SARM: we need the initial frame (from episode start) plus 8 consecutive frames - # The dataset will load relative to current frame - # We'll handle the "initial frame" logic in the processor - # For now, load the last 8*frame_gap frames - return list(range(-self.frame_gap * (self.num_frames - 1), 1, self.frame_gap)) + # First delta: negative of current frame index to reach frame 0 + initial_frame_delta = -episode_frame_index + + # Remaining 8 deltas: consecutive frames with frame_gap spacing + num_consecutive = self.num_frames - 1 # 8 frames + consecutive_deltas = list(range(-self.frame_gap * (num_consecutive - 1), 1, self.frame_gap)) + + return [initial_frame_delta] + consecutive_deltas @property def action_delta_indices(self) -> None: diff --git a/src/lerobot/policies/sarm/modeling_sarm.py b/src/lerobot/policies/sarm/modeling_sarm.py index 4cd57dced..296e9188e 100644 --- a/src/lerobot/policies/sarm/modeling_sarm.py +++ b/src/lerobot/policies/sarm/modeling_sarm.py @@ -19,6 +19,7 @@ from typing import List, Union, Dict, Optional import random import numpy as np +import pandas as pd import torch import torch.nn as nn import torch.nn.functional as F @@ -71,13 +72,31 @@ class SARMTransformer(nn.Module): num_layers: int = 8, num_stages: int = 5, max_length: int = 9, - dropout: float = 0.1 + dropout: float = 0.1, + temporal_proportions: list[float] | None = None ): super().__init__() self.hidden_dim = hidden_dim self.max_length = max_length self.num_stages = num_stages + # Store temporal proportions for progress conversion (Paper Eq. 4) + # ŷ = P_{k-1} + ᾱ_k × τ̂ + if temporal_proportions is None: + raise ValueError( + "temporal_proportions is required for SARM. " + "Provide subtask annotations in your dataset or set temporal_proportions in config." + ) + + # ᾱ_k: proportion for each stage + alpha = torch.tensor(temporal_proportions, dtype=torch.float32) + + # P_k: cumulative proportion up to stage k (P_0 = 0) + cumulative = torch.zeros(num_stages + 1, dtype=torch.float32) + cumulative[1:] = torch.cumsum(alpha, dim=0) + self.register_buffer('alpha', alpha) + self.register_buffer('cumulative_prior', cumulative) + # Project video, text, and state to same dimension self.video_proj = nn.Linear(video_dim, hidden_dim) self.text_proj = nn.Linear(text_dim, hidden_dim) @@ -97,24 +116,26 @@ class SARMTransformer(nn.Module): self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) # Stage estimator head (classification) + # Paper A.4: "2 layers with hidden dimension of 512" self.stage_head = nn.Sequential( - nn.Linear(hidden_dim, hidden_dim // 2), - nn.LayerNorm(hidden_dim // 2), + nn.Linear(hidden_dim, 512), + nn.LayerNorm(512), nn.GELU(), nn.Dropout(dropout), - nn.Linear(hidden_dim // 2, num_stages) + nn.Linear(512, num_stages) ) # Subtask estimator head (regression, conditioned on stage) # Takes concatenated [features, stage_embedding] + # Paper A.4: "2 layers with hidden dimension of 512" self.stage_embedding = nn.Embedding(num_stages, hidden_dim // 4) subtask_input_dim = hidden_dim + hidden_dim // 4 self.subtask_head = nn.Sequential( - nn.Linear(subtask_input_dim, hidden_dim // 2), - nn.LayerNorm(hidden_dim // 2), + nn.Linear(subtask_input_dim, 512), + nn.LayerNorm(512), nn.GELU(), nn.Dropout(dropout), - nn.Linear(hidden_dim // 2, 1), + nn.Linear(512, 1), nn.Sigmoid() ) @@ -189,7 +210,17 @@ class SARMTransformer(nn.Module): conditioned_features = torch.cat([frame_features, stage_embeds], dim=-1) # Subtask progress estimation (conditioned on stage) - progress_preds = self.subtask_head(conditioned_features) # [batch_size, seq_len, 1] + # τ̂ = within-subtask progress (0-1) + tau_preds = self.subtask_head(conditioned_features) # [batch_size, seq_len, 1] + + # Convert τ̂ to cumulative progress ŷ using Paper Eq. 4: + # ŷ = P_{k-1} + ᾱ_k × τ̂ + # P_{k-1} = cumulative prior up to stage k-1 + # ᾱ_k = temporal proportion of stage k + P_k_minus_1 = self.cumulative_prior[stage_indices] # [batch_size, seq_len] + alpha_k = self.alpha[stage_indices] # [batch_size, seq_len] + + progress_preds = P_k_minus_1.unsqueeze(-1) + alpha_k.unsqueeze(-1) * tau_preds return stage_logits, stage_probs, progress_preds @@ -263,7 +294,8 @@ class SARMRewardModel(PreTrainedPolicy): "2. Ensure dataset_stats contains 'observation.state' or 'state' key" ) - # Initialize SARM transformer + # Initialize SARM transformer with temporal proportions for progress conversion + temporal_proportions = getattr(config, 'temporal_proportions', None) self.sarm_transformer = SARMTransformer( video_dim=config.image_dim, text_dim=config.text_dim, @@ -273,7 +305,8 @@ class SARMRewardModel(PreTrainedPolicy): num_layers=config.num_layers, num_stages=config.num_stages, max_length=config.max_length, - dropout=config.dropout + dropout=config.dropout, + temporal_proportions=temporal_proportions ) self.sarm_transformer.to(self.device) @@ -281,7 +314,7 @@ class SARMRewardModel(PreTrainedPolicy): logging.info(f"SARM Reward Model initialized on {self.device}") def _update_num_stages_from_dataset(self, dataset_meta) -> None: - """Update num_stages in config based on dataset subtask annotations.""" + """Update num_stages and temporal_proportions from dataset subtask annotations.""" episodes = dataset_meta.episodes if episodes is None or len(episodes) == 0: raise ValueError("No episodes found, using default num_stages") @@ -291,13 +324,27 @@ class SARMRewardModel(PreTrainedPolicy): episodes_df = episodes.to_pandas() - # Collect all unique subtask names + # Collect all unique subtask names and compute durations all_subtask_names = set() + subtask_durations = {} + for ep_idx in episodes_df.index: subtask_names = episodes_df.loc[ep_idx, 'subtask_names'] if subtask_names is None or (isinstance(subtask_names, float) and pd.isna(subtask_names)): continue + all_subtask_names.update(subtask_names) + + # Compute durations if available + if 'subtask_start_frames' in episodes_df.columns and 'subtask_end_frames' in episodes_df.columns: + start_frames = episodes_df.loc[ep_idx, 'subtask_start_frames'] + end_frames = episodes_df.loc[ep_idx, 'subtask_end_frames'] + + for i, name in enumerate(subtask_names): + duration = end_frames[i] - start_frames[i] + if name not in subtask_durations: + subtask_durations[name] = [] + subtask_durations[name].append(duration) if not all_subtask_names: raise ValueError("No valid subtask names found, using default num_stages") @@ -305,11 +352,27 @@ class SARMRewardModel(PreTrainedPolicy): # Sort subtask names for consistent ordering subtask_names = sorted(list(all_subtask_names)) num_stages = len(subtask_names) + + # Compute temporal proportions (Paper Eq. 1: ᾱ_k) + avg_durations = {} + for name in subtask_names: + if name in subtask_durations and subtask_durations[name]: + avg_durations[name] = np.mean(subtask_durations[name]) + else: + avg_durations[name] = 1.0 # Default + + total_duration = sum(avg_durations.values()) + if total_duration > 0: + temporal_proportions = [avg_durations[name] / total_duration for name in subtask_names] + else: + temporal_proportions = [1.0 / num_stages] * num_stages self.config.num_stages = num_stages self.config.subtask_names = subtask_names + self.config.temporal_proportions = temporal_proportions - logging.info(f"Auto-detected {num_stages} subtasks from dataset: {subtask_names}, using {num_stages} stages") + logging.info(f"Auto-detected {num_stages} subtasks: {subtask_names}") + logging.info(f"Temporal proportions: {dict(zip(subtask_names, temporal_proportions))}") def to(self, device): """Override to method to ensure all components move together.""" @@ -357,7 +420,7 @@ class SARMRewardModel(PreTrainedPolicy): # Batch process frames with CLIP for i in range(0, len(frames), self.config.clip_batch_size): batch = frames[i:i + self.config.clip_batch_size] - inputs = self.clip_processor(images=batch, return_tensors="pt", padding=True) + inputs = self.clip_processor(images=batch, return_tensors="pt") inputs = {k: v.to(self.device) for k, v in inputs.items()} # Get image embeddings from CLIP @@ -578,8 +641,8 @@ class SARMRewardModel(PreTrainedPolicy): state: torch.Tensor | None, max_length: int, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]: - """Apply rewind augmentation: append 2-4 reversed frames (SARM paper).""" - num_reverse = random.randint(2, min(4, max_length - 1)) + """Apply rewind augmentation: append up to 4 reversed frames (SARM paper A.4).""" + num_reverse = random.randint(1, min(4, max_length - 1)) # Reverse and take frames (skip first which is last of original) reversed_video = video.flip(0)[1:num_reverse + 1] diff --git a/src/lerobot/policies/sarm/processor_sarm.py b/src/lerobot/policies/sarm/processor_sarm.py index d5bb1f3fc..59ddf7292 100644 --- a/src/lerobot/policies/sarm/processor_sarm.py +++ b/src/lerobot/policies/sarm/processor_sarm.py @@ -208,15 +208,31 @@ class SARMEncodingProcessorStep(ProcessorStep): return episode_indices def _compute_absolute_indices(self, frame_idx: int, ep_start: int, num_frames: int) -> torch.Tensor: - """Compute absolute frame indices for a sequence.""" + """Compute absolute frame indices for a sequence. + + (per SARM paper Section A.4): + - Frame 0: Initial frame of the episode (ep_start) + - Frames 1-8: 8 consecutive frames with frame_gap spacing ending at current frame + Pattern: [ep_start, t-(7*gap), t-(6*gap), ..., t-gap, t] + + """ frame_gap = getattr(self.config, 'frame_gap', 1) - if frame_gap > 1: - indices = [max(ep_start, frame_idx - (num_frames - 1 - i) * frame_gap) for i in range(num_frames)] - return torch.tensor(indices) - else: - start_idx = max(ep_start, frame_idx - num_frames + 1) - return torch.arange(start_idx, frame_idx + 1) + indices = [] + + + # First frame is the episode's initial frame + indices.append(ep_start) + + # Remaining frames are consecutive with frame_gap spacing + num_consecutive = num_frames - 1 + for i in range(num_consecutive): + offset = -(num_consecutive - 1 - i) * frame_gap + idx = max(ep_start, frame_idx + offset) + indices.append(idx) + + + return torch.tensor(indices) def _compute_episode_metadata( self, @@ -324,6 +340,10 @@ class SARMEncodingProcessorStep(ProcessorStep): ) -> tuple[torch.Tensor, torch.Tensor] | tuple[None, None]: """Compute stage labels and progress targets for a single sample. + (per SARM paper Section A.4): + - Frame 0: Initial frame of episode (stage at frame 0, progress at frame 0) + - Frames 1-8: 8 consecutive frames with frame_gap spacing ending at current frame + Args: frame_idx: The frame index for this sample ep_idx: The episode index @@ -348,7 +368,7 @@ class SARMEncodingProcessorStep(ProcessorStep): # Get episode boundaries ep_start = self.dataset_meta.episodes[ep_idx]["dataset_from_index"] - # Get frame gap for temporal sampling + # Get config values frame_gap = self.config.frame_gap if hasattr(self.config, 'frame_gap') else 1 # Generate labels for each frame in the sequence @@ -356,12 +376,15 @@ class SARMEncodingProcessorStep(ProcessorStep): progress_targets = [] for i in range(seq_len): - # Calculate actual frame index for this position in sequence - if frame_gap > 1: - offset = -(seq_len - 1 - i) * frame_gap - current_frame = max(0, frame_idx + offset - ep_start) + if i == 0: + # Position 0: Initial frame of the episode + current_frame = 0 # Relative to episode start else: - current_frame = max(0, frame_idx - seq_len + 1 + i - ep_start) + # Positions 1-8: consecutive frames with frame_gap spacing + num_consecutive = seq_len - 1 + offset = -(num_consecutive - i) * frame_gap + current_frame = max(0, frame_idx + offset - ep_start) + stage_idx, cumulative_progress = self._compute_stage_and_progress_for_frame( current_frame, subtask_names, subtask_start_frames, subtask_end_frames @@ -564,7 +587,7 @@ class SARMEncodingProcessorStep(ProcessorStep): batch_imgs = images_list[i:i + self.config.clip_batch_size] # Process with CLIP - inputs = self.clip_processor(images=batch_imgs, return_tensors="pt", padding=True) + inputs = self.clip_processor(images=batch_imgs, return_tensors="pt") inputs = {k: v.to(self.device) for k, v in inputs.items()} # Get image embeddings @@ -707,3 +730,5 @@ def make_sarm_pre_post_processors( ), ) + +