This commit is contained in:
Pepijn
2025-11-27 17:36:00 +01:00
parent 2889c0650a
commit 73dd4f10f7
5 changed files with 75 additions and 258 deletions
@@ -73,7 +73,7 @@ import json
import time
from pathlib import Path
from typing import Optional
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, as_completed
from concurrent.futures import ProcessPoolExecutor, as_completed
import multiprocessing as mp
import pandas as pd
@@ -82,7 +82,6 @@ from pydantic import BaseModel, Field
from qwen_vl_utils import process_vision_info
from rich.console import Console
from rich.panel import Panel
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn
from rich.tree import Tree
from transformers import Qwen3VLMoeForConditionalGeneration, AutoProcessor
@@ -176,7 +175,6 @@ class VideoAnnotator:
self.console.print(f"[cyan]Loading model: {model_name}...[/cyan]")
# Load model and processor
self.model = Qwen3VLMoeForConditionalGeneration.from_pretrained(
model_name,
torch_dtype=torch_dtype,
@@ -237,18 +235,17 @@ class VideoAnnotator:
cmd = [
'ffmpeg',
'-i', str(file_path),
'-ss', str(start_timestamp), # Start time
'-t', str(duration), # Duration
'-r', str(target_fps), # Output FPS
'-c:v', 'libx264', # H.264 codec
'-preset', 'ultrafast', # Faster encoding
'-crf', '23', # Better quality (lower = better)
'-an', # Remove audio
'-y', # Overwrite output file
'-ss', str(start_timestamp),
'-t', str(duration),
'-r', str(target_fps),
'-c:v', 'libx264',
'-preset', 'ultrafast',
'-crf', '23',
'-an',
'-y',
str(tmp_path)
]
# Run ffmpeg (suppress output)
subprocess.run(
cmd,
stdout=subprocess.DEVNULL,
@@ -309,7 +306,6 @@ class VideoAnnotator:
# Calculate episode duration
if end_timestamp is None:
# Get video metadata (suppress AV1 warnings)
import cv2
import os
import sys
@@ -27,13 +27,13 @@ from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
class SARMConfig(PreTrainedConfig):
"""Configuration class for SARM (Stage-Aware Reward Modeling)"""
# CLIP encoding parameters
# CLIP params
image_dim: int = 512
text_dim: int = 512
num_frames: int = 9 # 1 initial + 8 consecutive frames
frame_gap: int = 30 # Frame gap between consecutive frames (at 30 fps = 1 second)
frame_gap: int = 30 # Frame gap between frames (at 30 fps = 1 second)
# Architecture parameters
# Architecture params
hidden_dim: int = 768
num_heads: int = 12
num_layers: int = 8
@@ -44,7 +44,7 @@ class SARMConfig(PreTrainedConfig):
max_length: int = num_frames # Maximum video sequence length (matches num_frames)
use_temporal_sampler: bool = True # Always enable temporal sequence loading
# Training parameters
# Training params
batch_size: int = 64
clip_batch_size: int = 64 # Batch size for CLIP encoding
dropout: float = 0.1
@@ -80,14 +80,14 @@ class SARMConfig(PreTrainedConfig):
def __post_init__(self):
super().__post_init__()
# Add the image_key as VISUAL (this is the raw image from dataset)
# Add the image_key as VISUAL
if self.image_key:
self.input_features[self.image_key] = PolicyFeature(
shape=(480, 640, 3),
type=FeatureType.VISUAL
)
# Add state_key as STATE (raw state from dataset, will be padded to max_state_dim)
# Add state_key as STATE
self.input_features[self.state_key] = PolicyFeature(
shape=(self.max_state_dim,), # Single frame state, temporal sampling handles sequence
type=FeatureType.STATE
@@ -151,16 +151,13 @@ class SARMConfig(PreTrainedConfig):
to the episode start (frame 0) by the dataset loader. This ensures we always
get the initial frame regardless of the current position in the episode.
Returns:
9 delta indices: [-1_000_000, -(7*gap), -(6*gap), ..., -gap, 0]
"""
initial_frame_delta = -1_000_000
# Remaining consecutive frames with frame_gap spacing
num_consecutive = self.num_frames - 1
consecutive_deltas = list(range(-self.frame_gap * (num_consecutive - 1), 1, self.frame_gap))
num_consecutive = self.num_frames - 1 # 9 - 1 = 8
consecutive_deltas = list(range(-self.frame_gap * (num_consecutive - 1), 1, self.frame_gap)) # [-210, -180, -150, -120, -90, -60, -30, 0]
return [initial_frame_delta] + consecutive_deltas
@property
+32 -99
View File
@@ -19,7 +19,6 @@ from typing import List, Union, Optional
import random
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
@@ -31,9 +30,6 @@ from lerobot.policies.sarm.configuration_sarm import SARMConfig
from lerobot.policies.sarm.sarm_utils import compute_priors, compute_cumulative_progress_batch, pad_state_to_max_dim
from lerobot.policies.pretrained import PreTrainedPolicy
class SARMTransformer(nn.Module):
"""
SARM Transformer model for stage-aware reward prediction.
@@ -41,8 +37,6 @@ class SARMTransformer(nn.Module):
This model has a dual-head architecture:
1. Stage estimator: Predicts the high-level task stage (classification)
2. Subtask estimator: Predicts fine-grained progress within the stage (regression)
The subtask estimator is conditioned on the stage prediction.
"""
def __init__(
@@ -104,7 +98,7 @@ class SARMTransformer(nn.Module):
nn.Linear(512, num_stages)
)
# Subtask estimator head (regression, conditioned on stage)
# Subtask estimator head (regression)
self.stage_embedding = nn.Embedding(num_stages, hidden_dim // 4)
subtask_input_dim = hidden_dim + hidden_dim // 4
self.subtask_head = nn.Sequential(
@@ -219,13 +213,14 @@ class SARMRewardModel(PreTrainedPolicy):
def __init__(self, config: SARMConfig, dataset_stats: dict | None = None, dataset_meta=None):
super().__init__(config, dataset_stats)
config.validate_features()
self.config = config
self.dataset_stats = dataset_stats
self.device = torch.device(config.device if config.device else "cuda" if torch.cuda.is_available() else "cpu")
# Detect num_stages from dataset annotations before building the model
if dataset_meta is not None:
self._update_num_stages_from_dataset(dataset_meta)
# Load temporal proportions from dataset
if config.temporal_proportions is None and dataset_meta is not None:
self._load_temporal_proportions(dataset_meta)
logging.info("Loading CLIP encoder")
self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
@@ -246,78 +241,38 @@ class SARMRewardModel(PreTrainedPolicy):
temporal_proportions=config.temporal_proportions
)
self.sarm_transformer.to(self.device)
logging.info(f"SARM Reward Model initialized on {self.device}")
logging.info(f"SARM initialized on {self.device}")
def _update_num_stages_from_dataset(self, dataset_meta) -> None:
"""Update num_stages and temporal_proportions from dataset subtask annotations.
Implements SARM Paper Formula (1):
def _load_temporal_proportions(self, dataset_meta) -> None:
"""
Load pre-computed temporal proportions from dataset metadata JSON file.
The temporal proportions are computed during dataset annotation using SARM Paper Formula (1):
ᾱ_k = (1/M) × Σ_i (L_{i,k} / T_i)
"""
episodes = dataset_meta.episodes
if episodes is None or len(episodes) == 0:
raise ValueError("No episodes found, using default num_stages")
if 'subtask_names' not in episodes.column_names:
raise ValueError("No subtask annotations found in dataset, using default num_stages")
episodes_df = episodes.to_pandas()
import json
# Collect subtask durations and trajectory lengths for compute_priors
all_subtask_names = set()
subtask_durations_per_trajectory = {}
trajectory_lengths = {}
proportions_path = dataset_meta.root / "meta" / "temporal_proportions.json"
for ep_idx in episodes_df.index:
subtask_names_ep = episodes_df.loc[ep_idx, 'subtask_names']
if subtask_names_ep is None or (isinstance(subtask_names_ep, float) and pd.isna(subtask_names_ep)):
continue
all_subtask_names.update(subtask_names_ep)
# Compute durations if available
if 'subtask_start_frames' in episodes_df.columns and 'subtask_end_frames' in episodes_df.columns:
start_frames = episodes_df.loc[ep_idx, 'subtask_start_frames']
end_frames = episodes_df.loc[ep_idx, 'subtask_end_frames']
# Compute total trajectory length T_i
total_traj_length = sum(end_frames[i] - start_frames[i] for i in range(len(subtask_names_ep)))
if total_traj_length <= 0:
continue
for i, name in enumerate(subtask_names_ep):
duration = end_frames[i] - start_frames[i]
if name not in subtask_durations_per_trajectory:
subtask_durations_per_trajectory[name] = []
trajectory_lengths[name] = []
subtask_durations_per_trajectory[name].append(duration)
trajectory_lengths[name].append(total_traj_length)
if not proportions_path.exists():
raise ValueError(
f"Temporal proportions not found at {proportions_path}. "
"Run the subtask annotation tool first to compute and save temporal proportions."
)
if not all_subtask_names:
raise ValueError("No valid subtask names found, using default num_stages")
with open(proportions_path, "r") as f:
temporal_proportions_dict = json.load(f)
# Sort subtask names for consistent ordering
subtask_names = sorted(list(all_subtask_names))
num_stages = len(subtask_names)
subtask_names = sorted(temporal_proportions_dict.keys())
# Compute temporal proportions using Paper Formula (1)
temporal_proportions_dict = compute_priors(
subtask_durations_per_trajectory,
trajectory_lengths,
subtask_names
)
temporal_proportions = [temporal_proportions_dict[name] for name in subtask_names]
self.config.num_stages = num_stages
self.config.num_stages = len(subtask_names)
self.config.subtask_names = subtask_names
self.config.temporal_proportions = temporal_proportions
self.config.temporal_proportions = [temporal_proportions_dict[name] for name in subtask_names]
logging.info(f"Auto-detected {num_stages} subtasks: {subtask_names}")
logging.info(f"Loaded {len(subtask_names)} subtasks: {subtask_names}")
logging.info(f"Temporal proportions: {temporal_proportions_dict}")
def to(self, device):
"""Override to method to ensure all components move together."""
super().to(device)
@@ -503,43 +458,23 @@ class SARMRewardModel(PreTrainedPolicy):
return rewards
def load_pretrained_checkpoint(self, checkpoint_path: str, strict: bool = False):
"""Load pretrained model weights from a checkpoint file."""
logging.info(f"Loading pretrained checkpoint from {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=False)
if "model_state_dict" in checkpoint:
state_dict = checkpoint["model_state_dict"]
else:
state_dict = checkpoint
missing_keys, unexpected_keys = self.sarm_transformer.load_state_dict(state_dict, strict=strict)
if missing_keys:
logging.warning(f"Missing keys when loading checkpoint: {missing_keys}")
if unexpected_keys:
logging.warning(f"Unexpected keys when loading checkpoint: {unexpected_keys}")
logging.info("Checkpoint loaded successfully")
def train(self, mode: bool = True):
"""Set training mode. Note: CLIP encoder always stays in eval mode (frozen)."""
"""Overwrite train method to ensure CLIP encoder stays frozen during training"""
super().train(mode)
self.clip_model.eval()
self.sarm_transformer.train(mode)
return self
def eval(self):
"""Set evaluation mode."""
"""Overwrite eval method to ensure CLIP encoder stays frozen during evaluation"""
return self.train(False)
def parameters(self):
"""Return trainable parameters (only SARM transformer, not encoders)."""
"""Override to return trainable parameters (only SARM transformer, not CLIP encoder)."""
return self.sarm_transformer.parameters()
def get_optim_params(self):
"""Return optimizer parameters for the policy."""
"""Override to return optimizer parameters (only SARM transformer, not CLIP encoder)."""
return self.parameters()
def reset(self):
@@ -619,9 +554,7 @@ class SARMRewardModel(PreTrainedPolicy):
# Extract required features
video_features = observation['video_features'].to(self.device)
text_features = observation['text_features'].to(self.device)
state_features = observation.get('state_features')
if state_features is not None:
state_features = state_features.to(self.device)
state_features = observation.get('state_features').to(self.device)
batch_size = video_features.shape[0]
max_length = self.config.num_frames
@@ -643,7 +576,7 @@ class SARMRewardModel(PreTrainedPolicy):
if progress_from_annotations.dim() == 3 and progress_from_annotations.shape[0] == 1:
progress_from_annotations = progress_from_annotations.expand(batch_size, -1, -1)
# Process each sample: apply temporal augmentation (SARM paper A.4)
# Process each sample: apply temporal REWIND augmentation
processed_videos = []
processed_states = []
progress_targets = []
@@ -653,7 +586,7 @@ class SARMRewardModel(PreTrainedPolicy):
state = state_features[i] if state_features is not None else None
progress = progress_from_annotations[i].squeeze(-1) # (T,)
# Apply temporal augmentation with 50% probability: appends up to 4 reversed frames to simulate failures/recoveries
# Apply temporal REWIND augmentation with 50% probability: appends up to 4 reversed frames to simulate failures/recoveries
if random.random() < 0.5:
video, progress, state = self._apply_temporal_augmentation(video, progress, state, max_length)
+26 -134
View File
@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
from typing import Any
import numpy as np
@@ -23,7 +24,7 @@ import pandas as pd
from transformers import CLIPModel, CLIPProcessor
from lerobot.policies.sarm.configuration_sarm import SARMConfig
from lerobot.policies.sarm.sarm_utils import compute_priors, compute_tau, compute_cumulative_progress_batch, pad_state_to_max_dim
from lerobot.policies.sarm.sarm_utils import compute_tau, compute_cumulative_progress_batch, pad_state_to_max_dim
from lerobot.processor import (
ProcessorStep,
PolicyProcessorPipeline,
@@ -57,105 +58,19 @@ class SARMEncodingProcessorStep(ProcessorStep):
self.image_key = image_key or config.image_key
self.dataset_meta = dataset_meta
self.dataset_stats = dataset_stats
self.temporal_proportions = None
self.subtask_names = None
if dataset_meta is not None:
self._compute_temporal_proportions()
self._init_encoders()
def _init_encoders(self):
"""Initialize CLIP encoder for both images and text (per SARM paper A.4)."""
device = torch.device(
self.temporal_proportions = {name: prop for name, prop in zip(self.config.subtask_names, self.config.temporal_proportions)}
self.subtask_names = self.config.subtask_names
self.device = torch.device(
self.config.device if self.config.device
else "cuda" if torch.cuda.is_available() else "cpu"
)
logging.info("Initializing CLIP encoder for SARM (images + text)...")
self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", use_fast=True)
self.clip_model.to(device)
self.clip_model.to(self.device)
self.clip_model.eval()
self.device = device
def _compute_temporal_proportions(self):
"""Compute temporal proportions for each subtask from dataset annotations.
Implements SARM Paper Formula (1):
ᾱ_k = (1/M) × Σ_i (L_{i,k} / T_i)
This averages the proportion of time spent on each subtask within each trajectory,
giving equal weight to all trajectories regardless of absolute length.
"""
if self.dataset_meta is None or not hasattr(self.dataset_meta, 'episodes'):
return
# Check if subtask annotations exist
episodes = self.dataset_meta.episodes
if episodes is None or len(episodes) == 0:
return
# Check for subtask_names column
if 'subtask_names' not in episodes.column_names:
logging.info("No subtask annotations found in dataset")
return
episodes_df = episodes.to_pandas()
# Collect subtask durations and trajectory lengths for compute_priors
subtask_durations_per_trajectory = {}
trajectory_lengths = {}
all_subtask_names = set()
for ep_idx in episodes_df.index:
subtask_names_ep = episodes_df.loc[ep_idx, 'subtask_names']
# Skip episodes without annotations
if subtask_names_ep is None or (isinstance(subtask_names_ep, float) and pd.isna(subtask_names_ep)):
continue
start_times = episodes_df.loc[ep_idx, 'subtask_start_times']
end_times = episodes_df.loc[ep_idx, 'subtask_end_times']
# Track unique subtask names
all_subtask_names.update(subtask_names_ep)
# Compute total trajectory length T_i (sum of all subtask durations)
total_traj_length = sum(end_times[i] - start_times[i] for i in range(len(subtask_names_ep)))
if total_traj_length <= 0:
continue
# Store duration and trajectory length for each subtask occurrence
for i, name in enumerate(subtask_names_ep):
duration = end_times[i] - start_times[i]
if name not in subtask_durations_per_trajectory:
subtask_durations_per_trajectory[name] = []
trajectory_lengths[name] = []
subtask_durations_per_trajectory[name].append(duration)
trajectory_lengths[name].append(total_traj_length)
if not all_subtask_names:
logging.info("No valid subtask annotations found")
return
# Sort subtask names for consistent ordering
self.subtask_names = sorted(list(all_subtask_names))
self.config.num_stages = len(self.subtask_names)
self.config.subtask_names = self.subtask_names
# Compute temporal proportions using Paper Formula (1)
self.temporal_proportions = compute_priors(
subtask_durations_per_trajectory,
trajectory_lengths,
self.subtask_names
)
self.config.temporal_proportions = [self.temporal_proportions[name] for name in self.subtask_names]
logging.info(f"Computed temporal proportions for {len(self.subtask_names)} subtasks: {self.temporal_proportions}")
def _find_episode_for_frame(self, frame_idx: int) -> int:
"""Find the episode index for a given frame index."""
for ep_idx in range(len(self.dataset_meta.episodes)):
@@ -437,10 +352,8 @@ class SARMEncodingProcessorStep(ProcessorStep):
else:
state_tensor = torch.tensor(state_data, dtype=torch.float32)
# Pad state
observation['state_features'] = pad_state_to_max_dim(state_tensor, self.config.max_state_dim)
# Extract complementary data (includes task from dataset)
comp_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
if not isinstance(comp_data, dict):
raise ValueError("COMPLEMENTARY_DATA must be a dictionary")
@@ -595,7 +508,6 @@ class SARMEncodingProcessorStep(ProcessorStep):
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
"""Add encoded features to the observation features."""
# Add the encoded features (state uses max_state_dim, padded with zeros)
features[PipelineFeatureType.OBSERVATION]['video_features'] = PolicyFeature(
type=FeatureType.VISUAL,
shape=(self.config.num_frames, self.config.image_dim)
@@ -622,60 +534,40 @@ def make_sarm_pre_post_processors(
"""
Create pre-processor and post-processor pipelines for SARM.
Per SARM paper (Appendix A.4): "We employ a frozen clip-vit-base-patch32 encoder
to process both RGB image sequences and task descriptions."
The pre-processing pipeline:
1. Adds batch dimension
2. Normalizes observation.state using NormalizerProcessorStep (MEAN_STD)
3. SARMEncodingProcessorStep:
- Encodes images with CLIP (512-dim)
- Encodes images with CLIP
- Pads states to max_state_dim
- Encodes text with CLIP (512-dim)
- Encodes text with CLIP
4. Moves data to device
The post-processing pipeline:
1. Moves data to CPU (no unnormalization - outputs are rewards)
Args:
config: SARM configuration
dataset_stats: Dataset statistics for normalization
dataset_meta: Dataset metadata for computing episode info
Returns:
Tuple of (preprocessor, postprocessor) pipelines
1. Moves data to CPU
"""
input_steps = [
AddBatchDimensionProcessorStep(),
NormalizerProcessorStep(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
SARMEncodingProcessorStep(
config=config,
dataset_meta=dataset_meta,
dataset_stats=dataset_stats
),
DeviceProcessorStep(device=config.device),
]
output_steps = [
DeviceProcessorStep(device="cpu"),
]
return (
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
steps=input_steps,
steps=[
AddBatchDimensionProcessorStep(),
NormalizerProcessorStep(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
SARMEncodingProcessorStep(
config=config,
dataset_meta=dataset_meta,
dataset_stats=dataset_stats
),
DeviceProcessorStep(device=config.device),
],
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
),
PolicyProcessorPipeline[PolicyAction, PolicyAction](
steps=output_steps,
steps=[DeviceProcessorStep(device="cpu")],
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
to_transition=policy_action_to_transition,
to_output=transition_to_policy_action,
),
)
-1
View File
@@ -246,4 +246,3 @@ def pad_state_to_max_dim(state: torch.Tensor, max_state_dim: int) -> torch.Tenso
# Pad with zeros on the right
padding = (0, max_state_dim - current_dim) # (left, right) for last dim
return F.pad(state, padding, mode='constant', value=0)