mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-21 19:49:49 +00:00
update losses names in tests
This commit is contained in:
@@ -356,7 +356,7 @@ def test_learner_algorithm_wiring():
|
|||||||
}
|
}
|
||||||
|
|
||||||
stats = algorithm.update(batch_iterator())
|
stats = algorithm.update(batch_iterator())
|
||||||
assert "critic" in stats.losses
|
assert "loss_critic" in stats.losses
|
||||||
|
|
||||||
# get_weights -> state_to_bytes round-trip
|
# get_weights -> state_to_bytes round-trip
|
||||||
weights = algorithm.get_weights()
|
weights = algorithm.get_weights()
|
||||||
@@ -393,7 +393,7 @@ def test_learner_algorithm_wiring():
|
|||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
)
|
)
|
||||||
trainer_stats = trainer.training_step()
|
trainer_stats = trainer.training_step()
|
||||||
assert "critic" in trainer_stats.losses
|
assert "loss_critic" in trainer_stats.losses
|
||||||
|
|
||||||
|
|
||||||
def test_initial_and_periodic_weight_push_consistency():
|
def test_initial_and_periodic_weight_push_consistency():
|
||||||
|
|||||||
@@ -223,16 +223,16 @@ def test_update_returns_training_stats():
|
|||||||
algorithm, _ = _make_algorithm()
|
algorithm, _ = _make_algorithm()
|
||||||
stats = algorithm.update(_batch_iterator())
|
stats = algorithm.update(_batch_iterator())
|
||||||
assert isinstance(stats, TrainingStats)
|
assert isinstance(stats, TrainingStats)
|
||||||
assert "critic" in stats.losses
|
assert "loss_critic" in stats.losses
|
||||||
assert isinstance(stats.losses["critic"], float)
|
assert isinstance(stats.losses["loss_critic"], float)
|
||||||
|
|
||||||
|
|
||||||
def test_update_populates_actor_and_temperature_losses():
|
def test_update_populates_actor_and_temperature_losses():
|
||||||
"""With policy_update_freq=1 and step 0, actor/temperature should be updated."""
|
"""With policy_update_freq=1 and step 0, actor/temperature should be updated."""
|
||||||
algorithm, _ = _make_algorithm(policy_update_freq=1)
|
algorithm, _ = _make_algorithm(policy_update_freq=1)
|
||||||
stats = algorithm.update(_batch_iterator())
|
stats = algorithm.update(_batch_iterator())
|
||||||
assert "actor" in stats.losses
|
assert "loss_actor" in stats.losses
|
||||||
assert "temperature" in stats.losses
|
assert "loss_temperature" in stats.losses
|
||||||
assert "temperature" in stats.extra
|
assert "temperature" in stats.extra
|
||||||
|
|
||||||
|
|
||||||
@@ -244,11 +244,11 @@ def test_update_skips_actor_at_non_update_steps(policy_update_freq):
|
|||||||
|
|
||||||
# Step 0: should update actor
|
# Step 0: should update actor
|
||||||
stats_0 = algorithm.update(it)
|
stats_0 = algorithm.update(it)
|
||||||
assert "actor" in stats_0.losses
|
assert "loss_actor" in stats_0.losses
|
||||||
|
|
||||||
# Step 1: should NOT update actor
|
# Step 1: should NOT update actor
|
||||||
stats_1 = algorithm.update(it)
|
stats_1 = algorithm.update(it)
|
||||||
assert "actor" not in stats_1.losses
|
assert "loss_actor" not in stats_1.losses
|
||||||
|
|
||||||
|
|
||||||
def test_update_increments_optimization_step():
|
def test_update_increments_optimization_step():
|
||||||
@@ -264,7 +264,7 @@ def test_update_increments_optimization_step():
|
|||||||
def test_update_with_discrete_critic():
|
def test_update_with_discrete_critic():
|
||||||
algorithm, _ = _make_algorithm(num_discrete_actions=3, action_dim=6)
|
algorithm, _ = _make_algorithm(num_discrete_actions=3, action_dim=6)
|
||||||
stats = algorithm.update(_batch_iterator(action_dim=7)) # continuous + 1 discrete
|
stats = algorithm.update(_batch_iterator(action_dim=7)) # continuous + 1 discrete
|
||||||
assert "discrete_critic" in stats.losses
|
assert "loss_discrete_critic" in stats.losses
|
||||||
assert "discrete_critic" in stats.grad_norms
|
assert "discrete_critic" in stats.grad_norms
|
||||||
|
|
||||||
|
|
||||||
@@ -278,7 +278,7 @@ def test_update_with_utd_ratio(utd_ratio):
|
|||||||
algorithm, _ = _make_algorithm(utd_ratio=utd_ratio)
|
algorithm, _ = _make_algorithm(utd_ratio=utd_ratio)
|
||||||
stats = algorithm.update(_batch_iterator())
|
stats = algorithm.update(_batch_iterator())
|
||||||
assert isinstance(stats, TrainingStats)
|
assert isinstance(stats, TrainingStats)
|
||||||
assert "critic" in stats.losses
|
assert "loss_critic" in stats.losses
|
||||||
assert algorithm.optimization_step == 1
|
assert algorithm.optimization_step == 1
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user