diff --git a/scripts/visualize_sarm_predictions.py b/scripts/visualize_sarm_predictions.py index 8c9d7ea23..1841af63c 100644 --- a/scripts/visualize_sarm_predictions.py +++ b/scripts/visualize_sarm_predictions.py @@ -266,9 +266,10 @@ def run_inference( for current_frame in tqdm(range(len(video_embeddings)), desc="Creating slices"): # Compute frame indices using SARM pattern: - # [initial_frame, t-(7*gap), t-(6*gap), ..., t-gap, t] - deltas = model.config.observation_delta_indices(current_frame) - frame_indices = [max(0, current_frame + delta) for delta in deltas] + # [initial_frame (0), t-(7*gap), t-(6*gap), ..., t-gap, t] + # The first delta is -100000 which clamps to 0 (episode start) + deltas = model.config.observation_delta_indices + frame_indices = [max(0, min(current_frame + delta, len(video_embeddings) - 1)) for delta in deltas] # Extract slice video_slice = video_embeddings[frame_indices] diff --git a/src/lerobot/policies/sarm/configuration_sarm.py b/src/lerobot/policies/sarm/configuration_sarm.py index bccac50b5..a89e367e8 100644 --- a/src/lerobot/policies/sarm/configuration_sarm.py +++ b/src/lerobot/policies/sarm/configuration_sarm.py @@ -129,25 +129,24 @@ class SARMConfig(PreTrainedConfig): """Validate input and output features.""" pass - def observation_delta_indices(self, episode_frame_index: int) -> list[int]: - """Compute delta indices for SARM temporal sampling. + @property + def observation_delta_indices(self) -> list[int]: + """Load frames for SARM temporal sampling. Per SARM paper (Section A.4), the model uses 9 frames: - - Frame 0: Initial frame of the episode (delta = -episode_frame_index) + - Frame 0: Initial frame of the episode - Frames 1-8: 8 consecutive frames with frame_gap spacing ending at current frame - The dataloader converts these to seconds: delta_seconds = delta / fps - This means the first delta (-episode_frame_index) becomes -current_time, - which correctly points to t=0 (the initial frame). + The first delta uses a large negative offset (-1_000_000) that will be clamped + to the episode start (frame 0) by the dataset loader. This ensures we always + get the initial frame regardless of the current position in the episode. + - Args: - episode_frame_index: Current frame index within the episode (0, 1, 2, ...) - Returns: - 9 delta indices: [-episode_frame_index, -(7*gap), -(6*gap), ..., -gap, 0] + 9 delta indices: [-1_000_000, -(7*gap), -(6*gap), ..., -gap, 0] """ - # First delta: negative of current frame index to reach frame 0 - initial_frame_delta = -episode_frame_index + # First delta: large negative to always clamp to episode start (frame 0) + initial_frame_delta = -1_000_000 # Remaining 8 deltas: consecutive frames with frame_gap spacing num_consecutive = self.num_frames - 1 # 8 frames