mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 20:19:43 +00:00
change head init
This commit is contained in:
File diff suppressed because one or more lines are too long
@@ -181,6 +181,13 @@ class RLearNPolicy(PreTrainedPolicy):
|
|||||||
|
|
||||||
# MSE regression head with sigmoid activation to bound outputs to [0,1]
|
# MSE regression head with sigmoid activation to bound outputs to [0,1]
|
||||||
self.reward_head = nn.Linear(config.dim_model, 1)
|
self.reward_head = nn.Linear(config.dim_model, 1)
|
||||||
|
|
||||||
|
# Initialize with small weights to prevent sigmoid saturation
|
||||||
|
# Target: sigmoid(0) = 0.5, so we want raw logits around [-2, 2] range
|
||||||
|
with torch.no_grad():
|
||||||
|
self.reward_head.weight.normal_(0.0, 0.01) # Much smaller std
|
||||||
|
self.reward_head.bias.fill_(0.0) # Start at sigmoid(0) = 0.5
|
||||||
|
|
||||||
self.sigmoid = nn.Sigmoid()
|
self.sigmoid = nn.Sigmoid()
|
||||||
|
|
||||||
# Simple frame dropout probability
|
# Simple frame dropout probability
|
||||||
@@ -568,7 +575,8 @@ class RLearNPolicy(PreTrainedPolicy):
|
|||||||
print(f"Target range: [{target.min():.3f}, {target.max():.3f}]")
|
print(f"Target range: [{target.min():.3f}, {target.max():.3f}]")
|
||||||
# Model output statistics
|
# Model output statistics
|
||||||
print(f"Raw MLP range: [{raw_outputs.min():.3f}, {raw_outputs.max():.3f}]")
|
print(f"Raw MLP range: [{raw_outputs.min():.3f}, {raw_outputs.max():.3f}]")
|
||||||
print(f"Raw logits range: [{raw_logits.min():.3f}, {raw_logits.max():.3f}]")
|
print(f"Raw logits range: [{raw_logits.min():.6f}, {raw_logits.max():.6f}]")
|
||||||
|
print(f"Raw logits mean: {raw_logits.mean():.6f}")
|
||||||
print(f"Sigmoid pred range: [{preds.min():.3f}, {preds.max():.3f}]")
|
print(f"Sigmoid pred range: [{preds.min():.3f}, {preds.max():.3f}]")
|
||||||
print(f"Sigmoid pred mean: {preds.mean():.3f}")
|
print(f"Sigmoid pred mean: {preds.mean():.3f}")
|
||||||
print(f"Loss: {loss:.4f}")
|
print(f"Loss: {loss:.4f}")
|
||||||
|
|||||||
Reference in New Issue
Block a user