From e7617076ca1744e90543aea17c3671c0e629bd15 Mon Sep 17 00:00:00 2001 From: Pepijn Date: Sun, 31 Aug 2025 16:03:24 +0200 Subject: [PATCH] cleanup --- .../policies/rlearn/modeling_rlearn.py | 39 +++++++++++++++++-- 1 file changed, 36 insertions(+), 3 deletions(-) diff --git a/src/lerobot/policies/rlearn/modeling_rlearn.py b/src/lerobot/policies/rlearn/modeling_rlearn.py index 9e00f0cf6..ae66096fd 100644 --- a/src/lerobot/policies/rlearn/modeling_rlearn.py +++ b/src/lerobot/policies/rlearn/modeling_rlearn.py @@ -76,6 +76,7 @@ Notes from __future__ import annotations import math +import numpy as np from itertools import chain from operator import truediv @@ -560,7 +561,7 @@ class RLearNPolicy(PreTrainedPolicy): total_loss = loss + L_mismatch loss_time = time.perf_counter() - loss_start - # DEBUG: Clean logit regression monitoring + # DEBUG: Clean logit regression monitoring with full array printing if self.training and torch.rand(1).item() < 0.03: with torch.no_grad(): sample_idx = torch.randint(0, B, (1,)).item() @@ -573,9 +574,41 @@ class RLearNPolicy(PreTrainedPolicy): print(f"āœ“ Has targets >0.8: {has_high_targets} | T_eff: {T_eff}") print(f"Logits: min={raw_logits.min():.3f}, max={raw_logits.max():.3f}, mean={raw_logits.mean():.3f}") print(f"Preds: min={predicted_rewards.min():.3f}, max={predicted_rewards.max():.3f}, mean={predicted_rewards.mean():.3f}") - print(f"Sample {sample_idx} (T_eff={T_eff}): target_range=[{sample_targets.min():.3f}, {sample_targets.max():.3f}] pred_range=[{sample_preds.min():.3f}, {sample_preds.max():.3f}]") + + # Show full arrays occasionally (25% chance within debug) + show_full = torch.rand(1).item() < 0.25 + if show_full: + print(f"\nšŸ“Š FULL SAMPLE {sample_idx} ARRAYS (T_eff={T_eff}):") + # Always show full arrays up to 16 frames + if T_eff <= 16: + print(f" Targets: {sample_targets}") + print(f" Preds: {sample_preds}") + + # Show differences and error metrics + diffs = sample_preds - sample_targets + print(f" Errors: {diffs}") + mae = np.abs(diffs).mean() + mse = (diffs ** 2).mean() + max_error = np.abs(diffs).max() + print(f" MAE: {mae:.4f} | MSE: {mse:.4f} | Max Error: {max_error:.4f}") + + # Check if predictions are stuck or varying + pred_std = sample_preds.std() + target_std = sample_targets.std() + print(f" Variation - Target std: {target_std:.4f} | Pred std: {pred_std:.4f}") + if pred_std < 0.01: + print(f" āš ļø PREDICTIONS STUCK (std={pred_std:.5f})") + else: + print(f" āœ“ Predictions varying normally") + else: + # For longer sequences, show first 8 and last 8 + print(f" Targets: {sample_targets[:8]} ... {sample_targets[-8:]}") + print(f" Preds: {sample_preds[:8]} ... {sample_preds[-8:]}") + else: + print(f"Sample {sample_idx}: T_eff={T_eff}, target ∈ [{sample_targets.min():.3f}, {sample_targets.max():.3f}], pred ∈ [{sample_preds.min():.3f}, {sample_preds.max():.3f}]") + print(f"Loss: {loss:.6f}") - print("=" * 40) + print("=" * 60) total_forward_time = time.perf_counter() - forward_start