mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 17:20:05 +00:00
refactor(rl): make algorithm a nested config so all SAC hyperparameters are JSON-addressable
This commit is contained in:
@@ -26,9 +26,9 @@ pytest.importorskip("grpc")
|
||||
|
||||
from torch.multiprocessing import Event, Queue
|
||||
|
||||
from lerobot.configs.train import TrainRLServerPipelineConfig
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.policies.gaussian_actor.configuration_gaussian_actor import GaussianActorConfig
|
||||
from lerobot.rl.train_rl import TrainRLServerPipelineConfig
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE, OBS_STR
|
||||
from lerobot.utils.transition import Transition
|
||||
from tests.utils import skip_if_package_missing
|
||||
@@ -314,7 +314,7 @@ def test_learner_algorithm_wiring():
|
||||
get_weights() output is serializable."""
|
||||
from lerobot.policies.gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy
|
||||
from lerobot.rl.algorithms.factory import make_algorithm
|
||||
from lerobot.rl.algorithms.sac import SACAlgorithm
|
||||
from lerobot.rl.algorithms.sac import SACAlgorithm, SACAlgorithmConfig
|
||||
from lerobot.transport.utils import state_to_bytes
|
||||
|
||||
state_dim = 10
|
||||
@@ -333,7 +333,7 @@ def test_learner_algorithm_wiring():
|
||||
policy = GaussianActorPolicy(config=sac_cfg)
|
||||
policy.train()
|
||||
|
||||
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
|
||||
algorithm = make_algorithm(cfg=SACAlgorithmConfig.from_policy_config(sac_cfg), policy=policy)
|
||||
assert isinstance(algorithm, SACAlgorithm)
|
||||
|
||||
optimizers = algorithm.make_optimizers_and_scheduler()
|
||||
@@ -400,6 +400,7 @@ def test_initial_and_periodic_weight_push_consistency():
|
||||
and produce identical structures."""
|
||||
from lerobot.policies.gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy
|
||||
from lerobot.rl.algorithms.factory import make_algorithm
|
||||
from lerobot.rl.algorithms.sac import SACAlgorithmConfig
|
||||
from lerobot.transport.utils import bytes_to_state_dict, state_to_bytes
|
||||
|
||||
state_dim = 10
|
||||
@@ -416,7 +417,7 @@ def test_initial_and_periodic_weight_push_consistency():
|
||||
|
||||
policy = GaussianActorPolicy(config=sac_cfg)
|
||||
policy.train()
|
||||
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
|
||||
algorithm = make_algorithm(cfg=SACAlgorithmConfig.from_policy_config(sac_cfg), policy=policy)
|
||||
algorithm.make_optimizers_and_scheduler()
|
||||
|
||||
# Simulate initial push (same code path the learner now uses)
|
||||
@@ -437,7 +438,7 @@ def test_actor_side_algorithm_select_action_and_load_weights():
|
||||
"""Simulate actor: create algorithm without optimizers, select_action, load_weights."""
|
||||
from lerobot.policies.gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy
|
||||
from lerobot.rl.algorithms.factory import make_algorithm
|
||||
from lerobot.rl.algorithms.sac import SACAlgorithm
|
||||
from lerobot.rl.algorithms.sac import SACAlgorithm, SACAlgorithmConfig
|
||||
|
||||
state_dim = 10
|
||||
action_dim = 6
|
||||
@@ -454,7 +455,7 @@ def test_actor_side_algorithm_select_action_and_load_weights():
|
||||
# Actor side: no optimizers
|
||||
policy = GaussianActorPolicy(config=sac_cfg)
|
||||
policy.eval()
|
||||
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
|
||||
algorithm = make_algorithm(cfg=SACAlgorithmConfig.from_policy_config(sac_cfg), policy=policy)
|
||||
assert isinstance(algorithm, SACAlgorithm)
|
||||
assert algorithm.optimizers == {}
|
||||
|
||||
|
||||
@@ -348,7 +348,7 @@ def test_optimization_step_can_be_set_for_resume():
|
||||
def test_make_algorithm_returns_sac_for_sac_policy():
|
||||
sac_cfg = _make_sac_config()
|
||||
policy = GaussianActorPolicy(config=sac_cfg)
|
||||
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
|
||||
algorithm = make_algorithm(cfg=SACAlgorithmConfig.from_policy_config(sac_cfg), policy=policy)
|
||||
assert isinstance(algorithm, SACAlgorithm)
|
||||
assert algorithm.optimizers == {}
|
||||
|
||||
@@ -357,7 +357,7 @@ def test_make_optimizers_creates_expected_keys():
|
||||
"""make_optimizers_and_scheduler() should populate the algorithm with Adam optimizers."""
|
||||
sac_cfg = _make_sac_config()
|
||||
policy = GaussianActorPolicy(config=sac_cfg)
|
||||
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
|
||||
algorithm = make_algorithm(cfg=SACAlgorithmConfig.from_policy_config(sac_cfg), policy=policy)
|
||||
optimizers = algorithm.make_optimizers_and_scheduler()
|
||||
assert "actor" in optimizers
|
||||
assert "critic" in optimizers
|
||||
@@ -370,7 +370,7 @@ def test_actor_side_no_optimizers():
|
||||
"""Actor-side usage: no optimizers needed, make_optimizers_and_scheduler is not called."""
|
||||
sac_cfg = _make_sac_config()
|
||||
policy = GaussianActorPolicy(config=sac_cfg)
|
||||
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
|
||||
algorithm = make_algorithm(cfg=SACAlgorithmConfig.from_policy_config(sac_cfg), policy=policy)
|
||||
assert isinstance(algorithm, SACAlgorithm)
|
||||
assert algorithm.optimizers == {}
|
||||
|
||||
@@ -379,18 +379,16 @@ def test_make_algorithm_uses_sac_algorithm_defaults():
|
||||
"""make_algorithm populates SACAlgorithmConfig with its own defaults."""
|
||||
sac_cfg = _make_sac_config()
|
||||
policy = GaussianActorPolicy(config=sac_cfg)
|
||||
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
|
||||
algorithm = make_algorithm(cfg=SACAlgorithmConfig.from_policy_config(sac_cfg), policy=policy)
|
||||
assert algorithm.config.utd_ratio == 1
|
||||
assert algorithm.config.policy_update_freq == 1
|
||||
assert algorithm.config.grad_clip_norm == 40.0
|
||||
|
||||
|
||||
def test_make_algorithm_raises_for_unknown_type():
|
||||
class FakeConfig:
|
||||
type = "unknown_algo"
|
||||
|
||||
with pytest.raises(ValueError, match="No RLAlgorithmConfig"):
|
||||
make_algorithm(policy=None, policy_cfg=FakeConfig(), algorithm_name="unknown_algo")
|
||||
def test_unknown_algorithm_name_raises_in_registry():
|
||||
"""The ChoiceRegistry is the source of truth for unknown algorithm names."""
|
||||
with pytest.raises(KeyError):
|
||||
RLAlgorithmConfig.get_choice_class("unknown_algo")
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
@@ -523,5 +521,5 @@ def test_make_algorithm_uses_build_algorithm():
|
||||
"""make_algorithm should delegate to config.build_algorithm (no hardcoded if/else)."""
|
||||
sac_cfg = _make_sac_config()
|
||||
policy = GaussianActorPolicy(config=sac_cfg)
|
||||
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
|
||||
algorithm = make_algorithm(cfg=SACAlgorithmConfig.from_policy_config(sac_cfg), policy=policy)
|
||||
assert isinstance(algorithm, SACAlgorithm)
|
||||
|
||||
Reference in New Issue
Block a user