diff --git a/tests/rl/test_actor_learner.py b/tests/rl/test_actor_learner.py index 95a865594..59a0b0f7a 100644 --- a/tests/rl/test_actor_learner.py +++ b/tests/rl/test_actor_learner.py @@ -356,7 +356,7 @@ def test_learner_algorithm_wiring(): } stats = algorithm.update(batch_iterator()) - assert "critic" in stats.losses + assert "loss_critic" in stats.losses # get_weights -> state_to_bytes round-trip weights = algorithm.get_weights() @@ -393,7 +393,7 @@ def test_learner_algorithm_wiring(): batch_size=batch_size, ) trainer_stats = trainer.training_step() - assert "critic" in trainer_stats.losses + assert "loss_critic" in trainer_stats.losses def test_initial_and_periodic_weight_push_consistency(): diff --git a/tests/rl/test_sac_algorithm.py b/tests/rl/test_sac_algorithm.py index 79019777f..df69b7312 100644 --- a/tests/rl/test_sac_algorithm.py +++ b/tests/rl/test_sac_algorithm.py @@ -223,16 +223,16 @@ def test_update_returns_training_stats(): algorithm, _ = _make_algorithm() stats = algorithm.update(_batch_iterator()) assert isinstance(stats, TrainingStats) - assert "critic" in stats.losses - assert isinstance(stats.losses["critic"], float) + assert "loss_critic" in stats.losses + assert isinstance(stats.losses["loss_critic"], float) def test_update_populates_actor_and_temperature_losses(): """With policy_update_freq=1 and step 0, actor/temperature should be updated.""" algorithm, _ = _make_algorithm(policy_update_freq=1) stats = algorithm.update(_batch_iterator()) - assert "actor" in stats.losses - assert "temperature" in stats.losses + assert "loss_actor" in stats.losses + assert "loss_temperature" in stats.losses assert "temperature" in stats.extra @@ -244,11 +244,11 @@ def test_update_skips_actor_at_non_update_steps(policy_update_freq): # Step 0: should update actor stats_0 = algorithm.update(it) - assert "actor" in stats_0.losses + assert "loss_actor" in stats_0.losses # Step 1: should NOT update actor stats_1 = algorithm.update(it) - assert "actor" not in stats_1.losses + assert "loss_actor" not in stats_1.losses def test_update_increments_optimization_step(): @@ -264,7 +264,7 @@ def test_update_increments_optimization_step(): def test_update_with_discrete_critic(): algorithm, _ = _make_algorithm(num_discrete_actions=3, action_dim=6) stats = algorithm.update(_batch_iterator(action_dim=7)) # continuous + 1 discrete - assert "discrete_critic" in stats.losses + assert "loss_discrete_critic" in stats.losses assert "discrete_critic" in stats.grad_norms @@ -278,7 +278,7 @@ def test_update_with_utd_ratio(utd_ratio): algorithm, _ = _make_algorithm(utd_ratio=utd_ratio) stats = algorithm.update(_batch_iterator()) assert isinstance(stats, TrainingStats) - assert "critic" in stats.losses + assert "loss_critic" in stats.losses assert algorithm.optimization_step == 1