mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-26 05:59:52 +00:00
fix
This commit is contained in:
File diff suppressed because one or more lines are too long
@@ -186,7 +186,7 @@ class RLearNPolicy(PreTrainedPolicy):
|
|||||||
min_value=config.reward_min_value,
|
min_value=config.reward_min_value,
|
||||||
max_value=config.reward_max_value,
|
max_value=config.reward_max_value,
|
||||||
num_bins=config.reward_hl_gauss_loss_num_bins,
|
num_bins=config.reward_hl_gauss_loss_num_bins,
|
||||||
) if config.use_hl_gauss_loss else None
|
) # Always provide config, HLGaussLayer needs it even for regression mode
|
||||||
)
|
)
|
||||||
|
|
||||||
# Simple frame dropout probability
|
# Simple frame dropout probability
|
||||||
@@ -559,10 +559,13 @@ class RLearNPolicy(PreTrainedPolicy):
|
|||||||
# DEBUG: Print targets and predictions occasionally during training
|
# DEBUG: Print targets and predictions occasionally during training
|
||||||
if self.training and torch.rand(1).item() < 0.02: # ~2% chance to debug print
|
if self.training and torch.rand(1).item() < 0.02: # ~2% chance to debug print
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
# Get raw MLP outputs before HLGauss
|
||||||
|
raw_outputs = video_frame_embeds
|
||||||
preds = self.hl_gauss_layer(video_frame_embeds).squeeze(-1)
|
preds = self.hl_gauss_layer(video_frame_embeds).squeeze(-1)
|
||||||
print(f"\n=== DEBUG TRAINING ===")
|
print(f"\n=== DEBUG TRAINING ===")
|
||||||
print(f"Target range: [{target.min():.3f}, {target.max():.3f}]")
|
print(f"Target range: [{target.min():.3f}, {target.max():.3f}]")
|
||||||
print(f"Target mean: {target.mean():.3f}")
|
print(f"Target mean: {target.mean():.3f}")
|
||||||
|
print(f"Raw MLP range: [{raw_outputs.min():.3f}, {raw_outputs.max():.3f}]")
|
||||||
print(f"Pred range: [{preds.min():.3f}, {preds.max():.3f}]")
|
print(f"Pred range: [{preds.min():.3f}, {preds.max():.3f}]")
|
||||||
print(f"Pred mean: {preds.mean():.3f}")
|
print(f"Pred mean: {preds.mean():.3f}")
|
||||||
print(f"Loss: {loss:.4f}")
|
print(f"Loss: {loss:.4f}")
|
||||||
|
|||||||
Reference in New Issue
Block a user