From 5c1d930a34406804f5a291732c75c43cbe6a1ea7 Mon Sep 17 00:00:00 2001 From: Pepijn Date: Sun, 31 Aug 2025 18:32:47 +0200 Subject: [PATCH] add stride --- .../policies/rlearn/configuration_rlearn.py | 2 ++ src/lerobot/policies/rlearn/modeling_rlearn.py | 17 +++++++++-------- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/src/lerobot/policies/rlearn/configuration_rlearn.py b/src/lerobot/policies/rlearn/configuration_rlearn.py index 5724c1ca3..d3e02d442 100644 --- a/src/lerobot/policies/rlearn/configuration_rlearn.py +++ b/src/lerobot/policies/rlearn/configuration_rlearn.py @@ -55,6 +55,8 @@ class RLearNConfig(PreTrainedConfig): # Sequence length, amount of past frames including current one to use in the temporal model max_seq_len: int = 16 + # Temporal sampling stride (2 = skip every other frame for wider temporal coverage) + temporal_sampling_stride: int = 2 # Training learning_rate: float = 1e-3 diff --git a/src/lerobot/policies/rlearn/modeling_rlearn.py b/src/lerobot/policies/rlearn/modeling_rlearn.py index b6768b0d8..6f5a224eb 100644 --- a/src/lerobot/policies/rlearn/modeling_rlearn.py +++ b/src/lerobot/policies/rlearn/modeling_rlearn.py @@ -579,6 +579,7 @@ class RLearNPolicy(PreTrainedPolicy): # Only apply mismatch loss if we have actual mismatches if any(mismatch_mask): + print("Applying mismatch loss!!!") # Re-encode with mismatched language lang_embeds_mm, mask_mm = self._encode_language_tokens(shuffled_commands, device) lang_tokens_mm = self.to_lang_tokens(lang_embeds_mm) @@ -662,10 +663,6 @@ class RLearNPolicy(PreTrainedPolicy): pred_std = sample_preds.std() target_std = sample_targets.std() print(f" Variation - Target std: {target_std:.4f} | Pred std: {pred_std:.4f}") - if pred_std < 0.01: - print(f" ⚠️ PREDICTIONS STUCK (std={pred_std:.5f})") - else: - print(f" ✓ Predictions varying normally") else: # For longer sequences, show first 8 and last 8 print(f" Targets: {sample_targets[:8]} ... {sample_targets[-8:]}") @@ -897,17 +894,21 @@ class RLearNPolicy(PreTrainedPolicy): ep_length = ep_end - ep_start episode_lengths.append(ep_length) - # Choose random anchor - need at least T-1 frames before for [-15..0] window - min_anchor = T - 1 + # Choose random anchor - need enough frames before for stride sampling + # For T=16 and stride=2, we need frames [anchor-30, anchor-28, ..., anchor-2, anchor] + stride = self.config.temporal_sampling_stride + min_anchor = (T - 1) * stride # Need (T-1)*stride frames before anchor max_anchor = max(min_anchor, ep_length - 1) anchor = torch.randint(min_anchor, max_anchor + 1, (1,)).item() anchor_positions.append(anchor) - # Build window indices with reflection padding + # Build window indices with configurable stride sampling and reflection padding window_indices = [] frame_indices_for_progress = [] # Track actual frame positions for progress had_oob = False - for delta in range(-(T-1), 1): # [-15, -14, ..., 0] for T=16 + # Sample with stride: [anchor-(T-1)*stride, anchor-(T-2)*stride, ..., anchor-stride, anchor] + for i in range(T): + delta = -(T - 1 - i) * stride # Work backwards from anchor with stride spacing idx = anchor + delta actual_frame_idx = idx # Store the actual frame index before reflection if idx < 0: