feat(policies): convert save_policy_to_safetensors with pipeline

This commit is contained in:
AdilZouitine
2025-08-08 13:21:50 +02:00
parent 5f759b1637
commit e4fd30a8d4
17 changed files with 35 additions and 126 deletions
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:f3e4c8e85e146b043fd4e4984947c2a6f01627f174a19f18b5914cf690579d77
oid sha256:ee0c29d3782aa1cadcf4dc6ed767d9460ff00fff9fc70b460502340b832eefcc
size 5104
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:9b5f557e30aead3731c38cbd85af8c706395d8689a918ad88805b5a886245603
size 33400
oid sha256:ea76e6711959fd3f905ec2bdc306f488920f00ec99421e4870d05f6205eb323e
size 31672
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:2e6625cabfeb4800abc80252cf9112a9271c154edd01eb291658f143c951610b
oid sha256:c2b8f8532c7a0b776de5e536b8b54e30b1a0c2e3d5cc25a2d86fe43e40ae5e8c
size 515400
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:021562ee3e4814425e367ed0c144d6fbe2eb28838247085716cf0b58fd69a075
size 33400
oid sha256:eca0d87a699620e4fec7e68539b0be91e4cc933f6bf12032da52c182ab6f38cf
size 31672
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:a32376dde65a1562403afd1db3e56c7e6b987ebaf6c3c601336e77155b9e608c
oid sha256:19eaaa85f66ba4aa6388dbb83819ffad6ea4363247208f871a8dc385689f6fc8
size 992
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:12ee532c53173d0361ebb979f087b229cc045aa3d9e6b94cfd4290af54fd1201
oid sha256:227296eaeeb54acdc3dae2eb8af3d4d08fb87e245337624447140b1e91cfd002
size 47424
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:010c01181b95625051276d69cb4209423c21f2e30a3fa9464ae67064a2ba4c22
size 49120
oid sha256:778fddbbaa64248cee35cb377c02cc2b6076f7ce5855146de677128900617ddf
size 47424
@@ -23,7 +23,8 @@ from lerobot.configs.default import DatasetConfig
from lerobot.configs.train import TrainPipelineConfig
from lerobot.datasets.factory import make_dataset
from lerobot.optim.factory import make_optimizer_and_scheduler
from lerobot.policies.factory import make_policy, make_policy_config
from lerobot.policies.factory import make_policy, make_policy_config, make_processor
from lerobot.processor import TransitionKey
from lerobot.utils.random_utils import set_seed
@@ -37,7 +38,9 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict):
train_cfg.validate() # Needed for auto-setting some parameters
dataset = make_dataset(train_cfg)
dataset_stats = dataset.meta.stats
policy = make_policy(train_cfg.policy, ds_meta=dataset.meta)
preprocessor, postprocessor = make_processor(train_cfg.policy, dataset_stats=dataset_stats)
policy.train()
optimizer, _ = make_optimizer_and_scheduler(train_cfg, policy)
@@ -49,7 +52,9 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict):
)
batch = next(iter(dataloader))
batch = preprocessor(batch)
loss, output_dict = policy.forward(batch)
if output_dict is not None:
output_dict = {k: v for k, v in output_dict.items() if isinstance(v, torch.Tensor)}
output_dict["loss"] = loss
@@ -96,7 +101,12 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict):
else:
actions_queue = train_cfg.policy.n_action_repeats
actions = {str(i): policy.select_action(obs).contiguous() for i in range(actions_queue)}
actions = {}
for i in range(actions_queue):
unnormalized_action = policy.select_action(obs).contiguous()
action_robot = postprocessor({TransitionKey.ACTION: unnormalized_action}).get(TransitionKey.ACTION)
actions[str(i)] = action_robot
return output_dict, grad_stats, param_stats, actions
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:c5edc5600d7206f027cb696a597bc99fcdd9073a15fa130b8031c52c0a7c134b
oid sha256:d640988f2269cf6aa03c8ee17f9d096edace83d837f90025011fafec5bf53c61
size 200
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:a70e29263afdbff3a49d7041ff2d5065df75472b7c030cc8a5d12ab20d24cc10
oid sha256:32ddf36af25791935b395c7641531cda14d5c4a2cf654a2e76ac45271665d07a
size 16904
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:c49a5b4d4df92c9564009780f5e286ddfca84ca2b1753557024057b3b36afb8b
oid sha256:22a1031a2acfc36a455bff73ffbe097cfeb7742b6485e7422507e78d7a682703
size 164
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:5f8d19a86065937cffdd3ca49caef87c59e67d419b28f40f2817bad892dc3170
size 36312
oid sha256:b5dca7940998421ae58e9e26b2b2641b058d23b0270b7a147ebf85fbbdce7184
size 35496
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:a9c08753ddc43b6c02a176418b81eb784146e59f4fc914591cbd3582ade392bb
oid sha256:2212ae7b910d14d723214f5af50985e419f7bd0f4261565ef48b1ef495443d6d
size 200
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:a70e29263afdbff3a49d7041ff2d5065df75472b7c030cc8a5d12ab20d24cc10
oid sha256:32ddf36af25791935b395c7641531cda14d5c4a2cf654a2e76ac45271665d07a
size 16904
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:c49a5b4d4df92c9564009780f5e286ddfca84ca2b1753557024057b3b36afb8b
oid sha256:22a1031a2acfc36a455bff73ffbe097cfeb7742b6485e7422507e78d7a682703
size 164
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:5f8d19a86065937cffdd3ca49caef87c59e67d419b28f40f2817bad892dc3170
size 36312
oid sha256:b5dca7940998421ae58e9e26b2b2641b058d23b0270b7a147ebf85fbbdce7184
size 35496
+3 -104
View File
@@ -26,7 +26,7 @@ from safetensors.torch import load_file
from lerobot import available_policies
from lerobot.configs.default import DatasetConfig
from lerobot.configs.train import TrainPipelineConfig
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.constants import ACTION, OBS_STATE
from lerobot.datasets.factory import make_dataset
from lerobot.datasets.utils import cycle, dataset_to_policy_features
@@ -41,7 +41,6 @@ from lerobot.policies.factory import (
make_policy_config,
make_processor,
)
from lerobot.policies.normalize import Normalize, Unnormalize
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.utils.random_utils import seeded_context
from tests.artifacts.policies.save_policy_to_safetensors import get_policy_stats
@@ -266,108 +265,6 @@ def test_save_and_load_pretrained(dummy_dataset_metadata, tmp_path, policy_name:
torch.testing.assert_close(list(policy.parameters()), list(loaded_policy.parameters()), rtol=0, atol=0)
@pytest.mark.parametrize("insert_temporal_dim", [False, True])
def test_normalize(insert_temporal_dim):
"""
Test that normalize/unnormalize can run without exceptions when properly set up, and that they raise
an exception when the forward pass is called without the stats having been provided.
TODO(rcadene, alexander-soare): This should also test that the normalization / unnormalization works as
expected.
"""
input_features = {
"observation.image": PolicyFeature(
type=FeatureType.VISUAL,
shape=(3, 96, 96),
),
"observation.state": PolicyFeature(
type=FeatureType.STATE,
shape=(10,),
),
}
output_features = {
"action": PolicyFeature(
type=FeatureType.ACTION,
shape=(5,),
),
}
norm_map = {
"VISUAL": NormalizationMode.MEAN_STD,
"STATE": NormalizationMode.MIN_MAX,
"ACTION": NormalizationMode.MIN_MAX,
}
dataset_stats = {
"observation.image": {
"mean": torch.randn(3, 1, 1),
"std": torch.randn(3, 1, 1),
"min": torch.randn(3, 1, 1),
"max": torch.randn(3, 1, 1),
},
"observation.state": {
"mean": torch.randn(10),
"std": torch.randn(10),
"min": torch.randn(10),
"max": torch.randn(10),
},
"action": {
"mean": torch.randn(5),
"std": torch.randn(5),
"min": torch.randn(5),
"max": torch.randn(5),
},
}
bsize = 2
input_batch = {
"observation.image": torch.randn(bsize, 3, 96, 96),
"observation.state": torch.randn(bsize, 10),
}
output_batch = {
"action": torch.randn(bsize, 5),
}
if insert_temporal_dim:
tdim = 4
for key in input_batch:
# [2,3,96,96] -> [2,tdim,3,96,96]
input_batch[key] = torch.stack([input_batch[key]] * tdim, dim=1)
for key in output_batch:
output_batch[key] = torch.stack([output_batch[key]] * tdim, dim=1)
# test without stats
normalize = Normalize(input_features, norm_map, stats=None)
with pytest.raises(AssertionError):
normalize(input_batch)
# test with stats
normalize = Normalize(input_features, norm_map, stats=dataset_stats)
normalize(input_batch)
# test loading pretrained models
new_normalize = Normalize(input_features, norm_map, stats=None)
new_normalize.load_state_dict(normalize.state_dict())
new_normalize(input_batch)
# test without stats
unnormalize = Unnormalize(output_features, norm_map, stats=None)
with pytest.raises(AssertionError):
unnormalize(output_batch)
# test with stats
unnormalize = Unnormalize(output_features, norm_map, stats=dataset_stats)
unnormalize(output_batch)
# test loading pretrained models
new_unnormalize = Unnormalize(output_features, norm_map, stats=None)
new_unnormalize.load_state_dict(unnormalize.state_dict())
unnormalize(output_batch)
@pytest.mark.parametrize("multikey", [True, False])
def test_multikey_construction(multikey: bool):
"""
@@ -467,6 +364,8 @@ def test_backward_compatibility(ds_repo_id: str, policy_name: str, policy_kwargs
NOTE: If the test does not pass, and you don't change the policy, it is likely that the test artifact
is out of date. For example, some PyTorch versions have different randomness, see this PR:
https://github.com/huggingface/lerobot/pull/1127.
NOTE: If the test don't pass and you don't change the policy, and note the dependencies version,
and you changed your processor, you might have to update the test artifact.
"""