fix progress conversion and adding initial frame

This commit is contained in:
Pepijn
2025-11-26 11:02:42 +01:00
parent c66aef878c
commit cc2e91febe
4 changed files with 156 additions and 74 deletions
+16 -34
View File
@@ -224,14 +224,14 @@ def run_inference(
""" """
Run SARM inference on video frames and joint states. Run SARM inference on video frames and joint states.
For each frame t, creates a temporal sequence of 9 frames using SARM's pattern: (per SARM paper Section A.4):
[t-240, t-210, t-180, t-150, t-120, t-90, t-60, t-30, t] - Frame 0: Initial frame of the episode (frame 0)
This matches the training pattern where frames are loaded with 30-frame gaps - Frames 1-8: 8 consecutive frames with frame_gap spacing ending at current frame t
relative to the current frame. Pattern: [frame_0, t-(7*gap), t-(6*gap), ..., t-gap, t]
Args: Args:
model: SARM model model: SARM model
frames: Video frames (num_frames, H, W, C) frames: Video frames (num_frames, H, W, C) - all frames from ONE episode
states: Joint states (num_frames, state_dim) states: Joint states (num_frames, state_dim)
task_description: Task description text task_description: Task description text
batch_size: Batch size for processing slices batch_size: Batch size for processing slices
@@ -247,7 +247,12 @@ def run_inference(
logger.info("Encoding task description with MiniLM...") logger.info("Encoding task description with MiniLM...")
text_embedding = model.encode_text(task_description) text_embedding = model.encode_text(task_description)
logger.info("Creating video slices (SARM approach)...") # Get config values
num_frames_model = model.config.num_frames # 9
frame_gap = model.config.frame_gap # 30
logger.info("Creating video slices (SARM paper: initial frame + 8 consecutive)...")
# Convert to tensors # Convert to tensors
video_embeddings = torch.tensor(video_embeddings, dtype=torch.float32) video_embeddings = torch.tensor(video_embeddings, dtype=torch.float32)
text_embedding = torch.tensor(text_embedding, dtype=torch.float32) text_embedding = torch.tensor(text_embedding, dtype=torch.float32)
@@ -256,33 +261,14 @@ def run_inference(
else: else:
state_embeddings = None state_embeddings = None
# Create video slices: for each frame i, create a sequence using SARM's pattern
# For SARM: 9 frames relative to current, with 30-frame gaps
# Pattern: [current-240, current-210, ..., current-30, current]
num_frames_model = model.config.num_frames
frame_gap = model.config.frame_gap
video_slices = [] video_slices = []
state_slices = [] state_slices = []
last_frame_indices = []
for i in tqdm(range(len(video_embeddings)), desc="Creating slices"): for current_frame in tqdm(range(len(video_embeddings)), desc="Creating slices"):
# For SARM, create sequence relative to current frame (matching training pattern) # Compute frame indices using SARM pattern:
# Pattern: [current-240, current-210, ..., current-30, current] # [initial_frame, t-(7*gap), t-(6*gap), ..., t-gap, t]
# This matches observation_delta_indices: range(-240, 1, 30) deltas = model.config.observation_delta_indices(current_frame)
frame_indices = [max(0, current_frame + delta) for delta in deltas]
# Compute frame indices for this slice (relative to current frame i)
frame_indices = []
for j in range(num_frames_model):
# Start from -(num_frames_model-1) * frame_gap and go to 0
offset = -(num_frames_model - 1 - j) * frame_gap
idx = i + offset
# Clamp to valid range [0, current_frame]
if idx < 0:
idx = 0 # Pad with first available frame
frame_indices.append(idx)
# Extract slice # Extract slice
video_slice = video_embeddings[frame_indices] video_slice = video_embeddings[frame_indices]
@@ -292,9 +278,6 @@ def run_inference(
state_slice = state_embeddings[frame_indices] state_slice = state_embeddings[frame_indices]
state_slices.append(state_slice) state_slices.append(state_slice)
# Track which frame index corresponds to the "current" frame
last_frame_indices.append(min(i, len(frame_indices) - 1))
video_slices = torch.stack(video_slices) # (num_frames, num_frames_model, 512) video_slices = torch.stack(video_slices) # (num_frames, num_frames_model, 512)
if state_embeddings is not None: if state_embeddings is not None:
state_slices = torch.stack(state_slices) # (num_frames, num_frames_model, state_dim) state_slices = torch.stack(state_slices) # (num_frames, num_frames_model, state_dim)
@@ -320,7 +303,6 @@ def run_inference(
) )
# Extract last frame predictions (the "current" frame) # Extract last frame predictions (the "current" frame)
# For SARM, we take the last frame in each sequence
batch_progress = progress_preds[:, -1, 0].cpu().numpy() batch_progress = progress_preds[:, -1, 0].cpu().numpy()
batch_stages = stage_probs[:, -1, :].cpu().numpy() batch_stages = stage_probs[:, -1, :].cpu().numpy()
+22 -10
View File
@@ -44,6 +44,7 @@ class SARMConfig(PreTrainedConfig):
num_layers: int = 8 num_layers: int = 8
num_stages: int = 5 # Number of task stages for classification (auto-updated from annotations if available) num_stages: int = 5 # Number of task stages for classification (auto-updated from annotations if available)
subtask_names: list | None = None # List of subtask names (auto-populated from annotations) subtask_names: list | None = None # List of subtask names (auto-populated from annotations)
temporal_proportions: list | None = None # Temporal proportions for each stage (auto-computed from annotations)
# Temporal parameters # Temporal parameters
max_length: int = num_frames # Maximum video sequence length (matches num_frames) max_length: int = num_frames # Maximum video sequence length (matches num_frames)
@@ -128,20 +129,31 @@ class SARMConfig(PreTrainedConfig):
"""Validate input and output features.""" """Validate input and output features."""
pass pass
@property def observation_delta_indices(self, episode_frame_index: int) -> list[int]:
def observation_delta_indices(self) -> list[int]: """Compute delta indices for SARM temporal sampling.
"""Load frames for SARM temporal sampling.
SARM uses 9 frames: 1 initial frame + 8 consecutive frames with frame_gap spacing. Per SARM paper (Section A.4), the model uses 9 frames:
- Frame 0: Initial frame of the episode (delta = -episode_frame_index)
- Frames 1-8: 8 consecutive frames with frame_gap spacing ending at current frame
The dataloader converts these to seconds: delta_seconds = delta / fps
This means the first delta (-episode_frame_index) becomes -current_time,
which correctly points to t=0 (the initial frame).
Args:
episode_frame_index: Current frame index within the episode (0, 1, 2, ...)
Returns: Returns:
Indices for loading: [-(8*frame_gap), ..., -frame_gap, 0] 9 delta indices: [-episode_frame_index, -(7*gap), -(6*gap), ..., -gap, 0]
""" """
# For SARM: we need the initial frame (from episode start) plus 8 consecutive frames # First delta: negative of current frame index to reach frame 0
# The dataset will load relative to current frame initial_frame_delta = -episode_frame_index
# We'll handle the "initial frame" logic in the processor
# For now, load the last 8*frame_gap frames # Remaining 8 deltas: consecutive frames with frame_gap spacing
return list(range(-self.frame_gap * (self.num_frames - 1), 1, self.frame_gap)) num_consecutive = self.num_frames - 1 # 8 frames
consecutive_deltas = list(range(-self.frame_gap * (num_consecutive - 1), 1, self.frame_gap))
return [initial_frame_delta] + consecutive_deltas
@property @property
def action_delta_indices(self) -> None: def action_delta_indices(self) -> None:
+79 -16
View File
@@ -19,6 +19,7 @@ from typing import List, Union, Dict, Optional
import random import random
import numpy as np import numpy as np
import pandas as pd
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
@@ -71,13 +72,31 @@ class SARMTransformer(nn.Module):
num_layers: int = 8, num_layers: int = 8,
num_stages: int = 5, num_stages: int = 5,
max_length: int = 9, max_length: int = 9,
dropout: float = 0.1 dropout: float = 0.1,
temporal_proportions: list[float] | None = None
): ):
super().__init__() super().__init__()
self.hidden_dim = hidden_dim self.hidden_dim = hidden_dim
self.max_length = max_length self.max_length = max_length
self.num_stages = num_stages self.num_stages = num_stages
# Store temporal proportions for progress conversion (Paper Eq. 4)
# ŷ = P_{k-1} + ᾱ_k × τ̂
if temporal_proportions is None:
raise ValueError(
"temporal_proportions is required for SARM. "
"Provide subtask annotations in your dataset or set temporal_proportions in config."
)
# ᾱ_k: proportion for each stage
alpha = torch.tensor(temporal_proportions, dtype=torch.float32)
# P_k: cumulative proportion up to stage k (P_0 = 0)
cumulative = torch.zeros(num_stages + 1, dtype=torch.float32)
cumulative[1:] = torch.cumsum(alpha, dim=0)
self.register_buffer('alpha', alpha)
self.register_buffer('cumulative_prior', cumulative)
# Project video, text, and state to same dimension # Project video, text, and state to same dimension
self.video_proj = nn.Linear(video_dim, hidden_dim) self.video_proj = nn.Linear(video_dim, hidden_dim)
self.text_proj = nn.Linear(text_dim, hidden_dim) self.text_proj = nn.Linear(text_dim, hidden_dim)
@@ -97,24 +116,26 @@ class SARMTransformer(nn.Module):
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
# Stage estimator head (classification) # Stage estimator head (classification)
# Paper A.4: "2 layers with hidden dimension of 512"
self.stage_head = nn.Sequential( self.stage_head = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim // 2), nn.Linear(hidden_dim, 512),
nn.LayerNorm(hidden_dim // 2), nn.LayerNorm(512),
nn.GELU(), nn.GELU(),
nn.Dropout(dropout), nn.Dropout(dropout),
nn.Linear(hidden_dim // 2, num_stages) nn.Linear(512, num_stages)
) )
# Subtask estimator head (regression, conditioned on stage) # Subtask estimator head (regression, conditioned on stage)
# Takes concatenated [features, stage_embedding] # Takes concatenated [features, stage_embedding]
# Paper A.4: "2 layers with hidden dimension of 512"
self.stage_embedding = nn.Embedding(num_stages, hidden_dim // 4) self.stage_embedding = nn.Embedding(num_stages, hidden_dim // 4)
subtask_input_dim = hidden_dim + hidden_dim // 4 subtask_input_dim = hidden_dim + hidden_dim // 4
self.subtask_head = nn.Sequential( self.subtask_head = nn.Sequential(
nn.Linear(subtask_input_dim, hidden_dim // 2), nn.Linear(subtask_input_dim, 512),
nn.LayerNorm(hidden_dim // 2), nn.LayerNorm(512),
nn.GELU(), nn.GELU(),
nn.Dropout(dropout), nn.Dropout(dropout),
nn.Linear(hidden_dim // 2, 1), nn.Linear(512, 1),
nn.Sigmoid() nn.Sigmoid()
) )
@@ -189,7 +210,17 @@ class SARMTransformer(nn.Module):
conditioned_features = torch.cat([frame_features, stage_embeds], dim=-1) conditioned_features = torch.cat([frame_features, stage_embeds], dim=-1)
# Subtask progress estimation (conditioned on stage) # Subtask progress estimation (conditioned on stage)
progress_preds = self.subtask_head(conditioned_features) # [batch_size, seq_len, 1] # τ̂ = within-subtask progress (0-1)
tau_preds = self.subtask_head(conditioned_features) # [batch_size, seq_len, 1]
# Convert τ̂ to cumulative progress ŷ using Paper Eq. 4:
# ŷ = P_{k-1} + ᾱ_k × τ̂
# P_{k-1} = cumulative prior up to stage k-1
# ᾱ_k = temporal proportion of stage k
P_k_minus_1 = self.cumulative_prior[stage_indices] # [batch_size, seq_len]
alpha_k = self.alpha[stage_indices] # [batch_size, seq_len]
progress_preds = P_k_minus_1.unsqueeze(-1) + alpha_k.unsqueeze(-1) * tau_preds
return stage_logits, stage_probs, progress_preds return stage_logits, stage_probs, progress_preds
@@ -263,7 +294,8 @@ class SARMRewardModel(PreTrainedPolicy):
"2. Ensure dataset_stats contains 'observation.state' or 'state' key" "2. Ensure dataset_stats contains 'observation.state' or 'state' key"
) )
# Initialize SARM transformer # Initialize SARM transformer with temporal proportions for progress conversion
temporal_proportions = getattr(config, 'temporal_proportions', None)
self.sarm_transformer = SARMTransformer( self.sarm_transformer = SARMTransformer(
video_dim=config.image_dim, video_dim=config.image_dim,
text_dim=config.text_dim, text_dim=config.text_dim,
@@ -273,7 +305,8 @@ class SARMRewardModel(PreTrainedPolicy):
num_layers=config.num_layers, num_layers=config.num_layers,
num_stages=config.num_stages, num_stages=config.num_stages,
max_length=config.max_length, max_length=config.max_length,
dropout=config.dropout dropout=config.dropout,
temporal_proportions=temporal_proportions
) )
self.sarm_transformer.to(self.device) self.sarm_transformer.to(self.device)
@@ -281,7 +314,7 @@ class SARMRewardModel(PreTrainedPolicy):
logging.info(f"SARM Reward Model initialized on {self.device}") logging.info(f"SARM Reward Model initialized on {self.device}")
def _update_num_stages_from_dataset(self, dataset_meta) -> None: def _update_num_stages_from_dataset(self, dataset_meta) -> None:
"""Update num_stages in config based on dataset subtask annotations.""" """Update num_stages and temporal_proportions from dataset subtask annotations."""
episodes = dataset_meta.episodes episodes = dataset_meta.episodes
if episodes is None or len(episodes) == 0: if episodes is None or len(episodes) == 0:
raise ValueError("No episodes found, using default num_stages") raise ValueError("No episodes found, using default num_stages")
@@ -291,14 +324,28 @@ class SARMRewardModel(PreTrainedPolicy):
episodes_df = episodes.to_pandas() episodes_df = episodes.to_pandas()
# Collect all unique subtask names # Collect all unique subtask names and compute durations
all_subtask_names = set() all_subtask_names = set()
subtask_durations = {}
for ep_idx in episodes_df.index: for ep_idx in episodes_df.index:
subtask_names = episodes_df.loc[ep_idx, 'subtask_names'] subtask_names = episodes_df.loc[ep_idx, 'subtask_names']
if subtask_names is None or (isinstance(subtask_names, float) and pd.isna(subtask_names)): if subtask_names is None or (isinstance(subtask_names, float) and pd.isna(subtask_names)):
continue continue
all_subtask_names.update(subtask_names) all_subtask_names.update(subtask_names)
# Compute durations if available
if 'subtask_start_frames' in episodes_df.columns and 'subtask_end_frames' in episodes_df.columns:
start_frames = episodes_df.loc[ep_idx, 'subtask_start_frames']
end_frames = episodes_df.loc[ep_idx, 'subtask_end_frames']
for i, name in enumerate(subtask_names):
duration = end_frames[i] - start_frames[i]
if name not in subtask_durations:
subtask_durations[name] = []
subtask_durations[name].append(duration)
if not all_subtask_names: if not all_subtask_names:
raise ValueError("No valid subtask names found, using default num_stages") raise ValueError("No valid subtask names found, using default num_stages")
@@ -306,10 +353,26 @@ class SARMRewardModel(PreTrainedPolicy):
subtask_names = sorted(list(all_subtask_names)) subtask_names = sorted(list(all_subtask_names))
num_stages = len(subtask_names) num_stages = len(subtask_names)
# Compute temporal proportions (Paper Eq. 1: ᾱ_k)
avg_durations = {}
for name in subtask_names:
if name in subtask_durations and subtask_durations[name]:
avg_durations[name] = np.mean(subtask_durations[name])
else:
avg_durations[name] = 1.0 # Default
total_duration = sum(avg_durations.values())
if total_duration > 0:
temporal_proportions = [avg_durations[name] / total_duration for name in subtask_names]
else:
temporal_proportions = [1.0 / num_stages] * num_stages
self.config.num_stages = num_stages self.config.num_stages = num_stages
self.config.subtask_names = subtask_names self.config.subtask_names = subtask_names
self.config.temporal_proportions = temporal_proportions
logging.info(f"Auto-detected {num_stages} subtasks from dataset: {subtask_names}, using {num_stages} stages") logging.info(f"Auto-detected {num_stages} subtasks: {subtask_names}")
logging.info(f"Temporal proportions: {dict(zip(subtask_names, temporal_proportions))}")
def to(self, device): def to(self, device):
"""Override to method to ensure all components move together.""" """Override to method to ensure all components move together."""
@@ -357,7 +420,7 @@ class SARMRewardModel(PreTrainedPolicy):
# Batch process frames with CLIP # Batch process frames with CLIP
for i in range(0, len(frames), self.config.clip_batch_size): for i in range(0, len(frames), self.config.clip_batch_size):
batch = frames[i:i + self.config.clip_batch_size] batch = frames[i:i + self.config.clip_batch_size]
inputs = self.clip_processor(images=batch, return_tensors="pt", padding=True) inputs = self.clip_processor(images=batch, return_tensors="pt")
inputs = {k: v.to(self.device) for k, v in inputs.items()} inputs = {k: v.to(self.device) for k, v in inputs.items()}
# Get image embeddings from CLIP # Get image embeddings from CLIP
@@ -578,8 +641,8 @@ class SARMRewardModel(PreTrainedPolicy):
state: torch.Tensor | None, state: torch.Tensor | None,
max_length: int, max_length: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]: ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
"""Apply rewind augmentation: append 2-4 reversed frames (SARM paper).""" """Apply rewind augmentation: append up to 4 reversed frames (SARM paper A.4)."""
num_reverse = random.randint(2, min(4, max_length - 1)) num_reverse = random.randint(1, min(4, max_length - 1))
# Reverse and take frames (skip first which is last of original) # Reverse and take frames (skip first which is last of original)
reversed_video = video.flip(0)[1:num_reverse + 1] reversed_video = video.flip(0)[1:num_reverse + 1]
+38 -13
View File
@@ -208,15 +208,31 @@ class SARMEncodingProcessorStep(ProcessorStep):
return episode_indices return episode_indices
def _compute_absolute_indices(self, frame_idx: int, ep_start: int, num_frames: int) -> torch.Tensor: def _compute_absolute_indices(self, frame_idx: int, ep_start: int, num_frames: int) -> torch.Tensor:
"""Compute absolute frame indices for a sequence.""" """Compute absolute frame indices for a sequence.
(per SARM paper Section A.4):
- Frame 0: Initial frame of the episode (ep_start)
- Frames 1-8: 8 consecutive frames with frame_gap spacing ending at current frame
Pattern: [ep_start, t-(7*gap), t-(6*gap), ..., t-gap, t]
"""
frame_gap = getattr(self.config, 'frame_gap', 1) frame_gap = getattr(self.config, 'frame_gap', 1)
if frame_gap > 1: indices = []
indices = [max(ep_start, frame_idx - (num_frames - 1 - i) * frame_gap) for i in range(num_frames)]
# First frame is the episode's initial frame
indices.append(ep_start)
# Remaining frames are consecutive with frame_gap spacing
num_consecutive = num_frames - 1
for i in range(num_consecutive):
offset = -(num_consecutive - 1 - i) * frame_gap
idx = max(ep_start, frame_idx + offset)
indices.append(idx)
return torch.tensor(indices) return torch.tensor(indices)
else:
start_idx = max(ep_start, frame_idx - num_frames + 1)
return torch.arange(start_idx, frame_idx + 1)
def _compute_episode_metadata( def _compute_episode_metadata(
self, self,
@@ -324,6 +340,10 @@ class SARMEncodingProcessorStep(ProcessorStep):
) -> tuple[torch.Tensor, torch.Tensor] | tuple[None, None]: ) -> tuple[torch.Tensor, torch.Tensor] | tuple[None, None]:
"""Compute stage labels and progress targets for a single sample. """Compute stage labels and progress targets for a single sample.
(per SARM paper Section A.4):
- Frame 0: Initial frame of episode (stage at frame 0, progress at frame 0)
- Frames 1-8: 8 consecutive frames with frame_gap spacing ending at current frame
Args: Args:
frame_idx: The frame index for this sample frame_idx: The frame index for this sample
ep_idx: The episode index ep_idx: The episode index
@@ -348,7 +368,7 @@ class SARMEncodingProcessorStep(ProcessorStep):
# Get episode boundaries # Get episode boundaries
ep_start = self.dataset_meta.episodes[ep_idx]["dataset_from_index"] ep_start = self.dataset_meta.episodes[ep_idx]["dataset_from_index"]
# Get frame gap for temporal sampling # Get config values
frame_gap = self.config.frame_gap if hasattr(self.config, 'frame_gap') else 1 frame_gap = self.config.frame_gap if hasattr(self.config, 'frame_gap') else 1
# Generate labels for each frame in the sequence # Generate labels for each frame in the sequence
@@ -356,12 +376,15 @@ class SARMEncodingProcessorStep(ProcessorStep):
progress_targets = [] progress_targets = []
for i in range(seq_len): for i in range(seq_len):
# Calculate actual frame index for this position in sequence if i == 0:
if frame_gap > 1: # Position 0: Initial frame of the episode
offset = -(seq_len - 1 - i) * frame_gap current_frame = 0 # Relative to episode start
current_frame = max(0, frame_idx + offset - ep_start)
else: else:
current_frame = max(0, frame_idx - seq_len + 1 + i - ep_start) # Positions 1-8: consecutive frames with frame_gap spacing
num_consecutive = seq_len - 1
offset = -(num_consecutive - i) * frame_gap
current_frame = max(0, frame_idx + offset - ep_start)
stage_idx, cumulative_progress = self._compute_stage_and_progress_for_frame( stage_idx, cumulative_progress = self._compute_stage_and_progress_for_frame(
current_frame, subtask_names, subtask_start_frames, subtask_end_frames current_frame, subtask_names, subtask_start_frames, subtask_end_frames
@@ -564,7 +587,7 @@ class SARMEncodingProcessorStep(ProcessorStep):
batch_imgs = images_list[i:i + self.config.clip_batch_size] batch_imgs = images_list[i:i + self.config.clip_batch_size]
# Process with CLIP # Process with CLIP
inputs = self.clip_processor(images=batch_imgs, return_tensors="pt", padding=True) inputs = self.clip_processor(images=batch_imgs, return_tensors="pt")
inputs = {k: v.to(self.device) for k, v in inputs.items()} inputs = {k: v.to(self.device) for k, v in inputs.items()}
# Get image embeddings # Get image embeddings
@@ -707,3 +730,5 @@ def make_sarm_pre_post_processors(
), ),
) )