diff --git a/tests/artifacts/policies/aloha_sim_insertion_human_act_/actions.safetensors b/tests/artifacts/policies/aloha_sim_insertion_human_act_/actions.safetensors index 8bd63e894..771af2445 100644 --- a/tests/artifacts/policies/aloha_sim_insertion_human_act_/actions.safetensors +++ b/tests/artifacts/policies/aloha_sim_insertion_human_act_/actions.safetensors @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:f3e4c8e85e146b043fd4e4984947c2a6f01627f174a19f18b5914cf690579d77 +oid sha256:ee0c29d3782aa1cadcf4dc6ed767d9460ff00fff9fc70b460502340b832eefcc size 5104 diff --git a/tests/artifacts/policies/aloha_sim_insertion_human_act_/param_stats.safetensors b/tests/artifacts/policies/aloha_sim_insertion_human_act_/param_stats.safetensors index 724d22b58..3e8df708e 100644 --- a/tests/artifacts/policies/aloha_sim_insertion_human_act_/param_stats.safetensors +++ b/tests/artifacts/policies/aloha_sim_insertion_human_act_/param_stats.safetensors @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:9b5f557e30aead3731c38cbd85af8c706395d8689a918ad88805b5a886245603 -size 33400 +oid sha256:ea76e6711959fd3f905ec2bdc306f488920f00ec99421e4870d05f6205eb323e +size 31672 diff --git a/tests/artifacts/policies/aloha_sim_insertion_human_act_1000_steps/actions.safetensors b/tests/artifacts/policies/aloha_sim_insertion_human_act_1000_steps/actions.safetensors index 6d912d81a..dd7d4d0e7 100644 --- a/tests/artifacts/policies/aloha_sim_insertion_human_act_1000_steps/actions.safetensors +++ b/tests/artifacts/policies/aloha_sim_insertion_human_act_1000_steps/actions.safetensors @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:2e6625cabfeb4800abc80252cf9112a9271c154edd01eb291658f143c951610b +oid sha256:c2b8f8532c7a0b776de5e536b8b54e30b1a0c2e3d5cc25a2d86fe43e40ae5e8c size 515400 diff --git a/tests/artifacts/policies/aloha_sim_insertion_human_act_1000_steps/param_stats.safetensors b/tests/artifacts/policies/aloha_sim_insertion_human_act_1000_steps/param_stats.safetensors index cc6b4a24b..5da67a1af 100644 --- a/tests/artifacts/policies/aloha_sim_insertion_human_act_1000_steps/param_stats.safetensors +++ b/tests/artifacts/policies/aloha_sim_insertion_human_act_1000_steps/param_stats.safetensors @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:021562ee3e4814425e367ed0c144d6fbe2eb28838247085716cf0b58fd69a075 -size 33400 +oid sha256:eca0d87a699620e4fec7e68539b0be91e4cc933f6bf12032da52c182ab6f38cf +size 31672 diff --git a/tests/artifacts/policies/pusht_diffusion_/actions.safetensors b/tests/artifacts/policies/pusht_diffusion_/actions.safetensors index 84e14b975..ef581727d 100644 --- a/tests/artifacts/policies/pusht_diffusion_/actions.safetensors +++ b/tests/artifacts/policies/pusht_diffusion_/actions.safetensors @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:a32376dde65a1562403afd1db3e56c7e6b987ebaf6c3c601336e77155b9e608c +oid sha256:19eaaa85f66ba4aa6388dbb83819ffad6ea4363247208f871a8dc385689f6fc8 size 992 diff --git a/tests/artifacts/policies/pusht_diffusion_/grad_stats.safetensors b/tests/artifacts/policies/pusht_diffusion_/grad_stats.safetensors index 542297910..e00ed3238 100644 --- a/tests/artifacts/policies/pusht_diffusion_/grad_stats.safetensors +++ b/tests/artifacts/policies/pusht_diffusion_/grad_stats.safetensors @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:12ee532c53173d0361ebb979f087b229cc045aa3d9e6b94cfd4290af54fd1201 +oid sha256:227296eaeeb54acdc3dae2eb8af3d4d08fb87e245337624447140b1e91cfd002 size 47424 diff --git a/tests/artifacts/policies/pusht_diffusion_/param_stats.safetensors b/tests/artifacts/policies/pusht_diffusion_/param_stats.safetensors index e91cd08b7..614cc754e 100644 --- a/tests/artifacts/policies/pusht_diffusion_/param_stats.safetensors +++ b/tests/artifacts/policies/pusht_diffusion_/param_stats.safetensors @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:010c01181b95625051276d69cb4209423c21f2e30a3fa9464ae67064a2ba4c22 -size 49120 +oid sha256:778fddbbaa64248cee35cb377c02cc2b6076f7ce5855146de677128900617ddf +size 47424 diff --git a/tests/artifacts/policies/save_policy_to_safetensors.py b/tests/artifacts/policies/save_policy_to_safetensors.py index 6ccb47c3e..5a4606a8c 100644 --- a/tests/artifacts/policies/save_policy_to_safetensors.py +++ b/tests/artifacts/policies/save_policy_to_safetensors.py @@ -23,7 +23,8 @@ from lerobot.configs.default import DatasetConfig from lerobot.configs.train import TrainPipelineConfig from lerobot.datasets.factory import make_dataset from lerobot.optim.factory import make_optimizer_and_scheduler -from lerobot.policies.factory import make_policy, make_policy_config +from lerobot.policies.factory import make_policy, make_policy_config, make_processor +from lerobot.processor import TransitionKey from lerobot.utils.random_utils import set_seed @@ -37,7 +38,9 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict): train_cfg.validate() # Needed for auto-setting some parameters dataset = make_dataset(train_cfg) + dataset_stats = dataset.meta.stats policy = make_policy(train_cfg.policy, ds_meta=dataset.meta) + preprocessor, postprocessor = make_processor(train_cfg.policy, dataset_stats=dataset_stats) policy.train() optimizer, _ = make_optimizer_and_scheduler(train_cfg, policy) @@ -49,7 +52,9 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict): ) batch = next(iter(dataloader)) + batch = preprocessor(batch) loss, output_dict = policy.forward(batch) + if output_dict is not None: output_dict = {k: v for k, v in output_dict.items() if isinstance(v, torch.Tensor)} output_dict["loss"] = loss @@ -96,7 +101,12 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict): else: actions_queue = train_cfg.policy.n_action_repeats - actions = {str(i): policy.select_action(obs).contiguous() for i in range(actions_queue)} + actions = {} + for i in range(actions_queue): + unnormalized_action = policy.select_action(obs).contiguous() + action_robot = postprocessor({TransitionKey.ACTION: unnormalized_action}).get(TransitionKey.ACTION) + actions[str(i)] = action_robot + return output_dict, grad_stats, param_stats, actions diff --git a/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_mpc/actions.safetensors b/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_mpc/actions.safetensors index fa9bf06ab..e23eacffd 100644 --- a/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_mpc/actions.safetensors +++ b/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_mpc/actions.safetensors @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:c5edc5600d7206f027cb696a597bc99fcdd9073a15fa130b8031c52c0a7c134b +oid sha256:d640988f2269cf6aa03c8ee17f9d096edace83d837f90025011fafec5bf53c61 size 200 diff --git a/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_mpc/grad_stats.safetensors b/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_mpc/grad_stats.safetensors index 8d90a671f..e665f73c6 100644 --- a/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_mpc/grad_stats.safetensors +++ b/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_mpc/grad_stats.safetensors @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:a70e29263afdbff3a49d7041ff2d5065df75472b7c030cc8a5d12ab20d24cc10 +oid sha256:32ddf36af25791935b395c7641531cda14d5c4a2cf654a2e76ac45271665d07a size 16904 diff --git a/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_mpc/output_dict.safetensors b/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_mpc/output_dict.safetensors index cde6c6dca..97d783580 100644 --- a/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_mpc/output_dict.safetensors +++ b/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_mpc/output_dict.safetensors @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:c49a5b4d4df92c9564009780f5e286ddfca84ca2b1753557024057b3b36afb8b +oid sha256:22a1031a2acfc36a455bff73ffbe097cfeb7742b6485e7422507e78d7a682703 size 164 diff --git a/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_mpc/param_stats.safetensors b/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_mpc/param_stats.safetensors index 692377d1f..3090b7051 100644 --- a/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_mpc/param_stats.safetensors +++ b/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_mpc/param_stats.safetensors @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:5f8d19a86065937cffdd3ca49caef87c59e67d419b28f40f2817bad892dc3170 -size 36312 +oid sha256:b5dca7940998421ae58e9e26b2b2641b058d23b0270b7a147ebf85fbbdce7184 +size 35496 diff --git a/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_policy/actions.safetensors b/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_policy/actions.safetensors index 7a0b165e2..5ce44048f 100644 --- a/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_policy/actions.safetensors +++ b/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_policy/actions.safetensors @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:a9c08753ddc43b6c02a176418b81eb784146e59f4fc914591cbd3582ade392bb +oid sha256:2212ae7b910d14d723214f5af50985e419f7bd0f4261565ef48b1ef495443d6d size 200 diff --git a/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_policy/grad_stats.safetensors b/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_policy/grad_stats.safetensors index 8d90a671f..e665f73c6 100644 --- a/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_policy/grad_stats.safetensors +++ b/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_policy/grad_stats.safetensors @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:a70e29263afdbff3a49d7041ff2d5065df75472b7c030cc8a5d12ab20d24cc10 +oid sha256:32ddf36af25791935b395c7641531cda14d5c4a2cf654a2e76ac45271665d07a size 16904 diff --git a/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_policy/output_dict.safetensors b/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_policy/output_dict.safetensors index cde6c6dca..97d783580 100644 --- a/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_policy/output_dict.safetensors +++ b/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_policy/output_dict.safetensors @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:c49a5b4d4df92c9564009780f5e286ddfca84ca2b1753557024057b3b36afb8b +oid sha256:22a1031a2acfc36a455bff73ffbe097cfeb7742b6485e7422507e78d7a682703 size 164 diff --git a/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_policy/param_stats.safetensors b/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_policy/param_stats.safetensors index 692377d1f..3090b7051 100644 --- a/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_policy/param_stats.safetensors +++ b/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_policy/param_stats.safetensors @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:5f8d19a86065937cffdd3ca49caef87c59e67d419b28f40f2817bad892dc3170 -size 36312 +oid sha256:b5dca7940998421ae58e9e26b2b2641b058d23b0270b7a147ebf85fbbdce7184 +size 35496 diff --git a/tests/policies/test_policies.py b/tests/policies/test_policies.py index 79249f672..a135b344f 100644 --- a/tests/policies/test_policies.py +++ b/tests/policies/test_policies.py @@ -26,7 +26,7 @@ from safetensors.torch import load_file from lerobot import available_policies from lerobot.configs.default import DatasetConfig from lerobot.configs.train import TrainPipelineConfig -from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature +from lerobot.configs.types import FeatureType, PolicyFeature from lerobot.constants import ACTION, OBS_STATE from lerobot.datasets.factory import make_dataset from lerobot.datasets.utils import cycle, dataset_to_policy_features @@ -41,7 +41,6 @@ from lerobot.policies.factory import ( make_policy_config, make_processor, ) -from lerobot.policies.normalize import Normalize, Unnormalize from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.utils.random_utils import seeded_context from tests.artifacts.policies.save_policy_to_safetensors import get_policy_stats @@ -266,108 +265,6 @@ def test_save_and_load_pretrained(dummy_dataset_metadata, tmp_path, policy_name: torch.testing.assert_close(list(policy.parameters()), list(loaded_policy.parameters()), rtol=0, atol=0) -@pytest.mark.parametrize("insert_temporal_dim", [False, True]) -def test_normalize(insert_temporal_dim): - """ - Test that normalize/unnormalize can run without exceptions when properly set up, and that they raise - an exception when the forward pass is called without the stats having been provided. - - TODO(rcadene, alexander-soare): This should also test that the normalization / unnormalization works as - expected. - """ - - input_features = { - "observation.image": PolicyFeature( - type=FeatureType.VISUAL, - shape=(3, 96, 96), - ), - "observation.state": PolicyFeature( - type=FeatureType.STATE, - shape=(10,), - ), - } - output_features = { - "action": PolicyFeature( - type=FeatureType.ACTION, - shape=(5,), - ), - } - - norm_map = { - "VISUAL": NormalizationMode.MEAN_STD, - "STATE": NormalizationMode.MIN_MAX, - "ACTION": NormalizationMode.MIN_MAX, - } - - dataset_stats = { - "observation.image": { - "mean": torch.randn(3, 1, 1), - "std": torch.randn(3, 1, 1), - "min": torch.randn(3, 1, 1), - "max": torch.randn(3, 1, 1), - }, - "observation.state": { - "mean": torch.randn(10), - "std": torch.randn(10), - "min": torch.randn(10), - "max": torch.randn(10), - }, - "action": { - "mean": torch.randn(5), - "std": torch.randn(5), - "min": torch.randn(5), - "max": torch.randn(5), - }, - } - - bsize = 2 - input_batch = { - "observation.image": torch.randn(bsize, 3, 96, 96), - "observation.state": torch.randn(bsize, 10), - } - output_batch = { - "action": torch.randn(bsize, 5), - } - - if insert_temporal_dim: - tdim = 4 - - for key in input_batch: - # [2,3,96,96] -> [2,tdim,3,96,96] - input_batch[key] = torch.stack([input_batch[key]] * tdim, dim=1) - - for key in output_batch: - output_batch[key] = torch.stack([output_batch[key]] * tdim, dim=1) - - # test without stats - normalize = Normalize(input_features, norm_map, stats=None) - with pytest.raises(AssertionError): - normalize(input_batch) - - # test with stats - normalize = Normalize(input_features, norm_map, stats=dataset_stats) - normalize(input_batch) - - # test loading pretrained models - new_normalize = Normalize(input_features, norm_map, stats=None) - new_normalize.load_state_dict(normalize.state_dict()) - new_normalize(input_batch) - - # test without stats - unnormalize = Unnormalize(output_features, norm_map, stats=None) - with pytest.raises(AssertionError): - unnormalize(output_batch) - - # test with stats - unnormalize = Unnormalize(output_features, norm_map, stats=dataset_stats) - unnormalize(output_batch) - - # test loading pretrained models - new_unnormalize = Unnormalize(output_features, norm_map, stats=None) - new_unnormalize.load_state_dict(unnormalize.state_dict()) - unnormalize(output_batch) - - @pytest.mark.parametrize("multikey", [True, False]) def test_multikey_construction(multikey: bool): """ @@ -467,6 +364,8 @@ def test_backward_compatibility(ds_repo_id: str, policy_name: str, policy_kwargs NOTE: If the test does not pass, and you don't change the policy, it is likely that the test artifact is out of date. For example, some PyTorch versions have different randomness, see this PR: https://github.com/huggingface/lerobot/pull/1127. + NOTE: If the test don't pass and you don't change the policy, and note the dependencies version, + and you changed your processor, you might have to update the test artifact. """