mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-23 04:30:10 +00:00
add tests, implement formula 1,2 correctly and cleanup
This commit is contained in:
@@ -64,9 +64,7 @@ class SARMTemporalSampler(Sampler):
|
||||
self.shuffle = shuffle
|
||||
self.samples_per_epoch = samples_per_epoch
|
||||
|
||||
# Minimum frames needed for SARM pattern:
|
||||
# 8 consecutive frames with frame_gap spacing = 7 * frame_gap + 1
|
||||
# (Plus the initial frame which is always available)
|
||||
# Minimum frames needed for SARM pattern: 8 consecutive frames with frame_gap spacing = 7 * frame_gap + 1
|
||||
self.min_frames_needed = 7 * frame_gap + 1
|
||||
|
||||
if seed is not None:
|
||||
@@ -138,7 +136,3 @@ class SARMTemporalSampler(Sampler):
|
||||
for i in range(self.samples_per_epoch):
|
||||
idx = i % len(self.all_valid_positions)
|
||||
yield int(self.all_valid_positions[idx])
|
||||
|
||||
|
||||
# Backwards compatibility alias
|
||||
TemporalSequenceSampler = SARMTemporalSampler
|
||||
|
||||
@@ -18,7 +18,6 @@ from lerobot.policies.sarm.configuration_sarm import SARMConfig
|
||||
from lerobot.policies.sarm.modeling_sarm import (
|
||||
SARMRewardModel,
|
||||
SARMTransformer,
|
||||
compute_stage_loss,
|
||||
)
|
||||
from lerobot.policies.sarm.processor_sarm import (
|
||||
SARMEncodingProcessorStep,
|
||||
@@ -29,7 +28,6 @@ __all__ = [
|
||||
"SARMConfig",
|
||||
"SARMRewardModel",
|
||||
"SARMTransformer",
|
||||
"compute_stage_loss",
|
||||
"SARMEncodingProcessorStep",
|
||||
"make_sarm_pre_post_processors",
|
||||
]
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import PolicyFeature, FeatureType
|
||||
from lerobot.configs.types import PolicyFeature, FeatureType, NormalizationMode
|
||||
from lerobot.optim.optimizers import AdamWConfig
|
||||
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
|
||||
|
||||
@@ -27,63 +27,83 @@ from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
|
||||
class SARMConfig(PreTrainedConfig):
|
||||
"""Configuration class for SARM (Stage-Aware Reward Modeling)"""
|
||||
|
||||
# Visual encoding parameters
|
||||
image_dim: int = 512 # CLIP embedding dimension
|
||||
# CLIP encoding parameters
|
||||
image_dim: int = 512
|
||||
text_dim: int = 512
|
||||
num_frames: int = 9 # 1 initial + 8 consecutive frames
|
||||
frame_gap: int = 30 # Frame gap between consecutive frames (at 30 fps = 1 second)
|
||||
|
||||
# Text encoding parameters (CLIP text encoder output dimension)
|
||||
text_dim: int = 512
|
||||
|
||||
# Joint state parameters
|
||||
state_dim: int | None = None # Auto-detected from dataset if None
|
||||
|
||||
# Architecture parameters
|
||||
hidden_dim: int = 768
|
||||
num_heads: int = 12
|
||||
num_layers: int = 8
|
||||
num_stages: int = 5 # Number of task stages for classification (auto-updated from annotations if available)
|
||||
max_state_dim: int = 32
|
||||
num_stages: int = 5 # Number of task stages (auto-updated from annotations if available)
|
||||
subtask_names: list | None = None # List of subtask names (auto-populated from annotations)
|
||||
temporal_proportions: list | None = None # Temporal proportions for each stage (auto-computed from annotations)
|
||||
|
||||
# Temporal parameters
|
||||
max_length: int = num_frames # Maximum video sequence length (matches num_frames)
|
||||
use_temporal_sampler: bool = True # Always enable temporal sequence loading
|
||||
|
||||
# Training parameters
|
||||
batch_size: int = 64
|
||||
clip_batch_size: int = 64 # Batch size for CLIP encoding
|
||||
gradient_checkpointing: bool = False # Enable gradient checkpointing
|
||||
dropout: float = 0.1 # Dropout rate
|
||||
dropout: float = 0.1
|
||||
stage_loss_weight: float = 1.0 # Weight for stage classification loss when using subtask annotations
|
||||
|
||||
pretrained_model_path: str | None = None
|
||||
|
||||
device: str | None = None
|
||||
|
||||
# Processor settings
|
||||
image_key: str = "observation.images.top" # Key for image used from the dataset
|
||||
task_description: str = "perform the task" # Default task description
|
||||
|
||||
# Video_features and text_features are generated by the processor from raw images/text, we don't declare them as VISUAL/LANGUAGE here to avoid validation errors
|
||||
input_features: dict = field(default_factory=lambda: {
|
||||
"state_features": PolicyFeature(shape=(9, 14), type=FeatureType.STATE) # Example: 7 DOF × 2 arms
|
||||
})
|
||||
# State key in the dataset (for normalization)
|
||||
state_key: str = "observation.state"
|
||||
|
||||
# Populated by the processor (video_features, state_features, text_features)
|
||||
input_features: dict = field(default_factory=lambda: {})
|
||||
|
||||
# Output features
|
||||
output_features: dict = field(default_factory=lambda: {
|
||||
"stage": PolicyFeature(shape=(1,), type=FeatureType.REWARD),
|
||||
"progress": PolicyFeature(shape=(1,), type=FeatureType.REWARD)
|
||||
"stage": PolicyFeature(shape=(9, 5), type=FeatureType.REWARD),
|
||||
"progress": PolicyFeature(shape=(9, 1), type=FeatureType.REWARD),
|
||||
})
|
||||
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.MEAN_STD,
|
||||
"LANGUAGE": NormalizationMode.IDENTITY,
|
||||
"REWARD": NormalizationMode.IDENTITY,
|
||||
}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
# Add the image_key, the processor will transform this into video_features
|
||||
if self.image_key and self.image_key not in self.input_features:
|
||||
# Add the image_key as VISUAL (this is the raw image from dataset)
|
||||
if self.image_key:
|
||||
self.input_features[self.image_key] = PolicyFeature(
|
||||
shape=(480, 640, 3),
|
||||
type=FeatureType.VISUAL
|
||||
)
|
||||
|
||||
# Add state_key as STATE (raw state from dataset, will be padded to max_state_dim)
|
||||
self.input_features[self.state_key] = PolicyFeature(
|
||||
shape=(self.max_state_dim,), # Single frame state, temporal sampling handles sequence
|
||||
type=FeatureType.STATE
|
||||
)
|
||||
|
||||
# Update output features with actual dimensions
|
||||
self.output_features["stage"] = PolicyFeature(
|
||||
shape=(self.num_frames, self.num_stages),
|
||||
type=FeatureType.REWARD
|
||||
)
|
||||
self.output_features["progress"] = PolicyFeature(
|
||||
shape=(self.num_frames, 1),
|
||||
type=FeatureType.REWARD
|
||||
)
|
||||
|
||||
# Validate configuration
|
||||
if self.hidden_dim % self.num_heads != 0:
|
||||
raise ValueError(
|
||||
@@ -95,9 +115,6 @@ class SARMConfig(PreTrainedConfig):
|
||||
f"max_length ({self.max_length}) must equal num_frames ({self.num_frames})"
|
||||
)
|
||||
|
||||
if self.dropout < 0 or self.dropout >= 1:
|
||||
raise ValueError(f"dropout must be in [0, 1), got {self.dropout}")
|
||||
|
||||
if self.num_stages < 2:
|
||||
raise ValueError(f"num_stages must be at least 2, got {self.num_stages}")
|
||||
|
||||
@@ -139,11 +156,10 @@ class SARMConfig(PreTrainedConfig):
|
||||
Returns:
|
||||
9 delta indices: [-1_000_000, -(7*gap), -(6*gap), ..., -gap, 0]
|
||||
"""
|
||||
# First delta: large negative to always clamp to episode start (frame 0)
|
||||
initial_frame_delta = -1_000_000
|
||||
|
||||
# Remaining 8 deltas: consecutive frames with frame_gap spacing
|
||||
num_consecutive = self.num_frames - 1 # 8 frames
|
||||
# Remaining consecutive frames with frame_gap spacing
|
||||
num_consecutive = self.num_frames - 1
|
||||
consecutive_deltas = list(range(-self.frame_gap * (num_consecutive - 1), 1, self.frame_gap))
|
||||
|
||||
return [initial_frame_delta] + consecutive_deltas
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import List, Union, Dict, Optional
|
||||
from typing import List, Union, Optional
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
@@ -28,9 +28,12 @@ from transformers import CLIPModel, CLIPProcessor
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.policies.sarm.configuration_sarm import SARMConfig
|
||||
from lerobot.policies.sarm.sarm_utils import compute_priors, compute_cumulative_progress_batch, pad_state_to_max_dim
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
|
||||
|
||||
|
||||
|
||||
class SARMTransformer(nn.Module):
|
||||
"""
|
||||
SARM Transformer model for stage-aware reward prediction.
|
||||
@@ -45,8 +48,8 @@ class SARMTransformer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
video_dim: int = 512,
|
||||
text_dim: int = 512, # CLIP text encoder output dimension (per SARM paper A.4)
|
||||
state_dim: int = 14,
|
||||
text_dim: int = 512,
|
||||
max_state_dim: int = 32,
|
||||
hidden_dim: int = 768,
|
||||
num_heads: int = 12,
|
||||
num_layers: int = 8,
|
||||
@@ -59,9 +62,8 @@ class SARMTransformer(nn.Module):
|
||||
self.hidden_dim = hidden_dim
|
||||
self.max_length = max_length
|
||||
self.num_stages = num_stages
|
||||
self.max_state_dim = max_state_dim
|
||||
|
||||
# Store temporal proportions for progress conversion (Paper Eq. 4)
|
||||
# ŷ = P_{k-1} + ᾱ_k × τ̂
|
||||
if temporal_proportions is None:
|
||||
raise ValueError(
|
||||
"temporal_proportions is required for SARM. "
|
||||
@@ -77,15 +79,13 @@ class SARMTransformer(nn.Module):
|
||||
self.register_buffer('alpha', alpha)
|
||||
self.register_buffer('cumulative_prior', cumulative)
|
||||
|
||||
# Project video, text, and state to same dimension
|
||||
self.video_proj = nn.Linear(video_dim, hidden_dim)
|
||||
self.text_proj = nn.Linear(text_dim, hidden_dim)
|
||||
self.state_proj = nn.Linear(state_dim, hidden_dim)
|
||||
self.state_proj = nn.Linear(max_state_dim, hidden_dim)
|
||||
|
||||
# Position embedding only for the first frame
|
||||
self.first_pos_embed = nn.Parameter(torch.randn(1, hidden_dim))
|
||||
|
||||
# Transformer encoder
|
||||
encoder_layer = nn.TransformerEncoderLayer(
|
||||
d_model=hidden_dim,
|
||||
nhead=num_heads,
|
||||
@@ -96,7 +96,6 @@ class SARMTransformer(nn.Module):
|
||||
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
|
||||
|
||||
# Stage estimator head (classification)
|
||||
# Paper A.4: "2 layers with hidden dimension of 512"
|
||||
self.stage_head = nn.Sequential(
|
||||
nn.Linear(hidden_dim, 512),
|
||||
nn.LayerNorm(512),
|
||||
@@ -106,8 +105,6 @@ class SARMTransformer(nn.Module):
|
||||
)
|
||||
|
||||
# Subtask estimator head (regression, conditioned on stage)
|
||||
# Takes concatenated [features, stage_embedding]
|
||||
# Paper A.4: "2 layers with hidden dimension of 512"
|
||||
self.stage_embedding = nn.Embedding(num_stages, hidden_dim // 4)
|
||||
subtask_input_dim = hidden_dim + hidden_dim // 4
|
||||
self.subtask_head = nn.Sequential(
|
||||
@@ -119,13 +116,13 @@ class SARMTransformer(nn.Module):
|
||||
nn.Sigmoid()
|
||||
)
|
||||
|
||||
# Attention mask for causal self-attention
|
||||
# Attention mask
|
||||
self.register_buffer("attention_mask", None, persistent=False)
|
||||
|
||||
def _get_attention_mask(self, seq_length: int, device: torch.device) -> torch.Tensor:
|
||||
"""Generate or retrieve cached causal attention mask."""
|
||||
if self.attention_mask is None or self.attention_mask.shape[0] != seq_length:
|
||||
# Create causal mask (upper triangular with -inf)
|
||||
# Create causal mask
|
||||
mask = nn.Transformer.generate_square_subsequent_mask(seq_length, device=device)
|
||||
self.attention_mask = mask
|
||||
return self.attention_mask
|
||||
@@ -150,14 +147,16 @@ class SARMTransformer(nn.Module):
|
||||
- Stage probabilities (batch_size, seq_len, num_stages)
|
||||
- Progress predictions for each frame (batch_size, seq_len, 1)
|
||||
"""
|
||||
batch_size = video_frames.shape[0]
|
||||
|
||||
# Project inputs to common dimension
|
||||
video_embed = self.video_proj(video_frames) # [batch_size, seq_len, hidden_dim]
|
||||
text_embed = self.text_proj(text_embed).unsqueeze(1) # [batch_size, 1, hidden_dim]
|
||||
|
||||
state_embed = self.state_proj(state_features) # [batch_size, seq_len, hidden_dim]
|
||||
# Fuse video and state features (simple addition)
|
||||
# Pad state features to max_state_dim before projection
|
||||
state_features_padded = pad_state_to_max_dim(state_features, self.max_state_dim)
|
||||
|
||||
state_embed = self.state_proj(state_features_padded) # [batch_size, seq_len, hidden_dim]
|
||||
|
||||
# Fuse video and state features
|
||||
video_embed = video_embed + state_embed
|
||||
|
||||
# Add positional embedding to first video frame
|
||||
@@ -173,7 +172,7 @@ class SARMTransformer(nn.Module):
|
||||
# Pass through transformer with causal masking
|
||||
transformed = self.transformer(sequence, mask=attention_mask, is_causal=True)
|
||||
|
||||
# Get frame features (exclude text token)
|
||||
# Get frame features
|
||||
frame_features = transformed[:, 1:] # [batch_size, seq_len, hidden_dim]
|
||||
|
||||
# Stage estimation
|
||||
@@ -193,14 +192,11 @@ class SARMTransformer(nn.Module):
|
||||
# τ̂ = within-subtask progress (0-1)
|
||||
tau_preds = self.subtask_head(conditioned_features) # [batch_size, seq_len, 1]
|
||||
|
||||
# Convert τ̂ to cumulative progress ŷ using Paper Eq. 4:
|
||||
# Convert τ̂ to cumulative progress ŷ using Paper Formula (2):
|
||||
# ŷ = P_{k-1} + ᾱ_k × τ̂
|
||||
# P_{k-1} = cumulative prior up to stage k-1
|
||||
# ᾱ_k = temporal proportion of stage k
|
||||
P_k_minus_1 = self.cumulative_prior[stage_indices] # [batch_size, seq_len]
|
||||
alpha_k = self.alpha[stage_indices] # [batch_size, seq_len]
|
||||
|
||||
progress_preds = P_k_minus_1.unsqueeze(-1) + alpha_k.unsqueeze(-1) * tau_preds
|
||||
progress_preds = compute_cumulative_progress_batch(
|
||||
tau_preds, stage_indices, self.alpha, self.cumulative_prior
|
||||
)
|
||||
|
||||
return stage_logits, stage_probs, progress_preds
|
||||
|
||||
@@ -227,65 +223,37 @@ class SARMRewardModel(PreTrainedPolicy):
|
||||
self.dataset_stats = dataset_stats
|
||||
self.device = torch.device(config.device if config.device else "cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
# Auto-detect num_stages from dataset annotations before building the model
|
||||
# Detect num_stages from dataset annotations before building the model
|
||||
if dataset_meta is not None:
|
||||
self._update_num_stages_from_dataset(dataset_meta)
|
||||
|
||||
# Initialize CLIP encoder for images AND text (per SARM paper A.4)
|
||||
logging.info("Loading CLIP encoder for images and text...")
|
||||
logging.info("Loading CLIP encoder")
|
||||
self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
||||
self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", use_fast=True)
|
||||
self.clip_model.to(self.device)
|
||||
self.clip_model.eval()
|
||||
|
||||
# Auto-detect state_dim from dataset_stats
|
||||
if config.state_dim is None:
|
||||
logging.info(f"Attempting to auto-detect state_dim. dataset_stats is None: {dataset_stats is None}")
|
||||
|
||||
if dataset_stats is not None:
|
||||
if "observation.state" in dataset_stats:
|
||||
config.state_dim = dataset_stats["observation.state"]["mean"].shape[0]
|
||||
logging.info(f"Auto-detected state_dim={config.state_dim} from dataset_stats['observation.state']")
|
||||
elif "state" in dataset_stats:
|
||||
config.state_dim = dataset_stats["state"]["mean"].shape[0]
|
||||
logging.info(f"Auto-detected state_dim={config.state_dim} from dataset_stats['state']")
|
||||
else:
|
||||
logging.warning(f"State keys not found in dataset_stats. Available keys: {list(dataset_stats.keys())}")
|
||||
else:
|
||||
logging.warning("dataset_stats is None, cannot auto-detect state_dim")
|
||||
|
||||
# Raise explicit error if still None
|
||||
if config.state_dim is None:
|
||||
raise ValueError(
|
||||
"Could not determine state_dim! "
|
||||
f"dataset_stats={'None' if dataset_stats is None else f'available with keys: {list(dataset_stats.keys())}'}, "
|
||||
"config.state_dim=None. "
|
||||
"Please either:\n"
|
||||
"1. Provide --policy.state_dim=<your_state_dimension> explicitly, or\n"
|
||||
"2. Ensure dataset_stats contains 'observation.state' or 'state' key"
|
||||
)
|
||||
|
||||
# Initialize SARM transformer with temporal proportions for progress conversion
|
||||
temporal_proportions = getattr(config, 'temporal_proportions', None)
|
||||
self.sarm_transformer = SARMTransformer(
|
||||
video_dim=config.image_dim,
|
||||
text_dim=config.text_dim,
|
||||
state_dim=config.state_dim,
|
||||
max_state_dim=config.max_state_dim,
|
||||
hidden_dim=config.hidden_dim,
|
||||
num_heads=config.num_heads,
|
||||
num_layers=config.num_layers,
|
||||
num_stages=config.num_stages,
|
||||
max_length=config.max_length,
|
||||
dropout=config.dropout,
|
||||
temporal_proportions=temporal_proportions
|
||||
temporal_proportions=config.temporal_proportions
|
||||
)
|
||||
self.sarm_transformer.to(self.device)
|
||||
|
||||
|
||||
logging.info(f"SARM Reward Model initialized on {self.device}")
|
||||
|
||||
def _update_num_stages_from_dataset(self, dataset_meta) -> None:
|
||||
"""Update num_stages and temporal_proportions from dataset subtask annotations."""
|
||||
"""Update num_stages and temporal_proportions from dataset subtask annotations.
|
||||
|
||||
Implements SARM Paper Formula (1):
|
||||
ᾱ_k = (1/M) × Σ_i (L_{i,k} / T_i)
|
||||
"""
|
||||
episodes = dataset_meta.episodes
|
||||
if episodes is None or len(episodes) == 0:
|
||||
raise ValueError("No episodes found, using default num_stages")
|
||||
@@ -295,27 +263,38 @@ class SARMRewardModel(PreTrainedPolicy):
|
||||
|
||||
episodes_df = episodes.to_pandas()
|
||||
|
||||
# Collect all unique subtask names and compute durations
|
||||
# Collect subtask durations and trajectory lengths for compute_priors
|
||||
all_subtask_names = set()
|
||||
subtask_durations = {}
|
||||
subtask_durations_per_trajectory = {}
|
||||
trajectory_lengths = {}
|
||||
|
||||
for ep_idx in episodes_df.index:
|
||||
subtask_names = episodes_df.loc[ep_idx, 'subtask_names']
|
||||
if subtask_names is None or (isinstance(subtask_names, float) and pd.isna(subtask_names)):
|
||||
subtask_names_ep = episodes_df.loc[ep_idx, 'subtask_names']
|
||||
if subtask_names_ep is None or (isinstance(subtask_names_ep, float) and pd.isna(subtask_names_ep)):
|
||||
continue
|
||||
|
||||
all_subtask_names.update(subtask_names)
|
||||
all_subtask_names.update(subtask_names_ep)
|
||||
|
||||
# Compute durations if available
|
||||
if 'subtask_start_frames' in episodes_df.columns and 'subtask_end_frames' in episodes_df.columns:
|
||||
start_frames = episodes_df.loc[ep_idx, 'subtask_start_frames']
|
||||
end_frames = episodes_df.loc[ep_idx, 'subtask_end_frames']
|
||||
|
||||
for i, name in enumerate(subtask_names):
|
||||
# Compute total trajectory length T_i
|
||||
total_traj_length = sum(end_frames[i] - start_frames[i] for i in range(len(subtask_names_ep)))
|
||||
|
||||
if total_traj_length <= 0:
|
||||
continue
|
||||
|
||||
for i, name in enumerate(subtask_names_ep):
|
||||
duration = end_frames[i] - start_frames[i]
|
||||
if name not in subtask_durations:
|
||||
subtask_durations[name] = []
|
||||
subtask_durations[name].append(duration)
|
||||
|
||||
if name not in subtask_durations_per_trajectory:
|
||||
subtask_durations_per_trajectory[name] = []
|
||||
trajectory_lengths[name] = []
|
||||
|
||||
subtask_durations_per_trajectory[name].append(duration)
|
||||
trajectory_lengths[name].append(total_traj_length)
|
||||
|
||||
if not all_subtask_names:
|
||||
raise ValueError("No valid subtask names found, using default num_stages")
|
||||
@@ -324,26 +303,20 @@ class SARMRewardModel(PreTrainedPolicy):
|
||||
subtask_names = sorted(list(all_subtask_names))
|
||||
num_stages = len(subtask_names)
|
||||
|
||||
# Compute temporal proportions (Paper Eq. 1: ᾱ_k)
|
||||
avg_durations = {}
|
||||
for name in subtask_names:
|
||||
if name in subtask_durations and subtask_durations[name]:
|
||||
avg_durations[name] = np.mean(subtask_durations[name])
|
||||
else:
|
||||
avg_durations[name] = 1.0 # Default
|
||||
|
||||
total_duration = sum(avg_durations.values())
|
||||
if total_duration > 0:
|
||||
temporal_proportions = [avg_durations[name] / total_duration for name in subtask_names]
|
||||
else:
|
||||
temporal_proportions = [1.0 / num_stages] * num_stages
|
||||
# Compute temporal proportions using Paper Formula (1)
|
||||
temporal_proportions_dict = compute_priors(
|
||||
subtask_durations_per_trajectory,
|
||||
trajectory_lengths,
|
||||
subtask_names
|
||||
)
|
||||
temporal_proportions = [temporal_proportions_dict[name] for name in subtask_names]
|
||||
|
||||
self.config.num_stages = num_stages
|
||||
self.config.subtask_names = subtask_names
|
||||
self.config.temporal_proportions = temporal_proportions
|
||||
|
||||
logging.info(f"Auto-detected {num_stages} subtasks: {subtask_names}")
|
||||
logging.info(f"Temporal proportions: {dict(zip(subtask_names, temporal_proportions))}")
|
||||
logging.info(f"Temporal proportions: {temporal_proportions_dict}")
|
||||
|
||||
def to(self, device):
|
||||
"""Override to method to ensure all components move together."""
|
||||
@@ -475,7 +448,6 @@ class SARMRewardModel(PreTrainedPolicy):
|
||||
If return_stages=True:
|
||||
Tuple of (rewards, stage_probs)
|
||||
"""
|
||||
# Convert to tensors if needed
|
||||
if isinstance(text_embeddings, np.ndarray):
|
||||
text_embeddings = torch.tensor(text_embeddings, dtype=torch.float32)
|
||||
if isinstance(video_embeddings, np.ndarray):
|
||||
@@ -535,16 +507,13 @@ class SARMRewardModel(PreTrainedPolicy):
|
||||
def load_pretrained_checkpoint(self, checkpoint_path: str, strict: bool = False):
|
||||
"""Load pretrained model weights from a checkpoint file."""
|
||||
logging.info(f"Loading pretrained checkpoint from {checkpoint_path}")
|
||||
|
||||
checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=False)
|
||||
|
||||
# Handle different checkpoint formats
|
||||
if "model_state_dict" in checkpoint:
|
||||
state_dict = checkpoint["model_state_dict"]
|
||||
else:
|
||||
state_dict = checkpoint
|
||||
|
||||
# Load only the SARMTransformer weights
|
||||
missing_keys, unexpected_keys = self.sarm_transformer.load_state_dict(state_dict, strict=strict)
|
||||
|
||||
if missing_keys:
|
||||
@@ -557,9 +526,7 @@ class SARMRewardModel(PreTrainedPolicy):
|
||||
def train(self, mode: bool = True):
|
||||
"""Set training mode. Note: CLIP encoder always stays in eval mode (frozen)."""
|
||||
super().train(mode)
|
||||
# Keep CLIP encoder in eval mode (frozen per SARM paper)
|
||||
self.clip_model.eval()
|
||||
# Only transformer can be trained
|
||||
self.sarm_transformer.train(mode)
|
||||
return self
|
||||
|
||||
@@ -686,8 +653,7 @@ class SARMRewardModel(PreTrainedPolicy):
|
||||
state = state_features[i] if state_features is not None else None
|
||||
progress = progress_from_annotations[i].squeeze(-1) # (T,)
|
||||
|
||||
# Apply temporal augmentation with 50% probability (SARM paper A.4)
|
||||
# Appends up to 4 reversed frames to simulate failures/recoveries
|
||||
# Apply temporal augmentation with 50% probability: appends up to 4 reversed frames to simulate failures/recoveries
|
||||
if random.random() < 0.5:
|
||||
video, progress, state = self._apply_temporal_augmentation(video, progress, state, max_length)
|
||||
|
||||
@@ -729,7 +695,7 @@ class SARMRewardModel(PreTrainedPolicy):
|
||||
total_loss = total_loss + self.config.stage_loss_weight * stage_loss
|
||||
output_dict['stage_loss'] = stage_loss.item()
|
||||
|
||||
# Misaligned loss: 20% probability (SARM paper - improve video-language alignment)
|
||||
# Misaligned loss: 20% probability
|
||||
if random.random() < 0.2:
|
||||
shuffle_idx = torch.randperm(batch_size, device=self.device)
|
||||
_, _, misaligned_preds = self.sarm_transformer(
|
||||
|
||||
@@ -23,16 +23,19 @@ import pandas as pd
|
||||
from transformers import CLIPModel, CLIPProcessor
|
||||
|
||||
from lerobot.policies.sarm.configuration_sarm import SARMConfig
|
||||
from lerobot.policies.sarm.sarm_utils import compute_priors, compute_tau, compute_cumulative_progress_batch, pad_state_to_max_dim
|
||||
from lerobot.processor import (
|
||||
ProcessorStep,
|
||||
PolicyProcessorPipeline,
|
||||
PolicyAction,
|
||||
DeviceProcessorStep,
|
||||
AddBatchDimensionProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
)
|
||||
from lerobot.processor.converters import (
|
||||
policy_action_to_transition,
|
||||
transition_to_policy_action,
|
||||
from_tensor_to_numpy,
|
||||
)
|
||||
from lerobot.processor.pipeline import PipelineFeatureType
|
||||
from lerobot.processor.core import EnvTransition, TransitionKey
|
||||
@@ -41,20 +44,7 @@ from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PR
|
||||
|
||||
|
||||
class SARMEncodingProcessorStep(ProcessorStep):
|
||||
"""
|
||||
ProcessorStep that encodes images and text for SARM training.
|
||||
|
||||
Per SARM paper (Appendix A.4): "We employ a frozen clip-vit-base-patch32 encoder
|
||||
to process both RGB image sequences and task descriptions."
|
||||
|
||||
This step handles:
|
||||
- CLIP image encoding (512-dim)
|
||||
- CLIP text encoding (512-dim)
|
||||
- Joint state normalization
|
||||
|
||||
Supports temporal sequences: (B, T, C, H, W) → (B, T, 512) video features
|
||||
"""
|
||||
|
||||
"""ProcessorStep that encodes images and text with CLIP."""
|
||||
def __init__(
|
||||
self,
|
||||
config: SARMConfig,
|
||||
@@ -69,8 +59,6 @@ class SARMEncodingProcessorStep(ProcessorStep):
|
||||
self.task_description = task_description or config.task_description
|
||||
self.dataset_meta = dataset_meta
|
||||
self.dataset_stats = dataset_stats
|
||||
|
||||
# Compute temporal proportions from subtask annotations if available
|
||||
self.temporal_proportions = None
|
||||
self.subtask_names = None
|
||||
if dataset_meta is not None:
|
||||
@@ -94,7 +82,14 @@ class SARMEncodingProcessorStep(ProcessorStep):
|
||||
self.device = device
|
||||
|
||||
def _compute_temporal_proportions(self):
|
||||
"""Compute temporal proportions for each subtask from dataset annotations."""
|
||||
"""Compute temporal proportions for each subtask from dataset annotations.
|
||||
|
||||
Implements SARM Paper Formula (1):
|
||||
ᾱ_k = (1/M) × Σ_i (L_{i,k} / T_i)
|
||||
|
||||
This averages the proportion of time spent on each subtask within each trajectory,
|
||||
giving equal weight to all trajectories regardless of absolute length.
|
||||
"""
|
||||
if self.dataset_meta is None or not hasattr(self.dataset_meta, 'episodes'):
|
||||
return
|
||||
|
||||
@@ -108,32 +103,42 @@ class SARMEncodingProcessorStep(ProcessorStep):
|
||||
logging.info("No subtask annotations found in dataset")
|
||||
return
|
||||
|
||||
# Convert to pandas
|
||||
episodes_df = episodes.to_pandas()
|
||||
|
||||
# Collect all subtask names and compute average durations
|
||||
subtask_durations = {}
|
||||
# Collect subtask durations and trajectory lengths for compute_priors
|
||||
subtask_durations_per_trajectory = {}
|
||||
trajectory_lengths = {}
|
||||
all_subtask_names = set()
|
||||
|
||||
for ep_idx in episodes_df.index:
|
||||
subtask_names = episodes_df.loc[ep_idx, 'subtask_names']
|
||||
subtask_names_ep = episodes_df.loc[ep_idx, 'subtask_names']
|
||||
|
||||
# Skip episodes without annotations
|
||||
if subtask_names is None or (isinstance(subtask_names, float) and pd.isna(subtask_names)):
|
||||
if subtask_names_ep is None or (isinstance(subtask_names_ep, float) and pd.isna(subtask_names_ep)):
|
||||
continue
|
||||
|
||||
start_times = episodes_df.loc[ep_idx, 'subtask_start_times']
|
||||
end_times = episodes_df.loc[ep_idx, 'subtask_end_times']
|
||||
|
||||
# Track unique subtask names
|
||||
all_subtask_names.update(subtask_names)
|
||||
all_subtask_names.update(subtask_names_ep)
|
||||
|
||||
# Compute durations
|
||||
for i, name in enumerate(subtask_names):
|
||||
# Compute total trajectory length T_i (sum of all subtask durations)
|
||||
total_traj_length = sum(end_times[i] - start_times[i] for i in range(len(subtask_names_ep)))
|
||||
|
||||
if total_traj_length <= 0:
|
||||
continue
|
||||
|
||||
# Store duration and trajectory length for each subtask occurrence
|
||||
for i, name in enumerate(subtask_names_ep):
|
||||
duration = end_times[i] - start_times[i]
|
||||
if name not in subtask_durations:
|
||||
subtask_durations[name] = []
|
||||
subtask_durations[name].append(duration)
|
||||
|
||||
if name not in subtask_durations_per_trajectory:
|
||||
subtask_durations_per_trajectory[name] = []
|
||||
trajectory_lengths[name] = []
|
||||
|
||||
subtask_durations_per_trajectory[name].append(duration)
|
||||
trajectory_lengths[name].append(total_traj_length)
|
||||
|
||||
if not all_subtask_names:
|
||||
logging.info("No valid subtask annotations found")
|
||||
@@ -142,44 +147,17 @@ class SARMEncodingProcessorStep(ProcessorStep):
|
||||
# Sort subtask names for consistent ordering
|
||||
self.subtask_names = sorted(list(all_subtask_names))
|
||||
self.config.num_stages = len(self.subtask_names)
|
||||
self.config.subtask_names = self.subtask_names # Store in config for reference
|
||||
self.config.subtask_names = self.subtask_names
|
||||
|
||||
# Compute average duration for each subtask
|
||||
avg_durations = {}
|
||||
for name in self.subtask_names:
|
||||
if name in subtask_durations:
|
||||
avg_durations[name] = np.mean(subtask_durations[name])
|
||||
else:
|
||||
avg_durations[name] = 0.0
|
||||
|
||||
# Normalize to get proportions
|
||||
total_duration = sum(avg_durations.values())
|
||||
if total_duration > 0:
|
||||
self.temporal_proportions = {
|
||||
name: avg_durations[name] / total_duration
|
||||
for name in self.subtask_names
|
||||
}
|
||||
else:
|
||||
raise ValueError(
|
||||
"Cannot compute temporal proportions: all subtask durations are zero. "
|
||||
"Check that your dataset has valid subtask annotations with start/end times."
|
||||
)
|
||||
|
||||
# Store in config for the model to use in progress output conversion (SARM paper Eq. 4)
|
||||
# Compute temporal proportions using Paper Formula (1)
|
||||
self.temporal_proportions = compute_priors(
|
||||
subtask_durations_per_trajectory,
|
||||
trajectory_lengths,
|
||||
self.subtask_names
|
||||
)
|
||||
self.config.temporal_proportions = [self.temporal_proportions[name] for name in self.subtask_names]
|
||||
|
||||
logging.info(f"Computed temporal proportions for {len(self.subtask_names)} subtasks: {self.temporal_proportions}")
|
||||
|
||||
def _to_numpy_array(self, x) -> np.ndarray:
|
||||
"""Convert input to a 1D numpy array."""
|
||||
if isinstance(x, torch.Tensor):
|
||||
arr = x.cpu().numpy()
|
||||
else:
|
||||
arr = np.array(x)
|
||||
if arr.ndim == 0:
|
||||
arr = np.array([arr.item()])
|
||||
return arr
|
||||
|
||||
def _find_episode_for_frame(self, frame_idx: int) -> int:
|
||||
"""Find the episode index for a given frame index."""
|
||||
for ep_idx in range(len(self.dataset_meta.episodes)):
|
||||
@@ -187,14 +165,14 @@ class SARMEncodingProcessorStep(ProcessorStep):
|
||||
ep_end = self.dataset_meta.episodes[ep_idx]["dataset_to_index"]
|
||||
if ep_start <= frame_idx < ep_end:
|
||||
return ep_idx
|
||||
return 0 # Fallback
|
||||
return 0
|
||||
|
||||
def _get_episode_indices(self, frame_indices: np.ndarray, episode_index) -> np.ndarray:
|
||||
"""Get episode indices for each frame index."""
|
||||
if episode_index is None:
|
||||
return np.array([self._find_episode_for_frame(int(f)) for f in frame_indices])
|
||||
|
||||
episode_indices = self._to_numpy_array(episode_index)
|
||||
episode_indices = np.atleast_1d(np.asarray(from_tensor_to_numpy(episode_index)))
|
||||
|
||||
# If single episode but multiple frames, compute episode for each frame
|
||||
if len(episode_indices) == 1 and len(frame_indices) > 1:
|
||||
@@ -211,22 +189,16 @@ class SARMEncodingProcessorStep(ProcessorStep):
|
||||
Pattern: [ep_start, t-(7*gap), t-(6*gap), ..., t-gap, t]
|
||||
|
||||
"""
|
||||
frame_gap = getattr(self.config, 'frame_gap', 1)
|
||||
|
||||
indices = []
|
||||
|
||||
|
||||
# First frame is the episode's initial frame
|
||||
indices.append(ep_start)
|
||||
indices.append(ep_start) # First frame is the episode's initial frame
|
||||
|
||||
# Remaining frames are consecutive with frame_gap spacing
|
||||
num_consecutive = num_frames - 1
|
||||
for i in range(num_consecutive):
|
||||
offset = -(num_consecutive - 1 - i) * frame_gap
|
||||
offset = -(num_consecutive - 1 - i) * self.config.frame_gap
|
||||
idx = max(ep_start, frame_idx + offset)
|
||||
indices.append(idx)
|
||||
|
||||
|
||||
return torch.tensor(indices)
|
||||
|
||||
def _compute_episode_metadata(
|
||||
@@ -269,6 +241,14 @@ class SARMEncodingProcessorStep(ProcessorStep):
|
||||
) -> tuple[int, float]:
|
||||
"""Compute stage index and cumulative progress for a single frame.
|
||||
|
||||
Implements SARM Paper Formula (2):
|
||||
y_t = P_{k-1} + ᾱ_k × τ_t
|
||||
|
||||
where:
|
||||
- τ_t = (t - s_k) / (e_k - s_k) is within-subtask progress
|
||||
- P_{k-1} is cumulative prior (sum of previous subtask proportions)
|
||||
- ᾱ_k is the temporal proportion for subtask k
|
||||
|
||||
Args:
|
||||
current_frame: Frame index relative to episode start
|
||||
subtask_names: List of subtask names for this episode
|
||||
@@ -278,53 +258,46 @@ class SARMEncodingProcessorStep(ProcessorStep):
|
||||
Returns:
|
||||
Tuple of (stage_idx, cumulative_progress)
|
||||
"""
|
||||
stage_idx = -1
|
||||
cumulative_progress = 0.0
|
||||
# Get temporal proportions as list for compute_cumulative_progress
|
||||
temporal_proportions_list = [
|
||||
self.temporal_proportions.get(name, 0.0) for name in self.subtask_names
|
||||
]
|
||||
|
||||
# Find which subtask this frame belongs to
|
||||
for j, (name, start_frame, end_frame) in enumerate(zip(subtask_names, subtask_start_frames, subtask_end_frames)):
|
||||
if current_frame >= start_frame and current_frame <= end_frame:
|
||||
# Found the subtask
|
||||
# Found the subtask, get its global index
|
||||
stage_idx = self.subtask_names.index(name) if name in self.subtask_names else 0
|
||||
|
||||
# Calculate within-subtask progress
|
||||
subtask_duration = end_frame - start_frame
|
||||
if subtask_duration > 0:
|
||||
within_subtask_progress = (current_frame - start_frame) / subtask_duration
|
||||
else:
|
||||
within_subtask_progress = 1.0
|
||||
# Compute τ_t using utility function (Paper Formula 2)
|
||||
tau = compute_tau(current_frame, start_frame, end_frame)
|
||||
|
||||
# Calculate cumulative progress from completed subtasks
|
||||
for k in range(j):
|
||||
prev_name = subtask_names[k]
|
||||
if prev_name in self.temporal_proportions:
|
||||
cumulative_progress += self.temporal_proportions[prev_name]
|
||||
|
||||
# Add current subtask's partial progress
|
||||
if name in self.temporal_proportions:
|
||||
cumulative_progress += self.temporal_proportions[name] * within_subtask_progress
|
||||
# Compute cumulative progress using utility function (Paper Formula 2)
|
||||
cumulative_progress = compute_cumulative_progress_batch(
|
||||
tau, stage_idx, temporal_proportions_list
|
||||
)
|
||||
|
||||
return stage_idx, cumulative_progress
|
||||
|
||||
# No matching subtask found - estimate based on position
|
||||
# No matching subtask found
|
||||
if current_frame < subtask_start_frames[0]:
|
||||
return 0, 0.0
|
||||
elif current_frame > subtask_end_frames[-1]:
|
||||
return len(self.subtask_names) - 1, 1.0
|
||||
else:
|
||||
# Between subtasks - use previous subtask's end state
|
||||
# Between subtasks - use previous subtask's end state (tau = 1.0)
|
||||
for j in range(len(subtask_names) - 1):
|
||||
if current_frame > subtask_end_frames[j] and current_frame < subtask_start_frames[j + 1]:
|
||||
name = subtask_names[j]
|
||||
stage_idx = self.subtask_names.index(name) if name in self.subtask_names else j
|
||||
# Sum up all completed subtasks
|
||||
for k in range(j + 1):
|
||||
prev_name = subtask_names[k]
|
||||
if prev_name in self.temporal_proportions:
|
||||
cumulative_progress += self.temporal_proportions[prev_name]
|
||||
|
||||
# Completed subtask, so tau = 1.0
|
||||
cumulative_progress = compute_cumulative_progress_batch(
|
||||
1.0, stage_idx, temporal_proportions_list
|
||||
)
|
||||
return stage_idx, cumulative_progress
|
||||
|
||||
return 0, 0.0 # Fallback
|
||||
return 0, 0.0
|
||||
|
||||
def _compute_labels_for_sample(
|
||||
self,
|
||||
@@ -359,13 +332,8 @@ class SARMEncodingProcessorStep(ProcessorStep):
|
||||
|
||||
subtask_start_frames = episodes_df.loc[ep_idx, 'subtask_start_frames']
|
||||
subtask_end_frames = episodes_df.loc[ep_idx, 'subtask_end_frames']
|
||||
|
||||
# Get episode boundaries
|
||||
ep_start = self.dataset_meta.episodes[ep_idx]["dataset_from_index"]
|
||||
|
||||
# Get config values
|
||||
frame_gap = self.config.frame_gap if hasattr(self.config, 'frame_gap') else 1
|
||||
|
||||
# Generate labels for each frame in the sequence
|
||||
stage_labels = []
|
||||
progress_targets = []
|
||||
@@ -377,7 +345,7 @@ class SARMEncodingProcessorStep(ProcessorStep):
|
||||
else:
|
||||
# Positions 1-8: consecutive frames with frame_gap spacing
|
||||
num_consecutive = seq_len - 1
|
||||
offset = -(num_consecutive - i) * frame_gap
|
||||
offset = -(num_consecutive - i) * self.config.frame_gap
|
||||
current_frame = max(0, frame_idx + offset - ep_start)
|
||||
|
||||
|
||||
@@ -388,7 +356,6 @@ class SARMEncodingProcessorStep(ProcessorStep):
|
||||
stage_labels.append(stage_idx)
|
||||
progress_targets.append(cumulative_progress)
|
||||
|
||||
# Convert to tensors
|
||||
stage_labels = torch.tensor(stage_labels, dtype=torch.long)
|
||||
progress_targets = torch.tensor(progress_targets, dtype=torch.float32).unsqueeze(-1)
|
||||
|
||||
@@ -411,7 +378,7 @@ class SARMEncodingProcessorStep(ProcessorStep):
|
||||
is_batch = isinstance(frame_index, torch.Tensor) and frame_index.numel() > 1
|
||||
|
||||
# Normalize inputs to numpy arrays
|
||||
frame_indices = self._to_numpy_array(frame_index)
|
||||
frame_indices = np.atleast_1d(np.asarray(from_tensor_to_numpy(frame_index)))
|
||||
episode_indices = self._get_episode_indices(frame_indices, episode_index)
|
||||
|
||||
# Determine sequence length
|
||||
@@ -422,7 +389,6 @@ class SARMEncodingProcessorStep(ProcessorStep):
|
||||
|
||||
episodes_df = self.dataset_meta.episodes.to_pandas()
|
||||
|
||||
# Process all samples
|
||||
all_stage_labels = []
|
||||
all_progress_targets = []
|
||||
|
||||
@@ -450,7 +416,7 @@ class SARMEncodingProcessorStep(ProcessorStep):
|
||||
if not isinstance(observation, dict):
|
||||
raise ValueError("Observation must be a dictionary")
|
||||
|
||||
# 1. Encode images with CLIP
|
||||
# Encode images with CLIP
|
||||
image = observation.get(self.image_key)
|
||||
if image is None:
|
||||
raise ValueError(f"Image not found in observation for key: {self.image_key}")
|
||||
@@ -460,27 +426,27 @@ class SARMEncodingProcessorStep(ProcessorStep):
|
||||
video_features = self._encode_images_batch(image)
|
||||
observation['video_features'] = video_features
|
||||
|
||||
# 2. Extract and normalize joint states
|
||||
state_data = observation.get("state") or observation.get("observation.state")
|
||||
# Extract state and pad to max_state_dim (already normalized by NormalizerProcessorStep)
|
||||
state_key = self.config.state_key
|
||||
state_data = observation.get(state_key)
|
||||
if state_data is None:
|
||||
raise ValueError("State data not found in observation (expected 'state' or 'observation.state')")
|
||||
state_data = observation.get("state") or observation.get("observation.state")
|
||||
if state_data is None:
|
||||
raise ValueError(f"State data not found in observation (expected '{state_key}', 'state', or 'observation.state')")
|
||||
|
||||
if isinstance(state_data, torch.Tensor):
|
||||
state_data = state_data.cpu().numpy()
|
||||
state_tensor = state_data.float()
|
||||
else:
|
||||
state_tensor = torch.tensor(state_data, dtype=torch.float32)
|
||||
|
||||
state_key = "state" if "state" in observation else "observation.state"
|
||||
if self.dataset_stats and state_key in self.dataset_stats:
|
||||
mean = self.dataset_stats[state_key]['mean']
|
||||
std = self.dataset_stats[state_key]['std']
|
||||
state_data = (state_data - mean) / (std + 1e-8)
|
||||
# Pad state
|
||||
observation['state_features'] = pad_state_to_max_dim(state_tensor, self.config.max_state_dim)
|
||||
|
||||
observation['state_features'] = torch.tensor(state_data, dtype=torch.float32)
|
||||
|
||||
# 3. Encode text with CLIP (per SARM paper A.4)
|
||||
# Encode text with CLIP
|
||||
batch_size = video_features.shape[0]
|
||||
observation['text_features'] = self._encode_text_clip(self.task_description, batch_size)
|
||||
|
||||
# 4. Extract frame/episode indices from complementary data
|
||||
# Extract frame/episode indices from complementary data
|
||||
comp_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||
if not isinstance(comp_data, dict):
|
||||
raise ValueError("COMPLEMENTARY_DATA must be a dictionary")
|
||||
@@ -493,10 +459,10 @@ class SARMEncodingProcessorStep(ProcessorStep):
|
||||
if episode_index is None:
|
||||
raise ValueError("Episode index ('episode_index') not found in COMPLEMENTARY_DATA")
|
||||
|
||||
# 5. Compute episode metadata if dataset_meta is available
|
||||
# Compute episode metadata if dataset_meta is available
|
||||
if self.dataset_meta is not None:
|
||||
is_batch = isinstance(frame_index, torch.Tensor) and frame_index.numel() > 1
|
||||
frame_indices = self._to_numpy_array(frame_index)
|
||||
frame_indices = np.atleast_1d(np.asarray(from_tensor_to_numpy(frame_index)))
|
||||
episode_indices = self._get_episode_indices(frame_indices, episode_index)
|
||||
|
||||
# Determine number of frames from video features
|
||||
@@ -512,7 +478,7 @@ class SARMEncodingProcessorStep(ProcessorStep):
|
||||
observation['remaining_length'] = remaining
|
||||
observation['episode_length'] = ep_lengths
|
||||
|
||||
# 6. Generate stage labels and progress targets from subtask annotations
|
||||
# Generate stage labels and progress targets from subtask annotations
|
||||
if self.temporal_proportions is not None and self.dataset_meta is not None:
|
||||
stage_labels, progress_targets = self._generate_stage_and_progress_labels(
|
||||
frame_index, episode_index, video_features
|
||||
@@ -539,14 +505,10 @@ class SARMEncodingProcessorStep(ProcessorStep):
|
||||
# Check if we have temporal dimension
|
||||
has_temporal = len(images.shape) == 5
|
||||
|
||||
if has_temporal:
|
||||
# Shape: (B, T, C, H, W)
|
||||
if has_temporal: # Shape: (B, T, C, H, W)
|
||||
batch_size, seq_length = images.shape[0], images.shape[1]
|
||||
|
||||
# Reshape to (B*T, C, H, W) to process all frames at once
|
||||
images = images.reshape(batch_size * seq_length, *images.shape[2:])
|
||||
elif len(images.shape) == 4:
|
||||
# Shape: (B, C, H, W)
|
||||
elif len(images.shape) == 4: # Shape: (B, C, H, W)
|
||||
batch_size = images.shape[0]
|
||||
seq_length = 1
|
||||
else:
|
||||
@@ -608,7 +570,7 @@ class SARMEncodingProcessorStep(ProcessorStep):
|
||||
Returns:
|
||||
Encoded text features with shape (B, 512)
|
||||
"""
|
||||
# Use CLIP's tokenizer directly for text (avoids image processor validation issues)
|
||||
# Use CLIP's tokenizer directly for text
|
||||
tokenizer = self.clip_processor.tokenizer
|
||||
inputs = tokenizer([text], return_tensors="pt", padding=True, truncation=True)
|
||||
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
||||
@@ -629,7 +591,7 @@ class SARMEncodingProcessorStep(ProcessorStep):
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
"""Add encoded features to the observation features."""
|
||||
# Add the encoded features
|
||||
# Add the encoded features (state uses max_state_dim, padded with zeros)
|
||||
features[PipelineFeatureType.OBSERVATION]['video_features'] = PolicyFeature(
|
||||
type=FeatureType.VISUAL,
|
||||
shape=(self.config.num_frames, self.config.image_dim)
|
||||
@@ -640,7 +602,7 @@ class SARMEncodingProcessorStep(ProcessorStep):
|
||||
)
|
||||
features[PipelineFeatureType.OBSERVATION]['state_features'] = PolicyFeature(
|
||||
type=FeatureType.STATE,
|
||||
shape=(self.config.num_frames, self.config.state_dim)
|
||||
shape=(self.config.num_frames, self.config.max_state_dim)
|
||||
)
|
||||
return features
|
||||
|
||||
@@ -660,11 +622,16 @@ def make_sarm_pre_post_processors(
|
||||
to process both RGB image sequences and task descriptions."
|
||||
|
||||
The pre-processing pipeline:
|
||||
1. Encodes images with CLIP (512-dim)
|
||||
2. Encodes text with CLIP (512-dim)
|
||||
3. Normalizes joint states
|
||||
4. Adds batch dimension
|
||||
5. Moves data to device
|
||||
1. Adds batch dimension
|
||||
2. Normalizes observation.state using NormalizerProcessorStep (MEAN_STD)
|
||||
3. SARMEncodingProcessorStep:
|
||||
- Encodes images with CLIP (512-dim)
|
||||
- Pads states to max_state_dim
|
||||
- Encodes text with CLIP (512-dim)
|
||||
4. Moves data to device
|
||||
|
||||
The post-processing pipeline:
|
||||
1. Moves data to CPU (no unnormalization - outputs are rewards)
|
||||
|
||||
Args:
|
||||
config: SARM configuration
|
||||
@@ -676,6 +643,11 @@ def make_sarm_pre_post_processors(
|
||||
"""
|
||||
input_steps = [
|
||||
AddBatchDimensionProcessorStep(),
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
SARMEncodingProcessorStep(
|
||||
config=config,
|
||||
dataset_meta=dataset_meta,
|
||||
|
||||
@@ -0,0 +1,249 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Utility functions for SARM progress label computation.
|
||||
|
||||
Implements formulas from the SARM paper:
|
||||
- Formula (1): Compute dataset-level temporal proportions (priors) ᾱ_k
|
||||
- Formula (2): Compute normalized progress targets y_t = P_{k-1} + ᾱ_k × τ_t
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from typing import Sequence
|
||||
|
||||
|
||||
def compute_priors(
|
||||
subtask_durations_per_trajectory: dict[str, list[float]],
|
||||
trajectory_lengths: dict[str, list[float]],
|
||||
subtask_names: list[str],
|
||||
) -> dict[str, float]:
|
||||
"""
|
||||
Compute dataset-level temporal proportions (priors) for each subtask.
|
||||
|
||||
Implements SARM Paper Formula (1):
|
||||
ᾱ_k = (1/M) × Σ_i (L_{i,k} / T_i)
|
||||
|
||||
where:
|
||||
- M is the number of trajectories
|
||||
- L_{i,k} is the length of subtask k in trajectory i
|
||||
- T_i is the total length of trajectory i
|
||||
|
||||
This averages the PROPORTION of each subtask within each trajectory,
|
||||
giving equal weight to all trajectories regardless of their absolute length.
|
||||
|
||||
Args:
|
||||
subtask_durations_per_trajectory: Dict mapping subtask name to list of
|
||||
(duration, trajectory_length) tuples for each occurrence
|
||||
trajectory_lengths: Dict mapping subtask name to list of trajectory lengths
|
||||
for each occurrence of that subtask
|
||||
subtask_names: Ordered list of subtask names
|
||||
|
||||
Returns:
|
||||
Dict mapping subtask name to its temporal proportion (ᾱ_k)
|
||||
"""
|
||||
if not subtask_names:
|
||||
raise ValueError("subtask_names cannot be empty")
|
||||
|
||||
# Compute proportion per occurrence: L_{i,k} / T_i
|
||||
subtask_proportions = {}
|
||||
for name in subtask_names:
|
||||
if name in subtask_durations_per_trajectory and name in trajectory_lengths:
|
||||
durations = subtask_durations_per_trajectory[name]
|
||||
traj_lengths = trajectory_lengths[name]
|
||||
|
||||
if len(durations) != len(traj_lengths):
|
||||
raise ValueError(
|
||||
f"Mismatch in lengths for subtask '{name}': "
|
||||
f"{len(durations)} durations vs {len(traj_lengths)} trajectory lengths"
|
||||
)
|
||||
|
||||
# Compute L_{i,k} / T_i for each occurrence
|
||||
proportions = []
|
||||
for duration, traj_len in zip(durations, traj_lengths):
|
||||
if traj_len > 0:
|
||||
proportions.append(duration / traj_len)
|
||||
|
||||
# Average across all occurrences: (1/M) × Σ_i (L_{i,k} / T_i)
|
||||
subtask_proportions[name] = np.mean(proportions) if proportions else 0.0
|
||||
else:
|
||||
subtask_proportions[name] = 0.0
|
||||
|
||||
# Normalize to ensure sum = 1 (handles floating point errors and missing subtasks)
|
||||
total = sum(subtask_proportions.values())
|
||||
if total > 0:
|
||||
subtask_proportions = {
|
||||
name: prop / total for name, prop in subtask_proportions.items()
|
||||
}
|
||||
else:
|
||||
raise ValueError("Cannot compute temporal proportions: all proportions are zero. "
|
||||
"Check that your dataset has valid subtask annotations with start/end times.")
|
||||
|
||||
return subtask_proportions
|
||||
|
||||
|
||||
def compute_tau(
|
||||
current_frame: int | float,
|
||||
subtask_start: int | float,
|
||||
subtask_end: int | float,
|
||||
) -> float:
|
||||
"""
|
||||
Compute within-subtask normalized time τ_t.
|
||||
|
||||
Implements part of SARM Paper Formula (2):
|
||||
τ_t = (t - s_k) / (e_k - s_k) ∈ [0, 1]
|
||||
|
||||
where:
|
||||
- t is the current frame
|
||||
- s_k is the start frame of subtask k
|
||||
- e_k is the end frame of subtask k
|
||||
|
||||
Args:
|
||||
current_frame: Current frame index (t)
|
||||
subtask_start: Start frame of the subtask (s_k)
|
||||
subtask_end: End frame of the subtask (e_k)
|
||||
|
||||
Returns:
|
||||
Within-subtask progress τ_t ∈ [0, 1]
|
||||
"""
|
||||
subtask_duration = subtask_end - subtask_start
|
||||
|
||||
if subtask_duration <= 0:
|
||||
return 1.0
|
||||
|
||||
tau = (current_frame - subtask_start) / subtask_duration
|
||||
|
||||
return float(np.clip(tau, 0.0, 1.0))
|
||||
|
||||
|
||||
def compute_cumulative_progress_batch(
|
||||
tau: torch.Tensor | float,
|
||||
stage_indices: torch.Tensor | int,
|
||||
alpha: torch.Tensor | Sequence[float],
|
||||
cumulative_prior: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | float:
|
||||
"""
|
||||
Compute cumulative normalized progress from within-subtask progress.
|
||||
|
||||
This function implements the core formula used in SARM for both:
|
||||
|
||||
**Formula 2 (Training labels):**
|
||||
y_t = P_{k-1} + ᾱ_k × τ_t ∈ [0, 1]
|
||||
|
||||
Used to compute ground-truth progress labels from subtask annotations.
|
||||
- τ_t comes from annotated frame position: τ_t = (t - s_k) / (e_k - s_k)
|
||||
- k is the known subtask from annotations
|
||||
|
||||
**Formula 4 (Inference predictions):**
|
||||
ŷ_{1:N} = P̂_{k-1, 1:N} + ᾱ_{k, 1:N} × τ̂_{1:N} ∈ [0, 1]
|
||||
|
||||
Used to convert model outputs to cumulative progress during inference.
|
||||
- τ̂ comes from the subtask MLP head (conditioned on predicted stage)
|
||||
- k = Ŝ is the predicted stage from Formula 3: Ŝ = argmax(softmax(Ψ))
|
||||
|
||||
The formulas are mathematically identical; only the source of inputs differs:
|
||||
- Training: τ and k from annotations → ground-truth labels
|
||||
- Inference: τ̂ and Ŝ from model → predicted progress
|
||||
|
||||
where:
|
||||
- P_{k-1} = Σ_{j=1}^{k-1} ᾱ_j is the cumulative prior (sum of previous proportions)
|
||||
- ᾱ_k is the temporal proportion for subtask k (from Formula 1)
|
||||
- τ is within-subtask progress ∈ [0, 1]
|
||||
|
||||
This ensures:
|
||||
- y at start of subtask k = P_{k-1}
|
||||
- y at end of subtask k = P_k
|
||||
|
||||
Supports both scalar and batched tensor inputs:
|
||||
- Scalar: tau (float), stage_indices (int), alpha (list/sequence)
|
||||
- Batch: tau (Tensor), stage_indices (Tensor), alpha (Tensor), cumulative_prior (Tensor)
|
||||
|
||||
Args:
|
||||
tau: Within-subtask progress τ ∈ [0, 1].
|
||||
For training: computed from frame position in annotated subtask.
|
||||
For inference: predicted by subtask MLP head.
|
||||
Scalar float or Tensor with shape (..., 1)
|
||||
stage_indices: Index of current subtask k (0-indexed).
|
||||
For training: known from annotations.
|
||||
For inference: predicted via argmax(stage_probs) (Formula 3).
|
||||
Scalar int or Tensor with shape (...)
|
||||
alpha: Temporal proportions ᾱ with shape (num_stages,) or Sequence[float].
|
||||
Computed from dataset annotations using Formula 1.
|
||||
cumulative_prior: Optional. Cumulative priors P with shape (num_stages + 1,)
|
||||
where cumulative_prior[k] = P_k = Σ_{j=1}^{k} ᾱ_j.
|
||||
If None, will be computed from alpha.
|
||||
|
||||
Returns:
|
||||
Cumulative progress y ∈ [0, 1].
|
||||
Scalar float if inputs are scalar, otherwise Tensor with shape (..., 1)
|
||||
"""
|
||||
if not isinstance(tau, torch.Tensor):
|
||||
if not alpha:
|
||||
raise ValueError("alpha (temporal_proportions) cannot be empty")
|
||||
|
||||
if isinstance(alpha, torch.Tensor):
|
||||
alpha_list = alpha.tolist()
|
||||
else:
|
||||
alpha_list = list(alpha)
|
||||
|
||||
if stage_indices < 0 or stage_indices >= len(alpha_list):
|
||||
raise ValueError(
|
||||
f"stage_indices {stage_indices} out of range "
|
||||
f"for {len(alpha_list)} subtasks"
|
||||
)
|
||||
|
||||
# P_{k-1} = sum of proportions for subtasks 0 to k-1
|
||||
P_k_minus_1 = sum(alpha_list[:stage_indices])
|
||||
|
||||
# ᾱ_k = proportion for current subtask
|
||||
alpha_k = alpha_list[stage_indices]
|
||||
|
||||
# y_t = P_{k-1} + ᾱ_k × τ_t
|
||||
y_t = P_k_minus_1 + alpha_k * tau
|
||||
|
||||
return float(np.clip(y_t, 0.0, 1.0))
|
||||
|
||||
if not isinstance(alpha, torch.Tensor):
|
||||
alpha = torch.tensor(alpha, dtype=torch.float32)
|
||||
|
||||
# Compute cumulative_prior if not provided
|
||||
if cumulative_prior is None:
|
||||
cumulative_prior = torch.zeros(len(alpha) + 1, dtype=alpha.dtype, device=alpha.device)
|
||||
cumulative_prior[1:] = torch.cumsum(alpha, dim=0)
|
||||
|
||||
# P_{k-1} for each predicted stage
|
||||
P_k_minus_1 = cumulative_prior[stage_indices]
|
||||
|
||||
# ᾱ_k for each predicted stage
|
||||
alpha_k = alpha[stage_indices]
|
||||
|
||||
# ŷ = P_{k-1} + ᾱ_k × τ̂
|
||||
progress = P_k_minus_1.unsqueeze(-1) + alpha_k.unsqueeze(-1) * tau
|
||||
|
||||
return progress
|
||||
|
||||
def pad_state_to_max_dim(state: torch.Tensor, max_state_dim: int) -> torch.Tensor:
|
||||
"""Pad the state tensor's last dimension to max_state_dim with zeros."""
|
||||
current_dim = state.shape[-1]
|
||||
if current_dim >= max_state_dim:
|
||||
return state[..., :max_state_dim] # Truncate if larger
|
||||
|
||||
# Pad with zeros on the right
|
||||
padding = (0, max_state_dim - current_dim) # (left, right) for last dim
|
||||
return F.pad(state, padding, mode='constant', value=0)
|
||||
|
||||
@@ -0,0 +1,392 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Tests for SARM utility functions.
|
||||
|
||||
Tests the implementation of SARM paper formulas:
|
||||
- Formula (1): compute_priors - dataset-level temporal proportions
|
||||
- Formula (2): compute_tau, compute_cumulative_progress - progress labels
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.policies.sarm.sarm_utils import (
|
||||
compute_priors,
|
||||
compute_tau,
|
||||
compute_cumulative_progress_batch,
|
||||
)
|
||||
|
||||
|
||||
class TestComputePriors:
|
||||
"""Tests for compute_priors (SARM Paper Formula 1).
|
||||
|
||||
Formula: ᾱ_k = (1/M) × Σ_i (L_{i,k} / T_i)
|
||||
|
||||
Key insight: This averages the PROPORTION of each subtask within each trajectory,
|
||||
giving equal weight to all trajectories regardless of absolute length.
|
||||
"""
|
||||
|
||||
def test_basic_two_trajectories_equal_proportions(self):
|
||||
"""Test with two trajectories that have equal proportions."""
|
||||
# Both trajectories: subtask1 = 50%, subtask2 = 50%
|
||||
subtask_durations = {
|
||||
'subtask1': [50, 100], # durations
|
||||
'subtask2': [50, 100],
|
||||
}
|
||||
trajectory_lengths = {
|
||||
'subtask1': [100, 200],
|
||||
'subtask2': [100, 200],
|
||||
}
|
||||
subtask_names = ['subtask1', 'subtask2']
|
||||
|
||||
result = compute_priors(subtask_durations, trajectory_lengths, subtask_names)
|
||||
|
||||
# Both should be 0.5
|
||||
assert abs(result['subtask1'] - 0.5) < 1e-6
|
||||
assert abs(result['subtask2'] - 0.5) < 1e-6
|
||||
|
||||
def test_paper_example_different_from_avg_durations(self):
|
||||
"""Test that compute_priors differs from naive average duration approach.
|
||||
|
||||
This is the key test showing the difference between:
|
||||
- Paper formula: average of (L_i,k / T_i)
|
||||
- Naive approach: mean(L_i,k) / sum(mean(L_i,j))
|
||||
"""
|
||||
# Episode 1: T=100, subtask1=80, subtask2=20 (proportions: 0.8, 0.2)
|
||||
# Episode 2: T=200, subtask1=40, subtask2=160 (proportions: 0.2, 0.8)
|
||||
subtask_durations = {
|
||||
'subtask1': [80, 40],
|
||||
'subtask2': [20, 160],
|
||||
}
|
||||
trajectory_lengths = {
|
||||
'subtask1': [100, 200],
|
||||
'subtask2': [100, 200],
|
||||
}
|
||||
subtask_names = ['subtask1', 'subtask2']
|
||||
|
||||
result = compute_priors(subtask_durations, trajectory_lengths, subtask_names)
|
||||
|
||||
# Paper formula:
|
||||
# ᾱ_1 = (1/2) × (80/100 + 40/200) = (1/2) × (0.8 + 0.2) = 0.5
|
||||
# ᾱ_2 = (1/2) × (20/100 + 160/200) = (1/2) × (0.2 + 0.8) = 0.5
|
||||
assert abs(result['subtask1'] - 0.5) < 1e-6
|
||||
assert abs(result['subtask2'] - 0.5) < 1e-6
|
||||
|
||||
|
||||
def test_single_trajectory(self):
|
||||
"""Test with a single trajectory."""
|
||||
subtask_durations = {
|
||||
'reach': [30],
|
||||
'grasp': [20],
|
||||
'lift': [50],
|
||||
}
|
||||
trajectory_lengths = {
|
||||
'reach': [100],
|
||||
'grasp': [100],
|
||||
'lift': [100],
|
||||
}
|
||||
subtask_names = ['grasp', 'lift', 'reach'] # sorted order
|
||||
|
||||
result = compute_priors(subtask_durations, trajectory_lengths, subtask_names)
|
||||
|
||||
assert abs(result['reach'] - 0.3) < 1e-6
|
||||
assert abs(result['grasp'] - 0.2) < 1e-6
|
||||
assert abs(result['lift'] - 0.5) < 1e-6
|
||||
|
||||
def test_sum_to_one(self):
|
||||
"""Test that proportions always sum to 1."""
|
||||
subtask_durations = {
|
||||
'a': [10, 20, 30],
|
||||
'b': [40, 50, 60],
|
||||
'c': [50, 30, 10],
|
||||
}
|
||||
trajectory_lengths = {
|
||||
'a': [100, 100, 100],
|
||||
'b': [100, 100, 100],
|
||||
'c': [100, 100, 100],
|
||||
}
|
||||
subtask_names = ['a', 'b', 'c']
|
||||
|
||||
result = compute_priors(subtask_durations, trajectory_lengths, subtask_names)
|
||||
|
||||
total = sum(result.values())
|
||||
assert abs(total - 1.0) < 1e-6
|
||||
|
||||
def test_empty_subtask_names_raises(self):
|
||||
"""Test that empty subtask_names raises an error."""
|
||||
with pytest.raises(ValueError, match="subtask_names cannot be empty"):
|
||||
compute_priors({}, {}, [])
|
||||
|
||||
def test_missing_subtask_gets_zero_before_normalization(self):
|
||||
"""Test handling of subtasks that appear in some but not all trajectories."""
|
||||
# subtask1 appears in both, subtask2 only in first
|
||||
subtask_durations = {
|
||||
'subtask1': [50, 100],
|
||||
'subtask2': [50], # only in first trajectory
|
||||
}
|
||||
trajectory_lengths = {
|
||||
'subtask1': [100, 200],
|
||||
'subtask2': [100],
|
||||
}
|
||||
subtask_names = ['subtask1', 'subtask2']
|
||||
|
||||
result = compute_priors(subtask_durations, trajectory_lengths, subtask_names)
|
||||
|
||||
# subtask1: (50/100 + 100/200) / 2 = (0.5 + 0.5) / 2 = 0.5
|
||||
# subtask2: 50/100 = 0.5 (only one occurrence)
|
||||
# After normalization: both should be 0.5
|
||||
assert result['subtask1'] > 0
|
||||
assert result['subtask2'] > 0
|
||||
assert abs(sum(result.values()) - 1.0) < 1e-6
|
||||
|
||||
|
||||
class TestComputeTau:
|
||||
"""Tests for compute_tau (within-subtask progress).
|
||||
|
||||
Formula: τ_t = (t - s_k) / (e_k - s_k) ∈ [0, 1]
|
||||
"""
|
||||
|
||||
def test_at_start(self):
|
||||
"""τ should be 0 at subtask start."""
|
||||
tau = compute_tau(current_frame=10, subtask_start=10, subtask_end=50)
|
||||
assert tau == 0.0
|
||||
|
||||
def test_at_end(self):
|
||||
"""τ should be 1 at subtask end."""
|
||||
tau = compute_tau(current_frame=50, subtask_start=10, subtask_end=50)
|
||||
assert tau == 1.0
|
||||
|
||||
def test_at_middle(self):
|
||||
"""τ should be 0.5 at subtask midpoint."""
|
||||
tau = compute_tau(current_frame=30, subtask_start=10, subtask_end=50)
|
||||
assert abs(tau - 0.5) < 1e-6
|
||||
|
||||
def test_quarter_progress(self):
|
||||
"""Test τ at 25% through subtask."""
|
||||
tau = compute_tau(current_frame=20, subtask_start=0, subtask_end=80)
|
||||
assert abs(tau - 0.25) < 1e-6
|
||||
|
||||
def test_zero_duration_subtask(self):
|
||||
"""τ should be 1.0 for zero-duration subtask."""
|
||||
tau = compute_tau(current_frame=10, subtask_start=10, subtask_end=10)
|
||||
assert tau == 1.0
|
||||
|
||||
def test_clamps_below_zero(self):
|
||||
"""τ should be clamped to 0 if frame is before subtask."""
|
||||
tau = compute_tau(current_frame=5, subtask_start=10, subtask_end=50)
|
||||
assert tau == 0.0
|
||||
|
||||
def test_clamps_above_one(self):
|
||||
"""τ should be clamped to 1 if frame is after subtask."""
|
||||
tau = compute_tau(current_frame=60, subtask_start=10, subtask_end=50)
|
||||
assert tau == 1.0
|
||||
|
||||
def test_float_inputs(self):
|
||||
"""Test with float frame indices (from interpolation)."""
|
||||
tau = compute_tau(current_frame=25.5, subtask_start=10.0, subtask_end=50.0)
|
||||
expected = (25.5 - 10.0) / (50.0 - 10.0)
|
||||
assert abs(tau - expected) < 1e-6
|
||||
|
||||
|
||||
class TestComputeCumulativeProgressBatchScalar:
|
||||
"""Tests for compute_cumulative_progress_batch with scalar inputs (normalized progress y_t).
|
||||
|
||||
Formula: y_t = P_{k-1} + ᾱ_k × τ_t ∈ [0, 1]
|
||||
"""
|
||||
|
||||
def test_first_subtask_start(self):
|
||||
"""y should be 0 at start of first subtask."""
|
||||
proportions = [0.3, 0.5, 0.2]
|
||||
y = compute_cumulative_progress_batch(tau=0.0, stage_indices=0, alpha=proportions)
|
||||
assert y == 0.0
|
||||
|
||||
def test_first_subtask_end(self):
|
||||
"""y should equal ᾱ_1 at end of first subtask."""
|
||||
proportions = [0.3, 0.5, 0.2]
|
||||
y = compute_cumulative_progress_batch(tau=1.0, stage_indices=0, alpha=proportions)
|
||||
assert abs(y - 0.3) < 1e-6
|
||||
|
||||
def test_second_subtask_start(self):
|
||||
"""y should equal P_1 at start of second subtask."""
|
||||
proportions = [0.3, 0.5, 0.2]
|
||||
y = compute_cumulative_progress_batch(tau=0.0, stage_indices=1, alpha=proportions)
|
||||
assert abs(y - 0.3) < 1e-6
|
||||
|
||||
def test_second_subtask_end(self):
|
||||
"""y should equal P_2 at end of second subtask."""
|
||||
proportions = [0.3, 0.5, 0.2]
|
||||
y = compute_cumulative_progress_batch(tau=1.0, stage_indices=1, alpha=proportions)
|
||||
assert abs(y - 0.8) < 1e-6 # 0.3 + 0.5
|
||||
|
||||
def test_third_subtask_end(self):
|
||||
"""y should be 1.0 at end of last subtask."""
|
||||
proportions = [0.3, 0.5, 0.2]
|
||||
y = compute_cumulative_progress_batch(tau=1.0, stage_indices=2, alpha=proportions)
|
||||
assert abs(y - 1.0) < 1e-6
|
||||
|
||||
def test_midpoint_of_subtask(self):
|
||||
"""Test progress at midpoint of a subtask."""
|
||||
proportions = [0.4, 0.6]
|
||||
# At τ=0.5 in subtask 1: y = P_0 + ᾱ_1 × 0.5 = 0 + 0.4 × 0.5 = 0.2
|
||||
y = compute_cumulative_progress_batch(tau=0.5, stage_indices=0, alpha=proportions)
|
||||
assert abs(y - 0.2) < 1e-6
|
||||
|
||||
# At τ=0.5 in subtask 2: y = P_1 + ᾱ_2 × 0.5 = 0.4 + 0.6 × 0.5 = 0.7
|
||||
y = compute_cumulative_progress_batch(tau=0.5, stage_indices=1, alpha=proportions)
|
||||
assert abs(y - 0.7) < 1e-6
|
||||
|
||||
def test_uniform_proportions(self):
|
||||
"""Test with uniform proportions."""
|
||||
proportions = [0.25, 0.25, 0.25, 0.25]
|
||||
|
||||
# At end of each subtask, progress should be 0.25, 0.5, 0.75, 1.0
|
||||
for i in range(4):
|
||||
y = compute_cumulative_progress_batch(tau=1.0, stage_indices=i, alpha=proportions)
|
||||
expected = (i + 1) * 0.25
|
||||
assert abs(y - expected) < 1e-6
|
||||
|
||||
|
||||
class TestComputeCumulativeProgressBatchTensor:
|
||||
"""Tests for compute_cumulative_progress_batch with tensor inputs (GPU batch version)."""
|
||||
|
||||
def test_tensor_matches_scalar_version(self):
|
||||
"""Test that tensor version matches scalar version."""
|
||||
proportions = [0.3, 0.5, 0.2]
|
||||
alpha = torch.tensor(proportions, dtype=torch.float32)
|
||||
cumulative = torch.zeros(len(proportions) + 1, dtype=torch.float32)
|
||||
cumulative[1:] = torch.cumsum(alpha, dim=0)
|
||||
|
||||
test_cases = [
|
||||
(0.0, 0), # start of subtask 0
|
||||
(1.0, 0), # end of subtask 0
|
||||
(0.0, 1), # start of subtask 1
|
||||
(0.5, 1), # middle of subtask 1
|
||||
(1.0, 2), # end of subtask 2
|
||||
]
|
||||
|
||||
for tau_val, stage_idx in test_cases:
|
||||
# Scalar version
|
||||
expected = compute_cumulative_progress_batch(tau_val, stage_idx, proportions)
|
||||
|
||||
# Tensor version (single element)
|
||||
tau = torch.tensor([[[tau_val]]]) # (1, 1, 1)
|
||||
stages = torch.tensor([[stage_idx]]) # (1, 1)
|
||||
result = compute_cumulative_progress_batch(tau, stages, alpha, cumulative)
|
||||
|
||||
assert abs(result[0, 0, 0].item() - expected) < 1e-6
|
||||
|
||||
def test_batch_processing(self):
|
||||
"""Test batch processing with multiple samples."""
|
||||
proportions = [0.4, 0.6]
|
||||
alpha = torch.tensor(proportions, dtype=torch.float32)
|
||||
cumulative = torch.zeros(3, dtype=torch.float32)
|
||||
cumulative[1:] = torch.cumsum(alpha, dim=0)
|
||||
|
||||
# Batch of 2 samples, sequence length 3
|
||||
tau = torch.tensor([
|
||||
[[0.0], [0.5], [1.0]], # sample 1
|
||||
[[0.0], [0.5], [1.0]], # sample 2
|
||||
])
|
||||
stages = torch.tensor([
|
||||
[0, 0, 0], # sample 1: all in subtask 0
|
||||
[1, 1, 1], # sample 2: all in subtask 1
|
||||
])
|
||||
|
||||
result = compute_cumulative_progress_batch(tau, stages, alpha, cumulative)
|
||||
|
||||
# Sample 1: subtask 0 with tau 0, 0.5, 1.0 -> y = 0, 0.2, 0.4
|
||||
assert abs(result[0, 0, 0].item() - 0.0) < 1e-6
|
||||
assert abs(result[0, 1, 0].item() - 0.2) < 1e-6
|
||||
assert abs(result[0, 2, 0].item() - 0.4) < 1e-6
|
||||
|
||||
# Sample 2: subtask 1 with tau 0, 0.5, 1.0 -> y = 0.4, 0.7, 1.0
|
||||
assert abs(result[1, 0, 0].item() - 0.4) < 1e-6
|
||||
assert abs(result[1, 1, 0].item() - 0.7) < 1e-6
|
||||
assert abs(result[1, 2, 0].item() - 1.0) < 1e-6
|
||||
|
||||
def test_auto_compute_cumulative_prior(self):
|
||||
"""Test that cumulative_prior is auto-computed when not provided."""
|
||||
proportions = [0.3, 0.5, 0.2]
|
||||
alpha = torch.tensor(proportions, dtype=torch.float32)
|
||||
|
||||
tau = torch.tensor([[[0.5]]])
|
||||
stages = torch.tensor([[1]])
|
||||
|
||||
# Without cumulative_prior (should auto-compute)
|
||||
result = compute_cumulative_progress_batch(tau, stages, alpha)
|
||||
|
||||
# Expected: P_0 + alpha_1 * 0.5 = 0.3 + 0.5 * 0.5 = 0.55
|
||||
assert abs(result[0, 0, 0].item() - 0.55) < 1e-6
|
||||
|
||||
|
||||
class TestEndToEndProgressLabeling:
|
||||
"""End-to-end tests for progress label computation."""
|
||||
|
||||
def test_consistent_semantic_meaning(self):
|
||||
"""Test that same subtask completion maps to same progress across trajectories.
|
||||
|
||||
This is the key semantic property: "end of subtask 1" should always
|
||||
mean the same progress value regardless of trajectory speed.
|
||||
"""
|
||||
proportions = [0.3, 0.5, 0.2]
|
||||
|
||||
# Fast trajectory: subtask 1 ends at frame 30 (of 100)
|
||||
tau_fast = compute_tau(30, 0, 30) # = 1.0
|
||||
y_fast = compute_cumulative_progress_batch(tau_fast, 0, proportions)
|
||||
|
||||
# Slow trajectory: subtask 1 ends at frame 90 (of 300)
|
||||
tau_slow = compute_tau(90, 0, 90) # = 1.0
|
||||
y_slow = compute_cumulative_progress_batch(tau_slow, 0, proportions)
|
||||
|
||||
# Both should map to same progress (0.3 = end of subtask 1)
|
||||
assert abs(y_fast - y_slow) < 1e-6
|
||||
assert abs(y_fast - 0.3) < 1e-6
|
||||
|
||||
def test_monotonic_within_subtask(self):
|
||||
"""Test that progress is monotonically increasing within a subtask."""
|
||||
proportions = [0.4, 0.6]
|
||||
|
||||
prev_y = -1
|
||||
for tau in np.linspace(0, 1, 11):
|
||||
y = compute_cumulative_progress_batch(tau, 0, proportions)
|
||||
assert y > prev_y or (tau == 0 and y == 0)
|
||||
prev_y = y
|
||||
|
||||
def test_continuous_across_subtasks(self):
|
||||
"""Test that progress is continuous at subtask boundaries."""
|
||||
proportions = [0.3, 0.5, 0.2]
|
||||
|
||||
# End of subtask 0 (tau=1.0)
|
||||
y_end_0 = compute_cumulative_progress_batch(1.0, 0, proportions)
|
||||
|
||||
# Start of subtask 1 (tau=0.0)
|
||||
y_start_1 = compute_cumulative_progress_batch(0.0, 1, proportions)
|
||||
|
||||
# Should be equal (P_1 = 0.3)
|
||||
assert abs(y_end_0 - y_start_1) < 1e-6
|
||||
|
||||
# End of subtask 1
|
||||
y_end_1 = compute_cumulative_progress_batch(1.0, 1, proportions)
|
||||
|
||||
# Start of subtask 2
|
||||
y_start_2 = compute_cumulative_progress_batch(0.0, 2, proportions)
|
||||
|
||||
# Should be equal (P_2 = 0.8)
|
||||
assert abs(y_end_1 - y_start_2) < 1e-6
|
||||
|
||||
Reference in New Issue
Block a user