Files
lerobot/tests/policies/test_policies.py
T
Pepijn 919184d6f8 feat(envs): lazy env init + AsyncVectorEnv as default for n_envs > 1 (#3274)
* docs(benchmarks): add benchmark integration guide and standardize benchmark docs

Add a comprehensive guide for adding new benchmarks to LeRobot, and
refactor the existing LIBERO and Meta-World docs to follow the new
standardized template.

Made-with: Cursor

* refactor(envs): move dispatch logic from factory into EnvConfig subclasses

Replace hardcoded if/elif chains in factory.py with create_envs() and
get_env_processors() methods on EnvConfig. New benchmarks now only need
to register a config subclass — no factory.py edits required.

Net -23 lines: factory.py shrinks from ~200 to ~70 lines of logic.

Made-with: Cursor

* docs(benchmarks): clean up adding-benchmarks guide for clarity

Rewrite for simpler language, better structure, and easier navigation.
Move quick-reference table to the top, fold eval explanation into
architecture section, condense the doc template to a bulleted outline.

Made-with: Cursor

* fix link

* fix task count

* fix: enable SmolVLA eval on LIBERO with custom camera mappings

- Thread camera_name_mapping from LiberoEnv config through to gym envs
- Sync features_map with camera_name_mapping in LiberoEnv.__post_init__
- Fix render() to use first available camera instead of hardcoded "image"
- Handle non-dict final_info in rollout by falling back to info["is_success"]
- Add use_peft legacy field to SmolVLAConfig for checkpoint compat
- Add defaults to GR00TN15Config init=False fields for transformers 5.3

Made-with: Cursor

* fix: use direct AutoresetMode import for gymnasium compat

Made-with: Cursor

* fix: handle gymnasium < 1.0 without AutoresetMode

Made-with: Cursor

* refactor: revert policy changes, keep env-only camera mapping fixes

- Revert GR00T N1.5 default_factory/default changes (transformers compat)
- Revert SmolVLA use_peft legacy field
- Apply ruff formatting fixes
- camera_name_mapping stays entirely in env/eval layer (no policy changes)

Made-with: Cursor

* Update docs/source/env_processor.mdx

Co-authored-by: Khalil Meftah <khalil.meftah@huggingface.co>
Signed-off-by: Pepijn <138571049+pkooij@users.noreply.github.com>

* feat(envs): lazy env init + AsyncVectorEnv as default for n_envs > 1

LiberoEnv and MetaworldEnv previously allocated GPU resources (EGL context,
OpenGL framebuffer) in __init__, before AsyncVectorEnv's fork(). Worker
processes inherited stale GPU handles, causing EGL_BAD_CONTEXT crashes on
first render.

Fix: defer OffScreenRenderEnv / MT1 construction to _ensure_env(), called on
first reset() or step() inside the worker subprocess. Each worker creates its
own clean context after fork().

Also fixes lerobot_eval.py:170 (add_envs_task TODO): replace with
env.call("task") which works with both SyncVectorEnv and AsyncVectorEnv.

AsyncVectorEnv is now the default for n_envs > 1; auto-downgraded to
SyncVectorEnv when n_envs=1 (no benefit, less overhead).

Expected speedup: ~15-20x for LIBERO Spatial with batch_size=50.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* fix: close envs between tasks to prevent worker process accumulation

eval_policy_all never closed environments after each task completed,
causing AsyncVectorEnv worker processes to accumulate (N_tasks × n_envs).
This led to OOM, BrokenPipeError and EOFError on multi-task benchmarks.

Also fixes:
- AsyncVectorEnv compat in envs/utils.py (use get_attr/call instead of .envs)
- Tuple task handling in tokenizer_processor and lerobot_eval
- _LazyAsyncVectorEnv for deferred worker spawning in LIBERO

Made-with: Cursor

* fix(eval): use task_description instead of task for language conditioning

env.call("task") returns the LIBERO task name with underscores
(e.g. "pick_up_the_black_bowl_...") instead of the natural language
description ("pick up the black bowl ..."). The VLM tokenizes these
completely differently, causing 0.0 reward across all episodes.

Made-with: Cursor

* docs: update adding_benchmarks for async env changes

- Replace add_envs_task reference with env.call("task_description")
- Update use_async_envs default to True
- Add note about lazy GPU init for AsyncVectorEnv compatibility

Made-with: Cursor

* feat(eval): batch_size=auto + faster env loading

- batch_size=0 (default) auto-tunes based on CPU cores, capped by
  n_episodes and 64. Removes the need for users to guess the right
  value. The old batch_size > n_episodes error is replaced by silently
  clamping to n_episodes.
- _LazyAsyncVectorEnv accepts pre-computed spaces so only one temp env
  is created per suite (not per task). For libero_spatial (10 tasks)
  this avoids 9 redundant LiberoEnv instantiations during env setup.

Made-with: Cursor

* docs: add evaluation guide and update benchmarks doc

- New docs/source/evaluation.mdx covering lerobot-eval usage, batch_size
  auto-tuning, AsyncVectorEnv performance, tuning tips, output format,
  multi-task evaluation, and programmatic usage.
- Add evaluation page to _toctree.yml under Benchmarks section.
- Update adding_benchmarks.mdx to reference batch_size auto default and
  link to the evaluation guide.

Made-with: Cursor

* docs(evaluation): remove benchmark table, rename section header

Made-with: Cursor

* perf(eval): shared memory, observation passthrough, task prefetch

- AsyncVectorEnv now uses shared_memory=True for zero-copy observation transfer
- LiberoEnvConfig.gym_kwargs passes observation_height/width to the env
- eval_policy_all prefetches next task's workers while current task runs

Made-with: Cursor

* style: ruff format

Made-with: Cursor

* chore: revert env_processor.mdx changes (not part of this PR)

Made-with: Cursor

* ci(benchmarks): add isolated integration tests for libero and metaworld

Each benchmark gets its own Docker image (lerobot[libero] / lerobot[metaworld]
only) so incompatible dep trees cannot collide. A 1-episode smoke eval runs
per benchmark on GPU runners.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* ci(benchmarks): pin action hashes and use uv sync --locked

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* ci(benchmarks): trigger only on envs/ or lerobot_eval.py changes

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* fix(ci): set LIBERO_DATA_FOLDER to bypass interactive stdin prompt

libero/__init__.py calls input() to ask about a custom dataset path,
which raises EOFError when stdin is closed inside Docker. Setting
LIBERO_DATA_FOLDER skips the prompt entirely.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* docs(benchmarks): add CI smoke test step to adding_benchmarks guide

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* fix(ci): pre-create libero config in Dockerfile to bypass stdin prompt

libero/__init__.py calls input() when ~/.libero/config.yaml is missing.
We write the config at image build time (without importing libero) so
the prompt never fires at runtime. Also trigger CI on pyproject.toml changes.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* fix(ci): use shell to create libero config instead of multiline python -c

The multiline RUN python -c "..." was being parsed as Dockerfile
instructions. Use printf to write ~/.libero/config.yaml directly.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* fix(ci): point libero config to bundled package init_files

The config was pointing to /tmp/libero_init which doesn't exist.
Use importlib.util.find_spec to locate the hf-libero package directory
and write paths to the actual bundled bddl_files/init_files/assets.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* fix(ci): add smolvla extra to benchmark Dockerfiles

num2words (required by SmolVLM processor) is declared in lerobot[smolvla],
not lerobot[libero/metaworld]. Install both extras together.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* fix(eval): render_frame covers _LazyAsyncVectorEnv

isinstance(env, AsyncVectorEnv) silently skipped _LazyAsyncVectorEnv,
causing video rendering to produce no frames on the default async path.
Switch to hasattr(env, "call") so any async-compatible env (including
_LazyAsyncVectorEnv) hits the call("render") branch.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* refactor(envs): remove unused _get_sub_env_attr helper

_get_sub_env_attr was defined but never called anywhere in the codebase.
_sub_env_has_attr (its sibling) is kept — it is actively used in utils.py.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* chore: apply prettier formatting to docs

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* docs(env_processor): remove deprecated add_envs_task from pipeline example

add_envs_task is replaced by env.call("task_description") in this PR.
Remove it from the pipeline walkthrough and renumber the steps (8→7).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* refactor(envs): remove __del__ from _LazyAsyncVectorEnv

__del__ is unreliable as a cleanup mechanism. close() is already called
explicitly in the eval loop's finally block, so the finalizer is redundant.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* fix(eval): prefetch next task's workers after close to avoid GPU memory overlap

Previously, next task's AsyncVectorEnv workers were spawned while the
current task was still running, causing both tasks' GPU contexts to coexist.
Moving the prefetch start into the finally block (after env.close()) ensures
workers for task N+1 only spin up once task N has released GPU memory.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* refactor(envs): move _LazyAsyncVectorEnv to utils and apply to metaworld

_LazyAsyncVectorEnv lived in libero.py but metaworld had the same OOM
problem: all tasks' AsyncVectorEnv workers were spawned eagerly, wasting
GPU memory for tasks not yet running.

Move the class to envs/utils.py so both environments share it, then apply
the same is_async + lazy wrapping pattern in create_metaworld_envs.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* chore: remove out-of-scope benchmark/CI/docs files from PR

Benchmark CI workflow, Dockerfiles, benchmark docs, evaluation smoke-test
doc, and dispatch tests belong in a separate PR. Scope this PR to the
async env init changes only.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* chore: restore adding_benchmarks + test_dispatch, drop env_processor changes

- Restore docs/source/adding_benchmarks.mdx (belongs in this PR)
- Restore tests/envs/test_dispatch.py (belongs in this PR)
- Revert docs/source/env_processor.mdx to main (out of scope for this PR)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* docs(adding_benchmarks): remove CI smoke test step (coming in separate PR)

Step 7 (Dockerfile + benchmark_tests.yml CI job) and its table rows are
out of scope for this PR. The CI infrastructure will be added on top in a
follow-up PR.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* refactor(envs): remove unused add_envs_task

Replaced by env.call("task_description") in lerobot_eval.py. No callers
remain in the codebase.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* style: fix prettier formatting in env_processor.mdx

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* fix(eval): catch AttributeError and NotImplementedError explicitly for task description

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* fix(envs): use forkserver context and close envs in test to prevent deadlock

AsyncVectorEnv with default fork context leaks worker processes between
test_policy parametrized cases; subsequent env creation deadlocks because
new forked workers inherit stale pipe FDs from previous test's leaked workers.

- configs.py: pass context="forkserver" to AsyncVectorEnv (matches _LazyAsyncVectorEnv)
- test_policies.py: call close_envs(envs) at end of test_policy to clean up workers

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* fix(envs): default use_async_envs=False in create_envs and make_env

Tests that call make_env(n_envs=2) without passing use_async_envs were
getting AsyncVectorEnv, whose forked workers can't resolve gym namespaces
registered at runtime. Default to False (sync) so existing tests pass.

lerobot_eval.py explicitly passes cfg.eval.use_async_envs, so the CLI
async behaviour (controlled by EvalConfig.use_async_envs) is unchanged.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

---------

Signed-off-by: Pepijn <138571049+pkooij@users.noreply.github.com>
Co-authored-by: Khalil Meftah <khalil.meftah@huggingface.co>
Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 10:29:20 +02:00

509 lines
22 KiB
Python

#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from copy import deepcopy
from pathlib import Path
import einops
import pytest
import torch
from packaging import version
from safetensors.torch import load_file
from lerobot import available_policies
from lerobot.configs.default import DatasetConfig
from lerobot.configs.train import TrainPipelineConfig
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.datasets.factory import make_dataset
from lerobot.datasets.feature_utils import dataset_to_policy_features
from lerobot.datasets.utils import cycle
from lerobot.envs.factory import make_env, make_env_config
from lerobot.envs.utils import close_envs, preprocess_observation
from lerobot.optim.factory import make_optimizer_and_scheduler
from lerobot.policies.act.configuration_act import ACTConfig
from lerobot.policies.act.modeling_act import ACTTemporalEnsembler
from lerobot.policies.factory import (
get_policy_class,
make_policy,
make_policy_config,
make_pre_post_processors,
)
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
from lerobot.policies.vqbet.modeling_vqbet import VQBeTHead
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
from lerobot.utils.random_utils import seeded_context
from tests.artifacts.policies.save_policy_to_safetensors import get_policy_stats
from tests.utils import DEVICE, require_cpu, require_env, require_x86_64_kernel
@pytest.fixture
def dummy_dataset_metadata(lerobot_dataset_metadata_factory, info_factory, tmp_path):
# Create only one camera input which is squared to fit all current policy constraints
# e.g. vqbet and tdmpc works with one camera only, and tdmpc requires it to be squared
camera_features = {
f"{OBS_IMAGES}.laptop": {
"shape": (84, 84, 3),
"names": ["height", "width", "channels"],
"info": None,
},
}
motor_features = {
ACTION: {
"dtype": "float32",
"shape": (6,),
"names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"],
},
OBS_STATE: {
"dtype": "float32",
"shape": (6,),
"names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"],
},
}
info = info_factory(
total_episodes=1,
total_frames=1,
total_tasks=1,
camera_features=camera_features,
motor_features=motor_features,
)
ds_meta = lerobot_dataset_metadata_factory(root=tmp_path / "init", info=info)
return ds_meta
@pytest.mark.parametrize("policy_name", available_policies)
def test_get_policy_and_config_classes(policy_name: str):
"""Check that the correct policy and config classes are returned."""
policy_cls = get_policy_class(policy_name)
policy_cfg = make_policy_config(policy_name)
assert policy_cls.name == policy_name
assert issubclass(
policy_cfg.__class__, inspect.signature(policy_cls.__init__).parameters["config"].annotation
)
@pytest.mark.parametrize(
"ds_repo_id,env_name,env_kwargs,policy_name,policy_kwargs",
[
("lerobot/pusht", "pusht", {}, "diffusion", {}),
("lerobot/pusht", "pusht", {}, "vqbet", {}),
("lerobot/pusht", "pusht", {}, "act", {}),
("lerobot/aloha_sim_insertion_human", "aloha", {"task": "AlohaInsertion-v0"}, "act", {}),
(
"lerobot/aloha_sim_insertion_scripted",
"aloha",
{"task": "AlohaInsertion-v0"},
"act",
{},
),
(
"lerobot/aloha_sim_insertion_human",
"aloha",
{"task": "AlohaInsertion-v0"},
"diffusion",
{},
),
(
"lerobot/aloha_sim_transfer_cube_human",
"aloha",
{"task": "AlohaTransferCube-v0"},
"act",
{},
),
(
"lerobot/aloha_sim_transfer_cube_scripted",
"aloha",
{"task": "AlohaTransferCube-v0"},
"act",
{},
),
],
)
@require_env
def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs):
"""
Tests:
- Making the policy object.
- Checking that the policy follows the correct protocol and subclasses nn.Module
and PyTorchModelHubMixin.
- Updating the policy.
- Using the policy to select actions at inference time.
- Test the action can be applied to the policy
Note: We test various combinations of policy and dataset. The combinations are by no means exhaustive,
and for now we add tests as we see fit.
"""
if policy_name == "vqbet" and DEVICE == "mps":
pytest.skip("VQBet does not support MPS backend")
if policy_name == "act" and "aloha" in ds_repo_id and DEVICE == "mps":
pytest.skip("ACT with aloha has batch mutation issues on MPS")
train_cfg = TrainPipelineConfig(
# TODO(rcadene, aliberts): remove dataset download
dataset=DatasetConfig(repo_id=ds_repo_id, episodes=[0]),
policy=make_policy_config(policy_name, push_to_hub=False, **policy_kwargs),
env=make_env_config(env_name, **env_kwargs),
)
train_cfg.policy.device = DEVICE
train_cfg.validate()
# Check that we can make the policy object.
dataset = make_dataset(train_cfg)
preprocessor, _ = make_pre_post_processors(train_cfg.policy, None)
policy = make_policy(train_cfg.policy, ds_meta=dataset.meta)
assert isinstance(policy, PreTrainedPolicy)
# Check that we run select_actions and get the appropriate output.
envs = make_env(train_cfg.env, n_envs=2)
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=0,
batch_size=2,
shuffle=True,
pin_memory=DEVICE != "cpu",
drop_last=True,
)
dl_iter = cycle(dataloader)
batch = next(dl_iter)
for key in batch:
if isinstance(batch[key], torch.Tensor):
batch[key] = batch[key].to(DEVICE, non_blocking=True)
# Test updating the policy (and test that it does not mutate the batch)
batch_ = deepcopy(batch)
policy.forward(batch)
assert set(batch) == set(batch_), "Batch keys are not the same after a forward pass."
assert all(
torch.equal(batch[k], batch_[k]) if isinstance(batch[k], torch.Tensor) else batch[k] == batch_[k]
for k in batch
), "Batch values are not the same after a forward pass."
# reset the policy and environment
policy.reset()
# For testing purposes, we only need a single environment instance.
# So here we unwrap the first suite_name and first task_id to grab
# the actual env object (SyncVectorEnv) that exposes `.reset()`.
suite_name = next(iter(envs))
task_id = next(iter(envs[suite_name]))
env = envs[suite_name][task_id]
observation, _ = env.reset(seed=train_cfg.seed)
# apply transform to normalize the observations
observation = preprocess_observation(observation)
# send observation to device/gpu
observation = {key: observation[key].to(DEVICE, non_blocking=True) for key in observation}
# get the next action for the environment (also check that the observation batch is not modified)
observation_ = deepcopy(observation)
with torch.inference_mode():
action = policy.select_action(observation).cpu().numpy()
assert set(observation) == set(observation_), (
"Observation batch keys are not the same after a forward pass."
)
assert all(torch.equal(observation[k], observation_[k]) for k in observation), (
"Observation batch values are not the same after a forward pass."
)
# Test step through policy
env.step(action)
close_envs(envs)
# TODO(rcadene, aliberts): This test is quite end-to-end. Move this test in test_optimizer?
def test_act_backbone_lr():
"""
Test that the ACT policy can be instantiated with a different learning rate for the backbone.
"""
cfg = TrainPipelineConfig(
# TODO(rcadene, aliberts): remove dataset download
dataset=DatasetConfig(repo_id="lerobot/aloha_sim_insertion_scripted", episodes=[0]),
policy=make_policy_config("act", optimizer_lr=0.01, optimizer_lr_backbone=0.001, push_to_hub=False),
)
cfg.policy.device = DEVICE
cfg.validate() # Needed for auto-setting some parameters
assert cfg.policy.optimizer_lr == 0.01
assert cfg.policy.optimizer_lr_backbone == 0.001
dataset = make_dataset(cfg)
preprocessor, _ = make_pre_post_processors(cfg.policy, None)
policy = make_policy(cfg.policy, ds_meta=dataset.meta)
optimizer, _ = make_optimizer_and_scheduler(cfg, policy)
assert len(optimizer.param_groups) == 2
assert optimizer.param_groups[0]["lr"] == cfg.policy.optimizer_lr
assert optimizer.param_groups[1]["lr"] == cfg.policy.optimizer_lr_backbone
assert len(optimizer.param_groups[0]["params"]) == 133
assert len(optimizer.param_groups[1]["params"]) == 20
@pytest.mark.parametrize("policy_name", available_policies)
def test_policy_defaults(dummy_dataset_metadata, policy_name: str):
"""Check that the policy can be instantiated with defaults."""
policy_cls = get_policy_class(policy_name)
policy_cfg = make_policy_config(policy_name)
features = dataset_to_policy_features(dummy_dataset_metadata.features)
policy_cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
policy_cfg.input_features = {
key: ft for key, ft in features.items() if key not in policy_cfg.output_features
}
policy_cls(policy_cfg)
@pytest.mark.parametrize("policy_name", available_policies)
def test_save_and_load_pretrained(dummy_dataset_metadata, tmp_path, policy_name: str):
policy_cls = get_policy_class(policy_name)
policy_cfg = make_policy_config(policy_name)
features = dataset_to_policy_features(dummy_dataset_metadata.features)
policy_cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
policy_cfg.input_features = {
key: ft for key, ft in features.items() if key not in policy_cfg.output_features
}
policy = policy_cls(policy_cfg)
policy.to(policy_cfg.device)
save_dir = tmp_path / f"test_save_and_load_pretrained_{policy_cls.__name__}"
policy.save_pretrained(save_dir)
loaded_policy = policy_cls.from_pretrained(save_dir, config=policy_cfg)
torch.testing.assert_close(list(policy.parameters()), list(loaded_policy.parameters()), rtol=0, atol=0)
@pytest.mark.parametrize("multikey", [True, False])
def test_multikey_construction(multikey: bool):
"""
Asserts that multiple keys with type State/Action are correctly processed by the policy constructor,
preventing erroneous creation of the policy object.
"""
input_features = {
OBS_STATE: PolicyFeature(
type=FeatureType.STATE,
shape=(10,),
),
}
output_features = {
ACTION: PolicyFeature(
type=FeatureType.ACTION,
shape=(5,),
),
}
if multikey:
"""Simulates the complete state/action is constructed from more granular multiple
keys, of the same type as the overall state/action"""
input_features = {}
input_features[f"{OBS_STATE}.subset1"] = PolicyFeature(type=FeatureType.STATE, shape=(5,))
input_features[f"{OBS_STATE}.subset2"] = PolicyFeature(type=FeatureType.STATE, shape=(5,))
input_features[OBS_STATE] = PolicyFeature(type=FeatureType.STATE, shape=(10,))
output_features = {}
output_features["action.first_three_motors"] = PolicyFeature(type=FeatureType.ACTION, shape=(3,))
output_features["action.last_two_motors"] = PolicyFeature(type=FeatureType.ACTION, shape=(2,))
output_features[ACTION] = PolicyFeature(
type=FeatureType.ACTION,
shape=(5,),
)
config = ACTConfig(input_features=input_features, output_features=output_features)
state_condition = config.robot_state_feature == input_features[OBS_STATE]
action_condition = config.action_feature == output_features[ACTION]
assert state_condition, (
f"Discrepancy detected. Robot state feature is {config.robot_state_feature} but policy expects {input_features[OBS_STATE]}"
)
assert action_condition, (
f"Discrepancy detected. Action feature is {config.action_feature} but policy expects {output_features[ACTION]}"
)
@pytest.mark.parametrize(
"ds_repo_id, policy_name, policy_kwargs, file_name_extra",
[
# TODO(alexander-soare): `policy.use_mpc=false` was previously the default in the config yaml but it
# was changed to true. For some reason, tests would pass locally, but not in CI. So here we override
# to test with `policy.use_mpc=false`.
# TODO(rcadene): the diffusion model was normalizing the image in mean=0.5 std=0.5 which is a hack supposed to
# to normalize the image at all. In our current codebase we dont normalize at all. But there is still a minor difference
# that fails the test. However, by testing to normalize the image with 0.5 0.5 in the current codebase, the test pass.
# Thus, we deactivate this test for now.
(
"lerobot/pusht",
"diffusion",
{
"n_action_steps": 8,
"num_inference_steps": 10,
"down_dims": [128, 256, 512],
},
"",
),
("lerobot/aloha_sim_insertion_human", "act", {"n_action_steps": 10}, ""),
(
"lerobot/aloha_sim_insertion_human",
"act",
{"n_action_steps": 1000, "chunk_size": 1000},
"1000_steps",
),
],
)
# As artifacts have been generated on an x86_64 kernel, this test won't
# pass if it's run on another platform due to floating point errors
@require_x86_64_kernel
@require_cpu
def test_backward_compatibility(ds_repo_id: str, policy_name: str, policy_kwargs: dict, file_name_extra: str):
"""
NOTE: If this test does not pass, and you have intentionally changed something in the policy:
1. Inspect the differences in policy outputs and make sure you can account for them. Your PR should
include a report on what changed and how that affected the outputs.
2. Go to the `if __name__ == "__main__"` block of `tests/scripts/save_policy_to_safetensors.py` and
add the policies you want to update the test artifacts for.
3. Run `python tests/scripts/save_policy_to_safetensors.py`. The test artifact
should be updated.
4. Check that this test now passes.
5. Remember to restore `tests/scripts/save_policy_to_safetensors.py` to its original state.
6. Remember to stage and commit the resulting changes to `tests/artifacts`.
NOTE: If the test does not pass, and you don't change the policy, it is likely that the test artifact
is out of date. For example, some PyTorch versions have different randomness, see this PR:
https://github.com/huggingface/lerobot/pull/1127.
NOTE: If the test don't pass and you don't change the policy, and note the dependencies version,
and you changed your processor, you might have to update the test artifact.
"""
# NOTE: ACT policy has different randomness, after PyTorch 2.7.0
if policy_name == "act" and version.parse(torch.__version__) < version.parse("2.7.0"):
pytest.skip(f"Skipping act policy test with PyTorch {torch.__version__}. Requires PyTorch >= 2.7.0")
ds_name = ds_repo_id.split("/")[-1]
artifact_dir = Path("tests/artifacts/policies") / f"{ds_name}_{policy_name}_{file_name_extra}"
saved_output_dict = load_file(artifact_dir / "output_dict.safetensors")
saved_grad_stats = load_file(artifact_dir / "grad_stats.safetensors")
saved_param_stats = load_file(artifact_dir / "param_stats.safetensors")
saved_actions = load_file(artifact_dir / "actions.safetensors")
output_dict, grad_stats, param_stats, actions = get_policy_stats(ds_repo_id, policy_name, policy_kwargs)
for key in saved_output_dict:
torch.testing.assert_close(output_dict[key], saved_output_dict[key])
for key in saved_grad_stats:
torch.testing.assert_close(grad_stats[key], saved_grad_stats[key])
for key in saved_param_stats:
torch.testing.assert_close(param_stats[key], saved_param_stats[key])
for key in saved_actions:
rtol, atol = (2e-3, 5e-6) if policy_name == "diffusion" else (None, None) # HACK
torch.testing.assert_close(actions[key], saved_actions[key], rtol=rtol, atol=atol)
def test_act_temporal_ensembler():
"""Check that the online method in ACTTemporalEnsembler matches a simple offline calculation."""
temporal_ensemble_coeff = 0.01
chunk_size = 100
episode_length = 101
ensembler = ACTTemporalEnsembler(temporal_ensemble_coeff, chunk_size)
# An batch of arbitrary sequences of 1D actions we wish to compute the average over. We'll keep the
# "action space" in [-1, 1]. Apart from that, there is no real reason for the numbers chosen.
with seeded_context(0):
# Dimension is (batch, episode_length, chunk_size, action_dim(=1))
# Stepping through the episode_length dim is like running inference at each rollout step and getting
# a different action chunk.
batch_seq = torch.stack(
[
torch.rand(episode_length, chunk_size) * 0.05 - 0.6,
torch.rand(episode_length, chunk_size) * 0.02 - 0.01,
torch.rand(episode_length, chunk_size) * 0.2 + 0.3,
],
dim=0,
).unsqueeze(-1) # unsqueeze for action dim
batch_size = batch_seq.shape[0]
# Exponential weighting (normalized). Unsqueeze once to match the position of the `episode_length`
# dimension of `batch_seq`.
weights = torch.exp(-temporal_ensemble_coeff * torch.arange(chunk_size)).unsqueeze(-1)
# Simulate stepping through a rollout and computing a batch of actions with model on each step.
for i in range(episode_length):
# Mock a batch of actions.
actions = torch.zeros(size=(batch_size, chunk_size, 1)) + batch_seq[:, i]
online_avg = ensembler.update(actions)
# Simple offline calculation: avg = Σ(aᵢ*wᵢ) / Σ(wᵢ).
# Note: The complicated bit here is the slicing. Think about the (episode_length, chunk_size) grid.
# What we want to do is take diagonal slices across it starting from the left.
# eg: chunk_size=4, episode_length=6
# ┌───────┐
# │0 1 2 3│
# │1 2 3 4│
# │2 3 4 5│
# │3 4 5 6│
# │4 5 6 7│
# │5 6 7 8│
# └───────┘
chunk_indices = torch.arange(min(i, chunk_size - 1), -1, -1)
episode_step_indices = torch.arange(i + 1)[-len(chunk_indices) :]
seq_slice = batch_seq[:, episode_step_indices, chunk_indices]
offline_avg = (
einops.reduce(seq_slice * weights[: i + 1], "b s 1 -> b 1", "sum") / weights[: i + 1].sum()
)
# Sanity check. The average should be between the extrema.
assert torch.all(einops.reduce(seq_slice, "b s 1 -> b 1", "min") <= offline_avg)
assert torch.all(offline_avg <= einops.reduce(seq_slice, "b s 1 -> b 1", "max"))
# Selected atol=1e-4 keeping in mind actions in [-1, 1] and excepting 0.01% error.
torch.testing.assert_close(online_avg, offline_avg, rtol=1e-4, atol=1e-4)
def test_vqbet_discretize_keeps_buffers_on_device():
"""Regression test: VQBeTHead.discretize() must not move registered buffers off the model device.
Previously, `self.vqvae_model.discretized = torch.tensor(True)` replaced the
registered buffer with a new CPU tensor, causing DDP to crash with:
RuntimeError: No backend type associated with device type cpu
The fix uses `.fill_(True)` to update in-place, preserving device placement.
"""
config = VQBeTConfig()
config.input_features = {
OBS_IMAGES: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 96, 96)),
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(6,)),
}
config.output_features = {
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(6,)),
}
# Tiny sizes for fast CPU/GPU execution.
config.n_vqvae_training_steps = 3
config.vqvae_n_embed = 8
config.vqvae_embedding_dim = 32
config.vqvae_enc_hidden_dim = 32
config.action_chunk_size = 2
config.crop_shape = (84, 84)
head = VQBeTHead(config).to(DEVICE)
vqvae = head.vqvae_model
dummy_actions = torch.randn(4, config.action_chunk_size, config.action_feature.shape[0], device=DEVICE)
n_steps = config.n_vqvae_training_steps
for _ in range(n_steps):
head.discretize(n_steps, dummy_actions)
assert vqvae.discretized.device.type == torch.device(DEVICE).type, (
"vqvae_model.discretized was moved off the model device after discretize(). "
"Use .fill_(True) instead of = torch.tensor(True) to keep the buffer on device."
)
assert vqvae.vq_layer.freeze_codebook.device.type == torch.device(DEVICE).type, (
"vq_layer.freeze_codebook was moved off the model device after discretize(). "
"Use .fill_(True) instead of = torch.tensor(True) to keep the buffer on device."
)