Added docstrings to aggregate, fix test_policies.py

This commit is contained in:
Michel Aractingi
2025-07-04 11:27:00 +02:00
parent 830a3b9f27
commit 3dbc3e60fb
3 changed files with 157 additions and 28 deletions
+3 -7
View File
@@ -511,23 +511,19 @@ def test_backward_compatibility(repo_id):
)
# test2 first frames of first episode
i = dataset.meta.episodes["dataset_from_index"][0].item()
i = dataset.meta.episodes[0]["dataset_from_index"]
load_and_compare(i)
load_and_compare(i + 1)
# test 2 frames at the middle of first episode
i = int(
(
dataset.meta.episodes["dataset_to_index"][0].item()
- dataset.meta.episodes["dataset_from_index"][0].item()
)
/ 2
(dataset.meta.episodes[0]["dataset_to_index"] - dataset.meta.episodes[0]["dataset_from_index"]) / 2
)
load_and_compare(i)
load_and_compare(i + 1)
# test 2 last frames of first episode
i = dataset.meta.episodes["dataset_to_index"][0].item()
i = dataset.meta.episodes[0]["dataset_to_index"]
load_and_compare(i - 2)
load_and_compare(i - 1)
+15 -9
View File
@@ -20,6 +20,7 @@ from pathlib import Path
import einops
import pytest
import torch
from packaging import version
from safetensors.torch import load_file
from lerobot import available_policies
@@ -68,11 +69,7 @@ def dummy_dataset_metadata(lerobot_dataset_metadata_factory, info_factory, tmp_p
},
}
info = info_factory(
total_episodes=1,
total_frames=1,
total_tasks=1,
camera_features=camera_features,
motor_features=motor_features,
total_episodes=1, total_frames=1, camera_features=camera_features, motor_features=motor_features
)
ds_meta = lerobot_dataset_metadata_factory(root=tmp_path / "init", info=info)
return ds_meta
@@ -141,14 +138,13 @@ def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs):
Note: We test various combinations of policy and dataset. The combinations are by no means exhaustive,
and for now we add tests as we see fit.
"""
policy_kwargs["device"] = DEVICE
train_cfg = TrainPipelineConfig(
# TODO(rcadene, aliberts): remove dataset download
dataset=DatasetConfig(repo_id=ds_repo_id, episodes=[0]),
policy=make_policy_config(policy_name, **policy_kwargs),
policy=make_policy_config(policy_name, push_to_hub=False, **policy_kwargs),
env=make_env_config(env_name, **env_kwargs),
)
train_cfg.validate()
# Check that we can make the policy object.
dataset = make_dataset(train_cfg)
@@ -217,7 +213,7 @@ def test_act_backbone_lr():
cfg = TrainPipelineConfig(
# TODO(rcadene, aliberts): remove dataset download
dataset=DatasetConfig(repo_id="lerobot/aloha_sim_insertion_scripted", episodes=[0]),
policy=make_policy_config("act", optimizer_lr=0.01, optimizer_lr_backbone=0.001),
policy=make_policy_config("act", optimizer_lr=0.01, optimizer_lr_backbone=0.001, push_to_hub=False),
)
cfg.validate() # Needed for auto-setting some parameters
@@ -413,7 +409,17 @@ def test_backward_compatibility(ds_repo_id: str, policy_name: str, policy_kwargs
4. Check that this test now passes.
5. Remember to restore `tests/scripts/save_policy_to_safetensors.py` to its original state.
6. Remember to stage and commit the resulting changes to `tests/artifacts`.
NOTE: If the test does not pass, and you don't change the policy, it is likely that the test artifact
is out of date. For example, some PyTorch versions have different randomness, see this PR:
https://github.com/huggingface/lerobot/pull/1127.
"""
# NOTE: ACT policy has different randomness, after PyTorch 2.7.0
if policy_name == "act" and version.parse(torch.__version__) < version.parse("2.7.0"):
pytest.skip(f"Skipping act policy test with PyTorch {torch.__version__}. Requires PyTorch >= 2.7.0")
ds_name = ds_repo_id.split("/")[-1]
artifact_dir = Path("tests/artifacts/policies") / f"{ds_name}_{policy_name}_{file_name_extra}"
saved_output_dict = load_file(artifact_dir / "output_dict.safetensors")