From 73dd4f10f70bf93e730b93782a5a9965d37afe38 Mon Sep 17 00:00:00 2001 From: Pepijn Date: Thu, 27 Nov 2025 17:36:00 +0100 Subject: [PATCH] simplify --- .../dataset_annotation/subtask_annotation.py | 22 +-- .../policies/sarm/configuration_sarm.py | 19 +-- src/lerobot/policies/sarm/modeling_sarm.py | 131 ++++---------- src/lerobot/policies/sarm/processor_sarm.py | 160 +++--------------- src/lerobot/policies/sarm/sarm_utils.py | 1 - 5 files changed, 75 insertions(+), 258 deletions(-) diff --git a/examples/dataset_annotation/subtask_annotation.py b/examples/dataset_annotation/subtask_annotation.py index 4493f91e8..d84f85a6a 100644 --- a/examples/dataset_annotation/subtask_annotation.py +++ b/examples/dataset_annotation/subtask_annotation.py @@ -73,7 +73,7 @@ import json import time from pathlib import Path from typing import Optional -from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, as_completed +from concurrent.futures import ProcessPoolExecutor, as_completed import multiprocessing as mp import pandas as pd @@ -82,7 +82,6 @@ from pydantic import BaseModel, Field from qwen_vl_utils import process_vision_info from rich.console import Console from rich.panel import Panel -from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn from rich.tree import Tree from transformers import Qwen3VLMoeForConditionalGeneration, AutoProcessor @@ -176,7 +175,6 @@ class VideoAnnotator: self.console.print(f"[cyan]Loading model: {model_name}...[/cyan]") - # Load model and processor self.model = Qwen3VLMoeForConditionalGeneration.from_pretrained( model_name, torch_dtype=torch_dtype, @@ -237,18 +235,17 @@ class VideoAnnotator: cmd = [ 'ffmpeg', '-i', str(file_path), - '-ss', str(start_timestamp), # Start time - '-t', str(duration), # Duration - '-r', str(target_fps), # Output FPS - '-c:v', 'libx264', # H.264 codec - '-preset', 'ultrafast', # Faster encoding - '-crf', '23', # Better quality (lower = better) - '-an', # Remove audio - '-y', # Overwrite output file + '-ss', str(start_timestamp), + '-t', str(duration), + '-r', str(target_fps), + '-c:v', 'libx264', + '-preset', 'ultrafast', + '-crf', '23', + '-an', + '-y', str(tmp_path) ] - # Run ffmpeg (suppress output) subprocess.run( cmd, stdout=subprocess.DEVNULL, @@ -309,7 +306,6 @@ class VideoAnnotator: # Calculate episode duration if end_timestamp is None: - # Get video metadata (suppress AV1 warnings) import cv2 import os import sys diff --git a/src/lerobot/policies/sarm/configuration_sarm.py b/src/lerobot/policies/sarm/configuration_sarm.py index 0fa2e0b85..eef461288 100644 --- a/src/lerobot/policies/sarm/configuration_sarm.py +++ b/src/lerobot/policies/sarm/configuration_sarm.py @@ -27,13 +27,13 @@ from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig class SARMConfig(PreTrainedConfig): """Configuration class for SARM (Stage-Aware Reward Modeling)""" - # CLIP encoding parameters + # CLIP params image_dim: int = 512 text_dim: int = 512 num_frames: int = 9 # 1 initial + 8 consecutive frames - frame_gap: int = 30 # Frame gap between consecutive frames (at 30 fps = 1 second) + frame_gap: int = 30 # Frame gap between frames (at 30 fps = 1 second) - # Architecture parameters + # Architecture params hidden_dim: int = 768 num_heads: int = 12 num_layers: int = 8 @@ -44,7 +44,7 @@ class SARMConfig(PreTrainedConfig): max_length: int = num_frames # Maximum video sequence length (matches num_frames) use_temporal_sampler: bool = True # Always enable temporal sequence loading - # Training parameters + # Training params batch_size: int = 64 clip_batch_size: int = 64 # Batch size for CLIP encoding dropout: float = 0.1 @@ -80,14 +80,14 @@ class SARMConfig(PreTrainedConfig): def __post_init__(self): super().__post_init__() - # Add the image_key as VISUAL (this is the raw image from dataset) + # Add the image_key as VISUAL if self.image_key: self.input_features[self.image_key] = PolicyFeature( shape=(480, 640, 3), type=FeatureType.VISUAL ) - # Add state_key as STATE (raw state from dataset, will be padded to max_state_dim) + # Add state_key as STATE self.input_features[self.state_key] = PolicyFeature( shape=(self.max_state_dim,), # Single frame state, temporal sampling handles sequence type=FeatureType.STATE @@ -151,16 +151,13 @@ class SARMConfig(PreTrainedConfig): to the episode start (frame 0) by the dataset loader. This ensures we always get the initial frame regardless of the current position in the episode. - Returns: 9 delta indices: [-1_000_000, -(7*gap), -(6*gap), ..., -gap, 0] """ initial_frame_delta = -1_000_000 - # Remaining consecutive frames with frame_gap spacing - num_consecutive = self.num_frames - 1 - consecutive_deltas = list(range(-self.frame_gap * (num_consecutive - 1), 1, self.frame_gap)) - + num_consecutive = self.num_frames - 1 # 9 - 1 = 8 + consecutive_deltas = list(range(-self.frame_gap * (num_consecutive - 1), 1, self.frame_gap)) # [-210, -180, -150, -120, -90, -60, -30, 0] return [initial_frame_delta] + consecutive_deltas @property diff --git a/src/lerobot/policies/sarm/modeling_sarm.py b/src/lerobot/policies/sarm/modeling_sarm.py index d00e2aab1..fba9d2327 100644 --- a/src/lerobot/policies/sarm/modeling_sarm.py +++ b/src/lerobot/policies/sarm/modeling_sarm.py @@ -19,7 +19,6 @@ from typing import List, Union, Optional import random import numpy as np -import pandas as pd import torch import torch.nn as nn import torch.nn.functional as F @@ -31,9 +30,6 @@ from lerobot.policies.sarm.configuration_sarm import SARMConfig from lerobot.policies.sarm.sarm_utils import compute_priors, compute_cumulative_progress_batch, pad_state_to_max_dim from lerobot.policies.pretrained import PreTrainedPolicy - - - class SARMTransformer(nn.Module): """ SARM Transformer model for stage-aware reward prediction. @@ -41,8 +37,6 @@ class SARMTransformer(nn.Module): This model has a dual-head architecture: 1. Stage estimator: Predicts the high-level task stage (classification) 2. Subtask estimator: Predicts fine-grained progress within the stage (regression) - - The subtask estimator is conditioned on the stage prediction. """ def __init__( @@ -104,7 +98,7 @@ class SARMTransformer(nn.Module): nn.Linear(512, num_stages) ) - # Subtask estimator head (regression, conditioned on stage) + # Subtask estimator head (regression) self.stage_embedding = nn.Embedding(num_stages, hidden_dim // 4) subtask_input_dim = hidden_dim + hidden_dim // 4 self.subtask_head = nn.Sequential( @@ -219,13 +213,14 @@ class SARMRewardModel(PreTrainedPolicy): def __init__(self, config: SARMConfig, dataset_stats: dict | None = None, dataset_meta=None): super().__init__(config, dataset_stats) + config.validate_features() self.config = config self.dataset_stats = dataset_stats self.device = torch.device(config.device if config.device else "cuda" if torch.cuda.is_available() else "cpu") - # Detect num_stages from dataset annotations before building the model - if dataset_meta is not None: - self._update_num_stages_from_dataset(dataset_meta) + # Load temporal proportions from dataset + if config.temporal_proportions is None and dataset_meta is not None: + self._load_temporal_proportions(dataset_meta) logging.info("Loading CLIP encoder") self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") @@ -246,78 +241,38 @@ class SARMRewardModel(PreTrainedPolicy): temporal_proportions=config.temporal_proportions ) self.sarm_transformer.to(self.device) - logging.info(f"SARM Reward Model initialized on {self.device}") + logging.info(f"SARM initialized on {self.device}") - def _update_num_stages_from_dataset(self, dataset_meta) -> None: - """Update num_stages and temporal_proportions from dataset subtask annotations. - - Implements SARM Paper Formula (1): + def _load_temporal_proportions(self, dataset_meta) -> None: + """ + Load pre-computed temporal proportions from dataset metadata JSON file. + + The temporal proportions are computed during dataset annotation using SARM Paper Formula (1): ᾱ_k = (1/M) × Σ_i (L_{i,k} / T_i) """ - episodes = dataset_meta.episodes - if episodes is None or len(episodes) == 0: - raise ValueError("No episodes found, using default num_stages") - - if 'subtask_names' not in episodes.column_names: - raise ValueError("No subtask annotations found in dataset, using default num_stages") - - episodes_df = episodes.to_pandas() + import json - # Collect subtask durations and trajectory lengths for compute_priors - all_subtask_names = set() - subtask_durations_per_trajectory = {} - trajectory_lengths = {} + proportions_path = dataset_meta.root / "meta" / "temporal_proportions.json" - for ep_idx in episodes_df.index: - subtask_names_ep = episodes_df.loc[ep_idx, 'subtask_names'] - if subtask_names_ep is None or (isinstance(subtask_names_ep, float) and pd.isna(subtask_names_ep)): - continue - - all_subtask_names.update(subtask_names_ep) - - # Compute durations if available - if 'subtask_start_frames' in episodes_df.columns and 'subtask_end_frames' in episodes_df.columns: - start_frames = episodes_df.loc[ep_idx, 'subtask_start_frames'] - end_frames = episodes_df.loc[ep_idx, 'subtask_end_frames'] - - # Compute total trajectory length T_i - total_traj_length = sum(end_frames[i] - start_frames[i] for i in range(len(subtask_names_ep))) - - if total_traj_length <= 0: - continue - - for i, name in enumerate(subtask_names_ep): - duration = end_frames[i] - start_frames[i] - - if name not in subtask_durations_per_trajectory: - subtask_durations_per_trajectory[name] = [] - trajectory_lengths[name] = [] - - subtask_durations_per_trajectory[name].append(duration) - trajectory_lengths[name].append(total_traj_length) + if not proportions_path.exists(): + raise ValueError( + f"Temporal proportions not found at {proportions_path}. " + "Run the subtask annotation tool first to compute and save temporal proportions." + ) - if not all_subtask_names: - raise ValueError("No valid subtask names found, using default num_stages") + with open(proportions_path, "r") as f: + temporal_proportions_dict = json.load(f) # Sort subtask names for consistent ordering - subtask_names = sorted(list(all_subtask_names)) - num_stages = len(subtask_names) + subtask_names = sorted(temporal_proportions_dict.keys()) - # Compute temporal proportions using Paper Formula (1) - temporal_proportions_dict = compute_priors( - subtask_durations_per_trajectory, - trajectory_lengths, - subtask_names - ) - temporal_proportions = [temporal_proportions_dict[name] for name in subtask_names] - - self.config.num_stages = num_stages + self.config.num_stages = len(subtask_names) self.config.subtask_names = subtask_names - self.config.temporal_proportions = temporal_proportions + self.config.temporal_proportions = [temporal_proportions_dict[name] for name in subtask_names] - logging.info(f"Auto-detected {num_stages} subtasks: {subtask_names}") + logging.info(f"Loaded {len(subtask_names)} subtasks: {subtask_names}") logging.info(f"Temporal proportions: {temporal_proportions_dict}") - + def to(self, device): """Override to method to ensure all components move together.""" super().to(device) @@ -503,43 +458,23 @@ class SARMRewardModel(PreTrainedPolicy): return rewards - - def load_pretrained_checkpoint(self, checkpoint_path: str, strict: bool = False): - """Load pretrained model weights from a checkpoint file.""" - logging.info(f"Loading pretrained checkpoint from {checkpoint_path}") - checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=False) - - if "model_state_dict" in checkpoint: - state_dict = checkpoint["model_state_dict"] - else: - state_dict = checkpoint - - missing_keys, unexpected_keys = self.sarm_transformer.load_state_dict(state_dict, strict=strict) - - if missing_keys: - logging.warning(f"Missing keys when loading checkpoint: {missing_keys}") - if unexpected_keys: - logging.warning(f"Unexpected keys when loading checkpoint: {unexpected_keys}") - - logging.info("Checkpoint loaded successfully") - def train(self, mode: bool = True): - """Set training mode. Note: CLIP encoder always stays in eval mode (frozen).""" + """Overwrite train method to ensure CLIP encoder stays frozen during training""" super().train(mode) self.clip_model.eval() self.sarm_transformer.train(mode) return self def eval(self): - """Set evaluation mode.""" + """Overwrite eval method to ensure CLIP encoder stays frozen during evaluation""" return self.train(False) def parameters(self): - """Return trainable parameters (only SARM transformer, not encoders).""" + """Override to return trainable parameters (only SARM transformer, not CLIP encoder).""" return self.sarm_transformer.parameters() def get_optim_params(self): - """Return optimizer parameters for the policy.""" + """Override to return optimizer parameters (only SARM transformer, not CLIP encoder).""" return self.parameters() def reset(self): @@ -619,9 +554,7 @@ class SARMRewardModel(PreTrainedPolicy): # Extract required features video_features = observation['video_features'].to(self.device) text_features = observation['text_features'].to(self.device) - state_features = observation.get('state_features') - if state_features is not None: - state_features = state_features.to(self.device) + state_features = observation.get('state_features').to(self.device) batch_size = video_features.shape[0] max_length = self.config.num_frames @@ -643,7 +576,7 @@ class SARMRewardModel(PreTrainedPolicy): if progress_from_annotations.dim() == 3 and progress_from_annotations.shape[0] == 1: progress_from_annotations = progress_from_annotations.expand(batch_size, -1, -1) - # Process each sample: apply temporal augmentation (SARM paper A.4) + # Process each sample: apply temporal REWIND augmentation processed_videos = [] processed_states = [] progress_targets = [] @@ -653,7 +586,7 @@ class SARMRewardModel(PreTrainedPolicy): state = state_features[i] if state_features is not None else None progress = progress_from_annotations[i].squeeze(-1) # (T,) - # Apply temporal augmentation with 50% probability: appends up to 4 reversed frames to simulate failures/recoveries + # Apply temporal REWIND augmentation with 50% probability: appends up to 4 reversed frames to simulate failures/recoveries if random.random() < 0.5: video, progress, state = self._apply_temporal_augmentation(video, progress, state, max_length) diff --git a/src/lerobot/policies/sarm/processor_sarm.py b/src/lerobot/policies/sarm/processor_sarm.py index 45e2f6821..58d20ae8a 100644 --- a/src/lerobot/policies/sarm/processor_sarm.py +++ b/src/lerobot/policies/sarm/processor_sarm.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json import logging from typing import Any import numpy as np @@ -23,7 +24,7 @@ import pandas as pd from transformers import CLIPModel, CLIPProcessor from lerobot.policies.sarm.configuration_sarm import SARMConfig -from lerobot.policies.sarm.sarm_utils import compute_priors, compute_tau, compute_cumulative_progress_batch, pad_state_to_max_dim +from lerobot.policies.sarm.sarm_utils import compute_tau, compute_cumulative_progress_batch, pad_state_to_max_dim from lerobot.processor import ( ProcessorStep, PolicyProcessorPipeline, @@ -57,105 +58,19 @@ class SARMEncodingProcessorStep(ProcessorStep): self.image_key = image_key or config.image_key self.dataset_meta = dataset_meta self.dataset_stats = dataset_stats - self.temporal_proportions = None - self.subtask_names = None - if dataset_meta is not None: - self._compute_temporal_proportions() - - self._init_encoders() - - def _init_encoders(self): - """Initialize CLIP encoder for both images and text (per SARM paper A.4).""" - device = torch.device( + self.temporal_proportions = {name: prop for name, prop in zip(self.config.subtask_names, self.config.temporal_proportions)} + self.subtask_names = self.config.subtask_names + + self.device = torch.device( self.config.device if self.config.device else "cuda" if torch.cuda.is_available() else "cpu" ) - logging.info("Initializing CLIP encoder for SARM (images + text)...") self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", use_fast=True) - self.clip_model.to(device) + self.clip_model.to(self.device) self.clip_model.eval() - - self.device = device - - def _compute_temporal_proportions(self): - """Compute temporal proportions for each subtask from dataset annotations. - - Implements SARM Paper Formula (1): - ᾱ_k = (1/M) × Σ_i (L_{i,k} / T_i) - - This averages the proportion of time spent on each subtask within each trajectory, - giving equal weight to all trajectories regardless of absolute length. - """ - if self.dataset_meta is None or not hasattr(self.dataset_meta, 'episodes'): - return - - # Check if subtask annotations exist - episodes = self.dataset_meta.episodes - if episodes is None or len(episodes) == 0: - return - - # Check for subtask_names column - if 'subtask_names' not in episodes.column_names: - logging.info("No subtask annotations found in dataset") - return - - episodes_df = episodes.to_pandas() - - # Collect subtask durations and trajectory lengths for compute_priors - subtask_durations_per_trajectory = {} - trajectory_lengths = {} - all_subtask_names = set() - - for ep_idx in episodes_df.index: - subtask_names_ep = episodes_df.loc[ep_idx, 'subtask_names'] - - # Skip episodes without annotations - if subtask_names_ep is None or (isinstance(subtask_names_ep, float) and pd.isna(subtask_names_ep)): - continue - - start_times = episodes_df.loc[ep_idx, 'subtask_start_times'] - end_times = episodes_df.loc[ep_idx, 'subtask_end_times'] - - # Track unique subtask names - all_subtask_names.update(subtask_names_ep) - - # Compute total trajectory length T_i (sum of all subtask durations) - total_traj_length = sum(end_times[i] - start_times[i] for i in range(len(subtask_names_ep))) - - if total_traj_length <= 0: - continue - - # Store duration and trajectory length for each subtask occurrence - for i, name in enumerate(subtask_names_ep): - duration = end_times[i] - start_times[i] - - if name not in subtask_durations_per_trajectory: - subtask_durations_per_trajectory[name] = [] - trajectory_lengths[name] = [] - - subtask_durations_per_trajectory[name].append(duration) - trajectory_lengths[name].append(total_traj_length) - - if not all_subtask_names: - logging.info("No valid subtask annotations found") - return - - # Sort subtask names for consistent ordering - self.subtask_names = sorted(list(all_subtask_names)) - self.config.num_stages = len(self.subtask_names) - self.config.subtask_names = self.subtask_names - - # Compute temporal proportions using Paper Formula (1) - self.temporal_proportions = compute_priors( - subtask_durations_per_trajectory, - trajectory_lengths, - self.subtask_names - ) - self.config.temporal_proportions = [self.temporal_proportions[name] for name in self.subtask_names] - logging.info(f"Computed temporal proportions for {len(self.subtask_names)} subtasks: {self.temporal_proportions}") - + def _find_episode_for_frame(self, frame_idx: int) -> int: """Find the episode index for a given frame index.""" for ep_idx in range(len(self.dataset_meta.episodes)): @@ -437,10 +352,8 @@ class SARMEncodingProcessorStep(ProcessorStep): else: state_tensor = torch.tensor(state_data, dtype=torch.float32) - # Pad state observation['state_features'] = pad_state_to_max_dim(state_tensor, self.config.max_state_dim) - # Extract complementary data (includes task from dataset) comp_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) if not isinstance(comp_data, dict): raise ValueError("COMPLEMENTARY_DATA must be a dictionary") @@ -595,7 +508,6 @@ class SARMEncodingProcessorStep(ProcessorStep): self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: """Add encoded features to the observation features.""" - # Add the encoded features (state uses max_state_dim, padded with zeros) features[PipelineFeatureType.OBSERVATION]['video_features'] = PolicyFeature( type=FeatureType.VISUAL, shape=(self.config.num_frames, self.config.image_dim) @@ -622,60 +534,40 @@ def make_sarm_pre_post_processors( """ Create pre-processor and post-processor pipelines for SARM. - Per SARM paper (Appendix A.4): "We employ a frozen clip-vit-base-patch32 encoder - to process both RGB image sequences and task descriptions." - The pre-processing pipeline: 1. Adds batch dimension 2. Normalizes observation.state using NormalizerProcessorStep (MEAN_STD) 3. SARMEncodingProcessorStep: - - Encodes images with CLIP (512-dim) + - Encodes images with CLIP - Pads states to max_state_dim - - Encodes text with CLIP (512-dim) + - Encodes text with CLIP 4. Moves data to device The post-processing pipeline: - 1. Moves data to CPU (no unnormalization - outputs are rewards) - - Args: - config: SARM configuration - dataset_stats: Dataset statistics for normalization - dataset_meta: Dataset metadata for computing episode info - - Returns: - Tuple of (preprocessor, postprocessor) pipelines + 1. Moves data to CPU """ - input_steps = [ - AddBatchDimensionProcessorStep(), - NormalizerProcessorStep( - features={**config.input_features, **config.output_features}, - norm_map=config.normalization_mapping, - stats=dataset_stats, - ), - SARMEncodingProcessorStep( - config=config, - dataset_meta=dataset_meta, - dataset_stats=dataset_stats - ), - DeviceProcessorStep(device=config.device), - ] - - output_steps = [ - DeviceProcessorStep(device="cpu"), - ] - return ( PolicyProcessorPipeline[dict[str, Any], dict[str, Any]]( - steps=input_steps, + steps=[ + AddBatchDimensionProcessorStep(), + NormalizerProcessorStep( + features={**config.input_features, **config.output_features}, + norm_map=config.normalization_mapping, + stats=dataset_stats, + ), + SARMEncodingProcessorStep( + config=config, + dataset_meta=dataset_meta, + dataset_stats=dataset_stats + ), + DeviceProcessorStep(device=config.device), + ], name=POLICY_PREPROCESSOR_DEFAULT_NAME, ), PolicyProcessorPipeline[PolicyAction, PolicyAction]( - steps=output_steps, + steps=[DeviceProcessorStep(device="cpu")], name=POLICY_POSTPROCESSOR_DEFAULT_NAME, to_transition=policy_action_to_transition, to_output=transition_to_policy_action, ), ) - - - diff --git a/src/lerobot/policies/sarm/sarm_utils.py b/src/lerobot/policies/sarm/sarm_utils.py index ccfa10361..307480a8c 100644 --- a/src/lerobot/policies/sarm/sarm_utils.py +++ b/src/lerobot/policies/sarm/sarm_utils.py @@ -246,4 +246,3 @@ def pad_state_to_max_dim(state: torch.Tensor, max_state_dim: int) -> torch.Tenso # Pad with zeros on the right padding = (0, max_state_dim - current_dim) # (left, right) for last dim return F.pad(state, padding, mode='constant', value=0) -