mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-28 15:09:51 +00:00
refactor(processors): add transform_features method to various processors (#1843)
This commit is contained in:
@@ -17,6 +17,7 @@
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from lerobot.configs.types import PolicyFeature
|
||||||
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
|
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
|
||||||
from lerobot.policies.pi0.configuration_pi0 import PI0Config
|
from lerobot.policies.pi0.configuration_pi0 import PI0Config
|
||||||
from lerobot.processor import (
|
from lerobot.processor import (
|
||||||
@@ -64,6 +65,9 @@ class Pi0NewLineProcessor(ComplementaryDataProcessor):
|
|||||||
|
|
||||||
return new_complementary_data
|
return new_complementary_data
|
||||||
|
|
||||||
|
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||||
|
return features
|
||||||
|
|
||||||
|
|
||||||
def make_pi0_pre_post_processors(
|
def make_pi0_pre_post_processors(
|
||||||
config: PI0Config,
|
config: PI0Config,
|
||||||
|
|||||||
@@ -16,6 +16,7 @@
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from lerobot.configs.types import PolicyFeature
|
||||||
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
|
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
|
||||||
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
|
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
|
||||||
from lerobot.processor import (
|
from lerobot.processor import (
|
||||||
@@ -107,3 +108,6 @@ class SmolVLANewLineProcessor(ComplementaryDataProcessor):
|
|||||||
# If task is neither string nor list of strings, leave unchanged
|
# If task is neither string nor list of strings, leave unchanged
|
||||||
|
|
||||||
return new_complementary_data
|
return new_complementary_data
|
||||||
|
|
||||||
|
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||||
|
return features
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ from dataclasses import dataclass, field
|
|||||||
|
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
|
from lerobot.configs.types import PolicyFeature
|
||||||
from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
|
from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
|
||||||
from lerobot.processor.pipeline import (
|
from lerobot.processor.pipeline import (
|
||||||
ActionProcessor,
|
ActionProcessor,
|
||||||
@@ -37,6 +38,9 @@ class ToBatchProcessorAction(ActionProcessor):
|
|||||||
|
|
||||||
return action.unsqueeze(0)
|
return action.unsqueeze(0)
|
||||||
|
|
||||||
|
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||||
|
return features
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ProcessorStepRegistry.register(name="to_batch_processor_observation")
|
@ProcessorStepRegistry.register(name="to_batch_processor_observation")
|
||||||
@@ -63,6 +67,9 @@ class ToBatchProcessorObservation(ObservationProcessor):
|
|||||||
observation[key] = value.unsqueeze(0)
|
observation[key] = value.unsqueeze(0)
|
||||||
return observation
|
return observation
|
||||||
|
|
||||||
|
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||||
|
return features
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ProcessorStepRegistry.register(name="to_batch_processor_complementary_data")
|
@ProcessorStepRegistry.register(name="to_batch_processor_complementary_data")
|
||||||
@@ -89,6 +96,9 @@ class ToBatchProcessorComplementaryData(ComplementaryDataProcessor):
|
|||||||
complementary_data["task_index"] = task_index_value.unsqueeze(0)
|
complementary_data["task_index"] = task_index_value.unsqueeze(0)
|
||||||
return complementary_data
|
return complementary_data
|
||||||
|
|
||||||
|
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||||
|
return features
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ProcessorStepRegistry.register(name="to_batch_processor")
|
@ProcessorStepRegistry.register(name="to_batch_processor")
|
||||||
@@ -140,3 +150,6 @@ class ToBatchProcessor(ProcessorStep):
|
|||||||
transition = self.to_batch_observation_processor(transition)
|
transition = self.to_batch_observation_processor(transition)
|
||||||
transition = self.to_batch_complementary_data_processor(transition)
|
transition = self.to_batch_complementary_data_processor(transition)
|
||||||
return transition
|
return transition
|
||||||
|
|
||||||
|
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||||
|
return features
|
||||||
|
|||||||
@@ -295,7 +295,8 @@ def merge_transitions(transitions: Sequence[EnvTransition] | EnvTransition) -> E
|
|||||||
Returns:
|
Returns:
|
||||||
Merged EnvTransition.
|
Merged EnvTransition.
|
||||||
"""
|
"""
|
||||||
if isinstance(transitions, EnvTransition): # Single transition
|
|
||||||
|
if not isinstance(transitions, Sequence): # Single transition
|
||||||
return transitions
|
return transitions
|
||||||
|
|
||||||
items = list(transitions)
|
items = list(transitions)
|
||||||
|
|||||||
@@ -45,6 +45,9 @@ class MapTensorToDeltaActionDict(ActionProcessor):
|
|||||||
delta_action["action.gripper"] = action[3]
|
delta_action["action.gripper"] = action[3]
|
||||||
return delta_action
|
return delta_action
|
||||||
|
|
||||||
|
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||||
|
return features
|
||||||
|
|
||||||
|
|
||||||
@ProcessorStepRegistry.register("map_delta_action_to_robot_action")
|
@ProcessorStepRegistry.register("map_delta_action_to_robot_action")
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ from typing import Any
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from lerobot.configs.types import PolicyFeature
|
||||||
from lerobot.processor.core import EnvTransition, TransitionKey
|
from lerobot.processor.core import EnvTransition, TransitionKey
|
||||||
from lerobot.processor.pipeline import ProcessorStep, ProcessorStepRegistry
|
from lerobot.processor.pipeline import ProcessorStep, ProcessorStepRegistry
|
||||||
from lerobot.utils.utils import get_safe_torch_device
|
from lerobot.utils.utils import get_safe_torch_device
|
||||||
@@ -127,3 +128,6 @@ class DeviceProcessor(ProcessorStep):
|
|||||||
def get_config(self) -> dict[str, Any]:
|
def get_config(self) -> dict[str, Any]:
|
||||||
"""Return configuration for serialization."""
|
"""Return configuration for serialization."""
|
||||||
return {"device": self.device, "float_dtype": self.float_dtype}
|
return {"device": self.device, "float_dtype": self.float_dtype}
|
||||||
|
|
||||||
|
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||||
|
return features
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ from dataclasses import dataclass
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from lerobot.configs.types import PolicyFeature
|
||||||
from lerobot.processor.converters import to_tensor
|
from lerobot.processor.converters import to_tensor
|
||||||
from lerobot.processor.pipeline import ActionProcessor, ProcessorStepRegistry
|
from lerobot.processor.pipeline import ActionProcessor, ProcessorStepRegistry
|
||||||
|
|
||||||
@@ -48,6 +49,9 @@ class Torch2NumpyActionProcessor(ActionProcessor):
|
|||||||
|
|
||||||
return numpy_action
|
return numpy_action
|
||||||
|
|
||||||
|
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||||
|
return features
|
||||||
|
|
||||||
|
|
||||||
@ProcessorStepRegistry.register("numpy2torch_action_processor")
|
@ProcessorStepRegistry.register("numpy2torch_action_processor")
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -62,3 +66,6 @@ class Numpy2TorchActionProcessor(ActionProcessor):
|
|||||||
)
|
)
|
||||||
torch_action = to_tensor(action, dtype=None) # Preserve original dtype
|
torch_action = to_tensor(action, dtype=None) # Preserve original dtype
|
||||||
return torch_action
|
return torch_action
|
||||||
|
|
||||||
|
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||||
|
return features
|
||||||
|
|||||||
@@ -39,6 +39,9 @@ class AddTeleopActionAsComplimentaryData(ComplementaryDataProcessor):
|
|||||||
new_complementary_data[TELEOP_ACTION_KEY] = self.teleop_device.get_action()
|
new_complementary_data[TELEOP_ACTION_KEY] = self.teleop_device.get_action()
|
||||||
return new_complementary_data
|
return new_complementary_data
|
||||||
|
|
||||||
|
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||||
|
return features
|
||||||
|
|
||||||
|
|
||||||
@ProcessorStepRegistry.register("add_teleop_action_as_info")
|
@ProcessorStepRegistry.register("add_teleop_action_as_info")
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -53,6 +56,9 @@ class AddTeleopEventsAsInfo(InfoProcessor):
|
|||||||
new_info.update(teleop_events)
|
new_info.update(teleop_events)
|
||||||
return new_info
|
return new_info
|
||||||
|
|
||||||
|
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||||
|
return features
|
||||||
|
|
||||||
|
|
||||||
@ProcessorStepRegistry.register("image_crop_resize_processor")
|
@ProcessorStepRegistry.register("image_crop_resize_processor")
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -127,6 +133,9 @@ class TimeLimitProcessor(TruncatedProcessor):
|
|||||||
def reset(self) -> None:
|
def reset(self) -> None:
|
||||||
self.current_step = 0
|
self.current_step = 0
|
||||||
|
|
||||||
|
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||||
|
return features
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ProcessorStepRegistry.register("gripper_penalty_processor")
|
@ProcessorStepRegistry.register("gripper_penalty_processor")
|
||||||
@@ -173,6 +182,9 @@ class GripperPenaltyProcessor(ComplementaryDataProcessor):
|
|||||||
"""Reset the processor state."""
|
"""Reset the processor state."""
|
||||||
self.last_gripper_state = None
|
self.last_gripper_state = None
|
||||||
|
|
||||||
|
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||||
|
return features
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ProcessorStepRegistry.register("intervention_action_processor")
|
@ProcessorStepRegistry.register("intervention_action_processor")
|
||||||
@@ -243,6 +255,9 @@ class InterventionActionProcessor(ProcessorStep):
|
|||||||
"terminate_on_success": self.terminate_on_success,
|
"terminate_on_success": self.terminate_on_success,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||||
|
return features
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ProcessorStepRegistry.register("reward_classifier_processor")
|
@ProcessorStepRegistry.register("reward_classifier_processor")
|
||||||
@@ -312,3 +327,6 @@ class RewardClassifierProcessor(ProcessorStep):
|
|||||||
"success_reward": self.success_reward,
|
"success_reward": self.success_reward,
|
||||||
"terminate_on_success": self.terminate_on_success,
|
"terminate_on_success": self.terminate_on_success,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||||
|
return features
|
||||||
|
|||||||
@@ -211,6 +211,9 @@ class NormalizerProcessor(_NormalizationMixin, ProcessorStep):
|
|||||||
|
|
||||||
return new_transition
|
return new_transition
|
||||||
|
|
||||||
|
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||||
|
return features
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ProcessorStepRegistry.register(name="unnormalizer_processor")
|
@ProcessorStepRegistry.register(name="unnormalizer_processor")
|
||||||
@@ -249,6 +252,9 @@ class UnnormalizerProcessor(_NormalizationMixin, ProcessorStep):
|
|||||||
|
|
||||||
return new_transition
|
return new_transition
|
||||||
|
|
||||||
|
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||||
|
return features
|
||||||
|
|
||||||
|
|
||||||
def hotswap_stats(robot_processor: RobotProcessor, stats: dict[str, dict[str, Any]]) -> RobotProcessor:
|
def hotswap_stats(robot_processor: RobotProcessor, stats: dict[str, dict[str, Any]]) -> RobotProcessor:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -169,7 +169,7 @@ class ProcessorStep(ABC):
|
|||||||
def reset(self) -> None:
|
def reset(self) -> None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# TODO(Steven): Consider making this abstract so it is more explicit
|
@abstractmethod
|
||||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||||
return features
|
return features
|
||||||
|
|
||||||
@@ -1091,3 +1091,6 @@ class IdentityProcessor(ProcessorStep):
|
|||||||
|
|
||||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
return transition
|
return transition
|
||||||
|
|
||||||
|
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||||
|
return features
|
||||||
|
|||||||
@@ -220,6 +220,9 @@ class EEBoundsAndSafety(ActionProcessor):
|
|||||||
self._last_pos = None
|
self._last_pos = None
|
||||||
self._last_twist = None
|
self._last_twist = None
|
||||||
|
|
||||||
|
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||||
|
return features
|
||||||
|
|
||||||
|
|
||||||
@ProcessorStepRegistry.register("inverse_kinematics_ee_to_joints")
|
@ProcessorStepRegistry.register("inverse_kinematics_ee_to_joints")
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -444,3 +447,6 @@ class AddRobotObservationAsComplimentaryData(ComplementaryDataProcessor):
|
|||||||
if isinstance(k, str) and k.endswith(".pos")
|
if isinstance(k, str) and k.endswith(".pos")
|
||||||
}
|
}
|
||||||
return new_comp
|
return new_comp
|
||||||
|
|
||||||
|
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||||
|
return features
|
||||||
|
|||||||
Reference in New Issue
Block a user