mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-23 12:40:08 +00:00
fix
This commit is contained in:
File diff suppressed because one or more lines are too long
@@ -186,7 +186,7 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
min_value=config.reward_min_value,
|
||||
max_value=config.reward_max_value,
|
||||
num_bins=config.reward_hl_gauss_loss_num_bins,
|
||||
) if config.use_hl_gauss_loss else None
|
||||
) # Always provide config, HLGaussLayer needs it even for regression mode
|
||||
)
|
||||
|
||||
# Simple frame dropout probability
|
||||
@@ -559,10 +559,13 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
# DEBUG: Print targets and predictions occasionally during training
|
||||
if self.training and torch.rand(1).item() < 0.02: # ~2% chance to debug print
|
||||
with torch.no_grad():
|
||||
# Get raw MLP outputs before HLGauss
|
||||
raw_outputs = video_frame_embeds
|
||||
preds = self.hl_gauss_layer(video_frame_embeds).squeeze(-1)
|
||||
print(f"\n=== DEBUG TRAINING ===")
|
||||
print(f"Target range: [{target.min():.3f}, {target.max():.3f}]")
|
||||
print(f"Target mean: {target.mean():.3f}")
|
||||
print(f"Raw MLP range: [{raw_outputs.min():.3f}, {raw_outputs.max():.3f}]")
|
||||
print(f"Pred range: [{preds.min():.3f}, {preds.max():.3f}]")
|
||||
print(f"Pred mean: {preds.mean():.3f}")
|
||||
print(f"Loss: {loss:.4f}")
|
||||
|
||||
Reference in New Issue
Block a user