mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-23 12:40:08 +00:00
fix progress conversion and adding initial frame
This commit is contained in:
@@ -224,14 +224,14 @@ def run_inference(
|
||||
"""
|
||||
Run SARM inference on video frames and joint states.
|
||||
|
||||
For each frame t, creates a temporal sequence of 9 frames using SARM's pattern:
|
||||
[t-240, t-210, t-180, t-150, t-120, t-90, t-60, t-30, t]
|
||||
This matches the training pattern where frames are loaded with 30-frame gaps
|
||||
relative to the current frame.
|
||||
(per SARM paper Section A.4):
|
||||
- Frame 0: Initial frame of the episode (frame 0)
|
||||
- Frames 1-8: 8 consecutive frames with frame_gap spacing ending at current frame t
|
||||
Pattern: [frame_0, t-(7*gap), t-(6*gap), ..., t-gap, t]
|
||||
|
||||
Args:
|
||||
model: SARM model
|
||||
frames: Video frames (num_frames, H, W, C)
|
||||
frames: Video frames (num_frames, H, W, C) - all frames from ONE episode
|
||||
states: Joint states (num_frames, state_dim)
|
||||
task_description: Task description text
|
||||
batch_size: Batch size for processing slices
|
||||
@@ -247,7 +247,12 @@ def run_inference(
|
||||
logger.info("Encoding task description with MiniLM...")
|
||||
text_embedding = model.encode_text(task_description)
|
||||
|
||||
logger.info("Creating video slices (SARM approach)...")
|
||||
# Get config values
|
||||
num_frames_model = model.config.num_frames # 9
|
||||
frame_gap = model.config.frame_gap # 30
|
||||
|
||||
logger.info("Creating video slices (SARM paper: initial frame + 8 consecutive)...")
|
||||
|
||||
# Convert to tensors
|
||||
video_embeddings = torch.tensor(video_embeddings, dtype=torch.float32)
|
||||
text_embedding = torch.tensor(text_embedding, dtype=torch.float32)
|
||||
@@ -256,33 +261,14 @@ def run_inference(
|
||||
else:
|
||||
state_embeddings = None
|
||||
|
||||
# Create video slices: for each frame i, create a sequence using SARM's pattern
|
||||
# For SARM: 9 frames relative to current, with 30-frame gaps
|
||||
# Pattern: [current-240, current-210, ..., current-30, current]
|
||||
num_frames_model = model.config.num_frames
|
||||
frame_gap = model.config.frame_gap
|
||||
|
||||
video_slices = []
|
||||
state_slices = []
|
||||
last_frame_indices = []
|
||||
|
||||
for i in tqdm(range(len(video_embeddings)), desc="Creating slices"):
|
||||
# For SARM, create sequence relative to current frame (matching training pattern)
|
||||
# Pattern: [current-240, current-210, ..., current-30, current]
|
||||
# This matches observation_delta_indices: range(-240, 1, 30)
|
||||
|
||||
# Compute frame indices for this slice (relative to current frame i)
|
||||
frame_indices = []
|
||||
for j in range(num_frames_model):
|
||||
# Start from -(num_frames_model-1) * frame_gap and go to 0
|
||||
offset = -(num_frames_model - 1 - j) * frame_gap
|
||||
idx = i + offset
|
||||
|
||||
# Clamp to valid range [0, current_frame]
|
||||
if idx < 0:
|
||||
idx = 0 # Pad with first available frame
|
||||
|
||||
frame_indices.append(idx)
|
||||
for current_frame in tqdm(range(len(video_embeddings)), desc="Creating slices"):
|
||||
# Compute frame indices using SARM pattern:
|
||||
# [initial_frame, t-(7*gap), t-(6*gap), ..., t-gap, t]
|
||||
deltas = model.config.observation_delta_indices(current_frame)
|
||||
frame_indices = [max(0, current_frame + delta) for delta in deltas]
|
||||
|
||||
# Extract slice
|
||||
video_slice = video_embeddings[frame_indices]
|
||||
@@ -292,9 +278,6 @@ def run_inference(
|
||||
state_slice = state_embeddings[frame_indices]
|
||||
state_slices.append(state_slice)
|
||||
|
||||
# Track which frame index corresponds to the "current" frame
|
||||
last_frame_indices.append(min(i, len(frame_indices) - 1))
|
||||
|
||||
video_slices = torch.stack(video_slices) # (num_frames, num_frames_model, 512)
|
||||
if state_embeddings is not None:
|
||||
state_slices = torch.stack(state_slices) # (num_frames, num_frames_model, state_dim)
|
||||
@@ -320,7 +303,6 @@ def run_inference(
|
||||
)
|
||||
|
||||
# Extract last frame predictions (the "current" frame)
|
||||
# For SARM, we take the last frame in each sequence
|
||||
batch_progress = progress_preds[:, -1, 0].cpu().numpy()
|
||||
batch_stages = stage_probs[:, -1, :].cpu().numpy()
|
||||
|
||||
|
||||
@@ -44,6 +44,7 @@ class SARMConfig(PreTrainedConfig):
|
||||
num_layers: int = 8
|
||||
num_stages: int = 5 # Number of task stages for classification (auto-updated from annotations if available)
|
||||
subtask_names: list | None = None # List of subtask names (auto-populated from annotations)
|
||||
temporal_proportions: list | None = None # Temporal proportions for each stage (auto-computed from annotations)
|
||||
|
||||
# Temporal parameters
|
||||
max_length: int = num_frames # Maximum video sequence length (matches num_frames)
|
||||
@@ -128,20 +129,31 @@ class SARMConfig(PreTrainedConfig):
|
||||
"""Validate input and output features."""
|
||||
pass
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> list[int]:
|
||||
"""Load frames for SARM temporal sampling.
|
||||
def observation_delta_indices(self, episode_frame_index: int) -> list[int]:
|
||||
"""Compute delta indices for SARM temporal sampling.
|
||||
|
||||
SARM uses 9 frames: 1 initial frame + 8 consecutive frames with frame_gap spacing.
|
||||
Per SARM paper (Section A.4), the model uses 9 frames:
|
||||
- Frame 0: Initial frame of the episode (delta = -episode_frame_index)
|
||||
- Frames 1-8: 8 consecutive frames with frame_gap spacing ending at current frame
|
||||
|
||||
The dataloader converts these to seconds: delta_seconds = delta / fps
|
||||
This means the first delta (-episode_frame_index) becomes -current_time,
|
||||
which correctly points to t=0 (the initial frame).
|
||||
|
||||
Args:
|
||||
episode_frame_index: Current frame index within the episode (0, 1, 2, ...)
|
||||
|
||||
Returns:
|
||||
Indices for loading: [-(8*frame_gap), ..., -frame_gap, 0]
|
||||
9 delta indices: [-episode_frame_index, -(7*gap), -(6*gap), ..., -gap, 0]
|
||||
"""
|
||||
# For SARM: we need the initial frame (from episode start) plus 8 consecutive frames
|
||||
# The dataset will load relative to current frame
|
||||
# We'll handle the "initial frame" logic in the processor
|
||||
# For now, load the last 8*frame_gap frames
|
||||
return list(range(-self.frame_gap * (self.num_frames - 1), 1, self.frame_gap))
|
||||
# First delta: negative of current frame index to reach frame 0
|
||||
initial_frame_delta = -episode_frame_index
|
||||
|
||||
# Remaining 8 deltas: consecutive frames with frame_gap spacing
|
||||
num_consecutive = self.num_frames - 1 # 8 frames
|
||||
consecutive_deltas = list(range(-self.frame_gap * (num_consecutive - 1), 1, self.frame_gap))
|
||||
|
||||
return [initial_frame_delta] + consecutive_deltas
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> None:
|
||||
|
||||
@@ -19,6 +19,7 @@ from typing import List, Union, Dict, Optional
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
@@ -71,13 +72,31 @@ class SARMTransformer(nn.Module):
|
||||
num_layers: int = 8,
|
||||
num_stages: int = 5,
|
||||
max_length: int = 9,
|
||||
dropout: float = 0.1
|
||||
dropout: float = 0.1,
|
||||
temporal_proportions: list[float] | None = None
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_dim = hidden_dim
|
||||
self.max_length = max_length
|
||||
self.num_stages = num_stages
|
||||
|
||||
# Store temporal proportions for progress conversion (Paper Eq. 4)
|
||||
# ŷ = P_{k-1} + ᾱ_k × τ̂
|
||||
if temporal_proportions is None:
|
||||
raise ValueError(
|
||||
"temporal_proportions is required for SARM. "
|
||||
"Provide subtask annotations in your dataset or set temporal_proportions in config."
|
||||
)
|
||||
|
||||
# ᾱ_k: proportion for each stage
|
||||
alpha = torch.tensor(temporal_proportions, dtype=torch.float32)
|
||||
|
||||
# P_k: cumulative proportion up to stage k (P_0 = 0)
|
||||
cumulative = torch.zeros(num_stages + 1, dtype=torch.float32)
|
||||
cumulative[1:] = torch.cumsum(alpha, dim=0)
|
||||
self.register_buffer('alpha', alpha)
|
||||
self.register_buffer('cumulative_prior', cumulative)
|
||||
|
||||
# Project video, text, and state to same dimension
|
||||
self.video_proj = nn.Linear(video_dim, hidden_dim)
|
||||
self.text_proj = nn.Linear(text_dim, hidden_dim)
|
||||
@@ -97,24 +116,26 @@ class SARMTransformer(nn.Module):
|
||||
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
|
||||
|
||||
# Stage estimator head (classification)
|
||||
# Paper A.4: "2 layers with hidden dimension of 512"
|
||||
self.stage_head = nn.Sequential(
|
||||
nn.Linear(hidden_dim, hidden_dim // 2),
|
||||
nn.LayerNorm(hidden_dim // 2),
|
||||
nn.Linear(hidden_dim, 512),
|
||||
nn.LayerNorm(512),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_dim // 2, num_stages)
|
||||
nn.Linear(512, num_stages)
|
||||
)
|
||||
|
||||
# Subtask estimator head (regression, conditioned on stage)
|
||||
# Takes concatenated [features, stage_embedding]
|
||||
# Paper A.4: "2 layers with hidden dimension of 512"
|
||||
self.stage_embedding = nn.Embedding(num_stages, hidden_dim // 4)
|
||||
subtask_input_dim = hidden_dim + hidden_dim // 4
|
||||
self.subtask_head = nn.Sequential(
|
||||
nn.Linear(subtask_input_dim, hidden_dim // 2),
|
||||
nn.LayerNorm(hidden_dim // 2),
|
||||
nn.Linear(subtask_input_dim, 512),
|
||||
nn.LayerNorm(512),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_dim // 2, 1),
|
||||
nn.Linear(512, 1),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
|
||||
@@ -189,7 +210,17 @@ class SARMTransformer(nn.Module):
|
||||
conditioned_features = torch.cat([frame_features, stage_embeds], dim=-1)
|
||||
|
||||
# Subtask progress estimation (conditioned on stage)
|
||||
progress_preds = self.subtask_head(conditioned_features) # [batch_size, seq_len, 1]
|
||||
# τ̂ = within-subtask progress (0-1)
|
||||
tau_preds = self.subtask_head(conditioned_features) # [batch_size, seq_len, 1]
|
||||
|
||||
# Convert τ̂ to cumulative progress ŷ using Paper Eq. 4:
|
||||
# ŷ = P_{k-1} + ᾱ_k × τ̂
|
||||
# P_{k-1} = cumulative prior up to stage k-1
|
||||
# ᾱ_k = temporal proportion of stage k
|
||||
P_k_minus_1 = self.cumulative_prior[stage_indices] # [batch_size, seq_len]
|
||||
alpha_k = self.alpha[stage_indices] # [batch_size, seq_len]
|
||||
|
||||
progress_preds = P_k_minus_1.unsqueeze(-1) + alpha_k.unsqueeze(-1) * tau_preds
|
||||
|
||||
return stage_logits, stage_probs, progress_preds
|
||||
|
||||
@@ -263,7 +294,8 @@ class SARMRewardModel(PreTrainedPolicy):
|
||||
"2. Ensure dataset_stats contains 'observation.state' or 'state' key"
|
||||
)
|
||||
|
||||
# Initialize SARM transformer
|
||||
# Initialize SARM transformer with temporal proportions for progress conversion
|
||||
temporal_proportions = getattr(config, 'temporal_proportions', None)
|
||||
self.sarm_transformer = SARMTransformer(
|
||||
video_dim=config.image_dim,
|
||||
text_dim=config.text_dim,
|
||||
@@ -273,7 +305,8 @@ class SARMRewardModel(PreTrainedPolicy):
|
||||
num_layers=config.num_layers,
|
||||
num_stages=config.num_stages,
|
||||
max_length=config.max_length,
|
||||
dropout=config.dropout
|
||||
dropout=config.dropout,
|
||||
temporal_proportions=temporal_proportions
|
||||
)
|
||||
self.sarm_transformer.to(self.device)
|
||||
|
||||
@@ -281,7 +314,7 @@ class SARMRewardModel(PreTrainedPolicy):
|
||||
logging.info(f"SARM Reward Model initialized on {self.device}")
|
||||
|
||||
def _update_num_stages_from_dataset(self, dataset_meta) -> None:
|
||||
"""Update num_stages in config based on dataset subtask annotations."""
|
||||
"""Update num_stages and temporal_proportions from dataset subtask annotations."""
|
||||
episodes = dataset_meta.episodes
|
||||
if episodes is None or len(episodes) == 0:
|
||||
raise ValueError("No episodes found, using default num_stages")
|
||||
@@ -291,14 +324,28 @@ class SARMRewardModel(PreTrainedPolicy):
|
||||
|
||||
episodes_df = episodes.to_pandas()
|
||||
|
||||
# Collect all unique subtask names
|
||||
# Collect all unique subtask names and compute durations
|
||||
all_subtask_names = set()
|
||||
subtask_durations = {}
|
||||
|
||||
for ep_idx in episodes_df.index:
|
||||
subtask_names = episodes_df.loc[ep_idx, 'subtask_names']
|
||||
if subtask_names is None or (isinstance(subtask_names, float) and pd.isna(subtask_names)):
|
||||
continue
|
||||
|
||||
all_subtask_names.update(subtask_names)
|
||||
|
||||
# Compute durations if available
|
||||
if 'subtask_start_frames' in episodes_df.columns and 'subtask_end_frames' in episodes_df.columns:
|
||||
start_frames = episodes_df.loc[ep_idx, 'subtask_start_frames']
|
||||
end_frames = episodes_df.loc[ep_idx, 'subtask_end_frames']
|
||||
|
||||
for i, name in enumerate(subtask_names):
|
||||
duration = end_frames[i] - start_frames[i]
|
||||
if name not in subtask_durations:
|
||||
subtask_durations[name] = []
|
||||
subtask_durations[name].append(duration)
|
||||
|
||||
if not all_subtask_names:
|
||||
raise ValueError("No valid subtask names found, using default num_stages")
|
||||
|
||||
@@ -306,10 +353,26 @@ class SARMRewardModel(PreTrainedPolicy):
|
||||
subtask_names = sorted(list(all_subtask_names))
|
||||
num_stages = len(subtask_names)
|
||||
|
||||
# Compute temporal proportions (Paper Eq. 1: ᾱ_k)
|
||||
avg_durations = {}
|
||||
for name in subtask_names:
|
||||
if name in subtask_durations and subtask_durations[name]:
|
||||
avg_durations[name] = np.mean(subtask_durations[name])
|
||||
else:
|
||||
avg_durations[name] = 1.0 # Default
|
||||
|
||||
total_duration = sum(avg_durations.values())
|
||||
if total_duration > 0:
|
||||
temporal_proportions = [avg_durations[name] / total_duration for name in subtask_names]
|
||||
else:
|
||||
temporal_proportions = [1.0 / num_stages] * num_stages
|
||||
|
||||
self.config.num_stages = num_stages
|
||||
self.config.subtask_names = subtask_names
|
||||
self.config.temporal_proportions = temporal_proportions
|
||||
|
||||
logging.info(f"Auto-detected {num_stages} subtasks from dataset: {subtask_names}, using {num_stages} stages")
|
||||
logging.info(f"Auto-detected {num_stages} subtasks: {subtask_names}")
|
||||
logging.info(f"Temporal proportions: {dict(zip(subtask_names, temporal_proportions))}")
|
||||
|
||||
def to(self, device):
|
||||
"""Override to method to ensure all components move together."""
|
||||
@@ -357,7 +420,7 @@ class SARMRewardModel(PreTrainedPolicy):
|
||||
# Batch process frames with CLIP
|
||||
for i in range(0, len(frames), self.config.clip_batch_size):
|
||||
batch = frames[i:i + self.config.clip_batch_size]
|
||||
inputs = self.clip_processor(images=batch, return_tensors="pt", padding=True)
|
||||
inputs = self.clip_processor(images=batch, return_tensors="pt")
|
||||
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
||||
|
||||
# Get image embeddings from CLIP
|
||||
@@ -578,8 +641,8 @@ class SARMRewardModel(PreTrainedPolicy):
|
||||
state: torch.Tensor | None,
|
||||
max_length: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
|
||||
"""Apply rewind augmentation: append 2-4 reversed frames (SARM paper)."""
|
||||
num_reverse = random.randint(2, min(4, max_length - 1))
|
||||
"""Apply rewind augmentation: append up to 4 reversed frames (SARM paper A.4)."""
|
||||
num_reverse = random.randint(1, min(4, max_length - 1))
|
||||
|
||||
# Reverse and take frames (skip first which is last of original)
|
||||
reversed_video = video.flip(0)[1:num_reverse + 1]
|
||||
|
||||
@@ -208,15 +208,31 @@ class SARMEncodingProcessorStep(ProcessorStep):
|
||||
return episode_indices
|
||||
|
||||
def _compute_absolute_indices(self, frame_idx: int, ep_start: int, num_frames: int) -> torch.Tensor:
|
||||
"""Compute absolute frame indices for a sequence."""
|
||||
"""Compute absolute frame indices for a sequence.
|
||||
|
||||
(per SARM paper Section A.4):
|
||||
- Frame 0: Initial frame of the episode (ep_start)
|
||||
- Frames 1-8: 8 consecutive frames with frame_gap spacing ending at current frame
|
||||
Pattern: [ep_start, t-(7*gap), t-(6*gap), ..., t-gap, t]
|
||||
|
||||
"""
|
||||
frame_gap = getattr(self.config, 'frame_gap', 1)
|
||||
|
||||
if frame_gap > 1:
|
||||
indices = [max(ep_start, frame_idx - (num_frames - 1 - i) * frame_gap) for i in range(num_frames)]
|
||||
return torch.tensor(indices)
|
||||
else:
|
||||
start_idx = max(ep_start, frame_idx - num_frames + 1)
|
||||
return torch.arange(start_idx, frame_idx + 1)
|
||||
indices = []
|
||||
|
||||
|
||||
# First frame is the episode's initial frame
|
||||
indices.append(ep_start)
|
||||
|
||||
# Remaining frames are consecutive with frame_gap spacing
|
||||
num_consecutive = num_frames - 1
|
||||
for i in range(num_consecutive):
|
||||
offset = -(num_consecutive - 1 - i) * frame_gap
|
||||
idx = max(ep_start, frame_idx + offset)
|
||||
indices.append(idx)
|
||||
|
||||
|
||||
return torch.tensor(indices)
|
||||
|
||||
def _compute_episode_metadata(
|
||||
self,
|
||||
@@ -324,6 +340,10 @@ class SARMEncodingProcessorStep(ProcessorStep):
|
||||
) -> tuple[torch.Tensor, torch.Tensor] | tuple[None, None]:
|
||||
"""Compute stage labels and progress targets for a single sample.
|
||||
|
||||
(per SARM paper Section A.4):
|
||||
- Frame 0: Initial frame of episode (stage at frame 0, progress at frame 0)
|
||||
- Frames 1-8: 8 consecutive frames with frame_gap spacing ending at current frame
|
||||
|
||||
Args:
|
||||
frame_idx: The frame index for this sample
|
||||
ep_idx: The episode index
|
||||
@@ -348,7 +368,7 @@ class SARMEncodingProcessorStep(ProcessorStep):
|
||||
# Get episode boundaries
|
||||
ep_start = self.dataset_meta.episodes[ep_idx]["dataset_from_index"]
|
||||
|
||||
# Get frame gap for temporal sampling
|
||||
# Get config values
|
||||
frame_gap = self.config.frame_gap if hasattr(self.config, 'frame_gap') else 1
|
||||
|
||||
# Generate labels for each frame in the sequence
|
||||
@@ -356,12 +376,15 @@ class SARMEncodingProcessorStep(ProcessorStep):
|
||||
progress_targets = []
|
||||
|
||||
for i in range(seq_len):
|
||||
# Calculate actual frame index for this position in sequence
|
||||
if frame_gap > 1:
|
||||
offset = -(seq_len - 1 - i) * frame_gap
|
||||
current_frame = max(0, frame_idx + offset - ep_start)
|
||||
if i == 0:
|
||||
# Position 0: Initial frame of the episode
|
||||
current_frame = 0 # Relative to episode start
|
||||
else:
|
||||
current_frame = max(0, frame_idx - seq_len + 1 + i - ep_start)
|
||||
# Positions 1-8: consecutive frames with frame_gap spacing
|
||||
num_consecutive = seq_len - 1
|
||||
offset = -(num_consecutive - i) * frame_gap
|
||||
current_frame = max(0, frame_idx + offset - ep_start)
|
||||
|
||||
|
||||
stage_idx, cumulative_progress = self._compute_stage_and_progress_for_frame(
|
||||
current_frame, subtask_names, subtask_start_frames, subtask_end_frames
|
||||
@@ -564,7 +587,7 @@ class SARMEncodingProcessorStep(ProcessorStep):
|
||||
batch_imgs = images_list[i:i + self.config.clip_batch_size]
|
||||
|
||||
# Process with CLIP
|
||||
inputs = self.clip_processor(images=batch_imgs, return_tensors="pt", padding=True)
|
||||
inputs = self.clip_processor(images=batch_imgs, return_tensors="pt")
|
||||
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
||||
|
||||
# Get image embeddings
|
||||
@@ -707,3 +730,5 @@ def make_sarm_pre_post_processors(
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user