mirror of
https://github.com/huggingface/lerobot.git
synced 2026-07-03 16:17:15 +00:00
cleanup
This commit is contained in:
@@ -76,6 +76,7 @@ Notes
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
import numpy as np
|
||||
from itertools import chain
|
||||
from operator import truediv
|
||||
|
||||
@@ -560,7 +561,7 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
total_loss = loss + L_mismatch
|
||||
loss_time = time.perf_counter() - loss_start
|
||||
|
||||
# DEBUG: Clean logit regression monitoring
|
||||
# DEBUG: Clean logit regression monitoring with full array printing
|
||||
if self.training and torch.rand(1).item() < 0.03:
|
||||
with torch.no_grad():
|
||||
sample_idx = torch.randint(0, B, (1,)).item()
|
||||
@@ -573,9 +574,41 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
print(f"✓ Has targets >0.8: {has_high_targets} | T_eff: {T_eff}")
|
||||
print(f"Logits: min={raw_logits.min():.3f}, max={raw_logits.max():.3f}, mean={raw_logits.mean():.3f}")
|
||||
print(f"Preds: min={predicted_rewards.min():.3f}, max={predicted_rewards.max():.3f}, mean={predicted_rewards.mean():.3f}")
|
||||
print(f"Sample {sample_idx} (T_eff={T_eff}): target_range=[{sample_targets.min():.3f}, {sample_targets.max():.3f}] pred_range=[{sample_preds.min():.3f}, {sample_preds.max():.3f}]")
|
||||
|
||||
# Show full arrays occasionally (25% chance within debug)
|
||||
show_full = torch.rand(1).item() < 0.25
|
||||
if show_full:
|
||||
print(f"\n📊 FULL SAMPLE {sample_idx} ARRAYS (T_eff={T_eff}):")
|
||||
# Always show full arrays up to 16 frames
|
||||
if T_eff <= 16:
|
||||
print(f" Targets: {sample_targets}")
|
||||
print(f" Preds: {sample_preds}")
|
||||
|
||||
# Show differences and error metrics
|
||||
diffs = sample_preds - sample_targets
|
||||
print(f" Errors: {diffs}")
|
||||
mae = np.abs(diffs).mean()
|
||||
mse = (diffs ** 2).mean()
|
||||
max_error = np.abs(diffs).max()
|
||||
print(f" MAE: {mae:.4f} | MSE: {mse:.4f} | Max Error: {max_error:.4f}")
|
||||
|
||||
# Check if predictions are stuck or varying
|
||||
pred_std = sample_preds.std()
|
||||
target_std = sample_targets.std()
|
||||
print(f" Variation - Target std: {target_std:.4f} | Pred std: {pred_std:.4f}")
|
||||
if pred_std < 0.01:
|
||||
print(f" ⚠️ PREDICTIONS STUCK (std={pred_std:.5f})")
|
||||
else:
|
||||
print(f" ✓ Predictions varying normally")
|
||||
else:
|
||||
# For longer sequences, show first 8 and last 8
|
||||
print(f" Targets: {sample_targets[:8]} ... {sample_targets[-8:]}")
|
||||
print(f" Preds: {sample_preds[:8]} ... {sample_preds[-8:]}")
|
||||
else:
|
||||
print(f"Sample {sample_idx}: T_eff={T_eff}, target ∈ [{sample_targets.min():.3f}, {sample_targets.max():.3f}], pred ∈ [{sample_preds.min():.3f}, {sample_preds.max():.3f}]")
|
||||
|
||||
print(f"Loss: {loss:.6f}")
|
||||
print("=" * 40)
|
||||
print("=" * 60)
|
||||
|
||||
total_forward_time = time.perf_counter() - forward_start
|
||||
|
||||
|
||||
Reference in New Issue
Block a user