mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-30 06:37:15 +00:00
random sample for log
This commit is contained in:
@@ -162,14 +162,17 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
self.to_lang_tokens = nn.Linear(self.text_hidden, config.dim_model)
|
||||
self.to_video_tokens = nn.Linear(self.vision_hidden, config.dim_model)
|
||||
|
||||
# Stronger temporal positional encoding to distinguish between frames
|
||||
# This helps the model learn distinct representations for each frame in the sequence
|
||||
# Stronger temporal positional encoding
|
||||
self.temporal_pos_embedding = nn.Parameter(torch.randn(config.max_seq_len, config.dim_model) * 0.1)
|
||||
|
||||
# Add frame-specific processing to prevent over-smoothing
|
||||
# CRITICAL: Frame-specific MLPs prevent temporal over-smoothing
|
||||
# Problem: Transformer attention was making all 16 predictions identical (e.g. all 0.34)
|
||||
# Solution: Each temporal position gets its own specialized MLP processing
|
||||
# Frame 0 → MLP[0], Frame 1 → MLP[1], ..., Frame 15 → MLP[15]
|
||||
# This creates distinct pathways for each frame while preserving attention context
|
||||
self.frame_specific_mlp = nn.ModuleList([
|
||||
nn.Linear(config.dim_model, config.dim_model)
|
||||
for _ in range(config.max_seq_len)
|
||||
for _ in range(config.max_seq_len) # 16 separate MLPs for 16 frame positions
|
||||
])
|
||||
|
||||
# Register / memory / attention sink tokens
|
||||
@@ -607,6 +610,9 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
raw_logits = self.reward_head(normalized_embeds).squeeze(-1)
|
||||
preds = self.sigmoid(raw_logits)
|
||||
|
||||
# Randomly sample a sequence from the batch for detailed analysis
|
||||
sample_idx = torch.randint(0, B, (1,)).item()
|
||||
|
||||
print(f"\n=== DEBUG TRAINING ===")
|
||||
# Target statistics
|
||||
print(f"Target min: {target.min():.6f}")
|
||||
@@ -621,8 +627,9 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
print(f"Sigmoid pred range: [{preds.min():.3f}, {preds.max():.3f}]")
|
||||
print(f"Sigmoid pred mean: {preds.mean():.3f}")
|
||||
print(f"Loss: {loss:.4f}")
|
||||
print("First sample targets (all 16):", target[0].cpu().numpy())
|
||||
print("First sample preds (all 16):", preds[0].cpu().numpy())
|
||||
# Show randomly sampled sequence for comparison
|
||||
print(f"Sample {sample_idx} targets (all 16):", target[sample_idx].cpu().numpy())
|
||||
print(f"Sample {sample_idx} preds (all 16): ", preds[sample_idx].cpu().numpy())
|
||||
print("="*25)
|
||||
|
||||
total_forward_time = time.perf_counter() - forward_start
|
||||
|
||||
Reference in New Issue
Block a user