mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 12:09:42 +00:00
feat(policies): convert save_policy_to_safetensors with pipeline
This commit is contained in:
@@ -26,7 +26,7 @@ from safetensors.torch import load_file
|
||||
from lerobot import available_policies
|
||||
from lerobot.configs.default import DatasetConfig
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.constants import ACTION, OBS_STATE
|
||||
from lerobot.datasets.factory import make_dataset
|
||||
from lerobot.datasets.utils import cycle, dataset_to_policy_features
|
||||
@@ -41,7 +41,6 @@ from lerobot.policies.factory import (
|
||||
make_policy_config,
|
||||
make_processor,
|
||||
)
|
||||
from lerobot.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.utils.random_utils import seeded_context
|
||||
from tests.artifacts.policies.save_policy_to_safetensors import get_policy_stats
|
||||
@@ -266,108 +265,6 @@ def test_save_and_load_pretrained(dummy_dataset_metadata, tmp_path, policy_name:
|
||||
torch.testing.assert_close(list(policy.parameters()), list(loaded_policy.parameters()), rtol=0, atol=0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("insert_temporal_dim", [False, True])
|
||||
def test_normalize(insert_temporal_dim):
|
||||
"""
|
||||
Test that normalize/unnormalize can run without exceptions when properly set up, and that they raise
|
||||
an exception when the forward pass is called without the stats having been provided.
|
||||
|
||||
TODO(rcadene, alexander-soare): This should also test that the normalization / unnormalization works as
|
||||
expected.
|
||||
"""
|
||||
|
||||
input_features = {
|
||||
"observation.image": PolicyFeature(
|
||||
type=FeatureType.VISUAL,
|
||||
shape=(3, 96, 96),
|
||||
),
|
||||
"observation.state": PolicyFeature(
|
||||
type=FeatureType.STATE,
|
||||
shape=(10,),
|
||||
),
|
||||
}
|
||||
output_features = {
|
||||
"action": PolicyFeature(
|
||||
type=FeatureType.ACTION,
|
||||
shape=(5,),
|
||||
),
|
||||
}
|
||||
|
||||
norm_map = {
|
||||
"VISUAL": NormalizationMode.MEAN_STD,
|
||||
"STATE": NormalizationMode.MIN_MAX,
|
||||
"ACTION": NormalizationMode.MIN_MAX,
|
||||
}
|
||||
|
||||
dataset_stats = {
|
||||
"observation.image": {
|
||||
"mean": torch.randn(3, 1, 1),
|
||||
"std": torch.randn(3, 1, 1),
|
||||
"min": torch.randn(3, 1, 1),
|
||||
"max": torch.randn(3, 1, 1),
|
||||
},
|
||||
"observation.state": {
|
||||
"mean": torch.randn(10),
|
||||
"std": torch.randn(10),
|
||||
"min": torch.randn(10),
|
||||
"max": torch.randn(10),
|
||||
},
|
||||
"action": {
|
||||
"mean": torch.randn(5),
|
||||
"std": torch.randn(5),
|
||||
"min": torch.randn(5),
|
||||
"max": torch.randn(5),
|
||||
},
|
||||
}
|
||||
|
||||
bsize = 2
|
||||
input_batch = {
|
||||
"observation.image": torch.randn(bsize, 3, 96, 96),
|
||||
"observation.state": torch.randn(bsize, 10),
|
||||
}
|
||||
output_batch = {
|
||||
"action": torch.randn(bsize, 5),
|
||||
}
|
||||
|
||||
if insert_temporal_dim:
|
||||
tdim = 4
|
||||
|
||||
for key in input_batch:
|
||||
# [2,3,96,96] -> [2,tdim,3,96,96]
|
||||
input_batch[key] = torch.stack([input_batch[key]] * tdim, dim=1)
|
||||
|
||||
for key in output_batch:
|
||||
output_batch[key] = torch.stack([output_batch[key]] * tdim, dim=1)
|
||||
|
||||
# test without stats
|
||||
normalize = Normalize(input_features, norm_map, stats=None)
|
||||
with pytest.raises(AssertionError):
|
||||
normalize(input_batch)
|
||||
|
||||
# test with stats
|
||||
normalize = Normalize(input_features, norm_map, stats=dataset_stats)
|
||||
normalize(input_batch)
|
||||
|
||||
# test loading pretrained models
|
||||
new_normalize = Normalize(input_features, norm_map, stats=None)
|
||||
new_normalize.load_state_dict(normalize.state_dict())
|
||||
new_normalize(input_batch)
|
||||
|
||||
# test without stats
|
||||
unnormalize = Unnormalize(output_features, norm_map, stats=None)
|
||||
with pytest.raises(AssertionError):
|
||||
unnormalize(output_batch)
|
||||
|
||||
# test with stats
|
||||
unnormalize = Unnormalize(output_features, norm_map, stats=dataset_stats)
|
||||
unnormalize(output_batch)
|
||||
|
||||
# test loading pretrained models
|
||||
new_unnormalize = Unnormalize(output_features, norm_map, stats=None)
|
||||
new_unnormalize.load_state_dict(unnormalize.state_dict())
|
||||
unnormalize(output_batch)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("multikey", [True, False])
|
||||
def test_multikey_construction(multikey: bool):
|
||||
"""
|
||||
@@ -467,6 +364,8 @@ def test_backward_compatibility(ds_repo_id: str, policy_name: str, policy_kwargs
|
||||
NOTE: If the test does not pass, and you don't change the policy, it is likely that the test artifact
|
||||
is out of date. For example, some PyTorch versions have different randomness, see this PR:
|
||||
https://github.com/huggingface/lerobot/pull/1127.
|
||||
NOTE: If the test don't pass and you don't change the policy, and note the dependencies version,
|
||||
and you changed your processor, you might have to update the test artifact.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user