mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-24 13:09:43 +00:00
simplify
This commit is contained in:
@@ -73,7 +73,7 @@ import json
|
|||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, as_completed
|
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||||
import multiprocessing as mp
|
import multiprocessing as mp
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
@@ -82,7 +82,6 @@ from pydantic import BaseModel, Field
|
|||||||
from qwen_vl_utils import process_vision_info
|
from qwen_vl_utils import process_vision_info
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
from rich.panel import Panel
|
from rich.panel import Panel
|
||||||
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn
|
|
||||||
from rich.tree import Tree
|
from rich.tree import Tree
|
||||||
from transformers import Qwen3VLMoeForConditionalGeneration, AutoProcessor
|
from transformers import Qwen3VLMoeForConditionalGeneration, AutoProcessor
|
||||||
|
|
||||||
@@ -176,7 +175,6 @@ class VideoAnnotator:
|
|||||||
|
|
||||||
self.console.print(f"[cyan]Loading model: {model_name}...[/cyan]")
|
self.console.print(f"[cyan]Loading model: {model_name}...[/cyan]")
|
||||||
|
|
||||||
# Load model and processor
|
|
||||||
self.model = Qwen3VLMoeForConditionalGeneration.from_pretrained(
|
self.model = Qwen3VLMoeForConditionalGeneration.from_pretrained(
|
||||||
model_name,
|
model_name,
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
@@ -237,18 +235,17 @@ class VideoAnnotator:
|
|||||||
cmd = [
|
cmd = [
|
||||||
'ffmpeg',
|
'ffmpeg',
|
||||||
'-i', str(file_path),
|
'-i', str(file_path),
|
||||||
'-ss', str(start_timestamp), # Start time
|
'-ss', str(start_timestamp),
|
||||||
'-t', str(duration), # Duration
|
'-t', str(duration),
|
||||||
'-r', str(target_fps), # Output FPS
|
'-r', str(target_fps),
|
||||||
'-c:v', 'libx264', # H.264 codec
|
'-c:v', 'libx264',
|
||||||
'-preset', 'ultrafast', # Faster encoding
|
'-preset', 'ultrafast',
|
||||||
'-crf', '23', # Better quality (lower = better)
|
'-crf', '23',
|
||||||
'-an', # Remove audio
|
'-an',
|
||||||
'-y', # Overwrite output file
|
'-y',
|
||||||
str(tmp_path)
|
str(tmp_path)
|
||||||
]
|
]
|
||||||
|
|
||||||
# Run ffmpeg (suppress output)
|
|
||||||
subprocess.run(
|
subprocess.run(
|
||||||
cmd,
|
cmd,
|
||||||
stdout=subprocess.DEVNULL,
|
stdout=subprocess.DEVNULL,
|
||||||
@@ -309,7 +306,6 @@ class VideoAnnotator:
|
|||||||
|
|
||||||
# Calculate episode duration
|
# Calculate episode duration
|
||||||
if end_timestamp is None:
|
if end_timestamp is None:
|
||||||
# Get video metadata (suppress AV1 warnings)
|
|
||||||
import cv2
|
import cv2
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|||||||
@@ -27,13 +27,13 @@ from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
|
|||||||
class SARMConfig(PreTrainedConfig):
|
class SARMConfig(PreTrainedConfig):
|
||||||
"""Configuration class for SARM (Stage-Aware Reward Modeling)"""
|
"""Configuration class for SARM (Stage-Aware Reward Modeling)"""
|
||||||
|
|
||||||
# CLIP encoding parameters
|
# CLIP params
|
||||||
image_dim: int = 512
|
image_dim: int = 512
|
||||||
text_dim: int = 512
|
text_dim: int = 512
|
||||||
num_frames: int = 9 # 1 initial + 8 consecutive frames
|
num_frames: int = 9 # 1 initial + 8 consecutive frames
|
||||||
frame_gap: int = 30 # Frame gap between consecutive frames (at 30 fps = 1 second)
|
frame_gap: int = 30 # Frame gap between frames (at 30 fps = 1 second)
|
||||||
|
|
||||||
# Architecture parameters
|
# Architecture params
|
||||||
hidden_dim: int = 768
|
hidden_dim: int = 768
|
||||||
num_heads: int = 12
|
num_heads: int = 12
|
||||||
num_layers: int = 8
|
num_layers: int = 8
|
||||||
@@ -44,7 +44,7 @@ class SARMConfig(PreTrainedConfig):
|
|||||||
max_length: int = num_frames # Maximum video sequence length (matches num_frames)
|
max_length: int = num_frames # Maximum video sequence length (matches num_frames)
|
||||||
use_temporal_sampler: bool = True # Always enable temporal sequence loading
|
use_temporal_sampler: bool = True # Always enable temporal sequence loading
|
||||||
|
|
||||||
# Training parameters
|
# Training params
|
||||||
batch_size: int = 64
|
batch_size: int = 64
|
||||||
clip_batch_size: int = 64 # Batch size for CLIP encoding
|
clip_batch_size: int = 64 # Batch size for CLIP encoding
|
||||||
dropout: float = 0.1
|
dropout: float = 0.1
|
||||||
@@ -80,14 +80,14 @@ class SARMConfig(PreTrainedConfig):
|
|||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
super().__post_init__()
|
super().__post_init__()
|
||||||
|
|
||||||
# Add the image_key as VISUAL (this is the raw image from dataset)
|
# Add the image_key as VISUAL
|
||||||
if self.image_key:
|
if self.image_key:
|
||||||
self.input_features[self.image_key] = PolicyFeature(
|
self.input_features[self.image_key] = PolicyFeature(
|
||||||
shape=(480, 640, 3),
|
shape=(480, 640, 3),
|
||||||
type=FeatureType.VISUAL
|
type=FeatureType.VISUAL
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add state_key as STATE (raw state from dataset, will be padded to max_state_dim)
|
# Add state_key as STATE
|
||||||
self.input_features[self.state_key] = PolicyFeature(
|
self.input_features[self.state_key] = PolicyFeature(
|
||||||
shape=(self.max_state_dim,), # Single frame state, temporal sampling handles sequence
|
shape=(self.max_state_dim,), # Single frame state, temporal sampling handles sequence
|
||||||
type=FeatureType.STATE
|
type=FeatureType.STATE
|
||||||
@@ -151,16 +151,13 @@ class SARMConfig(PreTrainedConfig):
|
|||||||
to the episode start (frame 0) by the dataset loader. This ensures we always
|
to the episode start (frame 0) by the dataset loader. This ensures we always
|
||||||
get the initial frame regardless of the current position in the episode.
|
get the initial frame regardless of the current position in the episode.
|
||||||
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
9 delta indices: [-1_000_000, -(7*gap), -(6*gap), ..., -gap, 0]
|
9 delta indices: [-1_000_000, -(7*gap), -(6*gap), ..., -gap, 0]
|
||||||
"""
|
"""
|
||||||
initial_frame_delta = -1_000_000
|
initial_frame_delta = -1_000_000
|
||||||
|
|
||||||
# Remaining consecutive frames with frame_gap spacing
|
num_consecutive = self.num_frames - 1 # 9 - 1 = 8
|
||||||
num_consecutive = self.num_frames - 1
|
consecutive_deltas = list(range(-self.frame_gap * (num_consecutive - 1), 1, self.frame_gap)) # [-210, -180, -150, -120, -90, -60, -30, 0]
|
||||||
consecutive_deltas = list(range(-self.frame_gap * (num_consecutive - 1), 1, self.frame_gap))
|
|
||||||
|
|
||||||
return [initial_frame_delta] + consecutive_deltas
|
return [initial_frame_delta] + consecutive_deltas
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
@@ -19,7 +19,6 @@ from typing import List, Union, Optional
|
|||||||
import random
|
import random
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@@ -31,9 +30,6 @@ from lerobot.policies.sarm.configuration_sarm import SARMConfig
|
|||||||
from lerobot.policies.sarm.sarm_utils import compute_priors, compute_cumulative_progress_batch, pad_state_to_max_dim
|
from lerobot.policies.sarm.sarm_utils import compute_priors, compute_cumulative_progress_batch, pad_state_to_max_dim
|
||||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class SARMTransformer(nn.Module):
|
class SARMTransformer(nn.Module):
|
||||||
"""
|
"""
|
||||||
SARM Transformer model for stage-aware reward prediction.
|
SARM Transformer model for stage-aware reward prediction.
|
||||||
@@ -41,8 +37,6 @@ class SARMTransformer(nn.Module):
|
|||||||
This model has a dual-head architecture:
|
This model has a dual-head architecture:
|
||||||
1. Stage estimator: Predicts the high-level task stage (classification)
|
1. Stage estimator: Predicts the high-level task stage (classification)
|
||||||
2. Subtask estimator: Predicts fine-grained progress within the stage (regression)
|
2. Subtask estimator: Predicts fine-grained progress within the stage (regression)
|
||||||
|
|
||||||
The subtask estimator is conditioned on the stage prediction.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -104,7 +98,7 @@ class SARMTransformer(nn.Module):
|
|||||||
nn.Linear(512, num_stages)
|
nn.Linear(512, num_stages)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Subtask estimator head (regression, conditioned on stage)
|
# Subtask estimator head (regression)
|
||||||
self.stage_embedding = nn.Embedding(num_stages, hidden_dim // 4)
|
self.stage_embedding = nn.Embedding(num_stages, hidden_dim // 4)
|
||||||
subtask_input_dim = hidden_dim + hidden_dim // 4
|
subtask_input_dim = hidden_dim + hidden_dim // 4
|
||||||
self.subtask_head = nn.Sequential(
|
self.subtask_head = nn.Sequential(
|
||||||
@@ -219,13 +213,14 @@ class SARMRewardModel(PreTrainedPolicy):
|
|||||||
|
|
||||||
def __init__(self, config: SARMConfig, dataset_stats: dict | None = None, dataset_meta=None):
|
def __init__(self, config: SARMConfig, dataset_stats: dict | None = None, dataset_meta=None):
|
||||||
super().__init__(config, dataset_stats)
|
super().__init__(config, dataset_stats)
|
||||||
|
config.validate_features()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.dataset_stats = dataset_stats
|
self.dataset_stats = dataset_stats
|
||||||
self.device = torch.device(config.device if config.device else "cuda" if torch.cuda.is_available() else "cpu")
|
self.device = torch.device(config.device if config.device else "cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
# Detect num_stages from dataset annotations before building the model
|
# Load temporal proportions from dataset
|
||||||
if dataset_meta is not None:
|
if config.temporal_proportions is None and dataset_meta is not None:
|
||||||
self._update_num_stages_from_dataset(dataset_meta)
|
self._load_temporal_proportions(dataset_meta)
|
||||||
|
|
||||||
logging.info("Loading CLIP encoder")
|
logging.info("Loading CLIP encoder")
|
||||||
self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
||||||
@@ -246,76 +241,36 @@ class SARMRewardModel(PreTrainedPolicy):
|
|||||||
temporal_proportions=config.temporal_proportions
|
temporal_proportions=config.temporal_proportions
|
||||||
)
|
)
|
||||||
self.sarm_transformer.to(self.device)
|
self.sarm_transformer.to(self.device)
|
||||||
logging.info(f"SARM Reward Model initialized on {self.device}")
|
logging.info(f"SARM initialized on {self.device}")
|
||||||
|
|
||||||
def _update_num_stages_from_dataset(self, dataset_meta) -> None:
|
def _load_temporal_proportions(self, dataset_meta) -> None:
|
||||||
"""Update num_stages and temporal_proportions from dataset subtask annotations.
|
"""
|
||||||
|
Load pre-computed temporal proportions from dataset metadata JSON file.
|
||||||
|
|
||||||
Implements SARM Paper Formula (1):
|
The temporal proportions are computed during dataset annotation using SARM Paper Formula (1):
|
||||||
ᾱ_k = (1/M) × Σ_i (L_{i,k} / T_i)
|
ᾱ_k = (1/M) × Σ_i (L_{i,k} / T_i)
|
||||||
"""
|
"""
|
||||||
episodes = dataset_meta.episodes
|
import json
|
||||||
if episodes is None or len(episodes) == 0:
|
|
||||||
raise ValueError("No episodes found, using default num_stages")
|
|
||||||
|
|
||||||
if 'subtask_names' not in episodes.column_names:
|
proportions_path = dataset_meta.root / "meta" / "temporal_proportions.json"
|
||||||
raise ValueError("No subtask annotations found in dataset, using default num_stages")
|
|
||||||
|
|
||||||
episodes_df = episodes.to_pandas()
|
if not proportions_path.exists():
|
||||||
|
raise ValueError(
|
||||||
|
f"Temporal proportions not found at {proportions_path}. "
|
||||||
|
"Run the subtask annotation tool first to compute and save temporal proportions."
|
||||||
|
)
|
||||||
|
|
||||||
# Collect subtask durations and trajectory lengths for compute_priors
|
with open(proportions_path, "r") as f:
|
||||||
all_subtask_names = set()
|
temporal_proportions_dict = json.load(f)
|
||||||
subtask_durations_per_trajectory = {}
|
|
||||||
trajectory_lengths = {}
|
|
||||||
|
|
||||||
for ep_idx in episodes_df.index:
|
|
||||||
subtask_names_ep = episodes_df.loc[ep_idx, 'subtask_names']
|
|
||||||
if subtask_names_ep is None or (isinstance(subtask_names_ep, float) and pd.isna(subtask_names_ep)):
|
|
||||||
continue
|
|
||||||
|
|
||||||
all_subtask_names.update(subtask_names_ep)
|
|
||||||
|
|
||||||
# Compute durations if available
|
|
||||||
if 'subtask_start_frames' in episodes_df.columns and 'subtask_end_frames' in episodes_df.columns:
|
|
||||||
start_frames = episodes_df.loc[ep_idx, 'subtask_start_frames']
|
|
||||||
end_frames = episodes_df.loc[ep_idx, 'subtask_end_frames']
|
|
||||||
|
|
||||||
# Compute total trajectory length T_i
|
|
||||||
total_traj_length = sum(end_frames[i] - start_frames[i] for i in range(len(subtask_names_ep)))
|
|
||||||
|
|
||||||
if total_traj_length <= 0:
|
|
||||||
continue
|
|
||||||
|
|
||||||
for i, name in enumerate(subtask_names_ep):
|
|
||||||
duration = end_frames[i] - start_frames[i]
|
|
||||||
|
|
||||||
if name not in subtask_durations_per_trajectory:
|
|
||||||
subtask_durations_per_trajectory[name] = []
|
|
||||||
trajectory_lengths[name] = []
|
|
||||||
|
|
||||||
subtask_durations_per_trajectory[name].append(duration)
|
|
||||||
trajectory_lengths[name].append(total_traj_length)
|
|
||||||
|
|
||||||
if not all_subtask_names:
|
|
||||||
raise ValueError("No valid subtask names found, using default num_stages")
|
|
||||||
|
|
||||||
# Sort subtask names for consistent ordering
|
# Sort subtask names for consistent ordering
|
||||||
subtask_names = sorted(list(all_subtask_names))
|
subtask_names = sorted(temporal_proportions_dict.keys())
|
||||||
num_stages = len(subtask_names)
|
|
||||||
|
|
||||||
# Compute temporal proportions using Paper Formula (1)
|
self.config.num_stages = len(subtask_names)
|
||||||
temporal_proportions_dict = compute_priors(
|
|
||||||
subtask_durations_per_trajectory,
|
|
||||||
trajectory_lengths,
|
|
||||||
subtask_names
|
|
||||||
)
|
|
||||||
temporal_proportions = [temporal_proportions_dict[name] for name in subtask_names]
|
|
||||||
|
|
||||||
self.config.num_stages = num_stages
|
|
||||||
self.config.subtask_names = subtask_names
|
self.config.subtask_names = subtask_names
|
||||||
self.config.temporal_proportions = temporal_proportions
|
self.config.temporal_proportions = [temporal_proportions_dict[name] for name in subtask_names]
|
||||||
|
|
||||||
logging.info(f"Auto-detected {num_stages} subtasks: {subtask_names}")
|
logging.info(f"Loaded {len(subtask_names)} subtasks: {subtask_names}")
|
||||||
logging.info(f"Temporal proportions: {temporal_proportions_dict}")
|
logging.info(f"Temporal proportions: {temporal_proportions_dict}")
|
||||||
|
|
||||||
def to(self, device):
|
def to(self, device):
|
||||||
@@ -503,43 +458,23 @@ class SARMRewardModel(PreTrainedPolicy):
|
|||||||
|
|
||||||
return rewards
|
return rewards
|
||||||
|
|
||||||
|
|
||||||
def load_pretrained_checkpoint(self, checkpoint_path: str, strict: bool = False):
|
|
||||||
"""Load pretrained model weights from a checkpoint file."""
|
|
||||||
logging.info(f"Loading pretrained checkpoint from {checkpoint_path}")
|
|
||||||
checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=False)
|
|
||||||
|
|
||||||
if "model_state_dict" in checkpoint:
|
|
||||||
state_dict = checkpoint["model_state_dict"]
|
|
||||||
else:
|
|
||||||
state_dict = checkpoint
|
|
||||||
|
|
||||||
missing_keys, unexpected_keys = self.sarm_transformer.load_state_dict(state_dict, strict=strict)
|
|
||||||
|
|
||||||
if missing_keys:
|
|
||||||
logging.warning(f"Missing keys when loading checkpoint: {missing_keys}")
|
|
||||||
if unexpected_keys:
|
|
||||||
logging.warning(f"Unexpected keys when loading checkpoint: {unexpected_keys}")
|
|
||||||
|
|
||||||
logging.info("Checkpoint loaded successfully")
|
|
||||||
|
|
||||||
def train(self, mode: bool = True):
|
def train(self, mode: bool = True):
|
||||||
"""Set training mode. Note: CLIP encoder always stays in eval mode (frozen)."""
|
"""Overwrite train method to ensure CLIP encoder stays frozen during training"""
|
||||||
super().train(mode)
|
super().train(mode)
|
||||||
self.clip_model.eval()
|
self.clip_model.eval()
|
||||||
self.sarm_transformer.train(mode)
|
self.sarm_transformer.train(mode)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def eval(self):
|
def eval(self):
|
||||||
"""Set evaluation mode."""
|
"""Overwrite eval method to ensure CLIP encoder stays frozen during evaluation"""
|
||||||
return self.train(False)
|
return self.train(False)
|
||||||
|
|
||||||
def parameters(self):
|
def parameters(self):
|
||||||
"""Return trainable parameters (only SARM transformer, not encoders)."""
|
"""Override to return trainable parameters (only SARM transformer, not CLIP encoder)."""
|
||||||
return self.sarm_transformer.parameters()
|
return self.sarm_transformer.parameters()
|
||||||
|
|
||||||
def get_optim_params(self):
|
def get_optim_params(self):
|
||||||
"""Return optimizer parameters for the policy."""
|
"""Override to return optimizer parameters (only SARM transformer, not CLIP encoder)."""
|
||||||
return self.parameters()
|
return self.parameters()
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
@@ -619,9 +554,7 @@ class SARMRewardModel(PreTrainedPolicy):
|
|||||||
# Extract required features
|
# Extract required features
|
||||||
video_features = observation['video_features'].to(self.device)
|
video_features = observation['video_features'].to(self.device)
|
||||||
text_features = observation['text_features'].to(self.device)
|
text_features = observation['text_features'].to(self.device)
|
||||||
state_features = observation.get('state_features')
|
state_features = observation.get('state_features').to(self.device)
|
||||||
if state_features is not None:
|
|
||||||
state_features = state_features.to(self.device)
|
|
||||||
|
|
||||||
batch_size = video_features.shape[0]
|
batch_size = video_features.shape[0]
|
||||||
max_length = self.config.num_frames
|
max_length = self.config.num_frames
|
||||||
@@ -643,7 +576,7 @@ class SARMRewardModel(PreTrainedPolicy):
|
|||||||
if progress_from_annotations.dim() == 3 and progress_from_annotations.shape[0] == 1:
|
if progress_from_annotations.dim() == 3 and progress_from_annotations.shape[0] == 1:
|
||||||
progress_from_annotations = progress_from_annotations.expand(batch_size, -1, -1)
|
progress_from_annotations = progress_from_annotations.expand(batch_size, -1, -1)
|
||||||
|
|
||||||
# Process each sample: apply temporal augmentation (SARM paper A.4)
|
# Process each sample: apply temporal REWIND augmentation
|
||||||
processed_videos = []
|
processed_videos = []
|
||||||
processed_states = []
|
processed_states = []
|
||||||
progress_targets = []
|
progress_targets = []
|
||||||
@@ -653,7 +586,7 @@ class SARMRewardModel(PreTrainedPolicy):
|
|||||||
state = state_features[i] if state_features is not None else None
|
state = state_features[i] if state_features is not None else None
|
||||||
progress = progress_from_annotations[i].squeeze(-1) # (T,)
|
progress = progress_from_annotations[i].squeeze(-1) # (T,)
|
||||||
|
|
||||||
# Apply temporal augmentation with 50% probability: appends up to 4 reversed frames to simulate failures/recoveries
|
# Apply temporal REWIND augmentation with 50% probability: appends up to 4 reversed frames to simulate failures/recoveries
|
||||||
if random.random() < 0.5:
|
if random.random() < 0.5:
|
||||||
video, progress, state = self._apply_temporal_augmentation(video, progress, state, max_length)
|
video, progress, state = self._apply_temporal_augmentation(video, progress, state, max_length)
|
||||||
|
|
||||||
|
|||||||
@@ -14,6 +14,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -23,7 +24,7 @@ import pandas as pd
|
|||||||
from transformers import CLIPModel, CLIPProcessor
|
from transformers import CLIPModel, CLIPProcessor
|
||||||
|
|
||||||
from lerobot.policies.sarm.configuration_sarm import SARMConfig
|
from lerobot.policies.sarm.configuration_sarm import SARMConfig
|
||||||
from lerobot.policies.sarm.sarm_utils import compute_priors, compute_tau, compute_cumulative_progress_batch, pad_state_to_max_dim
|
from lerobot.policies.sarm.sarm_utils import compute_tau, compute_cumulative_progress_batch, pad_state_to_max_dim
|
||||||
from lerobot.processor import (
|
from lerobot.processor import (
|
||||||
ProcessorStep,
|
ProcessorStep,
|
||||||
PolicyProcessorPipeline,
|
PolicyProcessorPipeline,
|
||||||
@@ -57,105 +58,19 @@ class SARMEncodingProcessorStep(ProcessorStep):
|
|||||||
self.image_key = image_key or config.image_key
|
self.image_key = image_key or config.image_key
|
||||||
self.dataset_meta = dataset_meta
|
self.dataset_meta = dataset_meta
|
||||||
self.dataset_stats = dataset_stats
|
self.dataset_stats = dataset_stats
|
||||||
self.temporal_proportions = None
|
self.temporal_proportions = {name: prop for name, prop in zip(self.config.subtask_names, self.config.temporal_proportions)}
|
||||||
self.subtask_names = None
|
self.subtask_names = self.config.subtask_names
|
||||||
if dataset_meta is not None:
|
|
||||||
self._compute_temporal_proportions()
|
|
||||||
|
|
||||||
self._init_encoders()
|
self.device = torch.device(
|
||||||
|
|
||||||
def _init_encoders(self):
|
|
||||||
"""Initialize CLIP encoder for both images and text (per SARM paper A.4)."""
|
|
||||||
device = torch.device(
|
|
||||||
self.config.device if self.config.device
|
self.config.device if self.config.device
|
||||||
else "cuda" if torch.cuda.is_available() else "cpu"
|
else "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("Initializing CLIP encoder for SARM (images + text)...")
|
|
||||||
self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
||||||
self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", use_fast=True)
|
self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", use_fast=True)
|
||||||
self.clip_model.to(device)
|
self.clip_model.to(self.device)
|
||||||
self.clip_model.eval()
|
self.clip_model.eval()
|
||||||
|
|
||||||
self.device = device
|
|
||||||
|
|
||||||
def _compute_temporal_proportions(self):
|
|
||||||
"""Compute temporal proportions for each subtask from dataset annotations.
|
|
||||||
|
|
||||||
Implements SARM Paper Formula (1):
|
|
||||||
ᾱ_k = (1/M) × Σ_i (L_{i,k} / T_i)
|
|
||||||
|
|
||||||
This averages the proportion of time spent on each subtask within each trajectory,
|
|
||||||
giving equal weight to all trajectories regardless of absolute length.
|
|
||||||
"""
|
|
||||||
if self.dataset_meta is None or not hasattr(self.dataset_meta, 'episodes'):
|
|
||||||
return
|
|
||||||
|
|
||||||
# Check if subtask annotations exist
|
|
||||||
episodes = self.dataset_meta.episodes
|
|
||||||
if episodes is None or len(episodes) == 0:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Check for subtask_names column
|
|
||||||
if 'subtask_names' not in episodes.column_names:
|
|
||||||
logging.info("No subtask annotations found in dataset")
|
|
||||||
return
|
|
||||||
|
|
||||||
episodes_df = episodes.to_pandas()
|
|
||||||
|
|
||||||
# Collect subtask durations and trajectory lengths for compute_priors
|
|
||||||
subtask_durations_per_trajectory = {}
|
|
||||||
trajectory_lengths = {}
|
|
||||||
all_subtask_names = set()
|
|
||||||
|
|
||||||
for ep_idx in episodes_df.index:
|
|
||||||
subtask_names_ep = episodes_df.loc[ep_idx, 'subtask_names']
|
|
||||||
|
|
||||||
# Skip episodes without annotations
|
|
||||||
if subtask_names_ep is None or (isinstance(subtask_names_ep, float) and pd.isna(subtask_names_ep)):
|
|
||||||
continue
|
|
||||||
|
|
||||||
start_times = episodes_df.loc[ep_idx, 'subtask_start_times']
|
|
||||||
end_times = episodes_df.loc[ep_idx, 'subtask_end_times']
|
|
||||||
|
|
||||||
# Track unique subtask names
|
|
||||||
all_subtask_names.update(subtask_names_ep)
|
|
||||||
|
|
||||||
# Compute total trajectory length T_i (sum of all subtask durations)
|
|
||||||
total_traj_length = sum(end_times[i] - start_times[i] for i in range(len(subtask_names_ep)))
|
|
||||||
|
|
||||||
if total_traj_length <= 0:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Store duration and trajectory length for each subtask occurrence
|
|
||||||
for i, name in enumerate(subtask_names_ep):
|
|
||||||
duration = end_times[i] - start_times[i]
|
|
||||||
|
|
||||||
if name not in subtask_durations_per_trajectory:
|
|
||||||
subtask_durations_per_trajectory[name] = []
|
|
||||||
trajectory_lengths[name] = []
|
|
||||||
|
|
||||||
subtask_durations_per_trajectory[name].append(duration)
|
|
||||||
trajectory_lengths[name].append(total_traj_length)
|
|
||||||
|
|
||||||
if not all_subtask_names:
|
|
||||||
logging.info("No valid subtask annotations found")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Sort subtask names for consistent ordering
|
|
||||||
self.subtask_names = sorted(list(all_subtask_names))
|
|
||||||
self.config.num_stages = len(self.subtask_names)
|
|
||||||
self.config.subtask_names = self.subtask_names
|
|
||||||
|
|
||||||
# Compute temporal proportions using Paper Formula (1)
|
|
||||||
self.temporal_proportions = compute_priors(
|
|
||||||
subtask_durations_per_trajectory,
|
|
||||||
trajectory_lengths,
|
|
||||||
self.subtask_names
|
|
||||||
)
|
|
||||||
self.config.temporal_proportions = [self.temporal_proportions[name] for name in self.subtask_names]
|
|
||||||
logging.info(f"Computed temporal proportions for {len(self.subtask_names)} subtasks: {self.temporal_proportions}")
|
|
||||||
|
|
||||||
def _find_episode_for_frame(self, frame_idx: int) -> int:
|
def _find_episode_for_frame(self, frame_idx: int) -> int:
|
||||||
"""Find the episode index for a given frame index."""
|
"""Find the episode index for a given frame index."""
|
||||||
for ep_idx in range(len(self.dataset_meta.episodes)):
|
for ep_idx in range(len(self.dataset_meta.episodes)):
|
||||||
@@ -437,10 +352,8 @@ class SARMEncodingProcessorStep(ProcessorStep):
|
|||||||
else:
|
else:
|
||||||
state_tensor = torch.tensor(state_data, dtype=torch.float32)
|
state_tensor = torch.tensor(state_data, dtype=torch.float32)
|
||||||
|
|
||||||
# Pad state
|
|
||||||
observation['state_features'] = pad_state_to_max_dim(state_tensor, self.config.max_state_dim)
|
observation['state_features'] = pad_state_to_max_dim(state_tensor, self.config.max_state_dim)
|
||||||
|
|
||||||
# Extract complementary data (includes task from dataset)
|
|
||||||
comp_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
comp_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||||
if not isinstance(comp_data, dict):
|
if not isinstance(comp_data, dict):
|
||||||
raise ValueError("COMPLEMENTARY_DATA must be a dictionary")
|
raise ValueError("COMPLEMENTARY_DATA must be a dictionary")
|
||||||
@@ -595,7 +508,6 @@ class SARMEncodingProcessorStep(ProcessorStep):
|
|||||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||||
"""Add encoded features to the observation features."""
|
"""Add encoded features to the observation features."""
|
||||||
# Add the encoded features (state uses max_state_dim, padded with zeros)
|
|
||||||
features[PipelineFeatureType.OBSERVATION]['video_features'] = PolicyFeature(
|
features[PipelineFeatureType.OBSERVATION]['video_features'] = PolicyFeature(
|
||||||
type=FeatureType.VISUAL,
|
type=FeatureType.VISUAL,
|
||||||
shape=(self.config.num_frames, self.config.image_dim)
|
shape=(self.config.num_frames, self.config.image_dim)
|
||||||
@@ -622,60 +534,40 @@ def make_sarm_pre_post_processors(
|
|||||||
"""
|
"""
|
||||||
Create pre-processor and post-processor pipelines for SARM.
|
Create pre-processor and post-processor pipelines for SARM.
|
||||||
|
|
||||||
Per SARM paper (Appendix A.4): "We employ a frozen clip-vit-base-patch32 encoder
|
|
||||||
to process both RGB image sequences and task descriptions."
|
|
||||||
|
|
||||||
The pre-processing pipeline:
|
The pre-processing pipeline:
|
||||||
1. Adds batch dimension
|
1. Adds batch dimension
|
||||||
2. Normalizes observation.state using NormalizerProcessorStep (MEAN_STD)
|
2. Normalizes observation.state using NormalizerProcessorStep (MEAN_STD)
|
||||||
3. SARMEncodingProcessorStep:
|
3. SARMEncodingProcessorStep:
|
||||||
- Encodes images with CLIP (512-dim)
|
- Encodes images with CLIP
|
||||||
- Pads states to max_state_dim
|
- Pads states to max_state_dim
|
||||||
- Encodes text with CLIP (512-dim)
|
- Encodes text with CLIP
|
||||||
4. Moves data to device
|
4. Moves data to device
|
||||||
|
|
||||||
The post-processing pipeline:
|
The post-processing pipeline:
|
||||||
1. Moves data to CPU (no unnormalization - outputs are rewards)
|
1. Moves data to CPU
|
||||||
|
|
||||||
Args:
|
|
||||||
config: SARM configuration
|
|
||||||
dataset_stats: Dataset statistics for normalization
|
|
||||||
dataset_meta: Dataset metadata for computing episode info
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (preprocessor, postprocessor) pipelines
|
|
||||||
"""
|
"""
|
||||||
input_steps = [
|
|
||||||
AddBatchDimensionProcessorStep(),
|
|
||||||
NormalizerProcessorStep(
|
|
||||||
features={**config.input_features, **config.output_features},
|
|
||||||
norm_map=config.normalization_mapping,
|
|
||||||
stats=dataset_stats,
|
|
||||||
),
|
|
||||||
SARMEncodingProcessorStep(
|
|
||||||
config=config,
|
|
||||||
dataset_meta=dataset_meta,
|
|
||||||
dataset_stats=dataset_stats
|
|
||||||
),
|
|
||||||
DeviceProcessorStep(device=config.device),
|
|
||||||
]
|
|
||||||
|
|
||||||
output_steps = [
|
|
||||||
DeviceProcessorStep(device="cpu"),
|
|
||||||
]
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
|
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
|
||||||
steps=input_steps,
|
steps=[
|
||||||
|
AddBatchDimensionProcessorStep(),
|
||||||
|
NormalizerProcessorStep(
|
||||||
|
features={**config.input_features, **config.output_features},
|
||||||
|
norm_map=config.normalization_mapping,
|
||||||
|
stats=dataset_stats,
|
||||||
|
),
|
||||||
|
SARMEncodingProcessorStep(
|
||||||
|
config=config,
|
||||||
|
dataset_meta=dataset_meta,
|
||||||
|
dataset_stats=dataset_stats
|
||||||
|
),
|
||||||
|
DeviceProcessorStep(device=config.device),
|
||||||
|
],
|
||||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||||
),
|
),
|
||||||
PolicyProcessorPipeline[PolicyAction, PolicyAction](
|
PolicyProcessorPipeline[PolicyAction, PolicyAction](
|
||||||
steps=output_steps,
|
steps=[DeviceProcessorStep(device="cpu")],
|
||||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||||
to_transition=policy_action_to_transition,
|
to_transition=policy_action_to_transition,
|
||||||
to_output=transition_to_policy_action,
|
to_output=transition_to_policy_action,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -246,4 +246,3 @@ def pad_state_to_max_dim(state: torch.Tensor, max_state_dim: int) -> torch.Tenso
|
|||||||
# Pad with zeros on the right
|
# Pad with zeros on the right
|
||||||
padding = (0, max_state_dim - current_dim) # (left, right) for last dim
|
padding = (0, max_state_dim - current_dim) # (left, right) for last dim
|
||||||
return F.pad(state, padding, mode='constant', value=0)
|
return F.pad(state, padding, mode='constant', value=0)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user