From 06988b2135728c23776be358c36bd1923b488c62 Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Tue, 4 Mar 2025 13:32:58 +0100 Subject: [PATCH] WIP stretch 3 robot & teleop --- .../robots/stretch3/configuration_stretch3.py | 4 +- .../common/robots/stretch3/robot_stretch3.py | 27 ++- lerobot/common/robots/utils.py | 4 +- .../configuration_stretch3.py | 25 +++ .../stretch3_gamepad/teleop_stretch3.py | 180 ++++++++++++++++++ 5 files changed, 220 insertions(+), 20 deletions(-) create mode 100644 lerobot/common/teleoperators/stretch3_gamepad/configuration_stretch3.py create mode 100644 lerobot/common/teleoperators/stretch3_gamepad/teleop_stretch3.py diff --git a/lerobot/common/robots/stretch3/configuration_stretch3.py b/lerobot/common/robots/stretch3/configuration_stretch3.py index 79d54670a..47ddb54bb 100644 --- a/lerobot/common/robots/stretch3/configuration_stretch3.py +++ b/lerobot/common/robots/stretch3/configuration_stretch3.py @@ -7,9 +7,9 @@ from lerobot.common.cameras.opencv import OpenCVCameraConfig from ..config import RobotConfig -@RobotConfig.register_subclass("stretch") +@RobotConfig.register_subclass("stretch3") @dataclass -class StretchRobotConfig(RobotConfig): +class Stretch3RobotConfig(RobotConfig): # `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes. # Set this to a positive scalar to have the same value for all motors, or a list that is the same length as # the number of motors in your follower arms. diff --git a/lerobot/common/robots/stretch3/robot_stretch3.py b/lerobot/common/robots/stretch3/robot_stretch3.py index ffbd6078e..e07e3f1e0 100644 --- a/lerobot/common/robots/stretch3/robot_stretch3.py +++ b/lerobot/common/robots/stretch3/robot_stretch3.py @@ -17,16 +17,16 @@ import time import numpy as np -import torch from stretch_body.gamepad_teleop import GamePadTeleop from stretch_body.robot import Robot as StretchAPI from stretch_body.robot_params import RobotParams from lerobot.common.cameras.utils import make_cameras_from_configs +from lerobot.common.constants import OBS_IMAGES, OBS_STATE from lerobot.common.datasets.utils import get_nested_item from ..robot import Robot -from .configuration_stretch3 import StretchRobotConfig +from .configuration_stretch3 import Stretch3RobotConfig # {lerobot_keys: stretch.api.keys} STRETCH_MOTORS = { @@ -47,10 +47,10 @@ STRETCH_MOTORS = { class Stretch3Robot(Robot): """[Stretch 3](https://hello-robot.com/stretch-3-product), by Hello Robot.""" - config_class = StretchRobotConfig + config_class = Stretch3RobotConfig name = "stretch3" - def __init__(self, config: StretchRobotConfig): + def __init__(self, config: Stretch3RobotConfig): super().__init__(config) self.config = config @@ -121,6 +121,7 @@ class Stretch3Robot(Robot): def get_observation(self) -> dict[str, np.ndarray]: obs_dict = {} + # Read Stretch state before_read_t = time.perf_counter() state = self._get_state() self.logs["read_pos_dt_s"] = time.perf_counter() - before_read_t @@ -128,21 +129,15 @@ class Stretch3Robot(Robot): if self.state_keys is None: self.state_keys = list(state) - state = torch.as_tensor(list(state.values())) + state = np.asarray(list(state.values())) + obs_dict[OBS_STATE] = state # Capture images from cameras - images = {} - for name in self.cameras: + for cam_key, cam in self.cameras.items(): before_camread_t = time.perf_counter() - images[name] = self.cameras[name].async_read() - self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"] - self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t - - # Populate output dictionaries - obs_dict = {} - obs_dict["observation.state"] = state - for name in self.cameras: - obs_dict[f"observation.images.{name}"] = images[name] + obs_dict[f"{OBS_IMAGES}.{cam_key}"] = cam.async_read() + self.logs[f"read_camera_{cam_key}_dt_s"] = cam.logs["delta_timestamp_s"] + self.logs[f"async_read_camera_{cam_key}_dt_s"] = time.perf_counter() - before_camread_t return obs_dict diff --git a/lerobot/common/robots/utils.py b/lerobot/common/robots/utils.py index 8f2531469..db86e6cce 100644 --- a/lerobot/common/robots/utils.py +++ b/lerobot/common/robots/utils.py @@ -46,9 +46,9 @@ def make_robot_config(robot_type: str, **kwargs) -> RobotConfig: return So100RobotConfig(**kwargs) elif robot_type == "stretch": - from .stretch3.configuration_stretch3 import StretchRobotConfig + from .stretch3.configuration_stretch3 import Stretch3RobotConfig - return StretchRobotConfig(**kwargs) + return Stretch3RobotConfig(**kwargs) elif robot_type == "lekiwi": from .lekiwi.configuration_lekiwi import LeKiwiRobotConfig diff --git a/lerobot/common/teleoperators/stretch3_gamepad/configuration_stretch3.py b/lerobot/common/teleoperators/stretch3_gamepad/configuration_stretch3.py new file mode 100644 index 000000000..507a21589 --- /dev/null +++ b/lerobot/common/teleoperators/stretch3_gamepad/configuration_stretch3.py @@ -0,0 +1,25 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass + +from ..config import TeleoperatorConfig + + +@TeleoperatorConfig.register_subclass("stretch3") +@dataclass +class Stretch3GamePadConfig(TeleoperatorConfig): + mock: bool = False diff --git a/lerobot/common/teleoperators/stretch3_gamepad/teleop_stretch3.py b/lerobot/common/teleoperators/stretch3_gamepad/teleop_stretch3.py new file mode 100644 index 000000000..adab08428 --- /dev/null +++ b/lerobot/common/teleoperators/stretch3_gamepad/teleop_stretch3.py @@ -0,0 +1,180 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +import numpy as np +from stretch_body.gamepad_teleop import GamePadTeleop +from stretch_body.robot_params import RobotParams + +from lerobot.common.constants import OBS_IMAGES, OBS_STATE +from lerobot.common.datasets.utils import get_nested_item + +from ..teleoperator import Teleoperator +from .configuration_stretch3 import Stretch3GamePadConfig + +# {lerobot_keys: stretch.api.keys} +STRETCH_MOTORS = { + "head_pan.pos": "head.head_pan.pos", + "head_tilt.pos": "head.head_tilt.pos", + "lift.pos": "lift.pos", + "arm.pos": "arm.pos", + "wrist_pitch.pos": "end_of_arm.wrist_pitch.pos", + "wrist_roll.pos": "end_of_arm.wrist_roll.pos", + "wrist_yaw.pos": "end_of_arm.wrist_yaw.pos", + "gripper.pos": "end_of_arm.stretch_gripper.pos", + "base_x.vel": "base.x_vel", + "base_y.vel": "base.y_vel", + "base_theta.vel": "base.theta_vel", +} + + +class Stretch3GamePad(Teleoperator): + """[Stretch 3](https://hello-robot.com/stretch-3-product), by Hello Robot.""" + + config_class = Stretch3GamePadConfig + name = "stretch3" + + def __init__(self, config: Stretch3GamePadConfig): + super().__init__(config) + + self.config = config + self.robot_type = self.config.type + + self.api = GamePadTeleop(robot_instance=False) + + self.is_connected = False + self.logs = {} + + self.teleop = None # TODO remove + + # TODO(aliberts): test this + RobotParams.set_logging_level("WARNING") + RobotParams.set_logging_formatter("brief_console_formatter") + + self.state_keys = None + self.action_keys = None + + @property + def state_feature(self) -> dict: + return { + "dtype": "float32", + "shape": (len(STRETCH_MOTORS),), + "names": {"motors": list(STRETCH_MOTORS)}, + } + + @property + def action_feature(self) -> dict: + return self.state_feature + + @property + def camera_features(self) -> dict[str, dict]: + cam_ft = {} + for cam_key, cam in self.cameras.items(): + cam_ft[cam_key] = { + "shape": (cam.height, cam.width, cam.channels), + "names": ["height", "width", "channels"], + "info": None, + } + return cam_ft + + def connect(self) -> None: + self.is_connected = self.api.startup() + if not self.is_connected: + print("Another process is already using Stretch. Try running 'stretch_free_robot_process.py'") + raise ConnectionError() + + for cam in self.cameras.values(): + cam.connect() + self.is_connected = self.is_connected and cam.is_connected + + if not self.is_connected: + print("Could not connect to the cameras, check that all cameras are plugged-in.") + raise ConnectionError() + + self.calibrate() + + def calibrate(self) -> None: + if not self.api.is_homed(): + self.api.home() + + def _get_state(self) -> dict: + status = self.api.get_status() + return {k: get_nested_item(status, v, sep=".") for k, v in STRETCH_MOTORS.items()} + + def get_observation(self) -> dict[str, np.ndarray]: + obs_dict = {} + + # Read Stretch state + before_read_t = time.perf_counter() + state = self._get_state() + self.logs["read_pos_dt_s"] = time.perf_counter() - before_read_t + + if self.state_keys is None: + self.state_keys = list(state) + + state = np.asarray(list(state.values())) + obs_dict[OBS_STATE] = state + + # Capture images from cameras + for cam_key, cam in self.cameras.items(): + before_camread_t = time.perf_counter() + obs_dict[f"{OBS_IMAGES}.{cam_key}"] = cam.async_read() + self.logs[f"read_camera_{cam_key}_dt_s"] = cam.logs["delta_timestamp_s"] + self.logs[f"async_read_camera_{cam_key}_dt_s"] = time.perf_counter() - before_camread_t + + return obs_dict + + def send_action(self, action: np.ndarray) -> np.ndarray: + if not self.is_connected: + raise ConnectionError() + + if self.teleop is None: + self.teleop = GamePadTeleop(robot_instance=False) + self.teleop.startup(robot=self) + + if self.action_keys is None: + dummy_action = self.teleop.gamepad_controller.get_state() + self.action_keys = list(dummy_action.keys()) + + action_dict = dict(zip(self.action_keys, action.tolist(), strict=True)) + + before_write_t = time.perf_counter() + self.teleop.do_motion(state=action_dict, robot=self) + self.push_command() + self.logs["write_pos_dt_s"] = time.perf_counter() - before_write_t + + # TODO(aliberts): return action_sent when motion is limited + return action + + def print_logs(self) -> None: + pass + # TODO(aliberts): move robot-specific logs logic here + + def teleop_safety_stop(self) -> None: + if self.teleop is not None: + self.teleop._safety_stop(robot=self) + + def disconnect(self) -> None: + self.api.stop() + if self.teleop is not None: + self.teleop.gamepad_controller.stop() + self.teleop.stop() + + for cam in self.cameras.values(): + cam.disconnect() + + self.is_connected = False