From 7da15ba069a9113b398f242dbd9dea5a53c1a8aa Mon Sep 17 00:00:00 2001 From: Pepijn Date: Sun, 31 Aug 2025 13:40:13 +0200 Subject: [PATCH] simple eval --- src/lerobot/policies/rlearn/configuration_rlearn.py | 3 --- src/lerobot/policies/rlearn/modeling_rlearn.py | 8 ++++---- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/src/lerobot/policies/rlearn/configuration_rlearn.py b/src/lerobot/policies/rlearn/configuration_rlearn.py index 6605309cc..22609f06f 100644 --- a/src/lerobot/policies/rlearn/configuration_rlearn.py +++ b/src/lerobot/policies/rlearn/configuration_rlearn.py @@ -56,9 +56,6 @@ class RLearNConfig(PreTrainedConfig): # Sequence length, amount of past frames including current one to use in the temporal model max_seq_len: int = 16 - # Head - use_tanh_head: bool = False # when True, bound outputs in [-1, 1] - # Training learning_rate: float = 1e-3 weight_decay: float = 0.01 diff --git a/src/lerobot/policies/rlearn/modeling_rlearn.py b/src/lerobot/policies/rlearn/modeling_rlearn.py index 53facee5e..60d141ebc 100644 --- a/src/lerobot/policies/rlearn/modeling_rlearn.py +++ b/src/lerobot/policies/rlearn/modeling_rlearn.py @@ -255,10 +255,10 @@ class RLearNPolicy(PreTrainedPolicy): frames = extract_visual_sequence(batch, target_seq_len=self.config.max_seq_len) B, T, C, H, W = frames.shape - # Apply stride (no dropout during eval) - idx = torch.arange(0, T, self.stride, device=frames.device) - frames = frames[:, idx] - T_eff = frames.shape[1] + # CRITICAL FIX: Do NOT apply stride during evaluation + # During evaluation, we want to process all frames in the sliding window + # Stride should only be used during training to reduce computational cost + T_eff = T # Use all frames during evaluation # Get language commands commands = batch.get(OBS_LANGUAGE, None)