This commit is contained in:
Pepijn
2025-08-31 16:03:24 +02:00
parent 221e5862ea
commit e7617076ca
+36 -3
View File
@@ -76,6 +76,7 @@ Notes
from __future__ import annotations
import math
import numpy as np
from itertools import chain
from operator import truediv
@@ -560,7 +561,7 @@ class RLearNPolicy(PreTrainedPolicy):
total_loss = loss + L_mismatch
loss_time = time.perf_counter() - loss_start
# DEBUG: Clean logit regression monitoring
# DEBUG: Clean logit regression monitoring with full array printing
if self.training and torch.rand(1).item() < 0.03:
with torch.no_grad():
sample_idx = torch.randint(0, B, (1,)).item()
@@ -573,9 +574,41 @@ class RLearNPolicy(PreTrainedPolicy):
print(f"✓ Has targets >0.8: {has_high_targets} | T_eff: {T_eff}")
print(f"Logits: min={raw_logits.min():.3f}, max={raw_logits.max():.3f}, mean={raw_logits.mean():.3f}")
print(f"Preds: min={predicted_rewards.min():.3f}, max={predicted_rewards.max():.3f}, mean={predicted_rewards.mean():.3f}")
print(f"Sample {sample_idx} (T_eff={T_eff}): target_range=[{sample_targets.min():.3f}, {sample_targets.max():.3f}] pred_range=[{sample_preds.min():.3f}, {sample_preds.max():.3f}]")
# Show full arrays occasionally (25% chance within debug)
show_full = torch.rand(1).item() < 0.25
if show_full:
print(f"\n📊 FULL SAMPLE {sample_idx} ARRAYS (T_eff={T_eff}):")
# Always show full arrays up to 16 frames
if T_eff <= 16:
print(f" Targets: {sample_targets}")
print(f" Preds: {sample_preds}")
# Show differences and error metrics
diffs = sample_preds - sample_targets
print(f" Errors: {diffs}")
mae = np.abs(diffs).mean()
mse = (diffs ** 2).mean()
max_error = np.abs(diffs).max()
print(f" MAE: {mae:.4f} | MSE: {mse:.4f} | Max Error: {max_error:.4f}")
# Check if predictions are stuck or varying
pred_std = sample_preds.std()
target_std = sample_targets.std()
print(f" Variation - Target std: {target_std:.4f} | Pred std: {pred_std:.4f}")
if pred_std < 0.01:
print(f" ⚠️ PREDICTIONS STUCK (std={pred_std:.5f})")
else:
print(f" ✓ Predictions varying normally")
else:
# For longer sequences, show first 8 and last 8
print(f" Targets: {sample_targets[:8]} ... {sample_targets[-8:]}")
print(f" Preds: {sample_preds[:8]} ... {sample_preds[-8:]}")
else:
print(f"Sample {sample_idx}: T_eff={T_eff}, target ∈ [{sample_targets.min():.3f}, {sample_targets.max():.3f}], pred ∈ [{sample_preds.min():.3f}, {sample_preds.max():.3f}]")
print(f"Loss: {loss:.6f}")
print("=" * 40)
print("=" * 60)
total_forward_time = time.perf_counter() - forward_start