mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 17:20:05 +00:00
fix progress conversion and adding initial frame
This commit is contained in:
@@ -224,14 +224,14 @@ def run_inference(
|
||||
"""
|
||||
Run SARM inference on video frames and joint states.
|
||||
|
||||
For each frame t, creates a temporal sequence of 9 frames using SARM's pattern:
|
||||
[t-240, t-210, t-180, t-150, t-120, t-90, t-60, t-30, t]
|
||||
This matches the training pattern where frames are loaded with 30-frame gaps
|
||||
relative to the current frame.
|
||||
(per SARM paper Section A.4):
|
||||
- Frame 0: Initial frame of the episode (frame 0)
|
||||
- Frames 1-8: 8 consecutive frames with frame_gap spacing ending at current frame t
|
||||
Pattern: [frame_0, t-(7*gap), t-(6*gap), ..., t-gap, t]
|
||||
|
||||
Args:
|
||||
model: SARM model
|
||||
frames: Video frames (num_frames, H, W, C)
|
||||
frames: Video frames (num_frames, H, W, C) - all frames from ONE episode
|
||||
states: Joint states (num_frames, state_dim)
|
||||
task_description: Task description text
|
||||
batch_size: Batch size for processing slices
|
||||
@@ -247,7 +247,12 @@ def run_inference(
|
||||
logger.info("Encoding task description with MiniLM...")
|
||||
text_embedding = model.encode_text(task_description)
|
||||
|
||||
logger.info("Creating video slices (SARM approach)...")
|
||||
# Get config values
|
||||
num_frames_model = model.config.num_frames # 9
|
||||
frame_gap = model.config.frame_gap # 30
|
||||
|
||||
logger.info("Creating video slices (SARM paper: initial frame + 8 consecutive)...")
|
||||
|
||||
# Convert to tensors
|
||||
video_embeddings = torch.tensor(video_embeddings, dtype=torch.float32)
|
||||
text_embedding = torch.tensor(text_embedding, dtype=torch.float32)
|
||||
@@ -256,33 +261,14 @@ def run_inference(
|
||||
else:
|
||||
state_embeddings = None
|
||||
|
||||
# Create video slices: for each frame i, create a sequence using SARM's pattern
|
||||
# For SARM: 9 frames relative to current, with 30-frame gaps
|
||||
# Pattern: [current-240, current-210, ..., current-30, current]
|
||||
num_frames_model = model.config.num_frames
|
||||
frame_gap = model.config.frame_gap
|
||||
|
||||
video_slices = []
|
||||
state_slices = []
|
||||
last_frame_indices = []
|
||||
|
||||
for i in tqdm(range(len(video_embeddings)), desc="Creating slices"):
|
||||
# For SARM, create sequence relative to current frame (matching training pattern)
|
||||
# Pattern: [current-240, current-210, ..., current-30, current]
|
||||
# This matches observation_delta_indices: range(-240, 1, 30)
|
||||
|
||||
# Compute frame indices for this slice (relative to current frame i)
|
||||
frame_indices = []
|
||||
for j in range(num_frames_model):
|
||||
# Start from -(num_frames_model-1) * frame_gap and go to 0
|
||||
offset = -(num_frames_model - 1 - j) * frame_gap
|
||||
idx = i + offset
|
||||
|
||||
# Clamp to valid range [0, current_frame]
|
||||
if idx < 0:
|
||||
idx = 0 # Pad with first available frame
|
||||
|
||||
frame_indices.append(idx)
|
||||
for current_frame in tqdm(range(len(video_embeddings)), desc="Creating slices"):
|
||||
# Compute frame indices using SARM pattern:
|
||||
# [initial_frame, t-(7*gap), t-(6*gap), ..., t-gap, t]
|
||||
deltas = model.config.observation_delta_indices(current_frame)
|
||||
frame_indices = [max(0, current_frame + delta) for delta in deltas]
|
||||
|
||||
# Extract slice
|
||||
video_slice = video_embeddings[frame_indices]
|
||||
@@ -291,9 +277,6 @@ def run_inference(
|
||||
if state_embeddings is not None:
|
||||
state_slice = state_embeddings[frame_indices]
|
||||
state_slices.append(state_slice)
|
||||
|
||||
# Track which frame index corresponds to the "current" frame
|
||||
last_frame_indices.append(min(i, len(frame_indices) - 1))
|
||||
|
||||
video_slices = torch.stack(video_slices) # (num_frames, num_frames_model, 512)
|
||||
if state_embeddings is not None:
|
||||
@@ -320,7 +303,6 @@ def run_inference(
|
||||
)
|
||||
|
||||
# Extract last frame predictions (the "current" frame)
|
||||
# For SARM, we take the last frame in each sequence
|
||||
batch_progress = progress_preds[:, -1, 0].cpu().numpy()
|
||||
batch_stages = stage_probs[:, -1, :].cpu().numpy()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user