diff --git a/src/lerobot/policies/pi0/processor_pi0.py b/src/lerobot/policies/pi0/processor_pi0.py index e288c9e43..3c2b18635 100644 --- a/src/lerobot/policies/pi0/processor_pi0.py +++ b/src/lerobot/policies/pi0/processor_pi0.py @@ -17,6 +17,7 @@ import torch +from lerobot.configs.types import PolicyFeature from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME from lerobot.policies.pi0.configuration_pi0 import PI0Config from lerobot.processor import ( @@ -64,6 +65,9 @@ class Pi0NewLineProcessor(ComplementaryDataProcessor): return new_complementary_data + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + return features + def make_pi0_pre_post_processors( config: PI0Config, diff --git a/src/lerobot/policies/smolvla/processor_smolvla.py b/src/lerobot/policies/smolvla/processor_smolvla.py index 0f27bdfa6..0b535e238 100644 --- a/src/lerobot/policies/smolvla/processor_smolvla.py +++ b/src/lerobot/policies/smolvla/processor_smolvla.py @@ -16,6 +16,7 @@ import torch +from lerobot.configs.types import PolicyFeature from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig from lerobot.processor import ( @@ -107,3 +108,6 @@ class SmolVLANewLineProcessor(ComplementaryDataProcessor): # If task is neither string nor list of strings, leave unchanged return new_complementary_data + + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + return features diff --git a/src/lerobot/processor/batch_processor.py b/src/lerobot/processor/batch_processor.py index aab575ef7..cfa57bb26 100644 --- a/src/lerobot/processor/batch_processor.py +++ b/src/lerobot/processor/batch_processor.py @@ -15,6 +15,7 @@ from dataclasses import dataclass, field from torch import Tensor +from lerobot.configs.types import PolicyFeature from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE from lerobot.processor.pipeline import ( ActionProcessor, @@ -37,6 +38,9 @@ class ToBatchProcessorAction(ActionProcessor): return action.unsqueeze(0) + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + return features + @dataclass @ProcessorStepRegistry.register(name="to_batch_processor_observation") @@ -63,6 +67,9 @@ class ToBatchProcessorObservation(ObservationProcessor): observation[key] = value.unsqueeze(0) return observation + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + return features + @dataclass @ProcessorStepRegistry.register(name="to_batch_processor_complementary_data") @@ -89,6 +96,9 @@ class ToBatchProcessorComplementaryData(ComplementaryDataProcessor): complementary_data["task_index"] = task_index_value.unsqueeze(0) return complementary_data + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + return features + @dataclass @ProcessorStepRegistry.register(name="to_batch_processor") @@ -140,3 +150,6 @@ class ToBatchProcessor(ProcessorStep): transition = self.to_batch_observation_processor(transition) transition = self.to_batch_complementary_data_processor(transition) return transition + + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + return features diff --git a/src/lerobot/processor/converters.py b/src/lerobot/processor/converters.py index 550bb470d..2dec92dda 100644 --- a/src/lerobot/processor/converters.py +++ b/src/lerobot/processor/converters.py @@ -295,7 +295,8 @@ def merge_transitions(transitions: Sequence[EnvTransition] | EnvTransition) -> E Returns: Merged EnvTransition. """ - if isinstance(transitions, EnvTransition): # Single transition + + if not isinstance(transitions, Sequence): # Single transition return transitions items = list(transitions) diff --git a/src/lerobot/processor/delta_action_processor.py b/src/lerobot/processor/delta_action_processor.py index 63eff9aad..38964f1fa 100644 --- a/src/lerobot/processor/delta_action_processor.py +++ b/src/lerobot/processor/delta_action_processor.py @@ -45,6 +45,9 @@ class MapTensorToDeltaActionDict(ActionProcessor): delta_action["action.gripper"] = action[3] return delta_action + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + return features + @ProcessorStepRegistry.register("map_delta_action_to_robot_action") @dataclass diff --git a/src/lerobot/processor/device_processor.py b/src/lerobot/processor/device_processor.py index c099d050a..3f68d9d03 100644 --- a/src/lerobot/processor/device_processor.py +++ b/src/lerobot/processor/device_processor.py @@ -18,6 +18,7 @@ from typing import Any import torch +from lerobot.configs.types import PolicyFeature from lerobot.processor.core import EnvTransition, TransitionKey from lerobot.processor.pipeline import ProcessorStep, ProcessorStepRegistry from lerobot.utils.utils import get_safe_torch_device @@ -127,3 +128,6 @@ class DeviceProcessor(ProcessorStep): def get_config(self) -> dict[str, Any]: """Return configuration for serialization.""" return {"device": self.device, "float_dtype": self.float_dtype} + + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + return features diff --git a/src/lerobot/processor/gym_action_processor.py b/src/lerobot/processor/gym_action_processor.py index 41b320370..54142ca9f 100644 --- a/src/lerobot/processor/gym_action_processor.py +++ b/src/lerobot/processor/gym_action_processor.py @@ -16,6 +16,7 @@ from dataclasses import dataclass import numpy as np import torch +from lerobot.configs.types import PolicyFeature from lerobot.processor.converters import to_tensor from lerobot.processor.pipeline import ActionProcessor, ProcessorStepRegistry @@ -48,6 +49,9 @@ class Torch2NumpyActionProcessor(ActionProcessor): return numpy_action + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + return features + @ProcessorStepRegistry.register("numpy2torch_action_processor") @dataclass @@ -62,3 +66,6 @@ class Numpy2TorchActionProcessor(ActionProcessor): ) torch_action = to_tensor(action, dtype=None) # Preserve original dtype return torch_action + + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + return features diff --git a/src/lerobot/processor/hil_processor.py b/src/lerobot/processor/hil_processor.py index c75e40fff..b271c560a 100644 --- a/src/lerobot/processor/hil_processor.py +++ b/src/lerobot/processor/hil_processor.py @@ -39,6 +39,9 @@ class AddTeleopActionAsComplimentaryData(ComplementaryDataProcessor): new_complementary_data[TELEOP_ACTION_KEY] = self.teleop_device.get_action() return new_complementary_data + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + return features + @ProcessorStepRegistry.register("add_teleop_action_as_info") @dataclass @@ -53,6 +56,9 @@ class AddTeleopEventsAsInfo(InfoProcessor): new_info.update(teleop_events) return new_info + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + return features + @ProcessorStepRegistry.register("image_crop_resize_processor") @dataclass @@ -127,6 +133,9 @@ class TimeLimitProcessor(TruncatedProcessor): def reset(self) -> None: self.current_step = 0 + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + return features + @dataclass @ProcessorStepRegistry.register("gripper_penalty_processor") @@ -173,6 +182,9 @@ class GripperPenaltyProcessor(ComplementaryDataProcessor): """Reset the processor state.""" self.last_gripper_state = None + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + return features + @dataclass @ProcessorStepRegistry.register("intervention_action_processor") @@ -243,6 +255,9 @@ class InterventionActionProcessor(ProcessorStep): "terminate_on_success": self.terminate_on_success, } + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + return features + @dataclass @ProcessorStepRegistry.register("reward_classifier_processor") @@ -312,3 +327,6 @@ class RewardClassifierProcessor(ProcessorStep): "success_reward": self.success_reward, "terminate_on_success": self.terminate_on_success, } + + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + return features diff --git a/src/lerobot/processor/normalize_processor.py b/src/lerobot/processor/normalize_processor.py index b88d8b5af..4b6949b2d 100644 --- a/src/lerobot/processor/normalize_processor.py +++ b/src/lerobot/processor/normalize_processor.py @@ -211,6 +211,9 @@ class NormalizerProcessor(_NormalizationMixin, ProcessorStep): return new_transition + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + return features + @dataclass @ProcessorStepRegistry.register(name="unnormalizer_processor") @@ -249,6 +252,9 @@ class UnnormalizerProcessor(_NormalizationMixin, ProcessorStep): return new_transition + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + return features + def hotswap_stats(robot_processor: RobotProcessor, stats: dict[str, dict[str, Any]]) -> RobotProcessor: """ diff --git a/src/lerobot/processor/pipeline.py b/src/lerobot/processor/pipeline.py index 7054ba439..abe64599d 100644 --- a/src/lerobot/processor/pipeline.py +++ b/src/lerobot/processor/pipeline.py @@ -169,7 +169,7 @@ class ProcessorStep(ABC): def reset(self) -> None: return None - # TODO(Steven): Consider making this abstract so it is more explicit + @abstractmethod def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: return features @@ -1091,3 +1091,6 @@ class IdentityProcessor(ProcessorStep): def __call__(self, transition: EnvTransition) -> EnvTransition: return transition + + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + return features diff --git a/src/lerobot/robots/so100_follower/robot_kinematic_processor.py b/src/lerobot/robots/so100_follower/robot_kinematic_processor.py index 39bab604f..62cffc4d2 100644 --- a/src/lerobot/robots/so100_follower/robot_kinematic_processor.py +++ b/src/lerobot/robots/so100_follower/robot_kinematic_processor.py @@ -220,6 +220,9 @@ class EEBoundsAndSafety(ActionProcessor): self._last_pos = None self._last_twist = None + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + return features + @ProcessorStepRegistry.register("inverse_kinematics_ee_to_joints") @dataclass @@ -444,3 +447,6 @@ class AddRobotObservationAsComplimentaryData(ComplementaryDataProcessor): if isinstance(k, str) and k.endswith(".pos") } return new_comp + + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + return features