mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-24 21:19:53 +00:00
fix(processor): enforce signatures
This commit is contained in:
@@ -14,14 +14,13 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||||
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
|
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
|
||||||
from lerobot.datasets.utils import combine_feature_dicts
|
from lerobot.datasets.utils import combine_feature_dicts
|
||||||
from lerobot.model.kinematics import RobotKinematics
|
from lerobot.model.kinematics import RobotKinematics
|
||||||
from lerobot.processor import RobotAction, RobotProcessorPipeline
|
from lerobot.processor import RobotAction, RobotObservation, RobotProcessorPipeline
|
||||||
from lerobot.processor.converters import (
|
from lerobot.processor.converters import (
|
||||||
observation_to_transition,
|
observation_to_transition,
|
||||||
robot_action_to_transition,
|
robot_action_to_transition,
|
||||||
@@ -76,7 +75,7 @@ leader_kinematics_solver = RobotKinematics(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Build pipeline to convert follower joints to EE observation
|
# Build pipeline to convert follower joints to EE observation
|
||||||
follower_joints_to_ee = RobotProcessorPipeline[dict[str, Any], dict[str, Any]](
|
follower_joints_to_ee = RobotProcessorPipeline[RobotObservation, RobotObservation](
|
||||||
steps=[
|
steps=[
|
||||||
ForwardKinematicsJointsToEE(
|
ForwardKinematicsJointsToEE(
|
||||||
kinematics=follower_kinematics_solver, motor_names=list(follower.bus.motors.keys())
|
kinematics=follower_kinematics_solver, motor_names=list(follower.bus.motors.keys())
|
||||||
|
|||||||
@@ -221,7 +221,8 @@ def robot_action_to_transition(action: RobotAction) -> EnvTransition:
|
|||||||
Returns:
|
Returns:
|
||||||
An `EnvTransition` containing the formatted action.
|
An `EnvTransition` containing the formatted action.
|
||||||
"""
|
"""
|
||||||
|
if not isinstance(action, RobotAction):
|
||||||
|
raise ValueError(f"Action should be a RobotAction type got {type(action)}")
|
||||||
return create_transition(action=action)
|
return create_transition(action=action)
|
||||||
|
|
||||||
|
|
||||||
@@ -239,7 +240,8 @@ def observation_to_transition(observation: RobotObservation) -> EnvTransition:
|
|||||||
Returns:
|
Returns:
|
||||||
An `EnvTransition` containing the formatted observation.
|
An `EnvTransition` containing the formatted observation.
|
||||||
"""
|
"""
|
||||||
|
if not isinstance(observation, RobotObservation):
|
||||||
|
raise ValueError(f"Observation should be a RobotObservation type got {type(observation)}")
|
||||||
return create_transition(observation=observation)
|
return create_transition(observation=observation)
|
||||||
|
|
||||||
|
|
||||||
@@ -256,6 +258,9 @@ def transition_to_robot_action(transition: EnvTransition) -> RobotAction:
|
|||||||
Returns:
|
Returns:
|
||||||
A dictionary representing the raw robot action.
|
A dictionary representing the raw robot action.
|
||||||
"""
|
"""
|
||||||
|
if not isinstance(transition, dict):
|
||||||
|
raise ValueError(f"Transition should be a EnvTransition type (dict) got {type(transition)}")
|
||||||
|
|
||||||
action = transition.get(TransitionKey.ACTION)
|
action = transition.get(TransitionKey.ACTION)
|
||||||
if not isinstance(action, dict):
|
if not isinstance(action, dict):
|
||||||
raise ValueError(f"Action should be a RobotAction type (dict) got {type(action)}")
|
raise ValueError(f"Action should be a RobotAction type (dict) got {type(action)}")
|
||||||
@@ -266,6 +271,9 @@ def transition_to_policy_action(transition: EnvTransition) -> PolicyAction:
|
|||||||
"""
|
"""
|
||||||
Convert an `EnvTransition` to a `PolicyAction`.
|
Convert an `EnvTransition` to a `PolicyAction`.
|
||||||
"""
|
"""
|
||||||
|
if not isinstance(transition, dict):
|
||||||
|
raise ValueError(f"Transition should be a EnvTransition type (dict) got {type(transition)}")
|
||||||
|
|
||||||
action = transition.get(TransitionKey.ACTION)
|
action = transition.get(TransitionKey.ACTION)
|
||||||
if not isinstance(action, PolicyAction):
|
if not isinstance(action, PolicyAction):
|
||||||
raise ValueError(f"Action should be a PolicyAction type got {type(action)}")
|
raise ValueError(f"Action should be a PolicyAction type got {type(action)}")
|
||||||
@@ -276,6 +284,9 @@ def transition_to_observation(transition: EnvTransition) -> RobotObservation:
|
|||||||
"""
|
"""
|
||||||
Convert an `EnvTransition` to a `RobotObservation`.
|
Convert an `EnvTransition` to a `RobotObservation`.
|
||||||
"""
|
"""
|
||||||
|
if not isinstance(transition, dict):
|
||||||
|
raise ValueError(f"Transition should be a EnvTransition type (dict) got {type(transition)}")
|
||||||
|
|
||||||
observation = transition.get(TransitionKey.OBSERVATION)
|
observation = transition.get(TransitionKey.OBSERVATION)
|
||||||
if not isinstance(observation, dict):
|
if not isinstance(observation, dict):
|
||||||
raise ValueError(f"Observation should be a RobotObservation (dict) type got {type(observation)}")
|
raise ValueError(f"Observation should be a RobotObservation (dict) type got {type(observation)}")
|
||||||
@@ -343,6 +354,9 @@ def transition_to_batch(transition: EnvTransition) -> dict[str, Any]:
|
|||||||
Returns:
|
Returns:
|
||||||
A batch dictionary with canonical LeRobot field names.
|
A batch dictionary with canonical LeRobot field names.
|
||||||
"""
|
"""
|
||||||
|
if not isinstance(transition, dict):
|
||||||
|
raise ValueError(f"Transition should be a EnvTransition type (dict) got {type(transition)}")
|
||||||
|
|
||||||
batch = {
|
batch = {
|
||||||
"action": transition.get(TransitionKey.ACTION),
|
"action": transition.get(TransitionKey.ACTION),
|
||||||
"next.reward": transition.get(TransitionKey.REWARD, 0.0),
|
"next.reward": transition.get(TransitionKey.REWARD, 0.0),
|
||||||
@@ -364,7 +378,7 @@ def transition_to_batch(transition: EnvTransition) -> dict[str, Any]:
|
|||||||
return batch
|
return batch
|
||||||
|
|
||||||
|
|
||||||
def identity_transition(tr: EnvTransition) -> EnvTransition:
|
def identity_transition(transition: EnvTransition) -> EnvTransition:
|
||||||
"""
|
"""
|
||||||
An identity function for transitions, returning the input unchanged.
|
An identity function for transitions, returning the input unchanged.
|
||||||
|
|
||||||
@@ -376,4 +390,4 @@ def identity_transition(tr: EnvTransition) -> EnvTransition:
|
|||||||
Returns:
|
Returns:
|
||||||
The same `EnvTransition`.
|
The same `EnvTransition`.
|
||||||
"""
|
"""
|
||||||
return tr
|
return transition
|
||||||
|
|||||||
+3
-10
@@ -48,11 +48,8 @@ from pprint import pformat
|
|||||||
from lerobot.configs import parser
|
from lerobot.configs import parser
|
||||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||||
from lerobot.processor import (
|
from lerobot.processor import (
|
||||||
IdentityProcessorStep,
|
make_default_robot_action_processor,
|
||||||
RobotAction,
|
|
||||||
RobotProcessorPipeline,
|
|
||||||
)
|
)
|
||||||
from lerobot.processor.converters import robot_action_to_transition, transition_to_robot_action
|
|
||||||
from lerobot.robots import ( # noqa: F401
|
from lerobot.robots import ( # noqa: F401
|
||||||
Robot,
|
Robot,
|
||||||
RobotConfig,
|
RobotConfig,
|
||||||
@@ -96,11 +93,7 @@ def replay(cfg: ReplayConfig):
|
|||||||
init_logging()
|
init_logging()
|
||||||
logging.info(pformat(asdict(cfg)))
|
logging.info(pformat(asdict(cfg)))
|
||||||
|
|
||||||
robot_action_processor = RobotProcessorPipeline[RobotAction, RobotAction](
|
robot_action_processor = make_default_robot_action_processor()
|
||||||
steps=[IdentityProcessorStep()],
|
|
||||||
to_transition=robot_action_to_transition,
|
|
||||||
to_output=transition_to_robot_action,
|
|
||||||
)
|
|
||||||
|
|
||||||
robot = make_robot_from_config(cfg.robot)
|
robot = make_robot_from_config(cfg.robot)
|
||||||
dataset = LeRobotDataset(cfg.dataset.repo_id, root=cfg.dataset.root, episodes=[cfg.dataset.episode])
|
dataset = LeRobotDataset(cfg.dataset.repo_id, root=cfg.dataset.root, episodes=[cfg.dataset.episode])
|
||||||
@@ -122,7 +115,7 @@ def replay(cfg: ReplayConfig):
|
|||||||
|
|
||||||
processed_action = robot_action_processor(action)
|
processed_action = robot_action_processor(action)
|
||||||
|
|
||||||
robot.send_action(processed_action)
|
_ = robot.send_action(processed_action)
|
||||||
|
|
||||||
dt_s = time.perf_counter() - start_episode_t
|
dt_s = time.perf_counter() - start_episode_t
|
||||||
busy_wait(1 / dataset.fps - dt_s)
|
busy_wait(1 / dataset.fps - dt_s)
|
||||||
|
|||||||
@@ -55,7 +55,6 @@ import logging
|
|||||||
import time
|
import time
|
||||||
from dataclasses import asdict, dataclass
|
from dataclasses import asdict, dataclass
|
||||||
from pprint import pformat
|
from pprint import pformat
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import rerun as rr
|
import rerun as rr
|
||||||
|
|
||||||
@@ -63,10 +62,9 @@ from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # no
|
|||||||
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
|
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
|
||||||
from lerobot.configs import parser
|
from lerobot.configs import parser
|
||||||
from lerobot.processor import (
|
from lerobot.processor import (
|
||||||
EnvTransition,
|
|
||||||
RobotAction,
|
RobotAction,
|
||||||
|
RobotObservation,
|
||||||
RobotProcessorPipeline,
|
RobotProcessorPipeline,
|
||||||
TransitionKey,
|
|
||||||
make_default_processors,
|
make_default_processors,
|
||||||
)
|
)
|
||||||
from lerobot.robots import ( # noqa: F401
|
from lerobot.robots import ( # noqa: F401
|
||||||
@@ -111,9 +109,9 @@ def teleop_loop(
|
|||||||
teleop: Teleoperator,
|
teleop: Teleoperator,
|
||||||
robot: Robot,
|
robot: Robot,
|
||||||
fps: int,
|
fps: int,
|
||||||
teleop_action_processor: RobotProcessorPipeline[RobotAction, EnvTransition],
|
teleop_action_processor: RobotProcessorPipeline[RobotAction, RobotAction],
|
||||||
robot_action_processor: RobotProcessorPipeline[EnvTransition, RobotAction],
|
robot_action_processor: RobotProcessorPipeline[RobotAction, RobotAction],
|
||||||
robot_observation_processor: RobotProcessorPipeline[dict[str, Any], EnvTransition],
|
robot_observation_processor: RobotProcessorPipeline[RobotObservation, RobotObservation],
|
||||||
display_data: bool = False,
|
display_data: bool = False,
|
||||||
duration: float | None = None,
|
duration: float | None = None,
|
||||||
):
|
):
|
||||||
@@ -143,13 +141,13 @@ def teleop_loop(
|
|||||||
raw_action = teleop.get_action()
|
raw_action = teleop.get_action()
|
||||||
|
|
||||||
# Process teleop action through pipeline
|
# Process teleop action through pipeline
|
||||||
teleop_transition = teleop_action_processor(raw_action)
|
teleop_action = teleop_action_processor(raw_action)
|
||||||
|
|
||||||
# Process action for robot through pipeline
|
# Process action for robot through pipeline
|
||||||
robot_action_to_send = robot_action_processor(teleop_transition)
|
robot_action_to_send = robot_action_processor(teleop_action)
|
||||||
|
|
||||||
# Send processed action to robot (robot_action_processor.to_output should return dict[str, Any])
|
# Send processed action to robot (robot_action_processor.to_output should return dict[str, Any])
|
||||||
robot.send_action(robot_action_to_send)
|
_ = robot.send_action(robot_action_to_send)
|
||||||
|
|
||||||
if display_data:
|
if display_data:
|
||||||
# Get robot observation
|
# Get robot observation
|
||||||
@@ -158,8 +156,8 @@ def teleop_loop(
|
|||||||
obs_transition = robot_observation_processor(obs)
|
obs_transition = robot_observation_processor(obs)
|
||||||
|
|
||||||
log_rerun_data(
|
log_rerun_data(
|
||||||
observation=obs_transition.get(TransitionKey.OBSERVATION),
|
observation=obs_transition,
|
||||||
action=teleop_transition.get(TransitionKey.ACTION),
|
action=teleop_action,
|
||||||
)
|
)
|
||||||
|
|
||||||
print("\n" + "-" * (display_len + 10))
|
print("\n" + "-" * (display_len + 10))
|
||||||
|
|||||||
Reference in New Issue
Block a user