mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 00:29:52 +00:00
fix(processor): enforce signatures
This commit is contained in:
@@ -14,14 +14,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
|
||||
from lerobot.datasets.utils import combine_feature_dicts
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.processor import RobotAction, RobotProcessorPipeline
|
||||
from lerobot.processor import RobotAction, RobotObservation, RobotProcessorPipeline
|
||||
from lerobot.processor.converters import (
|
||||
observation_to_transition,
|
||||
robot_action_to_transition,
|
||||
@@ -76,7 +75,7 @@ leader_kinematics_solver = RobotKinematics(
|
||||
)
|
||||
|
||||
# Build pipeline to convert follower joints to EE observation
|
||||
follower_joints_to_ee = RobotProcessorPipeline[dict[str, Any], dict[str, Any]](
|
||||
follower_joints_to_ee = RobotProcessorPipeline[RobotObservation, RobotObservation](
|
||||
steps=[
|
||||
ForwardKinematicsJointsToEE(
|
||||
kinematics=follower_kinematics_solver, motor_names=list(follower.bus.motors.keys())
|
||||
|
||||
@@ -221,7 +221,8 @@ def robot_action_to_transition(action: RobotAction) -> EnvTransition:
|
||||
Returns:
|
||||
An `EnvTransition` containing the formatted action.
|
||||
"""
|
||||
|
||||
if not isinstance(action, RobotAction):
|
||||
raise ValueError(f"Action should be a RobotAction type got {type(action)}")
|
||||
return create_transition(action=action)
|
||||
|
||||
|
||||
@@ -239,7 +240,8 @@ def observation_to_transition(observation: RobotObservation) -> EnvTransition:
|
||||
Returns:
|
||||
An `EnvTransition` containing the formatted observation.
|
||||
"""
|
||||
|
||||
if not isinstance(observation, RobotObservation):
|
||||
raise ValueError(f"Observation should be a RobotObservation type got {type(observation)}")
|
||||
return create_transition(observation=observation)
|
||||
|
||||
|
||||
@@ -256,6 +258,9 @@ def transition_to_robot_action(transition: EnvTransition) -> RobotAction:
|
||||
Returns:
|
||||
A dictionary representing the raw robot action.
|
||||
"""
|
||||
if not isinstance(transition, dict):
|
||||
raise ValueError(f"Transition should be a EnvTransition type (dict) got {type(transition)}")
|
||||
|
||||
action = transition.get(TransitionKey.ACTION)
|
||||
if not isinstance(action, dict):
|
||||
raise ValueError(f"Action should be a RobotAction type (dict) got {type(action)}")
|
||||
@@ -266,6 +271,9 @@ def transition_to_policy_action(transition: EnvTransition) -> PolicyAction:
|
||||
"""
|
||||
Convert an `EnvTransition` to a `PolicyAction`.
|
||||
"""
|
||||
if not isinstance(transition, dict):
|
||||
raise ValueError(f"Transition should be a EnvTransition type (dict) got {type(transition)}")
|
||||
|
||||
action = transition.get(TransitionKey.ACTION)
|
||||
if not isinstance(action, PolicyAction):
|
||||
raise ValueError(f"Action should be a PolicyAction type got {type(action)}")
|
||||
@@ -276,6 +284,9 @@ def transition_to_observation(transition: EnvTransition) -> RobotObservation:
|
||||
"""
|
||||
Convert an `EnvTransition` to a `RobotObservation`.
|
||||
"""
|
||||
if not isinstance(transition, dict):
|
||||
raise ValueError(f"Transition should be a EnvTransition type (dict) got {type(transition)}")
|
||||
|
||||
observation = transition.get(TransitionKey.OBSERVATION)
|
||||
if not isinstance(observation, dict):
|
||||
raise ValueError(f"Observation should be a RobotObservation (dict) type got {type(observation)}")
|
||||
@@ -343,6 +354,9 @@ def transition_to_batch(transition: EnvTransition) -> dict[str, Any]:
|
||||
Returns:
|
||||
A batch dictionary with canonical LeRobot field names.
|
||||
"""
|
||||
if not isinstance(transition, dict):
|
||||
raise ValueError(f"Transition should be a EnvTransition type (dict) got {type(transition)}")
|
||||
|
||||
batch = {
|
||||
"action": transition.get(TransitionKey.ACTION),
|
||||
"next.reward": transition.get(TransitionKey.REWARD, 0.0),
|
||||
@@ -364,7 +378,7 @@ def transition_to_batch(transition: EnvTransition) -> dict[str, Any]:
|
||||
return batch
|
||||
|
||||
|
||||
def identity_transition(tr: EnvTransition) -> EnvTransition:
|
||||
def identity_transition(transition: EnvTransition) -> EnvTransition:
|
||||
"""
|
||||
An identity function for transitions, returning the input unchanged.
|
||||
|
||||
@@ -376,4 +390,4 @@ def identity_transition(tr: EnvTransition) -> EnvTransition:
|
||||
Returns:
|
||||
The same `EnvTransition`.
|
||||
"""
|
||||
return tr
|
||||
return transition
|
||||
|
||||
+3
-10
@@ -48,11 +48,8 @@ from pprint import pformat
|
||||
from lerobot.configs import parser
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.processor import (
|
||||
IdentityProcessorStep,
|
||||
RobotAction,
|
||||
RobotProcessorPipeline,
|
||||
make_default_robot_action_processor,
|
||||
)
|
||||
from lerobot.processor.converters import robot_action_to_transition, transition_to_robot_action
|
||||
from lerobot.robots import ( # noqa: F401
|
||||
Robot,
|
||||
RobotConfig,
|
||||
@@ -96,11 +93,7 @@ def replay(cfg: ReplayConfig):
|
||||
init_logging()
|
||||
logging.info(pformat(asdict(cfg)))
|
||||
|
||||
robot_action_processor = RobotProcessorPipeline[RobotAction, RobotAction](
|
||||
steps=[IdentityProcessorStep()],
|
||||
to_transition=robot_action_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
robot_action_processor = make_default_robot_action_processor()
|
||||
|
||||
robot = make_robot_from_config(cfg.robot)
|
||||
dataset = LeRobotDataset(cfg.dataset.repo_id, root=cfg.dataset.root, episodes=[cfg.dataset.episode])
|
||||
@@ -122,7 +115,7 @@ def replay(cfg: ReplayConfig):
|
||||
|
||||
processed_action = robot_action_processor(action)
|
||||
|
||||
robot.send_action(processed_action)
|
||||
_ = robot.send_action(processed_action)
|
||||
|
||||
dt_s = time.perf_counter() - start_episode_t
|
||||
busy_wait(1 / dataset.fps - dt_s)
|
||||
|
||||
@@ -55,7 +55,6 @@ import logging
|
||||
import time
|
||||
from dataclasses import asdict, dataclass
|
||||
from pprint import pformat
|
||||
from typing import Any
|
||||
|
||||
import rerun as rr
|
||||
|
||||
@@ -63,10 +62,9 @@ from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # no
|
||||
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
|
||||
from lerobot.configs import parser
|
||||
from lerobot.processor import (
|
||||
EnvTransition,
|
||||
RobotAction,
|
||||
RobotObservation,
|
||||
RobotProcessorPipeline,
|
||||
TransitionKey,
|
||||
make_default_processors,
|
||||
)
|
||||
from lerobot.robots import ( # noqa: F401
|
||||
@@ -111,9 +109,9 @@ def teleop_loop(
|
||||
teleop: Teleoperator,
|
||||
robot: Robot,
|
||||
fps: int,
|
||||
teleop_action_processor: RobotProcessorPipeline[RobotAction, EnvTransition],
|
||||
robot_action_processor: RobotProcessorPipeline[EnvTransition, RobotAction],
|
||||
robot_observation_processor: RobotProcessorPipeline[dict[str, Any], EnvTransition],
|
||||
teleop_action_processor: RobotProcessorPipeline[RobotAction, RobotAction],
|
||||
robot_action_processor: RobotProcessorPipeline[RobotAction, RobotAction],
|
||||
robot_observation_processor: RobotProcessorPipeline[RobotObservation, RobotObservation],
|
||||
display_data: bool = False,
|
||||
duration: float | None = None,
|
||||
):
|
||||
@@ -143,13 +141,13 @@ def teleop_loop(
|
||||
raw_action = teleop.get_action()
|
||||
|
||||
# Process teleop action through pipeline
|
||||
teleop_transition = teleop_action_processor(raw_action)
|
||||
teleop_action = teleop_action_processor(raw_action)
|
||||
|
||||
# Process action for robot through pipeline
|
||||
robot_action_to_send = robot_action_processor(teleop_transition)
|
||||
robot_action_to_send = robot_action_processor(teleop_action)
|
||||
|
||||
# Send processed action to robot (robot_action_processor.to_output should return dict[str, Any])
|
||||
robot.send_action(robot_action_to_send)
|
||||
_ = robot.send_action(robot_action_to_send)
|
||||
|
||||
if display_data:
|
||||
# Get robot observation
|
||||
@@ -158,8 +156,8 @@ def teleop_loop(
|
||||
obs_transition = robot_observation_processor(obs)
|
||||
|
||||
log_rerun_data(
|
||||
observation=obs_transition.get(TransitionKey.OBSERVATION),
|
||||
action=teleop_transition.get(TransitionKey.ACTION),
|
||||
observation=obs_transition,
|
||||
action=teleop_action,
|
||||
)
|
||||
|
||||
print("\n" + "-" * (display_len + 10))
|
||||
|
||||
Reference in New Issue
Block a user