change head init

This commit is contained in:
Pepijn
2025-08-31 01:02:25 +02:00
parent 7739fe12e4
commit a1b1643ff6
2 changed files with 53 additions and 58 deletions
File diff suppressed because one or more lines are too long
@@ -181,6 +181,13 @@ class RLearNPolicy(PreTrainedPolicy):
# MSE regression head with sigmoid activation to bound outputs to [0,1] # MSE regression head with sigmoid activation to bound outputs to [0,1]
self.reward_head = nn.Linear(config.dim_model, 1) self.reward_head = nn.Linear(config.dim_model, 1)
# Initialize with small weights to prevent sigmoid saturation
# Target: sigmoid(0) = 0.5, so we want raw logits around [-2, 2] range
with torch.no_grad():
self.reward_head.weight.normal_(0.0, 0.01) # Much smaller std
self.reward_head.bias.fill_(0.0) # Start at sigmoid(0) = 0.5
self.sigmoid = nn.Sigmoid() self.sigmoid = nn.Sigmoid()
# Simple frame dropout probability # Simple frame dropout probability
@@ -568,7 +575,8 @@ class RLearNPolicy(PreTrainedPolicy):
print(f"Target range: [{target.min():.3f}, {target.max():.3f}]") print(f"Target range: [{target.min():.3f}, {target.max():.3f}]")
# Model output statistics # Model output statistics
print(f"Raw MLP range: [{raw_outputs.min():.3f}, {raw_outputs.max():.3f}]") print(f"Raw MLP range: [{raw_outputs.min():.3f}, {raw_outputs.max():.3f}]")
print(f"Raw logits range: [{raw_logits.min():.3f}, {raw_logits.max():.3f}]") print(f"Raw logits range: [{raw_logits.min():.6f}, {raw_logits.max():.6f}]")
print(f"Raw logits mean: {raw_logits.mean():.6f}")
print(f"Sigmoid pred range: [{preds.min():.3f}, {preds.max():.3f}]") print(f"Sigmoid pred range: [{preds.min():.3f}, {preds.max():.3f}]")
print(f"Sigmoid pred mean: {preds.mean():.3f}") print(f"Sigmoid pred mean: {preds.mean():.3f}")
print(f"Loss: {loss:.4f}") print(f"Loss: {loss:.4f}")