diff --git a/src/lerobot/policies/rlearn/modeling_rlearn.py b/src/lerobot/policies/rlearn/modeling_rlearn.py index 556747eb8..8b1bf97ab 100644 --- a/src/lerobot/policies/rlearn/modeling_rlearn.py +++ b/src/lerobot/policies/rlearn/modeling_rlearn.py @@ -162,14 +162,17 @@ class RLearNPolicy(PreTrainedPolicy): self.to_lang_tokens = nn.Linear(self.text_hidden, config.dim_model) self.to_video_tokens = nn.Linear(self.vision_hidden, config.dim_model) - # Stronger temporal positional encoding to distinguish between frames - # This helps the model learn distinct representations for each frame in the sequence + # Stronger temporal positional encoding self.temporal_pos_embedding = nn.Parameter(torch.randn(config.max_seq_len, config.dim_model) * 0.1) - # Add frame-specific processing to prevent over-smoothing + # CRITICAL: Frame-specific MLPs prevent temporal over-smoothing + # Problem: Transformer attention was making all 16 predictions identical (e.g. all 0.34) + # Solution: Each temporal position gets its own specialized MLP processing + # Frame 0 → MLP[0], Frame 1 → MLP[1], ..., Frame 15 → MLP[15] + # This creates distinct pathways for each frame while preserving attention context self.frame_specific_mlp = nn.ModuleList([ nn.Linear(config.dim_model, config.dim_model) - for _ in range(config.max_seq_len) + for _ in range(config.max_seq_len) # 16 separate MLPs for 16 frame positions ]) # Register / memory / attention sink tokens @@ -607,6 +610,9 @@ class RLearNPolicy(PreTrainedPolicy): raw_logits = self.reward_head(normalized_embeds).squeeze(-1) preds = self.sigmoid(raw_logits) + # Randomly sample a sequence from the batch for detailed analysis + sample_idx = torch.randint(0, B, (1,)).item() + print(f"\n=== DEBUG TRAINING ===") # Target statistics print(f"Target min: {target.min():.6f}") @@ -621,8 +627,9 @@ class RLearNPolicy(PreTrainedPolicy): print(f"Sigmoid pred range: [{preds.min():.3f}, {preds.max():.3f}]") print(f"Sigmoid pred mean: {preds.mean():.3f}") print(f"Loss: {loss:.4f}") - print("First sample targets (all 16):", target[0].cpu().numpy()) - print("First sample preds (all 16):", preds[0].cpu().numpy()) + # Show randomly sampled sequence for comparison + print(f"Sample {sample_idx} targets (all 16):", target[sample_idx].cpu().numpy()) + print(f"Sample {sample_idx} preds (all 16): ", preds[sample_idx].cpu().numpy()) print("="*25) total_forward_time = time.perf_counter() - forward_start