mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-20 02:59:50 +00:00
refactor(rl/sac): consolidate hyperparameter ownership and clean up discrete critic
This commit is contained in:
@@ -55,9 +55,6 @@ def test_gaussian_actor_config_default_initialization():
|
||||
# Basic parameters
|
||||
assert config.device == "cpu"
|
||||
assert config.storage_device == "cpu"
|
||||
assert config.discount == 0.99
|
||||
assert config.temperature_init == 1.0
|
||||
assert config.num_critics == 2
|
||||
|
||||
# Architecture specifics
|
||||
assert config.vision_encoder_name is None
|
||||
@@ -66,6 +63,8 @@ def test_gaussian_actor_config_default_initialization():
|
||||
assert config.shared_encoder is True
|
||||
assert config.num_discrete_actions is None
|
||||
assert config.image_embedding_pooling_dim == 8
|
||||
assert config.state_encoder_hidden_dim == 256
|
||||
assert config.latent_dim == 256
|
||||
|
||||
# Training parameters
|
||||
assert config.online_steps == 1000000
|
||||
@@ -73,20 +72,6 @@ def test_gaussian_actor_config_default_initialization():
|
||||
assert config.offline_buffer_capacity == 100000
|
||||
assert config.async_prefetch is False
|
||||
assert config.online_step_before_learning == 100
|
||||
assert config.policy_update_freq == 1
|
||||
|
||||
# SAC algorithm parameters
|
||||
assert config.num_subsample_critics is None
|
||||
assert config.critic_lr == 3e-4
|
||||
assert config.actor_lr == 3e-4
|
||||
assert config.temperature_lr == 3e-4
|
||||
assert config.critic_target_update_weight == 0.005
|
||||
assert config.utd_ratio == 1
|
||||
assert config.state_encoder_hidden_dim == 256
|
||||
assert config.latent_dim == 256
|
||||
assert config.target_entropy is None
|
||||
assert config.use_backup_entropy is True
|
||||
assert config.grad_clip_norm == 40.0
|
||||
|
||||
# Dataset stats defaults
|
||||
expected_dataset_stats = {
|
||||
@@ -105,11 +90,6 @@ def test_gaussian_actor_config_default_initialization():
|
||||
}
|
||||
assert config.dataset_stats == expected_dataset_stats
|
||||
|
||||
# Critic network configuration
|
||||
assert config.critic_network_kwargs.hidden_dims == [256, 256]
|
||||
assert config.critic_network_kwargs.activate_final is True
|
||||
assert config.critic_network_kwargs.final_activation is None
|
||||
|
||||
# Actor network configuration
|
||||
assert config.actor_network_kwargs.hidden_dims == [256, 256]
|
||||
assert config.actor_network_kwargs.activate_final is True
|
||||
@@ -135,7 +115,6 @@ def test_gaussian_actor_config_default_initialization():
|
||||
assert config.concurrency.learner == "threads"
|
||||
|
||||
assert isinstance(config.actor_network_kwargs, ActorNetworkConfig)
|
||||
assert isinstance(config.critic_network_kwargs, CriticNetworkConfig)
|
||||
assert isinstance(config.policy_kwargs, PolicyConfig)
|
||||
assert isinstance(config.actor_learner_config, ActorLearnerConfig)
|
||||
assert isinstance(config.concurrency, ConcurrencyConfig)
|
||||
@@ -178,15 +157,15 @@ def test_concurrency_config():
|
||||
def test_gaussian_actor_config_custom_initialization():
|
||||
config = GaussianActorConfig(
|
||||
device="cpu",
|
||||
discount=0.95,
|
||||
temperature_init=0.5,
|
||||
num_critics=3,
|
||||
latent_dim=128,
|
||||
state_encoder_hidden_dim=128,
|
||||
num_discrete_actions=3,
|
||||
)
|
||||
|
||||
assert config.device == "cpu"
|
||||
assert config.discount == 0.95
|
||||
assert config.temperature_init == 0.5
|
||||
assert config.num_critics == 3
|
||||
assert config.latent_dim == 128
|
||||
assert config.state_encoder_hidden_dim == 128
|
||||
assert config.num_discrete_actions == 3
|
||||
|
||||
|
||||
def test_validate_features():
|
||||
|
||||
@@ -404,19 +404,16 @@ def test_sac_training_with_discrete_critic():
|
||||
|
||||
|
||||
def test_sac_algorithm_target_entropy():
|
||||
"""Target entropy is an SAC hyperparameter and lives on the algorithm."""
|
||||
config = create_default_config(continuous_action_dim=10, state_dim=10)
|
||||
_, policy = _make_algorithm(config)
|
||||
algo_config = SACAlgorithmConfig.from_policy_config(config)
|
||||
algorithm = SACAlgorithm(policy=policy, config=algo_config)
|
||||
algorithm, _ = _make_algorithm(config)
|
||||
assert algorithm.target_entropy == -5.0
|
||||
|
||||
|
||||
def test_sac_algorithm_target_entropy_with_discrete_action():
|
||||
config = create_config_with_visual_input(state_dim=10, continuous_action_dim=6, has_discrete_action=True)
|
||||
config.num_discrete_actions = 5
|
||||
algo_config = SACAlgorithmConfig.from_policy_config(config)
|
||||
policy = GaussianActorPolicy(config=config)
|
||||
algorithm = SACAlgorithm(policy=policy, config=algo_config)
|
||||
algorithm, _ = _make_algorithm(config)
|
||||
assert algorithm.target_entropy == -3.5
|
||||
|
||||
|
||||
@@ -435,8 +432,8 @@ def test_sac_algorithm_temperature():
|
||||
|
||||
def test_sac_algorithm_update_target_network():
|
||||
config = create_default_config(state_dim=10, continuous_action_dim=6)
|
||||
config.critic_target_update_weight = 1.0
|
||||
algo_config = SACAlgorithmConfig.from_policy_config(config)
|
||||
algo_config.critic_target_update_weight = 1.0
|
||||
policy = GaussianActorPolicy(config=config)
|
||||
algorithm = SACAlgorithm(policy=policy, config=algo_config)
|
||||
|
||||
@@ -454,9 +451,13 @@ def test_sac_algorithm_with_critics_number_of_heads(num_critics: int):
|
||||
action_dim = 10
|
||||
state_dim = 10
|
||||
config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim)
|
||||
config.num_critics = num_critics
|
||||
|
||||
algorithm, policy = _make_algorithm(config)
|
||||
policy = GaussianActorPolicy(config=config)
|
||||
policy.train()
|
||||
algo_config = SACAlgorithmConfig.from_policy_config(config)
|
||||
algo_config.num_critics = num_critics
|
||||
algorithm = SACAlgorithm(policy=policy, config=algo_config)
|
||||
algorithm.make_optimizers_and_scheduler()
|
||||
|
||||
assert len(algorithm.critic_ensemble.critics) == num_critics
|
||||
|
||||
|
||||
@@ -327,7 +327,6 @@ def test_learner_algorithm_wiring():
|
||||
OBS_STATE: {"min": [0.0] * state_dim, "max": [1.0] * state_dim},
|
||||
ACTION: {"min": [0.0] * action_dim, "max": [1.0] * action_dim},
|
||||
},
|
||||
use_torch_compile=False,
|
||||
)
|
||||
sac_cfg.validate_features()
|
||||
|
||||
@@ -412,7 +411,6 @@ def test_initial_and_periodic_weight_push_consistency():
|
||||
OBS_STATE: {"min": [0.0] * state_dim, "max": [1.0] * state_dim},
|
||||
ACTION: {"min": [0.0] * action_dim, "max": [1.0] * action_dim},
|
||||
},
|
||||
use_torch_compile=False,
|
||||
)
|
||||
sac_cfg.validate_features()
|
||||
|
||||
@@ -450,7 +448,6 @@ def test_actor_side_algorithm_select_action_and_load_weights():
|
||||
OBS_STATE: {"min": [0.0] * state_dim, "max": [1.0] * state_dim},
|
||||
ACTION: {"min": [0.0] * action_dim, "max": [1.0] * action_dim},
|
||||
},
|
||||
use_torch_compile=False,
|
||||
)
|
||||
sac_cfg.validate_features()
|
||||
|
||||
|
||||
@@ -44,8 +44,6 @@ def _make_sac_config(
|
||||
state_dim: int = 10,
|
||||
action_dim: int = 6,
|
||||
num_discrete_actions: int | None = None,
|
||||
utd_ratio: int = 1,
|
||||
policy_update_freq: int = 1,
|
||||
with_images: bool = False,
|
||||
) -> GaussianActorConfig:
|
||||
config = GaussianActorConfig(
|
||||
@@ -55,10 +53,7 @@ def _make_sac_config(
|
||||
OBS_STATE: {"min": [0.0] * state_dim, "max": [1.0] * state_dim},
|
||||
ACTION: {"min": [0.0] * action_dim, "max": [1.0] * action_dim},
|
||||
},
|
||||
utd_ratio=utd_ratio,
|
||||
policy_update_freq=policy_update_freq,
|
||||
num_discrete_actions=num_discrete_actions,
|
||||
use_torch_compile=False,
|
||||
)
|
||||
if with_images:
|
||||
config.input_features[OBS_IMAGE] = PolicyFeature(type=FeatureType.VISUAL, shape=(3, 84, 84))
|
||||
@@ -83,14 +78,14 @@ def _make_algorithm(
|
||||
sac_cfg = _make_sac_config(
|
||||
state_dim=state_dim,
|
||||
action_dim=action_dim,
|
||||
utd_ratio=utd_ratio,
|
||||
policy_update_freq=policy_update_freq,
|
||||
num_discrete_actions=num_discrete_actions,
|
||||
with_images=with_images,
|
||||
)
|
||||
policy = GaussianActorPolicy(config=sac_cfg)
|
||||
policy.train()
|
||||
algo_config = SACAlgorithmConfig.from_policy_config(sac_cfg)
|
||||
algo_config.utd_ratio = utd_ratio
|
||||
algo_config.policy_update_freq = policy_update_freq
|
||||
algorithm = SACAlgorithm(policy=policy, config=algo_config)
|
||||
algorithm.make_optimizers_and_scheduler()
|
||||
return algorithm, policy
|
||||
@@ -136,13 +131,16 @@ def test_sac_algorithm_config_registered():
|
||||
|
||||
|
||||
def test_sac_algorithm_config_from_policy_config():
|
||||
"""from_policy_config should copy algorithm hyperparameters from the policy config."""
|
||||
sac_cfg = _make_sac_config(utd_ratio=4, policy_update_freq=2)
|
||||
"""from_policy_config embeds the policy config and uses SAC defaults."""
|
||||
sac_cfg = _make_sac_config()
|
||||
algo_cfg = SACAlgorithmConfig.from_policy_config(sac_cfg)
|
||||
assert algo_cfg.sac_config is sac_cfg
|
||||
assert algo_cfg.utd_ratio == 4
|
||||
assert algo_cfg.policy_update_freq == 2
|
||||
assert algo_cfg.grad_clip_norm == sac_cfg.grad_clip_norm
|
||||
assert algo_cfg.policy_config is sac_cfg
|
||||
assert algo_cfg.discrete_critic_network_kwargs is sac_cfg.discrete_critic_network_kwargs
|
||||
# Defaults come from SACAlgorithmConfig, not from the policy config.
|
||||
assert algo_cfg.utd_ratio == 1
|
||||
assert algo_cfg.policy_update_freq == 1
|
||||
assert algo_cfg.grad_clip_norm == 40.0
|
||||
assert algo_cfg.actor_lr == 3e-4
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
@@ -377,12 +375,14 @@ def test_actor_side_no_optimizers():
|
||||
assert algorithm.optimizers == {}
|
||||
|
||||
|
||||
def test_make_algorithm_copies_config_fields():
|
||||
sac_cfg = _make_sac_config(utd_ratio=5, policy_update_freq=3)
|
||||
def test_make_algorithm_uses_sac_algorithm_defaults():
|
||||
"""make_algorithm populates SACAlgorithmConfig with its own defaults."""
|
||||
sac_cfg = _make_sac_config()
|
||||
policy = GaussianActorPolicy(config=sac_cfg)
|
||||
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
|
||||
assert algorithm.config.utd_ratio == 5
|
||||
assert algorithm.config.policy_update_freq == 3
|
||||
assert algorithm.config.utd_ratio == 1
|
||||
assert algorithm.config.policy_update_freq == 1
|
||||
assert algorithm.config.grad_clip_norm == 40.0
|
||||
|
||||
|
||||
def test_make_algorithm_raises_for_unknown_type():
|
||||
@@ -431,10 +431,10 @@ def test_load_weights_round_trip_with_discrete_critic():
|
||||
|
||||
assert "discrete_critic" in weights
|
||||
assert len(weights["discrete_critic"]) > 0
|
||||
dst_dc_state_dict = algo_dst.discrete_critic.state_dict()
|
||||
dst_discrete_critic_state_dict = algo_dst.policy.discrete_critic.state_dict()
|
||||
for key, tensor in weights["discrete_critic"].items():
|
||||
assert torch.equal(
|
||||
dst_dc_state_dict[key].cpu(),
|
||||
dst_discrete_critic_state_dict[key].cpu(),
|
||||
tensor.cpu(),
|
||||
), f"Discrete critic param '{key}' mismatch after load_weights"
|
||||
|
||||
@@ -446,6 +446,47 @@ def test_load_weights_ignores_missing_discrete_critic():
|
||||
algorithm.load_weights(weights, device="cpu")
|
||||
|
||||
|
||||
def test_actor_side_weight_sync_with_discrete_critic():
|
||||
"""End-to-end: learner ``algorithm.get_weights()`` -> actor ``policy.load_actor_weights()``."""
|
||||
# Learner side: train the algorithm so its weights diverge from init.
|
||||
algo_src, _ = _make_algorithm(num_discrete_actions=3, action_dim=6)
|
||||
algo_src.update(_batch_iterator(action_dim=7))
|
||||
weights = algo_src.get_weights()
|
||||
|
||||
# Actor side: fresh policy, no algorithm/optimizer.
|
||||
sac_cfg = _make_sac_config(num_discrete_actions=3, action_dim=6)
|
||||
policy_actor = GaussianActorPolicy(config=sac_cfg)
|
||||
|
||||
# Snapshot initial actor state for the "did it change?" assertion below.
|
||||
initial_discrete_critic_state_dict = {
|
||||
k: v.clone() for k, v in policy_actor.discrete_critic.state_dict().items()
|
||||
}
|
||||
|
||||
policy_actor.load_actor_weights(weights, device="cpu")
|
||||
|
||||
# Actor weights match the learner's exported actor state dict.
|
||||
actor_state_dict = policy_actor.actor.state_dict()
|
||||
for key, tensor in weights["policy"].items():
|
||||
assert torch.equal(actor_state_dict[key].cpu(), tensor.cpu()), (
|
||||
f"Actor param '{key}' not synced by load_actor_weights"
|
||||
)
|
||||
|
||||
# Discrete critic weights match the learner's exported discrete critic.
|
||||
discrete_critic_state_dict = policy_actor.discrete_critic.state_dict()
|
||||
for key, tensor in weights["discrete_critic"].items():
|
||||
assert torch.equal(discrete_critic_state_dict[key].cpu(), tensor.cpu()), (
|
||||
f"Discrete critic param '{key}' not synced by load_actor_weights"
|
||||
)
|
||||
|
||||
# Sanity: the discrete critic actually changed (otherwise the sync is trivial).
|
||||
changed = any(
|
||||
not torch.equal(initial_discrete_critic_state_dict[key], discrete_critic_state_dict[key])
|
||||
for key in initial_discrete_critic_state_dict
|
||||
if key in discrete_critic_state_dict
|
||||
)
|
||||
assert changed, "Discrete critic weights did not change between init and after sync"
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# TrainingStats generic losses dict
|
||||
# ===========================================================================
|
||||
@@ -468,8 +509,9 @@ def test_training_stats_generic_losses():
|
||||
|
||||
def test_build_algorithm_via_config():
|
||||
"""SACAlgorithmConfig.build_algorithm should produce a working SACAlgorithm."""
|
||||
sac_cfg = _make_sac_config(utd_ratio=2)
|
||||
sac_cfg = _make_sac_config()
|
||||
algo_config = SACAlgorithmConfig.from_policy_config(sac_cfg)
|
||||
algo_config.utd_ratio = 2
|
||||
policy = GaussianActorPolicy(config=sac_cfg)
|
||||
|
||||
algorithm = algo_config.build_algorithm(policy)
|
||||
|
||||
Reference in New Issue
Block a user