use large offset for initial frame (ugly)

This commit is contained in:
Pepijn
2025-11-26 11:53:12 +01:00
parent cc2e91febe
commit 425eced2de
2 changed files with 15 additions and 15 deletions
+4 -3
View File
@@ -266,9 +266,10 @@ def run_inference(
for current_frame in tqdm(range(len(video_embeddings)), desc="Creating slices"): for current_frame in tqdm(range(len(video_embeddings)), desc="Creating slices"):
# Compute frame indices using SARM pattern: # Compute frame indices using SARM pattern:
# [initial_frame, t-(7*gap), t-(6*gap), ..., t-gap, t] # [initial_frame (0), t-(7*gap), t-(6*gap), ..., t-gap, t]
deltas = model.config.observation_delta_indices(current_frame) # The first delta is -100000 which clamps to 0 (episode start)
frame_indices = [max(0, current_frame + delta) for delta in deltas] deltas = model.config.observation_delta_indices
frame_indices = [max(0, min(current_frame + delta, len(video_embeddings) - 1)) for delta in deltas]
# Extract slice # Extract slice
video_slice = video_embeddings[frame_indices] video_slice = video_embeddings[frame_indices]
+11 -12
View File
@@ -129,25 +129,24 @@ class SARMConfig(PreTrainedConfig):
"""Validate input and output features.""" """Validate input and output features."""
pass pass
def observation_delta_indices(self, episode_frame_index: int) -> list[int]: @property
"""Compute delta indices for SARM temporal sampling. def observation_delta_indices(self) -> list[int]:
"""Load frames for SARM temporal sampling.
Per SARM paper (Section A.4), the model uses 9 frames: Per SARM paper (Section A.4), the model uses 9 frames:
- Frame 0: Initial frame of the episode (delta = -episode_frame_index) - Frame 0: Initial frame of the episode
- Frames 1-8: 8 consecutive frames with frame_gap spacing ending at current frame - Frames 1-8: 8 consecutive frames with frame_gap spacing ending at current frame
The dataloader converts these to seconds: delta_seconds = delta / fps The first delta uses a large negative offset (-1_000_000) that will be clamped
This means the first delta (-episode_frame_index) becomes -current_time, to the episode start (frame 0) by the dataset loader. This ensures we always
which correctly points to t=0 (the initial frame). get the initial frame regardless of the current position in the episode.
Args:
episode_frame_index: Current frame index within the episode (0, 1, 2, ...)
Returns: Returns:
9 delta indices: [-episode_frame_index, -(7*gap), -(6*gap), ..., -gap, 0] 9 delta indices: [-1_000_000, -(7*gap), -(6*gap), ..., -gap, 0]
""" """
# First delta: negative of current frame index to reach frame 0 # First delta: large negative to always clamp to episode start (frame 0)
initial_frame_delta = -episode_frame_index initial_frame_delta = -1_000_000
# Remaining 8 deltas: consecutive frames with frame_gap spacing # Remaining 8 deltas: consecutive frames with frame_gap spacing
num_consecutive = self.num_frames - 1 # 8 frames num_consecutive = self.num_frames - 1 # 8 frames