refactor(rl): make algorithm a nested config so all SAC hyperparameters are JSON-addressable

This commit is contained in:
Khalil Meftah
2026-04-27 13:39:03 +02:00
parent 21c16a27f0
commit 9ce9e01469
9 changed files with 86 additions and 56 deletions
-15
View File
@@ -207,18 +207,3 @@ class TrainPipelineConfig(HubMixin):
cli_args = kwargs.pop("cli_args", []) cli_args = kwargs.pop("cli_args", [])
with draccus.config_type("json"): with draccus.config_type("json"):
return draccus.parse(cls, config_file, args=cli_args) return draccus.parse(cls, config_file, args=cli_args)
@dataclass(kw_only=True)
class TrainRLServerPipelineConfig(TrainPipelineConfig):
# NOTE: In RL, we don't need an offline dataset
# TODO: Make `TrainPipelineConfig.dataset` optional
dataset: DatasetConfig | None = None # type: ignore[assignment] # because the parent class has made it's type non-optional
# Algorithm name registered in RLAlgorithmConfig registry
algorithm: str = "sac"
# Data mixer strategy name. Currently supports "online_offline"
mixer: str = "online_offline"
# Fraction sampled from online replay when using OnlineOfflineMixer
online_ratio: float = 0.5
+1 -1
View File
@@ -60,11 +60,11 @@ from torch.multiprocessing import Queue
from lerobot.cameras import opencv # noqa: F401 from lerobot.cameras import opencv # noqa: F401
from lerobot.configs import parser from lerobot.configs import parser
from lerobot.configs.train import TrainRLServerPipelineConfig
from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors
from lerobot.processor import TransitionKey from lerobot.processor import TransitionKey
from lerobot.rl.process import ProcessSignalHandler from lerobot.rl.process import ProcessSignalHandler
from lerobot.rl.queue import get_last_item_from_queue from lerobot.rl.queue import get_last_item_from_queue
from lerobot.rl.train_rl import TrainRLServerPipelineConfig
from lerobot.robots import so_follower # noqa: F401 from lerobot.robots import so_follower # noqa: F401
from lerobot.teleoperators import gamepad, so_leader # noqa: F401 from lerobot.teleoperators import gamepad, so_leader # noqa: F401
from lerobot.teleoperators.utils import TeleopEvents from lerobot.teleoperators.utils import TeleopEvents
+2 -13
View File
@@ -20,16 +20,5 @@ from lerobot.rl.algorithms.base import RLAlgorithm
from lerobot.rl.algorithms.configs import RLAlgorithmConfig from lerobot.rl.algorithms.configs import RLAlgorithmConfig
def make_algorithm( def make_algorithm(cfg: RLAlgorithmConfig, policy: torch.nn.Module) -> RLAlgorithm:
policy: torch.nn.Module, return cfg.build_algorithm(policy)
policy_cfg,
*,
algorithm_name: str,
) -> RLAlgorithm:
known = RLAlgorithmConfig.get_known_choices()
if algorithm_name not in known:
raise ValueError(f"No RLAlgorithmConfig registered for '{algorithm_name}'. Known: {list(known)}")
config_cls = RLAlgorithmConfig.get_choice_class(algorithm_name)
algo_config = config_cls.from_policy_config(policy_cfg)
return algo_config.build_algorithm(policy)
@@ -34,9 +34,6 @@ if TYPE_CHECKING:
class SACAlgorithmConfig(RLAlgorithmConfig): class SACAlgorithmConfig(RLAlgorithmConfig):
"""SAC algorithm hyperparameters.""" """SAC algorithm hyperparameters."""
# Policy config
policy_config: GaussianActorConfig
# Optimizer learning rates # Optimizer learning rates
actor_lr: float = 3e-4 actor_lr: float = 3e-4
critic_lr: float = 3e-4 critic_lr: float = 3e-4
@@ -69,6 +66,9 @@ class SACAlgorithmConfig(RLAlgorithmConfig):
# torch.compile is currently disabled by default # torch.compile is currently disabled by default
use_torch_compile: bool = False use_torch_compile: bool = False
# Policy config
policy_config: GaussianActorConfig | None = None
@classmethod @classmethod
def from_policy_config(cls, policy_cfg: GaussianActorConfig) -> SACAlgorithmConfig: def from_policy_config(cls, policy_cfg: GaussianActorConfig) -> SACAlgorithmConfig:
"""Build an algorithm config with default hyperparameters for a given policy.""" """Build an algorithm config with default hyperparameters for a given policy."""
@@ -78,6 +78,13 @@ class SACAlgorithmConfig(RLAlgorithmConfig):
) )
def build_algorithm(self, policy: torch.nn.Module) -> SACAlgorithm: def build_algorithm(self, policy: torch.nn.Module) -> SACAlgorithm:
if self.policy_config is None:
raise ValueError(
"SACAlgorithmConfig.policy_config is None. "
"It must be populated (typically by TrainRLServerPipelineConfig.validate) "
"before calling build_algorithm()."
)
from lerobot.rl.algorithms.sac.sac_algorithm import SACAlgorithm from lerobot.rl.algorithms.sac.sac_algorithm import SACAlgorithm
return SACAlgorithm(policy=policy, config=self) return SACAlgorithm(policy=policy, config=self)
+1 -1
View File
@@ -17,9 +17,9 @@ import logging
from lerobot.cameras import opencv # noqa: F401 from lerobot.cameras import opencv # noqa: F401
from lerobot.configs import parser from lerobot.configs import parser
from lerobot.configs.train import TrainRLServerPipelineConfig
from lerobot.datasets import LeRobotDataset from lerobot.datasets import LeRobotDataset
from lerobot.policies import make_policy from lerobot.policies import make_policy
from lerobot.rl.train_rl import TrainRLServerPipelineConfig
from lerobot.robots import ( # noqa: F401 from lerobot.robots import ( # noqa: F401
RobotConfig, RobotConfig,
make_robot_from_config, make_robot_from_config,
+2 -6
View File
@@ -69,7 +69,6 @@ from lerobot.common.train_utils import (
) )
from lerobot.common.wandb_utils import WandBLogger from lerobot.common.wandb_utils import WandBLogger
from lerobot.configs import parser from lerobot.configs import parser
from lerobot.configs.train import TrainRLServerPipelineConfig
from lerobot.datasets import LeRobotDataset, make_dataset from lerobot.datasets import LeRobotDataset, make_dataset
from lerobot.policies import make_policy, make_pre_post_processors from lerobot.policies import make_policy, make_pre_post_processors
from lerobot.rl.algorithms.base import RLAlgorithm from lerobot.rl.algorithms.base import RLAlgorithm
@@ -77,6 +76,7 @@ from lerobot.rl.algorithms.factory import make_algorithm
from lerobot.rl.buffer import ReplayBuffer from lerobot.rl.buffer import ReplayBuffer
from lerobot.rl.data_sources import OnlineOfflineMixer from lerobot.rl.data_sources import OnlineOfflineMixer
from lerobot.rl.process import ProcessSignalHandler from lerobot.rl.process import ProcessSignalHandler
from lerobot.rl.train_rl import TrainRLServerPipelineConfig
from lerobot.rl.trainer import RLTrainer from lerobot.rl.trainer import RLTrainer
from lerobot.robots import so_follower # noqa: F401 from lerobot.robots import so_follower # noqa: F401
from lerobot.teleoperators import gamepad, so_leader # noqa: F401 from lerobot.teleoperators import gamepad, so_leader # noqa: F401
@@ -316,11 +316,7 @@ def add_actor_information_and_train(
policy.train() policy.train()
algorithm = make_algorithm( algorithm = make_algorithm(cfg=cfg.algorithm, policy=policy)
policy=policy,
policy_cfg=cfg.policy,
algorithm_name=cfg.algorithm,
)
preprocessor, postprocessor = make_pre_post_processors( preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=cfg.policy, policy_cfg=cfg.policy,
+54
View File
@@ -0,0 +1,54 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Top-level pipeline config for distributed RL training (actor / learner)."""
from __future__ import annotations
from dataclasses import dataclass
from lerobot.configs.default import DatasetConfig
from lerobot.configs.train import TrainPipelineConfig
from lerobot.rl.algorithms.configs import RLAlgorithmConfig
from lerobot.rl.algorithms.sac import SACAlgorithmConfig # noqa: F401
@dataclass(kw_only=True)
class TrainRLServerPipelineConfig(TrainPipelineConfig):
# NOTE: In RL, we don't need an offline dataset
# TODO: Make `TrainPipelineConfig.dataset` optional
dataset: DatasetConfig | None = None # type: ignore[assignment] # because the parent class has made it's type non-optional
# Algorithm config (a `draccus.ChoiceRegistry` subclass selected by `type`,
# e.g. ``"type": "sac"``). When omitted, defaults to a SAC config with
# default hyperparameters. The top-level `policy` is injected into
# ``algorithm.policy_config`` at validation time.
algorithm: RLAlgorithmConfig | None = None
# Data mixer strategy name. Currently supports "online_offline".
mixer: str = "online_offline"
# Fraction sampled from online replay when using OnlineOfflineMixer.
online_ratio: float = 0.5
def validate(self) -> None:
super().validate()
if self.algorithm is None:
sac_cls = RLAlgorithmConfig.get_choice_class("sac")
self.algorithm = sac_cls()
# The pipeline owns the policy config; inject it so the algorithm can
# introspect policy architecture (e.g. ``num_discrete_actions``).
if getattr(self.algorithm, "policy_config", None) is None:
self.algorithm.policy_config = self.policy
+7 -6
View File
@@ -26,9 +26,9 @@ pytest.importorskip("grpc")
from torch.multiprocessing import Event, Queue from torch.multiprocessing import Event, Queue
from lerobot.configs.train import TrainRLServerPipelineConfig
from lerobot.configs.types import FeatureType, PolicyFeature from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.policies.gaussian_actor.configuration_gaussian_actor import GaussianActorConfig from lerobot.policies.gaussian_actor.configuration_gaussian_actor import GaussianActorConfig
from lerobot.rl.train_rl import TrainRLServerPipelineConfig
from lerobot.utils.constants import ACTION, OBS_STATE, OBS_STR from lerobot.utils.constants import ACTION, OBS_STATE, OBS_STR
from lerobot.utils.transition import Transition from lerobot.utils.transition import Transition
from tests.utils import skip_if_package_missing from tests.utils import skip_if_package_missing
@@ -314,7 +314,7 @@ def test_learner_algorithm_wiring():
get_weights() output is serializable.""" get_weights() output is serializable."""
from lerobot.policies.gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy from lerobot.policies.gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy
from lerobot.rl.algorithms.factory import make_algorithm from lerobot.rl.algorithms.factory import make_algorithm
from lerobot.rl.algorithms.sac import SACAlgorithm from lerobot.rl.algorithms.sac import SACAlgorithm, SACAlgorithmConfig
from lerobot.transport.utils import state_to_bytes from lerobot.transport.utils import state_to_bytes
state_dim = 10 state_dim = 10
@@ -333,7 +333,7 @@ def test_learner_algorithm_wiring():
policy = GaussianActorPolicy(config=sac_cfg) policy = GaussianActorPolicy(config=sac_cfg)
policy.train() policy.train()
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac") algorithm = make_algorithm(cfg=SACAlgorithmConfig.from_policy_config(sac_cfg), policy=policy)
assert isinstance(algorithm, SACAlgorithm) assert isinstance(algorithm, SACAlgorithm)
optimizers = algorithm.make_optimizers_and_scheduler() optimizers = algorithm.make_optimizers_and_scheduler()
@@ -400,6 +400,7 @@ def test_initial_and_periodic_weight_push_consistency():
and produce identical structures.""" and produce identical structures."""
from lerobot.policies.gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy from lerobot.policies.gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy
from lerobot.rl.algorithms.factory import make_algorithm from lerobot.rl.algorithms.factory import make_algorithm
from lerobot.rl.algorithms.sac import SACAlgorithmConfig
from lerobot.transport.utils import bytes_to_state_dict, state_to_bytes from lerobot.transport.utils import bytes_to_state_dict, state_to_bytes
state_dim = 10 state_dim = 10
@@ -416,7 +417,7 @@ def test_initial_and_periodic_weight_push_consistency():
policy = GaussianActorPolicy(config=sac_cfg) policy = GaussianActorPolicy(config=sac_cfg)
policy.train() policy.train()
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac") algorithm = make_algorithm(cfg=SACAlgorithmConfig.from_policy_config(sac_cfg), policy=policy)
algorithm.make_optimizers_and_scheduler() algorithm.make_optimizers_and_scheduler()
# Simulate initial push (same code path the learner now uses) # Simulate initial push (same code path the learner now uses)
@@ -437,7 +438,7 @@ def test_actor_side_algorithm_select_action_and_load_weights():
"""Simulate actor: create algorithm without optimizers, select_action, load_weights.""" """Simulate actor: create algorithm without optimizers, select_action, load_weights."""
from lerobot.policies.gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy from lerobot.policies.gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy
from lerobot.rl.algorithms.factory import make_algorithm from lerobot.rl.algorithms.factory import make_algorithm
from lerobot.rl.algorithms.sac import SACAlgorithm from lerobot.rl.algorithms.sac import SACAlgorithm, SACAlgorithmConfig
state_dim = 10 state_dim = 10
action_dim = 6 action_dim = 6
@@ -454,7 +455,7 @@ def test_actor_side_algorithm_select_action_and_load_weights():
# Actor side: no optimizers # Actor side: no optimizers
policy = GaussianActorPolicy(config=sac_cfg) policy = GaussianActorPolicy(config=sac_cfg)
policy.eval() policy.eval()
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac") algorithm = make_algorithm(cfg=SACAlgorithmConfig.from_policy_config(sac_cfg), policy=policy)
assert isinstance(algorithm, SACAlgorithm) assert isinstance(algorithm, SACAlgorithm)
assert algorithm.optimizers == {} assert algorithm.optimizers == {}
+9 -11
View File
@@ -348,7 +348,7 @@ def test_optimization_step_can_be_set_for_resume():
def test_make_algorithm_returns_sac_for_sac_policy(): def test_make_algorithm_returns_sac_for_sac_policy():
sac_cfg = _make_sac_config() sac_cfg = _make_sac_config()
policy = GaussianActorPolicy(config=sac_cfg) policy = GaussianActorPolicy(config=sac_cfg)
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac") algorithm = make_algorithm(cfg=SACAlgorithmConfig.from_policy_config(sac_cfg), policy=policy)
assert isinstance(algorithm, SACAlgorithm) assert isinstance(algorithm, SACAlgorithm)
assert algorithm.optimizers == {} assert algorithm.optimizers == {}
@@ -357,7 +357,7 @@ def test_make_optimizers_creates_expected_keys():
"""make_optimizers_and_scheduler() should populate the algorithm with Adam optimizers.""" """make_optimizers_and_scheduler() should populate the algorithm with Adam optimizers."""
sac_cfg = _make_sac_config() sac_cfg = _make_sac_config()
policy = GaussianActorPolicy(config=sac_cfg) policy = GaussianActorPolicy(config=sac_cfg)
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac") algorithm = make_algorithm(cfg=SACAlgorithmConfig.from_policy_config(sac_cfg), policy=policy)
optimizers = algorithm.make_optimizers_and_scheduler() optimizers = algorithm.make_optimizers_and_scheduler()
assert "actor" in optimizers assert "actor" in optimizers
assert "critic" in optimizers assert "critic" in optimizers
@@ -370,7 +370,7 @@ def test_actor_side_no_optimizers():
"""Actor-side usage: no optimizers needed, make_optimizers_and_scheduler is not called.""" """Actor-side usage: no optimizers needed, make_optimizers_and_scheduler is not called."""
sac_cfg = _make_sac_config() sac_cfg = _make_sac_config()
policy = GaussianActorPolicy(config=sac_cfg) policy = GaussianActorPolicy(config=sac_cfg)
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac") algorithm = make_algorithm(cfg=SACAlgorithmConfig.from_policy_config(sac_cfg), policy=policy)
assert isinstance(algorithm, SACAlgorithm) assert isinstance(algorithm, SACAlgorithm)
assert algorithm.optimizers == {} assert algorithm.optimizers == {}
@@ -379,18 +379,16 @@ def test_make_algorithm_uses_sac_algorithm_defaults():
"""make_algorithm populates SACAlgorithmConfig with its own defaults.""" """make_algorithm populates SACAlgorithmConfig with its own defaults."""
sac_cfg = _make_sac_config() sac_cfg = _make_sac_config()
policy = GaussianActorPolicy(config=sac_cfg) policy = GaussianActorPolicy(config=sac_cfg)
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac") algorithm = make_algorithm(cfg=SACAlgorithmConfig.from_policy_config(sac_cfg), policy=policy)
assert algorithm.config.utd_ratio == 1 assert algorithm.config.utd_ratio == 1
assert algorithm.config.policy_update_freq == 1 assert algorithm.config.policy_update_freq == 1
assert algorithm.config.grad_clip_norm == 40.0 assert algorithm.config.grad_clip_norm == 40.0
def test_make_algorithm_raises_for_unknown_type(): def test_unknown_algorithm_name_raises_in_registry():
class FakeConfig: """The ChoiceRegistry is the source of truth for unknown algorithm names."""
type = "unknown_algo" with pytest.raises(KeyError):
RLAlgorithmConfig.get_choice_class("unknown_algo")
with pytest.raises(ValueError, match="No RLAlgorithmConfig"):
make_algorithm(policy=None, policy_cfg=FakeConfig(), algorithm_name="unknown_algo")
# =========================================================================== # ===========================================================================
@@ -523,5 +521,5 @@ def test_make_algorithm_uses_build_algorithm():
"""make_algorithm should delegate to config.build_algorithm (no hardcoded if/else).""" """make_algorithm should delegate to config.build_algorithm (no hardcoded if/else)."""
sac_cfg = _make_sac_config() sac_cfg = _make_sac_config()
policy = GaussianActorPolicy(config=sac_cfg) policy = GaussianActorPolicy(config=sac_cfg)
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac") algorithm = make_algorithm(cfg=SACAlgorithmConfig.from_policy_config(sac_cfg), policy=policy)
assert isinstance(algorithm, SACAlgorithm) assert isinstance(algorithm, SACAlgorithm)