use large offset for initial frame (ugly)

This commit is contained in:
Pepijn
2025-11-26 11:53:12 +01:00
parent cc2e91febe
commit 425eced2de
2 changed files with 15 additions and 15 deletions
+4 -3
View File
@@ -266,9 +266,10 @@ def run_inference(
for current_frame in tqdm(range(len(video_embeddings)), desc="Creating slices"):
# Compute frame indices using SARM pattern:
# [initial_frame, t-(7*gap), t-(6*gap), ..., t-gap, t]
deltas = model.config.observation_delta_indices(current_frame)
frame_indices = [max(0, current_frame + delta) for delta in deltas]
# [initial_frame (0), t-(7*gap), t-(6*gap), ..., t-gap, t]
# The first delta is -100000 which clamps to 0 (episode start)
deltas = model.config.observation_delta_indices
frame_indices = [max(0, min(current_frame + delta, len(video_embeddings) - 1)) for delta in deltas]
# Extract slice
video_slice = video_embeddings[frame_indices]
+10 -11
View File
@@ -129,25 +129,24 @@ class SARMConfig(PreTrainedConfig):
"""Validate input and output features."""
pass
def observation_delta_indices(self, episode_frame_index: int) -> list[int]:
"""Compute delta indices for SARM temporal sampling.
@property
def observation_delta_indices(self) -> list[int]:
"""Load frames for SARM temporal sampling.
Per SARM paper (Section A.4), the model uses 9 frames:
- Frame 0: Initial frame of the episode (delta = -episode_frame_index)
- Frame 0: Initial frame of the episode
- Frames 1-8: 8 consecutive frames with frame_gap spacing ending at current frame
The dataloader converts these to seconds: delta_seconds = delta / fps
This means the first delta (-episode_frame_index) becomes -current_time,
which correctly points to t=0 (the initial frame).
The first delta uses a large negative offset (-1_000_000) that will be clamped
to the episode start (frame 0) by the dataset loader. This ensures we always
get the initial frame regardless of the current position in the episode.
Args:
episode_frame_index: Current frame index within the episode (0, 1, 2, ...)
Returns:
9 delta indices: [-episode_frame_index, -(7*gap), -(6*gap), ..., -gap, 0]
9 delta indices: [-1_000_000, -(7*gap), -(6*gap), ..., -gap, 0]
"""
# First delta: negative of current frame index to reach frame 0
initial_frame_delta = -episode_frame_index
# First delta: large negative to always clamp to episode start (frame 0)
initial_frame_delta = -1_000_000
# Remaining 8 deltas: consecutive frames with frame_gap spacing
num_consecutive = self.num_frames - 1 # 8 frames