mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-21 03:30:10 +00:00
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
+92
-24
@@ -39,7 +39,13 @@ from lerobot.common.policies.policy_protocol import Policy
|
||||
from lerobot.common.utils.utils import init_hydra_config, seeded_context
|
||||
from lerobot.scripts.train import make_optimizer_and_scheduler
|
||||
from tests.scripts.save_policy_to_safetensors import get_policy_stats
|
||||
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_cpu, require_env, require_x86_64_kernel
|
||||
from tests.utils import (
|
||||
DEFAULT_CONFIG_PATH,
|
||||
DEVICE,
|
||||
require_cpu,
|
||||
require_env,
|
||||
require_x86_64_kernel,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("policy_name", available_policies)
|
||||
@@ -47,37 +53,63 @@ def test_get_policy_and_config_classes(policy_name: str):
|
||||
"""Check that the correct policy and config classes are returned."""
|
||||
policy_cls, config_cls = get_policy_and_config_classes(policy_name)
|
||||
assert policy_cls.name == policy_name
|
||||
assert issubclass(config_cls, inspect.signature(policy_cls.__init__).parameters["config"].annotation)
|
||||
assert issubclass(
|
||||
config_cls,
|
||||
inspect.signature(policy_cls.__init__).parameters["config"].annotation,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skip("TODO after v2 migration / removing hydra")
|
||||
@pytest.mark.parametrize(
|
||||
"env_name,policy_name,extra_overrides",
|
||||
[
|
||||
("xarm", "tdmpc", ["policy.use_mpc=true", "dataset_repo_id=lerobot/xarm_lift_medium"]),
|
||||
(
|
||||
"xarm",
|
||||
"tdmpc",
|
||||
["policy.use_mpc=true", "dataset_repo_id=lerobot/xarm_lift_medium"],
|
||||
),
|
||||
("pusht", "diffusion", []),
|
||||
("pusht", "vqbet", []),
|
||||
("aloha", "act", ["env.task=AlohaInsertion-v0", "dataset_repo_id=lerobot/aloha_sim_insertion_human"]),
|
||||
(
|
||||
"aloha",
|
||||
"act",
|
||||
["env.task=AlohaInsertion-v0", "dataset_repo_id=lerobot/aloha_sim_insertion_scripted"],
|
||||
[
|
||||
"env.task=AlohaInsertion-v0",
|
||||
"dataset_repo_id=lerobot/aloha_sim_insertion_human",
|
||||
],
|
||||
),
|
||||
(
|
||||
"aloha",
|
||||
"act",
|
||||
["env.task=AlohaTransferCube-v0", "dataset_repo_id=lerobot/aloha_sim_transfer_cube_human"],
|
||||
[
|
||||
"env.task=AlohaInsertion-v0",
|
||||
"dataset_repo_id=lerobot/aloha_sim_insertion_scripted",
|
||||
],
|
||||
),
|
||||
(
|
||||
"aloha",
|
||||
"act",
|
||||
["env.task=AlohaTransferCube-v0", "dataset_repo_id=lerobot/aloha_sim_transfer_cube_scripted"],
|
||||
[
|
||||
"env.task=AlohaTransferCube-v0",
|
||||
"dataset_repo_id=lerobot/aloha_sim_transfer_cube_human",
|
||||
],
|
||||
),
|
||||
(
|
||||
"aloha",
|
||||
"act",
|
||||
[
|
||||
"env.task=AlohaTransferCube-v0",
|
||||
"dataset_repo_id=lerobot/aloha_sim_transfer_cube_scripted",
|
||||
],
|
||||
),
|
||||
# Note: these parameters also need custom logic in the test function for overriding the Hydra config.
|
||||
(
|
||||
"aloha",
|
||||
"diffusion",
|
||||
["env.task=AlohaInsertion-v0", "dataset_repo_id=lerobot/aloha_sim_insertion_human"],
|
||||
[
|
||||
"env.task=AlohaInsertion-v0",
|
||||
"dataset_repo_id=lerobot/aloha_sim_insertion_human",
|
||||
],
|
||||
),
|
||||
# Note: these parameters also need custom logic in the test function for overriding the Hydra config.
|
||||
("pusht", "act", ["env.task=PushT-v0", "dataset_repo_id=lerobot/pusht"]),
|
||||
@@ -165,7 +197,9 @@ def test_policy(env_name, policy_name, extra_overrides):
|
||||
# Test updating the policy (and test that it does not mutate the batch)
|
||||
batch_ = deepcopy(batch)
|
||||
policy.forward(batch)
|
||||
assert set(batch) == set(batch_), "Batch keys are not the same after a forward pass."
|
||||
assert set(batch) == set(
|
||||
batch_
|
||||
), "Batch keys are not the same after a forward pass."
|
||||
assert all(
|
||||
torch.equal(batch[k], batch_[k]) for k in batch
|
||||
), "Batch values are not the same after a forward pass."
|
||||
@@ -178,7 +212,9 @@ def test_policy(env_name, policy_name, extra_overrides):
|
||||
observation = preprocess_observation(observation)
|
||||
|
||||
# send observation to device/gpu
|
||||
observation = {key: observation[key].to(DEVICE, non_blocking=True) for key in observation}
|
||||
observation = {
|
||||
key: observation[key].to(DEVICE, non_blocking=True) for key in observation
|
||||
}
|
||||
|
||||
# get the next action for the environment (also check that the observation batch is not modified)
|
||||
observation_ = deepcopy(observation)
|
||||
@@ -240,7 +276,9 @@ def test_policy_defaults(policy_name: str):
|
||||
)
|
||||
def test_yaml_matches_dataclass(env_name: str, policy_name: str):
|
||||
"""Check that dataclass configs match their respective yaml configs."""
|
||||
hydra_cfg = init_hydra_config(DEFAULT_CONFIG_PATH, overrides=[f"env={env_name}", f"policy={policy_name}"])
|
||||
hydra_cfg = init_hydra_config(
|
||||
DEFAULT_CONFIG_PATH, overrides=[f"env={env_name}", f"policy={policy_name}"]
|
||||
)
|
||||
_, policy_cfg_cls = get_policy_and_config_classes(policy_name)
|
||||
policy_cfg_from_hydra = _policy_cfg_from_hydra_cfg(policy_cfg_cls, hydra_cfg)
|
||||
policy_cfg_from_dataclass = policy_cfg_cls()
|
||||
@@ -254,7 +292,10 @@ def test_save_and_load_pretrained(policy_name: str):
|
||||
save_dir = "/tmp/test_save_and_load_pretrained_{policy_cls.__name__}"
|
||||
policy.save_pretrained(save_dir)
|
||||
policy_ = policy_cls.from_pretrained(save_dir)
|
||||
assert all(torch.equal(p, p_) for p, p_ in zip(policy.parameters(), policy_.parameters(), strict=True))
|
||||
assert all(
|
||||
torch.equal(p, p_)
|
||||
for p, p_ in zip(policy.parameters(), policy_.parameters(), strict=True)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("insert_temporal_dim", [False, True])
|
||||
@@ -343,7 +384,9 @@ def test_normalize(insert_temporal_dim):
|
||||
unnormalize(output_batch)
|
||||
|
||||
# test with stats
|
||||
unnormalize = Unnormalize(output_shapes, unnormalize_output_modes, stats=dataset_stats)
|
||||
unnormalize = Unnormalize(
|
||||
output_shapes, unnormalize_output_modes, stats=dataset_stats
|
||||
)
|
||||
unnormalize(output_batch)
|
||||
|
||||
# test loading pretrained models
|
||||
@@ -364,11 +407,20 @@ def test_normalize(insert_temporal_dim):
|
||||
(
|
||||
"pusht",
|
||||
"diffusion",
|
||||
["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"],
|
||||
[
|
||||
"policy.n_action_steps=8",
|
||||
"policy.num_inference_steps=10",
|
||||
"policy.down_dims=[128, 256, 512]",
|
||||
],
|
||||
"",
|
||||
),
|
||||
("aloha", "act", ["policy.n_action_steps=10"], ""),
|
||||
("aloha", "act", ["policy.n_action_steps=1000", "policy.chunk_size=1000"], "_1000_steps"),
|
||||
(
|
||||
"aloha",
|
||||
"act",
|
||||
["policy.n_action_steps=1000", "policy.chunk_size=1000"],
|
||||
"_1000_steps",
|
||||
),
|
||||
("dora_aloha_real", "act_aloha_real", ["policy.n_action_steps=10"], ""),
|
||||
],
|
||||
)
|
||||
@@ -376,7 +428,9 @@ def test_normalize(insert_temporal_dim):
|
||||
# pass if it's run on another platform due to floating point errors
|
||||
@require_x86_64_kernel
|
||||
@require_cpu
|
||||
def test_backward_compatibility(env_name, policy_name, extra_overrides, file_name_extra):
|
||||
def test_backward_compatibility(
|
||||
env_name, policy_name, extra_overrides, file_name_extra
|
||||
):
|
||||
"""
|
||||
NOTE: If this test does not pass, and you have intentionally changed something in the policy:
|
||||
1. Inspect the differences in policy outputs and make sure you can account for them. Your PR should
|
||||
@@ -390,23 +444,34 @@ def test_backward_compatibility(env_name, policy_name, extra_overrides, file_nam
|
||||
6. Remember to stage and commit the resulting changes to `tests/data`.
|
||||
"""
|
||||
env_policy_dir = (
|
||||
Path("tests/data/save_policy_to_safetensors") / f"{env_name}_{policy_name}{file_name_extra}"
|
||||
Path("tests/data/save_policy_to_safetensors")
|
||||
/ f"{env_name}_{policy_name}{file_name_extra}"
|
||||
)
|
||||
saved_output_dict = load_file(env_policy_dir / "output_dict.safetensors")
|
||||
saved_grad_stats = load_file(env_policy_dir / "grad_stats.safetensors")
|
||||
saved_param_stats = load_file(env_policy_dir / "param_stats.safetensors")
|
||||
saved_actions = load_file(env_policy_dir / "actions.safetensors")
|
||||
|
||||
output_dict, grad_stats, param_stats, actions = get_policy_stats(env_name, policy_name, extra_overrides)
|
||||
output_dict, grad_stats, param_stats, actions = get_policy_stats(
|
||||
env_name, policy_name, extra_overrides
|
||||
)
|
||||
|
||||
for key in saved_output_dict:
|
||||
assert torch.isclose(output_dict[key], saved_output_dict[key], rtol=0.1, atol=1e-7).all()
|
||||
assert torch.isclose(
|
||||
output_dict[key], saved_output_dict[key], rtol=0.1, atol=1e-7
|
||||
).all()
|
||||
for key in saved_grad_stats:
|
||||
assert torch.isclose(grad_stats[key], saved_grad_stats[key], rtol=0.1, atol=1e-7).all()
|
||||
assert torch.isclose(
|
||||
grad_stats[key], saved_grad_stats[key], rtol=0.1, atol=1e-7
|
||||
).all()
|
||||
for key in saved_param_stats:
|
||||
assert torch.isclose(param_stats[key], saved_param_stats[key], rtol=50, atol=1e-7).all()
|
||||
assert torch.isclose(
|
||||
param_stats[key], saved_param_stats[key], rtol=50, atol=1e-7
|
||||
).all()
|
||||
for key in saved_actions:
|
||||
assert torch.isclose(actions[key], saved_actions[key], rtol=0.1, atol=1e-7).all()
|
||||
assert torch.isclose(
|
||||
actions[key], saved_actions[key], rtol=0.1, atol=1e-7
|
||||
).all()
|
||||
|
||||
|
||||
def test_act_temporal_ensembler():
|
||||
@@ -432,7 +497,9 @@ def test_act_temporal_ensembler():
|
||||
batch_size = batch_seq.shape[0]
|
||||
# Exponential weighting (normalized). Unsqueeze once to match the position of the `episode_length`
|
||||
# dimension of `batch_seq`.
|
||||
weights = torch.exp(-temporal_ensemble_coeff * torch.arange(chunk_size)).unsqueeze(-1)
|
||||
weights = torch.exp(-temporal_ensemble_coeff * torch.arange(chunk_size)).unsqueeze(
|
||||
-1
|
||||
)
|
||||
|
||||
# Simulate stepping through a rollout and computing a batch of actions with model on each step.
|
||||
for i in range(episode_length):
|
||||
@@ -455,7 +522,8 @@ def test_act_temporal_ensembler():
|
||||
episode_step_indices = torch.arange(i + 1)[-len(chunk_indices) :]
|
||||
seq_slice = batch_seq[:, episode_step_indices, chunk_indices]
|
||||
offline_avg = (
|
||||
einops.reduce(seq_slice * weights[: i + 1], "b s 1 -> b 1", "sum") / weights[: i + 1].sum()
|
||||
einops.reduce(seq_slice * weights[: i + 1], "b s 1 -> b 1", "sum")
|
||||
/ weights[: i + 1].sum()
|
||||
)
|
||||
# Sanity check. The average should be between the extrema.
|
||||
assert torch.all(einops.reduce(seq_slice, "b s 1 -> b 1", "min") <= offline_avg)
|
||||
|
||||
Reference in New Issue
Block a user