fix progress conversion and adding initial frame

This commit is contained in:
Pepijn
2025-11-26 11:02:42 +01:00
parent c66aef878c
commit cc2e91febe
4 changed files with 156 additions and 74 deletions
+16 -34
View File
@@ -224,14 +224,14 @@ def run_inference(
"""
Run SARM inference on video frames and joint states.
For each frame t, creates a temporal sequence of 9 frames using SARM's pattern:
[t-240, t-210, t-180, t-150, t-120, t-90, t-60, t-30, t]
This matches the training pattern where frames are loaded with 30-frame gaps
relative to the current frame.
(per SARM paper Section A.4):
- Frame 0: Initial frame of the episode (frame 0)
- Frames 1-8: 8 consecutive frames with frame_gap spacing ending at current frame t
Pattern: [frame_0, t-(7*gap), t-(6*gap), ..., t-gap, t]
Args:
model: SARM model
frames: Video frames (num_frames, H, W, C)
frames: Video frames (num_frames, H, W, C) - all frames from ONE episode
states: Joint states (num_frames, state_dim)
task_description: Task description text
batch_size: Batch size for processing slices
@@ -247,7 +247,12 @@ def run_inference(
logger.info("Encoding task description with MiniLM...")
text_embedding = model.encode_text(task_description)
logger.info("Creating video slices (SARM approach)...")
# Get config values
num_frames_model = model.config.num_frames # 9
frame_gap = model.config.frame_gap # 30
logger.info("Creating video slices (SARM paper: initial frame + 8 consecutive)...")
# Convert to tensors
video_embeddings = torch.tensor(video_embeddings, dtype=torch.float32)
text_embedding = torch.tensor(text_embedding, dtype=torch.float32)
@@ -256,33 +261,14 @@ def run_inference(
else:
state_embeddings = None
# Create video slices: for each frame i, create a sequence using SARM's pattern
# For SARM: 9 frames relative to current, with 30-frame gaps
# Pattern: [current-240, current-210, ..., current-30, current]
num_frames_model = model.config.num_frames
frame_gap = model.config.frame_gap
video_slices = []
state_slices = []
last_frame_indices = []
for i in tqdm(range(len(video_embeddings)), desc="Creating slices"):
# For SARM, create sequence relative to current frame (matching training pattern)
# Pattern: [current-240, current-210, ..., current-30, current]
# This matches observation_delta_indices: range(-240, 1, 30)
# Compute frame indices for this slice (relative to current frame i)
frame_indices = []
for j in range(num_frames_model):
# Start from -(num_frames_model-1) * frame_gap and go to 0
offset = -(num_frames_model - 1 - j) * frame_gap
idx = i + offset
# Clamp to valid range [0, current_frame]
if idx < 0:
idx = 0 # Pad with first available frame
frame_indices.append(idx)
for current_frame in tqdm(range(len(video_embeddings)), desc="Creating slices"):
# Compute frame indices using SARM pattern:
# [initial_frame, t-(7*gap), t-(6*gap), ..., t-gap, t]
deltas = model.config.observation_delta_indices(current_frame)
frame_indices = [max(0, current_frame + delta) for delta in deltas]
# Extract slice
video_slice = video_embeddings[frame_indices]
@@ -292,9 +278,6 @@ def run_inference(
state_slice = state_embeddings[frame_indices]
state_slices.append(state_slice)
# Track which frame index corresponds to the "current" frame
last_frame_indices.append(min(i, len(frame_indices) - 1))
video_slices = torch.stack(video_slices) # (num_frames, num_frames_model, 512)
if state_embeddings is not None:
state_slices = torch.stack(state_slices) # (num_frames, num_frames_model, state_dim)
@@ -320,7 +303,6 @@ def run_inference(
)
# Extract last frame predictions (the "current" frame)
# For SARM, we take the last frame in each sequence
batch_progress = progress_preds[:, -1, 0].cpu().numpy()
batch_stages = stage_probs[:, -1, :].cpu().numpy()
+22 -10
View File
@@ -44,6 +44,7 @@ class SARMConfig(PreTrainedConfig):
num_layers: int = 8
num_stages: int = 5 # Number of task stages for classification (auto-updated from annotations if available)
subtask_names: list | None = None # List of subtask names (auto-populated from annotations)
temporal_proportions: list | None = None # Temporal proportions for each stage (auto-computed from annotations)
# Temporal parameters
max_length: int = num_frames # Maximum video sequence length (matches num_frames)
@@ -128,20 +129,31 @@ class SARMConfig(PreTrainedConfig):
"""Validate input and output features."""
pass
@property
def observation_delta_indices(self) -> list[int]:
"""Load frames for SARM temporal sampling.
def observation_delta_indices(self, episode_frame_index: int) -> list[int]:
"""Compute delta indices for SARM temporal sampling.
SARM uses 9 frames: 1 initial frame + 8 consecutive frames with frame_gap spacing.
Per SARM paper (Section A.4), the model uses 9 frames:
- Frame 0: Initial frame of the episode (delta = -episode_frame_index)
- Frames 1-8: 8 consecutive frames with frame_gap spacing ending at current frame
The dataloader converts these to seconds: delta_seconds = delta / fps
This means the first delta (-episode_frame_index) becomes -current_time,
which correctly points to t=0 (the initial frame).
Args:
episode_frame_index: Current frame index within the episode (0, 1, 2, ...)
Returns:
Indices for loading: [-(8*frame_gap), ..., -frame_gap, 0]
9 delta indices: [-episode_frame_index, -(7*gap), -(6*gap), ..., -gap, 0]
"""
# For SARM: we need the initial frame (from episode start) plus 8 consecutive frames
# The dataset will load relative to current frame
# We'll handle the "initial frame" logic in the processor
# For now, load the last 8*frame_gap frames
return list(range(-self.frame_gap * (self.num_frames - 1), 1, self.frame_gap))
# First delta: negative of current frame index to reach frame 0
initial_frame_delta = -episode_frame_index
# Remaining 8 deltas: consecutive frames with frame_gap spacing
num_consecutive = self.num_frames - 1 # 8 frames
consecutive_deltas = list(range(-self.frame_gap * (num_consecutive - 1), 1, self.frame_gap))
return [initial_frame_delta] + consecutive_deltas
@property
def action_delta_indices(self) -> None:
+79 -16
View File
@@ -19,6 +19,7 @@ from typing import List, Union, Dict, Optional
import random
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
@@ -71,13 +72,31 @@ class SARMTransformer(nn.Module):
num_layers: int = 8,
num_stages: int = 5,
max_length: int = 9,
dropout: float = 0.1
dropout: float = 0.1,
temporal_proportions: list[float] | None = None
):
super().__init__()
self.hidden_dim = hidden_dim
self.max_length = max_length
self.num_stages = num_stages
# Store temporal proportions for progress conversion (Paper Eq. 4)
# ŷ = P_{k-1} + ᾱ_k × τ̂
if temporal_proportions is None:
raise ValueError(
"temporal_proportions is required for SARM. "
"Provide subtask annotations in your dataset or set temporal_proportions in config."
)
# ᾱ_k: proportion for each stage
alpha = torch.tensor(temporal_proportions, dtype=torch.float32)
# P_k: cumulative proportion up to stage k (P_0 = 0)
cumulative = torch.zeros(num_stages + 1, dtype=torch.float32)
cumulative[1:] = torch.cumsum(alpha, dim=0)
self.register_buffer('alpha', alpha)
self.register_buffer('cumulative_prior', cumulative)
# Project video, text, and state to same dimension
self.video_proj = nn.Linear(video_dim, hidden_dim)
self.text_proj = nn.Linear(text_dim, hidden_dim)
@@ -97,24 +116,26 @@ class SARMTransformer(nn.Module):
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
# Stage estimator head (classification)
# Paper A.4: "2 layers with hidden dimension of 512"
self.stage_head = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim // 2),
nn.LayerNorm(hidden_dim // 2),
nn.Linear(hidden_dim, 512),
nn.LayerNorm(512),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim // 2, num_stages)
nn.Linear(512, num_stages)
)
# Subtask estimator head (regression, conditioned on stage)
# Takes concatenated [features, stage_embedding]
# Paper A.4: "2 layers with hidden dimension of 512"
self.stage_embedding = nn.Embedding(num_stages, hidden_dim // 4)
subtask_input_dim = hidden_dim + hidden_dim // 4
self.subtask_head = nn.Sequential(
nn.Linear(subtask_input_dim, hidden_dim // 2),
nn.LayerNorm(hidden_dim // 2),
nn.Linear(subtask_input_dim, 512),
nn.LayerNorm(512),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim // 2, 1),
nn.Linear(512, 1),
nn.Sigmoid()
)
@@ -189,7 +210,17 @@ class SARMTransformer(nn.Module):
conditioned_features = torch.cat([frame_features, stage_embeds], dim=-1)
# Subtask progress estimation (conditioned on stage)
progress_preds = self.subtask_head(conditioned_features) # [batch_size, seq_len, 1]
# τ̂ = within-subtask progress (0-1)
tau_preds = self.subtask_head(conditioned_features) # [batch_size, seq_len, 1]
# Convert τ̂ to cumulative progress ŷ using Paper Eq. 4:
# ŷ = P_{k-1} + ᾱ_k × τ̂
# P_{k-1} = cumulative prior up to stage k-1
# ᾱ_k = temporal proportion of stage k
P_k_minus_1 = self.cumulative_prior[stage_indices] # [batch_size, seq_len]
alpha_k = self.alpha[stage_indices] # [batch_size, seq_len]
progress_preds = P_k_minus_1.unsqueeze(-1) + alpha_k.unsqueeze(-1) * tau_preds
return stage_logits, stage_probs, progress_preds
@@ -263,7 +294,8 @@ class SARMRewardModel(PreTrainedPolicy):
"2. Ensure dataset_stats contains 'observation.state' or 'state' key"
)
# Initialize SARM transformer
# Initialize SARM transformer with temporal proportions for progress conversion
temporal_proportions = getattr(config, 'temporal_proportions', None)
self.sarm_transformer = SARMTransformer(
video_dim=config.image_dim,
text_dim=config.text_dim,
@@ -273,7 +305,8 @@ class SARMRewardModel(PreTrainedPolicy):
num_layers=config.num_layers,
num_stages=config.num_stages,
max_length=config.max_length,
dropout=config.dropout
dropout=config.dropout,
temporal_proportions=temporal_proportions
)
self.sarm_transformer.to(self.device)
@@ -281,7 +314,7 @@ class SARMRewardModel(PreTrainedPolicy):
logging.info(f"SARM Reward Model initialized on {self.device}")
def _update_num_stages_from_dataset(self, dataset_meta) -> None:
"""Update num_stages in config based on dataset subtask annotations."""
"""Update num_stages and temporal_proportions from dataset subtask annotations."""
episodes = dataset_meta.episodes
if episodes is None or len(episodes) == 0:
raise ValueError("No episodes found, using default num_stages")
@@ -291,14 +324,28 @@ class SARMRewardModel(PreTrainedPolicy):
episodes_df = episodes.to_pandas()
# Collect all unique subtask names
# Collect all unique subtask names and compute durations
all_subtask_names = set()
subtask_durations = {}
for ep_idx in episodes_df.index:
subtask_names = episodes_df.loc[ep_idx, 'subtask_names']
if subtask_names is None or (isinstance(subtask_names, float) and pd.isna(subtask_names)):
continue
all_subtask_names.update(subtask_names)
# Compute durations if available
if 'subtask_start_frames' in episodes_df.columns and 'subtask_end_frames' in episodes_df.columns:
start_frames = episodes_df.loc[ep_idx, 'subtask_start_frames']
end_frames = episodes_df.loc[ep_idx, 'subtask_end_frames']
for i, name in enumerate(subtask_names):
duration = end_frames[i] - start_frames[i]
if name not in subtask_durations:
subtask_durations[name] = []
subtask_durations[name].append(duration)
if not all_subtask_names:
raise ValueError("No valid subtask names found, using default num_stages")
@@ -306,10 +353,26 @@ class SARMRewardModel(PreTrainedPolicy):
subtask_names = sorted(list(all_subtask_names))
num_stages = len(subtask_names)
# Compute temporal proportions (Paper Eq. 1: ᾱ_k)
avg_durations = {}
for name in subtask_names:
if name in subtask_durations and subtask_durations[name]:
avg_durations[name] = np.mean(subtask_durations[name])
else:
avg_durations[name] = 1.0 # Default
total_duration = sum(avg_durations.values())
if total_duration > 0:
temporal_proportions = [avg_durations[name] / total_duration for name in subtask_names]
else:
temporal_proportions = [1.0 / num_stages] * num_stages
self.config.num_stages = num_stages
self.config.subtask_names = subtask_names
self.config.temporal_proportions = temporal_proportions
logging.info(f"Auto-detected {num_stages} subtasks from dataset: {subtask_names}, using {num_stages} stages")
logging.info(f"Auto-detected {num_stages} subtasks: {subtask_names}")
logging.info(f"Temporal proportions: {dict(zip(subtask_names, temporal_proportions))}")
def to(self, device):
"""Override to method to ensure all components move together."""
@@ -357,7 +420,7 @@ class SARMRewardModel(PreTrainedPolicy):
# Batch process frames with CLIP
for i in range(0, len(frames), self.config.clip_batch_size):
batch = frames[i:i + self.config.clip_batch_size]
inputs = self.clip_processor(images=batch, return_tensors="pt", padding=True)
inputs = self.clip_processor(images=batch, return_tensors="pt")
inputs = {k: v.to(self.device) for k, v in inputs.items()}
# Get image embeddings from CLIP
@@ -578,8 +641,8 @@ class SARMRewardModel(PreTrainedPolicy):
state: torch.Tensor | None,
max_length: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
"""Apply rewind augmentation: append 2-4 reversed frames (SARM paper)."""
num_reverse = random.randint(2, min(4, max_length - 1))
"""Apply rewind augmentation: append up to 4 reversed frames (SARM paper A.4)."""
num_reverse = random.randint(1, min(4, max_length - 1))
# Reverse and take frames (skip first which is last of original)
reversed_video = video.flip(0)[1:num_reverse + 1]
+39 -14
View File
@@ -208,15 +208,31 @@ class SARMEncodingProcessorStep(ProcessorStep):
return episode_indices
def _compute_absolute_indices(self, frame_idx: int, ep_start: int, num_frames: int) -> torch.Tensor:
"""Compute absolute frame indices for a sequence."""
"""Compute absolute frame indices for a sequence.
(per SARM paper Section A.4):
- Frame 0: Initial frame of the episode (ep_start)
- Frames 1-8: 8 consecutive frames with frame_gap spacing ending at current frame
Pattern: [ep_start, t-(7*gap), t-(6*gap), ..., t-gap, t]
"""
frame_gap = getattr(self.config, 'frame_gap', 1)
if frame_gap > 1:
indices = [max(ep_start, frame_idx - (num_frames - 1 - i) * frame_gap) for i in range(num_frames)]
return torch.tensor(indices)
else:
start_idx = max(ep_start, frame_idx - num_frames + 1)
return torch.arange(start_idx, frame_idx + 1)
indices = []
# First frame is the episode's initial frame
indices.append(ep_start)
# Remaining frames are consecutive with frame_gap spacing
num_consecutive = num_frames - 1
for i in range(num_consecutive):
offset = -(num_consecutive - 1 - i) * frame_gap
idx = max(ep_start, frame_idx + offset)
indices.append(idx)
return torch.tensor(indices)
def _compute_episode_metadata(
self,
@@ -324,6 +340,10 @@ class SARMEncodingProcessorStep(ProcessorStep):
) -> tuple[torch.Tensor, torch.Tensor] | tuple[None, None]:
"""Compute stage labels and progress targets for a single sample.
(per SARM paper Section A.4):
- Frame 0: Initial frame of episode (stage at frame 0, progress at frame 0)
- Frames 1-8: 8 consecutive frames with frame_gap spacing ending at current frame
Args:
frame_idx: The frame index for this sample
ep_idx: The episode index
@@ -348,7 +368,7 @@ class SARMEncodingProcessorStep(ProcessorStep):
# Get episode boundaries
ep_start = self.dataset_meta.episodes[ep_idx]["dataset_from_index"]
# Get frame gap for temporal sampling
# Get config values
frame_gap = self.config.frame_gap if hasattr(self.config, 'frame_gap') else 1
# Generate labels for each frame in the sequence
@@ -356,12 +376,15 @@ class SARMEncodingProcessorStep(ProcessorStep):
progress_targets = []
for i in range(seq_len):
# Calculate actual frame index for this position in sequence
if frame_gap > 1:
offset = -(seq_len - 1 - i) * frame_gap
current_frame = max(0, frame_idx + offset - ep_start)
if i == 0:
# Position 0: Initial frame of the episode
current_frame = 0 # Relative to episode start
else:
current_frame = max(0, frame_idx - seq_len + 1 + i - ep_start)
# Positions 1-8: consecutive frames with frame_gap spacing
num_consecutive = seq_len - 1
offset = -(num_consecutive - i) * frame_gap
current_frame = max(0, frame_idx + offset - ep_start)
stage_idx, cumulative_progress = self._compute_stage_and_progress_for_frame(
current_frame, subtask_names, subtask_start_frames, subtask_end_frames
@@ -564,7 +587,7 @@ class SARMEncodingProcessorStep(ProcessorStep):
batch_imgs = images_list[i:i + self.config.clip_batch_size]
# Process with CLIP
inputs = self.clip_processor(images=batch_imgs, return_tensors="pt", padding=True)
inputs = self.clip_processor(images=batch_imgs, return_tensors="pt")
inputs = {k: v.to(self.device) for k, v in inputs.items()}
# Get image embeddings
@@ -707,3 +730,5 @@ def make_sarm_pre_post_processors(
),
)