diff --git a/src/lerobot/processor/__init__.py b/src/lerobot/processor/__init__.py index 147db78f7..9589ef695 100644 --- a/src/lerobot/processor/__init__.py +++ b/src/lerobot/processor/__init__.py @@ -15,6 +15,14 @@ # limitations under the License. from .batch_processor import ToBatchProcessor +from .converters import ( + batch_to_transition, + create_transition, + merge_transitions, + transition_to_batch, + transition_to_dataset_frame, +) +from .core import EnvTransition, TransitionKey from .delta_action_processor import MapDeltaActionToRobotAction, MapTensorToDeltaActionDict from .device_processor import DeviceProcessor from .gym_action_processor import Numpy2TorchActionProcessor, Torch2NumpyActionProcessor @@ -33,7 +41,6 @@ from .observation_processor import VanillaObservationProcessor from .pipeline import ( ActionProcessor, DoneProcessor, - EnvTransition, IdentityProcessor, InfoProcessor, ObservationProcessor, @@ -42,7 +49,6 @@ from .pipeline import ( ProcessorStepRegistry, RewardProcessor, RobotProcessor, - TransitionKey, TruncatedProcessor, ) from .rename_processor import RenameProcessor @@ -52,22 +58,24 @@ __all__ = [ "ActionProcessor", "AddTeleopActionAsComplimentaryData", "AddTeleopEventsAsInfo", + "batch_to_transition", + "create_transition", "DeviceProcessor", "DoneProcessor", - "MapDeltaActionToRobotAction", - "MapTensorToDeltaActionDict", "EnvTransition", "GripperPenaltyProcessor", + "hotswap_stats", "IdentityProcessor", "ImageCropResizeProcessor", "InfoProcessor", "InterventionActionProcessor", "JointVelocityProcessor", "MapDeltaActionToRobotAction", + "MapTensorToDeltaActionDict", + "merge_transitions", "MotorCurrentProcessor", "NormalizerProcessor", - "UnnormalizerProcessor", - "hotswap_stats", + "Numpy2TorchActionProcessor", "ObservationProcessor", "ProcessorKwargs", "ProcessorStep", @@ -76,12 +84,14 @@ __all__ = [ "RewardClassifierProcessor", "RewardProcessor", "RobotProcessor", + "TimeLimitProcessor", "ToBatchProcessor", "TokenizerProcessor", - "TimeLimitProcessor", - "Numpy2TorchActionProcessor", "Torch2NumpyActionProcessor", + "transition_to_batch", + "transition_to_dataset_frame", "TransitionKey", "TruncatedProcessor", + "UnnormalizerProcessor", "VanillaObservationProcessor", ] diff --git a/src/lerobot/processor/converters.py b/src/lerobot/processor/converters.py index fb3d9b860..550bb470d 100644 --- a/src/lerobot/processor/converters.py +++ b/src/lerobot/processor/converters.py @@ -16,7 +16,7 @@ from __future__ import annotations -from collections.abc import Iterable, Sequence +from collections.abc import Sequence from copy import deepcopy from functools import singledispatch from typing import Any @@ -27,7 +27,7 @@ from scipy.spatial.transform import Rotation from lerobot.constants import ACTION, DONE, OBS_IMAGES, OBS_STATE, REWARD, TRUNCATED -from .pipeline import EnvTransition, TransitionKey +from .core import EnvTransition, TransitionKey @singledispatch @@ -139,7 +139,8 @@ def _(value: dict, *, device=None, **kwargs) -> dict: return result -def _from_tensor(x: Any): +def _from_tensor(x: torch.Tensor | Any) -> np.ndarray | float | int | Any: + """Convert tensor to numpy/scalar if needed.""" if isinstance(x, torch.Tensor): return x.item() if x.numel() == 1 else x.detach().cpu().numpy() return x @@ -159,17 +160,76 @@ def _split_obs_to_state_and_images(obs: dict[str, Any]) -> tuple[dict[str, Any], return state, images -def make_obs_act_transition( - *, obs: dict[str, Any] | None = None, act: dict[str, Any] | None = None +# ============================================================================ +# Private Helper Functions (Common Logic) +# ============================================================================ + + +def _extract_complementary_data(batch: dict[str, Any]) -> dict[str, Any]: + """Extract complementary data (pad flags, task, index, task_index).""" + pad_keys = {k: v for k, v in batch.items() if "_is_pad" in k} + task_key = {"task": batch["task"]} if "task" in batch else {} + index_key = {"index": batch["index"]} if "index" in batch else {} + task_index_key = {"task_index": batch["task_index"]} if "task_index" in batch else {} + + return {**pad_keys, **task_key, **index_key, **task_index_key} + + +def _merge_transitions(base: EnvTransition, other: EnvTransition) -> EnvTransition: + """Merge two transitions, with other taking precedence.""" + out = deepcopy(base) + + for key in ( + TransitionKey.OBSERVATION, + TransitionKey.ACTION, + TransitionKey.INFO, + TransitionKey.COMPLEMENTARY_DATA, + ): + if other.get(key): + out.setdefault(key, {}).update(deepcopy(other[key])) + + for k in (TransitionKey.REWARD, TransitionKey.DONE, TransitionKey.TRUNCATED): + if k in other: + out[k] = other[k] + return out + + +# ============================================================================ +# Core Conversion Functions +# ============================================================================ + + +def create_transition( + observation: dict[str, Any] | None = None, + action: dict[str, Any] | None = None, + reward: float = 0.0, + done: bool = False, + truncated: bool = False, + info: dict[str, Any] | None = None, + complementary_data: dict[str, Any] | None = None, ) -> EnvTransition: + """Create an EnvTransition with sensible defaults. + + Args: + observation: Observation dictionary. + action: Action dictionary. + reward: Scalar reward value. + done: Episode termination flag. + truncated: Episode truncation flag. + info: Additional info dictionary. + complementary_data: Complementary data dictionary. + + Returns: + Complete EnvTransition dictionary. + """ return { - TransitionKey.OBSERVATION: {} if obs is None else obs, - TransitionKey.ACTION: {} if act is None else act, - TransitionKey.INFO: {}, - TransitionKey.COMPLEMENTARY_DATA: {}, - TransitionKey.REWARD: None, - TransitionKey.DONE: None, - TransitionKey.TRUNCATED: None, + TransitionKey.OBSERVATION: observation, + TransitionKey.ACTION: action, + TransitionKey.REWARD: reward, + TransitionKey.DONE: done, + TransitionKey.TRUNCATED: truncated, + TransitionKey.INFO: info if info is not None else {}, + TransitionKey.COMPLEMENTARY_DATA: complementary_data if complementary_data is not None else {}, } @@ -187,7 +247,7 @@ def to_transition_teleop_action(action: dict[str, Any]) -> EnvTransition: arr = np.array(v) if np.isscalar(v) else v act_dict[f"{ACTION}.{k}"] = to_tensor(arr) - return make_obs_act_transition(act=act_dict) + return create_transition(observation={}, action=act_dict) # TODO(Adil, Pepijn): Overtime we can maybe add these converters to pipeline.py itself @@ -205,7 +265,7 @@ def to_transition_robot_observation(observation: dict[str, Any]) -> EnvTransitio for cam, img in images.items(): obs_dict[f"{OBS_IMAGES}.{cam}"] = img - return make_obs_act_transition(obs=obs_dict) + return create_transition(observation=obs_dict, action={}) def to_output_robot_action(transition: EnvTransition) -> dict[str, Any]: @@ -226,69 +286,60 @@ def to_output_robot_action(transition: EnvTransition) -> dict[str, Any]: return out -def to_dataset_frame( - transitions_or_transition: EnvTransition | Iterable[EnvTransition], features: dict[str, dict] -) -> dict[str, any]: - """ - Converts a single EnvTransition or an iterable of them into a flat, - dataset-friendly dictionary for training or evaluation, according to - the provided `features` spec. +def merge_transitions(transitions: Sequence[EnvTransition] | EnvTransition) -> EnvTransition: + """Merge multiple transitions or return single transition. Args: - transitions_or_transition: Either a single EnvTransition dict - or an iterable of them (which will be merged). - features (dict[str, dict]): - A feature specification dictionary: - - 'action': dict with 'names': list of action feature names - - 'observation.state': dict with 'names': list of state feature names - - keys starting with 'observation.images.' are passed through + transitions: Either a single transition or iterable of transitions. Returns: - batch (dict[str, any]): Flat dictionary containing: - - numpy arrays for "observation.state" and "action" - - any image tensors defined in features - - next.{reward,done,truncated} - - info dict - - *_is_pad flags and task from complementary_data + Merged EnvTransition. + """ + if isinstance(transitions, EnvTransition): # Single transition + return transitions + + items = list(transitions) + if not items: + raise ValueError("merge_transitions() requires a non-empty sequence of transitions") + + result = items[0] + for t in items[1:]: + result = _merge_transitions(result, t) + return result + + +def transition_to_dataset_frame( + transitions_or_transition: EnvTransition | Sequence[EnvTransition], features: dict[str, dict] +) -> dict[str, Any]: + """Convert a single EnvTransition or an iterable of them into a flat, dataset-friendly dictionary for training or evaluation. + + Processes transitions according to the provided feature specification and returns + data in the format expected by machine learning models and datasets. + + Args: + transitions_or_transition: Either a single EnvTransition dict or an iterable of them + (which will be merged using merge_transitions). + features: Feature specification dictionary with the following structure: + - 'action': dict with 'names': list of action feature names + - 'observation.state': dict with 'names': list of state feature names + - keys starting with 'observation.images.' are passed through as-is + + Returns: + Flat dictionary containing: + - numpy arrays for "observation.state" and "action" (vectorized from feature names) + - any image tensors defined in features (passed through unchanged) + - next.{reward,done,truncated} scalar values + - info dict + - *_is_pad flags and task from complementary_data """ action_names = features.get(ACTION, {}).get("names", []) obs_state_names = features.get(OBS_STATE, {}).get("names", []) image_keys = [k for k in features if k.startswith(OBS_IMAGES)] - def _merge(base: EnvTransition, other: EnvTransition) -> EnvTransition: - out = deepcopy(base) - for key in ( - TransitionKey.OBSERVATION, - TransitionKey.ACTION, - TransitionKey.INFO, - TransitionKey.COMPLEMENTARY_DATA, - ): - if other.get(key): - out.setdefault(key, {}).update(deepcopy(other[key])) - for k in (TransitionKey.REWARD, TransitionKey.DONE, TransitionKey.TRUNCATED): - if k in other: - out[k] = other[k] - return out - - def _ensure_transition(obj) -> EnvTransition: - # single transition - if isinstance(obj, dict) and any(isinstance(k, TransitionKey) for k in obj): - return obj - # iterable of transitions - if isinstance(obj, Iterable): - items = list(obj) - if not items: - return {} - acc = items[0] - for t in items[1:]: - acc = _merge(acc, t) - return acc - raise TypeError("Expected EnvTransition or iterable of them") - - tr = _ensure_transition(transitions_or_transition) + tr = merge_transitions(transitions_or_transition) obs = tr.get(TransitionKey.OBSERVATION, {}) or {} act = tr.get(TransitionKey.ACTION, {}) or {} - batch: dict[str, any] = {} + batch: dict[str, Any] = {} # Images passthrough for k in image_keys: @@ -305,6 +356,7 @@ def to_dataset_frame( vals = [_from_tensor(act.get(f"{ACTION}.{n}", 0.0)) for n in action_names] batch[ACTION] = np.asarray(vals, dtype=np.float32) + # Add transition metadata if tr.get(TransitionKey.REWARD) is not None: batch[REWARD] = _from_tensor(tr[TransitionKey.REWARD]) if tr.get(TransitionKey.DONE) is not None: @@ -324,3 +376,90 @@ def to_dataset_frame( batch["task"] = comp["task"] return batch + + +def batch_to_transition(batch: dict[str, Any]) -> EnvTransition: + """Convert a batch dict coming from LeRobot replay/dataset code into an EnvTransition dictionary. + + The function maps well known keys to the EnvTransition structure. Missing keys are + filled with sane defaults (None or 0.0/False). + + Keys recognised (case-sensitive): + * "observation.*" (keys starting with "observation." are grouped into observation dict) + * "action" + * "next.reward" + * "next.done" + * "next.truncated" + * "info" + * "_is_pad" patterns (padding flags) + * "task", "index", "task_index" (complementary data) + + Additional keys are ignored so that existing dataloaders can carry extra + metadata without breaking the processor. + + Args: + batch: Batch dictionary from datasets or dataloaders containing the above keys. + + Returns: + EnvTransition dictionary with properly structured transition data. + """ + + # Validate input type + if not isinstance(batch, dict): + raise ValueError(f"EnvTransition must be a dictionary. Got {type(batch).__name__}") + + # Extract observation keys + observation_keys = {k: v for k, v in batch.items() if k.startswith("observation.")} + complementary_data = _extract_complementary_data(batch) + + return create_transition( + observation=observation_keys if observation_keys else None, + action=batch.get("action"), + reward=batch.get("next.reward", 0.0), + done=batch.get("next.done", False), + truncated=batch.get("next.truncated", False), + info=batch.get("info", {}), + complementary_data=complementary_data if complementary_data else None, + ) + + +def transition_to_batch(transition: EnvTransition) -> dict[str, Any]: + """Inverse of batch_to_transition. Returns a dict with canonical field names used throughout LeRobot. + + Converts an EnvTransition back to the batch format expected by datasets, dataloaders, + and other LeRobot components. + + Output format: + * "action": Action data from transition + * "next.reward": Reward value (defaults to 0.0) + * "next.done": Done flag (defaults to False) + * "next.truncated": Truncated flag (defaults to False) + * "info": Info dictionary (defaults to {}) + * Flattened observation keys (e.g., "observation.state", "observation.images.cam1") + * Complementary data fields ("task", "index", "task_index", padding flags) + + Args: + transition: EnvTransition dictionary to convert. + + Returns: + Batch dictionary with canonical LeRobot field names suitable for dataloaders. + """ + batch = { + "action": transition.get(TransitionKey.ACTION), + "next.reward": transition.get(TransitionKey.REWARD, 0.0), + "next.done": transition.get(TransitionKey.DONE, False), + "next.truncated": transition.get(TransitionKey.TRUNCATED, False), + "info": transition.get(TransitionKey.INFO, {}), + } + + # Add complementary data + comp_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) + if comp_data: + batch.update(comp_data) + + # Flatten observation dict + observation = transition.get(TransitionKey.OBSERVATION) + if isinstance(observation, dict): + batch.update(observation) + + return batch diff --git a/src/lerobot/processor/core.py b/src/lerobot/processor/core.py new file mode 100644 index 000000000..a60a52d02 --- /dev/null +++ b/src/lerobot/processor/core.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from enum import Enum +from typing import Any, TypedDict + +import torch + + +class TransitionKey(str, Enum): + """Keys for accessing EnvTransition dictionary components.""" + + # TODO(Steven): Use consts + OBSERVATION = "observation" + ACTION = "action" + REWARD = "reward" + DONE = "done" + TRUNCATED = "truncated" + INFO = "info" + COMPLEMENTARY_DATA = "complementary_data" + + +EnvTransition = TypedDict( + "EnvTransition", + { + TransitionKey.OBSERVATION.value: dict[str, Any] | None, + TransitionKey.ACTION.value: Any | torch.Tensor | None, + TransitionKey.REWARD.value: float | torch.Tensor | None, + TransitionKey.DONE.value: bool | torch.Tensor | None, + TransitionKey.TRUNCATED.value: bool | torch.Tensor | None, + TransitionKey.INFO.value: dict[str, Any] | None, + TransitionKey.COMPLEMENTARY_DATA.value: dict[str, Any] | None, + }, +) diff --git a/src/lerobot/processor/device_processor.py b/src/lerobot/processor/device_processor.py index 78a3ad797..c099d050a 100644 --- a/src/lerobot/processor/device_processor.py +++ b/src/lerobot/processor/device_processor.py @@ -18,7 +18,8 @@ from typing import Any import torch -from lerobot.processor.pipeline import EnvTransition, ProcessorStep, ProcessorStepRegistry, TransitionKey +from lerobot.processor.core import EnvTransition, TransitionKey +from lerobot.processor.pipeline import ProcessorStep, ProcessorStepRegistry from lerobot.utils.utils import get_safe_torch_device diff --git a/src/lerobot/processor/pipeline.py b/src/lerobot/processor/pipeline.py index 84d3e4164..7054ba439 100644 --- a/src/lerobot/processor/pipeline.py +++ b/src/lerobot/processor/pipeline.py @@ -22,7 +22,6 @@ from abc import ABC, abstractmethod from collections.abc import Callable, Iterable, Sequence from copy import deepcopy from dataclasses import dataclass, field -from enum import Enum from pathlib import Path from typing import Any, Generic, TypedDict, TypeVar, cast @@ -33,37 +32,13 @@ from safetensors.torch import load_file, save_file from lerobot.configs.types import PolicyFeature +from .converters import batch_to_transition, transition_to_batch +from .core import EnvTransition, TransitionKey + # Type variable for generic processor output type TOutput = TypeVar("TOutput") -class TransitionKey(str, Enum): - """Keys for accessing EnvTransition dictionary components.""" - - # TODO(Steven): Use consts - OBSERVATION = "observation" - ACTION = "action" - REWARD = "reward" - DONE = "done" - TRUNCATED = "truncated" - INFO = "info" - COMPLEMENTARY_DATA = "complementary_data" - - -EnvTransition = TypedDict( - "EnvTransition", - { - TransitionKey.OBSERVATION.value: dict[str, Any] | None, - TransitionKey.ACTION.value: Any | torch.Tensor | None, - TransitionKey.REWARD.value: float | torch.Tensor | None, - TransitionKey.DONE.value: bool | torch.Tensor | None, - TransitionKey.TRUNCATED.value: bool | torch.Tensor | None, - TransitionKey.INFO.value: dict[str, Any] | None, - TransitionKey.COMPLEMENTARY_DATA.value: dict[str, Any] | None, - }, -) - - class ProcessorStepRegistry: """Registry for processor steps that enables saving/loading by name instead of module path.""" @@ -199,93 +174,6 @@ class ProcessorStep(ABC): return features -def _default_batch_to_transition(batch: dict[str, Any]) -> EnvTransition: # noqa: D401 - """Convert a *batch* dict coming from Learobot replay/dataset code into an - ``EnvTransition`` dictionary. - - The function maps well known keys to the EnvTransition structure. Missing keys are - filled with sane defaults (``None`` or ``0.0``/``False``). - - Keys recognised (case-sensitive): - - * "observation.*" (keys starting with "observation." are grouped into observation dict) - * "action" - * "next.reward" - * "next.done" - * "next.truncated" - * "info" - - Additional keys are ignored so that existing dataloaders can carry extra - metadata without breaking the processor. - """ - - # Validate input type - if not isinstance(batch, dict): - raise ValueError(f"EnvTransition must be a dictionary. Got {type(batch).__name__}") - - # Extract observation keys - observation_keys = {k: v for k, v in batch.items() if k.startswith("observation.")} - observation = observation_keys if observation_keys else None - - # Extract padding, task, index, and task_index keys for complementary data - pad_keys = {k: v for k, v in batch.items() if "_is_pad" in k} - task_key = {"task": batch["task"]} if "task" in batch else {} - index_key = {"index": batch["index"]} if "index" in batch else {} - task_index_key = {"task_index": batch["task_index"]} if "task_index" in batch else {} - complementary_data = ( - {**pad_keys, **task_key, **index_key, **task_index_key} - if pad_keys or task_key or index_key or task_index_key - else {} - ) - - transition: EnvTransition = { - TransitionKey.OBSERVATION: observation, - TransitionKey.ACTION: batch.get("action"), - TransitionKey.REWARD: batch.get("next.reward", 0.0), - TransitionKey.DONE: batch.get("next.done", False), - TransitionKey.TRUNCATED: batch.get("next.truncated", False), - TransitionKey.INFO: batch.get("info", {}), - TransitionKey.COMPLEMENTARY_DATA: complementary_data, - } - return transition - - -def _default_transition_to_batch(transition: EnvTransition) -> dict[str, Any]: # noqa: D401 - """Inverse of :pyfunc:`_default_batch_to_transition`. Returns a dict with - the canonical field names used throughout *LeRobot*. - """ - - batch = { - "action": transition.get(TransitionKey.ACTION), - "next.reward": transition.get(TransitionKey.REWARD, 0.0), - "next.done": transition.get(TransitionKey.DONE, False), - "next.truncated": transition.get(TransitionKey.TRUNCATED, False), - "info": transition.get(TransitionKey.INFO, {}), - } - - # Add padding, task, index, and task_index data from complementary_data - complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA) - if complementary_data: - pad_data = {k: v for k, v in complementary_data.items() if "_is_pad" in k} - batch.update(pad_data) - - if "task" in complementary_data: - batch["task"] = complementary_data["task"] - - if "index" in complementary_data: - batch["index"] = complementary_data["index"] - - if "task_index" in complementary_data: - batch["task_index"] = complementary_data["task_index"] - - # Handle observation - flatten dict to observation.* keys if it's a dict - observation = transition.get(TransitionKey.OBSERVATION) - if isinstance(observation, dict): - batch.update(observation) - - return batch - - class ProcessorKwargs(TypedDict, total=False): """Keyword arguments for RobotProcessor constructor.""" @@ -357,15 +245,13 @@ class RobotProcessor(ModelHubMixin, Generic[TOutput]): steps: Sequence[ProcessorStep] = field(default_factory=list) name: str = "RobotProcessor" - to_transition: Callable[[dict[str, Any]], EnvTransition] = field( - default_factory=lambda: _default_batch_to_transition, repr=False - ) + to_transition: Callable[[dict[str, Any]], EnvTransition] = field(default=batch_to_transition, repr=False) to_output: Callable[[EnvTransition], TOutput] = field( # Cast is necessary here: Working around Python type-checker limitation. # _default_transition_to_batch returns dict[str, Any], but we need it to be TOutput # for the generic to work. When no explicit type is given, TOutput defaults to dict[str, Any], # making this cast safe. - default_factory=lambda: cast(Callable[[EnvTransition], TOutput], _default_transition_to_batch), + default_factory=lambda: cast(Callable[[EnvTransition], TOutput], transition_to_batch), repr=False, ) @@ -767,11 +653,11 @@ class RobotProcessor(ModelHubMixin, Generic[TOutput]): return cls( steps=steps, name=loaded_config.get("name", "RobotProcessor"), - to_transition=to_transition or _default_batch_to_transition, + to_transition=to_transition or batch_to_transition, # Cast is necessary here: Same type-checker limitation as above. # When to_output is None, we use the default which returns dict[str, Any]. # The cast ensures type consistency with the generic TOutput parameter. - to_output=to_output or cast(Callable[[EnvTransition], TOutput], _default_transition_to_batch), + to_output=to_output or cast(Callable[[EnvTransition], TOutput], transition_to_batch), ) def __len__(self) -> int: diff --git a/src/lerobot/record.py b/src/lerobot/record.py index 0ebe23501..e83aa8e52 100644 --- a/src/lerobot/record.py +++ b/src/lerobot/record.py @@ -78,10 +78,10 @@ from lerobot.policies.factory import make_policy, make_pre_post_processors from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.processor import RobotProcessor from lerobot.processor.converters import ( - to_dataset_frame, to_output_robot_action, to_transition_robot_observation, to_transition_teleop_action, + transition_to_dataset_frame, ) from lerobot.processor.pipeline import IdentityProcessor, TransitionKey from lerobot.processor.rename_processor import rename_stats @@ -308,7 +308,7 @@ def record_loop( # Get action from either policy or teleop if policy is not None and preprocessor is not None and postprocessor is not None: if dataset is not None: - observation_frame = to_dataset_frame( + observation_frame = transition_to_dataset_frame( obs_transition, dataset.features ) # Convert the observation to the dataset format @@ -366,7 +366,7 @@ def record_loop( # Write to dataset if dataset is not None: - # If to_dataset_frame is provided, use it to merge the transitions. + # If transition_to_dataset_frame is provided, use it to merge the transitions. merged = [] if obs_transition is not None: # The observation from the robot merged.append(obs_transition) @@ -374,7 +374,7 @@ def record_loop( merged.append(teleop_transition) if policy_transition is not None: # The action from policy merged.append(policy_transition) - frame = to_dataset_frame( + frame = transition_to_dataset_frame( merged if len(merged) > 1 else merged[0], dataset.features ) # Convert the observation to the dataset format dataset.add_frame(frame, task=single_task) diff --git a/src/lerobot/scripts/rl/gym_manipulator.py b/src/lerobot/scripts/rl/gym_manipulator.py index 835b85190..206fc3a05 100644 --- a/src/lerobot/scripts/rl/gym_manipulator.py +++ b/src/lerobot/scripts/rl/gym_manipulator.py @@ -46,6 +46,7 @@ from lerobot.processor import ( ToBatchProcessor, Torch2NumpyActionProcessor, VanillaObservationProcessor, + create_transition, ) from lerobot.processor.pipeline import EnvTransition, TransitionKey from lerobot.robots import ( # noqa: F401 @@ -98,21 +99,6 @@ class GymManipulatorConfig: device: str = "cpu" -def create_transition( - observation=None, action=None, reward=0.0, done=False, truncated=False, info=None, complementary_data=None -) -> dict[str, Any]: - """Create an EnvTransition dictionary with default values.""" - return { - TransitionKey.OBSERVATION: observation, - TransitionKey.ACTION: action, - TransitionKey.REWARD: reward, - TransitionKey.DONE: done, - TransitionKey.TRUNCATED: truncated, - TransitionKey.INFO: info if info is not None else {}, - TransitionKey.COMPLEMENTARY_DATA: complementary_data if complementary_data is not None else {}, - } - - def reset_follower_position(robot_arm: Robot, target_position: np.ndarray) -> None: """Reset robot arm to target position using smooth trajectory.""" current_position_dict = robot_arm.bus.sync_read("Present_Position") diff --git a/tests/processor/test_batch_conversion.py b/tests/processor/test_batch_conversion.py index 63894025d..eb09383a8 100644 --- a/tests/processor/test_batch_conversion.py +++ b/tests/processor/test_batch_conversion.py @@ -1,11 +1,7 @@ import torch -from lerobot.processor.pipeline import ( - RobotProcessor, - TransitionKey, - _default_batch_to_transition, - _default_transition_to_batch, -) +from lerobot.processor.converters import batch_to_transition, transition_to_batch +from lerobot.processor.pipeline import RobotProcessor, TransitionKey def _dummy_batch(): @@ -48,7 +44,7 @@ def test_observation_grouping_roundtrip(): def test_batch_to_transition_observation_grouping(): - """Test that _default_batch_to_transition correctly groups observation.* keys.""" + """Test that batch_to_transition correctly groups observation.* keys.""" batch = { "observation.image.top": torch.randn(1, 3, 128, 128), "observation.image.left": torch.randn(1, 3, 128, 128), @@ -60,7 +56,7 @@ def test_batch_to_transition_observation_grouping(): "info": {"episode": 42}, } - transition = _default_batch_to_transition(batch) + transition = batch_to_transition(batch) # Check observation is a dict with all observation.* keys assert isinstance(transition[TransitionKey.OBSERVATION], dict) @@ -87,7 +83,7 @@ def test_batch_to_transition_observation_grouping(): def test_transition_to_batch_observation_flattening(): - """Test that _default_transition_to_batch correctly flattens observation dict.""" + """Test that transition_to_batch correctly flattens observation dict.""" observation_dict = { "observation.image.top": torch.randn(1, 3, 128, 128), "observation.image.left": torch.randn(1, 3, 128, 128), @@ -104,7 +100,7 @@ def test_transition_to_batch_observation_flattening(): TransitionKey.COMPLEMENTARY_DATA: {}, } - batch = _default_transition_to_batch(transition) + batch = transition_to_batch(transition) # Check that observation.* keys are flattened back to batch assert "observation.image.top" in batch @@ -134,7 +130,7 @@ def test_no_observation_keys(): "info": {"test": "no_obs"}, } - transition = _default_batch_to_transition(batch) + transition = batch_to_transition(batch) # Observation should be None when no observation.* keys assert transition[TransitionKey.OBSERVATION] is None @@ -147,7 +143,7 @@ def test_no_observation_keys(): assert transition[TransitionKey.INFO] == {"test": "no_obs"} # Round trip should work - reconstructed_batch = _default_transition_to_batch(transition) + reconstructed_batch = transition_to_batch(transition) assert reconstructed_batch["action"] == "action_data" assert reconstructed_batch["next.reward"] == 2.0 assert not reconstructed_batch["next.done"] @@ -159,7 +155,7 @@ def test_minimal_batch(): """Test with minimal batch containing only observation.* and action.""" batch = {"observation.state": "minimal_state", "action": "minimal_action"} - transition = _default_batch_to_transition(batch) + transition = batch_to_transition(batch) # Check observation assert transition[TransitionKey.OBSERVATION] == {"observation.state": "minimal_state"} @@ -173,7 +169,7 @@ def test_minimal_batch(): assert transition[TransitionKey.COMPLEMENTARY_DATA] == {} # Round trip - reconstructed_batch = _default_transition_to_batch(transition) + reconstructed_batch = transition_to_batch(transition) assert reconstructed_batch["observation.state"] == "minimal_state" assert reconstructed_batch["action"] == "minimal_action" assert reconstructed_batch["next.reward"] == 0.0 @@ -186,7 +182,7 @@ def test_empty_batch(): """Test behavior with empty batch.""" batch = {} - transition = _default_batch_to_transition(batch) + transition = batch_to_transition(batch) # All fields should have defaults assert transition[TransitionKey.OBSERVATION] is None @@ -198,7 +194,7 @@ def test_empty_batch(): assert transition[TransitionKey.COMPLEMENTARY_DATA] == {} # Round trip - reconstructed_batch = _default_transition_to_batch(transition) + reconstructed_batch = transition_to_batch(transition) assert reconstructed_batch["action"] is None assert reconstructed_batch["next.reward"] == 0.0 assert not reconstructed_batch["next.done"] @@ -219,8 +215,8 @@ def test_complex_nested_observation(): "info": {"episode_length": 200, "success": True}, } - transition = _default_batch_to_transition(batch) - reconstructed_batch = _default_transition_to_batch(transition) + transition = batch_to_transition(batch) + reconstructed_batch = transition_to_batch(transition) # Check that all observation keys are preserved original_obs_keys = {k for k in batch if k.startswith("observation.")} @@ -254,7 +250,7 @@ def test_custom_converter(): def to_tr(batch): # Custom converter that modifies the reward - tr = _default_batch_to_transition(batch) + tr = batch_to_transition(batch) # Double the reward reward = tr.get(TransitionKey.REWARD, 0.0) new_tr = tr.copy() @@ -262,7 +258,7 @@ def test_custom_converter(): return new_tr def to_batch(tr): - batch = _default_transition_to_batch(tr) + batch = transition_to_batch(tr) return batch processor = RobotProcessor(steps=[], to_transition=to_tr, to_output=to_batch) diff --git a/tests/processor/test_converters.py b/tests/processor/test_converters.py index ac2015b48..a7b5a7b8f 100644 --- a/tests/processor/test_converters.py +++ b/tests/processor/test_converters.py @@ -3,11 +3,13 @@ import pytest import torch from lerobot.processor.converters import ( - to_dataset_frame, + batch_to_transition, to_output_robot_action, to_tensor, to_transition_robot_observation, to_transition_teleop_action, + transition_to_batch, + transition_to_dataset_frame, ) from lerobot.processor.pipeline import TransitionKey @@ -107,7 +109,7 @@ def test_to_output_robot_action_strips_prefix_and_filters_pos_keys_only(): assert out["gripper.pos"] == pytest.approx(33.0) -def test_to_dataset_frame_merge_and_pack_vectors_and_metadata(): +def test_transition_to_dataset_frame_merge_and_pack_vectors_and_metadata(): # Fabricate dataset features (as stored in dataset.meta["features"]) features = { # Action vector: 3 elements in specific order @@ -160,7 +162,7 @@ def test_to_dataset_frame_merge_and_pack_vectors_and_metadata(): } # Directly call the refactored function - batch = to_dataset_frame([teleop_transition, robot_transition], features) + batch = transition_to_dataset_frame([teleop_transition, robot_transition], features) # Images passthrough assert "observation.images.front" in batch @@ -377,3 +379,117 @@ def test_to_tensor_unsupported_type(): with pytest.raises(TypeError, match="Unsupported type for tensor conversion"): to_tensor(object()) + + +def create_transition( + observation=None, action=None, reward=0.0, done=False, truncated=False, info=None, complementary_data=None +): + """Helper to create an EnvTransition dictionary.""" + return { + TransitionKey.OBSERVATION: observation, + TransitionKey.ACTION: action, + TransitionKey.REWARD: reward, + TransitionKey.DONE: done, + TransitionKey.TRUNCATED: truncated, + TransitionKey.INFO: info if info is not None else {}, + TransitionKey.COMPLEMENTARY_DATA: complementary_data if complementary_data is not None else {}, + } + + +def test_batch_to_transition_with_index_fields(): + """Test that batch_to_transition handles index and task_index fields correctly.""" + + # Create batch with index and task_index fields + batch = { + "observation.state": torch.randn(1, 7), + "action": torch.randn(1, 4), + "next.reward": 1.5, + "next.done": False, + "task": ["pick_cube"], + "index": torch.tensor([42], dtype=torch.int64), + "task_index": torch.tensor([3], dtype=torch.int64), + } + + transition = batch_to_transition(batch) + + # Check basic transition structure + assert TransitionKey.OBSERVATION in transition + assert TransitionKey.ACTION in transition + assert TransitionKey.COMPLEMENTARY_DATA in transition + + # Check that index and task_index are in complementary_data + comp_data = transition[TransitionKey.COMPLEMENTARY_DATA] + assert "index" in comp_data + assert "task_index" in comp_data + assert "task" in comp_data + + # Verify values + assert torch.equal(comp_data["index"], batch["index"]) + assert torch.equal(comp_data["task_index"], batch["task_index"]) + assert comp_data["task"] == batch["task"] + + +def testtransition_to_batch_with_index_fields(): + """Test that transition_to_batch handles index and task_index fields correctly.""" + + # Create transition with index and task_index in complementary_data + transition = create_transition( + observation={"observation.state": torch.randn(1, 7)}, + action=torch.randn(1, 4), + reward=1.5, + done=False, + complementary_data={ + "task": ["navigate"], + "index": torch.tensor([100], dtype=torch.int64), + "task_index": torch.tensor([5], dtype=torch.int64), + }, + ) + + batch = transition_to_batch(transition) + + # Check that index and task_index are in the batch + assert "index" in batch + assert "task_index" in batch + assert "task" in batch + + # Verify values + assert torch.equal(batch["index"], transition[TransitionKey.COMPLEMENTARY_DATA]["index"]) + assert torch.equal(batch["task_index"], transition[TransitionKey.COMPLEMENTARY_DATA]["task_index"]) + assert batch["task"] == transition[TransitionKey.COMPLEMENTARY_DATA]["task"] + + +def test_batch_to_transition_without_index_fields(): + """Test that conversion works without index and task_index fields.""" + + # Batch without index/task_index + batch = { + "observation.state": torch.randn(1, 7), + "action": torch.randn(1, 4), + "task": ["pick_cube"], + } + + transition = batch_to_transition(batch) + comp_data = transition[TransitionKey.COMPLEMENTARY_DATA] + + # Should have task but not index/task_index + assert "task" in comp_data + assert "index" not in comp_data + assert "task_index" not in comp_data + + +def test_transition_to_batch_without_index_fields(): + """Test that conversion works without index and task_index fields.""" + + # Transition without index/task_index + transition = create_transition( + observation={"observation.state": torch.randn(1, 7)}, + action=torch.randn(1, 4), + complementary_data={"task": ["navigate"]}, + ) + + batch = transition_to_batch(transition) + + # Should have task but not index/task_index + assert "task" in batch + assert "index" not in batch + assert "task_index" not in batch diff --git a/tests/processor/test_pipeline.py b/tests/processor/test_pipeline.py index 57db5e727..4f5813024 100644 --- a/tests/processor/test_pipeline.py +++ b/tests/processor/test_pipeline.py @@ -1659,109 +1659,6 @@ def test_state_file_naming_with_multiple_processors(): assert loaded_post.steps[0].window_size == 10 -def test_default_batch_to_transition_with_index_fields(): - """Test that _default_batch_to_transition handles index and task_index fields correctly.""" - from lerobot.processor.pipeline import _default_batch_to_transition - - # Create batch with index and task_index fields - batch = { - "observation.state": torch.randn(1, 7), - "action": torch.randn(1, 4), - "next.reward": 1.5, - "next.done": False, - "task": ["pick_cube"], - "index": torch.tensor([42], dtype=torch.int64), - "task_index": torch.tensor([3], dtype=torch.int64), - } - - transition = _default_batch_to_transition(batch) - - # Check basic transition structure - assert TransitionKey.OBSERVATION in transition - assert TransitionKey.ACTION in transition - assert TransitionKey.COMPLEMENTARY_DATA in transition - - # Check that index and task_index are in complementary_data - comp_data = transition[TransitionKey.COMPLEMENTARY_DATA] - assert "index" in comp_data - assert "task_index" in comp_data - assert "task" in comp_data - - # Verify values - assert torch.equal(comp_data["index"], batch["index"]) - assert torch.equal(comp_data["task_index"], batch["task_index"]) - assert comp_data["task"] == batch["task"] - - -def test_default_transition_to_batch_with_index_fields(): - """Test that _default_transition_to_batch handles index and task_index fields correctly.""" - from lerobot.processor.pipeline import _default_transition_to_batch - - # Create transition with index and task_index in complementary_data - transition = create_transition( - observation={"observation.state": torch.randn(1, 7)}, - action=torch.randn(1, 4), - reward=1.5, - done=False, - complementary_data={ - "task": ["navigate"], - "index": torch.tensor([100], dtype=torch.int64), - "task_index": torch.tensor([5], dtype=torch.int64), - }, - ) - - batch = _default_transition_to_batch(transition) - - # Check that index and task_index are in the batch - assert "index" in batch - assert "task_index" in batch - assert "task" in batch - - # Verify values - assert torch.equal(batch["index"], transition[TransitionKey.COMPLEMENTARY_DATA]["index"]) - assert torch.equal(batch["task_index"], transition[TransitionKey.COMPLEMENTARY_DATA]["task_index"]) - assert batch["task"] == transition[TransitionKey.COMPLEMENTARY_DATA]["task"] - - -def test_batch_to_transition_without_index_fields(): - """Test that conversion works without index and task_index fields.""" - from lerobot.processor.pipeline import _default_batch_to_transition - - # Batch without index/task_index - batch = { - "observation.state": torch.randn(1, 7), - "action": torch.randn(1, 4), - "task": ["pick_cube"], - } - - transition = _default_batch_to_transition(batch) - comp_data = transition[TransitionKey.COMPLEMENTARY_DATA] - - # Should have task but not index/task_index - assert "task" in comp_data - assert "index" not in comp_data - assert "task_index" not in comp_data - - -def test_transition_to_batch_without_index_fields(): - """Test that conversion works without index and task_index fields.""" - from lerobot.processor.pipeline import _default_transition_to_batch - - # Transition without index/task_index - transition = create_transition( - observation={"observation.state": torch.randn(1, 7)}, - action=torch.randn(1, 4), - complementary_data={"task": ["navigate"]}, - ) - - batch = _default_transition_to_batch(transition) - - # Should have task but not index/task_index - assert "task" in batch - assert "index" not in batch - assert "task_index" not in batch - - def test_override_with_device_strings(): """Test overriding device parameters with string values."""