update losses names in tests

This commit is contained in:
Khalil Meftah
2026-04-21 11:53:32 +02:00
parent a84b0e8132
commit a4c0c9e358
2 changed files with 10 additions and 10 deletions
+8 -8
View File
@@ -223,16 +223,16 @@ def test_update_returns_training_stats():
algorithm, _ = _make_algorithm()
stats = algorithm.update(_batch_iterator())
assert isinstance(stats, TrainingStats)
assert "critic" in stats.losses
assert isinstance(stats.losses["critic"], float)
assert "loss_critic" in stats.losses
assert isinstance(stats.losses["loss_critic"], float)
def test_update_populates_actor_and_temperature_losses():
"""With policy_update_freq=1 and step 0, actor/temperature should be updated."""
algorithm, _ = _make_algorithm(policy_update_freq=1)
stats = algorithm.update(_batch_iterator())
assert "actor" in stats.losses
assert "temperature" in stats.losses
assert "loss_actor" in stats.losses
assert "loss_temperature" in stats.losses
assert "temperature" in stats.extra
@@ -244,11 +244,11 @@ def test_update_skips_actor_at_non_update_steps(policy_update_freq):
# Step 0: should update actor
stats_0 = algorithm.update(it)
assert "actor" in stats_0.losses
assert "loss_actor" in stats_0.losses
# Step 1: should NOT update actor
stats_1 = algorithm.update(it)
assert "actor" not in stats_1.losses
assert "loss_actor" not in stats_1.losses
def test_update_increments_optimization_step():
@@ -264,7 +264,7 @@ def test_update_increments_optimization_step():
def test_update_with_discrete_critic():
algorithm, _ = _make_algorithm(num_discrete_actions=3, action_dim=6)
stats = algorithm.update(_batch_iterator(action_dim=7)) # continuous + 1 discrete
assert "discrete_critic" in stats.losses
assert "loss_discrete_critic" in stats.losses
assert "discrete_critic" in stats.grad_norms
@@ -278,7 +278,7 @@ def test_update_with_utd_ratio(utd_ratio):
algorithm, _ = _make_algorithm(utd_ratio=utd_ratio)
stats = algorithm.update(_batch_iterator())
assert isinstance(stats, TrainingStats)
assert "critic" in stats.losses
assert "loss_critic" in stats.losses
assert algorithm.optimization_step == 1