This commit is contained in:
Pepijn
2025-11-27 17:36:00 +01:00
parent 2889c0650a
commit 73dd4f10f7
5 changed files with 75 additions and 258 deletions
@@ -73,7 +73,7 @@ import json
import time import time
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, as_completed from concurrent.futures import ProcessPoolExecutor, as_completed
import multiprocessing as mp import multiprocessing as mp
import pandas as pd import pandas as pd
@@ -82,7 +82,6 @@ from pydantic import BaseModel, Field
from qwen_vl_utils import process_vision_info from qwen_vl_utils import process_vision_info
from rich.console import Console from rich.console import Console
from rich.panel import Panel from rich.panel import Panel
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn
from rich.tree import Tree from rich.tree import Tree
from transformers import Qwen3VLMoeForConditionalGeneration, AutoProcessor from transformers import Qwen3VLMoeForConditionalGeneration, AutoProcessor
@@ -176,7 +175,6 @@ class VideoAnnotator:
self.console.print(f"[cyan]Loading model: {model_name}...[/cyan]") self.console.print(f"[cyan]Loading model: {model_name}...[/cyan]")
# Load model and processor
self.model = Qwen3VLMoeForConditionalGeneration.from_pretrained( self.model = Qwen3VLMoeForConditionalGeneration.from_pretrained(
model_name, model_name,
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
@@ -237,18 +235,17 @@ class VideoAnnotator:
cmd = [ cmd = [
'ffmpeg', 'ffmpeg',
'-i', str(file_path), '-i', str(file_path),
'-ss', str(start_timestamp), # Start time '-ss', str(start_timestamp),
'-t', str(duration), # Duration '-t', str(duration),
'-r', str(target_fps), # Output FPS '-r', str(target_fps),
'-c:v', 'libx264', # H.264 codec '-c:v', 'libx264',
'-preset', 'ultrafast', # Faster encoding '-preset', 'ultrafast',
'-crf', '23', # Better quality (lower = better) '-crf', '23',
'-an', # Remove audio '-an',
'-y', # Overwrite output file '-y',
str(tmp_path) str(tmp_path)
] ]
# Run ffmpeg (suppress output)
subprocess.run( subprocess.run(
cmd, cmd,
stdout=subprocess.DEVNULL, stdout=subprocess.DEVNULL,
@@ -309,7 +306,6 @@ class VideoAnnotator:
# Calculate episode duration # Calculate episode duration
if end_timestamp is None: if end_timestamp is None:
# Get video metadata (suppress AV1 warnings)
import cv2 import cv2
import os import os
import sys import sys
@@ -27,13 +27,13 @@ from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
class SARMConfig(PreTrainedConfig): class SARMConfig(PreTrainedConfig):
"""Configuration class for SARM (Stage-Aware Reward Modeling)""" """Configuration class for SARM (Stage-Aware Reward Modeling)"""
# CLIP encoding parameters # CLIP params
image_dim: int = 512 image_dim: int = 512
text_dim: int = 512 text_dim: int = 512
num_frames: int = 9 # 1 initial + 8 consecutive frames num_frames: int = 9 # 1 initial + 8 consecutive frames
frame_gap: int = 30 # Frame gap between consecutive frames (at 30 fps = 1 second) frame_gap: int = 30 # Frame gap between frames (at 30 fps = 1 second)
# Architecture parameters # Architecture params
hidden_dim: int = 768 hidden_dim: int = 768
num_heads: int = 12 num_heads: int = 12
num_layers: int = 8 num_layers: int = 8
@@ -44,7 +44,7 @@ class SARMConfig(PreTrainedConfig):
max_length: int = num_frames # Maximum video sequence length (matches num_frames) max_length: int = num_frames # Maximum video sequence length (matches num_frames)
use_temporal_sampler: bool = True # Always enable temporal sequence loading use_temporal_sampler: bool = True # Always enable temporal sequence loading
# Training parameters # Training params
batch_size: int = 64 batch_size: int = 64
clip_batch_size: int = 64 # Batch size for CLIP encoding clip_batch_size: int = 64 # Batch size for CLIP encoding
dropout: float = 0.1 dropout: float = 0.1
@@ -80,14 +80,14 @@ class SARMConfig(PreTrainedConfig):
def __post_init__(self): def __post_init__(self):
super().__post_init__() super().__post_init__()
# Add the image_key as VISUAL (this is the raw image from dataset) # Add the image_key as VISUAL
if self.image_key: if self.image_key:
self.input_features[self.image_key] = PolicyFeature( self.input_features[self.image_key] = PolicyFeature(
shape=(480, 640, 3), shape=(480, 640, 3),
type=FeatureType.VISUAL type=FeatureType.VISUAL
) )
# Add state_key as STATE (raw state from dataset, will be padded to max_state_dim) # Add state_key as STATE
self.input_features[self.state_key] = PolicyFeature( self.input_features[self.state_key] = PolicyFeature(
shape=(self.max_state_dim,), # Single frame state, temporal sampling handles sequence shape=(self.max_state_dim,), # Single frame state, temporal sampling handles sequence
type=FeatureType.STATE type=FeatureType.STATE
@@ -151,16 +151,13 @@ class SARMConfig(PreTrainedConfig):
to the episode start (frame 0) by the dataset loader. This ensures we always to the episode start (frame 0) by the dataset loader. This ensures we always
get the initial frame regardless of the current position in the episode. get the initial frame regardless of the current position in the episode.
Returns: Returns:
9 delta indices: [-1_000_000, -(7*gap), -(6*gap), ..., -gap, 0] 9 delta indices: [-1_000_000, -(7*gap), -(6*gap), ..., -gap, 0]
""" """
initial_frame_delta = -1_000_000 initial_frame_delta = -1_000_000
# Remaining consecutive frames with frame_gap spacing num_consecutive = self.num_frames - 1 # 9 - 1 = 8
num_consecutive = self.num_frames - 1 consecutive_deltas = list(range(-self.frame_gap * (num_consecutive - 1), 1, self.frame_gap)) # [-210, -180, -150, -120, -90, -60, -30, 0]
consecutive_deltas = list(range(-self.frame_gap * (num_consecutive - 1), 1, self.frame_gap))
return [initial_frame_delta] + consecutive_deltas return [initial_frame_delta] + consecutive_deltas
@property @property
+30 -97
View File
@@ -19,7 +19,6 @@ from typing import List, Union, Optional
import random import random
import numpy as np import numpy as np
import pandas as pd
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
@@ -31,9 +30,6 @@ from lerobot.policies.sarm.configuration_sarm import SARMConfig
from lerobot.policies.sarm.sarm_utils import compute_priors, compute_cumulative_progress_batch, pad_state_to_max_dim from lerobot.policies.sarm.sarm_utils import compute_priors, compute_cumulative_progress_batch, pad_state_to_max_dim
from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.pretrained import PreTrainedPolicy
class SARMTransformer(nn.Module): class SARMTransformer(nn.Module):
""" """
SARM Transformer model for stage-aware reward prediction. SARM Transformer model for stage-aware reward prediction.
@@ -41,8 +37,6 @@ class SARMTransformer(nn.Module):
This model has a dual-head architecture: This model has a dual-head architecture:
1. Stage estimator: Predicts the high-level task stage (classification) 1. Stage estimator: Predicts the high-level task stage (classification)
2. Subtask estimator: Predicts fine-grained progress within the stage (regression) 2. Subtask estimator: Predicts fine-grained progress within the stage (regression)
The subtask estimator is conditioned on the stage prediction.
""" """
def __init__( def __init__(
@@ -104,7 +98,7 @@ class SARMTransformer(nn.Module):
nn.Linear(512, num_stages) nn.Linear(512, num_stages)
) )
# Subtask estimator head (regression, conditioned on stage) # Subtask estimator head (regression)
self.stage_embedding = nn.Embedding(num_stages, hidden_dim // 4) self.stage_embedding = nn.Embedding(num_stages, hidden_dim // 4)
subtask_input_dim = hidden_dim + hidden_dim // 4 subtask_input_dim = hidden_dim + hidden_dim // 4
self.subtask_head = nn.Sequential( self.subtask_head = nn.Sequential(
@@ -219,13 +213,14 @@ class SARMRewardModel(PreTrainedPolicy):
def __init__(self, config: SARMConfig, dataset_stats: dict | None = None, dataset_meta=None): def __init__(self, config: SARMConfig, dataset_stats: dict | None = None, dataset_meta=None):
super().__init__(config, dataset_stats) super().__init__(config, dataset_stats)
config.validate_features()
self.config = config self.config = config
self.dataset_stats = dataset_stats self.dataset_stats = dataset_stats
self.device = torch.device(config.device if config.device else "cuda" if torch.cuda.is_available() else "cpu") self.device = torch.device(config.device if config.device else "cuda" if torch.cuda.is_available() else "cpu")
# Detect num_stages from dataset annotations before building the model # Load temporal proportions from dataset
if dataset_meta is not None: if config.temporal_proportions is None and dataset_meta is not None:
self._update_num_stages_from_dataset(dataset_meta) self._load_temporal_proportions(dataset_meta)
logging.info("Loading CLIP encoder") logging.info("Loading CLIP encoder")
self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
@@ -246,76 +241,36 @@ class SARMRewardModel(PreTrainedPolicy):
temporal_proportions=config.temporal_proportions temporal_proportions=config.temporal_proportions
) )
self.sarm_transformer.to(self.device) self.sarm_transformer.to(self.device)
logging.info(f"SARM Reward Model initialized on {self.device}") logging.info(f"SARM initialized on {self.device}")
def _update_num_stages_from_dataset(self, dataset_meta) -> None: def _load_temporal_proportions(self, dataset_meta) -> None:
"""Update num_stages and temporal_proportions from dataset subtask annotations. """
Load pre-computed temporal proportions from dataset metadata JSON file.
Implements SARM Paper Formula (1): The temporal proportions are computed during dataset annotation using SARM Paper Formula (1):
ᾱ_k = (1/M) × Σ_i (L_{i,k} / T_i) ᾱ_k = (1/M) × Σ_i (L_{i,k} / T_i)
""" """
episodes = dataset_meta.episodes import json
if episodes is None or len(episodes) == 0:
raise ValueError("No episodes found, using default num_stages")
if 'subtask_names' not in episodes.column_names: proportions_path = dataset_meta.root / "meta" / "temporal_proportions.json"
raise ValueError("No subtask annotations found in dataset, using default num_stages")
episodes_df = episodes.to_pandas() if not proportions_path.exists():
raise ValueError(
f"Temporal proportions not found at {proportions_path}. "
"Run the subtask annotation tool first to compute and save temporal proportions."
)
# Collect subtask durations and trajectory lengths for compute_priors with open(proportions_path, "r") as f:
all_subtask_names = set() temporal_proportions_dict = json.load(f)
subtask_durations_per_trajectory = {}
trajectory_lengths = {}
for ep_idx in episodes_df.index:
subtask_names_ep = episodes_df.loc[ep_idx, 'subtask_names']
if subtask_names_ep is None or (isinstance(subtask_names_ep, float) and pd.isna(subtask_names_ep)):
continue
all_subtask_names.update(subtask_names_ep)
# Compute durations if available
if 'subtask_start_frames' in episodes_df.columns and 'subtask_end_frames' in episodes_df.columns:
start_frames = episodes_df.loc[ep_idx, 'subtask_start_frames']
end_frames = episodes_df.loc[ep_idx, 'subtask_end_frames']
# Compute total trajectory length T_i
total_traj_length = sum(end_frames[i] - start_frames[i] for i in range(len(subtask_names_ep)))
if total_traj_length <= 0:
continue
for i, name in enumerate(subtask_names_ep):
duration = end_frames[i] - start_frames[i]
if name not in subtask_durations_per_trajectory:
subtask_durations_per_trajectory[name] = []
trajectory_lengths[name] = []
subtask_durations_per_trajectory[name].append(duration)
trajectory_lengths[name].append(total_traj_length)
if not all_subtask_names:
raise ValueError("No valid subtask names found, using default num_stages")
# Sort subtask names for consistent ordering # Sort subtask names for consistent ordering
subtask_names = sorted(list(all_subtask_names)) subtask_names = sorted(temporal_proportions_dict.keys())
num_stages = len(subtask_names)
# Compute temporal proportions using Paper Formula (1) self.config.num_stages = len(subtask_names)
temporal_proportions_dict = compute_priors(
subtask_durations_per_trajectory,
trajectory_lengths,
subtask_names
)
temporal_proportions = [temporal_proportions_dict[name] for name in subtask_names]
self.config.num_stages = num_stages
self.config.subtask_names = subtask_names self.config.subtask_names = subtask_names
self.config.temporal_proportions = temporal_proportions self.config.temporal_proportions = [temporal_proportions_dict[name] for name in subtask_names]
logging.info(f"Auto-detected {num_stages} subtasks: {subtask_names}") logging.info(f"Loaded {len(subtask_names)} subtasks: {subtask_names}")
logging.info(f"Temporal proportions: {temporal_proportions_dict}") logging.info(f"Temporal proportions: {temporal_proportions_dict}")
def to(self, device): def to(self, device):
@@ -503,43 +458,23 @@ class SARMRewardModel(PreTrainedPolicy):
return rewards return rewards
def load_pretrained_checkpoint(self, checkpoint_path: str, strict: bool = False):
"""Load pretrained model weights from a checkpoint file."""
logging.info(f"Loading pretrained checkpoint from {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=False)
if "model_state_dict" in checkpoint:
state_dict = checkpoint["model_state_dict"]
else:
state_dict = checkpoint
missing_keys, unexpected_keys = self.sarm_transformer.load_state_dict(state_dict, strict=strict)
if missing_keys:
logging.warning(f"Missing keys when loading checkpoint: {missing_keys}")
if unexpected_keys:
logging.warning(f"Unexpected keys when loading checkpoint: {unexpected_keys}")
logging.info("Checkpoint loaded successfully")
def train(self, mode: bool = True): def train(self, mode: bool = True):
"""Set training mode. Note: CLIP encoder always stays in eval mode (frozen).""" """Overwrite train method to ensure CLIP encoder stays frozen during training"""
super().train(mode) super().train(mode)
self.clip_model.eval() self.clip_model.eval()
self.sarm_transformer.train(mode) self.sarm_transformer.train(mode)
return self return self
def eval(self): def eval(self):
"""Set evaluation mode.""" """Overwrite eval method to ensure CLIP encoder stays frozen during evaluation"""
return self.train(False) return self.train(False)
def parameters(self): def parameters(self):
"""Return trainable parameters (only SARM transformer, not encoders).""" """Override to return trainable parameters (only SARM transformer, not CLIP encoder)."""
return self.sarm_transformer.parameters() return self.sarm_transformer.parameters()
def get_optim_params(self): def get_optim_params(self):
"""Return optimizer parameters for the policy.""" """Override to return optimizer parameters (only SARM transformer, not CLIP encoder)."""
return self.parameters() return self.parameters()
def reset(self): def reset(self):
@@ -619,9 +554,7 @@ class SARMRewardModel(PreTrainedPolicy):
# Extract required features # Extract required features
video_features = observation['video_features'].to(self.device) video_features = observation['video_features'].to(self.device)
text_features = observation['text_features'].to(self.device) text_features = observation['text_features'].to(self.device)
state_features = observation.get('state_features') state_features = observation.get('state_features').to(self.device)
if state_features is not None:
state_features = state_features.to(self.device)
batch_size = video_features.shape[0] batch_size = video_features.shape[0]
max_length = self.config.num_frames max_length = self.config.num_frames
@@ -643,7 +576,7 @@ class SARMRewardModel(PreTrainedPolicy):
if progress_from_annotations.dim() == 3 and progress_from_annotations.shape[0] == 1: if progress_from_annotations.dim() == 3 and progress_from_annotations.shape[0] == 1:
progress_from_annotations = progress_from_annotations.expand(batch_size, -1, -1) progress_from_annotations = progress_from_annotations.expand(batch_size, -1, -1)
# Process each sample: apply temporal augmentation (SARM paper A.4) # Process each sample: apply temporal REWIND augmentation
processed_videos = [] processed_videos = []
processed_states = [] processed_states = []
progress_targets = [] progress_targets = []
@@ -653,7 +586,7 @@ class SARMRewardModel(PreTrainedPolicy):
state = state_features[i] if state_features is not None else None state = state_features[i] if state_features is not None else None
progress = progress_from_annotations[i].squeeze(-1) # (T,) progress = progress_from_annotations[i].squeeze(-1) # (T,)
# Apply temporal augmentation with 50% probability: appends up to 4 reversed frames to simulate failures/recoveries # Apply temporal REWIND augmentation with 50% probability: appends up to 4 reversed frames to simulate failures/recoveries
if random.random() < 0.5: if random.random() < 0.5:
video, progress, state = self._apply_temporal_augmentation(video, progress, state, max_length) video, progress, state = self._apply_temporal_augmentation(video, progress, state, max_length)
+24 -132
View File
@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import json
import logging import logging
from typing import Any from typing import Any
import numpy as np import numpy as np
@@ -23,7 +24,7 @@ import pandas as pd
from transformers import CLIPModel, CLIPProcessor from transformers import CLIPModel, CLIPProcessor
from lerobot.policies.sarm.configuration_sarm import SARMConfig from lerobot.policies.sarm.configuration_sarm import SARMConfig
from lerobot.policies.sarm.sarm_utils import compute_priors, compute_tau, compute_cumulative_progress_batch, pad_state_to_max_dim from lerobot.policies.sarm.sarm_utils import compute_tau, compute_cumulative_progress_batch, pad_state_to_max_dim
from lerobot.processor import ( from lerobot.processor import (
ProcessorStep, ProcessorStep,
PolicyProcessorPipeline, PolicyProcessorPipeline,
@@ -57,105 +58,19 @@ class SARMEncodingProcessorStep(ProcessorStep):
self.image_key = image_key or config.image_key self.image_key = image_key or config.image_key
self.dataset_meta = dataset_meta self.dataset_meta = dataset_meta
self.dataset_stats = dataset_stats self.dataset_stats = dataset_stats
self.temporal_proportions = None self.temporal_proportions = {name: prop for name, prop in zip(self.config.subtask_names, self.config.temporal_proportions)}
self.subtask_names = None self.subtask_names = self.config.subtask_names
if dataset_meta is not None:
self._compute_temporal_proportions()
self._init_encoders() self.device = torch.device(
def _init_encoders(self):
"""Initialize CLIP encoder for both images and text (per SARM paper A.4)."""
device = torch.device(
self.config.device if self.config.device self.config.device if self.config.device
else "cuda" if torch.cuda.is_available() else "cpu" else "cuda" if torch.cuda.is_available() else "cpu"
) )
logging.info("Initializing CLIP encoder for SARM (images + text)...")
self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", use_fast=True) self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", use_fast=True)
self.clip_model.to(device) self.clip_model.to(self.device)
self.clip_model.eval() self.clip_model.eval()
self.device = device
def _compute_temporal_proportions(self):
"""Compute temporal proportions for each subtask from dataset annotations.
Implements SARM Paper Formula (1):
ᾱ_k = (1/M) × Σ_i (L_{i,k} / T_i)
This averages the proportion of time spent on each subtask within each trajectory,
giving equal weight to all trajectories regardless of absolute length.
"""
if self.dataset_meta is None or not hasattr(self.dataset_meta, 'episodes'):
return
# Check if subtask annotations exist
episodes = self.dataset_meta.episodes
if episodes is None or len(episodes) == 0:
return
# Check for subtask_names column
if 'subtask_names' not in episodes.column_names:
logging.info("No subtask annotations found in dataset")
return
episodes_df = episodes.to_pandas()
# Collect subtask durations and trajectory lengths for compute_priors
subtask_durations_per_trajectory = {}
trajectory_lengths = {}
all_subtask_names = set()
for ep_idx in episodes_df.index:
subtask_names_ep = episodes_df.loc[ep_idx, 'subtask_names']
# Skip episodes without annotations
if subtask_names_ep is None or (isinstance(subtask_names_ep, float) and pd.isna(subtask_names_ep)):
continue
start_times = episodes_df.loc[ep_idx, 'subtask_start_times']
end_times = episodes_df.loc[ep_idx, 'subtask_end_times']
# Track unique subtask names
all_subtask_names.update(subtask_names_ep)
# Compute total trajectory length T_i (sum of all subtask durations)
total_traj_length = sum(end_times[i] - start_times[i] for i in range(len(subtask_names_ep)))
if total_traj_length <= 0:
continue
# Store duration and trajectory length for each subtask occurrence
for i, name in enumerate(subtask_names_ep):
duration = end_times[i] - start_times[i]
if name not in subtask_durations_per_trajectory:
subtask_durations_per_trajectory[name] = []
trajectory_lengths[name] = []
subtask_durations_per_trajectory[name].append(duration)
trajectory_lengths[name].append(total_traj_length)
if not all_subtask_names:
logging.info("No valid subtask annotations found")
return
# Sort subtask names for consistent ordering
self.subtask_names = sorted(list(all_subtask_names))
self.config.num_stages = len(self.subtask_names)
self.config.subtask_names = self.subtask_names
# Compute temporal proportions using Paper Formula (1)
self.temporal_proportions = compute_priors(
subtask_durations_per_trajectory,
trajectory_lengths,
self.subtask_names
)
self.config.temporal_proportions = [self.temporal_proportions[name] for name in self.subtask_names]
logging.info(f"Computed temporal proportions for {len(self.subtask_names)} subtasks: {self.temporal_proportions}")
def _find_episode_for_frame(self, frame_idx: int) -> int: def _find_episode_for_frame(self, frame_idx: int) -> int:
"""Find the episode index for a given frame index.""" """Find the episode index for a given frame index."""
for ep_idx in range(len(self.dataset_meta.episodes)): for ep_idx in range(len(self.dataset_meta.episodes)):
@@ -437,10 +352,8 @@ class SARMEncodingProcessorStep(ProcessorStep):
else: else:
state_tensor = torch.tensor(state_data, dtype=torch.float32) state_tensor = torch.tensor(state_data, dtype=torch.float32)
# Pad state
observation['state_features'] = pad_state_to_max_dim(state_tensor, self.config.max_state_dim) observation['state_features'] = pad_state_to_max_dim(state_tensor, self.config.max_state_dim)
# Extract complementary data (includes task from dataset)
comp_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) comp_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
if not isinstance(comp_data, dict): if not isinstance(comp_data, dict):
raise ValueError("COMPLEMENTARY_DATA must be a dictionary") raise ValueError("COMPLEMENTARY_DATA must be a dictionary")
@@ -595,7 +508,6 @@ class SARMEncodingProcessorStep(ProcessorStep):
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
"""Add encoded features to the observation features.""" """Add encoded features to the observation features."""
# Add the encoded features (state uses max_state_dim, padded with zeros)
features[PipelineFeatureType.OBSERVATION]['video_features'] = PolicyFeature( features[PipelineFeatureType.OBSERVATION]['video_features'] = PolicyFeature(
type=FeatureType.VISUAL, type=FeatureType.VISUAL,
shape=(self.config.num_frames, self.config.image_dim) shape=(self.config.num_frames, self.config.image_dim)
@@ -622,60 +534,40 @@ def make_sarm_pre_post_processors(
""" """
Create pre-processor and post-processor pipelines for SARM. Create pre-processor and post-processor pipelines for SARM.
Per SARM paper (Appendix A.4): "We employ a frozen clip-vit-base-patch32 encoder
to process both RGB image sequences and task descriptions."
The pre-processing pipeline: The pre-processing pipeline:
1. Adds batch dimension 1. Adds batch dimension
2. Normalizes observation.state using NormalizerProcessorStep (MEAN_STD) 2. Normalizes observation.state using NormalizerProcessorStep (MEAN_STD)
3. SARMEncodingProcessorStep: 3. SARMEncodingProcessorStep:
- Encodes images with CLIP (512-dim) - Encodes images with CLIP
- Pads states to max_state_dim - Pads states to max_state_dim
- Encodes text with CLIP (512-dim) - Encodes text with CLIP
4. Moves data to device 4. Moves data to device
The post-processing pipeline: The post-processing pipeline:
1. Moves data to CPU (no unnormalization - outputs are rewards) 1. Moves data to CPU
Args:
config: SARM configuration
dataset_stats: Dataset statistics for normalization
dataset_meta: Dataset metadata for computing episode info
Returns:
Tuple of (preprocessor, postprocessor) pipelines
""" """
input_steps = [
AddBatchDimensionProcessorStep(),
NormalizerProcessorStep(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
SARMEncodingProcessorStep(
config=config,
dataset_meta=dataset_meta,
dataset_stats=dataset_stats
),
DeviceProcessorStep(device=config.device),
]
output_steps = [
DeviceProcessorStep(device="cpu"),
]
return ( return (
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]]( PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
steps=input_steps, steps=[
AddBatchDimensionProcessorStep(),
NormalizerProcessorStep(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
SARMEncodingProcessorStep(
config=config,
dataset_meta=dataset_meta,
dataset_stats=dataset_stats
),
DeviceProcessorStep(device=config.device),
],
name=POLICY_PREPROCESSOR_DEFAULT_NAME, name=POLICY_PREPROCESSOR_DEFAULT_NAME,
), ),
PolicyProcessorPipeline[PolicyAction, PolicyAction]( PolicyProcessorPipeline[PolicyAction, PolicyAction](
steps=output_steps, steps=[DeviceProcessorStep(device="cpu")],
name=POLICY_POSTPROCESSOR_DEFAULT_NAME, name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
to_transition=policy_action_to_transition, to_transition=policy_action_to_transition,
to_output=transition_to_policy_action, to_output=transition_to_policy_action,
), ),
) )
-1
View File
@@ -246,4 +246,3 @@ def pad_state_to_max_dim(state: torch.Tensor, max_state_dim: int) -> torch.Tenso
# Pad with zeros on the right # Pad with zeros on the right
padding = (0, max_state_dim - current_dim) # (left, right) for last dim padding = (0, max_state_dim - current_dim) # (left, right) for last dim
return F.pad(state, padding, mode='constant', value=0) return F.pad(state, padding, mode='constant', value=0)