From 9ce9e0146933fa3c5c6164f490ba1f8bf910cbfd Mon Sep 17 00:00:00 2001 From: Khalil Meftah Date: Mon, 27 Apr 2026 13:39:03 +0200 Subject: [PATCH] refactor(rl): make algorithm a nested config so all SAC hyperparameters are JSON-addressable --- src/lerobot/configs/train.py | 15 ------ src/lerobot/rl/actor.py | 2 +- src/lerobot/rl/algorithms/factory.py | 15 +----- .../rl/algorithms/sac/configuration_sac.py | 13 +++-- src/lerobot/rl/eval_policy.py | 2 +- src/lerobot/rl/learner.py | 8 +-- src/lerobot/rl/train_rl.py | 54 +++++++++++++++++++ tests/rl/test_actor_learner.py | 13 ++--- tests/rl/test_sac_algorithm.py | 20 ++++--- 9 files changed, 86 insertions(+), 56 deletions(-) create mode 100644 src/lerobot/rl/train_rl.py diff --git a/src/lerobot/configs/train.py b/src/lerobot/configs/train.py index f6118cc67..e681db2e8 100644 --- a/src/lerobot/configs/train.py +++ b/src/lerobot/configs/train.py @@ -207,18 +207,3 @@ class TrainPipelineConfig(HubMixin): cli_args = kwargs.pop("cli_args", []) with draccus.config_type("json"): return draccus.parse(cls, config_file, args=cli_args) - - -@dataclass(kw_only=True) -class TrainRLServerPipelineConfig(TrainPipelineConfig): - # NOTE: In RL, we don't need an offline dataset - # TODO: Make `TrainPipelineConfig.dataset` optional - dataset: DatasetConfig | None = None # type: ignore[assignment] # because the parent class has made it's type non-optional - - # Algorithm name registered in RLAlgorithmConfig registry - algorithm: str = "sac" - - # Data mixer strategy name. Currently supports "online_offline" - mixer: str = "online_offline" - # Fraction sampled from online replay when using OnlineOfflineMixer - online_ratio: float = 0.5 diff --git a/src/lerobot/rl/actor.py b/src/lerobot/rl/actor.py index fb50e0b11..a3cc0478e 100644 --- a/src/lerobot/rl/actor.py +++ b/src/lerobot/rl/actor.py @@ -60,11 +60,11 @@ from torch.multiprocessing import Queue from lerobot.cameras import opencv # noqa: F401 from lerobot.configs import parser -from lerobot.configs.train import TrainRLServerPipelineConfig from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors from lerobot.processor import TransitionKey from lerobot.rl.process import ProcessSignalHandler from lerobot.rl.queue import get_last_item_from_queue +from lerobot.rl.train_rl import TrainRLServerPipelineConfig from lerobot.robots import so_follower # noqa: F401 from lerobot.teleoperators import gamepad, so_leader # noqa: F401 from lerobot.teleoperators.utils import TeleopEvents diff --git a/src/lerobot/rl/algorithms/factory.py b/src/lerobot/rl/algorithms/factory.py index 70622c5f2..3704fe1e7 100644 --- a/src/lerobot/rl/algorithms/factory.py +++ b/src/lerobot/rl/algorithms/factory.py @@ -20,16 +20,5 @@ from lerobot.rl.algorithms.base import RLAlgorithm from lerobot.rl.algorithms.configs import RLAlgorithmConfig -def make_algorithm( - policy: torch.nn.Module, - policy_cfg, - *, - algorithm_name: str, -) -> RLAlgorithm: - known = RLAlgorithmConfig.get_known_choices() - if algorithm_name not in known: - raise ValueError(f"No RLAlgorithmConfig registered for '{algorithm_name}'. Known: {list(known)}") - - config_cls = RLAlgorithmConfig.get_choice_class(algorithm_name) - algo_config = config_cls.from_policy_config(policy_cfg) - return algo_config.build_algorithm(policy) +def make_algorithm(cfg: RLAlgorithmConfig, policy: torch.nn.Module) -> RLAlgorithm: + return cfg.build_algorithm(policy) diff --git a/src/lerobot/rl/algorithms/sac/configuration_sac.py b/src/lerobot/rl/algorithms/sac/configuration_sac.py index cbce441d4..0ccb740f3 100644 --- a/src/lerobot/rl/algorithms/sac/configuration_sac.py +++ b/src/lerobot/rl/algorithms/sac/configuration_sac.py @@ -34,9 +34,6 @@ if TYPE_CHECKING: class SACAlgorithmConfig(RLAlgorithmConfig): """SAC algorithm hyperparameters.""" - # Policy config - policy_config: GaussianActorConfig - # Optimizer learning rates actor_lr: float = 3e-4 critic_lr: float = 3e-4 @@ -69,6 +66,9 @@ class SACAlgorithmConfig(RLAlgorithmConfig): # torch.compile is currently disabled by default use_torch_compile: bool = False + # Policy config + policy_config: GaussianActorConfig | None = None + @classmethod def from_policy_config(cls, policy_cfg: GaussianActorConfig) -> SACAlgorithmConfig: """Build an algorithm config with default hyperparameters for a given policy.""" @@ -78,6 +78,13 @@ class SACAlgorithmConfig(RLAlgorithmConfig): ) def build_algorithm(self, policy: torch.nn.Module) -> SACAlgorithm: + if self.policy_config is None: + raise ValueError( + "SACAlgorithmConfig.policy_config is None. " + "It must be populated (typically by TrainRLServerPipelineConfig.validate) " + "before calling build_algorithm()." + ) + from lerobot.rl.algorithms.sac.sac_algorithm import SACAlgorithm return SACAlgorithm(policy=policy, config=self) diff --git a/src/lerobot/rl/eval_policy.py b/src/lerobot/rl/eval_policy.py index 4398351c5..b7eb25e95 100644 --- a/src/lerobot/rl/eval_policy.py +++ b/src/lerobot/rl/eval_policy.py @@ -17,9 +17,9 @@ import logging from lerobot.cameras import opencv # noqa: F401 from lerobot.configs import parser -from lerobot.configs.train import TrainRLServerPipelineConfig from lerobot.datasets import LeRobotDataset from lerobot.policies import make_policy +from lerobot.rl.train_rl import TrainRLServerPipelineConfig from lerobot.robots import ( # noqa: F401 RobotConfig, make_robot_from_config, diff --git a/src/lerobot/rl/learner.py b/src/lerobot/rl/learner.py index 334b8bbb2..afeb44f22 100644 --- a/src/lerobot/rl/learner.py +++ b/src/lerobot/rl/learner.py @@ -69,7 +69,6 @@ from lerobot.common.train_utils import ( ) from lerobot.common.wandb_utils import WandBLogger from lerobot.configs import parser -from lerobot.configs.train import TrainRLServerPipelineConfig from lerobot.datasets import LeRobotDataset, make_dataset from lerobot.policies import make_policy, make_pre_post_processors from lerobot.rl.algorithms.base import RLAlgorithm @@ -77,6 +76,7 @@ from lerobot.rl.algorithms.factory import make_algorithm from lerobot.rl.buffer import ReplayBuffer from lerobot.rl.data_sources import OnlineOfflineMixer from lerobot.rl.process import ProcessSignalHandler +from lerobot.rl.train_rl import TrainRLServerPipelineConfig from lerobot.rl.trainer import RLTrainer from lerobot.robots import so_follower # noqa: F401 from lerobot.teleoperators import gamepad, so_leader # noqa: F401 @@ -316,11 +316,7 @@ def add_actor_information_and_train( policy.train() - algorithm = make_algorithm( - policy=policy, - policy_cfg=cfg.policy, - algorithm_name=cfg.algorithm, - ) + algorithm = make_algorithm(cfg=cfg.algorithm, policy=policy) preprocessor, postprocessor = make_pre_post_processors( policy_cfg=cfg.policy, diff --git a/src/lerobot/rl/train_rl.py b/src/lerobot/rl/train_rl.py new file mode 100644 index 000000000..442856bf5 --- /dev/null +++ b/src/lerobot/rl/train_rl.py @@ -0,0 +1,54 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Top-level pipeline config for distributed RL training (actor / learner).""" + +from __future__ import annotations + +from dataclasses import dataclass + +from lerobot.configs.default import DatasetConfig +from lerobot.configs.train import TrainPipelineConfig +from lerobot.rl.algorithms.configs import RLAlgorithmConfig +from lerobot.rl.algorithms.sac import SACAlgorithmConfig # noqa: F401 + + +@dataclass(kw_only=True) +class TrainRLServerPipelineConfig(TrainPipelineConfig): + # NOTE: In RL, we don't need an offline dataset + # TODO: Make `TrainPipelineConfig.dataset` optional + dataset: DatasetConfig | None = None # type: ignore[assignment] # because the parent class has made it's type non-optional + + # Algorithm config (a `draccus.ChoiceRegistry` subclass selected by `type`, + # e.g. ``"type": "sac"``). When omitted, defaults to a SAC config with + # default hyperparameters. The top-level `policy` is injected into + # ``algorithm.policy_config`` at validation time. + algorithm: RLAlgorithmConfig | None = None + + # Data mixer strategy name. Currently supports "online_offline". + mixer: str = "online_offline" + # Fraction sampled from online replay when using OnlineOfflineMixer. + online_ratio: float = 0.5 + + def validate(self) -> None: + super().validate() + + if self.algorithm is None: + sac_cls = RLAlgorithmConfig.get_choice_class("sac") + self.algorithm = sac_cls() + + # The pipeline owns the policy config; inject it so the algorithm can + # introspect policy architecture (e.g. ``num_discrete_actions``). + if getattr(self.algorithm, "policy_config", None) is None: + self.algorithm.policy_config = self.policy diff --git a/tests/rl/test_actor_learner.py b/tests/rl/test_actor_learner.py index 3dc65118f..33c150997 100644 --- a/tests/rl/test_actor_learner.py +++ b/tests/rl/test_actor_learner.py @@ -26,9 +26,9 @@ pytest.importorskip("grpc") from torch.multiprocessing import Event, Queue -from lerobot.configs.train import TrainRLServerPipelineConfig from lerobot.configs.types import FeatureType, PolicyFeature from lerobot.policies.gaussian_actor.configuration_gaussian_actor import GaussianActorConfig +from lerobot.rl.train_rl import TrainRLServerPipelineConfig from lerobot.utils.constants import ACTION, OBS_STATE, OBS_STR from lerobot.utils.transition import Transition from tests.utils import skip_if_package_missing @@ -314,7 +314,7 @@ def test_learner_algorithm_wiring(): get_weights() output is serializable.""" from lerobot.policies.gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy from lerobot.rl.algorithms.factory import make_algorithm - from lerobot.rl.algorithms.sac import SACAlgorithm + from lerobot.rl.algorithms.sac import SACAlgorithm, SACAlgorithmConfig from lerobot.transport.utils import state_to_bytes state_dim = 10 @@ -333,7 +333,7 @@ def test_learner_algorithm_wiring(): policy = GaussianActorPolicy(config=sac_cfg) policy.train() - algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac") + algorithm = make_algorithm(cfg=SACAlgorithmConfig.from_policy_config(sac_cfg), policy=policy) assert isinstance(algorithm, SACAlgorithm) optimizers = algorithm.make_optimizers_and_scheduler() @@ -400,6 +400,7 @@ def test_initial_and_periodic_weight_push_consistency(): and produce identical structures.""" from lerobot.policies.gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy from lerobot.rl.algorithms.factory import make_algorithm + from lerobot.rl.algorithms.sac import SACAlgorithmConfig from lerobot.transport.utils import bytes_to_state_dict, state_to_bytes state_dim = 10 @@ -416,7 +417,7 @@ def test_initial_and_periodic_weight_push_consistency(): policy = GaussianActorPolicy(config=sac_cfg) policy.train() - algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac") + algorithm = make_algorithm(cfg=SACAlgorithmConfig.from_policy_config(sac_cfg), policy=policy) algorithm.make_optimizers_and_scheduler() # Simulate initial push (same code path the learner now uses) @@ -437,7 +438,7 @@ def test_actor_side_algorithm_select_action_and_load_weights(): """Simulate actor: create algorithm without optimizers, select_action, load_weights.""" from lerobot.policies.gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy from lerobot.rl.algorithms.factory import make_algorithm - from lerobot.rl.algorithms.sac import SACAlgorithm + from lerobot.rl.algorithms.sac import SACAlgorithm, SACAlgorithmConfig state_dim = 10 action_dim = 6 @@ -454,7 +455,7 @@ def test_actor_side_algorithm_select_action_and_load_weights(): # Actor side: no optimizers policy = GaussianActorPolicy(config=sac_cfg) policy.eval() - algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac") + algorithm = make_algorithm(cfg=SACAlgorithmConfig.from_policy_config(sac_cfg), policy=policy) assert isinstance(algorithm, SACAlgorithm) assert algorithm.optimizers == {} diff --git a/tests/rl/test_sac_algorithm.py b/tests/rl/test_sac_algorithm.py index a2653b6cb..4c4568da0 100644 --- a/tests/rl/test_sac_algorithm.py +++ b/tests/rl/test_sac_algorithm.py @@ -348,7 +348,7 @@ def test_optimization_step_can_be_set_for_resume(): def test_make_algorithm_returns_sac_for_sac_policy(): sac_cfg = _make_sac_config() policy = GaussianActorPolicy(config=sac_cfg) - algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac") + algorithm = make_algorithm(cfg=SACAlgorithmConfig.from_policy_config(sac_cfg), policy=policy) assert isinstance(algorithm, SACAlgorithm) assert algorithm.optimizers == {} @@ -357,7 +357,7 @@ def test_make_optimizers_creates_expected_keys(): """make_optimizers_and_scheduler() should populate the algorithm with Adam optimizers.""" sac_cfg = _make_sac_config() policy = GaussianActorPolicy(config=sac_cfg) - algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac") + algorithm = make_algorithm(cfg=SACAlgorithmConfig.from_policy_config(sac_cfg), policy=policy) optimizers = algorithm.make_optimizers_and_scheduler() assert "actor" in optimizers assert "critic" in optimizers @@ -370,7 +370,7 @@ def test_actor_side_no_optimizers(): """Actor-side usage: no optimizers needed, make_optimizers_and_scheduler is not called.""" sac_cfg = _make_sac_config() policy = GaussianActorPolicy(config=sac_cfg) - algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac") + algorithm = make_algorithm(cfg=SACAlgorithmConfig.from_policy_config(sac_cfg), policy=policy) assert isinstance(algorithm, SACAlgorithm) assert algorithm.optimizers == {} @@ -379,18 +379,16 @@ def test_make_algorithm_uses_sac_algorithm_defaults(): """make_algorithm populates SACAlgorithmConfig with its own defaults.""" sac_cfg = _make_sac_config() policy = GaussianActorPolicy(config=sac_cfg) - algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac") + algorithm = make_algorithm(cfg=SACAlgorithmConfig.from_policy_config(sac_cfg), policy=policy) assert algorithm.config.utd_ratio == 1 assert algorithm.config.policy_update_freq == 1 assert algorithm.config.grad_clip_norm == 40.0 -def test_make_algorithm_raises_for_unknown_type(): - class FakeConfig: - type = "unknown_algo" - - with pytest.raises(ValueError, match="No RLAlgorithmConfig"): - make_algorithm(policy=None, policy_cfg=FakeConfig(), algorithm_name="unknown_algo") +def test_unknown_algorithm_name_raises_in_registry(): + """The ChoiceRegistry is the source of truth for unknown algorithm names.""" + with pytest.raises(KeyError): + RLAlgorithmConfig.get_choice_class("unknown_algo") # =========================================================================== @@ -523,5 +521,5 @@ def test_make_algorithm_uses_build_algorithm(): """make_algorithm should delegate to config.build_algorithm (no hardcoded if/else).""" sac_cfg = _make_sac_config() policy = GaussianActorPolicy(config=sac_cfg) - algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac") + algorithm = make_algorithm(cfg=SACAlgorithmConfig.from_policy_config(sac_cfg), policy=policy) assert isinstance(algorithm, SACAlgorithm)