add stride

This commit is contained in:
Pepijn
2025-08-31 18:32:47 +02:00
parent 8d20ca1625
commit 5c1d930a34
2 changed files with 11 additions and 8 deletions
@@ -55,6 +55,8 @@ class RLearNConfig(PreTrainedConfig):
# Sequence length, amount of past frames including current one to use in the temporal model
max_seq_len: int = 16
# Temporal sampling stride (2 = skip every other frame for wider temporal coverage)
temporal_sampling_stride: int = 2
# Training
learning_rate: float = 1e-3
@@ -579,6 +579,7 @@ class RLearNPolicy(PreTrainedPolicy):
# Only apply mismatch loss if we have actual mismatches
if any(mismatch_mask):
print("Applying mismatch loss!!!")
# Re-encode with mismatched language
lang_embeds_mm, mask_mm = self._encode_language_tokens(shuffled_commands, device)
lang_tokens_mm = self.to_lang_tokens(lang_embeds_mm)
@@ -662,10 +663,6 @@ class RLearNPolicy(PreTrainedPolicy):
pred_std = sample_preds.std()
target_std = sample_targets.std()
print(f" Variation - Target std: {target_std:.4f} | Pred std: {pred_std:.4f}")
if pred_std < 0.01:
print(f" ⚠️ PREDICTIONS STUCK (std={pred_std:.5f})")
else:
print(f" ✓ Predictions varying normally")
else:
# For longer sequences, show first 8 and last 8
print(f" Targets: {sample_targets[:8]} ... {sample_targets[-8:]}")
@@ -897,17 +894,21 @@ class RLearNPolicy(PreTrainedPolicy):
ep_length = ep_end - ep_start
episode_lengths.append(ep_length)
# Choose random anchor - need at least T-1 frames before for [-15..0] window
min_anchor = T - 1
# Choose random anchor - need enough frames before for stride sampling
# For T=16 and stride=2, we need frames [anchor-30, anchor-28, ..., anchor-2, anchor]
stride = self.config.temporal_sampling_stride
min_anchor = (T - 1) * stride # Need (T-1)*stride frames before anchor
max_anchor = max(min_anchor, ep_length - 1)
anchor = torch.randint(min_anchor, max_anchor + 1, (1,)).item()
anchor_positions.append(anchor)
# Build window indices with reflection padding
# Build window indices with configurable stride sampling and reflection padding
window_indices = []
frame_indices_for_progress = [] # Track actual frame positions for progress
had_oob = False
for delta in range(-(T-1), 1): # [-15, -14, ..., 0] for T=16
# Sample with stride: [anchor-(T-1)*stride, anchor-(T-2)*stride, ..., anchor-stride, anchor]
for i in range(T):
delta = -(T - 1 - i) * stride # Work backwards from anchor with stride spacing
idx = anchor + delta
actual_frame_idx = idx # Store the actual frame index before reflection
if idx < 0: