random sample for log

This commit is contained in:
Pepijn
2025-08-31 01:33:58 +02:00
parent 1f38712c95
commit 852713dc84
+13 -6
View File
@@ -162,14 +162,17 @@ class RLearNPolicy(PreTrainedPolicy):
self.to_lang_tokens = nn.Linear(self.text_hidden, config.dim_model)
self.to_video_tokens = nn.Linear(self.vision_hidden, config.dim_model)
# Stronger temporal positional encoding to distinguish between frames
# This helps the model learn distinct representations for each frame in the sequence
# Stronger temporal positional encoding
self.temporal_pos_embedding = nn.Parameter(torch.randn(config.max_seq_len, config.dim_model) * 0.1)
# Add frame-specific processing to prevent over-smoothing
# CRITICAL: Frame-specific MLPs prevent temporal over-smoothing
# Problem: Transformer attention was making all 16 predictions identical (e.g. all 0.34)
# Solution: Each temporal position gets its own specialized MLP processing
# Frame 0 → MLP[0], Frame 1 → MLP[1], ..., Frame 15 → MLP[15]
# This creates distinct pathways for each frame while preserving attention context
self.frame_specific_mlp = nn.ModuleList([
nn.Linear(config.dim_model, config.dim_model)
for _ in range(config.max_seq_len)
for _ in range(config.max_seq_len) # 16 separate MLPs for 16 frame positions
])
# Register / memory / attention sink tokens
@@ -607,6 +610,9 @@ class RLearNPolicy(PreTrainedPolicy):
raw_logits = self.reward_head(normalized_embeds).squeeze(-1)
preds = self.sigmoid(raw_logits)
# Randomly sample a sequence from the batch for detailed analysis
sample_idx = torch.randint(0, B, (1,)).item()
print(f"\n=== DEBUG TRAINING ===")
# Target statistics
print(f"Target min: {target.min():.6f}")
@@ -621,8 +627,9 @@ class RLearNPolicy(PreTrainedPolicy):
print(f"Sigmoid pred range: [{preds.min():.3f}, {preds.max():.3f}]")
print(f"Sigmoid pred mean: {preds.mean():.3f}")
print(f"Loss: {loss:.4f}")
print("First sample targets (all 16):", target[0].cpu().numpy())
print("First sample preds (all 16):", preds[0].cpu().numpy())
# Show randomly sampled sequence for comparison
print(f"Sample {sample_idx} targets (all 16):", target[sample_idx].cpu().numpy())
print(f"Sample {sample_idx} preds (all 16): ", preds[sample_idx].cpu().numpy())
print("="*25)
total_forward_time = time.perf_counter() - forward_start