mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-21 19:49:49 +00:00
Feat/pipeline add feature contract (#1637)
* Add feature contract to pipelinestep and pipeline * Add tests * Add processor tests * PR feedback * encorperate pr feedback * type in doc * oops
This commit is contained in:
@@ -18,6 +18,7 @@ from typing import Any
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from lerobot.configs.types import PolicyFeature
|
||||||
from lerobot.processor.pipeline import EnvTransition, TransitionKey
|
from lerobot.processor.pipeline import EnvTransition, TransitionKey
|
||||||
|
|
||||||
|
|
||||||
@@ -74,3 +75,6 @@ class DeviceProcessor:
|
|||||||
def get_config(self) -> dict[str, Any]:
|
def get_config(self) -> dict[str, Any]:
|
||||||
"""Return configuration for serialization."""
|
"""Return configuration for serialization."""
|
||||||
return {"device": self.device}
|
return {"device": self.device}
|
||||||
|
|
||||||
|
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||||
|
return features
|
||||||
|
|||||||
@@ -204,6 +204,9 @@ class NormalizerProcessor:
|
|||||||
def reset(self):
|
def reset(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||||
|
return features
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ProcessorStepRegistry.register(name="unnormalizer_processor")
|
@ProcessorStepRegistry.register(name="unnormalizer_processor")
|
||||||
@@ -327,3 +330,6 @@ class UnnormalizerProcessor:
|
|||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||||
|
return features
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
|
from lerobot.configs.types import PolicyFeature
|
||||||
from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
|
from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
|
||||||
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey
|
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey
|
||||||
|
|
||||||
@@ -110,6 +111,27 @@ class ImageProcessor:
|
|||||||
"""Reset processor state (no-op for this processor)."""
|
"""Reset processor state (no-op for this processor)."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||||
|
"""Transforms:
|
||||||
|
pixels -> OBS_IMAGE,
|
||||||
|
observation.pixels -> OBS_IMAGE,
|
||||||
|
pixels.<cam> -> OBS_IMAGES.<cam>,
|
||||||
|
observation.pixels.<cam> -> OBS_IMAGES.<cam>
|
||||||
|
"""
|
||||||
|
if "pixels" in features:
|
||||||
|
features[OBS_IMAGE] = features.pop("pixels")
|
||||||
|
if "observation.pixels" in features:
|
||||||
|
features[OBS_IMAGE] = features.pop("observation.pixels")
|
||||||
|
|
||||||
|
prefixes = ("pixels.", "observation.pixels.")
|
||||||
|
for key in list(features.keys()):
|
||||||
|
for p in prefixes:
|
||||||
|
if key.startswith(p):
|
||||||
|
suffix = key[len(p) :]
|
||||||
|
features[f"{OBS_IMAGES}.{suffix}"] = features.pop(key)
|
||||||
|
break
|
||||||
|
return features
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class StateProcessor:
|
class StateProcessor:
|
||||||
@@ -169,6 +191,25 @@ class StateProcessor:
|
|||||||
"""Reset processor state (no-op for this processor)."""
|
"""Reset processor state (no-op for this processor)."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||||
|
"""Transforms:
|
||||||
|
environment_state -> OBS_ENV_STATE,
|
||||||
|
agent_pos -> OBS_STATE,
|
||||||
|
observation.environment_state -> OBS_ENV_STATE,
|
||||||
|
observation.agent_pos -> OBS_STATE
|
||||||
|
"""
|
||||||
|
pairs = (
|
||||||
|
("environment_state", OBS_ENV_STATE),
|
||||||
|
("agent_pos", OBS_STATE),
|
||||||
|
)
|
||||||
|
for old, new in pairs:
|
||||||
|
if old in features:
|
||||||
|
features[new] = features.pop(old)
|
||||||
|
prefixed = f"observation.{old}"
|
||||||
|
if prefixed in features:
|
||||||
|
features[new] = features.pop(prefixed)
|
||||||
|
return features
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ProcessorStepRegistry.register(name="observation_processor")
|
@ProcessorStepRegistry.register(name="observation_processor")
|
||||||
@@ -219,3 +260,8 @@ class VanillaObservationProcessor:
|
|||||||
"""Reset processor state."""
|
"""Reset processor state."""
|
||||||
self.image_processor.reset()
|
self.image_processor.reset()
|
||||||
self.state_processor.reset()
|
self.state_processor.reset()
|
||||||
|
|
||||||
|
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||||
|
features = self.image_processor.feature_contract(features)
|
||||||
|
features = self.state_processor.feature_contract(features)
|
||||||
|
return features
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ import importlib
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from collections.abc import Callable, Iterable, Sequence
|
from collections.abc import Callable, Iterable, Sequence
|
||||||
|
from copy import deepcopy
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -29,6 +30,7 @@ from huggingface_hub import ModelHubMixin, hf_hub_download
|
|||||||
from huggingface_hub.errors import HfHubHTTPError
|
from huggingface_hub.errors import HfHubHTTPError
|
||||||
from safetensors.torch import load_file, save_file
|
from safetensors.torch import load_file, save_file
|
||||||
|
|
||||||
|
from lerobot.configs.types import PolicyFeature
|
||||||
from lerobot.utils.utils import get_safe_torch_device
|
from lerobot.utils.utils import get_safe_torch_device
|
||||||
|
|
||||||
|
|
||||||
@@ -141,6 +143,11 @@ class ProcessorStep(Protocol):
|
|||||||
automatically serialise the step's configuration and learnable state using
|
automatically serialise the step's configuration and learnable state using
|
||||||
a safe-to-share JSON + SafeTensors format.
|
a safe-to-share JSON + SafeTensors format.
|
||||||
|
|
||||||
|
|
||||||
|
**Required**:
|
||||||
|
- ``__call__(transition: EnvTransition) -> EnvTransition``
|
||||||
|
- ``feature_contract(features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]``
|
||||||
|
|
||||||
Optional helper protocol:
|
Optional helper protocol:
|
||||||
* ``get_config() -> dict[str, Any]`` – User-defined JSON-serializable
|
* ``get_config() -> dict[str, Any]`` – User-defined JSON-serializable
|
||||||
configuration and state. YOU decide what to save here. This is where all
|
configuration and state. YOU decide what to save here. This is where all
|
||||||
@@ -168,6 +175,8 @@ class ProcessorStep(Protocol):
|
|||||||
|
|
||||||
def reset(self) -> None: ...
|
def reset(self) -> None: ...
|
||||||
|
|
||||||
|
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: ...
|
||||||
|
|
||||||
|
|
||||||
def _default_batch_to_transition(batch: dict[str, Any]) -> EnvTransition: # noqa: D401
|
def _default_batch_to_transition(batch: dict[str, Any]) -> EnvTransition: # noqa: D401
|
||||||
"""Convert a *batch* dict coming from Learobot replay/dataset code into an
|
"""Convert a *batch* dict coming from Learobot replay/dataset code into an
|
||||||
@@ -840,6 +849,33 @@ class RobotProcessor(ModelHubMixin):
|
|||||||
|
|
||||||
return f"RobotProcessor({', '.join(parts)})"
|
return f"RobotProcessor({', '.join(parts)})"
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
for i, step in enumerate(self.steps):
|
||||||
|
if not callable(step):
|
||||||
|
raise TypeError(
|
||||||
|
f"Step {i} ({type(step).__name__}) must define __call__(transition) -> EnvTransition"
|
||||||
|
)
|
||||||
|
|
||||||
|
fc = getattr(step, "feature_contract", None)
|
||||||
|
if not callable(fc):
|
||||||
|
raise TypeError(
|
||||||
|
f"Step {i} ({type(step).__name__}) must define feature_contract(features) -> dict[str, Any]"
|
||||||
|
)
|
||||||
|
|
||||||
|
def feature_contract(self, initial_features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||||
|
"""
|
||||||
|
Apply ALL steps in order. Each step must implement
|
||||||
|
feature_contract(features) and return a dict (full or incremental schema).
|
||||||
|
"""
|
||||||
|
features: dict[str, PolicyFeature] = deepcopy(initial_features)
|
||||||
|
|
||||||
|
for _, step in enumerate(self.steps):
|
||||||
|
out = step.feature_contract(features)
|
||||||
|
if not isinstance(out, dict):
|
||||||
|
raise TypeError(f"{step.__class__.__name__}.feature_contract must return dict[str, Any]")
|
||||||
|
features = out
|
||||||
|
return features
|
||||||
|
|
||||||
|
|
||||||
class ObservationProcessor:
|
class ObservationProcessor:
|
||||||
"""Base class for processors that modify only the observation component of a transition.
|
"""Base class for processors that modify only the observation component of a transition.
|
||||||
@@ -1145,3 +1181,6 @@ class IdentityProcessor:
|
|||||||
|
|
||||||
def reset(self) -> None:
|
def reset(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||||
|
return features
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ from typing import Any
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from lerobot.configs.types import PolicyFeature
|
||||||
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey
|
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey
|
||||||
|
|
||||||
|
|
||||||
@@ -53,3 +54,10 @@ class RenameProcessor:
|
|||||||
|
|
||||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||||
|
"""Transforms:
|
||||||
|
- Each key in the observation that appears in `rename_map` is renamed to its value.
|
||||||
|
- Keys not in `rename_map` remain unchanged.
|
||||||
|
"""
|
||||||
|
return {self.rename_map.get(k, k): v for k, v in features.items()}
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ import traceback
|
|||||||
import pytest
|
import pytest
|
||||||
from serial import SerialException
|
from serial import SerialException
|
||||||
|
|
||||||
|
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||||
from tests.utils import DEVICE
|
from tests.utils import DEVICE
|
||||||
|
|
||||||
# Import fixture modules as plugins
|
# Import fixture modules as plugins
|
||||||
@@ -69,3 +70,19 @@ def patch_builtins_input(monkeypatch):
|
|||||||
print(text)
|
print(text)
|
||||||
|
|
||||||
monkeypatch.setattr("builtins.input", print_text)
|
monkeypatch.setattr("builtins.input", print_text)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def policy_feature_factory():
|
||||||
|
"""PolicyFeature factory"""
|
||||||
|
|
||||||
|
def _pf(ft: FeatureType, shape: tuple[int, ...]) -> PolicyFeature:
|
||||||
|
return PolicyFeature(type=ft, shape=shape)
|
||||||
|
|
||||||
|
return _pf
|
||||||
|
|
||||||
|
|
||||||
|
def assert_contract_is_typed(features: dict[str, PolicyFeature]) -> None:
|
||||||
|
assert isinstance(features, dict)
|
||||||
|
assert all(isinstance(k, str) for k in features.keys())
|
||||||
|
assert all(isinstance(v, PolicyFeature) for v in features.values())
|
||||||
|
|||||||
@@ -18,12 +18,15 @@ import numpy as np
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from lerobot.configs.types import FeatureType
|
||||||
|
from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
|
||||||
from lerobot.processor import (
|
from lerobot.processor import (
|
||||||
ImageProcessor,
|
ImageProcessor,
|
||||||
StateProcessor,
|
StateProcessor,
|
||||||
VanillaObservationProcessor,
|
VanillaObservationProcessor,
|
||||||
)
|
)
|
||||||
from lerobot.processor.pipeline import TransitionKey
|
from lerobot.processor.pipeline import TransitionKey
|
||||||
|
from tests.conftest import assert_contract_is_typed
|
||||||
|
|
||||||
|
|
||||||
def create_transition(
|
def create_transition(
|
||||||
@@ -420,3 +423,79 @@ def test_equivalent_with_image_dict():
|
|||||||
|
|
||||||
for key in original_result:
|
for key in original_result:
|
||||||
torch.testing.assert_close(original_result[key], processor_result[key])
|
torch.testing.assert_close(original_result[key], processor_result[key])
|
||||||
|
|
||||||
|
|
||||||
|
def test_image_processor_feature_contract_pixels_to_image(policy_feature_factory):
|
||||||
|
processor = ImageProcessor()
|
||||||
|
features = {
|
||||||
|
"pixels": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
|
||||||
|
"keep": policy_feature_factory(FeatureType.ENV, (1,)),
|
||||||
|
}
|
||||||
|
out = processor.feature_contract(features.copy())
|
||||||
|
|
||||||
|
assert OBS_IMAGE in out and out[OBS_IMAGE] == features["pixels"]
|
||||||
|
assert "pixels" not in out
|
||||||
|
assert out["keep"] == features["keep"]
|
||||||
|
assert_contract_is_typed(out)
|
||||||
|
|
||||||
|
|
||||||
|
def test_image_processor_feature_contract_observation_pixels_to_image(policy_feature_factory):
|
||||||
|
processor = ImageProcessor()
|
||||||
|
features = {
|
||||||
|
"observation.pixels": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
|
||||||
|
"keep": policy_feature_factory(FeatureType.ENV, (1,)),
|
||||||
|
}
|
||||||
|
out = processor.feature_contract(features.copy())
|
||||||
|
|
||||||
|
assert OBS_IMAGE in out and out[OBS_IMAGE] == features["observation.pixels"]
|
||||||
|
assert "observation.pixels" not in out
|
||||||
|
assert out["keep"] == features["keep"]
|
||||||
|
assert_contract_is_typed(out)
|
||||||
|
|
||||||
|
|
||||||
|
def test_image_processor_feature_contract_multi_camera_and_prefixed(policy_feature_factory):
|
||||||
|
processor = ImageProcessor()
|
||||||
|
features = {
|
||||||
|
"pixels.front": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
|
||||||
|
"pixels.wrist": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
|
||||||
|
"observation.pixels.rear": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
|
||||||
|
"keep": policy_feature_factory(FeatureType.ENV, (7,)),
|
||||||
|
}
|
||||||
|
out = processor.feature_contract(features.copy())
|
||||||
|
|
||||||
|
assert f"{OBS_IMAGES}.front" in out and out[f"{OBS_IMAGES}.front"] == features["pixels.front"]
|
||||||
|
assert f"{OBS_IMAGES}.wrist" in out and out[f"{OBS_IMAGES}.wrist"] == features["pixels.wrist"]
|
||||||
|
assert f"{OBS_IMAGES}.rear" in out and out[f"{OBS_IMAGES}.rear"] == features["observation.pixels.rear"]
|
||||||
|
assert "pixels.front" not in out and "pixels.wrist" not in out and "observation.pixels.rear" not in out
|
||||||
|
assert out["keep"] == features["keep"]
|
||||||
|
assert_contract_is_typed(out)
|
||||||
|
|
||||||
|
|
||||||
|
def test_state_processor_feature_contract_environment_and_agent_pos(policy_feature_factory):
|
||||||
|
processor = StateProcessor()
|
||||||
|
features = {
|
||||||
|
"environment_state": policy_feature_factory(FeatureType.STATE, (3,)),
|
||||||
|
"agent_pos": policy_feature_factory(FeatureType.STATE, (7,)),
|
||||||
|
"keep": policy_feature_factory(FeatureType.ENV, (1,)),
|
||||||
|
}
|
||||||
|
out = processor.feature_contract(features.copy())
|
||||||
|
|
||||||
|
assert OBS_ENV_STATE in out and out[OBS_ENV_STATE] == features["environment_state"]
|
||||||
|
assert OBS_STATE in out and out[OBS_STATE] == features["agent_pos"]
|
||||||
|
assert "environment_state" not in out and "agent_pos" not in out
|
||||||
|
assert out["keep"] == features["keep"]
|
||||||
|
assert_contract_is_typed(out)
|
||||||
|
|
||||||
|
|
||||||
|
def test_state_processor_feature_contract_prefixed_inputs(policy_feature_factory):
|
||||||
|
proc = StateProcessor()
|
||||||
|
features = {
|
||||||
|
"observation.environment_state": policy_feature_factory(FeatureType.STATE, (2,)),
|
||||||
|
"observation.agent_pos": policy_feature_factory(FeatureType.STATE, (4,)),
|
||||||
|
}
|
||||||
|
out = proc.feature_contract(features.copy())
|
||||||
|
|
||||||
|
assert OBS_ENV_STATE in out and out[OBS_ENV_STATE] == features["observation.environment_state"]
|
||||||
|
assert OBS_STATE in out and out[OBS_STATE] == features["observation.agent_pos"]
|
||||||
|
assert "environment_state" not in out and "agent_pos" not in out
|
||||||
|
assert_contract_is_typed(out)
|
||||||
|
|||||||
@@ -16,6 +16,7 @@
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import tempfile
|
import tempfile
|
||||||
|
from collections.abc import Callable
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@@ -25,8 +26,10 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||||
from lerobot.processor import EnvTransition, ProcessorStepRegistry, RobotProcessor
|
from lerobot.processor import EnvTransition, ProcessorStepRegistry, RobotProcessor
|
||||||
from lerobot.processor.pipeline import TransitionKey
|
from lerobot.processor.pipeline import TransitionKey
|
||||||
|
from tests.conftest import assert_contract_is_typed
|
||||||
|
|
||||||
|
|
||||||
def create_transition(
|
def create_transition(
|
||||||
@@ -88,6 +91,10 @@ class MockStep:
|
|||||||
def reset(self) -> None:
|
def reset(self) -> None:
|
||||||
self.counter = 0
|
self.counter = 0
|
||||||
|
|
||||||
|
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||||
|
# We do not test feature_contract here
|
||||||
|
return features
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MockStepWithoutOptionalMethods:
|
class MockStepWithoutOptionalMethods:
|
||||||
@@ -106,6 +113,10 @@ class MockStepWithoutOptionalMethods:
|
|||||||
|
|
||||||
return transition
|
return transition
|
||||||
|
|
||||||
|
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||||
|
# We do not test feature_contract here
|
||||||
|
return features
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MockStepWithTensorState:
|
class MockStepWithTensorState:
|
||||||
@@ -158,6 +169,10 @@ class MockStepWithTensorState:
|
|||||||
self.running_mean.zero_()
|
self.running_mean.zero_()
|
||||||
self.running_count.zero_()
|
self.running_count.zero_()
|
||||||
|
|
||||||
|
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||||
|
# We do not test feature_contract here
|
||||||
|
return features
|
||||||
|
|
||||||
|
|
||||||
def test_empty_pipeline():
|
def test_empty_pipeline():
|
||||||
"""Test pipeline with no steps."""
|
"""Test pipeline with no steps."""
|
||||||
@@ -699,6 +714,10 @@ class MockModuleStep(nn.Module):
|
|||||||
self.running_mean.zero_()
|
self.running_mean.zero_()
|
||||||
self.counter = 0
|
self.counter = 0
|
||||||
|
|
||||||
|
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||||
|
# We do not test feature_contract here
|
||||||
|
return features
|
||||||
|
|
||||||
|
|
||||||
def test_to_device_with_state_dict():
|
def test_to_device_with_state_dict():
|
||||||
"""Test moving pipeline to device for steps with state_dict."""
|
"""Test moving pipeline to device for steps with state_dict."""
|
||||||
@@ -953,6 +972,10 @@ class MockNonModuleStepWithState:
|
|||||||
self.step_count.zero_()
|
self.step_count.zero_()
|
||||||
self.history.clear()
|
self.history.clear()
|
||||||
|
|
||||||
|
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||||
|
# We do not test feature_contract here
|
||||||
|
return features
|
||||||
|
|
||||||
|
|
||||||
def test_to_device_non_module_class():
|
def test_to_device_non_module_class():
|
||||||
"""Test moving pipeline to device for regular classes (non nn.Module) with tensor state.
|
"""Test moving pipeline to device for regular classes (non nn.Module) with tensor state.
|
||||||
@@ -1127,6 +1150,10 @@ class MockStepWithNonSerializableParam:
|
|||||||
def reset(self) -> None:
|
def reset(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||||
|
# We do not test feature_contract here
|
||||||
|
return features
|
||||||
|
|
||||||
|
|
||||||
@ProcessorStepRegistry.register("registered_mock_step")
|
@ProcessorStepRegistry.register("registered_mock_step")
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -1162,6 +1189,10 @@ class RegisteredMockStep:
|
|||||||
def reset(self) -> None:
|
def reset(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||||
|
# We do not test feature_contract here
|
||||||
|
return features
|
||||||
|
|
||||||
|
|
||||||
class MockEnvironment:
|
class MockEnvironment:
|
||||||
"""Mock environment for testing non-serializable parameters."""
|
"""Mock environment for testing non-serializable parameters."""
|
||||||
@@ -1483,6 +1514,10 @@ class MockStepWithMixedState:
|
|||||||
"list_value": self.list_value,
|
"list_value": self.list_value,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||||
|
# We do not test feature_contract here
|
||||||
|
return features
|
||||||
|
|
||||||
|
|
||||||
def test_to_device_with_mixed_state_types():
|
def test_to_device_with_mixed_state_types():
|
||||||
"""Test that to() only moves tensor state, while non-tensor state remains in config."""
|
"""Test that to() only moves tensor state, while non-tensor state remains in config."""
|
||||||
@@ -1790,6 +1825,10 @@ def test_state_file_naming_with_registry():
|
|||||||
def load_state_dict(self, state):
|
def load_state_dict(self, state):
|
||||||
self.state_tensor = state["state_tensor"]
|
self.state_tensor = state["state_tensor"]
|
||||||
|
|
||||||
|
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||||
|
# We do not test feature_contract here
|
||||||
|
return features
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Create pipeline with registered steps
|
# Create pipeline with registered steps
|
||||||
step1 = TestStatefulStep(1)
|
step1 = TestStatefulStep(1)
|
||||||
@@ -1843,6 +1882,10 @@ def test_override_with_nested_config():
|
|||||||
def get_config(self):
|
def get_config(self):
|
||||||
return {"name": self.name, "simple_param": self.simple_param, "nested_config": self.nested_config}
|
return {"name": self.name, "simple_param": self.simple_param, "nested_config": self.nested_config}
|
||||||
|
|
||||||
|
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||||
|
# We do not test feature_contract here
|
||||||
|
return features
|
||||||
|
|
||||||
try:
|
try:
|
||||||
step = ComplexConfigStep()
|
step = ComplexConfigStep()
|
||||||
pipeline = RobotProcessor([step])
|
pipeline = RobotProcessor([step])
|
||||||
@@ -1931,6 +1974,10 @@ def test_override_with_callables():
|
|||||||
def get_config(self):
|
def get_config(self):
|
||||||
return {"name": self.name}
|
return {"name": self.name}
|
||||||
|
|
||||||
|
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||||
|
# We do not test feature_contract here
|
||||||
|
return features
|
||||||
|
|
||||||
try:
|
try:
|
||||||
step = CallableStep()
|
step = CallableStep()
|
||||||
pipeline = RobotProcessor([step])
|
pipeline = RobotProcessor([step])
|
||||||
@@ -2059,6 +2106,10 @@ def test_override_with_device_strings():
|
|||||||
def load_state_dict(self, state):
|
def load_state_dict(self, state):
|
||||||
self.buffer = state["buffer"]
|
self.buffer = state["buffer"]
|
||||||
|
|
||||||
|
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||||
|
# We do not test feature_contract here
|
||||||
|
return features
|
||||||
|
|
||||||
try:
|
try:
|
||||||
step = DeviceAwareStep(device="cpu")
|
step = DeviceAwareStep(device="cpu")
|
||||||
pipeline = RobotProcessor([step])
|
pipeline = RobotProcessor([step])
|
||||||
@@ -2146,3 +2197,166 @@ def test_save_load_with_custom_converter_functions():
|
|||||||
# Should work with standard format (wouldn't work with custom converter)
|
# Should work with standard format (wouldn't work with custom converter)
|
||||||
result = loaded(batch)
|
result = loaded(batch)
|
||||||
assert "observation.image" in result # Standard format preserved
|
assert "observation.image" in result # Standard format preserved
|
||||||
|
|
||||||
|
|
||||||
|
class NonCompliantStep:
|
||||||
|
"""Intentionally non-compliant: missing feature_contract."""
|
||||||
|
|
||||||
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
|
return transition
|
||||||
|
|
||||||
|
|
||||||
|
def test_construction_rejects_step_without_feature_contract():
|
||||||
|
with pytest.raises(TypeError, match=r"must define feature_contract\(features\) -> dict\[str, Any\]"):
|
||||||
|
RobotProcessor([NonCompliantStep()])
|
||||||
|
|
||||||
|
|
||||||
|
class NonCallableStep:
|
||||||
|
"""Intentionally non-compliant: missing __call__."""
|
||||||
|
|
||||||
|
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||||
|
return features
|
||||||
|
|
||||||
|
|
||||||
|
def test_construction_rejects_step_without_call():
|
||||||
|
with pytest.raises(TypeError, match=r"must define __call__"):
|
||||||
|
RobotProcessor([NonCallableStep()])
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FeatureContractAddStep:
|
||||||
|
"""Adds a PolicyFeature"""
|
||||||
|
|
||||||
|
key: str = "a"
|
||||||
|
value: PolicyFeature = PolicyFeature(type=FeatureType.STATE, shape=(1,))
|
||||||
|
|
||||||
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
|
return transition
|
||||||
|
|
||||||
|
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||||
|
features[self.key] = self.value
|
||||||
|
return features
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FeatureContractMutateStep:
|
||||||
|
"""Mutates a PolicyFeature"""
|
||||||
|
|
||||||
|
key: str = "a"
|
||||||
|
fn: Callable[[PolicyFeature | None], PolicyFeature] = lambda x: x # noqa: E731
|
||||||
|
|
||||||
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
|
return transition
|
||||||
|
|
||||||
|
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||||
|
features[self.key] = self.fn(features.get(self.key))
|
||||||
|
return features
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FeatureContractBadReturnStep:
|
||||||
|
"""Returns a non-dict"""
|
||||||
|
|
||||||
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
|
return transition
|
||||||
|
|
||||||
|
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||||
|
return ["not-a-dict"]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FeatureContractRemoveStep:
|
||||||
|
"""Removes a PolicyFeature"""
|
||||||
|
|
||||||
|
key: str
|
||||||
|
|
||||||
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
|
return transition
|
||||||
|
|
||||||
|
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||||
|
features.pop(self.key, None)
|
||||||
|
return features
|
||||||
|
|
||||||
|
|
||||||
|
def test_feature_contract_orders_and_merges(policy_feature_factory):
|
||||||
|
p = RobotProcessor(
|
||||||
|
[
|
||||||
|
FeatureContractAddStep("a", policy_feature_factory(FeatureType.STATE, (1,))),
|
||||||
|
FeatureContractMutateStep("a", lambda v: PolicyFeature(type=v.type, shape=(3,))),
|
||||||
|
FeatureContractAddStep("b", policy_feature_factory(FeatureType.ENV, (2,))),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
out = p.feature_contract({})
|
||||||
|
|
||||||
|
assert out["a"].type == FeatureType.STATE and out["a"].shape == (3,)
|
||||||
|
assert out["b"].type == FeatureType.ENV and out["b"].shape == (2,)
|
||||||
|
assert_contract_is_typed(out)
|
||||||
|
|
||||||
|
|
||||||
|
def test_feature_contract_respects_initial_without_mutation(policy_feature_factory):
|
||||||
|
initial = {
|
||||||
|
"seed": policy_feature_factory(FeatureType.STATE, (7,)),
|
||||||
|
"nested": policy_feature_factory(FeatureType.ENV, (0,)),
|
||||||
|
}
|
||||||
|
p = RobotProcessor(
|
||||||
|
[
|
||||||
|
FeatureContractMutateStep("seed", lambda v: PolicyFeature(type=v.type, shape=(v.shape[0] + 1,))),
|
||||||
|
FeatureContractMutateStep(
|
||||||
|
"nested", lambda v: PolicyFeature(type=v.type, shape=(v.shape[0] + 5,))
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
out = p.feature_contract(initial_features=initial)
|
||||||
|
|
||||||
|
assert out["seed"].shape == (8,)
|
||||||
|
assert out["nested"].shape == (5,)
|
||||||
|
# Initial dict must be preserved
|
||||||
|
assert initial["seed"].shape == (7,)
|
||||||
|
assert initial["nested"].shape == (0,)
|
||||||
|
|
||||||
|
assert_contract_is_typed(out)
|
||||||
|
|
||||||
|
|
||||||
|
def test_feature_contract_type_error_on_bad_step():
|
||||||
|
p = RobotProcessor([FeatureContractAddStep(), FeatureContractBadReturnStep()])
|
||||||
|
with pytest.raises(TypeError, match=r"\w+\.feature_contract must return dict\[str, Any\]"):
|
||||||
|
_ = p.feature_contract({})
|
||||||
|
|
||||||
|
|
||||||
|
def test_feature_contract_execution_order_tracking():
|
||||||
|
class Track:
|
||||||
|
def __init__(self, label):
|
||||||
|
self.label = label
|
||||||
|
|
||||||
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
|
return transition
|
||||||
|
|
||||||
|
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||||
|
code = {"A": 1, "B": 2, "C": 3}[self.label]
|
||||||
|
pf = features.get("order", PolicyFeature(type=FeatureType.ENV, shape=()))
|
||||||
|
features["order"] = PolicyFeature(type=pf.type, shape=pf.shape + (code,))
|
||||||
|
return features
|
||||||
|
|
||||||
|
out = RobotProcessor([Track("A"), Track("B"), Track("C")]).feature_contract({})
|
||||||
|
assert out["order"].shape == (1, 2, 3)
|
||||||
|
|
||||||
|
|
||||||
|
def test_feature_contract_remove_key(policy_feature_factory):
|
||||||
|
p = RobotProcessor(
|
||||||
|
[
|
||||||
|
FeatureContractAddStep("a", policy_feature_factory(FeatureType.STATE, (1,))),
|
||||||
|
FeatureContractRemoveStep("a"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
out = p.feature_contract({})
|
||||||
|
assert "a" not in out
|
||||||
|
|
||||||
|
|
||||||
|
def test_feature_contract_remove_from_initial(policy_feature_factory):
|
||||||
|
initial = {
|
||||||
|
"keep": policy_feature_factory(FeatureType.STATE, (1,)),
|
||||||
|
"drop": policy_feature_factory(FeatureType.STATE, (1,)),
|
||||||
|
}
|
||||||
|
p = RobotProcessor([FeatureContractRemoveStep("drop")])
|
||||||
|
out = p.feature_contract(initial_features=initial)
|
||||||
|
assert "drop" not in out and out["keep"] == initial["keep"]
|
||||||
|
|||||||
@@ -19,7 +19,9 @@ from pathlib import Path
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from lerobot.configs.types import FeatureType
|
||||||
from lerobot.processor import ProcessorStepRegistry, RenameProcessor, RobotProcessor, TransitionKey
|
from lerobot.processor import ProcessorStepRegistry, RenameProcessor, RobotProcessor, TransitionKey
|
||||||
|
from tests.conftest import assert_contract_is_typed
|
||||||
|
|
||||||
|
|
||||||
def create_transition(
|
def create_transition(
|
||||||
@@ -406,3 +408,60 @@ def test_value_types_preserved():
|
|||||||
assert processed_obs["old_string"] == "hello"
|
assert processed_obs["old_string"] == "hello"
|
||||||
assert processed_obs["old_dict"] == {"nested": "value"}
|
assert processed_obs["old_dict"] == {"nested": "value"}
|
||||||
assert processed_obs["old_list"] == [1, 2, 3]
|
assert processed_obs["old_list"] == [1, 2, 3]
|
||||||
|
|
||||||
|
|
||||||
|
def test_feature_contract_basic_renaming(policy_feature_factory):
|
||||||
|
processor = RenameProcessor(rename_map={"a": "x", "b": "y"})
|
||||||
|
features = {
|
||||||
|
"a": policy_feature_factory(FeatureType.STATE, (2,)),
|
||||||
|
"b": policy_feature_factory(FeatureType.ACTION, (3,)),
|
||||||
|
"c": policy_feature_factory(FeatureType.ENV, (1,)),
|
||||||
|
}
|
||||||
|
|
||||||
|
out = processor.feature_contract(features.copy())
|
||||||
|
|
||||||
|
# Values preserved and typed
|
||||||
|
assert out["x"] == features["a"]
|
||||||
|
assert out["y"] == features["b"]
|
||||||
|
assert out["c"] == features["c"]
|
||||||
|
|
||||||
|
assert_contract_is_typed(out)
|
||||||
|
# Input not mutated
|
||||||
|
assert set(features) == {"a", "b", "c"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_feature_contract_overlapping_keys(policy_feature_factory):
|
||||||
|
# Overlapping renames: both 'a' and 'b' exist. 'a'->'b', 'b'->'c'
|
||||||
|
processor = RenameProcessor(rename_map={"a": "b", "b": "c"})
|
||||||
|
features = {
|
||||||
|
"a": policy_feature_factory(FeatureType.STATE, (1,)),
|
||||||
|
"b": policy_feature_factory(FeatureType.STATE, (2,)),
|
||||||
|
}
|
||||||
|
out = processor.feature_contract(features)
|
||||||
|
|
||||||
|
assert set(out) == {"b", "c"}
|
||||||
|
assert out["b"] == features["a"] # 'a' renamed to'b'
|
||||||
|
assert out["c"] == features["b"] # 'b' renamed to 'c'
|
||||||
|
assert_contract_is_typed(out)
|
||||||
|
|
||||||
|
|
||||||
|
def test_feature_contract_chained_processors(policy_feature_factory):
|
||||||
|
# Chain two rename processors at the contract level
|
||||||
|
processor1 = RenameProcessor(rename_map={"pos": "agent_position", "img": "camera_image"})
|
||||||
|
processor2 = RenameProcessor(
|
||||||
|
rename_map={"agent_position": "observation.state", "camera_image": "observation.image"}
|
||||||
|
)
|
||||||
|
pipeline = RobotProcessor([processor1, processor2])
|
||||||
|
|
||||||
|
spec = {
|
||||||
|
"pos": policy_feature_factory(FeatureType.STATE, (7,)),
|
||||||
|
"img": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
|
||||||
|
"extra": policy_feature_factory(FeatureType.ENV, (1,)),
|
||||||
|
}
|
||||||
|
out = pipeline.feature_contract(initial_features=spec)
|
||||||
|
|
||||||
|
assert set(out) == {"observation.state", "observation.image", "extra"}
|
||||||
|
assert out["observation.state"] == spec["pos"]
|
||||||
|
assert out["observation.image"] == spec["img"]
|
||||||
|
assert out["extra"] == spec["extra"]
|
||||||
|
assert_contract_is_typed(out)
|
||||||
|
|||||||
Reference in New Issue
Block a user