refactor(processors): add transform_features method to various processors (#1843)

This commit is contained in:
Steven Palma
2025-09-02 17:15:01 +02:00
committed by GitHub
parent 645c87e3a9
commit 2914ae2a96
11 changed files with 71 additions and 2 deletions
@@ -17,6 +17,7 @@
import torch import torch
from lerobot.configs.types import PolicyFeature
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
from lerobot.policies.pi0.configuration_pi0 import PI0Config from lerobot.policies.pi0.configuration_pi0 import PI0Config
from lerobot.processor import ( from lerobot.processor import (
@@ -64,6 +65,9 @@ class Pi0NewLineProcessor(ComplementaryDataProcessor):
return new_complementary_data return new_complementary_data
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
def make_pi0_pre_post_processors( def make_pi0_pre_post_processors(
config: PI0Config, config: PI0Config,
@@ -16,6 +16,7 @@
import torch import torch
from lerobot.configs.types import PolicyFeature
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
from lerobot.processor import ( from lerobot.processor import (
@@ -107,3 +108,6 @@ class SmolVLANewLineProcessor(ComplementaryDataProcessor):
# If task is neither string nor list of strings, leave unchanged # If task is neither string nor list of strings, leave unchanged
return new_complementary_data return new_complementary_data
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
+13
View File
@@ -15,6 +15,7 @@ from dataclasses import dataclass, field
from torch import Tensor from torch import Tensor
from lerobot.configs.types import PolicyFeature
from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
from lerobot.processor.pipeline import ( from lerobot.processor.pipeline import (
ActionProcessor, ActionProcessor,
@@ -37,6 +38,9 @@ class ToBatchProcessorAction(ActionProcessor):
return action.unsqueeze(0) return action.unsqueeze(0)
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
@dataclass @dataclass
@ProcessorStepRegistry.register(name="to_batch_processor_observation") @ProcessorStepRegistry.register(name="to_batch_processor_observation")
@@ -63,6 +67,9 @@ class ToBatchProcessorObservation(ObservationProcessor):
observation[key] = value.unsqueeze(0) observation[key] = value.unsqueeze(0)
return observation return observation
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
@dataclass @dataclass
@ProcessorStepRegistry.register(name="to_batch_processor_complementary_data") @ProcessorStepRegistry.register(name="to_batch_processor_complementary_data")
@@ -89,6 +96,9 @@ class ToBatchProcessorComplementaryData(ComplementaryDataProcessor):
complementary_data["task_index"] = task_index_value.unsqueeze(0) complementary_data["task_index"] = task_index_value.unsqueeze(0)
return complementary_data return complementary_data
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
@dataclass @dataclass
@ProcessorStepRegistry.register(name="to_batch_processor") @ProcessorStepRegistry.register(name="to_batch_processor")
@@ -140,3 +150,6 @@ class ToBatchProcessor(ProcessorStep):
transition = self.to_batch_observation_processor(transition) transition = self.to_batch_observation_processor(transition)
transition = self.to_batch_complementary_data_processor(transition) transition = self.to_batch_complementary_data_processor(transition)
return transition return transition
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
+2 -1
View File
@@ -295,7 +295,8 @@ def merge_transitions(transitions: Sequence[EnvTransition] | EnvTransition) -> E
Returns: Returns:
Merged EnvTransition. Merged EnvTransition.
""" """
if isinstance(transitions, EnvTransition): # Single transition
if not isinstance(transitions, Sequence): # Single transition
return transitions return transitions
items = list(transitions) items = list(transitions)
@@ -45,6 +45,9 @@ class MapTensorToDeltaActionDict(ActionProcessor):
delta_action["action.gripper"] = action[3] delta_action["action.gripper"] = action[3]
return delta_action return delta_action
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
@ProcessorStepRegistry.register("map_delta_action_to_robot_action") @ProcessorStepRegistry.register("map_delta_action_to_robot_action")
@dataclass @dataclass
@@ -18,6 +18,7 @@ from typing import Any
import torch import torch
from lerobot.configs.types import PolicyFeature
from lerobot.processor.core import EnvTransition, TransitionKey from lerobot.processor.core import EnvTransition, TransitionKey
from lerobot.processor.pipeline import ProcessorStep, ProcessorStepRegistry from lerobot.processor.pipeline import ProcessorStep, ProcessorStepRegistry
from lerobot.utils.utils import get_safe_torch_device from lerobot.utils.utils import get_safe_torch_device
@@ -127,3 +128,6 @@ class DeviceProcessor(ProcessorStep):
def get_config(self) -> dict[str, Any]: def get_config(self) -> dict[str, Any]:
"""Return configuration for serialization.""" """Return configuration for serialization."""
return {"device": self.device, "float_dtype": self.float_dtype} return {"device": self.device, "float_dtype": self.float_dtype}
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
@@ -16,6 +16,7 @@ from dataclasses import dataclass
import numpy as np import numpy as np
import torch import torch
from lerobot.configs.types import PolicyFeature
from lerobot.processor.converters import to_tensor from lerobot.processor.converters import to_tensor
from lerobot.processor.pipeline import ActionProcessor, ProcessorStepRegistry from lerobot.processor.pipeline import ActionProcessor, ProcessorStepRegistry
@@ -48,6 +49,9 @@ class Torch2NumpyActionProcessor(ActionProcessor):
return numpy_action return numpy_action
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
@ProcessorStepRegistry.register("numpy2torch_action_processor") @ProcessorStepRegistry.register("numpy2torch_action_processor")
@dataclass @dataclass
@@ -62,3 +66,6 @@ class Numpy2TorchActionProcessor(ActionProcessor):
) )
torch_action = to_tensor(action, dtype=None) # Preserve original dtype torch_action = to_tensor(action, dtype=None) # Preserve original dtype
return torch_action return torch_action
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
+18
View File
@@ -39,6 +39,9 @@ class AddTeleopActionAsComplimentaryData(ComplementaryDataProcessor):
new_complementary_data[TELEOP_ACTION_KEY] = self.teleop_device.get_action() new_complementary_data[TELEOP_ACTION_KEY] = self.teleop_device.get_action()
return new_complementary_data return new_complementary_data
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
@ProcessorStepRegistry.register("add_teleop_action_as_info") @ProcessorStepRegistry.register("add_teleop_action_as_info")
@dataclass @dataclass
@@ -53,6 +56,9 @@ class AddTeleopEventsAsInfo(InfoProcessor):
new_info.update(teleop_events) new_info.update(teleop_events)
return new_info return new_info
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
@ProcessorStepRegistry.register("image_crop_resize_processor") @ProcessorStepRegistry.register("image_crop_resize_processor")
@dataclass @dataclass
@@ -127,6 +133,9 @@ class TimeLimitProcessor(TruncatedProcessor):
def reset(self) -> None: def reset(self) -> None:
self.current_step = 0 self.current_step = 0
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
@dataclass @dataclass
@ProcessorStepRegistry.register("gripper_penalty_processor") @ProcessorStepRegistry.register("gripper_penalty_processor")
@@ -173,6 +182,9 @@ class GripperPenaltyProcessor(ComplementaryDataProcessor):
"""Reset the processor state.""" """Reset the processor state."""
self.last_gripper_state = None self.last_gripper_state = None
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
@dataclass @dataclass
@ProcessorStepRegistry.register("intervention_action_processor") @ProcessorStepRegistry.register("intervention_action_processor")
@@ -243,6 +255,9 @@ class InterventionActionProcessor(ProcessorStep):
"terminate_on_success": self.terminate_on_success, "terminate_on_success": self.terminate_on_success,
} }
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
@dataclass @dataclass
@ProcessorStepRegistry.register("reward_classifier_processor") @ProcessorStepRegistry.register("reward_classifier_processor")
@@ -312,3 +327,6 @@ class RewardClassifierProcessor(ProcessorStep):
"success_reward": self.success_reward, "success_reward": self.success_reward,
"terminate_on_success": self.terminate_on_success, "terminate_on_success": self.terminate_on_success,
} }
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
@@ -211,6 +211,9 @@ class NormalizerProcessor(_NormalizationMixin, ProcessorStep):
return new_transition return new_transition
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
@dataclass @dataclass
@ProcessorStepRegistry.register(name="unnormalizer_processor") @ProcessorStepRegistry.register(name="unnormalizer_processor")
@@ -249,6 +252,9 @@ class UnnormalizerProcessor(_NormalizationMixin, ProcessorStep):
return new_transition return new_transition
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
def hotswap_stats(robot_processor: RobotProcessor, stats: dict[str, dict[str, Any]]) -> RobotProcessor: def hotswap_stats(robot_processor: RobotProcessor, stats: dict[str, dict[str, Any]]) -> RobotProcessor:
""" """
+4 -1
View File
@@ -169,7 +169,7 @@ class ProcessorStep(ABC):
def reset(self) -> None: def reset(self) -> None:
return None return None
# TODO(Steven): Consider making this abstract so it is more explicit @abstractmethod
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features return features
@@ -1091,3 +1091,6 @@ class IdentityProcessor(ProcessorStep):
def __call__(self, transition: EnvTransition) -> EnvTransition: def __call__(self, transition: EnvTransition) -> EnvTransition:
return transition return transition
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
@@ -220,6 +220,9 @@ class EEBoundsAndSafety(ActionProcessor):
self._last_pos = None self._last_pos = None
self._last_twist = None self._last_twist = None
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
@ProcessorStepRegistry.register("inverse_kinematics_ee_to_joints") @ProcessorStepRegistry.register("inverse_kinematics_ee_to_joints")
@dataclass @dataclass
@@ -444,3 +447,6 @@ class AddRobotObservationAsComplimentaryData(ComplementaryDataProcessor):
if isinstance(k, str) and k.endswith(".pos") if isinstance(k, str) and k.endswith(".pos")
} }
return new_comp return new_comp
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features