mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-24 13:09:43 +00:00
chore (batch handling): Enhance processing components with batch conversion utilities
This commit is contained in:
@@ -25,9 +25,11 @@ from .pipeline import (
|
|||||||
ActionProcessor,
|
ActionProcessor,
|
||||||
DoneProcessor,
|
DoneProcessor,
|
||||||
EnvTransition,
|
EnvTransition,
|
||||||
|
IdentityProcessor,
|
||||||
InfoProcessor,
|
InfoProcessor,
|
||||||
ObservationProcessor,
|
ObservationProcessor,
|
||||||
ProcessorStep,
|
ProcessorStep,
|
||||||
|
ProcessorStepRegistry,
|
||||||
RewardProcessor,
|
RewardProcessor,
|
||||||
RobotProcessor,
|
RobotProcessor,
|
||||||
TruncatedProcessor,
|
TruncatedProcessor,
|
||||||
@@ -39,12 +41,14 @@ __all__ = [
|
|||||||
"DeviceProcessor",
|
"DeviceProcessor",
|
||||||
"DoneProcessor",
|
"DoneProcessor",
|
||||||
"EnvTransition",
|
"EnvTransition",
|
||||||
|
"IdentityProcessor",
|
||||||
"ImageProcessor",
|
"ImageProcessor",
|
||||||
"InfoProcessor",
|
"InfoProcessor",
|
||||||
"NormalizerProcessor",
|
"NormalizerProcessor",
|
||||||
"UnnormalizerProcessor",
|
"UnnormalizerProcessor",
|
||||||
"ObservationProcessor",
|
"ObservationProcessor",
|
||||||
"ProcessorStep",
|
"ProcessorStep",
|
||||||
|
"ProcessorStepRegistry",
|
||||||
"RenameProcessor",
|
"RenameProcessor",
|
||||||
"RewardProcessor",
|
"RewardProcessor",
|
||||||
"RobotProcessor",
|
"RobotProcessor",
|
||||||
|
|||||||
@@ -33,6 +33,9 @@ class DeviceProcessor:
|
|||||||
|
|
||||||
device: str = "cpu"
|
device: str = "cpu"
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
self.non_blocking = "cuda" in self.device
|
||||||
|
|
||||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
observation: dict[str, torch.Tensor] = transition[TransitionIndex.OBSERVATION]
|
observation: dict[str, torch.Tensor] = transition[TransitionIndex.OBSERVATION]
|
||||||
action = transition[TransitionIndex.ACTION]
|
action = transition[TransitionIndex.ACTION]
|
||||||
@@ -43,7 +46,9 @@ class DeviceProcessor:
|
|||||||
complementary_data = transition[TransitionIndex.COMPLEMENTARY_DATA]
|
complementary_data = transition[TransitionIndex.COMPLEMENTARY_DATA]
|
||||||
|
|
||||||
if observation is not None:
|
if observation is not None:
|
||||||
observation = {k: v.to(self.device) for k, v in observation.items()}
|
observation = {
|
||||||
|
k: v.to(self.device, non_blocking=self.non_blocking) for k, v in observation.items()
|
||||||
|
}
|
||||||
if action is not None:
|
if action is not None:
|
||||||
action = action.to(self.device)
|
action = action.to(self.device)
|
||||||
|
|
||||||
|
|||||||
@@ -1,12 +1,13 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, Mapping
|
from typing import Any, Mapping, Optional, Set
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
|
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||||
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionIndex
|
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionIndex
|
||||||
|
|
||||||
@@ -45,8 +46,18 @@ class NormalizerProcessor:
|
|||||||
the normalize_keys parameter.
|
the normalize_keys parameter.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
stats: dict[str, dict[str, Any]]
|
# Features and normalisation map are mandatory to match the design of normalize.py
|
||||||
normalize_keys: set[str] | None = None
|
features: dict[str, PolicyFeature]
|
||||||
|
norm_map: dict[FeatureType, NormalizationMode]
|
||||||
|
|
||||||
|
# Pre-computed statistics coming from dataset.meta.stats for instance.
|
||||||
|
stats: Optional[dict[str, dict[str, Any]]] = None
|
||||||
|
|
||||||
|
# Explicit subset of keys to normalise. If ``None`` every key (except
|
||||||
|
# "action") found in ``stats`` will be normalised. Using a ``set`` makes
|
||||||
|
# membership checks O(1).
|
||||||
|
normalize_keys: Optional[Set[str]] = None
|
||||||
|
|
||||||
eps: float = 1e-8
|
eps: float = 1e-8
|
||||||
|
|
||||||
_tensor_stats: dict[str, dict[str, Tensor]] = field(default_factory=dict, init=False, repr=False)
|
_tensor_stats: dict[str, dict[str, Tensor]] = field(default_factory=dict, init=False, repr=False)
|
||||||
@@ -55,24 +66,48 @@ class NormalizerProcessor:
|
|||||||
def from_lerobot_dataset(
|
def from_lerobot_dataset(
|
||||||
cls,
|
cls,
|
||||||
dataset: LeRobotDataset,
|
dataset: LeRobotDataset,
|
||||||
|
features: dict[str, PolicyFeature],
|
||||||
|
norm_map: dict[FeatureType, NormalizationMode],
|
||||||
*,
|
*,
|
||||||
normalize_keys: set[str] | None = None,
|
normalize_keys: Optional[Set[str]] = None,
|
||||||
eps: float = 1e-8,
|
eps: float = 1e-8,
|
||||||
) -> NormalizerProcessor:
|
) -> "NormalizerProcessor":
|
||||||
return cls(stats=dataset.meta.stats, normalize_keys=normalize_keys, eps=eps)
|
"""Factory helper that pulls statistics from a :class:`LeRobotDataset`.
|
||||||
|
|
||||||
|
The features and norm_map parameters are mandatory to match the design
|
||||||
|
pattern used in normalize.py.
|
||||||
|
"""
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
features=features,
|
||||||
|
norm_map=norm_map,
|
||||||
|
stats=dataset.meta.stats,
|
||||||
|
normalize_keys=normalize_keys,
|
||||||
|
eps=eps,
|
||||||
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
|
# Convert statistics once so we avoid repeated numpy→Tensor conversions
|
||||||
|
# during runtime.
|
||||||
|
self.stats = self.stats or {}
|
||||||
self._tensor_stats = _convert_stats_to_tensors(self.stats)
|
self._tensor_stats = _convert_stats_to_tensors(self.stats)
|
||||||
|
|
||||||
|
# Ensure *normalize_keys* is a set for fast look-ups and compare by
|
||||||
|
# value later when returning the configuration.
|
||||||
|
if self.normalize_keys is not None and not isinstance(self.normalize_keys, set):
|
||||||
|
self.normalize_keys = set(self.normalize_keys)
|
||||||
|
|
||||||
def _normalize_obs(self, observation):
|
def _normalize_obs(self, observation):
|
||||||
if observation is None:
|
if observation is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
keys_to_norm = (
|
# Decide which keys should be normalised for this call.
|
||||||
self.normalize_keys
|
if self.normalize_keys is not None:
|
||||||
if self.normalize_keys is not None
|
keys_to_norm = self.normalize_keys
|
||||||
else {k for k in self._tensor_stats if k != "action"}
|
else:
|
||||||
)
|
# Use feature map to skip action keys.
|
||||||
|
keys_to_norm = {k for k, ft in self.features.items() if ft.type is not FeatureType.ACTION}
|
||||||
|
|
||||||
processed = dict(observation)
|
processed = dict(observation)
|
||||||
for key in keys_to_norm:
|
for key in keys_to_norm:
|
||||||
if key not in processed or key not in self._tensor_stats:
|
if key not in processed or key not in self._tensor_stats:
|
||||||
@@ -126,7 +161,11 @@ class NormalizerProcessor:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def get_config(self) -> dict[str, Any]:
|
def get_config(self) -> dict[str, Any]:
|
||||||
return {"normalize_keys": list(self.normalize_keys) if self.normalize_keys else None, "eps": self.eps}
|
config = {"eps": self.eps}
|
||||||
|
if self.normalize_keys is not None:
|
||||||
|
# Serialise as a list for YAML / JSON friendliness
|
||||||
|
config["normalize_keys"] = sorted(self.normalize_keys)
|
||||||
|
return config
|
||||||
|
|
||||||
def state_dict(self) -> dict[str, Tensor]:
|
def state_dict(self) -> dict[str, Tensor]:
|
||||||
flat = {}
|
flat = {}
|
||||||
@@ -154,8 +193,9 @@ class UnnormalizerProcessor:
|
|||||||
transform.
|
transform.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
stats: dict[str, dict[str, Any]]
|
features: dict[str, PolicyFeature]
|
||||||
unnormalize_keys: set[str] | None = None
|
norm_map: dict[FeatureType, NormalizationMode]
|
||||||
|
stats: Optional[dict[str, dict[str, Any]]] = None
|
||||||
eps: float = 1e-8
|
eps: float = 1e-8
|
||||||
|
|
||||||
_tensor_stats: dict[str, dict[str, Tensor]] = field(default_factory=dict, init=False, repr=False)
|
_tensor_stats: dict[str, dict[str, Tensor]] = field(default_factory=dict, init=False, repr=False)
|
||||||
@@ -164,23 +204,21 @@ class UnnormalizerProcessor:
|
|||||||
def from_lerobot_dataset(
|
def from_lerobot_dataset(
|
||||||
cls,
|
cls,
|
||||||
dataset: LeRobotDataset,
|
dataset: LeRobotDataset,
|
||||||
|
features: dict[str, PolicyFeature],
|
||||||
|
norm_map: dict[FeatureType, NormalizationMode],
|
||||||
*,
|
*,
|
||||||
unnormalize_keys: set[str] | None = None,
|
|
||||||
eps: float = 1e-8,
|
eps: float = 1e-8,
|
||||||
) -> UnnormalizerProcessor:
|
) -> "UnnormalizerProcessor":
|
||||||
return cls(stats=dataset.meta.stats, unnormalize_keys=unnormalize_keys, eps=eps)
|
return cls(features=features, norm_map=norm_map, stats=dataset.meta.stats, eps=eps)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
|
self.stats = self.stats or {}
|
||||||
self._tensor_stats = _convert_stats_to_tensors(self.stats)
|
self._tensor_stats = _convert_stats_to_tensors(self.stats)
|
||||||
|
|
||||||
def _unnormalize_obs(self, observation):
|
def _unnormalize_obs(self, observation):
|
||||||
if observation is None:
|
if observation is None:
|
||||||
return None
|
return None
|
||||||
keys = (
|
keys = [k for k, ft in self.features.items() if ft.type is not FeatureType.ACTION]
|
||||||
self.unnormalize_keys
|
|
||||||
if self.unnormalize_keys is not None
|
|
||||||
else {k for k in self._tensor_stats if k != "action"}
|
|
||||||
)
|
|
||||||
processed = dict(observation)
|
processed = dict(observation)
|
||||||
for key in keys:
|
for key in keys:
|
||||||
if key not in processed or key not in self._tensor_stats:
|
if key not in processed or key not in self._tensor_stats:
|
||||||
@@ -231,10 +269,7 @@ class UnnormalizerProcessor:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def get_config(self) -> dict[str, Any]:
|
def get_config(self) -> dict[str, Any]:
|
||||||
return {
|
return {"eps": self.eps}
|
||||||
"unnormalize_keys": list(self.unnormalize_keys) if self.unnormalize_keys else None,
|
|
||||||
"eps": self.eps,
|
|
||||||
}
|
|
||||||
|
|
||||||
def state_dict(self) -> dict[str, Tensor]:
|
def state_dict(self) -> dict[str, Tensor]:
|
||||||
flat = {}
|
flat = {}
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ class TransitionIndex(IntEnum):
|
|||||||
|
|
||||||
# (observation, action, reward, done, truncated, info, complementary_data)
|
# (observation, action, reward, done, truncated, info, complementary_data)
|
||||||
EnvTransition = Tuple[
|
EnvTransition = Tuple[
|
||||||
Any | None, # observation
|
dict[str, Any] | None, # observation
|
||||||
Any | None, # action
|
Any | None, # action
|
||||||
float | None, # reward
|
float | None, # reward
|
||||||
bool | None, # done
|
bool | None, # done
|
||||||
@@ -162,6 +162,79 @@ class ProcessorStep(Protocol):
|
|||||||
def reset(self) -> None: ...
|
def reset(self) -> None: ...
|
||||||
|
|
||||||
|
|
||||||
|
def _default_batch_to_transition(batch: dict[str, Any]) -> EnvTransition: # noqa: D401
|
||||||
|
"""Convert a *batch* dict coming from Learobot replay/dataset code into an
|
||||||
|
``EnvTransition`` tuple.
|
||||||
|
|
||||||
|
The function is intentionally **strictly positional** – it maps well known
|
||||||
|
keys to the fixed slot order used inside the pipeline. Missing keys are
|
||||||
|
filled with sane defaults (``None`` or ``0.0``/``False``).
|
||||||
|
|
||||||
|
Keys recognised (case-sensitive):
|
||||||
|
|
||||||
|
* "observation.*" (keys starting with "observation." are grouped into observation dict)
|
||||||
|
* "action"
|
||||||
|
* "next.reward"
|
||||||
|
* "next.done"
|
||||||
|
* "next.truncated"
|
||||||
|
* "info"
|
||||||
|
|
||||||
|
Additional keys are ignored so that existing dataloaders can carry extra
|
||||||
|
metadata without breaking the processor.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Handle observation and observation.* keys
|
||||||
|
observation_keys = {k: v for k, v in batch.items() if k.startswith("observation.")}
|
||||||
|
|
||||||
|
observation = None
|
||||||
|
if observation_keys:
|
||||||
|
observation = {}
|
||||||
|
# Add observation.* keys to the observation dict, removing the "observation." prefix
|
||||||
|
for key, value in observation_keys.items():
|
||||||
|
observation[key] = value
|
||||||
|
|
||||||
|
return (
|
||||||
|
observation,
|
||||||
|
batch.get("action"),
|
||||||
|
batch.get("next.reward", 0.0),
|
||||||
|
batch.get("next.done", False),
|
||||||
|
batch.get("next.truncated", False),
|
||||||
|
batch.get("info", {}),
|
||||||
|
{},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _default_transition_to_batch(transition: EnvTransition) -> dict[str, Any]: # noqa: D401
|
||||||
|
"""Inverse of :pyfunc:`_default_batch_to_transition`. Returns a dict with
|
||||||
|
the canonical field names used throughout *LeRobot*.
|
||||||
|
"""
|
||||||
|
|
||||||
|
(
|
||||||
|
observation,
|
||||||
|
action,
|
||||||
|
reward,
|
||||||
|
done,
|
||||||
|
truncated,
|
||||||
|
info,
|
||||||
|
_,
|
||||||
|
) = transition
|
||||||
|
|
||||||
|
batch = {
|
||||||
|
"action": action,
|
||||||
|
"next.reward": reward,
|
||||||
|
"next.done": done,
|
||||||
|
"next.truncated": truncated,
|
||||||
|
"info": info,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Handle observation - flatten dict to observation.* keys if it's a dict
|
||||||
|
if isinstance(observation, dict):
|
||||||
|
# Check if this looks like a dict that was created from observation.* keys
|
||||||
|
for key, value in observation.items():
|
||||||
|
batch[key] = value
|
||||||
|
return batch
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class RobotProcessor(ModelHubMixin):
|
class RobotProcessor(ModelHubMixin):
|
||||||
"""
|
"""
|
||||||
@@ -200,6 +273,13 @@ class RobotProcessor(ModelHubMixin):
|
|||||||
name: str = "RobotProcessor"
|
name: str = "RobotProcessor"
|
||||||
seed: int | None = None
|
seed: int | None = None
|
||||||
|
|
||||||
|
to_transition: Callable[[dict[str, Any]], EnvTransition] = field(
|
||||||
|
default_factory=lambda: _default_batch_to_transition, repr=False
|
||||||
|
)
|
||||||
|
to_batch: Callable[[EnvTransition], dict[str, Any]] = field(
|
||||||
|
default_factory=lambda: _default_transition_to_batch, repr=False
|
||||||
|
)
|
||||||
|
|
||||||
# Processor-level hooks
|
# Processor-level hooks
|
||||||
# A hook can optionally return a modified transition. If it returns
|
# A hook can optionally return a modified transition. If it returns
|
||||||
# ``None`` the current value is left untouched.
|
# ``None`` the current value is left untouched.
|
||||||
@@ -211,14 +291,29 @@ class RobotProcessor(ModelHubMixin):
|
|||||||
)
|
)
|
||||||
reset_hooks: list[Callable[[], None]] = field(default_factory=list, repr=False)
|
reset_hooks: list[Callable[[], None]] = field(default_factory=list, repr=False)
|
||||||
|
|
||||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
def __call__(self, data: EnvTransition | dict[str, Any]):
|
||||||
"""Run *transition* through every step, firing hooks on the way."""
|
"""Process *data* through all steps.
|
||||||
|
|
||||||
# Basic validation with helpful error message
|
The method accepts **either** the classic :pydata:`EnvTransition` tuple
|
||||||
|
**or** a *batch* dictionary (like the ones returned by
|
||||||
|
:class:`lerobot.utils.buffer.ReplayBuffer` or
|
||||||
|
:class:`lerobot.datasets.lerobot_dataset.LeRobotDataset`). If a dict is
|
||||||
|
supplied it is first converted to the internal tuple format using
|
||||||
|
:pyattr:`to_transition`; after all steps are executed the tuple is
|
||||||
|
transformed back into a dict with :pyattr:`to_batch` and the result is
|
||||||
|
returned – thereby preserving the caller's original data type.
|
||||||
|
"""
|
||||||
|
|
||||||
|
called_with_batch = isinstance(data, dict)
|
||||||
|
|
||||||
|
transition = self.to_transition(data) if called_with_batch else data
|
||||||
|
|
||||||
|
# Basic validation with helpful error message for tuple input
|
||||||
if not isinstance(transition, tuple) or len(transition) != 7:
|
if not isinstance(transition, tuple) or len(transition) != 7:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"EnvTransition must be a 7-tuple of (observation, action, reward, done, truncated, info, complementary_data), "
|
"EnvTransition must be a 7-tuple of (observation, action, reward, done, "
|
||||||
f"got {type(transition).__name__} with length {len(transition) if hasattr(transition, '__len__') else 'unknown'}"
|
"truncated, info, complementary_data). "
|
||||||
|
f"Got {type(transition).__name__} with length {len(transition) if hasattr(transition, '__len__') else 'unknown'}."
|
||||||
)
|
)
|
||||||
|
|
||||||
for idx, processor_step in enumerate(self.steps):
|
for idx, processor_step in enumerate(self.steps):
|
||||||
@@ -234,7 +329,7 @@ class RobotProcessor(ModelHubMixin):
|
|||||||
if updated is not None:
|
if updated is not None:
|
||||||
transition = updated
|
transition = updated
|
||||||
|
|
||||||
return transition
|
return self.to_batch(transition) if called_with_batch else transition
|
||||||
|
|
||||||
def step_through(self, transition: EnvTransition) -> Iterable[EnvTransition]:
|
def step_through(self, transition: EnvTransition) -> Iterable[EnvTransition]:
|
||||||
"""Yield the intermediate Transition instances after each processor step."""
|
"""Yield the intermediate Transition instances after each processor step."""
|
||||||
@@ -737,3 +832,22 @@ class InfoProcessor:
|
|||||||
*transition[TransitionIndex.COMPLEMENTARY_DATA :],
|
*transition[TransitionIndex.COMPLEMENTARY_DATA :],
|
||||||
)
|
)
|
||||||
return transition
|
return transition
|
||||||
|
|
||||||
|
|
||||||
|
class IdentityProcessor:
|
||||||
|
"""Identity processor that does nothing."""
|
||||||
|
|
||||||
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
|
return transition
|
||||||
|
|
||||||
|
def get_config(self) -> dict[str, Any]:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
pass
|
||||||
|
|||||||
@@ -0,0 +1,288 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
from lerobot.processor.pipeline import (
|
||||||
|
RobotProcessor,
|
||||||
|
TransitionIndex,
|
||||||
|
_default_batch_to_transition,
|
||||||
|
_default_transition_to_batch,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _dummy_batch():
|
||||||
|
"""Create a dummy batch using the new format with observation.* and next.* keys."""
|
||||||
|
return {
|
||||||
|
"observation.image.left": torch.randn(1, 3, 128, 128),
|
||||||
|
"observation.image.right": torch.randn(1, 3, 128, 128),
|
||||||
|
"observation.state": torch.tensor([[0.1, 0.2, 0.3, 0.4]]),
|
||||||
|
"action": torch.tensor([[0.5]]),
|
||||||
|
"next.reward": 1.0,
|
||||||
|
"next.done": False,
|
||||||
|
"next.truncated": False,
|
||||||
|
"info": {"key": "value"},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_observation_grouping_roundtrip():
|
||||||
|
"""Test that observation.* keys are properly grouped and ungrouped."""
|
||||||
|
proc = RobotProcessor([])
|
||||||
|
batch_in = _dummy_batch()
|
||||||
|
batch_out = proc(batch_in)
|
||||||
|
|
||||||
|
# Check that all observation.* keys are preserved
|
||||||
|
original_obs_keys = {k: v for k, v in batch_in.items() if k.startswith("observation.")}
|
||||||
|
reconstructed_obs_keys = {k: v for k, v in batch_out.items() if k.startswith("observation.")}
|
||||||
|
|
||||||
|
assert set(original_obs_keys.keys()) == set(reconstructed_obs_keys.keys())
|
||||||
|
|
||||||
|
# Check tensor values
|
||||||
|
assert torch.allclose(batch_out["observation.image.left"], batch_in["observation.image.left"])
|
||||||
|
assert torch.allclose(batch_out["observation.image.right"], batch_in["observation.image.right"])
|
||||||
|
assert torch.allclose(batch_out["observation.state"], batch_in["observation.state"])
|
||||||
|
|
||||||
|
# Check other fields
|
||||||
|
assert torch.allclose(batch_out["action"], batch_in["action"])
|
||||||
|
assert batch_out["next.reward"] == batch_in["next.reward"]
|
||||||
|
assert batch_out["next.done"] == batch_in["next.done"]
|
||||||
|
assert batch_out["next.truncated"] == batch_in["next.truncated"]
|
||||||
|
assert batch_out["info"] == batch_in["info"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_batch_to_transition_observation_grouping():
|
||||||
|
"""Test that _default_batch_to_transition correctly groups observation.* keys."""
|
||||||
|
batch = {
|
||||||
|
"observation.image.top": torch.randn(1, 3, 128, 128),
|
||||||
|
"observation.image.left": torch.randn(1, 3, 128, 128),
|
||||||
|
"observation.state": [1, 2, 3, 4],
|
||||||
|
"action": "action_data",
|
||||||
|
"next.reward": 1.5,
|
||||||
|
"next.done": True,
|
||||||
|
"next.truncated": False,
|
||||||
|
"info": {"episode": 42},
|
||||||
|
}
|
||||||
|
|
||||||
|
transition = _default_batch_to_transition(batch)
|
||||||
|
|
||||||
|
# Check observation is a dict with all observation.* keys
|
||||||
|
assert isinstance(transition[TransitionIndex.OBSERVATION], dict)
|
||||||
|
assert "observation.image.top" in transition[TransitionIndex.OBSERVATION]
|
||||||
|
assert "observation.image.left" in transition[TransitionIndex.OBSERVATION]
|
||||||
|
assert "observation.state" in transition[TransitionIndex.OBSERVATION]
|
||||||
|
|
||||||
|
# Check values are preserved
|
||||||
|
assert torch.allclose(
|
||||||
|
transition[TransitionIndex.OBSERVATION]["observation.image.top"], batch["observation.image.top"]
|
||||||
|
)
|
||||||
|
assert torch.allclose(
|
||||||
|
transition[TransitionIndex.OBSERVATION]["observation.image.left"], batch["observation.image.left"]
|
||||||
|
)
|
||||||
|
assert transition[TransitionIndex.OBSERVATION]["observation.state"] == [1, 2, 3, 4]
|
||||||
|
|
||||||
|
# Check other fields
|
||||||
|
assert transition[TransitionIndex.ACTION] == "action_data"
|
||||||
|
assert transition[TransitionIndex.REWARD] == 1.5
|
||||||
|
assert transition[TransitionIndex.DONE]
|
||||||
|
assert not transition[TransitionIndex.TRUNCATED]
|
||||||
|
assert transition[TransitionIndex.INFO] == {"episode": 42}
|
||||||
|
assert transition[TransitionIndex.COMPLEMENTARY_DATA] == {}
|
||||||
|
|
||||||
|
|
||||||
|
def test_transition_to_batch_observation_flattening():
|
||||||
|
"""Test that _default_transition_to_batch correctly flattens observation dict."""
|
||||||
|
observation_dict = {
|
||||||
|
"observation.image.top": torch.randn(1, 3, 128, 128),
|
||||||
|
"observation.image.left": torch.randn(1, 3, 128, 128),
|
||||||
|
"observation.state": [1, 2, 3, 4],
|
||||||
|
}
|
||||||
|
|
||||||
|
transition = (
|
||||||
|
observation_dict, # observation
|
||||||
|
"action_data", # action
|
||||||
|
1.5, # reward
|
||||||
|
True, # done
|
||||||
|
False, # truncated
|
||||||
|
{"episode": 42}, # info
|
||||||
|
{}, # complementary_data
|
||||||
|
)
|
||||||
|
|
||||||
|
batch = _default_transition_to_batch(transition)
|
||||||
|
|
||||||
|
# Check that observation.* keys are flattened back to batch
|
||||||
|
assert "observation.image.top" in batch
|
||||||
|
assert "observation.image.left" in batch
|
||||||
|
assert "observation.state" in batch
|
||||||
|
|
||||||
|
# Check values are preserved
|
||||||
|
assert torch.allclose(batch["observation.image.top"], observation_dict["observation.image.top"])
|
||||||
|
assert torch.allclose(batch["observation.image.left"], observation_dict["observation.image.left"])
|
||||||
|
assert batch["observation.state"] == [1, 2, 3, 4]
|
||||||
|
|
||||||
|
# Check other fields are mapped to next.* format
|
||||||
|
assert batch["action"] == "action_data"
|
||||||
|
assert batch["next.reward"] == 1.5
|
||||||
|
assert batch["next.done"]
|
||||||
|
assert not batch["next.truncated"]
|
||||||
|
assert batch["info"] == {"episode": 42}
|
||||||
|
|
||||||
|
|
||||||
|
def test_no_observation_keys():
|
||||||
|
"""Test behavior when there are no observation.* keys."""
|
||||||
|
batch = {
|
||||||
|
"action": "action_data",
|
||||||
|
"next.reward": 2.0,
|
||||||
|
"next.done": False,
|
||||||
|
"next.truncated": True,
|
||||||
|
"info": {"test": "no_obs"},
|
||||||
|
}
|
||||||
|
|
||||||
|
transition = _default_batch_to_transition(batch)
|
||||||
|
|
||||||
|
# Observation should be None when no observation.* keys
|
||||||
|
assert transition[TransitionIndex.OBSERVATION] is None
|
||||||
|
|
||||||
|
# Check other fields
|
||||||
|
assert transition[TransitionIndex.ACTION] == "action_data"
|
||||||
|
assert transition[TransitionIndex.REWARD] == 2.0
|
||||||
|
assert not transition[TransitionIndex.DONE]
|
||||||
|
assert transition[TransitionIndex.TRUNCATED]
|
||||||
|
assert transition[TransitionIndex.INFO] == {"test": "no_obs"}
|
||||||
|
|
||||||
|
# Round trip should work
|
||||||
|
reconstructed_batch = _default_transition_to_batch(transition)
|
||||||
|
assert reconstructed_batch["action"] == "action_data"
|
||||||
|
assert reconstructed_batch["next.reward"] == 2.0
|
||||||
|
assert not reconstructed_batch["next.done"]
|
||||||
|
assert reconstructed_batch["next.truncated"]
|
||||||
|
assert reconstructed_batch["info"] == {"test": "no_obs"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_minimal_batch():
|
||||||
|
"""Test with minimal batch containing only observation.* and action."""
|
||||||
|
batch = {"observation.state": "minimal_state", "action": "minimal_action"}
|
||||||
|
|
||||||
|
transition = _default_batch_to_transition(batch)
|
||||||
|
|
||||||
|
# Check observation
|
||||||
|
assert transition[TransitionIndex.OBSERVATION] == {"observation.state": "minimal_state"}
|
||||||
|
assert transition[TransitionIndex.ACTION] == "minimal_action"
|
||||||
|
|
||||||
|
# Check defaults
|
||||||
|
assert transition[TransitionIndex.REWARD] == 0.0
|
||||||
|
assert not transition[TransitionIndex.DONE]
|
||||||
|
assert not transition[TransitionIndex.TRUNCATED]
|
||||||
|
assert transition[TransitionIndex.INFO] == {}
|
||||||
|
assert transition[TransitionIndex.COMPLEMENTARY_DATA] == {}
|
||||||
|
|
||||||
|
# Round trip
|
||||||
|
reconstructed_batch = _default_transition_to_batch(transition)
|
||||||
|
assert reconstructed_batch["observation.state"] == "minimal_state"
|
||||||
|
assert reconstructed_batch["action"] == "minimal_action"
|
||||||
|
assert reconstructed_batch["next.reward"] == 0.0
|
||||||
|
assert not reconstructed_batch["next.done"]
|
||||||
|
assert not reconstructed_batch["next.truncated"]
|
||||||
|
assert reconstructed_batch["info"] == {}
|
||||||
|
|
||||||
|
|
||||||
|
def test_empty_batch():
|
||||||
|
"""Test behavior with empty batch."""
|
||||||
|
batch = {}
|
||||||
|
|
||||||
|
transition = _default_batch_to_transition(batch)
|
||||||
|
|
||||||
|
# All fields should have defaults
|
||||||
|
assert transition[TransitionIndex.OBSERVATION] is None
|
||||||
|
assert transition[TransitionIndex.ACTION] is None
|
||||||
|
assert transition[TransitionIndex.REWARD] == 0.0
|
||||||
|
assert not transition[TransitionIndex.DONE]
|
||||||
|
assert not transition[TransitionIndex.TRUNCATED]
|
||||||
|
assert transition[TransitionIndex.INFO] == {}
|
||||||
|
assert transition[TransitionIndex.COMPLEMENTARY_DATA] == {}
|
||||||
|
|
||||||
|
# Round trip
|
||||||
|
reconstructed_batch = _default_transition_to_batch(transition)
|
||||||
|
assert reconstructed_batch["action"] is None
|
||||||
|
assert reconstructed_batch["next.reward"] == 0.0
|
||||||
|
assert not reconstructed_batch["next.done"]
|
||||||
|
assert not reconstructed_batch["next.truncated"]
|
||||||
|
assert reconstructed_batch["info"] == {}
|
||||||
|
|
||||||
|
|
||||||
|
def test_complex_nested_observation():
|
||||||
|
"""Test with complex nested observation data."""
|
||||||
|
batch = {
|
||||||
|
"observation.image.top": {"image": torch.randn(1, 3, 128, 128), "timestamp": 1234567890},
|
||||||
|
"observation.image.left": {"image": torch.randn(1, 3, 128, 128), "timestamp": 1234567891},
|
||||||
|
"observation.state": torch.randn(7),
|
||||||
|
"action": torch.randn(8),
|
||||||
|
"next.reward": 3.14,
|
||||||
|
"next.done": False,
|
||||||
|
"next.truncated": True,
|
||||||
|
"info": {"episode_length": 200, "success": True},
|
||||||
|
}
|
||||||
|
|
||||||
|
transition = _default_batch_to_transition(batch)
|
||||||
|
reconstructed_batch = _default_transition_to_batch(transition)
|
||||||
|
|
||||||
|
# Check that all observation keys are preserved
|
||||||
|
original_obs_keys = {k for k in batch.keys() if k.startswith("observation.")}
|
||||||
|
reconstructed_obs_keys = {k for k in reconstructed_batch.keys() if k.startswith("observation.")}
|
||||||
|
|
||||||
|
assert original_obs_keys == reconstructed_obs_keys
|
||||||
|
|
||||||
|
# Check tensor values
|
||||||
|
assert torch.allclose(batch["observation.state"], reconstructed_batch["observation.state"])
|
||||||
|
|
||||||
|
# Check nested dict with tensors
|
||||||
|
assert torch.allclose(
|
||||||
|
batch["observation.image.top"]["image"], reconstructed_batch["observation.image.top"]["image"]
|
||||||
|
)
|
||||||
|
assert torch.allclose(
|
||||||
|
batch["observation.image.left"]["image"], reconstructed_batch["observation.image.left"]["image"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check action tensor
|
||||||
|
assert torch.allclose(batch["action"], reconstructed_batch["action"])
|
||||||
|
|
||||||
|
# Check other fields
|
||||||
|
assert batch["next.reward"] == reconstructed_batch["next.reward"]
|
||||||
|
assert batch["next.done"] == reconstructed_batch["next.done"]
|
||||||
|
assert batch["next.truncated"] == reconstructed_batch["next.truncated"]
|
||||||
|
assert batch["info"] == reconstructed_batch["info"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_custom_converter():
|
||||||
|
"""Test that custom converters can still be used."""
|
||||||
|
|
||||||
|
def to_tr(batch):
|
||||||
|
# Custom converter that modifies the reward
|
||||||
|
tr = _default_batch_to_transition(batch)
|
||||||
|
# Double the reward
|
||||||
|
reward = tr[TransitionIndex.REWARD] * 2 if tr[TransitionIndex.REWARD] is not None else 0.0
|
||||||
|
return (
|
||||||
|
tr[TransitionIndex.OBSERVATION],
|
||||||
|
tr[TransitionIndex.ACTION],
|
||||||
|
reward,
|
||||||
|
tr[TransitionIndex.DONE],
|
||||||
|
tr[TransitionIndex.TRUNCATED],
|
||||||
|
tr[TransitionIndex.INFO],
|
||||||
|
tr[TransitionIndex.COMPLEMENTARY_DATA],
|
||||||
|
)
|
||||||
|
|
||||||
|
def to_batch(tr):
|
||||||
|
# Custom converter that adds a custom field
|
||||||
|
batch = _default_transition_to_batch(tr)
|
||||||
|
batch["custom_field"] = "custom_value"
|
||||||
|
return batch
|
||||||
|
|
||||||
|
proc = RobotProcessor([], to_transition=to_tr, to_batch=to_batch)
|
||||||
|
batch = _dummy_batch()
|
||||||
|
out = proc(batch)
|
||||||
|
|
||||||
|
# Check that custom modifications were applied
|
||||||
|
assert out["next.reward"] == batch["next.reward"] * 2
|
||||||
|
assert out["custom_field"] == "custom_value"
|
||||||
|
|
||||||
|
# Check that observation.* keys are still preserved
|
||||||
|
original_obs_keys = {k: v for k, v in batch.items() if k.startswith("observation.")}
|
||||||
|
output_obs_keys = {k: v for k, v in out.items() if k.startswith("observation.")}
|
||||||
|
|
||||||
|
assert set(original_obs_keys.keys()) == set(output_obs_keys.keys())
|
||||||
@@ -4,6 +4,7 @@ import numpy as np
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||||
from lerobot.processor.normalize_processor import (
|
from lerobot.processor.normalize_processor import (
|
||||||
NormalizerProcessor,
|
NormalizerProcessor,
|
||||||
UnnormalizerProcessor,
|
UnnormalizerProcessor,
|
||||||
@@ -76,6 +77,21 @@ def test_unsupported_type():
|
|||||||
_convert_stats_to_tensors(stats)
|
_convert_stats_to_tensors(stats)
|
||||||
|
|
||||||
|
|
||||||
|
# Helper functions to create feature maps and norm maps
|
||||||
|
def _create_observation_features():
|
||||||
|
return {
|
||||||
|
"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96)),
|
||||||
|
"observation.state": PolicyFeature(FeatureType.STATE, (2,)),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _create_observation_norm_map():
|
||||||
|
return {
|
||||||
|
FeatureType.VISUAL: NormalizationMode.MEAN_STD,
|
||||||
|
FeatureType.STATE: NormalizationMode.MIN_MAX,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
# Fixtures for observation normalisation tests using NormalizerProcessor
|
# Fixtures for observation normalisation tests using NormalizerProcessor
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def observation_stats():
|
def observation_stats():
|
||||||
@@ -94,7 +110,9 @@ def observation_stats():
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def observation_normalizer(observation_stats):
|
def observation_normalizer(observation_stats):
|
||||||
"""Return a NormalizerProcessor that only has observation stats (no action)."""
|
"""Return a NormalizerProcessor that only has observation stats (no action)."""
|
||||||
return NormalizerProcessor(stats=observation_stats)
|
features = _create_observation_features()
|
||||||
|
norm_map = _create_observation_norm_map()
|
||||||
|
return NormalizerProcessor(features=features, norm_map=norm_map, stats=observation_stats)
|
||||||
|
|
||||||
|
|
||||||
def test_mean_std_normalization(observation_normalizer):
|
def test_mean_std_normalization(observation_normalizer):
|
||||||
@@ -129,7 +147,11 @@ def test_min_max_normalization(observation_normalizer):
|
|||||||
|
|
||||||
|
|
||||||
def test_selective_normalization(observation_stats):
|
def test_selective_normalization(observation_stats):
|
||||||
normalizer = NormalizerProcessor(stats=observation_stats, normalize_keys={"observation.image"})
|
features = _create_observation_features()
|
||||||
|
norm_map = _create_observation_norm_map()
|
||||||
|
normalizer = NormalizerProcessor(
|
||||||
|
features=features, norm_map=norm_map, stats=observation_stats, normalize_keys={"observation.image"}
|
||||||
|
)
|
||||||
|
|
||||||
observation = {
|
observation = {
|
||||||
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
|
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
|
||||||
@@ -148,7 +170,9 @@ def test_selective_normalization(observation_stats):
|
|||||||
|
|
||||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||||
def test_device_compatibility(observation_stats):
|
def test_device_compatibility(observation_stats):
|
||||||
normalizer = NormalizerProcessor(stats=observation_stats)
|
features = _create_observation_features()
|
||||||
|
norm_map = _create_observation_norm_map()
|
||||||
|
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=observation_stats)
|
||||||
observation = {
|
observation = {
|
||||||
"observation.image": torch.tensor([0.7, 0.5, 0.3]).cuda(),
|
"observation.image": torch.tensor([0.7, 0.5, 0.3]).cuda(),
|
||||||
}
|
}
|
||||||
@@ -165,10 +189,19 @@ def test_from_lerobot_dataset():
|
|||||||
mock_dataset = Mock()
|
mock_dataset = Mock()
|
||||||
mock_dataset.meta.stats = {
|
mock_dataset.meta.stats = {
|
||||||
"observation.image": {"mean": [0.5], "std": [0.2]},
|
"observation.image": {"mean": [0.5], "std": [0.2]},
|
||||||
"action": {"mean": [0.0], "std": [1.0]}, # Should be filtered out
|
"action": {"mean": [0.0], "std": [1.0]},
|
||||||
}
|
}
|
||||||
|
|
||||||
normalizer = NormalizerProcessor.from_lerobot_dataset(mock_dataset)
|
features = {
|
||||||
|
"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96)),
|
||||||
|
"action": PolicyFeature(FeatureType.ACTION, (1,)),
|
||||||
|
}
|
||||||
|
norm_map = {
|
||||||
|
FeatureType.VISUAL: NormalizationMode.MEAN_STD,
|
||||||
|
FeatureType.ACTION: NormalizationMode.MEAN_STD,
|
||||||
|
}
|
||||||
|
|
||||||
|
normalizer = NormalizerProcessor.from_lerobot_dataset(mock_dataset, features, norm_map)
|
||||||
|
|
||||||
# Both observation and action statistics should be present in tensor stats
|
# Both observation and action statistics should be present in tensor stats
|
||||||
assert "observation.image" in normalizer._tensor_stats
|
assert "observation.image" in normalizer._tensor_stats
|
||||||
@@ -180,7 +213,9 @@ def test_state_dict_save_load(observation_normalizer):
|
|||||||
state_dict = observation_normalizer.state_dict()
|
state_dict = observation_normalizer.state_dict()
|
||||||
|
|
||||||
# Create new normalizer and load state
|
# Create new normalizer and load state
|
||||||
new_normalizer = NormalizerProcessor(stats={})
|
features = _create_observation_features()
|
||||||
|
norm_map = _create_observation_norm_map()
|
||||||
|
new_normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats={})
|
||||||
new_normalizer.load_state_dict(state_dict)
|
new_normalizer.load_state_dict(state_dict)
|
||||||
|
|
||||||
# Test that it works the same
|
# Test that it works the same
|
||||||
@@ -210,8 +245,30 @@ def action_stats_min_max():
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _create_action_features():
|
||||||
|
return {
|
||||||
|
"action": PolicyFeature(FeatureType.ACTION, (3,)),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _create_action_norm_map_mean_std():
|
||||||
|
return {
|
||||||
|
FeatureType.ACTION: NormalizationMode.MEAN_STD,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _create_action_norm_map_min_max():
|
||||||
|
return {
|
||||||
|
FeatureType.ACTION: NormalizationMode.MIN_MAX,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def test_mean_std_unnormalization(action_stats_mean_std):
|
def test_mean_std_unnormalization(action_stats_mean_std):
|
||||||
unnormalizer = UnnormalizerProcessor(stats={"action": action_stats_mean_std})
|
features = _create_action_features()
|
||||||
|
norm_map = _create_action_norm_map_mean_std()
|
||||||
|
unnormalizer = UnnormalizerProcessor(
|
||||||
|
features=features, norm_map=norm_map, stats={"action": action_stats_mean_std}
|
||||||
|
)
|
||||||
|
|
||||||
normalized_action = torch.tensor([1.0, -0.5, 2.0])
|
normalized_action = torch.tensor([1.0, -0.5, 2.0])
|
||||||
transition = (None, normalized_action, None, None, None, None, None)
|
transition = (None, normalized_action, None, None, None, None, None)
|
||||||
@@ -225,7 +282,11 @@ def test_mean_std_unnormalization(action_stats_mean_std):
|
|||||||
|
|
||||||
|
|
||||||
def test_min_max_unnormalization(action_stats_min_max):
|
def test_min_max_unnormalization(action_stats_min_max):
|
||||||
unnormalizer = UnnormalizerProcessor(stats={"action": action_stats_min_max})
|
features = _create_action_features()
|
||||||
|
norm_map = _create_action_norm_map_min_max()
|
||||||
|
unnormalizer = UnnormalizerProcessor(
|
||||||
|
features=features, norm_map=norm_map, stats={"action": action_stats_min_max}
|
||||||
|
)
|
||||||
|
|
||||||
# Actions in [-1, 1]
|
# Actions in [-1, 1]
|
||||||
normalized_action = torch.tensor([0.0, -1.0, 1.0])
|
normalized_action = torch.tensor([0.0, -1.0, 1.0])
|
||||||
@@ -247,7 +308,11 @@ def test_min_max_unnormalization(action_stats_min_max):
|
|||||||
|
|
||||||
|
|
||||||
def test_numpy_action_input(action_stats_mean_std):
|
def test_numpy_action_input(action_stats_mean_std):
|
||||||
unnormalizer = UnnormalizerProcessor(stats={"action": action_stats_mean_std})
|
features = _create_action_features()
|
||||||
|
norm_map = _create_action_norm_map_mean_std()
|
||||||
|
unnormalizer = UnnormalizerProcessor(
|
||||||
|
features=features, norm_map=norm_map, stats={"action": action_stats_mean_std}
|
||||||
|
)
|
||||||
|
|
||||||
normalized_action = np.array([1.0, -0.5, 2.0], dtype=np.float32)
|
normalized_action = np.array([1.0, -0.5, 2.0], dtype=np.float32)
|
||||||
transition = (None, normalized_action, None, None, None, None, None)
|
transition = (None, normalized_action, None, None, None, None, None)
|
||||||
@@ -261,7 +326,11 @@ def test_numpy_action_input(action_stats_mean_std):
|
|||||||
|
|
||||||
|
|
||||||
def test_none_action(action_stats_mean_std):
|
def test_none_action(action_stats_mean_std):
|
||||||
unnormalizer = UnnormalizerProcessor(stats={"action": action_stats_mean_std})
|
features = _create_action_features()
|
||||||
|
norm_map = _create_action_norm_map_mean_std()
|
||||||
|
unnormalizer = UnnormalizerProcessor(
|
||||||
|
features=features, norm_map=norm_map, stats={"action": action_stats_mean_std}
|
||||||
|
)
|
||||||
|
|
||||||
transition = (None, None, None, None, None, None, None)
|
transition = (None, None, None, None, None, None, None)
|
||||||
result = unnormalizer(transition)
|
result = unnormalizer(transition)
|
||||||
@@ -273,7 +342,9 @@ def test_none_action(action_stats_mean_std):
|
|||||||
def test_action_from_lerobot_dataset():
|
def test_action_from_lerobot_dataset():
|
||||||
mock_dataset = Mock()
|
mock_dataset = Mock()
|
||||||
mock_dataset.meta.stats = {"action": {"mean": [0.0], "std": [1.0]}}
|
mock_dataset.meta.stats = {"action": {"mean": [0.0], "std": [1.0]}}
|
||||||
unnormalizer = UnnormalizerProcessor.from_lerobot_dataset(mock_dataset)
|
features = {"action": PolicyFeature(FeatureType.ACTION, (1,))}
|
||||||
|
norm_map = {FeatureType.ACTION: NormalizationMode.MEAN_STD}
|
||||||
|
unnormalizer = UnnormalizerProcessor.from_lerobot_dataset(mock_dataset, features, norm_map)
|
||||||
assert "mean" in unnormalizer._tensor_stats["action"]
|
assert "mean" in unnormalizer._tensor_stats["action"]
|
||||||
|
|
||||||
|
|
||||||
@@ -296,9 +367,27 @@ def full_stats():
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _create_full_features():
|
||||||
|
return {
|
||||||
|
"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96)),
|
||||||
|
"observation.state": PolicyFeature(FeatureType.STATE, (2,)),
|
||||||
|
"action": PolicyFeature(FeatureType.ACTION, (2,)),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _create_full_norm_map():
|
||||||
|
return {
|
||||||
|
FeatureType.VISUAL: NormalizationMode.MEAN_STD,
|
||||||
|
FeatureType.STATE: NormalizationMode.MIN_MAX,
|
||||||
|
FeatureType.ACTION: NormalizationMode.MEAN_STD,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def normalizer_processor(full_stats):
|
def normalizer_processor(full_stats):
|
||||||
return NormalizerProcessor(stats=full_stats)
|
features = _create_full_features()
|
||||||
|
norm_map = _create_full_norm_map()
|
||||||
|
return NormalizerProcessor(features=features, norm_map=norm_map, stats=full_stats)
|
||||||
|
|
||||||
|
|
||||||
def test_combined_normalization(normalizer_processor):
|
def test_combined_normalization(normalizer_processor):
|
||||||
@@ -331,7 +420,12 @@ def test_processor_from_lerobot_dataset(full_stats):
|
|||||||
mock_dataset = Mock()
|
mock_dataset = Mock()
|
||||||
mock_dataset.meta.stats = full_stats
|
mock_dataset.meta.stats = full_stats
|
||||||
|
|
||||||
processor = NormalizerProcessor.from_lerobot_dataset(mock_dataset, normalize_keys={"observation.image"})
|
features = _create_full_features()
|
||||||
|
norm_map = _create_full_norm_map()
|
||||||
|
|
||||||
|
processor = NormalizerProcessor.from_lerobot_dataset(
|
||||||
|
mock_dataset, features, norm_map, normalize_keys={"observation.image"}
|
||||||
|
)
|
||||||
|
|
||||||
assert processor.normalize_keys == {"observation.image"}
|
assert processor.normalize_keys == {"observation.image"}
|
||||||
assert "observation.image" in processor._tensor_stats
|
assert "observation.image" in processor._tensor_stats
|
||||||
@@ -339,7 +433,11 @@ def test_processor_from_lerobot_dataset(full_stats):
|
|||||||
|
|
||||||
|
|
||||||
def test_get_config(full_stats):
|
def test_get_config(full_stats):
|
||||||
processor = NormalizerProcessor(stats=full_stats, normalize_keys={"observation.image"}, eps=1e-6)
|
features = _create_full_features()
|
||||||
|
norm_map = _create_full_norm_map()
|
||||||
|
processor = NormalizerProcessor(
|
||||||
|
features=features, norm_map=norm_map, stats=full_stats, normalize_keys={"observation.image"}, eps=1e-6
|
||||||
|
)
|
||||||
|
|
||||||
config = processor.get_config()
|
config = processor.get_config()
|
||||||
assert config == {"normalize_keys": ["observation.image"], "eps": 1e-6}
|
assert config == {"normalize_keys": ["observation.image"], "eps": 1e-6}
|
||||||
@@ -366,7 +464,9 @@ def test_integration_with_robot_processor(normalizer_processor):
|
|||||||
# Edge case tests
|
# Edge case tests
|
||||||
def test_empty_observation():
|
def test_empty_observation():
|
||||||
stats = {"observation.image": {"mean": [0.5], "std": [0.2]}}
|
stats = {"observation.image": {"mean": [0.5], "std": [0.2]}}
|
||||||
normalizer = NormalizerProcessor(stats=stats)
|
features = {"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96))}
|
||||||
|
norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD}
|
||||||
|
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
|
||||||
|
|
||||||
transition = (None, None, None, None, None, None, None)
|
transition = (None, None, None, None, None, None, None)
|
||||||
result = normalizer(transition)
|
result = normalizer(transition)
|
||||||
@@ -375,19 +475,23 @@ def test_empty_observation():
|
|||||||
|
|
||||||
|
|
||||||
def test_empty_stats():
|
def test_empty_stats():
|
||||||
normalizer = NormalizerProcessor(stats={})
|
features = {"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96))}
|
||||||
|
norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD}
|
||||||
|
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats={})
|
||||||
observation = {"observation.image": torch.tensor([0.5])}
|
observation = {"observation.image": torch.tensor([0.5])}
|
||||||
transition = (observation, None, None, None, None, None, None)
|
transition = (observation, None, None, None, None, None, None)
|
||||||
|
|
||||||
result = normalizer(transition)
|
result = normalizer(transition)
|
||||||
# Should return observation unchanged
|
# Should return observation unchanged since no stats are available
|
||||||
assert torch.allclose(result[0]["observation.image"], observation["observation.image"])
|
assert torch.allclose(result[0]["observation.image"], observation["observation.image"])
|
||||||
|
|
||||||
|
|
||||||
def test_partial_stats():
|
def test_partial_stats():
|
||||||
"""If statistics are incomplete, the value should pass through unchanged."""
|
"""If statistics are incomplete, the value should pass through unchanged."""
|
||||||
stats = {"observation.image": {"mean": [0.5]}} # Missing std / (min,max)
|
stats = {"observation.image": {"mean": [0.5]}} # Missing std / (min,max)
|
||||||
normalizer = NormalizerProcessor(stats=stats)
|
features = {"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96))}
|
||||||
|
norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD}
|
||||||
|
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
|
||||||
observation = {"observation.image": torch.tensor([0.7])}
|
observation = {"observation.image": torch.tensor([0.7])}
|
||||||
transition = (observation, None, None, None, None, None, None)
|
transition = (observation, None, None, None, None, None, None)
|
||||||
|
|
||||||
@@ -399,6 +503,9 @@ def test_missing_action_stats_no_error():
|
|||||||
mock_dataset = Mock()
|
mock_dataset = Mock()
|
||||||
mock_dataset.meta.stats = {"observation.image": {"mean": [0.5], "std": [0.2]}}
|
mock_dataset.meta.stats = {"observation.image": {"mean": [0.5], "std": [0.2]}}
|
||||||
|
|
||||||
processor = UnnormalizerProcessor.from_lerobot_dataset(mock_dataset)
|
features = {"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96))}
|
||||||
|
norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD}
|
||||||
|
|
||||||
|
processor = UnnormalizerProcessor.from_lerobot_dataset(mock_dataset, features, norm_map)
|
||||||
# The tensor stats should not contain the 'action' key
|
# The tensor stats should not contain the 'action' key
|
||||||
assert "action" not in processor._tensor_stats
|
assert "action" not in processor._tensor_stats
|
||||||
|
|||||||
Reference in New Issue
Block a user