mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 00:29:52 +00:00
simple eval
This commit is contained in:
@@ -56,9 +56,6 @@ class RLearNConfig(PreTrainedConfig):
|
||||
# Sequence length, amount of past frames including current one to use in the temporal model
|
||||
max_seq_len: int = 16
|
||||
|
||||
# Head
|
||||
use_tanh_head: bool = False # when True, bound outputs in [-1, 1]
|
||||
|
||||
# Training
|
||||
learning_rate: float = 1e-3
|
||||
weight_decay: float = 0.01
|
||||
|
||||
@@ -255,10 +255,10 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
frames = extract_visual_sequence(batch, target_seq_len=self.config.max_seq_len)
|
||||
B, T, C, H, W = frames.shape
|
||||
|
||||
# Apply stride (no dropout during eval)
|
||||
idx = torch.arange(0, T, self.stride, device=frames.device)
|
||||
frames = frames[:, idx]
|
||||
T_eff = frames.shape[1]
|
||||
# CRITICAL FIX: Do NOT apply stride during evaluation
|
||||
# During evaluation, we want to process all frames in the sliding window
|
||||
# Stride should only be used during training to reduce computational cost
|
||||
T_eff = T # Use all frames during evaluation
|
||||
|
||||
# Get language commands
|
||||
commands = batch.get(OBS_LANGUAGE, None)
|
||||
|
||||
Reference in New Issue
Block a user