fix(processor): enforce signatures

This commit is contained in:
Steven Palma
2025-09-16 17:13:07 +02:00
parent fa8be1c4fe
commit 43eb0e375f
4 changed files with 32 additions and 28 deletions
+2 -3
View File
@@ -14,14 +14,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Any
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
from lerobot.datasets.utils import combine_feature_dicts from lerobot.datasets.utils import combine_feature_dicts
from lerobot.model.kinematics import RobotKinematics from lerobot.model.kinematics import RobotKinematics
from lerobot.processor import RobotAction, RobotProcessorPipeline from lerobot.processor import RobotAction, RobotObservation, RobotProcessorPipeline
from lerobot.processor.converters import ( from lerobot.processor.converters import (
observation_to_transition, observation_to_transition,
robot_action_to_transition, robot_action_to_transition,
@@ -76,7 +75,7 @@ leader_kinematics_solver = RobotKinematics(
) )
# Build pipeline to convert follower joints to EE observation # Build pipeline to convert follower joints to EE observation
follower_joints_to_ee = RobotProcessorPipeline[dict[str, Any], dict[str, Any]]( follower_joints_to_ee = RobotProcessorPipeline[RobotObservation, RobotObservation](
steps=[ steps=[
ForwardKinematicsJointsToEE( ForwardKinematicsJointsToEE(
kinematics=follower_kinematics_solver, motor_names=list(follower.bus.motors.keys()) kinematics=follower_kinematics_solver, motor_names=list(follower.bus.motors.keys())
+18 -4
View File
@@ -221,7 +221,8 @@ def robot_action_to_transition(action: RobotAction) -> EnvTransition:
Returns: Returns:
An `EnvTransition` containing the formatted action. An `EnvTransition` containing the formatted action.
""" """
if not isinstance(action, RobotAction):
raise ValueError(f"Action should be a RobotAction type got {type(action)}")
return create_transition(action=action) return create_transition(action=action)
@@ -239,7 +240,8 @@ def observation_to_transition(observation: RobotObservation) -> EnvTransition:
Returns: Returns:
An `EnvTransition` containing the formatted observation. An `EnvTransition` containing the formatted observation.
""" """
if not isinstance(observation, RobotObservation):
raise ValueError(f"Observation should be a RobotObservation type got {type(observation)}")
return create_transition(observation=observation) return create_transition(observation=observation)
@@ -256,6 +258,9 @@ def transition_to_robot_action(transition: EnvTransition) -> RobotAction:
Returns: Returns:
A dictionary representing the raw robot action. A dictionary representing the raw robot action.
""" """
if not isinstance(transition, dict):
raise ValueError(f"Transition should be a EnvTransition type (dict) got {type(transition)}")
action = transition.get(TransitionKey.ACTION) action = transition.get(TransitionKey.ACTION)
if not isinstance(action, dict): if not isinstance(action, dict):
raise ValueError(f"Action should be a RobotAction type (dict) got {type(action)}") raise ValueError(f"Action should be a RobotAction type (dict) got {type(action)}")
@@ -266,6 +271,9 @@ def transition_to_policy_action(transition: EnvTransition) -> PolicyAction:
""" """
Convert an `EnvTransition` to a `PolicyAction`. Convert an `EnvTransition` to a `PolicyAction`.
""" """
if not isinstance(transition, dict):
raise ValueError(f"Transition should be a EnvTransition type (dict) got {type(transition)}")
action = transition.get(TransitionKey.ACTION) action = transition.get(TransitionKey.ACTION)
if not isinstance(action, PolicyAction): if not isinstance(action, PolicyAction):
raise ValueError(f"Action should be a PolicyAction type got {type(action)}") raise ValueError(f"Action should be a PolicyAction type got {type(action)}")
@@ -276,6 +284,9 @@ def transition_to_observation(transition: EnvTransition) -> RobotObservation:
""" """
Convert an `EnvTransition` to a `RobotObservation`. Convert an `EnvTransition` to a `RobotObservation`.
""" """
if not isinstance(transition, dict):
raise ValueError(f"Transition should be a EnvTransition type (dict) got {type(transition)}")
observation = transition.get(TransitionKey.OBSERVATION) observation = transition.get(TransitionKey.OBSERVATION)
if not isinstance(observation, dict): if not isinstance(observation, dict):
raise ValueError(f"Observation should be a RobotObservation (dict) type got {type(observation)}") raise ValueError(f"Observation should be a RobotObservation (dict) type got {type(observation)}")
@@ -343,6 +354,9 @@ def transition_to_batch(transition: EnvTransition) -> dict[str, Any]:
Returns: Returns:
A batch dictionary with canonical LeRobot field names. A batch dictionary with canonical LeRobot field names.
""" """
if not isinstance(transition, dict):
raise ValueError(f"Transition should be a EnvTransition type (dict) got {type(transition)}")
batch = { batch = {
"action": transition.get(TransitionKey.ACTION), "action": transition.get(TransitionKey.ACTION),
"next.reward": transition.get(TransitionKey.REWARD, 0.0), "next.reward": transition.get(TransitionKey.REWARD, 0.0),
@@ -364,7 +378,7 @@ def transition_to_batch(transition: EnvTransition) -> dict[str, Any]:
return batch return batch
def identity_transition(tr: EnvTransition) -> EnvTransition: def identity_transition(transition: EnvTransition) -> EnvTransition:
""" """
An identity function for transitions, returning the input unchanged. An identity function for transitions, returning the input unchanged.
@@ -376,4 +390,4 @@ def identity_transition(tr: EnvTransition) -> EnvTransition:
Returns: Returns:
The same `EnvTransition`. The same `EnvTransition`.
""" """
return tr return transition
+3 -10
View File
@@ -48,11 +48,8 @@ from pprint import pformat
from lerobot.configs import parser from lerobot.configs import parser
from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.processor import ( from lerobot.processor import (
IdentityProcessorStep, make_default_robot_action_processor,
RobotAction,
RobotProcessorPipeline,
) )
from lerobot.processor.converters import robot_action_to_transition, transition_to_robot_action
from lerobot.robots import ( # noqa: F401 from lerobot.robots import ( # noqa: F401
Robot, Robot,
RobotConfig, RobotConfig,
@@ -96,11 +93,7 @@ def replay(cfg: ReplayConfig):
init_logging() init_logging()
logging.info(pformat(asdict(cfg))) logging.info(pformat(asdict(cfg)))
robot_action_processor = RobotProcessorPipeline[RobotAction, RobotAction]( robot_action_processor = make_default_robot_action_processor()
steps=[IdentityProcessorStep()],
to_transition=robot_action_to_transition,
to_output=transition_to_robot_action,
)
robot = make_robot_from_config(cfg.robot) robot = make_robot_from_config(cfg.robot)
dataset = LeRobotDataset(cfg.dataset.repo_id, root=cfg.dataset.root, episodes=[cfg.dataset.episode]) dataset = LeRobotDataset(cfg.dataset.repo_id, root=cfg.dataset.root, episodes=[cfg.dataset.episode])
@@ -122,7 +115,7 @@ def replay(cfg: ReplayConfig):
processed_action = robot_action_processor(action) processed_action = robot_action_processor(action)
robot.send_action(processed_action) _ = robot.send_action(processed_action)
dt_s = time.perf_counter() - start_episode_t dt_s = time.perf_counter() - start_episode_t
busy_wait(1 / dataset.fps - dt_s) busy_wait(1 / dataset.fps - dt_s)
+9 -11
View File
@@ -55,7 +55,6 @@ import logging
import time import time
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass
from pprint import pformat from pprint import pformat
from typing import Any
import rerun as rr import rerun as rr
@@ -63,10 +62,9 @@ from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # no
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401 from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
from lerobot.configs import parser from lerobot.configs import parser
from lerobot.processor import ( from lerobot.processor import (
EnvTransition,
RobotAction, RobotAction,
RobotObservation,
RobotProcessorPipeline, RobotProcessorPipeline,
TransitionKey,
make_default_processors, make_default_processors,
) )
from lerobot.robots import ( # noqa: F401 from lerobot.robots import ( # noqa: F401
@@ -111,9 +109,9 @@ def teleop_loop(
teleop: Teleoperator, teleop: Teleoperator,
robot: Robot, robot: Robot,
fps: int, fps: int,
teleop_action_processor: RobotProcessorPipeline[RobotAction, EnvTransition], teleop_action_processor: RobotProcessorPipeline[RobotAction, RobotAction],
robot_action_processor: RobotProcessorPipeline[EnvTransition, RobotAction], robot_action_processor: RobotProcessorPipeline[RobotAction, RobotAction],
robot_observation_processor: RobotProcessorPipeline[dict[str, Any], EnvTransition], robot_observation_processor: RobotProcessorPipeline[RobotObservation, RobotObservation],
display_data: bool = False, display_data: bool = False,
duration: float | None = None, duration: float | None = None,
): ):
@@ -143,13 +141,13 @@ def teleop_loop(
raw_action = teleop.get_action() raw_action = teleop.get_action()
# Process teleop action through pipeline # Process teleop action through pipeline
teleop_transition = teleop_action_processor(raw_action) teleop_action = teleop_action_processor(raw_action)
# Process action for robot through pipeline # Process action for robot through pipeline
robot_action_to_send = robot_action_processor(teleop_transition) robot_action_to_send = robot_action_processor(teleop_action)
# Send processed action to robot (robot_action_processor.to_output should return dict[str, Any]) # Send processed action to robot (robot_action_processor.to_output should return dict[str, Any])
robot.send_action(robot_action_to_send) _ = robot.send_action(robot_action_to_send)
if display_data: if display_data:
# Get robot observation # Get robot observation
@@ -158,8 +156,8 @@ def teleop_loop(
obs_transition = robot_observation_processor(obs) obs_transition = robot_observation_processor(obs)
log_rerun_data( log_rerun_data(
observation=obs_transition.get(TransitionKey.OBSERVATION), observation=obs_transition,
action=teleop_transition.get(TransitionKey.ACTION), action=teleop_action,
) )
print("\n" + "-" * (display_len + 10)) print("\n" + "-" * (display_len + 10))