simple eval

This commit is contained in:
Pepijn
2025-08-31 13:40:13 +02:00
parent b0a5b88c21
commit 7da15ba069
2 changed files with 4 additions and 7 deletions
@@ -56,9 +56,6 @@ class RLearNConfig(PreTrainedConfig):
# Sequence length, amount of past frames including current one to use in the temporal model
max_seq_len: int = 16
# Head
use_tanh_head: bool = False # when True, bound outputs in [-1, 1]
# Training
learning_rate: float = 1e-3
weight_decay: float = 0.01
@@ -255,10 +255,10 @@ class RLearNPolicy(PreTrainedPolicy):
frames = extract_visual_sequence(batch, target_seq_len=self.config.max_seq_len)
B, T, C, H, W = frames.shape
# Apply stride (no dropout during eval)
idx = torch.arange(0, T, self.stride, device=frames.device)
frames = frames[:, idx]
T_eff = frames.shape[1]
# CRITICAL FIX: Do NOT apply stride during evaluation
# During evaluation, we want to process all frames in the sliding window
# Stride should only be used during training to reduce computational cost
T_eff = T # Use all frames during evaluation
# Get language commands
commands = batch.get(OBS_LANGUAGE, None)