mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-20 11:09:59 +00:00
refactor(rl): make algorithm a nested config so all SAC hyperparameters are JSON-addressable
This commit is contained in:
@@ -207,18 +207,3 @@ class TrainPipelineConfig(HubMixin):
|
|||||||
cli_args = kwargs.pop("cli_args", [])
|
cli_args = kwargs.pop("cli_args", [])
|
||||||
with draccus.config_type("json"):
|
with draccus.config_type("json"):
|
||||||
return draccus.parse(cls, config_file, args=cli_args)
|
return draccus.parse(cls, config_file, args=cli_args)
|
||||||
|
|
||||||
|
|
||||||
@dataclass(kw_only=True)
|
|
||||||
class TrainRLServerPipelineConfig(TrainPipelineConfig):
|
|
||||||
# NOTE: In RL, we don't need an offline dataset
|
|
||||||
# TODO: Make `TrainPipelineConfig.dataset` optional
|
|
||||||
dataset: DatasetConfig | None = None # type: ignore[assignment] # because the parent class has made it's type non-optional
|
|
||||||
|
|
||||||
# Algorithm name registered in RLAlgorithmConfig registry
|
|
||||||
algorithm: str = "sac"
|
|
||||||
|
|
||||||
# Data mixer strategy name. Currently supports "online_offline"
|
|
||||||
mixer: str = "online_offline"
|
|
||||||
# Fraction sampled from online replay when using OnlineOfflineMixer
|
|
||||||
online_ratio: float = 0.5
|
|
||||||
|
|||||||
@@ -60,11 +60,11 @@ from torch.multiprocessing import Queue
|
|||||||
|
|
||||||
from lerobot.cameras import opencv # noqa: F401
|
from lerobot.cameras import opencv # noqa: F401
|
||||||
from lerobot.configs import parser
|
from lerobot.configs import parser
|
||||||
from lerobot.configs.train import TrainRLServerPipelineConfig
|
|
||||||
from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors
|
from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors
|
||||||
from lerobot.processor import TransitionKey
|
from lerobot.processor import TransitionKey
|
||||||
from lerobot.rl.process import ProcessSignalHandler
|
from lerobot.rl.process import ProcessSignalHandler
|
||||||
from lerobot.rl.queue import get_last_item_from_queue
|
from lerobot.rl.queue import get_last_item_from_queue
|
||||||
|
from lerobot.rl.train_rl import TrainRLServerPipelineConfig
|
||||||
from lerobot.robots import so_follower # noqa: F401
|
from lerobot.robots import so_follower # noqa: F401
|
||||||
from lerobot.teleoperators import gamepad, so_leader # noqa: F401
|
from lerobot.teleoperators import gamepad, so_leader # noqa: F401
|
||||||
from lerobot.teleoperators.utils import TeleopEvents
|
from lerobot.teleoperators.utils import TeleopEvents
|
||||||
|
|||||||
@@ -20,16 +20,5 @@ from lerobot.rl.algorithms.base import RLAlgorithm
|
|||||||
from lerobot.rl.algorithms.configs import RLAlgorithmConfig
|
from lerobot.rl.algorithms.configs import RLAlgorithmConfig
|
||||||
|
|
||||||
|
|
||||||
def make_algorithm(
|
def make_algorithm(cfg: RLAlgorithmConfig, policy: torch.nn.Module) -> RLAlgorithm:
|
||||||
policy: torch.nn.Module,
|
return cfg.build_algorithm(policy)
|
||||||
policy_cfg,
|
|
||||||
*,
|
|
||||||
algorithm_name: str,
|
|
||||||
) -> RLAlgorithm:
|
|
||||||
known = RLAlgorithmConfig.get_known_choices()
|
|
||||||
if algorithm_name not in known:
|
|
||||||
raise ValueError(f"No RLAlgorithmConfig registered for '{algorithm_name}'. Known: {list(known)}")
|
|
||||||
|
|
||||||
config_cls = RLAlgorithmConfig.get_choice_class(algorithm_name)
|
|
||||||
algo_config = config_cls.from_policy_config(policy_cfg)
|
|
||||||
return algo_config.build_algorithm(policy)
|
|
||||||
|
|||||||
@@ -34,9 +34,6 @@ if TYPE_CHECKING:
|
|||||||
class SACAlgorithmConfig(RLAlgorithmConfig):
|
class SACAlgorithmConfig(RLAlgorithmConfig):
|
||||||
"""SAC algorithm hyperparameters."""
|
"""SAC algorithm hyperparameters."""
|
||||||
|
|
||||||
# Policy config
|
|
||||||
policy_config: GaussianActorConfig
|
|
||||||
|
|
||||||
# Optimizer learning rates
|
# Optimizer learning rates
|
||||||
actor_lr: float = 3e-4
|
actor_lr: float = 3e-4
|
||||||
critic_lr: float = 3e-4
|
critic_lr: float = 3e-4
|
||||||
@@ -69,6 +66,9 @@ class SACAlgorithmConfig(RLAlgorithmConfig):
|
|||||||
# torch.compile is currently disabled by default
|
# torch.compile is currently disabled by default
|
||||||
use_torch_compile: bool = False
|
use_torch_compile: bool = False
|
||||||
|
|
||||||
|
# Policy config
|
||||||
|
policy_config: GaussianActorConfig | None = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_policy_config(cls, policy_cfg: GaussianActorConfig) -> SACAlgorithmConfig:
|
def from_policy_config(cls, policy_cfg: GaussianActorConfig) -> SACAlgorithmConfig:
|
||||||
"""Build an algorithm config with default hyperparameters for a given policy."""
|
"""Build an algorithm config with default hyperparameters for a given policy."""
|
||||||
@@ -78,6 +78,13 @@ class SACAlgorithmConfig(RLAlgorithmConfig):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def build_algorithm(self, policy: torch.nn.Module) -> SACAlgorithm:
|
def build_algorithm(self, policy: torch.nn.Module) -> SACAlgorithm:
|
||||||
|
if self.policy_config is None:
|
||||||
|
raise ValueError(
|
||||||
|
"SACAlgorithmConfig.policy_config is None. "
|
||||||
|
"It must be populated (typically by TrainRLServerPipelineConfig.validate) "
|
||||||
|
"before calling build_algorithm()."
|
||||||
|
)
|
||||||
|
|
||||||
from lerobot.rl.algorithms.sac.sac_algorithm import SACAlgorithm
|
from lerobot.rl.algorithms.sac.sac_algorithm import SACAlgorithm
|
||||||
|
|
||||||
return SACAlgorithm(policy=policy, config=self)
|
return SACAlgorithm(policy=policy, config=self)
|
||||||
|
|||||||
@@ -17,9 +17,9 @@ import logging
|
|||||||
|
|
||||||
from lerobot.cameras import opencv # noqa: F401
|
from lerobot.cameras import opencv # noqa: F401
|
||||||
from lerobot.configs import parser
|
from lerobot.configs import parser
|
||||||
from lerobot.configs.train import TrainRLServerPipelineConfig
|
|
||||||
from lerobot.datasets import LeRobotDataset
|
from lerobot.datasets import LeRobotDataset
|
||||||
from lerobot.policies import make_policy
|
from lerobot.policies import make_policy
|
||||||
|
from lerobot.rl.train_rl import TrainRLServerPipelineConfig
|
||||||
from lerobot.robots import ( # noqa: F401
|
from lerobot.robots import ( # noqa: F401
|
||||||
RobotConfig,
|
RobotConfig,
|
||||||
make_robot_from_config,
|
make_robot_from_config,
|
||||||
|
|||||||
@@ -69,7 +69,6 @@ from lerobot.common.train_utils import (
|
|||||||
)
|
)
|
||||||
from lerobot.common.wandb_utils import WandBLogger
|
from lerobot.common.wandb_utils import WandBLogger
|
||||||
from lerobot.configs import parser
|
from lerobot.configs import parser
|
||||||
from lerobot.configs.train import TrainRLServerPipelineConfig
|
|
||||||
from lerobot.datasets import LeRobotDataset, make_dataset
|
from lerobot.datasets import LeRobotDataset, make_dataset
|
||||||
from lerobot.policies import make_policy, make_pre_post_processors
|
from lerobot.policies import make_policy, make_pre_post_processors
|
||||||
from lerobot.rl.algorithms.base import RLAlgorithm
|
from lerobot.rl.algorithms.base import RLAlgorithm
|
||||||
@@ -77,6 +76,7 @@ from lerobot.rl.algorithms.factory import make_algorithm
|
|||||||
from lerobot.rl.buffer import ReplayBuffer
|
from lerobot.rl.buffer import ReplayBuffer
|
||||||
from lerobot.rl.data_sources import OnlineOfflineMixer
|
from lerobot.rl.data_sources import OnlineOfflineMixer
|
||||||
from lerobot.rl.process import ProcessSignalHandler
|
from lerobot.rl.process import ProcessSignalHandler
|
||||||
|
from lerobot.rl.train_rl import TrainRLServerPipelineConfig
|
||||||
from lerobot.rl.trainer import RLTrainer
|
from lerobot.rl.trainer import RLTrainer
|
||||||
from lerobot.robots import so_follower # noqa: F401
|
from lerobot.robots import so_follower # noqa: F401
|
||||||
from lerobot.teleoperators import gamepad, so_leader # noqa: F401
|
from lerobot.teleoperators import gamepad, so_leader # noqa: F401
|
||||||
@@ -316,11 +316,7 @@ def add_actor_information_and_train(
|
|||||||
|
|
||||||
policy.train()
|
policy.train()
|
||||||
|
|
||||||
algorithm = make_algorithm(
|
algorithm = make_algorithm(cfg=cfg.algorithm, policy=policy)
|
||||||
policy=policy,
|
|
||||||
policy_cfg=cfg.policy,
|
|
||||||
algorithm_name=cfg.algorithm,
|
|
||||||
)
|
|
||||||
|
|
||||||
preprocessor, postprocessor = make_pre_post_processors(
|
preprocessor, postprocessor = make_pre_post_processors(
|
||||||
policy_cfg=cfg.policy,
|
policy_cfg=cfg.policy,
|
||||||
|
|||||||
@@ -0,0 +1,54 @@
|
|||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""Top-level pipeline config for distributed RL training (actor / learner)."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from lerobot.configs.default import DatasetConfig
|
||||||
|
from lerobot.configs.train import TrainPipelineConfig
|
||||||
|
from lerobot.rl.algorithms.configs import RLAlgorithmConfig
|
||||||
|
from lerobot.rl.algorithms.sac import SACAlgorithmConfig # noqa: F401
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(kw_only=True)
|
||||||
|
class TrainRLServerPipelineConfig(TrainPipelineConfig):
|
||||||
|
# NOTE: In RL, we don't need an offline dataset
|
||||||
|
# TODO: Make `TrainPipelineConfig.dataset` optional
|
||||||
|
dataset: DatasetConfig | None = None # type: ignore[assignment] # because the parent class has made it's type non-optional
|
||||||
|
|
||||||
|
# Algorithm config (a `draccus.ChoiceRegistry` subclass selected by `type`,
|
||||||
|
# e.g. ``"type": "sac"``). When omitted, defaults to a SAC config with
|
||||||
|
# default hyperparameters. The top-level `policy` is injected into
|
||||||
|
# ``algorithm.policy_config`` at validation time.
|
||||||
|
algorithm: RLAlgorithmConfig | None = None
|
||||||
|
|
||||||
|
# Data mixer strategy name. Currently supports "online_offline".
|
||||||
|
mixer: str = "online_offline"
|
||||||
|
# Fraction sampled from online replay when using OnlineOfflineMixer.
|
||||||
|
online_ratio: float = 0.5
|
||||||
|
|
||||||
|
def validate(self) -> None:
|
||||||
|
super().validate()
|
||||||
|
|
||||||
|
if self.algorithm is None:
|
||||||
|
sac_cls = RLAlgorithmConfig.get_choice_class("sac")
|
||||||
|
self.algorithm = sac_cls()
|
||||||
|
|
||||||
|
# The pipeline owns the policy config; inject it so the algorithm can
|
||||||
|
# introspect policy architecture (e.g. ``num_discrete_actions``).
|
||||||
|
if getattr(self.algorithm, "policy_config", None) is None:
|
||||||
|
self.algorithm.policy_config = self.policy
|
||||||
@@ -26,9 +26,9 @@ pytest.importorskip("grpc")
|
|||||||
|
|
||||||
from torch.multiprocessing import Event, Queue
|
from torch.multiprocessing import Event, Queue
|
||||||
|
|
||||||
from lerobot.configs.train import TrainRLServerPipelineConfig
|
|
||||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||||
from lerobot.policies.gaussian_actor.configuration_gaussian_actor import GaussianActorConfig
|
from lerobot.policies.gaussian_actor.configuration_gaussian_actor import GaussianActorConfig
|
||||||
|
from lerobot.rl.train_rl import TrainRLServerPipelineConfig
|
||||||
from lerobot.utils.constants import ACTION, OBS_STATE, OBS_STR
|
from lerobot.utils.constants import ACTION, OBS_STATE, OBS_STR
|
||||||
from lerobot.utils.transition import Transition
|
from lerobot.utils.transition import Transition
|
||||||
from tests.utils import skip_if_package_missing
|
from tests.utils import skip_if_package_missing
|
||||||
@@ -314,7 +314,7 @@ def test_learner_algorithm_wiring():
|
|||||||
get_weights() output is serializable."""
|
get_weights() output is serializable."""
|
||||||
from lerobot.policies.gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy
|
from lerobot.policies.gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy
|
||||||
from lerobot.rl.algorithms.factory import make_algorithm
|
from lerobot.rl.algorithms.factory import make_algorithm
|
||||||
from lerobot.rl.algorithms.sac import SACAlgorithm
|
from lerobot.rl.algorithms.sac import SACAlgorithm, SACAlgorithmConfig
|
||||||
from lerobot.transport.utils import state_to_bytes
|
from lerobot.transport.utils import state_to_bytes
|
||||||
|
|
||||||
state_dim = 10
|
state_dim = 10
|
||||||
@@ -333,7 +333,7 @@ def test_learner_algorithm_wiring():
|
|||||||
policy = GaussianActorPolicy(config=sac_cfg)
|
policy = GaussianActorPolicy(config=sac_cfg)
|
||||||
policy.train()
|
policy.train()
|
||||||
|
|
||||||
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
|
algorithm = make_algorithm(cfg=SACAlgorithmConfig.from_policy_config(sac_cfg), policy=policy)
|
||||||
assert isinstance(algorithm, SACAlgorithm)
|
assert isinstance(algorithm, SACAlgorithm)
|
||||||
|
|
||||||
optimizers = algorithm.make_optimizers_and_scheduler()
|
optimizers = algorithm.make_optimizers_and_scheduler()
|
||||||
@@ -400,6 +400,7 @@ def test_initial_and_periodic_weight_push_consistency():
|
|||||||
and produce identical structures."""
|
and produce identical structures."""
|
||||||
from lerobot.policies.gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy
|
from lerobot.policies.gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy
|
||||||
from lerobot.rl.algorithms.factory import make_algorithm
|
from lerobot.rl.algorithms.factory import make_algorithm
|
||||||
|
from lerobot.rl.algorithms.sac import SACAlgorithmConfig
|
||||||
from lerobot.transport.utils import bytes_to_state_dict, state_to_bytes
|
from lerobot.transport.utils import bytes_to_state_dict, state_to_bytes
|
||||||
|
|
||||||
state_dim = 10
|
state_dim = 10
|
||||||
@@ -416,7 +417,7 @@ def test_initial_and_periodic_weight_push_consistency():
|
|||||||
|
|
||||||
policy = GaussianActorPolicy(config=sac_cfg)
|
policy = GaussianActorPolicy(config=sac_cfg)
|
||||||
policy.train()
|
policy.train()
|
||||||
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
|
algorithm = make_algorithm(cfg=SACAlgorithmConfig.from_policy_config(sac_cfg), policy=policy)
|
||||||
algorithm.make_optimizers_and_scheduler()
|
algorithm.make_optimizers_and_scheduler()
|
||||||
|
|
||||||
# Simulate initial push (same code path the learner now uses)
|
# Simulate initial push (same code path the learner now uses)
|
||||||
@@ -437,7 +438,7 @@ def test_actor_side_algorithm_select_action_and_load_weights():
|
|||||||
"""Simulate actor: create algorithm without optimizers, select_action, load_weights."""
|
"""Simulate actor: create algorithm without optimizers, select_action, load_weights."""
|
||||||
from lerobot.policies.gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy
|
from lerobot.policies.gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy
|
||||||
from lerobot.rl.algorithms.factory import make_algorithm
|
from lerobot.rl.algorithms.factory import make_algorithm
|
||||||
from lerobot.rl.algorithms.sac import SACAlgorithm
|
from lerobot.rl.algorithms.sac import SACAlgorithm, SACAlgorithmConfig
|
||||||
|
|
||||||
state_dim = 10
|
state_dim = 10
|
||||||
action_dim = 6
|
action_dim = 6
|
||||||
@@ -454,7 +455,7 @@ def test_actor_side_algorithm_select_action_and_load_weights():
|
|||||||
# Actor side: no optimizers
|
# Actor side: no optimizers
|
||||||
policy = GaussianActorPolicy(config=sac_cfg)
|
policy = GaussianActorPolicy(config=sac_cfg)
|
||||||
policy.eval()
|
policy.eval()
|
||||||
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
|
algorithm = make_algorithm(cfg=SACAlgorithmConfig.from_policy_config(sac_cfg), policy=policy)
|
||||||
assert isinstance(algorithm, SACAlgorithm)
|
assert isinstance(algorithm, SACAlgorithm)
|
||||||
assert algorithm.optimizers == {}
|
assert algorithm.optimizers == {}
|
||||||
|
|
||||||
|
|||||||
@@ -348,7 +348,7 @@ def test_optimization_step_can_be_set_for_resume():
|
|||||||
def test_make_algorithm_returns_sac_for_sac_policy():
|
def test_make_algorithm_returns_sac_for_sac_policy():
|
||||||
sac_cfg = _make_sac_config()
|
sac_cfg = _make_sac_config()
|
||||||
policy = GaussianActorPolicy(config=sac_cfg)
|
policy = GaussianActorPolicy(config=sac_cfg)
|
||||||
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
|
algorithm = make_algorithm(cfg=SACAlgorithmConfig.from_policy_config(sac_cfg), policy=policy)
|
||||||
assert isinstance(algorithm, SACAlgorithm)
|
assert isinstance(algorithm, SACAlgorithm)
|
||||||
assert algorithm.optimizers == {}
|
assert algorithm.optimizers == {}
|
||||||
|
|
||||||
@@ -357,7 +357,7 @@ def test_make_optimizers_creates_expected_keys():
|
|||||||
"""make_optimizers_and_scheduler() should populate the algorithm with Adam optimizers."""
|
"""make_optimizers_and_scheduler() should populate the algorithm with Adam optimizers."""
|
||||||
sac_cfg = _make_sac_config()
|
sac_cfg = _make_sac_config()
|
||||||
policy = GaussianActorPolicy(config=sac_cfg)
|
policy = GaussianActorPolicy(config=sac_cfg)
|
||||||
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
|
algorithm = make_algorithm(cfg=SACAlgorithmConfig.from_policy_config(sac_cfg), policy=policy)
|
||||||
optimizers = algorithm.make_optimizers_and_scheduler()
|
optimizers = algorithm.make_optimizers_and_scheduler()
|
||||||
assert "actor" in optimizers
|
assert "actor" in optimizers
|
||||||
assert "critic" in optimizers
|
assert "critic" in optimizers
|
||||||
@@ -370,7 +370,7 @@ def test_actor_side_no_optimizers():
|
|||||||
"""Actor-side usage: no optimizers needed, make_optimizers_and_scheduler is not called."""
|
"""Actor-side usage: no optimizers needed, make_optimizers_and_scheduler is not called."""
|
||||||
sac_cfg = _make_sac_config()
|
sac_cfg = _make_sac_config()
|
||||||
policy = GaussianActorPolicy(config=sac_cfg)
|
policy = GaussianActorPolicy(config=sac_cfg)
|
||||||
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
|
algorithm = make_algorithm(cfg=SACAlgorithmConfig.from_policy_config(sac_cfg), policy=policy)
|
||||||
assert isinstance(algorithm, SACAlgorithm)
|
assert isinstance(algorithm, SACAlgorithm)
|
||||||
assert algorithm.optimizers == {}
|
assert algorithm.optimizers == {}
|
||||||
|
|
||||||
@@ -379,18 +379,16 @@ def test_make_algorithm_uses_sac_algorithm_defaults():
|
|||||||
"""make_algorithm populates SACAlgorithmConfig with its own defaults."""
|
"""make_algorithm populates SACAlgorithmConfig with its own defaults."""
|
||||||
sac_cfg = _make_sac_config()
|
sac_cfg = _make_sac_config()
|
||||||
policy = GaussianActorPolicy(config=sac_cfg)
|
policy = GaussianActorPolicy(config=sac_cfg)
|
||||||
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
|
algorithm = make_algorithm(cfg=SACAlgorithmConfig.from_policy_config(sac_cfg), policy=policy)
|
||||||
assert algorithm.config.utd_ratio == 1
|
assert algorithm.config.utd_ratio == 1
|
||||||
assert algorithm.config.policy_update_freq == 1
|
assert algorithm.config.policy_update_freq == 1
|
||||||
assert algorithm.config.grad_clip_norm == 40.0
|
assert algorithm.config.grad_clip_norm == 40.0
|
||||||
|
|
||||||
|
|
||||||
def test_make_algorithm_raises_for_unknown_type():
|
def test_unknown_algorithm_name_raises_in_registry():
|
||||||
class FakeConfig:
|
"""The ChoiceRegistry is the source of truth for unknown algorithm names."""
|
||||||
type = "unknown_algo"
|
with pytest.raises(KeyError):
|
||||||
|
RLAlgorithmConfig.get_choice_class("unknown_algo")
|
||||||
with pytest.raises(ValueError, match="No RLAlgorithmConfig"):
|
|
||||||
make_algorithm(policy=None, policy_cfg=FakeConfig(), algorithm_name="unknown_algo")
|
|
||||||
|
|
||||||
|
|
||||||
# ===========================================================================
|
# ===========================================================================
|
||||||
@@ -523,5 +521,5 @@ def test_make_algorithm_uses_build_algorithm():
|
|||||||
"""make_algorithm should delegate to config.build_algorithm (no hardcoded if/else)."""
|
"""make_algorithm should delegate to config.build_algorithm (no hardcoded if/else)."""
|
||||||
sac_cfg = _make_sac_config()
|
sac_cfg = _make_sac_config()
|
||||||
policy = GaussianActorPolicy(config=sac_cfg)
|
policy = GaussianActorPolicy(config=sac_cfg)
|
||||||
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
|
algorithm = make_algorithm(cfg=SACAlgorithmConfig.from_policy_config(sac_cfg), policy=policy)
|
||||||
assert isinstance(algorithm, SACAlgorithm)
|
assert isinstance(algorithm, SACAlgorithm)
|
||||||
|
|||||||
Reference in New Issue
Block a user