remove type exp

This commit is contained in:
Michel Aractingi
2026-01-14 17:09:56 +01:00
parent 94efcea867
commit 0264ac717b
+1 -1
View File
@@ -101,7 +101,7 @@ def update_policy(
if sample_weights is not None:
# Use per-sample loss for weighted training
# Note: Policies supporting sample weighting must implement forward(batch, reduction="none")
per_sample_loss, output_dict = policy.forward(batch, reduction="none") # type: ignore[call-arg]
per_sample_loss, output_dict = policy.forward(batch, reduction="none")
# Apply sample weights: L_weighted = Σ(w_i * l_i) / (Σw_i + ε)
# Weights are already normalized to sum to batch_size