fix(processor): enforce signatures

This commit is contained in:
Steven Palma
2025-09-16 17:13:07 +02:00
parent fa8be1c4fe
commit 43eb0e375f
4 changed files with 32 additions and 28 deletions
+2 -3
View File
@@ -14,14 +14,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
from lerobot.datasets.utils import combine_feature_dicts
from lerobot.model.kinematics import RobotKinematics
from lerobot.processor import RobotAction, RobotProcessorPipeline
from lerobot.processor import RobotAction, RobotObservation, RobotProcessorPipeline
from lerobot.processor.converters import (
observation_to_transition,
robot_action_to_transition,
@@ -76,7 +75,7 @@ leader_kinematics_solver = RobotKinematics(
)
# Build pipeline to convert follower joints to EE observation
follower_joints_to_ee = RobotProcessorPipeline[dict[str, Any], dict[str, Any]](
follower_joints_to_ee = RobotProcessorPipeline[RobotObservation, RobotObservation](
steps=[
ForwardKinematicsJointsToEE(
kinematics=follower_kinematics_solver, motor_names=list(follower.bus.motors.keys())
+18 -4
View File
@@ -221,7 +221,8 @@ def robot_action_to_transition(action: RobotAction) -> EnvTransition:
Returns:
An `EnvTransition` containing the formatted action.
"""
if not isinstance(action, RobotAction):
raise ValueError(f"Action should be a RobotAction type got {type(action)}")
return create_transition(action=action)
@@ -239,7 +240,8 @@ def observation_to_transition(observation: RobotObservation) -> EnvTransition:
Returns:
An `EnvTransition` containing the formatted observation.
"""
if not isinstance(observation, RobotObservation):
raise ValueError(f"Observation should be a RobotObservation type got {type(observation)}")
return create_transition(observation=observation)
@@ -256,6 +258,9 @@ def transition_to_robot_action(transition: EnvTransition) -> RobotAction:
Returns:
A dictionary representing the raw robot action.
"""
if not isinstance(transition, dict):
raise ValueError(f"Transition should be a EnvTransition type (dict) got {type(transition)}")
action = transition.get(TransitionKey.ACTION)
if not isinstance(action, dict):
raise ValueError(f"Action should be a RobotAction type (dict) got {type(action)}")
@@ -266,6 +271,9 @@ def transition_to_policy_action(transition: EnvTransition) -> PolicyAction:
"""
Convert an `EnvTransition` to a `PolicyAction`.
"""
if not isinstance(transition, dict):
raise ValueError(f"Transition should be a EnvTransition type (dict) got {type(transition)}")
action = transition.get(TransitionKey.ACTION)
if not isinstance(action, PolicyAction):
raise ValueError(f"Action should be a PolicyAction type got {type(action)}")
@@ -276,6 +284,9 @@ def transition_to_observation(transition: EnvTransition) -> RobotObservation:
"""
Convert an `EnvTransition` to a `RobotObservation`.
"""
if not isinstance(transition, dict):
raise ValueError(f"Transition should be a EnvTransition type (dict) got {type(transition)}")
observation = transition.get(TransitionKey.OBSERVATION)
if not isinstance(observation, dict):
raise ValueError(f"Observation should be a RobotObservation (dict) type got {type(observation)}")
@@ -343,6 +354,9 @@ def transition_to_batch(transition: EnvTransition) -> dict[str, Any]:
Returns:
A batch dictionary with canonical LeRobot field names.
"""
if not isinstance(transition, dict):
raise ValueError(f"Transition should be a EnvTransition type (dict) got {type(transition)}")
batch = {
"action": transition.get(TransitionKey.ACTION),
"next.reward": transition.get(TransitionKey.REWARD, 0.0),
@@ -364,7 +378,7 @@ def transition_to_batch(transition: EnvTransition) -> dict[str, Any]:
return batch
def identity_transition(tr: EnvTransition) -> EnvTransition:
def identity_transition(transition: EnvTransition) -> EnvTransition:
"""
An identity function for transitions, returning the input unchanged.
@@ -376,4 +390,4 @@ def identity_transition(tr: EnvTransition) -> EnvTransition:
Returns:
The same `EnvTransition`.
"""
return tr
return transition
+3 -10
View File
@@ -48,11 +48,8 @@ from pprint import pformat
from lerobot.configs import parser
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.processor import (
IdentityProcessorStep,
RobotAction,
RobotProcessorPipeline,
make_default_robot_action_processor,
)
from lerobot.processor.converters import robot_action_to_transition, transition_to_robot_action
from lerobot.robots import ( # noqa: F401
Robot,
RobotConfig,
@@ -96,11 +93,7 @@ def replay(cfg: ReplayConfig):
init_logging()
logging.info(pformat(asdict(cfg)))
robot_action_processor = RobotProcessorPipeline[RobotAction, RobotAction](
steps=[IdentityProcessorStep()],
to_transition=robot_action_to_transition,
to_output=transition_to_robot_action,
)
robot_action_processor = make_default_robot_action_processor()
robot = make_robot_from_config(cfg.robot)
dataset = LeRobotDataset(cfg.dataset.repo_id, root=cfg.dataset.root, episodes=[cfg.dataset.episode])
@@ -122,7 +115,7 @@ def replay(cfg: ReplayConfig):
processed_action = robot_action_processor(action)
robot.send_action(processed_action)
_ = robot.send_action(processed_action)
dt_s = time.perf_counter() - start_episode_t
busy_wait(1 / dataset.fps - dt_s)
+9 -11
View File
@@ -55,7 +55,6 @@ import logging
import time
from dataclasses import asdict, dataclass
from pprint import pformat
from typing import Any
import rerun as rr
@@ -63,10 +62,9 @@ from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # no
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
from lerobot.configs import parser
from lerobot.processor import (
EnvTransition,
RobotAction,
RobotObservation,
RobotProcessorPipeline,
TransitionKey,
make_default_processors,
)
from lerobot.robots import ( # noqa: F401
@@ -111,9 +109,9 @@ def teleop_loop(
teleop: Teleoperator,
robot: Robot,
fps: int,
teleop_action_processor: RobotProcessorPipeline[RobotAction, EnvTransition],
robot_action_processor: RobotProcessorPipeline[EnvTransition, RobotAction],
robot_observation_processor: RobotProcessorPipeline[dict[str, Any], EnvTransition],
teleop_action_processor: RobotProcessorPipeline[RobotAction, RobotAction],
robot_action_processor: RobotProcessorPipeline[RobotAction, RobotAction],
robot_observation_processor: RobotProcessorPipeline[RobotObservation, RobotObservation],
display_data: bool = False,
duration: float | None = None,
):
@@ -143,13 +141,13 @@ def teleop_loop(
raw_action = teleop.get_action()
# Process teleop action through pipeline
teleop_transition = teleop_action_processor(raw_action)
teleop_action = teleop_action_processor(raw_action)
# Process action for robot through pipeline
robot_action_to_send = robot_action_processor(teleop_transition)
robot_action_to_send = robot_action_processor(teleop_action)
# Send processed action to robot (robot_action_processor.to_output should return dict[str, Any])
robot.send_action(robot_action_to_send)
_ = robot.send_action(robot_action_to_send)
if display_data:
# Get robot observation
@@ -158,8 +156,8 @@ def teleop_loop(
obs_transition = robot_observation_processor(obs)
log_rerun_data(
observation=obs_transition.get(TransitionKey.OBSERVATION),
action=teleop_transition.get(TransitionKey.ACTION),
observation=obs_transition,
action=teleop_action,
)
print("\n" + "-" * (display_len + 10))