mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-24 13:09:43 +00:00
simple eval
This commit is contained in:
@@ -56,9 +56,6 @@ class RLearNConfig(PreTrainedConfig):
|
|||||||
# Sequence length, amount of past frames including current one to use in the temporal model
|
# Sequence length, amount of past frames including current one to use in the temporal model
|
||||||
max_seq_len: int = 16
|
max_seq_len: int = 16
|
||||||
|
|
||||||
# Head
|
|
||||||
use_tanh_head: bool = False # when True, bound outputs in [-1, 1]
|
|
||||||
|
|
||||||
# Training
|
# Training
|
||||||
learning_rate: float = 1e-3
|
learning_rate: float = 1e-3
|
||||||
weight_decay: float = 0.01
|
weight_decay: float = 0.01
|
||||||
|
|||||||
@@ -255,10 +255,10 @@ class RLearNPolicy(PreTrainedPolicy):
|
|||||||
frames = extract_visual_sequence(batch, target_seq_len=self.config.max_seq_len)
|
frames = extract_visual_sequence(batch, target_seq_len=self.config.max_seq_len)
|
||||||
B, T, C, H, W = frames.shape
|
B, T, C, H, W = frames.shape
|
||||||
|
|
||||||
# Apply stride (no dropout during eval)
|
# CRITICAL FIX: Do NOT apply stride during evaluation
|
||||||
idx = torch.arange(0, T, self.stride, device=frames.device)
|
# During evaluation, we want to process all frames in the sliding window
|
||||||
frames = frames[:, idx]
|
# Stride should only be used during training to reduce computational cost
|
||||||
T_eff = frames.shape[1]
|
T_eff = T # Use all frames during evaluation
|
||||||
|
|
||||||
# Get language commands
|
# Get language commands
|
||||||
commands = batch.get(OBS_LANGUAGE, None)
|
commands = batch.get(OBS_LANGUAGE, None)
|
||||||
|
|||||||
Reference in New Issue
Block a user