mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 12:09:42 +00:00
refactor(rl): make algorithm a nested config so all SAC hyperparameters are JSON-addressable
This commit is contained in:
@@ -207,18 +207,3 @@ class TrainPipelineConfig(HubMixin):
|
||||
cli_args = kwargs.pop("cli_args", [])
|
||||
with draccus.config_type("json"):
|
||||
return draccus.parse(cls, config_file, args=cli_args)
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class TrainRLServerPipelineConfig(TrainPipelineConfig):
|
||||
# NOTE: In RL, we don't need an offline dataset
|
||||
# TODO: Make `TrainPipelineConfig.dataset` optional
|
||||
dataset: DatasetConfig | None = None # type: ignore[assignment] # because the parent class has made it's type non-optional
|
||||
|
||||
# Algorithm name registered in RLAlgorithmConfig registry
|
||||
algorithm: str = "sac"
|
||||
|
||||
# Data mixer strategy name. Currently supports "online_offline"
|
||||
mixer: str = "online_offline"
|
||||
# Fraction sampled from online replay when using OnlineOfflineMixer
|
||||
online_ratio: float = 0.5
|
||||
|
||||
@@ -60,11 +60,11 @@ from torch.multiprocessing import Queue
|
||||
|
||||
from lerobot.cameras import opencv # noqa: F401
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.train import TrainRLServerPipelineConfig
|
||||
from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors
|
||||
from lerobot.processor import TransitionKey
|
||||
from lerobot.rl.process import ProcessSignalHandler
|
||||
from lerobot.rl.queue import get_last_item_from_queue
|
||||
from lerobot.rl.train_rl import TrainRLServerPipelineConfig
|
||||
from lerobot.robots import so_follower # noqa: F401
|
||||
from lerobot.teleoperators import gamepad, so_leader # noqa: F401
|
||||
from lerobot.teleoperators.utils import TeleopEvents
|
||||
|
||||
@@ -20,16 +20,5 @@ from lerobot.rl.algorithms.base import RLAlgorithm
|
||||
from lerobot.rl.algorithms.configs import RLAlgorithmConfig
|
||||
|
||||
|
||||
def make_algorithm(
|
||||
policy: torch.nn.Module,
|
||||
policy_cfg,
|
||||
*,
|
||||
algorithm_name: str,
|
||||
) -> RLAlgorithm:
|
||||
known = RLAlgorithmConfig.get_known_choices()
|
||||
if algorithm_name not in known:
|
||||
raise ValueError(f"No RLAlgorithmConfig registered for '{algorithm_name}'. Known: {list(known)}")
|
||||
|
||||
config_cls = RLAlgorithmConfig.get_choice_class(algorithm_name)
|
||||
algo_config = config_cls.from_policy_config(policy_cfg)
|
||||
return algo_config.build_algorithm(policy)
|
||||
def make_algorithm(cfg: RLAlgorithmConfig, policy: torch.nn.Module) -> RLAlgorithm:
|
||||
return cfg.build_algorithm(policy)
|
||||
|
||||
@@ -34,9 +34,6 @@ if TYPE_CHECKING:
|
||||
class SACAlgorithmConfig(RLAlgorithmConfig):
|
||||
"""SAC algorithm hyperparameters."""
|
||||
|
||||
# Policy config
|
||||
policy_config: GaussianActorConfig
|
||||
|
||||
# Optimizer learning rates
|
||||
actor_lr: float = 3e-4
|
||||
critic_lr: float = 3e-4
|
||||
@@ -69,6 +66,9 @@ class SACAlgorithmConfig(RLAlgorithmConfig):
|
||||
# torch.compile is currently disabled by default
|
||||
use_torch_compile: bool = False
|
||||
|
||||
# Policy config
|
||||
policy_config: GaussianActorConfig | None = None
|
||||
|
||||
@classmethod
|
||||
def from_policy_config(cls, policy_cfg: GaussianActorConfig) -> SACAlgorithmConfig:
|
||||
"""Build an algorithm config with default hyperparameters for a given policy."""
|
||||
@@ -78,6 +78,13 @@ class SACAlgorithmConfig(RLAlgorithmConfig):
|
||||
)
|
||||
|
||||
def build_algorithm(self, policy: torch.nn.Module) -> SACAlgorithm:
|
||||
if self.policy_config is None:
|
||||
raise ValueError(
|
||||
"SACAlgorithmConfig.policy_config is None. "
|
||||
"It must be populated (typically by TrainRLServerPipelineConfig.validate) "
|
||||
"before calling build_algorithm()."
|
||||
)
|
||||
|
||||
from lerobot.rl.algorithms.sac.sac_algorithm import SACAlgorithm
|
||||
|
||||
return SACAlgorithm(policy=policy, config=self)
|
||||
|
||||
@@ -17,9 +17,9 @@ import logging
|
||||
|
||||
from lerobot.cameras import opencv # noqa: F401
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.train import TrainRLServerPipelineConfig
|
||||
from lerobot.datasets import LeRobotDataset
|
||||
from lerobot.policies import make_policy
|
||||
from lerobot.rl.train_rl import TrainRLServerPipelineConfig
|
||||
from lerobot.robots import ( # noqa: F401
|
||||
RobotConfig,
|
||||
make_robot_from_config,
|
||||
|
||||
@@ -69,7 +69,6 @@ from lerobot.common.train_utils import (
|
||||
)
|
||||
from lerobot.common.wandb_utils import WandBLogger
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.train import TrainRLServerPipelineConfig
|
||||
from lerobot.datasets import LeRobotDataset, make_dataset
|
||||
from lerobot.policies import make_policy, make_pre_post_processors
|
||||
from lerobot.rl.algorithms.base import RLAlgorithm
|
||||
@@ -77,6 +76,7 @@ from lerobot.rl.algorithms.factory import make_algorithm
|
||||
from lerobot.rl.buffer import ReplayBuffer
|
||||
from lerobot.rl.data_sources import OnlineOfflineMixer
|
||||
from lerobot.rl.process import ProcessSignalHandler
|
||||
from lerobot.rl.train_rl import TrainRLServerPipelineConfig
|
||||
from lerobot.rl.trainer import RLTrainer
|
||||
from lerobot.robots import so_follower # noqa: F401
|
||||
from lerobot.teleoperators import gamepad, so_leader # noqa: F401
|
||||
@@ -316,11 +316,7 @@ def add_actor_information_and_train(
|
||||
|
||||
policy.train()
|
||||
|
||||
algorithm = make_algorithm(
|
||||
policy=policy,
|
||||
policy_cfg=cfg.policy,
|
||||
algorithm_name=cfg.algorithm,
|
||||
)
|
||||
algorithm = make_algorithm(cfg=cfg.algorithm, policy=policy)
|
||||
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg.policy,
|
||||
|
||||
@@ -0,0 +1,54 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Top-level pipeline config for distributed RL training (actor / learner)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from lerobot.configs.default import DatasetConfig
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.rl.algorithms.configs import RLAlgorithmConfig
|
||||
from lerobot.rl.algorithms.sac import SACAlgorithmConfig # noqa: F401
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class TrainRLServerPipelineConfig(TrainPipelineConfig):
|
||||
# NOTE: In RL, we don't need an offline dataset
|
||||
# TODO: Make `TrainPipelineConfig.dataset` optional
|
||||
dataset: DatasetConfig | None = None # type: ignore[assignment] # because the parent class has made it's type non-optional
|
||||
|
||||
# Algorithm config (a `draccus.ChoiceRegistry` subclass selected by `type`,
|
||||
# e.g. ``"type": "sac"``). When omitted, defaults to a SAC config with
|
||||
# default hyperparameters. The top-level `policy` is injected into
|
||||
# ``algorithm.policy_config`` at validation time.
|
||||
algorithm: RLAlgorithmConfig | None = None
|
||||
|
||||
# Data mixer strategy name. Currently supports "online_offline".
|
||||
mixer: str = "online_offline"
|
||||
# Fraction sampled from online replay when using OnlineOfflineMixer.
|
||||
online_ratio: float = 0.5
|
||||
|
||||
def validate(self) -> None:
|
||||
super().validate()
|
||||
|
||||
if self.algorithm is None:
|
||||
sac_cls = RLAlgorithmConfig.get_choice_class("sac")
|
||||
self.algorithm = sac_cls()
|
||||
|
||||
# The pipeline owns the policy config; inject it so the algorithm can
|
||||
# introspect policy architecture (e.g. ``num_discrete_actions``).
|
||||
if getattr(self.algorithm, "policy_config", None) is None:
|
||||
self.algorithm.policy_config = self.policy
|
||||
@@ -26,9 +26,9 @@ pytest.importorskip("grpc")
|
||||
|
||||
from torch.multiprocessing import Event, Queue
|
||||
|
||||
from lerobot.configs.train import TrainRLServerPipelineConfig
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.policies.gaussian_actor.configuration_gaussian_actor import GaussianActorConfig
|
||||
from lerobot.rl.train_rl import TrainRLServerPipelineConfig
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE, OBS_STR
|
||||
from lerobot.utils.transition import Transition
|
||||
from tests.utils import skip_if_package_missing
|
||||
@@ -314,7 +314,7 @@ def test_learner_algorithm_wiring():
|
||||
get_weights() output is serializable."""
|
||||
from lerobot.policies.gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy
|
||||
from lerobot.rl.algorithms.factory import make_algorithm
|
||||
from lerobot.rl.algorithms.sac import SACAlgorithm
|
||||
from lerobot.rl.algorithms.sac import SACAlgorithm, SACAlgorithmConfig
|
||||
from lerobot.transport.utils import state_to_bytes
|
||||
|
||||
state_dim = 10
|
||||
@@ -333,7 +333,7 @@ def test_learner_algorithm_wiring():
|
||||
policy = GaussianActorPolicy(config=sac_cfg)
|
||||
policy.train()
|
||||
|
||||
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
|
||||
algorithm = make_algorithm(cfg=SACAlgorithmConfig.from_policy_config(sac_cfg), policy=policy)
|
||||
assert isinstance(algorithm, SACAlgorithm)
|
||||
|
||||
optimizers = algorithm.make_optimizers_and_scheduler()
|
||||
@@ -400,6 +400,7 @@ def test_initial_and_periodic_weight_push_consistency():
|
||||
and produce identical structures."""
|
||||
from lerobot.policies.gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy
|
||||
from lerobot.rl.algorithms.factory import make_algorithm
|
||||
from lerobot.rl.algorithms.sac import SACAlgorithmConfig
|
||||
from lerobot.transport.utils import bytes_to_state_dict, state_to_bytes
|
||||
|
||||
state_dim = 10
|
||||
@@ -416,7 +417,7 @@ def test_initial_and_periodic_weight_push_consistency():
|
||||
|
||||
policy = GaussianActorPolicy(config=sac_cfg)
|
||||
policy.train()
|
||||
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
|
||||
algorithm = make_algorithm(cfg=SACAlgorithmConfig.from_policy_config(sac_cfg), policy=policy)
|
||||
algorithm.make_optimizers_and_scheduler()
|
||||
|
||||
# Simulate initial push (same code path the learner now uses)
|
||||
@@ -437,7 +438,7 @@ def test_actor_side_algorithm_select_action_and_load_weights():
|
||||
"""Simulate actor: create algorithm without optimizers, select_action, load_weights."""
|
||||
from lerobot.policies.gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy
|
||||
from lerobot.rl.algorithms.factory import make_algorithm
|
||||
from lerobot.rl.algorithms.sac import SACAlgorithm
|
||||
from lerobot.rl.algorithms.sac import SACAlgorithm, SACAlgorithmConfig
|
||||
|
||||
state_dim = 10
|
||||
action_dim = 6
|
||||
@@ -454,7 +455,7 @@ def test_actor_side_algorithm_select_action_and_load_weights():
|
||||
# Actor side: no optimizers
|
||||
policy = GaussianActorPolicy(config=sac_cfg)
|
||||
policy.eval()
|
||||
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
|
||||
algorithm = make_algorithm(cfg=SACAlgorithmConfig.from_policy_config(sac_cfg), policy=policy)
|
||||
assert isinstance(algorithm, SACAlgorithm)
|
||||
assert algorithm.optimizers == {}
|
||||
|
||||
|
||||
@@ -348,7 +348,7 @@ def test_optimization_step_can_be_set_for_resume():
|
||||
def test_make_algorithm_returns_sac_for_sac_policy():
|
||||
sac_cfg = _make_sac_config()
|
||||
policy = GaussianActorPolicy(config=sac_cfg)
|
||||
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
|
||||
algorithm = make_algorithm(cfg=SACAlgorithmConfig.from_policy_config(sac_cfg), policy=policy)
|
||||
assert isinstance(algorithm, SACAlgorithm)
|
||||
assert algorithm.optimizers == {}
|
||||
|
||||
@@ -357,7 +357,7 @@ def test_make_optimizers_creates_expected_keys():
|
||||
"""make_optimizers_and_scheduler() should populate the algorithm with Adam optimizers."""
|
||||
sac_cfg = _make_sac_config()
|
||||
policy = GaussianActorPolicy(config=sac_cfg)
|
||||
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
|
||||
algorithm = make_algorithm(cfg=SACAlgorithmConfig.from_policy_config(sac_cfg), policy=policy)
|
||||
optimizers = algorithm.make_optimizers_and_scheduler()
|
||||
assert "actor" in optimizers
|
||||
assert "critic" in optimizers
|
||||
@@ -370,7 +370,7 @@ def test_actor_side_no_optimizers():
|
||||
"""Actor-side usage: no optimizers needed, make_optimizers_and_scheduler is not called."""
|
||||
sac_cfg = _make_sac_config()
|
||||
policy = GaussianActorPolicy(config=sac_cfg)
|
||||
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
|
||||
algorithm = make_algorithm(cfg=SACAlgorithmConfig.from_policy_config(sac_cfg), policy=policy)
|
||||
assert isinstance(algorithm, SACAlgorithm)
|
||||
assert algorithm.optimizers == {}
|
||||
|
||||
@@ -379,18 +379,16 @@ def test_make_algorithm_uses_sac_algorithm_defaults():
|
||||
"""make_algorithm populates SACAlgorithmConfig with its own defaults."""
|
||||
sac_cfg = _make_sac_config()
|
||||
policy = GaussianActorPolicy(config=sac_cfg)
|
||||
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
|
||||
algorithm = make_algorithm(cfg=SACAlgorithmConfig.from_policy_config(sac_cfg), policy=policy)
|
||||
assert algorithm.config.utd_ratio == 1
|
||||
assert algorithm.config.policy_update_freq == 1
|
||||
assert algorithm.config.grad_clip_norm == 40.0
|
||||
|
||||
|
||||
def test_make_algorithm_raises_for_unknown_type():
|
||||
class FakeConfig:
|
||||
type = "unknown_algo"
|
||||
|
||||
with pytest.raises(ValueError, match="No RLAlgorithmConfig"):
|
||||
make_algorithm(policy=None, policy_cfg=FakeConfig(), algorithm_name="unknown_algo")
|
||||
def test_unknown_algorithm_name_raises_in_registry():
|
||||
"""The ChoiceRegistry is the source of truth for unknown algorithm names."""
|
||||
with pytest.raises(KeyError):
|
||||
RLAlgorithmConfig.get_choice_class("unknown_algo")
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
@@ -523,5 +521,5 @@ def test_make_algorithm_uses_build_algorithm():
|
||||
"""make_algorithm should delegate to config.build_algorithm (no hardcoded if/else)."""
|
||||
sac_cfg = _make_sac_config()
|
||||
policy = GaussianActorPolicy(config=sac_cfg)
|
||||
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
|
||||
algorithm = make_algorithm(cfg=SACAlgorithmConfig.from_policy_config(sac_cfg), policy=policy)
|
||||
assert isinstance(algorithm, SACAlgorithm)
|
||||
|
||||
Reference in New Issue
Block a user