chore (batch handling): Enhance processing components with batch conversion utilities

This commit is contained in:
Adil Zouitine
2025-07-06 21:29:51 +02:00
parent c227107f60
commit b08149a113
6 changed files with 606 additions and 53 deletions
+4
View File
@@ -25,9 +25,11 @@ from .pipeline import (
ActionProcessor, ActionProcessor,
DoneProcessor, DoneProcessor,
EnvTransition, EnvTransition,
IdentityProcessor,
InfoProcessor, InfoProcessor,
ObservationProcessor, ObservationProcessor,
ProcessorStep, ProcessorStep,
ProcessorStepRegistry,
RewardProcessor, RewardProcessor,
RobotProcessor, RobotProcessor,
TruncatedProcessor, TruncatedProcessor,
@@ -39,12 +41,14 @@ __all__ = [
"DeviceProcessor", "DeviceProcessor",
"DoneProcessor", "DoneProcessor",
"EnvTransition", "EnvTransition",
"IdentityProcessor",
"ImageProcessor", "ImageProcessor",
"InfoProcessor", "InfoProcessor",
"NormalizerProcessor", "NormalizerProcessor",
"UnnormalizerProcessor", "UnnormalizerProcessor",
"ObservationProcessor", "ObservationProcessor",
"ProcessorStep", "ProcessorStep",
"ProcessorStepRegistry",
"RenameProcessor", "RenameProcessor",
"RewardProcessor", "RewardProcessor",
"RobotProcessor", "RobotProcessor",
+6 -1
View File
@@ -33,6 +33,9 @@ class DeviceProcessor:
device: str = "cpu" device: str = "cpu"
def __post_init__(self):
self.non_blocking = "cuda" in self.device
def __call__(self, transition: EnvTransition) -> EnvTransition: def __call__(self, transition: EnvTransition) -> EnvTransition:
observation: dict[str, torch.Tensor] = transition[TransitionIndex.OBSERVATION] observation: dict[str, torch.Tensor] = transition[TransitionIndex.OBSERVATION]
action = transition[TransitionIndex.ACTION] action = transition[TransitionIndex.ACTION]
@@ -43,7 +46,9 @@ class DeviceProcessor:
complementary_data = transition[TransitionIndex.COMPLEMENTARY_DATA] complementary_data = transition[TransitionIndex.COMPLEMENTARY_DATA]
if observation is not None: if observation is not None:
observation = {k: v.to(self.device) for k, v in observation.items()} observation = {
k: v.to(self.device, non_blocking=self.non_blocking) for k, v in observation.items()
}
if action is not None: if action is not None:
action = action.to(self.device) action = action.to(self.device)
+61 -26
View File
@@ -1,12 +1,13 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Mapping from typing import Any, Mapping, Optional, Set
import numpy as np import numpy as np
import torch import torch
from torch import Tensor from torch import Tensor
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionIndex from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionIndex
@@ -45,8 +46,18 @@ class NormalizerProcessor:
the normalize_keys parameter. the normalize_keys parameter.
""" """
stats: dict[str, dict[str, Any]] # Features and normalisation map are mandatory to match the design of normalize.py
normalize_keys: set[str] | None = None features: dict[str, PolicyFeature]
norm_map: dict[FeatureType, NormalizationMode]
# Pre-computed statistics coming from dataset.meta.stats for instance.
stats: Optional[dict[str, dict[str, Any]]] = None
# Explicit subset of keys to normalise. If ``None`` every key (except
# "action") found in ``stats`` will be normalised. Using a ``set`` makes
# membership checks O(1).
normalize_keys: Optional[Set[str]] = None
eps: float = 1e-8 eps: float = 1e-8
_tensor_stats: dict[str, dict[str, Tensor]] = field(default_factory=dict, init=False, repr=False) _tensor_stats: dict[str, dict[str, Tensor]] = field(default_factory=dict, init=False, repr=False)
@@ -55,24 +66,48 @@ class NormalizerProcessor:
def from_lerobot_dataset( def from_lerobot_dataset(
cls, cls,
dataset: LeRobotDataset, dataset: LeRobotDataset,
features: dict[str, PolicyFeature],
norm_map: dict[FeatureType, NormalizationMode],
*, *,
normalize_keys: set[str] | None = None, normalize_keys: Optional[Set[str]] = None,
eps: float = 1e-8, eps: float = 1e-8,
) -> NormalizerProcessor: ) -> "NormalizerProcessor":
return cls(stats=dataset.meta.stats, normalize_keys=normalize_keys, eps=eps) """Factory helper that pulls statistics from a :class:`LeRobotDataset`.
The features and norm_map parameters are mandatory to match the design
pattern used in normalize.py.
"""
return cls(
features=features,
norm_map=norm_map,
stats=dataset.meta.stats,
normalize_keys=normalize_keys,
eps=eps,
)
def __post_init__(self): def __post_init__(self):
# Convert statistics once so we avoid repeated numpy→Tensor conversions
# during runtime.
self.stats = self.stats or {}
self._tensor_stats = _convert_stats_to_tensors(self.stats) self._tensor_stats = _convert_stats_to_tensors(self.stats)
# Ensure *normalize_keys* is a set for fast look-ups and compare by
# value later when returning the configuration.
if self.normalize_keys is not None and not isinstance(self.normalize_keys, set):
self.normalize_keys = set(self.normalize_keys)
def _normalize_obs(self, observation): def _normalize_obs(self, observation):
if observation is None: if observation is None:
return None return None
keys_to_norm = ( # Decide which keys should be normalised for this call.
self.normalize_keys if self.normalize_keys is not None:
if self.normalize_keys is not None keys_to_norm = self.normalize_keys
else {k for k in self._tensor_stats if k != "action"} else:
) # Use feature map to skip action keys.
keys_to_norm = {k for k, ft in self.features.items() if ft.type is not FeatureType.ACTION}
processed = dict(observation) processed = dict(observation)
for key in keys_to_norm: for key in keys_to_norm:
if key not in processed or key not in self._tensor_stats: if key not in processed or key not in self._tensor_stats:
@@ -126,7 +161,11 @@ class NormalizerProcessor:
) )
def get_config(self) -> dict[str, Any]: def get_config(self) -> dict[str, Any]:
return {"normalize_keys": list(self.normalize_keys) if self.normalize_keys else None, "eps": self.eps} config = {"eps": self.eps}
if self.normalize_keys is not None:
# Serialise as a list for YAML / JSON friendliness
config["normalize_keys"] = sorted(self.normalize_keys)
return config
def state_dict(self) -> dict[str, Tensor]: def state_dict(self) -> dict[str, Tensor]:
flat = {} flat = {}
@@ -154,8 +193,9 @@ class UnnormalizerProcessor:
transform. transform.
""" """
stats: dict[str, dict[str, Any]] features: dict[str, PolicyFeature]
unnormalize_keys: set[str] | None = None norm_map: dict[FeatureType, NormalizationMode]
stats: Optional[dict[str, dict[str, Any]]] = None
eps: float = 1e-8 eps: float = 1e-8
_tensor_stats: dict[str, dict[str, Tensor]] = field(default_factory=dict, init=False, repr=False) _tensor_stats: dict[str, dict[str, Tensor]] = field(default_factory=dict, init=False, repr=False)
@@ -164,23 +204,21 @@ class UnnormalizerProcessor:
def from_lerobot_dataset( def from_lerobot_dataset(
cls, cls,
dataset: LeRobotDataset, dataset: LeRobotDataset,
features: dict[str, PolicyFeature],
norm_map: dict[FeatureType, NormalizationMode],
*, *,
unnormalize_keys: set[str] | None = None,
eps: float = 1e-8, eps: float = 1e-8,
) -> UnnormalizerProcessor: ) -> "UnnormalizerProcessor":
return cls(stats=dataset.meta.stats, unnormalize_keys=unnormalize_keys, eps=eps) return cls(features=features, norm_map=norm_map, stats=dataset.meta.stats, eps=eps)
def __post_init__(self): def __post_init__(self):
self.stats = self.stats or {}
self._tensor_stats = _convert_stats_to_tensors(self.stats) self._tensor_stats = _convert_stats_to_tensors(self.stats)
def _unnormalize_obs(self, observation): def _unnormalize_obs(self, observation):
if observation is None: if observation is None:
return None return None
keys = ( keys = [k for k, ft in self.features.items() if ft.type is not FeatureType.ACTION]
self.unnormalize_keys
if self.unnormalize_keys is not None
else {k for k in self._tensor_stats if k != "action"}
)
processed = dict(observation) processed = dict(observation)
for key in keys: for key in keys:
if key not in processed or key not in self._tensor_stats: if key not in processed or key not in self._tensor_stats:
@@ -231,10 +269,7 @@ class UnnormalizerProcessor:
) )
def get_config(self) -> dict[str, Any]: def get_config(self) -> dict[str, Any]:
return { return {"eps": self.eps}
"unnormalize_keys": list(self.unnormalize_keys) if self.unnormalize_keys else None,
"eps": self.eps,
}
def state_dict(self) -> dict[str, Tensor]: def state_dict(self) -> dict[str, Tensor]:
flat = {} flat = {}
+121 -7
View File
@@ -42,7 +42,7 @@ class TransitionIndex(IntEnum):
# (observation, action, reward, done, truncated, info, complementary_data) # (observation, action, reward, done, truncated, info, complementary_data)
EnvTransition = Tuple[ EnvTransition = Tuple[
Any | None, # observation dict[str, Any] | None, # observation
Any | None, # action Any | None, # action
float | None, # reward float | None, # reward
bool | None, # done bool | None, # done
@@ -162,6 +162,79 @@ class ProcessorStep(Protocol):
def reset(self) -> None: ... def reset(self) -> None: ...
def _default_batch_to_transition(batch: dict[str, Any]) -> EnvTransition: # noqa: D401
"""Convert a *batch* dict coming from Learobot replay/dataset code into an
``EnvTransition`` tuple.
The function is intentionally **strictly positional** it maps well known
keys to the fixed slot order used inside the pipeline. Missing keys are
filled with sane defaults (``None`` or ``0.0``/``False``).
Keys recognised (case-sensitive):
* "observation.*" (keys starting with "observation." are grouped into observation dict)
* "action"
* "next.reward"
* "next.done"
* "next.truncated"
* "info"
Additional keys are ignored so that existing dataloaders can carry extra
metadata without breaking the processor.
"""
# Handle observation and observation.* keys
observation_keys = {k: v for k, v in batch.items() if k.startswith("observation.")}
observation = None
if observation_keys:
observation = {}
# Add observation.* keys to the observation dict, removing the "observation." prefix
for key, value in observation_keys.items():
observation[key] = value
return (
observation,
batch.get("action"),
batch.get("next.reward", 0.0),
batch.get("next.done", False),
batch.get("next.truncated", False),
batch.get("info", {}),
{},
)
def _default_transition_to_batch(transition: EnvTransition) -> dict[str, Any]: # noqa: D401
"""Inverse of :pyfunc:`_default_batch_to_transition`. Returns a dict with
the canonical field names used throughout *LeRobot*.
"""
(
observation,
action,
reward,
done,
truncated,
info,
_,
) = transition
batch = {
"action": action,
"next.reward": reward,
"next.done": done,
"next.truncated": truncated,
"info": info,
}
# Handle observation - flatten dict to observation.* keys if it's a dict
if isinstance(observation, dict):
# Check if this looks like a dict that was created from observation.* keys
for key, value in observation.items():
batch[key] = value
return batch
@dataclass @dataclass
class RobotProcessor(ModelHubMixin): class RobotProcessor(ModelHubMixin):
""" """
@@ -200,6 +273,13 @@ class RobotProcessor(ModelHubMixin):
name: str = "RobotProcessor" name: str = "RobotProcessor"
seed: int | None = None seed: int | None = None
to_transition: Callable[[dict[str, Any]], EnvTransition] = field(
default_factory=lambda: _default_batch_to_transition, repr=False
)
to_batch: Callable[[EnvTransition], dict[str, Any]] = field(
default_factory=lambda: _default_transition_to_batch, repr=False
)
# Processor-level hooks # Processor-level hooks
# A hook can optionally return a modified transition. If it returns # A hook can optionally return a modified transition. If it returns
# ``None`` the current value is left untouched. # ``None`` the current value is left untouched.
@@ -211,14 +291,29 @@ class RobotProcessor(ModelHubMixin):
) )
reset_hooks: list[Callable[[], None]] = field(default_factory=list, repr=False) reset_hooks: list[Callable[[], None]] = field(default_factory=list, repr=False)
def __call__(self, transition: EnvTransition) -> EnvTransition: def __call__(self, data: EnvTransition | dict[str, Any]):
"""Run *transition* through every step, firing hooks on the way.""" """Process *data* through all steps.
# Basic validation with helpful error message The method accepts **either** the classic :pydata:`EnvTransition` tuple
**or** a *batch* dictionary (like the ones returned by
:class:`lerobot.utils.buffer.ReplayBuffer` or
:class:`lerobot.datasets.lerobot_dataset.LeRobotDataset`). If a dict is
supplied it is first converted to the internal tuple format using
:pyattr:`to_transition`; after all steps are executed the tuple is
transformed back into a dict with :pyattr:`to_batch` and the result is
returned thereby preserving the caller's original data type.
"""
called_with_batch = isinstance(data, dict)
transition = self.to_transition(data) if called_with_batch else data
# Basic validation with helpful error message for tuple input
if not isinstance(transition, tuple) or len(transition) != 7: if not isinstance(transition, tuple) or len(transition) != 7:
raise ValueError( raise ValueError(
f"EnvTransition must be a 7-tuple of (observation, action, reward, done, truncated, info, complementary_data), " "EnvTransition must be a 7-tuple of (observation, action, reward, done, "
f"got {type(transition).__name__} with length {len(transition) if hasattr(transition, '__len__') else 'unknown'}" "truncated, info, complementary_data). "
f"Got {type(transition).__name__} with length {len(transition) if hasattr(transition, '__len__') else 'unknown'}."
) )
for idx, processor_step in enumerate(self.steps): for idx, processor_step in enumerate(self.steps):
@@ -234,7 +329,7 @@ class RobotProcessor(ModelHubMixin):
if updated is not None: if updated is not None:
transition = updated transition = updated
return transition return self.to_batch(transition) if called_with_batch else transition
def step_through(self, transition: EnvTransition) -> Iterable[EnvTransition]: def step_through(self, transition: EnvTransition) -> Iterable[EnvTransition]:
"""Yield the intermediate Transition instances after each processor step.""" """Yield the intermediate Transition instances after each processor step."""
@@ -737,3 +832,22 @@ class InfoProcessor:
*transition[TransitionIndex.COMPLEMENTARY_DATA :], *transition[TransitionIndex.COMPLEMENTARY_DATA :],
) )
return transition return transition
class IdentityProcessor:
"""Identity processor that does nothing."""
def __call__(self, transition: EnvTransition) -> EnvTransition:
return transition
def get_config(self) -> dict[str, Any]:
return {}
def state_dict(self) -> dict[str, torch.Tensor]:
return {}
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
pass
def reset(self) -> None:
pass
+288
View File
@@ -0,0 +1,288 @@
import torch
from lerobot.processor.pipeline import (
RobotProcessor,
TransitionIndex,
_default_batch_to_transition,
_default_transition_to_batch,
)
def _dummy_batch():
"""Create a dummy batch using the new format with observation.* and next.* keys."""
return {
"observation.image.left": torch.randn(1, 3, 128, 128),
"observation.image.right": torch.randn(1, 3, 128, 128),
"observation.state": torch.tensor([[0.1, 0.2, 0.3, 0.4]]),
"action": torch.tensor([[0.5]]),
"next.reward": 1.0,
"next.done": False,
"next.truncated": False,
"info": {"key": "value"},
}
def test_observation_grouping_roundtrip():
"""Test that observation.* keys are properly grouped and ungrouped."""
proc = RobotProcessor([])
batch_in = _dummy_batch()
batch_out = proc(batch_in)
# Check that all observation.* keys are preserved
original_obs_keys = {k: v for k, v in batch_in.items() if k.startswith("observation.")}
reconstructed_obs_keys = {k: v for k, v in batch_out.items() if k.startswith("observation.")}
assert set(original_obs_keys.keys()) == set(reconstructed_obs_keys.keys())
# Check tensor values
assert torch.allclose(batch_out["observation.image.left"], batch_in["observation.image.left"])
assert torch.allclose(batch_out["observation.image.right"], batch_in["observation.image.right"])
assert torch.allclose(batch_out["observation.state"], batch_in["observation.state"])
# Check other fields
assert torch.allclose(batch_out["action"], batch_in["action"])
assert batch_out["next.reward"] == batch_in["next.reward"]
assert batch_out["next.done"] == batch_in["next.done"]
assert batch_out["next.truncated"] == batch_in["next.truncated"]
assert batch_out["info"] == batch_in["info"]
def test_batch_to_transition_observation_grouping():
"""Test that _default_batch_to_transition correctly groups observation.* keys."""
batch = {
"observation.image.top": torch.randn(1, 3, 128, 128),
"observation.image.left": torch.randn(1, 3, 128, 128),
"observation.state": [1, 2, 3, 4],
"action": "action_data",
"next.reward": 1.5,
"next.done": True,
"next.truncated": False,
"info": {"episode": 42},
}
transition = _default_batch_to_transition(batch)
# Check observation is a dict with all observation.* keys
assert isinstance(transition[TransitionIndex.OBSERVATION], dict)
assert "observation.image.top" in transition[TransitionIndex.OBSERVATION]
assert "observation.image.left" in transition[TransitionIndex.OBSERVATION]
assert "observation.state" in transition[TransitionIndex.OBSERVATION]
# Check values are preserved
assert torch.allclose(
transition[TransitionIndex.OBSERVATION]["observation.image.top"], batch["observation.image.top"]
)
assert torch.allclose(
transition[TransitionIndex.OBSERVATION]["observation.image.left"], batch["observation.image.left"]
)
assert transition[TransitionIndex.OBSERVATION]["observation.state"] == [1, 2, 3, 4]
# Check other fields
assert transition[TransitionIndex.ACTION] == "action_data"
assert transition[TransitionIndex.REWARD] == 1.5
assert transition[TransitionIndex.DONE]
assert not transition[TransitionIndex.TRUNCATED]
assert transition[TransitionIndex.INFO] == {"episode": 42}
assert transition[TransitionIndex.COMPLEMENTARY_DATA] == {}
def test_transition_to_batch_observation_flattening():
"""Test that _default_transition_to_batch correctly flattens observation dict."""
observation_dict = {
"observation.image.top": torch.randn(1, 3, 128, 128),
"observation.image.left": torch.randn(1, 3, 128, 128),
"observation.state": [1, 2, 3, 4],
}
transition = (
observation_dict, # observation
"action_data", # action
1.5, # reward
True, # done
False, # truncated
{"episode": 42}, # info
{}, # complementary_data
)
batch = _default_transition_to_batch(transition)
# Check that observation.* keys are flattened back to batch
assert "observation.image.top" in batch
assert "observation.image.left" in batch
assert "observation.state" in batch
# Check values are preserved
assert torch.allclose(batch["observation.image.top"], observation_dict["observation.image.top"])
assert torch.allclose(batch["observation.image.left"], observation_dict["observation.image.left"])
assert batch["observation.state"] == [1, 2, 3, 4]
# Check other fields are mapped to next.* format
assert batch["action"] == "action_data"
assert batch["next.reward"] == 1.5
assert batch["next.done"]
assert not batch["next.truncated"]
assert batch["info"] == {"episode": 42}
def test_no_observation_keys():
"""Test behavior when there are no observation.* keys."""
batch = {
"action": "action_data",
"next.reward": 2.0,
"next.done": False,
"next.truncated": True,
"info": {"test": "no_obs"},
}
transition = _default_batch_to_transition(batch)
# Observation should be None when no observation.* keys
assert transition[TransitionIndex.OBSERVATION] is None
# Check other fields
assert transition[TransitionIndex.ACTION] == "action_data"
assert transition[TransitionIndex.REWARD] == 2.0
assert not transition[TransitionIndex.DONE]
assert transition[TransitionIndex.TRUNCATED]
assert transition[TransitionIndex.INFO] == {"test": "no_obs"}
# Round trip should work
reconstructed_batch = _default_transition_to_batch(transition)
assert reconstructed_batch["action"] == "action_data"
assert reconstructed_batch["next.reward"] == 2.0
assert not reconstructed_batch["next.done"]
assert reconstructed_batch["next.truncated"]
assert reconstructed_batch["info"] == {"test": "no_obs"}
def test_minimal_batch():
"""Test with minimal batch containing only observation.* and action."""
batch = {"observation.state": "minimal_state", "action": "minimal_action"}
transition = _default_batch_to_transition(batch)
# Check observation
assert transition[TransitionIndex.OBSERVATION] == {"observation.state": "minimal_state"}
assert transition[TransitionIndex.ACTION] == "minimal_action"
# Check defaults
assert transition[TransitionIndex.REWARD] == 0.0
assert not transition[TransitionIndex.DONE]
assert not transition[TransitionIndex.TRUNCATED]
assert transition[TransitionIndex.INFO] == {}
assert transition[TransitionIndex.COMPLEMENTARY_DATA] == {}
# Round trip
reconstructed_batch = _default_transition_to_batch(transition)
assert reconstructed_batch["observation.state"] == "minimal_state"
assert reconstructed_batch["action"] == "minimal_action"
assert reconstructed_batch["next.reward"] == 0.0
assert not reconstructed_batch["next.done"]
assert not reconstructed_batch["next.truncated"]
assert reconstructed_batch["info"] == {}
def test_empty_batch():
"""Test behavior with empty batch."""
batch = {}
transition = _default_batch_to_transition(batch)
# All fields should have defaults
assert transition[TransitionIndex.OBSERVATION] is None
assert transition[TransitionIndex.ACTION] is None
assert transition[TransitionIndex.REWARD] == 0.0
assert not transition[TransitionIndex.DONE]
assert not transition[TransitionIndex.TRUNCATED]
assert transition[TransitionIndex.INFO] == {}
assert transition[TransitionIndex.COMPLEMENTARY_DATA] == {}
# Round trip
reconstructed_batch = _default_transition_to_batch(transition)
assert reconstructed_batch["action"] is None
assert reconstructed_batch["next.reward"] == 0.0
assert not reconstructed_batch["next.done"]
assert not reconstructed_batch["next.truncated"]
assert reconstructed_batch["info"] == {}
def test_complex_nested_observation():
"""Test with complex nested observation data."""
batch = {
"observation.image.top": {"image": torch.randn(1, 3, 128, 128), "timestamp": 1234567890},
"observation.image.left": {"image": torch.randn(1, 3, 128, 128), "timestamp": 1234567891},
"observation.state": torch.randn(7),
"action": torch.randn(8),
"next.reward": 3.14,
"next.done": False,
"next.truncated": True,
"info": {"episode_length": 200, "success": True},
}
transition = _default_batch_to_transition(batch)
reconstructed_batch = _default_transition_to_batch(transition)
# Check that all observation keys are preserved
original_obs_keys = {k for k in batch.keys() if k.startswith("observation.")}
reconstructed_obs_keys = {k for k in reconstructed_batch.keys() if k.startswith("observation.")}
assert original_obs_keys == reconstructed_obs_keys
# Check tensor values
assert torch.allclose(batch["observation.state"], reconstructed_batch["observation.state"])
# Check nested dict with tensors
assert torch.allclose(
batch["observation.image.top"]["image"], reconstructed_batch["observation.image.top"]["image"]
)
assert torch.allclose(
batch["observation.image.left"]["image"], reconstructed_batch["observation.image.left"]["image"]
)
# Check action tensor
assert torch.allclose(batch["action"], reconstructed_batch["action"])
# Check other fields
assert batch["next.reward"] == reconstructed_batch["next.reward"]
assert batch["next.done"] == reconstructed_batch["next.done"]
assert batch["next.truncated"] == reconstructed_batch["next.truncated"]
assert batch["info"] == reconstructed_batch["info"]
def test_custom_converter():
"""Test that custom converters can still be used."""
def to_tr(batch):
# Custom converter that modifies the reward
tr = _default_batch_to_transition(batch)
# Double the reward
reward = tr[TransitionIndex.REWARD] * 2 if tr[TransitionIndex.REWARD] is not None else 0.0
return (
tr[TransitionIndex.OBSERVATION],
tr[TransitionIndex.ACTION],
reward,
tr[TransitionIndex.DONE],
tr[TransitionIndex.TRUNCATED],
tr[TransitionIndex.INFO],
tr[TransitionIndex.COMPLEMENTARY_DATA],
)
def to_batch(tr):
# Custom converter that adds a custom field
batch = _default_transition_to_batch(tr)
batch["custom_field"] = "custom_value"
return batch
proc = RobotProcessor([], to_transition=to_tr, to_batch=to_batch)
batch = _dummy_batch()
out = proc(batch)
# Check that custom modifications were applied
assert out["next.reward"] == batch["next.reward"] * 2
assert out["custom_field"] == "custom_value"
# Check that observation.* keys are still preserved
original_obs_keys = {k: v for k, v in batch.items() if k.startswith("observation.")}
output_obs_keys = {k: v for k, v in out.items() if k.startswith("observation.")}
assert set(original_obs_keys.keys()) == set(output_obs_keys.keys())
+126 -19
View File
@@ -4,6 +4,7 @@ import numpy as np
import pytest import pytest
import torch import torch
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.processor.normalize_processor import ( from lerobot.processor.normalize_processor import (
NormalizerProcessor, NormalizerProcessor,
UnnormalizerProcessor, UnnormalizerProcessor,
@@ -76,6 +77,21 @@ def test_unsupported_type():
_convert_stats_to_tensors(stats) _convert_stats_to_tensors(stats)
# Helper functions to create feature maps and norm maps
def _create_observation_features():
return {
"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96)),
"observation.state": PolicyFeature(FeatureType.STATE, (2,)),
}
def _create_observation_norm_map():
return {
FeatureType.VISUAL: NormalizationMode.MEAN_STD,
FeatureType.STATE: NormalizationMode.MIN_MAX,
}
# Fixtures for observation normalisation tests using NormalizerProcessor # Fixtures for observation normalisation tests using NormalizerProcessor
@pytest.fixture @pytest.fixture
def observation_stats(): def observation_stats():
@@ -94,7 +110,9 @@ def observation_stats():
@pytest.fixture @pytest.fixture
def observation_normalizer(observation_stats): def observation_normalizer(observation_stats):
"""Return a NormalizerProcessor that only has observation stats (no action).""" """Return a NormalizerProcessor that only has observation stats (no action)."""
return NormalizerProcessor(stats=observation_stats) features = _create_observation_features()
norm_map = _create_observation_norm_map()
return NormalizerProcessor(features=features, norm_map=norm_map, stats=observation_stats)
def test_mean_std_normalization(observation_normalizer): def test_mean_std_normalization(observation_normalizer):
@@ -129,7 +147,11 @@ def test_min_max_normalization(observation_normalizer):
def test_selective_normalization(observation_stats): def test_selective_normalization(observation_stats):
normalizer = NormalizerProcessor(stats=observation_stats, normalize_keys={"observation.image"}) features = _create_observation_features()
norm_map = _create_observation_norm_map()
normalizer = NormalizerProcessor(
features=features, norm_map=norm_map, stats=observation_stats, normalize_keys={"observation.image"}
)
observation = { observation = {
"observation.image": torch.tensor([0.7, 0.5, 0.3]), "observation.image": torch.tensor([0.7, 0.5, 0.3]),
@@ -148,7 +170,9 @@ def test_selective_normalization(observation_stats):
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_device_compatibility(observation_stats): def test_device_compatibility(observation_stats):
normalizer = NormalizerProcessor(stats=observation_stats) features = _create_observation_features()
norm_map = _create_observation_norm_map()
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=observation_stats)
observation = { observation = {
"observation.image": torch.tensor([0.7, 0.5, 0.3]).cuda(), "observation.image": torch.tensor([0.7, 0.5, 0.3]).cuda(),
} }
@@ -165,10 +189,19 @@ def test_from_lerobot_dataset():
mock_dataset = Mock() mock_dataset = Mock()
mock_dataset.meta.stats = { mock_dataset.meta.stats = {
"observation.image": {"mean": [0.5], "std": [0.2]}, "observation.image": {"mean": [0.5], "std": [0.2]},
"action": {"mean": [0.0], "std": [1.0]}, # Should be filtered out "action": {"mean": [0.0], "std": [1.0]},
} }
normalizer = NormalizerProcessor.from_lerobot_dataset(mock_dataset) features = {
"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96)),
"action": PolicyFeature(FeatureType.ACTION, (1,)),
}
norm_map = {
FeatureType.VISUAL: NormalizationMode.MEAN_STD,
FeatureType.ACTION: NormalizationMode.MEAN_STD,
}
normalizer = NormalizerProcessor.from_lerobot_dataset(mock_dataset, features, norm_map)
# Both observation and action statistics should be present in tensor stats # Both observation and action statistics should be present in tensor stats
assert "observation.image" in normalizer._tensor_stats assert "observation.image" in normalizer._tensor_stats
@@ -180,7 +213,9 @@ def test_state_dict_save_load(observation_normalizer):
state_dict = observation_normalizer.state_dict() state_dict = observation_normalizer.state_dict()
# Create new normalizer and load state # Create new normalizer and load state
new_normalizer = NormalizerProcessor(stats={}) features = _create_observation_features()
norm_map = _create_observation_norm_map()
new_normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats={})
new_normalizer.load_state_dict(state_dict) new_normalizer.load_state_dict(state_dict)
# Test that it works the same # Test that it works the same
@@ -210,8 +245,30 @@ def action_stats_min_max():
} }
def _create_action_features():
return {
"action": PolicyFeature(FeatureType.ACTION, (3,)),
}
def _create_action_norm_map_mean_std():
return {
FeatureType.ACTION: NormalizationMode.MEAN_STD,
}
def _create_action_norm_map_min_max():
return {
FeatureType.ACTION: NormalizationMode.MIN_MAX,
}
def test_mean_std_unnormalization(action_stats_mean_std): def test_mean_std_unnormalization(action_stats_mean_std):
unnormalizer = UnnormalizerProcessor(stats={"action": action_stats_mean_std}) features = _create_action_features()
norm_map = _create_action_norm_map_mean_std()
unnormalizer = UnnormalizerProcessor(
features=features, norm_map=norm_map, stats={"action": action_stats_mean_std}
)
normalized_action = torch.tensor([1.0, -0.5, 2.0]) normalized_action = torch.tensor([1.0, -0.5, 2.0])
transition = (None, normalized_action, None, None, None, None, None) transition = (None, normalized_action, None, None, None, None, None)
@@ -225,7 +282,11 @@ def test_mean_std_unnormalization(action_stats_mean_std):
def test_min_max_unnormalization(action_stats_min_max): def test_min_max_unnormalization(action_stats_min_max):
unnormalizer = UnnormalizerProcessor(stats={"action": action_stats_min_max}) features = _create_action_features()
norm_map = _create_action_norm_map_min_max()
unnormalizer = UnnormalizerProcessor(
features=features, norm_map=norm_map, stats={"action": action_stats_min_max}
)
# Actions in [-1, 1] # Actions in [-1, 1]
normalized_action = torch.tensor([0.0, -1.0, 1.0]) normalized_action = torch.tensor([0.0, -1.0, 1.0])
@@ -247,7 +308,11 @@ def test_min_max_unnormalization(action_stats_min_max):
def test_numpy_action_input(action_stats_mean_std): def test_numpy_action_input(action_stats_mean_std):
unnormalizer = UnnormalizerProcessor(stats={"action": action_stats_mean_std}) features = _create_action_features()
norm_map = _create_action_norm_map_mean_std()
unnormalizer = UnnormalizerProcessor(
features=features, norm_map=norm_map, stats={"action": action_stats_mean_std}
)
normalized_action = np.array([1.0, -0.5, 2.0], dtype=np.float32) normalized_action = np.array([1.0, -0.5, 2.0], dtype=np.float32)
transition = (None, normalized_action, None, None, None, None, None) transition = (None, normalized_action, None, None, None, None, None)
@@ -261,7 +326,11 @@ def test_numpy_action_input(action_stats_mean_std):
def test_none_action(action_stats_mean_std): def test_none_action(action_stats_mean_std):
unnormalizer = UnnormalizerProcessor(stats={"action": action_stats_mean_std}) features = _create_action_features()
norm_map = _create_action_norm_map_mean_std()
unnormalizer = UnnormalizerProcessor(
features=features, norm_map=norm_map, stats={"action": action_stats_mean_std}
)
transition = (None, None, None, None, None, None, None) transition = (None, None, None, None, None, None, None)
result = unnormalizer(transition) result = unnormalizer(transition)
@@ -273,7 +342,9 @@ def test_none_action(action_stats_mean_std):
def test_action_from_lerobot_dataset(): def test_action_from_lerobot_dataset():
mock_dataset = Mock() mock_dataset = Mock()
mock_dataset.meta.stats = {"action": {"mean": [0.0], "std": [1.0]}} mock_dataset.meta.stats = {"action": {"mean": [0.0], "std": [1.0]}}
unnormalizer = UnnormalizerProcessor.from_lerobot_dataset(mock_dataset) features = {"action": PolicyFeature(FeatureType.ACTION, (1,))}
norm_map = {FeatureType.ACTION: NormalizationMode.MEAN_STD}
unnormalizer = UnnormalizerProcessor.from_lerobot_dataset(mock_dataset, features, norm_map)
assert "mean" in unnormalizer._tensor_stats["action"] assert "mean" in unnormalizer._tensor_stats["action"]
@@ -296,9 +367,27 @@ def full_stats():
} }
def _create_full_features():
return {
"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96)),
"observation.state": PolicyFeature(FeatureType.STATE, (2,)),
"action": PolicyFeature(FeatureType.ACTION, (2,)),
}
def _create_full_norm_map():
return {
FeatureType.VISUAL: NormalizationMode.MEAN_STD,
FeatureType.STATE: NormalizationMode.MIN_MAX,
FeatureType.ACTION: NormalizationMode.MEAN_STD,
}
@pytest.fixture @pytest.fixture
def normalizer_processor(full_stats): def normalizer_processor(full_stats):
return NormalizerProcessor(stats=full_stats) features = _create_full_features()
norm_map = _create_full_norm_map()
return NormalizerProcessor(features=features, norm_map=norm_map, stats=full_stats)
def test_combined_normalization(normalizer_processor): def test_combined_normalization(normalizer_processor):
@@ -331,7 +420,12 @@ def test_processor_from_lerobot_dataset(full_stats):
mock_dataset = Mock() mock_dataset = Mock()
mock_dataset.meta.stats = full_stats mock_dataset.meta.stats = full_stats
processor = NormalizerProcessor.from_lerobot_dataset(mock_dataset, normalize_keys={"observation.image"}) features = _create_full_features()
norm_map = _create_full_norm_map()
processor = NormalizerProcessor.from_lerobot_dataset(
mock_dataset, features, norm_map, normalize_keys={"observation.image"}
)
assert processor.normalize_keys == {"observation.image"} assert processor.normalize_keys == {"observation.image"}
assert "observation.image" in processor._tensor_stats assert "observation.image" in processor._tensor_stats
@@ -339,7 +433,11 @@ def test_processor_from_lerobot_dataset(full_stats):
def test_get_config(full_stats): def test_get_config(full_stats):
processor = NormalizerProcessor(stats=full_stats, normalize_keys={"observation.image"}, eps=1e-6) features = _create_full_features()
norm_map = _create_full_norm_map()
processor = NormalizerProcessor(
features=features, norm_map=norm_map, stats=full_stats, normalize_keys={"observation.image"}, eps=1e-6
)
config = processor.get_config() config = processor.get_config()
assert config == {"normalize_keys": ["observation.image"], "eps": 1e-6} assert config == {"normalize_keys": ["observation.image"], "eps": 1e-6}
@@ -366,7 +464,9 @@ def test_integration_with_robot_processor(normalizer_processor):
# Edge case tests # Edge case tests
def test_empty_observation(): def test_empty_observation():
stats = {"observation.image": {"mean": [0.5], "std": [0.2]}} stats = {"observation.image": {"mean": [0.5], "std": [0.2]}}
normalizer = NormalizerProcessor(stats=stats) features = {"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96))}
norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD}
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
transition = (None, None, None, None, None, None, None) transition = (None, None, None, None, None, None, None)
result = normalizer(transition) result = normalizer(transition)
@@ -375,19 +475,23 @@ def test_empty_observation():
def test_empty_stats(): def test_empty_stats():
normalizer = NormalizerProcessor(stats={}) features = {"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96))}
norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD}
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats={})
observation = {"observation.image": torch.tensor([0.5])} observation = {"observation.image": torch.tensor([0.5])}
transition = (observation, None, None, None, None, None, None) transition = (observation, None, None, None, None, None, None)
result = normalizer(transition) result = normalizer(transition)
# Should return observation unchanged # Should return observation unchanged since no stats are available
assert torch.allclose(result[0]["observation.image"], observation["observation.image"]) assert torch.allclose(result[0]["observation.image"], observation["observation.image"])
def test_partial_stats(): def test_partial_stats():
"""If statistics are incomplete, the value should pass through unchanged.""" """If statistics are incomplete, the value should pass through unchanged."""
stats = {"observation.image": {"mean": [0.5]}} # Missing std / (min,max) stats = {"observation.image": {"mean": [0.5]}} # Missing std / (min,max)
normalizer = NormalizerProcessor(stats=stats) features = {"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96))}
norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD}
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
observation = {"observation.image": torch.tensor([0.7])} observation = {"observation.image": torch.tensor([0.7])}
transition = (observation, None, None, None, None, None, None) transition = (observation, None, None, None, None, None, None)
@@ -399,6 +503,9 @@ def test_missing_action_stats_no_error():
mock_dataset = Mock() mock_dataset = Mock()
mock_dataset.meta.stats = {"observation.image": {"mean": [0.5], "std": [0.2]}} mock_dataset.meta.stats = {"observation.image": {"mean": [0.5], "std": [0.2]}}
processor = UnnormalizerProcessor.from_lerobot_dataset(mock_dataset) features = {"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96))}
norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD}
processor = UnnormalizerProcessor.from_lerobot_dataset(mock_dataset, features, norm_map)
# The tensor stats should not contain the 'action' key # The tensor stats should not contain the 'action' key
assert "action" not in processor._tensor_stats assert "action" not in processor._tensor_stats