mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-14 16:19:45 +00:00
refactor(rl): make algorithm a nested config so all SAC hyperparameters are JSON-addressable
This commit is contained in:
@@ -348,7 +348,7 @@ def test_optimization_step_can_be_set_for_resume():
|
||||
def test_make_algorithm_returns_sac_for_sac_policy():
|
||||
sac_cfg = _make_sac_config()
|
||||
policy = GaussianActorPolicy(config=sac_cfg)
|
||||
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
|
||||
algorithm = make_algorithm(cfg=SACAlgorithmConfig.from_policy_config(sac_cfg), policy=policy)
|
||||
assert isinstance(algorithm, SACAlgorithm)
|
||||
assert algorithm.optimizers == {}
|
||||
|
||||
@@ -357,7 +357,7 @@ def test_make_optimizers_creates_expected_keys():
|
||||
"""make_optimizers_and_scheduler() should populate the algorithm with Adam optimizers."""
|
||||
sac_cfg = _make_sac_config()
|
||||
policy = GaussianActorPolicy(config=sac_cfg)
|
||||
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
|
||||
algorithm = make_algorithm(cfg=SACAlgorithmConfig.from_policy_config(sac_cfg), policy=policy)
|
||||
optimizers = algorithm.make_optimizers_and_scheduler()
|
||||
assert "actor" in optimizers
|
||||
assert "critic" in optimizers
|
||||
@@ -370,7 +370,7 @@ def test_actor_side_no_optimizers():
|
||||
"""Actor-side usage: no optimizers needed, make_optimizers_and_scheduler is not called."""
|
||||
sac_cfg = _make_sac_config()
|
||||
policy = GaussianActorPolicy(config=sac_cfg)
|
||||
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
|
||||
algorithm = make_algorithm(cfg=SACAlgorithmConfig.from_policy_config(sac_cfg), policy=policy)
|
||||
assert isinstance(algorithm, SACAlgorithm)
|
||||
assert algorithm.optimizers == {}
|
||||
|
||||
@@ -379,18 +379,16 @@ def test_make_algorithm_uses_sac_algorithm_defaults():
|
||||
"""make_algorithm populates SACAlgorithmConfig with its own defaults."""
|
||||
sac_cfg = _make_sac_config()
|
||||
policy = GaussianActorPolicy(config=sac_cfg)
|
||||
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
|
||||
algorithm = make_algorithm(cfg=SACAlgorithmConfig.from_policy_config(sac_cfg), policy=policy)
|
||||
assert algorithm.config.utd_ratio == 1
|
||||
assert algorithm.config.policy_update_freq == 1
|
||||
assert algorithm.config.grad_clip_norm == 40.0
|
||||
|
||||
|
||||
def test_make_algorithm_raises_for_unknown_type():
|
||||
class FakeConfig:
|
||||
type = "unknown_algo"
|
||||
|
||||
with pytest.raises(ValueError, match="No RLAlgorithmConfig"):
|
||||
make_algorithm(policy=None, policy_cfg=FakeConfig(), algorithm_name="unknown_algo")
|
||||
def test_unknown_algorithm_name_raises_in_registry():
|
||||
"""The ChoiceRegistry is the source of truth for unknown algorithm names."""
|
||||
with pytest.raises(KeyError):
|
||||
RLAlgorithmConfig.get_choice_class("unknown_algo")
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
@@ -523,5 +521,5 @@ def test_make_algorithm_uses_build_algorithm():
|
||||
"""make_algorithm should delegate to config.build_algorithm (no hardcoded if/else)."""
|
||||
sac_cfg = _make_sac_config()
|
||||
policy = GaussianActorPolicy(config=sac_cfg)
|
||||
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
|
||||
algorithm = make_algorithm(cfg=SACAlgorithmConfig.from_policy_config(sac_cfg), policy=policy)
|
||||
assert isinstance(algorithm, SACAlgorithm)
|
||||
|
||||
Reference in New Issue
Block a user