# Copyright 2024 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import pytest import torch from lerobot.optim.optimizers import ( AdamConfig, AdamWConfig, MultiAdamConfig, SGDConfig, load_optimizer_state, load_optimizer_state_dict, save_optimizer_state, ) from lerobot.utils.constants import ( OPTIMIZER_PARAM_GROUPS, OPTIMIZER_STATE, ) @pytest.mark.parametrize( "config_cls, expected_class", [ (AdamConfig, torch.optim.Adam), (AdamWConfig, torch.optim.AdamW), (SGDConfig, torch.optim.SGD), (MultiAdamConfig, dict), ], ) def test_optimizer_build(config_cls, expected_class, model_params): config = config_cls() if config_cls == MultiAdamConfig: params_dict = {"default": model_params} optimizer = config.build(params_dict) assert isinstance(optimizer, expected_class) assert isinstance(optimizer["default"], torch.optim.Adam) assert optimizer["default"].defaults["lr"] == config.lr else: optimizer = config.build(model_params) assert isinstance(optimizer, expected_class) assert optimizer.defaults["lr"] == config.lr def test_save_optimizer_state(optimizer, tmp_path): save_optimizer_state(optimizer, tmp_path) assert (tmp_path / OPTIMIZER_STATE).is_file() assert (tmp_path / OPTIMIZER_PARAM_GROUPS).is_file() def test_save_and_load_optimizer_state(model_params, optimizer, tmp_path): save_optimizer_state(optimizer, tmp_path) loaded_optimizer = AdamConfig().build(model_params) loaded_optimizer = load_optimizer_state(loaded_optimizer, tmp_path) torch.testing.assert_close(optimizer.state_dict(), loaded_optimizer.state_dict()) def test_save_and_load_fsdp_optimizer_state_dict_roundtrip(tmp_path): """The FSDP full optimizer state dict is keyed by parameter FQNs (dotted strings), not the integer indices of the single-GPU path. Verify it survives the safetensors save -> read round-trip used by the FSDP save/resume path (save_optimizer_state(optim_state_dict=...) then load_optimizer_state_dict), which the flatten/unflatten "/" separator must not corrupt.""" full_osd = { "state": { "model.layers.0.weight": { "step": torch.tensor(3.0), "exp_avg": torch.randn(4, 4), "exp_avg_sq": torch.randn(4, 4), }, "model.layers.0.bias": { "step": torch.tensor(3.0), "exp_avg": torch.randn(4), "exp_avg_sq": torch.randn(4), }, }, "param_groups": [ {"lr": 1e-4, "betas": [0.9, 0.999], "eps": 1e-8, "weight_decay": 0.0, "params": [0, 1]} ], } save_optimizer_state( torch.optim.Adam([torch.nn.Parameter(torch.randn(1))]), tmp_path, optim_state_dict=full_osd ) assert (tmp_path / OPTIMIZER_STATE).is_file() assert (tmp_path / OPTIMIZER_PARAM_GROUPS).is_file() loaded = load_optimizer_state_dict(tmp_path) # FQN keys must be preserved verbatim (not int-cast, not split on their dots). assert set(loaded["state"].keys()) == set(full_osd["state"].keys()) for fqn, sub in full_osd["state"].items(): for k, v in sub.items(): torch.testing.assert_close(loaded["state"][fqn][k], v) assert loaded["param_groups"] == full_osd["param_groups"] @pytest.fixture def base_params_dict(): return { "actor": [torch.nn.Parameter(torch.randn(10, 10))], "critic": [torch.nn.Parameter(torch.randn(5, 5))], "temperature": [torch.nn.Parameter(torch.randn(3, 3))], } @pytest.mark.parametrize( "config_params, expected_values", [ # Test 1: Basic configuration with different learning rates ( { "lr": 1e-3, "weight_decay": 1e-4, "optimizer_groups": { "actor": {"lr": 1e-4}, "critic": {"lr": 5e-4}, "temperature": {"lr": 2e-3}, }, }, { "actor": {"lr": 1e-4, "weight_decay": 1e-4, "betas": (0.9, 0.999)}, "critic": {"lr": 5e-4, "weight_decay": 1e-4, "betas": (0.9, 0.999)}, "temperature": {"lr": 2e-3, "weight_decay": 1e-4, "betas": (0.9, 0.999)}, }, ), # Test 2: Different weight decays and beta values ( { "lr": 1e-3, "weight_decay": 1e-4, "optimizer_groups": { "actor": {"lr": 1e-4, "weight_decay": 1e-5}, "critic": {"lr": 5e-4, "weight_decay": 1e-6}, "temperature": {"lr": 2e-3, "betas": (0.95, 0.999)}, }, }, { "actor": {"lr": 1e-4, "weight_decay": 1e-5, "betas": (0.9, 0.999)}, "critic": {"lr": 5e-4, "weight_decay": 1e-6, "betas": (0.9, 0.999)}, "temperature": {"lr": 2e-3, "weight_decay": 1e-4, "betas": (0.95, 0.999)}, }, ), # Test 3: Epsilon parameter customization ( { "lr": 1e-3, "weight_decay": 1e-4, "optimizer_groups": { "actor": {"lr": 1e-4, "eps": 1e-6}, "critic": {"lr": 5e-4, "eps": 1e-7}, "temperature": {"lr": 2e-3, "eps": 1e-8}, }, }, { "actor": {"lr": 1e-4, "weight_decay": 1e-4, "betas": (0.9, 0.999), "eps": 1e-6}, "critic": {"lr": 5e-4, "weight_decay": 1e-4, "betas": (0.9, 0.999), "eps": 1e-7}, "temperature": {"lr": 2e-3, "weight_decay": 1e-4, "betas": (0.9, 0.999), "eps": 1e-8}, }, ), ], ) def test_multi_adam_configuration(base_params_dict, config_params, expected_values): # Create config with the given parameters config = MultiAdamConfig(**config_params) optimizers = config.build(base_params_dict) # Verify optimizer count and keys assert len(optimizers) == len(expected_values) assert set(optimizers.keys()) == set(expected_values.keys()) # Check that all optimizers are Adam instances for opt in optimizers.values(): assert isinstance(opt, torch.optim.Adam) # Verify hyperparameters for each optimizer for name, expected in expected_values.items(): optimizer = optimizers[name] for param, value in expected.items(): assert optimizer.defaults[param] == value @pytest.fixture def multi_optimizers(base_params_dict): config = MultiAdamConfig( lr=1e-3, optimizer_groups={ "actor": {"lr": 1e-4}, "critic": {"lr": 5e-4}, "temperature": {"lr": 2e-3}, }, ) return config.build(base_params_dict) def test_save_multi_optimizer_state(multi_optimizers, tmp_path): # Save optimizer states save_optimizer_state(multi_optimizers, tmp_path) # Verify that directories were created for each optimizer for name in multi_optimizers: assert (tmp_path / name).is_dir() assert (tmp_path / name / OPTIMIZER_STATE).is_file() assert (tmp_path / name / OPTIMIZER_PARAM_GROUPS).is_file() def test_save_and_load_multi_optimizer_state(base_params_dict, multi_optimizers, tmp_path): # Option 1: Add a minimal backward pass to populate optimizer states for name, params in base_params_dict.items(): if name in multi_optimizers: # Create a dummy loss and do backward dummy_loss = params[0].sum() dummy_loss.backward() # Perform an optimization step multi_optimizers[name].step() # Zero gradients for next steps multi_optimizers[name].zero_grad() # Save optimizer states save_optimizer_state(multi_optimizers, tmp_path) # Create new optimizers with the same config config = MultiAdamConfig( lr=1e-3, optimizer_groups={ "actor": {"lr": 1e-4}, "critic": {"lr": 5e-4}, "temperature": {"lr": 2e-3}, }, ) new_optimizers = config.build(base_params_dict) # Load optimizer states loaded_optimizers = load_optimizer_state(new_optimizers, tmp_path) # Verify state dictionaries match for name in multi_optimizers: torch.testing.assert_close(multi_optimizers[name].state_dict(), loaded_optimizers[name].state_dict()) def test_save_and_load_empty_multi_optimizer_state(base_params_dict, tmp_path): """Test saving and loading optimizer states even when the state is empty (no backward pass).""" # Create config and build optimizers config = MultiAdamConfig( lr=1e-3, optimizer_groups={ "actor": {"lr": 1e-4}, "critic": {"lr": 5e-4}, "temperature": {"lr": 2e-3}, }, ) optimizers = config.build(base_params_dict) # Save optimizer states without any backward pass (empty state) save_optimizer_state(optimizers, tmp_path) # Create new optimizers with the same config new_optimizers = config.build(base_params_dict) # Load optimizer states loaded_optimizers = load_optimizer_state(new_optimizers, tmp_path) # Verify hyperparameters match even with empty state for name, optimizer in optimizers.items(): assert optimizer.defaults["lr"] == loaded_optimizers[name].defaults["lr"] assert optimizer.defaults["weight_decay"] == loaded_optimizers[name].defaults["weight_decay"] assert optimizer.defaults["betas"] == loaded_optimizers[name].defaults["betas"] # Verify state dictionaries match (they will be empty) torch.testing.assert_close( optimizer.state_dict()["param_groups"], loaded_optimizers[name].state_dict()["param_groups"] )