mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-20 11:09:59 +00:00
refactor: RL stack refactoring — RLAlgorithm, RLTrainer, DataMixer, and SAC restructuring
This commit is contained in:
@@ -14,7 +14,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
@@ -38,9 +37,6 @@ def test_classifier_output():
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
@pytest.mark.skip(
|
||||
reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers"
|
||||
)
|
||||
def test_binary_classifier_with_default_params():
|
||||
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
|
||||
|
||||
@@ -82,9 +78,6 @@ def test_binary_classifier_with_default_params():
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
@pytest.mark.skip(
|
||||
reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers"
|
||||
)
|
||||
def test_multiclass_classifier():
|
||||
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
|
||||
|
||||
@@ -124,9 +117,6 @@ def test_multiclass_classifier():
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
@pytest.mark.skip(
|
||||
reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers"
|
||||
)
|
||||
def test_default_device():
|
||||
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
|
||||
|
||||
@@ -139,9 +129,6 @@ def test_default_device():
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
@pytest.mark.skip(
|
||||
reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers"
|
||||
)
|
||||
def test_explicit_device_setup():
|
||||
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
|
||||
|
||||
|
||||
+187
-209
@@ -14,8 +14,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
@@ -23,6 +21,7 @@ from torch import Tensor, nn
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.policies.sac.configuration_sac import SACConfig
|
||||
from lerobot.policies.sac.modeling_sac import MLP, SACPolicy
|
||||
from lerobot.rl.algorithms.sac import SACAlgorithm, SACAlgorithmConfig
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE
|
||||
from lerobot.utils.random_utils import seeded_context, set_seed
|
||||
|
||||
@@ -138,41 +137,6 @@ def create_observation_batch_with_visual_input(batch_size: int = 8, state_dim: i
|
||||
}
|
||||
|
||||
|
||||
def make_optimizers(policy: SACPolicy, has_discrete_action: bool = False) -> dict[str, torch.optim.Optimizer]:
|
||||
"""Create optimizers for the SAC policy."""
|
||||
optimizer_actor = torch.optim.Adam(
|
||||
# Handle the case of shared encoder where the encoder weights are not optimized with the actor gradient
|
||||
params=[
|
||||
p
|
||||
for n, p in policy.actor.named_parameters()
|
||||
if not policy.config.shared_encoder or not n.startswith("encoder")
|
||||
],
|
||||
lr=policy.config.actor_lr,
|
||||
)
|
||||
optimizer_critic = torch.optim.Adam(
|
||||
params=policy.critic_ensemble.parameters(),
|
||||
lr=policy.config.critic_lr,
|
||||
)
|
||||
optimizer_temperature = torch.optim.Adam(
|
||||
params=[policy.log_alpha],
|
||||
lr=policy.config.critic_lr,
|
||||
)
|
||||
|
||||
optimizers = {
|
||||
"actor": optimizer_actor,
|
||||
"critic": optimizer_critic,
|
||||
"temperature": optimizer_temperature,
|
||||
}
|
||||
|
||||
if has_discrete_action:
|
||||
optimizers["discrete_critic"] = torch.optim.Adam(
|
||||
params=policy.discrete_critic.parameters(),
|
||||
lr=policy.config.critic_lr,
|
||||
)
|
||||
|
||||
return optimizers
|
||||
|
||||
|
||||
def create_default_config(
|
||||
state_dim: int, continuous_action_dim: int, has_discrete_action: bool = False
|
||||
) -> SACConfig:
|
||||
@@ -212,7 +176,6 @@ def create_config_with_visual_input(
|
||||
"std": torch.randn(3, 1, 1),
|
||||
}
|
||||
|
||||
# Let make tests a little bit faster
|
||||
config.state_encoder_hidden_dim = 32
|
||||
config.latent_dim = 32
|
||||
|
||||
@@ -220,75 +183,112 @@ def create_config_with_visual_input(
|
||||
return config
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 6, 6), (1, 10, 10)])
|
||||
def test_sac_policy_with_default_config(batch_size: int, state_dim: int, action_dim: int):
|
||||
batch = create_default_train_batch(batch_size=batch_size, action_dim=action_dim, state_dim=state_dim)
|
||||
config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim)
|
||||
|
||||
def _make_algorithm(config: SACConfig) -> tuple[SACAlgorithm, SACPolicy]:
|
||||
"""Helper to create policy + algorithm pair for tests that need critics."""
|
||||
policy = SACPolicy(config=config)
|
||||
policy.train()
|
||||
algo_config = SACAlgorithmConfig.from_policy_config(config)
|
||||
algorithm = SACAlgorithm(policy=policy, config=algo_config)
|
||||
algorithm.make_optimizers_and_scheduler()
|
||||
return algorithm, policy
|
||||
|
||||
optimizers = make_optimizers(policy)
|
||||
|
||||
cirtic_loss = policy.forward(batch, model="critic")["loss_critic"]
|
||||
assert cirtic_loss.item() is not None
|
||||
assert cirtic_loss.shape == ()
|
||||
cirtic_loss.backward()
|
||||
optimizers["critic"].step()
|
||||
|
||||
actor_loss = policy.forward(batch, model="actor")["loss_actor"]
|
||||
assert actor_loss.item() is not None
|
||||
assert actor_loss.shape == ()
|
||||
|
||||
actor_loss.backward()
|
||||
optimizers["actor"].step()
|
||||
|
||||
temperature_loss = policy.forward(batch, model="temperature")["loss_temperature"]
|
||||
assert temperature_loss.item() is not None
|
||||
assert temperature_loss.shape == ()
|
||||
|
||||
temperature_loss.backward()
|
||||
optimizers["temperature"].step()
|
||||
|
||||
@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 6, 6), (1, 10, 10)])
|
||||
def test_sac_policy_select_action(batch_size: int, state_dim: int, action_dim: int):
|
||||
config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim)
|
||||
policy = SACPolicy(config=config)
|
||||
policy.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
|
||||
selected_action = policy.select_action(observation_batch)
|
||||
assert selected_action.shape == (batch_size, action_dim)
|
||||
# squeeze(0) removes batch dim when batch_size==1
|
||||
assert selected_action.shape[-1] == action_dim
|
||||
|
||||
|
||||
def test_sac_policy_select_action_with_discrete():
|
||||
"""select_action should return continuous + discrete actions."""
|
||||
config = create_default_config(state_dim=10, continuous_action_dim=6)
|
||||
config.num_discrete_actions = 3
|
||||
policy = SACPolicy(config=config)
|
||||
policy.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
observation_batch = create_observation_batch(batch_size=1, state_dim=10)
|
||||
# Squeeze to unbatched (single observation)
|
||||
observation_batch = {k: v.squeeze(0) for k, v in observation_batch.items()}
|
||||
selected_action = policy.select_action(observation_batch)
|
||||
assert selected_action.shape[-1] == 7 # 6 continuous + 1 discrete
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 6, 6), (1, 10, 10)])
|
||||
def test_sac_policy_with_visual_input(batch_size: int, state_dim: int, action_dim: int):
|
||||
config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim)
|
||||
def test_sac_policy_forward(batch_size: int, state_dim: int, action_dim: int):
|
||||
config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim)
|
||||
policy = SACPolicy(config=config)
|
||||
policy.eval()
|
||||
|
||||
batch = create_default_train_batch(batch_size=batch_size, action_dim=action_dim, state_dim=state_dim)
|
||||
with torch.no_grad():
|
||||
output = policy.forward(batch)
|
||||
assert "action" in output
|
||||
assert "log_prob" in output
|
||||
assert "action_mean" in output
|
||||
assert output["action"].shape == (batch_size, action_dim)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 6, 6), (1, 10, 10)])
|
||||
def test_sac_training_through_algorithm(batch_size: int, state_dim: int, action_dim: int):
|
||||
config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim)
|
||||
algorithm, policy = _make_algorithm(config)
|
||||
|
||||
batch = create_default_train_batch(batch_size=batch_size, action_dim=action_dim, state_dim=state_dim)
|
||||
forward_batch = algorithm._prepare_forward_batch(batch)
|
||||
|
||||
critic_loss = algorithm._compute_loss_critic(forward_batch)
|
||||
assert critic_loss.item() is not None
|
||||
assert critic_loss.shape == ()
|
||||
algorithm.optimizers["critic"].zero_grad()
|
||||
critic_loss.backward()
|
||||
algorithm.optimizers["critic"].step()
|
||||
|
||||
actor_loss = algorithm._compute_loss_actor(forward_batch)
|
||||
assert actor_loss.item() is not None
|
||||
assert actor_loss.shape == ()
|
||||
algorithm.optimizers["actor"].zero_grad()
|
||||
actor_loss.backward()
|
||||
algorithm.optimizers["actor"].step()
|
||||
|
||||
temp_loss = algorithm._compute_loss_temperature(forward_batch)
|
||||
assert temp_loss.item() is not None
|
||||
assert temp_loss.shape == ()
|
||||
algorithm.optimizers["temperature"].zero_grad()
|
||||
temp_loss.backward()
|
||||
algorithm.optimizers["temperature"].step()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 6, 6), (1, 10, 10)])
|
||||
def test_sac_training_with_visual_input(batch_size: int, state_dim: int, action_dim: int):
|
||||
config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim)
|
||||
algorithm, policy = _make_algorithm(config)
|
||||
|
||||
batch = create_train_batch_with_visual_input(
|
||||
batch_size=batch_size, state_dim=state_dim, action_dim=action_dim
|
||||
)
|
||||
forward_batch = algorithm._prepare_forward_batch(batch)
|
||||
|
||||
policy.train()
|
||||
critic_loss = algorithm._compute_loss_critic(forward_batch)
|
||||
assert critic_loss.item() is not None
|
||||
assert critic_loss.shape == ()
|
||||
algorithm.optimizers["critic"].zero_grad()
|
||||
critic_loss.backward()
|
||||
algorithm.optimizers["critic"].step()
|
||||
|
||||
optimizers = make_optimizers(policy)
|
||||
|
||||
cirtic_loss = policy.forward(batch, model="critic")["loss_critic"]
|
||||
assert cirtic_loss.item() is not None
|
||||
assert cirtic_loss.shape == ()
|
||||
cirtic_loss.backward()
|
||||
optimizers["critic"].step()
|
||||
|
||||
actor_loss = policy.forward(batch, model="actor")["loss_actor"]
|
||||
actor_loss = algorithm._compute_loss_actor(forward_batch)
|
||||
assert actor_loss.item() is not None
|
||||
assert actor_loss.shape == ()
|
||||
|
||||
algorithm.optimizers["actor"].zero_grad()
|
||||
actor_loss.backward()
|
||||
optimizers["actor"].step()
|
||||
|
||||
temperature_loss = policy.forward(batch, model="temperature")["loss_temperature"]
|
||||
assert temperature_loss.item() is not None
|
||||
assert temperature_loss.shape == ()
|
||||
|
||||
temperature_loss.backward()
|
||||
optimizers["temperature"].step()
|
||||
algorithm.optimizers["actor"].step()
|
||||
|
||||
policy.eval()
|
||||
with torch.no_grad():
|
||||
@@ -296,210 +296,181 @@ def test_sac_policy_with_visual_input(batch_size: int, state_dim: int, action_di
|
||||
batch_size=batch_size, state_dim=state_dim
|
||||
)
|
||||
selected_action = policy.select_action(observation_batch)
|
||||
assert selected_action.shape == (batch_size, action_dim)
|
||||
assert selected_action.shape[-1] == action_dim
|
||||
|
||||
|
||||
# Let's check best candidates for pretrained encoders
|
||||
@pytest.mark.parametrize(
|
||||
"batch_size,state_dim,action_dim,vision_encoder_name",
|
||||
[(1, 6, 6, "helper2424/resnet10"), (1, 6, 6, "facebook/convnext-base-224")],
|
||||
)
|
||||
@pytest.mark.skipif(not TRANSFORMERS_AVAILABLE, reason="Transformers are not installed")
|
||||
@pytest.mark.skip(
|
||||
reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers"
|
||||
)
|
||||
def test_sac_policy_with_pretrained_encoder(
|
||||
batch_size: int, state_dim: int, action_dim: int, vision_encoder_name: str
|
||||
):
|
||||
config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim)
|
||||
config.vision_encoder_name = vision_encoder_name
|
||||
policy = SACPolicy(config=config)
|
||||
policy.train()
|
||||
algorithm, policy = _make_algorithm(config)
|
||||
|
||||
batch = create_train_batch_with_visual_input(
|
||||
batch_size=batch_size, state_dim=state_dim, action_dim=action_dim
|
||||
)
|
||||
forward_batch = algorithm._prepare_forward_batch(batch)
|
||||
|
||||
optimizers = make_optimizers(policy)
|
||||
critic_loss = algorithm._compute_loss_critic(forward_batch)
|
||||
assert critic_loss.item() is not None
|
||||
assert critic_loss.shape == ()
|
||||
algorithm.optimizers["critic"].zero_grad()
|
||||
critic_loss.backward()
|
||||
algorithm.optimizers["critic"].step()
|
||||
|
||||
cirtic_loss = policy.forward(batch, model="critic")["loss_critic"]
|
||||
assert cirtic_loss.item() is not None
|
||||
assert cirtic_loss.shape == ()
|
||||
cirtic_loss.backward()
|
||||
optimizers["critic"].step()
|
||||
|
||||
actor_loss = policy.forward(batch, model="actor")["loss_actor"]
|
||||
actor_loss = algorithm._compute_loss_actor(forward_batch)
|
||||
assert actor_loss.item() is not None
|
||||
assert actor_loss.shape == ()
|
||||
|
||||
|
||||
def test_sac_policy_with_shared_encoder():
|
||||
def test_sac_training_with_shared_encoder():
|
||||
batch_size = 2
|
||||
action_dim = 10
|
||||
state_dim = 10
|
||||
config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim)
|
||||
config.shared_encoder = True
|
||||
|
||||
policy = SACPolicy(config=config)
|
||||
policy.train()
|
||||
algorithm, policy = _make_algorithm(config)
|
||||
|
||||
batch = create_train_batch_with_visual_input(
|
||||
batch_size=batch_size, state_dim=state_dim, action_dim=action_dim
|
||||
)
|
||||
forward_batch = algorithm._prepare_forward_batch(batch)
|
||||
|
||||
policy.train()
|
||||
critic_loss = algorithm._compute_loss_critic(forward_batch)
|
||||
assert critic_loss.shape == ()
|
||||
algorithm.optimizers["critic"].zero_grad()
|
||||
critic_loss.backward()
|
||||
algorithm.optimizers["critic"].step()
|
||||
|
||||
optimizers = make_optimizers(policy)
|
||||
|
||||
cirtic_loss = policy.forward(batch, model="critic")["loss_critic"]
|
||||
assert cirtic_loss.item() is not None
|
||||
assert cirtic_loss.shape == ()
|
||||
cirtic_loss.backward()
|
||||
optimizers["critic"].step()
|
||||
|
||||
actor_loss = policy.forward(batch, model="actor")["loss_actor"]
|
||||
assert actor_loss.item() is not None
|
||||
actor_loss = algorithm._compute_loss_actor(forward_batch)
|
||||
assert actor_loss.shape == ()
|
||||
|
||||
algorithm.optimizers["actor"].zero_grad()
|
||||
actor_loss.backward()
|
||||
optimizers["actor"].step()
|
||||
algorithm.optimizers["actor"].step()
|
||||
|
||||
|
||||
def test_sac_policy_with_discrete_critic():
|
||||
def test_sac_training_with_discrete_critic():
|
||||
batch_size = 2
|
||||
continuous_action_dim = 9
|
||||
full_action_dim = continuous_action_dim + 1 # the last action is discrete
|
||||
full_action_dim = continuous_action_dim + 1
|
||||
state_dim = 10
|
||||
config = create_config_with_visual_input(
|
||||
state_dim=state_dim, continuous_action_dim=continuous_action_dim, has_discrete_action=True
|
||||
)
|
||||
config.num_discrete_actions = 5
|
||||
|
||||
num_discrete_actions = 5
|
||||
config.num_discrete_actions = num_discrete_actions
|
||||
|
||||
policy = SACPolicy(config=config)
|
||||
policy.train()
|
||||
algorithm, policy = _make_algorithm(config)
|
||||
|
||||
batch = create_train_batch_with_visual_input(
|
||||
batch_size=batch_size, state_dim=state_dim, action_dim=full_action_dim
|
||||
)
|
||||
forward_batch = algorithm._prepare_forward_batch(batch)
|
||||
|
||||
policy.train()
|
||||
critic_loss = algorithm._compute_loss_critic(forward_batch)
|
||||
assert critic_loss.shape == ()
|
||||
algorithm.optimizers["critic"].zero_grad()
|
||||
critic_loss.backward()
|
||||
algorithm.optimizers["critic"].step()
|
||||
|
||||
optimizers = make_optimizers(policy, has_discrete_action=True)
|
||||
|
||||
cirtic_loss = policy.forward(batch, model="critic")["loss_critic"]
|
||||
assert cirtic_loss.item() is not None
|
||||
assert cirtic_loss.shape == ()
|
||||
cirtic_loss.backward()
|
||||
optimizers["critic"].step()
|
||||
|
||||
discrete_critic_loss = policy.forward(batch, model="discrete_critic")["loss_discrete_critic"]
|
||||
assert discrete_critic_loss.item() is not None
|
||||
discrete_critic_loss = algorithm._compute_loss_discrete_critic(forward_batch)
|
||||
assert discrete_critic_loss.shape == ()
|
||||
algorithm.optimizers["discrete_critic"].zero_grad()
|
||||
discrete_critic_loss.backward()
|
||||
optimizers["discrete_critic"].step()
|
||||
algorithm.optimizers["discrete_critic"].step()
|
||||
|
||||
actor_loss = policy.forward(batch, model="actor")["loss_actor"]
|
||||
assert actor_loss.item() is not None
|
||||
actor_loss = algorithm._compute_loss_actor(forward_batch)
|
||||
assert actor_loss.shape == ()
|
||||
|
||||
algorithm.optimizers["actor"].zero_grad()
|
||||
actor_loss.backward()
|
||||
optimizers["actor"].step()
|
||||
algorithm.optimizers["actor"].step()
|
||||
|
||||
policy.eval()
|
||||
with torch.no_grad():
|
||||
observation_batch = create_observation_batch_with_visual_input(
|
||||
batch_size=batch_size, state_dim=state_dim
|
||||
)
|
||||
selected_action = policy.select_action(observation_batch)
|
||||
assert selected_action.shape == (batch_size, full_action_dim)
|
||||
|
||||
discrete_actions = selected_action[:, -1].long()
|
||||
discrete_action_values = set(discrete_actions.tolist())
|
||||
|
||||
assert all(action in range(num_discrete_actions) for action in discrete_action_values), (
|
||||
f"Discrete action {discrete_action_values} is not in range({num_discrete_actions})"
|
||||
)
|
||||
# Policy.select_action now handles both continuous + discrete
|
||||
selected_action = policy.select_action({k: v.squeeze(0) for k, v in observation_batch.items()})
|
||||
assert selected_action.shape[-1] == continuous_action_dim + 1
|
||||
|
||||
|
||||
def test_sac_policy_with_default_entropy():
|
||||
def test_sac_algorithm_target_entropy():
|
||||
config = create_default_config(continuous_action_dim=10, state_dim=10)
|
||||
policy = SACPolicy(config=config)
|
||||
assert policy.target_entropy == -5.0
|
||||
_, policy = _make_algorithm(config)
|
||||
algo_config = SACAlgorithmConfig.from_policy_config(config)
|
||||
algorithm = SACAlgorithm(policy=policy, config=algo_config)
|
||||
assert algorithm.target_entropy == -5.0
|
||||
|
||||
|
||||
def test_sac_policy_default_target_entropy_with_discrete_action():
|
||||
def test_sac_algorithm_target_entropy_with_discrete_action():
|
||||
config = create_config_with_visual_input(state_dim=10, continuous_action_dim=6, has_discrete_action=True)
|
||||
config.num_discrete_actions = 5
|
||||
algo_config = SACAlgorithmConfig.from_policy_config(config)
|
||||
policy = SACPolicy(config=config)
|
||||
assert policy.target_entropy == -3.0
|
||||
algorithm = SACAlgorithm(policy=policy, config=algo_config)
|
||||
assert algorithm.target_entropy == -3.5
|
||||
|
||||
|
||||
def test_sac_policy_with_predefined_entropy():
|
||||
config = create_default_config(state_dim=10, continuous_action_dim=6)
|
||||
config.target_entropy = -3.5
|
||||
def test_sac_algorithm_temperature():
|
||||
import math
|
||||
|
||||
policy = SACPolicy(config=config)
|
||||
assert policy.target_entropy == pytest.approx(-3.5)
|
||||
|
||||
|
||||
def test_sac_policy_update_temperature():
|
||||
"""Test that temperature property is always in sync with log_alpha."""
|
||||
config = create_default_config(continuous_action_dim=10, state_dim=10)
|
||||
algo_config = SACAlgorithmConfig.from_policy_config(config)
|
||||
policy = SACPolicy(config=config)
|
||||
algorithm = SACAlgorithm(policy=policy, config=algo_config)
|
||||
|
||||
assert policy.temperature == pytest.approx(1.0)
|
||||
policy.log_alpha.data = torch.tensor([math.log(0.1)])
|
||||
# Temperature property automatically reflects log_alpha changes
|
||||
assert policy.temperature == pytest.approx(0.1)
|
||||
assert algorithm.temperature == pytest.approx(1.0)
|
||||
algorithm.log_alpha.data = torch.tensor([math.log(0.1)])
|
||||
assert algorithm.temperature == pytest.approx(0.1)
|
||||
|
||||
|
||||
def test_sac_policy_update_target_network():
|
||||
def test_sac_algorithm_update_target_network():
|
||||
config = create_default_config(state_dim=10, continuous_action_dim=6)
|
||||
config.critic_target_update_weight = 1.0
|
||||
|
||||
algo_config = SACAlgorithmConfig.from_policy_config(config)
|
||||
policy = SACPolicy(config=config)
|
||||
policy.train()
|
||||
algorithm = SACAlgorithm(policy=policy, config=algo_config)
|
||||
|
||||
for p in policy.critic_ensemble.parameters():
|
||||
for p in algorithm.critic_ensemble.parameters():
|
||||
p.data = torch.ones_like(p.data)
|
||||
|
||||
policy.update_target_networks()
|
||||
for p in policy.critic_target.parameters():
|
||||
assert torch.allclose(p.data, torch.ones_like(p.data)), (
|
||||
f"Target network {p.data} is not equal to {torch.ones_like(p.data)}"
|
||||
)
|
||||
algorithm._update_target_networks()
|
||||
for p in algorithm.critic_target.parameters():
|
||||
assert torch.allclose(p.data, torch.ones_like(p.data))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_critics", [1, 3])
|
||||
def test_sac_policy_with_critics_number_of_heads(num_critics: int):
|
||||
def test_sac_algorithm_with_critics_number_of_heads(num_critics: int):
|
||||
batch_size = 2
|
||||
action_dim = 10
|
||||
state_dim = 10
|
||||
config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim)
|
||||
config.num_critics = num_critics
|
||||
|
||||
policy = SACPolicy(config=config)
|
||||
policy.train()
|
||||
algorithm, policy = _make_algorithm(config)
|
||||
|
||||
assert len(policy.critic_ensemble.critics) == num_critics
|
||||
assert len(algorithm.critic_ensemble.critics) == num_critics
|
||||
|
||||
batch = create_train_batch_with_visual_input(
|
||||
batch_size=batch_size, state_dim=state_dim, action_dim=action_dim
|
||||
)
|
||||
forward_batch = algorithm._prepare_forward_batch(batch)
|
||||
|
||||
policy.train()
|
||||
|
||||
optimizers = make_optimizers(policy)
|
||||
|
||||
cirtic_loss = policy.forward(batch, model="critic")["loss_critic"]
|
||||
assert cirtic_loss.item() is not None
|
||||
assert cirtic_loss.shape == ()
|
||||
cirtic_loss.backward()
|
||||
optimizers["critic"].step()
|
||||
critic_loss = algorithm._compute_loss_critic(forward_batch)
|
||||
assert critic_loss.shape == ()
|
||||
algorithm.optimizers["critic"].zero_grad()
|
||||
critic_loss.backward()
|
||||
algorithm.optimizers["critic"].step()
|
||||
|
||||
|
||||
def test_sac_policy_save_and_load(tmp_path):
|
||||
"""Test that the policy can be saved and loaded from pretrained."""
|
||||
root = tmp_path / "test_sac_save_and_load"
|
||||
|
||||
state_dim = 10
|
||||
@@ -513,34 +484,41 @@ def test_sac_policy_save_and_load(tmp_path):
|
||||
loaded_policy = SACPolicy.from_pretrained(root, config=config)
|
||||
loaded_policy.eval()
|
||||
|
||||
batch = create_default_train_batch(batch_size=1, state_dim=10, action_dim=10)
|
||||
assert policy.state_dict().keys() == loaded_policy.state_dict().keys()
|
||||
for k in policy.state_dict():
|
||||
assert torch.allclose(policy.state_dict()[k], loaded_policy.state_dict()[k], atol=1e-6)
|
||||
|
||||
with torch.no_grad():
|
||||
with seeded_context(12):
|
||||
# Collect policy values before saving
|
||||
cirtic_loss = policy.forward(batch, model="critic")["loss_critic"]
|
||||
actor_loss = policy.forward(batch, model="actor")["loss_actor"]
|
||||
temperature_loss = policy.forward(batch, model="temperature")["loss_temperature"]
|
||||
|
||||
observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
|
||||
actions = policy.select_action(observation_batch)
|
||||
|
||||
with seeded_context(12):
|
||||
# Collect policy values after loading
|
||||
loaded_cirtic_loss = loaded_policy.forward(batch, model="critic")["loss_critic"]
|
||||
loaded_actor_loss = loaded_policy.forward(batch, model="actor")["loss_actor"]
|
||||
loaded_temperature_loss = loaded_policy.forward(batch, model="temperature")["loss_temperature"]
|
||||
|
||||
loaded_observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
|
||||
loaded_actions = loaded_policy.select_action(loaded_observation_batch)
|
||||
|
||||
assert policy.state_dict().keys() == loaded_policy.state_dict().keys()
|
||||
for k in policy.state_dict():
|
||||
assert torch.allclose(policy.state_dict()[k], loaded_policy.state_dict()[k], atol=1e-6)
|
||||
|
||||
# Compare values before and after saving and loading
|
||||
# They should be the same
|
||||
assert torch.allclose(cirtic_loss, loaded_cirtic_loss)
|
||||
assert torch.allclose(actor_loss, loaded_actor_loss)
|
||||
assert torch.allclose(temperature_loss, loaded_temperature_loss)
|
||||
assert torch.allclose(actions, loaded_actions)
|
||||
|
||||
|
||||
def test_sac_policy_save_and_load_with_discrete_critic(tmp_path):
|
||||
"""Discrete critic should be saved/loaded as part of the policy."""
|
||||
root = tmp_path / "test_sac_save_and_load_discrete"
|
||||
|
||||
state_dim = 10
|
||||
action_dim = 6
|
||||
|
||||
config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim)
|
||||
config.num_discrete_actions = 3
|
||||
policy = SACPolicy(config=config)
|
||||
policy.eval()
|
||||
policy.save_pretrained(root)
|
||||
|
||||
loaded_policy = SACPolicy.from_pretrained(root, config=config)
|
||||
loaded_policy.eval()
|
||||
|
||||
assert loaded_policy.discrete_critic is not None
|
||||
dc_keys = [k for k in loaded_policy.state_dict() if k.startswith("discrete_critic.")]
|
||||
assert len(dc_keys) > 0
|
||||
|
||||
for k in policy.state_dict():
|
||||
assert torch.allclose(policy.state_dict()[k], loaded_policy.state_dict()[k], atol=1e-6)
|
||||
|
||||
Reference in New Issue
Block a user