mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-29 23:49:43 +00:00
refactor(pipeline): minor improvements (#1684)
* chore(pipeline): remove unused features + device torch + envtransition keys * refactor(pipeline): ImageProcessor & StateProcessor are both implemented directly in VanillaObservationPRocessor * refactor(pipeline): RenameProcessor now inherits from ObservationProcessor + remove unused code * test(pipeline): fix broken test after refactors * docs(pipeline): update docstrings VanillaObservationProcessor * chore(pipeline): move None check to base pipeline classes
This commit is contained in:
@@ -16,11 +16,7 @@
|
|||||||
|
|
||||||
from .device_processor import DeviceProcessor
|
from .device_processor import DeviceProcessor
|
||||||
from .normalize_processor import NormalizerProcessor, UnnormalizerProcessor
|
from .normalize_processor import NormalizerProcessor, UnnormalizerProcessor
|
||||||
from .observation_processor import (
|
from .observation_processor import VanillaObservationProcessor
|
||||||
ImageProcessor,
|
|
||||||
StateProcessor,
|
|
||||||
VanillaObservationProcessor,
|
|
||||||
)
|
|
||||||
from .pipeline import (
|
from .pipeline import (
|
||||||
ActionProcessor,
|
ActionProcessor,
|
||||||
DoneProcessor,
|
DoneProcessor,
|
||||||
@@ -43,7 +39,6 @@ __all__ = [
|
|||||||
"DoneProcessor",
|
"DoneProcessor",
|
||||||
"EnvTransition",
|
"EnvTransition",
|
||||||
"IdentityProcessor",
|
"IdentityProcessor",
|
||||||
"ImageProcessor",
|
|
||||||
"InfoProcessor",
|
"InfoProcessor",
|
||||||
"NormalizerProcessor",
|
"NormalizerProcessor",
|
||||||
"UnnormalizerProcessor",
|
"UnnormalizerProcessor",
|
||||||
@@ -53,7 +48,6 @@ __all__ = [
|
|||||||
"RenameProcessor",
|
"RenameProcessor",
|
||||||
"RewardProcessor",
|
"RewardProcessor",
|
||||||
"RobotProcessor",
|
"RobotProcessor",
|
||||||
"StateProcessor",
|
|
||||||
"TransitionKey",
|
"TransitionKey",
|
||||||
"TruncatedProcessor",
|
"TruncatedProcessor",
|
||||||
"VanillaObservationProcessor",
|
"VanillaObservationProcessor",
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ import torch
|
|||||||
|
|
||||||
from lerobot.configs.types import PolicyFeature
|
from lerobot.configs.types import PolicyFeature
|
||||||
from lerobot.processor.pipeline import EnvTransition, TransitionKey
|
from lerobot.processor.pipeline import EnvTransition, TransitionKey
|
||||||
|
from lerobot.utils.utils import get_safe_torch_device
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -30,10 +31,11 @@ class DeviceProcessor:
|
|||||||
specified device (CPU or GPU) before they are returned.
|
specified device (CPU or GPU) before they are returned.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
device: str = "cpu"
|
device: torch.device = "cpu"
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self.non_blocking = "cuda" in self.device
|
self.device = get_safe_torch_device(self.device)
|
||||||
|
self.non_blocking = "cuda" in str(self.device)
|
||||||
|
|
||||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
# Create a copy of the transition
|
# Create a copy of the transition
|
||||||
|
|||||||
@@ -220,7 +220,6 @@ class UnnormalizerProcessor:
|
|||||||
features: dict[str, PolicyFeature]
|
features: dict[str, PolicyFeature]
|
||||||
norm_map: dict[FeatureType, NormalizationMode]
|
norm_map: dict[FeatureType, NormalizationMode]
|
||||||
stats: dict[str, dict[str, Any]] | None = None
|
stats: dict[str, dict[str, Any]] | None = None
|
||||||
eps: float = 1e-8
|
|
||||||
|
|
||||||
_tensor_stats: dict[str, dict[str, Tensor]] = field(default_factory=dict, init=False, repr=False)
|
_tensor_stats: dict[str, dict[str, Tensor]] = field(default_factory=dict, init=False, repr=False)
|
||||||
|
|
||||||
@@ -230,10 +229,8 @@ class UnnormalizerProcessor:
|
|||||||
dataset: LeRobotDataset,
|
dataset: LeRobotDataset,
|
||||||
features: dict[str, PolicyFeature],
|
features: dict[str, PolicyFeature],
|
||||||
norm_map: dict[FeatureType, NormalizationMode],
|
norm_map: dict[FeatureType, NormalizationMode],
|
||||||
*,
|
|
||||||
eps: float = 1e-8,
|
|
||||||
) -> UnnormalizerProcessor:
|
) -> UnnormalizerProcessor:
|
||||||
return cls(features=features, norm_map=norm_map, stats=dataset.meta.stats, eps=eps)
|
return cls(features=features, norm_map=norm_map, stats=dataset.meta.stats)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
# Handle deserialization from JSON config
|
# Handle deserialization from JSON config
|
||||||
@@ -308,7 +305,6 @@ class UnnormalizerProcessor:
|
|||||||
|
|
||||||
def get_config(self) -> dict[str, Any]:
|
def get_config(self) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"eps": self.eps,
|
|
||||||
"features": {
|
"features": {
|
||||||
key: {"type": ft.type.value, "shape": ft.shape} for key, ft in self.features.items()
|
key: {"type": ft.type.value, "shape": ft.shape} for key, ft in self.features.items()
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -13,8 +13,7 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import einops
|
import einops
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -23,52 +22,27 @@ from torch import Tensor
|
|||||||
|
|
||||||
from lerobot.configs.types import PolicyFeature
|
from lerobot.configs.types import PolicyFeature
|
||||||
from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
|
from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
|
||||||
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey
|
from lerobot.processor.pipeline import ObservationProcessor, ProcessorStepRegistry
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ImageProcessor:
|
@ProcessorStepRegistry.register(name="observation_processor")
|
||||||
"""Process image observations from environment format to policy format.
|
class VanillaObservationProcessor(ObservationProcessor):
|
||||||
|
|
||||||
Converts images from:
|
|
||||||
- Channel-last (H, W, C) to channel-first (C, H, W)
|
|
||||||
- uint8 [0, 255] to float32 [0, 1]
|
|
||||||
- Adds batch dimension if needed
|
|
||||||
- Handles both single images and dictionaries of images
|
|
||||||
"""
|
"""
|
||||||
|
Processes environment observations into the LeRobot format by handling both images and states.
|
||||||
|
|
||||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
Image processing:
|
||||||
observation = transition.get(TransitionKey.OBSERVATION)
|
- Converts channel-last (H, W, C) images to channel-first (C, H, W)
|
||||||
|
- Normalizes uint8 images ([0, 255]) to float32 ([0, 1])
|
||||||
|
- Adds a batch dimension if missing
|
||||||
|
- Supports single images and image dictionaries
|
||||||
|
|
||||||
if observation is None:
|
State processing:
|
||||||
return transition
|
- Maps 'environment_state' to observation.environment_state
|
||||||
|
- Maps 'agent_pos' to observation.state
|
||||||
processed_obs = {}
|
- Converts numpy arrays to tensors
|
||||||
|
- Adds a batch dimension if missing
|
||||||
# Copy all observations first
|
"""
|
||||||
for key, value in observation.items():
|
|
||||||
processed_obs[key] = value
|
|
||||||
|
|
||||||
# Handle pixels key if present
|
|
||||||
pixels = observation.get("pixels")
|
|
||||||
if pixels is not None:
|
|
||||||
# Remove pixels from processed_obs since we'll replace it with processed images
|
|
||||||
processed_obs.pop("pixels", None)
|
|
||||||
# Determine image mapping
|
|
||||||
if isinstance(pixels, dict):
|
|
||||||
imgs = {f"{OBS_IMAGES}.{key}": img for key, img in pixels.items()}
|
|
||||||
else:
|
|
||||||
imgs = {OBS_IMAGE: pixels}
|
|
||||||
|
|
||||||
# Process each image
|
|
||||||
for imgkey, img in imgs.items():
|
|
||||||
processed_img = self._process_single_image(img)
|
|
||||||
processed_obs[imgkey] = processed_img
|
|
||||||
|
|
||||||
# Return new transition with processed observation
|
|
||||||
new_transition = transition.copy()
|
|
||||||
new_transition[TransitionKey.OBSERVATION] = processed_obs
|
|
||||||
return new_transition
|
|
||||||
|
|
||||||
def _process_single_image(self, img: np.ndarray) -> Tensor:
|
def _process_single_image(self, img: np.ndarray) -> Tensor:
|
||||||
"""Process a single image array."""
|
"""Process a single image array."""
|
||||||
@@ -95,173 +69,89 @@ class ImageProcessor:
|
|||||||
|
|
||||||
return img_tensor
|
return img_tensor
|
||||||
|
|
||||||
def get_config(self) -> dict[str, Any]:
|
def _process_observation(self, observation):
|
||||||
"""Return configuration for serialization."""
|
|
||||||
return {}
|
|
||||||
|
|
||||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
|
||||||
"""Return state dictionary (empty for this processor)."""
|
|
||||||
return {}
|
|
||||||
|
|
||||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
|
||||||
"""Load state dictionary (no-op for this processor)."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def reset(self) -> None:
|
|
||||||
"""Reset processor state (no-op for this processor)."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
|
||||||
"""Transforms:
|
|
||||||
pixels -> OBS_IMAGE,
|
|
||||||
observation.pixels -> OBS_IMAGE,
|
|
||||||
pixels.<cam> -> OBS_IMAGES.<cam>,
|
|
||||||
observation.pixels.<cam> -> OBS_IMAGES.<cam>
|
|
||||||
"""
|
"""
|
||||||
if "pixels" in features:
|
Processes both image and state observations.
|
||||||
features[OBS_IMAGE] = features.pop("pixels")
|
|
||||||
if "observation.pixels" in features:
|
|
||||||
features[OBS_IMAGE] = features.pop("observation.pixels")
|
|
||||||
|
|
||||||
prefixes = ("pixels.", "observation.pixels.")
|
|
||||||
for key in list(features.keys()):
|
|
||||||
for p in prefixes:
|
|
||||||
if key.startswith(p):
|
|
||||||
suffix = key[len(p) :]
|
|
||||||
features[f"{OBS_IMAGES}.{suffix}"] = features.pop(key)
|
|
||||||
break
|
|
||||||
return features
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class StateProcessor:
|
|
||||||
"""Process state observations from environment format to policy format.
|
|
||||||
|
|
||||||
Handles:
|
|
||||||
- environment_state -> observation.environment_state
|
|
||||||
- agent_pos -> observation.state
|
|
||||||
- Converts numpy arrays to tensors
|
|
||||||
- Adds batch dimension if needed
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
processed_obs = observation.copy()
|
||||||
observation = transition.get(TransitionKey.OBSERVATION)
|
|
||||||
|
|
||||||
if observation is None:
|
if "pixels" in processed_obs:
|
||||||
return transition
|
pixels = processed_obs.pop("pixels")
|
||||||
|
|
||||||
processed_obs = dict(observation) # Copy existing observations
|
if isinstance(pixels, dict):
|
||||||
|
imgs = {f"{OBS_IMAGES}.{key}": img for key, img in pixels.items()}
|
||||||
|
else:
|
||||||
|
imgs = {OBS_IMAGE: pixels}
|
||||||
|
|
||||||
# Process environment_state
|
for imgkey, img in imgs.items():
|
||||||
if "environment_state" in observation:
|
processed_obs[imgkey] = self._process_single_image(img)
|
||||||
env_state = torch.from_numpy(observation["environment_state"]).float()
|
|
||||||
|
if "environment_state" in processed_obs:
|
||||||
|
env_state_np = processed_obs.pop("environment_state")
|
||||||
|
env_state = torch.from_numpy(env_state_np).float()
|
||||||
if env_state.dim() == 1:
|
if env_state.dim() == 1:
|
||||||
env_state = env_state.unsqueeze(0)
|
env_state = env_state.unsqueeze(0)
|
||||||
processed_obs[OBS_ENV_STATE] = env_state
|
processed_obs[OBS_ENV_STATE] = env_state
|
||||||
# Remove original key
|
|
||||||
del processed_obs["environment_state"]
|
|
||||||
|
|
||||||
# Process agent_pos
|
if "agent_pos" in processed_obs:
|
||||||
if "agent_pos" in observation:
|
agent_pos_np = processed_obs.pop("agent_pos")
|
||||||
agent_pos = torch.from_numpy(observation["agent_pos"]).float()
|
agent_pos = torch.from_numpy(agent_pos_np).float()
|
||||||
if agent_pos.dim() == 1:
|
if agent_pos.dim() == 1:
|
||||||
agent_pos = agent_pos.unsqueeze(0)
|
agent_pos = agent_pos.unsqueeze(0)
|
||||||
processed_obs[OBS_STATE] = agent_pos
|
processed_obs[OBS_STATE] = agent_pos
|
||||||
# Remove original key
|
|
||||||
del processed_obs["agent_pos"]
|
|
||||||
|
|
||||||
# Return new transition with processed observation
|
return processed_obs
|
||||||
new_transition = transition.copy()
|
|
||||||
new_transition[TransitionKey.OBSERVATION] = processed_obs
|
|
||||||
return new_transition
|
|
||||||
|
|
||||||
def get_config(self) -> dict[str, Any]:
|
def observation(self, observation):
|
||||||
"""Return configuration for serialization."""
|
return self._process_observation(observation)
|
||||||
return {}
|
|
||||||
|
|
||||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
|
||||||
"""Return state dictionary (empty for this processor)."""
|
|
||||||
return {}
|
|
||||||
|
|
||||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
|
||||||
"""Load state dictionary (no-op for this processor)."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def reset(self) -> None:
|
|
||||||
"""Reset processor state (no-op for this processor)."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||||
"""Transforms:
|
"""Transforms feature keys to a standardized contract.
|
||||||
environment_state -> OBS_ENV_STATE,
|
|
||||||
agent_pos -> OBS_STATE,
|
This method handles several renaming patterns:
|
||||||
observation.environment_state -> OBS_ENV_STATE,
|
- Exact matches (e.g., 'pixels' -> 'OBS_IMAGE').
|
||||||
observation.agent_pos -> OBS_STATE
|
- Prefixed exact matches (e.g., 'observation.pixels' -> 'OBS_IMAGE').
|
||||||
|
- Prefix matches (e.g., 'pixels.cam1' -> 'OBS_IMAGES.cam1').
|
||||||
|
- Prefixed prefix matches (e.g., 'observation.pixels.cam1' -> 'OBS_IMAGES.cam1').
|
||||||
|
- environment_state -> OBS_ENV_STATE,
|
||||||
|
- agent_pos -> OBS_STATE,
|
||||||
|
- observation.environment_state -> OBS_ENV_STATE,
|
||||||
|
- observation.agent_pos -> OBS_STATE
|
||||||
"""
|
"""
|
||||||
pairs = (
|
exact_pairs = {
|
||||||
("environment_state", OBS_ENV_STATE),
|
"pixels": OBS_IMAGE,
|
||||||
("agent_pos", OBS_STATE),
|
"environment_state": OBS_ENV_STATE,
|
||||||
)
|
"agent_pos": OBS_STATE,
|
||||||
for old, new in pairs:
|
}
|
||||||
if old in features:
|
|
||||||
features[new] = features.pop(old)
|
prefix_pairs = {
|
||||||
prefixed = f"observation.{old}"
|
"pixels.": f"{OBS_IMAGES}.",
|
||||||
if prefixed in features:
|
}
|
||||||
features[new] = features.pop(prefixed)
|
|
||||||
return features
|
for key in list(features.keys()):
|
||||||
|
matched_prefix = False
|
||||||
|
for old_prefix, new_prefix in prefix_pairs.items():
|
||||||
@dataclass
|
prefixed_old = f"observation.{old_prefix}"
|
||||||
@ProcessorStepRegistry.register(name="observation_processor")
|
if key.startswith(prefixed_old):
|
||||||
class VanillaObservationProcessor:
|
suffix = key[len(prefixed_old) :]
|
||||||
"""Complete observation processor that combines image and state processing.
|
features[f"{new_prefix}{suffix}"] = features.pop(key)
|
||||||
|
matched_prefix = True
|
||||||
This processor replicates the functionality of the original preprocess_observation
|
break
|
||||||
function but in a modular, composable way that fits into the pipeline architecture.
|
|
||||||
"""
|
if key.startswith(old_prefix):
|
||||||
|
suffix = key[len(old_prefix) :]
|
||||||
image_processor: ImageProcessor = field(default_factory=ImageProcessor)
|
features[f"{new_prefix}{suffix}"] = features.pop(key)
|
||||||
state_processor: StateProcessor = field(default_factory=StateProcessor)
|
matched_prefix = True
|
||||||
|
break
|
||||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
|
||||||
# First process images
|
if matched_prefix:
|
||||||
transition = self.image_processor(transition)
|
continue
|
||||||
# Then process state
|
|
||||||
transition = self.state_processor(transition)
|
for old, new in exact_pairs.items():
|
||||||
return transition
|
if key == old or key == f"observation.{old}":
|
||||||
|
if key in features:
|
||||||
def get_config(self) -> dict[str, Any]:
|
features[new] = features.pop(key)
|
||||||
"""Return configuration for serialization."""
|
break
|
||||||
return {
|
|
||||||
"image_processor": self.image_processor.get_config(),
|
|
||||||
"state_processor": self.state_processor.get_config(),
|
|
||||||
}
|
|
||||||
|
|
||||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
|
||||||
"""Return state dictionary."""
|
|
||||||
state = {}
|
|
||||||
state.update({f"image_processor.{k}": v for k, v in self.image_processor.state_dict().items()})
|
|
||||||
state.update({f"state_processor.{k}": v for k, v in self.state_processor.state_dict().items()})
|
|
||||||
return state
|
|
||||||
|
|
||||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
|
||||||
"""Load state dictionary."""
|
|
||||||
image_state = {
|
|
||||||
k.replace("image_processor.", ""): v for k, v in state.items() if k.startswith("image_processor.")
|
|
||||||
}
|
|
||||||
state_state = {
|
|
||||||
k.replace("state_processor.", ""): v for k, v in state.items() if k.startswith("state_processor.")
|
|
||||||
}
|
|
||||||
|
|
||||||
self.image_processor.load_state_dict(image_state)
|
|
||||||
self.state_processor.load_state_dict(state_state)
|
|
||||||
|
|
||||||
def reset(self) -> None:
|
|
||||||
"""Reset processor state."""
|
|
||||||
self.image_processor.reset()
|
|
||||||
self.state_processor.reset()
|
|
||||||
|
|
||||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
|
||||||
features = self.image_processor.feature_contract(features)
|
|
||||||
features = self.state_processor.feature_contract(features)
|
|
||||||
return features
|
return features
|
||||||
|
|||||||
@@ -36,6 +36,7 @@ from lerobot.configs.types import PolicyFeature
|
|||||||
class TransitionKey(str, Enum):
|
class TransitionKey(str, Enum):
|
||||||
"""Keys for accessing EnvTransition dictionary components."""
|
"""Keys for accessing EnvTransition dictionary components."""
|
||||||
|
|
||||||
|
# TODO(Steven): Use consts
|
||||||
OBSERVATION = "observation"
|
OBSERVATION = "observation"
|
||||||
ACTION = "action"
|
ACTION = "action"
|
||||||
REWARD = "reward"
|
REWARD = "reward"
|
||||||
@@ -45,19 +46,18 @@ class TransitionKey(str, Enum):
|
|||||||
COMPLEMENTARY_DATA = "complementary_data"
|
COMPLEMENTARY_DATA = "complementary_data"
|
||||||
|
|
||||||
|
|
||||||
class EnvTransition(TypedDict, total=False):
|
EnvTransition = TypedDict(
|
||||||
"""Environment transition data structure.
|
"EnvTransition",
|
||||||
|
{
|
||||||
All fields are optional (total=False) to allow flexible usage.
|
TransitionKey.OBSERVATION.value: dict[str, Any] | None,
|
||||||
"""
|
TransitionKey.ACTION.value: Any | torch.Tensor | None,
|
||||||
|
TransitionKey.REWARD.value: float | torch.Tensor | None,
|
||||||
observation: dict[str, Any] | None
|
TransitionKey.DONE.value: bool | torch.Tensor | None,
|
||||||
action: Any | torch.Tensor | None
|
TransitionKey.TRUNCATED.value: bool | torch.Tensor | None,
|
||||||
reward: float | torch.Tensor | None
|
TransitionKey.INFO.value: dict[str, Any] | None,
|
||||||
done: bool | torch.Tensor | None
|
TransitionKey.COMPLEMENTARY_DATA.value: dict[str, Any] | None,
|
||||||
truncated: bool | torch.Tensor | None
|
},
|
||||||
info: dict[str, Any] | None
|
)
|
||||||
complementary_data: dict[str, Any] | None
|
|
||||||
|
|
||||||
|
|
||||||
class ProcessorStepRegistry:
|
class ProcessorStepRegistry:
|
||||||
@@ -135,8 +135,8 @@ class ProcessorStepRegistry:
|
|||||||
class ProcessorStep(Protocol):
|
class ProcessorStep(Protocol):
|
||||||
"""Structural typing interface for a single processor step.
|
"""Structural typing interface for a single processor step.
|
||||||
|
|
||||||
A step is any callable accepting a full `EnvTransition` tuple and
|
A step is any callable accepting a full `EnvTransition` dict and
|
||||||
returning a (possibly modified) tuple of the same structure. Implementers
|
returning a (possibly modified) dict of the same structure. Implementers
|
||||||
are encouraged—but not required—to expose the optional helper methods
|
are encouraged—but not required—to expose the optional helper methods
|
||||||
listed below. When present, these hooks let `RobotProcessor`
|
listed below. When present, these hooks let `RobotProcessor`
|
||||||
automatically serialise the step's configuration and learnable state using
|
automatically serialise the step's configuration and learnable state using
|
||||||
@@ -254,24 +254,22 @@ class RobotProcessor(ModelHubMixin):
|
|||||||
Composable, debuggable post-processing processor for robot transitions.
|
Composable, debuggable post-processing processor for robot transitions.
|
||||||
|
|
||||||
The class orchestrates an ordered collection of small, functional transforms—steps—executed
|
The class orchestrates an ordered collection of small, functional transforms—steps—executed
|
||||||
left-to-right on each incoming `EnvTransition`. It can process both `EnvTransition` tuples
|
left-to-right on each incoming `EnvTransition`. It can process both `EnvTransition` dicts
|
||||||
and batch dictionaries, automatically converting between formats as needed.
|
and batch dictionaries, automatically converting between formats as needed.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
steps: Ordered list of processing steps executed on every call. Defaults to empty list.
|
steps: Ordered list of processing steps executed on every call. Defaults to empty list.
|
||||||
name: Human-readable identifier that is persisted inside the JSON config.
|
name: Human-readable identifier that is persisted inside the JSON config.
|
||||||
Defaults to "RobotProcessor".
|
Defaults to "RobotProcessor".
|
||||||
seed: Global seed forwarded to steps that choose to consume it. Defaults to None.
|
to_transition: Function to convert batch dict to EnvTransition dict.
|
||||||
to_transition: Function to convert batch dict to EnvTransition tuple.
|
|
||||||
Defaults to _default_batch_to_transition.
|
Defaults to _default_batch_to_transition.
|
||||||
to_output: Function to convert EnvTransition tuple to the desired output format.
|
to_output: Function to convert EnvTransition dict to the desired output format.
|
||||||
Usually it is a batch dict or EnvTransition tuple.
|
Usually it is a batch dict or EnvTransition dict.
|
||||||
Defaults to _default_transition_to_batch.
|
Defaults to _default_transition_to_batch.
|
||||||
before_step_hooks: List of hooks called before each step. Each hook receives the step
|
before_step_hooks: List of hooks called before each step. Each hook receives the step
|
||||||
index and transition, and can optionally return a modified transition.
|
index and transition, and can optionally return a modified transition.
|
||||||
after_step_hooks: List of hooks called after each step. Each hook receives the step
|
after_step_hooks: List of hooks called after each step. Each hook receives the step
|
||||||
index and transition, and can optionally return a modified transition.
|
index and transition, and can optionally return a modified transition.
|
||||||
reset_hooks: List of hooks called during processor reset.
|
|
||||||
|
|
||||||
Hook Semantics:
|
Hook Semantics:
|
||||||
- Hooks are executed sequentially in the order they were registered. There is no way to
|
- Hooks are executed sequentially in the order they were registered. There is no way to
|
||||||
@@ -290,7 +288,6 @@ class RobotProcessor(ModelHubMixin):
|
|||||||
|
|
||||||
steps: Sequence[ProcessorStep] = field(default_factory=list)
|
steps: Sequence[ProcessorStep] = field(default_factory=list)
|
||||||
name: str = "RobotProcessor"
|
name: str = "RobotProcessor"
|
||||||
seed: int | None = None
|
|
||||||
|
|
||||||
to_transition: Callable[[dict[str, Any]], EnvTransition] = field(
|
to_transition: Callable[[dict[str, Any]], EnvTransition] = field(
|
||||||
default_factory=lambda: _default_batch_to_transition, repr=False
|
default_factory=lambda: _default_batch_to_transition, repr=False
|
||||||
@@ -303,7 +300,6 @@ class RobotProcessor(ModelHubMixin):
|
|||||||
# Hooks do not modify transitions - they are called for logging, debugging, or monitoring purposes
|
# Hooks do not modify transitions - they are called for logging, debugging, or monitoring purposes
|
||||||
before_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False)
|
before_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False)
|
||||||
after_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False)
|
after_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False)
|
||||||
reset_hooks: list[Callable[[], None]] = field(default_factory=list, repr=False)
|
|
||||||
|
|
||||||
def __call__(self, data: EnvTransition | dict[str, Any]):
|
def __call__(self, data: EnvTransition | dict[str, Any]):
|
||||||
"""Process data through all steps.
|
"""Process data through all steps.
|
||||||
@@ -431,7 +427,6 @@ class RobotProcessor(ModelHubMixin):
|
|||||||
|
|
||||||
config: dict[str, Any] = {
|
config: dict[str, Any] = {
|
||||||
"name": self.name,
|
"name": self.name,
|
||||||
"seed": self.seed,
|
|
||||||
"steps": [],
|
"steps": [],
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -728,7 +723,7 @@ class RobotProcessor(ModelHubMixin):
|
|||||||
f"Make sure override keys match exact step class names or registry names."
|
f"Make sure override keys match exact step class names or registry names."
|
||||||
)
|
)
|
||||||
|
|
||||||
return cls(steps, loaded_config.get("name", "RobotProcessor"), loaded_config.get("seed"))
|
return cls(steps, loaded_config.get("name", "RobotProcessor"))
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
"""Return the number of steps in the processor."""
|
"""Return the number of steps in the processor."""
|
||||||
@@ -740,7 +735,7 @@ class RobotProcessor(ModelHubMixin):
|
|||||||
* ``slice`` – returns a new RobotProcessor with the sliced steps.
|
* ``slice`` – returns a new RobotProcessor with the sliced steps.
|
||||||
"""
|
"""
|
||||||
if isinstance(idx, slice):
|
if isinstance(idx, slice):
|
||||||
return RobotProcessor(self.steps[idx], self.name, self.seed)
|
return RobotProcessor(self.steps[idx], self.name)
|
||||||
return self.steps[idx]
|
return self.steps[idx]
|
||||||
|
|
||||||
def register_before_step_hook(self, fn: Callable[[int, EnvTransition], None]):
|
def register_before_step_hook(self, fn: Callable[[int, EnvTransition], None]):
|
||||||
@@ -783,71 +778,11 @@ class RobotProcessor(ModelHubMixin):
|
|||||||
f"Hook {fn} not found in after_step_hooks. Make sure to pass the exact same function reference."
|
f"Hook {fn} not found in after_step_hooks. Make sure to pass the exact same function reference."
|
||||||
) from None
|
) from None
|
||||||
|
|
||||||
def register_reset_hook(self, fn: Callable[[], None]):
|
|
||||||
"""Attach fn to be executed when reset is called."""
|
|
||||||
self.reset_hooks.append(fn)
|
|
||||||
|
|
||||||
def unregister_reset_hook(self, fn: Callable[[], None]):
|
|
||||||
"""Remove a previously registered reset hook.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
fn: The exact function reference that was registered. Must be the same object.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If the hook is not found in the registered hooks.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
self.reset_hooks.remove(fn)
|
|
||||||
except ValueError:
|
|
||||||
raise ValueError(
|
|
||||||
f"Hook {fn} not found in reset_hooks. Make sure to pass the exact same function reference."
|
|
||||||
) from None
|
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
"""Clear state in every step that implements ``reset()`` and fire registered hooks."""
|
"""Clear state in every step that implements ``reset()`` and fire registered hooks."""
|
||||||
for step in self.steps:
|
for step in self.steps:
|
||||||
if hasattr(step, "reset"):
|
if hasattr(step, "reset"):
|
||||||
step.reset() # type: ignore[attr-defined]
|
step.reset() # type: ignore[attr-defined]
|
||||||
for fn in self.reset_hooks:
|
|
||||||
fn()
|
|
||||||
|
|
||||||
def profile_steps(
|
|
||||||
self, transition: EnvTransition, num_runs: int = 100, warmup_runs: int = 5
|
|
||||||
) -> dict[str, float]:
|
|
||||||
"""Profile the execution time of each step for performance optimization."""
|
|
||||||
import copy
|
|
||||||
import time
|
|
||||||
|
|
||||||
profile_results = {}
|
|
||||||
|
|
||||||
# Make a copy to avoid altering the original transition
|
|
||||||
transition_copy = copy.deepcopy(transition)
|
|
||||||
|
|
||||||
# Get intermediate transitions for each step using step_through
|
|
||||||
intermediate_transitions = list(self.step_through(transition_copy))
|
|
||||||
|
|
||||||
for idx, processor_step in enumerate(self.steps):
|
|
||||||
step_name = f"step_{idx}_{processor_step.__class__.__name__}"
|
|
||||||
|
|
||||||
# Use the appropriate input transition for this step
|
|
||||||
input_transition = intermediate_transitions[idx]
|
|
||||||
|
|
||||||
# Warm up - copy transition for each run to ensure consistent conditions
|
|
||||||
for _ in range(warmup_runs):
|
|
||||||
transition_copy = copy.deepcopy(input_transition)
|
|
||||||
_ = processor_step(transition_copy)
|
|
||||||
|
|
||||||
# Time the step - copy transition for each run to ensure consistent conditions
|
|
||||||
start_time = time.perf_counter()
|
|
||||||
for _ in range(num_runs):
|
|
||||||
transition_copy = copy.deepcopy(input_transition)
|
|
||||||
_ = processor_step(transition_copy)
|
|
||||||
end_time = time.perf_counter()
|
|
||||||
|
|
||||||
avg_time = (end_time - start_time) / num_runs * 1000 # Convert to milliseconds
|
|
||||||
profile_results[step_name] = avg_time
|
|
||||||
|
|
||||||
return profile_results
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
"""Return a readable string representation of the processor."""
|
"""Return a readable string representation of the processor."""
|
||||||
@@ -864,9 +799,6 @@ class RobotProcessor(ModelHubMixin):
|
|||||||
|
|
||||||
parts = [f"name='{self.name}'", steps_repr]
|
parts = [f"name='{self.name}'", steps_repr]
|
||||||
|
|
||||||
if self.seed is not None:
|
|
||||||
parts.append(f"seed={self.seed}")
|
|
||||||
|
|
||||||
return f"RobotProcessor({', '.join(parts)})"
|
return f"RobotProcessor({', '.join(parts)})"
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
@@ -931,6 +863,9 @@ class ObservationProcessor:
|
|||||||
|
|
||||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
observation = transition.get(TransitionKey.OBSERVATION)
|
observation = transition.get(TransitionKey.OBSERVATION)
|
||||||
|
if observation is None:
|
||||||
|
return transition
|
||||||
|
|
||||||
processed_observation = self.observation(observation)
|
processed_observation = self.observation(observation)
|
||||||
# Create a new transition dict with the processed observation
|
# Create a new transition dict with the processed observation
|
||||||
new_transition = transition.copy()
|
new_transition = transition.copy()
|
||||||
@@ -988,6 +923,9 @@ class ActionProcessor:
|
|||||||
|
|
||||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
action = transition.get(TransitionKey.ACTION)
|
action = transition.get(TransitionKey.ACTION)
|
||||||
|
if action is None:
|
||||||
|
return transition
|
||||||
|
|
||||||
processed_action = self.action(action)
|
processed_action = self.action(action)
|
||||||
# Create a new transition dict with the processed action
|
# Create a new transition dict with the processed action
|
||||||
new_transition = transition.copy()
|
new_transition = transition.copy()
|
||||||
@@ -1044,6 +982,9 @@ class RewardProcessor:
|
|||||||
|
|
||||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
reward = transition.get(TransitionKey.REWARD)
|
reward = transition.get(TransitionKey.REWARD)
|
||||||
|
if reward is None:
|
||||||
|
return transition
|
||||||
|
|
||||||
processed_reward = self.reward(reward)
|
processed_reward = self.reward(reward)
|
||||||
# Create a new transition dict with the processed reward
|
# Create a new transition dict with the processed reward
|
||||||
new_transition = transition.copy()
|
new_transition = transition.copy()
|
||||||
@@ -1105,6 +1046,9 @@ class DoneProcessor:
|
|||||||
|
|
||||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
done = transition.get(TransitionKey.DONE)
|
done = transition.get(TransitionKey.DONE)
|
||||||
|
if done is None:
|
||||||
|
return transition
|
||||||
|
|
||||||
processed_done = self.done(done)
|
processed_done = self.done(done)
|
||||||
# Create a new transition dict with the processed done flag
|
# Create a new transition dict with the processed done flag
|
||||||
new_transition = transition.copy()
|
new_transition = transition.copy()
|
||||||
@@ -1162,6 +1106,9 @@ class TruncatedProcessor:
|
|||||||
|
|
||||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
truncated = transition.get(TransitionKey.TRUNCATED)
|
truncated = transition.get(TransitionKey.TRUNCATED)
|
||||||
|
if truncated is None:
|
||||||
|
return transition
|
||||||
|
|
||||||
processed_truncated = self.truncated(truncated)
|
processed_truncated = self.truncated(truncated)
|
||||||
# Create a new transition dict with the processed truncated flag
|
# Create a new transition dict with the processed truncated flag
|
||||||
new_transition = transition.copy()
|
new_transition = transition.copy()
|
||||||
@@ -1224,6 +1171,9 @@ class InfoProcessor:
|
|||||||
|
|
||||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
info = transition.get(TransitionKey.INFO)
|
info = transition.get(TransitionKey.INFO)
|
||||||
|
if info is None:
|
||||||
|
return transition
|
||||||
|
|
||||||
processed_info = self.info(info)
|
processed_info = self.info(info)
|
||||||
# Create a new transition dict with the processed info
|
# Create a new transition dict with the processed info
|
||||||
new_transition = transition.copy()
|
new_transition = transition.copy()
|
||||||
@@ -1267,6 +1217,9 @@ class ComplementaryDataProcessor:
|
|||||||
|
|
||||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
|
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
|
||||||
|
if complementary_data is None:
|
||||||
|
return transition
|
||||||
|
|
||||||
processed_complementary_data = self.complementary_data(complementary_data)
|
processed_complementary_data = self.complementary_data(complementary_data)
|
||||||
# Create a new transition dict with the processed complementary data
|
# Create a new transition dict with the processed complementary data
|
||||||
new_transition = transition.copy()
|
new_transition = transition.copy()
|
||||||
|
|||||||
@@ -16,24 +16,21 @@
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from lerobot.configs.types import PolicyFeature
|
from lerobot.configs.types import PolicyFeature
|
||||||
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey
|
from lerobot.processor.pipeline import (
|
||||||
|
ObservationProcessor,
|
||||||
|
ProcessorStepRegistry,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ProcessorStepRegistry.register(name="rename_processor")
|
@ProcessorStepRegistry.register(name="rename_processor")
|
||||||
class RenameProcessor:
|
class RenameProcessor(ObservationProcessor):
|
||||||
"""Rename processor that renames keys in the observation."""
|
"""Rename processor that renames keys in the observation."""
|
||||||
|
|
||||||
rename_map: dict[str, str] = field(default_factory=dict)
|
rename_map: dict[str, str] = field(default_factory=dict)
|
||||||
|
|
||||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
def observation(self, observation):
|
||||||
observation = transition.get(TransitionKey.OBSERVATION)
|
|
||||||
if observation is None:
|
|
||||||
return transition
|
|
||||||
|
|
||||||
processed_obs = {}
|
processed_obs = {}
|
||||||
for key, value in observation.items():
|
for key, value in observation.items():
|
||||||
if key in self.rename_map:
|
if key in self.rename_map:
|
||||||
@@ -41,20 +38,11 @@ class RenameProcessor:
|
|||||||
else:
|
else:
|
||||||
processed_obs[key] = value
|
processed_obs[key] = value
|
||||||
|
|
||||||
# Create a new transition with the renamed observation
|
return processed_obs
|
||||||
new_transition = transition.copy()
|
|
||||||
new_transition[TransitionKey.OBSERVATION] = processed_obs
|
|
||||||
return new_transition
|
|
||||||
|
|
||||||
def get_config(self) -> dict[str, Any]:
|
def get_config(self) -> dict[str, Any]:
|
||||||
return {"rename_map": self.rename_map}
|
return {"rename_map": self.rename_map}
|
||||||
|
|
||||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||||
"""Transforms:
|
"""Transforms:
|
||||||
- Each key in the observation that appears in `rename_map` is renamed to its value.
|
- Each key in the observation that appears in `rename_map` is renamed to its value.
|
||||||
|
|||||||
@@ -20,11 +20,7 @@ import torch
|
|||||||
|
|
||||||
from lerobot.configs.types import FeatureType
|
from lerobot.configs.types import FeatureType
|
||||||
from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
|
from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
|
||||||
from lerobot.processor import (
|
from lerobot.processor import VanillaObservationProcessor
|
||||||
ImageProcessor,
|
|
||||||
StateProcessor,
|
|
||||||
VanillaObservationProcessor,
|
|
||||||
)
|
|
||||||
from lerobot.processor.pipeline import TransitionKey
|
from lerobot.processor.pipeline import TransitionKey
|
||||||
from tests.conftest import assert_contract_is_typed
|
from tests.conftest import assert_contract_is_typed
|
||||||
|
|
||||||
@@ -46,7 +42,7 @@ def create_transition(
|
|||||||
|
|
||||||
def test_process_single_image():
|
def test_process_single_image():
|
||||||
"""Test processing a single image."""
|
"""Test processing a single image."""
|
||||||
processor = ImageProcessor()
|
processor = VanillaObservationProcessor()
|
||||||
|
|
||||||
# Create a mock image (H, W, C) format, uint8
|
# Create a mock image (H, W, C) format, uint8
|
||||||
image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8)
|
image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8)
|
||||||
@@ -72,7 +68,7 @@ def test_process_single_image():
|
|||||||
|
|
||||||
def test_process_image_dict():
|
def test_process_image_dict():
|
||||||
"""Test processing multiple images in a dictionary."""
|
"""Test processing multiple images in a dictionary."""
|
||||||
processor = ImageProcessor()
|
processor = VanillaObservationProcessor()
|
||||||
|
|
||||||
# Create mock images
|
# Create mock images
|
||||||
image1 = np.random.randint(0, 256, size=(32, 32, 3), dtype=np.uint8)
|
image1 = np.random.randint(0, 256, size=(32, 32, 3), dtype=np.uint8)
|
||||||
@@ -95,7 +91,7 @@ def test_process_image_dict():
|
|||||||
|
|
||||||
def test_process_batched_image():
|
def test_process_batched_image():
|
||||||
"""Test processing already batched images."""
|
"""Test processing already batched images."""
|
||||||
processor = ImageProcessor()
|
processor = VanillaObservationProcessor()
|
||||||
|
|
||||||
# Create a batched image (B, H, W, C)
|
# Create a batched image (B, H, W, C)
|
||||||
image = np.random.randint(0, 256, size=(2, 64, 64, 3), dtype=np.uint8)
|
image = np.random.randint(0, 256, size=(2, 64, 64, 3), dtype=np.uint8)
|
||||||
@@ -112,7 +108,7 @@ def test_process_batched_image():
|
|||||||
|
|
||||||
def test_invalid_image_format():
|
def test_invalid_image_format():
|
||||||
"""Test error handling for invalid image formats."""
|
"""Test error handling for invalid image formats."""
|
||||||
processor = ImageProcessor()
|
processor = VanillaObservationProcessor()
|
||||||
|
|
||||||
# Test wrong channel order (channels first)
|
# Test wrong channel order (channels first)
|
||||||
image = np.random.randint(0, 256, size=(3, 64, 64), dtype=np.uint8)
|
image = np.random.randint(0, 256, size=(3, 64, 64), dtype=np.uint8)
|
||||||
@@ -125,7 +121,7 @@ def test_invalid_image_format():
|
|||||||
|
|
||||||
def test_invalid_image_dtype():
|
def test_invalid_image_dtype():
|
||||||
"""Test error handling for invalid image dtype."""
|
"""Test error handling for invalid image dtype."""
|
||||||
processor = ImageProcessor()
|
processor = VanillaObservationProcessor()
|
||||||
|
|
||||||
# Test wrong dtype
|
# Test wrong dtype
|
||||||
image = np.random.rand(64, 64, 3).astype(np.float32)
|
image = np.random.rand(64, 64, 3).astype(np.float32)
|
||||||
@@ -138,7 +134,7 @@ def test_invalid_image_dtype():
|
|||||||
|
|
||||||
def test_no_pixels_in_observation():
|
def test_no_pixels_in_observation():
|
||||||
"""Test processor when no pixels are in observation."""
|
"""Test processor when no pixels are in observation."""
|
||||||
processor = ImageProcessor()
|
processor = VanillaObservationProcessor()
|
||||||
|
|
||||||
observation = {"other_data": np.array([1, 2, 3])}
|
observation = {"other_data": np.array([1, 2, 3])}
|
||||||
transition = create_transition(observation=observation)
|
transition = create_transition(observation=observation)
|
||||||
@@ -153,7 +149,7 @@ def test_no_pixels_in_observation():
|
|||||||
|
|
||||||
def test_none_observation():
|
def test_none_observation():
|
||||||
"""Test processor with None observation."""
|
"""Test processor with None observation."""
|
||||||
processor = ImageProcessor()
|
processor = VanillaObservationProcessor()
|
||||||
|
|
||||||
transition = create_transition()
|
transition = create_transition()
|
||||||
result = processor(transition)
|
result = processor(transition)
|
||||||
@@ -163,7 +159,7 @@ def test_none_observation():
|
|||||||
|
|
||||||
def test_serialization_methods():
|
def test_serialization_methods():
|
||||||
"""Test serialization methods."""
|
"""Test serialization methods."""
|
||||||
processor = ImageProcessor()
|
processor = VanillaObservationProcessor()
|
||||||
|
|
||||||
# Test get_config
|
# Test get_config
|
||||||
config = processor.get_config()
|
config = processor.get_config()
|
||||||
@@ -182,7 +178,7 @@ def test_serialization_methods():
|
|||||||
|
|
||||||
def test_process_environment_state():
|
def test_process_environment_state():
|
||||||
"""Test processing environment_state."""
|
"""Test processing environment_state."""
|
||||||
processor = StateProcessor()
|
processor = VanillaObservationProcessor()
|
||||||
|
|
||||||
env_state = np.array([1.0, 2.0, 3.0], dtype=np.float32)
|
env_state = np.array([1.0, 2.0, 3.0], dtype=np.float32)
|
||||||
observation = {"environment_state": env_state}
|
observation = {"environment_state": env_state}
|
||||||
@@ -203,7 +199,7 @@ def test_process_environment_state():
|
|||||||
|
|
||||||
def test_process_agent_pos():
|
def test_process_agent_pos():
|
||||||
"""Test processing agent_pos."""
|
"""Test processing agent_pos."""
|
||||||
processor = StateProcessor()
|
processor = VanillaObservationProcessor()
|
||||||
|
|
||||||
agent_pos = np.array([0.5, -0.5, 1.0], dtype=np.float32)
|
agent_pos = np.array([0.5, -0.5, 1.0], dtype=np.float32)
|
||||||
observation = {"agent_pos": agent_pos}
|
observation = {"agent_pos": agent_pos}
|
||||||
@@ -224,7 +220,7 @@ def test_process_agent_pos():
|
|||||||
|
|
||||||
def test_process_batched_states():
|
def test_process_batched_states():
|
||||||
"""Test processing already batched states."""
|
"""Test processing already batched states."""
|
||||||
processor = StateProcessor()
|
processor = VanillaObservationProcessor()
|
||||||
|
|
||||||
env_state = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)
|
env_state = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)
|
||||||
agent_pos = np.array([[0.5, -0.5], [1.0, -1.0]], dtype=np.float32)
|
agent_pos = np.array([[0.5, -0.5], [1.0, -1.0]], dtype=np.float32)
|
||||||
@@ -242,7 +238,7 @@ def test_process_batched_states():
|
|||||||
|
|
||||||
def test_process_both_states():
|
def test_process_both_states():
|
||||||
"""Test processing both environment_state and agent_pos."""
|
"""Test processing both environment_state and agent_pos."""
|
||||||
processor = StateProcessor()
|
processor = VanillaObservationProcessor()
|
||||||
|
|
||||||
env_state = np.array([1.0, 2.0], dtype=np.float32)
|
env_state = np.array([1.0, 2.0], dtype=np.float32)
|
||||||
agent_pos = np.array([0.5, -0.5], dtype=np.float32)
|
agent_pos = np.array([0.5, -0.5], dtype=np.float32)
|
||||||
@@ -267,7 +263,7 @@ def test_process_both_states():
|
|||||||
|
|
||||||
def test_no_states_in_observation():
|
def test_no_states_in_observation():
|
||||||
"""Test processor when no states are in observation."""
|
"""Test processor when no states are in observation."""
|
||||||
processor = StateProcessor()
|
processor = VanillaObservationProcessor()
|
||||||
|
|
||||||
observation = {"other_data": np.array([1, 2, 3])}
|
observation = {"other_data": np.array([1, 2, 3])}
|
||||||
transition = create_transition(observation=observation)
|
transition = create_transition(observation=observation)
|
||||||
@@ -359,17 +355,6 @@ def test_empty_observation():
|
|||||||
assert processed_obs == {}
|
assert processed_obs == {}
|
||||||
|
|
||||||
|
|
||||||
def test_custom_sub_processors():
|
|
||||||
"""Test ObservationProcessor with custom sub-processors."""
|
|
||||||
image_proc = ImageProcessor()
|
|
||||||
state_proc = StateProcessor()
|
|
||||||
processor = VanillaObservationProcessor(image_processor=image_proc, state_processor=state_proc)
|
|
||||||
|
|
||||||
# Should use the provided processors
|
|
||||||
assert processor.image_processor is image_proc
|
|
||||||
assert processor.state_processor is state_proc
|
|
||||||
|
|
||||||
|
|
||||||
def test_equivalent_to_original_function():
|
def test_equivalent_to_original_function():
|
||||||
"""Test that ObservationProcessor produces equivalent results to preprocess_observation."""
|
"""Test that ObservationProcessor produces equivalent results to preprocess_observation."""
|
||||||
# Import the original function for comparison
|
# Import the original function for comparison
|
||||||
@@ -426,7 +411,7 @@ def test_equivalent_with_image_dict():
|
|||||||
|
|
||||||
|
|
||||||
def test_image_processor_feature_contract_pixels_to_image(policy_feature_factory):
|
def test_image_processor_feature_contract_pixels_to_image(policy_feature_factory):
|
||||||
processor = ImageProcessor()
|
processor = VanillaObservationProcessor()
|
||||||
features = {
|
features = {
|
||||||
"pixels": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
|
"pixels": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
|
||||||
"keep": policy_feature_factory(FeatureType.ENV, (1,)),
|
"keep": policy_feature_factory(FeatureType.ENV, (1,)),
|
||||||
@@ -440,7 +425,7 @@ def test_image_processor_feature_contract_pixels_to_image(policy_feature_factory
|
|||||||
|
|
||||||
|
|
||||||
def test_image_processor_feature_contract_observation_pixels_to_image(policy_feature_factory):
|
def test_image_processor_feature_contract_observation_pixels_to_image(policy_feature_factory):
|
||||||
processor = ImageProcessor()
|
processor = VanillaObservationProcessor()
|
||||||
features = {
|
features = {
|
||||||
"observation.pixels": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
|
"observation.pixels": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
|
||||||
"keep": policy_feature_factory(FeatureType.ENV, (1,)),
|
"keep": policy_feature_factory(FeatureType.ENV, (1,)),
|
||||||
@@ -454,7 +439,7 @@ def test_image_processor_feature_contract_observation_pixels_to_image(policy_fea
|
|||||||
|
|
||||||
|
|
||||||
def test_image_processor_feature_contract_multi_camera_and_prefixed(policy_feature_factory):
|
def test_image_processor_feature_contract_multi_camera_and_prefixed(policy_feature_factory):
|
||||||
processor = ImageProcessor()
|
processor = VanillaObservationProcessor()
|
||||||
features = {
|
features = {
|
||||||
"pixels.front": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
|
"pixels.front": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
|
||||||
"pixels.wrist": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
|
"pixels.wrist": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
|
||||||
@@ -472,7 +457,7 @@ def test_image_processor_feature_contract_multi_camera_and_prefixed(policy_featu
|
|||||||
|
|
||||||
|
|
||||||
def test_state_processor_feature_contract_environment_and_agent_pos(policy_feature_factory):
|
def test_state_processor_feature_contract_environment_and_agent_pos(policy_feature_factory):
|
||||||
processor = StateProcessor()
|
processor = VanillaObservationProcessor()
|
||||||
features = {
|
features = {
|
||||||
"environment_state": policy_feature_factory(FeatureType.STATE, (3,)),
|
"environment_state": policy_feature_factory(FeatureType.STATE, (3,)),
|
||||||
"agent_pos": policy_feature_factory(FeatureType.STATE, (7,)),
|
"agent_pos": policy_feature_factory(FeatureType.STATE, (7,)),
|
||||||
@@ -488,7 +473,7 @@ def test_state_processor_feature_contract_environment_and_agent_pos(policy_featu
|
|||||||
|
|
||||||
|
|
||||||
def test_state_processor_feature_contract_prefixed_inputs(policy_feature_factory):
|
def test_state_processor_feature_contract_prefixed_inputs(policy_feature_factory):
|
||||||
proc = StateProcessor()
|
proc = VanillaObservationProcessor()
|
||||||
features = {
|
features = {
|
||||||
"observation.environment_state": policy_feature_factory(FeatureType.STATE, (2,)),
|
"observation.environment_state": policy_feature_factory(FeatureType.STATE, (2,)),
|
||||||
"observation.agent_pos": policy_feature_factory(FeatureType.STATE, (4,)),
|
"observation.agent_pos": policy_feature_factory(FeatureType.STATE, (4,)),
|
||||||
|
|||||||
@@ -363,32 +363,6 @@ def test_hooks():
|
|||||||
assert after_calls == [0]
|
assert after_calls == [0]
|
||||||
|
|
||||||
|
|
||||||
def test_reset():
|
|
||||||
"""Test pipeline reset functionality."""
|
|
||||||
step = MockStep("test_step")
|
|
||||||
pipeline = RobotProcessor([step])
|
|
||||||
|
|
||||||
reset_called = []
|
|
||||||
|
|
||||||
def reset_hook():
|
|
||||||
reset_called.append(True)
|
|
||||||
|
|
||||||
pipeline.register_reset_hook(reset_hook)
|
|
||||||
|
|
||||||
# Make some calls to increment counter
|
|
||||||
transition = create_transition()
|
|
||||||
pipeline(transition)
|
|
||||||
pipeline(transition)
|
|
||||||
|
|
||||||
assert step.counter == 2
|
|
||||||
|
|
||||||
# Reset should reset step and call hook
|
|
||||||
pipeline.reset()
|
|
||||||
|
|
||||||
assert step.counter == 0
|
|
||||||
assert len(reset_called) == 1
|
|
||||||
|
|
||||||
|
|
||||||
def test_unregister_hooks():
|
def test_unregister_hooks():
|
||||||
"""Test unregistering hooks from the pipeline."""
|
"""Test unregistering hooks from the pipeline."""
|
||||||
step = MockStep("test_step")
|
step = MockStep("test_step")
|
||||||
@@ -428,21 +402,6 @@ def test_unregister_hooks():
|
|||||||
pipeline(transition)
|
pipeline(transition)
|
||||||
assert len(after_calls) == 0
|
assert len(after_calls) == 0
|
||||||
|
|
||||||
# Test reset_hook
|
|
||||||
reset_calls = []
|
|
||||||
|
|
||||||
def reset_hook():
|
|
||||||
reset_calls.append(True)
|
|
||||||
|
|
||||||
pipeline.register_reset_hook(reset_hook)
|
|
||||||
pipeline.reset()
|
|
||||||
assert len(reset_calls) == 1
|
|
||||||
|
|
||||||
pipeline.unregister_reset_hook(reset_hook)
|
|
||||||
reset_calls.clear()
|
|
||||||
pipeline.reset()
|
|
||||||
assert len(reset_calls) == 0
|
|
||||||
|
|
||||||
|
|
||||||
def test_unregister_nonexistent_hook():
|
def test_unregister_nonexistent_hook():
|
||||||
"""Test error handling when unregistering hooks that don't exist."""
|
"""Test error handling when unregistering hooks that don't exist."""
|
||||||
@@ -461,9 +420,6 @@ def test_unregister_nonexistent_hook():
|
|||||||
with pytest.raises(ValueError, match="not found in after_step_hooks"):
|
with pytest.raises(ValueError, match="not found in after_step_hooks"):
|
||||||
pipeline.unregister_after_step_hook(some_hook)
|
pipeline.unregister_after_step_hook(some_hook)
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="not found in reset_hooks"):
|
|
||||||
pipeline.unregister_reset_hook(reset_hook)
|
|
||||||
|
|
||||||
|
|
||||||
def test_multiple_hooks_and_selective_unregister():
|
def test_multiple_hooks_and_selective_unregister():
|
||||||
"""Test registering multiple hooks and selectively unregistering them."""
|
"""Test registering multiple hooks and selectively unregistering them."""
|
||||||
@@ -552,22 +508,6 @@ def test_hook_execution_order_documentation():
|
|||||||
assert execution_order == ["A", "C", "B"] # B is now last
|
assert execution_order == ["A", "C", "B"] # B is now last
|
||||||
|
|
||||||
|
|
||||||
def test_profile_steps():
|
|
||||||
"""Test step profiling functionality."""
|
|
||||||
step1 = MockStep("step1")
|
|
||||||
step2 = MockStep("step2")
|
|
||||||
pipeline = RobotProcessor([step1, step2])
|
|
||||||
|
|
||||||
transition = create_transition()
|
|
||||||
|
|
||||||
profile_results = pipeline.profile_steps(transition, num_runs=10)
|
|
||||||
|
|
||||||
assert len(profile_results) == 2
|
|
||||||
assert "step_0_MockStep" in profile_results
|
|
||||||
assert "step_1_MockStep" in profile_results
|
|
||||||
assert all(isinstance(time, float) and time >= 0 for time in profile_results.values())
|
|
||||||
|
|
||||||
|
|
||||||
def test_save_and_load_pretrained():
|
def test_save_and_load_pretrained():
|
||||||
"""Test saving and loading pipeline.
|
"""Test saving and loading pipeline.
|
||||||
|
|
||||||
@@ -581,7 +521,7 @@ def test_save_and_load_pretrained():
|
|||||||
step1.counter = 5
|
step1.counter = 5
|
||||||
step2.counter = 10
|
step2.counter = 10
|
||||||
|
|
||||||
pipeline = RobotProcessor([step1, step2], name="TestPipeline", seed=42)
|
pipeline = RobotProcessor([step1, step2], name="TestPipeline")
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
# Save pipeline
|
# Save pipeline
|
||||||
@@ -596,7 +536,6 @@ def test_save_and_load_pretrained():
|
|||||||
config = json.load(f)
|
config = json.load(f)
|
||||||
|
|
||||||
assert config["name"] == "TestPipeline"
|
assert config["name"] == "TestPipeline"
|
||||||
assert config["seed"] == 42
|
|
||||||
assert len(config["steps"]) == 2
|
assert len(config["steps"]) == 2
|
||||||
|
|
||||||
# Verify counters are saved in config, not in separate state files
|
# Verify counters are saved in config, not in separate state files
|
||||||
@@ -607,7 +546,6 @@ def test_save_and_load_pretrained():
|
|||||||
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir)
|
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir)
|
||||||
|
|
||||||
assert loaded_pipeline.name == "TestPipeline"
|
assert loaded_pipeline.name == "TestPipeline"
|
||||||
assert loaded_pipeline.seed == 42
|
|
||||||
assert len(loaded_pipeline) == 2
|
assert len(loaded_pipeline) == 2
|
||||||
|
|
||||||
# Check that counter was restored from config
|
# Check that counter was restored from config
|
||||||
@@ -1255,10 +1193,10 @@ def test_repr_with_custom_name():
|
|||||||
def test_repr_with_seed():
|
def test_repr_with_seed():
|
||||||
"""Test __repr__ with seed parameter."""
|
"""Test __repr__ with seed parameter."""
|
||||||
step = MockStep("test_step")
|
step = MockStep("test_step")
|
||||||
pipeline = RobotProcessor([step], seed=42)
|
pipeline = RobotProcessor([step])
|
||||||
repr_str = repr(pipeline)
|
repr_str = repr(pipeline)
|
||||||
|
|
||||||
expected = "RobotProcessor(name='RobotProcessor', steps=1: [MockStep], seed=42)"
|
expected = "RobotProcessor(name='RobotProcessor', steps=1: [MockStep])"
|
||||||
assert repr_str == expected
|
assert repr_str == expected
|
||||||
|
|
||||||
|
|
||||||
@@ -1266,19 +1204,17 @@ def test_repr_with_custom_name_and_seed():
|
|||||||
"""Test __repr__ with both custom name and seed."""
|
"""Test __repr__ with both custom name and seed."""
|
||||||
step1 = MockStep("step1")
|
step1 = MockStep("step1")
|
||||||
step2 = MockStepWithoutOptionalMethods()
|
step2 = MockStepWithoutOptionalMethods()
|
||||||
pipeline = RobotProcessor([step1, step2], name="MyProcessor", seed=123)
|
pipeline = RobotProcessor([step1, step2], name="MyProcessor")
|
||||||
repr_str = repr(pipeline)
|
repr_str = repr(pipeline)
|
||||||
|
|
||||||
expected = (
|
expected = "RobotProcessor(name='MyProcessor', steps=2: [MockStep, MockStepWithoutOptionalMethods])"
|
||||||
"RobotProcessor(name='MyProcessor', steps=2: [MockStep, MockStepWithoutOptionalMethods], seed=123)"
|
|
||||||
)
|
|
||||||
assert repr_str == expected
|
assert repr_str == expected
|
||||||
|
|
||||||
|
|
||||||
def test_repr_without_seed():
|
def test_repr_without_seed():
|
||||||
"""Test __repr__ when seed is explicitly None (should not show seed)."""
|
"""Test __repr__ when seed is explicitly None (should not show seed)."""
|
||||||
step = MockStep("test_step")
|
step = MockStep("test_step")
|
||||||
pipeline = RobotProcessor([step], name="TestProcessor", seed=None)
|
pipeline = RobotProcessor([step], name="TestProcessor")
|
||||||
repr_str = repr(pipeline)
|
repr_str = repr(pipeline)
|
||||||
|
|
||||||
expected = "RobotProcessor(name='TestProcessor', steps=1: [MockStep])"
|
expected = "RobotProcessor(name='TestProcessor', steps=1: [MockStep])"
|
||||||
@@ -1306,10 +1242,10 @@ def test_repr_edge_case_long_names():
|
|||||||
step3 = MockStepWithTensorState()
|
step3 = MockStepWithTensorState()
|
||||||
step4 = MockNonModuleStepWithState()
|
step4 = MockNonModuleStepWithState()
|
||||||
|
|
||||||
pipeline = RobotProcessor([step1, step2, step3, step4], name="LongNames", seed=999)
|
pipeline = RobotProcessor([step1, step2, step3, step4], name="LongNames")
|
||||||
repr_str = repr(pipeline)
|
repr_str = repr(pipeline)
|
||||||
|
|
||||||
expected = "RobotProcessor(name='LongNames', steps=4: [MockStepWithNonSerializableParam, MockStepWithoutOptionalMethods, ..., MockNonModuleStepWithState], seed=999)"
|
expected = "RobotProcessor(name='LongNames', steps=4: [MockStepWithNonSerializableParam, MockStepWithoutOptionalMethods, ..., MockNonModuleStepWithState])"
|
||||||
assert repr_str == expected
|
assert repr_str == expected
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user