From ee48a80e4d018b3150549dbac29449f19485afff Mon Sep 17 00:00:00 2001 From: Pepijn Date: Mon, 1 Sep 2025 14:19:07 +0200 Subject: [PATCH] hls_gaus true --- .../policies/rlearn/configuration_rlearn.py | 9 +++++---- src/lerobot/policies/rlearn/modeling_rlearn.py | 18 +++++++++--------- 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/src/lerobot/policies/rlearn/configuration_rlearn.py b/src/lerobot/policies/rlearn/configuration_rlearn.py index 31dca0b67..141434887 100644 --- a/src/lerobot/policies/rlearn/configuration_rlearn.py +++ b/src/lerobot/policies/rlearn/configuration_rlearn.py @@ -102,10 +102,11 @@ class RLearNConfig(PreTrainedConfig): @property def observation_delta_indices(self) -> list | None: - # Use temporal sequences: past frames from -(max_seq_len-1) to current (0) - # This gives us max_seq_len frames total, e.g. [-3, -2, -1, 0] for max_seq_len=4 - # The dataset will handle padding/repeating frames for episodes shorter than this - return list(range(1 - self.max_seq_len, 1)) + # Request a long enough context so in-window stride sampling can be >1. + # We ask for (max_seq_len * temporal_sampling_stride) frames ending at t=0. + # Example: max_seq_len=16, temporal_sampling_stride=3 → 48 deltas → ~46 frames available. + total_needed = self.max_seq_len * max(1, int(self.temporal_sampling_stride)) + return list(range(1 - total_needed, 1)) @property def action_delta_indices(self) -> list | None: diff --git a/src/lerobot/policies/rlearn/modeling_rlearn.py b/src/lerobot/policies/rlearn/modeling_rlearn.py index bdf210129..008c150cd 100644 --- a/src/lerobot/policies/rlearn/modeling_rlearn.py +++ b/src/lerobot/policies/rlearn/modeling_rlearn.py @@ -792,16 +792,16 @@ class RLearNPolicy(PreTrainedPolicy): ep_length = ep_end - ep_start episode_lengths.append(ep_length) - # Proper window-relative stride sampling within available frames - stride = self.config.temporal_sampling_stride - # Ensure we have room for T frames at given stride; shrink stride if needed - if available_T <= 1: - effective_stride = 1 - else: - effective_stride = max(1, min(stride, (available_T - 1) // max(T - 1, 1) if (T - 1) > 0 else 1)) + # Proper window-relative stride sampling; allow clamping for out-of-bounds + effective_stride = max(1, int(self.config.temporal_sampling_stride)) min_anchor_in_window = (T - 1) * effective_stride - max_anchor_in_window = max(min_anchor_in_window, available_T - 1) - anchor_in_window = torch.randint(min_anchor_in_window, max_anchor_in_window + 1, (1,)).item() + max_anchor_in_window = available_T - 1 + if min_anchor_in_window <= max_anchor_in_window: + anchor_in_window = torch.randint(min_anchor_in_window, max_anchor_in_window + 1, (1,)).item() + else: + # Not enough frames to satisfy stride; anchor at the last available frame, + # earlier indices will clamp to 0 (repeat first frame) + anchor_in_window = max_anchor_in_window # Convert window-anchor to episode-anchor (absolute frame index within episode) cur_frame_idx = frame_indices[b_idx].item()