diff --git a/.github/workflows/full_tests.yml b/.github/workflows/full_tests.yml index fd5e422b3..d23b99de0 100644 --- a/.github/workflows/full_tests.yml +++ b/.github/workflows/full_tests.yml @@ -173,6 +173,8 @@ jobs: shell: bash working-directory: /lerobot steps: + - name: Fix ptxas permissions + run: chmod +x /lerobot/.venv/lib/python3.10/site-packages/triton/backends/nvidia/bin/ptxas - name: Run pytest on GPU run: pytest tests -vv --maxfail=10 - name: Run end-to-end tests diff --git a/MANIFEST.in b/MANIFEST.in index c1fb2ea75..c1fce3b5a 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,2 +1,3 @@ include src/lerobot/templates/lerobot_modelcard_template.md include src/lerobot/datasets/card_template.md +include src/lerobot/envs/metaworld_config.json diff --git a/docker/Dockerfile.internal b/docker/Dockerfile.internal index c1dfa1dae..ed7d10495 100644 --- a/docker/Dockerfile.internal +++ b/docker/Dockerfile.internal @@ -85,6 +85,8 @@ RUN if [ "$UNBOUND_DEPS" = "true" ]; then \ RUN uv pip install --no-cache ".[all]" +RUN chmod +x /lerobot/.venv/lib/python${PYTHON_VERSION}/site-packages/triton/backends/nvidia/bin/ptxas + # Copy the rest of the application source code # Make sure to have the git-LFS files for testing COPY --chown=user_lerobot:user_lerobot . . diff --git a/pyproject.toml b/pyproject.toml index a8b4add16..8b38e40dd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -217,6 +217,9 @@ lerobot-edit-dataset="lerobot.scripts.lerobot_edit_dataset:main" lerobot-setup-can="lerobot.scripts.lerobot_setup_can:main" # ---------------- Tool Configurations ---------------- +[tool.setuptools.package-data] +lerobot = ["envs/*.json"] + [tool.setuptools.packages.find] where = ["src"] diff --git a/src/lerobot/async_inference/robot_client.py b/src/lerobot/async_inference/robot_client.py index e4d21652a..da576eb48 100644 --- a/src/lerobot/async_inference/robot_client.py +++ b/src/lerobot/async_inference/robot_client.py @@ -49,23 +49,18 @@ import torch from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401 from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401 -from lerobot.robots import ( # noqa: F401 - Robot, - RobotConfig, - bi_so_follower, - koch_follower, +from lerobot.robots import ( + RobotConfig, # noqa: F401 make_robot_from_config, - omx_follower, - so_follower, ) from lerobot.transport import ( services_pb2, # type: ignore services_pb2_grpc, # type: ignore ) from lerobot.transport.utils import grpc_channel_options, send_bytes_in_chunks +from lerobot.utils.import_utils import register_third_party_plugins from .configs import RobotClientConfig -from .constants import SUPPORTED_ROBOTS from .helpers import ( Action, FPSTracker, @@ -485,8 +480,9 @@ class RobotClient: def async_client(cfg: RobotClientConfig): logging.info(pformat(asdict(cfg))) - if cfg.robot.type not in SUPPORTED_ROBOTS: - raise ValueError(f"Robot {cfg.robot.type} not yet supported!") + # TODO: Assert if checking robot support is still needed with the plugin system + # if cfg.robot.type not in SUPPORTED_ROBOTS: + # raise ValueError(f"Robot {cfg.robot.type} not yet supported!") client = RobotClient(cfg) @@ -512,4 +508,5 @@ def async_client(cfg: RobotClientConfig): if __name__ == "__main__": + register_third_party_plugins() async_client() # run the client diff --git a/src/lerobot/datasets/card_template.md b/src/lerobot/datasets/card_template.md index ee26a78f5..1eced9f4c 100644 --- a/src/lerobot/datasets/card_template.md +++ b/src/lerobot/datasets/card_template.md @@ -7,6 +7,13 @@ This dataset was created using [LeRobot](https://github.com/huggingface/lerobot). +{% if repo_id is defined and repo_id %} + + + + +{% endif %} + ## Dataset Description {{ dataset_description | default("", true) }} diff --git a/src/lerobot/datasets/dataset_tools.py b/src/lerobot/datasets/dataset_tools.py index 123d455c6..b62d7d959 100644 --- a/src/lerobot/datasets/dataset_tools.py +++ b/src/lerobot/datasets/dataset_tools.py @@ -567,20 +567,22 @@ def _copy_and_reindex_data( def _keep_episodes_from_video_with_av( input_path: Path, output_path: Path, - episodes_to_keep: list[tuple[float, float]], + episodes_to_keep: list[tuple[int, int]], fps: float, vcodec: str = "libsvtav1", pix_fmt: str = "yuv420p", ) -> None: """Keep only specified episodes from a video file using PyAV. - This function decodes frames from specified time ranges and re-encodes them with + This function decodes frames from specified frame ranges and re-encodes them with properly reset timestamps to ensure monotonic progression. Args: input_path: Source video file path. output_path: Destination video file path. - episodes_to_keep: List of (start_time, end_time) tuples for episodes to keep. + episodes_to_keep: List of (start_frame, end_frame) tuples for episodes to keep. + Ranges are half-open intervals: [start_frame, end_frame), where start_frame + is inclusive and end_frame is exclusive. fps: Frame rate of the video. vcodec: Video codec to use for encoding. pix_fmt: Pixel format for output video. @@ -622,9 +624,10 @@ def _keep_episodes_from_video_with_av( # Create set of (start, end) ranges for fast lookup. # Convert to a sorted list for efficient checking. - time_ranges = sorted(episodes_to_keep) + frame_ranges = sorted(episodes_to_keep) # Track frame index for setting PTS and current range being processed. + src_frame_count = 0 frame_count = 0 range_idx = 0 @@ -634,21 +637,20 @@ def _keep_episodes_from_video_with_av( if frame is None: continue - # Get frame timestamp. - frame_time = float(frame.pts * frame.time_base) if frame.pts is not None else 0.0 - - # Check if frame is in any of our desired time ranges. + # Check if frame is in any of our desired frame ranges. # Skip ranges that have already passed. - while range_idx < len(time_ranges) and frame_time >= time_ranges[range_idx][1]: + while range_idx < len(frame_ranges) and src_frame_count >= frame_ranges[range_idx][1]: range_idx += 1 # If we've passed all ranges, stop processing. - if range_idx >= len(time_ranges): + if range_idx >= len(frame_ranges): break # Check if frame is in current range. - start_ts, end_ts = time_ranges[range_idx] - if frame_time < start_ts: + start_frame = frame_ranges[range_idx][0] + + if src_frame_count < start_frame: + src_frame_count += 1 continue # Frame is in range - create a new frame with reset timestamps. @@ -661,6 +663,7 @@ def _keep_episodes_from_video_with_av( for pkt in v_out.encode(new_frame): out.mux(pkt) + src_frame_count += 1 frame_count += 1 # Flush encoder. @@ -749,15 +752,17 @@ def _copy_and_reindex_videos( f"videos/{video_key}/to_timestamp" ] else: - # Build list of time ranges to keep, in sorted order. + # Build list of frame ranges to keep, in sorted order. sorted_keep_episodes = sorted(episodes_in_file, key=lambda x: episode_mapping[x]) - episodes_to_keep_ranges: list[tuple[float, float]] = [] - + episodes_to_keep_ranges: list[tuple[int, int]] = [] for old_idx in sorted_keep_episodes: src_ep = src_dataset.meta.episodes[old_idx] - from_ts = src_ep[f"videos/{video_key}/from_timestamp"] - to_ts = src_ep[f"videos/{video_key}/to_timestamp"] - episodes_to_keep_ranges.append((from_ts, to_ts)) + from_frame = round(src_ep[f"videos/{video_key}/from_timestamp"] * src_dataset.meta.fps) + to_frame = round(src_ep[f"videos/{video_key}/to_timestamp"] * src_dataset.meta.fps) + assert src_ep["length"] == to_frame - from_frame, ( + f"Episode length mismatch: {src_ep['length']} vs {to_frame - from_frame}" + ) + episodes_to_keep_ranges.append((from_frame, to_frame)) # Use PyAV filters to efficiently re-encode only the desired segments. assert src_dataset.meta.video_path is not None diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index 83d452a44..bb526740e 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -747,7 +747,7 @@ class LeRobotDataset(torch.utils.data.Dataset): # Check if cached dataset contains all requested episodes if not self._check_cached_episodes_sufficient(): raise FileNotFoundError("Cached dataset doesn't contain all requested episodes") - except (AssertionError, FileNotFoundError, NotADirectoryError): + except (FileNotFoundError, NotADirectoryError): if is_valid_version(self.revision): self.revision = get_safe_version(self.repo_id, self.revision) self.download(download_videos) @@ -839,7 +839,7 @@ class LeRobotDataset(torch.utils.data.Dataset): hub_api.upload_folder(**upload_kwargs) card = create_lerobot_dataset_card( - tags=tags, dataset_info=self.meta.info, license=license, **card_kwargs + tags=tags, dataset_info=self.meta.info, license=license, repo_id=self.repo_id, **card_kwargs ) card.push_to_hub(repo_id=self.repo_id, repo_type="dataset", revision=branch) @@ -1771,11 +1771,12 @@ class MultiLeRobotDataset(torch.utils.data.Dataset): ) for repo_id, ds in zip(self.repo_ids, self._datasets, strict=True): extra_keys = set(ds.features).difference(intersection_features) - logging.warning( - f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the " - "other datasets." - ) - self.disabled_features.update(extra_keys) + if extra_keys: + logging.warning( + f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the " + "other datasets." + ) + self.disabled_features.update(extra_keys) self.image_transforms = image_transforms self.delta_timestamps = delta_timestamps diff --git a/src/lerobot/datasets/video_utils.py b/src/lerobot/datasets/video_utils.py index acc24a9e0..8c8494b87 100644 --- a/src/lerobot/datasets/video_utils.py +++ b/src/lerobot/datasets/video_utils.py @@ -227,16 +227,17 @@ def decode_video_frames_torchvision( min_, argmin_ = dist.min(1) is_within_tol = min_ < tolerance_s - assert is_within_tol.all(), ( - f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})." - "It means that the closest frame that can be loaded from the video is too far away in time." - "This might be due to synchronization issues with timestamps during data collection." - "To be safe, we advise to ignore this item during training." - f"\nqueried timestamps: {query_ts}" - f"\nloaded timestamps: {loaded_ts}" - f"\nvideo: {video_path}" - f"\nbackend: {backend}" - ) + if not is_within_tol.all(): + raise FrameTimestampError( + f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})." + " It means that the closest frame that can be loaded from the video is too far away in time." + " This might be due to synchronization issues with timestamps during data collection." + " To be safe, we advise to ignore this item during training." + f"\nqueried timestamps: {query_ts}" + f"\nloaded timestamps: {loaded_ts}" + f"\nvideo: {video_path}" + f"\nbackend: {backend}" + ) # get closest frames to the query timestamps closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_]) @@ -248,7 +249,11 @@ def decode_video_frames_torchvision( # convert to the pytorch format which is float32 in [0,1] range (and channel first) closest_frames = closest_frames.type(torch.float32) / 255 - assert len(timestamps) == len(closest_frames) + if len(timestamps) != len(closest_frames): + raise FrameTimestampError( + f"Number of retrieved frames ({len(closest_frames)}) does not match " + f"number of queried timestamps ({len(timestamps)})" + ) return closest_frames @@ -353,15 +358,16 @@ def decode_video_frames_torchcodec( min_, argmin_ = dist.min(1) is_within_tol = min_ < tolerance_s - assert is_within_tol.all(), ( - f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})." - "It means that the closest frame that can be loaded from the video is too far away in time." - "This might be due to synchronization issues with timestamps during data collection." - "To be safe, we advise to ignore this item during training." - f"\nqueried timestamps: {query_ts}" - f"\nloaded timestamps: {loaded_ts}" - f"\nvideo: {video_path}" - ) + if not is_within_tol.all(): + raise FrameTimestampError( + f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})." + " It means that the closest frame that can be loaded from the video is too far away in time." + " This might be due to synchronization issues with timestamps during data collection." + " To be safe, we advise to ignore this item during training." + f"\nqueried timestamps: {query_ts}" + f"\nloaded timestamps: {loaded_ts}" + f"\nvideo: {video_path}" + ) # get closest frames to the query timestamps closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_]) diff --git a/src/lerobot/policies/diffusion/configuration_diffusion.py b/src/lerobot/policies/diffusion/configuration_diffusion.py index 8ac0920dd..3d30e0941 100644 --- a/src/lerobot/policies/diffusion/configuration_diffusion.py +++ b/src/lerobot/policies/diffusion/configuration_diffusion.py @@ -139,6 +139,10 @@ class DiffusionConfig(PreTrainedConfig): # Inference num_inference_steps: int | None = None + # Optimization + compile_model: bool = False + compile_mode: str = "reduce-overhead" + # Loss computation do_mask_loss_for_padding: bool = False diff --git a/src/lerobot/policies/diffusion/modeling_diffusion.py b/src/lerobot/policies/diffusion/modeling_diffusion.py index 1fdc76f10..314ca369c 100644 --- a/src/lerobot/policies/diffusion/modeling_diffusion.py +++ b/src/lerobot/policies/diffusion/modeling_diffusion.py @@ -142,6 +142,9 @@ class DiffusionPolicy(PreTrainedPolicy): """Run the batch through the model and compute the loss for training or validation.""" if self.config.image_features: batch = dict(batch) # shallow copy so that adding a key doesn't modify the original + for key in self.config.image_features: + if self.config.n_obs_steps == 1 and batch[key].ndim == 4: + batch[key] = batch[key].unsqueeze(1) batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4) loss = self.diffusion.compute_loss(batch) # no output_dict so returning None @@ -182,6 +185,11 @@ class DiffusionModel(nn.Module): self.unet = DiffusionConditionalUnet1d(config, global_cond_dim=global_cond_dim * config.n_obs_steps) + if config.compile_model: + # Compile the U-Net. "reduce-overhead" is preferred for the small-batch repetitive loops + # common in diffusion inference. + self.unet = torch.compile(self.unet, mode=config.compile_mode) + self.noise_scheduler = _make_noise_scheduler( config.noise_scheduler_type, num_train_timesteps=config.num_train_timesteps, diff --git a/src/lerobot/policies/sarm/processor_sarm.py b/src/lerobot/policies/sarm/processor_sarm.py index 5c617282a..8f2bc23db 100644 --- a/src/lerobot/policies/sarm/processor_sarm.py +++ b/src/lerobot/policies/sarm/processor_sarm.py @@ -277,9 +277,7 @@ class SARMEncodingProcessorStep(ProcessorStep): # When language is perturbed, targets are zero so perturbed samples don't contribute to progress loss if self.dataset_meta is not None: - episodes_df = None - if self.sparse_subtask_names != ["task"]: - episodes_df = self.dataset_meta.episodes.to_pandas() + episodes_df = self.dataset_meta.episodes.to_pandas() # Generate sparse targets if self.sparse_temporal_proportions is not None: diff --git a/src/lerobot/policies/smolvla/smolvlm_with_expert.py b/src/lerobot/policies/smolvla/smolvlm_with_expert.py index 555c40773..caca41dab 100644 --- a/src/lerobot/policies/smolvla/smolvlm_with_expert.py +++ b/src/lerobot/policies/smolvla/smolvlm_with_expert.py @@ -77,7 +77,6 @@ class SmolVLMWithExpertModel(nn.Module): print(f"Loading {model_id} weights ...") self.vlm = AutoModelForImageTextToText.from_pretrained( model_id, - device_map=device, torch_dtype="bfloat16", low_cpu_mem_usage=True, ) diff --git a/src/lerobot/scripts/lerobot_setup_motors.py b/src/lerobot/scripts/lerobot_setup_motors.py index 01af95b61..2c962a6e2 100644 --- a/src/lerobot/scripts/lerobot_setup_motors.py +++ b/src/lerobot/scripts/lerobot_setup_motors.py @@ -43,6 +43,7 @@ from lerobot.teleoperators import ( # noqa: F401 koch_leader, make_teleoperator_from_config, omx_leader, + openarm_mini, so_leader, ) @@ -51,6 +52,7 @@ COMPATIBLE_DEVICES = [ "koch_leader", "omx_follower", "omx_leader", + "openarm_mini", "so100_follower", "so100_leader", "so101_follower", diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index 93b99e245..465cbf531 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -24,6 +24,7 @@ import torch from accelerate import Accelerator from termcolor import colored from torch.optim import Optimizer +from tqdm import tqdm from lerobot.configs import parser from lerobot.configs.train import TrainPipelineConfig @@ -51,6 +52,7 @@ from lerobot.utils.utils import ( format_big_number, has_method, init_logging, + inside_slurm, ) @@ -390,6 +392,14 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): ) if is_main_process: + progbar = tqdm( + total=cfg.steps - step, + desc="Training", + unit="step", + disable=inside_slurm(), + position=0, + leave=True, + ) logging.info( f"Start offline training on a fixed dataset, with effective batch size: {effective_batch_size}" ) @@ -414,6 +424,8 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): # Note: eval and checkpoint happens *after* the `step`th training update has completed, so we # increment `step` here. step += 1 + if is_main_process: + progbar.update(1) train_tracker.step() is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0 and is_main_process is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps @@ -507,6 +519,9 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): accelerator.wait_for_everyone() + if is_main_process: + progbar.close() + if eval_env: close_envs(eval_env) diff --git a/src/lerobot/teleoperators/openarm_mini/__init__.py b/src/lerobot/teleoperators/openarm_mini/__init__.py new file mode 100644 index 000000000..8620af1d7 --- /dev/null +++ b/src/lerobot/teleoperators/openarm_mini/__init__.py @@ -0,0 +1,20 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .config_openarm_mini import OpenArmMiniConfig +from .openarm_mini import OpenArmMini + +__all__ = ["OpenArmMini", "OpenArmMiniConfig"] diff --git a/src/lerobot/teleoperators/openarm_mini/config_openarm_mini.py b/src/lerobot/teleoperators/openarm_mini/config_openarm_mini.py new file mode 100644 index 000000000..7dc3e0212 --- /dev/null +++ b/src/lerobot/teleoperators/openarm_mini/config_openarm_mini.py @@ -0,0 +1,30 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass + +from ..config import TeleoperatorConfig + + +@TeleoperatorConfig.register_subclass("openarm_mini") +@dataclass +class OpenArmMiniConfig(TeleoperatorConfig): + """Configuration for OpenArm Mini teleoperator with Feetech motors (dual arms).""" + + port_right: str = "/dev/ttyUSB0" + port_left: str = "/dev/ttyUSB1" + + use_degrees: bool = True diff --git a/src/lerobot/teleoperators/openarm_mini/openarm_mini.py b/src/lerobot/teleoperators/openarm_mini/openarm_mini.py new file mode 100644 index 000000000..3fbcecf24 --- /dev/null +++ b/src/lerobot/teleoperators/openarm_mini/openarm_mini.py @@ -0,0 +1,296 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import time +from typing import Any + +from lerobot.motors import Motor, MotorCalibration, MotorNormMode +from lerobot.motors.feetech import ( + FeetechMotorsBus, + OperatingMode, +) +from lerobot.processor import RobotAction +from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected + +from ..teleoperator import Teleoperator +from .config_openarm_mini import OpenArmMiniConfig + +logger = logging.getLogger(__name__) + +# Motors whose direction is inverted during readout +RIGHT_MOTORS_TO_FLIP = ["joint_1", "joint_2", "joint_3", "joint_4", "joint_5"] +LEFT_MOTORS_TO_FLIP = ["joint_1", "joint_3", "joint_4", "joint_5", "joint_6", "joint_7"] + + +class OpenArmMini(Teleoperator): + """ + OpenArm Mini Teleoperator with dual Feetech-based arms (8 motors per arm). + + Each arm has 7 joints plus a gripper, using Feetech STS3215 servos. + """ + + config_class = OpenArmMiniConfig + name = "openarm_mini" + + def __init__(self, config: OpenArmMiniConfig): + super().__init__(config) + self.config = config + + norm_mode_body = MotorNormMode.DEGREES + + motors_right = { + "joint_1": Motor(1, "sts3215", norm_mode_body), + "joint_2": Motor(2, "sts3215", norm_mode_body), + "joint_3": Motor(3, "sts3215", norm_mode_body), + "joint_4": Motor(4, "sts3215", norm_mode_body), + "joint_5": Motor(5, "sts3215", norm_mode_body), + "joint_6": Motor(6, "sts3215", norm_mode_body), + "joint_7": Motor(7, "sts3215", norm_mode_body), + "gripper": Motor(8, "sts3215", MotorNormMode.RANGE_0_100), + } + + motors_left = { + "joint_1": Motor(1, "sts3215", norm_mode_body), + "joint_2": Motor(2, "sts3215", norm_mode_body), + "joint_3": Motor(3, "sts3215", norm_mode_body), + "joint_4": Motor(4, "sts3215", norm_mode_body), + "joint_5": Motor(5, "sts3215", norm_mode_body), + "joint_6": Motor(6, "sts3215", norm_mode_body), + "joint_7": Motor(7, "sts3215", norm_mode_body), + "gripper": Motor(8, "sts3215", MotorNormMode.RANGE_0_100), + } + + cal_right = { + k.replace("right_", ""): v for k, v in (self.calibration or {}).items() if k.startswith("right_") + } + cal_left = { + k.replace("left_", ""): v for k, v in (self.calibration or {}).items() if k.startswith("left_") + } + + self.bus_right = FeetechMotorsBus( + port=self.config.port_right, + motors=motors_right, + calibration=cal_right, + ) + + self.bus_left = FeetechMotorsBus( + port=self.config.port_left, + motors=motors_left, + calibration=cal_left, + ) + + @property + def action_features(self) -> dict[str, type]: + features: dict[str, type] = {} + for motor in self.bus_right.motors: + features[f"right_{motor}.pos"] = float + for motor in self.bus_left.motors: + features[f"left_{motor}.pos"] = float + return features + + @property + def feedback_features(self) -> dict[str, type]: + return {} + + @property + def is_connected(self) -> bool: + return self.bus_right.is_connected and self.bus_left.is_connected + + @check_if_already_connected + def connect(self, calibrate: bool = True) -> None: + logger.info(f"Connecting right arm on {self.config.port_right}...") + self.bus_right.connect() + logger.info(f"Connecting left arm on {self.config.port_left}...") + self.bus_left.connect() + + if calibrate: + self.calibrate() + + self.configure() + logger.info(f"{self} connected.") + + @property + def is_calibrated(self) -> bool: + return self.bus_right.is_calibrated and self.bus_left.is_calibrated + + def calibrate(self) -> None: + """ + Run calibration procedure for OpenArm Mini. + + 1. Disable torque + 2. Ask user to position arms in hanging position with grippers closed + 3. Set this as zero position via half-turn homing + 4. Interactive gripper calibration (open/close positions) + 5. Save calibration + """ + if self.calibration: + user_input = input( + f"Press ENTER to use existing calibration for {self.id}, " + f"or type 'c' and press ENTER to run new calibration: " + ) + if user_input.strip().lower() != "c": + logger.info(f"Using existing calibration for {self.id}") + cal_right = { + k.replace("right_", ""): v for k, v in self.calibration.items() if k.startswith("right_") + } + cal_left = { + k.replace("left_", ""): v for k, v in self.calibration.items() if k.startswith("left_") + } + self.bus_right.write_calibration(cal_right) + self.bus_left.write_calibration(cal_left) + return + + logger.info(f"\nRunning calibration for {self}") + + self._calibrate_arm("right", self.bus_right) + self._calibrate_arm("left", self.bus_left) + + self._save_calibration() + print(f"\nCalibration complete and saved to {self.calibration_fpath}") + + def _calibrate_arm(self, arm_name: str, bus: FeetechMotorsBus) -> None: + """Calibrate a single arm with Feetech motors.""" + logger.info(f"\n=== Calibrating {arm_name.upper()} arm ===") + + bus.disable_torque() + + logger.info(f"Setting Phase to 12 for all motors in {arm_name.upper()} arm...") + for motor in bus.motors: + bus.write("Phase", motor, 12) + + for motor in bus.motors: + bus.write("Operating_Mode", motor, OperatingMode.POSITION.value) + + input( + f"\nCalibration: Zero Position ({arm_name.upper()} arm)\n" + "Position the arm in the following configuration:\n" + " - Arm hanging straight down\n" + " - Gripper closed\n" + "Press ENTER when ready..." + ) + + homing_offsets = bus.set_half_turn_homings() + logger.info(f"{arm_name.capitalize()} arm zero position set.") + + print(f"\nSetting motor ranges for {arm_name.upper()} arm\n") + + if self.calibration is None: + self.calibration = {} + + motor_resolution = bus.model_resolution_table[list(bus.motors.values())[0].model] + max_res = motor_resolution - 1 + + for motor_name, motor in bus.motors.items(): + prefixed_name = f"{arm_name}_{motor_name}" + + if motor_name == "gripper": + input( + f"\nGripper Calibration ({arm_name.upper()} arm)\n" + f"Step 1: CLOSE the gripper fully\n" + f"Press ENTER when gripper is closed..." + ) + closed_pos = bus.read("Present_Position", motor_name, normalize=False) + logger.info(f" Gripper closed position recorded: {closed_pos}") + + input("\nStep 2: OPEN the gripper fully\nPress ENTER when gripper is fully open...") + open_pos = bus.read("Present_Position", motor_name, normalize=False) + logger.info(f" Gripper open position recorded: {open_pos}") + + if closed_pos < open_pos: + range_min = int(closed_pos) + range_max = int(open_pos) + drive_mode = 0 + else: + range_min = int(open_pos) + range_max = int(closed_pos) + drive_mode = 1 + + logger.info( + f" {prefixed_name}: range set to [{range_min}, {range_max}] " + f"(0=closed, 100=open, drive_mode={drive_mode})" + ) + else: + range_min = 0 + range_max = max_res + drive_mode = 0 + logger.info(f" {prefixed_name}: range set to [0, {max_res}] (full motor range)") + + self.calibration[prefixed_name] = MotorCalibration( + id=motor.id, + drive_mode=drive_mode, + homing_offset=homing_offsets[motor_name], + range_min=range_min, + range_max=range_max, + ) + + cal_for_bus = { + k.replace(f"{arm_name}_", ""): v + for k, v in self.calibration.items() + if k.startswith(f"{arm_name}_") + } + bus.write_calibration(cal_for_bus) + + def configure(self) -> None: + self.bus_right.disable_torque() + self.bus_right.configure_motors() + for motor in self.bus_right.motors: + self.bus_right.write("Operating_Mode", motor, OperatingMode.POSITION.value) + + self.bus_left.disable_torque() + self.bus_left.configure_motors() + for motor in self.bus_left.motors: + self.bus_left.write("Operating_Mode", motor, OperatingMode.POSITION.value) + + def setup_motors(self) -> None: + print("\nSetting up RIGHT arm motors...") + for motor in reversed(self.bus_right.motors): + input(f"Connect the controller board to the RIGHT '{motor}' motor only and press enter.") + self.bus_right.setup_motor(motor) + print(f"RIGHT '{motor}' motor id set to {self.bus_right.motors[motor].id}") + + print("\nSetting up LEFT arm motors...") + for motor in reversed(self.bus_left.motors): + input(f"Connect the controller board to the LEFT '{motor}' motor only and press enter.") + self.bus_left.setup_motor(motor) + print(f"LEFT '{motor}' motor id set to {self.bus_left.motors[motor].id}") + + @check_if_not_connected + def get_action(self) -> RobotAction: + """Get current action from both arms (read positions from all motors).""" + start = time.perf_counter() + + right_positions = self.bus_right.sync_read("Present_Position") + left_positions = self.bus_left.sync_read("Present_Position") + + action: dict[str, Any] = {} + for motor, val in right_positions.items(): + action[f"right_{motor}.pos"] = -val if motor in RIGHT_MOTORS_TO_FLIP else val + for motor, val in left_positions.items(): + action[f"left_{motor}.pos"] = -val if motor in LEFT_MOTORS_TO_FLIP else val + + dt_ms = (time.perf_counter() - start) * 1e3 + logger.debug(f"{self} read action: {dt_ms:.1f}ms") + return action + + def send_feedback(self, feedback: dict[str, float]) -> None: + raise NotImplementedError("Feedback is not yet implemented for OpenArm Mini.") + + @check_if_not_connected + def disconnect(self) -> None: + self.bus_right.disconnect() + self.bus_left.disconnect() + logger.info(f"{self} disconnected.") diff --git a/src/lerobot/teleoperators/utils.py b/src/lerobot/teleoperators/utils.py index 16454d5ad..db685f396 100644 --- a/src/lerobot/teleoperators/utils.py +++ b/src/lerobot/teleoperators/utils.py @@ -95,6 +95,10 @@ def make_teleoperator_from_config(config: TeleoperatorConfig) -> "Teleoperator": from .bi_openarm_leader import BiOpenArmLeader return BiOpenArmLeader(config) + elif config.type == "openarm_mini": + from .openarm_mini import OpenArmMini + + return OpenArmMini(config) else: try: return cast("Teleoperator", make_device_from_device_class(config)) diff --git a/src/lerobot/utils/control_utils.py b/src/lerobot/utils/control_utils.py index 7cfe177ef..7c605af17 100644 --- a/src/lerobot/utils/control_utils.py +++ b/src/lerobot/utils/control_utils.py @@ -189,7 +189,7 @@ def sanity_check_dataset_name(repo_id, policy_cfg): # Check if dataset_name starts with "eval_" but policy is missing if dataset_name.startswith("eval_") and policy_cfg is None: raise ValueError( - f"Your dataset name begins with 'eval_' ({dataset_name}), but no policy is provided ({policy_cfg.type})." + f"Your dataset name begins with 'eval_' ({dataset_name}), but no policy is provided." ) # Check if dataset_name does not start with "eval_" but policy is provided