refactor(rl): make algorithm a nested config so all SAC hyperparameters are JSON-addressable

This commit is contained in:
Khalil Meftah
2026-04-27 13:39:03 +02:00
parent 21c16a27f0
commit 9ce9e01469
9 changed files with 86 additions and 56 deletions
+7 -6
View File
@@ -26,9 +26,9 @@ pytest.importorskip("grpc")
from torch.multiprocessing import Event, Queue
from lerobot.configs.train import TrainRLServerPipelineConfig
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.policies.gaussian_actor.configuration_gaussian_actor import GaussianActorConfig
from lerobot.rl.train_rl import TrainRLServerPipelineConfig
from lerobot.utils.constants import ACTION, OBS_STATE, OBS_STR
from lerobot.utils.transition import Transition
from tests.utils import skip_if_package_missing
@@ -314,7 +314,7 @@ def test_learner_algorithm_wiring():
get_weights() output is serializable."""
from lerobot.policies.gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy
from lerobot.rl.algorithms.factory import make_algorithm
from lerobot.rl.algorithms.sac import SACAlgorithm
from lerobot.rl.algorithms.sac import SACAlgorithm, SACAlgorithmConfig
from lerobot.transport.utils import state_to_bytes
state_dim = 10
@@ -333,7 +333,7 @@ def test_learner_algorithm_wiring():
policy = GaussianActorPolicy(config=sac_cfg)
policy.train()
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
algorithm = make_algorithm(cfg=SACAlgorithmConfig.from_policy_config(sac_cfg), policy=policy)
assert isinstance(algorithm, SACAlgorithm)
optimizers = algorithm.make_optimizers_and_scheduler()
@@ -400,6 +400,7 @@ def test_initial_and_periodic_weight_push_consistency():
and produce identical structures."""
from lerobot.policies.gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy
from lerobot.rl.algorithms.factory import make_algorithm
from lerobot.rl.algorithms.sac import SACAlgorithmConfig
from lerobot.transport.utils import bytes_to_state_dict, state_to_bytes
state_dim = 10
@@ -416,7 +417,7 @@ def test_initial_and_periodic_weight_push_consistency():
policy = GaussianActorPolicy(config=sac_cfg)
policy.train()
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
algorithm = make_algorithm(cfg=SACAlgorithmConfig.from_policy_config(sac_cfg), policy=policy)
algorithm.make_optimizers_and_scheduler()
# Simulate initial push (same code path the learner now uses)
@@ -437,7 +438,7 @@ def test_actor_side_algorithm_select_action_and_load_weights():
"""Simulate actor: create algorithm without optimizers, select_action, load_weights."""
from lerobot.policies.gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy
from lerobot.rl.algorithms.factory import make_algorithm
from lerobot.rl.algorithms.sac import SACAlgorithm
from lerobot.rl.algorithms.sac import SACAlgorithm, SACAlgorithmConfig
state_dim = 10
action_dim = 6
@@ -454,7 +455,7 @@ def test_actor_side_algorithm_select_action_and_load_weights():
# Actor side: no optimizers
policy = GaussianActorPolicy(config=sac_cfg)
policy.eval()
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
algorithm = make_algorithm(cfg=SACAlgorithmConfig.from_policy_config(sac_cfg), policy=policy)
assert isinstance(algorithm, SACAlgorithm)
assert algorithm.optimizers == {}
+9 -11
View File
@@ -348,7 +348,7 @@ def test_optimization_step_can_be_set_for_resume():
def test_make_algorithm_returns_sac_for_sac_policy():
sac_cfg = _make_sac_config()
policy = GaussianActorPolicy(config=sac_cfg)
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
algorithm = make_algorithm(cfg=SACAlgorithmConfig.from_policy_config(sac_cfg), policy=policy)
assert isinstance(algorithm, SACAlgorithm)
assert algorithm.optimizers == {}
@@ -357,7 +357,7 @@ def test_make_optimizers_creates_expected_keys():
"""make_optimizers_and_scheduler() should populate the algorithm with Adam optimizers."""
sac_cfg = _make_sac_config()
policy = GaussianActorPolicy(config=sac_cfg)
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
algorithm = make_algorithm(cfg=SACAlgorithmConfig.from_policy_config(sac_cfg), policy=policy)
optimizers = algorithm.make_optimizers_and_scheduler()
assert "actor" in optimizers
assert "critic" in optimizers
@@ -370,7 +370,7 @@ def test_actor_side_no_optimizers():
"""Actor-side usage: no optimizers needed, make_optimizers_and_scheduler is not called."""
sac_cfg = _make_sac_config()
policy = GaussianActorPolicy(config=sac_cfg)
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
algorithm = make_algorithm(cfg=SACAlgorithmConfig.from_policy_config(sac_cfg), policy=policy)
assert isinstance(algorithm, SACAlgorithm)
assert algorithm.optimizers == {}
@@ -379,18 +379,16 @@ def test_make_algorithm_uses_sac_algorithm_defaults():
"""make_algorithm populates SACAlgorithmConfig with its own defaults."""
sac_cfg = _make_sac_config()
policy = GaussianActorPolicy(config=sac_cfg)
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
algorithm = make_algorithm(cfg=SACAlgorithmConfig.from_policy_config(sac_cfg), policy=policy)
assert algorithm.config.utd_ratio == 1
assert algorithm.config.policy_update_freq == 1
assert algorithm.config.grad_clip_norm == 40.0
def test_make_algorithm_raises_for_unknown_type():
class FakeConfig:
type = "unknown_algo"
with pytest.raises(ValueError, match="No RLAlgorithmConfig"):
make_algorithm(policy=None, policy_cfg=FakeConfig(), algorithm_name="unknown_algo")
def test_unknown_algorithm_name_raises_in_registry():
"""The ChoiceRegistry is the source of truth for unknown algorithm names."""
with pytest.raises(KeyError):
RLAlgorithmConfig.get_choice_class("unknown_algo")
# ===========================================================================
@@ -523,5 +521,5 @@ def test_make_algorithm_uses_build_algorithm():
"""make_algorithm should delegate to config.build_algorithm (no hardcoded if/else)."""
sac_cfg = _make_sac_config()
policy = GaussianActorPolicy(config=sac_cfg)
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
algorithm = make_algorithm(cfg=SACAlgorithmConfig.from_policy_config(sac_cfg), policy=policy)
assert isinstance(algorithm, SACAlgorithm)