mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-25 05:29:55 +00:00
simplify
This commit is contained in:
@@ -73,7 +73,7 @@ import json
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, as_completed
|
||||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||
import multiprocessing as mp
|
||||
|
||||
import pandas as pd
|
||||
@@ -82,7 +82,6 @@ from pydantic import BaseModel, Field
|
||||
from qwen_vl_utils import process_vision_info
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn
|
||||
from rich.tree import Tree
|
||||
from transformers import Qwen3VLMoeForConditionalGeneration, AutoProcessor
|
||||
|
||||
@@ -176,7 +175,6 @@ class VideoAnnotator:
|
||||
|
||||
self.console.print(f"[cyan]Loading model: {model_name}...[/cyan]")
|
||||
|
||||
# Load model and processor
|
||||
self.model = Qwen3VLMoeForConditionalGeneration.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch_dtype,
|
||||
@@ -237,18 +235,17 @@ class VideoAnnotator:
|
||||
cmd = [
|
||||
'ffmpeg',
|
||||
'-i', str(file_path),
|
||||
'-ss', str(start_timestamp), # Start time
|
||||
'-t', str(duration), # Duration
|
||||
'-r', str(target_fps), # Output FPS
|
||||
'-c:v', 'libx264', # H.264 codec
|
||||
'-preset', 'ultrafast', # Faster encoding
|
||||
'-crf', '23', # Better quality (lower = better)
|
||||
'-an', # Remove audio
|
||||
'-y', # Overwrite output file
|
||||
'-ss', str(start_timestamp),
|
||||
'-t', str(duration),
|
||||
'-r', str(target_fps),
|
||||
'-c:v', 'libx264',
|
||||
'-preset', 'ultrafast',
|
||||
'-crf', '23',
|
||||
'-an',
|
||||
'-y',
|
||||
str(tmp_path)
|
||||
]
|
||||
|
||||
# Run ffmpeg (suppress output)
|
||||
subprocess.run(
|
||||
cmd,
|
||||
stdout=subprocess.DEVNULL,
|
||||
@@ -309,7 +306,6 @@ class VideoAnnotator:
|
||||
|
||||
# Calculate episode duration
|
||||
if end_timestamp is None:
|
||||
# Get video metadata (suppress AV1 warnings)
|
||||
import cv2
|
||||
import os
|
||||
import sys
|
||||
|
||||
@@ -27,13 +27,13 @@ from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
|
||||
class SARMConfig(PreTrainedConfig):
|
||||
"""Configuration class for SARM (Stage-Aware Reward Modeling)"""
|
||||
|
||||
# CLIP encoding parameters
|
||||
# CLIP params
|
||||
image_dim: int = 512
|
||||
text_dim: int = 512
|
||||
num_frames: int = 9 # 1 initial + 8 consecutive frames
|
||||
frame_gap: int = 30 # Frame gap between consecutive frames (at 30 fps = 1 second)
|
||||
frame_gap: int = 30 # Frame gap between frames (at 30 fps = 1 second)
|
||||
|
||||
# Architecture parameters
|
||||
# Architecture params
|
||||
hidden_dim: int = 768
|
||||
num_heads: int = 12
|
||||
num_layers: int = 8
|
||||
@@ -44,7 +44,7 @@ class SARMConfig(PreTrainedConfig):
|
||||
max_length: int = num_frames # Maximum video sequence length (matches num_frames)
|
||||
use_temporal_sampler: bool = True # Always enable temporal sequence loading
|
||||
|
||||
# Training parameters
|
||||
# Training params
|
||||
batch_size: int = 64
|
||||
clip_batch_size: int = 64 # Batch size for CLIP encoding
|
||||
dropout: float = 0.1
|
||||
@@ -80,14 +80,14 @@ class SARMConfig(PreTrainedConfig):
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
# Add the image_key as VISUAL (this is the raw image from dataset)
|
||||
# Add the image_key as VISUAL
|
||||
if self.image_key:
|
||||
self.input_features[self.image_key] = PolicyFeature(
|
||||
shape=(480, 640, 3),
|
||||
type=FeatureType.VISUAL
|
||||
)
|
||||
|
||||
# Add state_key as STATE (raw state from dataset, will be padded to max_state_dim)
|
||||
# Add state_key as STATE
|
||||
self.input_features[self.state_key] = PolicyFeature(
|
||||
shape=(self.max_state_dim,), # Single frame state, temporal sampling handles sequence
|
||||
type=FeatureType.STATE
|
||||
@@ -151,16 +151,13 @@ class SARMConfig(PreTrainedConfig):
|
||||
to the episode start (frame 0) by the dataset loader. This ensures we always
|
||||
get the initial frame regardless of the current position in the episode.
|
||||
|
||||
|
||||
Returns:
|
||||
9 delta indices: [-1_000_000, -(7*gap), -(6*gap), ..., -gap, 0]
|
||||
"""
|
||||
initial_frame_delta = -1_000_000
|
||||
|
||||
# Remaining consecutive frames with frame_gap spacing
|
||||
num_consecutive = self.num_frames - 1
|
||||
consecutive_deltas = list(range(-self.frame_gap * (num_consecutive - 1), 1, self.frame_gap))
|
||||
|
||||
num_consecutive = self.num_frames - 1 # 9 - 1 = 8
|
||||
consecutive_deltas = list(range(-self.frame_gap * (num_consecutive - 1), 1, self.frame_gap)) # [-210, -180, -150, -120, -90, -60, -30, 0]
|
||||
return [initial_frame_delta] + consecutive_deltas
|
||||
|
||||
@property
|
||||
|
||||
@@ -19,7 +19,6 @@ from typing import List, Union, Optional
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
@@ -31,9 +30,6 @@ from lerobot.policies.sarm.configuration_sarm import SARMConfig
|
||||
from lerobot.policies.sarm.sarm_utils import compute_priors, compute_cumulative_progress_batch, pad_state_to_max_dim
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
|
||||
|
||||
|
||||
|
||||
class SARMTransformer(nn.Module):
|
||||
"""
|
||||
SARM Transformer model for stage-aware reward prediction.
|
||||
@@ -41,8 +37,6 @@ class SARMTransformer(nn.Module):
|
||||
This model has a dual-head architecture:
|
||||
1. Stage estimator: Predicts the high-level task stage (classification)
|
||||
2. Subtask estimator: Predicts fine-grained progress within the stage (regression)
|
||||
|
||||
The subtask estimator is conditioned on the stage prediction.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -104,7 +98,7 @@ class SARMTransformer(nn.Module):
|
||||
nn.Linear(512, num_stages)
|
||||
)
|
||||
|
||||
# Subtask estimator head (regression, conditioned on stage)
|
||||
# Subtask estimator head (regression)
|
||||
self.stage_embedding = nn.Embedding(num_stages, hidden_dim // 4)
|
||||
subtask_input_dim = hidden_dim + hidden_dim // 4
|
||||
self.subtask_head = nn.Sequential(
|
||||
@@ -219,13 +213,14 @@ class SARMRewardModel(PreTrainedPolicy):
|
||||
|
||||
def __init__(self, config: SARMConfig, dataset_stats: dict | None = None, dataset_meta=None):
|
||||
super().__init__(config, dataset_stats)
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
self.dataset_stats = dataset_stats
|
||||
self.device = torch.device(config.device if config.device else "cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
# Detect num_stages from dataset annotations before building the model
|
||||
if dataset_meta is not None:
|
||||
self._update_num_stages_from_dataset(dataset_meta)
|
||||
# Load temporal proportions from dataset
|
||||
if config.temporal_proportions is None and dataset_meta is not None:
|
||||
self._load_temporal_proportions(dataset_meta)
|
||||
|
||||
logging.info("Loading CLIP encoder")
|
||||
self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
||||
@@ -246,78 +241,38 @@ class SARMRewardModel(PreTrainedPolicy):
|
||||
temporal_proportions=config.temporal_proportions
|
||||
)
|
||||
self.sarm_transformer.to(self.device)
|
||||
logging.info(f"SARM Reward Model initialized on {self.device}")
|
||||
logging.info(f"SARM initialized on {self.device}")
|
||||
|
||||
def _update_num_stages_from_dataset(self, dataset_meta) -> None:
|
||||
"""Update num_stages and temporal_proportions from dataset subtask annotations.
|
||||
|
||||
Implements SARM Paper Formula (1):
|
||||
def _load_temporal_proportions(self, dataset_meta) -> None:
|
||||
"""
|
||||
Load pre-computed temporal proportions from dataset metadata JSON file.
|
||||
|
||||
The temporal proportions are computed during dataset annotation using SARM Paper Formula (1):
|
||||
ᾱ_k = (1/M) × Σ_i (L_{i,k} / T_i)
|
||||
"""
|
||||
episodes = dataset_meta.episodes
|
||||
if episodes is None or len(episodes) == 0:
|
||||
raise ValueError("No episodes found, using default num_stages")
|
||||
|
||||
if 'subtask_names' not in episodes.column_names:
|
||||
raise ValueError("No subtask annotations found in dataset, using default num_stages")
|
||||
|
||||
episodes_df = episodes.to_pandas()
|
||||
import json
|
||||
|
||||
# Collect subtask durations and trajectory lengths for compute_priors
|
||||
all_subtask_names = set()
|
||||
subtask_durations_per_trajectory = {}
|
||||
trajectory_lengths = {}
|
||||
proportions_path = dataset_meta.root / "meta" / "temporal_proportions.json"
|
||||
|
||||
for ep_idx in episodes_df.index:
|
||||
subtask_names_ep = episodes_df.loc[ep_idx, 'subtask_names']
|
||||
if subtask_names_ep is None or (isinstance(subtask_names_ep, float) and pd.isna(subtask_names_ep)):
|
||||
continue
|
||||
|
||||
all_subtask_names.update(subtask_names_ep)
|
||||
|
||||
# Compute durations if available
|
||||
if 'subtask_start_frames' in episodes_df.columns and 'subtask_end_frames' in episodes_df.columns:
|
||||
start_frames = episodes_df.loc[ep_idx, 'subtask_start_frames']
|
||||
end_frames = episodes_df.loc[ep_idx, 'subtask_end_frames']
|
||||
|
||||
# Compute total trajectory length T_i
|
||||
total_traj_length = sum(end_frames[i] - start_frames[i] for i in range(len(subtask_names_ep)))
|
||||
|
||||
if total_traj_length <= 0:
|
||||
continue
|
||||
|
||||
for i, name in enumerate(subtask_names_ep):
|
||||
duration = end_frames[i] - start_frames[i]
|
||||
|
||||
if name not in subtask_durations_per_trajectory:
|
||||
subtask_durations_per_trajectory[name] = []
|
||||
trajectory_lengths[name] = []
|
||||
|
||||
subtask_durations_per_trajectory[name].append(duration)
|
||||
trajectory_lengths[name].append(total_traj_length)
|
||||
if not proportions_path.exists():
|
||||
raise ValueError(
|
||||
f"Temporal proportions not found at {proportions_path}. "
|
||||
"Run the subtask annotation tool first to compute and save temporal proportions."
|
||||
)
|
||||
|
||||
if not all_subtask_names:
|
||||
raise ValueError("No valid subtask names found, using default num_stages")
|
||||
with open(proportions_path, "r") as f:
|
||||
temporal_proportions_dict = json.load(f)
|
||||
|
||||
# Sort subtask names for consistent ordering
|
||||
subtask_names = sorted(list(all_subtask_names))
|
||||
num_stages = len(subtask_names)
|
||||
subtask_names = sorted(temporal_proportions_dict.keys())
|
||||
|
||||
# Compute temporal proportions using Paper Formula (1)
|
||||
temporal_proportions_dict = compute_priors(
|
||||
subtask_durations_per_trajectory,
|
||||
trajectory_lengths,
|
||||
subtask_names
|
||||
)
|
||||
temporal_proportions = [temporal_proportions_dict[name] for name in subtask_names]
|
||||
|
||||
self.config.num_stages = num_stages
|
||||
self.config.num_stages = len(subtask_names)
|
||||
self.config.subtask_names = subtask_names
|
||||
self.config.temporal_proportions = temporal_proportions
|
||||
self.config.temporal_proportions = [temporal_proportions_dict[name] for name in subtask_names]
|
||||
|
||||
logging.info(f"Auto-detected {num_stages} subtasks: {subtask_names}")
|
||||
logging.info(f"Loaded {len(subtask_names)} subtasks: {subtask_names}")
|
||||
logging.info(f"Temporal proportions: {temporal_proportions_dict}")
|
||||
|
||||
|
||||
def to(self, device):
|
||||
"""Override to method to ensure all components move together."""
|
||||
super().to(device)
|
||||
@@ -503,43 +458,23 @@ class SARMRewardModel(PreTrainedPolicy):
|
||||
|
||||
return rewards
|
||||
|
||||
|
||||
def load_pretrained_checkpoint(self, checkpoint_path: str, strict: bool = False):
|
||||
"""Load pretrained model weights from a checkpoint file."""
|
||||
logging.info(f"Loading pretrained checkpoint from {checkpoint_path}")
|
||||
checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=False)
|
||||
|
||||
if "model_state_dict" in checkpoint:
|
||||
state_dict = checkpoint["model_state_dict"]
|
||||
else:
|
||||
state_dict = checkpoint
|
||||
|
||||
missing_keys, unexpected_keys = self.sarm_transformer.load_state_dict(state_dict, strict=strict)
|
||||
|
||||
if missing_keys:
|
||||
logging.warning(f"Missing keys when loading checkpoint: {missing_keys}")
|
||||
if unexpected_keys:
|
||||
logging.warning(f"Unexpected keys when loading checkpoint: {unexpected_keys}")
|
||||
|
||||
logging.info("Checkpoint loaded successfully")
|
||||
|
||||
def train(self, mode: bool = True):
|
||||
"""Set training mode. Note: CLIP encoder always stays in eval mode (frozen)."""
|
||||
"""Overwrite train method to ensure CLIP encoder stays frozen during training"""
|
||||
super().train(mode)
|
||||
self.clip_model.eval()
|
||||
self.sarm_transformer.train(mode)
|
||||
return self
|
||||
|
||||
def eval(self):
|
||||
"""Set evaluation mode."""
|
||||
"""Overwrite eval method to ensure CLIP encoder stays frozen during evaluation"""
|
||||
return self.train(False)
|
||||
|
||||
def parameters(self):
|
||||
"""Return trainable parameters (only SARM transformer, not encoders)."""
|
||||
"""Override to return trainable parameters (only SARM transformer, not CLIP encoder)."""
|
||||
return self.sarm_transformer.parameters()
|
||||
|
||||
def get_optim_params(self):
|
||||
"""Return optimizer parameters for the policy."""
|
||||
"""Override to return optimizer parameters (only SARM transformer, not CLIP encoder)."""
|
||||
return self.parameters()
|
||||
|
||||
def reset(self):
|
||||
@@ -619,9 +554,7 @@ class SARMRewardModel(PreTrainedPolicy):
|
||||
# Extract required features
|
||||
video_features = observation['video_features'].to(self.device)
|
||||
text_features = observation['text_features'].to(self.device)
|
||||
state_features = observation.get('state_features')
|
||||
if state_features is not None:
|
||||
state_features = state_features.to(self.device)
|
||||
state_features = observation.get('state_features').to(self.device)
|
||||
|
||||
batch_size = video_features.shape[0]
|
||||
max_length = self.config.num_frames
|
||||
@@ -643,7 +576,7 @@ class SARMRewardModel(PreTrainedPolicy):
|
||||
if progress_from_annotations.dim() == 3 and progress_from_annotations.shape[0] == 1:
|
||||
progress_from_annotations = progress_from_annotations.expand(batch_size, -1, -1)
|
||||
|
||||
# Process each sample: apply temporal augmentation (SARM paper A.4)
|
||||
# Process each sample: apply temporal REWIND augmentation
|
||||
processed_videos = []
|
||||
processed_states = []
|
||||
progress_targets = []
|
||||
@@ -653,7 +586,7 @@ class SARMRewardModel(PreTrainedPolicy):
|
||||
state = state_features[i] if state_features is not None else None
|
||||
progress = progress_from_annotations[i].squeeze(-1) # (T,)
|
||||
|
||||
# Apply temporal augmentation with 50% probability: appends up to 4 reversed frames to simulate failures/recoveries
|
||||
# Apply temporal REWIND augmentation with 50% probability: appends up to 4 reversed frames to simulate failures/recoveries
|
||||
if random.random() < 0.5:
|
||||
video, progress, state = self._apply_temporal_augmentation(video, progress, state, max_length)
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
import numpy as np
|
||||
@@ -23,7 +24,7 @@ import pandas as pd
|
||||
from transformers import CLIPModel, CLIPProcessor
|
||||
|
||||
from lerobot.policies.sarm.configuration_sarm import SARMConfig
|
||||
from lerobot.policies.sarm.sarm_utils import compute_priors, compute_tau, compute_cumulative_progress_batch, pad_state_to_max_dim
|
||||
from lerobot.policies.sarm.sarm_utils import compute_tau, compute_cumulative_progress_batch, pad_state_to_max_dim
|
||||
from lerobot.processor import (
|
||||
ProcessorStep,
|
||||
PolicyProcessorPipeline,
|
||||
@@ -57,105 +58,19 @@ class SARMEncodingProcessorStep(ProcessorStep):
|
||||
self.image_key = image_key or config.image_key
|
||||
self.dataset_meta = dataset_meta
|
||||
self.dataset_stats = dataset_stats
|
||||
self.temporal_proportions = None
|
||||
self.subtask_names = None
|
||||
if dataset_meta is not None:
|
||||
self._compute_temporal_proportions()
|
||||
|
||||
self._init_encoders()
|
||||
|
||||
def _init_encoders(self):
|
||||
"""Initialize CLIP encoder for both images and text (per SARM paper A.4)."""
|
||||
device = torch.device(
|
||||
self.temporal_proportions = {name: prop for name, prop in zip(self.config.subtask_names, self.config.temporal_proportions)}
|
||||
self.subtask_names = self.config.subtask_names
|
||||
|
||||
self.device = torch.device(
|
||||
self.config.device if self.config.device
|
||||
else "cuda" if torch.cuda.is_available() else "cpu"
|
||||
)
|
||||
|
||||
logging.info("Initializing CLIP encoder for SARM (images + text)...")
|
||||
self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
||||
self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", use_fast=True)
|
||||
self.clip_model.to(device)
|
||||
self.clip_model.to(self.device)
|
||||
self.clip_model.eval()
|
||||
|
||||
self.device = device
|
||||
|
||||
def _compute_temporal_proportions(self):
|
||||
"""Compute temporal proportions for each subtask from dataset annotations.
|
||||
|
||||
Implements SARM Paper Formula (1):
|
||||
ᾱ_k = (1/M) × Σ_i (L_{i,k} / T_i)
|
||||
|
||||
This averages the proportion of time spent on each subtask within each trajectory,
|
||||
giving equal weight to all trajectories regardless of absolute length.
|
||||
"""
|
||||
if self.dataset_meta is None or not hasattr(self.dataset_meta, 'episodes'):
|
||||
return
|
||||
|
||||
# Check if subtask annotations exist
|
||||
episodes = self.dataset_meta.episodes
|
||||
if episodes is None or len(episodes) == 0:
|
||||
return
|
||||
|
||||
# Check for subtask_names column
|
||||
if 'subtask_names' not in episodes.column_names:
|
||||
logging.info("No subtask annotations found in dataset")
|
||||
return
|
||||
|
||||
episodes_df = episodes.to_pandas()
|
||||
|
||||
# Collect subtask durations and trajectory lengths for compute_priors
|
||||
subtask_durations_per_trajectory = {}
|
||||
trajectory_lengths = {}
|
||||
all_subtask_names = set()
|
||||
|
||||
for ep_idx in episodes_df.index:
|
||||
subtask_names_ep = episodes_df.loc[ep_idx, 'subtask_names']
|
||||
|
||||
# Skip episodes without annotations
|
||||
if subtask_names_ep is None or (isinstance(subtask_names_ep, float) and pd.isna(subtask_names_ep)):
|
||||
continue
|
||||
|
||||
start_times = episodes_df.loc[ep_idx, 'subtask_start_times']
|
||||
end_times = episodes_df.loc[ep_idx, 'subtask_end_times']
|
||||
|
||||
# Track unique subtask names
|
||||
all_subtask_names.update(subtask_names_ep)
|
||||
|
||||
# Compute total trajectory length T_i (sum of all subtask durations)
|
||||
total_traj_length = sum(end_times[i] - start_times[i] for i in range(len(subtask_names_ep)))
|
||||
|
||||
if total_traj_length <= 0:
|
||||
continue
|
||||
|
||||
# Store duration and trajectory length for each subtask occurrence
|
||||
for i, name in enumerate(subtask_names_ep):
|
||||
duration = end_times[i] - start_times[i]
|
||||
|
||||
if name not in subtask_durations_per_trajectory:
|
||||
subtask_durations_per_trajectory[name] = []
|
||||
trajectory_lengths[name] = []
|
||||
|
||||
subtask_durations_per_trajectory[name].append(duration)
|
||||
trajectory_lengths[name].append(total_traj_length)
|
||||
|
||||
if not all_subtask_names:
|
||||
logging.info("No valid subtask annotations found")
|
||||
return
|
||||
|
||||
# Sort subtask names for consistent ordering
|
||||
self.subtask_names = sorted(list(all_subtask_names))
|
||||
self.config.num_stages = len(self.subtask_names)
|
||||
self.config.subtask_names = self.subtask_names
|
||||
|
||||
# Compute temporal proportions using Paper Formula (1)
|
||||
self.temporal_proportions = compute_priors(
|
||||
subtask_durations_per_trajectory,
|
||||
trajectory_lengths,
|
||||
self.subtask_names
|
||||
)
|
||||
self.config.temporal_proportions = [self.temporal_proportions[name] for name in self.subtask_names]
|
||||
logging.info(f"Computed temporal proportions for {len(self.subtask_names)} subtasks: {self.temporal_proportions}")
|
||||
|
||||
|
||||
def _find_episode_for_frame(self, frame_idx: int) -> int:
|
||||
"""Find the episode index for a given frame index."""
|
||||
for ep_idx in range(len(self.dataset_meta.episodes)):
|
||||
@@ -437,10 +352,8 @@ class SARMEncodingProcessorStep(ProcessorStep):
|
||||
else:
|
||||
state_tensor = torch.tensor(state_data, dtype=torch.float32)
|
||||
|
||||
# Pad state
|
||||
observation['state_features'] = pad_state_to_max_dim(state_tensor, self.config.max_state_dim)
|
||||
|
||||
# Extract complementary data (includes task from dataset)
|
||||
comp_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||
if not isinstance(comp_data, dict):
|
||||
raise ValueError("COMPLEMENTARY_DATA must be a dictionary")
|
||||
@@ -595,7 +508,6 @@ class SARMEncodingProcessorStep(ProcessorStep):
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
"""Add encoded features to the observation features."""
|
||||
# Add the encoded features (state uses max_state_dim, padded with zeros)
|
||||
features[PipelineFeatureType.OBSERVATION]['video_features'] = PolicyFeature(
|
||||
type=FeatureType.VISUAL,
|
||||
shape=(self.config.num_frames, self.config.image_dim)
|
||||
@@ -622,60 +534,40 @@ def make_sarm_pre_post_processors(
|
||||
"""
|
||||
Create pre-processor and post-processor pipelines for SARM.
|
||||
|
||||
Per SARM paper (Appendix A.4): "We employ a frozen clip-vit-base-patch32 encoder
|
||||
to process both RGB image sequences and task descriptions."
|
||||
|
||||
The pre-processing pipeline:
|
||||
1. Adds batch dimension
|
||||
2. Normalizes observation.state using NormalizerProcessorStep (MEAN_STD)
|
||||
3. SARMEncodingProcessorStep:
|
||||
- Encodes images with CLIP (512-dim)
|
||||
- Encodes images with CLIP
|
||||
- Pads states to max_state_dim
|
||||
- Encodes text with CLIP (512-dim)
|
||||
- Encodes text with CLIP
|
||||
4. Moves data to device
|
||||
|
||||
The post-processing pipeline:
|
||||
1. Moves data to CPU (no unnormalization - outputs are rewards)
|
||||
|
||||
Args:
|
||||
config: SARM configuration
|
||||
dataset_stats: Dataset statistics for normalization
|
||||
dataset_meta: Dataset metadata for computing episode info
|
||||
|
||||
Returns:
|
||||
Tuple of (preprocessor, postprocessor) pipelines
|
||||
1. Moves data to CPU
|
||||
"""
|
||||
input_steps = [
|
||||
AddBatchDimensionProcessorStep(),
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
SARMEncodingProcessorStep(
|
||||
config=config,
|
||||
dataset_meta=dataset_meta,
|
||||
dataset_stats=dataset_stats
|
||||
),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
]
|
||||
|
||||
output_steps = [
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
]
|
||||
|
||||
return (
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
|
||||
steps=input_steps,
|
||||
steps=[
|
||||
AddBatchDimensionProcessorStep(),
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
SARMEncodingProcessorStep(
|
||||
config=config,
|
||||
dataset_meta=dataset_meta,
|
||||
dataset_stats=dataset_stats
|
||||
),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
],
|
||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
),
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction](
|
||||
steps=output_steps,
|
||||
steps=[DeviceProcessorStep(device="cpu")],
|
||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
to_transition=policy_action_to_transition,
|
||||
to_output=transition_to_policy_action,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -246,4 +246,3 @@ def pad_state_to_max_dim(state: torch.Tensor, max_state_dim: int) -> torch.Tenso
|
||||
# Pad with zeros on the right
|
||||
padding = (0, max_state_dim - current_dim) # (left, right) for last dim
|
||||
return F.pad(state, padding, mode='constant', value=0)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user