refactor(policies): rename policies/sac → policies/gaussian_actor

This commit is contained in:
Khalil Meftah
2026-04-23 19:13:18 +02:00
parent 8065bf15c7
commit 06255996ea
24 changed files with 185 additions and 168 deletions
+11 -11
View File
@@ -28,7 +28,7 @@ from torch.multiprocessing import Event, Queue
from lerobot.configs.train import TrainRLServerPipelineConfig
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.policies.sac.configuration_sac import SACConfig
from lerobot.policies.gaussian_actor.configuration_gaussian_actor import GaussianActorConfig
from lerobot.utils.constants import ACTION, OBS_STATE, OBS_STR
from lerobot.utils.transition import Transition
from tests.utils import skip_if_package_missing
@@ -81,7 +81,7 @@ def cfg():
port = find_free_port()
policy_cfg = SACConfig()
policy_cfg = GaussianActorConfig()
policy_cfg.actor_learner_config.learner_host = "127.0.0.1"
policy_cfg.actor_learner_config.learner_port = port
policy_cfg.concurrency.actor = "threads"
@@ -312,7 +312,7 @@ def test_learner_algorithm_wiring():
"""Verify that make_algorithm constructs an SACAlgorithm from config,
make_optimizers_and_scheduler() creates the right optimizers, update() works, and
get_weights() output is serializable."""
from lerobot.policies.sac.modeling_sac import SACPolicy
from lerobot.policies.gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy
from lerobot.rl.algorithms.factory import make_algorithm
from lerobot.rl.algorithms.sac import SACAlgorithm
from lerobot.transport.utils import state_to_bytes
@@ -320,7 +320,7 @@ def test_learner_algorithm_wiring():
state_dim = 10
action_dim = 6
sac_cfg = SACConfig(
sac_cfg = GaussianActorConfig(
input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))},
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))},
dataset_stats={
@@ -331,7 +331,7 @@ def test_learner_algorithm_wiring():
)
sac_cfg.validate_features()
policy = SACPolicy(config=sac_cfg)
policy = GaussianActorPolicy(config=sac_cfg)
policy.train()
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
@@ -399,13 +399,13 @@ def test_learner_algorithm_wiring():
def test_initial_and_periodic_weight_push_consistency():
"""Both initial and periodic weight pushes should use algorithm.get_weights()
and produce identical structures."""
from lerobot.policies.sac.modeling_sac import SACPolicy
from lerobot.policies.gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy
from lerobot.rl.algorithms.factory import make_algorithm
from lerobot.transport.utils import bytes_to_state_dict, state_to_bytes
state_dim = 10
action_dim = 6
sac_cfg = SACConfig(
sac_cfg = GaussianActorConfig(
input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))},
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))},
dataset_stats={
@@ -416,7 +416,7 @@ def test_initial_and_periodic_weight_push_consistency():
)
sac_cfg.validate_features()
policy = SACPolicy(config=sac_cfg)
policy = GaussianActorPolicy(config=sac_cfg)
policy.train()
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
algorithm.make_optimizers_and_scheduler()
@@ -437,13 +437,13 @@ def test_initial_and_periodic_weight_push_consistency():
def test_actor_side_algorithm_select_action_and_load_weights():
"""Simulate actor: create algorithm without optimizers, select_action, load_weights."""
from lerobot.policies.sac.modeling_sac import SACPolicy
from lerobot.policies.gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy
from lerobot.rl.algorithms.factory import make_algorithm
from lerobot.rl.algorithms.sac import SACAlgorithm
state_dim = 10
action_dim = 6
sac_cfg = SACConfig(
sac_cfg = GaussianActorConfig(
input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))},
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))},
dataset_stats={
@@ -455,7 +455,7 @@ def test_actor_side_algorithm_select_action_and_load_weights():
sac_cfg.validate_features()
# Actor side: no optimizers
policy = SACPolicy(config=sac_cfg)
policy = GaussianActorPolicy(config=sac_cfg)
policy.eval()
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
assert isinstance(algorithm, SACAlgorithm)
+14 -14
View File
@@ -22,8 +22,8 @@ pytest.importorskip("grpc")
import torch
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.policies.sac.configuration_sac import SACConfig
from lerobot.policies.sac.modeling_sac import SACPolicy
from lerobot.policies.gaussian_actor.configuration_gaussian_actor import GaussianActorConfig
from lerobot.policies.gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy
from lerobot.rl.algorithms.configs import RLAlgorithmConfig, TrainingStats
from lerobot.rl.algorithms.factory import make_algorithm
from lerobot.rl.algorithms.sac import SACAlgorithm, SACAlgorithmConfig
@@ -47,8 +47,8 @@ def _make_sac_config(
utd_ratio: int = 1,
policy_update_freq: int = 1,
with_images: bool = False,
) -> SACConfig:
config = SACConfig(
) -> GaussianActorConfig:
config = GaussianActorConfig(
input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))},
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))},
dataset_stats={
@@ -79,7 +79,7 @@ def _make_algorithm(
policy_update_freq: int = 1,
num_discrete_actions: int | None = None,
with_images: bool = False,
) -> tuple[SACAlgorithm, SACPolicy]:
) -> tuple[SACAlgorithm, GaussianActorPolicy]:
sac_cfg = _make_sac_config(
state_dim=state_dim,
action_dim=action_dim,
@@ -88,7 +88,7 @@ def _make_algorithm(
num_discrete_actions=num_discrete_actions,
with_images=with_images,
)
policy = SACPolicy(config=sac_cfg)
policy = GaussianActorPolicy(config=sac_cfg)
policy.train()
algo_config = SACAlgorithmConfig.from_policy_config(sac_cfg)
algorithm = SACAlgorithm(policy=policy, config=algo_config)
@@ -349,7 +349,7 @@ def test_optimization_step_can_be_set_for_resume():
def test_make_algorithm_returns_sac_for_sac_policy():
sac_cfg = _make_sac_config()
policy = SACPolicy(config=sac_cfg)
policy = GaussianActorPolicy(config=sac_cfg)
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
assert isinstance(algorithm, SACAlgorithm)
assert algorithm.optimizers == {}
@@ -358,7 +358,7 @@ def test_make_algorithm_returns_sac_for_sac_policy():
def test_make_optimizers_creates_expected_keys():
"""make_optimizers_and_scheduler() should populate the algorithm with Adam optimizers."""
sac_cfg = _make_sac_config()
policy = SACPolicy(config=sac_cfg)
policy = GaussianActorPolicy(config=sac_cfg)
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
optimizers = algorithm.make_optimizers_and_scheduler()
assert "actor" in optimizers
@@ -371,7 +371,7 @@ def test_make_optimizers_creates_expected_keys():
def test_actor_side_no_optimizers():
"""Actor-side usage: no optimizers needed, make_optimizers_and_scheduler is not called."""
sac_cfg = _make_sac_config()
policy = SACPolicy(config=sac_cfg)
policy = GaussianActorPolicy(config=sac_cfg)
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
assert isinstance(algorithm, SACAlgorithm)
assert algorithm.optimizers == {}
@@ -379,7 +379,7 @@ def test_actor_side_no_optimizers():
def test_make_algorithm_copies_config_fields():
sac_cfg = _make_sac_config(utd_ratio=5, policy_update_freq=3)
policy = SACPolicy(config=sac_cfg)
policy = GaussianActorPolicy(config=sac_cfg)
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
assert algorithm.config.utd_ratio == 5
assert algorithm.config.policy_update_freq == 3
@@ -404,7 +404,7 @@ def test_load_weights_round_trip():
algo_src.update(_batch_iterator())
sac_cfg = _make_sac_config(state_dim=10, action_dim=6)
policy_dst = SACPolicy(config=sac_cfg)
policy_dst = GaussianActorPolicy(config=sac_cfg)
algo_dst = SACAlgorithm(policy=policy_dst, config=algo_src.config)
weights = algo_src.get_weights()
@@ -423,7 +423,7 @@ def test_load_weights_round_trip_with_discrete_critic():
algo_src.update(_batch_iterator(action_dim=7))
sac_cfg = _make_sac_config(num_discrete_actions=3, action_dim=6)
policy_dst = SACPolicy(config=sac_cfg)
policy_dst = GaussianActorPolicy(config=sac_cfg)
algo_dst = SACAlgorithm(policy=policy_dst, config=algo_src.config)
weights = algo_src.get_weights()
@@ -470,7 +470,7 @@ def test_build_algorithm_via_config():
"""SACAlgorithmConfig.build_algorithm should produce a working SACAlgorithm."""
sac_cfg = _make_sac_config(utd_ratio=2)
algo_config = SACAlgorithmConfig.from_policy_config(sac_cfg)
policy = SACPolicy(config=sac_cfg)
policy = GaussianActorPolicy(config=sac_cfg)
algorithm = algo_config.build_algorithm(policy)
assert isinstance(algorithm, SACAlgorithm)
@@ -480,6 +480,6 @@ def test_build_algorithm_via_config():
def test_make_algorithm_uses_build_algorithm():
"""make_algorithm should delegate to config.build_algorithm (no hardcoded if/else)."""
sac_cfg = _make_sac_config()
policy = SACPolicy(config=sac_cfg)
policy = GaussianActorPolicy(config=sac_cfg)
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
assert isinstance(algorithm, SACAlgorithm)