mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 08:39:49 +00:00
add stride
This commit is contained in:
@@ -55,6 +55,8 @@ class RLearNConfig(PreTrainedConfig):
|
||||
|
||||
# Sequence length, amount of past frames including current one to use in the temporal model
|
||||
max_seq_len: int = 16
|
||||
# Temporal sampling stride (2 = skip every other frame for wider temporal coverage)
|
||||
temporal_sampling_stride: int = 2
|
||||
|
||||
# Training
|
||||
learning_rate: float = 1e-3
|
||||
|
||||
@@ -579,6 +579,7 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
|
||||
# Only apply mismatch loss if we have actual mismatches
|
||||
if any(mismatch_mask):
|
||||
print("Applying mismatch loss!!!")
|
||||
# Re-encode with mismatched language
|
||||
lang_embeds_mm, mask_mm = self._encode_language_tokens(shuffled_commands, device)
|
||||
lang_tokens_mm = self.to_lang_tokens(lang_embeds_mm)
|
||||
@@ -662,10 +663,6 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
pred_std = sample_preds.std()
|
||||
target_std = sample_targets.std()
|
||||
print(f" Variation - Target std: {target_std:.4f} | Pred std: {pred_std:.4f}")
|
||||
if pred_std < 0.01:
|
||||
print(f" ⚠️ PREDICTIONS STUCK (std={pred_std:.5f})")
|
||||
else:
|
||||
print(f" ✓ Predictions varying normally")
|
||||
else:
|
||||
# For longer sequences, show first 8 and last 8
|
||||
print(f" Targets: {sample_targets[:8]} ... {sample_targets[-8:]}")
|
||||
@@ -897,17 +894,21 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
ep_length = ep_end - ep_start
|
||||
episode_lengths.append(ep_length)
|
||||
|
||||
# Choose random anchor - need at least T-1 frames before for [-15..0] window
|
||||
min_anchor = T - 1
|
||||
# Choose random anchor - need enough frames before for stride sampling
|
||||
# For T=16 and stride=2, we need frames [anchor-30, anchor-28, ..., anchor-2, anchor]
|
||||
stride = self.config.temporal_sampling_stride
|
||||
min_anchor = (T - 1) * stride # Need (T-1)*stride frames before anchor
|
||||
max_anchor = max(min_anchor, ep_length - 1)
|
||||
anchor = torch.randint(min_anchor, max_anchor + 1, (1,)).item()
|
||||
anchor_positions.append(anchor)
|
||||
|
||||
# Build window indices with reflection padding
|
||||
# Build window indices with configurable stride sampling and reflection padding
|
||||
window_indices = []
|
||||
frame_indices_for_progress = [] # Track actual frame positions for progress
|
||||
had_oob = False
|
||||
for delta in range(-(T-1), 1): # [-15, -14, ..., 0] for T=16
|
||||
# Sample with stride: [anchor-(T-1)*stride, anchor-(T-2)*stride, ..., anchor-stride, anchor]
|
||||
for i in range(T):
|
||||
delta = -(T - 1 - i) * stride # Work backwards from anchor with stride spacing
|
||||
idx = anchor + delta
|
||||
actual_frame_idx = idx # Store the actual frame index before reflection
|
||||
if idx < 0:
|
||||
|
||||
Reference in New Issue
Block a user