hls_gaus true

This commit is contained in:
Pepijn
2025-09-01 14:19:07 +02:00
parent cb0fb8ad15
commit ee48a80e4d
2 changed files with 14 additions and 13 deletions
@@ -102,10 +102,11 @@ class RLearNConfig(PreTrainedConfig):
@property
def observation_delta_indices(self) -> list | None:
# Use temporal sequences: past frames from -(max_seq_len-1) to current (0)
# This gives us max_seq_len frames total, e.g. [-3, -2, -1, 0] for max_seq_len=4
# The dataset will handle padding/repeating frames for episodes shorter than this
return list(range(1 - self.max_seq_len, 1))
# Request a long enough context so in-window stride sampling can be >1.
# We ask for (max_seq_len * temporal_sampling_stride) frames ending at t=0.
# Example: max_seq_len=16, temporal_sampling_stride=3 → 48 deltas → ~46 frames available.
total_needed = self.max_seq_len * max(1, int(self.temporal_sampling_stride))
return list(range(1 - total_needed, 1))
@property
def action_delta_indices(self) -> list | None:
@@ -792,16 +792,16 @@ class RLearNPolicy(PreTrainedPolicy):
ep_length = ep_end - ep_start
episode_lengths.append(ep_length)
# Proper window-relative stride sampling within available frames
stride = self.config.temporal_sampling_stride
# Ensure we have room for T frames at given stride; shrink stride if needed
if available_T <= 1:
effective_stride = 1
else:
effective_stride = max(1, min(stride, (available_T - 1) // max(T - 1, 1) if (T - 1) > 0 else 1))
# Proper window-relative stride sampling; allow clamping for out-of-bounds
effective_stride = max(1, int(self.config.temporal_sampling_stride))
min_anchor_in_window = (T - 1) * effective_stride
max_anchor_in_window = max(min_anchor_in_window, available_T - 1)
anchor_in_window = torch.randint(min_anchor_in_window, max_anchor_in_window + 1, (1,)).item()
max_anchor_in_window = available_T - 1
if min_anchor_in_window <= max_anchor_in_window:
anchor_in_window = torch.randint(min_anchor_in_window, max_anchor_in_window + 1, (1,)).item()
else:
# Not enough frames to satisfy stride; anchor at the last available frame,
# earlier indices will clamp to 0 (repeat first frame)
anchor_in_window = max_anchor_in_window
# Convert window-anchor to episode-anchor (absolute frame index within episode)
cur_frame_idx = frame_indices[b_idx].item()