mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-19 01:07:18 +00:00
221 lines
7.0 KiB
Python
221 lines
7.0 KiB
Python
"""Tests for the benchmark dispatch refactor (create_envs / get_env_processors on EnvConfig)."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
from dataclasses import dataclass, field
|
|
|
|
import gymnasium as gym
|
|
import pytest
|
|
import torch
|
|
from gymnasium.envs.registration import register, registry as gym_registry
|
|
|
|
from lerobot.configs.types import PolicyFeature
|
|
from lerobot.envs.configs import EnvConfig, LiberoEnv
|
|
from lerobot.envs.factory import make_env, make_env_config, make_env_pre_post_processors
|
|
from lerobot.processor import LiberoActionProcessorStep, LiberoProcessorStep
|
|
from lerobot.utils.constants import OBS_PREFIX, OBS_STATE
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def test_registry_all_types():
|
|
"""make_env_config should resolve every registered EnvConfig subclass via the registry."""
|
|
known = list(EnvConfig.get_known_choices().keys())
|
|
assert len(known) >= 6
|
|
for t in known:
|
|
cfg = make_env_config(t)
|
|
if not isinstance(cfg, EnvConfig):
|
|
continue
|
|
assert cfg.type == t
|
|
|
|
|
|
def test_unknown_type():
|
|
with pytest.raises(ValueError, match="not registered"):
|
|
make_env_config("nonexistent")
|
|
|
|
|
|
def test_identity_processors():
|
|
"""Base class get_env_processors() returns identity pipelines."""
|
|
cfg = make_env_config("aloha")
|
|
pre, post = cfg.get_env_processors()
|
|
assert len(pre.steps) == 0 and len(post.steps) == 0
|
|
|
|
|
|
def test_delegation():
|
|
"""make_env() should call cfg.create_envs(), not use if/elif dispatch."""
|
|
sentinel = {"delegated": {0: "marker"}}
|
|
fake = type(
|
|
"Fake",
|
|
(),
|
|
{
|
|
"hub_path": None,
|
|
"create_envs": lambda self, n_envs, use_async_envs=False: sentinel,
|
|
},
|
|
)()
|
|
result = make_env(fake, n_envs=1)
|
|
assert result is sentinel
|
|
|
|
|
|
def test_processors_delegation():
|
|
"""make_env_pre_post_processors delegates to cfg.get_env_processors()."""
|
|
cfg = make_env_config("aloha")
|
|
pre, post = make_env_pre_post_processors(cfg, policy_cfg=None)
|
|
assert len(pre.steps) == 0
|
|
|
|
|
|
def test_processors_delegation_supports_legacy_override_signature():
|
|
"""External EnvConfig subclasses with the old get_env_processors() signature keep working."""
|
|
from lerobot.processor.pipeline import DataProcessorPipeline
|
|
|
|
@EnvConfig.register_subclass("_dispatch_legacy_proc_test")
|
|
@dataclass
|
|
class _Env(EnvConfig):
|
|
task: str = "x"
|
|
features: dict[str, PolicyFeature] = field(default_factory=dict)
|
|
|
|
@property
|
|
def gym_kwargs(self):
|
|
return {}
|
|
|
|
def get_env_processors(self):
|
|
return DataProcessorPipeline(steps=[]), DataProcessorPipeline(steps=[])
|
|
|
|
pre, post = make_env_pre_post_processors(_Env(), policy_cfg=object())
|
|
assert isinstance(pre, DataProcessorPipeline)
|
|
assert isinstance(post, DataProcessorPipeline)
|
|
|
|
|
|
def test_libero_evo1_processors_use_padded_state_and_env_action_dim():
|
|
"""EVO1 uses padded LIBERO state features while env actions stay executable."""
|
|
|
|
class _Evo1Config:
|
|
type = "evo1"
|
|
max_state_dim = 24
|
|
|
|
cfg = LiberoEnv()
|
|
pre, post = make_env_pre_post_processors(cfg, policy_cfg=_Evo1Config())
|
|
assert isinstance(pre.steps[0], LiberoProcessorStep)
|
|
assert pre.steps[0].max_state_dim == 24
|
|
assert isinstance(post.steps[0], LiberoActionProcessorStep)
|
|
assert post.steps[0].action_dim == cfg.features["action"].shape[0] == 7
|
|
|
|
class _OtherConfig:
|
|
type = "other"
|
|
|
|
pre_other, _ = make_env_pre_post_processors(cfg, policy_cfg=_OtherConfig())
|
|
assert pre_other.steps[0].max_state_dim is None
|
|
|
|
|
|
def test_libero_processor_pads_state_to_max_dim():
|
|
step = LiberoProcessorStep(max_state_dim=24)
|
|
observation = {
|
|
OBS_PREFIX
|
|
+ "robot_state": {
|
|
"eef": {
|
|
"pos": torch.tensor([[1.0, 2.0, 3.0]]),
|
|
"quat": torch.tensor([[0.0, 0.0, 0.0, 1.0]]),
|
|
},
|
|
"gripper": {"qpos": torch.tensor([[4.0, 5.0]])},
|
|
}
|
|
}
|
|
|
|
state = step.observation(observation)[OBS_STATE]
|
|
assert state.shape == (1, 24)
|
|
assert torch.allclose(state[:, :8], torch.tensor([[1.0, 2.0, 3.0, 0.0, 0.0, 0.0, 4.0, 5.0]]))
|
|
assert torch.count_nonzero(state[:, 8:]).item() == 0
|
|
|
|
|
|
def test_libero_action_processor_slices_padded_action():
|
|
step = LiberoActionProcessorStep(action_dim=7)
|
|
action = torch.arange(2 * 3 * 24, dtype=torch.float32).reshape(2, 3, 24)
|
|
|
|
sliced = step.action(action)
|
|
assert sliced.shape == (2, 3, 7)
|
|
assert torch.equal(sliced, action[..., :7])
|
|
|
|
with pytest.raises(ValueError, match="smaller than action_dim=7"):
|
|
step.action(torch.zeros(2, 6))
|
|
|
|
|
|
def test_base_create_envs():
|
|
"""Base class create_envs() should build a single-task VectorEnv via gym.make()."""
|
|
gym_id = "_dispatch_test/CartPole-v99"
|
|
if gym_id not in gym_registry:
|
|
register(id=gym_id, entry_point="gymnasium.envs.classic_control:CartPoleEnv")
|
|
|
|
@EnvConfig.register_subclass("_dispatch_base_test")
|
|
@dataclass
|
|
class _Env(EnvConfig):
|
|
task: str = "CartPole-v99"
|
|
fps: int = 10
|
|
features: dict[str, PolicyFeature] = field(default_factory=dict)
|
|
|
|
@property
|
|
def package_name(self):
|
|
return "_dispatch_test"
|
|
|
|
@property
|
|
def gym_id(self):
|
|
return gym_id
|
|
|
|
@property
|
|
def gym_kwargs(self):
|
|
return {}
|
|
|
|
try:
|
|
envs = _Env().create_envs(n_envs=2)
|
|
assert "_dispatch_base_test" in envs
|
|
env = envs["_dispatch_base_test"][0]
|
|
assert isinstance(env, gym.vector.VectorEnv)
|
|
assert env.num_envs == 2
|
|
env.close()
|
|
finally:
|
|
if gym_id in gym_registry:
|
|
del gym_registry[gym_id]
|
|
|
|
|
|
def test_custom_create_envs_override():
|
|
"""A custom EnvConfig subclass can override create_envs()."""
|
|
mock_vec = gym.vector.SyncVectorEnv([lambda: gym.make("CartPole-v1")])
|
|
|
|
@EnvConfig.register_subclass("_dispatch_custom_test")
|
|
@dataclass
|
|
class _Env(EnvConfig):
|
|
task: str = "x"
|
|
features: dict[str, PolicyFeature] = field(default_factory=dict)
|
|
|
|
@property
|
|
def gym_kwargs(self):
|
|
return {}
|
|
|
|
def create_envs(self, n_envs, use_async_envs=False):
|
|
return {"custom_suite": {0: mock_vec}}
|
|
|
|
try:
|
|
result = make_env(_Env(), n_envs=1)
|
|
assert "custom_suite" in result
|
|
finally:
|
|
mock_vec.close()
|
|
|
|
|
|
def test_custom_get_env_processors_override():
|
|
"""A custom EnvConfig subclass can override get_env_processors()."""
|
|
from lerobot.processor.pipeline import DataProcessorPipeline
|
|
|
|
@EnvConfig.register_subclass("_dispatch_proc_test")
|
|
@dataclass
|
|
class _Env(EnvConfig):
|
|
task: str = "x"
|
|
features: dict[str, PolicyFeature] = field(default_factory=dict)
|
|
|
|
@property
|
|
def gym_kwargs(self):
|
|
return {}
|
|
|
|
def get_env_processors(self, policy_cfg=None):
|
|
return DataProcessorPipeline(steps=[]), DataProcessorPipeline(steps=[])
|
|
|
|
pre, post = _Env().get_env_processors()
|
|
assert isinstance(pre, DataProcessorPipeline)
|