fix progress conversion and adding initial frame

This commit is contained in:
Pepijn
2025-11-26 11:02:42 +01:00
parent c66aef878c
commit cc2e91febe
4 changed files with 156 additions and 74 deletions
+16 -34
View File
@@ -224,14 +224,14 @@ def run_inference(
"""
Run SARM inference on video frames and joint states.
For each frame t, creates a temporal sequence of 9 frames using SARM's pattern:
[t-240, t-210, t-180, t-150, t-120, t-90, t-60, t-30, t]
This matches the training pattern where frames are loaded with 30-frame gaps
relative to the current frame.
(per SARM paper Section A.4):
- Frame 0: Initial frame of the episode (frame 0)
- Frames 1-8: 8 consecutive frames with frame_gap spacing ending at current frame t
Pattern: [frame_0, t-(7*gap), t-(6*gap), ..., t-gap, t]
Args:
model: SARM model
frames: Video frames (num_frames, H, W, C)
frames: Video frames (num_frames, H, W, C) - all frames from ONE episode
states: Joint states (num_frames, state_dim)
task_description: Task description text
batch_size: Batch size for processing slices
@@ -247,7 +247,12 @@ def run_inference(
logger.info("Encoding task description with MiniLM...")
text_embedding = model.encode_text(task_description)
logger.info("Creating video slices (SARM approach)...")
# Get config values
num_frames_model = model.config.num_frames # 9
frame_gap = model.config.frame_gap # 30
logger.info("Creating video slices (SARM paper: initial frame + 8 consecutive)...")
# Convert to tensors
video_embeddings = torch.tensor(video_embeddings, dtype=torch.float32)
text_embedding = torch.tensor(text_embedding, dtype=torch.float32)
@@ -256,33 +261,14 @@ def run_inference(
else:
state_embeddings = None
# Create video slices: for each frame i, create a sequence using SARM's pattern
# For SARM: 9 frames relative to current, with 30-frame gaps
# Pattern: [current-240, current-210, ..., current-30, current]
num_frames_model = model.config.num_frames
frame_gap = model.config.frame_gap
video_slices = []
state_slices = []
last_frame_indices = []
for i in tqdm(range(len(video_embeddings)), desc="Creating slices"):
# For SARM, create sequence relative to current frame (matching training pattern)
# Pattern: [current-240, current-210, ..., current-30, current]
# This matches observation_delta_indices: range(-240, 1, 30)
# Compute frame indices for this slice (relative to current frame i)
frame_indices = []
for j in range(num_frames_model):
# Start from -(num_frames_model-1) * frame_gap and go to 0
offset = -(num_frames_model - 1 - j) * frame_gap
idx = i + offset
# Clamp to valid range [0, current_frame]
if idx < 0:
idx = 0 # Pad with first available frame
frame_indices.append(idx)
for current_frame in tqdm(range(len(video_embeddings)), desc="Creating slices"):
# Compute frame indices using SARM pattern:
# [initial_frame, t-(7*gap), t-(6*gap), ..., t-gap, t]
deltas = model.config.observation_delta_indices(current_frame)
frame_indices = [max(0, current_frame + delta) for delta in deltas]
# Extract slice
video_slice = video_embeddings[frame_indices]
@@ -291,9 +277,6 @@ def run_inference(
if state_embeddings is not None:
state_slice = state_embeddings[frame_indices]
state_slices.append(state_slice)
# Track which frame index corresponds to the "current" frame
last_frame_indices.append(min(i, len(frame_indices) - 1))
video_slices = torch.stack(video_slices) # (num_frames, num_frames_model, 512)
if state_embeddings is not None:
@@ -320,7 +303,6 @@ def run_inference(
)
# Extract last frame predictions (the "current" frame)
# For SARM, we take the last frame in each sequence
batch_progress = progress_preds[:, -1, 0].cpu().numpy()
batch_stages = stage_probs[:, -1, :].cpu().numpy()