From d36dfcdf7155debdaa17cc8fbae35eb150dc965f Mon Sep 17 00:00:00 2001 From: Caroline Pascal Date: Tue, 20 Jan 2026 15:00:45 +0100 Subject: [PATCH 001/131] fix(discord link): fixing discord link in CONTRIBUTING.md (#2826) Signed-off-by: Caroline Pascal --- CONTRIBUTING.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index abca0d821..c51a48831 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -14,7 +14,7 @@ You can contribute in many ways: - **Documentation:** Improve examples, guides, and docstrings. - **Feedback:** Submit tickets related to bugs or desired new features. -If you are unsure where to start, join our [Discord Channel](https://discord.gg/JkrYNdmw). +If you are unsure where to start, join our [Discord Channel](https://discord.gg/q8Dzzpym3f). ## Development Setup From 9919b16b3693a756388dd2cccdd04396308fd0ae Mon Sep 17 00:00:00 2001 From: sato_shinji Date: Tue, 20 Jan 2026 14:17:38 +0000 Subject: [PATCH 002/131] fix: ensure action tensors are moved to client_device in async training (#2792) * feat(async_inference): server always sends CPU tensors, client handles device conversion * fix:fix the type annotation of RawObservation in src/lerobot/async_inference/helpers.py * update the import of robot_client --------- Co-authored-by: Sato shinji Co-authored-by: Steven Palma Co-authored-by: KB --- docs/source/async.mdx | 1 + examples/tutorial/async-inf/robot_client.py | 1 + src/lerobot/async_inference/configs.py | 10 ++++++++++ src/lerobot/async_inference/helpers.py | 5 +++-- src/lerobot/async_inference/policy_server.py | 2 ++ src/lerobot/async_inference/robot_client.py | 20 ++++++++++++++++++-- tests/async_inference/test_e2e.py | 10 ++++++++-- 7 files changed, 43 insertions(+), 6 deletions(-) diff --git a/docs/source/async.mdx b/docs/source/async.mdx index 1d3e0edbf..3244fc2a3 100644 --- a/docs/source/async.mdx +++ b/docs/source/async.mdx @@ -195,6 +195,7 @@ client_cfg = RobotClientConfig( robot=robot_cfg, server_address="localhost:8080", policy_device="mps", + client_device="cpu", policy_type="smolvla", pretrained_name_or_path="/smolvla_async", chunk_size_threshold=0.5, diff --git a/examples/tutorial/async-inf/robot_client.py b/examples/tutorial/async-inf/robot_client.py index eb3751169..db6ead3fe 100644 --- a/examples/tutorial/async-inf/robot_client.py +++ b/examples/tutorial/async-inf/robot_client.py @@ -30,6 +30,7 @@ def main(): robot=robot_cfg, server_address=server_address, policy_device="mps", + client_device="cpu", policy_type="act", pretrained_name_or_path="/robot_learning_tutorial_act", chunk_size_threshold=0.5, # g diff --git a/src/lerobot/async_inference/configs.py b/src/lerobot/async_inference/configs.py index d1768a323..2e3fe576d 100644 --- a/src/lerobot/async_inference/configs.py +++ b/src/lerobot/async_inference/configs.py @@ -126,6 +126,12 @@ class RobotClientConfig: # Device configuration policy_device: str = field(default="cpu", metadata={"help": "Device for policy inference"}) + client_device: str = field( + default="cpu", + metadata={ + "help": "Device to move actions to after receiving from server (e.g., for downstream planners)" + }, + ) # Control behavior configuration chunk_size_threshold: float = field(default=0.5, metadata={"help": "Threshold for chunk size control"}) @@ -161,6 +167,9 @@ class RobotClientConfig: if not self.policy_device: raise ValueError("policy_device cannot be empty") + if not self.client_device: + raise ValueError("client_device cannot be empty") + if self.chunk_size_threshold < 0 or self.chunk_size_threshold > 1: raise ValueError(f"chunk_size_threshold must be between 0 and 1, got {self.chunk_size_threshold}") @@ -184,6 +193,7 @@ class RobotClientConfig: "policy_type": self.policy_type, "pretrained_name_or_path": self.pretrained_name_or_path, "policy_device": self.policy_device, + "client_device": self.client_device, "chunk_size_threshold": self.chunk_size_threshold, "fps": self.fps, "actions_per_chunk": self.actions_per_chunk, diff --git a/src/lerobot/async_inference/helpers.py b/src/lerobot/async_inference/helpers.py index 2158f51ac..8b12920d9 100644 --- a/src/lerobot/async_inference/helpers.py +++ b/src/lerobot/async_inference/helpers.py @@ -18,6 +18,7 @@ import os import time from dataclasses import dataclass, field from pathlib import Path +from typing import Any import torch @@ -39,8 +40,8 @@ from lerobot.utils.utils import init_logging Action = torch.Tensor -# observation as received from the robot -RawObservation = dict[str, torch.Tensor] +# observation as received from the robot (can be numpy arrays, floats, etc.) +RawObservation = dict[str, Any] # observation as those recorded in LeRobot dataset (keys are different) LeRobotObservation = dict[str, torch.Tensor] diff --git a/src/lerobot/async_inference/policy_server.py b/src/lerobot/async_inference/policy_server.py index ab2e6bcd8..aedce2a74 100644 --- a/src/lerobot/async_inference/policy_server.py +++ b/src/lerobot/async_inference/policy_server.py @@ -381,6 +381,8 @@ class PolicyServer(services_pb2_grpc.AsyncInferenceServicer): action_tensor = torch.stack(processed_actions, dim=1).squeeze(0) self.logger.debug(f"Postprocessed action shape: {action_tensor.shape}") + action_tensor = action_tensor.detach().cpu() + """5. Convert to TimedAction list""" action_chunk = self._time_action_chunk( observation_t.get_timestamp(), list(action_tensor), observation_t.get_timestep() diff --git a/src/lerobot/async_inference/robot_client.py b/src/lerobot/async_inference/robot_client.py index f26639dc1..e4d21652a 100644 --- a/src/lerobot/async_inference/robot_client.py +++ b/src/lerobot/async_inference/robot_client.py @@ -25,6 +25,7 @@ python src/lerobot/async_inference/robot_client.py \ --policy_type=act \ --pretrained_name_or_path=user/model \ --policy_device=mps \ + --client_device=cpu \ --actions_per_chunk=50 \ --chunk_size_threshold=0.5 \ --aggregate_fn_name=weighted_average \ @@ -40,6 +41,7 @@ from collections.abc import Callable from dataclasses import asdict from pprint import pformat from queue import Queue +from typing import Any import draccus import grpc @@ -47,7 +49,6 @@ import torch from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401 from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401 -from lerobot.processor import RobotAction from lerobot.robots import ( # noqa: F401 Robot, RobotConfig, @@ -285,6 +286,21 @@ class RobotClient: timed_actions = pickle.loads(actions_chunk.data) # nosec deserialize_time = time.perf_counter() - deserialize_start + # Log device type of received actions + if len(timed_actions) > 0: + received_device = timed_actions[0].get_action().device.type + self.logger.debug(f"Received actions on device: {received_device}") + + # Move actions to client_device (e.g., for downstream planners that need GPU) + client_device = self.config.client_device + if client_device != "cpu": + for timed_action in timed_actions: + if timed_action.get_action().device.type != client_device: + timed_action.action = timed_action.get_action().to(client_device) + self.logger.debug(f"Converted actions to device: {client_device}") + else: + self.logger.debug(f"Actions kept on device: {client_device}") + self.action_chunk_size = max(self.action_chunk_size, len(timed_actions)) # Calculate network latency if we have matching observations @@ -351,7 +367,7 @@ class RobotClient: action = {key: action_tensor[i].item() for i, key in enumerate(self.robot.action_features)} return action - def control_loop_action(self, verbose: bool = False) -> RobotAction: + def control_loop_action(self, verbose: bool = False) -> dict[str, Any]: """Reading and performing actions in local queue""" # Lock only for queue operations diff --git a/tests/async_inference/test_e2e.py b/tests/async_inference/test_e2e.py index 11941ce32..54ca29b48 100644 --- a/tests/async_inference/test_e2e.py +++ b/tests/async_inference/test_e2e.py @@ -144,12 +144,18 @@ def test_async_inference_e2e(monkeypatch): client = RobotClient(client_config) assert client.start(), "Client failed initial handshake with the server" - # Track action chunks received without modifying RobotClient - action_chunks_received = {"count": 0} + # Track action chunks received and verify device type + action_chunks_received = {"count": 0, "actions_on_cpu": True} original_aggregate = client._aggregate_action_queues def counting_aggregate(*args, **kwargs): action_chunks_received["count"] += 1 + # Check that all received actions are on CPU + if args: + for timed_action in args[0]: # args[0] is the list of TimedAction + action_tensor = timed_action.get_action() + if action_tensor.device.type != "cpu": + action_chunks_received["actions_on_cpu"] = False return original_aggregate(*args, **kwargs) monkeypatch.setattr(client, "_aggregate_action_queues", counting_aggregate) From 9ca680dce28f2a58c6af12a62980b1ae86b1659f Mon Sep 17 00:00:00 2001 From: Tommy in Tongji <36354458+TommyZihao@users.noreply.github.com> Date: Wed, 21 Jan 2026 00:54:24 +0800 Subject: [PATCH 003/131] Update README.md (#2827) Add Chinese doc link. Signed-off-by: Tommy in Tongji <36354458+TommyZihao@users.noreply.github.com> --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 57fec2e5f..d60cd35a9 100644 --- a/README.md +++ b/README.md @@ -128,6 +128,7 @@ Learn how to implement your own simulation environment or benchmark and distribu ## Resources - **[Documentation](https://huggingface.co/docs/lerobot/index):** The complete guide to tutorials & API. +- **[Chinese Tutorials: LeRobot+SO-ARM101中文教程-同济子豪兄](https://zihao-ai.feishu.cn/wiki/space/7589642043471924447)** Detailed doc for assembling, teleoperate, dataset, train, deploy. Verified by Seed Studio and 5 global hackathon players. - **[Discord](https://discord.gg/q8Dzzpym3f):** Join the `LeRobot` server to discuss with the community. - **[X](https://x.com/LeRobotHF):** Follow us on X to stay up-to-date with the latest developments. - **[Robot Learning Tutorial](https://huggingface.co/spaces/lerobot/robot-learning-tutorial):** A free, hands-on course to learn robot learning using LeRobot. From 0b067df57d21d3a02d6c511f1609172fa39ac29b Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Tue, 20 Jan 2026 18:02:38 +0100 Subject: [PATCH 004/131] feat(robots): add context managers (#2828) --- src/lerobot/robots/robot.py | 26 +++++++++++++++++++++++ src/lerobot/teleoperators/teleoperator.py | 26 +++++++++++++++++++++++ 2 files changed, 52 insertions(+) diff --git a/src/lerobot/robots/robot.py b/src/lerobot/robots/robot.py index d1021daf4..d165886b9 100644 --- a/src/lerobot/robots/robot.py +++ b/src/lerobot/robots/robot.py @@ -58,6 +58,32 @@ class Robot(abc.ABC): def __str__(self) -> str: return f"{self.id} {self.__class__.__name__}" + def __enter__(self): + """ + Context manager entry. + Automatically connects to the camera. + """ + self.connect() + return self + + def __exit__(self, exc_type, exc_value, traceback) -> None: + """ + Context manager exit. + Automatically disconnects, ensuring resources are released even on error. + """ + self.disconnect() + + def __del__(self) -> None: + """ + Destructor safety net. + Attempts to disconnect if the object is garbage collected without cleanup. + """ + try: + if self.is_connected: + self.disconnect() + except Exception: # nosec B110 + pass + # TODO(aliberts): create a proper Feature class for this that links with datasets @property @abc.abstractmethod diff --git a/src/lerobot/teleoperators/teleoperator.py b/src/lerobot/teleoperators/teleoperator.py index cd9e3a53d..847b88b7f 100644 --- a/src/lerobot/teleoperators/teleoperator.py +++ b/src/lerobot/teleoperators/teleoperator.py @@ -58,6 +58,32 @@ class Teleoperator(abc.ABC): def __str__(self) -> str: return f"{self.id} {self.__class__.__name__}" + def __enter__(self): + """ + Context manager entry. + Automatically connects to the camera. + """ + self.connect() + return self + + def __exit__(self, exc_type, exc_value, traceback) -> None: + """ + Context manager exit. + Automatically disconnects, ensuring resources are released even on error. + """ + self.disconnect() + + def __del__(self) -> None: + """ + Destructor safety net. + Attempts to disconnect if the object is garbage collected without cleanup. + """ + try: + if self.is_connected: + self.disconnect() + except Exception: # nosec B110 + pass + @property @abc.abstractmethod def action_features(self) -> dict: From 961277d86e2467257c3916257baa4817632f2ed5 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Thu, 22 Jan 2026 12:24:12 +0100 Subject: [PATCH 005/131] chore(dependencies): Bump lerobot to 0.4.4 (#2840) --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index fa4b22bdf..75f617e75 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ discord = "https://discord.gg/s3KuuzsPFb" [project] name = "lerobot" -version = "0.4.3" +version = "0.4.4" description = "🤗 LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch" dynamic = ["readme"] license = { text = "Apache-2.0" } From 6d34a986de44c5f22a9a99ed514f1b16832c3f32 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Thu, 22 Jan 2026 12:26:17 +0100 Subject: [PATCH 006/131] feat(ci): trigger manually documentation release version (#2841) --- .github/workflows/documentation.yml | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/.github/workflows/documentation.yml b/.github/workflows/documentation.yml index 48a10e4bc..c7926c542 100644 --- a/.github/workflows/documentation.yml +++ b/.github/workflows/documentation.yml @@ -18,6 +18,11 @@ name: Documentation on: # Allows running this workflow manually from the Actions tab workflow_dispatch: + inputs: + version: + description: 'Version tag (e.g. v0.1.2) - Leave empty for standard main build' + required: false + type: string # Triggers the workflow on push events to main for the docs folder push: @@ -54,7 +59,13 @@ jobs: with: commit_sha: ${{ github.sha }} package: lerobot - additional_args: --not_python_module ${{ github.event_name == 'release' && format('--version {0}', github.event.release.tag_name) || '' }} + additional_args: >- + --not_python_module + ${{ + (github.event_name == 'release' && format('--version {0}', github.event.release.tag_name)) || + (inputs.version != '' && format('--version {0}', inputs.version)) || + '' + }} secrets: token: ${{ secrets.HUGGINGFACE_PUSH }} hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }} From 9e10eb4a77856fb2dfb6784c4afc3984748387f7 Mon Sep 17 00:00:00 2001 From: Woojin Wie Date: Mon, 26 Jan 2026 06:29:37 +0900 Subject: [PATCH 007/131] fix(robots): update gripper configuration and calibration settings for OMX (#2815) --- docs/source/_toctree.yml | 2 + docs/source/omx.mdx | 197 ++++++++++++++++++ .../omx_leader/config_omx_leader.py | 2 +- .../teleoperators/omx_leader/omx_leader.py | 10 +- 4 files changed, 209 insertions(+), 2 deletions(-) create mode 100644 docs/source/omx.mdx diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 2b8086cd7..4298758b5 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -99,6 +99,8 @@ title: Unitree G1 - local: earthrover_mini_plus title: Earth Rover Mini + - local: omx + title: OMX title: "Robots" - sections: - local: phone_teleop diff --git a/docs/source/omx.mdx b/docs/source/omx.mdx new file mode 100644 index 000000000..4617ac7bd --- /dev/null +++ b/docs/source/omx.mdx @@ -0,0 +1,197 @@ +## Order and Assemble the parts + +First, assemble the OMX hardware following the official assembly guide. + +OMX Assembly Guide: https://ai.robotis.com/omx/assembly_guide_omx.html + +OMX robots are shipped preconfigured from the factory. Motor IDs, communication parameters, and joint offsets are already set, so no additional motor setup or calibration is required before using LeRobot. + +## Install LeRobot 🤗 + +To install LeRobot, follow our [Installation Guide](./installation) + +In addition to these instructions, you need to install the Dynamixel SDK: + +```bash +pip install -e ".[dynamixel]" +``` + +## Connect the robot + +To find the port for each bus servo adapter, run this script: + +```bash +lerobot-find-port +``` + +This command runs and when prompted, disconnect the USB cable from either the leader or follower arm and press Enter. The output will show 'The port of this MotorsBus is [port]'. This identifies the port for the disconnected arm. Repeat for the other arm to identify both ports. + + + + +Example output on macOS: + +``` +Finding all available ports for the MotorBus. +['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751'] +Remove the USB cable from your MotorsBus and press Enter when done. + +[...Disconnect corresponding leader or follower arm and press Enter...] + +The port of this MotorsBus is /dev/tty.usbmodem575E0032081 +Reconnect the USB cable. +``` + +Where the found port is: `/dev/tty.usbmodem575E0032081` corresponding to your leader or follower arm. + + + + +On Linux, we strongly recommend using udev rules to assign persistent and human-readable device names to the OMX leader and follower arms. This avoids issues where device names such as ttyACM0 and ttyACM1 change when the robot is unplugged, replugged, or when the system is rebooted. + +#### 1. Find your device serial numbers + +You should have obtained the port numbers like ../../ttyACM? for the leader and follower using `lerobot-find-port`. You can match those results with the serial numbers using the `ls -l /dev/serial/by-id/` command. +To create udev rules, you need the unique serial number for each OMX device. The easiest way is to list devices under: + +```bash +ls -l /dev/serial/by-id/ +``` + +You will see output similar to: + +```bash +usb-ROBOTIS_OpenRB-150_228BDD7B503059384C2E3120FF0A2B19-if00 -> ../../ttyACM0 +usb-ROBOTIS_OpenRB-150_67E1ED68503059384C2E3120FF092234-if00 -> ../../ttyACM1 +``` + +In each line, the serial number is the long string after `usb-ROBOTIS_OpenRB-150_` and before `-if00`. + +Follower serial: `228BDD7B503059384C2E3120FF0A2B19` + +Leader serial: `67E1ED68503059384C2E3120FF092234` + +#### 2. Create the udev rule + +Create a new udev rule file: + +```bash +sudo nano /etc/udev/rules.d/99-omx.rules +``` + +Paste the following lines, replacing the serial numbers with the values you found above: + +```bash +SUBSYSTEM=="tty", ATTRS{idVendor}=="0403", ATTRS{serial}=="228BDD7B503059384C2E3120FF0A2B19", SYMLINK+="omx_follower" +SUBSYSTEM=="tty", ATTRS{idVendor}=="0403", ATTRS{serial}=="67E1ED68503059384C2E3120FF092234", SYMLINK+="omx_leader" +``` + +Save the file and reload udev rules: + +```bash +sudo udevadm control --reload-rules +sudo udevadm trigger +``` + +Now unplug and replug both devices once. + +#### 3. Verify the symlinks + +Check that the persistent device names exist: + +```bash +ls -l /dev/omx_follower /dev/omx_leader +``` + +You should see them pointing to ttyACM\* devices: + +```bash +/dev/omx_follower -> ttyACM* +/dev/omx_leader -> ttyACM* +``` + +These names remain stable across reboots and reconnections. + + + + +## Teleoperate + +After identifying the correct ports, you can directly teleoperate the follower arm using the leader arm. + + + + +### Teleoperate without camera + +```bash +lerobot-teleoperate \ + --robot.type=omx_follower \ + --robot.port= \ + --robot.id=omx_follower_arm \ + --teleop.type=omx_leader \ + --teleop.port= \ + --teleop.id=omx_leader_arm +``` + +During teleoperation, motions of the leader arm are mirrored in real time by the follower arm. OMX is already preconfigured, teleoperation can begin immediately without any calibration steps. + +### Teleoperate with camera + +You can also enable camera input during teleoperation by providing a camera configuration for the follower arm. + +```bash +lerobot-teleoperate \ + --robot.type=omx_follower \ + --robot.port= \ + --robot.id=omx_follower_arm \ + --robot.cameras="{front: {type: opencv, index_or_path: '/dev/video0', width: 640, height: 480, fps: 30}}" \ + --teleop.type=omx_leader \ + --teleop.port= \ + --teleop.id=omx_leader_arm \ + --display_data=true +``` + +When the camera is enabled, the camera stream is displayed in real time and synchronized with the robot state. This setup is useful for visual monitoring and can be reused later for demonstration recording and imitation learning. + + + + +### Teleoperate without camera + +```bash +lerobot-teleoperate \ + --robot.type=omx_follower \ + --robot.port=/dev/omx_follower \ + --robot.id=omx_follower_arm \ + --teleop.type=omx_leader \ + --teleop.port=/dev/omx_leader \ + --teleop.id=omx_leader_arm +``` + +During teleoperation, motions of the leader arm are mirrored in real time by the follower arm. OMX is already preconfigured, teleoperation can begin immediately without any calibration steps. + +### Teleoperate with camera + +You can also enable camera input during teleoperation by providing a camera configuration for the follower arm. + +```bash +lerobot-teleoperate \ + --robot.type=omx_follower \ + --robot.port=/dev/omx_follower \ + --robot.id=omx_follower_arm \ + --robot.cameras="{front: {type: opencv, index_or_path: '/dev/video0', width: 640, height: 480, fps: 30}}" \ + --teleop.type=omx_leader \ + --teleop.port=/dev/omx_leader \ + --teleop.id=omx_leader_arm \ + --display_data=true +``` + +When the camera is enabled, the camera stream is displayed in real time and synchronized with the robot state. This setup is useful for visual monitoring and can be reused later for demonstration recording and imitation learning. + + + + +Congrats 🎉, your robot is all set to learn a task on its own. + +> If you have any questions or need help, please reach out on [Discord](https://discord.com/invite/robotis). diff --git a/src/lerobot/teleoperators/omx_leader/config_omx_leader.py b/src/lerobot/teleoperators/omx_leader/config_omx_leader.py index 3c0420ab2..a0eca38f7 100644 --- a/src/lerobot/teleoperators/omx_leader/config_omx_leader.py +++ b/src/lerobot/teleoperators/omx_leader/config_omx_leader.py @@ -27,4 +27,4 @@ class OmxLeaderConfig(TeleoperatorConfig): # Sets the arm in torque mode with the gripper motor set to this value. This makes it possible to squeeze # the gripper and have it spring back to an open position on its own. - gripper_open_pos: float = 37.0 + gripper_open_pos: float = 60.0 diff --git a/src/lerobot/teleoperators/omx_leader/omx_leader.py b/src/lerobot/teleoperators/omx_leader/omx_leader.py index 4423be714..4264b0485 100644 --- a/src/lerobot/teleoperators/omx_leader/omx_leader.py +++ b/src/lerobot/teleoperators/omx_leader/omx_leader.py @@ -103,7 +103,7 @@ class OmxLeader(Teleoperator): self.calibration[motor] = MotorCalibration( id=m.id, drive_mode=drive_modes[motor], - homing_offset=0, + homing_offset=0 if motor != "gripper" else 100, range_min=0, range_max=4095, ) @@ -123,12 +123,20 @@ class OmxLeader(Teleoperator): # point self.bus.write("Operating_Mode", motor, OperatingMode.EXTENDED_POSITION.value) + if motor == "gripper": + self.bus.write("Drive_Mode", motor, DriveMode.INVERTED.value) + else: + self.bus.write("Drive_Mode", motor, DriveMode.NON_INVERTED.value) + # Use 'position control current based' for gripper to be limited by the limit of the current. # For the follower gripper, it means it can grasp an object without forcing too much even tho, # its goal position is a complete grasp (both gripper fingers are ordered to join and reach a touch). # For the leader gripper, it means we can use it as a physical trigger, since we can force with our finger # to make it move, and it will move back to its original target position when we release the force. self.bus.write("Operating_Mode", "gripper", OperatingMode.CURRENT_POSITION.value) + self.bus.write("Current_Limit", "gripper", 100) + self.bus.write("Goal_Current", "gripper", 100) + self.bus.write("Homing_Offset", "gripper", 100) # Set gripper's goal pos in current position mode so that we can use it as a trigger. self.bus.enable_torque("gripper") if self.is_calibrated: From 366bef915cdff80d0a1bcf4834130eb63a479cb5 Mon Sep 17 00:00:00 2001 From: Reece O'Mahoney <66252930+reeceomahoney@users.noreply.github.com> Date: Mon, 26 Jan 2026 16:26:49 +0000 Subject: [PATCH 008/131] add task ids to libero env cfg (#2842) --- src/lerobot/envs/configs.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/lerobot/envs/configs.py b/src/lerobot/envs/configs.py index 112d3a73f..cd88b37bc 100644 --- a/src/lerobot/envs/configs.py +++ b/src/lerobot/envs/configs.py @@ -260,6 +260,7 @@ class HILSerlRobotEnvConfig(EnvConfig): @dataclass class LiberoEnv(EnvConfig): task: str = "libero_10" # can also choose libero_spatial, libero_object, etc. + task_ids: list[int] | None = None fps: int = 30 episode_length: int | None = None obs_type: str = "pixels_agent_pos" @@ -338,10 +339,10 @@ class LiberoEnv(EnvConfig): @property def gym_kwargs(self) -> dict: - return { - "obs_type": self.obs_type, - "render_mode": self.render_mode, - } + kwargs: dict[str, Any] = {"obs_type": self.obs_type, "render_mode": self.render_mode} + if self.task_ids is not None: + kwargs["task_ids"] = self.task_ids + return kwargs @EnvConfig.register_subclass("metaworld") From 9cfb5ce5468d0f2df568f7ba3f902f13f89802a7 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Mon, 26 Jan 2026 17:53:25 +0100 Subject: [PATCH 009/131] feat(motors): add damiao motors & can bus (#2788) * fix(motors): cleanup imports + fix signatures * feat(motors): add damiao canbus + multiple fixes * fix(motors): address comments -> last_state + different gains + sleep * refactor(motors): reduce duplicated code + adressed some comments in the PR * chore(motors): better timeouts * tests(motors): damiao test and imports * chore(deps): fix space * Apply suggestions from code review Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com> Signed-off-by: Steven Palma * chore(motors): remove normalization tables damiao * fix(motors): imports and signatures * feat(motors): add motor_type_str + recv_id to motor class and _get_motor_recv_id raises if no motor_obj.recv_id * chore(motors): remove normalize from base motor class and damaio * tests(motors): remove bad tests (to be replaced) * chore(motors): updated import check * use constant for kp and kd range and check responses in mit_control_batch() * Add docs on setting up canbus and use damiao otor bus, also add lerobot_setup_can.py and log if there is not response from a write command * precommit format * supress bandit as these are intentional cli commands * fix setup-can * add test * skip test in ci * nit precommit * update doc example * dont import can for tests --------- Signed-off-by: Steven Palma Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com> Co-authored-by: Pepijn --- docs/source/_toctree.yml | 2 + docs/source/damiao.mdx | 165 +++++ pyproject.toml | 3 + src/lerobot/motors/__init__.py | 6 +- src/lerobot/motors/calibration_gui.py | 2 +- src/lerobot/motors/damiao/__init__.py | 18 + src/lerobot/motors/damiao/damiao.py | 808 ++++++++++++++++++++++ src/lerobot/motors/damiao/tables.py | 209 ++++++ src/lerobot/motors/dynamixel/dynamixel.py | 11 +- src/lerobot/motors/feetech/feetech.py | 13 +- src/lerobot/motors/motors_bus.py | 101 ++- src/lerobot/scripts/lerobot_setup_can.py | 360 ++++++++++ src/lerobot/utils/import_utils.py | 1 + tests/motors/test_damiao.py | 66 ++ 14 files changed, 1740 insertions(+), 25 deletions(-) create mode 100644 docs/source/damiao.mdx create mode 100644 src/lerobot/motors/damiao/__init__.py create mode 100644 src/lerobot/motors/damiao/damiao.py create mode 100644 src/lerobot/motors/damiao/tables.py create mode 100644 src/lerobot/scripts/lerobot_setup_can.py create mode 100644 tests/motors/test_damiao.py diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 4298758b5..f86dd11c7 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -115,6 +115,8 @@ title: Notebooks - local: feetech title: Updating Feetech Firmware + - local: damiao + title: Damiao Motors and CAN Bus title: "Resources" - sections: - local: contributing diff --git a/docs/source/damiao.mdx b/docs/source/damiao.mdx new file mode 100644 index 000000000..45388ab9b --- /dev/null +++ b/docs/source/damiao.mdx @@ -0,0 +1,165 @@ +# Damiao Motors and CAN Bus + +This guide covers setup and usage of Damiao motors with LeRobot via CAN bus communication. + +Currently, only Linux is supported, as the OpenArms CAN adapter only has drivers for Linux. + +## Linux CAN Setup + +Before using Damiao motors, you need to set up the CAN interface on your Linux system. + +### Install CAN Utilities + +```bash +sudo apt-get install can-utils +``` + +### Configure CAN Interface (Manual) + +For standard CAN FD (recommended for OpenArms): + +```bash +sudo ip link set can0 down +sudo ip link set can0 type can bitrate 1000000 dbitrate 5000000 fd on +sudo ip link set can0 up +``` + +For standard CAN (without FD): + +```bash +sudo ip link set can0 down +sudo ip link set can0 type can bitrate 1000000 +sudo ip link set can0 up +``` + +### Configure CAN Interface (Using LeRobot) + +LeRobot provides a utility script to setup and test CAN interfaces: + +```bash +# Setup multiple interfaces (e.g., OpenArms Followers with 2 CAN buses) +lerobot-setup-can --mode=setup --interfaces=can0,can1 +``` + +## Debugging CAN Communication + +Use the built-in debug tools to test motor communication: + +```bash +# Test motors on all interfaces +lerobot-setup-can --mode=test --interfaces=can0,can1 + +# Run speed/latency test +lerobot-setup-can --mode=speed --interfaces=can0 +``` + +The test mode will scan for motors (IDs 0x01-0x08) and report which ones respond. Example output: + +``` +can0: UP (CAN FD) + Motor 0x01 (joint_1): ✓ FOUND + → Response 0x11 [FD]: 00112233... + Motor 0x02 (joint_2): ✓ FOUND + Motor 0x03 (joint_3): ✗ No response + ... + Summary: 2/8 motors found +``` + +## Usage + +### Basic Setup + +```python +from lerobot.motors import Motor +from lerobot.motors.damiao import DamiaoMotorsBus + +# Define your motors with send/receive CAN IDs +motors = { + "joint_1": Motor(id=0x01, motor_type_str="dm8009", recv_id=0x11), + "joint_2": Motor(id=0x02, motor_type_str="dm4340", recv_id=0x12), + "joint_3": Motor(id=0x03, motor_type_str="dm4310", recv_id=0x13), +} + +# Create the bus +bus = DamiaoMotorsBus( + port="can0", # Linux socketcan interface + motors=motors, +) + +# Connect +bus.connect() +``` + +### Reading Motor States + +```python +# Read single motor position (degrees) +position = bus.read("Present_Position", "joint_1") + +# Read from multiple motors +positions = bus.sync_read("Present_Position") # All motors +positions = bus.sync_read("Present_Position", ["joint_1", "joint_2"]) + +# Read all states at once (position, velocity, torque) +states = bus.sync_read_all_states() +# Returns: {'joint_1': {'position': 45.2, 'velocity': 1.3, 'torque': 0.5}, ...} +``` + +### Writing Motor Commands + +```python +# Enable torque +bus.enable_torque() + +# Set goal position (degrees) +bus.write("Goal_Position", "joint_1", 45.0) + +# Set positions for multiple motors +bus.sync_write("Goal_Position", { + "joint_1": 45.0, + "joint_2": -30.0, + "joint_3": 90.0, +}) + +# Disable torque +bus.disable_torque() +``` + +## Configuration Options + +| Parameter | Default | Description | +| -------------- | --------- | ----------------------------------------------------------- | +| `port` | - | CAN interface (`can0`) or serial port (`/dev/cu.usbmodem*`) | +| `use_can_fd` | `True` | Enable CAN FD for higher data rates | +| `bitrate` | `1000000` | Nominal bitrate (1 Mbps) | +| `data_bitrate` | `5000000` | CAN FD data bitrate (5 Mbps) | + +## Motor Configuration + +Each motor requires: + +- `id`: CAN ID for sending commands +- `motor_type`: One of the supported motor types (e.g., `"dm8009"`, `"dm4340"`) +- `recv_id`: CAN ID for receiving responses + +OpenArms default IDs follow the pattern: send ID `0x0N`, receive ID `0x1N` where N is the joint number. + +## Troubleshooting + +### No Response from Motors + +1. **Check power** +2. **Verify CAN wiring**: Check CAN-H, CAN-L, and GND connections +3. **Check motor IDs**: Use Damiao Debugging Tools to verify/configure IDs +4. **Test CAN interface**: Run `candump can0` to see if messages are being received +5. **Run diagnostics**: `lerobot-setup-can --mode=test --interfaces=can0` + +### Motor Timeout Parameter + +If motors were configured with timeout=0, they won't respond to commands. Use Damiao Debugging Tools to set a non-zero timeout value. + +### Verify CAN FD Status + +```bash +ip -d link show can0 | grep fd +``` diff --git a/pyproject.toml b/pyproject.toml index 75f617e75..27126f855 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -102,6 +102,7 @@ grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"] # Motors feetech = ["feetech-servo-sdk>=1.0.0,<2.0.0"] dynamixel = ["dynamixel-sdk>=3.7.31,<3.9.0"] +damiao = ["python-can>=4.2.0,<5.0.0"] # Robots gamepad = ["lerobot[pygame-dep]", "hidapi>=0.14.0,<0.15.0"] @@ -203,6 +204,7 @@ lerobot-info="lerobot.scripts.lerobot_info:main" lerobot-find-joint-limits="lerobot.scripts.lerobot_find_joint_limits:main" lerobot-imgtransform-viz="lerobot.scripts.lerobot_imgtransform_viz:main" lerobot-edit-dataset="lerobot.scripts.lerobot_edit_dataset:main" +lerobot-setup-can="lerobot.scripts.lerobot_setup_can:main" # ---------------- Tool Configurations ---------------- [tool.setuptools.packages.find] @@ -278,6 +280,7 @@ default.extend-ignore-identifiers-re = [ "thw", "inpt", "ROBOTIS", + "OT_VALUE" ] # TODO: Uncomment when ready to use diff --git a/src/lerobot/motors/__init__.py b/src/lerobot/motors/__init__.py index 850ef33d7..5df80d5ba 100644 --- a/src/lerobot/motors/__init__.py +++ b/src/lerobot/motors/__init__.py @@ -14,4 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .motors_bus import Motor, MotorCalibration, MotorNormMode, MotorsBus +from .motors_bus import ( + Motor, + MotorCalibration, + MotorNormMode, +) diff --git a/src/lerobot/motors/calibration_gui.py b/src/lerobot/motors/calibration_gui.py index 9832a1636..02bba454f 100644 --- a/src/lerobot/motors/calibration_gui.py +++ b/src/lerobot/motors/calibration_gui.py @@ -18,7 +18,7 @@ from dataclasses import dataclass os.environ["PYGAME_HIDE_SUPPORT_PROMPT"] = "1" -from lerobot.motors import MotorCalibration, MotorsBus +from .motors_bus import MotorCalibration, MotorsBus BAR_LEN, BAR_THICKNESS = 450, 8 HANDLE_R = 10 diff --git a/src/lerobot/motors/damiao/__init__.py b/src/lerobot/motors/damiao/__init__.py new file mode 100644 index 000000000..8240138cf --- /dev/null +++ b/src/lerobot/motors/damiao/__init__.py @@ -0,0 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .damiao import DamiaoMotorsBus +from .tables import * diff --git a/src/lerobot/motors/damiao/damiao.py b/src/lerobot/motors/damiao/damiao.py new file mode 100644 index 000000000..dd0213fc3 --- /dev/null +++ b/src/lerobot/motors/damiao/damiao.py @@ -0,0 +1,808 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Portions of this file are derived from DM_Control_Python by cmjang. +# Licensed under the MIT License; see `LICENSE` for the full text: +# https://github.com/cmjang/DM_Control_Python + +import logging +import time +from contextlib import contextmanager +from copy import deepcopy +from functools import cached_property +from typing import TYPE_CHECKING, Any, TypedDict + +from lerobot.utils.import_utils import _can_available + +if TYPE_CHECKING or _can_available: + import can +else: + can.Message = object + can.interface = None + +import numpy as np + +from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError +from lerobot.utils.robot_utils import precise_sleep +from lerobot.utils.utils import enter_pressed, move_cursor_up + +from ..motors_bus import Motor, MotorCalibration, MotorsBusBase, NameOrID, Value +from .tables import ( + AVAILABLE_BAUDRATES, + CAN_CMD_DISABLE, + CAN_CMD_ENABLE, + CAN_CMD_REFRESH, + CAN_CMD_SET_ZERO, + CAN_PARAM_ID, + DEFAULT_BAUDRATE, + DEFAULT_TIMEOUT_MS, + MIT_KD_RANGE, + MIT_KP_RANGE, + MOTOR_LIMIT_PARAMS, + MotorType, +) + +logger = logging.getLogger(__name__) + + +LONG_TIMEOUT_SEC = 0.1 +MEDIUM_TIMEOUT_SEC = 0.01 +SHORT_TIMEOUT_SEC = 0.001 +PRECISE_TIMEOUT_SEC = 0.0001 + + +class MotorState(TypedDict): + position: float + velocity: float + torque: float + temp_mos: float + temp_rotor: float + + +class DamiaoMotorsBus(MotorsBusBase): + """ + The Damiao implementation for a MotorsBus using CAN bus communication. + + This class uses python-can for CAN bus communication with Damiao motors. + For more info, see: + - python-can documentation: https://python-can.readthedocs.io/en/stable/ + - Seedstudio documentation: https://wiki.seeedstudio.com/damiao_series/ + - DM_Control_Python repo: https://github.com/cmjang/DM_Control_Python + """ + + # CAN-specific settings + available_baudrates = deepcopy(AVAILABLE_BAUDRATES) + default_baudrate = DEFAULT_BAUDRATE + default_timeout = DEFAULT_TIMEOUT_MS + + def __init__( + self, + port: str, + motors: dict[str, Motor], + calibration: dict[str, MotorCalibration] | None = None, + can_interface: str = "auto", + use_can_fd: bool = True, + bitrate: int = 1000000, + data_bitrate: int | None = 5000000, + ): + """ + Initialize the Damiao motors bus. + + Args: + port: CAN interface name (e.g., "can0" for Linux, "/dev/cu.usbmodem*" for macOS) + motors: Dictionary mapping motor names to Motor objects + calibration: Optional calibration data + can_interface: CAN interface type - "auto" (default), "socketcan" (Linux), or "slcan" (macOS/serial) + use_can_fd: Whether to use CAN FD mode (default: True for OpenArms) + bitrate: Nominal bitrate in bps (default: 1000000 = 1 Mbps) + data_bitrate: Data bitrate for CAN FD in bps (default: 5000000 = 5 Mbps), ignored if use_can_fd is False + """ + super().__init__(port, motors, calibration) + self.port = port + self.can_interface = can_interface + self.use_can_fd = use_can_fd + self.bitrate = bitrate + self.data_bitrate = data_bitrate + self.canbus: can.interface.Bus | None = None + self._is_connected = False + + # Map motor names to CAN IDs + self._motor_can_ids: dict[str, int] = {} + self._recv_id_to_motor: dict[int, str] = {} + self._motor_types: dict[str, MotorType] = {} + + for name, motor in self.motors.items(): + if motor.motor_type_str is None: + raise ValueError(f"Motor '{name}' is missing required 'motor_type'") + self._motor_types[name] = getattr(MotorType, motor.motor_type_str.upper().replace("-", "_")) + + # Map recv_id to motor name for filtering responses + if motor.recv_id is not None: + self._recv_id_to_motor[motor.recv_id] = name + + # State cache for handling packet drops safely + self._last_known_states: dict[str, MotorState] = { + name: { + "position": 0.0, + "velocity": 0.0, + "torque": 0.0, + "temp_mos": 0.0, + "temp_rotor": 0.0, + } + for name in self.motors + } + + # Dynamic gains storage + # Defaults: Kp=10.0 (Stiffness), Kd=0.5 (Damping) + self._gains: dict[str, dict[str, float]] = {name: {"kp": 10.0, "kd": 0.5} for name in self.motors} + + @property + def is_connected(self) -> bool: + """Check if the CAN bus is connected.""" + return self._is_connected and self.canbus is not None + + def connect(self, handshake: bool = True) -> None: + """ + Open the CAN bus and initialize communication. + + Args: + handshake: If True, ping all motors to verify they're present + """ + if self.is_connected: + raise DeviceAlreadyConnectedError( + f"{self.__class__.__name__}('{self.port}') is already connected." + ) + + try: + # Auto-detect interface type based on port name + if self.can_interface == "auto": + if self.port.startswith("/dev/"): + self.can_interface = "slcan" + logger.info(f"Auto-detected slcan interface for port {self.port}") + else: + self.can_interface = "socketcan" + logger.info(f"Auto-detected socketcan interface for port {self.port}") + + # Connect to CAN bus + kwargs = { + "channel": self.port, + "bitrate": self.bitrate, + "interface": self.can_interface, + } + + if self.can_interface == "socketcan" and self.use_can_fd and self.data_bitrate is not None: + kwargs.update({"data_bitrate": self.data_bitrate, "fd": True}) + logger.info( + f"Connected to {self.port} with CAN FD (bitrate={self.bitrate}, data_bitrate={self.data_bitrate})" + ) + else: + logger.info(f"Connected to {self.port} with {self.can_interface} (bitrate={self.bitrate})") + + self.canbus = can.interface.Bus(**kwargs) + self._is_connected = True + + if handshake: + self._handshake() + + logger.debug(f"{self.__class__.__name__} connected via {self.can_interface}.") + except Exception as e: + self._is_connected = False + raise ConnectionError(f"Failed to connect to CAN bus: {e}") from e + + def _handshake(self) -> None: + """ + Verify all motors are present and populate initial state cache. + Raises ConnectionError if any motor fails to respond. + """ + logger.info("Starting handshake with motors...") + missing_motors = [] + + for motor_name in self.motors: + msg = self._refresh_motor(motor_name) + if msg is None: + missing_motors.append(motor_name) + else: + self._process_response(motor_name, msg) + time.sleep(MEDIUM_TIMEOUT_SEC) + + if missing_motors: + raise ConnectionError( + f"Handshake failed. The following motors did not respond: {missing_motors}. " + "Check power (24V) and CAN wiring." + ) + logger.info("Handshake successful. All motors ready.") + + def disconnect(self, disable_torque: bool = True) -> None: + """ + Close the CAN bus connection. + + Args: + disable_torque: If True, disable torque on all motors before disconnecting + """ + if not self.is_connected: + raise DeviceNotConnectedError(f"{self.__class__.__name__}('{self.port}') is not connected.") + + if disable_torque: + try: + self.disable_torque() + except Exception as e: + logger.warning(f"Failed to disable torque during disconnect: {e}") + + if self.canbus: + self.canbus.shutdown() + self.canbus = None + self._is_connected = False + logger.debug(f"{self.__class__.__name__} disconnected.") + + def configure_motors(self) -> None: + """Configure all motors with default settings.""" + # Damiao motors don't require much configuration in MIT mode + # Just ensure they're enabled + for motor in self.motors: + self._send_simple_command(motor, CAN_CMD_ENABLE) + time.sleep(MEDIUM_TIMEOUT_SEC) + + def _send_simple_command(self, motor: NameOrID, command_byte: int) -> None: + """Helper to send simple 8-byte commands (Enable, Disable, Zero).""" + motor_id = self._get_motor_id(motor) + motor_name = self._get_motor_name(motor) + recv_id = self._get_motor_recv_id(motor) + data = [0xFF] * 7 + [command_byte] + msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False) + self.canbus.send(msg) + if msg := self._recv_motor_response(expected_recv_id=recv_id): + self._process_response(motor_name, msg) + else: + logger.debug(f"No response from {motor_name} after command 0x{command_byte:02X}") + + def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None: + """Enable torque on selected motors.""" + target_motors = self._get_motors_list(motors) + for motor in target_motors: + for _ in range(num_retry + 1): + try: + self._send_simple_command(motor, CAN_CMD_ENABLE) + break + except Exception as e: + if _ == num_retry: + raise e + time.sleep(MEDIUM_TIMEOUT_SEC) + + def disable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None: + """Disable torque on selected motors.""" + target_motors = self._get_motors_list(motors) + for motor in target_motors: + for _ in range(num_retry + 1): + try: + self._send_simple_command(motor, CAN_CMD_DISABLE) + break + except Exception as e: + if _ == num_retry: + raise e + time.sleep(MEDIUM_TIMEOUT_SEC) + + @contextmanager + def torque_disabled(self, motors: str | list[str] | None = None): + """ + Context manager that guarantees torque is re-enabled. + + This helper is useful to temporarily disable torque when configuring motors. + """ + self.disable_torque(motors) + try: + yield + finally: + self.enable_torque(motors) + + def set_zero_position(self, motors: str | list[str] | None = None) -> None: + """Set current position as zero for selected motors.""" + target_motors = self._get_motors_list(motors) + for motor in target_motors: + self._send_simple_command(motor, CAN_CMD_SET_ZERO) + time.sleep(MEDIUM_TIMEOUT_SEC) + + def _refresh_motor(self, motor: NameOrID) -> can.Message | None: + """Refresh motor status and return the response.""" + motor_id = self._get_motor_id(motor) + recv_id = self._get_motor_recv_id(motor) + data = [motor_id & 0xFF, (motor_id >> 8) & 0xFF, CAN_CMD_REFRESH, 0, 0, 0, 0, 0] + msg = can.Message(arbitration_id=CAN_PARAM_ID, data=data, is_extended_id=False) + self.canbus.send(msg) + return self._recv_motor_response(expected_recv_id=recv_id) + + def _recv_motor_response( + self, expected_recv_id: int | None = None, timeout: float = 0.001 + ) -> can.Message | None: + """ + Receive a response from a motor. + + Args: + expected_recv_id: If provided, only return messages from this CAN ID + timeout: Timeout in seconds (default: 1ms for high-speed operation) + Returns: + CAN message if received, None otherwise + """ + try: + start_time = time.time() + messages_seen = [] + while time.time() - start_time < timeout: + msg = self.canbus.recv(timeout=PRECISE_TIMEOUT_SEC) + if msg: + messages_seen.append(f"0x{msg.arbitration_id:02X}") + if expected_recv_id is None or msg.arbitration_id == expected_recv_id: + return msg + logger.debug( + f"Ignoring message from 0x{msg.arbitration_id:02X}, expected 0x{expected_recv_id:02X}" + ) + + if logger.isEnabledFor(logging.DEBUG): + if messages_seen: + logger.debug( + f"Received {len(messages_seen)} msgs from {set(messages_seen)}, expected 0x{expected_recv_id:02X}" + ) + else: + logger.debug(f"No CAN messages received (expected 0x{expected_recv_id:02X})") + except Exception as e: + logger.debug(f"Failed to receive CAN message: {e}") + return None + + def _recv_all_responses( + self, expected_recv_ids: list[int], timeout: float = 0.002 + ) -> dict[int, can.Message]: + """ + Efficiently receive responses from multiple motors at once. + Uses the OpenArms pattern: collect all available messages within timeout. + + Args: + expected_recv_ids: List of CAN IDs we expect responses from + timeout: Total timeout in seconds (default: 2ms) + + Returns: + Dictionary mapping recv_id to CAN message + """ + responses = {} + expected_set = set(expected_recv_ids) + start_time = time.time() + + try: + while len(responses) < len(expected_recv_ids) and (time.time() - start_time) < timeout: + # 100us poll timeout + msg = self.canbus.recv(timeout=PRECISE_TIMEOUT_SEC) + if msg and msg.arbitration_id in expected_set: + responses[msg.arbitration_id] = msg + if len(responses) == len(expected_recv_ids): + break + except Exception as e: + logger.debug(f"Error receiving responses: {e}") + + return responses + + def _encode_mit_packet( + self, + motor_type: MotorType, + kp: float, + kd: float, + position_degrees: float, + velocity_deg_per_sec: float, + torque: float, + ) -> list[int]: + """Helper to encode control parameters into 8 bytes for MIT mode.""" + # Convert degrees to radians + position_rad = np.radians(position_degrees) + velocity_rad_per_sec = np.radians(velocity_deg_per_sec) + + # Get motor limits + pmax, vmax, tmax = MOTOR_LIMIT_PARAMS[motor_type] + + # Encode parameters + kp_uint = self._float_to_uint(kp, *MIT_KP_RANGE, 12) + kd_uint = self._float_to_uint(kd, *MIT_KD_RANGE, 12) + q_uint = self._float_to_uint(position_rad, -pmax, pmax, 16) + dq_uint = self._float_to_uint(velocity_rad_per_sec, -vmax, vmax, 12) + tau_uint = self._float_to_uint(torque, -tmax, tmax, 12) + + # Pack data + data = [0] * 8 + data[0] = (q_uint >> 8) & 0xFF + data[1] = q_uint & 0xFF + data[2] = dq_uint >> 4 + data[3] = ((dq_uint & 0xF) << 4) | ((kp_uint >> 8) & 0xF) + data[4] = kp_uint & 0xFF + data[5] = kd_uint >> 4 + data[6] = ((kd_uint & 0xF) << 4) | ((tau_uint >> 8) & 0xF) + data[7] = tau_uint & 0xFF + return data + + def _mit_control( + self, + motor: NameOrID, + kp: float, + kd: float, + position_degrees: float, + velocity_deg_per_sec: float, + torque: float, + ) -> None: + """Send MIT control command to a motor.""" + motor_id = self._get_motor_id(motor) + motor_name = self._get_motor_name(motor) + motor_type = self._motor_types[motor_name] + + data = self._encode_mit_packet(motor_type, kp, kd, position_degrees, velocity_deg_per_sec, torque) + msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False) + self.canbus.send(msg) + + recv_id = self._get_motor_recv_id(motor) + if msg := self._recv_motor_response(expected_recv_id=recv_id): + self._process_response(motor_name, msg) + else: + logger.debug(f"No response from {motor_name} after MIT control command") + + def _mit_control_batch( + self, + commands: dict[NameOrID, tuple[float, float, float, float, float]], + ) -> None: + """ + Send MIT control commands to multiple motors in batch. + Sends all commands first, then collects responses. + + Args: + commands: Dict mapping motor name/ID to (kp, kd, position_deg, velocity_deg/s, torque) + Example: {'joint_1': (10.0, 0.5, 45.0, 0.0, 0.0), ...} + """ + if not commands: + return + + recv_id_to_motor: dict[int, str] = {} + + # Step 1: Send all MIT control commands + for motor, (kp, kd, position_degrees, velocity_deg_per_sec, torque) in commands.items(): + motor_id = self._get_motor_id(motor) + motor_name = self._get_motor_name(motor) + motor_type = self._motor_types[motor_name] + + data = self._encode_mit_packet(motor_type, kp, kd, position_degrees, velocity_deg_per_sec, torque) + msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False) + self.canbus.send(msg) + + recv_id_to_motor[self._get_motor_recv_id(motor)] = motor_name + + # Step 2: Collect responses and update state cache + responses = self._recv_all_responses(list(recv_id_to_motor.keys()), timeout=SHORT_TIMEOUT_SEC) + for recv_id, motor_name in recv_id_to_motor.items(): + if msg := responses.get(recv_id): + self._process_response(motor_name, msg) + + def _float_to_uint(self, x: float, x_min: float, x_max: float, bits: int) -> int: + """Convert float to unsigned integer for CAN transmission.""" + x = max(x_min, min(x_max, x)) # Clamp to range + span = x_max - x_min + data_norm = (x - x_min) / span + return int(data_norm * ((1 << bits) - 1)) + + def _uint_to_float(self, x: int, x_min: float, x_max: float, bits: int) -> float: + """Convert unsigned integer from CAN to float.""" + span = x_max - x_min + data_norm = float(x) / ((1 << bits) - 1) + return data_norm * span + x_min + + def _decode_motor_state( + self, data: bytearray | bytes, motor_type: MotorType + ) -> tuple[float, float, float, int, int]: + """ + Decode motor state from CAN data. + Returns: (position_deg, velocity_deg_s, torque, temp_mos, temp_rotor) + """ + if len(data) < 8: + raise ValueError("Invalid motor state data") + + # Extract encoded values + q_uint = (data[1] << 8) | data[2] + dq_uint = (data[3] << 4) | (data[4] >> 4) + tau_uint = ((data[4] & 0x0F) << 8) | data[5] + t_mos = data[6] + t_rotor = data[7] + + # Get motor limits + pmax, vmax, tmax = MOTOR_LIMIT_PARAMS[motor_type] + + # Decode to physical values + position_rad = self._uint_to_float(q_uint, -pmax, pmax, 16) + velocity_rad_per_sec = self._uint_to_float(dq_uint, -vmax, vmax, 12) + torque = self._uint_to_float(tau_uint, -tmax, tmax, 12) + + return np.degrees(position_rad), np.degrees(velocity_rad_per_sec), torque, t_mos, t_rotor + + def _process_response(self, motor: str, msg: can.Message) -> None: + """Decode a message and update the motor state cache.""" + try: + motor_type = self._motor_types[motor] + pos, vel, torque, t_mos, t_rotor = self._decode_motor_state(msg.data, motor_type) + + self._last_known_states[motor] = { + "position": pos, + "velocity": vel, + "torque": torque, + "temp_mos": float(t_mos), + "temp_rotor": float(t_rotor), + } + except Exception as e: + logger.warning(f"Failed to decode response from {motor}: {e}") + + def read(self, data_name: str, motor: str) -> Value: + """Read a value from a single motor. Positions are always in degrees.""" + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + # Refresh motor to get latest state + msg = self._refresh_motor(motor) + if msg is None: + motor_id = self._get_motor_id(motor) + recv_id = self._get_motor_recv_id(motor) + raise ConnectionError( + f"No response from motor '{motor}' (send ID: 0x{motor_id:02X}, recv ID: 0x{recv_id:02X}). " + f"Check that: 1) Motor is powered (24V), 2) CAN wiring is correct, " + f"3) Motor IDs are configured correctly using Damiao Debugging Tools" + ) + + self._process_response(motor, msg) + return self._get_cached_value(motor, data_name) + + def _get_cached_value(self, motor: str, data_name: str) -> Value: + """Retrieve a specific value from the cache.""" + state = self._last_known_states[motor] + mapping: dict[str, Any] = { + "Present_Position": state["position"], + "Present_Velocity": state["velocity"], + "Present_Torque": state["torque"], + "Temperature_MOS": state["temp_mos"], + "Temperature_Rotor": state["temp_rotor"], + } + if data_name not in mapping: + raise ValueError(f"Unknown data_name: {data_name}") + return mapping[data_name] + + def write( + self, + data_name: str, + motor: str, + value: Value, + ) -> None: + """ + Write a value to a single motor. Positions are always in degrees. + Can write 'Goal_Position', 'Kp', or 'Kd'. + """ + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + if data_name in ("Kp", "Kd"): + self._gains[motor][data_name.lower()] = float(value) + elif data_name == "Goal_Position": + kp = self._gains[motor]["kp"] + kd = self._gains[motor]["kd"] + self._mit_control(motor, kp, kd, float(value), 0.0, 0.0) + else: + raise ValueError(f"Writing {data_name} not supported in MIT mode") + + def sync_read( + self, + data_name: str, + motors: str | list[str] | None = None, + ) -> dict[str, Value]: + """ + Read the same value from multiple motors simultaneously. + """ + target_motors = self._get_motors_list(motors) + self._batch_refresh(target_motors) + + result = {} + for motor in target_motors: + result[motor] = self._get_cached_value(motor, data_name) + return result + + def sync_read_all_states( + self, + motors: str | list[str] | None = None, + *, + num_retry: int = 0, + ) -> dict[str, MotorState]: + """ + Read ALL motor states (position, velocity, torque) from multiple motors in ONE refresh cycle. + + Returns: + Dictionary mapping motor names to state dicts with keys: 'position', 'velocity', 'torque' + Example: {'joint_1': {'position': 45.2, 'velocity': 1.3, 'torque': 0.5}, ...} + """ + target_motors = self._get_motors_list(motors) + self._batch_refresh(target_motors) + + result = {} + for motor in target_motors: + result[motor] = self._last_known_states[motor].copy() + return result + + def _batch_refresh(self, motors: list[str]) -> None: + """Internal helper to refresh a list of motors and update cache.""" + # Send refresh commands + for motor in motors: + motor_id = self._get_motor_id(motor) + data = [motor_id & 0xFF, (motor_id >> 8) & 0xFF, CAN_CMD_REFRESH, 0, 0, 0, 0, 0] + msg = can.Message(arbitration_id=CAN_PARAM_ID, data=data, is_extended_id=False) + self.canbus.send(msg) + # Small delay to reduce bus congestion if necessary, though removed in sync_read previously + # precise_sleep(PRECISE_SLEEP_SEC) + + # Collect responses + expected_recv_ids = [self._get_motor_recv_id(m) for m in motors] + responses = self._recv_all_responses(expected_recv_ids, timeout=MEDIUM_TIMEOUT_SEC) + + # Update cache + for motor in motors: + recv_id = self._get_motor_recv_id(motor) + msg = responses.get(recv_id) + if msg: + self._process_response(motor, msg) + else: + logger.warning(f"Packet drop: {motor} (ID: 0x{recv_id:02X}). Using last known state.") + + def sync_write(self, data_name: str, values: Value | dict[str, Value]) -> None: + """ + Write values to multiple motors simultaneously. Positions are always in degrees. + """ + if data_name in ("Kp", "Kd"): + key = data_name.lower() + for motor, val in values.items(): + self._gains[motor][key] = float(val) + + elif data_name == "Goal_Position": + # Step 1: Send all MIT control commands + recv_id_to_motor: dict[int, str] = {} + for motor, value_degrees in values.items(): + motor_id = self._get_motor_id(motor) + motor_name = self._get_motor_name(motor) + motor_type = self._motor_types[motor_name] + + kp = self._gains[motor]["kp"] + kd = self._gains[motor]["kd"] + + data = self._encode_mit_packet(motor_type, kp, kd, float(value_degrees), 0.0, 0.0) + msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False) + self.canbus.send(msg) + precise_sleep(PRECISE_TIMEOUT_SEC) + + recv_id_to_motor[self._get_motor_recv_id(motor)] = motor_name + + # Step 2: Collect responses and update state cache + responses = self._recv_all_responses(list(recv_id_to_motor.keys()), timeout=MEDIUM_TIMEOUT_SEC) + for recv_id, motor_name in recv_id_to_motor.items(): + if msg := responses.get(recv_id): + self._process_response(motor_name, msg) + else: + # Fall back to individual writes + for motor, value in values.items(): + self.write(data_name, motor, value) + + def read_calibration(self) -> dict[str, MotorCalibration]: + """Read calibration data from motors.""" + # Damiao motors don't store calibration internally + # Return existing calibration or empty dict + return self.calibration if self.calibration else {} + + def write_calibration(self, calibration_dict: dict[str, MotorCalibration], cache: bool = True) -> None: + """Write calibration data to motors.""" + # Damiao motors don't store calibration internally + # Just cache it in memory + if cache: + self.calibration = calibration_dict + + def record_ranges_of_motion( + self, + motors: NameOrID | list[NameOrID] | None = None, + display_values: bool = True, + ) -> tuple[dict[NameOrID, Value], dict[NameOrID, Value]]: + """ + Interactively record the min/max values of each motor in degrees. + + Move the joints by hand (with torque disabled) while the method streams live positions. + Press Enter to finish. + """ + target_motors = self._get_motors_list(motors) + + self.disable_torque(target_motors) + time.sleep(LONG_TIMEOUT_SEC) + + start_positions = self.sync_read("Present_Position", target_motors) + mins = start_positions.copy() + maxes = start_positions.copy() + + print("\nMove joints through their full range of motion. Press ENTER when done.") + user_pressed_enter = False + + while not user_pressed_enter: + positions = self.sync_read("Present_Position", target_motors) + + for motor in target_motors: + if motor in positions: + mins[motor] = min(positions[motor], mins.get(motor, positions[motor])) + maxes[motor] = max(positions[motor], maxes.get(motor, positions[motor])) + + if display_values: + print("\n" + "=" * 50) + print(f"{'MOTOR':<20} | {'MIN (deg)':>12} | {'POS (deg)':>12} | {'MAX (deg)':>12}") + print("-" * 50) + for motor in target_motors: + if motor in positions: + print( + f"{motor:<20} | {mins[motor]:>12.1f} | {positions[motor]:>12.1f} | {maxes[motor]:>12.1f}" + ) + + if enter_pressed(): + user_pressed_enter = True + + if display_values and not user_pressed_enter: + move_cursor_up(len(target_motors) + 4) + + time.sleep(LONG_TIMEOUT_SEC) + + self.enable_torque(target_motors) + + for motor in target_motors: + if (motor in mins) and (motor in maxes) and (int(abs(maxes[motor] - mins[motor])) < 5): + raise ValueError(f"Motor {motor} has insufficient range of motion (< 5 degrees)") + + return mins, maxes + + def _get_motors_list(self, motors: str | list[str] | None) -> list[str]: + """Convert motor specification to list of motor names.""" + if motors is None: + return list(self.motors.keys()) + elif isinstance(motors, str): + return [motors] + elif isinstance(motors, list): + return motors + else: + raise TypeError(f"Invalid motors type: {type(motors)}") + + def _get_motor_id(self, motor: NameOrID) -> int: + """Get CAN ID for a motor.""" + if isinstance(motor, str): + if motor in self.motors: + return self.motors[motor].id + else: + raise ValueError(f"Unknown motor: {motor}") + else: + return motor + + def _get_motor_name(self, motor: NameOrID) -> str: + """Get motor name from name or ID.""" + if isinstance(motor, str): + return motor + else: + for name, m in self.motors.items(): + if m.id == motor: + return name + raise ValueError(f"Unknown motor ID: {motor}") + + def _get_motor_recv_id(self, motor: NameOrID) -> int: + """Get motor recv_id from name or ID.""" + motor_name = self._get_motor_name(motor) + motor_obj = self.motors.get(motor_name) + if motor_obj and motor_obj.recv_id is not None: + return motor_obj.recv_id + else: + raise ValueError(f"Motor {motor_obj} doesn't have a valid recv_id (None).") + + @cached_property + def is_calibrated(self) -> bool: + """Check if motors are calibrated.""" + return bool(self.calibration) diff --git a/src/lerobot/motors/damiao/tables.py b/src/lerobot/motors/damiao/tables.py new file mode 100644 index 000000000..22d1624fa --- /dev/null +++ b/src/lerobot/motors/damiao/tables.py @@ -0,0 +1,209 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Configuration tables for Damiao motors.""" + +from enum import IntEnum + + +# Motor type definitions +class MotorType(IntEnum): + DM3507 = 0 + DM4310 = 1 + DM4310_48V = 2 + DM4340 = 3 + DM4340_48V = 4 + DM6006 = 5 + DM8006 = 6 + DM8009 = 7 + DM10010L = 8 + DM10010 = 9 + DMH3510 = 10 + DMH6215 = 11 + DMG6220 = 12 + + +# Control modes +class ControlMode(IntEnum): + MIT = 1 + POS_VEL = 2 + VEL = 3 + TORQUE_POS = 4 + + +# Motor variable IDs (RID) +class MotorVariable(IntEnum): + UV_VALUE = 0 + KT_VALUE = 1 + OT_VALUE = 2 + OC_VALUE = 3 + ACC = 4 + DEC = 5 + MAX_SPD = 6 + MST_ID = 7 + ESC_ID = 8 + TIMEOUT = 9 + CTRL_MODE = 10 + DAMP = 11 + INERTIA = 12 + HW_VER = 13 + SW_VER = 14 + SN = 15 + NPP = 16 + RS = 17 + LS = 18 + FLUX = 19 + GR = 20 + PMAX = 21 + VMAX = 22 + TMAX = 23 + I_BW = 24 + KP_ASR = 25 + KI_ASR = 26 + KP_APR = 27 + KI_APR = 28 + OV_VALUE = 29 + GREF = 30 + DETA = 31 + V_BW = 32 + IQ_C1 = 33 + VL_C1 = 34 + CAN_BR = 35 + SUB_VER = 36 + U_OFF = 50 + V_OFF = 51 + K1 = 52 + K2 = 53 + M_OFF = 54 + DIR = 55 + P_M = 80 + XOUT = 81 + + +# Motor limit parameters [PMAX, VMAX, TMAX] +# PMAX: Maximum position (rad) +# VMAX: Maximum velocity (rad/s) +# TMAX: Maximum torque (N·m) +MOTOR_LIMIT_PARAMS = { + MotorType.DM3507: (12.5, 30, 10), + MotorType.DM4310: (12.5, 30, 10), + MotorType.DM4310_48V: (12.5, 50, 10), + MotorType.DM4340: (12.5, 8, 28), + MotorType.DM4340_48V: (12.5, 10, 28), + MotorType.DM6006: (12.5, 45, 20), + MotorType.DM8006: (12.5, 45, 40), + MotorType.DM8009: (12.5, 45, 54), + MotorType.DM10010L: (12.5, 25, 200), + MotorType.DM10010: (12.5, 20, 200), + MotorType.DMH3510: (12.5, 280, 1), + MotorType.DMH6215: (12.5, 45, 10), + MotorType.DMG6220: (12.5, 45, 10), +} + +# Motor model names +MODEL_NAMES = { + MotorType.DM3507: "dm3507", + MotorType.DM4310: "dm4310", + MotorType.DM4310_48V: "dm4310_48v", + MotorType.DM4340: "dm4340", + MotorType.DM4340_48V: "dm4340_48v", + MotorType.DM6006: "dm6006", + MotorType.DM8006: "dm8006", + MotorType.DM8009: "dm8009", + MotorType.DM10010L: "dm10010l", + MotorType.DM10010: "dm10010", + MotorType.DMH3510: "dmh3510", + MotorType.DMH6215: "dmh6215", + MotorType.DMG6220: "dmg6220", +} + +# Motor resolution table (encoder counts per revolution) +MODEL_RESOLUTION = { + "dm3507": 65536, + "dm4310": 65536, + "dm4310_48v": 65536, + "dm4340": 65536, + "dm4340_48v": 65536, + "dm6006": 65536, + "dm8006": 65536, + "dm8009": 65536, + "dm10010l": 65536, + "dm10010": 65536, + "dmh3510": 65536, + "dmh6215": 65536, + "dmg6220": 65536, +} + +# CAN baudrates supported by Damiao motors +AVAILABLE_BAUDRATES = [ + 125000, # 0: 125 kbps + 200000, # 1: 200 kbps + 250000, # 2: 250 kbps + 500000, # 3: 500 kbps + 1000000, # 4: 1 mbps (default for OpenArms) + 2000000, # 5: 2 mbps + 2500000, # 6: 2.5 mbps + 3200000, # 7: 3.2 mbps + 4000000, # 8: 4 mbps + 5000000, # 9: 5 mbps +] +DEFAULT_BAUDRATE = 1000000 # 1 Mbps is standard for OpenArms + +# Default timeout in milliseconds +DEFAULT_TIMEOUT_MS = 1000 + +# OpenArms specific configurations +# Based on: https://docs.openarm.dev/software/setup/configure-test +# OpenArms has 7 DOF per arm (14 total for dual arm) +OPENARMS_ARM_MOTOR_IDS = { + "joint_1": {"send": 0x01, "recv": 0x11}, # J1 - Shoulder pan + "joint_2": {"send": 0x02, "recv": 0x12}, # J2 - Shoulder lift + "joint_3": {"send": 0x03, "recv": 0x13}, # J3 - Elbow flex + "joint_4": {"send": 0x04, "recv": 0x14}, # J4 - Wrist flex + "joint_5": {"send": 0x05, "recv": 0x15}, # J5 - Wrist roll + "joint_6": {"send": 0x06, "recv": 0x16}, # J6 - Wrist pitch + "joint_7": {"send": 0x07, "recv": 0x17}, # J7 - Wrist rotation +} + +OPENARMS_GRIPPER_MOTOR_IDS = { + "gripper": {"send": 0x08, "recv": 0x18}, # J8 - Gripper +} + +# Default motor types for OpenArms +OPENARMS_DEFAULT_MOTOR_TYPES = { + "joint_1": MotorType.DM8009, # Shoulder pan - high torque + "joint_2": MotorType.DM8009, # Shoulder lift - high torque + "joint_3": MotorType.DM4340, # Shoulder rotation + "joint_4": MotorType.DM4340, # Elbow flex + "joint_5": MotorType.DM4310, # Wrist roll + "joint_6": MotorType.DM4310, # Wrist pitch + "joint_7": MotorType.DM4310, # Wrist rotation + "gripper": MotorType.DM4310, # Gripper +} + +# MIT control parameter ranges +MIT_KP_RANGE = (0.0, 500.0) +MIT_KD_RANGE = (0.0, 5.0) + +# CAN frame command IDs +CAN_CMD_ENABLE = 0xFC +CAN_CMD_DISABLE = 0xFD +CAN_CMD_SET_ZERO = 0xFE +CAN_CMD_REFRESH = 0xCC +CAN_CMD_QUERY_PARAM = 0x33 +CAN_CMD_WRITE_PARAM = 0x55 +CAN_CMD_SAVE_PARAM = 0xAA + +# CAN ID for parameter operations +CAN_PARAM_ID = 0x7FF diff --git a/src/lerobot/motors/dynamixel/dynamixel.py b/src/lerobot/motors/dynamixel/dynamixel.py index 01bfcf544..c6752ee96 100644 --- a/src/lerobot/motors/dynamixel/dynamixel.py +++ b/src/lerobot/motors/dynamixel/dynamixel.py @@ -22,9 +22,8 @@ import logging from copy import deepcopy from enum import Enum -from lerobot.motors.encoding_utils import decode_twos_complement, encode_twos_complement - -from ..motors_bus import Motor, MotorCalibration, MotorsBus, NameOrID, Value, get_address +from ..encoding_utils import decode_twos_complement, encode_twos_complement +from ..motors_bus import Motor, MotorCalibration, NameOrID, SerialMotorsBus, Value, get_address from .tables import ( AVAILABLE_BAUDRATES, MODEL_BAUDRATE_TABLE, @@ -100,7 +99,7 @@ def _split_into_byte_chunks(value: int, length: int) -> list[int]: return data -class DynamixelMotorsBus(MotorsBus): +class DynamixelMotorsBus(SerialMotorsBus): """ The Dynamixel implementation for a MotorsBus. It relies on the python dynamixel sdk to communicate with the motors. For more info, see the Dynamixel SDK Documentation: @@ -203,9 +202,9 @@ class DynamixelMotorsBus(MotorsBus): for motor in self._get_motors_list(motors): self.write("Torque_Enable", motor, TorqueMode.DISABLED.value, num_retry=num_retry) - def _disable_torque(self, motor_id: int, model: str, num_retry: int = 0) -> None: + def _disable_torque(self, motor: int, model: str, num_retry: int = 0) -> None: addr, length = get_address(self.model_ctrl_table, model, "Torque_Enable") - self._write(addr, length, motor_id, TorqueMode.DISABLED.value, num_retry=num_retry) + self._write(addr, length, motor, TorqueMode.DISABLED.value, num_retry=num_retry) def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None: for motor in self._get_motors_list(motors): diff --git a/src/lerobot/motors/feetech/feetech.py b/src/lerobot/motors/feetech/feetech.py index 2ea57af12..7ce3388b6 100644 --- a/src/lerobot/motors/feetech/feetech.py +++ b/src/lerobot/motors/feetech/feetech.py @@ -17,9 +17,8 @@ from copy import deepcopy from enum import Enum from pprint import pformat -from lerobot.motors.encoding_utils import decode_sign_magnitude, encode_sign_magnitude - -from ..motors_bus import Motor, MotorCalibration, MotorsBus, NameOrID, Value, get_address +from ..encoding_utils import decode_sign_magnitude, encode_sign_magnitude +from ..motors_bus import Motor, MotorCalibration, NameOrID, SerialMotorsBus, Value, get_address from .tables import ( FIRMWARE_MAJOR_VERSION, FIRMWARE_MINOR_VERSION, @@ -96,7 +95,7 @@ def patch_setPacketTimeout(self, packet_length): # noqa: N802 self.packet_timeout = (self.tx_time_per_byte * packet_length) + (self.tx_time_per_byte * 3.0) + 50 -class FeetechMotorsBus(MotorsBus): +class FeetechMotorsBus(SerialMotorsBus): """ The FeetechMotorsBus class allows to efficiently read and write to the attached motors. It relies on the python feetech sdk to communicate with the motors, which is itself based on the dynamixel sdk. @@ -298,11 +297,11 @@ class FeetechMotorsBus(MotorsBus): self.write("Torque_Enable", motor, TorqueMode.DISABLED.value, num_retry=num_retry) self.write("Lock", motor, 0, num_retry=num_retry) - def _disable_torque(self, motor_id: int, model: str, num_retry: int = 0) -> None: + def _disable_torque(self, motor: int, model: str, num_retry: int = 0) -> None: addr, length = get_address(self.model_ctrl_table, model, "Torque_Enable") - self._write(addr, length, motor_id, TorqueMode.DISABLED.value, num_retry=num_retry) + self._write(addr, length, motor, TorqueMode.DISABLED.value, num_retry=num_retry) addr, length = get_address(self.model_ctrl_table, model, "Lock") - self._write(addr, length, motor_id, 0, num_retry=num_retry) + self._write(addr, length, motor, 0, num_retry=num_retry) def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None: for motor in self._get_motors_list(motors): diff --git a/src/lerobot/motors/motors_bus.py b/src/lerobot/motors/motors_bus.py index 91bee994a..c04f718b6 100644 --- a/src/lerobot/motors/motors_bus.py +++ b/src/lerobot/motors/motors_bus.py @@ -19,6 +19,8 @@ # TODO(aliberts): Add block noqa when feature below is available # https://github.com/astral-sh/ruff/issues/3711 +from __future__ import annotations + import abc import logging from contextlib import contextmanager @@ -41,6 +43,81 @@ Value: TypeAlias = int | float logger = logging.getLogger(__name__) +class MotorsBusBase(abc.ABC): + """ + Base class for all motor bus implementations. + + This is a minimal interface that all motor buses must implement, regardless of their + communication protocol (serial, CAN, etc.). + """ + + def __init__( + self, + port: str, + motors: dict[str, Motor], + calibration: dict[str, MotorCalibration] | None = None, + ): + self.port = port + self.motors = motors + self.calibration = calibration if calibration else {} + + @abc.abstractmethod + def connect(self, handshake: bool = True) -> None: + """Establish connection to the motors.""" + pass + + @abc.abstractmethod + def disconnect(self, disable_torque: bool = True) -> None: + """Disconnect from the motors.""" + pass + + @property + @abc.abstractmethod + def is_connected(self) -> bool: + """Check if connected to the motors.""" + pass + + @abc.abstractmethod + def read(self, data_name: str, motor: str) -> Value: + """Read a value from a single motor.""" + pass + + @abc.abstractmethod + def write(self, data_name: str, motor: str, value: Value) -> None: + """Write a value to a single motor.""" + pass + + @abc.abstractmethod + def sync_read(self, data_name: str, motors: str | list[str] | None = None) -> dict[str, Value]: + """Read a value from multiple motors.""" + pass + + @abc.abstractmethod + def sync_write(self, data_name: str, values: Value | dict[str, Value]) -> None: + """Write values to multiple motors.""" + pass + + @abc.abstractmethod + def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None: + """Enable torque on selected motors.""" + pass + + @abc.abstractmethod + def disable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None: + """Disable torque on selected motors.""" + pass + + @abc.abstractmethod + def read_calibration(self) -> dict[str, MotorCalibration]: + """Read calibration parameters from the motors.""" + pass + + @abc.abstractmethod + def write_calibration(self, calibration_dict: dict[str, MotorCalibration], cache: bool = True) -> None: + """Write calibration parameters to the motors.""" + pass + + def get_ctrl_table(model_ctrl_table: dict[str, dict], model: str) -> dict[str, tuple[int, int]]: ctrl_table = model_ctrl_table.get(model) if ctrl_table is None: @@ -97,6 +174,8 @@ class Motor: id: int model: str norm_mode: MotorNormMode + motor_type_str: str | None = None + recv_id: int | None = None class PortHandler(Protocol): @@ -203,15 +282,15 @@ class GroupSyncWrite(Protocol): def txPacket(self): ... -class MotorsBus(abc.ABC): +class SerialMotorsBus(MotorsBusBase): """ - A MotorsBus allows to efficiently read and write to the attached motors. + A SerialMotorsBus allows to efficiently read and write to motors connected via serial communication. It represents several motors daisy-chained together and connected through a serial port. - There are currently two implementations of this abstract class: + There are currently two implementations of this class: - DynamixelMotorsBus - FeetechMotorsBus - Note: This class may evolve in the future should we add support for other types of bus. + This class is specifically for serial-based motor protocols (Dynamixel, Feetech, etc.). A MotorsBus subclass instance requires a port (e.g. `FeetechMotorsBus(port="/dev/tty.usbmodem575E0031751"`)). To find the port, you can run our utility script: @@ -260,9 +339,7 @@ class MotorsBus(abc.ABC): motors: dict[str, Motor], calibration: dict[str, MotorCalibration] | None = None, ): - self.port = port - self.motors = motors - self.calibration = calibration if calibration else {} + super().__init__(port, motors, calibration) self.port_handler: PortHandler self.packet_handler: PacketHandler @@ -532,7 +609,7 @@ class MotorsBus(abc.ABC): self.set_baudrate(self.default_baudrate) @abc.abstractmethod - def _find_single_motor(self, motor: str, initial_baudrate: int | None) -> tuple[int, int]: + def _find_single_motor(self, motor: str, initial_baudrate: int | None = None) -> tuple[int, int]: pass @abc.abstractmethod @@ -545,13 +622,13 @@ class MotorsBus(abc.ABC): pass @abc.abstractmethod - def disable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None: + def disable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None: """Disable torque on selected motors. Disabling Torque allows to write to the motors' permanent memory area (EPROM/EEPROM). Args: - motors (int | str | list[str] | None, optional): Target motors. Accepts a motor name, an ID, a + motors ( str | list[str] | None, optional): Target motors. Accepts a motor name, an ID, a list of names or `None` to affect every registered motor. Defaults to `None`. num_retry (int, optional): Number of additional retry attempts on communication failure. Defaults to 0. @@ -1194,3 +1271,7 @@ class MotorsBus(abc.ABC): for id_, value in ids_values.items(): data = self._serialize_data(value, length) self.sync_writer.addParam(id_, data) + + +# Backward compatibility alias +MotorsBus: TypeAlias = SerialMotorsBus diff --git a/src/lerobot/scripts/lerobot_setup_can.py b/src/lerobot/scripts/lerobot_setup_can.py new file mode 100644 index 000000000..55de74724 --- /dev/null +++ b/src/lerobot/scripts/lerobot_setup_can.py @@ -0,0 +1,360 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Setup and debug CAN interfaces for Damiao motors (e.g., OpenArms). + +Examples: + +Setup CAN interfaces with CAN FD: +```shell +lerobot-setup-can --mode=setup --interfaces=can0,can1,can2,can3 +``` + +Test motors on a single interface: +```shell +lerobot-setup-can --mode=test --interfaces=can0 +``` + +Test motors on all interfaces: +```shell +lerobot-setup-can --mode=test --interfaces=can0,can1,can2,can3 +``` + +Speed test: +```shell +lerobot-setup-can --mode=speed --interfaces=can0 +``` +""" + +import subprocess +import sys +import time +from dataclasses import dataclass, field + +import draccus + +from lerobot.utils.import_utils import is_package_available + +MOTOR_NAMES = { + 0x01: "joint_1", + 0x02: "joint_2", + 0x03: "joint_3", + 0x04: "joint_4", + 0x05: "joint_5", + 0x06: "joint_6", + 0x07: "joint_7", + 0x08: "gripper", +} + + +@dataclass +class CANSetupConfig: + mode: str = "test" + interfaces: str = "can0" # Comma-separated, e.g. "can0,can1,can2,can3" + bitrate: int = 1000000 + data_bitrate: int = 5000000 + use_fd: bool = True + motor_ids: list[int] = field(default_factory=lambda: list(range(0x01, 0x09))) + timeout: float = 1.0 + speed_iterations: int = 100 + + def get_interfaces(self) -> list[str]: + return [i.strip() for i in self.interfaces.split(",") if i.strip()] + + +def check_interface_status(interface: str) -> tuple[bool, str, bool]: + """Check if CAN interface is UP and configured.""" + try: + result = subprocess.run(["ip", "link", "show", interface], capture_output=True, text=True) # nosec B607 + if result.returncode != 0: + return False, "Interface not found", False + + output = result.stdout + is_up = "UP" in output + is_fd = "fd on" in output.lower() or "canfd" in output.lower() + status = "UP" if is_up else "DOWN" + if is_fd: + status += " (CAN FD)" + + return is_up, status, is_fd + except FileNotFoundError: + return False, "ip command not found", False + + +def setup_interface(interface: str, bitrate: int, data_bitrate: int, use_fd: bool) -> bool: + """Configure a CAN interface.""" + try: + subprocess.run(["sudo", "ip", "link", "set", interface, "down"], check=False, capture_output=True) # nosec B607 + + cmd = ["sudo", "ip", "link", "set", interface, "type", "can", "bitrate", str(bitrate)] + if use_fd: + cmd.extend(["dbitrate", str(data_bitrate), "fd", "on"]) + + result = subprocess.run(cmd, capture_output=True, text=True) # nosec B607 + if result.returncode != 0: + print(f" ✗ Failed to configure: {result.stderr}") + return False + + result = subprocess.run( # nosec B607 + ["sudo", "ip", "link", "set", interface, "up"], capture_output=True, text=True + ) + if result.returncode != 0: + print(f" ✗ Failed to bring up: {result.stderr}") + return False + + return True + except Exception as e: + print(f" ✗ Error: {e}") + return False + + +def test_motor(bus, motor_id: int, timeout: float, use_fd: bool): + """Test a single motor and return responses.""" + import can + + enable_msg = can.Message( + arbitration_id=motor_id, + data=[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFC], + is_extended_id=False, + is_fd=use_fd, + ) + + try: + bus.send(enable_msg) + except Exception as e: + return None, f"Send error: {e}" + + responses = [] + start_time = time.time() + + while time.time() - start_time < timeout: + msg = bus.recv(timeout=0.1) + if msg: + responses.append((msg.arbitration_id, msg.data.hex(), getattr(msg, "is_fd", False))) + + disable_msg = can.Message( + arbitration_id=motor_id, + data=[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFD], + is_extended_id=False, + is_fd=use_fd, + ) + try: + bus.send(disable_msg) + except Exception: + print(f"Error sending message to motor 0x{motor_id:02X}") + + return responses, None + + +def test_interface(cfg: CANSetupConfig, interface: str): + """Test all motors on a CAN interface.""" + import can + + is_up, status, _ = check_interface_status(interface) + print(f"\n{interface}: {status}") + + if not is_up: + print(f" ⚠ Interface is not UP. Run: lerobot-setup-can --mode=setup --interfaces {interface}") + return {} + + try: + kwargs = {"channel": interface, "interface": "socketcan", "bitrate": cfg.bitrate} + if cfg.use_fd: + kwargs.update({"data_bitrate": cfg.data_bitrate, "fd": True}) + bus = can.interface.Bus(**kwargs) + except Exception as e: + print(f" ✗ Connection failed: {e}") + return {} + + results = {} + try: + while bus.recv(timeout=0.01): + pass + + for motor_id in cfg.motor_ids: + motor_name = MOTOR_NAMES.get(motor_id, f"motor_0x{motor_id:02X}") + responses, error = test_motor(bus, motor_id, cfg.timeout, cfg.use_fd) + + if error: + print(f" Motor 0x{motor_id:02X} ({motor_name}): ✗ {error}") + results[motor_id] = {"found": False, "error": error} + elif responses: + print(f" Motor 0x{motor_id:02X} ({motor_name}): ✓ FOUND") + for resp_id, data, is_fd in responses: + fd_flag = " [FD]" if is_fd else "" + print(f" → Response 0x{resp_id:02X}{fd_flag}: {data}") + results[motor_id] = {"found": True, "responses": responses} + else: + print(f" Motor 0x{motor_id:02X} ({motor_name}): ✗ No response") + results[motor_id] = {"found": False} + + time.sleep(0.05) + finally: + bus.shutdown() + + found = sum(1 for r in results.values() if r.get("found")) + print(f"\n Summary: {found}/{len(cfg.motor_ids)} motors found") + return results + + +def speed_test(cfg: CANSetupConfig, interface: str): + """Test communication speed with motors.""" + import can + + is_up, status, _ = check_interface_status(interface) + if not is_up: + print(f"{interface}: {status} - skipping") + return + + print(f"\n{interface}: Running speed test ({cfg.speed_iterations} iterations)...") + + try: + kwargs = {"channel": interface, "interface": "socketcan", "bitrate": cfg.bitrate} + if cfg.use_fd: + kwargs.update({"data_bitrate": cfg.data_bitrate, "fd": True}) + bus = can.interface.Bus(**kwargs) + except Exception as e: + print(f" ✗ Connection failed: {e}") + return + + responding_motor = None + for motor_id in cfg.motor_ids: + responses, _ = test_motor(bus, motor_id, 0.5, cfg.use_fd) + if responses: + responding_motor = motor_id + break + + if not responding_motor: + print(" ✗ No responding motors found") + bus.shutdown() + return + + print(f" Testing with motor 0x{responding_motor:02X}...") + latencies = [] + + for _ in range(cfg.speed_iterations): + start = time.perf_counter() + msg = can.Message( + arbitration_id=responding_motor, + data=[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFC], + is_extended_id=False, + is_fd=cfg.use_fd, + ) + bus.send(msg) + resp = bus.recv(timeout=0.1) + if resp: + latencies.append((time.perf_counter() - start) * 1000) + + bus.shutdown() + + if latencies: + avg_latency = sum(latencies) / len(latencies) + hz = 1000.0 / avg_latency if avg_latency > 0 else 0 + print(f" ✓ Success rate: {len(latencies)}/{cfg.speed_iterations}") + print(f" ✓ Avg latency: {avg_latency:.2f} ms") + print(f" ✓ Max frequency: {hz:.1f} Hz") + else: + print(" ✗ No successful responses") + + +def run_setup(cfg: CANSetupConfig): + """Setup CAN interfaces.""" + print("=" * 50) + print("CAN Interface Setup") + print("=" * 50) + print(f"Mode: {'CAN FD' if cfg.use_fd else 'CAN 2.0'}") + print(f"Bitrate: {cfg.bitrate / 1_000_000:.1f} Mbps") + if cfg.use_fd: + print(f"Data bitrate: {cfg.data_bitrate / 1_000_000:.1f} Mbps") + print() + + interfaces = cfg.get_interfaces() + for interface in interfaces: + print(f"Configuring {interface}...") + if setup_interface(interface, cfg.bitrate, cfg.data_bitrate, cfg.use_fd): + is_up, status, _ = check_interface_status(interface) + print(f" ✓ {interface}: {status}") + else: + print(f" ✗ {interface}: Failed") + + print("\nSetup complete!") + print("\nNext: Test motors with:") + print(f" lerobot-setup-can --mode=test --interfaces {','.join(interfaces)}") + + +def run_test(cfg: CANSetupConfig): + """Test motors on CAN interfaces.""" + print("=" * 50) + print("CAN Motor Test") + print("=" * 50) + print(f"Testing motors 0x{min(cfg.motor_ids):02X}-0x{max(cfg.motor_ids):02X}") + print(f"Mode: {'CAN FD' if cfg.use_fd else 'CAN 2.0'}") + print() + + interfaces = cfg.get_interfaces() + all_results = {} + for interface in interfaces: + all_results[interface] = test_interface(cfg, interface) + + total_found = sum(sum(1 for r in res.values() if r.get("found")) for res in all_results.values()) + + print("\n" + "=" * 50) + print("Summary") + print("=" * 50) + print(f"Total motors found: {total_found}") + + if total_found == 0: + print("\n⚠ No motors found! Check:") + print(" 1. Motors are powered (24V)") + print(" 2. CAN wiring (CANH, CANL, GND)") + print(" 3. Motor timeout parameter > 0 (use Damiao tools)") + print(" 4. 120Ω termination at both cable ends") + print(f" 5. Interface configured: lerobot-setup-can --mode=setup --interfaces {interfaces[0]}") + + +def run_speed(cfg: CANSetupConfig): + """Run speed tests on CAN interfaces.""" + print("=" * 50) + print("CAN Speed Test") + print("=" * 50) + + for interface in cfg.get_interfaces(): + speed_test(cfg, interface) + + +@draccus.wrap() +def setup_can(cfg: CANSetupConfig): + if not is_package_available("can"): + print("Error: python-can not installed. Install with: pip install python-can") + sys.exit(1) + + if cfg.mode == "setup": + run_setup(cfg) + elif cfg.mode == "test": + run_test(cfg) + elif cfg.mode == "speed": + run_speed(cfg) + else: + print(f"Unknown mode: {cfg.mode}") + print("Available modes: setup, test, speed") + sys.exit(1) + + +def main(): + setup_can() + + +if __name__ == "__main__": + main() diff --git a/src/lerobot/utils/import_utils.py b/src/lerobot/utils/import_utils.py index a499b96c7..c33a73589 100644 --- a/src/lerobot/utils/import_utils.py +++ b/src/lerobot/utils/import_utils.py @@ -73,6 +73,7 @@ _transformers_available = is_package_available("transformers") _peft_available = is_package_available("peft") _scipy_available = is_package_available("scipy") _reachy2_sdk_available = is_package_available("reachy2_sdk") +_can_available = is_package_available("python-can", "can") def make_device_from_device_class(config: ChoiceRegistry) -> Any: diff --git a/tests/motors/test_damiao.py b/tests/motors/test_damiao.py new file mode 100644 index 000000000..7ce1af34f --- /dev/null +++ b/tests/motors/test_damiao.py @@ -0,0 +1,66 @@ +"""Minimal test script for Damiao motor with ID 3.""" + +import pytest + +from lerobot.utils.import_utils import _can_available + +if not _can_available: + pytest.skip("python-can not available", allow_module_level=True) + +from lerobot.motors import Motor +from lerobot.motors.damiao import DamiaoMotorsBus + + +@pytest.mark.skip(reason="Requires physical Damiao motor and CAN interface") +def test_damiao_motor(): + motors = { + "joint_3": Motor( + id=0x03, + model="damiao", + norm_mode="degrees", + motor_type_str="dm4310", + recv_id=0x13, + ), + } + + bus = DamiaoMotorsBus(port="can0", motors=motors) + + try: + print("Connecting...") + bus.connect() + print("✓ Connected") + + print("Enabling torque...") + bus.enable_torque() + print("✓ Torque enabled") + + print("Reading all states...") + states = bus.sync_read_all_states() + print(f"✓ States: {states}") + + print("Reading position...") + positions = bus.sync_read("Present_Position") + print(f"✓ Position: {positions}") + + print("Testing MIT control batch...") + current_pos = states["joint_3"]["position"] + commands = {"joint_3": (10.0, 0.5, current_pos, 0.0, 0.0)} + bus._mit_control_batch(commands) + print("✓ MIT control batch sent") + + print("Disabling torque...") + bus.disable_torque() + print("✓ Torque disabled") + + print("Setting zero position...") + bus.set_zero_position() + print("✓ Zero position set") + + finally: + print("Disconnecting...") + bus.disconnect(disable_torque=True) + print("✓ Disconnected") + + +if __name__ == "__main__": + test_damiao_motor() From 0c0c171d3543fafe587e0ac1768dd033aaf12ebf Mon Sep 17 00:00:00 2001 From: Pepijn <138571049+pkooij@users.noreply.github.com> Date: Tue, 27 Jan 2026 13:33:45 +0100 Subject: [PATCH 010/131] Add robot images to docs (#2862) * Add robot images to docs * increase img size * remove img so100 --- docs/source/earthrover_mini_plus.mdx | 6 ++++++ docs/source/lekiwi.mdx | 6 ++++++ docs/source/so101.mdx | 13 +++++++++++++ 3 files changed, 25 insertions(+) diff --git a/docs/source/earthrover_mini_plus.mdx b/docs/source/earthrover_mini_plus.mdx index e3ffa6b32..d8083336a 100644 --- a/docs/source/earthrover_mini_plus.mdx +++ b/docs/source/earthrover_mini_plus.mdx @@ -1,5 +1,11 @@ # EarthRover Mini Plus +EarthRover Mini Plus + The EarthRover Mini Plus is a fully open source mobile robot that connects through the cloud using the Frodobots SDK. This lets you control the robot and record datasets for training AI models. ## What You Need diff --git a/docs/source/lekiwi.mdx b/docs/source/lekiwi.mdx index 511521580..b339225d8 100644 --- a/docs/source/lekiwi.mdx +++ b/docs/source/lekiwi.mdx @@ -1,5 +1,11 @@ # LeKiwi +LeKiwi + In the steps below, we explain how to assemble the LeKiwi mobile robot. ## Source the parts diff --git a/docs/source/so101.mdx b/docs/source/so101.mdx index cf882b373..7c9df588a 100644 --- a/docs/source/so101.mdx +++ b/docs/source/so101.mdx @@ -1,5 +1,18 @@ # SO-101 +
+ SO-101 + SO-101 +
+ In the steps below, we explain how to assemble our flagship robot, the SO-101. ## Source the parts From f6b1c39b785af0f2f78899f5de6e008f3295e594 Mon Sep 17 00:00:00 2001 From: Reece O'Mahoney <66252930+reeceomahoney@users.noreply.github.com> Date: Tue, 27 Jan 2026 14:31:53 +0000 Subject: [PATCH 011/131] docs: update libero (#2857) * update libero docs * Update docs/source/libero.mdx Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Jade Choghari --------- Signed-off-by: Jade Choghari Co-authored-by: Jade Choghari Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- docs/source/libero.mdx | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/libero.mdx b/docs/source/libero.mdx index 3617f3b25..def974531 100644 --- a/docs/source/libero.mdx +++ b/docs/source/libero.mdx @@ -42,6 +42,7 @@ lerobot-eval \ ``` - `--env.task` picks the suite (`libero_object`, `libero_spatial`, etc.). +- `--env.task_ids` picks task ids to run (`[0]`, `[1,2,3]`, etc.). Omit this flag (or set it to `null`) to run all tasks in the suite. - `--eval.batch_size` controls how many environments run in parallel. - `--eval.n_episodes` sets how many episodes to run in total. From 736b43f3cfb5db2450fa787a45f645e1309caa00 Mon Sep 17 00:00:00 2001 From: Michel Aractingi Date: Wed, 28 Jan 2026 13:31:27 +0100 Subject: [PATCH 012/131] Fix(aggregate.py) Aggregation of datasets when sub-datasets are already a result of a previous merge (#2861) * Fix aggeregation of datasets when subdatasets are already a result of a previous merge * docstring * respond to copilot review + add regression test * Remove unnecessary int conversion for indicies --- src/lerobot/datasets/aggregate.py | 100 ++++++++++++++++++++++++------ tests/datasets/test_aggregate.py | 89 ++++++++++++++++++++++++++ 2 files changed, 171 insertions(+), 18 deletions(-) diff --git a/src/lerobot/datasets/aggregate.py b/src/lerobot/datasets/aggregate.py index 94ffe602e..7020545d2 100644 --- a/src/lerobot/datasets/aggregate.py +++ b/src/lerobot/datasets/aggregate.py @@ -116,6 +116,9 @@ def update_meta_data( Adjusts all indices and timestamps to account for previously aggregated data and videos in the destination dataset. + For data file indices, uses the 'src_to_dst' mapping from aggregate_data() + to correctly map source file indices to their destination locations. + Args: df: DataFrame containing the metadata to be updated. dst_meta: Destination dataset metadata. @@ -129,8 +132,50 @@ def update_meta_data( df["meta/episodes/chunk_index"] = df["meta/episodes/chunk_index"] + meta_idx["chunk"] df["meta/episodes/file_index"] = df["meta/episodes/file_index"] + meta_idx["file"] - df["data/chunk_index"] = df["data/chunk_index"] + data_idx["chunk"] - df["data/file_index"] = df["data/file_index"] + data_idx["file"] + + # Update data file indices using source-to-destination mapping + # This is critical for handling datasets that are already results of a merge + data_src_to_dst = data_idx.get("src_to_dst", {}) + if data_src_to_dst: + # Store original indices for lookup + df["_orig_data_chunk"] = df["data/chunk_index"].copy() + df["_orig_data_file"] = df["data/file_index"].copy() + + # Vectorized mapping from (src_chunk, src_file) to (dst_chunk, dst_file) + # This is much faster than per-row iteration for large metadata tables + mapping_index = pd.MultiIndex.from_tuples( + list(data_src_to_dst.keys()), + names=["chunk_index", "file_index"], + ) + mapping_values = list(data_src_to_dst.values()) + mapping_df = pd.DataFrame( + mapping_values, + index=mapping_index, + columns=["dst_chunk", "dst_file"], + ) + + # Construct a MultiIndex for each row based on original data indices + row_index = pd.MultiIndex.from_arrays( + [df["_orig_data_chunk"], df["_orig_data_file"]], + names=["chunk_index", "file_index"], + ) + + # Align mapping to rows; missing keys fall back to the default destination + reindexed = mapping_df.reindex(row_index) + reindexed[["dst_chunk", "dst_file"]] = reindexed[["dst_chunk", "dst_file"]].fillna( + {"dst_chunk": data_idx["chunk"], "dst_file": data_idx["file"]} + ) + + # Assign mapped destination indices back to the DataFrame + df["data/chunk_index"] = reindexed["dst_chunk"].to_numpy() + df["data/file_index"] = reindexed["dst_file"].to_numpy() + + # Clean up temporary columns + df = df.drop(columns=["_orig_data_chunk", "_orig_data_file"]) + else: + # Fallback to simple offset (backward compatibility for single-file sources) + df["data/chunk_index"] = df["data/chunk_index"] + data_idx["chunk"] + df["data/file_index"] = df["data/file_index"] + data_idx["file"] for key, video_idx in videos_idx.items(): # Store original video file indices before updating orig_chunk_col = f"videos/{key}/chunk_index" @@ -146,8 +191,7 @@ def update_meta_data( if src_to_dst: # Map each episode to its correct destination file and apply offset for idx in df.index: - # Convert to Python int to avoid numpy type mismatch in dict lookup - src_key = (int(df.at[idx, "_orig_chunk"]), int(df.at[idx, "_orig_file"])) + src_key = (df.at[idx, "_orig_chunk"], df.at[idx, "_orig_file"]) # Get destination chunk/file for this source file dst_chunk, dst_file = src_to_dst.get(src_key, (video_idx["chunk"], video_idx["file"])) @@ -163,8 +207,7 @@ def update_meta_data( df[orig_chunk_col] = video_idx["chunk"] df[orig_file_col] = video_idx["file"] for idx in df.index: - # Convert to Python int to avoid numpy type mismatch in dict lookup - src_key = (int(df.at[idx, "_orig_chunk"]), int(df.at[idx, "_orig_file"])) + src_key = (df.at[idx, "_orig_chunk"], df.at[idx, "_orig_file"]) offset = src_to_offset.get(src_key, 0) df.at[idx, f"videos/{key}/from_timestamp"] += offset df.at[idx, f"videos/{key}/to_timestamp"] += offset @@ -262,6 +305,10 @@ def aggregate_datasets( meta_idx = aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx) + # Clear the src_to_dst mapping after processing each source dataset + # to avoid interference between different source datasets + data_idx.pop("src_to_dst", None) + dst_meta.info["total_episodes"] += src_meta.total_episodes dst_meta.info["total_frames"] += src_meta.total_frames @@ -312,10 +359,6 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu dst_file_durations = video_idx["dst_file_durations"] for src_chunk_idx, src_file_idx in unique_chunk_file_pairs: - # Convert to Python int to ensure consistent dict keys - src_chunk_idx = int(src_chunk_idx) - src_file_idx = int(src_file_idx) - src_path = src_meta.root / DEFAULT_VIDEO_PATH.format( video_key=key, chunk_index=src_chunk_idx, @@ -388,10 +431,16 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si Reads source data files, updates indices to match the aggregated dataset, and writes them to the destination with proper file rotation. + Tracks a `src_to_dst` mapping from source (chunk, file) to destination (chunk, file) + which is critical for correctly updating episode metadata when source datasets + have multiple data files (e.g., from a previous merge operation). + Args: src_meta: Source dataset metadata. dst_meta: Destination dataset metadata. data_idx: Dictionary tracking data chunk and file indices. + data_files_size_in_mb: Maximum size for data files in MB. + chunk_size: Maximum number of files per chunk. Returns: dict: Updated data_idx with current chunk and file indices. @@ -409,6 +458,10 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si # retrieve features schema for proper image typing in parquet hf_features = get_hf_features_from_features(dst_meta.features) if contains_images else None + # Track source to destination file mapping for metadata update + # This is critical for handling datasets that are already results of a merge + src_to_dst: dict[tuple[int, int], tuple[int, int]] = {} + for src_chunk_idx, src_file_idx in unique_chunk_file_ids: src_path = src_meta.root / DEFAULT_DATA_PATH.format( chunk_index=src_chunk_idx, file_index=src_file_idx @@ -421,7 +474,9 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si df = pd.read_parquet(src_path) df = update_data_df(df, src_meta, dst_meta) - data_idx = append_or_create_parquet_file( + # Write data and get the actual destination file it was written to + # This avoids duplicating the rotation logic here + data_idx, (dst_chunk, dst_file) = append_or_create_parquet_file( df, src_path, data_idx, @@ -433,6 +488,12 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si hf_features=hf_features, ) + # Record the mapping from source to actual destination + src_to_dst[(src_chunk_idx, src_file_idx)] = (dst_chunk, dst_file) + + # Add the mapping to data_idx for use in metadata update + data_idx["src_to_dst"] = src_to_dst + return data_idx @@ -473,7 +534,7 @@ def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx): videos_idx, ) - meta_idx = append_or_create_parquet_file( + meta_idx, _ = append_or_create_parquet_file( df, src_path, meta_idx, @@ -501,7 +562,7 @@ def append_or_create_parquet_file( contains_images: bool = False, aggr_root: Path = None, hf_features: datasets.Features | None = None, -): +) -> tuple[dict[str, int], tuple[int, int]]: """Appends data to an existing parquet file or creates a new one based on size constraints. Manages file rotation when size limits are exceeded to prevent individual files @@ -519,9 +580,11 @@ def append_or_create_parquet_file( hf_features: Optional HuggingFace Features schema for proper image typing. Returns: - dict: Updated index dictionary with current chunk and file indices. + tuple: (updated_idx, (dst_chunk, dst_file)) where updated_idx is the index dict + and (dst_chunk, dst_file) is the actual destination file the data was written to. """ - dst_path = aggr_root / default_path.format(chunk_index=idx["chunk"], file_index=idx["file"]) + dst_chunk, dst_file = idx["chunk"], idx["file"] + dst_path = aggr_root / default_path.format(chunk_index=dst_chunk, file_index=dst_file) if not dst_path.exists(): dst_path.parent.mkdir(parents=True, exist_ok=True) @@ -529,14 +592,15 @@ def append_or_create_parquet_file( to_parquet_with_hf_images(df, dst_path, features=hf_features) else: df.to_parquet(dst_path) - return idx + return idx, (dst_chunk, dst_file) src_size = get_parquet_file_size_in_mb(src_path) dst_size = get_parquet_file_size_in_mb(dst_path) if dst_size + src_size >= max_mb: idx["chunk"], idx["file"] = update_chunk_file_indices(idx["chunk"], idx["file"], chunk_size) - new_path = aggr_root / default_path.format(chunk_index=idx["chunk"], file_index=idx["file"]) + dst_chunk, dst_file = idx["chunk"], idx["file"] + new_path = aggr_root / default_path.format(chunk_index=dst_chunk, file_index=dst_file) new_path.parent.mkdir(parents=True, exist_ok=True) final_df = df target_path = new_path @@ -555,7 +619,7 @@ def append_or_create_parquet_file( else: final_df.to_parquet(target_path) - return idx + return idx, (dst_chunk, dst_file) def finalize_aggregation(aggr_meta, all_metadata): diff --git a/tests/datasets/test_aggregate.py b/tests/datasets/test_aggregate.py index 031c29d60..3609bac24 100644 --- a/tests/datasets/test_aggregate.py +++ b/tests/datasets/test_aggregate.py @@ -525,3 +525,92 @@ def test_aggregate_image_datasets(tmp_path, lerobot_dataset_factory): assert img.shape[0] == 3, f"Image {image_key} should have 3 channels" assert_dataset_iteration_works(aggr_ds) + + +def test_aggregate_already_merged_dataset(tmp_path, lerobot_dataset_factory): + """Regression test for aggregating a dataset that is itself a result of a previous merge. + + This test reproduces the bug where merging datasets with multiple parquet files + (e.g., from a previous merge with file rotation) would cause FileNotFoundError + because metadata file indices were incorrectly preserved instead of being mapped + to their actual destination files. + + The fix adds src_to_dst tracking in aggregate_data() to correctly map source + file indices to destination file indices. + """ + # Step 1: Create datasets A and B + ds_a = lerobot_dataset_factory( + root=tmp_path / "ds_a", + repo_id=f"{DUMMY_REPO_ID}_a", + total_episodes=4, + total_frames=200, + ) + ds_b = lerobot_dataset_factory( + root=tmp_path / "ds_b", + repo_id=f"{DUMMY_REPO_ID}_b", + total_episodes=4, + total_frames=200, + ) + + # Step 2: Merge A+B into AB with small file size to force multiple files + aggregate_datasets( + repo_ids=[ds_a.repo_id, ds_b.repo_id], + roots=[ds_a.root, ds_b.root], + aggr_repo_id=f"{DUMMY_REPO_ID}_ab", + aggr_root=tmp_path / "ds_ab", + data_files_size_in_mb=0.01, # Force file rotation + ) + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.return_value = str(tmp_path / "ds_ab") + ds_ab = LeRobotDataset(f"{DUMMY_REPO_ID}_ab", root=tmp_path / "ds_ab") + + # Verify AB has multiple data files (file rotation occurred) + ab_data_files = list((tmp_path / "ds_ab" / "data").rglob("*.parquet")) + assert len(ab_data_files) > 1, "First merge should create multiple parquet files" + + # Step 3: Create dataset C + ds_c = lerobot_dataset_factory( + root=tmp_path / "ds_c", + repo_id=f"{DUMMY_REPO_ID}_c", + total_episodes=2, + total_frames=100, + ) + + # Step 4: Merge AB+C into final - THIS IS WHERE THE BUG OCCURRED + aggregate_datasets( + repo_ids=[ds_ab.repo_id, ds_c.repo_id], + roots=[ds_ab.root, ds_c.root], + aggr_repo_id=f"{DUMMY_REPO_ID}_abc", + aggr_root=tmp_path / "ds_abc", + ) + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.return_value = str(tmp_path / "ds_abc") + ds_abc = LeRobotDataset(f"{DUMMY_REPO_ID}_abc", root=tmp_path / "ds_abc") + + # Step 5: Verify all data files referenced in metadata actually exist + for ep_idx in range(ds_abc.num_episodes): + data_file_path = ds_abc.root / ds_abc.meta.get_data_file_path(ep_idx) + assert data_file_path.exists(), ( + f"Episode {ep_idx} references non-existent file: {data_file_path}\n" + "This indicates the src_to_dst mapping fix is not working correctly." + ) + + # Step 6: Verify we can iterate through the entire dataset without FileNotFoundError + expected_episodes = ds_a.num_episodes + ds_b.num_episodes + ds_c.num_episodes + expected_frames = ds_a.num_frames + ds_b.num_frames + ds_c.num_frames + + assert ds_abc.num_episodes == expected_episodes + assert ds_abc.num_frames == expected_frames + + # This would raise FileNotFoundError before the fix + assert_dataset_iteration_works(ds_abc) From bf337e716da18054e463003fa37f47df2aa9bfe3 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Wed, 28 Jan 2026 14:28:51 +0100 Subject: [PATCH 013/131] feat(robots): add OpenArm robot & teleoperator (#2795) * fix(motors): cleanup imports + fix signatures * feat(motors): add damiao canbus + multiple fixes * fix(motors): address comments -> last_state + different gains + sleep * refactor(motors): reduce duplicated code + adressed some comments in the PR * chore(motors): better timeouts * tests(motors): damiao test and imports * chore(deps): fix space * feat(robot): add openarm leader Co-authored-by: Pepijn * feat(robot): add openarm follower Co-authored-by: Pepijn * refactor(robot): remove mechanical compensations and double arm assumption + rename * chore(robots): remove left arm references * refactor(teleop): multiple improvements to leader * refactor(teleop): multiple improvements to leader * feat(robots): add open arm to util CLI * chore(robot): add alias openarm * Apply suggestions from code review Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com> Signed-off-by: Steven Palma * chore(motors): remove normalization tables damiao * fix(motors): imports and signatures * feat(motors): add motor_type_str + recv_id to motor class and _get_motor_recv_id raises if no motor_obj.recv_id * chore(motors): remove normalize from base motor class and damaio * tests(motors): remove bad tests (to be replaced) * chore(motors): updated import check * fix(robots): open arm mirrored config for joint limits * chore(motors): update position_kd gain values * chore(robots): set to 0 if openarm is calibrated at connect time * chore(robots): remove macos in open arm as can doesn't support it * chore(robots): update for motor_type_str in Motor class * chore(robots): no default value for can port in open arms * use constant for kp and kd range and check responses in mit_control_batch() * Add docs on setting up canbus and use damiao otor bus, also add lerobot_setup_can.py and log if there is not response from a write command * precommit format * supress bandit as these are intentional cli commands * fix setup-can * add test * skip test in ci * nit precommit * update doc example * dont import can for tests * remove comment * Add openarms docs * format * update purchase link * can to none if nit availabl;e * add canfd option in bus * make handshake logic similar to lerobot-can * type hint * type check * add temp teleop test * remove script * mock class * ignore linter --------- Signed-off-by: Steven Palma Co-authored-by: Pepijn Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com> --- docs/source/_toctree.yml | 2 + docs/source/openarm.mdx | 258 +++++++++++++ pyproject.toml | 1 + src/lerobot/motors/damiao/damiao.py | 51 ++- src/lerobot/processor/hil_processor.py | 12 +- .../robots/openarm_follower/__init__.py | 20 + .../config_openarm_follower.py | 117 ++++++ .../openarm_follower/openarm_follower.py | 348 ++++++++++++++++++ src/lerobot/robots/utils.py | 4 + src/lerobot/scripts/lerobot_calibrate.py | 2 + .../scripts/lerobot_find_joint_limits.py | 2 + src/lerobot/scripts/lerobot_record.py | 2 + src/lerobot/scripts/lerobot_replay.py | 1 + src/lerobot/scripts/lerobot_teleoperate.py | 2 + .../teleoperators/openarm_leader/__init__.py | 20 + .../openarm_leader/config_openarm_leader.py | 70 ++++ .../openarm_leader/openarm_leader.py | 225 +++++++++++ src/lerobot/teleoperators/utils.py | 14 +- 18 files changed, 1129 insertions(+), 22 deletions(-) create mode 100644 docs/source/openarm.mdx create mode 100644 src/lerobot/robots/openarm_follower/__init__.py create mode 100644 src/lerobot/robots/openarm_follower/config_openarm_follower.py create mode 100644 src/lerobot/robots/openarm_follower/openarm_follower.py create mode 100644 src/lerobot/teleoperators/openarm_leader/__init__.py create mode 100644 src/lerobot/teleoperators/openarm_leader/config_openarm_leader.py create mode 100644 src/lerobot/teleoperators/openarm_leader/openarm_leader.py diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index f86dd11c7..eb97117af 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -101,6 +101,8 @@ title: Earth Rover Mini - local: omx title: OMX + - local: openarm + title: OpenArm title: "Robots" - sections: - local: phone_teleop diff --git a/docs/source/openarm.mdx b/docs/source/openarm.mdx new file mode 100644 index 000000000..661808749 --- /dev/null +++ b/docs/source/openarm.mdx @@ -0,0 +1,258 @@ +# OpenArm + +[OpenArm](https://openarm.dev) is an open-source 7DOF humanoid arm designed for physical AI research and deployment. + +To get your OpenArm, assembled or DIY, and join the global community, browse verified and certified manufacturers worldwide at [openarm.dev](https://openarm.dev). + +## What's Unique? + +- **Human-Scale Design**: OpenArm is designed with human-like proportions, scaled for a person around 160-165cm tall. This provides an optimal balance between practical reach and manageable inertia for safe, responsive operation. + +- **Safety-First Architecture**: Built with QDD backdrivable motors and high compliance, OpenArm prioritizes safe human-robot interaction while maintaining practical payload capabilities (6.0kg peak / 4.1kg nominal) for real-world tasks. + +- **Built for Durability**: Critical structural components use aluminum and stainless steel construction, ensuring robust performance for repetitive data collection and continuous research use. + +- **Fully Accessible & Buildable**: Every component, from CNC parts and 3D-printed casings to electrical wiring is designed to be purchasable and buildable by individual researchers and labs, with complete fabrication data provided. + +- **Practical & Affordable**: At $6,500 USD for a complete bimanual system, OpenArm delivers research-grade capabilities at a fraction of traditional humanoid robot costs. + +## Platform Requirements + + + **Linux Only**: OpenArm currently only works on Linux. The CAN bus USB adapter + does not have macOS drivers and has not been tested on Windows. + + +## Safety Guide + +Before operating OpenArm, please read the [official safety guide](https://docs.openarm.dev/getting-started/safety-guide). Key points: + +- **Secure installation**: Fasten the arm to a flat, stable surface with screws or clamps +- **Safe distance**: Keep body parts and objects outside the range of motion during operation +- **Protective equipment**: Always wear safety goggles; use additional PPE as needed +- **Payload limits**: Do not exceed specified payload limits (6.0kg peak / 4.1kg nominal per arm) +- **Emergency stop**: Know the location and operation of the emergency stop device +- **Regular inspection**: Check for loose screws, damaged mechanical limits, unusual noises, and wiring damage + +## Hardware Setup + +Follow the official [OpenArm hardware documentation](https://docs.openarm.dev) for: + +- Bill of materials and sourcing +- 3D printing instructions +- Mechanical assembly +- Electrical wiring + +The hardware repositories are available at [github.com/enactic/openarm](https://github.com/enactic/openarm). + +## CAN Bus Setup + +OpenArm uses CAN bus communication with Damiao motors. Once you have the CAN bus USB adapter plugged into your Linux PC, follow the [Damiao Motors and CAN Bus guide](./damiao) to configure the interface. + +Quick setup: + +```bash +# Setup CAN interfaces +lerobot-setup-can --mode=setup --interfaces=can0,can1 + +# Test motor communication +lerobot-setup-can --mode=test --interfaces=can0,can1 +``` + +## Install LeRobot 🤗 + +Follow our [Installation Guide](./installation), then install the Damiao motor support: + +```bash +pip install -e ".[damiao]" +``` + +## Usage + +### Follower Arm (Robot) + + + + +```bash +lerobot-calibrate \ + --robot.type=openarm_follower \ + --robot.port=can0 \ + --robot.side=right \ + --robot.id=my_openarm_follower +``` + + + + +```python +from lerobot.robots.openarm_follower import OpenArmFollower, OpenArmFollowerConfig + +config = OpenArmFollowerConfig( + port="can0", + side="right", # or "left" for left arm + id="my_openarm_follower", +) + +follower = OpenArmFollower(config) +follower.connect() + +# Read current state +obs = follower.get_observation() +print(obs) + +# Send action (position in degrees) +action = { + "joint_1.pos": 0.0, + "joint_2.pos": 0.0, + "joint_3.pos": 0.0, + "joint_4.pos": 45.0, + "joint_5.pos": 0.0, + "joint_6.pos": 0.0, + "joint_7.pos": 0.0, + "gripper.pos": 0.0, +} +follower.send_action(action) + +follower.disconnect() +``` + + + + +### Leader Arm (Teleoperator) + +The leader arm is used for teleoperation - manually moving it to control the follower arm. + + + + +```bash +lerobot-calibrate \ + --teleop.type=openarm_leader \ + --teleop.port=can1 \ + --teleop.id=my_openarm_leader +``` + + + + +```python +from lerobot.teleoperators.openarm_leader import OpenArmLeader, OpenArmLeaderConfig + +config = OpenArmLeaderConfig( + port="can1", + id="my_openarm_leader", + manual_control=True, # Disable torque for manual movement +) + +leader = OpenArmLeader(config) +leader.connect() + +# Read current position (as action to send to follower) +action = leader.get_action() +print(action) + +leader.disconnect() +``` + + + + +### Teleoperation + +To teleoperate OpenArm with leader-follower control: + +```bash +lerobot-teleoperate \ + --robot.type=openarm_follower \ + --robot.port=can0 \ + --robot.side=right \ + --robot.id=my_follower \ + --teleop.type=openarm_leader \ + --teleop.port=can1 \ + --teleop.id=my_leader +``` + +### Recording Data + +To record a dataset during teleoperation: + +```bash +lerobot-record \ + --robot.type=openarm_follower \ + --robot.port=can0 \ + --robot.side=right \ + --robot.id=my_follower \ + --teleop.type=openarm_leader \ + --teleop.port=can1 \ + --teleop.id=my_leader \ + --repo-id=my_hf_username/my_openarm_dataset \ + --fps=30 \ + --num-episodes=10 +``` + +## Configuration Options + +### Follower Configuration + +| Parameter | Default | Description | +| --------------------- | --------- | ---------------------------------------------------------- | +| `port` | - | CAN interface (e.g., `can0`) | +| `side` | `None` | Arm side: `"left"`, `"right"`, or `None` for custom limits | +| `use_can_fd` | `True` | Enable CAN FD for higher data rates | +| `can_bitrate` | `1000000` | Nominal bitrate (1 Mbps) | +| `can_data_bitrate` | `5000000` | CAN FD data bitrate (5 Mbps) | +| `max_relative_target` | `None` | Safety limit for relative target positions | +| `position_kp` | Per-joint | Position control proportional gains | +| `position_kd` | Per-joint | Position control derivative gains | + +### Leader Configuration + +| Parameter | Default | Description | +| ------------------ | --------- | ----------------------------------- | +| `port` | - | CAN interface (e.g., `can1`) | +| `manual_control` | `True` | Disable torque for manual movement | +| `use_can_fd` | `True` | Enable CAN FD for higher data rates | +| `can_bitrate` | `1000000` | Nominal bitrate (1 Mbps) | +| `can_data_bitrate` | `5000000` | CAN FD data bitrate (5 Mbps) | + +## Motor Configuration + +OpenArm uses Damiao motors with the following default configuration: + +| Joint | Motor Type | Send ID | Recv ID | +| --------------------------- | ---------- | ------- | ------- | +| joint_1 (Shoulder pan) | DM8009 | 0x01 | 0x11 | +| joint_2 (Shoulder lift) | DM8009 | 0x02 | 0x12 | +| joint_3 (Shoulder rotation) | DM4340 | 0x03 | 0x13 | +| joint_4 (Elbow flex) | DM4340 | 0x04 | 0x14 | +| joint_5 (Wrist roll) | DM4310 | 0x05 | 0x15 | +| joint_6 (Wrist pitch) | DM4310 | 0x06 | 0x16 | +| joint_7 (Wrist rotation) | DM4310 | 0x07 | 0x17 | +| gripper | DM4310 | 0x08 | 0x18 | + +## Troubleshooting + +### No Response from Motors + +1. Check power supply connections +2. Verify CAN wiring (CAN-H, CAN-L, GND) +3. Run diagnostics: `lerobot-setup-can --mode=test --interfaces=can0` +4. See the [Damiao troubleshooting guide](./damiao#troubleshooting) for more details + +### CAN Interface Not Found + +Ensure the CAN interface is configured: + +```bash +ip link show can0 +``` + +## Resources + +- [OpenArm Website](https://openarm.dev) +- [OpenArm Documentation](https://docs.openarm.dev) +- [OpenArm GitHub](https://github.com/enactic/openarm) +- [Safety Guide](https://docs.openarm.dev/getting-started/safety-guide) +- [Damiao Motors and CAN Bus](./damiao) diff --git a/pyproject.toml b/pyproject.toml index 27126f855..ea2dfb4a2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -105,6 +105,7 @@ dynamixel = ["dynamixel-sdk>=3.7.31,<3.9.0"] damiao = ["python-can>=4.2.0,<5.0.0"] # Robots +openarms = ["lerobot[damiao]"] gamepad = ["lerobot[pygame-dep]", "hidapi>=0.14.0,<0.15.0"] hopejr = ["lerobot[feetech]", "lerobot[pygame-dep]"] lekiwi = ["lerobot[feetech]", "pyzmq>=26.2.1,<28.0.0"] diff --git a/src/lerobot/motors/damiao/damiao.py b/src/lerobot/motors/damiao/damiao.py index dd0213fc3..c79f8d17e 100644 --- a/src/lerobot/motors/damiao/damiao.py +++ b/src/lerobot/motors/damiao/damiao.py @@ -28,8 +28,11 @@ from lerobot.utils.import_utils import _can_available if TYPE_CHECKING or _can_available: import can else: - can.Message = object - can.interface = None + + class can: # noqa: N801 + Message = object + interface = None + import numpy as np @@ -206,11 +209,31 @@ class DamiaoMotorsBus(MotorsBusBase): Raises ConnectionError if any motor fails to respond. """ logger.info("Starting handshake with motors...") - missing_motors = [] + # Drain any pending messages + while self.canbus.recv(timeout=0.01): + pass + + missing_motors = [] for motor_name in self.motors: - msg = self._refresh_motor(motor_name) - if msg is None: + motor_id = self._get_motor_id(motor_name) + recv_id = self._get_motor_recv_id(motor_name) + + # Send enable command + data = [0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, CAN_CMD_ENABLE] + msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False, is_fd=self.use_can_fd) + self.canbus.send(msg) + + # Wait for response with longer timeout + response = None + start_time = time.time() + while time.time() - start_time < 0.1: + response = self.canbus.recv(timeout=0.1) + if response and response.arbitration_id == recv_id: + break + response = None + + if response is None: missing_motors.append(motor_name) else: self._process_response(motor_name, msg) @@ -259,7 +282,7 @@ class DamiaoMotorsBus(MotorsBusBase): motor_name = self._get_motor_name(motor) recv_id = self._get_motor_recv_id(motor) data = [0xFF] * 7 + [command_byte] - msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False) + msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False, is_fd=self.use_can_fd) self.canbus.send(msg) if msg := self._recv_motor_response(expected_recv_id=recv_id): self._process_response(motor_name, msg) @@ -317,7 +340,7 @@ class DamiaoMotorsBus(MotorsBusBase): motor_id = self._get_motor_id(motor) recv_id = self._get_motor_recv_id(motor) data = [motor_id & 0xFF, (motor_id >> 8) & 0xFF, CAN_CMD_REFRESH, 0, 0, 0, 0, 0] - msg = can.Message(arbitration_id=CAN_PARAM_ID, data=data, is_extended_id=False) + msg = can.Message(arbitration_id=CAN_PARAM_ID, data=data, is_extended_id=False, is_fd=self.use_can_fd) self.canbus.send(msg) return self._recv_motor_response(expected_recv_id=recv_id) @@ -439,7 +462,7 @@ class DamiaoMotorsBus(MotorsBusBase): motor_type = self._motor_types[motor_name] data = self._encode_mit_packet(motor_type, kp, kd, position_degrees, velocity_deg_per_sec, torque) - msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False) + msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False, is_fd=self.use_can_fd) self.canbus.send(msg) recv_id = self._get_motor_recv_id(motor) @@ -472,7 +495,7 @@ class DamiaoMotorsBus(MotorsBusBase): motor_type = self._motor_types[motor_name] data = self._encode_mit_packet(motor_type, kp, kd, position_degrees, velocity_deg_per_sec, torque) - msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False) + msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False, is_fd=self.use_can_fd) self.canbus.send(msg) recv_id_to_motor[self._get_motor_recv_id(motor)] = motor_name @@ -637,10 +660,10 @@ class DamiaoMotorsBus(MotorsBusBase): for motor in motors: motor_id = self._get_motor_id(motor) data = [motor_id & 0xFF, (motor_id >> 8) & 0xFF, CAN_CMD_REFRESH, 0, 0, 0, 0, 0] - msg = can.Message(arbitration_id=CAN_PARAM_ID, data=data, is_extended_id=False) + msg = can.Message( + arbitration_id=CAN_PARAM_ID, data=data, is_extended_id=False, is_fd=self.use_can_fd + ) self.canbus.send(msg) - # Small delay to reduce bus congestion if necessary, though removed in sync_read previously - # precise_sleep(PRECISE_SLEEP_SEC) # Collect responses expected_recv_ids = [self._get_motor_recv_id(m) for m in motors] @@ -676,7 +699,9 @@ class DamiaoMotorsBus(MotorsBusBase): kd = self._gains[motor]["kd"] data = self._encode_mit_packet(motor_type, kp, kd, float(value_degrees), 0.0, 0.0) - msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False) + msg = can.Message( + arbitration_id=motor_id, data=data, is_extended_id=False, is_fd=self.use_can_fd + ) self.canbus.send(msg) precise_sleep(PRECISE_TIMEOUT_SEC) diff --git a/src/lerobot/processor/hil_processor.py b/src/lerobot/processor/hil_processor.py index f0dbac9c3..6d44ed8cb 100644 --- a/src/lerobot/processor/hil_processor.py +++ b/src/lerobot/processor/hil_processor.py @@ -18,16 +18,18 @@ import math import time from dataclasses import dataclass -from typing import Any, Protocol, TypeVar, runtime_checkable +from typing import TYPE_CHECKING, Any, Protocol, TypeVar, runtime_checkable import numpy as np import torch import torchvision.transforms.functional as F # noqa: N812 from lerobot.configs.types import PipelineFeatureType, PolicyFeature -from lerobot.teleoperators.teleoperator import Teleoperator from lerobot.teleoperators.utils import TeleopEvents +if TYPE_CHECKING: + from lerobot.teleoperators.teleoperator import Teleoperator + from .core import EnvTransition, PolicyAction, TransitionKey from .pipeline import ( ComplementaryDataProcessorStep, @@ -69,10 +71,10 @@ class HasTeleopEvents(Protocol): # Type variable constrained to Teleoperator subclasses that also implement events -TeleopWithEvents = TypeVar("TeleopWithEvents", bound=Teleoperator) +TeleopWithEvents = TypeVar("TeleopWithEvents", bound="Teleoperator") -def _check_teleop_with_events(teleop: Teleoperator) -> None: +def _check_teleop_with_events(teleop: "Teleoperator") -> None: """ Runtime check that a teleoperator implements the `HasTeleopEvents` protocol. @@ -103,7 +105,7 @@ class AddTeleopActionAsComplimentaryDataStep(ComplementaryDataProcessorStep): teleop_device: The teleoperator instance to get the action from. """ - teleop_device: Teleoperator + teleop_device: "Teleoperator" def complementary_data(self, complementary_data: dict) -> dict: """ diff --git a/src/lerobot/robots/openarm_follower/__init__.py b/src/lerobot/robots/openarm_follower/__init__.py new file mode 100644 index 000000000..1eb0d9fc7 --- /dev/null +++ b/src/lerobot/robots/openarm_follower/__init__.py @@ -0,0 +1,20 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .config_openarm_follower import OpenArmFollowerConfig +from .openarm_follower import OpenArmFollower + +__all__ = ["OpenArmFollower", "OpenArmFollowerConfig"] diff --git a/src/lerobot/robots/openarm_follower/config_openarm_follower.py b/src/lerobot/robots/openarm_follower/config_openarm_follower.py new file mode 100644 index 000000000..af95b6395 --- /dev/null +++ b/src/lerobot/robots/openarm_follower/config_openarm_follower.py @@ -0,0 +1,117 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field + +from lerobot.cameras import CameraConfig + +from ..config import RobotConfig + +LEFT_DEFAULT_JOINTS_LIMITS: dict[str, tuple[float, float]] = { + "joint_1": (-75.0, 75.0), + "joint_2": (-90.0, 9.0), + "joint_3": (-85.0, 85.0), + "joint_4": (0.0, 135.0), + "joint_5": (-85.0, 85.0), + "joint_6": (-40.0, 40.0), + "joint_7": (-80.0, 80.0), + "gripper": (-65.0, 0.0), +} + +RIGHT_DEFAULT_JOINTS_LIMITS: dict[str, tuple[float, float]] = { + "joint_1": (-75.0, 75.0), + "joint_2": (-9.0, 90.0), + "joint_3": (-85.0, 85.0), + "joint_4": (0.0, 135.0), + "joint_5": (-85.0, 85.0), + "joint_6": (-40.0, 40.0), + "joint_7": (-80.0, 80.0), + "gripper": (-65.0, 0.0), +} + + +@RobotConfig.register_subclass("openarm_follower") +@dataclass +class OpenArmFollowerConfig(RobotConfig): + """Configuration for the OpenArms follower robot with Damiao motors.""" + + # CAN interfaces - one per arm + # arm CAN interface (e.g., "can1") + # Linux: "can0", "can1", etc. + port: str + + # side of the arm: "left" or "right". If "None" default values will be used + side: str | None = None + + # CAN interface type: "socketcan" (Linux), "slcan" (serial), or "auto" (auto-detect) + can_interface: str = "socketcan" + + # CAN FD settings (OpenArms uses CAN FD by default) + use_can_fd: bool = True + can_bitrate: int = 1000000 # Nominal bitrate (1 Mbps) + can_data_bitrate: int = 5000000 # Data bitrate for CAN FD (5 Mbps) + + # Whether to disable torque when disconnecting + disable_torque_on_disconnect: bool = True + + # Safety limit for relative target positions + # Set to a positive scalar for all motors, or a dict mapping motor names to limits + max_relative_target: float | dict[str, float] | None = None + + # Camera configurations + cameras: dict[str, CameraConfig] = field(default_factory=dict) + + # Motor configuration for OpenArms (7 DOF per arm) + # Maps motor names to (send_can_id, recv_can_id, motor_type) + # Based on: https://docs.openarm.dev/software/setup/configure-test + # OpenArms uses 4 types of motors: + # - DM8009 (DM-J8009P-2EC) for shoulders (high torque) + # - DM4340P and DM4340 for shoulder rotation and elbow + # - DM4310 (DM-J4310-2EC V1.1) for wrist and gripper + motor_config: dict[str, tuple[int, int, str]] = field( + default_factory=lambda: { + "joint_1": (0x01, 0x11, "dm8009"), # J1 - Shoulder pan (DM8009) + "joint_2": (0x02, 0x12, "dm8009"), # J2 - Shoulder lift (DM8009) + "joint_3": (0x03, 0x13, "dm4340"), # J3 - Shoulder rotation (DM4340) + "joint_4": (0x04, 0x14, "dm4340"), # J4 - Elbow flex (DM4340) + "joint_5": (0x05, 0x15, "dm4310"), # J5 - Wrist roll (DM4310) + "joint_6": (0x06, 0x16, "dm4310"), # J6 - Wrist pitch (DM4310) + "joint_7": (0x07, 0x17, "dm4310"), # J7 - Wrist rotation (DM4310) + "gripper": (0x08, 0x18, "dm4310"), # J8 - Gripper (DM4310) + } + ) + + # MIT control parameters for position control (used in send_action) + # List of 8 values: [joint_1, joint_2, joint_3, joint_4, joint_5, joint_6, joint_7, gripper] + position_kp: list[float] = field( + default_factory=lambda: [240.0, 240.0, 240.0, 240.0, 24.0, 31.0, 25.0, 25.0] + ) + position_kd: list[float] = field(default_factory=lambda: [5.0, 5.0, 3.0, 5.0, 0.3, 0.3, 0.3, 0.3]) + + # Values for joint limits. Can be overridden via CLI (for custom values) or by setting config.side to either 'left' or 'right'. + # If config.side is left set to None and no CLI values are passed, the default joint limit values are small for safety. + joint_limits: dict[str, tuple[float, float]] = field( + default_factory=lambda: { + "joint_1": (-5.0, 5.0), + "joint_2": (-5.0, 5.0), + "joint_3": (-5.0, 5.0), + "joint_4": (0.0, 5.0), + "joint_5": (-5.0, 5.0), + "joint_6": (-5.0, 5.0), + "joint_7": (-5.0, 5.0), + "gripper": (-5.0, 0.0), + } + ) diff --git a/src/lerobot/robots/openarm_follower/openarm_follower.py b/src/lerobot/robots/openarm_follower/openarm_follower.py new file mode 100644 index 000000000..c221afd10 --- /dev/null +++ b/src/lerobot/robots/openarm_follower/openarm_follower.py @@ -0,0 +1,348 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import time +from functools import cached_property +from typing import Any + +from lerobot.cameras.utils import make_cameras_from_configs +from lerobot.motors import Motor, MotorCalibration, MotorNormMode +from lerobot.motors.damiao import DamiaoMotorsBus +from lerobot.processor import RobotAction, RobotObservation +from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError + +from ..robot import Robot +from ..utils import ensure_safe_goal_position +from .config_openarm_follower import ( + LEFT_DEFAULT_JOINTS_LIMITS, + RIGHT_DEFAULT_JOINTS_LIMITS, + OpenArmFollowerConfig, +) + +logger = logging.getLogger(__name__) + + +class OpenArmFollower(Robot): + """ + OpenArms Follower Robot which uses CAN bus communication to control 7 DOF arm with a gripper. + The arm uses Damiao motors in MIT control mode. + """ + + config_class = OpenArmFollowerConfig + name = "openarm_follower" + + def __init__(self, config: OpenArmFollowerConfig): + super().__init__(config) + self.config = config + + # Arm motors + motors: dict[str, Motor] = {} + for motor_name, (send_id, recv_id, motor_type_str) in config.motor_config.items(): + motor = Motor( + send_id, motor_type_str, MotorNormMode.DEGREES + ) # Always use degrees for Damiao motors + motor.recv_id = recv_id + motor.motor_type_str = motor_type_str + motors[motor_name] = motor + + self.bus = DamiaoMotorsBus( + port=self.config.port, + motors=motors, + calibration=self.calibration, + can_interface=self.config.can_interface, + use_can_fd=self.config.use_can_fd, + bitrate=self.config.can_bitrate, + data_bitrate=self.config.can_data_bitrate if self.config.use_can_fd else None, + ) + + if config.side is not None: + if config.side == "left": + config.joint_limits = LEFT_DEFAULT_JOINTS_LIMITS + elif config.side == "right": + config.joint_limits = RIGHT_DEFAULT_JOINTS_LIMITS + else: + raise ValueError( + "config.side must be either 'left', 'right' (for default values) or 'None' (for CLI values)" + ) + else: + logger.info( + "Set config.side to either 'left' or 'right' to use pre-configured values for joint limits." + ) + logger.info(f"Values used for joint limits: {config.joint_limits}.") + + # Initialize cameras + self.cameras = make_cameras_from_configs(config.cameras) + + @property + def _motors_ft(self) -> dict[str, type]: + """Motor features for observation and action spaces.""" + features: dict[str, type] = {} + for motor in self.bus.motors: + features[f"{motor}.pos"] = float + features[f"{motor}.vel"] = float # Add this + features[f"{motor}.torque"] = float # Add this + return features + + @property + def _cameras_ft(self) -> dict[str, tuple]: + """Camera features for observation space.""" + return { + cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras + } + + @cached_property + def observation_features(self) -> dict[str, type | tuple]: + """Combined observation features from motors and cameras.""" + return {**self._motors_ft, **self._cameras_ft} + + @cached_property + def action_features(self) -> dict[str, type]: + """Action features.""" + return self._motors_ft + + @property + def is_connected(self) -> bool: + """Check if robot is connected.""" + return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values()) + + def connect(self, calibrate: bool = True) -> None: + """ + Connect to the robot and optionally calibrate. + + We assume that at connection time, the arms are in a safe rest position, + and torque can be safely disabled to run calibration if needed. + """ + if self.is_connected: + raise DeviceAlreadyConnectedError(f"{self} already connected") + + # Connect to CAN bus + logger.info(f"Connecting arm on {self.config.port}...") + self.bus.connect() + + # Run calibration if needed + if not self.is_calibrated and calibrate: + logger.info( + "Mismatch between calibration values in the motor and the calibration file or no calibration file found" + ) + self.calibrate() + + for cam in self.cameras.values(): + cam.connect() + + self.configure() + + if self.is_calibrated: + self.bus.set_zero_position() + + self.bus.enable_torque() + + logger.info(f"{self} connected.") + + @property + def is_calibrated(self) -> bool: + """Check if robot is calibrated.""" + return self.bus.is_calibrated + + def calibrate(self) -> None: + """ + Run calibration procedure for OpenArms robot. + + The calibration procedure: + 1. Disable torque + 2. Ask user to position arms in hanging position with grippers closed + 3. Set this as zero position + 4. Record range of motion for each joint + 5. Save calibration + """ + if self.calibration: + # Calibration file exists, ask user whether to use it or run new calibration + user_input = input( + f"Press ENTER to use provided calibration file associated with the id {self.id}, or type 'c' and press ENTER to run calibration: " + ) + if user_input.strip().lower() != "c": + logger.info(f"Writing calibration file associated with the id {self.id} to the motors") + self.bus.write_calibration(self.calibration) + return + + logger.info(f"\nRunning calibration for {self}") + self.bus.disable_torque() + + # Step 1: Set zero position + input( + "\nCalibration: Set Zero Position)\n" + "Position the arm in the following configuration:\n" + " - Arm hanging straight down\n" + " - Gripper closed\n" + "Press ENTER when ready..." + ) + + # Set current position as zero for all motors + self.bus.set_zero_position() + logger.info("Arm zero position set.") + + logger.info("Setting range: -90° to +90° for safety by default for all joints") + for motor_name, motor in self.bus.motors.items(): + self.calibration[motor_name] = MotorCalibration( + id=motor.id, + drive_mode=0, + homing_offset=0, + range_min=-90, + range_max=90, + ) + + self.bus.write_calibration(self.calibration) + self._save_calibration() + print(f"Calibration saved to {self.calibration_fpath}") + + def configure(self) -> None: + """Configure motors with appropriate settings.""" + # TODO(Steven, Pepijn): Slightly different from what it is happening in the leader + with self.bus.torque_disabled(): + self.bus.configure_motors() + + def setup_motors(self) -> None: + raise NotImplementedError( + "Motor ID configuration is typically done via manufacturer tools for CAN motors." + ) + + def get_observation(self) -> RobotObservation: + """ + Get current observation from robot including position, velocity, and torque. + + Reads all motor states (pos/vel/torque) in one CAN refresh cycle + instead of 3 separate reads. + """ + start = time.perf_counter() + + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + obs_dict: dict[str, Any] = {} + + states = self.bus.sync_read_all_states() + + for motor in self.bus.motors: + state = states.get(motor, {}) + obs_dict[f"{motor}.pos"] = state.get("position", 0.0) + obs_dict[f"{motor}.vel"] = state.get("velocity", 0.0) + obs_dict[f"{motor}.torque"] = state.get("torque", 0.0) + + # Capture images from cameras + for cam_key, cam in self.cameras.items(): + start = time.perf_counter() + obs_dict[cam_key] = cam.async_read() + dt_ms = (time.perf_counter() - start) * 1e3 + logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms") + + dt_ms = (time.perf_counter() - start) * 1e3 + logger.debug(f"{self} get_observation took: {dt_ms:.1f}ms") + + return obs_dict + + def send_action( + self, + action: RobotAction, + custom_kp: dict[str, float] | None = None, + custom_kd: dict[str, float] | None = None, + ) -> RobotAction: + """ + Send action command to robot. + + The action magnitude may be clipped based on safety limits. + + Args: + action: Dictionary with motor positions (e.g., "joint_1.pos", "joint_2.pos") + custom_kp: Optional custom kp gains per motor (e.g., {"joint_1": 120.0, "joint_2": 150.0}) + custom_kd: Optional custom kd gains per motor (e.g., {"joint_1": 1.5, "joint_2": 2.0}) + + Returns: + The action actually sent (potentially clipped) + """ + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + goal_pos = {key.removesuffix(".pos"): val for key, val in action.items() if key.endswith(".pos")} + + # Apply joint limit clipping to arm + for motor_name, position in goal_pos.items(): + if motor_name in self.config.joint_limits: + min_limit, max_limit = self.config.joint_limits[motor_name] + clipped_position = max(min_limit, min(max_limit, position)) + if clipped_position != position: + logger.debug(f"Clipped {motor_name} from {position:.2f}° to {clipped_position:.2f}°") + goal_pos[motor_name] = clipped_position + + # Cap goal position when too far away from present position. + # /!\ Slower fps expected due to reading from the follower. + if self.config.max_relative_target is not None: + present_pos = self.bus.sync_read("Present_Position") + goal_present_pos = {key: (g_pos, present_pos[key]) for key, g_pos in goal_pos.items()} + goal_pos = ensure_safe_goal_position(goal_present_pos, self.config.max_relative_target) + + # TODO(Steven, Pepijn): Refactor writing + # Motor name to index mapping for gains + motor_index = { + "joint_1": 0, + "joint_2": 1, + "joint_3": 2, + "joint_4": 3, + "joint_5": 4, + "joint_6": 5, + "joint_7": 6, + "gripper": 7, + } + + # Use batch MIT control for arm (sends all commands, then collects responses) + commands = {} + for motor_name, position_degrees in goal_pos.items(): + idx = motor_index.get(motor_name, 0) + # Use custom gains if provided, otherwise use config defaults + if custom_kp is not None and motor_name in custom_kp: + kp = custom_kp[motor_name] + else: + kp = ( + self.config.position_kp[idx] + if isinstance(self.config.position_kp, list) + else self.config.position_kp + ) + if custom_kd is not None and motor_name in custom_kd: + kd = custom_kd[motor_name] + else: + kd = ( + self.config.position_kd[idx] + if isinstance(self.config.position_kd, list) + else self.config.position_kd + ) + commands[motor_name] = (kp, kd, position_degrees, 0.0, 0.0) + + self.bus._mit_control_batch(commands) + + return {f"{motor}.pos": val for motor, val in goal_pos.items()} + + def disconnect(self): + """Disconnect from robot.""" + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + # Disconnect CAN bus + self.bus.disconnect(self.config.disable_torque_on_disconnect) + + # Disconnect cameras + for cam in self.cameras.values(): + cam.disconnect() + + logger.info(f"{self} disconnected.") diff --git a/src/lerobot/robots/utils.py b/src/lerobot/robots/utils.py index 27abaaa86..e0c76cab3 100644 --- a/src/lerobot/robots/utils.py +++ b/src/lerobot/robots/utils.py @@ -60,6 +60,10 @@ def make_robot_from_config(config: RobotConfig) -> Robot: from .reachy2 import Reachy2Robot return Reachy2Robot(config) + elif config.type == "openarm_follower": + from .openarm_follower import OpenArmFollower + + return OpenArmFollower(config) elif config.type == "mock_robot": from tests.mocks.mock_robot import MockRobot diff --git a/src/lerobot/scripts/lerobot_calibrate.py b/src/lerobot/scripts/lerobot_calibrate.py index cbc7684d3..0f79e6aa2 100644 --- a/src/lerobot/scripts/lerobot_calibrate.py +++ b/src/lerobot/scripts/lerobot_calibrate.py @@ -42,6 +42,7 @@ from lerobot.robots import ( # noqa: F401 lekiwi, make_robot_from_config, omx_follower, + openarm_follower, so_follower, ) from lerobot.teleoperators import ( # noqa: F401 @@ -52,6 +53,7 @@ from lerobot.teleoperators import ( # noqa: F401 koch_leader, make_teleoperator_from_config, omx_leader, + openarm_leader, so_leader, ) from lerobot.utils.import_utils import register_third_party_plugins diff --git a/src/lerobot/scripts/lerobot_find_joint_limits.py b/src/lerobot/scripts/lerobot_find_joint_limits.py index 20bbc8615..d928dc5cd 100644 --- a/src/lerobot/scripts/lerobot_find_joint_limits.py +++ b/src/lerobot/scripts/lerobot_find_joint_limits.py @@ -48,6 +48,7 @@ from lerobot.robots import ( # noqa: F401 koch_follower, make_robot_from_config, omx_follower, + openarm_follower, so_follower, ) from lerobot.teleoperators import ( # noqa: F401 @@ -57,6 +58,7 @@ from lerobot.teleoperators import ( # noqa: F401 koch_leader, make_teleoperator_from_config, omx_leader, + openarm_leader, so_leader, ) from lerobot.utils.robot_utils import precise_sleep diff --git a/src/lerobot/scripts/lerobot_record.py b/src/lerobot/scripts/lerobot_record.py index f03776989..4d334f38f 100644 --- a/src/lerobot/scripts/lerobot_record.py +++ b/src/lerobot/scripts/lerobot_record.py @@ -104,6 +104,7 @@ from lerobot.robots import ( # noqa: F401 koch_follower, make_robot_from_config, omx_follower, + openarm_follower, reachy2, so_follower, unitree_g1, @@ -116,6 +117,7 @@ from lerobot.teleoperators import ( # noqa: F401 koch_leader, make_teleoperator_from_config, omx_leader, + openarm_leader, reachy2_teleoperator, so_leader, ) diff --git a/src/lerobot/scripts/lerobot_replay.py b/src/lerobot/scripts/lerobot_replay.py index 49c06d643..c3bc3d766 100644 --- a/src/lerobot/scripts/lerobot_replay.py +++ b/src/lerobot/scripts/lerobot_replay.py @@ -59,6 +59,7 @@ from lerobot.robots import ( # noqa: F401 koch_follower, make_robot_from_config, omx_follower, + openarm_follower, reachy2, so_follower, unitree_g1, diff --git a/src/lerobot/scripts/lerobot_teleoperate.py b/src/lerobot/scripts/lerobot_teleoperate.py index 18d8863d6..a415dd600 100644 --- a/src/lerobot/scripts/lerobot_teleoperate.py +++ b/src/lerobot/scripts/lerobot_teleoperate.py @@ -76,6 +76,7 @@ from lerobot.robots import ( # noqa: F401 koch_follower, make_robot_from_config, omx_follower, + openarm_follower, reachy2, so_follower, ) @@ -89,6 +90,7 @@ from lerobot.teleoperators import ( # noqa: F401 koch_leader, make_teleoperator_from_config, omx_leader, + openarm_leader, reachy2_teleoperator, so_leader, ) diff --git a/src/lerobot/teleoperators/openarm_leader/__init__.py b/src/lerobot/teleoperators/openarm_leader/__init__.py new file mode 100644 index 000000000..1493317fe --- /dev/null +++ b/src/lerobot/teleoperators/openarm_leader/__init__.py @@ -0,0 +1,20 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .config_openarm_leader import OpenArmLeaderConfig +from .openarm_leader import OpenArmLeader + +__all__ = ["OpenArmLeader", "OpenArmLeaderConfig"] diff --git a/src/lerobot/teleoperators/openarm_leader/config_openarm_leader.py b/src/lerobot/teleoperators/openarm_leader/config_openarm_leader.py new file mode 100644 index 000000000..c53169b0a --- /dev/null +++ b/src/lerobot/teleoperators/openarm_leader/config_openarm_leader.py @@ -0,0 +1,70 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field + +from ..config import TeleoperatorConfig + + +@TeleoperatorConfig.register_subclass("openarm_leader") +@dataclass +class OpenArmLeaderConfig(TeleoperatorConfig): + """Configuration for the OpenArms leader/teleoperator with Damiao motors.""" + + # CAN interfaces - one per arm + # Arm CAN interface (e.g., "can3") + # Linux: "can0", "can1", etc. + port: str + + # CAN interface type: "socketcan" (Linux), "slcan" (serial), or "auto" (auto-detect) + can_interface: str = "socketcan" + + # CAN FD settings (OpenArms uses CAN FD by default) + use_can_fd: bool = True + can_bitrate: int = 1000000 # Nominal bitrate (1 Mbps) + can_data_bitrate: int = 5000000 # Data bitrate for CAN FD (5 Mbps) + + # Motor configuration for OpenArms (7 DOF per arm) + # Maps motor names to (send_can_id, recv_can_id, motor_type) + # Based on: https://docs.openarm.dev/software/setup/configure-test + # OpenArms uses 4 types of motors: + # - DM8009 (DM-J8009P-2EC) for shoulders (high torque) + # - DM4340P and DM4340 for shoulder rotation and elbow + # - DM4310 (DM-J4310-2EC V1.1) for wrist and gripper + motor_config: dict[str, tuple[int, int, str]] = field( + default_factory=lambda: { + "joint_1": (0x01, 0x11, "dm8009"), # J1 - Shoulder pan (DM8009) + "joint_2": (0x02, 0x12, "dm8009"), # J2 - Shoulder lift (DM8009) + "joint_3": (0x03, 0x13, "dm4340"), # J3 - Shoulder rotation (DM4340) + "joint_4": (0x04, 0x14, "dm4340"), # J4 - Elbow flex (DM4340) + "joint_5": (0x05, 0x15, "dm4310"), # J5 - Wrist roll (DM4310) + "joint_6": (0x06, 0x16, "dm4310"), # J6 - Wrist pitch (DM4310) + "joint_7": (0x07, 0x17, "dm4310"), # J7 - Wrist rotation (DM4310) + "gripper": (0x08, 0x18, "dm4310"), # J8 - Gripper (DM4310) + } + ) + + # Torque mode settings for manual control + # When enabled, motors have torque disabled for manual movement + manual_control: bool = True + + # TODO(Steven, Pepijn): Not used ... ? + # MIT control parameters (used when manual_control=False for torque control) + # List of 8 values: [joint_1, joint_2, joint_3, joint_4, joint_5, joint_6, joint_7, gripper] + position_kp: list[float] = field( + default_factory=lambda: [240.0, 240.0, 240.0, 240.0, 24.0, 31.0, 25.0, 16.0] + ) + position_kd: list[float] = field(default_factory=lambda: [3.0, 3.0, 3.0, 3.0, 0.2, 0.2, 0.2, 0.2]) diff --git a/src/lerobot/teleoperators/openarm_leader/openarm_leader.py b/src/lerobot/teleoperators/openarm_leader/openarm_leader.py new file mode 100644 index 000000000..edf4d7090 --- /dev/null +++ b/src/lerobot/teleoperators/openarm_leader/openarm_leader.py @@ -0,0 +1,225 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import time +from typing import Any + +from lerobot.motors import Motor, MotorCalibration, MotorNormMode +from lerobot.motors.damiao import DamiaoMotorsBus +from lerobot.processor import RobotAction +from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError + +from ..teleoperator import Teleoperator +from .config_openarm_leader import OpenArmLeaderConfig + +logger = logging.getLogger(__name__) + + +class OpenArmLeader(Teleoperator): + """ + OpenArm Leader/Teleoperator Arm with Damiao motors. + + This teleoperator uses CAN bus communication to read positions from + Damiao motors that are manually moved (torque disabled). + """ + + config_class = OpenArmLeaderConfig + name = "openarm_leader" + + def __init__(self, config: OpenArmLeaderConfig): + super().__init__(config) + self.config = config + + # Arm motors + motors: dict[str, Motor] = {} + for motor_name, (send_id, recv_id, motor_type_str) in config.motor_config.items(): + motor = Motor( + send_id, motor_type_str, MotorNormMode.DEGREES + ) # Always use degrees for Damiao motors + motor.recv_id = recv_id + motor.motor_type_str = motor_type_str + motors[motor_name] = motor + + self.bus = DamiaoMotorsBus( + port=self.config.port, + motors=motors, + calibration=self.calibration, + can_interface=self.config.can_interface, + use_can_fd=self.config.use_can_fd, + bitrate=self.config.can_bitrate, + data_bitrate=self.config.can_data_bitrate if self.config.use_can_fd else None, + ) + + @property + def action_features(self) -> dict[str, type]: + """Features produced by this teleoperator.""" + features: dict[str, type] = {} + for motor in self.bus.motors: + features[f"{motor}.pos"] = float + features[f"{motor}.vel"] = float + features[f"{motor}.torque"] = float + return features + + @property + def feedback_features(self) -> dict[str, type]: + """Feedback features (not implemented for OpenArms).""" + return {} + + @property + def is_connected(self) -> bool: + """Check if teleoperator is connected.""" + return self.bus.is_connected + + def connect(self, calibrate: bool = True) -> None: + """ + Connect to the teleoperator. + + For manual control, we disable torque after connecting so the + arm can be moved by hand. + """ + if self.is_connected: + raise DeviceAlreadyConnectedError(f"{self} already connected") + + # Connect to CAN bus + logger.info(f"Connecting arm on {self.config.port}...") + self.bus.connect() + + # Run calibration if needed + if not self.is_calibrated and calibrate: + logger.info( + "Mismatch between calibration values in the motor and the calibration file or no calibration file found" + ) + self.calibrate() + + self.configure() + + if self.is_calibrated: + self.bus.set_zero_position() + + logger.info(f"{self} connected.") + + @property + def is_calibrated(self) -> bool: + """Check if teleoperator is calibrated.""" + return self.bus.is_calibrated + + def calibrate(self) -> None: + """ + Run calibration procedure for OpenArms leader. + + The calibration procedure: + 1. Disable torque (if not already disabled) + 2. Ask user to position arm in zero position (hanging with gripper closed) + 3. Set this as zero position + 4. Record range of motion for each joint + 5. Save calibration + """ + if self.calibration: + # Calibration file exists, ask user whether to use it or run new calibration + user_input = input( + f"Press ENTER to use provided calibration file associated with the id {self.id}, or type 'c' and press ENTER to run calibration: " + ) + if user_input.strip().lower() != "c": + logger.info(f"Writing calibration file associated with the id {self.id} to the motors") + self.bus.write_calibration(self.calibration) + return + + logger.info(f"\nRunning calibration for {self}") + self.bus.disable_torque() + + # Step 1: Set zero position + input( + "\nCalibration: Set Zero Position)\n" + "Position the arm in the following configuration:\n" + " - Arm hanging straight down\n" + " - Gripper closed\n" + "Press ENTER when ready..." + ) + + # Set current position as zero for all motors + self.bus.set_zero_position() + logger.info("Arm zero position set.") + + logger.info("Setting range: -90° to +90° by default for all joints") + # TODO(Steven, Pepijn): Check if MotorCalibration is actually needed here given that we only use Degrees + for motor_name, motor in self.bus.motors.items(): + self.calibration[motor_name] = MotorCalibration( + id=motor.id, + drive_mode=0, + homing_offset=0, + range_min=-90, + range_max=90, + ) + + self.bus.write_calibration(self.calibration) + self._save_calibration() + print(f"Calibration saved to {self.calibration_fpath}") + + def configure(self) -> None: + """ + Configure motors for manual teleoperation. + + For manual control, we disable torque so the arm can be moved by hand. + """ + + return self.bus.disable_torque() if self.config.manual_control else self.bus.configure_motors() + + def setup_motors(self) -> None: + raise NotImplementedError( + "Motor ID configuration is typically done via manufacturer tools for CAN motors." + ) + + def get_action(self) -> RobotAction: + """ + Get current action from the leader arm. + + This is the main method for teleoperators - it reads the current state + of the leader arm and returns it as an action that can be sent to a follower. + + Reads all motor states (pos/vel/torque) in one CAN refresh cycle. + """ + start = time.perf_counter() + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + action_dict: dict[str, Any] = {} + + # Use sync_read_all_states to get pos/vel/torque in one go + states = self.bus.sync_read_all_states() + for motor in self.bus.motors: + state = states.get(motor, {}) + action_dict[f"{motor}.pos"] = state.get("position") + action_dict[f"{motor}.vel"] = state.get("velocity") + action_dict[f"{motor}.torque"] = state.get("torque") + + dt_ms = (time.perf_counter() - start) * 1e3 + logger.debug(f"{self} read state: {dt_ms:.1f}ms") + + return action_dict + + def send_feedback(self, feedback: dict[str, float]) -> None: + raise NotImplementedError("Feedback is not yet implemented for OpenArm leader.") + + def disconnect(self) -> None: + """Disconnect from teleoperator.""" + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + # Disconnect CAN bus + # For manual control, ensure torque is disabled before disconnecting + self.bus.disconnect(disable_torque=self.config.manual_control) + logger.info(f"{self} disconnected.") diff --git a/src/lerobot/teleoperators/utils.py b/src/lerobot/teleoperators/utils.py index eec2f119c..8f6bbc787 100644 --- a/src/lerobot/teleoperators/utils.py +++ b/src/lerobot/teleoperators/utils.py @@ -13,12 +13,14 @@ # limitations under the License. from enum import Enum -from typing import cast +from typing import TYPE_CHECKING, cast from lerobot.utils.import_utils import make_device_from_device_class from .config import TeleoperatorConfig -from .teleoperator import Teleoperator + +if TYPE_CHECKING: + from .teleoperator import Teleoperator class TeleopEvents(Enum): @@ -31,7 +33,7 @@ class TeleopEvents(Enum): TERMINATE_EPISODE = "terminate_episode" -def make_teleoperator_from_config(config: TeleoperatorConfig) -> Teleoperator: +def make_teleoperator_from_config(config: TeleoperatorConfig) -> "Teleoperator": # TODO(Steven): Consider just using the make_device_from_device_class for all types if config.type == "keyboard": from .keyboard import KeyboardTeleop @@ -81,8 +83,12 @@ def make_teleoperator_from_config(config: TeleoperatorConfig) -> Teleoperator: from .reachy2_teleoperator import Reachy2Teleoperator return Reachy2Teleoperator(config) + elif config.type == "openarm_leader": + from .openarm_leader import OpenArmLeader + + return OpenArmLeader(config) else: try: - return cast(Teleoperator, make_device_from_device_class(config)) + return cast("Teleoperator", make_device_from_device_class(config)) except Exception as e: raise ValueError(f"Error creating robot with config {config}: {e}") from e From 149628dfd5b3079a2c0f80832deac1da3e7bd287 Mon Sep 17 00:00:00 2001 From: Martino Russi <77496684+nepyope@users.noreply.github.com> Date: Wed, 28 Jan 2026 15:17:38 +0100 Subject: [PATCH 014/131] add g1 teleoperation (#2791) * add gravity compensation * add g1 teleoperation --------- Co-authored-by: Michel Aractingi --- docs/source/unitree_g1.mdx | 100 +++- pyproject.toml | 6 +- .../robots/unitree_g1/config_unitree_g1.py | 3 + src/lerobot/robots/unitree_g1/g1_utils.py | 2 +- .../unitree_g1/robot_kinematic_processor.py | 313 ++++++++++++ src/lerobot/robots/unitree_g1/unitree_g1.py | 19 +- src/lerobot/scripts/lerobot_calibrate.py | 1 + src/lerobot/scripts/lerobot_record.py | 3 +- src/lerobot/scripts/lerobot_teleoperate.py | 2 + .../teleoperators/unitree_g1/__init__.py | 21 + .../unitree_g1/config_unitree_g1.py | 37 ++ .../teleoperators/unitree_g1/exo_calib.py | 446 ++++++++++++++++++ .../teleoperators/unitree_g1/exo_ik.py | 353 ++++++++++++++ .../teleoperators/unitree_g1/exo_serial.py | 119 +++++ .../teleoperators/unitree_g1/unitree_g1.py | 157 ++++++ src/lerobot/teleoperators/utils.py | 4 + 16 files changed, 1581 insertions(+), 5 deletions(-) create mode 100644 src/lerobot/robots/unitree_g1/robot_kinematic_processor.py create mode 100644 src/lerobot/teleoperators/unitree_g1/__init__.py create mode 100644 src/lerobot/teleoperators/unitree_g1/config_unitree_g1.py create mode 100644 src/lerobot/teleoperators/unitree_g1/exo_calib.py create mode 100644 src/lerobot/teleoperators/unitree_g1/exo_ik.py create mode 100644 src/lerobot/teleoperators/unitree_g1/exo_serial.py create mode 100644 src/lerobot/teleoperators/unitree_g1/unitree_g1.py diff --git a/docs/source/unitree_g1.mdx b/docs/source/unitree_g1.mdx index e6bffdf1b..ea6bf54ad 100644 --- a/docs/source/unitree_g1.mdx +++ b/docs/source/unitree_g1.mdx @@ -188,7 +188,105 @@ Press `Ctrl+C` to stop the policy. ## Running in Simulation Mode (MuJoCo) -You can now test policies before unleashing them on the physical robot using MuJoCo. To do so simply set `is_simulation=True` in config. +You can test policies before deploying on the physical robot using MuJoCo simulation. Set `is_simulation=True` in config or pass `--robot.is_simulation=true` via CLI. + +### Calibrate Exoskeleton Teleoperator + +```bash +lerobot-calibrate \ + --teleop.type=unitree_g1 \ + --teleop.left_arm_config.port=/dev/ttyACM1 \ + --teleop.right_arm_config.port=/dev/ttyACM0 \ + --teleop.id=exo +``` + +### Teleoperate in Simulation + +```bash +lerobot-teleoperate \ + --robot.type=unitree_g1 \ + --robot.is_simulation=true \ + --teleop.type=unitree_g1 \ + --teleop.left_arm_config.port=/dev/ttyACM1 \ + --teleop.right_arm_config.port=/dev/ttyACM0 \ + --teleop.id=exo \ + --fps=100 +``` + +### Record Dataset in Simulation + +```bash +python -m lerobot.scripts.lerobot_record \ + --robot.type=unitree_g1 \ + --robot.is_simulation=true \ + --robot.cameras='{"global_view": {"type": "zmq", "server_address": "localhost", "port": 5555, "camera_name": "head_camera", "width": 640, "height": 480, "fps": 30}}' \ + --teleop.type=unitree_g1 \ + --teleop.left_arm_config.port=/dev/ttyACM1 \ + --teleop.right_arm_config.port=/dev/ttyACM0 \ + --teleop.id=exo \ + --dataset.repo_id=your-username/dataset-name \ + --dataset.single_task="Test" \ + --dataset.num_episodes=2 \ + --dataset.episode_time_s=5 \ + --dataset.reset_time_s=5 \ + --dataset.push_to_hub=true +``` + +Example simulation dataset: [nepyope/teleop_test_sim](https://huggingface.co/datasets/nepyope/teleop_test_sim) + +--- + +## Running on Real Robot + +Once the robot server is running on the G1 (see Part 3), you can teleoperate and record on the real robot. + +### Start the Camera Server + +On the robot, start the ZMQ image server: + +```bash +python src/lerobot/cameras/zmq/image_server.py +``` + +Keep this running in a separate terminal for camera streaming during recording. + +### Teleoperate Real Robot + +```bash +lerobot-teleoperate \ + --robot.type=unitree_g1 \ + --robot.is_simulation=false \ + --teleop.type=unitree_g1 \ + --teleop.left_arm_config.port=/dev/ttyACM1 \ + --teleop.right_arm_config.port=/dev/ttyACM0 \ + --teleop.id=exo \ + --fps=100 +``` + +### Record Dataset on Real Robot + +```bash +python -m lerobot.scripts.lerobot_record \ + --robot.type=unitree_g1 \ + --robot.is_simulation=false \ + --robot.cameras='{"global_view": {"type": "zmq", "server_address": "172.18.129.215", "port": 5555, "camera_name": "head_camera", "width": 640, "height": 480, "fps": 30}}' \ + --teleop.type=unitree_g1 \ + --teleop.left_arm_config.port=/dev/ttyACM1 \ + --teleop.right_arm_config.port=/dev/ttyACM0 \ + --teleop.id=exo \ + --dataset.repo_id=your-username/dataset-name \ + --dataset.single_task="Test" \ + --dataset.num_episodes=2 \ + --dataset.episode_time_s=5 \ + --dataset.reset_time_s=5 \ + --dataset.push_to_hub=true +``` + +**Note**: Update `server_address` to match your robot's camera server IP. + +Example real robot dataset: [nepyope/teleop_test_real](https://huggingface.co/datasets/nepyope/teleop_test_real) + +--- ## Additional Resources diff --git a/pyproject.toml b/pyproject.toml index ea2dfb4a2..210d70b6b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -111,7 +111,11 @@ hopejr = ["lerobot[feetech]", "lerobot[pygame-dep]"] lekiwi = ["lerobot[feetech]", "pyzmq>=26.2.1,<28.0.0"] unitree_g1 = [ "pyzmq>=26.2.1,<28.0.0", - "onnxruntime>=1.16.0,<2.0.0" + "onnxruntime>=1.16.0,<2.0.0", + "pin>=3.0.0,<4.0.0", + "meshcat>=0.3.0,<0.4.0", + "matplotlib>=3.9.0,<4.0.0", + "casadi>=3.6.0,<4.0.0", ] reachy2 = ["reachy2_sdk>=1.0.15,<1.1.0"] kinematics = ["lerobot[placo-dep]"] diff --git a/src/lerobot/robots/unitree_g1/config_unitree_g1.py b/src/lerobot/robots/unitree_g1/config_unitree_g1.py index 0b163019d..1b81214a6 100644 --- a/src/lerobot/robots/unitree_g1/config_unitree_g1.py +++ b/src/lerobot/robots/unitree_g1/config_unitree_g1.py @@ -65,3 +65,6 @@ class UnitreeG1Config(RobotConfig): # Cameras (ZMQ-based remote cameras) cameras: dict[str, CameraConfig] = field(default_factory=dict) + + # Compensates for gravity on the unitree's arms using the arm ik solver + gravity_compensation: bool = False diff --git a/src/lerobot/robots/unitree_g1/g1_utils.py b/src/lerobot/robots/unitree_g1/g1_utils.py index 3c41ee985..4e37bdcef 100644 --- a/src/lerobot/robots/unitree_g1/g1_utils.py +++ b/src/lerobot/robots/unitree_g1/g1_utils.py @@ -18,7 +18,7 @@ from enum import IntEnum # ruff: noqa: N801, N815 -NUM_MOTORS = 35 +NUM_MOTORS = 29 class G1_29_JointArmIndex(IntEnum): diff --git a/src/lerobot/robots/unitree_g1/robot_kinematic_processor.py b/src/lerobot/robots/unitree_g1/robot_kinematic_processor.py new file mode 100644 index 000000000..d086a9986 --- /dev/null +++ b/src/lerobot/robots/unitree_g1/robot_kinematic_processor.py @@ -0,0 +1,313 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import sys + +import numpy as np + +logger = logging.getLogger(__name__) +parent2_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +sys.path.append(parent2_dir) + + +class WeightedMovingFilter: + def __init__(self, weights, data_size=14): + self._window_size = len(weights) + self._weights = np.array(weights) + self._data_size = data_size + self._filtered_data = np.zeros(self._data_size) + self._data_queue = [] + + def _apply_filter(self): + if len(self._data_queue) < self._window_size: + return self._data_queue[-1] + + data_array = np.array(self._data_queue) + temp_filtered_data = np.zeros(self._data_size) + for i in range(self._data_size): + temp_filtered_data[i] = np.convolve(data_array[:, i], self._weights, mode="valid")[-1] + + return temp_filtered_data + + def add_data(self, new_data): + assert len(new_data) == self._data_size + + if len(self._data_queue) > 0 and np.array_equal( + new_data, self._data_queue[-1] + ): # skip duplicate data + return + + if len(self._data_queue) >= self._window_size: + self._data_queue.pop(0) + + self._data_queue.append(new_data) + self._filtered_data = self._apply_filter() + + @property + def filtered_data(self): + return self._filtered_data + + +class G1_29_ArmIK: # noqa: N801 + def __init__(self, unit_test=False): + import casadi + import pinocchio as pin + from huggingface_hub import snapshot_download + from pinocchio import casadi as cpin + + self._pin = pin + np.set_printoptions(precision=5, suppress=True, linewidth=200) + + self.unit_test = unit_test + + self.repo_path = snapshot_download("lerobot/unitree-g1-mujoco") + urdf_path = os.path.join(self.repo_path, "assets", "g1_body29_hand14.urdf") + mesh_dir = os.path.join(self.repo_path, "assets") + + self.robot = self._pin.RobotWrapper.BuildFromURDF(urdf_path, mesh_dir) + + self.mixed_jointsToLockIDs = [ + "left_hip_pitch_joint", + "left_hip_roll_joint", + "left_hip_yaw_joint", + "left_knee_joint", + "left_ankle_pitch_joint", + "left_ankle_roll_joint", + "right_hip_pitch_joint", + "right_hip_roll_joint", + "right_hip_yaw_joint", + "right_knee_joint", + "right_ankle_pitch_joint", + "right_ankle_roll_joint", + "waist_yaw_joint", + "waist_roll_joint", + "waist_pitch_joint", + "left_hand_thumb_0_joint", + "left_hand_thumb_1_joint", + "left_hand_thumb_2_joint", + "left_hand_middle_0_joint", + "left_hand_middle_1_joint", + "left_hand_index_0_joint", + "left_hand_index_1_joint", + "right_hand_thumb_0_joint", + "right_hand_thumb_1_joint", + "right_hand_thumb_2_joint", + "right_hand_index_0_joint", + "right_hand_index_1_joint", + "right_hand_middle_0_joint", + "right_hand_middle_1_joint", + ] + + self.reduced_robot = self.robot.buildReducedRobot( + list_of_joints_to_lock=self.mixed_jointsToLockIDs, + reference_configuration=np.array([0.0] * self.robot.model.nq), + ) + + # Arm joint names in G1 motor order (G1_29_JointArmIndex) + self._arm_joint_names_g1 = [ + "left_shoulder_pitch_joint", + "left_shoulder_roll_joint", + "left_shoulder_yaw_joint", + "left_elbow_joint", + "left_wrist_roll_joint", + "left_wrist_pitch_joint", + "left_wrist_yaw_joint", + "right_shoulder_pitch_joint", + "right_shoulder_roll_joint", + "right_shoulder_yaw_joint", + "right_elbow_joint", + "right_wrist_roll_joint", + "right_wrist_pitch_joint", + "right_wrist_yaw_joint", + ] + # Pinocchio uses its own joint order in q; build index mapping. + self._arm_joint_names_pin = sorted( + self._arm_joint_names_g1, + key=lambda name: self.reduced_robot.model.idx_qs[self.reduced_robot.model.getJointId(name)], + ) + logger.info(f"Pinocchio arm joint order: {self._arm_joint_names_pin}") + self._arm_reorder_g1_to_pin = [ + self._arm_joint_names_g1.index(name) for name in self._arm_joint_names_pin + ] + # Inverse mapping to return tau in G1 motor order. + self._arm_reorder_pin_to_g1 = np.argsort(self._arm_reorder_g1_to_pin) + + self.reduced_robot.model.addFrame( + self._pin.Frame( + "L_ee", + self.reduced_robot.model.getJointId("left_wrist_yaw_joint"), + self._pin.SE3(np.eye(3), np.array([0.05, 0, 0]).T), + self._pin.FrameType.OP_FRAME, + ) + ) + + self.reduced_robot.model.addFrame( + self._pin.Frame( + "R_ee", + self.reduced_robot.model.getJointId("right_wrist_yaw_joint"), + self._pin.SE3(np.eye(3), np.array([0.05, 0, 0]).T), + self._pin.FrameType.OP_FRAME, + ) + ) + + # Creating Casadi models and data for symbolic computing + self.cmodel = cpin.Model(self.reduced_robot.model) + self.cdata = self.cmodel.createData() + + # Creating symbolic variables + self.cq = casadi.SX.sym("q", self.reduced_robot.model.nq, 1) + self.cTf_l = casadi.SX.sym("tf_l", 4, 4) + self.cTf_r = casadi.SX.sym("tf_r", 4, 4) + cpin.framesForwardKinematics(self.cmodel, self.cdata, self.cq) + + # Get the hand joint ID and define the error function + self.L_hand_id = self.reduced_robot.model.getFrameId("L_ee") + self.R_hand_id = self.reduced_robot.model.getFrameId("R_ee") + + self.translational_error = casadi.Function( + "translational_error", + [self.cq, self.cTf_l, self.cTf_r], + [ + casadi.vertcat( + self.cdata.oMf[self.L_hand_id].translation - self.cTf_l[:3, 3], + self.cdata.oMf[self.R_hand_id].translation - self.cTf_r[:3, 3], + ) + ], + ) + self.rotational_error = casadi.Function( + "rotational_error", + [self.cq, self.cTf_l, self.cTf_r], + [ + casadi.vertcat( + cpin.log3(self.cdata.oMf[self.L_hand_id].rotation @ self.cTf_l[:3, :3].T), + cpin.log3(self.cdata.oMf[self.R_hand_id].rotation @ self.cTf_r[:3, :3].T), + ) + ], + ) + + # Defining the optimization problem + self.opti = casadi.Opti() + self.var_q = self.opti.variable(self.reduced_robot.model.nq) + self.var_q_last = self.opti.parameter(self.reduced_robot.model.nq) # for smooth + self.param_tf_l = self.opti.parameter(4, 4) + self.param_tf_r = self.opti.parameter(4, 4) + self.translational_cost = casadi.sumsqr( + self.translational_error(self.var_q, self.param_tf_l, self.param_tf_r) + ) + self.rotation_cost = casadi.sumsqr( + self.rotational_error(self.var_q, self.param_tf_l, self.param_tf_r) + ) + self.regularization_cost = casadi.sumsqr(self.var_q) + self.smooth_cost = casadi.sumsqr(self.var_q - self.var_q_last) + + # Setting optimization constraints and goals + self.opti.subject_to( + self.opti.bounded( + self.reduced_robot.model.lowerPositionLimit, + self.var_q, + self.reduced_robot.model.upperPositionLimit, + ) + ) + self.opti.minimize( + 50 * self.translational_cost + + self.rotation_cost + + 0.02 * self.regularization_cost + + 0.1 * self.smooth_cost + ) + + opts = { + "ipopt": {"print_level": 0, "max_iter": 50, "tol": 1e-6}, + "print_time": False, # print or not + "calc_lam_p": False, # https://github.com/casadi/casadi/wiki/FAQ:-Why-am-I-getting-%22NaN-detected%22in-my-optimization%3F + } + self.opti.solver("ipopt", opts) + + self.init_data = np.zeros(self.reduced_robot.model.nq) + self.smooth_filter = WeightedMovingFilter(np.array([0.4, 0.3, 0.2, 0.1]), 14) + + def solve_ik(self, left_wrist, right_wrist, current_lr_arm_motor_q=None, current_lr_arm_motor_dq=None): + if current_lr_arm_motor_q is not None: + self.init_data = current_lr_arm_motor_q + self.opti.set_initial(self.var_q, self.init_data) + + self.opti.set_value(self.param_tf_l, left_wrist) + self.opti.set_value(self.param_tf_r, right_wrist) + self.opti.set_value(self.var_q_last, self.init_data) # for smooth + + try: + self.opti.solve() + + sol_q = self.opti.value(self.var_q) + self.smooth_filter.add_data(sol_q) + sol_q = self.smooth_filter.filtered_data + + if current_lr_arm_motor_dq is not None: + v = current_lr_arm_motor_dq * 0.0 + else: + v = (sol_q - self.init_data) * 0.0 + + self.init_data = sol_q + + sol_tauff = self._pin.rnea( + self.reduced_robot.model, + self.reduced_robot.data, + sol_q, + v, + np.zeros(self.reduced_robot.model.nv), + ) + + return sol_q, sol_tauff + + except Exception as e: + logger.error(f"ERROR in convergence, plotting debug info.{e}") + + sol_q = self.opti.debug.value(self.var_q) + self.smooth_filter.add_data(sol_q) + sol_q = self.smooth_filter.filtered_data + + if current_lr_arm_motor_dq is not None: + v = current_lr_arm_motor_dq * 0.0 + else: + v = (sol_q - self.init_data) * 0.0 + + self.init_data = sol_q + + logger.error( + f"sol_q:{sol_q} \nmotorstate: \n{current_lr_arm_motor_q} \nleft_pose: \n{left_wrist} \nright_pose: \n{right_wrist}" + ) + + return current_lr_arm_motor_q, np.zeros(self.reduced_robot.model.nv) + + def solve_tau(self, current_lr_arm_motor_q=None, current_lr_arm_motor_dq=None): + try: + q_g1 = np.array(current_lr_arm_motor_q, dtype=float) + if q_g1.shape[0] != len(self._arm_joint_names_g1): + raise ValueError(f"Expected {len(self._arm_joint_names_g1)} arm joints, got {q_g1.shape[0]}") + q_pin = q_g1[self._arm_reorder_g1_to_pin] + sol_tauff = self._pin.rnea( + self.reduced_robot.model, + self.reduced_robot.data, + q_pin, + np.zeros(self.reduced_robot.model.nv), + np.zeros(self.reduced_robot.model.nv), + ) + return sol_tauff[self._arm_reorder_pin_to_g1] + + except Exception as e: + logger.error(f"ERROR in convergence, plotting debug info.{e}") + return np.zeros(self.reduced_robot.model.nv) diff --git a/src/lerobot/robots/unitree_g1/unitree_g1.py b/src/lerobot/robots/unitree_g1/unitree_g1.py index fa6e0da85..01b4f330e 100644 --- a/src/lerobot/robots/unitree_g1/unitree_g1.py +++ b/src/lerobot/robots/unitree_g1/unitree_g1.py @@ -27,7 +27,8 @@ import numpy as np from lerobot.cameras.utils import make_cameras_from_configs from lerobot.envs.factory import make_env from lerobot.processor import RobotAction, RobotObservation -from lerobot.robots.unitree_g1.g1_utils import G1_29_JointIndex +from lerobot.robots.unitree_g1.g1_utils import G1_29_JointArmIndex, G1_29_JointIndex +from lerobot.robots.unitree_g1.robot_kinematic_processor import G1_29_ArmIK from ..robot import Robot from .config_unitree_g1 import UnitreeG1Config @@ -127,6 +128,8 @@ class UnitreeG1(Robot): self.subscribe_thread = None self.remote_controller = self.RemoteController() + self.arm_ik = G1_29_ArmIK() + def _subscribe_motor_state(self): # polls robot state @ 250Hz while not self._shutdown_event.is_set(): start_time = time.time() @@ -361,6 +364,20 @@ class UnitreeG1(Robot): self.msg.motor_cmd[motor.value].kd = self.kd[motor.value] self.msg.motor_cmd[motor.value].tau = 0 + if self.config.gravity_compensation: + # Build action_np from motor commands (arm joints are indices 15-28, local indices 0-13) + action_np = np.zeros(14) + arm_start_idx = G1_29_JointArmIndex.kLeftShoulderPitch.value # 15 + for joint in G1_29_JointArmIndex: + local_idx = joint.value - arm_start_idx + action_np[local_idx] = self.msg.motor_cmd[joint.value].q + tau = self.arm_ik.solve_tau(action_np) + + # Apply tau back to motor commands + for joint in G1_29_JointArmIndex: + local_idx = joint.value - arm_start_idx + self.msg.motor_cmd[joint.value].tau = tau[local_idx] + self.msg.crc = self.crc.Crc(self.msg) self.lowcmd_publisher.Write(self.msg) return action diff --git a/src/lerobot/scripts/lerobot_calibrate.py b/src/lerobot/scripts/lerobot_calibrate.py index 0f79e6aa2..2fa1b2a03 100644 --- a/src/lerobot/scripts/lerobot_calibrate.py +++ b/src/lerobot/scripts/lerobot_calibrate.py @@ -55,6 +55,7 @@ from lerobot.teleoperators import ( # noqa: F401 omx_leader, openarm_leader, so_leader, + unitree_g1, ) from lerobot.utils.import_utils import register_third_party_plugins from lerobot.utils.utils import init_logging diff --git a/src/lerobot/scripts/lerobot_record.py b/src/lerobot/scripts/lerobot_record.py index 4d334f38f..d621189e8 100644 --- a/src/lerobot/scripts/lerobot_record.py +++ b/src/lerobot/scripts/lerobot_record.py @@ -107,7 +107,7 @@ from lerobot.robots import ( # noqa: F401 openarm_follower, reachy2, so_follower, - unitree_g1, + unitree_g1 as unitree_g1_robot, ) from lerobot.teleoperators import ( # noqa: F401 Teleoperator, @@ -120,6 +120,7 @@ from lerobot.teleoperators import ( # noqa: F401 openarm_leader, reachy2_teleoperator, so_leader, + unitree_g1, ) from lerobot.teleoperators.keyboard.teleop_keyboard import KeyboardTeleop from lerobot.utils.constants import ACTION, OBS_STR diff --git a/src/lerobot/scripts/lerobot_teleoperate.py b/src/lerobot/scripts/lerobot_teleoperate.py index a415dd600..958bd00ef 100644 --- a/src/lerobot/scripts/lerobot_teleoperate.py +++ b/src/lerobot/scripts/lerobot_teleoperate.py @@ -79,6 +79,7 @@ from lerobot.robots import ( # noqa: F401 openarm_follower, reachy2, so_follower, + unitree_g1 as unitree_g1_robot, ) from lerobot.teleoperators import ( # noqa: F401 Teleoperator, @@ -93,6 +94,7 @@ from lerobot.teleoperators import ( # noqa: F401 openarm_leader, reachy2_teleoperator, so_leader, + unitree_g1, ) from lerobot.utils.import_utils import register_third_party_plugins from lerobot.utils.robot_utils import precise_sleep diff --git a/src/lerobot/teleoperators/unitree_g1/__init__.py b/src/lerobot/teleoperators/unitree_g1/__init__.py new file mode 100644 index 000000000..45955a0e2 --- /dev/null +++ b/src/lerobot/teleoperators/unitree_g1/__init__.py @@ -0,0 +1,21 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .config_unitree_g1 import ExoskeletonArmPortConfig, UnitreeG1TeleoperatorConfig +from .exo_calib import ExoskeletonCalibration, ExoskeletonJointCalibration +from .exo_ik import ExoskeletonIKHelper +from .exo_serial import ExoskeletonArm +from .unitree_g1 import UnitreeG1Teleoperator diff --git a/src/lerobot/teleoperators/unitree_g1/config_unitree_g1.py b/src/lerobot/teleoperators/unitree_g1/config_unitree_g1.py new file mode 100644 index 000000000..66c4e7f31 --- /dev/null +++ b/src/lerobot/teleoperators/unitree_g1/config_unitree_g1.py @@ -0,0 +1,37 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field + +from ..config import TeleoperatorConfig + + +@dataclass +class ExoskeletonArmPortConfig: + """Serial port configuration for individual exoskeleton arm.""" + + port: str = "" + baud_rate: int = 115200 + + +@TeleoperatorConfig.register_subclass("unitree_g1") +@dataclass +class UnitreeG1TeleoperatorConfig(TeleoperatorConfig): + left_arm_config: ExoskeletonArmPortConfig = field(default_factory=ExoskeletonArmPortConfig) + right_arm_config: ExoskeletonArmPortConfig = field(default_factory=ExoskeletonArmPortConfig) + + # Frozen joints (comma-separated joint names that won't be moved by IK) + frozen_joints: str = "" diff --git a/src/lerobot/teleoperators/unitree_g1/exo_calib.py b/src/lerobot/teleoperators/unitree_g1/exo_calib.py new file mode 100644 index 000000000..2927a1b55 --- /dev/null +++ b/src/lerobot/teleoperators/unitree_g1/exo_calib.py @@ -0,0 +1,446 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This module handles calibration of hall effect sensors used in the exoskeleton. +Each joint has a pair of ADC channels outputting sin and cos values that trace an ellipse +as the joint rotates due to imprecision in magnet/sensor placement. We fit this ellipse to a unit circle, +and calculate arctan2 of the unit circle to get the joint angle. +We then store the ellipse parameters and the zero offset for each joint to be used at runtime. +""" + +import json +import logging +import time +from collections import deque +from dataclasses import dataclass, field +from pathlib import Path + +import numpy as np +import serial + +logger = logging.getLogger(__name__) + + +# exoskeleton joint names -> ADC channel pairs. TODO: add wrist pitch and wrist yaw +JOINTS = { + "shoulder_pitch": (0, 1), + "shoulder_yaw": (2, 3), + "shoulder_roll": (4, 5), + "elbow_flex": (6, 7), + "wrist_roll": (14, 15), +} + + +@dataclass +class ExoskeletonJointCalibration: + name: str # joint name + center_fit: list[float] # center of the ellipse + T: list[list[float]] # 2x2 transformation matrix + zero_offset: float = 0.0 # angle at neutral pose + + +@dataclass +class ExoskeletonCalibration: + """Full calibration data for an exoskeleton arm.""" + + version: int = 2 + side: str = "" + adc_max: int = 2**12 - 1 + joints: list[ExoskeletonJointCalibration] = field(default_factory=list) + + def to_dict(self) -> dict: + return { + "version": self.version, + "side": self.side, + "adc_max": self.adc_max, + "joints": [ + { + "name": j.name, + "center_fit": j.center_fit, + "T": j.T, + "zero_offset": j.zero_offset, + } + for j in self.joints + ], + } + + @classmethod + def from_dict(cls, data: dict) -> "ExoskeletonCalibration": + joints = [ + ExoskeletonJointCalibration( + name=j["name"], + center_fit=j["center_fit"], + T=j["T"], + zero_offset=j.get("zero_offset", 0.0), + ) + for j in data.get("joints", []) + ] + return cls( + version=data.get("version", 2), + side=data.get("side", ""), + adc_max=data.get("adc_max", 2**12 - 1), + joints=joints, + ) + + +@dataclass(frozen=True) +class CalibParams: + fit_every: float = 0.15 + min_fit_points: int = 60 + fit_window: int = 900 + max_fit_points: int = 300 + trim_low: float = 0.05 + trim_high: float = 0.95 + median_window: int = 5 + history: int = 3500 + draw_hz: float = 120.0 + sample_count: int = 50 + + +def normalize_angle(angle: float) -> float: + while angle > np.pi: + angle -= 2 * np.pi + while angle < -np.pi: + angle += 2 * np.pi + return angle + + +def joint_z_and_angle(raw16: list[int], j: ExoskeletonJointCalibration) -> tuple[np.ndarray, float]: + """ + Applies calibration to each joint: raw → centered → ellipse-to-circle → angle. + """ + pair = JOINTS[j.name] + s, c = raw16[pair[0]], raw16[pair[1]] # get sin and cos + p = np.array([float(c) - (2**12 - 1) / 2, float(s) - (2**12 - 1) / 2]) # center the raw values + z = np.asarray(j.T) @ ( + p - np.asarray(j.center_fit) + ) # center the ellipse and invert the transformation matrix to get unit circle coords + ang = float(np.arctan2(z[1], z[0])) - j.zero_offset # calculate the anvgle and apply the zero offset + return z, normalize_angle(-ang) # ensure range is [-pi, pi] + + +def exo_raw_to_angles(raw16: list[int], calib: ExoskeletonCalibration) -> dict[str, float]: + """Convert raw sensor readings to joint angles using calibration.""" + return {j.name: joint_z_and_angle(raw16, j)[1] for j in calib.joints} + + +def run_exo_calibration( + ser: serial.Serial, + side: str, + save_path: Path, + params: CalibParams | None = None, +) -> ExoskeletonCalibration: + """ + Run interactive calibration for an exoskeleton arm. + """ + try: + import cv2 + import matplotlib.pyplot as plt + except ImportError as e: + raise ImportError( + "Calibration requires matplotlib and opencv-python. " + "Install with: pip install matplotlib opencv-python" + ) from e + + from .exo_serial import read_raw_from_serial + + params = params or CalibParams() + joint_list = list(JOINTS.items()) # Convert dict to list for indexing + logger.info(f"Starting calibration for {side} exoskeleton arm") + + def running_median(win: deque) -> float: + return float(np.median(np.fromiter(win, dtype=float))) + + def read_joint_point(raw16: list[int], pair: tuple[int, int]): + s, c = raw16[pair[0]], raw16[pair[1]] + return float(c) - (2**12 - 1) / 2, float(s) - (2**12 - 1) / 2, float(s), float(c) + + def select_fit_subset(xs, ys): + """Select and filter points for ellipse fitting. Trims outliers by radius and downsamples.""" + n = min(params.fit_window, len(xs)) + if n <= 0: + return None, None + x = np.asarray(list(xs)[-n:], dtype=float) # most recent n samples + y = np.asarray(list(ys)[-n:], dtype=float) + r = np.sqrt(x * x + y * y) # radius from origin + if len(r) >= 20: + lo, hi = np.quantile(r, params.trim_low), np.quantile(r, params.trim_high) # outlier bounds + keep = (r >= lo) & (r <= hi) + x, y = x[keep], y[keep] # remove outliers + if len(x) > params.max_fit_points: + idx = np.linspace(0, len(x) - 1, params.max_fit_points).astype(int) # downsample evenly + x, y = x[idx], y[idx] + return x, y + + def fit_ellipse_opencv(x, y): + """Fit ellipse to (x,y) points using OpenCV. Returns center, axes, rotation matrix, and outline.""" + x, y = np.asarray(x, dtype=float), np.asarray(y, dtype=float) + if len(x) < 5: + return None + pts = np.stack([x, y], axis=1).astype(np.float32).reshape(-1, 1, 2) + try: + (xc, yc), (w, h), angle_deg = cv2.fitEllipse(pts) # returns center, axes, rotation in degrees + except cv2.error: + return None + a, b = float(w) * 0.5, float(h) * 0.5 # get ellipse major and minor semi-axes + phi = np.deg2rad(float(angle_deg)) # to rad + if b > a: # ensure major axis is a + a, b = b, a + phi += np.pi / 2.0 + if not np.isfinite(a) or not np.isfinite(b) or a <= 1e-6 or b <= 1e-6: + return None + cp, sp = float(np.cos(phi)), float(np.sin(phi)) # + rot = np.array([[cp, -sp], [sp, cp]], dtype=float) # 2x2 rotation matrix + center = np.array([float(xc), float(yc)], dtype=float) # offset vector + tt = np.linspace(0, 2 * np.pi, 360) + outline = (rot @ np.stack([a * np.cos(tt), b * np.sin(tt)])).T + center # for viz + return {"center": center, "a": a, "b": b, "R": rot, "ex": outline[:, 0], "ey": outline[:, 1]} + + # Setup matplotlib + plt.ion() + fig, (ax0, ax1) = plt.subplots(1, 2, figsize=(12, 6)) + ax0.set_xlabel("cos - center") + ax0.set_ylabel("sin - center") + ax0.grid(True, alpha=0.25) + ax0.set_aspect("equal", adjustable="box") + ax1.set_title("Unit circle + angle") + ax1.set_xlabel("x") + ax1.set_ylabel("y") + ax1.grid(True, alpha=0.25) + ax1.set_aspect("equal", adjustable="box") + tt = np.linspace(0, 2 * np.pi, 360) + ax1.plot(np.cos(tt), np.sin(tt), "k-", linewidth=1) + ax0.set_xlim(-2200, 2200) + ax0.set_ylim(-2200, 2200) + ax1.set_xlim(-1.4, 1.4) + ax1.set_ylim(-1.4, 1.4) + + sc0 = ax0.scatter([], [], s=6, animated=True) + (ell_line,) = ax0.plot([], [], "r-", linewidth=2, animated=True) + sc1 = ax1.scatter([], [], s=6, animated=True) + (radius_line,) = ax1.plot([], [], "g-", linewidth=2, animated=True) + angle_text = ax1.text( + 0.02, 0.98, "", transform=ax1.transAxes, va="top", ha="left", fontsize=12, animated=True + ) + + fig.canvas.draw() + bg0 = fig.canvas.copy_from_bbox(ax0.bbox) + bg1 = fig.canvas.copy_from_bbox(ax1.bbox) + + # State + joints_out = [] + joint_idx = 0 + phase = "ellipse" + advance_requested = False + zero_samples = [] + + def on_key(event): + nonlocal advance_requested + if event.key in ("n", "N", "enter", " "): + advance_requested = True + + fig.canvas.mpl_connect("key_press_event", on_key) + + def reset_state(): + return { + "xs": deque(maxlen=params.history), + "ys": deque(maxlen=params.history), + "xu": deque(maxlen=params.history), + "yu": deque(maxlen=params.history), + "win_s": deque(maxlen=params.median_window), + "win_c": deque(maxlen=params.median_window), + "ellipse_cache": None, + "T": None, + "center_fit": None, + "have_transform": False, + "latest_z": None, + "last_fit": 0.0, + } + + state = reset_state() + last_draw = 0.0 + name, pair = joint_list[joint_idx] + fig.canvas.manager.set_window_title(f"[{joint_idx + 1}/{len(joint_list)}] {name} - ELLIPSE") + ax0.set_title(f"{name} raw (filtered)") + logger.info(f"[{joint_idx + 1}/{len(joint_list)}] Calibrating {name}") + logger.info("Step 1: Move joint around to map ellipse, then press 'n'") + + try: + while plt.fignum_exists(fig.number): + name, pair = joint_list[joint_idx] + + # Handles calibration GUI state: ellipse → zero_pose → next joint -> ellipse -> ... + if phase == "ellipse" and advance_requested and state["have_transform"]: + joints_out.append( + { + "name": name, + "center_fit": state["center_fit"].tolist(), + "T": state["T"].tolist(), + } + ) + logger.info(f" -> Ellipse saved for {name}") + phase, zero_samples, advance_requested = "zero_pose", [], False + fig.canvas.manager.set_window_title(f"[{joint_idx + 1}/{len(joint_list)}] {name} - ZERO POSE") + ax0.set_title(f"{name} - hold zero pose") + fig.canvas.draw() + bg0, bg1 = fig.canvas.copy_from_bbox(ax0.bbox), fig.canvas.copy_from_bbox(ax1.bbox) + logger.info(f"Step 2: Hold {name} in zero position, then press 'n'") + + elif phase == "ellipse" and advance_requested and not state["have_transform"]: + logger.info(" (Need valid fit first - keep moving the joint)") + advance_requested = False + + elif phase == "zero_pose" and advance_requested: + if len(zero_samples) >= params.sample_count: + zero_offset = float(np.mean(zero_samples[-params.sample_count :])) + joints_out[-1]["zero_offset"] = zero_offset + logger.info(f" -> {name} zero: {zero_offset:+.3f} rad ({np.degrees(zero_offset):+.1f}°)") + joint_idx += 1 + advance_requested = False + + if joint_idx >= len(joint_list): + # All joints done + calib = ExoskeletonCalibration( + version=2, + side=side, + adc_max=2**12 - 1, + joints=[ + ExoskeletonJointCalibration( + name=j["name"], + center_fit=j["center_fit"], + T=j["T"], + zero_offset=j.get("zero_offset", 0.0), + ) + for j in joints_out + ], + ) + save_path.parent.mkdir(parents=True, exist_ok=True) + with open(save_path, "w") as f: + json.dump(calib.to_dict(), f, indent=2) + logger.info(f"Saved calibration to {save_path}") + logger.info("Calibration complete!") + plt.close(fig) + return calib + + # Next joint + phase, state = "ellipse", reset_state() + name, pair = joint_list[joint_idx] + fig.canvas.manager.set_window_title( + f"[{joint_idx + 1}/{len(joint_list)}] {name} - ELLIPSE" + ) + ax0.set_title(f"{name} raw (filtered)") + fig.canvas.draw() + bg0, bg1 = fig.canvas.copy_from_bbox(ax0.bbox), fig.canvas.copy_from_bbox(ax1.bbox) + logger.info(f"[{joint_idx + 1}/{len(joint_list)}] Calibrating {name}") + logger.info("Step 1: Move joint around to map ellipse, then press 'n'") + else: + logger.info( + f" (Collecting samples: {len(zero_samples)}/{params.sample_count} - hold still)" + ) + advance_requested = False + + # Read sensor + raw16 = read_raw_from_serial(ser) + if raw16 is not None: + x_raw, y_raw, s_raw, c_raw = read_joint_point(raw16, pair) + + if phase == "ellipse": + if state["have_transform"]: + z = state["T"] @ (np.array([x_raw, y_raw]) - state["center_fit"]) + state["xu"].append(float(z[0])) + state["yu"].append(float(z[1])) + state["latest_z"] = (float(z[0]), float(z[1])) + state["win_s"].append(s_raw) + state["win_c"].append(c_raw) + if len(state["win_s"]) >= max(3, params.median_window): + state["ys"].append(running_median(state["win_s"]) - (2**12 - 1) / 2) + state["xs"].append(running_median(state["win_c"]) - (2**12 - 1) / 2) + else: + jdata = joints_out[-1] + z = np.array(jdata["T"]) @ (np.array([x_raw, y_raw]) - np.array(jdata["center_fit"])) + zero_samples.append(float(np.arctan2(z[1], z[0]))) + state["latest_z"] = (float(z[0]), float(z[1])) + + # Ellipse fitting + t = time.time() + if ( + phase == "ellipse" + and (t - state["last_fit"]) >= params.fit_every + and len(state["xs"]) >= params.min_fit_points + ): + xfit, yfit = select_fit_subset(state["xs"], state["ys"]) + if xfit is not None and len(xfit) >= params.min_fit_points: + fit = fit_ellipse_opencv(xfit, yfit) + if fit is not None: + state["center_fit"] = fit["center"] + state["T"] = np.diag([1.0 / fit["a"], 1.0 / fit["b"]]) @ fit["R"].T + state["ellipse_cache"] = (fit["ex"], fit["ey"]) + state["have_transform"] = True + state["last_fit"] = t + + # Drawing + if (t - last_draw) >= 1.0 / params.draw_hz: + fig.canvas.restore_region(bg0) + fig.canvas.restore_region(bg1) + + if phase == "ellipse": + sc0.set_offsets(np.c_[state["xs"], state["ys"]] if state["xs"] else np.empty((0, 2))) + ax0.draw_artist(sc0) + ell_line.set_data(*state["ellipse_cache"] if state["ellipse_cache"] else ([], [])) + ax0.draw_artist(ell_line) + sc1.set_offsets(np.c_[state["xu"], state["yu"]] if state["xu"] else np.empty((0, 2))) + ax1.draw_artist(sc1) + if state["latest_z"]: + zx, zy = state["latest_z"] + radius_line.set_data([0.0, zx], [0.0, zy]) + ang = float(np.arctan2(zy, zx)) + angle_text.set_text( + f"angle: {ang:+.3f} rad ({np.degrees(ang):+.1f}°)\nmove {name}, press 'n' to advance" + ) + else: + radius_line.set_data([], []) + angle_text.set_text("(waiting for fit)") + else: + sc0.set_offsets(np.empty((0, 2))) + ax0.draw_artist(sc0) + ell_line.set_data([], []) + ax0.draw_artist(ell_line) + if state["latest_z"]: + zx, zy = state["latest_z"] + sc1.set_offsets([[zx, zy]]) + radius_line.set_data([0.0, zx], [0.0, zy]) + ang = float(np.arctan2(zy, zx)) + angle_text.set_text( + f"Zero pose for {name}\nangle: {ang:+.3f} rad\nsamples: {len(zero_samples)}/{params.sample_count}\nhold still, press 'n'" + ) + else: + sc1.set_offsets(np.empty((0, 2))) + radius_line.set_data([], []) + angle_text.set_text("(waiting for data)") + ax1.draw_artist(sc1) + + ax1.draw_artist(radius_line) + ax1.draw_artist(angle_text) + fig.canvas.blit(ax0.bbox) + fig.canvas.blit(ax1.bbox) + fig.canvas.flush_events() + last_draw = t + + plt.pause(0.001) + + finally: + plt.close(fig) diff --git a/src/lerobot/teleoperators/unitree_g1/exo_ik.py b/src/lerobot/teleoperators/unitree_g1/exo_ik.py new file mode 100644 index 000000000..92519540f --- /dev/null +++ b/src/lerobot/teleoperators/unitree_g1/exo_ik.py @@ -0,0 +1,353 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +IK helper for exoskeleton-to-G1 teleoperation. We map Exoskeleton joint angles to end-effector pose in world frame, +visualizing the result in meshcat after calibration. +""" + +import logging +import os +from dataclasses import dataclass + +import numpy as np + +from lerobot.robots.unitree_g1.g1_utils import G1_29_JointArmIndex +from lerobot.robots.unitree_g1.robot_kinematic_processor import G1_29_ArmIK + +from .exo_calib import JOINTS + +logger = logging.getLogger(__name__) + + +def _frame_id(model, name: str) -> int | None: + try: + fid = model.getFrameId(name) + return fid if 0 <= fid < model.nframes else None + except Exception: + return None + + +@dataclass +class ArmCfg: + side: str # "left" | "right" + urdf: str # exo_left.urdf / exo_right.urdf + root: str # "exo_left" / "exo_right" + g1_ee: str # "l_ee" / "r_ee" + offset: np.ndarray # world offset for viz + target + marker_prefix: str # "left" / "right" + + +class Markers: + """Creates meshcat visualization primitives, showing end-effector frames of exoskeleton and G1""" + + def __init__(self, viewer): + self.v = viewer + + def sphere(self, path: str, r: float, rgba: tuple[float, float, float, float]): + import meshcat.geometry as mg + + c = (int(rgba[0] * 255) << 16) | (int(rgba[1] * 255) << 8) | int(rgba[2] * 255) + self.v[path].set_object( + mg.Sphere(r), + mg.MeshPhongMaterial(color=c, opacity=rgba[3], transparent=rgba[3] < 1.0), + ) + + def axes(self, path: str, axis_len: float = 0.1, axis_w: int = 6): + import meshcat.geometry as mg + + pts = np.array( + [[0, 0, 0], [axis_len, 0, 0], [0, 0, 0], [0, axis_len, 0], [0, 0, 0], [0, 0, axis_len]], + dtype=np.float32, + ).T + cols = np.array( + [[1, 0, 0], [1, 0, 0], [0, 1, 0], [0, 1, 0], [0, 0, 1], [0, 0, 1]], + dtype=np.float32, + ).T + self.v[path].set_object( + mg.LineSegments( + mg.PointsGeometry(position=pts, color=cols), + mg.LineBasicMaterial(linewidth=axis_w, vertexColors=True), + ) + ) + + def tf(self, path: str, mat: np.ndarray): + self.v[path].set_transform(mat) + + +class ExoskeletonIKHelper: + """ + - Loads G1 robot and exoskeleton URDF models via Pinocchio + - Computes forward kinematics on exoskeleton to get end-effector poses + - Solves inverse kinematics on G1 to match those poses + - Provides meshcat visualization showing both robots and targets + + Args: + frozen_joints: List of G1 joint names to exclude from IK (kept at neutral). + """ + + def __init__(self, frozen_joints: list[str] | None = None): + try: + import pinocchio as pin + except ImportError as e: + raise ImportError("ik mode needs pinocchio: pip install pin") from e + + self.pin = pin + self.frozen_joints = frozen_joints or [] + + self.g1_ik = G1_29_ArmIK() + self.robot_g1 = self.g1_ik.reduced_robot + self.robot_g1.data = self.robot_g1.model.createData() + self.q_g1 = pin.neutral(self.robot_g1.model) + + assets_dir = os.path.join(self.g1_ik.repo_path, "assets") + + self.frozen_idx = self._frozen_joint_indices() + + self.arms = [ + ArmCfg( + side="left", + urdf=os.path.join(assets_dir, "exo_left.urdf"), + root="exo_left", + g1_ee="L_ee", + offset=np.array([0.6, 0.3, 0.0]), + marker_prefix="left", + ), + ArmCfg( + side="right", + urdf=os.path.join(assets_dir, "exo_right.urdf"), + root="exo_right", + g1_ee="R_ee", + offset=np.array([0.6, -0.3, 0.0]), + marker_prefix="right", + ), + ] + + self.exo = {} # side -> pin.RobotWrapper + self.q_exo = {} # side -> q + self.ee_id_exo = {} # side -> frame id + self.qmap = {} # side -> {joint_name: q_idx} + self.ee_id_g1 = {} # side -> frame id + + self._load_exo_models(assets_dir) + for a in self.arms: + self.ee_id_g1[a.side] = _frame_id(self.robot_g1.model, a.g1_ee) + + self.viewer = None + self.markers: Markers | None = None + self.viz_g1 = None + self.viz_exo = {} # side -> viz + + def _frozen_joint_indices(self) -> dict[str, int]: + out = {} + m = self.robot_g1.model + for name in self.frozen_joints: + if name in m.names: + jid = m.getJointId(name) + out[name] = m.idx_qs[jid] + logger.info(f"freezing joint: {name} (q_idx={out[name]})") + return out + + def _find_exo_ee(self, model, ee_name: str = "ee") -> int: + ee = _frame_id(model, ee_name) + if ee is not None: + return ee + for fid in reversed(range(model.nframes)): + if model.frames[fid].type == self.pin.FrameType.BODY: + return fid + return 0 + + def _build_joint_map(self, robot) -> dict[str, int]: + m = robot.model + return {n: m.idx_qs[m.getJointId(n)] for n in JOINTS if n in m.names} + + def _load_exo_models(self, assets_dir: str): + pin = self.pin + for a in self.arms: + if not os.path.exists(a.urdf): + logger.warning(f"{a.side} exo urdf not found: {a.urdf}") + continue + r = pin.RobotWrapper.BuildFromURDF(a.urdf, assets_dir) + self.exo[a.side] = r + self.q_exo[a.side] = pin.neutral(r.model) + self.ee_id_exo[a.side] = self._find_exo_ee(r.model) + self.qmap[a.side] = self._build_joint_map(r) + logger.info(f"loaded {a.side} exo urdf: {a.urdf}") + + def init_visualization(self): + """ + Creates a browser-based visualization of exoskeleton and G1 robot, + highlighting end-effector frames and target positions. + """ + try: + from pinocchio.visualize import MeshcatVisualizer + except ImportError as e: + logger.warning(f"meshcat viz unavailable: {e}") + return + + # g1 + self.viz_g1 = MeshcatVisualizer( + self.robot_g1.model, self.robot_g1.collision_model, self.robot_g1.visual_model + ) + self.viz_g1.initViewer(open=True) + self.viz_g1.loadViewerModel("g1") + self.viz_g1.display(self.q_g1) + + self.viewer = self.viz_g1.viewer + self.markers = Markers(self.viewer) + + # exos + for a in self.arms: + if a.side not in self.exo: + continue + r = self.exo[a.side] + v = MeshcatVisualizer(r.model, r.collision_model, r.visual_model) + v.initViewer(open=False) + v.viewer = self.viewer + v.loadViewerModel(a.root) + offset_tf = np.eye(4) + offset_tf[:3, 3] = a.offset + self.viewer[a.root].set_transform(offset_tf) + v.display(self.q_exo[a.side]) + self.viz_exo[a.side] = v + + # markers + for a in self.arms: + p = a.marker_prefix + self.markers.sphere(f"markers/{p}_exo_ee", 0.012, (0.2, 1.0, 0.2, 0.9)) + self.markers.sphere(f"markers/{p}_g1_ee", 0.015, (1.0, 0.2, 0.2, 0.9)) + self.markers.sphere(f"markers/{p}_ik_target", 0.015, (0.1, 0.3, 1.0, 0.9)) + self.markers.axes(f"markers/{p}_exo_axes", 0.06) + self.markers.axes(f"markers/{p}_g1_axes", 0.08) + + logger.info(f"meshcat viz initialized: {self.viewer.url()}") + print(f"\nmeshcat url: {self.viewer.url()}\n") + + def _fk_target_world(self, side: str, angles: dict[str, float]) -> np.ndarray | None: + """returns wrist frame target to be used for G1 IK in 4x4 homogeneous transform. Takes offset into account.""" + if side not in self.exo or not angles: + return None + + pin = self.pin + q = self.q_exo[side] + qmap = self.qmap[side] + + for name, ang in angles.items(): + idx = qmap.get(name) + if idx is not None: + q[idx] = float(ang) + + r = self.exo[side] + pin.forwardKinematics(r.model, r.data, q) + pin.updateFramePlacements(r.model, r.data) + + ee = r.data.oMf[self.ee_id_exo[side]] + target = np.eye(4) + target[:3, :3] = ee.rotation + # offset gets applied in world space + cfg = next(a for a in self.arms if a.side == side) + target[:3, 3] = cfg.offset + ee.translation + return target + + def update_visualization(self): + if self.viewer is None or self.markers is None: + return + + pin = self.pin + + # g1 + if self.viz_g1 is not None: + self.viz_g1.display(self.q_g1) + pin.forwardKinematics(self.robot_g1.model, self.robot_g1.data, self.q_g1) + pin.updateFramePlacements(self.robot_g1.model, self.robot_g1.data) + + for a in self.arms: + fid = self.ee_id_g1.get(a.side) + if fid is None: + continue + ee_tf = self.robot_g1.data.oMf[fid].homogeneous + p = a.marker_prefix + self.markers.tf(f"markers/{p}_g1_ee", ee_tf) + self.markers.tf(f"markers/{p}_g1_axes", ee_tf) + + # exos + for a in self.arms: + side = a.side + v = self.viz_exo.get(side) + if v is None: + continue + + v.display(self.q_exo[side]) + r = self.exo[side] + pin.forwardKinematics(r.model, r.data, self.q_exo[side]) + pin.updateFramePlacements(r.model, r.data) + + ee = r.data.oMf[self.ee_id_exo[side]] + world_tf = (pin.SE3(np.eye(3), a.offset) * ee).homogeneous + p = a.marker_prefix + self.markers.tf(f"markers/{p}_exo_ee", world_tf) + self.markers.tf(f"markers/{p}_exo_axes", world_tf) + + target_tf = np.eye(4) + target_tf[:3, :3] = ee.rotation + target_tf[:3, 3] = a.offset + ee.translation + self.markers.tf(f"markers/{p}_ik_target", target_tf) + + def compute_g1_joints_from_exo( + self, + left_angles: dict[str, float], + right_angles: dict[str, float], + ) -> dict[str, float]: + """ + Performs FK on exoskeleton to get end-effector poses in world frame, + after which it solves IK on G1 to return joint angles matching those poses in G1 motor order. + """ + pin = self.pin + + targets = { + "left": self._fk_target_world("left", left_angles), + "right": self._fk_target_world("right", right_angles), + } + + # fallback to current g1 ee pose if missing target + pin.forwardKinematics(self.robot_g1.model, self.robot_g1.data, self.q_g1) + pin.updateFramePlacements(self.robot_g1.model, self.robot_g1.data) + + for a in self.arms: + if targets[a.side] is not None: + continue + fid = self.ee_id_g1.get(a.side) + if fid is not None: + targets[a.side] = self.robot_g1.data.oMf[fid].homogeneous + + if targets["left"] is None or targets["right"] is None: + logger.warning("missing ik targets, returning current pose") + return {} + + frozen_vals = {n: self.q_g1[i] for n, i in self.frozen_idx.items()} + + self.q_g1, _ = self.g1_ik.solve_ik( + targets["left"], targets["right"], current_lr_arm_motor_q=self.q_g1 + ) + + for n, i in self.frozen_idx.items(): + self.q_g1[i] = frozen_vals[n] + + return { + f"{j.name}.q": float(self.q_g1[i]) + for i, j in enumerate(G1_29_JointArmIndex) + if i < len(self.q_g1) + } diff --git a/src/lerobot/teleoperators/unitree_g1/exo_serial.py b/src/lerobot/teleoperators/unitree_g1/exo_serial.py new file mode 100644 index 000000000..1211c57cc --- /dev/null +++ b/src/lerobot/teleoperators/unitree_g1/exo_serial.py @@ -0,0 +1,119 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import logging +from dataclasses import dataclass +from pathlib import Path + +import serial + +from .exo_calib import ExoskeletonCalibration, exo_raw_to_angles, run_exo_calibration + +logger = logging.getLogger(__name__) + + +def parse_raw16(line: bytes) -> list[int] | None: + try: + parts = line.decode("utf-8", errors="ignore").split() + if len(parts) < 16: + return None + return [int(x) for x in parts[:16]] + except Exception: + return None + + +def read_raw_from_serial(ser) -> list[int] | None: + """Read latest sample from serial; if buffer is backed up, keep only the newest.""" + last = None + while ser.in_waiting > 0: + b = ser.readline() + if not b: + break + raw16 = parse_raw16(b) + if raw16 is not None: + last = raw16 + if last is None: + b = ser.readline() + if b: + last = parse_raw16(b) + return last + + +@dataclass +class ExoskeletonArm: + port: str + calibration_fpath: Path + side: str + baud_rate: int = 115200 + + _ser: serial.Serial | None = None + calibration: ExoskeletonCalibration | None = None + + def __post_init__(self): + if self.calibration_fpath.is_file(): + self._load_calibration() + + @property + def is_connected(self) -> bool: + return self._ser is not None and getattr(self._ser, "is_open", False) + + @property + def is_calibrated(self) -> bool: + return self.calibration is not None + + def connect(self, calibrate: bool = True) -> None: + if self.is_connected: + return + try: + self._ser = serial.Serial(self.port, self.baud_rate, timeout=0.02) + self._ser.reset_input_buffer() + logger.info(f"connected: {self.port}") + except serial.SerialException as e: + raise ConnectionError(f"failed to connect to {self.port}: {e}") from e + + if calibrate and not self.is_calibrated: + self.calibrate() + + def disconnect(self) -> None: + if self._ser: + try: + self._ser.close() + finally: + self._ser = None + + def _load_calibration(self) -> None: + try: + data = json.loads(self.calibration_fpath.read_text()) + self.calibration = ExoskeletonCalibration.from_dict(data) + logger.info(f"loaded calibration: {self.calibration_fpath}") + except Exception as e: + logger.warning(f"failed to load calibration: {e}") + + def read_raw(self) -> list[int] | None: + if not self._ser: + return None + return read_raw_from_serial(self._ser) + + def get_angles(self) -> dict[str, float]: + if not self.calibration: + raise RuntimeError("exoskeleton not calibrated") + raw = self.read_raw() + return {} if raw is None else exo_raw_to_angles(raw, self.calibration) + + def calibrate(self) -> None: + ser = self._ser + self.calibration = run_exo_calibration(ser, self.side, self.calibration_fpath) diff --git a/src/lerobot/teleoperators/unitree_g1/unitree_g1.py b/src/lerobot/teleoperators/unitree_g1/unitree_g1.py new file mode 100644 index 000000000..3779d83ec --- /dev/null +++ b/src/lerobot/teleoperators/unitree_g1/unitree_g1.py @@ -0,0 +1,157 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import time +from functools import cached_property + +from lerobot.robots.unitree_g1.g1_utils import G1_29_JointIndex +from lerobot.utils.constants import HF_LEROBOT_CALIBRATION, TELEOPERATORS + +from ..teleoperator import Teleoperator +from .config_unitree_g1 import UnitreeG1TeleoperatorConfig +from .exo_ik import ExoskeletonIKHelper +from .exo_serial import ExoskeletonArm + +logger = logging.getLogger(__name__) + + +class UnitreeG1Teleoperator(Teleoperator): + """ + Bimanual exoskeleton arms teleoperator for Unitree G1 arms. + + Uses inverse kinematics: exoskeleton FK computes end-effector pose, + G1 IK solves for joint angles. + """ + + config_class = UnitreeG1TeleoperatorConfig + name = "unitree_g1" + + def __init__(self, config: UnitreeG1TeleoperatorConfig): + super().__init__(config) + self.config = config + + # Setup calibration directory + self.calibration_dir = ( + config.calibration_dir + if config.calibration_dir + else HF_LEROBOT_CALIBRATION / TELEOPERATORS / self.name + ) + self.calibration_dir.mkdir(parents=True, exist_ok=True) + + left_id = f"{config.id}_left" if config.id else "left" + right_id = f"{config.id}_right" if config.id else "right" + + # Create exoskeleton arm instances + self.left_arm = ExoskeletonArm( + port=config.left_arm_config.port, + baud_rate=config.left_arm_config.baud_rate, + calibration_fpath=self.calibration_dir / f"{left_id}.json", + side="left", + ) + self.right_arm = ExoskeletonArm( + port=config.right_arm_config.port, + baud_rate=config.right_arm_config.baud_rate, + calibration_fpath=self.calibration_dir / f"{right_id}.json", + side="right", + ) + + self.ik_helper: ExoskeletonIKHelper | None = None + + @cached_property + def action_features(self) -> dict[str, type]: + return {f"{name}.q": float for name in self._g1_joint_names} + + @cached_property + def feedback_features(self) -> dict[str, type]: + return {} + + @property + def is_connected(self) -> bool: + return self.left_arm.is_connected and self.right_arm.is_connected + + @property + def is_calibrated(self) -> bool: + return self.left_arm.is_calibrated and self.right_arm.is_calibrated + + def connect(self, calibrate: bool = True) -> None: + self.left_arm.connect(calibrate) + self.right_arm.connect(calibrate) + + frozen_joints = [j.strip() for j in self.config.frozen_joints.split(",") if j.strip()] + self.ik_helper = ExoskeletonIKHelper(frozen_joints=frozen_joints) + logger.info("IK helper initialized") + + def calibrate(self) -> None: + if not self.left_arm.is_calibrated: + logger.info("Starting calibration for left arm...") + self.left_arm.calibrate() + else: + logger.info("Left arm already calibrated. Skipping.") + + if not self.right_arm.is_calibrated: + logger.info("Starting calibration for right arm...") + self.right_arm.calibrate() + else: + logger.info("Right arm already calibrated. Skipping.") + + logger.info("Starting visualization to verify calibration...") + self.run_visualization_loop() + + def configure(self) -> None: + pass + + def get_action(self) -> dict[str, float]: + left_angles = self.left_arm.get_angles() + right_angles = self.right_arm.get_angles() + return self.ik_helper.compute_g1_joints_from_exo(left_angles, right_angles) + + def send_feedback(self, feedback: dict[str, float]) -> None: + raise NotImplementedError("Exoskeleton arms do not support feedback") + + def disconnect(self) -> None: + self.left_arm.disconnect() + self.right_arm.disconnect() + + def run_visualization_loop(self): + """Run interactive Meshcat visualization loop to verify tracking.""" + if self.ik_helper is None: + frozen_joints = [j.strip() for j in self.config.frozen_joints.split(",") if j.strip()] + self.ik_helper = ExoskeletonIKHelper(frozen_joints=frozen_joints) + + self.ik_helper.init_visualization() + + print("\n" + "=" * 60) + print("Visualization running! Move the exoskeletons to test tracking.") + print("Press Ctrl+C to exit.") + print("=" * 60 + "\n") + + try: + while True: + left_angles = self.left_arm.get_angles() + right_angles = self.right_arm.get_angles() + + self.ik_helper.compute_g1_joints_from_exo(left_angles, right_angles) + self.ik_helper.update_visualization() + + time.sleep(0.01) + + except KeyboardInterrupt: + print("\n\nVisualization stopped.") + + @cached_property + def _g1_joint_names(self) -> list[str]: + return [joint.name for joint in G1_29_JointIndex] diff --git a/src/lerobot/teleoperators/utils.py b/src/lerobot/teleoperators/utils.py index 8f6bbc787..3b42d294e 100644 --- a/src/lerobot/teleoperators/utils.py +++ b/src/lerobot/teleoperators/utils.py @@ -75,6 +75,10 @@ def make_teleoperator_from_config(config: TeleoperatorConfig) -> "Teleoperator": from .homunculus import HomunculusArm return HomunculusArm(config) + elif config.type == "unitree_g1": + from .unitree_g1 import UnitreeG1Teleoperator + + return UnitreeG1Teleoperator(config) elif config.type == "bi_so_leader": from .bi_so_leader import BiSOLeader From 4483184875e0c283066ba304e2404638f8aa803e Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Wed, 28 Jan 2026 17:25:57 +0100 Subject: [PATCH 015/131] feat(robots): add bi manual openarm follower and leader (#2835) * fix(motors): cleanup imports + fix signatures * feat(motors): add damiao canbus + multiple fixes * fix(motors): address comments -> last_state + different gains + sleep * refactor(motors): reduce duplicated code + adressed some comments in the PR * chore(motors): better timeouts * tests(motors): damiao test and imports * chore(deps): fix space * feat(robot): add openarm leader Co-authored-by: Pepijn * feat(robot): add openarm follower Co-authored-by: Pepijn * refactor(robot): remove mechanical compensations and double arm assumption + rename * chore(robots): remove left arm references * refactor(teleop): multiple improvements to leader * refactor(teleop): multiple improvements to leader * feat(robots): add open arm to util CLI * chore(robot): add alias openarm * Apply suggestions from code review Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com> Signed-off-by: Steven Palma * chore(motors): remove normalization tables damiao * fix(motors): imports and signatures * feat(motors): add motor_type_str + recv_id to motor class and _get_motor_recv_id raises if no motor_obj.recv_id * chore(motors): remove normalize from base motor class and damaio * tests(motors): remove bad tests (to be replaced) * chore(motors): updated import check * fix(robots): open arm mirrored config for joint limits * chore(motors): update position_kd gain values * chore(robots): set to 0 if openarm is calibrated at connect time * chore(robots): remove macos in open arm as can doesn't support it * chore(robots): update for motor_type_str in Motor class * chore(robots): no default value for can port in open arms * feat(robots): add bi manual openarm follower and leader * use constant for kp and kd range and check responses in mit_control_batch() * Add docs on setting up canbus and use damiao otor bus, also add lerobot_setup_can.py and log if there is not response from a write command * precommit format * supress bandit as these are intentional cli commands * fix setup-can * add test * skip test in ci * nit precommit * update doc example * dont import can for tests * remove comment * Add openarms docs * format * update purchase link * can to none if nit availabl;e * add canfd option in bus * make handshake logic similar to lerobot-can * type hint * type check * add temp teleop test * remove script * mock class * mock class * ignore linter * pre-commit * Add command for bimanual openarm * fix import * fix import leader * fix import draccus --------- Signed-off-by: Steven Palma Co-authored-by: Pepijn Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com> --- docs/source/openarm.mdx | 18 ++ .../robots/bi_openarm_follower/__init__.py | 20 ++ .../bi_openarm_follower.py | 175 ++++++++++++++++++ .../config_bi_openarm_follower.py | 30 +++ .../robots/openarm_follower/__init__.py | 4 +- .../config_openarm_follower.py | 11 +- src/lerobot/robots/utils.py | 4 + src/lerobot/scripts/lerobot_calibrate.py | 2 + .../scripts/lerobot_find_joint_limits.py | 2 + src/lerobot/scripts/lerobot_record.py | 2 + src/lerobot/scripts/lerobot_replay.py | 1 + src/lerobot/scripts/lerobot_teleoperate.py | 2 + .../bi_openarm_leader/__init__.py | 20 ++ .../bi_openarm_leader/bi_openarm_leader.py | 131 +++++++++++++ .../config_bi_openarm_leader.py | 30 +++ .../teleoperators/openarm_leader/__init__.py | 4 +- .../openarm_leader/config_openarm_leader.py | 11 +- src/lerobot/teleoperators/utils.py | 4 + 18 files changed, 461 insertions(+), 10 deletions(-) create mode 100644 src/lerobot/robots/bi_openarm_follower/__init__.py create mode 100644 src/lerobot/robots/bi_openarm_follower/bi_openarm_follower.py create mode 100644 src/lerobot/robots/bi_openarm_follower/config_bi_openarm_follower.py create mode 100644 src/lerobot/teleoperators/bi_openarm_leader/__init__.py create mode 100644 src/lerobot/teleoperators/bi_openarm_leader/bi_openarm_leader.py create mode 100644 src/lerobot/teleoperators/bi_openarm_leader/config_bi_openarm_leader.py diff --git a/docs/source/openarm.mdx b/docs/source/openarm.mdx index 661808749..cd4ace912 100644 --- a/docs/source/openarm.mdx +++ b/docs/source/openarm.mdx @@ -174,6 +174,24 @@ lerobot-teleoperate \ --teleop.id=my_leader ``` +### Bimanual Teleoperation + +To teleoperate a bimanual OpenArm setup with two leader and two follower arms: + +```bash +lerobot-teleoperate \ + --robot.type=bi_openarm_follower \ + --robot.left_arm_config.port=can0 \ + --robot.left_arm_config.side=left \ + --robot.right_arm_config.port=can1 \ + --robot.right_arm_config.side=right \ + --robot.id=my_bimanual_follower \ + --teleop.type=bi_openarm_leader \ + --teleop.left_arm_config.port=can2 \ + --teleop.right_arm_config.port=can3 \ + --teleop.id=my_bimanual_leader +``` + ### Recording Data To record a dataset during teleoperation: diff --git a/src/lerobot/robots/bi_openarm_follower/__init__.py b/src/lerobot/robots/bi_openarm_follower/__init__.py new file mode 100644 index 000000000..b1dcce431 --- /dev/null +++ b/src/lerobot/robots/bi_openarm_follower/__init__.py @@ -0,0 +1,20 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .bi_openarm_follower import BiOpenArmFollower +from .config_bi_openarm_follower import BiOpenArmFollowerConfig + +__all__ = ["BiOpenArmFollower", "BiOpenArmFollowerConfig"] diff --git a/src/lerobot/robots/bi_openarm_follower/bi_openarm_follower.py b/src/lerobot/robots/bi_openarm_follower/bi_openarm_follower.py new file mode 100644 index 000000000..466eb07e5 --- /dev/null +++ b/src/lerobot/robots/bi_openarm_follower/bi_openarm_follower.py @@ -0,0 +1,175 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from functools import cached_property + +from lerobot.processor import RobotAction, RobotObservation +from lerobot.robots.openarm_follower import OpenArmFollower, OpenArmFollowerConfig + +from ..robot import Robot +from .config_bi_openarm_follower import BiOpenArmFollowerConfig + +logger = logging.getLogger(__name__) + + +class BiOpenArmFollower(Robot): + """ + Bimanual OpenArm Follower Arms + """ + + config_class = BiOpenArmFollowerConfig + name = "bi_openarm_follower" + + def __init__(self, config: BiOpenArmFollowerConfig): + super().__init__(config) + self.config = config + + left_arm_config = OpenArmFollowerConfig( + id=f"{config.id}_left" if config.id else None, + calibration_dir=config.calibration_dir, + port=config.left_arm_config.port, + disable_torque_on_disconnect=config.left_arm_config.disable_torque_on_disconnect, + max_relative_target=config.left_arm_config.max_relative_target, + cameras=config.left_arm_config.cameras, + side=config.left_arm_config.side, + can_interface=config.left_arm_config.can_interface, + use_can_fd=config.left_arm_config.use_can_fd, + can_bitrate=config.left_arm_config.can_bitrate, + can_data_bitrate=config.left_arm_config.can_data_bitrate, + motor_config=config.left_arm_config.motor_config, + position_kd=config.left_arm_config.position_kd, + position_kp=config.left_arm_config.position_kp, + joint_limits=config.left_arm_config.joint_limits, + ) + + right_arm_config = OpenArmFollowerConfig( + id=f"{config.id}_right" if config.id else None, + calibration_dir=config.calibration_dir, + port=config.right_arm_config.port, + disable_torque_on_disconnect=config.right_arm_config.disable_torque_on_disconnect, + max_relative_target=config.right_arm_config.max_relative_target, + cameras=config.right_arm_config.cameras, + side=config.right_arm_config.side, + can_interface=config.right_arm_config.can_interface, + use_can_fd=config.right_arm_config.use_can_fd, + can_bitrate=config.right_arm_config.can_bitrate, + can_data_bitrate=config.right_arm_config.can_data_bitrate, + motor_config=config.right_arm_config.motor_config, + position_kd=config.right_arm_config.position_kd, + position_kp=config.right_arm_config.position_kp, + joint_limits=config.right_arm_config.joint_limits, + ) + + self.left_arm = OpenArmFollower(left_arm_config) + self.right_arm = OpenArmFollower(right_arm_config) + + # Only for compatibility with other parts of the codebase that expect a `robot.cameras` attribute + self.cameras = {**self.left_arm.cameras, **self.right_arm.cameras} + + @property + def _motors_ft(self) -> dict[str, type]: + left_arm_motors_ft = self.left_arm._motors_ft + right_arm_motors_ft = self.right_arm._motors_ft + + return { + **{f"left_{k}": v for k, v in left_arm_motors_ft.items()}, + **{f"right_{k}": v for k, v in right_arm_motors_ft.items()}, + } + + @property + def _cameras_ft(self) -> dict[str, tuple]: + left_arm_cameras_ft = self.left_arm._cameras_ft + right_arm_cameras_ft = self.right_arm._cameras_ft + + return { + **{f"left_{k}": v for k, v in left_arm_cameras_ft.items()}, + **{f"right_{k}": v for k, v in right_arm_cameras_ft.items()}, + } + + @cached_property + def observation_features(self) -> dict[str, type | tuple]: + return {**self._motors_ft, **self._cameras_ft} + + @cached_property + def action_features(self) -> dict[str, type]: + return self._motors_ft + + @property + def is_connected(self) -> bool: + return self.left_arm.is_connected and self.right_arm.is_connected + + def connect(self, calibrate: bool = True) -> None: + self.left_arm.connect(calibrate) + self.right_arm.connect(calibrate) + + @property + def is_calibrated(self) -> bool: + return self.left_arm.is_calibrated and self.right_arm.is_calibrated + + def calibrate(self) -> None: + self.left_arm.calibrate() + self.right_arm.calibrate() + + def configure(self) -> None: + self.left_arm.configure() + self.right_arm.configure() + + def setup_motors(self) -> None: + raise NotImplementedError( + "Motor ID configuration is typically done via manufacturer tools for CAN motors." + ) + + def get_observation(self) -> RobotObservation: + obs_dict = {} + + # Add "left_" prefix + left_obs = self.left_arm.get_observation() + obs_dict.update({f"left_{key}": value for key, value in left_obs.items()}) + + # Add "right_" prefix + right_obs = self.right_arm.get_observation() + obs_dict.update({f"right_{key}": value for key, value in right_obs.items()}) + + return obs_dict + + def send_action( + self, + action: RobotAction, + custom_kp: dict[str, float] | None = None, + custom_kd: dict[str, float] | None = None, + ) -> RobotAction: + # Remove "left_" prefix + left_action = { + key.removeprefix("left_"): value for key, value in action.items() if key.startswith("left_") + } + # Remove "right_" prefix + right_action = { + key.removeprefix("right_"): value for key, value in action.items() if key.startswith("right_") + } + + sent_action_left = self.left_arm.send_action(left_action, custom_kp, custom_kd) + sent_action_right = self.right_arm.send_action(right_action, custom_kp, custom_kd) + + # Add prefixes back + prefixed_sent_action_left = {f"left_{key}": value for key, value in sent_action_left.items()} + prefixed_sent_action_right = {f"right_{key}": value for key, value in sent_action_right.items()} + + return {**prefixed_sent_action_left, **prefixed_sent_action_right} + + def disconnect(self): + self.left_arm.disconnect() + self.right_arm.disconnect() diff --git a/src/lerobot/robots/bi_openarm_follower/config_bi_openarm_follower.py b/src/lerobot/robots/bi_openarm_follower/config_bi_openarm_follower.py new file mode 100644 index 000000000..9d11f7b4e --- /dev/null +++ b/src/lerobot/robots/bi_openarm_follower/config_bi_openarm_follower.py @@ -0,0 +1,30 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass + +from lerobot.robots.openarm_follower import OpenArmFollowerConfigBase + +from ..config import RobotConfig + + +@RobotConfig.register_subclass("bi_openarm_follower") +@dataclass +class BiOpenArmFollowerConfig(RobotConfig): + """Configuration class for Bi OpenArm Follower robots.""" + + left_arm_config: OpenArmFollowerConfigBase + right_arm_config: OpenArmFollowerConfigBase diff --git a/src/lerobot/robots/openarm_follower/__init__.py b/src/lerobot/robots/openarm_follower/__init__.py index 1eb0d9fc7..217432fd5 100644 --- a/src/lerobot/robots/openarm_follower/__init__.py +++ b/src/lerobot/robots/openarm_follower/__init__.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .config_openarm_follower import OpenArmFollowerConfig +from .config_openarm_follower import OpenArmFollowerConfig, OpenArmFollowerConfigBase from .openarm_follower import OpenArmFollower -__all__ = ["OpenArmFollower", "OpenArmFollowerConfig"] +__all__ = ["OpenArmFollower", "OpenArmFollowerConfig", "OpenArmFollowerConfigBase"] diff --git a/src/lerobot/robots/openarm_follower/config_openarm_follower.py b/src/lerobot/robots/openarm_follower/config_openarm_follower.py index af95b6395..88d81fd50 100644 --- a/src/lerobot/robots/openarm_follower/config_openarm_follower.py +++ b/src/lerobot/robots/openarm_follower/config_openarm_follower.py @@ -43,10 +43,9 @@ RIGHT_DEFAULT_JOINTS_LIMITS: dict[str, tuple[float, float]] = { } -@RobotConfig.register_subclass("openarm_follower") @dataclass -class OpenArmFollowerConfig(RobotConfig): - """Configuration for the OpenArms follower robot with Damiao motors.""" +class OpenArmFollowerConfigBase: + """Base configuration for the OpenArms follower robot with Damiao motors.""" # CAN interfaces - one per arm # arm CAN interface (e.g., "can1") @@ -115,3 +114,9 @@ class OpenArmFollowerConfig(RobotConfig): "gripper": (-5.0, 0.0), } ) + + +@RobotConfig.register_subclass("openarm_follower") +@dataclass +class OpenArmFollowerConfig(RobotConfig, OpenArmFollowerConfigBase): + pass diff --git a/src/lerobot/robots/utils.py b/src/lerobot/robots/utils.py index e0c76cab3..92da597f1 100644 --- a/src/lerobot/robots/utils.py +++ b/src/lerobot/robots/utils.py @@ -64,6 +64,10 @@ def make_robot_from_config(config: RobotConfig) -> Robot: from .openarm_follower import OpenArmFollower return OpenArmFollower(config) + elif config.type == "bi_openarm_follower": + from .bi_openarm_follower import BiOpenArmFollower + + return BiOpenArmFollower(config) elif config.type == "mock_robot": from tests.mocks.mock_robot import MockRobot diff --git a/src/lerobot/scripts/lerobot_calibrate.py b/src/lerobot/scripts/lerobot_calibrate.py index 2fa1b2a03..eb3df6872 100644 --- a/src/lerobot/scripts/lerobot_calibrate.py +++ b/src/lerobot/scripts/lerobot_calibrate.py @@ -36,6 +36,7 @@ from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraCon from lerobot.robots import ( # noqa: F401 Robot, RobotConfig, + bi_openarm_follower, bi_so_follower, hope_jr, koch_follower, @@ -48,6 +49,7 @@ from lerobot.robots import ( # noqa: F401 from lerobot.teleoperators import ( # noqa: F401 Teleoperator, TeleoperatorConfig, + bi_openarm_leader, bi_so_leader, homunculus, koch_leader, diff --git a/src/lerobot/scripts/lerobot_find_joint_limits.py b/src/lerobot/scripts/lerobot_find_joint_limits.py index d928dc5cd..082d11803 100644 --- a/src/lerobot/scripts/lerobot_find_joint_limits.py +++ b/src/lerobot/scripts/lerobot_find_joint_limits.py @@ -44,6 +44,7 @@ import numpy as np from lerobot.model.kinematics import RobotKinematics from lerobot.robots import ( # noqa: F401 RobotConfig, + bi_openarm_follower, bi_so_follower, koch_follower, make_robot_from_config, @@ -53,6 +54,7 @@ from lerobot.robots import ( # noqa: F401 ) from lerobot.teleoperators import ( # noqa: F401 TeleoperatorConfig, + bi_openarm_leader, bi_so_leader, gamepad, koch_leader, diff --git a/src/lerobot/scripts/lerobot_record.py b/src/lerobot/scripts/lerobot_record.py index d621189e8..0b39e6fff 100644 --- a/src/lerobot/scripts/lerobot_record.py +++ b/src/lerobot/scripts/lerobot_record.py @@ -98,6 +98,7 @@ from lerobot.processor.rename_processor import rename_stats from lerobot.robots import ( # noqa: F401 Robot, RobotConfig, + bi_openarm_follower, bi_so_follower, earthrover_mini_plus, hope_jr, @@ -112,6 +113,7 @@ from lerobot.robots import ( # noqa: F401 from lerobot.teleoperators import ( # noqa: F401 Teleoperator, TeleoperatorConfig, + bi_openarm_leader, bi_so_leader, homunculus, koch_leader, diff --git a/src/lerobot/scripts/lerobot_replay.py b/src/lerobot/scripts/lerobot_replay.py index c3bc3d766..5717dffb6 100644 --- a/src/lerobot/scripts/lerobot_replay.py +++ b/src/lerobot/scripts/lerobot_replay.py @@ -53,6 +53,7 @@ from lerobot.processor import ( from lerobot.robots import ( # noqa: F401 Robot, RobotConfig, + bi_openarm_follower, bi_so_follower, earthrover_mini_plus, hope_jr, diff --git a/src/lerobot/scripts/lerobot_teleoperate.py b/src/lerobot/scripts/lerobot_teleoperate.py index 958bd00ef..b6aa4a750 100644 --- a/src/lerobot/scripts/lerobot_teleoperate.py +++ b/src/lerobot/scripts/lerobot_teleoperate.py @@ -70,6 +70,7 @@ from lerobot.processor import ( from lerobot.robots import ( # noqa: F401 Robot, RobotConfig, + bi_openarm_follower, bi_so_follower, earthrover_mini_plus, hope_jr, @@ -84,6 +85,7 @@ from lerobot.robots import ( # noqa: F401 from lerobot.teleoperators import ( # noqa: F401 Teleoperator, TeleoperatorConfig, + bi_openarm_leader, bi_so_leader, gamepad, homunculus, diff --git a/src/lerobot/teleoperators/bi_openarm_leader/__init__.py b/src/lerobot/teleoperators/bi_openarm_leader/__init__.py new file mode 100644 index 000000000..fe728b826 --- /dev/null +++ b/src/lerobot/teleoperators/bi_openarm_leader/__init__.py @@ -0,0 +1,20 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .bi_openarm_leader import BiOpenArmLeader +from .config_bi_openarm_leader import BiOpenArmLeaderConfig + +__all__ = ["BiOpenArmLeader", "BiOpenArmLeaderConfig"] diff --git a/src/lerobot/teleoperators/bi_openarm_leader/bi_openarm_leader.py b/src/lerobot/teleoperators/bi_openarm_leader/bi_openarm_leader.py new file mode 100644 index 000000000..c4383293f --- /dev/null +++ b/src/lerobot/teleoperators/bi_openarm_leader/bi_openarm_leader.py @@ -0,0 +1,131 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from functools import cached_property + +from lerobot.processor import RobotAction +from lerobot.teleoperators.openarm_leader import OpenArmLeaderConfig + +from ..openarm_leader import OpenArmLeader +from ..teleoperator import Teleoperator +from .config_bi_openarm_leader import BiOpenArmLeaderConfig + +logger = logging.getLogger(__name__) + + +class BiOpenArmLeader(Teleoperator): + """ + Bimanual OpenArm Leader Arms + """ + + config_class = BiOpenArmLeaderConfig + name = "bi_openarm_leader" + + def __init__(self, config: BiOpenArmLeaderConfig): + super().__init__(config) + self.config = config + + left_arm_config = OpenArmLeaderConfig( + id=f"{config.id}_left" if config.id else None, + calibration_dir=config.calibration_dir, + port=config.left_arm_config.port, + can_interface=config.left_arm_config.can_interface, + use_can_fd=config.left_arm_config.use_can_fd, + can_bitrate=config.left_arm_config.can_bitrate, + can_data_bitrate=config.left_arm_config.can_data_bitrate, + motor_config=config.left_arm_config.motor_config, + manual_control=config.left_arm_config.manual_control, + position_kd=config.left_arm_config.position_kd, + position_kp=config.left_arm_config.position_kp, + ) + + right_arm_config = OpenArmLeaderConfig( + id=f"{config.id}_right" if config.id else None, + calibration_dir=config.calibration_dir, + port=config.right_arm_config.port, + can_interface=config.right_arm_config.can_interface, + use_can_fd=config.right_arm_config.use_can_fd, + can_bitrate=config.right_arm_config.can_bitrate, + can_data_bitrate=config.right_arm_config.can_data_bitrate, + motor_config=config.right_arm_config.motor_config, + manual_control=config.right_arm_config.manual_control, + position_kd=config.right_arm_config.position_kd, + position_kp=config.right_arm_config.position_kp, + ) + + self.left_arm = OpenArmLeader(left_arm_config) + self.right_arm = OpenArmLeader(right_arm_config) + + @cached_property + def action_features(self) -> dict[str, type]: + left_arm_features = self.left_arm.action_features + right_arm_features = self.right_arm.action_features + + return { + **{f"left_{k}": v for k, v in left_arm_features.items()}, + **{f"right_{k}": v for k, v in right_arm_features.items()}, + } + + @cached_property + def feedback_features(self) -> dict[str, type]: + return {} + + @property + def is_connected(self) -> bool: + return self.left_arm.is_connected and self.right_arm.is_connected + + def connect(self, calibrate: bool = True) -> None: + self.left_arm.connect(calibrate) + self.right_arm.connect(calibrate) + + @property + def is_calibrated(self) -> bool: + return self.left_arm.is_calibrated and self.right_arm.is_calibrated + + def calibrate(self) -> None: + self.left_arm.calibrate() + self.right_arm.calibrate() + + def configure(self) -> None: + self.left_arm.configure() + self.right_arm.configure() + + def setup_motors(self) -> None: + raise NotImplementedError( + "Motor ID configuration is typically done via manufacturer tools for CAN motors." + ) + + def get_action(self) -> RobotAction: + action_dict = {} + + # Add "left_" prefix + left_action = self.left_arm.get_action() + action_dict.update({f"left_{key}": value for key, value in left_action.items()}) + + # Add "right_" prefix + right_action = self.right_arm.get_action() + action_dict.update({f"right_{key}": value for key, value in right_action.items()}) + + return action_dict + + def send_feedback(self, feedback: dict[str, float]) -> None: + # TODO: Implement force feedback + raise NotImplementedError + + def disconnect(self) -> None: + self.left_arm.disconnect() + self.right_arm.disconnect() diff --git a/src/lerobot/teleoperators/bi_openarm_leader/config_bi_openarm_leader.py b/src/lerobot/teleoperators/bi_openarm_leader/config_bi_openarm_leader.py new file mode 100644 index 000000000..39fc90add --- /dev/null +++ b/src/lerobot/teleoperators/bi_openarm_leader/config_bi_openarm_leader.py @@ -0,0 +1,30 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass + +from lerobot.teleoperators.openarm_leader import OpenArmLeaderConfigBase + +from ..config import TeleoperatorConfig + + +@TeleoperatorConfig.register_subclass("bi_openarm_leader") +@dataclass +class BiOpenArmLeaderConfig(TeleoperatorConfig): + """Configuration class for Bi OpenArm Follower robots.""" + + left_arm_config: OpenArmLeaderConfigBase + right_arm_config: OpenArmLeaderConfigBase diff --git a/src/lerobot/teleoperators/openarm_leader/__init__.py b/src/lerobot/teleoperators/openarm_leader/__init__.py index 1493317fe..172cf8228 100644 --- a/src/lerobot/teleoperators/openarm_leader/__init__.py +++ b/src/lerobot/teleoperators/openarm_leader/__init__.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .config_openarm_leader import OpenArmLeaderConfig +from .config_openarm_leader import OpenArmLeaderConfig, OpenArmLeaderConfigBase from .openarm_leader import OpenArmLeader -__all__ = ["OpenArmLeader", "OpenArmLeaderConfig"] +__all__ = ["OpenArmLeader", "OpenArmLeaderConfig", "OpenArmLeaderConfigBase"] diff --git a/src/lerobot/teleoperators/openarm_leader/config_openarm_leader.py b/src/lerobot/teleoperators/openarm_leader/config_openarm_leader.py index c53169b0a..4b12fe730 100644 --- a/src/lerobot/teleoperators/openarm_leader/config_openarm_leader.py +++ b/src/lerobot/teleoperators/openarm_leader/config_openarm_leader.py @@ -19,10 +19,9 @@ from dataclasses import dataclass, field from ..config import TeleoperatorConfig -@TeleoperatorConfig.register_subclass("openarm_leader") @dataclass -class OpenArmLeaderConfig(TeleoperatorConfig): - """Configuration for the OpenArms leader/teleoperator with Damiao motors.""" +class OpenArmLeaderConfigBase: + """Base configuration for the OpenArms leader/teleoperator with Damiao motors.""" # CAN interfaces - one per arm # Arm CAN interface (e.g., "can3") @@ -68,3 +67,9 @@ class OpenArmLeaderConfig(TeleoperatorConfig): default_factory=lambda: [240.0, 240.0, 240.0, 240.0, 24.0, 31.0, 25.0, 16.0] ) position_kd: list[float] = field(default_factory=lambda: [3.0, 3.0, 3.0, 3.0, 0.2, 0.2, 0.2, 0.2]) + + +@TeleoperatorConfig.register_subclass("openarm_leader") +@dataclass +class OpenArmLeaderConfig(TeleoperatorConfig, OpenArmLeaderConfigBase): + pass diff --git a/src/lerobot/teleoperators/utils.py b/src/lerobot/teleoperators/utils.py index 3b42d294e..16454d5ad 100644 --- a/src/lerobot/teleoperators/utils.py +++ b/src/lerobot/teleoperators/utils.py @@ -91,6 +91,10 @@ def make_teleoperator_from_config(config: TeleoperatorConfig) -> "Teleoperator": from .openarm_leader import OpenArmLeader return OpenArmLeader(config) + elif config.type == "bi_openarm_leader": + from .bi_openarm_leader import BiOpenArmLeader + + return BiOpenArmLeader(config) else: try: return cast("Teleoperator", make_device_from_device_class(config)) From 3409ef0dc2fc948468ec573a1e47f8ab7747cee2 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Thu, 29 Jan 2026 04:07:47 -0600 Subject: [PATCH 016/131] refactor(cameras): cameras API extension (#2808) * feat(cameras): add new read_latest() method * fix(cameras): fix threading bug + clear state * refactor(cameras): multiple improvements * feat(camera): add context manager to camera base class * chore(camera): slight modifications to opencv * test(cameras): update opencv tests according to the changes * refactor(cameras): reflect desing changes to realsense + deal with depth * test(cameras): fix realsense tests accordingly to new changes * refactor(cameras): update reachymini and zmq accordingly * chore: wrap resource sensitive examples into a try/finally * test(cameras): add test for new read_latest * test(cameras): fix problem with image artifact in opencv tests * test(cameras): fix test_read_latest_high_frequency expectations * Apply suggestions from code review 1 Co-authored-by: Caroline Pascal Signed-off-by: Steven Palma * chore(cameras): address feedback * feat(cameras): add max_age_ms check in read_latest * test(cameras): fix read_latest tests * chore(redundancies): removing redundancies in Reachy 2 camera class * fix(warmup): replacing the arbitrary time.sleep in by an actual warmup in the RealSense camera class * chore(format): formatting latest changes * chore(warning): adding a "to be implemented" warning for read_latest() in Camera base class * chore(warning): making read_latest() warning message shorter and clearer --------- Signed-off-by: Steven Palma Co-authored-by: Caroline Pascal --- examples/backward_compatibility/replay.py | 31 +-- examples/lekiwi/evaluate.py | 88 +++---- examples/lekiwi/record.py | 89 +++---- examples/lekiwi/replay.py | 32 +-- examples/phone_to_so100/evaluate.py | 85 ++++--- examples/phone_to_so100/record.py | 85 ++++--- examples/phone_to_so100/replay.py | 42 ++-- examples/so100_to_so100_EE/evaluate.py | 85 ++++--- examples/so100_to_so100_EE/record.py | 86 ++++--- examples/so100_to_so100_EE/replay.py | 41 +-- src/lerobot/cameras/camera.py | 100 ++++++-- src/lerobot/cameras/opencv/camera_opencv.py | 166 ++++++++---- .../cameras/reachy2_camera/reachy2_camera.py | 67 +++-- .../cameras/realsense/camera_realsense.py | 212 +++++++++++----- src/lerobot/cameras/zmq/camera_zmq.py | 238 ++++++++++++++---- src/lerobot/cameras/zmq/configuration_zmq.py | 1 + src/lerobot/scripts/lerobot_calibrate.py | 7 +- src/lerobot/scripts/lerobot_replay.py | 29 +-- tests/cameras/test_opencv.py | 184 +++++++++----- tests/cameras/test_reachy2_camera.py | 38 +++ tests/cameras/test_realsense.py | 114 +++++---- 21 files changed, 1179 insertions(+), 641 deletions(-) diff --git a/examples/backward_compatibility/replay.py b/examples/backward_compatibility/replay.py index ed78d016f..8de5ba197 100644 --- a/examples/backward_compatibility/replay.py +++ b/examples/backward_compatibility/replay.py @@ -81,24 +81,25 @@ def replay(cfg: ReplayConfig): actions = dataset.hf_dataset.select_columns(ACTION) robot.connect() - log_say("Replaying episode", cfg.play_sounds, blocking=True) - for idx in range(dataset.num_frames): - start_episode_t = time.perf_counter() + try: + log_say("Replaying episode", cfg.play_sounds, blocking=True) + for idx in range(dataset.num_frames): + start_episode_t = time.perf_counter() - action_array = actions[idx][ACTION] - action = {} - for i, name in enumerate(dataset.features[ACTION]["names"]): - key = f"{name.removeprefix('main_')}.pos" - action[key] = action_array[i].item() + action_array = actions[idx][ACTION] + action = {} + for i, name in enumerate(dataset.features[ACTION]["names"]): + key = f"{name.removeprefix('main_')}.pos" + action[key] = action_array[i].item() - action["shoulder_lift.pos"] = -(action["shoulder_lift.pos"] - 90) - action["elbow_flex.pos"] -= 90 - robot.send_action(action) + action["shoulder_lift.pos"] = -(action["shoulder_lift.pos"] - 90) + action["elbow_flex.pos"] -= 90 + robot.send_action(action) - dt_s = time.perf_counter() - start_episode_t - precise_sleep(max(1 / dataset.fps - dt_s, 0.0)) - - robot.disconnect() + dt_s = time.perf_counter() - start_episode_t + precise_sleep(max(1 / dataset.fps - dt_s, 0.0)) + finally: + robot.disconnect() if __name__ == "__main__": diff --git a/examples/lekiwi/evaluate.py b/examples/lekiwi/evaluate.py index 2f7f9f95f..a3144a442 100644 --- a/examples/lekiwi/evaluate.py +++ b/examples/lekiwi/evaluate.py @@ -78,40 +78,24 @@ def main(): listener, events = init_keyboard_listener() init_rerun(session_name="lekiwi_evaluate") - if not robot.is_connected: - raise ValueError("Robot is not connected!") + try: + if not robot.is_connected: + raise ValueError("Robot is not connected!") - print("Starting evaluate loop...") - recorded_episodes = 0 - while recorded_episodes < NUM_EPISODES and not events["stop_recording"]: - log_say(f"Running inference, recording eval episode {recorded_episodes} of {NUM_EPISODES}") + print("Starting evaluate loop...") + recorded_episodes = 0 + while recorded_episodes < NUM_EPISODES and not events["stop_recording"]: + log_say(f"Running inference, recording eval episode {recorded_episodes} of {NUM_EPISODES}") - # Main record loop - record_loop( - robot=robot, - events=events, - fps=FPS, - policy=policy, - preprocessor=preprocessor, # Pass the pre and post policy processors - postprocessor=postprocessor, - dataset=dataset, - control_time_s=EPISODE_TIME_SEC, - single_task=TASK_DESCRIPTION, - display_data=True, - teleop_action_processor=teleop_action_processor, - robot_action_processor=robot_action_processor, - robot_observation_processor=robot_observation_processor, - ) - - # Reset the environment if not stopping or re-recording - if not events["stop_recording"] and ( - (recorded_episodes < NUM_EPISODES - 1) or events["rerecord_episode"] - ): - log_say("Reset the environment") + # Main record loop record_loop( robot=robot, events=events, fps=FPS, + policy=policy, + preprocessor=preprocessor, # Pass the pre and post policy processors + postprocessor=postprocessor, + dataset=dataset, control_time_s=EPISODE_TIME_SEC, single_task=TASK_DESCRIPTION, display_data=True, @@ -120,24 +104,42 @@ def main(): robot_observation_processor=robot_observation_processor, ) - if events["rerecord_episode"]: - log_say("Re-record episode") - events["rerecord_episode"] = False - events["exit_early"] = False - dataset.clear_episode_buffer() - continue + # Reset the environment if not stopping or re-recording + if not events["stop_recording"] and ( + (recorded_episodes < NUM_EPISODES - 1) or events["rerecord_episode"] + ): + log_say("Reset the environment") + record_loop( + robot=robot, + events=events, + fps=FPS, + control_time_s=EPISODE_TIME_SEC, + single_task=TASK_DESCRIPTION, + display_data=True, + teleop_action_processor=teleop_action_processor, + robot_action_processor=robot_action_processor, + robot_observation_processor=robot_observation_processor, + ) - # Save episode - dataset.save_episode() - recorded_episodes += 1 + if events["rerecord_episode"]: + log_say("Re-record episode") + events["rerecord_episode"] = False + events["exit_early"] = False + dataset.clear_episode_buffer() + continue - # Clean up - log_say("Stop recording") - robot.disconnect() - listener.stop() + # Save episode + dataset.save_episode() + recorded_episodes += 1 - dataset.finalize() - dataset.push_to_hub() + finally: + # Clean up + log_say("Stop recording") + robot.disconnect() + listener.stop() + + dataset.finalize() + dataset.push_to_hub() if __name__ == "__main__": diff --git a/examples/lekiwi/record.py b/examples/lekiwi/record.py index 18b9f857e..9292157f7 100644 --- a/examples/lekiwi/record.py +++ b/examples/lekiwi/record.py @@ -74,40 +74,23 @@ def main(): listener, events = init_keyboard_listener() init_rerun(session_name="lekiwi_record") - if not robot.is_connected or not leader_arm.is_connected or not keyboard.is_connected: - raise ValueError("Robot or teleop is not connected!") + try: + if not robot.is_connected or not leader_arm.is_connected or not keyboard.is_connected: + raise ValueError("Robot or teleop is not connected!") - print("Starting record loop...") - recorded_episodes = 0 - while recorded_episodes < NUM_EPISODES and not events["stop_recording"]: - log_say(f"Recording episode {recorded_episodes}") + print("Starting record loop...") + recorded_episodes = 0 + while recorded_episodes < NUM_EPISODES and not events["stop_recording"]: + log_say(f"Recording episode {recorded_episodes}") - # Main record loop - record_loop( - robot=robot, - events=events, - fps=FPS, - dataset=dataset, - teleop=[leader_arm, keyboard], - control_time_s=EPISODE_TIME_SEC, - single_task=TASK_DESCRIPTION, - display_data=True, - teleop_action_processor=teleop_action_processor, - robot_action_processor=robot_action_processor, - robot_observation_processor=robot_observation_processor, - ) - - # Reset the environment if not stopping or re-recording - if not events["stop_recording"] and ( - (recorded_episodes < NUM_EPISODES - 1) or events["rerecord_episode"] - ): - log_say("Reset the environment") + # Main record loop record_loop( robot=robot, events=events, fps=FPS, + dataset=dataset, teleop=[leader_arm, keyboard], - control_time_s=RESET_TIME_SEC, + control_time_s=EPISODE_TIME_SEC, single_task=TASK_DESCRIPTION, display_data=True, teleop_action_processor=teleop_action_processor, @@ -115,26 +98,44 @@ def main(): robot_observation_processor=robot_observation_processor, ) - if events["rerecord_episode"]: - log_say("Re-record episode") - events["rerecord_episode"] = False - events["exit_early"] = False - dataset.clear_episode_buffer() - continue + # Reset the environment if not stopping or re-recording + if not events["stop_recording"] and ( + (recorded_episodes < NUM_EPISODES - 1) or events["rerecord_episode"] + ): + log_say("Reset the environment") + record_loop( + robot=robot, + events=events, + fps=FPS, + teleop=[leader_arm, keyboard], + control_time_s=RESET_TIME_SEC, + single_task=TASK_DESCRIPTION, + display_data=True, + teleop_action_processor=teleop_action_processor, + robot_action_processor=robot_action_processor, + robot_observation_processor=robot_observation_processor, + ) - # Save episode - dataset.save_episode() - recorded_episodes += 1 + if events["rerecord_episode"]: + log_say("Re-record episode") + events["rerecord_episode"] = False + events["exit_early"] = False + dataset.clear_episode_buffer() + continue - # Clean up - log_say("Stop recording") - robot.disconnect() - leader_arm.disconnect() - keyboard.disconnect() - listener.stop() + # Save episode + dataset.save_episode() + recorded_episodes += 1 + finally: + # Clean up + log_say("Stop recording") + robot.disconnect() + leader_arm.disconnect() + keyboard.disconnect() + listener.stop() - dataset.finalize() - dataset.push_to_hub() + dataset.finalize() + dataset.push_to_hub() if __name__ == "__main__": diff --git a/examples/lekiwi/replay.py b/examples/lekiwi/replay.py index 872dacf27..cf89aea16 100644 --- a/examples/lekiwi/replay.py +++ b/examples/lekiwi/replay.py @@ -42,25 +42,27 @@ def main(): # Connect to the robot robot.connect() - if not robot.is_connected: - raise ValueError("Robot is not connected!") + try: + if not robot.is_connected: + raise ValueError("Robot is not connected!") - print("Starting replay loop...") - log_say(f"Replaying episode {EPISODE_IDX}") - for idx in range(len(episode_frames)): - t0 = time.perf_counter() + print("Starting replay loop...") + log_say(f"Replaying episode {EPISODE_IDX}") + for idx in range(len(episode_frames)): + t0 = time.perf_counter() - # Get recorded action from dataset - action = { - name: float(actions[idx][ACTION][i]) for i, name in enumerate(dataset.features[ACTION]["names"]) - } + # Get recorded action from dataset + action = { + name: float(actions[idx][ACTION][i]) + for i, name in enumerate(dataset.features[ACTION]["names"]) + } - # Send action to robot - _ = robot.send_action(action) + # Send action to robot + _ = robot.send_action(action) - precise_sleep(max(1.0 / dataset.fps - (time.perf_counter() - t0), 0.0)) - - robot.disconnect() + precise_sleep(max(1.0 / dataset.fps - (time.perf_counter() - t0), 0.0)) + finally: + robot.disconnect() if __name__ == "__main__": diff --git a/examples/phone_to_so100/evaluate.py b/examples/phone_to_so100/evaluate.py index 246c923aa..837217eda 100644 --- a/examples/phone_to_so100/evaluate.py +++ b/examples/phone_to_so100/evaluate.py @@ -142,38 +142,24 @@ def main(): listener, events = init_keyboard_listener() init_rerun(session_name="phone_so100_evaluate") - if not robot.is_connected: - raise ValueError("Robot is not connected!") + try: + if not robot.is_connected: + raise ValueError("Robot is not connected!") - print("Starting evaluate loop...") - episode_idx = 0 - for episode_idx in range(NUM_EPISODES): - log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}") + print("Starting evaluate loop...") + episode_idx = 0 + for episode_idx in range(NUM_EPISODES): + log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}") - # Main record loop - record_loop( - robot=robot, - events=events, - fps=FPS, - policy=policy, - preprocessor=preprocessor, # Pass the pre and post policy processors - postprocessor=postprocessor, - dataset=dataset, - control_time_s=EPISODE_TIME_SEC, - single_task=TASK_DESCRIPTION, - display_data=True, - teleop_action_processor=make_default_teleop_action_processor(), - robot_action_processor=robot_ee_to_joints_processor, - robot_observation_processor=robot_joints_to_ee_pose_processor, - ) - - # Reset the environment if not stopping or re-recording - if not events["stop_recording"] and ((episode_idx < NUM_EPISODES - 1) or events["rerecord_episode"]): - log_say("Reset the environment") + # Main record loop record_loop( robot=robot, events=events, fps=FPS, + policy=policy, + preprocessor=preprocessor, # Pass the pre and post policy processors + postprocessor=postprocessor, + dataset=dataset, control_time_s=EPISODE_TIME_SEC, single_task=TASK_DESCRIPTION, display_data=True, @@ -182,24 +168,41 @@ def main(): robot_observation_processor=robot_joints_to_ee_pose_processor, ) - if events["rerecord_episode"]: - log_say("Re-record episode") - events["rerecord_episode"] = False - events["exit_early"] = False - dataset.clear_episode_buffer() - continue + # Reset the environment if not stopping or re-recording + if not events["stop_recording"] and ( + (episode_idx < NUM_EPISODES - 1) or events["rerecord_episode"] + ): + log_say("Reset the environment") + record_loop( + robot=robot, + events=events, + fps=FPS, + control_time_s=EPISODE_TIME_SEC, + single_task=TASK_DESCRIPTION, + display_data=True, + teleop_action_processor=make_default_teleop_action_processor(), + robot_action_processor=robot_ee_to_joints_processor, + robot_observation_processor=robot_joints_to_ee_pose_processor, + ) - # Save episode - dataset.save_episode() - episode_idx += 1 + if events["rerecord_episode"]: + log_say("Re-record episode") + events["rerecord_episode"] = False + events["exit_early"] = False + dataset.clear_episode_buffer() + continue - # Clean up - log_say("Stop recording") - robot.disconnect() - listener.stop() + # Save episode + dataset.save_episode() + episode_idx += 1 + finally: + # Clean up + log_say("Stop recording") + robot.disconnect() + listener.stop() - dataset.finalize() - dataset.push_to_hub() + dataset.finalize() + dataset.push_to_hub() if __name__ == "__main__": diff --git a/examples/phone_to_so100/record.py b/examples/phone_to_so100/record.py index 7b5b704e2..1f5005db9 100644 --- a/examples/phone_to_so100/record.py +++ b/examples/phone_to_so100/record.py @@ -149,38 +149,23 @@ def main(): listener, events = init_keyboard_listener() init_rerun(session_name="phone_so100_record") - if not robot.is_connected or not phone.is_connected: - raise ValueError("Robot or teleop is not connected!") + try: + if not robot.is_connected or not phone.is_connected: + raise ValueError("Robot or teleop is not connected!") - print("Starting record loop. Move your phone to teleoperate the robot...") - episode_idx = 0 - while episode_idx < NUM_EPISODES and not events["stop_recording"]: - log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}") + print("Starting record loop. Move your phone to teleoperate the robot...") + episode_idx = 0 + while episode_idx < NUM_EPISODES and not events["stop_recording"]: + log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}") - # Main record loop - record_loop( - robot=robot, - events=events, - fps=FPS, - teleop=phone, - dataset=dataset, - control_time_s=EPISODE_TIME_SEC, - single_task=TASK_DESCRIPTION, - display_data=True, - teleop_action_processor=phone_to_robot_ee_pose_processor, - robot_action_processor=robot_ee_to_joints_processor, - robot_observation_processor=robot_joints_to_ee_pose, - ) - - # Reset the environment if not stopping or re-recording - if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]): - log_say("Reset the environment") + # Main record loop record_loop( robot=robot, events=events, fps=FPS, teleop=phone, - control_time_s=RESET_TIME_SEC, + dataset=dataset, + control_time_s=EPISODE_TIME_SEC, single_task=TASK_DESCRIPTION, display_data=True, teleop_action_processor=phone_to_robot_ee_pose_processor, @@ -188,25 +173,43 @@ def main(): robot_observation_processor=robot_joints_to_ee_pose, ) - if events["rerecord_episode"]: - log_say("Re-recording episode") - events["rerecord_episode"] = False - events["exit_early"] = False - dataset.clear_episode_buffer() - continue + # Reset the environment if not stopping or re-recording + if not events["stop_recording"] and ( + episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"] + ): + log_say("Reset the environment") + record_loop( + robot=robot, + events=events, + fps=FPS, + teleop=phone, + control_time_s=RESET_TIME_SEC, + single_task=TASK_DESCRIPTION, + display_data=True, + teleop_action_processor=phone_to_robot_ee_pose_processor, + robot_action_processor=robot_ee_to_joints_processor, + robot_observation_processor=robot_joints_to_ee_pose, + ) - # Save episode - dataset.save_episode() - episode_idx += 1 + if events["rerecord_episode"]: + log_say("Re-recording episode") + events["rerecord_episode"] = False + events["exit_early"] = False + dataset.clear_episode_buffer() + continue - # Clean up - log_say("Stop recording") - robot.disconnect() - phone.disconnect() - listener.stop() + # Save episode + dataset.save_episode() + episode_idx += 1 + finally: + # Clean up + log_say("Stop recording") + robot.disconnect() + phone.disconnect() + listener.stop() - dataset.finalize() - dataset.push_to_hub() + dataset.finalize() + dataset.push_to_hub() if __name__ == "__main__": diff --git a/examples/phone_to_so100/replay.py b/examples/phone_to_so100/replay.py index 875025dfc..9d7806cf4 100644 --- a/examples/phone_to_so100/replay.py +++ b/examples/phone_to_so100/replay.py @@ -73,32 +73,34 @@ def main(): # Connect to the robot robot.connect() - if not robot.is_connected: - raise ValueError("Robot is not connected!") + try: + if not robot.is_connected: + raise ValueError("Robot is not connected!") - print("Starting replay loop...") - log_say(f"Replaying episode {EPISODE_IDX}") - for idx in range(len(episode_frames)): - t0 = time.perf_counter() + print("Starting replay loop...") + log_say(f"Replaying episode {EPISODE_IDX}") + for idx in range(len(episode_frames)): + t0 = time.perf_counter() - # Get recorded action from dataset - ee_action = { - name: float(actions[idx][ACTION][i]) for i, name in enumerate(dataset.features[ACTION]["names"]) - } + # Get recorded action from dataset + ee_action = { + name: float(actions[idx][ACTION][i]) + for i, name in enumerate(dataset.features[ACTION]["names"]) + } - # Get robot observation - robot_obs = robot.get_observation() + # Get robot observation + robot_obs = robot.get_observation() - # Dataset EE -> robot joints - joint_action = robot_ee_to_joints_processor((ee_action, robot_obs)) + # Dataset EE -> robot joints + joint_action = robot_ee_to_joints_processor((ee_action, robot_obs)) - # Send action to robot - _ = robot.send_action(joint_action) + # Send action to robot + _ = robot.send_action(joint_action) - precise_sleep(max(1.0 / dataset.fps - (time.perf_counter() - t0), 0.0)) - - # Clean up - robot.disconnect() + precise_sleep(max(1.0 / dataset.fps - (time.perf_counter() - t0), 0.0)) + finally: + # Clean up + robot.disconnect() if __name__ == "__main__": diff --git a/examples/so100_to_so100_EE/evaluate.py b/examples/so100_to_so100_EE/evaluate.py index 87d188f99..b614b89f2 100644 --- a/examples/so100_to_so100_EE/evaluate.py +++ b/examples/so100_to_so100_EE/evaluate.py @@ -142,38 +142,24 @@ def main(): listener, events = init_keyboard_listener() init_rerun(session_name="so100_so100_evaluate") - if not robot.is_connected: - raise ValueError("Robot is not connected!") + try: + if not robot.is_connected: + raise ValueError("Robot is not connected!") - print("Starting evaluate loop...") - episode_idx = 0 - for episode_idx in range(NUM_EPISODES): - log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}") + print("Starting evaluate loop...") + episode_idx = 0 + for episode_idx in range(NUM_EPISODES): + log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}") - # Main record loop - record_loop( - robot=robot, - events=events, - fps=FPS, - policy=policy, - preprocessor=preprocessor, # Pass the pre and post policy processors - postprocessor=postprocessor, - dataset=dataset, - control_time_s=EPISODE_TIME_SEC, - single_task=TASK_DESCRIPTION, - display_data=True, - teleop_action_processor=make_default_teleop_action_processor(), - robot_action_processor=robot_ee_to_joints_processor, - robot_observation_processor=robot_joints_to_ee_pose_processor, - ) - - # Reset the environment if not stopping or re-recording - if not events["stop_recording"] and ((episode_idx < NUM_EPISODES - 1) or events["rerecord_episode"]): - log_say("Reset the environment") + # Main record loop record_loop( robot=robot, events=events, fps=FPS, + policy=policy, + preprocessor=preprocessor, # Pass the pre and post policy processors + postprocessor=postprocessor, + dataset=dataset, control_time_s=EPISODE_TIME_SEC, single_task=TASK_DESCRIPTION, display_data=True, @@ -182,24 +168,41 @@ def main(): robot_observation_processor=robot_joints_to_ee_pose_processor, ) - if events["rerecord_episode"]: - log_say("Re-record episode") - events["rerecord_episode"] = False - events["exit_early"] = False - dataset.clear_episode_buffer() - continue + # Reset the environment if not stopping or re-recording + if not events["stop_recording"] and ( + (episode_idx < NUM_EPISODES - 1) or events["rerecord_episode"] + ): + log_say("Reset the environment") + record_loop( + robot=robot, + events=events, + fps=FPS, + control_time_s=EPISODE_TIME_SEC, + single_task=TASK_DESCRIPTION, + display_data=True, + teleop_action_processor=make_default_teleop_action_processor(), + robot_action_processor=robot_ee_to_joints_processor, + robot_observation_processor=robot_joints_to_ee_pose_processor, + ) - # Save episode - dataset.save_episode() - episode_idx += 1 + if events["rerecord_episode"]: + log_say("Re-record episode") + events["rerecord_episode"] = False + events["exit_early"] = False + dataset.clear_episode_buffer() + continue - # Clean up - log_say("Stop recording") - robot.disconnect() - listener.stop() + # Save episode + dataset.save_episode() + episode_idx += 1 + finally: + # Clean up + log_say("Stop recording") + robot.disconnect() + listener.stop() - dataset.finalize() - dataset.push_to_hub() + dataset.finalize() + dataset.push_to_hub() if __name__ == "__main__": diff --git a/examples/so100_to_so100_EE/record.py b/examples/so100_to_so100_EE/record.py index eead7a9a8..d85a1c5cc 100644 --- a/examples/so100_to_so100_EE/record.py +++ b/examples/so100_to_so100_EE/record.py @@ -146,38 +146,23 @@ def main(): listener, events = init_keyboard_listener() init_rerun(session_name="recording_phone") - if not leader.is_connected or not follower.is_connected: - raise ValueError("Robot or teleop is not connected!") + try: + if not leader.is_connected or not follower.is_connected: + raise ValueError("Robot or teleop is not connected!") - print("Starting record loop...") - episode_idx = 0 - while episode_idx < NUM_EPISODES and not events["stop_recording"]: - log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}") + print("Starting record loop...") + episode_idx = 0 + while episode_idx < NUM_EPISODES and not events["stop_recording"]: + log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}") - # Main record loop - record_loop( - robot=follower, - events=events, - fps=FPS, - teleop=leader, - dataset=dataset, - control_time_s=EPISODE_TIME_SEC, - single_task=TASK_DESCRIPTION, - display_data=True, - teleop_action_processor=leader_joints_to_ee, - robot_action_processor=ee_to_follower_joints, - robot_observation_processor=follower_joints_to_ee, - ) - - # Reset the environment if not stopping or re-recording - if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]): - log_say("Reset the environment") + # Main record loop record_loop( robot=follower, events=events, fps=FPS, teleop=leader, - control_time_s=RESET_TIME_SEC, + dataset=dataset, + control_time_s=EPISODE_TIME_SEC, single_task=TASK_DESCRIPTION, display_data=True, teleop_action_processor=leader_joints_to_ee, @@ -185,25 +170,44 @@ def main(): robot_observation_processor=follower_joints_to_ee, ) - if events["rerecord_episode"]: - log_say("Re-recording episode") - events["rerecord_episode"] = False - events["exit_early"] = False - dataset.clear_episode_buffer() - continue + # Reset the environment if not stopping or re-recording + if not events["stop_recording"] and ( + episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"] + ): + log_say("Reset the environment") + record_loop( + robot=follower, + events=events, + fps=FPS, + teleop=leader, + control_time_s=RESET_TIME_SEC, + single_task=TASK_DESCRIPTION, + display_data=True, + teleop_action_processor=leader_joints_to_ee, + robot_action_processor=ee_to_follower_joints, + robot_observation_processor=follower_joints_to_ee, + ) - # Save episode - dataset.save_episode() - episode_idx += 1 + if events["rerecord_episode"]: + log_say("Re-recording episode") + events["rerecord_episode"] = False + events["exit_early"] = False + dataset.clear_episode_buffer() + continue - # Clean up - log_say("Stop recording") - leader.disconnect() - follower.disconnect() - listener.stop() + # Save episode + dataset.save_episode() + episode_idx += 1 - dataset.finalize() - dataset.push_to_hub() + finally: + # Clean up + log_say("Stop recording") + leader.disconnect() + follower.disconnect() + listener.stop() + + dataset.finalize() + dataset.push_to_hub() if __name__ == "__main__": diff --git a/examples/so100_to_so100_EE/replay.py b/examples/so100_to_so100_EE/replay.py index 7d35a7b44..47a2f6635 100644 --- a/examples/so100_to_so100_EE/replay.py +++ b/examples/so100_to_so100_EE/replay.py @@ -74,32 +74,35 @@ def main(): # Connect to the robot robot.connect() - if not robot.is_connected: - raise ValueError("Robot is not connected!") + try: + if not robot.is_connected: + raise ValueError("Robot is not connected!") - print("Starting replay loop...") - log_say(f"Replaying episode {EPISODE_IDX}") - for idx in range(len(episode_frames)): - t0 = time.perf_counter() + print("Starting replay loop...") + log_say(f"Replaying episode {EPISODE_IDX}") + for idx in range(len(episode_frames)): + t0 = time.perf_counter() - # Get recorded action from dataset - ee_action = { - name: float(actions[idx][ACTION][i]) for i, name in enumerate(dataset.features[ACTION]["names"]) - } + # Get recorded action from dataset + ee_action = { + name: float(actions[idx][ACTION][i]) + for i, name in enumerate(dataset.features[ACTION]["names"]) + } - # Get robot observation - robot_obs = robot.get_observation() + # Get robot observation + robot_obs = robot.get_observation() - # Dataset EE -> robot joints - joint_action = robot_ee_to_joints_processor((ee_action, robot_obs)) + # Dataset EE -> robot joints + joint_action = robot_ee_to_joints_processor((ee_action, robot_obs)) - # Send action to robot - _ = robot.send_action(joint_action) + # Send action to robot + _ = robot.send_action(joint_action) - precise_sleep(max(1.0 / dataset.fps - (time.perf_counter() - t0), 0.0)) + precise_sleep(max(1.0 / dataset.fps - (time.perf_counter() - t0), 0.0)) - # Clean up - robot.disconnect() + finally: + # Clean up + robot.disconnect() if __name__ == "__main__": diff --git a/src/lerobot/cameras/camera.py b/src/lerobot/cameras/camera.py index bfdb571a7..2894e0215 100644 --- a/src/lerobot/cameras/camera.py +++ b/src/lerobot/cameras/camera.py @@ -15,11 +15,12 @@ # limitations under the License. import abc +import warnings from typing import Any from numpy.typing import NDArray # type: ignore # TODO: add type stubs for numpy.typing -from .configs import CameraConfig, ColorMode +from .configs import CameraConfig class Camera(abc.ABC): @@ -30,20 +31,12 @@ class Camera(abc.ABC): Manages basic camera properties (FPS, resolution) and core operations: - Connection/disconnection - - Frame capture (sync/async) + - Frame capture (sync/async/latest) Attributes: fps (int | None): Configured frames per second width (int | None): Frame width in pixels height (int | None): Frame height in pixels - - Example: - class MyCamera(Camera): - def __init__(self, config): ... - @property - def is_connected(self) -> bool: ... - def connect(self, warmup=True): ... - # Plus other required methods """ def __init__(self, config: CameraConfig): @@ -56,6 +49,32 @@ class Camera(abc.ABC): self.width: int | None = config.width self.height: int | None = config.height + def __enter__(self): + """ + Context manager entry. + Automatically connects to the camera. + """ + self.connect() + return self + + def __exit__(self, exc_type, exc_value, traceback) -> None: + """ + Context manager exit. + Automatically disconnects, ensuring resources are released even on error. + """ + self.disconnect() + + def __del__(self) -> None: + """ + Destructor safety net. + Attempts to disconnect if the object is garbage collected without cleanup. + """ + try: + if self.is_connected: + self.disconnect() + except Exception: # nosec B110 + pass + @property @abc.abstractmethod def is_connected(self) -> bool: @@ -89,12 +108,10 @@ class Camera(abc.ABC): pass @abc.abstractmethod - def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]: - """Capture and return a single frame from the camera. + def read(self) -> NDArray[Any]: + """Capture and return a single frame from the camera synchronously. - Args: - color_mode: Desired color mode for the output frame. If None, - uses the camera's default color mode. + This is a blocking call that will wait for the hardware and its SDK. Returns: np.ndarray: Captured frame as a numpy array. @@ -103,17 +120,64 @@ class Camera(abc.ABC): @abc.abstractmethod def async_read(self, timeout_ms: float = ...) -> NDArray[Any]: - """Asynchronously capture and return a single frame from the camera. + """Return the most recent new frame. + + This method retrieves the latest frame captured by the background thread. + If a new frame is already available in the buffer (captured since the last call), + it returns it immediately. + + It blocks up to `timeout_ms` only if the buffer is empty or if the latest frame + was already consumed by a previous `async_read` call. + + Essentially, this method return the latest unconsumed frame, waiting if necessary + for a new one to arrive within the specified timeout. + + Usage: + - Ideal for control loops where you want to ensure every processed frame + is fresh, effectively synchronizing your loop to the camera's FPS. + - Causes of a timeout usually include: very low camera FPS, heavy processing load, + or if the camera is disconnected. Args: - timeout_ms: Maximum time to wait for a frame in milliseconds. - Defaults to implementation-specific timeout. + timeout_ms: Maximum time to wait for a new frame in milliseconds. + Defaults to 200ms (0.2s). Returns: np.ndarray: Captured frame as a numpy array. + + Raises: + TimeoutError: If no new frame arrives within `timeout_ms`. """ pass + def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]: + """Return the most recent frame captured immediately (Peeking). + + This method is non-blocking and returns whatever is currently in the + memory buffer. The frame may be stale, + meaning it could have been captured a while ago (hanging camera scenario e.g.). + + Usage: + Ideal for scenarios requiring zero latency or decoupled frequencies & when + we want a guaranteed frame, such as UI visualization, logging, or + non-critical monitoring. + + Returns: + NDArray[Any]: The frame image (numpy array). + + Raises: + TimeoutError: If the latest frame is older than `max_age_ms`. + NotConnectedError: If the camera is not connected. + RuntimeError: If the camera is connected but has not captured any frames yet. + """ + warnings.warn( + f"{self.__class__.__name__}.read_latest() is not implemented. " + "Please override read_latest(); it will be required in future releases.", + FutureWarning, + stacklevel=2, + ) + return self.async_read() + @abc.abstractmethod def disconnect(self) -> None: """Disconnect from the camera and release resources.""" diff --git a/src/lerobot/cameras/opencv/camera_opencv.py b/src/lerobot/cameras/opencv/camera_opencv.py index b1043ba64..d581e1425 100644 --- a/src/lerobot/cameras/opencv/camera_opencv.py +++ b/src/lerobot/cameras/opencv/camera_opencv.py @@ -70,34 +70,24 @@ class OpenCVCamera(Camera): Example: ```python from lerobot.cameras.opencv import OpenCVCamera - from lerobot.cameras.configuration_opencv import OpenCVCameraConfig, ColorMode, Cv2Rotation + from lerobot.cameras.configuration_opencv import OpenCVCameraConfig # Basic usage with camera index 0 config = OpenCVCameraConfig(index_or_path=0) camera = OpenCVCamera(config) camera.connect() - # Read 1 frame synchronously + # Read 1 frame synchronously (blocking) color_image = camera.read() - print(color_image.shape) - # Read 1 frame asynchronously + # Read 1 frame asynchronously (waits for new frame with a timeout) async_image = camera.async_read() + # Get the latest frame immediately (no wait, returns timestamp) + latest_image, timestamp = camera.read_latest() + # When done, properly disconnect the camera using camera.disconnect() - - # Example with custom settings - custom_config = OpenCVCameraConfig( - index_or_path='/dev/video0', # Or use an index - fps=30, - width=1280, - height=720, - color_mode=ColorMode.RGB, - rotation=Cv2Rotation.ROTATE_90 - ) - custom_camera = OpenCVCamera(custom_config) - # ... connect, read, disconnect ... ``` """ @@ -123,6 +113,7 @@ class OpenCVCamera(Camera): self.stop_event: Event | None = None self.frame_lock: Lock = Lock() self.latest_frame: NDArray[Any] | None = None + self.latest_timestamp: float | None = None self.new_frame_event: Event = Event() self.rotation: int | None = get_cv2_rotation(config.rotation) @@ -146,12 +137,16 @@ class OpenCVCamera(Camera): Connects to the OpenCV camera specified in the configuration. Initializes the OpenCV VideoCapture object, sets desired camera properties - (FPS, width, height), and performs initial checks. + (FPS, width, height), starts the background reading thread and performs initial checks. + + Args: + warmup (bool): If True, waits at connect() time until at least one valid frame + has been captured by the background thread. Defaults to True. Raises: DeviceAlreadyConnectedError: If the camera is already connected. - ConnectionError: If the specified camera index/path is not found or the camera is found but fails to open. - RuntimeError: If the camera opens but fails to apply requested FPS/resolution settings. + ConnectionError: If the specified camera index/path is not found or fails to open. + RuntimeError: If the camera opens but fails to apply requested settings. """ if self.is_connected: raise DeviceAlreadyConnectedError(f"{self} is already connected.") @@ -170,12 +165,16 @@ class OpenCVCamera(Camera): ) self._configure_capture_settings() + self._start_read_thread() - if warmup: + if warmup and self.warmup_s > 0: start_time = time.time() while time.time() - start_time < self.warmup_s: - self.read() + self.async_read(timeout_ms=self.warmup_s * 1000) time.sleep(0.1) + with self.frame_lock: + if self.latest_frame is None: + raise ConnectionError(f"{self} failed to capture frames during warmup.") logger.info(f"{self} connected.") @@ -196,8 +195,7 @@ class OpenCVCamera(Camera): Raises: RuntimeError: If the camera fails to set any of the specified properties to the requested value. - DeviceNotConnectedError: If the camera is not connected when attempting - to configure settings. + DeviceNotConnectedError: If the camera is not connected. """ if not self.is_connected: raise DeviceNotConnectedError(f"Cannot configure settings for {self} as it is not connected.") @@ -339,6 +337,17 @@ class OpenCVCamera(Camera): return found_cameras_info + def _read_from_hardware(self) -> NDArray[Any]: + if self.videocapture is None: + raise DeviceNotConnectedError(f"{self} videocapture is not initialized") + + ret, frame = self.videocapture.read() + + if not ret: + raise RuntimeError(f"{self} read failed (status={ret}).") + + return frame + def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]: """ Reads a single frame synchronously from the camera. @@ -346,11 +355,6 @@ class OpenCVCamera(Camera): This is a blocking call. It waits for the next available frame from the camera hardware via OpenCV. - Args: - color_mode (Optional[ColorMode]): If specified, overrides the default - color mode (`self.color_mode`) for this read operation (e.g., - request RGB even if default is BGR). - Returns: np.ndarray: The captured frame as a NumPy array in the format (height, width, channels), using the specified or default @@ -362,34 +366,34 @@ class OpenCVCamera(Camera): received frame dimensions don't match expectations before rotation. ValueError: If an invalid `color_mode` is requested. """ - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") start_time = time.perf_counter() - if self.videocapture is None: - raise DeviceNotConnectedError(f"{self} videocapture is not initialized") + if color_mode is not None: + logger.warning( + f"{self} read() color_mode parameter is deprecated and will be removed in future versions." + ) - ret, frame = self.videocapture.read() + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") - if not ret or frame is None: - raise RuntimeError(f"{self} read failed (status={ret}).") + if self.thread is None or not self.thread.is_alive(): + raise RuntimeError(f"{self} read thread is not running.") - processed_frame = self._postprocess_image(frame, color_mode) + self.new_frame_event.clear() + frame = self.async_read(timeout_ms=10000) read_duration_ms = (time.perf_counter() - start_time) * 1e3 logger.debug(f"{self} read took: {read_duration_ms:.1f}ms") - return processed_frame + return frame - def _postprocess_image(self, image: NDArray[Any], color_mode: ColorMode | None = None) -> NDArray[Any]: + def _postprocess_image(self, image: NDArray[Any]) -> NDArray[Any]: """ Applies color conversion, dimension validation, and rotation to a raw frame. Args: image (np.ndarray): The raw image frame (expected BGR format from OpenCV). - color_mode (Optional[ColorMode]): The target color mode (RGB or BGR). If None, - uses the instance's default `self.color_mode`. Returns: np.ndarray: The processed image frame. @@ -399,11 +403,10 @@ class OpenCVCamera(Camera): RuntimeError: If the raw frame dimensions do not match the configured `width` and `height`. """ - requested_color_mode = self.color_mode if color_mode is None else color_mode - if requested_color_mode not in (ColorMode.RGB, ColorMode.BGR): + if self.color_mode not in (ColorMode.RGB, ColorMode.BGR): raise ValueError( - f"Invalid color mode '{requested_color_mode}'. Expected {ColorMode.RGB} or {ColorMode.BGR}." + f"Invalid color mode '{self.color_mode}'. Expected {ColorMode.RGB} or {ColorMode.BGR}." ) h, w, c = image.shape @@ -417,7 +420,7 @@ class OpenCVCamera(Camera): raise RuntimeError(f"{self} frame channels={c} do not match expected 3 channels (RGB/BGR).") processed_image = image - if requested_color_mode == ColorMode.RGB: + if self.color_mode == ColorMode.RGB: processed_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) if self.rotation in [cv2.ROTATE_90_CLOCKWISE, cv2.ROTATE_90_COUNTERCLOCKWISE, cv2.ROTATE_180]: @@ -431,7 +434,7 @@ class OpenCVCamera(Camera): On each iteration: 1. Reads a color frame - 2. Stores result in latest_frame (thread-safe) + 2. Stores result in latest_frame and updates timestamp (thread-safe) 3. Sets new_frame_event to notify listeners Stops on DeviceNotConnectedError, logs other errors and continues. @@ -439,30 +442,37 @@ class OpenCVCamera(Camera): if self.stop_event is None: raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.") + failure_count = 0 while not self.stop_event.is_set(): try: - color_image = self.read() + raw_frame = self._read_from_hardware() + processed_frame = self._postprocess_image(raw_frame) + capture_time = time.perf_counter() with self.frame_lock: - self.latest_frame = color_image + self.latest_frame = processed_frame + self.latest_timestamp = capture_time self.new_frame_event.set() + failure_count = 0 except DeviceNotConnectedError: break except Exception as e: - logger.warning(f"Error reading frame in background thread for {self}: {e}") + if failure_count <= 10: + failure_count += 1 + logger.warning(f"Error reading frame in background thread for {self}: {e}") + else: + raise RuntimeError(f"{self} exceeded maximum consecutive read failures.") from e def _start_read_thread(self) -> None: """Starts or restarts the background read thread if it's not running.""" - if self.thread is not None and self.thread.is_alive(): - self.thread.join(timeout=0.1) - if self.stop_event is not None: - self.stop_event.set() + self._stop_read_thread() self.stop_event = Event() self.thread = Thread(target=self._read_loop, args=(), name=f"{self}_read_loop") self.thread.daemon = True self.thread.start() + time.sleep(0.1) def _stop_read_thread(self) -> None: """Signals the background read thread to stop and waits for it to join.""" @@ -475,6 +485,11 @@ class OpenCVCamera(Camera): self.thread = None self.stop_event = None + with self.frame_lock: + self.latest_frame = None + self.latest_timestamp = None + self.new_frame_event.clear() + def async_read(self, timeout_ms: float = 200) -> NDArray[Any]: """ Reads the latest available frame asynchronously. @@ -482,6 +497,7 @@ class OpenCVCamera(Camera): This method retrieves the most recent frame captured by the background read thread. It does not block waiting for the camera hardware directly, but may wait up to timeout_ms for the background thread to provide a frame. + It is “best effort” under high FPS. Args: timeout_ms (float): Maximum time in milliseconds to wait for a frame @@ -500,13 +516,12 @@ class OpenCVCamera(Camera): raise DeviceNotConnectedError(f"{self} is not connected.") if self.thread is None or not self.thread.is_alive(): - self._start_read_thread() + raise RuntimeError(f"{self} read thread is not running.") if not self.new_frame_event.wait(timeout=timeout_ms / 1000.0): - thread_alive = self.thread is not None and self.thread.is_alive() raise TimeoutError( f"Timed out waiting for frame from camera {self} after {timeout_ms} ms. " - f"Read thread alive: {thread_alive}." + f"Read thread alive: {self.thread.is_alive()}." ) with self.frame_lock: @@ -518,6 +533,42 @@ class OpenCVCamera(Camera): return frame + def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]: + """Return the most recent frame captured immediately (Peeking). + + This method is non-blocking and returns whatever is currently in the + memory buffer. The frame may be stale, + meaning it could have been captured a while ago (hanging camera scenario e.g.). + + Returns: + NDArray[Any]: The frame image (numpy array). + + Raises: + TimeoutError: If the latest frame is older than `max_age_ms`. + DeviceNotConnectedError: If the camera is not connected. + RuntimeError: If the camera is connected but has not captured any frames yet. + """ + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + if self.thread is None or not self.thread.is_alive(): + raise RuntimeError(f"{self} read thread is not running.") + + with self.frame_lock: + frame = self.latest_frame + timestamp = self.latest_timestamp + + if frame is None or timestamp is None: + raise RuntimeError(f"{self} has not captured any frames yet.") + + age_ms = (time.perf_counter() - timestamp) * 1e3 + if age_ms > max_age_ms: + raise TimeoutError( + f"{self} latest frame is too old: {age_ms:.1f} ms (max allowed: {max_age_ms} ms)." + ) + + return frame + def disconnect(self) -> None: """ Disconnects from the camera and cleans up resources. @@ -538,4 +589,9 @@ class OpenCVCamera(Camera): self.videocapture.release() self.videocapture = None + with self.frame_lock: + self.latest_frame = None + self.latest_timestamp = None + self.new_frame_event.clear() + logger.info(f"{self} disconnected.") diff --git a/src/lerobot/cameras/reachy2_camera/reachy2_camera.py b/src/lerobot/cameras/reachy2_camera/reachy2_camera.py index c8916c5ee..5cede466d 100644 --- a/src/lerobot/cameras/reachy2_camera/reachy2_camera.py +++ b/src/lerobot/cameras/reachy2_camera/reachy2_camera.py @@ -80,6 +80,8 @@ class Reachy2Camera(Camera): self.config = config self.color_mode = config.color_mode + self.latest_frame: NDArray[Any] | None = None + self.latest_timestamp: float | None = None self.cam_manager: CameraManager | None = None @@ -125,12 +127,7 @@ class Reachy2Camera(Camera): """ Reads a single frame synchronously from the camera. - This is a blocking call. - - Args: - color_mode (Optional[ColorMode]): If specified, overrides the default - color mode (`self.color_mode`) for this read operation (e.g., - request RGB even if default is BGR). + This method retrieves the most recent frame available in Reachy 2's low-level software. Returns: np.ndarray: The captured frame as a NumPy array in the format @@ -145,6 +142,11 @@ class Reachy2Camera(Camera): if self.cam_manager is None: raise DeviceNotConnectedError(f"{self} is not connected.") + if color_mode is not None: + logger.warning( + f"{self} read() color_mode parameter is deprecated and will be removed in future versions." + ) + frame: NDArray[Any] = np.empty((0, 0, 3), dtype=np.uint8) if self.config.name == "teleop" and hasattr(self.cam_manager, "teleop"): @@ -165,11 +167,18 @@ class Reachy2Camera(Camera): raise ValueError(f"Invalid camera name '{self.config.name}'. Expected 'teleop' or 'depth'.") if frame is None: - return np.empty((0, 0, 3), dtype=np.uint8) + raise RuntimeError(f"Internal error: No frame available for {self}.") - if self.config.color_mode == "rgb": + if self.color_mode not in (ColorMode.RGB, ColorMode.BGR): + raise ValueError( + f"Invalid color mode '{self.color_mode}'. Expected {ColorMode.RGB} or {ColorMode.BGR}." + ) + if self.color_mode == ColorMode.RGB: frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + self.latest_frame = frame + self.latest_timestamp = time.perf_counter() + read_duration_ms = (time.perf_counter() - start_time) * 1e3 logger.debug(f"{self} read took: {read_duration_ms:.1f}ms") @@ -177,13 +186,7 @@ class Reachy2Camera(Camera): def async_read(self, timeout_ms: float = 200) -> NDArray[Any]: """ - Reads the latest available frame. - - This method retrieves the most recent frame available in Reachy 2's low-level software. - - Args: - timeout_ms (float): Maximum time in milliseconds to wait for a frame - to become available. Defaults to 200ms (0.2 seconds). + Same as read() Returns: np.ndarray: The latest captured frame as a NumPy array in the format @@ -197,12 +200,38 @@ class Reachy2Camera(Camera): if not self.is_connected: raise DeviceNotConnectedError(f"{self} is not connected.") - frame = self.read() + return self.read() - if frame is None: - raise RuntimeError(f"Internal error: No frame available for {self}.") + def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]: + """Return the most recent frame captured immediately (Peeking). - return frame + This method is non-blocking and returns whatever is currently in the + memory buffer. The frame may be stale, + meaning it could have been captured a while ago (hanging camera scenario e.g.). + + Returns: + tuple[NDArray, float]: + - The frame image (numpy array). + - The timestamp (time.perf_counter) when this frame was captured. + + Raises: + TimeoutError: If the latest frame is older than `max_age_ms`. + DeviceNotConnectedError: If the camera is not connected. + RuntimeError: If the camera is connected but has not captured any frames yet. + """ + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + if self.latest_frame is None or self.latest_timestamp is None: + raise RuntimeError(f"{self} has not captured any frames yet.") + + age_ms = (time.perf_counter() - self.latest_timestamp) * 1e3 + if age_ms > max_age_ms: + raise TimeoutError( + f"{self} latest frame is too old: {age_ms:.1f} ms (max allowed: {max_age_ms} ms)." + ) + + return self.latest_frame def disconnect(self) -> None: """ diff --git a/src/lerobot/cameras/realsense/camera_realsense.py b/src/lerobot/cameras/realsense/camera_realsense.py index f2906fdd8..e47f25381 100644 --- a/src/lerobot/cameras/realsense/camera_realsense.py +++ b/src/lerobot/cameras/realsense/camera_realsense.py @@ -72,15 +72,14 @@ class RealSenseCamera(Camera): camera = RealSenseCamera(config) camera.connect() - # Read 1 frame synchronously + # Read 1 frame synchronously (blocking) color_image = camera.read() - print(color_image.shape) - # Read 1 frame asynchronously + # Read 1 frame asynchronously (waits for new frame with a timeout) async_image = camera.async_read() - # When done, properly disconnect the camera using - camera.disconnect() + # Get the latest frame immediately (no wait, returns timestamp) + latest_image, timestamp = camera.read_latest() # Example with depth capture and custom settings custom_config = RealSenseCameraConfig( @@ -133,7 +132,9 @@ class RealSenseCamera(Camera): self.thread: Thread | None = None self.stop_event: Event | None = None self.frame_lock: Lock = Lock() - self.latest_frame: NDArray[Any] | None = None + self.latest_color_frame: NDArray[Any] | None = None + self.latest_depth_frame: NDArray[Any] | None = None + self.latest_timestamp: float | None = None self.new_frame_event: Event = Event() self.rotation: int | None = get_cv2_rotation(config.rotation) @@ -158,6 +159,10 @@ class RealSenseCamera(Camera): Initializes the RealSense pipeline, configures the required streams (color and optionally depth), starts the pipeline, and validates the actual stream settings. + Args: + warmup (bool): If True, waits at connect() time until at least one valid frame + has been captured by the background thread. Defaults to True. + Raises: DeviceAlreadyConnectedError: If the camera is already connected. ValueError: If the configuration is invalid (e.g., missing serial/name, name not unique). @@ -181,15 +186,18 @@ class RealSenseCamera(Camera): ) from e self._configure_capture_settings() + self._start_read_thread() - if warmup: - time.sleep( - 1 - ) # NOTE(Steven): RS cameras need a bit of time to warm up before the first read. If we don't wait, the first read from the warmup will raise. - start_time = time.time() - while time.time() - start_time < self.warmup_s: - self.read() - time.sleep(0.1) + # NOTE(Steven/Caroline): Enforcing at least one second of warmup as RS cameras need a bit of time before the first read. If we don't wait, the first read from the warmup will raise. + self.warmup_s = max(self.warmup_s, 1) + + start_time = time.time() + while time.time() - start_time < self.warmup_s: + self.async_read(timeout_ms=self.warmup_s * 1000) + time.sleep(0.1) + with self.frame_lock: + if self.latest_color_frame is None or self.use_depth and self.latest_depth_frame is None: + raise ConnectionError(f"{self} failed to capture frames during warmup.") logger.info(f"{self} connected.") @@ -319,9 +327,6 @@ class RealSenseCamera(Camera): This is a blocking call. It waits for a coherent set of frames (depth) from the camera hardware via the RealSense pipeline. - Args: - timeout_ms (int): Maximum time in milliseconds to wait for a frame. Defaults to 200ms. - Returns: np.ndarray: The depth map as a NumPy array (height, width) of type `np.uint16` (raw depth values in millimeters) and rotation. @@ -330,44 +335,52 @@ class RealSenseCamera(Camera): DeviceNotConnectedError: If the camera is not connected. RuntimeError: If reading frames from the pipeline fails or frames are invalid. """ + if timeout_ms: + logger.warning( + f"{self} read() timeout_ms parameter is deprecated and will be removed in future versions." + ) - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") if not self.use_depth: raise RuntimeError( f"Failed to capture depth frame '.read_depth()'. Depth stream is not enabled for {self}." ) - start_time = time.perf_counter() + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + if self.thread is None or not self.thread.is_alive(): + raise RuntimeError(f"{self} read thread is not running.") + + self.new_frame_event.clear() + + _ = self.async_read(timeout_ms=10000) + + with self.frame_lock: + depth_map = self.latest_depth_frame + + if depth_map is None: + raise RuntimeError("No depth frame available. Ensure camera is streaming.") + + return depth_map + + def _read_from_hardware(self): if self.rs_pipeline is None: raise RuntimeError(f"{self}: rs_pipeline must be initialized before use.") - ret, frame = self.rs_pipeline.try_wait_for_frames(timeout_ms=timeout_ms) + ret, frame = self.rs_pipeline.try_wait_for_frames(timeout_ms=10000) if not ret or frame is None: - raise RuntimeError(f"{self} read_depth failed (status={ret}).") + raise RuntimeError(f"{self} read failed (status={ret}).") - depth_frame = frame.get_depth_frame() - depth_map = np.asanyarray(depth_frame.get_data()) + return frame - depth_map_processed = self._postprocess_image(depth_map, depth_frame=True) - - read_duration_ms = (time.perf_counter() - start_time) * 1e3 - logger.debug(f"{self} read took: {read_duration_ms:.1f}ms") - - return depth_map_processed - - def read(self, color_mode: ColorMode | None = None, timeout_ms: int = 200) -> NDArray[Any]: + def read(self, color_mode: ColorMode | None = None, timeout_ms: int = 0) -> NDArray[Any]: """ Reads a single frame (color) synchronously from the camera. This is a blocking call. It waits for a coherent set of frames (color) from the camera hardware via the RealSense pipeline. - Args: - timeout_ms (int): Maximum time in milliseconds to wait for a frame. Defaults to 200ms. - Returns: np.ndarray: The captured color frame as a NumPy array (height, width, channels), processed according to `color_mode` and rotation. @@ -378,39 +391,39 @@ class RealSenseCamera(Camera): ValueError: If an invalid `color_mode` is requested. """ + start_time = time.perf_counter() + + if color_mode is not None: + logger.warning( + f"{self} read() color_mode parameter is deprecated and will be removed in future versions." + ) + + if timeout_ms: + logger.warning( + f"{self} read() timeout_ms parameter is deprecated and will be removed in future versions." + ) + if not self.is_connected: raise DeviceNotConnectedError(f"{self} is not connected.") - start_time = time.perf_counter() + if self.thread is None or not self.thread.is_alive(): + raise RuntimeError(f"{self} read thread is not running.") - if self.rs_pipeline is None: - raise RuntimeError(f"{self}: rs_pipeline must be initialized before use.") + self.new_frame_event.clear() - ret, frame = self.rs_pipeline.try_wait_for_frames(timeout_ms=timeout_ms) - - if not ret or frame is None: - raise RuntimeError(f"{self} read failed (status={ret}).") - - color_frame = frame.get_color_frame() - color_image_raw = np.asanyarray(color_frame.get_data()) - - color_image_processed = self._postprocess_image(color_image_raw, color_mode) + frame = self.async_read(timeout_ms=10000) read_duration_ms = (time.perf_counter() - start_time) * 1e3 logger.debug(f"{self} read took: {read_duration_ms:.1f}ms") - return color_image_processed + return frame - def _postprocess_image( - self, image: NDArray[Any], color_mode: ColorMode | None = None, depth_frame: bool = False - ) -> NDArray[Any]: + def _postprocess_image(self, image: NDArray[Any], depth_frame: bool = False) -> NDArray[Any]: """ Applies color conversion, dimension validation, and rotation to a raw color frame. Args: image (np.ndarray): The raw image frame (expected RGB format from RealSense). - color_mode (Optional[ColorMode]): The target color mode (RGB or BGR). If None, - uses the instance's default `self.color_mode`. Returns: np.ndarray: The processed image frame according to `self.color_mode` and `self.rotation`. @@ -421,9 +434,9 @@ class RealSenseCamera(Camera): `width` and `height`. """ - if color_mode and color_mode not in (ColorMode.RGB, ColorMode.BGR): + if self.color_mode and self.color_mode not in (ColorMode.RGB, ColorMode.BGR): raise ValueError( - f"Invalid requested color mode '{color_mode}'. Expected {ColorMode.RGB} or {ColorMode.BGR}." + f"Invalid requested color mode '{self.color_mode}'. Expected {ColorMode.RGB} or {ColorMode.BGR}." ) if depth_frame: @@ -454,7 +467,7 @@ class RealSenseCamera(Camera): On each iteration: 1. Reads a color frame with 500ms timeout - 2. Stores result in latest_frame (thread-safe) + 2. Stores result in latest_frame and updates timestamp (thread-safe) 3. Sets new_frame_event to notify listeners Stops on DeviceNotConnectedError, logs other errors and continues. @@ -462,25 +475,41 @@ class RealSenseCamera(Camera): if self.stop_event is None: raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.") + failure_count = 0 while not self.stop_event.is_set(): try: - color_image = self.read(timeout_ms=500) + frame = self._read_from_hardware() + color_frame_raw = frame.get_color_frame() + color_frame = np.asanyarray(color_frame_raw.get_data()) + processed_color_frame = self._postprocess_image(color_frame) + + if self.use_depth: + depth_frame_raw = frame.get_depth_frame() + depth_frame = np.asanyarray(depth_frame_raw.get_data()) + processed_depth_frame = self._postprocess_image(depth_frame, depth_frame=True) + + capture_time = time.perf_counter() with self.frame_lock: - self.latest_frame = color_image + self.latest_color_frame = processed_color_frame + if self.use_depth: + self.latest_depth_frame = processed_depth_frame + self.latest_timestamp = capture_time self.new_frame_event.set() + failure_count = 0 except DeviceNotConnectedError: break except Exception as e: - logger.warning(f"Error reading frame in background thread for {self}: {e}") + if failure_count <= 10: + failure_count += 1 + logger.warning(f"Error reading frame in background thread for {self}: {e}") + else: + raise RuntimeError(f"{self} exceeded maximum consecutive read failures.") from e def _start_read_thread(self) -> None: """Starts or restarts the background read thread if it's not running.""" - if self.thread is not None and self.thread.is_alive(): - self.thread.join(timeout=0.1) - if self.stop_event is not None: - self.stop_event.set() + self._stop_read_thread() self.stop_event = Event() self.thread = Thread(target=self._read_loop, args=(), name=f"{self}_read_loop") @@ -498,6 +527,12 @@ class RealSenseCamera(Camera): self.thread = None self.stop_event = None + with self.frame_lock: + self.latest_color_frame = None + self.latest_depth_frame = None + self.latest_timestamp = None + self.new_frame_event.clear() + # NOTE(Steven): Missing implementation for depth for now def async_read(self, timeout_ms: float = 200) -> NDArray[Any]: """ @@ -506,6 +541,7 @@ class RealSenseCamera(Camera): This method retrieves the most recent color frame captured by the background read thread. It does not block waiting for the camera hardware directly, but may wait up to timeout_ms for the background thread to provide a frame. + It is “best effort” under high FPS. Args: timeout_ms (float): Maximum time in milliseconds to wait for a frame @@ -524,17 +560,16 @@ class RealSenseCamera(Camera): raise DeviceNotConnectedError(f"{self} is not connected.") if self.thread is None or not self.thread.is_alive(): - self._start_read_thread() + raise RuntimeError(f"{self} read thread is not running.") if not self.new_frame_event.wait(timeout=timeout_ms / 1000.0): - thread_alive = self.thread is not None and self.thread.is_alive() raise TimeoutError( f"Timed out waiting for frame from camera {self} after {timeout_ms} ms. " - f"Read thread alive: {thread_alive}." + f"Read thread alive: {self.thread.is_alive()}." ) with self.frame_lock: - frame = self.latest_frame + frame = self.latest_color_frame self.new_frame_event.clear() if frame is None: @@ -542,6 +577,43 @@ class RealSenseCamera(Camera): return frame + # NOTE(Steven): Missing implementation for depth for now + def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]: + """Return the most recent (color) frame captured immediately (Peeking). + + This method is non-blocking and returns whatever is currently in the + memory buffer. The frame may be stale, + meaning it could have been captured a while ago (hanging camera scenario e.g.). + + Returns: + NDArray[Any]: The frame image (numpy array). + + Raises: + TimeoutError: If the latest frame is older than `max_age_ms`. + DeviceNotConnectedError: If the camera is not connected. + RuntimeError: If the camera is connected but has not captured any frames yet. + """ + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + if self.thread is None or not self.thread.is_alive(): + raise RuntimeError(f"{self} read thread is not running.") + + with self.frame_lock: + frame = self.latest_color_frame + timestamp = self.latest_timestamp + + if frame is None or timestamp is None: + raise RuntimeError(f"{self} has not captured any frames yet.") + + age_ms = (time.perf_counter() - timestamp) * 1e3 + if age_ms > max_age_ms: + raise TimeoutError( + f"{self} latest frame is too old: {age_ms:.1f} ms (max allowed: {max_age_ms} ms)." + ) + + return frame + def disconnect(self) -> None: """ Disconnects from the camera, stops the pipeline, and cleans up resources. @@ -565,4 +637,10 @@ class RealSenseCamera(Camera): self.rs_pipeline = None self.rs_profile = None + with self.frame_lock: + self.latest_color_frame = None + self.latest_depth_frame = None + self.latest_timestamp = None + self.new_frame_event.clear() + logger.info(f"{self} disconnected.") diff --git a/src/lerobot/cameras/zmq/camera_zmq.py b/src/lerobot/cameras/zmq/camera_zmq.py index 1a4155f4b..a231a582a 100644 --- a/src/lerobot/cameras/zmq/camera_zmq.py +++ b/src/lerobot/cameras/zmq/camera_zmq.py @@ -45,6 +45,12 @@ logger = logging.getLogger(__name__) class ZMQCamera(Camera): """ + Manages camera interactions via ZeroMQ for receiving frames from a remote server. + + This class connects to a ZMQ Publisher, subscribes to frame topics, and decodes + incoming JSON messages containing Base64 encoded images. It supports both + synchronous and asynchronous frame reading patterns. + Example usage: ```python from lerobot.cameras.zmq import ZMQCamera, ZMQCameraConfig @@ -52,7 +58,16 @@ class ZMQCamera(Camera): config = ZMQCameraConfig(server_address="192.168.123.164", port=5555, camera_name="head_camera") camera = ZMQCamera(config) camera.connect() - frame = camera.read() + + # Read 1 frame synchronously (blocking) + color_image = camera.read() + + # Read 1 frame asynchronously (waits for new frame with a timeout) + async_image = camera.async_read() + + # Get the latest frame immediately (no wait, returns timestamp) + latest_image, timestamp = camera.read_latest() + camera.disconnect() ``` """ @@ -68,14 +83,17 @@ class ZMQCamera(Camera): self.color_mode = config.color_mode self.timeout_ms = config.timeout_ms + # ZMQ Context and Socket self.context: zmq.Context | None = None self.socket: zmq.Socket | None = None self._connected = False + # Threading resources self.thread: Thread | None = None self.stop_event: Event | None = None self.frame_lock: Lock = Lock() self.latest_frame: NDArray[Any] | None = None + self.latest_timestamp: float | None = None self.new_frame_event: Event = Event() def __str__(self) -> str: @@ -83,10 +101,16 @@ class ZMQCamera(Camera): @property def is_connected(self) -> bool: + """Checks if the ZMQ socket is initialized and connected.""" return self._connected and self.context is not None and self.socket is not None def connect(self, warmup: bool = True) -> None: - """Connect to ZMQ camera server.""" + """Connect to ZMQ camera server. + + Args: + warmup (bool): If True, waits for the camera to provide at least one + valid frame before returning. Defaults to True. + """ if self.is_connected: raise DeviceAlreadyConnectedError(f"{self} is already connected.") @@ -103,17 +127,28 @@ class ZMQCamera(Camera): self.socket.connect(f"tcp://{self.server_address}:{self.port}") self._connected = True - # Auto-detect resolution + # Auto-detect resolution if not provided if self.width is None or self.height is None: - h, w = self.read().shape[:2] + # Read directly from hardware because the thread isn't running yet + temp_frame = self._read_from_hardware() + h, w = temp_frame.shape[:2] self.height = h self.width = w - logger.info(f"{self} resolution: {w}x{h}") + logger.info(f"{self} resolution detected: {w}x{h}") + self._start_read_thread() logger.info(f"{self} connected.") if warmup: - time.sleep(0.1) + # Ensure we have captured at least one frame via the thread + start_time = time.time() + while time.time() - start_time < (self.config.warmup_s): # Wait a bit more than timeout + self.async_read(timeout_ms=self.config.warmup_s * 1000) + time.sleep(0.1) + + with self.frame_lock: + if self.latest_frame is None: + raise ConnectionError(f"{self} failed to capture frames during warmup.") except Exception as e: self._cleanup() @@ -134,12 +169,9 @@ class ZMQCamera(Camera): """ZMQ cameras require manual configuration (server address/port).""" return [] - def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]: + def _read_from_hardware(self) -> NDArray[Any]: """ - Read a single frame from the ZMQ camera. - - Returns: - np.ndarray: Decoded frame (height, width, 3) + Reads a single frame directly from the ZMQ socket. """ if not self.is_connected or self.socket is None: raise DeviceNotConnectedError(f"{self} is not connected.") @@ -147,6 +179,7 @@ class ZMQCamera(Camera): try: message = self.socket.recv_string() except Exception as e: + # Check for ZMQ timeout (EAGAIN/Again) without requiring global zmq import if type(e).__name__ == "Again": raise TimeoutError(f"{self} timeout after {self.timeout_ms}ms") from e raise @@ -176,42 +209,117 @@ class ZMQCamera(Camera): return frame - def _read_loop(self) -> None: - while self.stop_event and not self.stop_event.is_set(): - try: - frame = self.read() - with self.frame_lock: - self.latest_frame = frame - self.new_frame_event.set() - except DeviceNotConnectedError: - break - except TimeoutError: - pass - except Exception as e: - logger.warning(f"Read error: {e}") + def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]: + """ + Reads a single frame synchronously from the camera. - def _start_read_thread(self) -> None: - if self.thread and self.thread.is_alive(): - return - self.stop_event = Event() - self.thread = Thread(target=self._read_loop, daemon=True) - self.thread.start() + This is a blocking call. It waits for the next available frame from the + camera background thread. - def _stop_read_thread(self) -> None: - if self.stop_event: - self.stop_event.set() - if self.thread and self.thread.is_alive(): - self.thread.join(timeout=2.0) - self.thread = None - self.stop_event = None + Returns: + np.ndarray: Decoded frame (height, width, 3) + """ + start_time = time.perf_counter() + + if color_mode is not None: + logger.warning( + f"{self} read() color_mode parameter is deprecated and will be removed in future versions." + ) - def async_read(self, timeout_ms: float = 10000) -> NDArray[Any]: - """Read latest frame asynchronously (non-blocking).""" if not self.is_connected: raise DeviceNotConnectedError(f"{self} is not connected.") - if not self.thread or not self.thread.is_alive(): - self._start_read_thread() + if self.thread is None or not self.thread.is_alive(): + raise RuntimeError(f"{self} read thread is not running.") + + self.new_frame_event.clear() + frame = self.async_read(timeout_ms=10000) + + read_duration_ms = (time.perf_counter() - start_time) * 1e3 + logger.debug(f"{self} read took: {read_duration_ms:.1f}ms") + + return frame + + def _read_loop(self) -> None: + """ + Internal loop run by the background thread for asynchronous reading. + """ + if self.stop_event is None: + raise RuntimeError(f"{self}: stop_event is not initialized.") + + failure_count = 0 + while not self.stop_event.is_set(): + try: + frame = self._read_from_hardware() + capture_time = time.perf_counter() + + with self.frame_lock: + self.latest_frame = frame + self.latest_timestamp = capture_time + self.new_frame_event.set() + failure_count = 0 + + except DeviceNotConnectedError: + break + except (TimeoutError, Exception) as e: + if failure_count <= 10: + failure_count += 1 + logger.warning(f"Read error: {e}") + else: + raise RuntimeError(f"{self} exceeded maximum consecutive read failures.") from e + + def _start_read_thread(self) -> None: + if self.stop_event is not None: + self.stop_event.set() + if self.thread is not None and self.thread.is_alive(): + self.thread.join(timeout=2.0) + + with self.frame_lock: + self.latest_frame = None + self.latest_timestamp = None + self.new_frame_event.clear() + + self.stop_event = Event() + self.thread = Thread(target=self._read_loop, daemon=True, name=f"{self}_read_loop") + self.thread.start() + time.sleep(0.1) + + def _stop_read_thread(self) -> None: + if self.stop_event is not None: + self.stop_event.set() + + if self.thread is not None and self.thread.is_alive(): + self.thread.join(timeout=2.0) + + self.thread = None + self.stop_event = None + + with self.frame_lock: + self.latest_frame = None + self.latest_timestamp = None + self.new_frame_event.clear() + + def async_read(self, timeout_ms: float = 200) -> NDArray[Any]: + """ + Reads the latest available frame asynchronously. + + Args: + timeout_ms (float): Maximum time in milliseconds to wait for a frame + to become available. Defaults to 200ms. + + Returns: + np.ndarray: The latest captured frame. + + Raises: + DeviceNotConnectedError: If the camera is not connected. + TimeoutError: If no frame data becomes available within the specified timeout. + RuntimeError: If the background thread is not running. + """ + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + if self.thread is None or not self.thread.is_alive(): + raise RuntimeError(f"{self} read thread is not running.") if not self.new_frame_event.wait(timeout=timeout_ms / 1000.0): raise TimeoutError(f"{self} async_read timeout after {timeout_ms}ms") @@ -225,11 +333,55 @@ class ZMQCamera(Camera): return frame + def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]: + """Return the most recent frame captured immediately (Peeking). + + This method is non-blocking and returns whatever is currently in the + memory buffer. The frame may be stale, + meaning it could have been captured a while ago (hanging camera scenario e.g.). + + Returns: + NDArray[Any]: The frame image (numpy array). + + Raises: + TimeoutError: If the latest frame is older than `max_age_ms`. + DeviceNotConnectedError: If the camera is not connected. + RuntimeError: If the camera is connected but has not captured any frames yet. + """ + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + if self.thread is None or not self.thread.is_alive(): + raise RuntimeError(f"{self} read thread is not running.") + + with self.frame_lock: + frame = self.latest_frame + timestamp = self.latest_timestamp + + if frame is None or timestamp is None: + raise RuntimeError(f"{self} has not captured any frames yet.") + + age_ms = (time.perf_counter() - timestamp) * 1e3 + if age_ms > max_age_ms: + raise TimeoutError( + f"{self} latest frame is too old: {age_ms:.1f} ms (max allowed: {max_age_ms} ms)." + ) + + return frame + def disconnect(self) -> None: """Disconnect from ZMQ camera.""" - if not self.is_connected and not self.thread: + if not self.is_connected and self.thread is None: raise DeviceNotConnectedError(f"{self} not connected.") - self._stop_read_thread() + if self.thread is not None: + self._stop_read_thread() + self._cleanup() + + with self.frame_lock: + self.latest_frame = None + self.latest_timestamp = None + self.new_frame_event.clear() + logger.info(f"{self} disconnected.") diff --git a/src/lerobot/cameras/zmq/configuration_zmq.py b/src/lerobot/cameras/zmq/configuration_zmq.py index 027ae12b5..4e7732cfc 100644 --- a/src/lerobot/cameras/zmq/configuration_zmq.py +++ b/src/lerobot/cameras/zmq/configuration_zmq.py @@ -29,6 +29,7 @@ class ZMQCameraConfig(CameraConfig): camera_name: str = "zmq_camera" color_mode: ColorMode = ColorMode.RGB timeout_ms: int = 5000 + warmup_s: int = 1 def __post_init__(self) -> None: if self.color_mode not in (ColorMode.RGB, ColorMode.BGR): diff --git a/src/lerobot/scripts/lerobot_calibrate.py b/src/lerobot/scripts/lerobot_calibrate.py index eb3df6872..1b30021dd 100644 --- a/src/lerobot/scripts/lerobot_calibrate.py +++ b/src/lerobot/scripts/lerobot_calibrate.py @@ -86,8 +86,11 @@ def calibrate(cfg: CalibrateConfig): device = make_teleoperator_from_config(cfg.device) device.connect(calibrate=False) - device.calibrate() - device.disconnect() + + try: + device.calibrate() + finally: + device.disconnect() def main(): diff --git a/src/lerobot/scripts/lerobot_replay.py b/src/lerobot/scripts/lerobot_replay.py index 5717dffb6..c9a559d07 100644 --- a/src/lerobot/scripts/lerobot_replay.py +++ b/src/lerobot/scripts/lerobot_replay.py @@ -110,25 +110,26 @@ def replay(cfg: ReplayConfig): robot.connect() - log_say("Replaying episode", cfg.play_sounds, blocking=True) - for idx in range(len(episode_frames)): - start_episode_t = time.perf_counter() + try: + log_say("Replaying episode", cfg.play_sounds, blocking=True) + for idx in range(len(episode_frames)): + start_episode_t = time.perf_counter() - action_array = actions[idx][ACTION] - action = {} - for i, name in enumerate(dataset.features[ACTION]["names"]): - action[name] = action_array[i] + action_array = actions[idx][ACTION] + action = {} + for i, name in enumerate(dataset.features[ACTION]["names"]): + action[name] = action_array[i] - robot_obs = robot.get_observation() + robot_obs = robot.get_observation() - processed_action = robot_action_processor((action, robot_obs)) + processed_action = robot_action_processor((action, robot_obs)) - _ = robot.send_action(processed_action) + _ = robot.send_action(processed_action) - dt_s = time.perf_counter() - start_episode_t - precise_sleep(max(1 / dataset.fps - dt_s, 0.0)) - - robot.disconnect() + dt_s = time.perf_counter() - start_episode_t + precise_sleep(max(1 / dataset.fps - dt_s, 0.0)) + finally: + robot.disconnect() def main(): diff --git a/tests/cameras/test_opencv.py b/tests/cameras/test_opencv.py index 3cf3793b6..feb700631 100644 --- a/tests/cameras/test_opencv.py +++ b/tests/cameras/test_opencv.py @@ -20,7 +20,9 @@ # ``` from pathlib import Path +from unittest.mock import patch +import cv2 import numpy as np import pytest @@ -28,6 +30,50 @@ from lerobot.cameras.configs import Cv2Rotation from lerobot.cameras.opencv import OpenCVCamera, OpenCVCameraConfig from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError +RealVideoCapture = cv2.VideoCapture + + +class MockLoopingVideoCapture: + """ + Wraps the real OpenCV VideoCapture. + Motivation: cv2.VideoCapture(file.png) is only valid for one read. + Strategy: Read the file once & return the cached frame for subsequent reads. + Consequence: No recurrent I/O operations, but we keep the test artifacts simple. + """ + + def __init__(self, *args, **kwargs): + args_clean = [str(a) if isinstance(a, Path) else a for a in args] + self._real_vc = RealVideoCapture(*args_clean, **kwargs) + self._cached_frame = None + + def read(self): + ret, frame = self._real_vc.read() + + if ret: + self._cached_frame = frame + return ret, frame + + if not ret and self._cached_frame is not None: + return True, self._cached_frame.copy() + + return ret, frame + + def __getattr__(self, name): + return getattr(self._real_vc, name) + + +@pytest.fixture(autouse=True) +def patch_opencv_videocapture(): + """ + Automatically patches cv2.VideoCapture for all tests. + """ + module_path = OpenCVCamera.__module__ + target = f"{module_path}.cv2.VideoCapture" + + with patch(target, new=MockLoopingVideoCapture): + yield + + # NOTE(Steven): more tests + assertions? TEST_ARTIFACTS_DIR = Path(__file__).parent.parent / "artifacts" / "cameras" DEFAULT_PNG_FILE_PATH = TEST_ARTIFACTS_DIR / "image_160x120.png" @@ -43,25 +89,22 @@ def test_abc_implementation(): def test_connect(): - config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH) - camera = OpenCVCamera(config) + config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH, warmup_s=0) - camera.connect(warmup=False) - - assert camera.is_connected + with OpenCVCamera(config) as camera: + assert camera.is_connected def test_connect_already_connected(): - config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH) - camera = OpenCVCamera(config) - camera.connect(warmup=False) + config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH, warmup_s=0) - with pytest.raises(DeviceAlreadyConnectedError): - camera.connect(warmup=False) + with OpenCVCamera(config) as camera, pytest.raises(DeviceAlreadyConnectedError): + camera.connect() def test_connect_invalid_camera_path(): config = OpenCVCameraConfig(index_or_path="nonexistent/camera.png") + camera = OpenCVCamera(config) with pytest.raises(ConnectionError): @@ -74,27 +117,25 @@ def test_invalid_width_connect(): width=99999, # Invalid width to trigger error height=480, ) - camera = OpenCVCamera(config) + camera = OpenCVCamera(config) with pytest.raises(RuntimeError): camera.connect(warmup=False) @pytest.mark.parametrize("index_or_path", TEST_IMAGE_PATHS, ids=TEST_IMAGE_SIZES) def test_read(index_or_path): - config = OpenCVCameraConfig(index_or_path=index_or_path) - camera = OpenCVCamera(config) - camera.connect(warmup=False) + config = OpenCVCameraConfig(index_or_path=index_or_path, warmup_s=0) - img = camera.read() - - assert isinstance(img, np.ndarray) + with OpenCVCamera(config) as camera: + img = camera.read() + assert isinstance(img, np.ndarray) def test_read_before_connect(): config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH) - camera = OpenCVCamera(config) + camera = OpenCVCamera(config) with pytest.raises(DeviceNotConnectedError): _ = camera.read() @@ -119,32 +160,22 @@ def test_disconnect_before_connect(): @pytest.mark.parametrize("index_or_path", TEST_IMAGE_PATHS, ids=TEST_IMAGE_SIZES) def test_async_read(index_or_path): - config = OpenCVCameraConfig(index_or_path=index_or_path) - camera = OpenCVCamera(config) - camera.connect(warmup=False) + config = OpenCVCameraConfig(index_or_path=index_or_path, warmup_s=0) - try: + with OpenCVCamera(config) as camera: img = camera.async_read() assert camera.thread is not None assert camera.thread.is_alive() assert isinstance(img, np.ndarray) - finally: - if camera.is_connected: - camera.disconnect() # To stop/join the thread. Otherwise get warnings when the test ends def test_async_read_timeout(): - config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH) - camera = OpenCVCamera(config) - camera.connect(warmup=False) + config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH, warmup_s=0) - try: - with pytest.raises(TimeoutError): - camera.async_read(timeout_ms=0) - finally: - if camera.is_connected: - camera.disconnect() + with OpenCVCamera(config) as camera, pytest.raises(TimeoutError): + camera.async_read(timeout_ms=0) # consumes any available frame by then + camera.async_read(timeout_ms=0) # request immediately another one def test_async_read_before_connect(): @@ -155,6 +186,50 @@ def test_async_read_before_connect(): _ = camera.async_read() +def test_read_latest(): + config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH, warmup_s=0) + + with OpenCVCamera(config) as camera: + # ensure at least one fresh frame is captured + frame = camera.read() + latest = camera.read_latest() + + assert isinstance(latest, np.ndarray) + assert latest.shape == frame.shape + + +def test_read_latest_before_connect(): + config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH) + + camera = OpenCVCamera(config) + with pytest.raises(DeviceNotConnectedError): + _ = camera.read_latest() + + +def test_read_latest_high_frequency(): + config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH, warmup_s=0) + + with OpenCVCamera(config) as camera: + # prime to ensure frames are available + ref = camera.read() + + for _ in range(20): + latest = camera.read_latest() + assert isinstance(latest, np.ndarray) + assert latest.shape == ref.shape + + +def test_read_latest_too_old(): + config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH, warmup_s=0) + + with OpenCVCamera(config) as camera: + # prime to ensure frames are available + _ = camera.read() + + with pytest.raises(TimeoutError): + _ = camera.read_latest(max_age_ms=0) # immediately too old + + def test_fourcc_configuration(): """Test FourCC configuration validation and application.""" @@ -181,18 +256,15 @@ def test_fourcc_configuration(): def test_fourcc_with_camera(): """Test FourCC functionality with actual camera connection.""" - config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH, fourcc="MJPG") - camera = OpenCVCamera(config) + config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH, fourcc="MJPG", warmup_s=0) # Connect should work with MJPG specified - camera.connect(warmup=False) - assert camera.is_connected + with OpenCVCamera(config) as camera: + assert camera.is_connected - # Read should work normally - img = camera.read() - assert isinstance(img, np.ndarray) - - camera.disconnect() + # Read should work normally + img = camera.read() + assert isinstance(img, np.ndarray) @pytest.mark.parametrize("index_or_path", TEST_IMAGE_PATHS, ids=TEST_IMAGE_SIZES) @@ -211,18 +283,16 @@ def test_rotation(rotation, index_or_path): dimensions = filename.split("_")[-1].split(".")[0] # Assumes filenames format (_wxh.png) original_width, original_height = map(int, dimensions.split("x")) - config = OpenCVCameraConfig(index_or_path=index_or_path, rotation=rotation) - camera = OpenCVCamera(config) - camera.connect(warmup=False) + config = OpenCVCameraConfig(index_or_path=index_or_path, rotation=rotation, warmup_s=0) + with OpenCVCamera(config) as camera: + img = camera.read() + assert isinstance(img, np.ndarray) - img = camera.read() - assert isinstance(img, np.ndarray) - - if rotation in (Cv2Rotation.ROTATE_90, Cv2Rotation.ROTATE_270): - assert camera.width == original_height - assert camera.height == original_width - assert img.shape[:2] == (original_width, original_height) - else: - assert camera.width == original_width - assert camera.height == original_height - assert img.shape[:2] == (original_height, original_width) + if rotation in (Cv2Rotation.ROTATE_90, Cv2Rotation.ROTATE_270): + assert camera.width == original_height + assert camera.height == original_width + assert img.shape[:2] == (original_width, original_height) + else: + assert camera.width == original_width + assert camera.height == original_height + assert img.shape[:2] == (original_height, original_width) diff --git a/tests/cameras/test_reachy2_camera.py b/tests/cameras/test_reachy2_camera.py index 14774bf38..2aebfdf0a 100644 --- a/tests/cameras/test_reachy2_camera.py +++ b/tests/cameras/test_reachy2_camera.py @@ -150,6 +150,44 @@ def test_async_read_before_connect(camera): _ = camera.async_read() +def test_read_latest(camera): + camera.connect() + + frame = camera.read() + latest = camera.read_latest() + + assert isinstance(latest, np.ndarray) + assert latest.shape == frame.shape + + +def test_read_latest_before_connect(camera): + # camera fixture yields an unconnected camera instance + with pytest.raises(DeviceNotConnectedError): + _ = camera.read_latest() + + +def test_read_latest_high_frequency(camera): + camera.connect() + + # prime to ensure frames are available + ref = camera.read() + + for _ in range(20): + latest = camera.read_latest() + assert isinstance(latest, np.ndarray) + assert latest.shape == ref.shape + + +def test_read_latest_too_old(camera): + camera.connect() + + # prime to ensure frames are available + _ = camera.read() + + with pytest.raises(TimeoutError): + _ = camera.read_latest(max_age_ms=0) # immediately too old + + def test_wrong_camera_name(): with pytest.raises(ValueError): _ = Reachy2CameraConfig(name="wrong-name", image_type="left") diff --git a/tests/cameras/test_realsense.py b/tests/cameras/test_realsense.py index fb9912257..1deb73f05 100644 --- a/tests/cameras/test_realsense.py +++ b/tests/cameras/test_realsense.py @@ -62,19 +62,15 @@ def test_abc_implementation(): def test_connect(): - config = RealSenseCameraConfig(serial_number_or_name="042") - camera = RealSenseCamera(config) + config = RealSenseCameraConfig(serial_number_or_name="042", warmup_s=0) - camera.connect(warmup=False) - assert camera.is_connected + with RealSenseCamera(config) as camera: + assert camera.is_connected def test_connect_already_connected(): - config = RealSenseCameraConfig(serial_number_or_name="042") - camera = RealSenseCamera(config) - camera.connect(warmup=False) - - with pytest.raises(DeviceAlreadyConnectedError): + config = RealSenseCameraConfig(serial_number_or_name="042", warmup_s=0) + with RealSenseCamera(config) as camera, pytest.raises(DeviceAlreadyConnectedError): camera.connect(warmup=False) @@ -96,12 +92,10 @@ def test_invalid_width_connect(): def test_read(): - config = RealSenseCameraConfig(serial_number_or_name="042", width=640, height=480, fps=30) - camera = RealSenseCamera(config) - camera.connect(warmup=False) - - img = camera.read() - assert isinstance(img, np.ndarray) + config = RealSenseCameraConfig(serial_number_or_name="042", width=640, height=480, fps=30, warmup_s=0) + with RealSenseCamera(config) as camera: + img = camera.read() + assert isinstance(img, np.ndarray) # TODO(Steven): Fix this test for the latest version of pyrealsense2. @@ -142,32 +136,21 @@ def test_disconnect_before_connect(): def test_async_read(): - config = RealSenseCameraConfig(serial_number_or_name="042", width=640, height=480, fps=30) - camera = RealSenseCamera(config) - camera.connect(warmup=False) + config = RealSenseCameraConfig(serial_number_or_name="042", width=640, height=480, fps=30, warmup_s=0) - try: + with RealSenseCamera(config) as camera: img = camera.async_read() assert camera.thread is not None assert camera.thread.is_alive() assert isinstance(img, np.ndarray) - finally: - if camera.is_connected: - camera.disconnect() # To stop/join the thread. Otherwise get warnings when the test ends def test_async_read_timeout(): - config = RealSenseCameraConfig(serial_number_or_name="042", width=640, height=480, fps=30) - camera = RealSenseCamera(config) - camera.connect(warmup=False) - - try: - with pytest.raises(TimeoutError): - camera.async_read(timeout_ms=0) - finally: - if camera.is_connected: - camera.disconnect() + config = RealSenseCameraConfig(serial_number_or_name="042", width=640, height=480, fps=30, warmup_s=0) + with RealSenseCamera(config) as camera, pytest.raises(TimeoutError): + camera.async_read(timeout_ms=0) # consumes any available frame by then + camera.async_read(timeout_ms=0) # request immediately another one def test_async_read_before_connect(): @@ -178,6 +161,47 @@ def test_async_read_before_connect(): _ = camera.async_read() +def test_read_latest(): + config = RealSenseCameraConfig(serial_number_or_name="042", width=640, height=480, fps=30, warmup_s=0) + with RealSenseCamera(config) as camera: + img = camera.read() + latest = camera.read_latest() + + assert isinstance(latest, np.ndarray) + assert latest.shape == img.shape + + +def test_read_latest_high_frequency(): + config = RealSenseCameraConfig(serial_number_or_name="042", width=640, height=480, fps=30, warmup_s=0) + with RealSenseCamera(config) as camera: + # prime with one read to ensure frames are available + ref = camera.read() + + for _ in range(20): + latest = camera.read_latest() + assert isinstance(latest, np.ndarray) + assert latest.shape == ref.shape + + +def test_read_latest_before_connect(): + config = RealSenseCameraConfig(serial_number_or_name="042") + camera = RealSenseCamera(config) + + with pytest.raises(DeviceNotConnectedError): + _ = camera.read_latest() + + +def test_read_latest_too_old(): + config = RealSenseCameraConfig(serial_number_or_name="042") + + with RealSenseCamera(config) as camera: + # prime to ensure frames are available + _ = camera.read() + + with pytest.raises(TimeoutError): + _ = camera.read_latest(max_age_ms=0) # immediately too old + + @pytest.mark.parametrize( "rotation", [ @@ -189,18 +213,16 @@ def test_async_read_before_connect(): ids=["no_rot", "rot90", "rot180", "rot270"], ) def test_rotation(rotation): - config = RealSenseCameraConfig(serial_number_or_name="042", rotation=rotation) - camera = RealSenseCamera(config) - camera.connect(warmup=False) + config = RealSenseCameraConfig(serial_number_or_name="042", rotation=rotation, warmup_s=0) + with RealSenseCamera(config) as camera: + img = camera.read() + assert isinstance(img, np.ndarray) - img = camera.read() - assert isinstance(img, np.ndarray) - - if rotation in (Cv2Rotation.ROTATE_90, Cv2Rotation.ROTATE_270): - assert camera.width == 480 - assert camera.height == 640 - assert img.shape[:2] == (640, 480) - else: - assert camera.width == 640 - assert camera.height == 480 - assert img.shape[:2] == (480, 640) + if rotation in (Cv2Rotation.ROTATE_90, Cv2Rotation.ROTATE_270): + assert camera.width == 480 + assert camera.height == 640 + assert img.shape[:2] == (640, 480) + else: + assert camera.width == 640 + assert camera.height == 480 + assert img.shape[:2] == (480, 640) From 04cbf669cf0565950f8ba66e8a03a66bd8f20d7a Mon Sep 17 00:00:00 2001 From: Michel Aractingi Date: Fri, 30 Jan 2026 12:23:22 +0100 Subject: [PATCH 017/131] fix(sac): make temperature a property to fix checkpoint resume bug (#2877) * fix(sac): make temperature a property to fix checkpoint resume bug Temperature was stored as a plain float and not restored after loading a checkpoint, causing incorrect loss computations until update_temperature() was called. Changed to a property that always computes from log_alpha, ensuring correct behavior after checkpoint loading. * simplify docstrings --- src/lerobot/policies/sac/modeling_sac.py | 11 ++++++----- src/lerobot/rl/learner.py | 3 --- tests/policies/test_sac_policy.py | 3 ++- 3 files changed, 8 insertions(+), 9 deletions(-) diff --git a/src/lerobot/policies/sac/modeling_sac.py b/src/lerobot/policies/sac/modeling_sac.py index c7c6798ed..d5dd71a48 100644 --- a/src/lerobot/policies/sac/modeling_sac.py +++ b/src/lerobot/policies/sac/modeling_sac.py @@ -239,8 +239,10 @@ class SACPolicy( + target_param.data * (1.0 - self.config.critic_target_update_weight) ) - def update_temperature(self): - self.temperature = self.log_alpha.exp().item() + @property + def temperature(self) -> float: + """Return the current temperature value, always in sync with log_alpha.""" + return self.log_alpha.exp().item() def compute_loss_critic( self, @@ -457,11 +459,10 @@ class SACPolicy( dim = continuous_action_dim + (1 if self.config.num_discrete_actions is not None else 0) self.target_entropy = -np.prod(dim) / 2 - def _init_temperature(self): - """Set up temperature parameter and initial log_alpha.""" + def _init_temperature(self) -> None: + """Set up temperature parameter (log_alpha).""" temp_init = self.config.temperature_init self.log_alpha = nn.Parameter(torch.tensor([math.log(temp_init)])) - self.temperature = self.log_alpha.exp().item() class SACObservationEncoder(nn.Module): diff --git a/src/lerobot/rl/learner.py b/src/lerobot/rl/learner.py index abc5c9504..ee09ac9ac 100644 --- a/src/lerobot/rl/learner.py +++ b/src/lerobot/rl/learner.py @@ -545,9 +545,6 @@ def add_actor_information_and_train( training_infos["temperature_grad_norm"] = temp_grad_norm training_infos["temperature"] = policy.temperature - # Update temperature - policy.update_temperature() - # Push policy to actors if needed if time.time() - last_time_policy_pushed > policy_parameters_push_frequency: push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy) diff --git a/tests/policies/test_sac_policy.py b/tests/policies/test_sac_policy.py index 8576883bd..6fad2979e 100644 --- a/tests/policies/test_sac_policy.py +++ b/tests/policies/test_sac_policy.py @@ -441,12 +441,13 @@ def test_sac_policy_with_predefined_entropy(): def test_sac_policy_update_temperature(): + """Test that temperature property is always in sync with log_alpha.""" config = create_default_config(continuous_action_dim=10, state_dim=10) policy = SACPolicy(config=config) assert policy.temperature == pytest.approx(1.0) policy.log_alpha.data = torch.tensor([math.log(0.1)]) - policy.update_temperature() + # Temperature property automatically reflects log_alpha changes assert policy.temperature == pytest.approx(0.1) From ec04b7ce3aca23491e42232c1ae723bb4b981993 Mon Sep 17 00:00:00 2001 From: Michel Aractingi Date: Fri, 30 Jan 2026 13:19:42 +0100 Subject: [PATCH 018/131] Feat(dataset_tools.py) Add modify tasks tool (#2875) * feat(datasets): add modify_tasks function for in-place task editing Add a new utility function to modify tasks in LeRobotDataset in-place. This allows users to: - Set a single task for all episodes - Set specific tasks for individual episodes - Combine a default task with per-episode overrides * feat(edit-dataset): add CLI support for modify_tasks operation Integrate the modify_tasks function into lerobot_edit_dataset CLI. Users can now modify dataset tasks via command line: Supports setting a default task, per-episode tasks, or both combined. * test(datasets): add tests for modify_tasks function Add comprehensive test coverage for the modify_tasks utility: - Single task for all episodes - Episode-specific task assignment - Default task with per-episode overrides - Error handling for missing/invalid arguments - Verification of task_index correctness - In-place modification behavior - Metadata preservation * respond to copilot review --- src/lerobot/datasets/dataset_tools.py | 126 +++++++++++++++ src/lerobot/scripts/lerobot_edit_dataset.py | 82 +++++++++- tests/datasets/test_dataset_tools.py | 169 ++++++++++++++++++++ 3 files changed, 374 insertions(+), 3 deletions(-) diff --git a/src/lerobot/datasets/dataset_tools.py b/src/lerobot/datasets/dataset_tools.py index e2928e2a6..123d455c6 100644 --- a/src/lerobot/datasets/dataset_tools.py +++ b/src/lerobot/datasets/dataset_tools.py @@ -1396,6 +1396,132 @@ BYTES_PER_KIB = 1024 BYTES_PER_MIB = BYTES_PER_KIB * BYTES_PER_KIB +def modify_tasks( + dataset: LeRobotDataset, + new_task: str | None = None, + episode_tasks: dict[int, str] | None = None, +) -> LeRobotDataset: + """Modify tasks in a LeRobotDataset. + + This function allows you to either: + 1. Set a single task for the entire dataset (using `new_task`) + 2. Set specific tasks for specific episodes (using `episode_tasks`) + + You can combine both: `new_task` sets the default, and `episode_tasks` overrides + specific episodes. + + The dataset is modified in-place, updating only the task-related files: + - meta/tasks.parquet + - data/**/*.parquet (task_index column) + - meta/episodes/**/*.parquet (tasks column) + - meta/info.json (total_tasks) + + Args: + dataset: The source LeRobotDataset to modify. + new_task: A single task string to apply to all episodes. If None and episode_tasks + is also None, raises an error. + episode_tasks: Optional dict mapping episode indices to their task strings. + Overrides `new_task` for specific episodes. + + + Examples: + Set a single task for all episodes: + dataset = modify_tasks(dataset, new_task="Pick up the cube") + + Set different tasks for specific episodes: + dataset = modify_tasks( + dataset, + episode_tasks={0: "Task A", 1: "Task B", 2: "Task A"} + ) + + Set a default task with overrides: + dataset = modify_tasks( + dataset, + new_task="Default task", + episode_tasks={5: "Special task for episode 5"} + ) + """ + if new_task is None and episode_tasks is None: + raise ValueError("Must specify at least one of new_task or episode_tasks") + + if episode_tasks is not None: + valid_indices = set(range(dataset.meta.total_episodes)) + invalid = set(episode_tasks.keys()) - valid_indices + if invalid: + raise ValueError(f"Invalid episode indices: {invalid}") + + # Ensure episodes metadata is loaded + if dataset.meta.episodes is None: + dataset.meta.episodes = load_episodes(dataset.root) + + # Build the mapping from episode index to task string + episode_to_task: dict[int, str] = {} + for ep_idx in range(dataset.meta.total_episodes): + if episode_tasks and ep_idx in episode_tasks: + episode_to_task[ep_idx] = episode_tasks[ep_idx] + elif new_task is not None: + episode_to_task[ep_idx] = new_task + else: + # Keep original task if not overridden and no default provided + original_tasks = dataset.meta.episodes[ep_idx]["tasks"] + if not original_tasks: + raise ValueError(f"Episode {ep_idx} has no tasks and no default task was provided") + episode_to_task[ep_idx] = original_tasks[0] + + # Collect all unique tasks and create new task mapping + unique_tasks = sorted(set(episode_to_task.values())) + new_task_df = pd.DataFrame({"task_index": list(range(len(unique_tasks)))}, index=unique_tasks) + task_to_index = {task: idx for idx, task in enumerate(unique_tasks)} + + logging.info(f"Modifying tasks in {dataset.repo_id}") + logging.info(f"New tasks: {unique_tasks}") + + root = dataset.root + + # Update data files - modify task_index column + logging.info("Updating data files...") + data_dir = root / DATA_DIR + + for parquet_path in tqdm(sorted(data_dir.rglob("*.parquet")), desc="Updating data"): + df = pd.read_parquet(parquet_path) + + # Build a mapping from episode_index to new task_index for rows in this file + episode_indices_in_file = df["episode_index"].unique() + ep_to_new_task_idx = { + ep_idx: task_to_index[episode_to_task[ep_idx]] for ep_idx in episode_indices_in_file + } + + # Update task_index column + df["task_index"] = df["episode_index"].map(ep_to_new_task_idx) + df.to_parquet(parquet_path, index=False) + + # Update episodes metadata - modify tasks column + logging.info("Updating episodes metadata...") + episodes_dir = root / "meta" / "episodes" + + for parquet_path in tqdm(sorted(episodes_dir.rglob("*.parquet")), desc="Updating episodes"): + df = pd.read_parquet(parquet_path) + + # Update tasks column + df["tasks"] = df["episode_index"].apply(lambda ep_idx: [episode_to_task[ep_idx]]) + df.to_parquet(parquet_path, index=False) + + # Write new tasks.parquet + write_tasks(new_task_df, root) + + # Update info.json + dataset.meta.info["total_tasks"] = len(unique_tasks) + write_info(dataset.meta.info, root) + + # Reload metadata to reflect changes + dataset.meta.tasks = new_task_df + dataset.meta.episodes = load_episodes(root) + + logging.info(f"Tasks: {unique_tasks}") + + return dataset + + def convert_image_to_video_dataset( dataset: LeRobotDataset, output_dir: Path, diff --git a/src/lerobot/scripts/lerobot_edit_dataset.py b/src/lerobot/scripts/lerobot_edit_dataset.py index 4ba6ce44f..2ca9c520d 100644 --- a/src/lerobot/scripts/lerobot_edit_dataset.py +++ b/src/lerobot/scripts/lerobot_edit_dataset.py @@ -18,7 +18,7 @@ Edit LeRobot datasets using various transformation tools. This script allows you to delete episodes, split datasets, merge datasets, -remove features, and convert image datasets to video format. +remove features, modify tasks, and convert image datasets to video format. When new_repo_id is specified, creates a new dataset. Usage Examples: @@ -66,6 +66,25 @@ Remove camera feature: --operation.type remove_feature \ --operation.feature_names "['observation.images.top']" +Modify tasks - set a single task for all episodes (WARNING: modifies in-place): + python -m lerobot.scripts.lerobot_edit_dataset \ + --repo_id lerobot/pusht \ + --operation.type modify_tasks \ + --operation.new_task "Pick up the cube and place it" + +Modify tasks - set different tasks for specific episodes (WARNING: modifies in-place): + python -m lerobot.scripts.lerobot_edit_dataset \ + --repo_id lerobot/pusht \ + --operation.type modify_tasks \ + --operation.episode_tasks '{"0": "Task A", "1": "Task B", "2": "Task A"}' + +Modify tasks - set default task with overrides for specific episodes (WARNING: modifies in-place): + python -m lerobot.scripts.lerobot_edit_dataset \ + --repo_id lerobot/pusht \ + --operation.type modify_tasks \ + --operation.new_task "Default task" \ + --operation.episode_tasks '{"5": "Special task for episode 5"}' + Convert image dataset to video format and save locally: python -m lerobot.scripts.lerobot_edit_dataset \ --repo_id lerobot/pusht_image \ @@ -100,6 +119,7 @@ from lerobot.datasets.dataset_tools import ( convert_image_to_video_dataset, delete_episodes, merge_datasets, + modify_tasks, remove_feature, split_dataset, ) @@ -132,6 +152,13 @@ class RemoveFeatureConfig: feature_names: list[str] | None = None +@dataclass +class ModifyTasksConfig: + type: str = "modify_tasks" + new_task: str | None = None + episode_tasks: dict[str, str] | None = None + + @dataclass class ConvertImageToVideoConfig: type: str = "convert_image_to_video" @@ -151,7 +178,12 @@ class ConvertImageToVideoConfig: class EditDatasetConfig: repo_id: str operation: ( - DeleteEpisodesConfig | SplitConfig | MergeConfig | RemoveFeatureConfig | ConvertImageToVideoConfig + DeleteEpisodesConfig + | SplitConfig + | MergeConfig + | RemoveFeatureConfig + | ModifyTasksConfig + | ConvertImageToVideoConfig ) root: str | None = None new_repo_id: str | None = None @@ -296,6 +328,48 @@ def handle_remove_feature(cfg: EditDatasetConfig) -> None: LeRobotDataset(output_repo_id, root=output_dir).push_to_hub() +def handle_modify_tasks(cfg: EditDatasetConfig) -> None: + if not isinstance(cfg.operation, ModifyTasksConfig): + raise ValueError("Operation config must be ModifyTasksConfig") + + new_task = cfg.operation.new_task + episode_tasks_raw = cfg.operation.episode_tasks + + if new_task is None and episode_tasks_raw is None: + raise ValueError("Must specify at least one of new_task or episode_tasks for modify_tasks operation") + + # Warn about in-place modification behavior + if cfg.new_repo_id is not None: + logging.warning("modify_tasks modifies datasets in-place. The --new_repo_id parameter is ignored.") + + dataset = LeRobotDataset(cfg.repo_id, root=cfg.root) + logging.warning(f"Modifying dataset in-place at {dataset.root}. Original data will be overwritten.") + + # Convert episode_tasks keys from string to int if needed (CLI passes strings) + episode_tasks: dict[int, str] | None = None + if episode_tasks_raw is not None: + episode_tasks = {int(k): v for k, v in episode_tasks_raw.items()} + + logging.info(f"Modifying tasks in {cfg.repo_id}") + if new_task: + logging.info(f" Default task: '{new_task}'") + if episode_tasks: + logging.info(f" Episode-specific tasks: {episode_tasks}") + + modified_dataset = modify_tasks( + dataset, + new_task=new_task, + episode_tasks=episode_tasks, + ) + + logging.info(f"Dataset modified at {dataset.root}") + logging.info(f"Tasks: {list(modified_dataset.meta.tasks.index)}") + + if cfg.push_to_hub: + logging.info(f"Pushing to hub as {cfg.repo_id}") + modified_dataset.push_to_hub() + + def handle_convert_image_to_video(cfg: EditDatasetConfig) -> None: # Note: Parser may create any config type with the right fields, so we access fields directly # instead of checking isinstance() @@ -371,12 +445,14 @@ def edit_dataset(cfg: EditDatasetConfig) -> None: handle_merge(cfg) elif operation_type == "remove_feature": handle_remove_feature(cfg) + elif operation_type == "modify_tasks": + handle_modify_tasks(cfg) elif operation_type == "convert_image_to_video": handle_convert_image_to_video(cfg) else: raise ValueError( f"Unknown operation type: {operation_type}\n" - f"Available operations: delete_episodes, split, merge, remove_feature, convert_to_video" + f"Available operations: delete_episodes, split, merge, remove_feature, modify_tasks, convert_image_to_video" ) diff --git a/tests/datasets/test_dataset_tools.py b/tests/datasets/test_dataset_tools.py index 35a369de9..1de199630 100644 --- a/tests/datasets/test_dataset_tools.py +++ b/tests/datasets/test_dataset_tools.py @@ -26,6 +26,7 @@ from lerobot.datasets.dataset_tools import ( delete_episodes, merge_datasets, modify_features, + modify_tasks, remove_feature, split_dataset, ) @@ -1050,6 +1051,174 @@ def test_modify_features_preserves_file_structure(sample_dataset, tmp_path): assert "reward" in modified_dataset.meta.features +def test_modify_tasks_single_task_for_all(sample_dataset): + """Test setting a single task for all episodes.""" + new_task = "Pick up the cube and place it" + + modified_dataset = modify_tasks(sample_dataset, new_task=new_task) + + # Verify all episodes have the new task + assert len(modified_dataset.meta.tasks) == 1 + assert new_task in modified_dataset.meta.tasks.index + + # Verify task_index is 0 for all frames (only one task) + for i in range(len(modified_dataset)): + item = modified_dataset[i] + assert item["task_index"].item() == 0 + assert item["task"] == new_task + + +def test_modify_tasks_episode_specific(sample_dataset): + """Test setting different tasks for specific episodes.""" + episode_tasks = { + 0: "Task A", + 1: "Task B", + 2: "Task A", + 3: "Task C", + 4: "Task B", + } + + modified_dataset = modify_tasks(sample_dataset, episode_tasks=episode_tasks) + + # Verify correct number of unique tasks + unique_tasks = set(episode_tasks.values()) + assert len(modified_dataset.meta.tasks) == len(unique_tasks) + + # Verify each episode has the correct task + for ep_idx, expected_task in episode_tasks.items(): + ep_data = modified_dataset.meta.episodes[ep_idx] + assert ep_data["tasks"][0] == expected_task + + +def test_modify_tasks_default_with_overrides(sample_dataset): + """Test setting a default task with specific overrides.""" + default_task = "Default task" + override_task = "Special task" + episode_tasks = {2: override_task, 4: override_task} + + modified_dataset = modify_tasks( + sample_dataset, + new_task=default_task, + episode_tasks=episode_tasks, + ) + + # Verify correct number of unique tasks + assert len(modified_dataset.meta.tasks) == 2 + assert default_task in modified_dataset.meta.tasks.index + assert override_task in modified_dataset.meta.tasks.index + + # Verify episodes have correct tasks + for ep_idx in range(5): + ep_data = modified_dataset.meta.episodes[ep_idx] + if ep_idx in episode_tasks: + assert ep_data["tasks"][0] == override_task + else: + assert ep_data["tasks"][0] == default_task + + +def test_modify_tasks_no_task_specified(sample_dataset): + """Test error when no task is specified.""" + with pytest.raises(ValueError, match="Must specify at least one of new_task or episode_tasks"): + modify_tasks(sample_dataset) + + +def test_modify_tasks_invalid_episode_indices(sample_dataset): + """Test error with invalid episode indices.""" + with pytest.raises(ValueError, match="Invalid episode indices"): + modify_tasks(sample_dataset, episode_tasks={10: "Task", 20: "Task"}) + + +def test_modify_tasks_updates_info_json(sample_dataset): + """Test that total_tasks is updated in info.json.""" + episode_tasks = {0: "Task A", 1: "Task B", 2: "Task C", 3: "Task A", 4: "Task B"} + + modified_dataset = modify_tasks(sample_dataset, episode_tasks=episode_tasks) + + # Verify total_tasks is updated + assert modified_dataset.meta.total_tasks == 3 + + +def test_modify_tasks_preserves_other_metadata(sample_dataset): + """Test that modifying tasks preserves other metadata.""" + original_frames = sample_dataset.meta.total_frames + original_episodes = sample_dataset.meta.total_episodes + original_fps = sample_dataset.meta.fps + + modified_dataset = modify_tasks(sample_dataset, new_task="New task") + + # Verify other metadata is preserved + assert modified_dataset.meta.total_frames == original_frames + assert modified_dataset.meta.total_episodes == original_episodes + assert modified_dataset.meta.fps == original_fps + + +def test_modify_tasks_task_index_correct(sample_dataset): + """Test that task_index values are correct in data files.""" + # Create tasks that will have predictable indices (sorted alphabetically) + episode_tasks = { + 0: "Alpha task", # Will be index 0 + 1: "Beta task", # Will be index 1 + 2: "Alpha task", # Will be index 0 + 3: "Gamma task", # Will be index 2 + 4: "Beta task", # Will be index 1 + } + + modified_dataset = modify_tasks(sample_dataset, episode_tasks=episode_tasks) + + # Verify task indices are correct + task_to_expected_idx = { + "Alpha task": 0, + "Beta task": 1, + "Gamma task": 2, + } + + for i in range(len(modified_dataset)): + item = modified_dataset[i] + ep_idx = item["episode_index"].item() + expected_task = episode_tasks[ep_idx] + expected_idx = task_to_expected_idx[expected_task] + assert item["task_index"].item() == expected_idx + assert item["task"] == expected_task + + +def test_modify_tasks_in_place(sample_dataset): + """Test that modify_tasks modifies the dataset in-place.""" + original_root = sample_dataset.root + + modified_dataset = modify_tasks(sample_dataset, new_task="New task") + + # Verify same instance is returned and root is unchanged + assert modified_dataset is sample_dataset + assert modified_dataset.root == original_root + + +def test_modify_tasks_keeps_original_when_not_overridden(sample_dataset): + """Test that original tasks are kept when using episode_tasks without new_task.""" + from lerobot.datasets.utils import load_episodes + + # Ensure episodes metadata is loaded + if sample_dataset.meta.episodes is None: + sample_dataset.meta.episodes = load_episodes(sample_dataset.meta.root) + + # Get original tasks for episodes not being overridden + original_task_ep0 = sample_dataset.meta.episodes[0]["tasks"][0] + original_task_ep1 = sample_dataset.meta.episodes[1]["tasks"][0] + + # Only override episodes 2, 3, 4 + episode_tasks = {2: "New Task A", 3: "New Task B", 4: "New Task A"} + + modified_dataset = modify_tasks(sample_dataset, episode_tasks=episode_tasks) + + # Verify original tasks are kept for episodes 0 and 1 + assert modified_dataset.meta.episodes[0]["tasks"][0] == original_task_ep0 + assert modified_dataset.meta.episodes[1]["tasks"][0] == original_task_ep1 + + # Verify new tasks for overridden episodes + assert modified_dataset.meta.episodes[2]["tasks"][0] == "New Task A" + assert modified_dataset.meta.episodes[3]["tasks"][0] == "New Task B" + assert modified_dataset.meta.episodes[4]["tasks"][0] == "New Task A" + + def test_convert_image_to_video_dataset(tmp_path): """Test converting lerobot/pusht_image dataset to video format.""" from lerobot.datasets.lerobot_dataset import LeRobotDataset From 55c0471db9e440e99e801e2e67d645ecd7fdb9d5 Mon Sep 17 00:00:00 2001 From: Caroline Pascal Date: Fri, 30 Jan 2026 16:57:56 +0100 Subject: [PATCH 019/131] docs(cameras): revising and improving docs on cameras (#2878) * docs(cameras): revising and improving docs on cameras * resolving copilot comments --- docs/source/_toctree.yml | 6 +- docs/source/cameras.mdx | 176 +++++++++++++++++++++------------------ 2 files changed, 99 insertions(+), 83 deletions(-) diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index eb97117af..98417f134 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -7,8 +7,6 @@ - sections: - local: il_robots title: Imitation Learning for Robots - - local: cameras - title: Cameras - local: bring_your_own_policies title: Bring Your Own Policies - local: integrate_hardware @@ -108,6 +106,10 @@ - local: phone_teleop title: Phone title: "Teleoperators" +- sections: + - local: cameras + title: Cameras + title: "Sensors" - sections: - local: torch_accelerators title: PyTorch accelerators diff --git a/docs/source/cameras.mdx b/docs/source/cameras.mdx index 5c35be0ba..8af0f5ae5 100644 --- a/docs/source/cameras.mdx +++ b/docs/source/cameras.mdx @@ -1,12 +1,22 @@ # Cameras -LeRobot offers multiple options for video capture, including phone cameras, built-in laptop cameras, external webcams, and Intel RealSense cameras. To efficiently record frames from most cameras, you can use either the `OpenCVCamera` or `RealSenseCamera` class. For additional compatibility details on the `OpenCVCamera` class, refer to the [Video I/O with OpenCV Overview](https://docs.opencv.org/4.x/d0/da7/videoio_overview.html). +LeRobot offers multiple options for video capture: -### Finding your camera +| Class | Supported Cameras | +| ----------------- | ----------------------------------- | +| `OpenCVCamera` | Phone, built-in laptop, USB webcams | +| `ZMQCamera` | Network-connected cameras | +| `RealSenseCamera` | Intel RealSense (with depth) | +| `Reachy2Camera` | Reachy 2 robot cameras | -To instantiate a camera, you need a camera identifier. This identifier might change if you reboot your computer or re-plug your camera, a behavior mostly dependant on your operating system. +> [!TIP] +> For `OpenCVCamera` compatibility details, see the [Video I/O with OpenCV Overview](https://docs.opencv.org/4.x/d0/da7/videoio_overview.html). -To find the camera indices of the cameras plugged into your system, run the following script: +### Find your camera + +Every camera requires a unique identifier to be instantiated, allowing you to distinguish between multiple connected devices. + +`OpenCVCamera` and `RealSenseCamera` support auto-discovery. Run the command below to list available devices and their identifiers. Note that these identifiers may change after rebooting your computer or re-plugging the camera, depending on your operating system. ```bash lerobot-find-cameras opencv # or realsense for Intel Realsense cameras @@ -14,7 +24,7 @@ lerobot-find-cameras opencv # or realsense for Intel Realsense cameras The output will look something like this if you have two cameras connected: -``` +```bash --- Detected Cameras --- Camera #0: Name: OpenCV Camera @ 0 @@ -33,13 +43,37 @@ Camera #0: > [!WARNING] > When using Intel RealSense cameras in `macOS`, you could get this [error](https://github.com/IntelRealSense/librealsense/issues/12307): `Error finding RealSense cameras: failed to set power state`, this can be solved by running the same command with `sudo` permissions. Note that using RealSense cameras in `macOS` is unstable. -## Use Cameras +`ZMQCamera` and `Reachy2Camera` do not support auto-discovery. They must be configured manually by providing their network address and port or robot SDK settings. -Below are two examples, demonstrating how to work with the API. +## Use cameras -- **Asynchronous frame capture** using an OpenCV-based camera +### Frame access modes + +All camera classes implement three access modes for capturing frames: + +| Method | Behavior | Blocks? | Best For | +| ------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------- | -------------- | ---------------------------------------- | +| `read()` | Waits for the camera hardware to return a frame. May block for a long time depending on the camera and SDK. | Yes | Simple scripts, sequential capture | +| `async_read(timeout_ms)` | Returns the latest unconsumed frame from background thread. Blocks only if buffer is empty, up to `timeout_ms`. Raises `TimeoutError` if no frame arrives. | With a timeout | Control loops synchronized to camera FPS | +| `read_latest(max_age_ms)` | Peeks at the most recent frame in buffer (may be stale). Raises `TimeoutError` if frame is older than `max_age_ms`. | No | UI visualization, logging, monitoring | + +### Usage examples + +The following examples show how to use the camera API to configure and capture frames from different camera types. + +- **Blocking and non-blocking frame capture** using an OpenCV-based camera - **Color and depth capture** using an Intel RealSense camera +> [!WARNING] +> Failing to cleanly disconnect cameras can cause resource leaks. Use the context manager protocol to ensure automatic cleanup: +> +> ```python +> with OpenCVCamera(config) as camera: +> ... +> ``` +> +> You can also call `connect()` and `disconnect()` manually, but always use a `finally` block for the latter. + @@ -60,16 +94,30 @@ config = OpenCVCameraConfig( ) # Instantiate and connect an `OpenCVCamera`, performing a warm-up read (default). -camera = OpenCVCamera(config) -camera.connect() +with OpenCVCamera(config) as camera: + + # Read a frame synchronously — blocks until hardware delivers a new frame + frame = camera.read() + print(f"read() call returned frame with shape:", frame.shape) + + # Read a frame asynchronously with a timeout — returns the latest unconsumed frame or waits up to timeout_ms for a new one + try: + for i in range(10): + frame = camera.async_read(timeout_ms=200) + print(f"async_read call returned frame {i} with shape:", frame.shape) + except TimeoutError as e: + print(f"No frame received within timeout: {e}") + + # Instantly return a frame - returns the most recent frame captured by the camera + try: + initial_frame = camera.read_latest(max_age_ms=1000) + for i in range(10): + frame = camera.read_latest(max_age_ms=1000) + print(f"read_latest call returned frame {i} with shape:", frame.shape) + print(f"Was a new frame received by the camera? {not (initial_frame == frame).any()}") + except TimeoutError as e: + print(f"Frame too old: {e}") -# Read frames asynchronously in a loop via `async_read(timeout_ms)` -try: - for i in range(10): - frame = camera.async_read(timeout_ms=200) - print(f"Async frame {i} shape:", frame.shape) -finally: - camera.disconnect() ``` @@ -111,10 +159,10 @@ finally: -## Use your phone +## Use your phone's camera - + To use your iPhone as a camera on macOS, enable the Continuity Camera feature: @@ -124,83 +172,49 @@ To use your iPhone as a camera on macOS, enable the Continuity Camera feature: For more details, visit [Apple support](https://support.apple.com/en-gb/guide/mac-help/mchl77879b8a/mac). -Your iPhone should be detected automatically when running the camera setup script in the next section. - - + -If you want to use your phone as a camera on Linux, follow these steps to set up a virtual camera +If you want to use your phone as a camera using OBS, follow these steps to set up a virtual camera. -1. _Install `v4l2loopback-dkms` and `v4l-utils`_. Those packages are required to create virtual camera devices (`v4l2loopback`) and verify their settings with the `v4l2-ctl` utility from `v4l-utils`. Install them using: +1. _(Linux only) Install `v4l2loopback-dkms` and `v4l-utils`_. These packages create virtual camera devices and verify their settings. Install with: - -```python +```bash sudo apt install v4l2loopback-dkms v4l-utils ``` - -2. _Install [DroidCam](https://droidcam.app) on your phone_. This app is available for both iOS and Android. -3. _Install [OBS Studio](https://obsproject.com)_. This software will help you manage the camera feed. Install it using [Flatpak](https://flatpak.org): +2. _Install the [DroidCam app](https://droidcam.app) on your phone_. This app is available for both iOS and Android. +3. _Download and install [OBS Studio](https://obsproject.com)_. +4. _Download and install the [DroidCam OBS plugin](https://droidcam.app/obs)_. +5. _Start OBS Studio_. - -```python -flatpak install flathub com.obsproject.Studio -``` - - -4. _Install the DroidCam OBS plugin_. This plugin integrates DroidCam with OBS Studio. Install it with: - - -```python -flatpak install flathub com.obsproject.Studio.Plugin.DroidCam -``` - - -5. _Start OBS Studio_. Launch with: - - -```python -flatpak run com.obsproject.Studio -``` - - -6. _Add your phone as a source_. Follow the instructions [here](https://droidcam.app/obs/usage). Be sure to set the resolution to `640x480`. -7. _Adjust resolution settings_. In OBS Studio, go to `File > Settings > Video`. Change the `Base(Canvas) Resolution` and the `Output(Scaled) Resolution` to `640x480` by manually typing it in. +6. _Add your phone as a source_. Follow the instructions [here](https://droidcam.app/obs/usage). Be sure to set the resolution to `640x480` to avoid the watermarks. +7. _Adjust resolution settings_. In OBS Studio, go to `File > Settings > Video` or `OBS > Preferences... > Video`. Change the `Base(Canvas) Resolution` and the `Output(Scaled) Resolution` to `640x480` by manually typing it. 8. _Start virtual camera_. In OBS Studio, follow the instructions [here](https://obsproject.com/kb/virtual-camera-guide). -9. _Verify the virtual camera setup_. Use `v4l2-ctl` to list the devices: +9. _Verify the virtual camera setup and resolution_. + - **Linux**: Use `v4l2-ctl` to list devices and check resolution: + ```bash + v4l2-ctl --list-devices # find VirtualCam and note its /dev/videoX path + v4l2-ctl -d /dev/videoX --get-fmt-video # replace with your VirtualCam path + ``` + You should see `VirtualCam` listed and resolution `640x480`. + - **macOS**: Open Photo Booth or FaceTime and select "OBS Virtual Camera" as the input. + - **Windows**: The native Camera app doesn't support virtual cameras. Use a video conferencing app (Zoom, Teams) or run `lerobot-find-cameras opencv` directly to verify. - -```python -v4l2-ctl --list-devices -``` - +
+Troubleshooting -You should see an entry like: +> The virtual camera resolution is incorrect. -``` -VirtualCam (platform:v4l2loopback-000): -/dev/video1 -``` +Delete the virtual camera source and recreate it. The resolution cannot be changed after creation. -10. _Check the camera resolution_. Use `v4l2-ctl` to ensure that the virtual camera output resolution is `640x480`. Change `/dev/video1` to the port of your virtual camera from the output of `v4l2-ctl --list-devices`. +> Error reading frame in background thread for OpenCVCamera(X): OpenCVCamera(X) frame width=640 or height=480 do not match configured width=1920 or height=1080. - -```python -v4l2-ctl -d /dev/video1 --get-fmt-video -``` - +This error is caused by OBS Virtual Camera advertising a `1920x1080` resolution despite rescaling. The only fix for now is to comment out the width and height check in `_postprocess_image()`. -You should see an entry like: - -``` ->>> Format Video Capture: ->>> Width/Height : 640/480 ->>> Pixel Format : 'YUYV' (YUYV 4:2:2) -``` - -Troubleshooting: If the resolution is not correct you will have to delete the Virtual Camera port and try again as it cannot be changed. - -If everything is set up correctly, you can proceed with the rest of the tutorial. +
+ +If everything is set up correctly, your phone will appear as a standard OpenCV camera and can be used with `OpenCVCamera`. From 5c6182176f31996fb5d0c51f88a1bc59457ba7a6 Mon Sep 17 00:00:00 2001 From: Caroline Pascal Date: Fri, 30 Jan 2026 16:58:13 +0100 Subject: [PATCH 020/131] fix(find zmq): adding a clearer not implemented warning for the ZMQ find_cameras method (#2879) Co-authored-by: Martino Russi <77496684+nepyope@users.noreply.github.com> --- src/lerobot/cameras/zmq/camera_zmq.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/lerobot/cameras/zmq/camera_zmq.py b/src/lerobot/cameras/zmq/camera_zmq.py index a231a582a..f29e16a28 100644 --- a/src/lerobot/cameras/zmq/camera_zmq.py +++ b/src/lerobot/cameras/zmq/camera_zmq.py @@ -166,8 +166,10 @@ class ZMQCamera(Camera): @staticmethod def find_cameras() -> list[dict[str, Any]]: - """ZMQ cameras require manual configuration (server address/port).""" - return [] + """ + Detection not implemented for ZMQ cameras. These cameras require manual configuration (server address/port). + """ + raise NotImplementedError("Camera detection is not implemented for ZMQ cameras.") def _read_from_hardware(self) -> NDArray[Any]: """ From b18cef2e260a80db6cbe2327140950964c797b46 Mon Sep 17 00:00:00 2001 From: Jade Choghari Date: Fri, 30 Jan 2026 10:29:37 -0800 Subject: [PATCH 021/131] feat(dataset): add subtask support (#2860) * add subtask * remove folder * add docs * update doc * add testing * update test * update constant naming + doc * more docs --- docs/source/_toctree.yml | 2 + docs/source/dataset_subtask.mdx | 278 +++++++++++ src/lerobot/datasets/lerobot_dataset.py | 9 + src/lerobot/datasets/utils.py | 9 + src/lerobot/processor/converters.py | 3 +- src/lerobot/processor/tokenizer_processor.py | 46 ++ src/lerobot/utils/constants.py | 3 + tests/datasets/test_subtask_dataset.py | 190 ++++++++ tests/processor/test_tokenizer_processor.py | 465 ++++++++++++++++++- 9 files changed, 1003 insertions(+), 2 deletions(-) create mode 100644 docs/source/dataset_subtask.mdx create mode 100644 tests/datasets/test_subtask_dataset.py diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 98417f134..d61aac9c1 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -27,6 +27,8 @@ title: Porting Large Datasets - local: using_dataset_tools title: Using the Dataset Tools + - local: dataset_subtask + title: Using Subtasks in the Dataset title: "Datasets" - sections: - local: act diff --git a/docs/source/dataset_subtask.mdx b/docs/source/dataset_subtask.mdx new file mode 100644 index 000000000..beb5d80bd --- /dev/null +++ b/docs/source/dataset_subtask.mdx @@ -0,0 +1,278 @@ +# Using Subtasks in LeRobot Datasets + +Subtask support in robotics datasets has proven effective in improving robot reasoning and understanding. Subtasks are particularly useful for: + +- **Hierarchical policies**: Building policies that include subtask predictions to visualize robot reasoning in real time +- **Reward modeling**: Helping reward models understand task progression (e.g., SARM-style stage-aware reward models) +- **Task decomposition**: Breaking down complex manipulation tasks into atomic, interpretable steps + +LeRobotDataset now supports subtasks as part of its dataset structure, alongside tasks. + +## What are Subtasks? + +While a **task** describes the overall goal (e.g., "Pick up the apple and place it in the basket"), **subtasks** break down the execution into finer-grained steps: + +1. "Approach the apple" +2. "Grasp the apple" +3. "Lift the apple" +4. "Move to basket" +5. "Release the apple" + +Each frame in the dataset can be annotated with its corresponding subtask, enabling models to learn and predict these intermediate stages. + +An overview of subtask annotation showing how frames are labeled with intermediate subtask stages + +

+ Figure: Overview of subtask annotation. +

+ +**Reference:** _Subtask-learning based for robot self-assembly in flexible collaborative assembly in manufacturing_, Original Article, Published: 19 April 2022. + +## Dataset Structure + +Subtask information is stored in the dataset metadata: + +``` +my-dataset/ +├── data/ +│ └── ... +├── meta/ +│ ├── info.json +│ ├── stats.json +│ ├── tasks.parquet +│ ├── subtasks.parquet # Subtask index → subtask string mapping +│ └── episodes/ +│ └── ... +└── videos/ + └── ... +``` + +### Subtasks Parquet File + +The `meta/subtasks.parquet` file maps subtask indices to their natural language descriptions: + +| subtask_index | subtask (index column) | +| ------------- | ---------------------- | +| 0 | "Approach the apple" | +| 1 | "Grasp the apple" | +| 2 | "Lift the apple" | +| ... | ... | + +### Frame-Level Annotations + +Each frame in the dataset can include a `subtask_index` field that references the subtasks parquet file: + +```python +# Example frame data in the parquet file +{ + "index": 42, + "timestamp": 1.4, + "episode_index": 0, + "task_index": 0, + "subtask_index": 2, # References "Lift the apple" + "observation.state": [...], + "action": [...], +} +``` + +## Annotating Datasets with Subtasks + +We provide a HuggingFace Space for easily annotating any LeRobotDataset with subtasks: + +**[https://huggingface.co/spaces/lerobot/annotate](https://huggingface.co/spaces/lerobot/annotate)** + +After completing your annotation: + +1. Click "Push to Hub" to upload your annotated dataset +2. You can also run the annotation space locally by following the instructions at [github.com/huggingface/lerobot-annotate](https://github.com/huggingface/lerobot-annotate) + +## Loading Datasets with Subtasks + +When you load a dataset with subtask annotations, the subtask information is automatically available: + +```python +from lerobot.datasets.lerobot_dataset import LeRobotDataset + +# Load a dataset with subtask annotations +dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated") + +# Access a sample +sample = dataset[100] + +# The sample includes both task and subtask information +print(sample["task"]) # "Collect the fruit" +print(sample["subtask"]) # "Grasp the apple" +print(sample["task_index"]) # tensor(0) +print(sample["subtask_index"]) # tensor(2) +``` + +### Checking for Subtask Support + +You can check if a dataset has subtask annotations: + +```python +# Check if subtasks are available +has_subtasks = ( + "subtask_index" in dataset.features + and dataset.meta.subtasks is not None +) + +if has_subtasks: + print(f"Dataset has {len(dataset.meta.subtasks)} unique subtasks") + print("Subtasks:", list(dataset.meta.subtasks.index)) +``` + +## Using Subtasks for Training + +### With the Tokenizer Processor + +The `TokenizerProcessor` automatically handles subtask tokenization for Vision-Language Action (VLA) models: + +```python +from lerobot.processor.tokenizer_processor import TokenizerProcessor +from lerobot.processor.pipeline import ProcessorPipeline + +# Create a tokenizer processor +tokenizer_processor = TokenizerProcessor( + tokenizer_name_or_path="google/paligemma-3b-pt-224", + padding="max_length", + max_length=64, +) + +# The processor will automatically tokenize subtasks if present in the batch +# and add them to the observation under: +# - "observation.subtask.tokens" +# - "observation.subtask.attention_mask" +``` + +When subtasks are available in the batch, the tokenizer processor adds: + +- `observation.subtask.tokens`: Tokenized subtask text +- `observation.subtask.attention_mask`: Attention mask for the subtask tokens + +### DataLoader with Subtasks + +```python +import torch +from lerobot.datasets.lerobot_dataset import LeRobotDataset + +dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated") + +dataloader = torch.utils.data.DataLoader( + dataset, + batch_size=16, + shuffle=True, +) + +for batch in dataloader: + # Access subtask information in the batch + subtasks = batch["subtask"] # List of subtask strings + subtask_indices = batch["subtask_index"] # Tensor of subtask indices + + # Use for training hierarchical policies or reward models + print(f"Batch subtasks: {set(subtasks)}") +``` + +## Example Datasets with Subtask Annotations + +Try loading a dataset with subtask annotations: + +```python +from lerobot.datasets.lerobot_dataset import LeRobotDataset + +# Example dataset with subtask annotations +dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated") + +# Explore the subtasks +print("Available subtasks:") +for subtask_name in dataset.meta.subtasks.index: + print(f" - {subtask_name}") + +# Get subtask distribution +subtask_counts = {} +for i in range(len(dataset)): + sample = dataset[i] + subtask = sample["subtask"] + subtask_counts[subtask] = subtask_counts.get(subtask, 0) + 1 + +print("\nSubtask distribution:") +for subtask, count in sorted(subtask_counts.items(), key=lambda x: -x[1]): + print(f" {subtask}: {count} frames") +``` + +## Use Cases + +### 1. Hierarchical Policy Training + +Train policies that predict both actions and current subtask: + +```python +class HierarchicalPolicy(nn.Module): + def __init__(self, num_subtasks): + super().__init__() + self.action_head = nn.Linear(hidden_dim, action_dim) + self.subtask_head = nn.Linear(hidden_dim, num_subtasks) + + def forward(self, observations): + features = self.encoder(observations) + actions = self.action_head(features) + subtask_logits = self.subtask_head(features) + return actions, subtask_logits +``` + +### 2. Stage-Aware Reward Modeling (SARM) + +Build reward models that understand task progression: + +```python +# SARM predicts: +# - Stage: Which subtask is being executed (discrete) +# - Progress: How far along the subtask (continuous 0-1) + +class SARMRewardModel(nn.Module): + def forward(self, observations): + features = self.encoder(observations) + stage_logits = self.stage_classifier(features) + progress = self.progress_regressor(features) + return stage_logits, progress +``` + +### 3. Progress Visualization + +Monitor robot execution by tracking subtask progression: + +```python +def visualize_execution(model, observations): + for t, obs in enumerate(observations): + action, subtask_logits = model(obs) + predicted_subtask = subtask_names[subtask_logits.argmax()] + print(f"t={t}: Executing '{predicted_subtask}'") +``` + +## API Reference + +### LeRobotDataset Properties + +| Property | Type | Description | +| --------------------------- | ---------------------- | ------------------------------------------ | +| `meta.subtasks` | `pd.DataFrame \| None` | DataFrame mapping subtask names to indices | +| `features["subtask_index"]` | `dict` | Feature spec for subtask_index if present | + +### Sample Keys + +When subtasks are available, each sample includes: + +| Key | Type | Description | +| --------------- | -------------- | ------------------------------------ | +| `subtask_index` | `torch.Tensor` | Integer index of the current subtask | +| `subtask` | `str` | Natural language subtask description | + +## Related Resources + +- [SARM Paper](https://arxiv.org/pdf/2509.25358) - Stage-Aware Reward Modeling for Long Horizon Robot Manipulation +- [LeRobot Annotate Space](https://huggingface.co/spaces/lerobot/annotate) - Interactive annotation tool +- [LeRobotDataset v3.0](./lerobot-dataset-v3) - Dataset format documentation diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index 6798e7fd7..36bffa190 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -57,6 +57,7 @@ from lerobot.datasets.utils import ( load_info, load_nested_dataset, load_stats, + load_subtasks, load_tasks, update_chunk_file_indices, validate_episode_buffer, @@ -162,6 +163,7 @@ class LeRobotDatasetMetadata: self.info = load_info(self.root) check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION) self.tasks = load_tasks(self.root) + self.subtasks = load_subtasks(self.root) self.episodes = load_episodes(self.root) self.stats = load_stats(self.root) @@ -518,6 +520,7 @@ class LeRobotDatasetMetadata: _validate_feature_names(features) obj.tasks = None + obj.subtasks = None obj.episodes = None obj.stats = None obj.info = create_empty_dataset_info( @@ -1075,6 +1078,12 @@ class LeRobotDataset(torch.utils.data.Dataset): # Add task as a string task_idx = item["task_index"].item() item["task"] = self.meta.tasks.iloc[task_idx].name + + # add subtask information if available + if "subtask_index" in self.features and self.meta.subtasks is not None: + subtask_idx = item["subtask_index"].item() + item["subtask"] = self.meta.subtasks.iloc[subtask_idx].name + return item def __repr__(self): diff --git a/src/lerobot/datasets/utils.py b/src/lerobot/datasets/utils.py index ed678af6e..321ecedd5 100644 --- a/src/lerobot/datasets/utils.py +++ b/src/lerobot/datasets/utils.py @@ -60,6 +60,7 @@ VIDEO_DIR = "videos" CHUNK_FILE_PATTERN = "chunk-{chunk_index:03d}/file-{file_index:03d}" DEFAULT_TASKS_PATH = "meta/tasks.parquet" +DEFAULT_SUBTASKS_PATH = "meta/subtasks.parquet" DEFAULT_EPISODES_PATH = EPISODES_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet" DEFAULT_DATA_PATH = DATA_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet" DEFAULT_VIDEO_PATH = VIDEO_DIR + "/{video_key}/" + CHUNK_FILE_PATTERN + ".mp4" @@ -353,6 +354,14 @@ def load_tasks(local_dir: Path) -> pandas.DataFrame: return tasks +def load_subtasks(local_dir: Path) -> pandas.DataFrame | None: + """Load subtasks from subtasks.parquet if it exists.""" + subtasks_path = local_dir / DEFAULT_SUBTASKS_PATH + if subtasks_path.exists(): + return pd.read_parquet(subtasks_path) + return None + + def write_episodes(episodes: Dataset, local_dir: Path) -> None: """Write episode metadata to a parquet file in the LeRobot v3.0 format. This function writes episode-level metadata to a single parquet file. diff --git a/src/lerobot/processor/converters.py b/src/lerobot/processor/converters.py index 4f9485fee..18c7b0220 100644 --- a/src/lerobot/processor/converters.py +++ b/src/lerobot/processor/converters.py @@ -168,11 +168,12 @@ def _extract_complementary_data(batch: dict[str, Any]) -> dict[str, Any]: """ pad_keys = {k: v for k, v in batch.items() if "_is_pad" in k} task_key = {"task": batch["task"]} if "task" in batch else {} + subtask_key = {"subtask": batch["subtask"]} if "subtask" in batch else {} index_key = {"index": batch["index"]} if "index" in batch else {} task_index_key = {"task_index": batch["task_index"]} if "task_index" in batch else {} episode_index_key = {"episode_index": batch["episode_index"]} if "episode_index" in batch else {} - return {**pad_keys, **task_key, **index_key, **task_index_key, **episode_index_key} + return {**pad_keys, **task_key, **subtask_key, **index_key, **task_index_key, **episode_index_key} def create_transition( diff --git a/src/lerobot/processor/tokenizer_processor.py b/src/lerobot/processor/tokenizer_processor.py index 5cd1bebb0..df559555a 100644 --- a/src/lerobot/processor/tokenizer_processor.py +++ b/src/lerobot/processor/tokenizer_processor.py @@ -34,6 +34,8 @@ from lerobot.utils.constants import ( ACTION_TOKEN_MASK, ACTION_TOKENS, OBS_LANGUAGE_ATTENTION_MASK, + OBS_LANGUAGE_SUBTASK_ATTENTION_MASK, + OBS_LANGUAGE_SUBTASK_TOKENS, OBS_LANGUAGE_TOKENS, ) from lerobot.utils.import_utils import _transformers_available @@ -139,6 +141,32 @@ class TokenizerProcessorStep(ObservationProcessorStep): return None + def get_subtask(self, transition: EnvTransition) -> list[str] | None: + """ + Extracts the subtask from the transition's complementary data. + + Args: + transition: The environment transition. + + Returns: + A list of subtask strings, or None if the subtask key is not found or the value is None. + """ + complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA) + if complementary_data is None: + return None + + subtask = complementary_data.get("subtask") + if subtask is None: + return None + + # Standardize to a list of strings for the tokenizer + if isinstance(subtask, str): + return [subtask] + elif isinstance(subtask, list) and all(isinstance(t, str) for t in subtask): + return subtask + + return None + def observation(self, observation: RobotObservation) -> RobotObservation: """ Tokenizes the task description and adds it to the observation dictionary. @@ -176,6 +204,24 @@ class TokenizerProcessorStep(ObservationProcessorStep): new_observation[OBS_LANGUAGE_TOKENS] = tokenized_prompt["input_ids"] new_observation[OBS_LANGUAGE_ATTENTION_MASK] = tokenized_prompt["attention_mask"].to(dtype=torch.bool) + # Tokenize subtask if available + subtask = self.get_subtask(self.transition) + if subtask is not None: + tokenized_subtask = self._tokenize_text(subtask) + + # Move new tokenized tensors to the detected device + if target_device is not None: + tokenized_subtask = { + k: v.to(target_device) if isinstance(v, torch.Tensor) else v + for k, v in tokenized_subtask.items() + } + + # Add tokenized subtask to the observation + new_observation[OBS_LANGUAGE_SUBTASK_TOKENS] = tokenized_subtask["input_ids"] + new_observation[OBS_LANGUAGE_SUBTASK_ATTENTION_MASK] = tokenized_subtask["attention_mask"].to( + dtype=torch.bool + ) + return new_observation def _detect_device(self, transition: EnvTransition) -> torch.device | None: diff --git a/src/lerobot/utils/constants.py b/src/lerobot/utils/constants.py index 43a61b4f7..ecd54844c 100644 --- a/src/lerobot/utils/constants.py +++ b/src/lerobot/utils/constants.py @@ -26,6 +26,9 @@ OBS_IMAGES = OBS_IMAGE + "s" OBS_LANGUAGE = OBS_STR + ".language" OBS_LANGUAGE_TOKENS = OBS_LANGUAGE + ".tokens" OBS_LANGUAGE_ATTENTION_MASK = OBS_LANGUAGE + ".attention_mask" +OBS_LANGUAGE_SUBTASK = OBS_STR + ".subtask" +OBS_LANGUAGE_SUBTASK_TOKENS = OBS_LANGUAGE_SUBTASK + ".tokens" +OBS_LANGUAGE_SUBTASK_ATTENTION_MASK = OBS_LANGUAGE_SUBTASK + ".attention_mask" ACTION = "action" ACTION_PREFIX = ACTION + "." diff --git a/tests/datasets/test_subtask_dataset.py b/tests/datasets/test_subtask_dataset.py new file mode 100644 index 000000000..f80a6c72d --- /dev/null +++ b/tests/datasets/test_subtask_dataset.py @@ -0,0 +1,190 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Tests for subtask functionality in LeRobotDataset. + +These tests verify that: +- Subtask information is correctly loaded from datasets that have subtask data +- The __getitem__ method correctly adds subtask strings to returned items +- Subtask handling gracefully handles missing data +""" + +import pandas as pd +import pytest +import torch + +from lerobot.datasets.lerobot_dataset import LeRobotDataset + + +class TestSubtaskDataset: + """Tests for subtask handling in LeRobotDataset.""" + + @pytest.fixture + def subtask_dataset(self): + """Load the test subtask dataset from the hub.""" + # Use lerobot/pusht-subtask dataset with episode 1 + return LeRobotDataset( + repo_id="lerobot/pusht-subtask", + episodes=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], + ) + + def test_subtask_dataset_loads(self, subtask_dataset): + """Test that the subtask dataset loads successfully.""" + assert subtask_dataset is not None + assert len(subtask_dataset) > 0 + + def test_subtask_metadata_loaded(self, subtask_dataset): + """Test that subtask metadata is loaded when present in dataset.""" + # The dataset should have subtasks metadata loaded + assert subtask_dataset.meta.subtasks is not None + assert isinstance(subtask_dataset.meta.subtasks, pd.DataFrame) + + def test_subtask_index_in_features(self, subtask_dataset): + """Test that subtask_index is a feature when dataset has subtasks.""" + assert "subtask_index" in subtask_dataset.features + + def test_getitem_returns_subtask_string(self, subtask_dataset): + """Test that __getitem__ correctly adds subtask string to returned item.""" + item = subtask_dataset[0] + + # Subtask should be present in the returned item + assert "subtask" in item + assert isinstance(item["subtask"], str) + assert len(item["subtask"]) > 0 # Should not be empty + + def test_getitem_has_subtask_index(self, subtask_dataset): + """Test that __getitem__ includes subtask_index.""" + item = subtask_dataset[0] + + assert "subtask_index" in item + assert isinstance(item["subtask_index"], torch.Tensor) + + def test_subtask_index_maps_to_valid_subtask(self, subtask_dataset): + """Test that subtask_index correctly maps to a subtask in metadata.""" + item = subtask_dataset[0] + + subtask_idx = item["subtask_index"].item() + subtask_from_metadata = subtask_dataset.meta.subtasks.iloc[subtask_idx].name + + assert item["subtask"] == subtask_from_metadata + + def test_all_items_have_subtask(self, subtask_dataset): + """Test that all items in the dataset have subtask information.""" + for i in range(min(len(subtask_dataset), 5)): # Check first 5 items + item = subtask_dataset[i] + assert "subtask" in item + assert isinstance(item["subtask"], str) + + def test_task_and_subtask_coexist(self, subtask_dataset): + """Test that both task and subtask are present in returned items.""" + item = subtask_dataset[0] + + # Both task and subtask should be present + assert "task" in item + assert "subtask" in item + assert isinstance(item["task"], str) + assert isinstance(item["subtask"], str) + + +class TestSubtaskDatasetMissing: + """Tests for graceful handling when subtask data is missing.""" + + @pytest.fixture + def dataset_without_subtasks(self, tmp_path, empty_lerobot_dataset_factory): + """Create a dataset without subtask information.""" + features = {"state": {"dtype": "float32", "shape": (2,), "names": None}} + dataset = empty_lerobot_dataset_factory(root=tmp_path / "no_subtask", features=features) + + # Add some frames and save + for _ in range(5): + dataset.add_frame({"state": torch.randn(2), "task": "Test task"}) + dataset.save_episode() + dataset.finalize() + + # Reload the dataset + return LeRobotDataset(dataset.repo_id, root=dataset.root) + + def test_no_subtask_in_features(self, dataset_without_subtasks): + """Test that subtask_index is not in features when not provided.""" + assert "subtask_index" not in dataset_without_subtasks.features + + def test_getitem_without_subtask(self, dataset_without_subtasks): + """Test that __getitem__ works when subtask is not present.""" + item = dataset_without_subtasks[0] + + # Item should still be retrievable + assert item is not None + assert "state" in item + assert "task" in item + + # Subtask should NOT be present + assert "subtask" not in item + + def test_subtasks_metadata_is_none(self, dataset_without_subtasks): + """Test that subtasks metadata is None when not present.""" + assert dataset_without_subtasks.meta.subtasks is None + + +class TestSubtaskEdgeCases: + """Edge case tests for subtask handling.""" + + def test_subtask_with_multiple_episodes(self): + """Test subtask handling with multiple episodes if available.""" + try: + dataset = LeRobotDataset( + repo_id="lerobot/pusht-subtask", + episodes=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], + ) + except Exception: + pytest.skip("Could not load test-subtask dataset") + + # Check first and last items have valid subtasks + first_item = dataset[0] + last_item = dataset[len(dataset) - 1] + + assert "subtask" in first_item + assert "subtask" in last_item + assert isinstance(first_item["subtask"], str) + assert isinstance(last_item["subtask"], str) + + def test_subtask_index_consistency(self): + """Test that same subtask_index returns same subtask string.""" + try: + dataset = LeRobotDataset( + repo_id="lerobot/pusht-subtask", + episodes=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], + ) + except Exception: + pytest.skip("Could not load test-subtask dataset") + + if len(dataset) < 2: + pytest.skip("Dataset too small for this test") + + # Collect subtask_index to subtask mappings + subtask_map = {} + for i in range(min(len(dataset), 10)): + item = dataset[i] + idx = item["subtask_index"].item() + subtask = item["subtask"] + + if idx in subtask_map: + # Same index should always return same subtask + assert subtask_map[idx] == subtask, ( + f"Inconsistent subtask for index {idx}: '{subtask_map[idx]}' vs '{subtask}'" + ) + else: + subtask_map[idx] = subtask diff --git a/tests/processor/test_tokenizer_processor.py b/tests/processor/test_tokenizer_processor.py index d6f87f567..64cc8aac8 100644 --- a/tests/processor/test_tokenizer_processor.py +++ b/tests/processor/test_tokenizer_processor.py @@ -27,7 +27,14 @@ import torch from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature from lerobot.processor import DataProcessorPipeline, TokenizerProcessorStep, TransitionKey from lerobot.processor.converters import create_transition, identity_transition -from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_LANGUAGE, OBS_STATE +from lerobot.utils.constants import ( + ACTION, + OBS_IMAGE, + OBS_LANGUAGE, + OBS_LANGUAGE_SUBTASK_ATTENTION_MASK, + OBS_LANGUAGE_SUBTASK_TOKENS, + OBS_STATE, +) from tests.utils import require_package @@ -1038,3 +1045,459 @@ def test_simulated_accelerate_scenario(): # MockTokenizer squeezes single-item batches, so shape is (max_length,) not (1, max_length) assert tokens.shape == (10,) # MockTokenizer behavior for single string in list assert attention_mask.shape == (10,) + + +# ============================================================================= +# Tests for get_subtask method +# ============================================================================= + + +@require_package("transformers") +def test_get_subtask_missing_key(): + """Test get_subtask returns None when subtask key is missing from complementary_data.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10) + + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": "main task"}, # No "subtask" key + ) + + result = processor.get_subtask(transition) + assert result is None + + +@require_package("transformers") +def test_get_subtask_none_value(): + """Test get_subtask returns None when subtask value is None.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10) + + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": "main task", "subtask": None}, + ) + + result = processor.get_subtask(transition) + assert result is None + + +@require_package("transformers") +def test_get_subtask_none_complementary_data(): + """Test get_subtask returns None when complementary_data is None.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10) + + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data=None, # No complementary data + ) + + result = processor.get_subtask(transition) + assert result is None + + +@require_package("transformers") +def test_get_subtask_string(): + """Test get_subtask returns list with single string when subtask is a string.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10) + + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": "main task", "subtask": "pick up the cube"}, + ) + + result = processor.get_subtask(transition) + assert result == ["pick up the cube"] + assert isinstance(result, list) + assert len(result) == 1 + + +@require_package("transformers") +def test_get_subtask_list_of_strings(): + """Test get_subtask returns the list when subtask is already a list of strings.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10) + + subtask_list = ["pick up", "move to target", "place down"] + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": "main task", "subtask": subtask_list}, + ) + + result = processor.get_subtask(transition) + assert result == subtask_list + assert isinstance(result, list) + assert len(result) == 3 + + +@require_package("transformers") +def test_get_subtask_unsupported_type_integer(): + """Test get_subtask returns None when subtask is an unsupported type (integer).""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10) + + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": "main task", "subtask": 123}, + ) + + result = processor.get_subtask(transition) + assert result is None + + +@require_package("transformers") +def test_get_subtask_unsupported_type_mixed_list(): + """Test get_subtask returns None when subtask is a list with mixed types.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10) + + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": "main task", "subtask": ["valid string", 123, "another string"]}, + ) + + result = processor.get_subtask(transition) + assert result is None + + +@require_package("transformers") +def test_get_subtask_unsupported_type_dict(): + """Test get_subtask returns None when subtask is a dictionary.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10) + + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": "main task", "subtask": {"key": "value"}}, + ) + + result = processor.get_subtask(transition) + assert result is None + + +@require_package("transformers") +def test_get_subtask_empty_string(): + """Test get_subtask with empty string returns list with empty string.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10) + + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": "main task", "subtask": ""}, + ) + + result = processor.get_subtask(transition) + assert result == [""] + + +@require_package("transformers") +def test_get_subtask_empty_list(): + """Test get_subtask with empty list returns empty list.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10) + + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": "main task", "subtask": []}, + ) + + result = processor.get_subtask(transition) + assert result == [] + + +# ============================================================================= +# Tests for subtask tokenization in observation method +# ============================================================================= + + +@require_package("transformers") +def test_subtask_tokenization_when_present(): + """Test that subtask is tokenized and added to observation when present.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=8) + + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": "main task", "subtask": "pick up the red cube"}, + ) + + result = processor(transition) + + # Check that subtask tokens were added to observation + observation = result[TransitionKey.OBSERVATION] + assert OBS_LANGUAGE_SUBTASK_TOKENS in observation + assert OBS_LANGUAGE_SUBTASK_ATTENTION_MASK in observation + + # Check token structure + subtask_tokens = observation[OBS_LANGUAGE_SUBTASK_TOKENS] + subtask_attention_mask = observation[OBS_LANGUAGE_SUBTASK_ATTENTION_MASK] + assert isinstance(subtask_tokens, torch.Tensor) + assert isinstance(subtask_attention_mask, torch.Tensor) + assert subtask_tokens.shape == (8,) + assert subtask_attention_mask.shape == (8,) + assert subtask_attention_mask.dtype == torch.bool + + +@require_package("transformers") +def test_subtask_tokenization_not_added_when_none(): + """Test that subtask tokens are NOT added to observation when subtask is None.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=8) + + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": "main task"}, # No subtask + ) + + result = processor(transition) + + # Check that subtask tokens were NOT added to observation + observation = result[TransitionKey.OBSERVATION] + assert OBS_LANGUAGE_SUBTASK_TOKENS not in observation + assert OBS_LANGUAGE_SUBTASK_ATTENTION_MASK not in observation + + # But main task tokens should still be present + assert f"{OBS_LANGUAGE}.tokens" in observation + assert f"{OBS_LANGUAGE}.attention_mask" in observation + + +@require_package("transformers") +def test_subtask_tokenization_not_added_when_subtask_value_is_none(): + """Test that subtask tokens are NOT added when subtask value is explicitly None.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=8) + + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": "main task", "subtask": None}, + ) + + result = processor(transition) + + # Check that subtask tokens were NOT added to observation + observation = result[TransitionKey.OBSERVATION] + assert OBS_LANGUAGE_SUBTASK_TOKENS not in observation + assert OBS_LANGUAGE_SUBTASK_ATTENTION_MASK not in observation + + +@require_package("transformers") +def test_subtask_tokenization_list_of_strings(): + """Test subtask tokenization with list of strings.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=8) + + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": "main task", "subtask": ["pick up", "place down"]}, + ) + + result = processor(transition) + + # Check that subtask tokens were added to observation + observation = result[TransitionKey.OBSERVATION] + assert OBS_LANGUAGE_SUBTASK_TOKENS in observation + assert OBS_LANGUAGE_SUBTASK_ATTENTION_MASK in observation + + # Check token structure for batch + subtask_tokens = observation[OBS_LANGUAGE_SUBTASK_TOKENS] + subtask_attention_mask = observation[OBS_LANGUAGE_SUBTASK_ATTENTION_MASK] + assert subtask_tokens.shape == (2, 8) # batch_size=2, seq_len=8 + assert subtask_attention_mask.shape == (2, 8) + + +@require_package("transformers") +def test_subtask_tokenization_device_cpu(): + """Test that subtask tokens are on CPU when other tensors are on CPU.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10) + + # Create transition with CPU tensors + observation = {OBS_STATE: torch.randn(10)} # CPU tensor + action = torch.randn(5) # CPU tensor + transition = create_transition( + observation=observation, + action=action, + complementary_data={"task": "main task", "subtask": "pick up cube"}, + ) + + result = processor(transition) + + # Check that subtask tokens are on CPU + subtask_tokens = result[TransitionKey.OBSERVATION][OBS_LANGUAGE_SUBTASK_TOKENS] + subtask_attention_mask = result[TransitionKey.OBSERVATION][OBS_LANGUAGE_SUBTASK_ATTENTION_MASK] + + assert subtask_tokens.device.type == "cpu" + assert subtask_attention_mask.device.type == "cpu" + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@require_package("transformers") +def test_subtask_tokenization_device_cuda(): + """Test that subtask tokens are moved to CUDA when other tensors are on CUDA.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10) + + # Create transition with CUDA tensors + observation = {OBS_STATE: torch.randn(10).cuda()} # CUDA tensor + action = torch.randn(5).cuda() # CUDA tensor + transition = create_transition( + observation=observation, + action=action, + complementary_data={"task": "main task", "subtask": "pick up cube"}, + ) + + result = processor(transition) + + # Check that subtask tokens are on CUDA + subtask_tokens = result[TransitionKey.OBSERVATION][OBS_LANGUAGE_SUBTASK_TOKENS] + subtask_attention_mask = result[TransitionKey.OBSERVATION][OBS_LANGUAGE_SUBTASK_ATTENTION_MASK] + + assert subtask_tokens.device.type == "cuda" + assert subtask_attention_mask.device.type == "cuda" + + +@require_package("transformers") +def test_subtask_tokenization_preserves_other_observation_data(): + """Test that subtask tokenization preserves other observation data.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10) + + original_state = torch.tensor([1.0, 2.0, 3.0]) + transition = create_transition( + observation={"state": original_state.clone()}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": "main task", "subtask": "pick up cube"}, + ) + + result = processor(transition) + observation = result[TransitionKey.OBSERVATION] + + # Check that original observation data is preserved + assert torch.equal(observation["state"], original_state) + + # Check that both task and subtask tokens are present + assert f"{OBS_LANGUAGE}.tokens" in observation + assert f"{OBS_LANGUAGE}.attention_mask" in observation + assert OBS_LANGUAGE_SUBTASK_TOKENS in observation + assert OBS_LANGUAGE_SUBTASK_ATTENTION_MASK in observation + + +@require_package("transformers") +def test_subtask_attention_mask_dtype(): + """Test that subtask attention mask has correct dtype (bool).""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10) + + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": "main task", "subtask": "pick up cube"}, + ) + + result = processor(transition) + observation = result[TransitionKey.OBSERVATION] + + subtask_attention_mask = observation[OBS_LANGUAGE_SUBTASK_ATTENTION_MASK] + assert subtask_attention_mask.dtype == torch.bool + + +@require_package("transformers") +def test_subtask_tokenization_deterministic(): + """Test that subtask tokenization is deterministic for the same input.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10) + + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": "main task", "subtask": "consistent subtask"}, + ) + + result1 = processor(transition) + result2 = processor(transition) + + subtask_tokens1 = result1[TransitionKey.OBSERVATION][OBS_LANGUAGE_SUBTASK_TOKENS] + subtask_tokens2 = result2[TransitionKey.OBSERVATION][OBS_LANGUAGE_SUBTASK_TOKENS] + subtask_mask1 = result1[TransitionKey.OBSERVATION][OBS_LANGUAGE_SUBTASK_ATTENTION_MASK] + subtask_mask2 = result2[TransitionKey.OBSERVATION][OBS_LANGUAGE_SUBTASK_ATTENTION_MASK] + + # Results should be identical + assert torch.equal(subtask_tokens1, subtask_tokens2) + assert torch.equal(subtask_mask1, subtask_mask2) + + +@require_package("transformers") +@patch("lerobot.processor.tokenizer_processor.AutoTokenizer") +def test_subtask_tokenization_integration_with_pipeline(mock_auto_tokenizer): + """Test subtask tokenization works correctly with DataProcessorPipeline.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer + + tokenizer_processor = TokenizerProcessorStep(tokenizer_name="test-tokenizer", max_length=6) + robot_processor = DataProcessorPipeline( + [tokenizer_processor], to_transition=identity_transition, to_output=identity_transition + ) + + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": "main task", "subtask": "subtask instruction"}, + ) + + result = robot_processor(transition) + + # Check that observation exists and both tokenizations were applied + assert TransitionKey.OBSERVATION in result + observation = result[TransitionKey.OBSERVATION] + + # Check task tokens + assert f"{OBS_LANGUAGE}.tokens" in observation + assert f"{OBS_LANGUAGE}.attention_mask" in observation + + # Check subtask tokens + assert OBS_LANGUAGE_SUBTASK_TOKENS in observation + assert OBS_LANGUAGE_SUBTASK_ATTENTION_MASK in observation + + # Check shapes + assert observation[f"{OBS_LANGUAGE}.tokens"].shape == (6,) + assert observation[OBS_LANGUAGE_SUBTASK_TOKENS].shape == (6,) + + +@require_package("transformers") +def test_subtask_not_added_for_unsupported_types(): + """Test that subtask tokens are not added when subtask has unsupported type.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=8) + + # Test with integer subtask + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": "main task", "subtask": 123}, + ) + + result = processor(transition) + observation = result[TransitionKey.OBSERVATION] + + # Subtask tokens should NOT be added for unsupported types + assert OBS_LANGUAGE_SUBTASK_TOKENS not in observation + assert OBS_LANGUAGE_SUBTASK_ATTENTION_MASK not in observation + + # But main task tokens should still be present + assert f"{OBS_LANGUAGE}.tokens" in observation From 9c24a09665ffe1338a91ee7a05a0e76d10e7e3d4 Mon Sep 17 00:00:00 2001 From: Hirokazu Ishida <38597814+HiroIshida@users.noreply.github.com> Date: Tue, 3 Feb 2026 04:05:58 +0900 Subject: [PATCH 022/131] docs: update document in response to Simplify configs PR (#1596) * docs: update document input/output_shapes -> input/output_features * fix inconsistent quote (suggested by copilot reviewer) * docs: shapes => PolicyFeature * docs: relfect normalization_mapping and remove outdated --- src/lerobot/configs/policies.py | 12 ++++----- src/lerobot/policies/act/configuration_act.py | 23 +++++----------- .../diffusion/configuration_diffusion.py | 23 +++++----------- .../policies/tdmpc/configuration_tdmpc.py | 26 +++++-------------- .../policies/vqbet/configuration_vqbet.py | 23 +++++----------- 5 files changed, 34 insertions(+), 73 deletions(-) diff --git a/src/lerobot/configs/policies.py b/src/lerobot/configs/policies.py index 7f326b70b..44b013c29 100644 --- a/src/lerobot/configs/policies.py +++ b/src/lerobot/configs/policies.py @@ -45,12 +45,12 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: igno Args: n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the current step and additional steps going back). - input_shapes: A dictionary defining the shapes of the input data for the policy. - output_shapes: A dictionary defining the shapes of the output data for the policy. - input_normalization_modes: A dictionary with key representing the modality and the value specifies the - normalization mode to apply. - output_normalization_modes: Similar dictionary as `input_normalization_modes`, but to unnormalize to - the original scale. + input_features: A dictionary defining the PolicyFeature of the input data for the policy. The key represents + the input data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes. + output_features: A dictionary defining the PolicyFeature of the output data for the policy. The key represents + the output data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes. + normalization_mapping: A dictionary that maps from a str value of FeatureType (e.g., "STATE", "VISUAL") to + a corresponding NormalizationMode (e.g., NormalizationMode.MIN_MAX) """ n_obs_steps: int = 1 diff --git a/src/lerobot/policies/act/configuration_act.py b/src/lerobot/policies/act/configuration_act.py index 6f6c1c4be..bd89185fd 100644 --- a/src/lerobot/policies/act/configuration_act.py +++ b/src/lerobot/policies/act/configuration_act.py @@ -28,7 +28,7 @@ class ACTConfig(PreTrainedConfig): Defaults are configured for training on bimanual Aloha tasks like "insertion" or "transfer". The parameters you will most likely need to change are the ones which depend on the environment / sensors. - Those are: `input_shapes` and 'output_shapes`. + Those are: `input_features` and `output_features`. Notes on the inputs and outputs: - Either: @@ -48,21 +48,12 @@ class ACTConfig(PreTrainedConfig): This should be no greater than the chunk size. For example, if the chunk size size 100, you may set this to 50. This would mean that the model predicts 100 steps worth of actions, runs 50 in the environment, and throws the other 50 out. - input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents - the input data name, and the value is a list indicating the dimensions of the corresponding data. - For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96], - indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't - include batch dimension or temporal dimension. - output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents - the output data name, and the value is a list indicating the dimensions of the corresponding data. - For example, "action" refers to an output shape of [14], indicating 14-dimensional actions. - Importantly, `output_shapes` doesn't include batch dimension or temporal dimension. - input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"), - and the value specifies the normalization mode to apply. The two available modes are "mean_std" - which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a - [-1, 1] range. - output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the - original scale. Note that this is also used for normalizing the training targets. + input_features: A dictionary defining the PolicyFeature of the input data for the policy. The key represents + the input data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes. + output_features: A dictionary defining the PolicyFeature of the output data for the policy. The key represents + the output data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes. + normalization_mapping: A dictionary that maps from a str value of FeatureType (e.g., "STATE", "VISUAL") to + a corresponding NormalizationMode (e.g., NormalizationMode.MIN_MAX) vision_backbone: Name of the torchvision resnet backbone to use for encoding images. pretrained_backbone_weights: Pretrained weights from torchvision to initialize the backbone. `None` means no pretrained weights. diff --git a/src/lerobot/policies/diffusion/configuration_diffusion.py b/src/lerobot/policies/diffusion/configuration_diffusion.py index 54569434a..8322ca337 100644 --- a/src/lerobot/policies/diffusion/configuration_diffusion.py +++ b/src/lerobot/policies/diffusion/configuration_diffusion.py @@ -30,7 +30,7 @@ class DiffusionConfig(PreTrainedConfig): Defaults are configured for training with PushT providing proprioceptive and single camera observations. The parameters you will most likely need to change are the ones which depend on the environment / sensors. - Those are: `input_shapes` and `output_shapes`. + Those are: `input_features` and `output_features`. Notes on the inputs and outputs: - "observation.state" is required as an input key. @@ -48,21 +48,12 @@ class DiffusionConfig(PreTrainedConfig): horizon: Diffusion model action prediction size as detailed in `DiffusionPolicy.select_action`. n_action_steps: The number of action steps to run in the environment for one invocation of the policy. See `DiffusionPolicy.select_action` for more details. - input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents - the input data name, and the value is a list indicating the dimensions of the corresponding data. - For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96], - indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't - include batch dimension or temporal dimension. - output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents - the output data name, and the value is a list indicating the dimensions of the corresponding data. - For example, "action" refers to an output shape of [14], indicating 14-dimensional actions. - Importantly, `output_shapes` doesn't include batch dimension or temporal dimension. - input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"), - and the value specifies the normalization mode to apply. The two available modes are "mean_std" - which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a - [-1, 1] range. - output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the - original scale. Note that this is also used for normalizing the training targets. + input_features: A dictionary defining the PolicyFeature of the input data for the policy. The key represents + the input data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes. + output_features: A dictionary defining the PolicyFeature of the output data for the policy. The key represents + the output data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes. + normalization_mapping: A dictionary that maps from a str value of FeatureType (e.g., "STATE", "VISUAL") to + a corresponding NormalizationMode (e.g., NormalizationMode.MIN_MAX) vision_backbone: Name of the torchvision resnet backbone to use for encoding images. crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit within the image size. If None, no cropping is done. diff --git a/src/lerobot/policies/tdmpc/configuration_tdmpc.py b/src/lerobot/policies/tdmpc/configuration_tdmpc.py index 3c1a29932..3ec493472 100644 --- a/src/lerobot/policies/tdmpc/configuration_tdmpc.py +++ b/src/lerobot/policies/tdmpc/configuration_tdmpc.py @@ -30,7 +30,7 @@ class TDMPCConfig(PreTrainedConfig): camera observations. The parameters you will most likely need to change are the ones which depend on the environment / sensors. - Those are: `input_shapes`, `output_shapes`, and perhaps `max_random_shift_ratio`. + Those are: `input_features`, `output_features`, and perhaps `max_random_shift_ratio`. Args: n_action_repeats: The number of times to repeat the action returned by the planning. (hint: Google @@ -40,24 +40,12 @@ class TDMPCConfig(PreTrainedConfig): is an alternative to using action repeats. If this is set to more than 1, then we require `n_action_repeats == 1`, `use_mpc == True` and `n_action_steps <= horizon`. Note that this approach of using multiple steps from the plan is not in the original implementation. - input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents - the input data name, and the value is a list indicating the dimensions of the corresponding data. - For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96], - indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't - include batch dimension or temporal dimension. - output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents - the output data name, and the value is a list indicating the dimensions of the corresponding data. - For example, "action" refers to an output shape of [14], indicating 14-dimensional actions. - Importantly, `output_shapes` doesn't include batch dimension or temporal dimension. - input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"), - and the value specifies the normalization mode to apply. The two available modes are "mean_std" - which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a - [-1, 1] range. Note that here this defaults to None meaning inputs are not normalized. This is to - match the original implementation. - output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the - original scale. Note that this is also used for normalizing the training targets. NOTE: Clipping - to [-1, +1] is used during MPPI/CEM. Therefore, it is recommended that you stick with "min_max" - normalization mode here. + input_features: A dictionary defining the PolicyFeature of the input data for the policy. The key represents + the input data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes. + output_features: A dictionary defining the PolicyFeature of the output data for the policy. The key represents + the output data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes. + normalization_mapping: A dictionary that maps from a str value of FeatureType (e.g., "STATE", "VISUAL") to + a corresponding NormalizationMode (e.g., NormalizationMode.MIN_MAX) image_encoder_hidden_dim: Number of channels for the convolutional layers used for image encoding. state_encoder_hidden_dim: Hidden dimension for MLP used for state vector encoding. latent_dim: Observation's latent embedding dimension. diff --git a/src/lerobot/policies/vqbet/configuration_vqbet.py b/src/lerobot/policies/vqbet/configuration_vqbet.py index 44ada9f17..32906e528 100644 --- a/src/lerobot/policies/vqbet/configuration_vqbet.py +++ b/src/lerobot/policies/vqbet/configuration_vqbet.py @@ -32,7 +32,7 @@ class VQBeTConfig(PreTrainedConfig): Defaults are configured for training with PushT providing proprioceptive and single camera observations. The parameters you will most likely need to change are the ones which depend on the environment / sensors. - Those are: `input_shapes` and `output_shapes`. + Those are: `input_features` and `output_features`. Notes on the inputs and outputs: - "observation.state" is required as an input key. @@ -46,21 +46,12 @@ class VQBeTConfig(PreTrainedConfig): current step and additional steps going back). n_action_pred_token: Total number of current token and future tokens that VQ-BeT predicts. action_chunk_size: Action chunk size of each action prediction token. - input_shapes: A dictionary defining the shapes of the input data for the policy. - The key represents the input data name, and the value is a list indicating the dimensions - of the corresponding data. For example, "observation.image" refers to an input from - a camera with dimensions [3, 96, 96], indicating it has three color channels and 96x96 resolution. - Importantly, shapes doesnt include batch dimension or temporal dimension. - output_shapes: A dictionary defining the shapes of the output data for the policy. - The key represents the output data name, and the value is a list indicating the dimensions - of the corresponding data. For example, "action" refers to an output shape of [14], indicating - 14-dimensional actions. Importantly, shapes doesnt include batch dimension or temporal dimension. - input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"), - and the value specifies the normalization mode to apply. The two available modes are "mean_std" - which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a - [-1, 1] range. - output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the - original scale. Note that this is also used for normalizing the training targets. + input_features: A dictionary defining the PolicyFeature of the input data for the policy. The key represents + the input data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes. + output_features: A dictionary defining the PolicyFeature of the output data for the policy. The key represents + the output data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes. + normalization_mapping: A dictionary that maps from a str value of FeatureType (e.g., "STATE", "VISUAL") to + a corresponding NormalizationMode (e.g., NormalizationMode.MIN_MAX) vision_backbone: Name of the torchvision resnet backbone to use for encoding images. crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit within the image size. If None, no cropping is done. From 14a15f90e762170209d283c3545523549841ca3d Mon Sep 17 00:00:00 2001 From: Michel Aractingi Date: Mon, 2 Feb 2026 22:14:03 +0100 Subject: [PATCH 023/131] Add missing RL config options: add_ee_pose_to_observation and gripper_penalty_in_reward (#2873) * fix(RL) add missing config arguments * respond to copilot review * fix(revert penalty in reward): reverting gripper penalty addition in reward. This is already done in compute_loss_discrete_critic. --------- Co-authored-by: CarolinePascal --- src/lerobot/envs/configs.py | 1 + src/lerobot/processor/hil_processor.py | 22 ++++++++++++---------- src/lerobot/rl/gym_manipulator.py | 12 ++++++++++-- 3 files changed, 23 insertions(+), 12 deletions(-) diff --git a/src/lerobot/envs/configs.py b/src/lerobot/envs/configs.py index cd88b37bc..9c1c083a4 100644 --- a/src/lerobot/envs/configs.py +++ b/src/lerobot/envs/configs.py @@ -205,6 +205,7 @@ class ObservationConfig: add_joint_velocity_to_observation: bool = False add_current_to_observation: bool = False + add_ee_pose_to_observation: bool = False display_cameras: bool = False diff --git a/src/lerobot/processor/hil_processor.py b/src/lerobot/processor/hil_processor.py index 6d44ed8cb..24b5628fa 100644 --- a/src/lerobot/processor/hil_processor.py +++ b/src/lerobot/processor/hil_processor.py @@ -314,7 +314,7 @@ class TimeLimitProcessorStep(TruncatedProcessorStep): @dataclass @ProcessorStepRegistry.register("gripper_penalty_processor") -class GripperPenaltyProcessorStep(ComplementaryDataProcessorStep): +class GripperPenaltyProcessorStep(ProcessorStep): """ Applies a penalty for inefficient gripper usage. @@ -329,26 +329,27 @@ class GripperPenaltyProcessorStep(ComplementaryDataProcessorStep): penalty: float = -0.01 max_gripper_pos: float = 30.0 - def complementary_data(self, complementary_data: dict) -> dict: + def __call__(self, transition: EnvTransition) -> EnvTransition: """ Calculates the gripper penalty and adds it to the complementary data. Args: - complementary_data: The incoming complementary data, which should contain - raw joint positions. + transition: The incoming environment transition. Returns: - A new complementary data dictionary with the `discrete_penalty` key added. + The modified transition with the penalty added to complementary data. """ - action = self.transition.get(TransitionKey.ACTION) + new_transition = transition.copy() + action = new_transition.get(TransitionKey.ACTION) + complementary_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) raw_joint_positions = complementary_data.get("raw_joint_positions") if raw_joint_positions is None: - return complementary_data + return new_transition current_gripper_pos = raw_joint_positions.get(GRIPPER_KEY, None) if current_gripper_pos is None: - return complementary_data + return new_transition # Gripper action is a PolicyAction at this stage gripper_action = action[-1].item() @@ -364,11 +365,12 @@ class GripperPenaltyProcessorStep(ComplementaryDataProcessorStep): gripper_penalty = self.penalty * int(gripper_penalty_bool) - # Create new complementary data with penalty info + # Update complementary data with penalty info new_complementary_data = dict(complementary_data) new_complementary_data[DISCRETE_PENALTY_KEY] = gripper_penalty + new_transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary_data - return new_complementary_data + return new_transition def get_config(self) -> dict[str, Any]: """ diff --git a/src/lerobot/rl/gym_manipulator.py b/src/lerobot/rl/gym_manipulator.py index 3d58ae18f..1c1cb752f 100644 --- a/src/lerobot/rl/gym_manipulator.py +++ b/src/lerobot/rl/gym_manipulator.py @@ -412,7 +412,10 @@ def make_processors( if cfg.processor.observation.add_current_to_observation: env_pipeline_steps.append(MotorCurrentProcessorStep(robot=env.robot)) - if kinematics_solver is not None: + add_ee_pose = ( + cfg.processor.observation is not None and cfg.processor.observation.add_ee_pose_to_observation + ) + if kinematics_solver is not None and add_ee_pose: env_pipeline_steps.append( ForwardKinematicsJointsToEEObservation( kinematics=kinematics_solver, @@ -435,7 +438,12 @@ def make_processors( ) # Add gripper penalty processor if gripper config exists and enabled - if cfg.processor.gripper is not None and cfg.processor.gripper.use_gripper: + # Only add if max_gripper_pos is explicitly configured (required for normalization) + if ( + cfg.processor.gripper is not None + and cfg.processor.gripper.use_gripper + and cfg.processor.max_gripper_pos is not None + ): env_pipeline_steps.append( GripperPenaltyProcessorStep( penalty=cfg.processor.gripper.gripper_penalty, From a6370dd783c1048096b9596853beccc08a7b0bbd Mon Sep 17 00:00:00 2001 From: Iori Yanokura Date: Tue, 3 Feb 2026 22:17:04 +0900 Subject: [PATCH 024/131] fix(wandb): truncate init tags to 64-character limit (#995) --- src/lerobot/rl/wandb_utils.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/src/lerobot/rl/wandb_utils.py b/src/lerobot/rl/wandb_utils.py index 7b7f8a57b..ee30b75df 100644 --- a/src/lerobot/rl/wandb_utils.py +++ b/src/lerobot/rl/wandb_utils.py @@ -26,8 +26,21 @@ from lerobot.configs.train import TrainPipelineConfig from lerobot.utils.constants import PRETRAINED_MODEL_DIR -def cfg_to_group(cfg: TrainPipelineConfig, return_list: bool = False) -> list[str] | str: +def cfg_to_group( + cfg: TrainPipelineConfig, return_list: bool = False, truncate_tags: bool = False, max_tag_length: int = 64 +) -> list[str] | str: """Return a group name for logging. Optionally returns group name as list.""" + + def _maybe_truncate(tag: str) -> str: + """Truncate tag to max_tag_length characters if required. + + wandb rejects tags longer than 64 characters. + See: https://github.com/wandb/wandb/blob/main/wandb/sdk/wandb_settings.py + """ + if len(tag) <= max_tag_length: + return tag + return tag[:max_tag_length] + lst = [ f"policy:{cfg.policy.type}", f"seed:{cfg.seed}", @@ -36,6 +49,8 @@ def cfg_to_group(cfg: TrainPipelineConfig, return_list: bool = False) -> list[st lst.append(f"dataset:{cfg.dataset.repo_id}") if cfg.env is not None: lst.append(f"env:{cfg.env.type}") + if truncate_tags: + lst = [_maybe_truncate(tag) for tag in lst] return lst if return_list else "-".join(lst) @@ -83,7 +98,7 @@ class WandBLogger: entity=self.cfg.entity, name=self.job_name, notes=self.cfg.notes, - tags=cfg_to_group(cfg, return_list=True), + tags=cfg_to_group(cfg, return_list=True, truncate_tags=True), dir=self.log_dir, config=cfg.to_dict(), # TODO(rcadene): try set to True From 0f392484458cb5ebca0310c0c4c47390a31c80ed Mon Sep 17 00:00:00 2001 From: jwang078 Date: Tue, 3 Feb 2026 13:19:00 -0500 Subject: [PATCH 025/131] Small docstring fix in diffusion configuration (#2847) --- src/lerobot/policies/diffusion/configuration_diffusion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lerobot/policies/diffusion/configuration_diffusion.py b/src/lerobot/policies/diffusion/configuration_diffusion.py index 8322ca337..8ac0920dd 100644 --- a/src/lerobot/policies/diffusion/configuration_diffusion.py +++ b/src/lerobot/policies/diffusion/configuration_diffusion.py @@ -64,7 +64,7 @@ class DiffusionConfig(PreTrainedConfig): use_group_norm: Whether to replace batch normalization with group normalization in the backbone. The group sizes are set to be about 16 (to be precise, feature_dim // 16). spatial_softmax_num_keypoints: Number of keypoints for SpatialSoftmax. - use_separate_rgb_encoders_per_camera: Whether to use a separate RGB encoder for each camera view. + use_separate_rgb_encoder_per_camera: Whether to use a separate RGB encoder for each camera view. down_dims: Feature dimension for each stage of temporal downsampling in the diffusion modeling Unet. You may provide a variable number of dimensions, therefore also controlling the degree of downsampling. From 97e7e0f9ed8831daee04a6e5f67d777f689c87e4 Mon Sep 17 00:00:00 2001 From: Reece O'Mahoney <66252930+reeceomahoney@users.noreply.github.com> Date: Thu, 5 Feb 2026 14:39:58 +0000 Subject: [PATCH 026/131] feat(datasets): improve image transform support (#2885) * improve image transform support * add tests * Add stricter transform check and extra test * improve subclass check --- src/lerobot/datasets/transforms.py | 19 ++++++++++--------- tests/datasets/test_image_transforms.py | 24 ++++++++++++++++++++++++ 2 files changed, 34 insertions(+), 9 deletions(-) diff --git a/src/lerobot/datasets/transforms.py b/src/lerobot/datasets/transforms.py index beacc48d9..5240619cb 100644 --- a/src/lerobot/datasets/transforms.py +++ b/src/lerobot/datasets/transforms.py @@ -216,16 +216,17 @@ class ImageTransformsConfig: def make_transform_from_config(cfg: ImageTransformConfig): - if cfg.type == "Identity": - return v2.Identity(**cfg.kwargs) - elif cfg.type == "ColorJitter": - return v2.ColorJitter(**cfg.kwargs) - elif cfg.type == "SharpnessJitter": + if cfg.type == "SharpnessJitter": return SharpnessJitter(**cfg.kwargs) - elif cfg.type == "RandomAffine": - return v2.RandomAffine(**cfg.kwargs) - else: - raise ValueError(f"Transform '{cfg.type}' is not valid.") + + transform_cls = getattr(v2, cfg.type, None) + if isinstance(transform_cls, type) and issubclass(transform_cls, Transform): + return transform_cls(**cfg.kwargs) + + raise ValueError( + f"Transform '{cfg.type}' is not valid. It must be a class in " + f"torchvision.transforms.v2 or 'SharpnessJitter'." + ) class ImageTransforms(Transform): diff --git a/tests/datasets/test_image_transforms.py b/tests/datasets/test_image_transforms.py index 8a66ceb24..ef7e8c395 100644 --- a/tests/datasets/test_image_transforms.py +++ b/tests/datasets/test_image_transforms.py @@ -390,6 +390,30 @@ def test_sharpness_jitter_invalid_range_max_smaller(): SharpnessJitter((2.0, 0.1)) +def test_make_transform_from_config_with_v2_resize(img_tensor_factory): + img_tensor = img_tensor_factory() + tf_cfg = ImageTransformConfig(type="Resize", kwargs={"size": (32, 32)}) + tf = make_transform_from_config(tf_cfg) + assert isinstance(tf, v2.Resize) + output = tf(img_tensor) + assert output.shape[-2:] == (32, 32) + + +def test_make_transform_from_config_with_v2_identity(img_tensor_factory): + img_tensor = img_tensor_factory() + tf_cfg = ImageTransformConfig(type="Identity", kwargs={}) + tf = make_transform_from_config(tf_cfg) + assert isinstance(tf, v2.Identity) + output = tf(img_tensor) + assert output.shape == img_tensor.shape + + +def test_make_transform_from_config_invalid_type(): + tf_cfg = ImageTransformConfig(type="NotARealTransform", kwargs={}) + with pytest.raises(ValueError, match="not valid"): + make_transform_from_config(tf_cfg) + + def test_save_all_transforms(img_tensor_factory, tmp_path): img_tensor = img_tensor_factory() tf_cfg = ImageTransformsConfig(enable=True) From e14bdf57d055e85ebc8a684efd2e4b9a4c7b6a37 Mon Sep 17 00:00:00 2001 From: Reece O'Mahoney <66252930+reeceomahoney@users.noreply.github.com> Date: Mon, 9 Feb 2026 13:46:12 +0000 Subject: [PATCH 027/131] Convert tensors to scalars (#2903) Co-authored-by: Steven Palma --- src/lerobot/policies/smolvla/modeling_smolvla.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/lerobot/policies/smolvla/modeling_smolvla.py b/src/lerobot/policies/smolvla/modeling_smolvla.py index c611e9ba2..60b968a42 100644 --- a/src/lerobot/policies/smolvla/modeling_smolvla.py +++ b/src/lerobot/policies/smolvla/modeling_smolvla.py @@ -378,16 +378,16 @@ class SmolVLAPolicy(PreTrainedPolicy): actions_is_pad = batch.get("actions_id_pad") loss_dict = {} losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions, noise, time) - loss_dict["losses_after_forward"] = losses.clone() + loss_dict["losses_after_forward"] = losses.clone().mean().item() if actions_is_pad is not None: in_episode_bound = ~actions_is_pad losses = losses * in_episode_bound.unsqueeze(-1) - loss_dict["losses_after_in_ep_bound"] = losses.clone() + loss_dict["losses_after_in_ep_bound"] = losses.clone().mean().item() # Remove padding losses = losses[:, :, : self.config.max_action_dim] - loss_dict["losses_after_rm_padding"] = losses.clone() + loss_dict["losses_after_rm_padding"] = losses.clone().mean().item() if reduction == "none": # Return per-sample losses (B,) by averaging over time and action dims From 489cb7b6b9a39b569aaf02ff26df6725e9b36285 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Mon, 9 Feb 2026 16:58:32 +0100 Subject: [PATCH 028/131] fix(scripts): correct can import check (#2937) --- src/lerobot/scripts/lerobot_setup_can.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lerobot/scripts/lerobot_setup_can.py b/src/lerobot/scripts/lerobot_setup_can.py index 55de74724..a31727ea4 100644 --- a/src/lerobot/scripts/lerobot_setup_can.py +++ b/src/lerobot/scripts/lerobot_setup_can.py @@ -45,7 +45,7 @@ from dataclasses import dataclass, field import draccus -from lerobot.utils.import_utils import is_package_available +from lerobot.utils.import_utils import _can_available MOTOR_NAMES = { 0x01: "joint_1", @@ -336,7 +336,7 @@ def run_speed(cfg: CANSetupConfig): @draccus.wrap() def setup_can(cfg: CANSetupConfig): - if not is_package_available("can"): + if not _can_available: print("Error: python-can not installed. Install with: pip install python-can") sys.exit(1) From cca0296cd6f0f281c5fd4628e836403628b59a05 Mon Sep 17 00:00:00 2001 From: Stepan Feduniak Date: Tue, 10 Feb 2026 13:55:11 +0100 Subject: [PATCH 029/131] fix(pipeline): use FeatureType for STATE features in Libero processor (#2888) * fix the types * pre-commit --------- Co-authored-by: Jade Choghari Co-authored-by: Steven Palma --- src/lerobot/processor/env_processor.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/lerobot/processor/env_processor.py b/src/lerobot/processor/env_processor.py index 8d42bfdb7..a77e066cf 100644 --- a/src/lerobot/processor/env_processor.py +++ b/src/lerobot/processor/env_processor.py @@ -17,7 +17,7 @@ from dataclasses import dataclass import torch -from lerobot.configs.types import PipelineFeatureType, PolicyFeature +from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature from lerobot.utils.constants import OBS_IMAGES, OBS_PREFIX, OBS_STATE, OBS_STR from .pipeline import ObservationProcessorStep, ProcessorStepRegistry @@ -92,7 +92,7 @@ class LiberoProcessorStep(ObservationProcessorStep): # copy over non-STATE features for ft, feats in features.items(): - if ft != PipelineFeatureType.STATE: + if ft != FeatureType.STATE: new_features[ft] = feats.copy() # rebuild STATE features @@ -100,13 +100,11 @@ class LiberoProcessorStep(ObservationProcessorStep): # add our new flattened state state_feats[OBS_STATE] = PolicyFeature( - key=OBS_STATE, + type=FeatureType.STATE, shape=(8,), # [eef_pos(3), axis_angle(3), gripper(2)] - dtype="float32", - description=("Concatenated end-effector position (3), axis-angle (3), and gripper qpos (2)."), ) - new_features[PipelineFeatureType.STATE] = state_feats + new_features[FeatureType.STATE] = state_feats return new_features From 5eba4ce6f453c2dfe4458b037bf3612df22f81ee Mon Sep 17 00:00:00 2001 From: Aoqun Jin Date: Tue, 10 Feb 2026 21:39:17 +0800 Subject: [PATCH 030/131] Change LIBERO init_state_id when reset. (#2899) * Change LIBERO init_state_id when reset. Signed-off-by: Aoqun Jin * Change LIBERO init_state_id when reset. Signed-off-by: Aoqun Jin * pre-commit run --------- Signed-off-by: Aoqun Jin Co-authored-by: Jade Choghari Co-authored-by: Steven Palma --- src/lerobot/envs/libero.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/lerobot/envs/libero.py b/src/lerobot/envs/libero.py index 96c5cf102..d20dae8ea 100644 --- a/src/lerobot/envs/libero.py +++ b/src/lerobot/envs/libero.py @@ -112,6 +112,7 @@ class LiberoEnv(gym.Env): visualization_height: int = 480, init_states: bool = True, episode_index: int = 0, + n_envs: int = 1, camera_name_mapping: dict[str, str] | None = None, num_steps_wait: int = 10, control_mode: str = "relative", @@ -145,7 +146,9 @@ class LiberoEnv(gym.Env): self.episode_length = episode_length # Load once and keep self._init_states = get_task_init_states(task_suite, self.task_id) if self.init_states else None - self._init_state_id = self.episode_index # tie each sub-env to a fixed init state + self._reset_stride = n_envs # when performing a reset, append `_reset_stride` to `init_state_id`. + + self.init_state_id = self.episode_index # tie each sub-env to a fixed init state self._env = self._make_envs_task(task_suite, self.task_id) default_steps = 500 @@ -295,7 +298,8 @@ class LiberoEnv(gym.Env): self._env.seed(seed) raw_obs = self._env.reset() if self.init_states and self._init_states is not None: - raw_obs = self._env.set_init_state(self._init_states[self._init_state_id]) + raw_obs = self._env.set_init_state(self._init_states[self.init_state_id % len(self._init_states)]) + self.init_state_id += self._reset_stride # Change init_state_id when reset # After reset, objects may be unstable (slightly floating, intersecting, etc.). # Step the simulator with a no-op action for a few frames so everything settles. @@ -373,6 +377,7 @@ def _make_env_fns( init_states=init_states, episode_length=episode_length, episode_index=episode_index, + n_envs=n_envs, control_mode=control_mode, **local_kwargs, ) From d2d01399d6773427347a37be401ec6ea35fa0e15 Mon Sep 17 00:00:00 2001 From: Jai Kumaar Ratadia Date: Tue, 10 Feb 2026 14:18:32 +0000 Subject: [PATCH 031/131] docs: clarify installation steps are sequential, not optional (#2925) * docs: clarify installation steps are sequential, not optional Add intro paragraph noting conda is one path (not the only one) and number the three sections as steps so readers understand miniforge and environment setup are prerequisites, not independent choices. * Update installation guide link for LeRobot Signed-off-by: Jai Kumaar Ratadia * Fix link formatting in installation guide again Signed-off-by: Jai Kumaar Ratadia --------- Signed-off-by: Jai Kumaar Ratadia Co-authored-by: Steven Palma --- docs/source/installation.mdx | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/docs/source/installation.mdx b/docs/source/installation.mdx index 44d8c7034..8cc83843e 100644 --- a/docs/source/installation.mdx +++ b/docs/source/installation.mdx @@ -1,13 +1,15 @@ # Installation -## Install [`miniforge`](https://conda-forge.org/download/) +This guide uses conda (via miniforge) to manage environments. If you prefer another environment manager (e.g. `uv`, `venv`), ensure you have Python >=3.10 and ffmpeg installed with the `libsvtav1` encoder, then skip ahead to [Install LeRobot](#step-3-install-lerobot-). + +## Step 1: Install [`miniforge`](https://conda-forge.org/download/) ```bash wget "https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-$(uname)-$(uname -m).sh" bash Miniforge3-$(uname)-$(uname -m).sh ``` -## Environment Setup +## Step 2: Environment Setup Create a virtual environment with Python 3.10, using conda: @@ -38,7 +40,7 @@ conda install ffmpeg -c conda-forge > > - _[On Linux only]_ If you want to bring your own ffmpeg: Install [ffmpeg build dependencies](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#GettheDependencies) and [compile ffmpeg from source with libsvtav1](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#libsvtav1), and make sure you use the corresponding ffmpeg binary to your install with `which ffmpeg`. -## Install LeRobot 🤗 +## Step 3: Install LeRobot 🤗 ### From Source From 778db19a178b2fee539934419b1475b1f411bac9 Mon Sep 17 00:00:00 2001 From: whats2000 <60466660+whats2000@users.noreply.github.com> Date: Tue, 10 Feb 2026 22:21:40 +0800 Subject: [PATCH 032/131] [Bug Fix] fix(ci): prevent runner group error on fork pushes (#2911) * fix(ci): prevent runner group error on fork pushes Add repository check to unbound_deps_tests workflow to ensure aws-general-8-plus runner group is only used on main repository, preventing 'Required runner group not found' errors on forks. * fix(ci): use gating job to prevent runner allocation on forks The previous approach failed because GitHub evaluates runs-on before if conditions. Now using a check-repo job that runs on ubuntu-latest first, and all jobs with special runners depend on it and check its output before being scheduled. * fix(ci): add gating job to full_tests to prevent runner allocation on forks Apply the same gating pattern used in unbound_deps_tests to full_tests.yml to prevent GitHub from trying to allocate custom runners when workflows run on forks. The check-repo job runs first on ubuntu-latest and all jobs with custom runners depend on it and check its output. * fix(ci): add repository check to unbound_deps_tests workflow Add 'if: github.repository == huggingface/lerobot' check to build-and-push-docker job to prevent runner group access errors on forks, matching the pattern used in nightly.yml * fix(ci): add repository check to full_tests workflow Add 'if: github.repository == huggingface/lerobot' check to build-and-push-docker and gpu-tests jobs to prevent runner group access errors on forks * refactor(ci): remove redundant check from gpu-tests job gpu-tests depends on build-and-push-docker via needs, so it will automatically skip when the parent job is skipped * refactor(ci): remove unnecessary fork check from full-tests job full-tests runs on ubuntu-latest which is available to all forks, no need to restrict it --------- Co-authored-by: Steven Palma --- .github/workflows/full_tests.yml | 8 +++++--- .github/workflows/unbound_deps_tests.yml | 1 + 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/.github/workflows/full_tests.yml b/.github/workflows/full_tests.yml index 4dce3121a..fd5e422b3 100644 --- a/.github/workflows/full_tests.yml +++ b/.github/workflows/full_tests.yml @@ -101,9 +101,11 @@ jobs: runs-on: group: aws-general-8-plus if: | - (github.event_name == 'pull_request_review' && github.event.review.state == 'approved' && github.event.pull_request.head.repo.fork == false) || - github.event_name == 'push' || - github.event_name == 'workflow_dispatch' + github.repository == 'huggingface/lerobot' && ( + (github.event_name == 'pull_request_review' && github.event.review.state == 'approved' && github.event.pull_request.head.repo.fork == false) || + github.event_name == 'push' || + github.event_name == 'workflow_dispatch' + ) outputs: image_tag: ${{ steps.set_tag.outputs.image_tag }} env: diff --git a/.github/workflows/unbound_deps_tests.yml b/.github/workflows/unbound_deps_tests.yml index a75ecc121..3f4ea3316 100644 --- a/.github/workflows/unbound_deps_tests.yml +++ b/.github/workflows/unbound_deps_tests.yml @@ -91,6 +91,7 @@ jobs: name: Build and Push Docker runs-on: group: aws-general-8-plus + if: github.repository == 'huggingface/lerobot' outputs: image_tag: ${{ env.DOCKER_IMAGE_NAME }} env: From 35363c5798d129d7667c2efa43ddfa342639a35a Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Tue, 10 Feb 2026 17:35:39 +0100 Subject: [PATCH 033/131] chore(linter): ensure motors module passes MyPy type checks (#2939) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix: ensure motors module passes MyPy type checks This commit fixes 62 mypy type errors in the motors module by: - Updating Protocol classes (PortHandler, PacketHandler, GroupSyncRead, GroupSyncWrite) to use class-level attribute declarations instead of __init__ body declarations - Adding missing `broadcastPing` method to PacketHandler Protocol - Fixing return type annotations (e.g., `_get_motor_model` returns str, not int) - Fixing parameter types to use `Sequence` for covariant list parameters - Fixing `Mapping` for covariant dict value types in `_normalize` - Updating method signatures to be consistent across parent and child classes (disable_torque, enable_torque, _get_half_turn_homings) - Adding explicit `int()` casts for MotorCalibration arguments - Adding explicit `return None` for functions returning Optional types - Adding type annotations for variables like `data_list: dict[int, int]` - Using `# type: ignore[method-assign]` for intentional monkeypatch - Fixing variable references (using `self.groups` instead of `groups`) Fixes #1723 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 * chore(style): pre-commit after main merge * chore(linter): solve comments * chore(linter): apply pre-commit fixes to damiao * chore(linter): more fixes to damiao --------- Co-authored-by: yurekami Co-authored-by: Claude Opus 4.5 --- pyproject.toml | 6 +- src/lerobot/motors/calibration_gui.py | 10 +- src/lerobot/motors/damiao/damiao.py | 42 ++++- src/lerobot/motors/dynamixel/dynamixel.py | 16 +- src/lerobot/motors/feetech/feetech.py | 18 +-- src/lerobot/motors/motors_bus.py | 183 +++++++++++----------- 6 files changed, 157 insertions(+), 118 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 210d70b6b..c4b1c547e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -360,9 +360,9 @@ ignore_errors = false module = "lerobot.cameras.*" ignore_errors = false -# [[tool.mypy.overrides]] -# module = "lerobot.motors.*" -# ignore_errors = false +[[tool.mypy.overrides]] +module = "lerobot.motors.*" +ignore_errors = false # [[tool.mypy.overrides]] # module = "lerobot.robots.*" diff --git a/src/lerobot/motors/calibration_gui.py b/src/lerobot/motors/calibration_gui.py index 02bba454f..3410cb28a 100644 --- a/src/lerobot/motors/calibration_gui.py +++ b/src/lerobot/motors/calibration_gui.py @@ -221,7 +221,7 @@ class RangeFinderGUI: self.bus = bus self.groups = groups if groups is not None else {"all": list(bus.motors)} - self.group_names = list(groups) + self.group_names = list(self.groups) self.current_group = self.group_names[0] if not bus.is_connected: @@ -230,18 +230,20 @@ class RangeFinderGUI: self.calibration = bus.read_calibration() self.res_table = bus.model_resolution_table self.present_cache = { - m: bus.read("Present_Position", m, normalize=False) for motors in groups.values() for m in motors + m: bus.read("Present_Position", m, normalize=False) + for motors in self.groups.values() + for m in motors } pygame.init() self.font = pygame.font.Font(None, FONT_SIZE) - label_pad = max(self.font.size(m)[0] for ms in groups.values() for m in ms) + label_pad = max(self.font.size(m)[0] for ms in self.groups.values() for m in ms) self.label_pad = label_pad width = 40 + label_pad + BAR_LEN + 6 + BTN_W + 10 + SAVE_W + 10 self.controls_bottom = 10 + SAVE_H self.base_y = self.controls_bottom + TOP_GAP - height = self.base_y + PADDING_Y * len(groups[self.current_group]) + 40 + height = self.base_y + PADDING_Y * len(self.groups[self.current_group]) + 40 self.screen = pygame.display.set_mode((width, height)) pygame.display.set_caption("Motors range finder") diff --git a/src/lerobot/motors/damiao/damiao.py b/src/lerobot/motors/damiao/damiao.py index c79f8d17e..95a9e70d1 100644 --- a/src/lerobot/motors/damiao/damiao.py +++ b/src/lerobot/motors/damiao/damiao.py @@ -211,6 +211,9 @@ class DamiaoMotorsBus(MotorsBusBase): logger.info("Starting handshake with motors...") # Drain any pending messages + if self.canbus is None: + raise RuntimeError("CAN bus is not initialized.") + while self.canbus.recv(timeout=0.01): pass @@ -283,6 +286,10 @@ class DamiaoMotorsBus(MotorsBusBase): recv_id = self._get_motor_recv_id(motor) data = [0xFF] * 7 + [command_byte] msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False, is_fd=self.use_can_fd) + + if self.canbus is None: + raise RuntimeError("CAN bus is not initialized.") + self.canbus.send(msg) if msg := self._recv_motor_response(expected_recv_id=recv_id): self._process_response(motor_name, msg) @@ -341,6 +348,10 @@ class DamiaoMotorsBus(MotorsBusBase): recv_id = self._get_motor_recv_id(motor) data = [motor_id & 0xFF, (motor_id >> 8) & 0xFF, CAN_CMD_REFRESH, 0, 0, 0, 0, 0] msg = can.Message(arbitration_id=CAN_PARAM_ID, data=data, is_extended_id=False, is_fd=self.use_can_fd) + + if self.canbus is None: + raise RuntimeError("CAN bus is not initialized.") + self.canbus.send(msg) return self._recv_motor_response(expected_recv_id=recv_id) @@ -356,6 +367,10 @@ class DamiaoMotorsBus(MotorsBusBase): Returns: CAN message if received, None otherwise """ + + if self.canbus is None: + raise RuntimeError("CAN bus is not initialized.") + try: start_time = time.time() messages_seen = [] @@ -394,10 +409,13 @@ class DamiaoMotorsBus(MotorsBusBase): Returns: Dictionary mapping recv_id to CAN message """ - responses = {} + responses: dict[int, can.Message] = {} expected_set = set(expected_recv_ids) start_time = time.time() + if self.canbus is None: + raise RuntimeError("CAN bus is not initialized.") + try: while len(responses) < len(expected_recv_ids) and (time.time() - start_time) < timeout: # 100us poll timeout @@ -461,6 +479,9 @@ class DamiaoMotorsBus(MotorsBusBase): motor_name = self._get_motor_name(motor) motor_type = self._motor_types[motor_name] + if self.canbus is None: + raise RuntimeError("CAN bus is not initialized.") + data = self._encode_mit_packet(motor_type, kp, kd, position_degrees, velocity_deg_per_sec, torque) msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False, is_fd=self.use_can_fd) self.canbus.send(msg) @@ -488,6 +509,9 @@ class DamiaoMotorsBus(MotorsBusBase): recv_id_to_motor: dict[int, str] = {} + if self.canbus is None: + raise RuntimeError("CAN bus is not initialized.") + # Step 1: Send all MIT control commands for motor, (kp, kd, position_degrees, velocity_deg_per_sec, torque) in commands.items(): motor_id = self._get_motor_id(motor) @@ -656,6 +680,10 @@ class DamiaoMotorsBus(MotorsBusBase): def _batch_refresh(self, motors: list[str]) -> None: """Internal helper to refresh a list of motors and update cache.""" + + if self.canbus is None: + raise RuntimeError("CAN bus is not initialized.") + # Send refresh commands for motor in motors: motor_id = self._get_motor_id(motor) @@ -678,10 +706,14 @@ class DamiaoMotorsBus(MotorsBusBase): else: logger.warning(f"Packet drop: {motor} (ID: 0x{recv_id:02X}). Using last known state.") - def sync_write(self, data_name: str, values: Value | dict[str, Value]) -> None: + def sync_write(self, data_name: str, values: dict[str, Value]) -> None: """ Write values to multiple motors simultaneously. Positions are always in degrees. """ + + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + if data_name in ("Kp", "Kd"): key = data_name.lower() for motor, val in values.items(): @@ -690,6 +722,8 @@ class DamiaoMotorsBus(MotorsBusBase): elif data_name == "Goal_Position": # Step 1: Send all MIT control commands recv_id_to_motor: dict[int, str] = {} + if self.canbus is None: + raise RuntimeError("CAN bus is not initialized.") for motor, value_degrees in values.items(): motor_id = self._get_motor_id(motor) motor_name = self._get_motor_name(motor) @@ -732,9 +766,9 @@ class DamiaoMotorsBus(MotorsBusBase): def record_ranges_of_motion( self, - motors: NameOrID | list[NameOrID] | None = None, + motors: str | list[str] | None = None, display_values: bool = True, - ) -> tuple[dict[NameOrID, Value], dict[NameOrID, Value]]: + ) -> tuple[dict[str, Value], dict[str, Value]]: """ Interactively record the min/max values of each motor in degrees. diff --git a/src/lerobot/motors/dynamixel/dynamixel.py b/src/lerobot/motors/dynamixel/dynamixel.py index c6752ee96..bca455dc5 100644 --- a/src/lerobot/motors/dynamixel/dynamixel.py +++ b/src/lerobot/motors/dynamixel/dynamixel.py @@ -181,10 +181,10 @@ class DynamixelMotorsBus(SerialMotorsBus): for motor, m in self.motors.items(): calibration[motor] = MotorCalibration( id=m.id, - drive_mode=drive_modes[motor], - homing_offset=offsets[motor], - range_min=mins[motor], - range_max=maxes[motor], + drive_mode=int(drive_modes[motor]), + homing_offset=int(offsets[motor]), + range_min=int(mins[motor]), + range_max=int(maxes[motor]), ) return calibration @@ -198,7 +198,7 @@ class DynamixelMotorsBus(SerialMotorsBus): if cache: self.calibration = calibration_dict - def disable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None: + def disable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None: for motor in self._get_motors_list(motors): self.write("Torque_Enable", motor, TorqueMode.DISABLED.value, num_retry=num_retry) @@ -206,7 +206,7 @@ class DynamixelMotorsBus(SerialMotorsBus): addr, length = get_address(self.model_ctrl_table, model, "Torque_Enable") self._write(addr, length, motor, TorqueMode.DISABLED.value, num_retry=num_retry) - def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None: + def enable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None: for motor in self._get_motors_list(motors): self.write("Torque_Enable", motor, TorqueMode.ENABLED.value, num_retry=num_retry) @@ -235,7 +235,7 @@ class DynamixelMotorsBus(SerialMotorsBus): On Dynamixel Motors: Present_Position = Actual_Position + Homing_Offset """ - half_turn_homings = {} + half_turn_homings: dict[NameOrID, Value] = {} for motor, pos in positions.items(): model = self._get_motor_model(motor) max_res = self.model_resolution_table[model] - 1 @@ -258,6 +258,6 @@ class DynamixelMotorsBus(SerialMotorsBus): if raise_on_error: raise ConnectionError(self.packet_handler.getTxRxResult(comm)) - return + return None return {id_: data[0] for id_, data in data_list.items()} diff --git a/src/lerobot/motors/feetech/feetech.py b/src/lerobot/motors/feetech/feetech.py index 7ce3388b6..58a65310d 100644 --- a/src/lerobot/motors/feetech/feetech.py +++ b/src/lerobot/motors/feetech/feetech.py @@ -126,7 +126,7 @@ class FeetechMotorsBus(SerialMotorsBus): self.port_handler = scs.PortHandler(self.port) # HACK: monkeypatch - self.port_handler.setPacketTimeout = patch_setPacketTimeout.__get__( + self.port_handler.setPacketTimeout = patch_setPacketTimeout.__get__( # type: ignore[method-assign] self.port_handler, scs.PortHandler ) self.packet_handler = scs.PacketHandler(protocol_version) @@ -262,9 +262,9 @@ class FeetechMotorsBus(SerialMotorsBus): calibration[motor] = MotorCalibration( id=m.id, drive_mode=0, - homing_offset=offsets[motor], - range_min=mins[motor], - range_max=maxes[motor], + homing_offset=int(offsets[motor]), + range_min=int(mins[motor]), + range_max=int(maxes[motor]), ) return calibration @@ -284,7 +284,7 @@ class FeetechMotorsBus(SerialMotorsBus): On Feetech Motors: Present_Position = Actual_Position - Homing_Offset """ - half_turn_homings = {} + half_turn_homings: dict[NameOrID, Value] = {} for motor, pos in positions.items(): model = self._get_motor_model(motor) max_res = self.model_resolution_table[model] - 1 @@ -292,7 +292,7 @@ class FeetechMotorsBus(SerialMotorsBus): return half_turn_homings - def disable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None: + def disable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None: for motor in self._get_motors_list(motors): self.write("Torque_Enable", motor, TorqueMode.DISABLED.value, num_retry=num_retry) self.write("Lock", motor, 0, num_retry=num_retry) @@ -303,7 +303,7 @@ class FeetechMotorsBus(SerialMotorsBus): addr, length = get_address(self.model_ctrl_table, model, "Lock") self._write(addr, length, motor, 0, num_retry=num_retry) - def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None: + def enable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None: for motor in self._get_motors_list(motors): self.write("Torque_Enable", motor, TorqueMode.ENABLED.value, num_retry=num_retry) self.write("Lock", motor, 1, num_retry=num_retry) @@ -334,7 +334,7 @@ class FeetechMotorsBus(SerialMotorsBus): def _broadcast_ping(self) -> tuple[dict[int, int], int]: import scservo_sdk as scs - data_list = {} + data_list: dict[int, int] = {} status_length = 6 @@ -414,7 +414,7 @@ class FeetechMotorsBus(SerialMotorsBus): if not self._is_comm_success(comm): if raise_on_error: raise ConnectionError(self.packet_handler.getTxRxResult(comm)) - return + return None ids_errors = {id_: status for id_, status in ids_status.items() if self._is_error(status)} if ids_errors: diff --git a/src/lerobot/motors/motors_bus.py b/src/lerobot/motors/motors_bus.py index c04f718b6..bc3ffb7e2 100644 --- a/src/lerobot/motors/motors_bus.py +++ b/src/lerobot/motors/motors_bus.py @@ -23,6 +23,7 @@ from __future__ import annotations import abc import logging +from collections.abc import Sequence from contextlib import contextmanager from dataclasses import dataclass from enum import Enum @@ -93,7 +94,7 @@ class MotorsBusBase(abc.ABC): pass @abc.abstractmethod - def sync_write(self, data_name: str, values: Value | dict[str, Value]) -> None: + def sync_write(self, data_name: str, values: dict[str, Value]) -> None: """Write values to multiple motors.""" pass @@ -179,15 +180,16 @@ class Motor: class PortHandler(Protocol): - def __init__(self, port_name): - self.is_open: bool - self.baudrate: int - self.packet_start_time: float - self.packet_timeout: float - self.tx_time_per_byte: float - self.is_using: bool - self.port_name: str - self.ser: serial.Serial + is_open: bool + baudrate: int + packet_start_time: float + packet_timeout: float + tx_time_per_byte: float + is_using: bool + port_name: str + ser: serial.Serial + + def __init__(self, port_name: str) -> None: ... def openPort(self): ... def closePort(self): ... @@ -240,19 +242,22 @@ class PacketHandler(Protocol): def regWriteTxRx(self, port, id, address, length, data): ... def syncReadTx(self, port, start_address, data_length, param, param_length): ... def syncWriteTxOnly(self, port, start_address, data_length, param, param_length): ... + def broadcastPing(self, port): ... class GroupSyncRead(Protocol): - def __init__(self, port, ph, start_address, data_length): - self.port: str - self.ph: PortHandler - self.start_address: int - self.data_length: int - self.last_result: bool - self.is_param_changed: bool - self.param: list - self.data_dict: dict + port: str + ph: PortHandler + start_address: int + data_length: int + last_result: bool + is_param_changed: bool + param: list + data_dict: dict + def __init__( + self, port: PortHandler, ph: PacketHandler, start_address: int, data_length: int + ) -> None: ... def makeParam(self): ... def addParam(self, id): ... def removeParam(self, id): ... @@ -265,15 +270,17 @@ class GroupSyncRead(Protocol): class GroupSyncWrite(Protocol): - def __init__(self, port, ph, start_address, data_length): - self.port: str - self.ph: PortHandler - self.start_address: int - self.data_length: int - self.is_param_changed: bool - self.param: list - self.data_dict: dict + port: str + ph: PortHandler + start_address: int + data_length: int + is_param_changed: bool + param: list + data_dict: dict + def __init__( + self, port: PortHandler, ph: PacketHandler, start_address: int, data_length: int + ) -> None: ... def makeParam(self): ... def addParam(self, id, data): ... def removeParam(self, id): ... @@ -400,7 +407,7 @@ class SerialMotorsBus(MotorsBusBase): else: raise TypeError(f"'{motor}' should be int, str.") - def _get_motor_model(self, motor: NameOrID) -> int: + def _get_motor_model(self, motor: NameOrID) -> str: if isinstance(motor, str): return self.motors[motor].model elif isinstance(motor, int): @@ -408,17 +415,19 @@ class SerialMotorsBus(MotorsBusBase): else: raise TypeError(f"'{motor}' should be int, str.") - def _get_motors_list(self, motors: str | list[str] | None) -> list[str]: + def _get_motors_list(self, motors: NameOrID | Sequence[NameOrID] | None) -> list[str]: if motors is None: return list(self.motors) elif isinstance(motors, str): return [motors] - elif isinstance(motors, list): - return motors.copy() + elif isinstance(motors, int): + return [self._id_to_name(motors)] + elif isinstance(motors, Sequence): + return [m if isinstance(m, str) else self._id_to_name(m) for m in motors] else: raise TypeError(motors) - def _get_ids_values_dict(self, values: Value | dict[str, Value] | None) -> list[str]: + def _get_ids_values_dict(self, values: Value | dict[str, Value] | None) -> dict[int, Value]: if isinstance(values, (int | float)): return dict.fromkeys(self.ids, values) elif isinstance(values, dict): @@ -640,18 +649,19 @@ class SerialMotorsBus(MotorsBusBase): pass @abc.abstractmethod - def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None: + def enable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None: """Enable torque on selected motors. Args: - motor (int): Same semantics as :pymeth:`disable_torque`. Defaults to `None`. + motors (int | str | list[str] | None, optional): Same semantics as :pymeth:`disable_torque`. + Defaults to `None`. num_retry (int, optional): Number of additional retry attempts on communication failure. Defaults to 0. """ pass @contextmanager - def torque_disabled(self, motors: int | str | list[str] | None = None): + def torque_disabled(self, motors: str | list[str] | None = None): """Context-manager that guarantees torque is re-enabled. This helper is useful to temporarily disable torque when configuring motors. @@ -728,24 +738,19 @@ class SerialMotorsBus(MotorsBusBase): """ pass - def reset_calibration(self, motors: NameOrID | list[NameOrID] | None = None) -> None: + def reset_calibration(self, motors: NameOrID | Sequence[NameOrID] | None = None) -> None: """Restore factory calibration for the selected motors. Homing offset is set to ``0`` and min/max position limits are set to the full usable range. The in-memory :pyattr:`calibration` is cleared. Args: - motors (NameOrID | list[NameOrID] | None, optional): Selection of motors. `None` (default) + motors (NameOrID | Sequence[NameOrID] | None, optional): Selection of motors. `None` (default) resets every motor. """ - if motors is None: - motors = list(self.motors) - elif isinstance(motors, (str | int)): - motors = [motors] - elif not isinstance(motors, list): - raise TypeError(motors) + motor_names = self._get_motors_list(motors) - for motor in motors: + for motor in motor_names: model = self._get_motor_model(motor) max_res = self.model_resolution_table[model] - 1 self.write("Homing_Offset", motor, 0, normalize=False) @@ -754,7 +759,9 @@ class SerialMotorsBus(MotorsBusBase): self.calibration = {} - def set_half_turn_homings(self, motors: NameOrID | list[NameOrID] | None = None) -> dict[NameOrID, Value]: + def set_half_turn_homings( + self, motors: NameOrID | Sequence[NameOrID] | None = None + ) -> dict[NameOrID, Value]: """Centre each motor range around its current position. The function computes and writes a homing offset such that the present position becomes exactly one @@ -764,17 +771,12 @@ class SerialMotorsBus(MotorsBusBase): motors (NameOrID | list[NameOrID] | None, optional): Motors to adjust. Defaults to all motors (`None`). Returns: - dict[NameOrID, Value]: Mapping *motor → written homing offset*. + dict[str, Value]: Mapping *motor name → written homing offset*. """ - if motors is None: - motors = list(self.motors) - elif isinstance(motors, (str | int)): - motors = [motors] - elif not isinstance(motors, list): - raise TypeError(motors) + motor_names = self._get_motors_list(motors) - self.reset_calibration(motors) - actual_positions = self.sync_read("Present_Position", motors, normalize=False) + self.reset_calibration(motor_names) + actual_positions = self.sync_read("Present_Position", motor_names, normalize=False) homing_offsets = self._get_half_turn_homings(actual_positions) for motor, offset in homing_offsets.items(): self.write("Homing_Offset", motor, offset) @@ -786,8 +788,8 @@ class SerialMotorsBus(MotorsBusBase): pass def record_ranges_of_motion( - self, motors: NameOrID | list[NameOrID] | None = None, display_values: bool = True - ) -> tuple[dict[NameOrID, Value], dict[NameOrID, Value]]: + self, motors: NameOrID | Sequence[NameOrID] | None = None, display_values: bool = True + ) -> tuple[dict[str, Value], dict[str, Value]]: """Interactively record the min/max encoder values of each motor. Move the joints by hand (with torque disabled) while the method streams live positions. Press @@ -799,30 +801,25 @@ class SerialMotorsBus(MotorsBusBase): display_values (bool, optional): When `True` (default) a live table is printed to the console. Returns: - tuple[dict[NameOrID, Value], dict[NameOrID, Value]]: Two dictionaries *mins* and *maxes* with the + tuple[dict[str, Value], dict[str, Value]]: Two dictionaries *mins* and *maxes* with the extreme values observed for each motor. """ - if motors is None: - motors = list(self.motors) - elif isinstance(motors, (str | int)): - motors = [motors] - elif not isinstance(motors, list): - raise TypeError(motors) + motor_names = self._get_motors_list(motors) - start_positions = self.sync_read("Present_Position", motors, normalize=False) + start_positions = self.sync_read("Present_Position", motor_names, normalize=False) mins = start_positions.copy() maxes = start_positions.copy() user_pressed_enter = False while not user_pressed_enter: - positions = self.sync_read("Present_Position", motors, normalize=False) + positions = self.sync_read("Present_Position", motor_names, normalize=False) mins = {motor: min(positions[motor], min_) for motor, min_ in mins.items()} maxes = {motor: max(positions[motor], max_) for motor, max_ in maxes.items()} if display_values: print("\n-------------------------------------------") print(f"{'NAME':<15} | {'MIN':>6} | {'POS':>6} | {'MAX':>6}") - for motor in motors: + for motor in motor_names: print(f"{motor:<15} | {mins[motor]:>6} | {positions[motor]:>6} | {maxes[motor]:>6}") if enter_pressed(): @@ -830,9 +827,9 @@ class SerialMotorsBus(MotorsBusBase): if display_values and not user_pressed_enter: # Move cursor up to overwrite the previous output - move_cursor_up(len(motors) + 3) + move_cursor_up(len(motor_names) + 3) - same_min_max = [motor for motor in motors if mins[motor] == maxes[motor]] + same_min_max = [motor for motor in motor_names if mins[motor] == maxes[motor]] if same_min_max: raise ValueError(f"Some motors have the same min and max values:\n{pformat(same_min_max)}") @@ -955,12 +952,12 @@ class SerialMotorsBus(MotorsBusBase): if raise_on_error: raise ConnectionError(self.packet_handler.getTxRxResult(comm)) else: - return + return None if self._is_error(error): if raise_on_error: raise RuntimeError(self.packet_handler.getRxPacketError(error)) else: - return + return None return model_number @@ -1007,12 +1004,13 @@ class SerialMotorsBus(MotorsBusBase): err_msg = f"Failed to read '{data_name}' on {id_=} after {num_retry + 1} tries." value, _, _ = self._read(addr, length, id_, num_retry=num_retry, raise_on_error=True, err_msg=err_msg) - id_value = self._decode_sign(data_name, {id_: value}) + decoded = self._decode_sign(data_name, {id_: value}) if normalize and data_name in self.normalized_data: - id_value = self._normalize(id_value) + normalized = self._normalize(decoded) + return normalized[id_] - return id_value[id_] + return decoded[id_] def _read( self, @@ -1023,7 +1021,7 @@ class SerialMotorsBus(MotorsBusBase): num_retry: int = 0, raise_on_error: bool = True, err_msg: str = "", - ) -> tuple[int, int]: + ) -> tuple[int, int, int]: if length == 1: read_fn = self.packet_handler.read1ByteTxRx elif length == 2: @@ -1073,13 +1071,14 @@ class SerialMotorsBus(MotorsBusBase): model = self.motors[motor].model addr, length = get_address(self.model_ctrl_table, model, data_name) + int_value = int(value) if normalize and data_name in self.normalized_data: - value = self._unnormalize({id_: value})[id_] + int_value = self._unnormalize({id_: value})[id_] - value = self._encode_sign(data_name, {id_: value})[id_] + int_value = self._encode_sign(data_name, {id_: int_value})[id_] - err_msg = f"Failed to write '{data_name}' on {id_=} with '{value}' after {num_retry + 1} tries." - self._write(addr, length, id_, value, num_retry=num_retry, raise_on_error=True, err_msg=err_msg) + err_msg = f"Failed to write '{data_name}' on {id_=} with '{int_value}' after {num_retry + 1} tries." + self._write(addr, length, id_, int_value, num_retry=num_retry, raise_on_error=True, err_msg=err_msg) def _write( self, @@ -1113,7 +1112,7 @@ class SerialMotorsBus(MotorsBusBase): def sync_read( self, data_name: str, - motors: str | list[str] | None = None, + motors: NameOrID | Sequence[NameOrID] | None = None, *, normalize: bool = True, num_retry: int = 0, @@ -1122,7 +1121,7 @@ class SerialMotorsBus(MotorsBusBase): Args: data_name (str): Register name. - motors (str | list[str] | None, optional): Motors to query. `None` (default) reads every motor. + motors (NameOrID | Sequence[NameOrID] | None, optional): Motors to query. `None` (default) reads every motor. normalize (bool, optional): Normalisation flag. Defaults to `True`. num_retry (int, optional): Retry attempts. Defaults to `0`. @@ -1143,16 +1142,17 @@ class SerialMotorsBus(MotorsBusBase): addr, length = get_address(self.model_ctrl_table, model, data_name) err_msg = f"Failed to sync read '{data_name}' on {ids=} after {num_retry + 1} tries." - ids_values, _ = self._sync_read( + raw_ids_values, _ = self._sync_read( addr, length, ids, num_retry=num_retry, raise_on_error=True, err_msg=err_msg ) - ids_values = self._decode_sign(data_name, ids_values) + decoded = self._decode_sign(data_name, raw_ids_values) if normalize and data_name in self.normalized_data: - ids_values = self._normalize(ids_values) + normalized = self._normalize(decoded) + return {self._id_to_name(id_): value for id_, value in normalized.items()} - return {self._id_to_name(id_): value for id_, value in ids_values.items()} + return {self._id_to_name(id_): value for id_, value in decoded.items()} def _sync_read( self, @@ -1224,21 +1224,24 @@ class SerialMotorsBus(MotorsBusBase): num_retry (int, optional): Retry attempts. Defaults to `0`. """ - ids_values = self._get_ids_values_dict(values) - models = [self._id_to_model(id_) for id_ in ids_values] + raw_ids_values = self._get_ids_values_dict(values) + models = [self._id_to_model(id_) for id_ in raw_ids_values] if self._has_different_ctrl_tables: assert_same_address(self.model_ctrl_table, models, data_name) model = next(iter(models)) addr, length = get_address(self.model_ctrl_table, model, data_name) + int_ids_values = {id_: int(val) for id_, val in raw_ids_values.items()} if normalize and data_name in self.normalized_data: - ids_values = self._unnormalize(ids_values) + int_ids_values = self._unnormalize(raw_ids_values) - ids_values = self._encode_sign(data_name, ids_values) + int_ids_values = self._encode_sign(data_name, int_ids_values) - err_msg = f"Failed to sync write '{data_name}' with {ids_values=} after {num_retry + 1} tries." - self._sync_write(addr, length, ids_values, num_retry=num_retry, raise_on_error=True, err_msg=err_msg) + err_msg = f"Failed to sync write '{data_name}' with ids_values={int_ids_values} after {num_retry + 1} tries." + self._sync_write( + addr, length, int_ids_values, num_retry=num_retry, raise_on_error=True, err_msg=err_msg + ) def _sync_write( self, From 1ba3975020c8079630ff7dda8fe983ad473d7c12 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Tue, 10 Feb 2026 17:49:30 +0100 Subject: [PATCH 034/131] chore: use is_connected decorators (#2948) * chore: use is_connected decorators * chore(robots): add is_connected to bi setups too --- src/lerobot/cameras/opencv/camera_opencv.py | 19 ++++++--------- .../cameras/reachy2_camera/reachy2_camera.py | 14 ++++------- .../cameras/realsense/camera_realsense.py | 23 +++++++------------ src/lerobot/cameras/zmq/camera_zmq.py | 16 +++++-------- src/lerobot/motors/damiao/damiao.py | 16 ++++--------- .../bi_openarm_follower.py | 5 ++++ .../robots/bi_so_follower/bi_so_follower.py | 5 ++++ .../openarm_follower/openarm_follower.py | 15 ++++-------- .../bi_openarm_leader/bi_openarm_leader.py | 4 ++++ .../bi_so_leader/bi_so_leader.py | 4 +++- .../openarm_leader/openarm_leader.py | 11 ++++----- 11 files changed, 57 insertions(+), 75 deletions(-) diff --git a/src/lerobot/cameras/opencv/camera_opencv.py b/src/lerobot/cameras/opencv/camera_opencv.py index d581e1425..465ba7a1b 100644 --- a/src/lerobot/cameras/opencv/camera_opencv.py +++ b/src/lerobot/cameras/opencv/camera_opencv.py @@ -32,7 +32,8 @@ if platform.system() == "Windows" and "OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS" os.environ["OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS"] = "0" import cv2 # type: ignore # TODO: add type stubs for OpenCV -from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError +from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected +from lerobot.utils.errors import DeviceNotConnectedError from ..camera import Camera from ..utils import get_cv2_backend, get_cv2_rotation @@ -132,6 +133,7 @@ class OpenCVCamera(Camera): """Checks if the camera is currently connected and opened.""" return isinstance(self.videocapture, cv2.VideoCapture) and self.videocapture.isOpened() + @check_if_already_connected def connect(self, warmup: bool = True) -> None: """ Connects to the OpenCV camera specified in the configuration. @@ -148,8 +150,6 @@ class OpenCVCamera(Camera): ConnectionError: If the specified camera index/path is not found or fails to open. RuntimeError: If the camera opens but fails to apply requested settings. """ - if self.is_connected: - raise DeviceAlreadyConnectedError(f"{self} is already connected.") # Use 1 thread for OpenCV operations to avoid potential conflicts or # blocking in multi-threaded applications, especially during data collection. @@ -178,6 +178,7 @@ class OpenCVCamera(Camera): logger.info(f"{self} connected.") + @check_if_not_connected def _configure_capture_settings(self) -> None: """ Applies the specified FOURCC, FPS, width, and height settings to the connected camera. @@ -197,8 +198,6 @@ class OpenCVCamera(Camera): to the requested value. DeviceNotConnectedError: If the camera is not connected. """ - if not self.is_connected: - raise DeviceNotConnectedError(f"Cannot configure settings for {self} as it is not connected.") # Set FOURCC first (if specified) as it can affect available FPS/resolution options if self.config.fourcc is not None: @@ -348,6 +347,7 @@ class OpenCVCamera(Camera): return frame + @check_if_not_connected def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]: """ Reads a single frame synchronously from the camera. @@ -374,9 +374,6 @@ class OpenCVCamera(Camera): f"{self} read() color_mode parameter is deprecated and will be removed in future versions." ) - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - if self.thread is None or not self.thread.is_alive(): raise RuntimeError(f"{self} read thread is not running.") @@ -490,6 +487,7 @@ class OpenCVCamera(Camera): self.latest_timestamp = None self.new_frame_event.clear() + @check_if_not_connected def async_read(self, timeout_ms: float = 200) -> NDArray[Any]: """ Reads the latest available frame asynchronously. @@ -512,8 +510,6 @@ class OpenCVCamera(Camera): TimeoutError: If no frame becomes available within the specified timeout. RuntimeError: If an unexpected error occurs. """ - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") if self.thread is None or not self.thread.is_alive(): raise RuntimeError(f"{self} read thread is not running.") @@ -533,6 +529,7 @@ class OpenCVCamera(Camera): return frame + @check_if_not_connected def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]: """Return the most recent frame captured immediately (Peeking). @@ -548,8 +545,6 @@ class OpenCVCamera(Camera): DeviceNotConnectedError: If the camera is not connected. RuntimeError: If the camera is connected but has not captured any frames yet. """ - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") if self.thread is None or not self.thread.is_alive(): raise RuntimeError(f"{self} read thread is not running.") diff --git a/src/lerobot/cameras/reachy2_camera/reachy2_camera.py b/src/lerobot/cameras/reachy2_camera/reachy2_camera.py index 5cede466d..0c1dc43d8 100644 --- a/src/lerobot/cameras/reachy2_camera/reachy2_camera.py +++ b/src/lerobot/cameras/reachy2_camera/reachy2_camera.py @@ -32,6 +32,7 @@ if platform.system() == "Windows" and "OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS" import cv2 # type: ignore # TODO: add type stubs for OpenCV import numpy as np # type: ignore # TODO: add type stubs for numpy +from lerobot.utils.decorators import check_if_not_connected from lerobot.utils.import_utils import _reachy2_sdk_available if TYPE_CHECKING or _reachy2_sdk_available: @@ -123,6 +124,7 @@ class Reachy2Camera(Camera): """ raise NotImplementedError("Camera detection is not implemented for Reachy2 cameras.") + @check_if_not_connected def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]: """ Reads a single frame synchronously from the camera. @@ -136,9 +138,6 @@ class Reachy2Camera(Camera): """ start_time = time.perf_counter() - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - if self.cam_manager is None: raise DeviceNotConnectedError(f"{self} is not connected.") @@ -184,6 +183,7 @@ class Reachy2Camera(Camera): return frame + @check_if_not_connected def async_read(self, timeout_ms: float = 200) -> NDArray[Any]: """ Same as read() @@ -197,11 +197,10 @@ class Reachy2Camera(Camera): TimeoutError: If no frame becomes available within the specified timeout. RuntimeError: If an unexpected error occurs. """ - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") return self.read() + @check_if_not_connected def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]: """Return the most recent frame captured immediately (Peeking). @@ -219,8 +218,6 @@ class Reachy2Camera(Camera): DeviceNotConnectedError: If the camera is not connected. RuntimeError: If the camera is connected but has not captured any frames yet. """ - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") if self.latest_frame is None or self.latest_timestamp is None: raise RuntimeError(f"{self} has not captured any frames yet.") @@ -233,6 +230,7 @@ class Reachy2Camera(Camera): return self.latest_frame + @check_if_not_connected def disconnect(self) -> None: """ Stops the background read thread (if running). @@ -240,8 +238,6 @@ class Reachy2Camera(Camera): Raises: DeviceNotConnectedError: If the camera is already disconnected. """ - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} not connected.") if self.cam_manager is not None: self.cam_manager.disconnect() diff --git a/src/lerobot/cameras/realsense/camera_realsense.py b/src/lerobot/cameras/realsense/camera_realsense.py index e47f25381..d599cdce0 100644 --- a/src/lerobot/cameras/realsense/camera_realsense.py +++ b/src/lerobot/cameras/realsense/camera_realsense.py @@ -30,7 +30,8 @@ try: except Exception as e: logging.info(f"Could not import realsense: {e}") -from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError +from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected +from lerobot.utils.errors import DeviceNotConnectedError from ..camera import Camera from ..configs import ColorMode @@ -152,6 +153,7 @@ class RealSenseCamera(Camera): """Checks if the camera pipeline is started and streams are active.""" return self.rs_pipeline is not None and self.rs_profile is not None + @check_if_already_connected def connect(self, warmup: bool = True) -> None: """ Connects to the RealSense camera specified in the configuration. @@ -169,8 +171,6 @@ class RealSenseCamera(Camera): ConnectionError: If the camera is found but fails to start the pipeline or no RealSense devices are detected at all. RuntimeError: If the pipeline starts but fails to apply requested settings. """ - if self.is_connected: - raise DeviceAlreadyConnectedError(f"{self} is already connected.") self.rs_pipeline = rs.pipeline() rs_config = rs.config() @@ -290,6 +290,7 @@ class RealSenseCamera(Camera): if self.use_depth: rs_config.enable_stream(rs.stream.depth) + @check_if_not_connected def _configure_capture_settings(self) -> None: """Sets fps, width, and height from device stream if not already configured. @@ -299,8 +300,6 @@ class RealSenseCamera(Camera): Raises: DeviceNotConnectedError: If device is not connected. """ - if not self.is_connected: - raise DeviceNotConnectedError(f"Cannot validate settings for {self} as it is not connected.") if self.rs_profile is None: raise RuntimeError(f"{self}: rs_profile must be initialized before use.") @@ -320,6 +319,7 @@ class RealSenseCamera(Camera): self.width, self.height = actual_width, actual_height self.capture_width, self.capture_height = actual_width, actual_height + @check_if_not_connected def read_depth(self, timeout_ms: int = 200) -> NDArray[Any]: """ Reads a single frame (depth) synchronously from the camera. @@ -345,9 +345,6 @@ class RealSenseCamera(Camera): f"Failed to capture depth frame '.read_depth()'. Depth stream is not enabled for {self}." ) - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - if self.thread is None or not self.thread.is_alive(): raise RuntimeError(f"{self} read thread is not running.") @@ -374,6 +371,7 @@ class RealSenseCamera(Camera): return frame + @check_if_not_connected def read(self, color_mode: ColorMode | None = None, timeout_ms: int = 0) -> NDArray[Any]: """ Reads a single frame (color) synchronously from the camera. @@ -403,9 +401,6 @@ class RealSenseCamera(Camera): f"{self} read() timeout_ms parameter is deprecated and will be removed in future versions." ) - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - if self.thread is None or not self.thread.is_alive(): raise RuntimeError(f"{self} read thread is not running.") @@ -534,6 +529,7 @@ class RealSenseCamera(Camera): self.new_frame_event.clear() # NOTE(Steven): Missing implementation for depth for now + @check_if_not_connected def async_read(self, timeout_ms: float = 200) -> NDArray[Any]: """ Reads the latest available frame data (color) asynchronously. @@ -556,8 +552,6 @@ class RealSenseCamera(Camera): TimeoutError: If no frame data becomes available within the specified timeout. RuntimeError: If the background thread died unexpectedly or another error occurs. """ - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") if self.thread is None or not self.thread.is_alive(): raise RuntimeError(f"{self} read thread is not running.") @@ -578,6 +572,7 @@ class RealSenseCamera(Camera): return frame # NOTE(Steven): Missing implementation for depth for now + @check_if_not_connected def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]: """Return the most recent (color) frame captured immediately (Peeking). @@ -593,8 +588,6 @@ class RealSenseCamera(Camera): DeviceNotConnectedError: If the camera is not connected. RuntimeError: If the camera is connected but has not captured any frames yet. """ - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") if self.thread is None or not self.thread.is_alive(): raise RuntimeError(f"{self} read thread is not running.") diff --git a/src/lerobot/cameras/zmq/camera_zmq.py b/src/lerobot/cameras/zmq/camera_zmq.py index f29e16a28..16523b50a 100644 --- a/src/lerobot/cameras/zmq/camera_zmq.py +++ b/src/lerobot/cameras/zmq/camera_zmq.py @@ -34,7 +34,8 @@ import cv2 import numpy as np from numpy.typing import NDArray -from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError +from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected +from lerobot.utils.errors import DeviceNotConnectedError from ..camera import Camera from ..configs import ColorMode @@ -104,6 +105,7 @@ class ZMQCamera(Camera): """Checks if the ZMQ socket is initialized and connected.""" return self._connected and self.context is not None and self.socket is not None + @check_if_already_connected def connect(self, warmup: bool = True) -> None: """Connect to ZMQ camera server. @@ -111,8 +113,6 @@ class ZMQCamera(Camera): warmup (bool): If True, waits for the camera to provide at least one valid frame before returning. Defaults to True. """ - if self.is_connected: - raise DeviceAlreadyConnectedError(f"{self} is already connected.") logger.info(f"Connecting to {self}...") @@ -211,6 +211,7 @@ class ZMQCamera(Camera): return frame + @check_if_not_connected def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]: """ Reads a single frame synchronously from the camera. @@ -228,9 +229,6 @@ class ZMQCamera(Camera): f"{self} read() color_mode parameter is deprecated and will be removed in future versions." ) - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - if self.thread is None or not self.thread.is_alive(): raise RuntimeError(f"{self} read thread is not running.") @@ -301,6 +299,7 @@ class ZMQCamera(Camera): self.latest_timestamp = None self.new_frame_event.clear() + @check_if_not_connected def async_read(self, timeout_ms: float = 200) -> NDArray[Any]: """ Reads the latest available frame asynchronously. @@ -317,8 +316,6 @@ class ZMQCamera(Camera): TimeoutError: If no frame data becomes available within the specified timeout. RuntimeError: If the background thread is not running. """ - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") if self.thread is None or not self.thread.is_alive(): raise RuntimeError(f"{self} read thread is not running.") @@ -335,6 +332,7 @@ class ZMQCamera(Camera): return frame + @check_if_not_connected def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]: """Return the most recent frame captured immediately (Peeking). @@ -350,8 +348,6 @@ class ZMQCamera(Camera): DeviceNotConnectedError: If the camera is not connected. RuntimeError: If the camera is connected but has not captured any frames yet. """ - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") if self.thread is None or not self.thread.is_alive(): raise RuntimeError(f"{self} read thread is not running.") diff --git a/src/lerobot/motors/damiao/damiao.py b/src/lerobot/motors/damiao/damiao.py index 95a9e70d1..a454130a6 100644 --- a/src/lerobot/motors/damiao/damiao.py +++ b/src/lerobot/motors/damiao/damiao.py @@ -23,6 +23,7 @@ from copy import deepcopy from functools import cached_property from typing import TYPE_CHECKING, Any, TypedDict +from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from lerobot.utils.import_utils import _can_available if TYPE_CHECKING or _can_available: @@ -36,7 +37,6 @@ else: import numpy as np -from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError from lerobot.utils.robot_utils import precise_sleep from lerobot.utils.utils import enter_pressed, move_cursor_up @@ -155,6 +155,7 @@ class DamiaoMotorsBus(MotorsBusBase): """Check if the CAN bus is connected.""" return self._is_connected and self.canbus is not None + @check_if_already_connected def connect(self, handshake: bool = True) -> None: """ Open the CAN bus and initialize communication. @@ -162,10 +163,6 @@ class DamiaoMotorsBus(MotorsBusBase): Args: handshake: If True, ping all motors to verify they're present """ - if self.is_connected: - raise DeviceAlreadyConnectedError( - f"{self.__class__.__name__}('{self.port}') is already connected." - ) try: # Auto-detect interface type based on port name @@ -249,6 +246,7 @@ class DamiaoMotorsBus(MotorsBusBase): ) logger.info("Handshake successful. All motors ready.") + @check_if_not_connected def disconnect(self, disable_torque: bool = True) -> None: """ Close the CAN bus connection. @@ -256,8 +254,6 @@ class DamiaoMotorsBus(MotorsBusBase): Args: disable_torque: If True, disable torque on all motors before disconnecting """ - if not self.is_connected: - raise DeviceNotConnectedError(f"{self.__class__.__name__}('{self.port}') is not connected.") if disable_torque: try: @@ -586,10 +582,9 @@ class DamiaoMotorsBus(MotorsBusBase): except Exception as e: logger.warning(f"Failed to decode response from {motor}: {e}") + @check_if_not_connected def read(self, data_name: str, motor: str) -> Value: """Read a value from a single motor. Positions are always in degrees.""" - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") # Refresh motor to get latest state msg = self._refresh_motor(motor) @@ -619,6 +614,7 @@ class DamiaoMotorsBus(MotorsBusBase): raise ValueError(f"Unknown data_name: {data_name}") return mapping[data_name] + @check_if_not_connected def write( self, data_name: str, @@ -629,8 +625,6 @@ class DamiaoMotorsBus(MotorsBusBase): Write a value to a single motor. Positions are always in degrees. Can write 'Goal_Position', 'Kp', or 'Kd'. """ - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") if data_name in ("Kp", "Kd"): self._gains[motor][data_name.lower()] = float(value) diff --git a/src/lerobot/robots/bi_openarm_follower/bi_openarm_follower.py b/src/lerobot/robots/bi_openarm_follower/bi_openarm_follower.py index 466eb07e5..2e3885e67 100644 --- a/src/lerobot/robots/bi_openarm_follower/bi_openarm_follower.py +++ b/src/lerobot/robots/bi_openarm_follower/bi_openarm_follower.py @@ -19,6 +19,7 @@ from functools import cached_property from lerobot.processor import RobotAction, RobotObservation from lerobot.robots.openarm_follower import OpenArmFollower, OpenArmFollowerConfig +from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from ..robot import Robot from .config_bi_openarm_follower import BiOpenArmFollowerConfig @@ -112,6 +113,7 @@ class BiOpenArmFollower(Robot): def is_connected(self) -> bool: return self.left_arm.is_connected and self.right_arm.is_connected + @check_if_already_connected def connect(self, calibrate: bool = True) -> None: self.left_arm.connect(calibrate) self.right_arm.connect(calibrate) @@ -133,6 +135,7 @@ class BiOpenArmFollower(Robot): "Motor ID configuration is typically done via manufacturer tools for CAN motors." ) + @check_if_not_connected def get_observation(self) -> RobotObservation: obs_dict = {} @@ -146,6 +149,7 @@ class BiOpenArmFollower(Robot): return obs_dict + @check_if_not_connected def send_action( self, action: RobotAction, @@ -170,6 +174,7 @@ class BiOpenArmFollower(Robot): return {**prefixed_sent_action_left, **prefixed_sent_action_right} + @check_if_not_connected def disconnect(self): self.left_arm.disconnect() self.right_arm.disconnect() diff --git a/src/lerobot/robots/bi_so_follower/bi_so_follower.py b/src/lerobot/robots/bi_so_follower/bi_so_follower.py index 09f849772..28c58b898 100644 --- a/src/lerobot/robots/bi_so_follower/bi_so_follower.py +++ b/src/lerobot/robots/bi_so_follower/bi_so_follower.py @@ -19,6 +19,7 @@ from functools import cached_property from lerobot.processor import RobotAction, RobotObservation from lerobot.robots.so_follower import SOFollower, SOFollowerRobotConfig +from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from ..robot import Robot from .config_bi_so_follower import BiSOFollowerConfig @@ -96,6 +97,7 @@ class BiSOFollower(Robot): def is_connected(self) -> bool: return self.left_arm.is_connected and self.right_arm.is_connected + @check_if_already_connected def connect(self, calibrate: bool = True) -> None: self.left_arm.connect(calibrate) self.right_arm.connect(calibrate) @@ -116,6 +118,7 @@ class BiSOFollower(Robot): self.left_arm.setup_motors() self.right_arm.setup_motors() + @check_if_not_connected def get_observation(self) -> RobotObservation: obs_dict = {} @@ -129,6 +132,7 @@ class BiSOFollower(Robot): return obs_dict + @check_if_not_connected def send_action(self, action: RobotAction) -> RobotAction: # Remove "left_" prefix left_action = { @@ -148,6 +152,7 @@ class BiSOFollower(Robot): return {**prefixed_sent_action_left, **prefixed_sent_action_right} + @check_if_not_connected def disconnect(self): self.left_arm.disconnect() self.right_arm.disconnect() diff --git a/src/lerobot/robots/openarm_follower/openarm_follower.py b/src/lerobot/robots/openarm_follower/openarm_follower.py index c221afd10..d6794a226 100644 --- a/src/lerobot/robots/openarm_follower/openarm_follower.py +++ b/src/lerobot/robots/openarm_follower/openarm_follower.py @@ -23,7 +23,7 @@ from lerobot.cameras.utils import make_cameras_from_configs from lerobot.motors import Motor, MotorCalibration, MotorNormMode from lerobot.motors.damiao import DamiaoMotorsBus from lerobot.processor import RobotAction, RobotObservation -from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError +from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from ..robot import Robot from ..utils import ensure_safe_goal_position @@ -119,6 +119,7 @@ class OpenArmFollower(Robot): """Check if robot is connected.""" return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values()) + @check_if_already_connected def connect(self, calibrate: bool = True) -> None: """ Connect to the robot and optionally calibrate. @@ -126,8 +127,6 @@ class OpenArmFollower(Robot): We assume that at connection time, the arms are in a safe rest position, and torque can be safely disabled to run calibration if needed. """ - if self.is_connected: - raise DeviceAlreadyConnectedError(f"{self} already connected") # Connect to CAN bus logger.info(f"Connecting arm on {self.config.port}...") @@ -219,6 +218,7 @@ class OpenArmFollower(Robot): "Motor ID configuration is typically done via manufacturer tools for CAN motors." ) + @check_if_not_connected def get_observation(self) -> RobotObservation: """ Get current observation from robot including position, velocity, and torque. @@ -228,9 +228,6 @@ class OpenArmFollower(Robot): """ start = time.perf_counter() - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - obs_dict: dict[str, Any] = {} states = self.bus.sync_read_all_states() @@ -253,6 +250,7 @@ class OpenArmFollower(Robot): return obs_dict + @check_if_not_connected def send_action( self, action: RobotAction, @@ -272,8 +270,6 @@ class OpenArmFollower(Robot): Returns: The action actually sent (potentially clipped) """ - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") goal_pos = {key.removesuffix(".pos"): val for key, val in action.items() if key.endswith(".pos")} @@ -333,10 +329,9 @@ class OpenArmFollower(Robot): return {f"{motor}.pos": val for motor, val in goal_pos.items()} + @check_if_not_connected def disconnect(self): """Disconnect from robot.""" - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") # Disconnect CAN bus self.bus.disconnect(self.config.disable_torque_on_disconnect) diff --git a/src/lerobot/teleoperators/bi_openarm_leader/bi_openarm_leader.py b/src/lerobot/teleoperators/bi_openarm_leader/bi_openarm_leader.py index c4383293f..74b0c9b83 100644 --- a/src/lerobot/teleoperators/bi_openarm_leader/bi_openarm_leader.py +++ b/src/lerobot/teleoperators/bi_openarm_leader/bi_openarm_leader.py @@ -19,6 +19,7 @@ from functools import cached_property from lerobot.processor import RobotAction from lerobot.teleoperators.openarm_leader import OpenArmLeaderConfig +from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from ..openarm_leader import OpenArmLeader from ..teleoperator import Teleoperator @@ -88,6 +89,7 @@ class BiOpenArmLeader(Teleoperator): def is_connected(self) -> bool: return self.left_arm.is_connected and self.right_arm.is_connected + @check_if_already_connected def connect(self, calibrate: bool = True) -> None: self.left_arm.connect(calibrate) self.right_arm.connect(calibrate) @@ -109,6 +111,7 @@ class BiOpenArmLeader(Teleoperator): "Motor ID configuration is typically done via manufacturer tools for CAN motors." ) + @check_if_not_connected def get_action(self) -> RobotAction: action_dict = {} @@ -126,6 +129,7 @@ class BiOpenArmLeader(Teleoperator): # TODO: Implement force feedback raise NotImplementedError + @check_if_not_connected def disconnect(self) -> None: self.left_arm.disconnect() self.right_arm.disconnect() diff --git a/src/lerobot/teleoperators/bi_so_leader/bi_so_leader.py b/src/lerobot/teleoperators/bi_so_leader/bi_so_leader.py index 90bf2a92d..e84ac6f50 100644 --- a/src/lerobot/teleoperators/bi_so_leader/bi_so_leader.py +++ b/src/lerobot/teleoperators/bi_so_leader/bi_so_leader.py @@ -18,7 +18,7 @@ import logging from functools import cached_property from lerobot.teleoperators.so_leader import SOLeaderTeleopConfig -from lerobot.utils.decorators import check_if_not_connected +from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from ..so_leader import SOLeader from ..teleoperator import Teleoperator @@ -72,6 +72,7 @@ class BiSOLeader(Teleoperator): def is_connected(self) -> bool: return self.left_arm.is_connected and self.right_arm.is_connected + @check_if_already_connected def connect(self, calibrate: bool = True) -> None: self.left_arm.connect(calibrate) self.right_arm.connect(calibrate) @@ -110,6 +111,7 @@ class BiSOLeader(Teleoperator): # TODO: Implement force feedback raise NotImplementedError + @check_if_not_connected def disconnect(self) -> None: self.left_arm.disconnect() self.right_arm.disconnect() diff --git a/src/lerobot/teleoperators/openarm_leader/openarm_leader.py b/src/lerobot/teleoperators/openarm_leader/openarm_leader.py index edf4d7090..d9eaabe0f 100644 --- a/src/lerobot/teleoperators/openarm_leader/openarm_leader.py +++ b/src/lerobot/teleoperators/openarm_leader/openarm_leader.py @@ -21,7 +21,7 @@ from typing import Any from lerobot.motors import Motor, MotorCalibration, MotorNormMode from lerobot.motors.damiao import DamiaoMotorsBus from lerobot.processor import RobotAction -from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError +from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from ..teleoperator import Teleoperator from .config_openarm_leader import OpenArmLeaderConfig @@ -84,6 +84,7 @@ class OpenArmLeader(Teleoperator): """Check if teleoperator is connected.""" return self.bus.is_connected + @check_if_already_connected def connect(self, calibrate: bool = True) -> None: """ Connect to the teleoperator. @@ -91,8 +92,6 @@ class OpenArmLeader(Teleoperator): For manual control, we disable torque after connecting so the arm can be moved by hand. """ - if self.is_connected: - raise DeviceAlreadyConnectedError(f"{self} already connected") # Connect to CAN bus logger.info(f"Connecting arm on {self.config.port}...") @@ -183,6 +182,7 @@ class OpenArmLeader(Teleoperator): "Motor ID configuration is typically done via manufacturer tools for CAN motors." ) + @check_if_not_connected def get_action(self) -> RobotAction: """ Get current action from the leader arm. @@ -193,8 +193,6 @@ class OpenArmLeader(Teleoperator): Reads all motor states (pos/vel/torque) in one CAN refresh cycle. """ start = time.perf_counter() - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") action_dict: dict[str, Any] = {} @@ -214,10 +212,9 @@ class OpenArmLeader(Teleoperator): def send_feedback(self, feedback: dict[str, float]) -> None: raise NotImplementedError("Feedback is not yet implemented for OpenArm leader.") + @check_if_not_connected def disconnect(self) -> None: """Disconnect from teleoperator.""" - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") # Disconnect CAN bus # For manual control, ensure torque is disabled before disconnecting From 3c84d271d53c9ca972cda8fce3b3f715ec813817 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Tue, 10 Feb 2026 18:40:50 +0100 Subject: [PATCH 035/131] fix(motors): use decorator to fix precommit (#2951) --- src/lerobot/motors/damiao/damiao.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/lerobot/motors/damiao/damiao.py b/src/lerobot/motors/damiao/damiao.py index a454130a6..ae619f159 100644 --- a/src/lerobot/motors/damiao/damiao.py +++ b/src/lerobot/motors/damiao/damiao.py @@ -700,14 +700,12 @@ class DamiaoMotorsBus(MotorsBusBase): else: logger.warning(f"Packet drop: {motor} (ID: 0x{recv_id:02X}). Using last known state.") + @check_if_not_connected def sync_write(self, data_name: str, values: dict[str, Value]) -> None: """ Write values to multiple motors simultaneously. Positions are always in degrees. """ - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - if data_name in ("Kp", "Kd"): key = data_name.lower() for motor, val in values.items(): From fc8a388a2538937992bd8b28bc7ac909ebd1b9a0 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Wed, 11 Feb 2026 13:57:25 +0100 Subject: [PATCH 036/131] feat(cameras): make backend configurable to the CLI (#2945) * feat(cameras): make backend configurable to the CLI * chore(cameras): address feedback * feat(Enum error messages): adding better instanciation error messages for Enum classes * chore(Enum error messages): propagating Enum error messages to all camera classes * chore(comments): removing superfluous comments * chore(format): applying ruff checks --------- Co-authored-by: CarolinePascal --- src/lerobot/cameras/__init__.py | 2 +- src/lerobot/cameras/configs.py | 23 +++++++++++++++++++ src/lerobot/cameras/opencv/camera_opencv.py | 4 ++-- .../cameras/opencv/configuration_opencv.py | 23 ++++++------------- .../configuration_reachy2_camera.py | 5 +--- .../realsense/configuration_realsense.py | 16 ++----------- src/lerobot/cameras/utils.py | 12 ---------- src/lerobot/cameras/zmq/configuration_zmq.py | 5 +--- 8 files changed, 37 insertions(+), 53 deletions(-) diff --git a/src/lerobot/cameras/__init__.py b/src/lerobot/cameras/__init__.py index 1488cd89e..cbf1f11bf 100644 --- a/src/lerobot/cameras/__init__.py +++ b/src/lerobot/cameras/__init__.py @@ -13,5 +13,5 @@ # limitations under the License. from .camera import Camera -from .configs import CameraConfig, ColorMode, Cv2Rotation +from .configs import CameraConfig, ColorMode, Cv2Backends, Cv2Rotation from .utils import make_cameras_from_configs diff --git a/src/lerobot/cameras/configs.py b/src/lerobot/cameras/configs.py index 056eec314..987b74775 100644 --- a/src/lerobot/cameras/configs.py +++ b/src/lerobot/cameras/configs.py @@ -25,6 +25,10 @@ class ColorMode(str, Enum): RGB = "rgb" BGR = "bgr" + @classmethod + def _missing_(cls, value: object) -> None: + raise ValueError(f"`color_mode` is expected to be in {list(cls)}, but {value} is provided.") + class Cv2Rotation(int, Enum): NO_ROTATION = 0 @@ -32,6 +36,25 @@ class Cv2Rotation(int, Enum): ROTATE_180 = 180 ROTATE_270 = -90 + @classmethod + def _missing_(cls, value: object) -> None: + raise ValueError(f"`rotation` is expected to be in {list(cls)}, but {value} is provided.") + + +# Subset from https://docs.opencv.org/3.4/d4/d15/group__videoio__flags__base.html +class Cv2Backends(int, Enum): + ANY = 0 + V4L2 = 200 + DSHOW = 700 + PVAPI = 800 + ANDROID = 1000 + AVFOUNDATION = 1200 + MSMF = 1400 + + @classmethod + def _missing_(cls, value: object) -> None: + raise ValueError(f"`backend` is expected to be in {list(cls)}, but {value} is provided.") + @dataclass(kw_only=True) class CameraConfig(draccus.ChoiceRegistry, abc.ABC): # type: ignore # TODO: add type stubs for draccus diff --git a/src/lerobot/cameras/opencv/camera_opencv.py b/src/lerobot/cameras/opencv/camera_opencv.py index 465ba7a1b..10b3f21da 100644 --- a/src/lerobot/cameras/opencv/camera_opencv.py +++ b/src/lerobot/cameras/opencv/camera_opencv.py @@ -36,7 +36,7 @@ from lerobot.utils.decorators import check_if_already_connected, check_if_not_co from lerobot.utils.errors import DeviceNotConnectedError from ..camera import Camera -from ..utils import get_cv2_backend, get_cv2_rotation +from ..utils import get_cv2_rotation from .configuration_opencv import ColorMode, OpenCVCameraConfig # NOTE(Steven): The maximum opencv device index depends on your operating system. For instance, @@ -118,7 +118,7 @@ class OpenCVCamera(Camera): self.new_frame_event: Event = Event() self.rotation: int | None = get_cv2_rotation(config.rotation) - self.backend: int = get_cv2_backend() + self.backend: int = config.backend if self.height and self.width: self.capture_width, self.capture_height = self.width, self.height diff --git a/src/lerobot/cameras/opencv/configuration_opencv.py b/src/lerobot/cameras/opencv/configuration_opencv.py index 37a42861c..8ae57fe3c 100644 --- a/src/lerobot/cameras/opencv/configuration_opencv.py +++ b/src/lerobot/cameras/opencv/configuration_opencv.py @@ -15,9 +15,9 @@ from dataclasses import dataclass from pathlib import Path -from ..configs import CameraConfig, ColorMode, Cv2Rotation +from ..configs import CameraConfig, ColorMode, Cv2Backends, Cv2Rotation -__all__ = ["OpenCVCameraConfig", "ColorMode", "Cv2Rotation"] +__all__ = ["OpenCVCameraConfig", "ColorMode", "Cv2Rotation", "Cv2Backends"] @CameraConfig.register_subclass("opencv") @@ -50,6 +50,7 @@ class OpenCVCameraConfig(CameraConfig): rotation: Image rotation setting (0°, 90°, 180°, or 270°). Defaults to no rotation. warmup_s: Time reading frames before returning from connect (in seconds) fourcc: FOURCC code for video format (e.g., "MJPG", "YUYV", "I420"). Defaults to None (auto-detect). + backend: OpenCV backend identifier (https://docs.opencv.org/3.4/d4/d15/group__videoio__flags__base.html). Defaults to ANY. Note: - Only 3-channel color output (RGB/BGR) is currently supported. @@ -62,22 +63,12 @@ class OpenCVCameraConfig(CameraConfig): rotation: Cv2Rotation = Cv2Rotation.NO_ROTATION warmup_s: int = 1 fourcc: str | None = None + backend: Cv2Backends = Cv2Backends.ANY def __post_init__(self) -> None: - if self.color_mode not in (ColorMode.RGB, ColorMode.BGR): - raise ValueError( - f"`color_mode` is expected to be {ColorMode.RGB.value} or {ColorMode.BGR.value}, but {self.color_mode} is provided." - ) - - if self.rotation not in ( - Cv2Rotation.NO_ROTATION, - Cv2Rotation.ROTATE_90, - Cv2Rotation.ROTATE_180, - Cv2Rotation.ROTATE_270, - ): - raise ValueError( - f"`rotation` is expected to be in {(Cv2Rotation.NO_ROTATION, Cv2Rotation.ROTATE_90, Cv2Rotation.ROTATE_180, Cv2Rotation.ROTATE_270)}, but {self.rotation} is provided." - ) + self.color_mode = ColorMode(self.color_mode) + self.rotation = Cv2Rotation(self.rotation) + self.backend = Cv2Backends(self.backend) if self.fourcc is not None and (not isinstance(self.fourcc, str) or len(self.fourcc) != 4): raise ValueError( diff --git a/src/lerobot/cameras/reachy2_camera/configuration_reachy2_camera.py b/src/lerobot/cameras/reachy2_camera/configuration_reachy2_camera.py index ca6db4f03..b40bfe71b 100644 --- a/src/lerobot/cameras/reachy2_camera/configuration_reachy2_camera.py +++ b/src/lerobot/cameras/reachy2_camera/configuration_reachy2_camera.py @@ -74,7 +74,4 @@ class Reachy2CameraConfig(CameraConfig): f"`image_type` is expected to be 'left' or 'right' for teleop camera, and 'rgb' or 'depth' for depth camera, but {self.image_type} is provided." ) - if self.color_mode not in ["rgb", "bgr"]: - raise ValueError( - f"`color_mode` is expected to be 'rgb' or 'bgr', but {self.color_mode} is provided." - ) + self.color_mode = ColorMode(self.color_mode) diff --git a/src/lerobot/cameras/realsense/configuration_realsense.py b/src/lerobot/cameras/realsense/configuration_realsense.py index a094128bc..71b083b00 100644 --- a/src/lerobot/cameras/realsense/configuration_realsense.py +++ b/src/lerobot/cameras/realsense/configuration_realsense.py @@ -60,20 +60,8 @@ class RealSenseCameraConfig(CameraConfig): warmup_s: int = 1 def __post_init__(self) -> None: - if self.color_mode not in (ColorMode.RGB, ColorMode.BGR): - raise ValueError( - f"`color_mode` is expected to be {ColorMode.RGB.value} or {ColorMode.BGR.value}, but {self.color_mode} is provided." - ) - - if self.rotation not in ( - Cv2Rotation.NO_ROTATION, - Cv2Rotation.ROTATE_90, - Cv2Rotation.ROTATE_180, - Cv2Rotation.ROTATE_270, - ): - raise ValueError( - f"`rotation` is expected to be in {(Cv2Rotation.NO_ROTATION, Cv2Rotation.ROTATE_90, Cv2Rotation.ROTATE_180, Cv2Rotation.ROTATE_270)}, but {self.rotation} is provided." - ) + self.color_mode = ColorMode(self.color_mode) + self.rotation = Cv2Rotation(self.rotation) values = (self.fps, self.width, self.height) if any(v is not None for v in values) and any(v is None for v in values): diff --git a/src/lerobot/cameras/utils.py b/src/lerobot/cameras/utils.py index c0e7b6284..7fb2c3bb1 100644 --- a/src/lerobot/cameras/utils.py +++ b/src/lerobot/cameras/utils.py @@ -14,7 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import platform from typing import cast from lerobot.utils.import_utils import make_device_from_device_class @@ -68,14 +67,3 @@ def get_cv2_rotation(rotation: Cv2Rotation) -> int | None: return int(cv2.ROTATE_90_COUNTERCLOCKWISE) else: return None - - -def get_cv2_backend() -> int: - import cv2 - - if platform.system() == "Windows": - return int(cv2.CAP_MSMF) # Use MSMF for Windows instead of AVFOUNDATION - # elif platform.system() == "Darwin": # macOS - # return cv2.CAP_AVFOUNDATION - else: # Linux and others - return int(cv2.CAP_ANY) diff --git a/src/lerobot/cameras/zmq/configuration_zmq.py b/src/lerobot/cameras/zmq/configuration_zmq.py index 4e7732cfc..13690e14c 100644 --- a/src/lerobot/cameras/zmq/configuration_zmq.py +++ b/src/lerobot/cameras/zmq/configuration_zmq.py @@ -32,10 +32,7 @@ class ZMQCameraConfig(CameraConfig): warmup_s: int = 1 def __post_init__(self) -> None: - if self.color_mode not in (ColorMode.RGB, ColorMode.BGR): - raise ValueError( - f"`color_mode` is expected to be {ColorMode.RGB.value} or {ColorMode.BGR.value}, but {self.color_mode} is provided." - ) + self.color_mode = ColorMode(self.color_mode) if self.timeout_ms <= 0: raise ValueError(f"`timeout_ms` must be positive, but {self.timeout_ms} is provided.") From 3615160d891f00a1cb8258ed8f81d327049b640e Mon Sep 17 00:00:00 2001 From: taken-yjyoon Date: Fri, 13 Feb 2026 02:13:51 +0900 Subject: [PATCH 037/131] fix(typo): Fixing wrong argparse examples in the comments (using 'True' not 'true') (#1040) Co-authored-by: juni <> --- src/lerobot/processor/pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lerobot/processor/pipeline.py b/src/lerobot/processor/pipeline.py index 97ec716ff..8de376928 100644 --- a/src/lerobot/processor/pipeline.py +++ b/src/lerobot/processor/pipeline.py @@ -413,7 +413,7 @@ class DataProcessorPipeline(HubMixin, Generic[TInput, TOutput]): Args: save_directory: The directory where the pipeline will be saved. If None, saves to HF_LEROBOT_HOME/processors/{sanitized_pipeline_name}. - repo_id: ID of your repository on the Hub. Used only if `push_to_hub=True`. + repo_id: ID of your repository on the Hub. Used only if `push_to_hub=true`. push_to_hub: Whether or not to push your object to the Hugging Face Hub after saving it. card_kwargs: Additional arguments passed to the card template to customize the card. config_filename: The name of the JSON configuration file. If None, a name is From adebbcf090b47913d3f2e27bb27feccab174f2fc Mon Sep 17 00:00:00 2001 From: Caroline Pascal Date: Thu, 12 Feb 2026 18:56:04 +0100 Subject: [PATCH 038/131] fix(dataset tools draccus): fixing draccus parsing for dataset edit operation type specification (#2949) * fix(edit dataset operation): fixing dataset tools CLI operation type specification * test(edit dataset operation): adding tests for dataset tools operation type specification * chore(format): running pre-commit * chore(backward compatibility): adding a type property in OperationConfig for backward compatibility Signed-off-by: Caroline Pascal --- src/lerobot/scripts/lerobot_edit_dataset.py | 49 +++++++------- tests/scripts/test_edit_dataset_parsing.py | 71 +++++++++++++++++++++ 2 files changed, 96 insertions(+), 24 deletions(-) create mode 100644 tests/scripts/test_edit_dataset_parsing.py diff --git a/src/lerobot/scripts/lerobot_edit_dataset.py b/src/lerobot/scripts/lerobot_edit_dataset.py index 2ca9c520d..7c222ac6c 100644 --- a/src/lerobot/scripts/lerobot_edit_dataset.py +++ b/src/lerobot/scripts/lerobot_edit_dataset.py @@ -109,11 +109,14 @@ Using JSON config file: --config_path path/to/edit_config.json """ +import abc import logging import shutil from dataclasses import dataclass from pathlib import Path +import draccus + from lerobot.configs import parser from lerobot.datasets.dataset_tools import ( convert_image_to_video_dataset, @@ -129,39 +132,46 @@ from lerobot.utils.utils import init_logging @dataclass -class DeleteEpisodesConfig: - type: str = "delete_episodes" +class OperationConfig(draccus.ChoiceRegistry, abc.ABC): + @property + def type(self) -> str: + return self.get_choice_name(self.__class__) + + +@OperationConfig.register_subclass("delete_episodes") +@dataclass +class DeleteEpisodesConfig(OperationConfig): episode_indices: list[int] | None = None +@OperationConfig.register_subclass("split") @dataclass -class SplitConfig: - type: str = "split" +class SplitConfig(OperationConfig): splits: dict[str, float | list[int]] | None = None +@OperationConfig.register_subclass("merge") @dataclass -class MergeConfig: - type: str = "merge" +class MergeConfig(OperationConfig): repo_ids: list[str] | None = None +@OperationConfig.register_subclass("remove_feature") @dataclass -class RemoveFeatureConfig: - type: str = "remove_feature" +class RemoveFeatureConfig(OperationConfig): feature_names: list[str] | None = None +@OperationConfig.register_subclass("modify_tasks") @dataclass -class ModifyTasksConfig: - type: str = "modify_tasks" +class ModifyTasksConfig(OperationConfig): new_task: str | None = None episode_tasks: dict[str, str] | None = None +@OperationConfig.register_subclass("convert_image_to_video") @dataclass -class ConvertImageToVideoConfig: - type: str = "convert_image_to_video" +class ConvertImageToVideoConfig(OperationConfig): output_dir: str | None = None vcodec: str = "libsvtav1" pix_fmt: str = "yuv420p" @@ -177,14 +187,7 @@ class ConvertImageToVideoConfig: @dataclass class EditDatasetConfig: repo_id: str - operation: ( - DeleteEpisodesConfig - | SplitConfig - | MergeConfig - | RemoveFeatureConfig - | ModifyTasksConfig - | ConvertImageToVideoConfig - ) + operation: OperationConfig root: str | None = None new_repo_id: str | None = None push_to_hub: bool = False @@ -450,10 +453,8 @@ def edit_dataset(cfg: EditDatasetConfig) -> None: elif operation_type == "convert_image_to_video": handle_convert_image_to_video(cfg) else: - raise ValueError( - f"Unknown operation type: {operation_type}\n" - f"Available operations: delete_episodes, split, merge, remove_feature, modify_tasks, convert_image_to_video" - ) + available = ", ".join(OperationConfig.get_known_choices()) + raise ValueError(f"Unknown operation: {operation_type}\nAvailable operations: {available}") def main() -> None: diff --git a/tests/scripts/test_edit_dataset_parsing.py b/tests/scripts/test_edit_dataset_parsing.py new file mode 100644 index 000000000..bf7386b52 --- /dev/null +++ b/tests/scripts/test_edit_dataset_parsing.py @@ -0,0 +1,71 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import draccus +import pytest + +from lerobot.scripts.lerobot_edit_dataset import ( + ConvertImageToVideoConfig, + DeleteEpisodesConfig, + EditDatasetConfig, + MergeConfig, + ModifyTasksConfig, + OperationConfig, + RemoveFeatureConfig, + SplitConfig, +) + + +def parse_cfg(cli_args: list[str]) -> EditDatasetConfig: + """Helper to parse CLI args into an EditDatasetConfig via draccus.""" + return draccus.parse(EditDatasetConfig, args=cli_args) + + +class TestOperationTypeParsing: + """Test that --operation.type correctly selects the right config subclass.""" + + @pytest.mark.parametrize( + "type_name, expected_cls", + [ + ("delete_episodes", DeleteEpisodesConfig), + ("split", SplitConfig), + ("merge", MergeConfig), + ("remove_feature", RemoveFeatureConfig), + ("modify_tasks", ModifyTasksConfig), + ("convert_image_to_video", ConvertImageToVideoConfig), + ], + ) + def test_operation_type_resolves_correct_class(self, type_name, expected_cls): + cfg = parse_cfg(["--repo_id", "test/repo", "--operation.type", type_name]) + assert isinstance(cfg.operation, expected_cls), ( + f"Expected {expected_cls.__name__}, got {type(cfg.operation).__name__}" + ) + + @pytest.mark.parametrize( + "type_name, expected_cls", + [ + ("delete_episodes", DeleteEpisodesConfig), + ("split", SplitConfig), + ("merge", MergeConfig), + ("remove_feature", RemoveFeatureConfig), + ("modify_tasks", ModifyTasksConfig), + ("convert_image_to_video", ConvertImageToVideoConfig), + ], + ) + def test_get_choice_name_roundtrips(self, type_name, expected_cls): + cfg = parse_cfg(["--repo_id", "test/repo", "--operation.type", type_name]) + resolved_name = OperationConfig.get_choice_name(type(cfg.operation)) + assert resolved_name == type_name From 6600b60e7f5cc7476ddc34beaaf0e0692f82e4b6 Mon Sep 17 00:00:00 2001 From: Pepijn <138571049+pkooij@users.noreply.github.com> Date: Fri, 13 Feb 2026 13:49:01 +0100 Subject: [PATCH 039/131] always use degrees (#2968) --- src/lerobot/robots/so_follower/config_so_follower.py | 2 +- src/lerobot/teleoperators/so_leader/config_so_leader.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lerobot/robots/so_follower/config_so_follower.py b/src/lerobot/robots/so_follower/config_so_follower.py index e9ce27123..1ee589bda 100644 --- a/src/lerobot/robots/so_follower/config_so_follower.py +++ b/src/lerobot/robots/so_follower/config_so_follower.py @@ -40,7 +40,7 @@ class SOFollowerConfig: cameras: dict[str, CameraConfig] = field(default_factory=dict) # Set to `True` for backward compatibility with previous policies/dataset - use_degrees: bool = False + use_degrees: bool = True @RobotConfig.register_subclass("so101_follower") diff --git a/src/lerobot/teleoperators/so_leader/config_so_leader.py b/src/lerobot/teleoperators/so_leader/config_so_leader.py index dd55196d7..2b4f782a7 100644 --- a/src/lerobot/teleoperators/so_leader/config_so_leader.py +++ b/src/lerobot/teleoperators/so_leader/config_so_leader.py @@ -28,7 +28,7 @@ class SOLeaderConfig: port: str # Whether to use degrees for angles - use_degrees: bool = False + use_degrees: bool = True @TeleoperatorConfig.register_subclass("so101_leader") From 51d3822d75491507561b5f11db0e62d56b342d93 Mon Sep 17 00:00:00 2001 From: masato-ka Date: Wed, 18 Feb 2026 04:09:42 +0900 Subject: [PATCH 040/131] feat(datasets): Add info operation to lerobot-edit-dataset command (#2917) * Add New featrue to lerobot_edit_datset.py that show dataset information. * Fix to draccus error when happen give only --operation.type=info * Updating test and documents regarding lerobot-edit-dataset info function. * Updating documents regarding lerobot-edit-dataset extract function. option name in document is mistake. * feat(datasets): Update to align formatting with pre-commit.(#2917) Update to align formatting by pre-commit. --------- Co-authored-by: Caroline Pascal --- docs/source/using_dataset_tools.mdx | 25 ++++++++ src/lerobot/scripts/lerobot_edit_dataset.py | 65 +++++++++++++++++++++ tests/scripts/test_edit_dataset_parsing.py | 3 + 3 files changed, 93 insertions(+) diff --git a/docs/source/using_dataset_tools.mdx b/docs/source/using_dataset_tools.mdx index 9e662604e..f7fc9be20 100644 --- a/docs/source/using_dataset_tools.mdx +++ b/docs/source/using_dataset_tools.mdx @@ -12,6 +12,7 @@ LeRobot provides several utilities for manipulating datasets: 4. **Add Features** - Add new features to a dataset 5. **Remove Features** - Remove features from a dataset 6. **Convert to Video** - Convert image-based datasets to video format for efficient storage +7. **Show the Info of Datasets** - Show the summary of datasets information such as number of episode etc. The core implementation is in `lerobot.datasets.dataset_tools`. An example script detailing how to use the tools API is available in `examples/dataset/use_dataset_tools.py`. @@ -156,6 +157,30 @@ lerobot-edit-dataset \ **Note:** The resulting dataset will be a proper LeRobotDataset with all cameras encoded as videos in the `videos/` directory, with parquet files containing only metadata (no raw image data). All episodes, stats, and tasks are preserved. +### Show the information of datasets + +Show the information of datasets such as number of episode, number of frame, File size and so on. +No change will be made to the dataset + +```bash + +# Show dataset information without feature details +lerobot-edit-dataset \ + --repo_id lerobot/pusht_image \ + --operation.type info \ + +# Show dataset information with feature details +lerobot-edit-dataset \ + --repo_id lerobot/pusht_image \ + --operation.type info \ + --operation.show_features true + +``` + +**Parameters:** + +- `parameters`: The flag to control show or no show dataset information with feature details.(default=false) + ### Push to Hub Add the `--push_to_hub true` flag to any command to automatically upload the resulting dataset to the Hugging Face Hub: diff --git a/src/lerobot/scripts/lerobot_edit_dataset.py b/src/lerobot/scripts/lerobot_edit_dataset.py index 7c222ac6c..06e256fa2 100644 --- a/src/lerobot/scripts/lerobot_edit_dataset.py +++ b/src/lerobot/scripts/lerobot_edit_dataset.py @@ -104,6 +104,18 @@ Convert image dataset to video format and push to hub: --operation.type convert_image_to_video \ --push_to_hub true +Show dataset information: + python -m lerobot.scripts.lerobot_edit_dataset \ + --repo_id lerobot/pusht_image \ + --operation.type info \ + --operation.show_features true + +Show dataset information without feature details: + python -m lerobot.scripts.lerobot_edit_dataset \ + --repo_id lerobot/pusht_image \ + --operation.type info \ + --operation.show_features false + Using JSON config file: python -m lerobot.scripts.lerobot_edit_dataset \ --config_path path/to/edit_config.json @@ -112,6 +124,7 @@ Using JSON config file: import abc import logging import shutil +import sys from dataclasses import dataclass from pathlib import Path @@ -184,6 +197,13 @@ class ConvertImageToVideoConfig(OperationConfig): max_frames_per_batch: int | None = None +@OperationConfig.register_subclass("info") +@dataclass +class InfoConfig(OperationConfig): + type: str = "info" + show_features: bool = False + + @dataclass class EditDatasetConfig: repo_id: str @@ -436,6 +456,49 @@ def handle_convert_image_to_video(cfg: EditDatasetConfig) -> None: logging.info("Dataset saved locally (not pushed to hub)") +def _get_dataset_size(repo_path): + import os + + total = 0 + with os.scandir(repo_path) as it: + for entry in it: + if entry.is_file(): + total += entry.stat().st_size + elif entry.is_dir(): + total += _get_dataset_size(entry.path) + return total + + +def handle_info(cfg: EditDatasetConfig): + if not isinstance(cfg.operation, InfoConfig): + raise ValueError("Operation config must be InfoConfig") + + dataset = LeRobotDataset(cfg.repo_id, root=cfg.root) + sys.stdout.write(f"======Info {dataset.meta.repo_id}\n") + sys.stdout.write(f"Repository ID: {dataset.meta.repo_id} \n") + sys.stdout.write(f"Total episode: {dataset.meta.total_episodes} \n") + sys.stdout.write(f"Total task: {dataset.meta.total_tasks} \n") + sys.stdout.write(f"Total frame(Actual Count): {dataset.meta.total_frames}({len(dataset)}) \n") + sys.stdout.write( + f"Average frame per episode: {dataset.meta.total_frames / dataset.meta.total_episodes:.1f}\n" + ) + sys.stdout.write( + f"Average episode time(sec): {(dataset.meta.total_frames / dataset.meta.total_episodes) / dataset.meta.fps:.1f}\n" + ) + sys.stdout.write(f"FPS: {dataset.meta.fps}\n") + + total_file_size = _get_dataset_size(dataset.root) + sys.stdout.write(f"Size: {total_file_size / (1024 * 1024):.1f} MB\n") + if cfg.operation.show_features: + import json + + feature_dump_str = json.dumps( + dataset.meta.features, ensure_ascii=False, indent=4, sort_keys=True, separators=(",", ": ") + ) + sys.stdout.write("Features:\n") + sys.stdout.write(f"{feature_dump_str}\n") + + @parser.wrap() def edit_dataset(cfg: EditDatasetConfig) -> None: operation_type = cfg.operation.type @@ -452,6 +515,8 @@ def edit_dataset(cfg: EditDatasetConfig) -> None: handle_modify_tasks(cfg) elif operation_type == "convert_image_to_video": handle_convert_image_to_video(cfg) + elif operation_type == "info": + handle_info(cfg) else: available = ", ".join(OperationConfig.get_known_choices()) raise ValueError(f"Unknown operation: {operation_type}\nAvailable operations: {available}") diff --git a/tests/scripts/test_edit_dataset_parsing.py b/tests/scripts/test_edit_dataset_parsing.py index bf7386b52..8800b92ee 100644 --- a/tests/scripts/test_edit_dataset_parsing.py +++ b/tests/scripts/test_edit_dataset_parsing.py @@ -21,6 +21,7 @@ from lerobot.scripts.lerobot_edit_dataset import ( ConvertImageToVideoConfig, DeleteEpisodesConfig, EditDatasetConfig, + InfoConfig, MergeConfig, ModifyTasksConfig, OperationConfig, @@ -46,6 +47,7 @@ class TestOperationTypeParsing: ("remove_feature", RemoveFeatureConfig), ("modify_tasks", ModifyTasksConfig), ("convert_image_to_video", ConvertImageToVideoConfig), + ("info", InfoConfig), ], ) def test_operation_type_resolves_correct_class(self, type_name, expected_cls): @@ -63,6 +65,7 @@ class TestOperationTypeParsing: ("remove_feature", RemoveFeatureConfig), ("modify_tasks", ModifyTasksConfig), ("convert_image_to_video", ConvertImageToVideoConfig), + ("info", InfoConfig), ], ) def test_get_choice_name_roundtrips(self, type_name, expected_cls): From 1c388c0002c609ca783bf42729a1e41532a1fba0 Mon Sep 17 00:00:00 2001 From: Vladislav Sovrasov Date: Tue, 17 Feb 2026 23:37:46 +0100 Subject: [PATCH 041/131] (Chore) Bump upper bound for torch version (#2897) * Bump upper torch version bound * Apply suggestion from @Copilot Signed-off-by: Vladislav Sovrasov * Update ref state dicts for schedulers * Support older than 2.8 torch versions * Fix precommit --------- Signed-off-by: Vladislav Sovrasov --- pyproject.toml | 6 +++--- tests/optim/test_schedulers.py | 14 ++++++++++++++ 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index c4b1c547e..e5431ada3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -76,9 +76,9 @@ dependencies = [ "pyserial>=3.5,<4.0", "wandb>=0.24.0,<0.25.0", - "torch>=2.2.1,<2.8.0", # TODO: Bumb dependency - "torchcodec>=0.2.1,<0.6.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # TODO: Bumb dependency - "torchvision>=0.21.0,<0.23.0", # TODO: Bumb dependency + "torch>=2.2.1,<2.11.0", # TODO: Bump dependency + "torchcodec>=0.2.1,<0.11.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # TODO: Bump dependency + "torchvision>=0.21.0,<0.26.0", # TODO: Bump dependency "draccus==0.10.0", # TODO: Remove == "gymnasium>=1.1.1,<2.0.0", diff --git a/tests/optim/test_schedulers.py b/tests/optim/test_schedulers.py index 1e566a6ba..224613416 100644 --- a/tests/optim/test_schedulers.py +++ b/tests/optim/test_schedulers.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import torch +from packaging.version import Version from torch.optim.lr_scheduler import LambdaLR from lerobot.optim.schedulers import ( @@ -38,6 +40,10 @@ def test_diffuser_scheduler(optimizer): "last_epoch": 1, "lr_lambdas": [None], } + + if Version(torch.__version__) >= Version("2.8"): + expected_state_dict["_is_initial"] = False + assert scheduler.state_dict() == expected_state_dict @@ -56,6 +62,10 @@ def test_vqbet_scheduler(optimizer): "last_epoch": 1, "lr_lambdas": [None], } + + if Version(torch.__version__) >= Version("2.8"): + expected_state_dict["_is_initial"] = False + assert scheduler.state_dict() == expected_state_dict @@ -76,6 +86,10 @@ def test_cosine_decay_with_warmup_scheduler(optimizer): "last_epoch": 1, "lr_lambdas": [None], } + + if Version(torch.__version__) >= Version("2.8"): + expected_state_dict["_is_initial"] = False + assert scheduler.state_dict() == expected_state_dict From af036ce57e8ce2750f7fa57f4262c87a013bcdff Mon Sep 17 00:00:00 2001 From: Sota Nakamura <49087984+sotanakamura@users.noreply.github.com> Date: Wed, 18 Feb 2026 09:05:51 +0900 Subject: [PATCH 042/131] fix(scripts): serve grpc for a web viewer (#2881) * serve grpc for a web viewer * add help * remove ip detection * fix comment * pass grpc_port * fix(CLI): fixing CLI display-compressed-images argument 1/2 Co-authored-by: HUANG TZU-CHUN Signed-off-by: Caroline Pascal * fix(CLI): fixing CLI display-compressed-images argument 2/2 Co-authored-by: HUANG TZU-CHUN Signed-off-by: Caroline Pascal --------- Signed-off-by: Caroline Pascal Co-authored-by: Caroline Pascal Co-authored-by: HUANG TZU-CHUN Co-authored-by: Steven Palma --- src/lerobot/scripts/lerobot_dataset_viz.py | 37 +++++++++++++++------- 1 file changed, 25 insertions(+), 12 deletions(-) diff --git a/src/lerobot/scripts/lerobot_dataset_viz.py b/src/lerobot/scripts/lerobot_dataset_viz.py index 2cd48eab8..29d64554f 100644 --- a/src/lerobot/scripts/lerobot_dataset_viz.py +++ b/src/lerobot/scripts/lerobot_dataset_viz.py @@ -47,16 +47,14 @@ local$ rerun lerobot_pusht_episode_0.rrd ``` - Visualize data stored on a distant machine through streaming: -(You need to forward the websocket port to the distant machine, with -`ssh -L 9087:localhost:9087 username@remote-host`) ``` distant$ lerobot-dataset-viz \ --repo-id lerobot/pusht \ --episode-index 0 \ --mode distant \ - --ws-port 9087 + --grpc-port 9876 -local$ rerun ws://localhost:9087 +local$ rerun rerun+http://IP:GRPC_PORT/proxy ``` """ @@ -75,6 +73,7 @@ import tqdm from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.utils.constants import ACTION, DONE, OBS_STATE, REWARD +from lerobot.utils.utils import init_logging def to_hwc_uint8_numpy(chw_float32_torch: torch.Tensor) -> np.ndarray: @@ -93,10 +92,11 @@ def visualize_dataset( num_workers: int = 0, mode: str = "local", web_port: int = 9090, - ws_port: int = 9087, + grpc_port: int = 9876, save: bool = False, output_dir: Path | None = None, display_compressed_images: bool = False, + **kwargs, ) -> Path | None: if save: assert output_dir is not None, ( @@ -126,7 +126,9 @@ def visualize_dataset( gc.collect() if mode == "distant": - rr.serve_web_viewer(open_browser=False, web_port=web_port) + server_uri = rr.serve_grpc(grpc_port=grpc_port) + logging.info(f"Connect to a Rerun Server: rerun rerun+http://IP:{grpc_port}/proxy") + rr.serve_web_viewer(open_browser=False, web_port=web_port, connect_to=server_uri) logging.info("Logging to Rerun") @@ -226,7 +228,7 @@ def main(): "Mode of viewing between 'local' or 'distant'. " "'local' requires data to be on a local machine. It spawns a viewer to visualize the data locally. " "'distant' creates a server on the distant machine where the data is stored. " - "Visualize the data by connecting to the server with `rerun ws://localhost:PORT` on the local machine." + "Visualize the data by connecting to the server with `rerun rerun+http://IP:GRPC_PORT/proxy` on the local machine." ), ) parser.add_argument( @@ -238,8 +240,13 @@ def main(): parser.add_argument( "--ws-port", type=int, - default=9087, - help="Web socket port for rerun.io when `--mode distant` is set.", + help="deprecated, please use --grpc-port instead.", + ) + parser.add_argument( + "--grpc-port", + type=int, + default=9876, + help="gRPC port for rerun.io when `--mode distant` is set.", ) parser.add_argument( "--save", @@ -265,9 +272,7 @@ def main(): parser.add_argument( "--display-compressed-images", - type=bool, - required=True, - default=False, + action="store_true", help="If set, display compressed images in Rerun instead of uncompressed ones.", ) @@ -277,6 +282,14 @@ def main(): root = kwargs.pop("root") tolerance_s = kwargs.pop("tolerance_s") + if kwargs["ws_port"] is not None: + logging.warning( + "--ws-port is deprecated and will be removed in future versions. Please use --grpc-port instead." + ) + logging.warning("Setting grpc_port to ws_port value.") + kwargs["grpc_port"] = kwargs.pop("ws_port") + + init_logging() logging.info("Loading dataset") dataset = LeRobotDataset(repo_id, episodes=[args.episode_index], root=root, tolerance_s=tolerance_s) From fcbf550952b3794e425e20d01ec76475be54be4e Mon Sep 17 00:00:00 2001 From: HUANG TZU-CHUN Date: Wed, 18 Feb 2026 18:27:40 +0800 Subject: [PATCH 043/131] fix(docs): update environment variable name to HF_LEROBOT_HOME in docstring (#2973) Co-authored-by: Steven Palma --- src/lerobot/datasets/lerobot_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index 36bffa190..360ed8d30 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -656,7 +656,7 @@ class LeRobotDataset(torch.utils.data.Dataset): repo_id (str): This is the repo id that will be used to fetch the dataset. Locally, the dataset will be stored under root/repo_id. root (Path | None, optional): Local directory to use for downloading/writing files. You can also - set the LEROBOT_HOME environment variable to point to a different location. Defaults to + set the HF_LEROBOT_HOME environment variable to point to a different location. Defaults to '~/.cache/huggingface/lerobot'. episodes (list[int] | None, optional): If specified, this will only load episodes specified by their episode_index in this list. Defaults to None. From b22e0315b05447efbc9a0eb1d612192aad0337c2 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Wed, 18 Feb 2026 17:32:25 +0100 Subject: [PATCH 044/131] fix(utils): more conservative sleep_margin default value in precise_sleep (#2985) --- src/lerobot/utils/robot_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lerobot/utils/robot_utils.py b/src/lerobot/utils/robot_utils.py index 28c8e7c49..656dc2649 100644 --- a/src/lerobot/utils/robot_utils.py +++ b/src/lerobot/utils/robot_utils.py @@ -16,14 +16,14 @@ import platform import time -def precise_sleep(seconds: float, spin_threshold: float = 0.010, sleep_margin: float = 0.003): +def precise_sleep(seconds: float, spin_threshold: float = 0.010, sleep_margin: float = 0.005): """ Wait for `seconds` with better precision than time.sleep alone at the expense of more CPU usage. Parameters: - seconds: duration to wait - spin_threshold: if remaining <= spin_threshold -> spin; otherwise sleep (seconds). Default 10ms - - sleep_margin: when sleeping leave this much time before deadline to avoid oversleep. Default 3ms + - sleep_margin: when sleeping leave this much time before deadline to avoid oversleep. Default 5ms Note: The default parameters are chosen to prioritize timing accuracy over CPU usage for the common 30 FPS use case. From 89bd58a9a26ec5820df13866b6ebc1670ed8cd83 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Wed, 18 Feb 2026 18:22:35 +0100 Subject: [PATCH 045/131] chore(scripts): warn if we don't respect the target FPS (#2986) --- src/lerobot/scripts/lerobot_record.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/lerobot/scripts/lerobot_record.py b/src/lerobot/scripts/lerobot_record.py index 0b39e6fff..216ab22a6 100644 --- a/src/lerobot/scripts/lerobot_record.py +++ b/src/lerobot/scripts/lerobot_record.py @@ -398,7 +398,14 @@ def record_loop( ) dt_s = time.perf_counter() - start_loop_t - precise_sleep(max(1 / fps - dt_s, 0.0)) + + sleep_time_s: float = 1 / fps - dt_s + if sleep_time_s < 0: + logging.warning( + f"Record loop is running slower ({1 / dt_s:.1f} Hz) than the target FPS ({fps} Hz). Dataset frames might be dropped and robot control might be unstable. Common causes are: 1) Camera FPS not keeping up 2) Policy inference taking too long 3) CPU starvation" + ) + + precise_sleep(max(sleep_time_s, 0.0)) timestamp = time.perf_counter() - start_episode_t From aaf37070587581b3ffa8a28b6c134e846afe3a2e Mon Sep 17 00:00:00 2001 From: Caroline Pascal Date: Wed, 18 Feb 2026 19:16:53 +0100 Subject: [PATCH 046/131] fix(filtering): fixing episodes filtering in load_nested_dataset to always use .from_parquet() (#2982) --- src/lerobot/datasets/utils.py | 16 +++------------- 1 file changed, 3 insertions(+), 13 deletions(-) diff --git a/src/lerobot/datasets/utils.py b/src/lerobot/datasets/utils.py index 321ecedd5..da186bf30 100644 --- a/src/lerobot/datasets/utils.py +++ b/src/lerobot/datasets/utils.py @@ -122,19 +122,9 @@ def load_nested_dataset( raise FileNotFoundError(f"Provided directory does not contain any parquet file: {pq_dir}") with SuppressProgressBars(): - # When no filtering needed, Dataset uses memory-mapped loading for efficiency - # PyArrow loads the entire dataset into memory - if episodes is None: - return Dataset.from_parquet([str(path) for path in paths], features=features) - - arrow_dataset = pa_ds.dataset(paths, format="parquet") - filter_expr = pa_ds.field("episode_index").isin(episodes) - table = arrow_dataset.to_table(filter=filter_expr) - - if features is not None: - table = table.cast(features.arrow_schema) - - return Dataset(table) + # We use .from_parquet() memory-mapped loading for efficiency + filters = pa_ds.field("episode_index").isin(episodes) if episodes is not None else None + return Dataset.from_parquet([str(path) for path in paths], filters=filters, features=features) def get_parquet_num_frames(parquet_path: str | Path) -> int: From bc38261321f377621a05595914798023bc05d301 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Wed, 18 Feb 2026 20:05:15 +0100 Subject: [PATCH 047/131] feat(robots): use read_latest() camera (#2987) * feat(robots): use read_latest() camera * fix(test): add read_latest reachy cam mock --- src/lerobot/cameras/camera.py | 2 +- src/lerobot/cameras/opencv/camera_opencv.py | 2 +- src/lerobot/cameras/reachy2_camera/reachy2_camera.py | 2 +- src/lerobot/cameras/realsense/camera_realsense.py | 2 +- src/lerobot/robots/hope_jr/hope_jr_arm.py | 2 +- src/lerobot/robots/hope_jr/hope_jr_hand.py | 2 +- src/lerobot/robots/koch_follower/koch_follower.py | 2 +- src/lerobot/robots/lekiwi/lekiwi.py | 2 +- src/lerobot/robots/omx_follower/omx_follower.py | 2 +- src/lerobot/robots/openarm_follower/openarm_follower.py | 2 +- src/lerobot/robots/reachy2/robot_reachy2.py | 2 +- src/lerobot/robots/so_follower/so_follower.py | 2 +- src/lerobot/robots/unitree_g1/unitree_g1.py | 2 +- tests/robots/test_reachy2.py | 1 + 14 files changed, 14 insertions(+), 13 deletions(-) diff --git a/src/lerobot/cameras/camera.py b/src/lerobot/cameras/camera.py index 2894e0215..2a53d2544 100644 --- a/src/lerobot/cameras/camera.py +++ b/src/lerobot/cameras/camera.py @@ -150,7 +150,7 @@ class Camera(abc.ABC): """ pass - def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]: + def read_latest(self, max_age_ms: int = 500) -> NDArray[Any]: """Return the most recent frame captured immediately (Peeking). This method is non-blocking and returns whatever is currently in the diff --git a/src/lerobot/cameras/opencv/camera_opencv.py b/src/lerobot/cameras/opencv/camera_opencv.py index 10b3f21da..f3289ddc7 100644 --- a/src/lerobot/cameras/opencv/camera_opencv.py +++ b/src/lerobot/cameras/opencv/camera_opencv.py @@ -530,7 +530,7 @@ class OpenCVCamera(Camera): return frame @check_if_not_connected - def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]: + def read_latest(self, max_age_ms: int = 500) -> NDArray[Any]: """Return the most recent frame captured immediately (Peeking). This method is non-blocking and returns whatever is currently in the diff --git a/src/lerobot/cameras/reachy2_camera/reachy2_camera.py b/src/lerobot/cameras/reachy2_camera/reachy2_camera.py index 0c1dc43d8..9bef957bc 100644 --- a/src/lerobot/cameras/reachy2_camera/reachy2_camera.py +++ b/src/lerobot/cameras/reachy2_camera/reachy2_camera.py @@ -201,7 +201,7 @@ class Reachy2Camera(Camera): return self.read() @check_if_not_connected - def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]: + def read_latest(self, max_age_ms: int = 500) -> NDArray[Any]: """Return the most recent frame captured immediately (Peeking). This method is non-blocking and returns whatever is currently in the diff --git a/src/lerobot/cameras/realsense/camera_realsense.py b/src/lerobot/cameras/realsense/camera_realsense.py index d599cdce0..d80ec8093 100644 --- a/src/lerobot/cameras/realsense/camera_realsense.py +++ b/src/lerobot/cameras/realsense/camera_realsense.py @@ -573,7 +573,7 @@ class RealSenseCamera(Camera): # NOTE(Steven): Missing implementation for depth for now @check_if_not_connected - def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]: + def read_latest(self, max_age_ms: int = 500) -> NDArray[Any]: """Return the most recent (color) frame captured immediately (Peeking). This method is non-blocking and returns whatever is currently in the diff --git a/src/lerobot/robots/hope_jr/hope_jr_arm.py b/src/lerobot/robots/hope_jr/hope_jr_arm.py index 5fd9c4d1d..e8269ae46 100644 --- a/src/lerobot/robots/hope_jr/hope_jr_arm.py +++ b/src/lerobot/robots/hope_jr/hope_jr_arm.py @@ -140,7 +140,7 @@ class HopeJrArm(Robot): # Capture images from cameras for cam_key, cam in self.cameras.items(): start = time.perf_counter() - obs_dict[cam_key] = cam.async_read() + obs_dict[cam_key] = cam.read_latest() dt_ms = (time.perf_counter() - start) * 1e3 logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms") diff --git a/src/lerobot/robots/hope_jr/hope_jr_hand.py b/src/lerobot/robots/hope_jr/hope_jr_hand.py index 1e5c72b72..a05c4bbcb 100644 --- a/src/lerobot/robots/hope_jr/hope_jr_hand.py +++ b/src/lerobot/robots/hope_jr/hope_jr_hand.py @@ -171,7 +171,7 @@ class HopeJrHand(Robot): # Capture images from cameras for cam_key, cam in self.cameras.items(): start = time.perf_counter() - obs_dict[cam_key] = cam.async_read() + obs_dict[cam_key] = cam.read_latest() dt_ms = (time.perf_counter() - start) * 1e3 logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms") diff --git a/src/lerobot/robots/koch_follower/koch_follower.py b/src/lerobot/robots/koch_follower/koch_follower.py index fee0adba9..53a32beed 100644 --- a/src/lerobot/robots/koch_follower/koch_follower.py +++ b/src/lerobot/robots/koch_follower/koch_follower.py @@ -193,7 +193,7 @@ class KochFollower(Robot): # Capture images from cameras for cam_key, cam in self.cameras.items(): start = time.perf_counter() - obs_dict[cam_key] = cam.async_read() + obs_dict[cam_key] = cam.read_latest() dt_ms = (time.perf_counter() - start) * 1e3 logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms") diff --git a/src/lerobot/robots/lekiwi/lekiwi.py b/src/lerobot/robots/lekiwi/lekiwi.py index 54848f49d..9d11a000f 100644 --- a/src/lerobot/robots/lekiwi/lekiwi.py +++ b/src/lerobot/robots/lekiwi/lekiwi.py @@ -360,7 +360,7 @@ class LeKiwi(Robot): # Capture images from cameras for cam_key, cam in self.cameras.items(): start = time.perf_counter() - obs_dict[cam_key] = cam.async_read() + obs_dict[cam_key] = cam.read_latest() dt_ms = (time.perf_counter() - start) * 1e3 logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms") diff --git a/src/lerobot/robots/omx_follower/omx_follower.py b/src/lerobot/robots/omx_follower/omx_follower.py index a171affbd..e0b612c60 100644 --- a/src/lerobot/robots/omx_follower/omx_follower.py +++ b/src/lerobot/robots/omx_follower/omx_follower.py @@ -176,7 +176,7 @@ class OmxFollower(Robot): # Capture images from cameras for cam_key, cam in self.cameras.items(): start = time.perf_counter() - obs_dict[cam_key] = cam.async_read() + obs_dict[cam_key] = cam.read_latest() dt_ms = (time.perf_counter() - start) * 1e3 logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms") diff --git a/src/lerobot/robots/openarm_follower/openarm_follower.py b/src/lerobot/robots/openarm_follower/openarm_follower.py index d6794a226..c865f1ec1 100644 --- a/src/lerobot/robots/openarm_follower/openarm_follower.py +++ b/src/lerobot/robots/openarm_follower/openarm_follower.py @@ -241,7 +241,7 @@ class OpenArmFollower(Robot): # Capture images from cameras for cam_key, cam in self.cameras.items(): start = time.perf_counter() - obs_dict[cam_key] = cam.async_read() + obs_dict[cam_key] = cam.read_latest() dt_ms = (time.perf_counter() - start) * 1e3 logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms") diff --git a/src/lerobot/robots/reachy2/robot_reachy2.py b/src/lerobot/robots/reachy2/robot_reachy2.py index 6f4eef56c..fb466f85b 100644 --- a/src/lerobot/robots/reachy2/robot_reachy2.py +++ b/src/lerobot/robots/reachy2/robot_reachy2.py @@ -180,7 +180,7 @@ class Reachy2Robot(Robot): # Capture images from cameras for cam_key, cam in self.cameras.items(): - obs_dict[cam_key] = cam.async_read() + obs_dict[cam_key] = cam.read_latest() return obs_dict diff --git a/src/lerobot/robots/so_follower/so_follower.py b/src/lerobot/robots/so_follower/so_follower.py index b4d11fe3f..bc72a2b6a 100644 --- a/src/lerobot/robots/so_follower/so_follower.py +++ b/src/lerobot/robots/so_follower/so_follower.py @@ -187,7 +187,7 @@ class SOFollower(Robot): # Capture images from cameras for cam_key, cam in self.cameras.items(): start = time.perf_counter() - obs_dict[cam_key] = cam.async_read() + obs_dict[cam_key] = cam.read_latest() dt_ms = (time.perf_counter() - start) * 1e3 logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms") diff --git a/src/lerobot/robots/unitree_g1/unitree_g1.py b/src/lerobot/robots/unitree_g1/unitree_g1.py index 01b4f330e..df0de8f19 100644 --- a/src/lerobot/robots/unitree_g1/unitree_g1.py +++ b/src/lerobot/robots/unitree_g1/unitree_g1.py @@ -324,7 +324,7 @@ class UnitreeG1(Robot): # Cameras - read images from ZMQ cameras for cam_name, cam in self._cameras.items(): - obs[cam_name] = cam.async_read() + obs[cam_name] = cam.read_latest() return obs diff --git a/tests/robots/test_reachy2.py b/tests/robots/test_reachy2.py index d3c44bf5a..d3f32b1c2 100644 --- a/tests/robots/test_reachy2.py +++ b/tests/robots/test_reachy2.py @@ -142,6 +142,7 @@ def _make_reachy2_camera_mock(*args, **kwargs): cam.connect = MagicMock() cam.disconnect = MagicMock() cam.async_read = MagicMock(side_effect=lambda: np.zeros((height, width, 3), dtype=np.uint8)) + cam.read_latest = MagicMock(side_effect=lambda: np.zeros((height, width, 3), dtype=np.uint8)) return cam From 5f15232271a81ee6be16cec1960e300f55f25466 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Wed, 18 Feb 2026 22:46:12 +0100 Subject: [PATCH 048/131] chore: remove usernames + use entrypoints in docs, comments & sample commands (#2988) --- benchmarks/video/README.md | 84 +++++++++---------- docs/source/earthrover_mini_plus.mdx | 2 +- docs/source/hope_jr.mdx | 10 +-- docs/source/pi0.mdx | 2 +- docs/source/pi05.mdx | 2 +- docs/source/sarm.mdx | 8 +- docs/source/unitree_g1.mdx | 4 +- docs/source/walloss.mdx | 2 +- docs/source/xvla.mdx | 2 +- examples/backward_compatibility/replay.py | 2 +- examples/rtc/eval_dataset.py | 20 ++--- examples/rtc/eval_with_real_robot.py | 6 +- .../v30/convert_dataset_v21_to_v30.py | 2 +- .../policies/sarm/compute_rabc_weights.py | 10 +-- .../policies/smolvla/modeling_smolvla.py | 4 +- src/lerobot/scripts/lerobot_edit_dataset.py | 32 +++---- src/lerobot/scripts/lerobot_replay.py | 2 +- 17 files changed, 97 insertions(+), 97 deletions(-) diff --git a/benchmarks/video/README.md b/benchmarks/video/README.md index 490a4b495..1feee69c4 100644 --- a/benchmarks/video/README.md +++ b/benchmarks/video/README.md @@ -28,9 +28,9 @@ We don't expect the same optimal settings for a dataset of images from a simulat For these reasons, we run this benchmark on four representative datasets: - `lerobot/pusht_image`: (96 x 96 pixels) simulation with simple geometric shapes, fixed camera. -- `aliberts/aloha_mobile_shrimp_image`: (480 x 640 pixels) real-world indoor, moving camera. -- `aliberts/paris_street`: (720 x 1280 pixels) real-world outdoor, moving camera. -- `aliberts/kitchen`: (1080 x 1920 pixels) real-world indoor, fixed camera. +- `lerobot/aloha_mobile_shrimp_image`: (480 x 640 pixels) real-world indoor, moving camera. +- `lerobot/paris_street`: (720 x 1280 pixels) real-world outdoor, moving camera. +- `lerobot/kitchen`: (1080 x 1920 pixels) real-world indoor, fixed camera. Note: The datasets used for this benchmark need to be image datasets, not video datasets. @@ -179,7 +179,7 @@ python benchmark/video/run_video_benchmark.py \ --output-dir outputs/video_benchmark \ --repo-ids \ lerobot/pusht_image \ - aliberts/aloha_mobile_shrimp_image \ + lerobot/aloha_mobile_shrimp_image \ --vcodec libx264 libx265 \ --pix-fmt yuv444p yuv420p \ --g 2 20 None \ @@ -203,9 +203,9 @@ python benchmark/video/run_video_benchmark.py \ --output-dir outputs/video_benchmark \ --repo-ids \ lerobot/pusht_image \ - aliberts/aloha_mobile_shrimp_image \ - aliberts/paris_street \ - aliberts/kitchen \ + lerobot/aloha_mobile_shrimp_image \ + lerobot/paris_street \ + lerobot/kitchen \ --vcodec libx264 libx265 \ --pix-fmt yuv444p yuv420p \ --g 1 2 3 4 5 6 10 15 20 40 None \ @@ -221,9 +221,9 @@ python benchmark/video/run_video_benchmark.py \ --output-dir outputs/video_benchmark \ --repo-ids \ lerobot/pusht_image \ - aliberts/aloha_mobile_shrimp_image \ - aliberts/paris_street \ - aliberts/kitchen \ + lerobot/aloha_mobile_shrimp_image \ + lerobot/paris_street \ + lerobot/kitchen \ --vcodec libsvtav1 \ --pix-fmt yuv420p \ --g 1 2 3 4 5 6 10 15 20 40 None \ @@ -252,37 +252,37 @@ Since we're using av1 encoding, we're choosing the `pyav` decoder as `video_read These tables show the results for `g=2` and `crf=30`, using `timestamps-modes=6_frames` and `backend=pyav` -| video_images_size_ratio | vcodec | pix_fmt | | | | -| ---------------------------------- | ---------- | ------- | --------- | --------- | --------- | -| | libx264 | | libx265 | | libsvtav1 | -| repo_id | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p | -| lerobot/pusht_image | **16.97%** | 17.58% | 18.57% | 18.86% | 22.06% | -| aliberts/aloha_mobile_shrimp_image | 2.14% | 2.11% | 1.38% | **1.37%** | 5.59% | -| aliberts/paris_street | 2.12% | 2.13% | **1.54%** | **1.54%** | 4.43% | -| aliberts/kitchen | 1.40% | 1.39% | **1.00%** | **1.00%** | 2.52% | +| video_images_size_ratio | vcodec | pix_fmt | | | | +| --------------------------------- | ---------- | ------- | --------- | --------- | --------- | +| | libx264 | | libx265 | | libsvtav1 | +| repo_id | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p | +| lerobot/pusht_image | **16.97%** | 17.58% | 18.57% | 18.86% | 22.06% | +| lerobot/aloha_mobile_shrimp_image | 2.14% | 2.11% | 1.38% | **1.37%** | 5.59% | +| lerobot/paris_street | 2.12% | 2.13% | **1.54%** | **1.54%** | 4.43% | +| lerobot/kitchen | 1.40% | 1.39% | **1.00%** | **1.00%** | 2.52% | -| video_images_load_time_ratio | vcodec | pix_fmt | | | | -| ---------------------------------- | ------- | ------- | -------- | ------- | --------- | -| | libx264 | | libx265 | | libsvtav1 | -| repo_id | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p | -| lerobot/pusht_image | 6.45 | 5.19 | **1.90** | 2.12 | 2.47 | -| aliberts/aloha_mobile_shrimp_image | 11.80 | 7.92 | 0.71 | 0.85 | **0.48** | -| aliberts/paris_street | 2.21 | 2.05 | 0.36 | 0.49 | **0.30** | -| aliberts/kitchen | 1.46 | 1.46 | 0.28 | 0.51 | **0.26** | +| video_images_load_time_ratio | vcodec | pix_fmt | | | | +| --------------------------------- | ------- | ------- | -------- | ------- | --------- | +| | libx264 | | libx265 | | libsvtav1 | +| repo_id | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p | +| lerobot/pusht_image | 6.45 | 5.19 | **1.90** | 2.12 | 2.47 | +| lerobot/aloha_mobile_shrimp_image | 11.80 | 7.92 | 0.71 | 0.85 | **0.48** | +| lerobot/paris_street | 2.21 | 2.05 | 0.36 | 0.49 | **0.30** | +| lerobot/kitchen | 1.46 | 1.46 | 0.28 | 0.51 | **0.26** | -| | | vcodec | pix_fmt | | | | -| ---------------------------------- | -------- | -------- | ------------ | -------- | --------- | ------------ | -| | | libx264 | | libx265 | | libsvtav1 | -| repo_id | metric | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p | -| lerobot/pusht_image | avg_mse | 2.90E-04 | **2.03E-04** | 3.13E-04 | 2.29E-04 | 2.19E-04 | -| | avg_psnr | 35.44 | 37.07 | 35.49 | **37.30** | 37.20 | -| | avg_ssim | 98.28% | **98.85%** | 98.31% | 98.84% | 98.72% | -| aliberts/aloha_mobile_shrimp_image | avg_mse | 2.76E-04 | 2.59E-04 | 3.17E-04 | 3.06E-04 | **1.30E-04** | -| | avg_psnr | 35.91 | 36.21 | 35.88 | 36.09 | **40.17** | -| | avg_ssim | 95.19% | 95.18% | 95.00% | 95.05% | **97.73%** | -| aliberts/paris_street | avg_mse | 6.89E-04 | 6.70E-04 | 4.03E-03 | 4.02E-03 | **3.09E-04** | -| | avg_psnr | 33.48 | 33.68 | 32.05 | 32.15 | **35.40** | -| | avg_ssim | 93.76% | 93.75% | 89.46% | 89.46% | **95.46%** | -| aliberts/kitchen | avg_mse | 2.50E-04 | 2.24E-04 | 4.28E-04 | 4.18E-04 | **1.53E-04** | -| | avg_psnr | 36.73 | 37.33 | 36.56 | 36.75 | **39.12** | -| | avg_ssim | 95.47% | 95.58% | 95.52% | 95.53% | **96.82%** | +| | | vcodec | pix_fmt | | | | +| --------------------------------- | -------- | -------- | ------------ | -------- | --------- | ------------ | +| | | libx264 | | libx265 | | libsvtav1 | +| repo_id | metric | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p | +| lerobot/pusht_image | avg_mse | 2.90E-04 | **2.03E-04** | 3.13E-04 | 2.29E-04 | 2.19E-04 | +| | avg_psnr | 35.44 | 37.07 | 35.49 | **37.30** | 37.20 | +| | avg_ssim | 98.28% | **98.85%** | 98.31% | 98.84% | 98.72% | +| lerobot/aloha_mobile_shrimp_image | avg_mse | 2.76E-04 | 2.59E-04 | 3.17E-04 | 3.06E-04 | **1.30E-04** | +| | avg_psnr | 35.91 | 36.21 | 35.88 | 36.09 | **40.17** | +| | avg_ssim | 95.19% | 95.18% | 95.00% | 95.05% | **97.73%** | +| lerobot/paris_street | avg_mse | 6.89E-04 | 6.70E-04 | 4.03E-03 | 4.02E-03 | **3.09E-04** | +| | avg_psnr | 33.48 | 33.68 | 32.05 | 32.15 | **35.40** | +| | avg_ssim | 93.76% | 93.75% | 89.46% | 89.46% | **95.46%** | +| lerobot/kitchen | avg_mse | 2.50E-04 | 2.24E-04 | 4.28E-04 | 4.18E-04 | **1.53E-04** | +| | avg_psnr | 36.73 | 37.33 | 36.56 | 36.75 | **39.12** | +| | avg_ssim | 95.47% | 95.58% | 95.52% | 95.53% | **96.82%** | diff --git a/docs/source/earthrover_mini_plus.mdx b/docs/source/earthrover_mini_plus.mdx index d8083336a..dd9c2ad2b 100644 --- a/docs/source/earthrover_mini_plus.mdx +++ b/docs/source/earthrover_mini_plus.mdx @@ -185,7 +185,7 @@ echo $HF_USER Use the standard recording command: ```bash -python src/lerobot/scripts/lerobot_record.py \ +lerobot-record \ --robot.type=earthrover_mini_plus \ --teleop.type=keyboard_rover \ --dataset.repo_id=your_username/dataset_name \ diff --git a/docs/source/hope_jr.mdx b/docs/source/hope_jr.mdx index 856febb95..026cd084a 100644 --- a/docs/source/hope_jr.mdx +++ b/docs/source/hope_jr.mdx @@ -224,7 +224,7 @@ lerobot-record \ --teleop.port=/dev/tty.usbmodem1201 \ --teleop.id=right \ --teleop.side=right \ - --dataset.repo_id=nepyope/hand_record_test_with_video_data \ + --dataset.repo_id=/hand_record_test_with_video_data \ --dataset.single_task="Hand recording test with video data" \ --dataset.num_episodes=1 \ --dataset.episode_time_s=5 \ @@ -241,7 +241,7 @@ lerobot-replay \ --robot.port=/dev/tty.usbmodem58760432281 \ --robot.id=right \ --robot.side=right \ - --dataset.repo_id=nepyope/hand_record_test_with_camera \ + --dataset.repo_id=/hand_record_test_with_camera \ --dataset.episode=0 ``` @@ -249,13 +249,13 @@ lerobot-replay \ ```bash lerobot-train \ - --dataset.repo_id=nepyope/hand_record_test_with_video_data \ + --dataset.repo_id=/hand_record_test_with_video_data \ --policy.type=act \ --output_dir=outputs/train/hopejr_hand \ --job_name=hopejr \ --policy.device=mps \ --wandb.enable=true \ - --policy.repo_id=nepyope/hand_test_policy + --policy.repo_id=/hand_test_policy ``` ### Evaluate @@ -270,7 +270,7 @@ lerobot-record \ --robot.side=right \ --robot.cameras='{"main": {"type": "opencv", "index_or_path": 0, "width": 640, "height": 480, "fps": 30}}' \ --display_data=false \ - --dataset.repo_id=nepyope/eval_hopejr \ + --dataset.repo_id=/eval_hopejr \ --dataset.single_task="Evaluate hopejr hand policy" \ --dataset.num_episodes=10 \ --policy.path=outputs/train/hopejr_hand/checkpoints/last/pretrained_model diff --git a/docs/source/pi0.mdx b/docs/source/pi0.mdx index 93e0b4c88..879bbd16d 100644 --- a/docs/source/pi0.mdx +++ b/docs/source/pi0.mdx @@ -60,7 +60,7 @@ policy.type=pi0 For training π₀, you can use the standard LeRobot training script with the appropriate configuration: ```bash -python src/lerobot/scripts/lerobot_train.py \ +lerobot-train \ --dataset.repo_id=your_dataset \ --policy.type=pi0 \ --output_dir=./outputs/pi0_training \ diff --git a/docs/source/pi05.mdx b/docs/source/pi05.mdx index dbf118aa3..8abaca989 100644 --- a/docs/source/pi05.mdx +++ b/docs/source/pi05.mdx @@ -56,7 +56,7 @@ policy.type=pi05 Here's a complete training command for finetuning the base π₀.₅ model on your own dataset: ```bash -python src/lerobot/scripts/lerobot_train.py\ +lerobot-train \ --dataset.repo_id=your_dataset \ --policy.type=pi05 \ --output_dir=./outputs/pi05_training \ diff --git a/docs/source/sarm.mdx b/docs/source/sarm.mdx index 65e49792b..cd488fe1f 100644 --- a/docs/source/sarm.mdx +++ b/docs/source/sarm.mdx @@ -269,7 +269,7 @@ This generates visualizations showing video frames with subtask boundaries overl Train with **no annotations** - uses linear progress from 0 to 1: ```bash -python src/lerobot/scripts/lerobot_train.py \ +lerobot-train \ --dataset.repo_id=your-username/your-dataset \ --policy.type=sarm \ --policy.annotation_mode=single_stage \ @@ -288,7 +288,7 @@ python src/lerobot/scripts/lerobot_train.py \ Train with **dense annotations only** (sparse auto-generated): ```bash -python src/lerobot/scripts/lerobot_train.py \ +lerobot-train \ --dataset.repo_id=your-username/your-dataset \ --policy.type=sarm \ --policy.annotation_mode=dense_only \ @@ -307,7 +307,7 @@ python src/lerobot/scripts/lerobot_train.py \ Train with **both sparse and dense annotations**: ```bash -python src/lerobot/scripts/lerobot_train.py \ +lerobot-train \ --dataset.repo_id=your-username/your-dataset \ --policy.type=sarm \ --policy.annotation_mode=dual \ @@ -468,7 +468,7 @@ This script: Once you have the progress file, train your policy with RA-BC weighting. The progress file is auto-detected from the dataset path (`sarm_progress.parquet`). Currently PI0, PI0.5 and SmolVLA are supported with RA-BC: ```bash -python src/lerobot/scripts/lerobot_train.py \ +lerobot-train \ --dataset.repo_id=your-username/your-dataset \ --policy.type=pi0 \ --use_rabc=true \ diff --git a/docs/source/unitree_g1.mdx b/docs/source/unitree_g1.mdx index ea6bf54ad..4c5d28924 100644 --- a/docs/source/unitree_g1.mdx +++ b/docs/source/unitree_g1.mdx @@ -216,7 +216,7 @@ lerobot-teleoperate \ ### Record Dataset in Simulation ```bash -python -m lerobot.scripts.lerobot_record \ +lerobot-record \ --robot.type=unitree_g1 \ --robot.is_simulation=true \ --robot.cameras='{"global_view": {"type": "zmq", "server_address": "localhost", "port": 5555, "camera_name": "head_camera", "width": 640, "height": 480, "fps": 30}}' \ @@ -266,7 +266,7 @@ lerobot-teleoperate \ ### Record Dataset on Real Robot ```bash -python -m lerobot.scripts.lerobot_record \ +lerobot-record \ --robot.type=unitree_g1 \ --robot.is_simulation=false \ --robot.cameras='{"global_view": {"type": "zmq", "server_address": "172.18.129.215", "port": 5555, "camera_name": "head_camera", "width": 640, "height": 480, "fps": 30}}' \ diff --git a/docs/source/walloss.mdx b/docs/source/walloss.mdx index c0756c087..e9785cc93 100644 --- a/docs/source/walloss.mdx +++ b/docs/source/walloss.mdx @@ -45,7 +45,7 @@ policy.type=wall_x For training WallX, you can use the standard LeRobot training script with the appropriate configuration: ```bash -python src/lerobot/scripts/lerobot_train.py \ +lerobot-train \ --dataset.repo_id=your_dataset \ --policy.type=wall_x \ --output_dir=./outputs/wallx_training \ diff --git a/docs/source/xvla.mdx b/docs/source/xvla.mdx index dd7d1ef57..97e04d4ec 100644 --- a/docs/source/xvla.mdx +++ b/docs/source/xvla.mdx @@ -154,7 +154,7 @@ lerobot-train \ ```bash lerobot-train \ - --dataset.repo_id=pepijn223/bimanual-so100-handover-cube \ + --dataset.repo_id=/bimanual-so100-handover-cube \ --output_dir=./outputs/xvla_bimanual \ --job_name=xvla_so101_training \ --policy.path="lerobot/xvla-base" \ diff --git a/examples/backward_compatibility/replay.py b/examples/backward_compatibility/replay.py index 8de5ba197..f7c47bec5 100644 --- a/examples/backward_compatibility/replay.py +++ b/examples/backward_compatibility/replay.py @@ -22,7 +22,7 @@ lerobot-replay \ --robot.type=so100_follower \ --robot.port=/dev/tty.usbmodem58760431541 \ --robot.id=black \ - --dataset.repo_id=aliberts/record-test \ + --dataset.repo_id=/record-test \ --dataset.episode=2 ``` """ diff --git a/examples/rtc/eval_dataset.py b/examples/rtc/eval_dataset.py index 4652df107..613fd67d7 100644 --- a/examples/rtc/eval_dataset.py +++ b/examples/rtc/eval_dataset.py @@ -27,8 +27,8 @@ measuring consistency and ground truth alignment. Usage: # Basic usage with smolvla policy uv run python examples/rtc/eval_dataset.py \ - --policy.path=helper2424/smolvla_check_rtc_last3 \ - --dataset.repo_id=helper2424/check_rtc \ + --policy.path=/smolvla_check_rtc_last3 \ + --dataset.repo_id=/check_rtc \ --rtc.execution_horizon=8 \ --device=mps \ --rtc.max_guidance_weight=10.0 \ @@ -58,16 +58,16 @@ Usage: --device=cuda uv run python examples/rtc/eval_dataset.py \ - --policy.path=lipsop/reuben_pi0 \ - --dataset.repo_id=ReubenLim/so101_cube_in_cup \ + --policy.path=/reuben_pi0 \ + --dataset.repo_id=/so101_cube_in_cup \ --rtc.execution_horizon=8 \ --device=cuda # With torch.compile for faster inference (PyTorch 2.0+) # Note: CUDA graphs disabled by default due to in-place ops in denoising loop uv run python examples/rtc/eval_dataset.py \ - --policy.path=helper2424/smolvla_check_rtc_last3 \ - --dataset.repo_id=helper2424/check_rtc \ + --policy.path=/smolvla_check_rtc_last3 \ + --dataset.repo_id=/check_rtc \ --rtc.execution_horizon=8 \ --device=mps \ --use_torch_compile=true \ @@ -75,8 +75,8 @@ Usage: # With torch.compile on CUDA (CUDA graphs disabled by default) uv run python examples/rtc/eval_dataset.py \ - --policy.path=helper2424/smolvla_check_rtc_last3 \ - --dataset.repo_id=helper2424/check_rtc \ + --policy.path=/smolvla_check_rtc_last3 \ + --dataset.repo_id=/check_rtc \ --rtc.execution_horizon=8 \ --device=cuda \ --use_torch_compile=true \ @@ -84,8 +84,8 @@ Usage: # Enable CUDA graphs (advanced - may cause tensor aliasing errors) uv run python examples/rtc/eval_dataset.py \ - --policy.path=helper2424/smolvla_check_rtc_last3 \ - --dataset.repo_id=helper2424/check_rtc \ + --policy.path=/smolvla_check_rtc_last3 \ + --dataset.repo_id=/check_rtc \ --use_torch_compile=true \ --torch_compile_backend=inductor \ --torch_compile_mode=max-autotune \ diff --git a/examples/rtc/eval_with_real_robot.py b/examples/rtc/eval_with_real_robot.py index 1470899d9..4c803eb7e 100644 --- a/examples/rtc/eval_with_real_robot.py +++ b/examples/rtc/eval_with_real_robot.py @@ -28,7 +28,7 @@ For simulation environments, see eval_with_simulation.py Usage: # Run RTC with Real robot with RTC uv run examples/rtc/eval_with_real_robot.py \ - --policy.path=helper2424/smolvla_check_rtc_last3 \ + --policy.path=/smolvla_check_rtc_last3 \ --policy.device=mps \ --rtc.enabled=true \ --rtc.execution_horizon=20 \ @@ -41,7 +41,7 @@ Usage: # Run RTC with Real robot without RTC uv run examples/rtc/eval_with_real_robot.py \ - --policy.path=helper2424/smolvla_check_rtc_last3 \ + --policy.path=/smolvla_check_rtc_last3 \ --policy.device=mps \ --rtc.enabled=false \ --robot.type=so100_follower \ @@ -53,7 +53,7 @@ Usage: # Run RTC with Real robot with pi0.5 policy uv run examples/rtc/eval_with_real_robot.py \ - --policy.path=helper2424/pi05_check_rtc \ + --policy.path=/pi05_check_rtc \ --policy.device=mps \ --rtc.enabled=true \ --rtc.execution_horizon=20 \ diff --git a/src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py b/src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py index 74be6bfa4..7be37a1b1 100644 --- a/src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py +++ b/src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py @@ -529,7 +529,7 @@ if __name__ == "__main__": type=str, required=True, help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset " - "(e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).", + "(e.g. `lerobot/pusht`, `/aloha_sim_insertion_human`).", ) parser.add_argument( "--branch", diff --git a/src/lerobot/policies/sarm/compute_rabc_weights.py b/src/lerobot/policies/sarm/compute_rabc_weights.py index 5b6ea6e9b..485c1096b 100644 --- a/src/lerobot/policies/sarm/compute_rabc_weights.py +++ b/src/lerobot/policies/sarm/compute_rabc_weights.py @@ -27,18 +27,18 @@ Usage: # Full RA-BC computation with visualizations python src/lerobot/policies/sarm/compute_rabc_weights.py \\ --dataset-repo-id lerobot/aloha_sim_insertion_human \\ - --reward-model-path pepijn223/sarm_single_uni4 + --reward-model-path /sarm_single_uni4 # Faster computation with stride (compute every 5 frames, interpolate the rest) python src/lerobot/policies/sarm/compute_rabc_weights.py \\ --dataset-repo-id lerobot/aloha_sim_insertion_human \\ - --reward-model-path pepijn223/sarm_single_uni4 \\ + --reward-model-path /sarm_single_uni4 \\ --stride 5 # Visualize predictions only (no RA-BC computation) python src/lerobot/policies/sarm/compute_rabc_weights.py \\ --dataset-repo-id lerobot/aloha_sim_insertion_human \\ - --reward-model-path pepijn223/sarm_single_uni4 \\ + --reward-model-path /sarm_single_uni4 \\ --visualize-only \\ --num-visualizations 5 @@ -714,12 +714,12 @@ Examples: # Full RA-BC computation with visualizations python src/lerobot/policies/sarm/compute_rabc_weights.py \\ --dataset-repo-id lerobot/aloha_sim_insertion_human \\ - --reward-model-path pepijn223/sarm_single_uni4 + --reward-model-path /sarm_single_uni4 # Visualize predictions only (no RA-BC computation) python src/lerobot/policies/sarm/compute_rabc_weights.py \\ --dataset-repo-id lerobot/aloha_sim_insertion_human \\ - --reward-model-path pepijn223/sarm_single_uni4 \\ + --reward-model-path /sarm_single_uni4 \\ --visualize-only \\ --num-visualizations 10 """, diff --git a/src/lerobot/policies/smolvla/modeling_smolvla.py b/src/lerobot/policies/smolvla/modeling_smolvla.py index 60b968a42..10544a949 100644 --- a/src/lerobot/policies/smolvla/modeling_smolvla.py +++ b/src/lerobot/policies/smolvla/modeling_smolvla.py @@ -30,7 +30,7 @@ Example of finetuning the smolvla pretrained model (`smolvla_base`): ```bash lerobot-train \ --policy.path=lerobot/smolvla_base \ ---dataset.repo_id=danaaubakirova/svla_so100_task1_v3 \ +--dataset.repo_id=/svla_so100_task1_v3 \ --batch_size=64 \ --steps=200000 ``` @@ -40,7 +40,7 @@ and an action expert. ```bash lerobot-train \ --policy.type=smolvla \ ---dataset.repo_id=danaaubakirova/svla_so100_task1_v3 \ +--dataset.repo_id=/svla_so100_task1_v3 \ --batch_size=64 \ --steps=200000 ``` diff --git a/src/lerobot/scripts/lerobot_edit_dataset.py b/src/lerobot/scripts/lerobot_edit_dataset.py index 06e256fa2..afdc95efd 100644 --- a/src/lerobot/scripts/lerobot_edit_dataset.py +++ b/src/lerobot/scripts/lerobot_edit_dataset.py @@ -24,100 +24,100 @@ When new_repo_id is specified, creates a new dataset. Usage Examples: Delete episodes 0, 2, and 5 from a dataset: - python -m lerobot.scripts.lerobot_edit_dataset \ + lerobot-edit-dataset \ --repo_id lerobot/pusht \ --operation.type delete_episodes \ --operation.episode_indices "[0, 2, 5]" Delete episodes and save to a new dataset: - python -m lerobot.scripts.lerobot_edit_dataset \ + lerobot-edit-dataset \ --repo_id lerobot/pusht \ --new_repo_id lerobot/pusht_filtered \ --operation.type delete_episodes \ --operation.episode_indices "[0, 2, 5]" Split dataset by fractions: - python -m lerobot.scripts.lerobot_edit_dataset \ + lerobot-edit-dataset \ --repo_id lerobot/pusht \ --operation.type split \ --operation.splits '{"train": 0.8, "val": 0.2}' Split dataset by episode indices: - python -m lerobot.scripts.lerobot_edit_dataset \ + lerobot-edit-dataset \ --repo_id lerobot/pusht \ --operation.type split \ --operation.splits '{"train": [0, 1, 2, 3], "val": [4, 5]}' Split into more than two splits: - python -m lerobot.scripts.lerobot_edit_dataset \ + lerobot-edit-dataset \ --repo_id lerobot/pusht \ --operation.type split \ --operation.splits '{"train": 0.6, "val": 0.2, "test": 0.2}' Merge multiple datasets: - python -m lerobot.scripts.lerobot_edit_dataset \ + lerobot-edit-dataset \ --repo_id lerobot/pusht_merged \ --operation.type merge \ --operation.repo_ids "['lerobot/pusht_train', 'lerobot/pusht_val']" Remove camera feature: - python -m lerobot.scripts.lerobot_edit_dataset \ + lerobot-edit-dataset \ --repo_id lerobot/pusht \ --operation.type remove_feature \ --operation.feature_names "['observation.images.top']" Modify tasks - set a single task for all episodes (WARNING: modifies in-place): - python -m lerobot.scripts.lerobot_edit_dataset \ + lerobot-edit-dataset \ --repo_id lerobot/pusht \ --operation.type modify_tasks \ --operation.new_task "Pick up the cube and place it" Modify tasks - set different tasks for specific episodes (WARNING: modifies in-place): - python -m lerobot.scripts.lerobot_edit_dataset \ + lerobot-edit-dataset \ --repo_id lerobot/pusht \ --operation.type modify_tasks \ --operation.episode_tasks '{"0": "Task A", "1": "Task B", "2": "Task A"}' Modify tasks - set default task with overrides for specific episodes (WARNING: modifies in-place): - python -m lerobot.scripts.lerobot_edit_dataset \ + lerobot-edit-dataset \ --repo_id lerobot/pusht \ --operation.type modify_tasks \ --operation.new_task "Default task" \ --operation.episode_tasks '{"5": "Special task for episode 5"}' Convert image dataset to video format and save locally: - python -m lerobot.scripts.lerobot_edit_dataset \ + lerobot-edit-dataset \ --repo_id lerobot/pusht_image \ --operation.type convert_image_to_video \ --operation.output_dir /path/to/output/pusht_video Convert image dataset to video format and save with new repo_id: - python -m lerobot.scripts.lerobot_edit_dataset \ + lerobot-edit-dataset \ --repo_id lerobot/pusht_image \ --new_repo_id lerobot/pusht_video \ --operation.type convert_image_to_video Convert image dataset to video format and push to hub: - python -m lerobot.scripts.lerobot_edit_dataset \ + lerobot-edit-dataset \ --repo_id lerobot/pusht_image \ --new_repo_id lerobot/pusht_video \ --operation.type convert_image_to_video \ --push_to_hub true Show dataset information: - python -m lerobot.scripts.lerobot_edit_dataset \ + lerobot-edit-dataset \ --repo_id lerobot/pusht_image \ --operation.type info \ --operation.show_features true Show dataset information without feature details: - python -m lerobot.scripts.lerobot_edit_dataset \ + lerobot-edit-dataset \ --repo_id lerobot/pusht_image \ --operation.type info \ --operation.show_features false Using JSON config file: - python -m lerobot.scripts.lerobot_edit_dataset \ + lerobot-edit-dataset \ --config_path path/to/edit_config.json """ diff --git a/src/lerobot/scripts/lerobot_replay.py b/src/lerobot/scripts/lerobot_replay.py index c9a559d07..8e2a394b9 100644 --- a/src/lerobot/scripts/lerobot_replay.py +++ b/src/lerobot/scripts/lerobot_replay.py @@ -22,7 +22,7 @@ lerobot-replay \ --robot.type=so100_follower \ --robot.port=/dev/tty.usbmodem58760431541 \ --robot.id=black \ - --dataset.repo_id=aliberts/record-test \ + --dataset.repo_id=/record-test \ --dataset.episode=0 ``` From 2dd366436ed30ed9729b4f18076a54fec7ec589b Mon Sep 17 00:00:00 2001 From: Khalil Date: Thu, 19 Feb 2026 14:35:02 +0100 Subject: [PATCH 049/131] Fix gym-hil integration with the new LeRobot pipeline. (#2482) * Add GymHILAdapterProcessorStep for gym-hil environment integration * Fix action features in control loop for None teleop device with gym-hil * Finalize dataset before pushing to hub for visualization on the hub * Fix neutral action for gripper * fix pre-commit --- src/lerobot/processor/__init__.py | 2 ++ src/lerobot/processor/gym_action_processor.py | 8 +++++ src/lerobot/processor/hil_processor.py | 31 +++++++++++++++++++ src/lerobot/rl/gym_manipulator.py | 15 +++++++-- 4 files changed, 54 insertions(+), 2 deletions(-) diff --git a/src/lerobot/processor/__init__.py b/src/lerobot/processor/__init__.py index 164f7da03..0b63e1606 100644 --- a/src/lerobot/processor/__init__.py +++ b/src/lerobot/processor/__init__.py @@ -44,6 +44,7 @@ from .hil_processor import ( AddTeleopActionAsComplimentaryDataStep, AddTeleopEventsAsInfoStep, GripperPenaltyProcessorStep, + GymHILAdapterProcessorStep, ImageCropResizeProcessorStep, InterventionActionProcessorStep, RewardClassifierProcessorStep, @@ -87,6 +88,7 @@ __all__ = [ "DoneProcessorStep", "EnvAction", "EnvTransition", + "GymHILAdapterProcessorStep", "GripperPenaltyProcessorStep", "hotswap_stats", "IdentityProcessorStep", diff --git a/src/lerobot/processor/gym_action_processor.py b/src/lerobot/processor/gym_action_processor.py index 8fa8cfd86..4f225af92 100644 --- a/src/lerobot/processor/gym_action_processor.py +++ b/src/lerobot/processor/gym_action_processor.py @@ -20,6 +20,7 @@ from lerobot.configs.types import PipelineFeatureType, PolicyFeature from .converters import to_tensor from .core import EnvAction, EnvTransition, PolicyAction +from .hil_processor import TELEOP_ACTION_KEY from .pipeline import ActionProcessorStep, ProcessorStep, ProcessorStepRegistry @@ -89,6 +90,13 @@ class Numpy2TorchActionProcessorStep(ProcessorStep): torch_action = to_tensor(action, dtype=None) # Preserve original dtype new_transition[TransitionKey.ACTION] = torch_action + complementary_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) + if TELEOP_ACTION_KEY in complementary_data: + teleop_action = complementary_data[TELEOP_ACTION_KEY] + if isinstance(teleop_action, EnvAction): + complementary_data[TELEOP_ACTION_KEY] = to_tensor(teleop_action) + new_transition[TransitionKey.COMPLEMENTARY_DATA] = complementary_data + return new_transition def transform_features( diff --git a/src/lerobot/processor/hil_processor.py b/src/lerobot/processor/hil_processor.py index 24b5628fa..34eaeed51 100644 --- a/src/lerobot/processor/hil_processor.py +++ b/src/lerobot/processor/hil_processor.py @@ -312,6 +312,37 @@ class TimeLimitProcessorStep(TruncatedProcessorStep): return features +@ProcessorStepRegistry.register("gym_hil_adapter_processor") +class GymHILAdapterProcessorStep(ProcessorStep): + """ + Adapts the output of the `gym-hil` environment to the format expected by `lerobot` processors. + + This step normalizes the `transition` object by: + 1. Copying `teleop_action` from `info` to `complementary_data`. + 2. Copying `is_intervention` from `info` (using the string key) to `info` (using the enum key). + """ + + def __call__(self, transition: EnvTransition) -> EnvTransition: + info = transition.get(TransitionKey.INFO, {}) + complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) + + if TELEOP_ACTION_KEY in info: + complementary_data[TELEOP_ACTION_KEY] = info[TELEOP_ACTION_KEY] + + if "is_intervention" in info: + info[TeleopEvents.IS_INTERVENTION] = info["is_intervention"] + + transition[TransitionKey.INFO] = info + transition[TransitionKey.COMPLEMENTARY_DATA] = complementary_data + + return transition + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + return features + + @dataclass @ProcessorStepRegistry.register("gripper_penalty_processor") class GripperPenaltyProcessorStep(ProcessorStep): diff --git a/src/lerobot/rl/gym_manipulator.py b/src/lerobot/rl/gym_manipulator.py index 1c1cb752f..f5fcb7437 100644 --- a/src/lerobot/rl/gym_manipulator.py +++ b/src/lerobot/rl/gym_manipulator.py @@ -36,6 +36,7 @@ from lerobot.processor import ( DeviceProcessorStep, EnvTransition, GripperPenaltyProcessorStep, + GymHILAdapterProcessorStep, ImageCropResizeProcessorStep, InterventionActionProcessorStep, MapDeltaActionToRobotActionStep, @@ -379,6 +380,7 @@ def make_processors( ] env_pipeline_steps = [ + GymHILAdapterProcessorStep(), Numpy2TorchActionProcessorStep(), VanillaObservationProcessorStep(), AddBatchDimensionProcessorStep(), @@ -608,7 +610,14 @@ def control_loop( dataset = None if cfg.mode == "record": - action_features = teleop_device.action_features + if teleop_device: + action_features = teleop_device.action_features + else: + action_features = { + "dtype": "float32", + "shape": (4,), + "names": ["delta_x", "delta_y", "delta_z", "gripper"], + } features = { ACTION: action_features, REWARD: {"dtype": "float32", "shape": (1,), "names": None}, @@ -656,7 +665,7 @@ def control_loop( # Create a neutral action (no movement) neutral_action = torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32) if use_gripper: - neutral_action = torch.cat([neutral_action, torch.tensor([1.0])]) # Gripper stay + neutral_action = torch.cat([neutral_action, torch.tensor([0.0])]) # Gripper stay # Use the new step function transition = step_env_and_process_transition( @@ -725,6 +734,8 @@ def control_loop( precise_sleep(max(dt - (time.perf_counter() - step_start_time), 0.0)) if dataset is not None and cfg.dataset.push_to_hub: + logging.info("Finalizing dataset before pushing to hub") + dataset.finalize() logging.info("Pushing dataset to hub") dataset.push_to_hub() From 5865170d36442b907bb35f946e837eee18aafdf1 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Fri, 20 Feb 2026 17:01:46 +0100 Subject: [PATCH 050/131] chore(deps): bump ceil datasets (#2946) --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index e5431ada3..0ca1f0432 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,7 +59,7 @@ keywords = ["lerobot", "huggingface", "robotics", "machine learning", "artifici dependencies = [ # Hugging Face dependencies - "datasets>=4.0.0,<4.2.0", + "datasets>=4.0.0,<5.0.0", "diffusers>=0.27.2,<0.36.0", "huggingface-hub[hf-transfer,cli]>=0.34.2,<0.36.0", "accelerate>=1.10.0,<2.0.0", From e96339a3b49ac080e11664d2be4737987d861a57 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Mon, 23 Feb 2026 13:57:43 +0100 Subject: [PATCH 051/131] feat(dataset): add streaming video encoding + HW encoder support (#2974) * feat(dataset): init stream encoding * feat(dataset): use threads to fix frame pickle latency * refactor(dataset): remove HW encoded related changes * add lp (#2977) * feat(dataset): add Hw encoding + log drop frames (#2978) * chore(docs): add streaming video encoding guide * fix(dataset): style docs + testing * chore(docs): simplify sttreaming video encoding guide * chore(dataset): add commands + streaming encoding default false + print note if false + queue default is now 30 * chore(docs): add verification note advice * chore(dataset): adjusting defaults & docs for streaming encoding * docs(scripts): improve docstrings * test(dataset): polish streaming encoding tests * chore(dataset): move FYI log related to streaming * chore(dataset): add arg vcodec to suggestions * refactor(dataset): better handling for auto and available vcodec * chore(dataset): change log level * docs(dataset): add note related to training performance vcodec * docs(dataset): add more notes to streaming encoding --------- Co-authored-by: Caroline Pascal Co-authored-by: Pepijn --- docs/source/_toctree.yml | 2 + docs/source/act.mdx | 3 + docs/source/earthrover_mini_plus.mdx | 3 + docs/source/groot.mdx | 9 +- docs/source/hope_jr.mdx | 6 + docs/source/il_robots.mdx | 8 +- docs/source/lerobot-dataset-v3.mdx | 5 +- docs/source/reachy2.mdx | 6 + docs/source/smolvla.mdx | 3 + docs/source/streaming_video_encoding.mdx | 155 ++++ docs/source/unitree_g1.mdx | 10 +- src/lerobot/datasets/lerobot_dataset.py | 124 ++- src/lerobot/datasets/video_utils.py | 480 +++++++++++- src/lerobot/scripts/lerobot_record.py | 34 +- tests/datasets/test_datasets.py | 9 +- .../datasets/test_streaming_video_encoder.py | 730 ++++++++++++++++++ 16 files changed, 1532 insertions(+), 55 deletions(-) create mode 100644 docs/source/streaming_video_encoding.mdx create mode 100644 tests/datasets/test_streaming_video_encoder.py diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index d61aac9c1..1055975d7 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -29,6 +29,8 @@ title: Using the Dataset Tools - local: dataset_subtask title: Using Subtasks in the Dataset + - local: streaming_video_encoding + title: Streaming Video Encoding title: "Datasets" - sections: - local: act diff --git a/docs/source/act.mdx b/docs/source/act.mdx index e3294ca69..453bcbba8 100644 --- a/docs/source/act.mdx +++ b/docs/source/act.mdx @@ -88,5 +88,8 @@ lerobot-record \ --dataset.repo_id=${HF_USER}/eval_act_your_dataset \ --dataset.num_episodes=10 \ --dataset.single_task="Your task description" \ + --dataset.streaming_encoding=true \ + --dataset.encoder_threads=2 \ + # --dataset.vcodec=auto \ --policy.path=${HF_USER}/act_policy ``` diff --git a/docs/source/earthrover_mini_plus.mdx b/docs/source/earthrover_mini_plus.mdx index dd9c2ad2b..cfc3a2eef 100644 --- a/docs/source/earthrover_mini_plus.mdx +++ b/docs/source/earthrover_mini_plus.mdx @@ -192,6 +192,9 @@ lerobot-record \ --dataset.num_episodes=2 \ --dataset.fps=10 \ --dataset.single_task="Navigate around obstacles" \ + --dataset.streaming_encoding=true \ + --dataset.encoder_threads=2 \ + # --dataset.vcodec=auto \ --display_data=true ``` diff --git a/docs/source/groot.mdx b/docs/source/groot.mdx index 8bfc22996..0ef591466 100644 --- a/docs/source/groot.mdx +++ b/docs/source/groot.mdx @@ -120,9 +120,12 @@ lerobot-record \ --display_data=true \ --dataset.repo_id=/eval_groot-bimanual \ --dataset.num_episodes=10 \ - --dataset.single_task="Grab and handover the red cube to the other arm" - --policy.path=/groot-bimanual # your trained model - --dataset.episode_time_s=30 + --dataset.single_task="Grab and handover the red cube to the other arm" \ + --dataset.streaming_encoding=true \ + --dataset.encoder_threads=2 \ + # --dataset.vcodec=auto \ + --policy.path=/groot-bimanual \ # your trained model + --dataset.episode_time_s=30 \ --dataset.reset_time_s=10 ``` diff --git a/docs/source/hope_jr.mdx b/docs/source/hope_jr.mdx index 026cd084a..8826d9758 100644 --- a/docs/source/hope_jr.mdx +++ b/docs/source/hope_jr.mdx @@ -230,6 +230,9 @@ lerobot-record \ --dataset.episode_time_s=5 \ --dataset.push_to_hub=true \ --dataset.private=true \ + --dataset.streaming_encoding=true \ + --dataset.encoder_threads=2 \ + # --dataset.vcodec=auto \ --display_data=true ``` @@ -273,5 +276,8 @@ lerobot-record \ --dataset.repo_id=/eval_hopejr \ --dataset.single_task="Evaluate hopejr hand policy" \ --dataset.num_episodes=10 \ + --dataset.streaming_encoding=true \ + --dataset.encoder_threads=2 \ + # --dataset.vcodec=auto \ --policy.path=outputs/train/hopejr_hand/checkpoints/last/pretrained_model ``` diff --git a/docs/source/il_robots.mdx b/docs/source/il_robots.mdx index 84dc6f2f6..7fc770b0c 100644 --- a/docs/source/il_robots.mdx +++ b/docs/source/il_robots.mdx @@ -185,7 +185,10 @@ lerobot-record \ --display_data=true \ --dataset.repo_id=${HF_USER}/record-test \ --dataset.num_episodes=5 \ - --dataset.single_task="Grab the black cube" + --dataset.single_task="Grab the black cube" \ + --dataset.streaming_encoding=true \ + # --dataset.vcodec=auto \ + --dataset.encoder_threads=2 ``` @@ -515,6 +518,9 @@ lerobot-record \ --display_data=false \ --dataset.repo_id=${HF_USER}/eval_so100 \ --dataset.single_task="Put lego brick into the transparent box" \ + --dataset.streaming_encoding=true \ + --dataset.encoder_threads=2 \ + # --dataset.vcodec=auto \ # <- Teleop optional if you want to teleoperate in between episodes \ # --teleop.type=so100_leader \ # --teleop.port=/dev/ttyACM0 \ diff --git a/docs/source/lerobot-dataset-v3.mdx b/docs/source/lerobot-dataset-v3.mdx index 3521914f2..235a355bd 100644 --- a/docs/source/lerobot-dataset-v3.mdx +++ b/docs/source/lerobot-dataset-v3.mdx @@ -41,7 +41,10 @@ lerobot-record \ --display_data=true \ --dataset.repo_id=${HF_USER}/record-test \ --dataset.num_episodes=5 \ - --dataset.single_task="Grab the black cube" + --dataset.single_task="Grab the black cube" \ + --dataset.streaming_encoding=true \ + # --dataset.vcodec=auto \ + --dataset.encoder_threads=2 ``` See the [recording guide](./il_robots#record-a-dataset) for more details. diff --git a/docs/source/reachy2.mdx b/docs/source/reachy2.mdx index 51b09acd2..1b868711a 100644 --- a/docs/source/reachy2.mdx +++ b/docs/source/reachy2.mdx @@ -159,6 +159,9 @@ lerobot-record \ --dataset.fps=15 \ --dataset.push_to_hub=true \ --dataset.private=true \ + --dataset.streaming_encoding=true \ + --dataset.encoder_threads=2 \ + # --dataset.vcodec=auto \ --display_data=true ``` @@ -198,6 +201,9 @@ lerobot-record \ --dataset.fps=15 \ --dataset.push_to_hub=true \ --dataset.private=true \ + --dataset.streaming_encoding=true \ + --dataset.encoder_threads=2 \ + # --dataset.vcodec=auto \ --display_data=true ``` diff --git a/docs/source/smolvla.mdx b/docs/source/smolvla.mdx index a56298b5e..bf8a0d2f0 100644 --- a/docs/source/smolvla.mdx +++ b/docs/source/smolvla.mdx @@ -106,6 +106,9 @@ lerobot-record \ --dataset.repo_id=${HF_USER}/eval_DATASET_NAME_test \ # <- This will be the dataset name on HF Hub --dataset.episode_time_s=50 \ --dataset.num_episodes=10 \ + --dataset.streaming_encoding=true \ + --dataset.encoder_threads=2 \ + # --dataset.vcodec=auto \ # <- Teleop optional if you want to teleoperate in between episodes \ # --teleop.type=so100_leader \ # --teleop.port=/dev/ttyACM0 \ diff --git a/docs/source/streaming_video_encoding.mdx b/docs/source/streaming_video_encoding.mdx new file mode 100644 index 000000000..40004200e --- /dev/null +++ b/docs/source/streaming_video_encoding.mdx @@ -0,0 +1,155 @@ +# Streaming Video Encoding Guide + +## 1. Overview + +Streaming video encoding eliminates the traditional PNG round-trip during video dataset recording. Instead of: + +1. Capture frame -> write PNG to disk -> (at episode end) read PNG's -> encode to MP4 -> delete PNG's + +Frames can be encoded in real-time during capture: + +1. Capture frame -> queue to encoder thread -> encode to MP4 directly + +This makes `save_episode()` near-instant (the video is already encoded by the time the episode ends) and removes the blocking wait that previously occurred between episodes, especially with multiple cameras in long episodes. + +## 2. Tuning Parameters + +| Parameter | CLI Flag | Type | Default | Description | +| ----------------------- | --------------------------------- | ------------- | ------------- | ----------------------------------------------------------------- | +| `streaming_encoding` | `--dataset.streaming_encoding` | `bool` | `True` | Enable real-time encoding during capture | +| `vcodec` | `--dataset.vcodec` | `str` | `"libsvtav1"` | Video codec. `"auto"` detects best HW encoder | +| `encoder_threads` | `--dataset.encoder_threads` | `int \| None` | `None` (auto) | Threads per encoder instance. `None` will leave the vcoded decide | +| `encoder_queue_maxsize` | `--dataset.encoder_queue_maxsize` | `int` | `60` | Max buffered frames per camera (~2s at 30fps). Consumes RAM | + +## 3. Performance Considerations + +Streaming encoding means the CPU is encoding video **during** the capture loop, not after. This creates a CPU budget that must be shared between: + +- **Control loop** (reading cameras, control the robot, writing non-video data) +- **Encoder threads** (one pool per camera) +- **Rerun visualization** (if enabled) +- **OS and other processes** + +### Resolution & Number of Cameras Impact + +| Setup | Throughput (px/sec) | CPU Encoding Load | Notes | +| ------------------------- | ------------------- | ----------------- | ------------------------------ | +| 2camsx 640x480x3 @30fps | 55M | Low | Works on most systems | +| 2camsx 1280x720x3 @30fps | 165M | Moderate | Comfortable on modern systems | +| 2camsx 1920x1080x3 @30fps | 373M | High | Requires powerful high-end CPU | + +### `encoder_threads` Tuning + +This parameter controls how many threads each encoder instance uses internally: + +- **Higher values** (e.g., 4-5): Faster encoding, but uses more CPU cores per camera. Good for high-end systems with many cores. +- **Lower values** (e.g., 1-2): Less CPU per camera, freeing cores for capture and visualization. Good for low-res images and capable CPUs. +- **`None` (default)**: Lets the codec decide. Information available in the codec logs. + +### Backpressure and Frame Dropping + +Each camera has a bounded queue (`encoder_queue_maxsize`, default 60 frames). When the encoder can't keep up: + +1. The queue fills up (consuming RAM) +2. New frames are **dropped** (not blocked) — the capture loop continues uninterrupted +3. A warning is logged: `"Encoder queue full for {camera}, dropped N frame(s)"` +4. At episode end, total dropped frames per camera are reported + +### Symptoms of Encoder Falling Behind + +- **System feels laggy and freezes**: all CPUs are at 100% +- **Dropped frame warnings** in the log or lower frames/FPS than expected in the recorded dataset +- **Choppy robot movement**: If CPU is severely overloaded, even the capture loop may be affected +- **Accumulated rerun lag**: Visualization falls behind real-time + +## 4. Hardware-Accelerated Encoding + +### When to Use + +Use HW encoding when: + +- CPU is the bottleneck (dropped frames, choppy robot, rerun lag) +- You have compatible hardware (GPU or dedicated encoder) +- You're recording at high throughput (high resolution or with many cameras) + +### Choosing a Codec + +| Codec | CPU Usage | File Size | Quality | Notes | +| --------------------- | --------- | -------------- | ------- | ---------------------------------------------------------------- | +| `libsvtav1` (default) | High | Smallest | Best | Default. Best compression but most CPU-intensive | +| `h264` | Medium | ~30-50% larger | Good | Software H.264. Lower CPU | +| HW encoders | Very Low | Largest | Good | Offloads to dedicated hardware. Best for CPU-constrained systems | + +### Available HW Encoders + +| Encoder | Platform | Hardware | CLI Value | +| ------------------- | ------------- | ------------------------------------------------------------------------------------------------ | ------------------------------------ | +| `h264_videotoolbox` | macOS | Apple Silicon / Intel | `--dataset.vcodec=h264_videotoolbox` | +| `hevc_videotoolbox` | macOS | Apple Silicon / Intel | `--dataset.vcodec=hevc_videotoolbox` | +| `h264_nvenc` | Linux/Windows | NVIDIA GPU | `--dataset.vcodec=h264_nvenc` | +| `hevc_nvenc` | Linux/Windows | NVIDIA GPU | `--dataset.vcodec=hevc_nvenc` | +| `h264_vaapi` | Linux | Intel/AMD GPU | `--dataset.vcodec=h264_vaapi` | +| `h264_qsv` | Linux/Windows | Intel Quick Sync | `--dataset.vcodec=h264_qsv` | +| `auto` | Any | Probes the system for available HW encoders. Falls back to `libsvtav1` if no HW encoder is found | `--dataset.vcodec=auto` | + +> [!NOTE] +> In order to use the HW accelerated encoders you might need to upgrade your GPU drivers. + +> [!NOTE] +> `libsvtav1` is the default because it provides the best training performance; other vcodecs can reduce CPU usage and be faster, but they typically produce larger files and may affect training time. + +## 5. Troubleshooting + +| Symptom | Likely Cause | Fix | +| ------------------------------------------------------------------ | -------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | +| System freezes or choppy robot movement or Rerun visualization lag | CPU starved (100% load usage) | Close other apps, reduce encoding throughput, lower `encoder_threads`, use `h264`, use `display_data=False`. If the CPU continues to be at 100% then it might be insufficient for your setup, consider `--dataset.streaming_encoding=false` or HW encoding (`--dataset.vcodec=auto`) | +| "Encoder queue full" warnings or dropped frames in dataset | Encoder can't keep up (Queue overflow) | If CPU is not at 100%: Increase `encoder_threads`, increase `encoder_queue_maxsize` or use HW encoding (`--dataset.vcodec=auto`). | +| High RAM usage | Queue filling faster than encoding | `encoder_threads` too low or CPU insufficient. Reduce `encoder_queue_maxsize` or use HW encoding | +| Large video files | Using HW encoder or H.264 | Expected trade-off. Switch to `libsvtav1` if CPU allows | +| `save_episode()` still slow | `streaming_encoding` is `False` | Set `--dataset.streaming_encoding=true` | +| Encoder thread crash | Codec not available or invalid settings | Check `vcodec` is installed, try `--dataset.vcodec=auto` | +| Recorded dataset is missing frames | CPU/GPU starvation or occasional load spikes | If ~5% of frames are missing, your system is likely overloaded — follow the recommendations above. If fewer frames are missing (~2%), they are probably due to occasional transient load spikes (often at startup) and can be considered expected. | + +## 6. Recommended Configurations + +These estimates are conservative; we recommend testing them on your setup—start with a low load and increase it gradually. + +### High-End Systems: modern 12+ cores (24+ threads) + +A throughput between ~250-500M px/sec should be comfortable in CPU. For even better results try HW encoding if available. + +```bash +# 3camsx 1280x720x3 @30fps: Defaults work well. Optionally increase encoder parallelism. +# 2camsx 1920x1080x3 @30fps: Defaults work well. Optionally increase encoder parallelism. +lerobot-record --dataset.encoder_threads=5 ... + +# 3camsx 1920x1080x3 @30fps: Might require some tuning. +``` + +### Mid-Range Systems: modern 8+ cores (16+ threads) or Apple Silicon + +A throughput between ~80-300M px/sec should be possible in CPU. + +```bash +# 3camsx 640x480x3 @30fps: Defaults work well. Optionally decrease encoder parallelism. +# 2camsx 1280x720x3 @30fps: Defaults work well. Optionally decrease encoder parallelism. +lerobot-record --dataset.encoder_threads=2 ... + +# 2camsx 1920x1080x3 @30fps: Might require some tuning. +``` + +### Low-Resource Systems: modern 4+ cores (8+ threads) or Raspberry Pi 5 + +On very constrained systems, streaming encoding may compete too heavily with the capture loop. Disabling it falls back to the PNG-based approach where encoding happens between episodes (blocking, but doesn't interfere with capture). Alternatively, record at a lower throughput to reduce both capture and encoding load. Consider also changing codec to `h264` and using batch encoding. + +```bash +# 2camsx 640x480x3 @30fps: Requires some tuning. + +# Use H.264, disable streaming, consider batching encoding +lerobot-record --dataset.vcodec=h264 --dataset.streaming_encoding=false ... +``` + +## 7. Closing note + +Performance ultimately depends on your exact setup — frames-per-second, resolution, CPU cores and load, available memory, episode length, and the encoder you choose. Always test with your target workload, be mindful about your CPU & system capabilities and tune `encoder_threads`, `encoder_queue_maxsize`, and +`vcodec` reasonably. That said, a common practical configuration (for many applications) is three cameras at 640×480x3 @30fps; this usually runs fine with the default streaming video encoding settings in modern systems. Always verify your recorded dataset is healthy by comparing the video duration to the CLI episode duration and confirming the row count equals FPS × CLI duration. diff --git a/docs/source/unitree_g1.mdx b/docs/source/unitree_g1.mdx index 4c5d28924..76e972dca 100644 --- a/docs/source/unitree_g1.mdx +++ b/docs/source/unitree_g1.mdx @@ -229,7 +229,10 @@ lerobot-record \ --dataset.num_episodes=2 \ --dataset.episode_time_s=5 \ --dataset.reset_time_s=5 \ - --dataset.push_to_hub=true + --dataset.push_to_hub=true \ + --dataset.streaming_encoding=true \ + # --dataset.vcodec=auto \ + --dataset.encoder_threads=2 ``` Example simulation dataset: [nepyope/teleop_test_sim](https://huggingface.co/datasets/nepyope/teleop_test_sim) @@ -279,7 +282,10 @@ lerobot-record \ --dataset.num_episodes=2 \ --dataset.episode_time_s=5 \ --dataset.reset_time_s=5 \ - --dataset.push_to_hub=true + --dataset.push_to_hub=true \ + --dataset.streaming_encoding=true \ + # --dataset.vcodec=auto \ + --dataset.encoder_threads=2 ``` **Note**: Update `server_address` to match your robot's camera server IP. diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index 360ed8d30..65b475e26 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -68,6 +68,7 @@ from lerobot.datasets.utils import ( write_tasks, ) from lerobot.datasets.video_utils import ( + StreamingVideoEncoder, VideoFrame, concatenate_video_files, decode_video_frames, @@ -75,11 +76,11 @@ from lerobot.datasets.video_utils import ( get_safe_default_codec, get_video_duration_in_s, get_video_info, + resolve_vcodec, ) from lerobot.utils.constants import HF_LEROBOT_HOME CODEBASE_VERSION = "v3.0" -VALID_VIDEO_CODECS = {"h264", "hevc", "libsvtav1"} class LeRobotDatasetMetadata: @@ -545,12 +546,19 @@ class LeRobotDatasetMetadata: def _encode_video_worker( - video_key: str, episode_index: int, root: Path, fps: int, vcodec: str = "libsvtav1" + video_key: str, + episode_index: int, + root: Path, + fps: int, + vcodec: str = "libsvtav1", + encoder_threads: int | None = None, ) -> Path: temp_path = Path(tempfile.mkdtemp(dir=root)) / f"{video_key}_{episode_index:03d}.mp4" fpath = DEFAULT_IMAGE_PATH.format(image_key=video_key, episode_index=episode_index, frame_index=0) img_dir = (root / fpath).parent - encode_video_frames(img_dir, temp_path, fps, vcodec=vcodec, overwrite=True) + encode_video_frames( + img_dir, temp_path, fps, vcodec=vcodec, overwrite=True, encoder_threads=encoder_threads + ) shutil.rmtree(img_dir) return temp_path @@ -570,6 +578,9 @@ class LeRobotDataset(torch.utils.data.Dataset): video_backend: str | None = None, batch_encoding_size: int = 1, vcodec: str = "libsvtav1", + streaming_encoding: bool = False, + encoder_queue_maxsize: int = 30, + encoder_threads: int | None = None, ): """ 2 modes are available for instantiating this class, depending on 2 different use cases: @@ -683,12 +694,17 @@ class LeRobotDataset(torch.utils.data.Dataset): batch_encoding_size (int, optional): Number of episodes to accumulate before batch encoding videos. Set to 1 for immediate encoding (default), or higher for batched encoding. Defaults to 1. vcodec (str, optional): Video codec for encoding videos during recording. Options: 'h264', 'hevc', - 'libsvtav1'. Defaults to 'libsvtav1'. Use 'h264' for faster encoding on systems where AV1 - encoding is CPU-heavy. + 'libsvtav1', 'auto', or hardware-specific codecs like 'h264_videotoolbox', 'h264_nvenc'. + Defaults to 'libsvtav1'. Use 'auto' to auto-detect the best available hardware encoder. + streaming_encoding (bool, optional): If True, encode video frames in real-time during capture + instead of writing PNG images first. This makes save_episode() near-instant. Defaults to False. + encoder_queue_maxsize (int, optional): Maximum number of frames to buffer per camera when using + streaming encoding. Defaults to 30 (~1s at 30fps). + encoder_threads (int | None, optional): Number of threads per encoder instance. None lets the + codec auto-detect (default). Lower values reduce CPU usage per encoder. Maps to 'lp' (via svtav1-params) for + libsvtav1 and 'threads' for h264/hevc. """ super().__init__() - if vcodec not in VALID_VIDEO_CODECS: - raise ValueError(f"Invalid vcodec '{vcodec}'. Must be one of: {sorted(VALID_VIDEO_CODECS)}") self.repo_id = repo_id self.root = Path(root) if root else HF_LEROBOT_HOME / repo_id self.image_transforms = image_transforms @@ -700,7 +716,8 @@ class LeRobotDataset(torch.utils.data.Dataset): self.delta_indices = None self.batch_encoding_size = batch_encoding_size self.episodes_since_last_encoding = 0 - self.vcodec = vcodec + self.vcodec = resolve_vcodec(vcodec) + self._encoder_threads = encoder_threads # Unused attributes self.image_writer = None @@ -708,6 +725,7 @@ class LeRobotDataset(torch.utils.data.Dataset): self.writer = None self.latest_episode = None self._current_file_start_frame = None # Track the starting frame index of the current parquet file + self._streaming_encoder = None self.root.mkdir(exist_ok=True, parents=True) @@ -749,6 +767,19 @@ class LeRobotDataset(torch.utils.data.Dataset): check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s) self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps) + # Initialize streaming encoder for resumed recording + if streaming_encoding and len(self.meta.video_keys) > 0: + self._streaming_encoder = StreamingVideoEncoder( + fps=self.meta.fps, + vcodec=self.vcodec, + pix_fmt="yuv420p", + g=2, + crf=30, + preset=None, + queue_maxsize=encoder_queue_maxsize, + encoder_threads=encoder_threads, + ) + def _close_writer(self) -> None: """Close and cleanup the parquet writer if it exists.""" writer = getattr(self, "writer", None) @@ -1104,6 +1135,8 @@ class LeRobotDataset(torch.utils.data.Dataset): """ self._close_writer() self.meta._close_writer() + if self._streaming_encoder is not None: + self._streaming_encoder.close() def create_episode_buffer(self, episode_index: int | None = None) -> dict: current_ep_idx = self.meta.total_episodes if episode_index is None else episode_index @@ -1158,6 +1191,13 @@ class LeRobotDataset(torch.utils.data.Dataset): self.episode_buffer["timestamp"].append(timestamp) self.episode_buffer["task"].append(frame.pop("task")) # Remove task from frame after processing + # Start streaming encoder on first frame of episode (once, before iterating keys) + if frame_index == 0 and self._streaming_encoder is not None: + self._streaming_encoder.start_episode( + video_keys=list(self.meta.video_keys), + temp_dir=self.root, + ) + # Add frame features to episode_buffer for key in frame: if key not in self.features: @@ -1165,7 +1205,10 @@ class LeRobotDataset(torch.utils.data.Dataset): f"An element of the frame is not in the features. '{key}' not in '{self.features.keys()}'." ) - if self.features[key]["dtype"] in ["image", "video"]: + if self.features[key]["dtype"] == "video" and self._streaming_encoder is not None: + self._streaming_encoder.feed_frame(key, frame[key]) + self.episode_buffer[key].append(None) # Placeholder (video keys are skipped in parquet) + elif self.features[key]["dtype"] in ["image", "video"]: img_path = self._get_image_file_path( episode_index=self.episode_buffer["episode_index"], image_key=key, frame_index=frame_index ) @@ -1226,13 +1269,38 @@ class LeRobotDataset(torch.utils.data.Dataset): # Wait for image writer to end, so that episode stats over images can be computed self._wait_image_writer() - ep_stats = compute_episode_stats(episode_buffer, self.features) - ep_metadata = self._save_episode_data(episode_buffer) has_video_keys = len(self.meta.video_keys) > 0 + use_streaming = self._streaming_encoder is not None and has_video_keys use_batched_encoding = self.batch_encoding_size > 1 - if has_video_keys and not use_batched_encoding: + if use_streaming: + # Compute stats for non-video features only (video stats come from encoder) + non_video_buffer = { + k: v + for k, v in episode_buffer.items() + if self.features.get(k, {}).get("dtype") not in ("video",) + } + non_video_features = {k: v for k, v in self.features.items() if v["dtype"] != "video"} + ep_stats = compute_episode_stats(non_video_buffer, non_video_features) + else: + ep_stats = compute_episode_stats(episode_buffer, self.features) + + ep_metadata = self._save_episode_data(episode_buffer) + + if use_streaming: + # Finish streaming encoding and collect results + streaming_results = self._streaming_encoder.finish_episode() + for video_key in self.meta.video_keys: + temp_path, video_stats = streaming_results[video_key] + if video_stats is not None: + # Format stats same as compute_episode_stats: normalize to [0,1], reshape to (C,1,1) + ep_stats[video_key] = { + k: v if k == "count" else np.squeeze(v.reshape(1, -1, 1, 1) / 255.0, axis=0) + for k, v in video_stats.items() + } + ep_metadata.update(self._save_episode_video(video_key, episode_index, temp_path=temp_path)) + elif has_video_keys and not use_batched_encoding: num_cameras = len(self.meta.video_keys) if parallel_encoding and num_cameras > 1: # TODO(Steven): Ideally we would like to control the number of threads per encoding such that: @@ -1246,6 +1314,7 @@ class LeRobotDataset(torch.utils.data.Dataset): self.root, self.fps, self.vcodec, + self._encoder_threads, ): video_key for video_key in self.meta.video_keys } @@ -1514,6 +1583,10 @@ class LeRobotDataset(torch.utils.data.Dataset): return metadata def clear_episode_buffer(self, delete_images: bool = True) -> None: + # Cancel streaming encoder if active + if self._streaming_encoder is not None: + self._streaming_encoder.cancel_episode() + # Clean up image files for the current episode buffer if delete_images: # Wait for the async image writer to finish @@ -1561,7 +1634,9 @@ class LeRobotDataset(torch.utils.data.Dataset): Note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding, since video encoding with ffmpeg is already using multithreading. """ - return _encode_video_worker(video_key, episode_index, self.root, self.fps, self.vcodec) + return _encode_video_worker( + video_key, episode_index, self.root, self.fps, self.vcodec, self._encoder_threads + ) @classmethod def create( @@ -1578,10 +1653,12 @@ class LeRobotDataset(torch.utils.data.Dataset): video_backend: str | None = None, batch_encoding_size: int = 1, vcodec: str = "libsvtav1", + streaming_encoding: bool = False, + encoder_queue_maxsize: int = 30, + encoder_threads: int | None = None, ) -> "LeRobotDataset": """Create a LeRobot Dataset from scratch in order to record data.""" - if vcodec not in VALID_VIDEO_CODECS: - raise ValueError(f"Invalid vcodec '{vcodec}'. Must be one of: {sorted(VALID_VIDEO_CODECS)}") + vcodec = resolve_vcodec(vcodec) obj = cls.__new__(cls) obj.meta = LeRobotDatasetMetadata.create( repo_id=repo_id, @@ -1599,6 +1676,7 @@ class LeRobotDataset(torch.utils.data.Dataset): obj.batch_encoding_size = batch_encoding_size obj.episodes_since_last_encoding = 0 obj.vcodec = vcodec + obj._encoder_threads = encoder_threads if image_writer_processes or image_writer_threads: obj.start_image_writer(image_writer_processes, image_writer_threads) @@ -1620,6 +1698,22 @@ class LeRobotDataset(torch.utils.data.Dataset): obj._lazy_loading = False obj._recorded_frames = 0 obj._writer_closed_for_reading = False + + # Initialize streaming encoder + if streaming_encoding and len(obj.meta.video_keys) > 0: + obj._streaming_encoder = StreamingVideoEncoder( + fps=fps, + vcodec=vcodec, + pix_fmt="yuv420p", + g=2, + crf=30, + preset=None, + queue_maxsize=encoder_queue_maxsize, + encoder_threads=encoder_threads, + ) + else: + obj._streaming_encoder = None + return obj diff --git a/src/lerobot/datasets/video_utils.py b/src/lerobot/datasets/video_utils.py index 84ce13772..acc24a9e0 100644 --- a/src/lerobot/datasets/video_utils.py +++ b/src/lerobot/datasets/video_utils.py @@ -13,25 +13,106 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import contextlib import glob import importlib import logging +import queue import shutil import tempfile +import threading import warnings from dataclasses import dataclass, field +from fractions import Fraction from pathlib import Path from threading import Lock from typing import Any, ClassVar import av import fsspec +import numpy as np import pyarrow as pa import torch import torchvision from datasets.features.features import register_feature from PIL import Image +# List of hardware encoders to probe for auto-selection. Availability depends on the platform and FFmpeg build. +# Determines the order of preference for auto-selection when vcodec="auto" is used. +HW_ENCODERS = [ + "h264_videotoolbox", # macOS + "hevc_videotoolbox", # macOS + "h264_nvenc", # NVIDIA GPU + "hevc_nvenc", # NVIDIA GPU + "h264_vaapi", # Linux Intel/AMD + "h264_qsv", # Intel Quick Sync +] + +VALID_VIDEO_CODECS = {"h264", "hevc", "libsvtav1", "auto"} | set(HW_ENCODERS) + + +def _get_codec_options( + vcodec: str, + g: int | None = 2, + crf: int | None = 30, + preset: int | None = None, +) -> dict: + """Build codec-specific options dict for video encoding.""" + options = {} + + # GOP size (keyframe interval) - supported by VideoToolbox and software encoders + if g is not None and (vcodec in ("h264_videotoolbox", "hevc_videotoolbox") or vcodec not in HW_ENCODERS): + options["g"] = str(g) + + # Quality control (codec-specific parameter names) + if crf is not None: + if vcodec in ("h264", "hevc", "libsvtav1"): + options["crf"] = str(crf) + elif vcodec in ("h264_videotoolbox", "hevc_videotoolbox"): + quality = max(1, min(100, int(100 - crf * 2))) + options["q:v"] = str(quality) + elif vcodec in ("h264_nvenc", "hevc_nvenc"): + options["rc"] = "constqp" + options["qp"] = str(crf) + elif vcodec in ("h264_vaapi",): + options["qp"] = str(crf) + elif vcodec in ("h264_qsv",): + options["global_quality"] = str(crf) + + # Preset (only for libsvtav1) + if vcodec == "libsvtav1": + options["preset"] = str(preset) if preset is not None else "12" + + return options + + +def detect_available_hw_encoders() -> list[str]: + """Probe PyAV/FFmpeg for available hardware video encoders.""" + available = [] + for codec_name in HW_ENCODERS: + try: + av.codec.Codec(codec_name, "w") + available.append(codec_name) + except Exception: # nosec B110 + pass # nosec B110 + return available + + +def resolve_vcodec(vcodec: str) -> str: + """Validate vcodec and resolve 'auto' to best available HW encoder, fallback to libsvtav1.""" + if vcodec not in VALID_VIDEO_CODECS: + raise ValueError(f"Invalid vcodec '{vcodec}'. Must be one of: {sorted(VALID_VIDEO_CODECS)}") + if vcodec != "auto": + logging.info(f"Using video codec: {vcodec}") + return vcodec + available = detect_available_hw_encoders() + for encoder in HW_ENCODERS: + if encoder in available: + logging.info(f"Auto-selected video codec: {encoder}") + return encoder + logging.info("No hardware encoder available, falling back to software encoder 'libsvtav1'") + return "libsvtav1" + def get_safe_default_codec(): if importlib.util.find_spec("torchcodec"): @@ -309,14 +390,13 @@ def encode_video_frames( g: int | None = 2, crf: int | None = 30, fast_decode: int = 0, - log_level: int | None = av.logging.ERROR, + log_level: int | None = av.logging.WARNING, overwrite: bool = False, preset: int | None = None, + encoder_threads: int | None = None, ) -> None: """More info on ffmpeg arguments tuning on `benchmark/video/README.md`""" - # Check encoder availability - if vcodec not in ["h264", "hevc", "libsvtav1"]: - raise ValueError(f"Unsupported video codec: {vcodec}. Supported codecs are: h264, hevc, libsvtav1.") + vcodec = resolve_vcodec(vcodec) video_path = Path(video_path) imgs_dir = Path(imgs_dir) @@ -347,21 +427,22 @@ def encode_video_frames( width, height = dummy_image.size # Define video codec options - video_options = {} - - if g is not None: - video_options["g"] = str(g) - - if crf is not None: - video_options["crf"] = str(crf) + video_options = _get_codec_options(vcodec, g, crf, preset) if fast_decode: key = "svtav1-params" if vcodec == "libsvtav1" else "tune" value = f"fast-decode={fast_decode}" if vcodec == "libsvtav1" else "fastdecode" video_options[key] = value - if vcodec == "libsvtav1": - video_options["preset"] = str(preset) if preset is not None else "12" + if encoder_threads is not None: + if vcodec == "libsvtav1": + lp_param = f"lp={encoder_threads}" + if "svtav1-params" in video_options: + video_options["svtav1-params"] += f":{lp_param}" + else: + video_options["svtav1-params"] = lp_param + else: + video_options["threads"] = str(encoder_threads) # Set logging level if log_level is not None: @@ -480,6 +561,348 @@ def concatenate_video_files( Path(tmp_concatenate_path).unlink() +class _CameraEncoderThread(threading.Thread): + """A thread that encodes video frames streamed via a queue into an MP4 file. + + One instance is created per camera per episode. Frames are received as numpy arrays + from the main thread, encoded in real-time using PyAV (which releases the GIL during + encoding), and written to disk. Stats are computed incrementally using + RunningQuantileStats and returned via result_queue. + """ + + def __init__( + self, + video_path: Path, + fps: int, + vcodec: str, + pix_fmt: str, + g: int | None, + crf: int | None, + preset: int | None, + frame_queue: queue.Queue, + result_queue: queue.Queue, + stop_event: threading.Event, + encoder_threads: int | None = None, + ): + super().__init__(daemon=True) + self.video_path = video_path + self.fps = fps + self.vcodec = vcodec + self.pix_fmt = pix_fmt + self.g = g + self.crf = crf + self.preset = preset + self.frame_queue = frame_queue + self.result_queue = result_queue + self.stop_event = stop_event + self.encoder_threads = encoder_threads + + def run(self) -> None: + from lerobot.datasets.compute_stats import RunningQuantileStats, auto_downsample_height_width + + container = None + output_stream = None + stats_tracker = RunningQuantileStats() + frame_count = 0 + + try: + logging.getLogger("libav").setLevel(av.logging.WARNING) + + while True: + try: + frame_data = self.frame_queue.get(timeout=1) + except queue.Empty: + if self.stop_event.is_set(): + break + continue + + if frame_data is None: + # Sentinel: flush and close + break + + # Ensure HWC uint8 numpy array + if isinstance(frame_data, np.ndarray): + if frame_data.ndim == 3 and frame_data.shape[0] == 3: + # CHW -> HWC + frame_data = frame_data.transpose(1, 2, 0) + if frame_data.dtype != np.uint8: + frame_data = (frame_data * 255).astype(np.uint8) + + # Open container on first frame (to get width/height) + if container is None: + height, width = frame_data.shape[:2] + video_options = _get_codec_options(self.vcodec, self.g, self.crf, self.preset) + if self.encoder_threads is not None: + if self.vcodec == "libsvtav1": + lp_param = f"lp={self.encoder_threads}" + if "svtav1-params" in video_options: + video_options["svtav1-params"] += f":{lp_param}" + else: + video_options["svtav1-params"] = lp_param + else: + video_options["threads"] = str(self.encoder_threads) + Path(self.video_path).parent.mkdir(parents=True, exist_ok=True) + container = av.open(str(self.video_path), "w") + output_stream = container.add_stream(self.vcodec, self.fps, options=video_options) + output_stream.pix_fmt = self.pix_fmt + output_stream.width = width + output_stream.height = height + output_stream.time_base = Fraction(1, self.fps) + + # Encode frame with explicit timestamps + pil_img = Image.fromarray(frame_data) + video_frame = av.VideoFrame.from_image(pil_img) + video_frame.pts = frame_count + video_frame.time_base = Fraction(1, self.fps) + packet = output_stream.encode(video_frame) + if packet: + container.mux(packet) + + # Update stats with downsampled frame (per-channel stats like compute_episode_stats) + img_chw = frame_data.transpose(2, 0, 1) # HWC -> CHW + img_downsampled = auto_downsample_height_width(img_chw) + # Reshape CHW to (H*W, C) for per-channel stats + channels = img_downsampled.shape[0] + img_for_stats = img_downsampled.transpose(1, 2, 0).reshape(-1, channels) + stats_tracker.update(img_for_stats) + + frame_count += 1 + + # Flush encoder + if output_stream is not None: + packet = output_stream.encode() + if packet: + container.mux(packet) + + if container is not None: + container.close() + + av.logging.restore_default_callback() + + # Get stats and put on result queue + if frame_count >= 2: + stats = stats_tracker.get_statistics() + self.result_queue.put(("ok", stats)) + else: + self.result_queue.put(("ok", None)) + + except Exception as e: + logging.error(f"Encoder thread error: {e}") + if container is not None: + with contextlib.suppress(Exception): + container.close() + self.result_queue.put(("error", str(e))) + + +class StreamingVideoEncoder: + """Manages per-camera encoder threads for real-time video encoding during recording. + + Instead of writing frames as PNG images and then encoding to MP4 at episode end, + this class streams frames directly to encoder threads, eliminating the + PNG round-trip and making save_episode() near-instant. + + Uses threading instead of multiprocessing to avoid the overhead of pickling large + numpy arrays through multiprocessing.Queue. PyAV's encode() releases the GIL, + so encoding runs in parallel with the main recording loop. + """ + + def __init__( + self, + fps: int, + vcodec: str = "libsvtav1", + pix_fmt: str = "yuv420p", + g: int | None = 2, + crf: int | None = 30, + preset: int | None = None, + queue_maxsize: int = 30, + encoder_threads: int | None = None, + ): + self.fps = fps + self.vcodec = resolve_vcodec(vcodec) + self.pix_fmt = pix_fmt + self.g = g + self.crf = crf + self.preset = preset + self.queue_maxsize = queue_maxsize + self.encoder_threads = encoder_threads + + self._frame_queues: dict[str, queue.Queue] = {} + self._result_queues: dict[str, queue.Queue] = {} + self._threads: dict[str, _CameraEncoderThread] = {} + self._stop_events: dict[str, threading.Event] = {} + self._video_paths: dict[str, Path] = {} + self._dropped_frames: dict[str, int] = {} + self._episode_active = False + + def start_episode(self, video_keys: list[str], temp_dir: Path) -> None: + """Start encoder threads for a new episode. + + Args: + video_keys: List of video feature keys (e.g. ["observation.images.laptop"]) + temp_dir: Base directory for temporary MP4 files + """ + if self._episode_active: + self.cancel_episode() + + self._dropped_frames.clear() + + for video_key in video_keys: + frame_queue: queue.Queue = queue.Queue(maxsize=self.queue_maxsize) + result_queue: queue.Queue = queue.Queue(maxsize=1) + stop_event = threading.Event() + + temp_video_dir = Path(tempfile.mkdtemp(dir=temp_dir)) + video_path = temp_video_dir / f"{video_key.replace('/', '_')}_streaming.mp4" + + encoder_thread = _CameraEncoderThread( + video_path=video_path, + fps=self.fps, + vcodec=self.vcodec, + pix_fmt=self.pix_fmt, + g=self.g, + crf=self.crf, + preset=self.preset, + frame_queue=frame_queue, + result_queue=result_queue, + stop_event=stop_event, + encoder_threads=self.encoder_threads, + ) + encoder_thread.start() + + self._frame_queues[video_key] = frame_queue + self._result_queues[video_key] = result_queue + self._threads[video_key] = encoder_thread + self._stop_events[video_key] = stop_event + self._video_paths[video_key] = video_path + + self._episode_active = True + + def feed_frame(self, video_key: str, image: np.ndarray) -> None: + """Feed a frame to the encoder for a specific camera. + + A copy of the image is made before enqueueing to prevent race conditions + with camera drivers that may reuse buffers. If the encoder queue is full + (encoder can't keep up), the frame is dropped with a warning instead of + crashing the recording session. + + Args: + video_key: The video feature key + image: numpy array in (H,W,C) or (C,H,W) format, uint8 or float + + Raises: + RuntimeError: If the encoder thread has crashed + """ + if not self._episode_active: + raise RuntimeError("No active episode. Call start_episode() first.") + + thread = self._threads[video_key] + if not thread.is_alive(): + # Check for error + try: + status, msg = self._result_queues[video_key].get_nowait() + if status == "error": + raise RuntimeError(f"Encoder thread for {video_key} crashed: {msg}") + except queue.Empty: + pass + raise RuntimeError(f"Encoder thread for {video_key} is not alive") + + try: + self._frame_queues[video_key].put(image.copy(), timeout=0.1) + except queue.Full: + self._dropped_frames[video_key] = self._dropped_frames.get(video_key, 0) + 1 + count = self._dropped_frames[video_key] + # Log periodically to avoid spam (1st, then every 10th) + if count == 1 or count % 10 == 0: + logging.warning( + f"Encoder queue full for {video_key}, dropped {count} frame(s). " + f"Consider using vcodec='auto' for hardware encoding or increasing encoder_queue_maxsize." + ) + + def finish_episode(self) -> dict[str, tuple[Path, dict | None]]: + """Finish encoding the current episode. + + Sends sentinel values, waits for encoder threads to complete, + and collects results. + + Returns: + Dict mapping video_key to (mp4_path, stats_dict_or_None) + """ + if not self._episode_active: + raise RuntimeError("No active episode to finish.") + + results = {} + + # Report dropped frames + for video_key, count in self._dropped_frames.items(): + if count > 0: + logging.warning(f"Episode finished with {count} dropped frame(s) for {video_key}.") + + # Send sentinel to all queues + for video_key in self._frame_queues: + self._frame_queues[video_key].put(None) + + # Wait for all threads and collect results + for video_key in self._threads: + self._threads[video_key].join(timeout=120) + if self._threads[video_key].is_alive(): + logging.error(f"Encoder thread for {video_key} did not finish in time") + self._stop_events[video_key].set() + self._threads[video_key].join(timeout=5) + results[video_key] = (self._video_paths[video_key], None) + continue + + try: + status, data = self._result_queues[video_key].get(timeout=5) + if status == "error": + raise RuntimeError(f"Encoder thread for {video_key} failed: {data}") + results[video_key] = (self._video_paths[video_key], data) + except queue.Empty: + logging.error(f"No result from encoder thread for {video_key}") + results[video_key] = (self._video_paths[video_key], None) + + self._cleanup() + self._episode_active = False + return results + + def cancel_episode(self) -> None: + """Cancel the current episode, stopping encoder threads and cleaning up.""" + if not self._episode_active: + return + + # Signal all threads to stop + for video_key in self._stop_events: + self._stop_events[video_key].set() + + # Wait for threads to finish + for video_key in self._threads: + self._threads[video_key].join(timeout=5) + + # Clean up temp MP4 files + video_path = self._video_paths.get(video_key) + if video_path is not None and video_path.exists(): + shutil.rmtree(str(video_path.parent), ignore_errors=True) + + self._cleanup() + self._episode_active = False + + def close(self) -> None: + """Close the encoder, canceling any in-progress episode.""" + if self._episode_active: + self.cancel_episode() + + def _cleanup(self) -> None: + """Clean up queues and thread tracking dicts.""" + for q in self._frame_queues.values(): + with contextlib.suppress(Exception): + while not q.empty(): + q.get_nowait() + self._frame_queues.clear() + self._result_queues.clear() + self._threads.clear() + self._stop_events.clear() + self._video_paths.clear() + + @dataclass class VideoFrame: # TODO(rcadene, lhoestq): move to Hugging Face `datasets` repo @@ -514,7 +937,7 @@ with warnings.catch_warnings(): def get_audio_info(video_path: Path | str) -> dict: # Set logging level - logging.getLogger("libav").setLevel(av.logging.ERROR) + logging.getLogger("libav").setLevel(av.logging.WARNING) # Getting audio stream information audio_info = {} @@ -546,7 +969,7 @@ def get_audio_info(video_path: Path | str) -> dict: def get_video_info(video_path: Path | str) -> dict: # Set logging level - logging.getLogger("libav").setLevel(av.logging.ERROR) + logging.getLogger("libav").setLevel(av.logging.WARNING) # Getting video stream information video_info = {} @@ -632,8 +1055,15 @@ class VideoEncodingManager: return self def __exit__(self, exc_type, exc_val, exc_tb): - # Handle any remaining episodes that haven't been batch encoded - if self.dataset.episodes_since_last_encoding > 0: + streaming_encoder = getattr(self.dataset, "_streaming_encoder", None) + + if streaming_encoder is not None: + # Handle streaming encoder cleanup + if exc_type is not None: + streaming_encoder.cancel_episode() + streaming_encoder.close() + elif self.dataset.episodes_since_last_encoding > 0: + # Handle any remaining episodes that haven't been batch encoded if exc_type is not None: logging.info("Exception occurred. Encoding remaining episodes before exit...") else: @@ -650,8 +1080,8 @@ class VideoEncodingManager: # Finalize the dataset to properly close all writers self.dataset.finalize() - # Clean up episode images if recording was interrupted - if exc_type is not None: + # Clean up episode images if recording was interrupted (only for non-streaming mode) + if exc_type is not None and streaming_encoder is None: interrupted_episode_index = self.dataset.num_episodes for key in self.dataset.meta.video_keys: img_dir = self.dataset._get_image_file_path( @@ -665,14 +1095,12 @@ class VideoEncodingManager: # Clean up any remaining images directory if it's empty img_dir = self.dataset.root / "images" - # Check for any remaining PNG files - png_files = list(img_dir.rglob("*.png")) - if len(png_files) == 0: - # Only remove the images directory if no PNG files remain - if img_dir.exists(): + if img_dir.exists(): + png_files = list(img_dir.rglob("*.png")) + if len(png_files) == 0: shutil.rmtree(img_dir) logging.debug("Cleaned up empty images directory") - else: - logging.debug(f"Images directory is not empty, containing {len(png_files)} PNG files") + else: + logging.debug(f"Images directory is not empty, containing {len(png_files)} PNG files") return False # Don't suppress the original exception diff --git a/src/lerobot/scripts/lerobot_record.py b/src/lerobot/scripts/lerobot_record.py index 216ab22a6..ec04975d4 100644 --- a/src/lerobot/scripts/lerobot_record.py +++ b/src/lerobot/scripts/lerobot_record.py @@ -26,8 +26,10 @@ lerobot-record \ --dataset.repo_id=/ \ --dataset.num_episodes=2 \ --dataset.single_task="Grab the cube" \ + --dataset.streaming_encoding=true \ + --dataset.encoder_threads=2 \ --display_data=true - # <- Optional: specify video codec (h264, hevc, libsvtav1). Default is libsvtav1. \ + # <- Optional: specify video codec (auto, h264, hevc, libsvtav1). Default is libsvtav1. \ # --dataset.vcodec=h264 \ # <- Teleop optional if you want to teleoperate to record or in between episodes with a policy \ # --teleop.type=so100_leader \ @@ -58,7 +60,10 @@ lerobot-record \ --display_data=true \ --dataset.repo_id=${HF_USER}/bimanual-so-handover-cube \ --dataset.num_episodes=25 \ - --dataset.single_task="Grab and handover the red cube to the other arm" + --dataset.single_task="Grab and handover the red cube to the other arm" \ + --dataset.streaming_encoding=true \ + # --dataset.vcodec=auto \ + --dataset.encoder_threads=2 ``` """ @@ -179,9 +184,19 @@ class DatasetRecordConfig: # Number of episodes to record before batch encoding videos # Set to 1 for immediate encoding (default behavior), or higher for batched encoding video_encoding_batch_size: int = 1 - # Video codec for encoding videos. Options: 'h264', 'hevc', 'libsvtav1'. - # Use 'h264' for faster encoding on systems where AV1 encoding is CPU-heavy. + # Video codec for encoding videos. Options: 'h264', 'hevc', 'libsvtav1', 'auto', + # or hardware-specific: 'h264_videotoolbox', 'h264_nvenc', 'h264_vaapi', 'h264_qsv'. + # Use 'auto' to auto-detect the best available hardware encoder. vcodec: str = "libsvtav1" + # Enable streaming video encoding: encode frames in real-time during capture instead + # of writing PNG images first. Makes save_episode() near-instant. More info in the documentation: https://huggingface.co/docs/lerobot/streaming_video_encoding + streaming_encoding: bool = False + # Maximum number of frames to buffer per camera when using streaming encoding. + # ~1s buffer at 30fps. Provides backpressure if the encoder can't keep up. + encoder_queue_maxsize: int = 30 + # Number of threads per encoder instance. None = auto (codec default). + # Lower values reduce CPU usage, maps to 'lp' (via svtav1-params) for libsvtav1 and 'threads' for h264/hevc.. + encoder_threads: int | None = None # Rename map for the observation to override the image and state keys rename_map: dict[str, str] = field(default_factory=dict) @@ -452,6 +467,9 @@ def record(cfg: RecordConfig) -> LeRobotDataset: root=cfg.dataset.root, batch_encoding_size=cfg.dataset.video_encoding_batch_size, vcodec=cfg.dataset.vcodec, + streaming_encoding=cfg.dataset.streaming_encoding, + encoder_queue_maxsize=cfg.dataset.encoder_queue_maxsize, + encoder_threads=cfg.dataset.encoder_threads, ) if hasattr(robot, "cameras") and len(robot.cameras) > 0: @@ -474,6 +492,9 @@ def record(cfg: RecordConfig) -> LeRobotDataset: image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera * len(robot.cameras), batch_encoding_size=cfg.dataset.video_encoding_batch_size, vcodec=cfg.dataset.vcodec, + streaming_encoding=cfg.dataset.streaming_encoding, + encoder_queue_maxsize=cfg.dataset.encoder_queue_maxsize, + encoder_threads=cfg.dataset.encoder_threads, ) # Load pretrained policy @@ -497,6 +518,11 @@ def record(cfg: RecordConfig) -> LeRobotDataset: listener, events = init_keyboard_listener() + if not cfg.dataset.streaming_encoding: + logging.info( + "Streaming encoding is disabled. If you have capable hardware, consider enabling it for way faster episode saving. --dataset.streaming_encoding=true --dataset.encoder_threads=2 # --dataset.vcodec=auto. More info in the documentation: https://huggingface.co/docs/lerobot/streaming_video_encoding" + ) + with VideoEncodingManager(dataset): recorded_episodes = 0 while recorded_episodes < cfg.dataset.num_episodes and not events["stop_recording"]: diff --git a/tests/datasets/test_datasets.py b/tests/datasets/test_datasets.py index 27c51b3c4..6f99eb301 100644 --- a/tests/datasets/test_datasets.py +++ b/tests/datasets/test_datasets.py @@ -31,7 +31,6 @@ from lerobot.configs.train import TrainPipelineConfig from lerobot.datasets.factory import make_dataset from lerobot.datasets.image_writer import image_array_to_pil_image from lerobot.datasets.lerobot_dataset import ( - VALID_VIDEO_CODECS, LeRobotDataset, MultiLeRobotDataset, _encode_video_worker, @@ -45,6 +44,7 @@ from lerobot.datasets.utils import ( hf_transform_to_torch, hw_to_dataset_features, ) +from lerobot.datasets.video_utils import VALID_VIDEO_CODECS from lerobot.envs.factory import make_env_config from lerobot.policies.factory import make_policy_config from lerobot.robots import make_robot_from_config @@ -393,7 +393,7 @@ def test_tmp_mixed_deletion(tmp_path, empty_lerobot_dataset_factory): vid_key: {"dtype": "video", "shape": DUMMY_HWC, "names": ["height", "width", "channels"]}, } ds_mixed = empty_lerobot_dataset_factory( - root=tmp_path / "mixed", features=features_mixed, batch_encoding_size=2 + root=tmp_path / "mixed", features=features_mixed, batch_encoding_size=2, streaming_encoding=False ) ds_mixed.add_frame( { @@ -1450,7 +1450,10 @@ def test_valid_video_codecs_constant(): assert "h264" in VALID_VIDEO_CODECS assert "hevc" in VALID_VIDEO_CODECS assert "libsvtav1" in VALID_VIDEO_CODECS - assert len(VALID_VIDEO_CODECS) == 3 + assert "auto" in VALID_VIDEO_CODECS + assert "h264_videotoolbox" in VALID_VIDEO_CODECS + assert "h264_nvenc" in VALID_VIDEO_CODECS + assert len(VALID_VIDEO_CODECS) == 10 def test_delta_timestamps_with_episodes_filter(tmp_path, empty_lerobot_dataset_factory): diff --git a/tests/datasets/test_streaming_video_encoder.py b/tests/datasets/test_streaming_video_encoder.py new file mode 100644 index 000000000..a85db6a8d --- /dev/null +++ b/tests/datasets/test_streaming_video_encoder.py @@ -0,0 +1,730 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for streaming video encoding and hardware-accelerated encoding.""" + +import queue +import threading +from unittest.mock import patch + +import av +import numpy as np +import pytest + +from lerobot.datasets.video_utils import ( + VALID_VIDEO_CODECS, + StreamingVideoEncoder, + _CameraEncoderThread, + _get_codec_options, + detect_available_hw_encoders, + resolve_vcodec, +) +from lerobot.utils.constants import OBS_IMAGES + +# ─── _get_codec_options tests ─── + + +class TestGetCodecOptions: + def test_libsvtav1_defaults(self): + opts = _get_codec_options("libsvtav1") + assert opts["g"] == "2" + assert opts["crf"] == "30" + assert opts["preset"] == "12" + + def test_libsvtav1_custom_preset(self): + opts = _get_codec_options("libsvtav1", preset=8) + assert opts["preset"] == "8" + + def test_h264_options(self): + opts = _get_codec_options("h264", g=10, crf=23) + assert opts["g"] == "10" + assert opts["crf"] == "23" + assert "preset" not in opts + + def test_videotoolbox_options(self): + opts = _get_codec_options("h264_videotoolbox", g=2, crf=30) + assert opts["g"] == "2" + # CRF 30 maps to quality = max(1, min(100, 100 - 30*2)) = 40 + assert opts["q:v"] == "40" + assert "crf" not in opts + + def test_nvenc_options(self): + opts = _get_codec_options("h264_nvenc", g=2, crf=25) + assert opts["rc"] == "constqp" + assert opts["qp"] == "25" + assert "crf" not in opts + # NVENC doesn't support g + assert "g" not in opts + + def test_vaapi_options(self): + opts = _get_codec_options("h264_vaapi", crf=28) + assert opts["qp"] == "28" + + def test_qsv_options(self): + opts = _get_codec_options("h264_qsv", crf=25) + assert opts["global_quality"] == "25" + + def test_no_g_no_crf(self): + opts = _get_codec_options("h264", g=None, crf=None) + assert "g" not in opts + assert "crf" not in opts + + +# ─── HW encoder detection tests ─── + + +class TestHWEncoderDetection: + def test_detect_available_hw_encoders_returns_list(self): + result = detect_available_hw_encoders() + assert isinstance(result, list) + + def test_detect_available_hw_encoders_only_valid(self): + from lerobot.datasets.video_utils import HW_ENCODERS + + result = detect_available_hw_encoders() + for encoder in result: + assert encoder in HW_ENCODERS + + def test_resolve_vcodec_passthrough(self): + assert resolve_vcodec("libsvtav1") == "libsvtav1" + assert resolve_vcodec("h264") == "h264" + + def test_resolve_vcodec_auto_fallback(self): + """When no HW encoders are available, auto should fall back to libsvtav1.""" + with patch("lerobot.datasets.video_utils.detect_available_hw_encoders", return_value=[]): + assert resolve_vcodec("auto") == "libsvtav1" + + def test_resolve_vcodec_auto_picks_hw(self): + """When a HW encoder is available, auto should pick it.""" + with patch( + "lerobot.datasets.video_utils.detect_available_hw_encoders", + return_value=["h264_videotoolbox"], + ): + assert resolve_vcodec("auto") == "h264_videotoolbox" + + def test_resolve_vcodec_auto_returns_valid(self): + """Test that resolve_vcodec('auto') returns a known valid codec.""" + result = resolve_vcodec("auto") + assert result in VALID_VIDEO_CODECS + + def test_hw_encoder_names_accepted_in_validation(self): + """Test that HW encoder names pass validation in VALID_VIDEO_CODECS.""" + assert "auto" in VALID_VIDEO_CODECS + assert "h264_videotoolbox" in VALID_VIDEO_CODECS + assert "h264_nvenc" in VALID_VIDEO_CODECS + + def test_resolve_vcodec_invalid_raises(self): + """Test that resolve_vcodec raises ValueError for invalid codecs.""" + with pytest.raises(ValueError, match="Invalid vcodec"): + resolve_vcodec("not_a_real_codec") + + +# ─── _CameraEncoderThread tests ─── + + +class TestCameraEncoderThread: + def test_encodes_valid_mp4(self, tmp_path): + """Test that the encoder thread creates a valid MP4 file with correct frame count.""" + num_frames = 30 + height, width = 64, 96 + fps = 30 + video_path = tmp_path / "test_output" / "test.mp4" + + frame_queue: queue.Queue = queue.Queue(maxsize=60) + result_queue: queue.Queue = queue.Queue(maxsize=1) + stop_event = threading.Event() + + encoder_thread = _CameraEncoderThread( + video_path=video_path, + fps=fps, + vcodec="libsvtav1", + pix_fmt="yuv420p", + g=2, + crf=30, + preset=13, + frame_queue=frame_queue, + result_queue=result_queue, + stop_event=stop_event, + ) + encoder_thread.start() + + # Feed frames (HWC uint8) + for _ in range(num_frames): + frame = np.random.randint(0, 255, (height, width, 3), dtype=np.uint8) + frame_queue.put(frame) + + # Send sentinel + frame_queue.put(None) + encoder_thread.join(timeout=60) + assert not encoder_thread.is_alive() + + # Check result + status, data = result_queue.get(timeout=5) + assert status == "ok" + assert data is not None # Stats should be returned + assert "mean" in data + assert "std" in data + assert "min" in data + assert "max" in data + assert "count" in data + + # Verify the MP4 file is valid + assert video_path.exists() + with av.open(str(video_path)) as container: + stream = container.streams.video[0] + # The frame count should match + total_frames = sum(1 for _ in container.decode(stream)) + assert total_frames == num_frames + + def test_handles_chw_input(self, tmp_path): + """Test that CHW format input is handled correctly.""" + num_frames = 5 + fps = 30 + video_path = tmp_path / "test_chw" / "test.mp4" + + frame_queue: queue.Queue = queue.Queue(maxsize=60) + result_queue: queue.Queue = queue.Queue(maxsize=1) + stop_event = threading.Event() + + encoder_thread = _CameraEncoderThread( + video_path=video_path, + fps=fps, + vcodec="libsvtav1", + pix_fmt="yuv420p", + g=2, + crf=30, + preset=13, + frame_queue=frame_queue, + result_queue=result_queue, + stop_event=stop_event, + ) + encoder_thread.start() + + # Feed CHW frames + for _ in range(num_frames): + frame = np.random.randint(0, 255, (3, 64, 96), dtype=np.uint8) + frame_queue.put(frame) + + frame_queue.put(None) + encoder_thread.join(timeout=60) + + status, _ = result_queue.get(timeout=5) + assert status == "ok" + assert video_path.exists() + + def test_stop_event_cancellation(self, tmp_path): + """Test that setting the stop event causes the thread to exit.""" + fps = 30 + video_path = tmp_path / "test_cancel" / "test.mp4" + + frame_queue: queue.Queue = queue.Queue(maxsize=60) + result_queue: queue.Queue = queue.Queue(maxsize=1) + stop_event = threading.Event() + + encoder_thread = _CameraEncoderThread( + video_path=video_path, + fps=fps, + vcodec="libsvtav1", + pix_fmt="yuv420p", + g=2, + crf=30, + preset=13, + frame_queue=frame_queue, + result_queue=result_queue, + stop_event=stop_event, + ) + encoder_thread.start() + + # Feed a few frames + for _ in range(3): + frame = np.random.randint(0, 255, (64, 96, 3), dtype=np.uint8) + frame_queue.put(frame) + + # Signal stop instead of sending sentinel + stop_event.set() + encoder_thread.join(timeout=10) + assert not encoder_thread.is_alive() + + +# ─── StreamingVideoEncoder tests ─── + + +class TestStreamingVideoEncoder: + def test_single_camera_episode(self, tmp_path): + """Test encoding a single camera episode.""" + encoder = StreamingVideoEncoder(fps=30, vcodec="libsvtav1", pix_fmt="yuv420p", g=2, crf=30, preset=13) + + video_keys = [f"{OBS_IMAGES}.laptop"] + encoder.start_episode(video_keys, tmp_path) + + num_frames = 20 + for _ in range(num_frames): + frame = np.random.randint(0, 255, (64, 96, 3), dtype=np.uint8) + encoder.feed_frame(f"{OBS_IMAGES}.laptop", frame) + + results = encoder.finish_episode() + assert f"{OBS_IMAGES}.laptop" in results + + mp4_path, stats = results[f"{OBS_IMAGES}.laptop"] + assert mp4_path.exists() + assert stats is not None + + # Verify frame count + with av.open(str(mp4_path)) as container: + stream = container.streams.video[0] + total_frames = sum(1 for _ in container.decode(stream)) + assert total_frames == num_frames + + encoder.close() + + def test_multi_camera_episode(self, tmp_path): + """Test encoding multiple cameras simultaneously.""" + encoder = StreamingVideoEncoder(fps=30, vcodec="libsvtav1", pix_fmt="yuv420p", g=2, crf=30) + + video_keys = [f"{OBS_IMAGES}.laptop", f"{OBS_IMAGES}.phone"] + encoder.start_episode(video_keys, tmp_path) + + num_frames = 15 + for _ in range(num_frames): + frame0 = np.random.randint(0, 255, (64, 96, 3), dtype=np.uint8) + frame1 = np.random.randint(0, 255, (64, 96, 3), dtype=np.uint8) + encoder.feed_frame(video_keys[0], frame0) + encoder.feed_frame(video_keys[1], frame1) + + results = encoder.finish_episode() + + for key in video_keys: + assert key in results + mp4_path, stats = results[key] + assert mp4_path.exists() + assert stats is not None + + encoder.close() + + def test_sequential_episodes(self, tmp_path): + """Test that multiple sequential episodes work correctly.""" + encoder = StreamingVideoEncoder(fps=30, vcodec="libsvtav1", pix_fmt="yuv420p", g=2, crf=30) + video_keys = [f"{OBS_IMAGES}.cam"] + + for ep in range(3): + encoder.start_episode(video_keys, tmp_path) + num_frames = 10 + ep * 5 + for _ in range(num_frames): + frame = np.random.randint(0, 255, (64, 96, 3), dtype=np.uint8) + encoder.feed_frame(f"{OBS_IMAGES}.cam", frame) + results = encoder.finish_episode() + + mp4_path, stats = results[f"{OBS_IMAGES}.cam"] + assert mp4_path.exists() + + with av.open(str(mp4_path)) as container: + stream = container.streams.video[0] + total_frames = sum(1 for _ in container.decode(stream)) + assert total_frames == num_frames + + encoder.close() + + def test_cancel_episode(self, tmp_path): + """Test that canceling an episode cleans up properly.""" + encoder = StreamingVideoEncoder(fps=30, vcodec="libsvtav1", pix_fmt="yuv420p", g=2, crf=30) + video_keys = [f"{OBS_IMAGES}.cam"] + + encoder.start_episode(video_keys, tmp_path) + + for _ in range(5): + frame = np.random.randint(0, 255, (64, 96, 3), dtype=np.uint8) + encoder.feed_frame(f"{OBS_IMAGES}.cam", frame) + + encoder.cancel_episode() + + # Should be able to start a new episode after cancel + encoder.start_episode(video_keys, tmp_path) + for _ in range(5): + frame = np.random.randint(0, 255, (64, 96, 3), dtype=np.uint8) + encoder.feed_frame(f"{OBS_IMAGES}.cam", frame) + results = encoder.finish_episode() + + assert f"{OBS_IMAGES}.cam" in results + encoder.close() + + def test_feed_without_start_raises(self, tmp_path): + """Test that feeding frames without starting an episode raises.""" + encoder = StreamingVideoEncoder(fps=30, vcodec="libsvtav1", pix_fmt="yuv420p") + with pytest.raises(RuntimeError, match="No active episode"): + encoder.feed_frame("cam", np.zeros((64, 96, 3), dtype=np.uint8)) + encoder.close() + + def test_finish_without_start_raises(self, tmp_path): + """Test that finishing without starting raises.""" + encoder = StreamingVideoEncoder(fps=30, vcodec="libsvtav1", pix_fmt="yuv420p") + with pytest.raises(RuntimeError, match="No active episode"): + encoder.finish_episode() + encoder.close() + + def test_close_is_idempotent(self, tmp_path): + """Test that close() can be called multiple times safely.""" + encoder = StreamingVideoEncoder(fps=30, vcodec="libsvtav1", pix_fmt="yuv420p") + encoder.close() + encoder.close() # Should not raise + + def test_video_duration_matches_frame_count(self, tmp_path): + """Test that encoded video duration matches num_frames / fps.""" + encoder = StreamingVideoEncoder(fps=30, vcodec="libsvtav1", pix_fmt="yuv420p", g=2, crf=30, preset=13) + video_keys = [f"{OBS_IMAGES}.cam"] + encoder.start_episode(video_keys, tmp_path) + + num_frames = 90 # 3 seconds at 30fps + for _ in range(num_frames): + frame = np.random.randint(0, 255, (64, 96, 3), dtype=np.uint8) + encoder.feed_frame(f"{OBS_IMAGES}.cam", frame) + + results = encoder.finish_episode() + mp4_path, _ = results[f"{OBS_IMAGES}.cam"] + + expected_duration = num_frames / 30.0 # 3.0 seconds + + with av.open(str(mp4_path)) as container: + stream = container.streams.video[0] + total_frames = sum(1 for _ in container.decode(stream)) + if stream.duration is not None: + actual_duration = float(stream.duration * stream.time_base) + else: + actual_duration = float(container.duration / av.time_base) + + assert total_frames == num_frames + # Allow small tolerance for duration due to codec framing + assert abs(actual_duration - expected_duration) < 0.5, ( + f"Video duration {actual_duration:.2f}s != expected {expected_duration:.2f}s" + ) + + encoder.close() + + def test_multi_camera_start_episode_called_once(self, tmp_path): + """Test that with multiple cameras, no frames are lost due to double start_episode.""" + encoder = StreamingVideoEncoder(fps=30, vcodec="libsvtav1", pix_fmt="yuv420p", g=2, crf=30) + + video_keys = [f"{OBS_IMAGES}.cam1", f"{OBS_IMAGES}.cam2"] + encoder.start_episode(video_keys, tmp_path) + + num_frames = 30 + for _ in range(num_frames): + frame0 = np.random.randint(0, 255, (64, 96, 3), dtype=np.uint8) + frame1 = np.random.randint(0, 255, (64, 96, 3), dtype=np.uint8) + encoder.feed_frame(video_keys[0], frame0) + encoder.feed_frame(video_keys[1], frame1) + + results = encoder.finish_episode() + + # Both cameras should have all frames + for key in video_keys: + mp4_path, stats = results[key] + assert mp4_path.exists() + with av.open(str(mp4_path)) as container: + stream = container.streams.video[0] + total_frames = sum(1 for _ in container.decode(stream)) + assert total_frames == num_frames, ( + f"Camera {key}: expected {num_frames} frames, got {total_frames}" + ) + + encoder.close() + + def test_encoder_threads_passed_to_thread(self, tmp_path): + """Test that encoder_threads is stored and passed through to encoder threads.""" + encoder = StreamingVideoEncoder( + fps=30, vcodec="libsvtav1", pix_fmt="yuv420p", g=2, crf=30, encoder_threads=2 + ) + assert encoder.encoder_threads == 2 + + video_keys = [f"{OBS_IMAGES}.cam"] + encoder.start_episode(video_keys, tmp_path) + + # Verify the thread received the encoder_threads value + thread = encoder._threads[f"{OBS_IMAGES}.cam"] + assert thread.encoder_threads == 2 + + # Feed some frames and finish to ensure it works end-to-end + num_frames = 10 + for _ in range(num_frames): + frame = np.random.randint(0, 255, (64, 96, 3), dtype=np.uint8) + encoder.feed_frame(f"{OBS_IMAGES}.cam", frame) + + results = encoder.finish_episode() + mp4_path, stats = results[f"{OBS_IMAGES}.cam"] + assert mp4_path.exists() + assert stats is not None + + with av.open(str(mp4_path)) as container: + stream = container.streams.video[0] + total_frames = sum(1 for _ in container.decode(stream)) + assert total_frames == num_frames + + encoder.close() + + def test_encoder_threads_none_by_default(self, tmp_path): + """Test that encoder_threads defaults to None (codec auto-detect).""" + encoder = StreamingVideoEncoder(fps=30, vcodec="libsvtav1", pix_fmt="yuv420p") + assert encoder.encoder_threads is None + encoder.close() + + def test_graceful_frame_dropping(self, tmp_path): + """Test that full queue drops frames instead of crashing.""" + encoder = StreamingVideoEncoder( + fps=30, vcodec="libsvtav1", pix_fmt="yuv420p", g=2, crf=30, preset=13, queue_maxsize=1 + ) + video_keys = [f"{OBS_IMAGES}.cam"] + encoder.start_episode(video_keys, tmp_path) + + # Feed many frames quickly - with queue_maxsize=1, some will be dropped + num_frames = 50 + for _ in range(num_frames): + frame = np.random.randint(0, 255, (64, 96, 3), dtype=np.uint8) + encoder.feed_frame(f"{OBS_IMAGES}.cam", frame) + + # Should not raise - frames are dropped gracefully + results = encoder.finish_episode() + assert f"{OBS_IMAGES}.cam" in results + + mp4_path, _ = results[f"{OBS_IMAGES}.cam"] + assert mp4_path.exists() + + # Some frames should have been dropped (queue was tiny) + dropped = encoder._dropped_frames.get(f"{OBS_IMAGES}.cam", 0) + # We can't guarantee drops but can verify no crash occurred + assert dropped >= 0 + + encoder.close() + + +# ─── Integration tests with LeRobotDataset ─── + + +class TestStreamingEncoderIntegration: + def test_add_frame_save_episode_streaming(self, tmp_path): + """Full integration test: add_frame -> save_episode with streaming encoding.""" + from lerobot.datasets.lerobot_dataset import LeRobotDataset + + features = { + "observation.images.cam": { + "dtype": "video", + "shape": (64, 96, 3), + "names": ["height", "width", "channels"], + }, + "action": {"dtype": "float32", "shape": (6,), "names": ["j1", "j2", "j3", "j4", "j5", "j6"]}, + } + + dataset = LeRobotDataset.create( + repo_id="test/streaming", + fps=30, + features=features, + root=tmp_path / "streaming_test", + use_videos=True, + streaming_encoding=True, + ) + + assert dataset._streaming_encoder is not None + + num_frames = 20 + for _ in range(num_frames): + frame = { + "observation.images.cam": np.random.randint(0, 255, (64, 96, 3), dtype=np.uint8), + "action": np.random.randn(6).astype(np.float32), + "task": "test task", + } + dataset.add_frame(frame) + + dataset.save_episode() + + # Verify dataset metadata + assert dataset.meta.total_episodes == 1 + assert dataset.meta.total_frames == num_frames + + # Verify stats exist for the video key + assert dataset.meta.stats is not None + assert "observation.images.cam" in dataset.meta.stats + assert "action" in dataset.meta.stats + + dataset.finalize() + + def test_streaming_disabled_creates_pngs(self, tmp_path): + """Test that disabling streaming encoding falls back to PNG path.""" + from lerobot.datasets.lerobot_dataset import LeRobotDataset + + features = { + "observation.images.cam": { + "dtype": "video", + "shape": (64, 96, 3), + "names": ["height", "width", "channels"], + }, + "action": {"dtype": "float32", "shape": (6,), "names": ["j1", "j2", "j3", "j4", "j5", "j6"]}, + } + + dataset = LeRobotDataset.create( + repo_id="test/no_streaming", + fps=30, + features=features, + root=tmp_path / "no_streaming_test", + use_videos=True, + streaming_encoding=False, + ) + + assert dataset._streaming_encoder is None + + num_frames = 5 + for _ in range(num_frames): + frame = { + "observation.images.cam": np.random.randint(0, 255, (64, 96, 3), dtype=np.uint8), + "action": np.random.randn(6).astype(np.float32), + "task": "test task", + } + dataset.add_frame(frame) + + # With streaming disabled, PNG files should be written + images_dir = dataset.root / "images" + assert images_dir.exists() + + dataset.save_episode() + dataset.finalize() + + def test_multi_episode_streaming(self, tmp_path): + """Test recording multiple episodes with streaming encoding.""" + from lerobot.datasets.lerobot_dataset import LeRobotDataset + + features = { + "observation.images.cam": { + "dtype": "video", + "shape": (64, 96, 3), + "names": ["height", "width", "channels"], + }, + "action": {"dtype": "float32", "shape": (2,), "names": ["j1", "j2"]}, + } + + dataset = LeRobotDataset.create( + repo_id="test/multi_ep", + fps=30, + features=features, + root=tmp_path / "multi_ep_test", + use_videos=True, + streaming_encoding=True, + ) + + for ep in range(3): + num_frames = 10 + ep * 5 + for _ in range(num_frames): + frame = { + "observation.images.cam": np.random.randint(0, 255, (64, 96, 3), dtype=np.uint8), + "action": np.random.randn(2).astype(np.float32), + "task": f"task_{ep}", + } + dataset.add_frame(frame) + dataset.save_episode() + + assert dataset.meta.total_episodes == 3 + assert dataset.meta.total_frames == 10 + 15 + 20 + + dataset.finalize() + + def test_clear_episode_buffer_cancels_streaming(self, tmp_path): + """Test that clearing episode buffer cancels streaming encoding.""" + from lerobot.datasets.lerobot_dataset import LeRobotDataset + + features = { + "observation.images.cam": { + "dtype": "video", + "shape": (64, 96, 3), + "names": ["height", "width", "channels"], + }, + "action": {"dtype": "float32", "shape": (2,), "names": ["j1", "j2"]}, + } + + dataset = LeRobotDataset.create( + repo_id="test/cancel", + fps=30, + features=features, + root=tmp_path / "cancel_test", + use_videos=True, + streaming_encoding=True, + ) + + # Add some frames + for _ in range(5): + frame = { + "observation.images.cam": np.random.randint(0, 255, (64, 96, 3), dtype=np.uint8), + "action": np.random.randn(2).astype(np.float32), + "task": "task", + } + dataset.add_frame(frame) + + # Cancel and re-record + dataset.clear_episode_buffer() + + # Record a new episode + for _ in range(10): + frame = { + "observation.images.cam": np.random.randint(0, 255, (64, 96, 3), dtype=np.uint8), + "action": np.random.randn(2).astype(np.float32), + "task": "task", + } + dataset.add_frame(frame) + dataset.save_episode() + + assert dataset.meta.total_episodes == 1 + assert dataset.meta.total_frames == 10 + + dataset.finalize() + + def test_multi_camera_add_frame_streaming(self, tmp_path): + """Test that start_episode is called once with multiple video keys.""" + from lerobot.datasets.lerobot_dataset import LeRobotDataset + + features = { + "observation.images.cam1": { + "dtype": "video", + "shape": (64, 96, 3), + "names": ["height", "width", "channels"], + }, + "observation.images.cam2": { + "dtype": "video", + "shape": (64, 96, 3), + "names": ["height", "width", "channels"], + }, + "action": {"dtype": "float32", "shape": (2,), "names": ["j1", "j2"]}, + } + + dataset = LeRobotDataset.create( + repo_id="test/multi_cam", + fps=30, + features=features, + root=tmp_path / "multi_cam_test", + use_videos=True, + streaming_encoding=True, + ) + + num_frames = 15 + for _ in range(num_frames): + frame = { + "observation.images.cam1": np.random.randint(0, 255, (64, 96, 3), dtype=np.uint8), + "observation.images.cam2": np.random.randint(0, 255, (64, 96, 3), dtype=np.uint8), + "action": np.random.randn(2).astype(np.float32), + "task": "test task", + } + dataset.add_frame(frame) + + dataset.save_episode() + + assert dataset.meta.total_episodes == 1 + assert dataset.meta.total_frames == num_frames + + dataset.finalize() From a0c5d193919cccc1125c7d03f4af94013d400b9d Mon Sep 17 00:00:00 2001 From: Yueci Deng Date: Mon, 23 Feb 2026 23:32:59 +0800 Subject: [PATCH 052/131] add metadata_buffer_size to dataset creation (#2998) Signed-off-by: Steven Palma Co-authored-by: Steven Palma --- src/lerobot/datasets/lerobot_dataset.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index 65b475e26..83d452a44 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -1653,6 +1653,7 @@ class LeRobotDataset(torch.utils.data.Dataset): video_backend: str | None = None, batch_encoding_size: int = 1, vcodec: str = "libsvtav1", + metadata_buffer_size: int = 10, streaming_encoding: bool = False, encoder_queue_maxsize: int = 30, encoder_threads: int | None = None, @@ -1667,6 +1668,7 @@ class LeRobotDataset(torch.utils.data.Dataset): features=features, root=root, use_videos=use_videos, + metadata_buffer_size=metadata_buffer_size, ) obj.repo_id = obj.meta.repo_id obj.root = obj.meta.root From 544cbc5f3874d20f7d0f388ef318dc4838c0906f Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Mon, 23 Feb 2026 16:39:04 +0100 Subject: [PATCH 053/131] feat(motors): add RobStride CAN implementation (#2821) * feat(motors): add initial implementation of robstride Co-authored-by: Virgile * chore(motors): solve some linter * remove kp/kd attribute * code uniformisation between damiao and robstride * remove normalization warning * remove non valid baudrates and small docstring update * remove all useless files. Only keeping robstride.py and table.py * typing for mypy * reduce NameOrId usage * align signature with damiao * put the same helper than in the damiao implementation * bug correction : expect a response after each bus.send --------- Co-authored-by: Virgile --- pyproject.toml | 4 +- src/lerobot/motors/robstride/__init__.py | 18 + src/lerobot/motors/robstride/robstride.py | 1003 +++++++++++++++++++++ src/lerobot/motors/robstride/tables.py | 120 +++ src/lerobot/scripts/lerobot_setup_can.py | 1 + 5 files changed, 1145 insertions(+), 1 deletion(-) create mode 100644 src/lerobot/motors/robstride/__init__.py create mode 100644 src/lerobot/motors/robstride/robstride.py create mode 100644 src/lerobot/motors/robstride/tables.py diff --git a/pyproject.toml b/pyproject.toml index 0ca1f0432..ea3df4a6d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -98,11 +98,13 @@ pygame-dep = ["pygame>=2.5.1,<2.7.0"] placo-dep = ["placo>=0.9.6,<0.10.0"] transformers-dep = ["transformers>=4.57.1,<5.0.0"] grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"] +can-dep = ["python-can>=4.2.0,<5.0.0"] # Motors feetech = ["feetech-servo-sdk>=1.0.0,<2.0.0"] dynamixel = ["dynamixel-sdk>=3.7.31,<3.9.0"] -damiao = ["python-can>=4.2.0,<5.0.0"] +damiao = ["lerobot[can-dep]"] +robstride = ["lerobot[can-dep]"] # Robots openarms = ["lerobot[damiao]"] diff --git a/src/lerobot/motors/robstride/__init__.py b/src/lerobot/motors/robstride/__init__.py new file mode 100644 index 000000000..7933ac6fa --- /dev/null +++ b/src/lerobot/motors/robstride/__init__.py @@ -0,0 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .robstride import RobstrideMotorsBus +from .tables import * diff --git a/src/lerobot/motors/robstride/robstride.py b/src/lerobot/motors/robstride/robstride.py new file mode 100644 index 000000000..f47e41509 --- /dev/null +++ b/src/lerobot/motors/robstride/robstride.py @@ -0,0 +1,1003 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# TODO(Virgile) : Robustify mode control , only the MIT protocole is implemented for now + +import logging +import time +from contextlib import contextmanager +from copy import deepcopy +from functools import cached_property +from types import SimpleNamespace +from typing import TYPE_CHECKING, Any, TypedDict + +from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected +from lerobot.utils.import_utils import _can_available + +if TYPE_CHECKING or _can_available: + import can +else: + can = SimpleNamespace(Message=object, interface=None) +import numpy as np + +from lerobot.utils.errors import DeviceNotConnectedError +from lerobot.utils.utils import enter_pressed, move_cursor_up + +from ..motors_bus import Motor, MotorCalibration, MotorsBusBase, NameOrID, Value +from .tables import ( + AVAILABLE_BAUDRATES, + CAN_CMD_CLEAR_FAULT, + CAN_CMD_DISABLE, + CAN_CMD_ENABLE, + CAN_CMD_SET_ZERO, + DEFAULT_BAUDRATE, + DEFAULT_TIMEOUT_MS, + MODEL_RESOLUTION, + MOTOR_LIMIT_PARAMS, + NORMALIZED_DATA, + PARAM_TIMEOUT, + RUNNING_TIMEOUT, + STATE_CACHE_TTL_S, + ControlMode, + MotorType, +) + +logger = logging.getLogger(__name__) + + +class MotorState(TypedDict): + position: float + velocity: float + torque: float + temp_mos: float + temp_rotor: float + + +class RobstrideMotorsBus(MotorsBusBase): + """ + The Robstride implementation for a MotorsBus using CAN bus communication. + + This class uses python-can for CAN bus communication with Robstride motors. + The motors need to be switched to MIT control mode to be compatible with this implementation. + More details on the protocol can be found in the documentation links below: + - python-can documentation: https://python-can.readthedocs.io/en/stable/ + - Robstride CAN protocol: https://github.com/RobStride/MotorStudio + """ + + # CAN-specific settings + available_baudrates = deepcopy(AVAILABLE_BAUDRATES) + default_baudrate = DEFAULT_BAUDRATE + default_timeout = DEFAULT_TIMEOUT_MS + + # Motor configuration + model_resolution_table = deepcopy(MODEL_RESOLUTION) + normalized_data = deepcopy(NORMALIZED_DATA) + + def __init__( + self, + port: str, + motors: dict[str, Motor], + calibration: dict[str, MotorCalibration] | None = None, + can_interface: str = "auto", + use_can_fd: bool = True, + bitrate: int = 1000000, + data_bitrate: int | None = 5000000, + ): + """ + Initialize the Robstride motors bus. + + Args: + port: CAN interface name (e.g., "can0" for Linux, "/dev/cu.usbmodem*" for macOS) + motors: Dictionary mapping motor names to Motor objects + calibration: Optional calibration data + can_interface: CAN interface type - "auto" (default), "socketcan" (Linux), or "slcan" (macOS/serial) + use_can_fd: Whether to use CAN FD mode (default: True for OpenArms) + bitrate: Nominal bitrate in bps (default: 1000000 = 1 Mbps) + data_bitrate: Data bitrate for CAN FD in bps (default: 5000000 = 5 Mbps), ignored if use_can_fd is False + """ + super().__init__(port, motors, calibration) + self.port = port + self.can_interface = can_interface + self.use_can_fd = use_can_fd + self.bitrate = bitrate + self.data_bitrate = data_bitrate + self.canbus: can.BusABC | None = None + self._is_connected = False + + # Map motor names to CAN IDs + self._motor_can_ids: dict[str, int] = {} + self._recv_id_to_motor: dict[int, str] = {} + + # Store motor types and recv IDs + self._motor_types: dict[str, MotorType] = {} + # Dynamic gains storage (Damiao-style update path via write/sync_write) + self._gains: dict[str, dict[str, float]] = {} + for name, motor in self.motors.items(): + if motor.motor_type_str is not None: + self._motor_types[name] = getattr(MotorType, motor.motor_type_str.upper()) + else: + # Default to O0if not specified + self._motor_types[name] = MotorType.O0 + + # Damiao-style defaults: fixed gains at startup for every motor. + self._gains[name] = {"kp": 10.0, "kd": 0.5} + + # Map recv_id to motor name for filtering responses + if motor.recv_id is not None: + self._recv_id_to_motor[motor.recv_id] = name + # Motor Mode + self.enabled: dict[str, bool] = {} + self.operation_mode: dict[str, ControlMode] = {} + self._last_known_states: dict[str, MotorState] = { + name: { + "position": 0.0, + "velocity": 0.0, + "torque": 0.0, + "temp_mos": 0.0, + "temp_rotor": 0.0, + } + for name in self.motors + } + self.last_feedback_time: dict[str, float | None] = {} + self._id_to_name: dict[int, str] = {} + for name in self.motors: + self.enabled[name] = False + self.operation_mode[name] = ControlMode.MIT # default mode + self.last_feedback_time[name] = None + + for name, motor in self.motors.items(): + key = motor.recv_id if motor.recv_id is not None else motor.id + self._id_to_name[key] = name + + @property + def is_connected(self) -> bool: + """Check if the CAN bus is connected.""" + return self._is_connected and self.canbus is not None + + def _bus(self) -> can.BusABC: + if self.canbus is None: + raise DeviceNotConnectedError(f"{self.__class__.__name__}('{self.port}') is not connected.") + return self.canbus + + @check_if_already_connected + def connect(self, handshake: bool = True) -> None: + """ + Open the CAN bus and initialize communication. + + Args: + handshake: If True, ping all motors to verify they're present + """ + try: + # Auto-detect interface type based on port name + if self.can_interface == "auto": + if self.port.startswith("/dev/"): + self.can_interface = "slcan" + logger.info(f"Auto-detected slcan interface for port {self.port}") + else: + self.can_interface = "socketcan" + logger.info(f"Auto-detected socketcan interface for port {self.port}") + + kwargs = { + "channel": self.port, + "bitrate": self.bitrate, + "interface": self.can_interface, + } + + if self.can_interface == "socketcan" and self.use_can_fd and self.data_bitrate is not None: + kwargs.update({"data_bitrate": self.data_bitrate, "fd": True}) + logger.info( + f"Connected to {self.port} with CAN FD (bitrate={self.bitrate}, data_bitrate={self.data_bitrate})" + ) + else: + logger.info(f"Connected to {self.port} with {self.can_interface} (bitrate={self.bitrate})") + + self.canbus = can.interface.Bus(**kwargs) + + self._is_connected = True + + if handshake: + self._handshake() + + logger.debug(f"{self.__class__.__name__} connected via {self.can_interface}.") + except Exception as e: + self._is_connected = False + raise ConnectionError(f"Failed to connect to CAN bus: {e}") from e + + def _query_status_via_clear_fault(self, motor: NameOrID) -> tuple[bool, can.Message | None]: + motor_name = self._get_motor_name(motor) + motor_id = self._get_motor_id(motor_name) + recv_id = self._get_motor_recv_id(motor_name) + data = [0xFF] * 7 + [CAN_CMD_CLEAR_FAULT] + msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False) + self._bus().send(msg) + return self._recv_status_via_clear_fault(expected_recv_id=recv_id) + + def _recv_status_via_clear_fault( + self, expected_recv_id: int | None = None, timeout: float = RUNNING_TIMEOUT + ) -> tuple[bool, can.Message | None]: + """ + Poll the bus for a response to a fault-clear request. + + Args: + expected_recv_id: Only accept frames from this CAN ID when provided. + timeout: Maximum time spent polling the bus in seconds. + + Returns: + Tuple where the first element is True if a fault frame was received, + and the second element is the CAN message (or None on timeout). + """ + start_time = time.time() + + while time.time() - start_time < timeout: + msg = self._bus().recv(timeout=RUNNING_TIMEOUT / 10) + if not msg: + continue + + if expected_recv_id is not None and msg.data[0] != expected_recv_id: + continue + + # Fault-status frame heuristic (doc-based) + fault_bits = int.from_bytes(msg.data[1:5], "little") + if fault_bits != 0 and msg.data[5] == msg.data[6] == msg.data[7] == 0: + logger.error( + f"Motor fault received from CAN ID 0x{msg.arbitration_id:02X}: " + f"fault_bits=0x{fault_bits:08X}" + ) + return True, msg + + # Otherwise: valid normal response + return False, msg + + return False, None + + def update_motor_state(self, motor: NameOrID) -> bool: + has_fault, msg = self._query_status_via_clear_fault(motor) + if msg is None: + logger.warning(f"No response received from motor '{motor}' during state update.") + raise ConnectionError(f"No response received from motor '{motor}' during state update.") + if has_fault: + logger.error(f"Fault reported by motor '{motor}' during state update. msg={msg.data.hex()}") + raise RuntimeError(f"Fault reported by motor '{motor}' during state update.") + + self._decode_motor_state(msg.data) # updates cache + return True + + def _handshake(self) -> None: + logger.info("Starting handshake with motors...") + missing_motors = [] + faulted_motors = [] + + for motor_name in self.motors: + has_fault, msg = self._query_status_via_clear_fault(motor_name) + if msg is None: + missing_motors.append(motor_name) + elif has_fault: + faulted_motors.append(motor_name) + else: + # CLEAR_FAULT responses are not guaranteed to always match the MIT feedback layout + # on all firmware versions. Handshake should not fail just because cache warm-up fails. + try: + self._decode_motor_state(msg.data) + except Exception as e: + logger.debug( + "Handshake cache warm-up decode failed for motor '%s': %s", + motor_name, + e, + ) + time.sleep(0.01) + + if missing_motors or faulted_motors: + details = [] + if missing_motors: + details.append(f"did not respond: {missing_motors}") + if faulted_motors: + details.append(f"reported fault: {faulted_motors}") + raise ConnectionError("Handshake failed. " + "; ".join(details)) + + logger.info("Handshake successful. All motors ready.") + + def _switch_operation_mode(self, motor: NameOrID, mode: ControlMode) -> None: + """Switch the operation mode of a motor.""" + motor_name = self._get_motor_name(motor) + motor_id = self._get_motor_id(motor_name) + recv_id = self._get_motor_recv_id(motor_name) + data = [0xFF] * 8 + data[6] = mode.value + data[7] = 0xFC + msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False) + self._bus().send(msg) + msg = self._recv_motor_response(expected_recv_id=recv_id, timeout=PARAM_TIMEOUT) + if msg is not None: + self.operation_mode[motor_name] = mode + + @check_if_not_connected + def disconnect(self, disable_torque: bool = True) -> None: + """ + Close the CAN bus connection. + + Args: + disable_torque: If True, disable torque on all motors before disconnecting + """ + if disable_torque: + try: + self.disable_torque() + except Exception as e: + logger.warning(f"Failed to disable torque during disconnect: {e}") + + if self.canbus: + self.canbus.shutdown() + self.canbus = None + self._is_connected = False + logger.debug(f"{self.__class__.__name__} disconnected.") + + def configure_motors(self) -> None: + """Configure all motors with default settings.""" + # Robstride motors don't require much configuration in MIT mode + # Just ensure they're enabled + for motor in self.motors: + self._enable_motor(self._get_motor_name(motor)) + self._switch_operation_mode(motor, ControlMode.MIT) + time.sleep(0.01) + + def switch_to_mode(self, mode: ControlMode) -> None: + """Switch operation mode on selected motors.""" + for motor in self.motors: + self._switch_operation_mode(motor, mode) + time.sleep(0.01) + + def _enable_motor(self, motor: NameOrID) -> None: + """Enable a single motor.""" + motor_id = self._get_motor_id(motor) + recv_id = self._get_motor_recv_id(motor) + data = [0xFF] * 7 + [CAN_CMD_ENABLE] + msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False) + self._bus().send(msg) + self._recv_motor_response(expected_recv_id=recv_id, timeout=PARAM_TIMEOUT) + + def _disable_motor(self, motor: NameOrID) -> None: + """Disable a single motor.""" + motor_id = self._get_motor_id(motor) + recv_id = self._get_motor_recv_id(motor) + data = [0xFF] * 7 + [CAN_CMD_DISABLE] + msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False) + self._bus().send(msg) + self._recv_motor_response(expected_recv_id=recv_id) + + def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None: + """Enable torque on selected motors.""" + motors = self._get_motors_list(motors) + for motor in motors: + for _ in range(num_retry + 1): + try: + self._get_motor_name(motor) + self._enable_motor(self._get_motor_name(motor)) + break + except Exception as e: + if _ == num_retry: + raise e + time.sleep(0.01) + + def disable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None: + """Disable torque on selected motors.""" + motors = self._get_motors_list(motors) + for motor in motors: + for _ in range(num_retry + 1): + try: + self._disable_motor(self._get_motor_name(motor)) + break + except Exception as e: + if _ == num_retry: + raise e + time.sleep(0.01) + + @contextmanager + def torque_disabled(self, motors: str | list[str] | None = None): + """ + Context manager that guarantees torque is re-enabled. + + This helper is useful to temporarily disable torque when configuring motors. + + Examples: + >>> with bus.torque_disabled(): + ... # Safe operations here with torque disabled + ... pass + """ + self.disable_torque(motors) + try: + yield + finally: + self.enable_torque(motors) + + def set_zero_position(self, motors: str | list[str] | None = None) -> None: + """Set current position as zero for selected motors.""" + motors = self._get_motors_list(motors) + for motor in motors: + motor_id = self._get_motor_id(motor) + recv_id = self._get_motor_recv_id(motor) + data = [0xFF] * 7 + [CAN_CMD_SET_ZERO] + msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False) + self._bus().send(msg) + self._recv_motor_response(expected_recv_id=recv_id) + time.sleep(0.01) + + def _recv_motor_response( + self, expected_recv_id: int | None = None, timeout: float = 0.001 + ) -> can.Message | None: + """ + Receive a response from a motor. + + Args: + expected_recv_id: If provided, only return messages from this CAN ID + timeout: Timeout in seconds (default: 1ms for high-speed operation) + + Returns: + CAN message if received, None otherwise + """ + try: + start_time = time.time() + messages_seen = [] + while time.time() - start_time < timeout: + msg = self._bus().recv(timeout=RUNNING_TIMEOUT / 10) # 100us timeout for fast polling + if msg: + messages_seen.append(f"0x{msg.arbitration_id:02X}") + # If no filter specified, return any message + if expected_recv_id is None: + return msg + # Otherwise, only return if it matches the expected recv_id + if msg.data[0] == expected_recv_id: + return msg + else: + logger.debug( + f"Ignoring message from CAN ID 0x{msg.arbitration_id:02X}, expected 0x{expected_recv_id:02X}" + ) + + # Only log warnings if we're in debug mode to reduce overhead + if logger.isEnabledFor(logging.DEBUG): + if messages_seen: + logger.debug( + f"Received {len(messages_seen)} message(s) from IDs {set(messages_seen)}, but expected 0x{expected_recv_id:02X}" + ) + else: + logger.debug(f"No CAN messages received (expected from 0x{expected_recv_id:02X})") + except Exception as e: + logger.debug(f"Failed to receive CAN message: {e}") + return None + + def _recv_all_responses( + self, expected_recv_ids: list[int], timeout: float = 0.002 + ) -> dict[int, can.Message]: + """ + Efficiently receive responses from multiple motors at once. + Uses the OpenArms pattern: collect all available messages within timeout. + + Args: + expected_recv_ids: List of CAN IDs we expect responses from + timeout: Total timeout in seconds (default: 2ms) + + Returns: + Dictionary mapping recv_id to CAN message + """ + responses: dict[int, can.Message] = {} + expected_set = set(expected_recv_ids) + start_time = time.time() + + try: + while len(responses) < len(expected_recv_ids) and (time.time() - start_time) < timeout: + msg = self._bus().recv(timeout=RUNNING_TIMEOUT / 10) # 100us poll timeout + if msg and msg.data[0] in expected_set: + responses[msg.data[0]] = msg + if len(responses) == len(expected_recv_ids): + break # Got all responses, exit early + except Exception as e: + logger.debug(f"Error receiving responses: {e}") + + return responses + + def _speed_control( + self, + motor: NameOrID, + velocity_deg_per_sec: float, + current_limit_a: float, + ) -> None: + """ + Send a Velocity Mode Control Command (Command 11) to a single motor. + + Args: + motor: Motor name or CAN ID. + velocity_rad_per_sec: Target speed in rad/s (32-bit float). + current_limit_a: Current limit in A (32-bit float). + """ + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + motor_id = self._get_motor_id(motor) + motor_name = self._get_motor_name(motor) + # Optional: ensure the motor is in velocity control mode + + if self.operation_mode[motor_name] != ControlMode.VEL: + raise RuntimeError(f"Motor '{motor_name}' is not in velocity control mode.") + # Convert to rad/s to match protocol specification + + velocity_rad_per_sec = np.radians(velocity_deg_per_sec) + + # Encode float32 little-endian without struct (byte list) + def _float32_to_le_bytes(x: float) -> list[int]: + b = np.float32(x).tobytes() # 4 bytes, little-endian + return [b[0], b[1], b[2], b[3]] + + speed_bytes = _float32_to_le_bytes(velocity_rad_per_sec) + limit_bytes = _float32_to_le_bytes(current_limit_a) + + data = speed_bytes + limit_bytes # 8 octets : [0–3]=speed, [4–7]=current limit + + msg = can.Message( + arbitration_id=motor_id, + data=data, + is_extended_id=False, + ) + self._bus().send(msg) + + # Si le proto renvoie une réponse type état, on peut la décoder comme pour MIT + recv_id = self._get_motor_recv_id(motor) + if recv_id is not None: + resp = self._recv_motor_response(expected_recv_id=recv_id) + if resp: + self._decode_motor_state(resp.data) + + def _mit_control( + self, + motor: NameOrID, + kp: float, + kd: float, + position_degrees: float, + velocity_deg_per_sec: float, + torque: float, + *, + wait_for_response: bool = True, + ) -> None: + """ + Send MIT control command to a motor. + + Args: + motor: Motor name or ID + kp: Position gain + kd: Velocity gain + position_degrees: Target position (degrees) + velocity_deg_per_sec: Target velocity (degrees/s) + torque: Target torque (N·m) + """ + motor_name = self._get_motor_name(motor) + motor_type = self._motor_types[motor_name] + if self.operation_mode[motor_name] != ControlMode.MIT: + raise RuntimeError(f"Motor '{motor_name}' is not in MIT control mode.") + motor_id = self._get_motor_id(motor) + data = self._encode_mit_packet(motor_type, kp, kd, position_degrees, velocity_deg_per_sec, torque) + msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False) + self._bus().send(msg) + + if wait_for_response: + recv_id = self._get_motor_recv_id(motor) + msg = self._recv_motor_response(expected_recv_id=recv_id) + if msg: + self._process_response(motor_name, msg) + + def _encode_mit_packet( + self, + motor_type: MotorType, + kp: float, + kd: float, + position_degrees: float, + velocity_deg_per_sec: float, + torque: float, + ) -> list[int]: + """Encode an MIT control command payload from physical units.""" + position_rad = np.radians(position_degrees) + velocity_rad_per_sec = np.radians(velocity_deg_per_sec) + pmax, vmax, tmax = MOTOR_LIMIT_PARAMS[motor_type] + + kp_uint = self._float_to_uint(kp, 0, 500, 12) + kd_uint = self._float_to_uint(kd, 0, 5, 12) + q_uint = self._float_to_uint(position_rad, -pmax, pmax, 16) + dq_uint = self._float_to_uint(velocity_rad_per_sec, -vmax, vmax, 12) + tau_uint = self._float_to_uint(torque, -tmax, tmax, 12) + + data = [0] * 8 + data[0] = (q_uint >> 8) & 0xFF + data[1] = q_uint & 0xFF + data[2] = dq_uint >> 4 + data[3] = ((dq_uint & 0xF) << 4) | ((kp_uint >> 8) & 0xF) + data[4] = kp_uint & 0xFF + data[5] = kd_uint >> 4 + data[6] = ((kd_uint & 0xF) << 4) | ((tau_uint >> 8) & 0xF) + data[7] = tau_uint & 0xFF + return data + + def _mit_control_batch( + self, + commands: dict[NameOrID, tuple[float, float, float, float, float]], + ) -> None: + """Send MIT commands in batch and update cache from collected responses.""" + if not commands: + return + + recv_id_to_motor: dict[int, str] = {} + for motor, (kp, kd, position_degrees, velocity_deg_per_sec, torque) in commands.items(): + motor_name = self._get_motor_name(motor) + if self.operation_mode[motor_name] != ControlMode.MIT: + raise RuntimeError(f"Motor '{motor_name}' is not in MIT control mode.") + + motor_id = self._get_motor_id(motor) + motor_type = self._motor_types[motor_name] + data = self._encode_mit_packet(motor_type, kp, kd, position_degrees, velocity_deg_per_sec, torque) + msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False) + self._bus().send(msg) + recv_id_to_motor[self._get_motor_recv_id(motor)] = motor_name + + responses = self._recv_all_responses(list(recv_id_to_motor.keys()), timeout=RUNNING_TIMEOUT) + for recv_id, motor_name in recv_id_to_motor.items(): + if msg := responses.get(recv_id): + self._process_response(motor_name, msg) + + def _float_to_uint(self, x: float, x_min: float, x_max: float, bits: int) -> int: + """Convert float to unsigned integer for CAN transmission.""" + x = max(x_min, min(x_max, x)) # Clamp to range + span = x_max - x_min + data_norm = (x - x_min) / span + return int(data_norm * ((1 << bits) - 1)) + + def _uint_to_float(self, x: int, x_min: float, x_max: float, bits: int) -> float: + """Convert unsigned integer from CAN to float.""" + span = x_max - x_min + data_norm = float(x) / ((1 << bits) - 1) + return data_norm * span + x_min + + def _decode_motor_state(self, data: bytearray | bytes) -> tuple[float, float, float, float]: + """ + Decode motor state from CAN data. + + Returns: + Tuple of (position_degrees, velocity_deg_per_sec, torque, temp_mos) + """ + if len(data) < 8: + raise ValueError("Invalid motor state data") + + # Extract encoded values + motor_id = data[0] + motor_name = self._id_to_name[motor_id] + q_uint = (data[1] << 8) | data[2] + dq_uint = (data[3] << 4) | (data[4] >> 4) + tau_uint = ((data[4] & 0x0F) << 8) | data[5] + t_mos = (data[6] << 8) | data[7] + + motor_type = self._motor_types[motor_name] + # Get motor limits + pmax, vmax, tmax = MOTOR_LIMIT_PARAMS[motor_type] + + # Decode to physical values (radians) + position_rad = self._uint_to_float(q_uint, -pmax, pmax, 16) + velocity_rad_per_sec = self._uint_to_float(dq_uint, -vmax, vmax, 12) + torque = self._uint_to_float(tau_uint, -tmax, tmax, 12) + + # Convert to degrees + position_degrees = np.degrees(position_rad) + velocity_deg_per_sec = np.degrees(velocity_rad_per_sec) + + # Update cached state + self.last_feedback_time[motor_name] = time.time() + self._last_known_states[motor_name] = { + "position": position_degrees, + "velocity": velocity_deg_per_sec, + "torque": torque, + "temp_mos": t_mos / 10, + # Not available in Robstride MIT feedback. + "temp_rotor": 0.0, + } + return position_degrees, velocity_deg_per_sec, torque, t_mos / 10 + + def _process_response(self, motor: str, msg: can.Message) -> None: + """Decode a feedback frame and update the cache for one motor.""" + try: + self._decode_motor_state(msg.data) + except Exception as e: + logger.warning(f"Failed to decode response from {motor}: {e}") + + def _get_cached_value(self, motor: str, data_name: str) -> Value: + """Retrieve a specific value from the state cache.""" + state = self._last_known_states[motor] + mapping: dict[str, Any] = { + "Present_Position": state["position"], + "Present_Velocity": state["velocity"], + "Present_Torque": state["torque"], + "Temperature_MOS": state["temp_mos"], + } + if data_name == "Temperature_Rotor": + raise NotImplementedError("Rotor temperature reading not accessible.") + if data_name not in mapping: + raise ValueError(f"Unknown data_name: {data_name}") + return mapping[data_name] + + @check_if_not_connected + def read( + self, + data_name: str, + motor: str, + ) -> Value: + """Read a value from a single motor. Positions are always in degrees.""" + + # Refresh motor to get latest state + t_init = time.time() + if ( + self.last_feedback_time[motor] is None + or t_init - (self.last_feedback_time[motor] or 0) > STATE_CACHE_TTL_S + ): + self.update_motor_state(motor) + + return self._get_cached_value(motor, data_name) + + @check_if_not_connected + def write( + self, + data_name: str, + motor: str, + value: Value, + ) -> None: + """Write a value to a single motor. Positions are always in degrees.""" + motor_name = self._get_motor_name(motor) + + if data_name in ("Kp", "Kd"): + self._gains[motor_name][data_name.lower()] = float(value) + elif data_name == "Goal_Position": + # Use MIT control with position in degrees + kp = self._gains[motor_name]["kp"] + kd = self._gains[motor_name]["kd"] + self._mit_control(motor, kp, kd, value, 0, 0) + elif data_name == "Goal_Velocity": + # Use Velocity control mode + if self.operation_mode[motor_name] != ControlMode.VEL: + raise RuntimeError(f"Motor '{motor_name}' is not in velocity control mode.") + current_limit_a = 5.0 # Example current limit / not specified in doc. This mode is rarely used and primarily intended for diagnostics + self._speed_control(motor, value, current_limit_a) + else: + raise ValueError(f"Writing {data_name} not supported in MIT mode") + + def sync_read( + self, + data_name: str, + motors: str | list[str] | None = None, + ) -> dict[str, Value]: + """ + Read the same value from multiple motors simultaneously. + Uses batched operations: sends all refresh commands, then collects all responses. + This is MUCH faster than sequential reads (OpenArms pattern). + """ + target_motors = self._get_motors_list(motors) + self._batch_refresh(target_motors) + return {motor: self._get_cached_value(motor, data_name) for motor in target_motors} + + @check_if_not_connected + def sync_write( + self, + data_name: str, + values: dict[str, Value], + ) -> None: + """ + Write different values to multiple motors simultaneously. Positions are always in degrees. + Uses batched operations: sends all commands first, then collects responses when MIT mode is used, otherwise send cmd and wait for response for each motor). + """ + if data_name in ("Kp", "Kd"): + key = data_name.lower() + for motor, val in values.items(): + motor_name = self._get_motor_name(motor) + self._gains[motor_name][key] = float(val) + elif data_name == "Goal_Position": + commands: dict[NameOrID, tuple[float, float, float, float, float]] = {} + for motor, value_degrees in values.items(): + motor_name = self._get_motor_name(motor) + commands[motor] = ( + self._gains[motor_name]["kp"], + self._gains[motor_name]["kd"], + float(value_degrees), + 0.0, + 0.0, + ) + self._mit_control_batch(commands) + else: + # Fall back to individual writes for other data types + for motor, value in values.items(): + self.write(data_name, motor, value) + + def sync_read_all_states( + self, + motors: str | list[str] | None = None, + *, + num_retry: int = 0, + ) -> dict[str, MotorState]: + """ + Read ALL motor states (position, velocity, torque) with Robstride TTL refresh policy. + """ + target_motors = self._get_motors_list(motors) + self._batch_refresh(target_motors) + return {motor: self._last_known_states[motor].copy() for motor in target_motors} + + def _batch_refresh(self, motors: list[str]) -> None: + """Refresh a set of motors and update the feedback cache.""" + init_time = time.time() + updated_motors: list[str] = [] + + for motor in motors: + if ( + self.last_feedback_time[motor] is not None + and (init_time - (self.last_feedback_time[motor] or 0)) < STATE_CACHE_TTL_S + ): + continue + motor_id = self._get_motor_id(motor) + data = [0xFF] * 7 + [CAN_CMD_CLEAR_FAULT] + msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False) + self._bus().send(msg) + updated_motors.append(motor) + + expected_recv_ids = [self._get_motor_recv_id(motor) for motor in updated_motors] + responses = self._recv_all_responses(expected_recv_ids, timeout=RUNNING_TIMEOUT) + + for response in responses.values(): + payload_motor_name = self._recv_id_to_motor.get(response.data[0]) + if payload_motor_name is not None: + self._process_response(payload_motor_name, response) + else: + # Fallback: still attempt to decode based on payload byte0 mapping. + self._decode_motor_state(response.data) + + for motor in updated_motors: + recv_id = self._get_motor_recv_id(motor) + if recv_id not in responses: + logger.warning(f"Packet drop: {motor} (ID: 0x{recv_id:02X}). Using last known state.") + + def read_calibration(self) -> dict[str, MotorCalibration]: + """Read calibration data from motors.""" + # Robstride motors don't store calibration internally + # Return existing calibration or empty dict + return self.calibration if self.calibration else {} + + def write_calibration(self, calibration_dict: dict[str, MotorCalibration], cache: bool = True) -> None: + """Write calibration data to motors.""" + # Robstride motors don't store calibration internally + # Just cache it in memory + if cache: + self.calibration = calibration_dict + + def record_ranges_of_motion( + self, motors: str | list[str] | None = None, display_values: bool = True + ) -> tuple[dict[str, Value], dict[str, Value]]: + """ + Interactively record the min/max values of each motor in degrees. + + Move the joints by hand (with torque disabled) while the method streams live positions. + Press Enter to finish. + """ + target_motors = self._get_motors_list(motors) + + # Disable torque for manual movement + self.disable_torque(target_motors) + time.sleep(0.1) + + # Get initial positions (already in degrees) + start_positions = self.sync_read("Present_Position", target_motors) + mins = start_positions.copy() + maxes = start_positions.copy() + + print("\nMove joints through their full range of motion. Press ENTER when done.") + user_pressed_enter = False + + while not user_pressed_enter: + positions = self.sync_read("Present_Position", target_motors) + + for motor in target_motors: + if motor in positions: + mins[motor] = int( + min( + positions[motor], + mins.get(motor, positions[motor]), + ) + ) + maxes[motor] = int( + max( + positions[motor], + maxes.get(motor, positions[motor]), + ) + ) + + if display_values: + print("\n" + "=" * 50) + print(f"{'MOTOR':<20} | {'MIN (deg)':>12} | {'POS (deg)':>12} | {'MAX (deg)':>12}") + print("-" * 50) + for motor in target_motors: + if motor in positions: + print( + f"{motor:<20} | {mins[motor]:>12.1f} | {positions[motor]:>12.1f} | {maxes[motor]:>12.1f}" + ) + + if enter_pressed(): + user_pressed_enter = True + + if display_values and not user_pressed_enter: + # Move cursor up to overwrite the previous output + move_cursor_up(len(target_motors) + 4) + + time.sleep(0.05) + + # Re-enable torque + self.enable_torque(target_motors) + + # Validate ranges + for motor in target_motors: + if (motor in mins) and (motor in maxes) and (abs(maxes[motor] - mins[motor]) < 5.0): + raise ValueError(f"Motor {motor} has insufficient range of motion (< 5 degrees)") + + return mins, maxes + + def _get_motors_list(self, motors: str | list[str] | None) -> list[str]: + """Convert motor specification to list of motor names.""" + if motors is None: + return list(self.motors.keys()) + elif isinstance(motors, str): + return [motors] + elif isinstance(motors, list): + return motors + else: + raise TypeError(f"Invalid motors type: {type(motors)}") + + def _get_motor_id(self, motor: NameOrID) -> int: + """Get CAN ID for a motor.""" + if isinstance(motor, str): + if motor in self.motors: + return self.motors[motor].id + else: + raise ValueError(f"Unknown motor: {motor}") + else: + return motor + + def _get_motor_name(self, motor: NameOrID) -> str: + """Get motor name from name or ID.""" + if isinstance(motor, str): + return motor + else: + for name, m in self.motors.items(): + if m.id == motor: + return name + raise ValueError(f"Unknown motor ID: {motor}") + + def _get_motor_recv_id(self, motor: NameOrID) -> int: + """Return the expected ID found in feedback payload byte0 for this motor. + + Robstride MIT feedback frames encode an ID in data[0]. Some setups expose it as + `motor.recv_id`; otherwise we fall back to the configured `motor.id`. + """ + motor_name = self._get_motor_name(motor) + motor_obj = self.motors[motor_name] + + recv_id = getattr(motor_obj, "recv_id", None) + if recv_id is None: + logger.debug( + "Motor '%s' has no recv_id; falling back to motor.id=%s for feedback demux.", + motor_name, + motor_obj.id, + ) + return motor_obj.id + + return recv_id + + @cached_property + def is_calibrated(self) -> bool: + """Check if motors are calibrated.""" + return bool(self.calibration) diff --git a/src/lerobot/motors/robstride/tables.py b/src/lerobot/motors/robstride/tables.py new file mode 100644 index 000000000..2fc1a97b0 --- /dev/null +++ b/src/lerobot/motors/robstride/tables.py @@ -0,0 +1,120 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Configuration tables for Damiao motors.""" + +from enum import IntEnum + + +# Motor type definitions +class MotorType(IntEnum): + O0 = 0 + O1 = 1 + O2 = 2 + O3 = 3 + O4 = 4 + O5 = 5 + ELO5 = 6 + O6 = 7 + + +class CommMode(IntEnum): + PrivateProtocole = 0 + CANopen = 1 + MIT = 2 + + +# Control modes +class ControlMode(IntEnum): + MIT = 0 + POS_VEL = 1 + VEL = 2 + + +# Motor limit parameters [PMAX, VMAX, TMAX] +# PMAX: Maximum position (rad) +# VMAX: Maximum velocity (rad/s) +# TMAX: Maximum torque (N·m) +MOTOR_LIMIT_PARAMS: dict[MotorType, tuple[float, float, float]] = { + MotorType.O0: (12.57, 33, 14), + MotorType.O1: (12.57, 44, 17), + MotorType.O2: (12.57, 33, 20), + MotorType.O3: (12.57, 33, 60), + MotorType.O4: (12.57, 33, 120), + MotorType.O5: (12.57, 50, 5.5), + MotorType.ELO5: (12.57, 50, 6), + MotorType.O6: (112.5, 50, 36), +} + +# Motor model names +MODEL_NAMES = { + MotorType.O0: "O0", + MotorType.O1: "O1", + MotorType.O2: "O2", + MotorType.O3: "O3", + MotorType.O4: "O4", + MotorType.O5: "O5", + MotorType.ELO5: "ELO5", + MotorType.O6: "O6", +} + +# Motor resolution table (encoder counts per revolution) +MODEL_RESOLUTION = { + "O0": 65536, + "O1": 65536, + "O2": 65536, + "O3": 65536, + "O4": 65536, + "O5": 65536, + "ELO5": 65536, + "O6": 65536, +} + +# CAN baudrates supported by Robstride motors +AVAILABLE_BAUDRATES = [ + 1000000, # 4: 1 mbps (default) +] +DEFAULT_BAUDRATE = 1000000 + +# Default timeout in milliseconds +DEFAULT_TIMEOUT_MS = 0 # disabled by default, otherwise 20000 is 1s + + +# Data that should be normalized +NORMALIZED_DATA = ["Present_Position", "Goal_Position"] + + +# MIT control parameter ranges +MIT_KP_RANGE = (0.0, 500.0) +MIT_KD_RANGE = (0.0, 5.0) + +# CAN frame command IDs +CAN_CMD_ENABLE = 0xFC +CAN_CMD_DISABLE = 0xFD +CAN_CMD_SET_ZERO = 0xFE +CAN_CMD_CLEAR_FAULT = 0xFB + + +CAN_CMD_QUERY_PARAM = 0x33 +CAN_CMD_WRITE_PARAM = 0x55 +CAN_CMD_SAVE_PARAM = 0xAA + +# CAN ID for parameter operations +CAN_PARAM_ID = 0x7FF + + +RUNNING_TIMEOUT = 0.001 +PARAM_TIMEOUT = 0.01 + +STATE_CACHE_TTL_S = 0.02 diff --git a/src/lerobot/scripts/lerobot_setup_can.py b/src/lerobot/scripts/lerobot_setup_can.py index a31727ea4..b28fca44d 100644 --- a/src/lerobot/scripts/lerobot_setup_can.py +++ b/src/lerobot/scripts/lerobot_setup_can.py @@ -152,6 +152,7 @@ def test_motor(bus, motor_id: int, timeout: float, use_fd: bool): ) try: bus.send(disable_msg) + bus.recv(timeout=0.1) # Clear any pending responses except Exception: print(f"Error sending message to motor 0x{motor_id:02X}") From fcabfd32a5a45f0a8210179a6cc8f605815f0dab Mon Sep 17 00:00:00 2001 From: Yuta Nakagawa Date: Tue, 24 Feb 2026 01:11:46 +0900 Subject: [PATCH 054/131] chore(docs): update the document for Phone teleop to clarify how to use the examples (#2991) * update the document for Phone teleope to clarify how to use the examples * Update docs/source/phone_teleop.mdx Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Yuta Nakagawa --------- Signed-off-by: Yuta Nakagawa Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Steven Palma --- docs/source/phone_teleop.mdx | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/docs/source/phone_teleop.mdx b/docs/source/phone_teleop.mdx index 06e524975..678783e7b 100644 --- a/docs/source/phone_teleop.mdx +++ b/docs/source/phone_teleop.mdx @@ -66,12 +66,13 @@ Run on of the examples scripts to teleoperate, record a dataset, replay a datase All scripts assume you configured your robot (e.g., SO-100 follower) and set the correct serial port. -Additionally you need to **copy the urdf of the robot to the examples folder**. For the examples in this tutorial (Using SO100/SO101) it is highly recommended to use the urdf in the [SO-ARM100 repo](https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf) +Additionally you need to **copy the URDF of the robot into the examples folder**. For the examples in this tutorial (using SO100/SO101), copy the `SO101` folder from the [SO-ARM100 repo](https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101) into the `examples/phone_to_so100/` directory, so that the URDF file path becomes `examples/phone_to_so100/SO101/so101_new_calib.urdf`. - Run this example to teleoperate: ```bash - python examples/phone_to_so100/teleoperate.py + cd examples/phone_to_so100 + python teleoperate.py ``` After running the example: @@ -84,19 +85,22 @@ Additionally you can customize mapping or safety limits by editing the processor - Run this example to record a dataset, which saves absolute end effector observations and actions: ```bash - python examples/phone_to_so100/record.py + cd examples/phone_to_so100 + python record.py ``` - Run this example to replay recorded episodes: ```bash - python examples/phone_to_so100/replay.py + cd examples/phone_to_so100 + python replay.py ``` - Run this example to evaluate a pretrained policy: ```bash - python examples/phone_to_so100/evaluate.py + cd examples/phone_to_so100 + python evaluate.py ``` ### Important pipeline steps and options From 7dbbaa3727e8866ce6649e895d152f129a0c6d32 Mon Sep 17 00:00:00 2001 From: Guilherme Miotto Date: Mon, 23 Feb 2026 17:11:55 +0100 Subject: [PATCH 055/131] Small comment fix (#2990) Co-authored-by: Steven Palma --- src/lerobot/policies/smolvla/configuration_smolvla.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lerobot/policies/smolvla/configuration_smolvla.py b/src/lerobot/policies/smolvla/configuration_smolvla.py index c32c8a60e..c696265f2 100644 --- a/src/lerobot/policies/smolvla/configuration_smolvla.py +++ b/src/lerobot/policies/smolvla/configuration_smolvla.py @@ -85,7 +85,7 @@ class SmolVLAConfig(PreTrainedConfig): scheduler_decay_lr: float = 2.5e-6 vlm_model_name: str = "HuggingFaceTB/SmolVLM2-500M-Video-Instruct" # Select the VLM backbone. - load_vlm_weights: bool = False # Set to True in case of training the expert from scratch. True when init from pretrained SmolVLA weights + load_vlm_weights: bool = False # Set to False in case of training the expert from scratch. True when init from pretrained SmolVLA weights add_image_special_tokens: bool = False # Whether to use special image tokens around image features. From 0f44adbeecffca72e9b0f41c58b8222402ff4613 Mon Sep 17 00:00:00 2001 From: Yuan Haokuan <138340416+WilbertYuan@users.noreply.github.com> Date: Tue, 24 Feb 2026 00:51:13 +0800 Subject: [PATCH 056/131] docs: fix HF_USER export command to correctly parse username (#2932) * Fix HF_USER extraction command in documentation Updated command to extract the username from hf auth output. Signed-off-by: Yuan Haokuan <138340416+WilbertYuan@users.noreply.github.com> * Correct HF_USER variable assignment in documentation Fix the variable extraction from hf auth output. Signed-off-by: Yuan Haokuan <138340416+WilbertYuan@users.noreply.github.com> * Update docs/source/il_robots.mdx Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Yuan Haokuan <138340416+WilbertYuan@users.noreply.github.com> --------- Signed-off-by: Yuan Haokuan <138340416+WilbertYuan@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Steven Palma --- docs/source/il_robots.mdx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/il_robots.mdx b/docs/source/il_robots.mdx index 7fc770b0c..bad88f88e 100644 --- a/docs/source/il_robots.mdx +++ b/docs/source/il_robots.mdx @@ -165,7 +165,7 @@ huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential Then store your Hugging Face repository name in a variable: ```bash -HF_USER=$(hf auth whoami | head -n 1) +HF_USER=$(hf auth whoami | awk -F': *' 'NR==1 {print $2}') echo $HF_USER ``` From 7fd71c83a3c5ba496b6a19aad96a048b7c42410f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dominik=20Pa=C4=BEo?= Date: Mon, 23 Feb 2026 20:41:20 +0100 Subject: [PATCH 057/131] docs: add WSL evdev installation note (#2855) Add a note in the installation guide explaining that users on WSL need to install evdev to avoid build issues. See: https://github.com/huggingface/lerobot/issues/2528 Signed-off-by: Steven Palma Co-authored-by: Steven Palma --- docs/source/installation.mdx | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/docs/source/installation.mdx b/docs/source/installation.mdx index 8cc83843e..a112377c1 100644 --- a/docs/source/installation.mdx +++ b/docs/source/installation.mdx @@ -40,6 +40,13 @@ conda install ffmpeg -c conda-forge > > - _[On Linux only]_ If you want to bring your own ffmpeg: Install [ffmpeg build dependencies](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#GettheDependencies) and [compile ffmpeg from source with libsvtav1](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#libsvtav1), and make sure you use the corresponding ffmpeg binary to your install with `which ffmpeg`. +> [!NOTE] +> When installing LeRobot inside WSL (Windows Subsystem for Linux), make sure to install `evdev` with the following command: +> +> ```bash +> conda install evdev -c conda-forge +> ``` + ## Step 3: Install LeRobot 🤗 ### From Source From dac1efd13d2b224f73ecaff4064bb6ca863dde0c Mon Sep 17 00:00:00 2001 From: Jash Shah <49280550+jashshah999@users.noreply.github.com> Date: Tue, 24 Feb 2026 08:29:08 -0800 Subject: [PATCH 058/131] feat: Enable torch.compile for DiffusionPolicy inference (#2486) Co-authored-by: Steven Palma --- src/lerobot/policies/diffusion/configuration_diffusion.py | 4 ++++ src/lerobot/policies/diffusion/modeling_diffusion.py | 5 +++++ 2 files changed, 9 insertions(+) diff --git a/src/lerobot/policies/diffusion/configuration_diffusion.py b/src/lerobot/policies/diffusion/configuration_diffusion.py index 8ac0920dd..3d30e0941 100644 --- a/src/lerobot/policies/diffusion/configuration_diffusion.py +++ b/src/lerobot/policies/diffusion/configuration_diffusion.py @@ -139,6 +139,10 @@ class DiffusionConfig(PreTrainedConfig): # Inference num_inference_steps: int | None = None + # Optimization + compile_model: bool = False + compile_mode: str = "reduce-overhead" + # Loss computation do_mask_loss_for_padding: bool = False diff --git a/src/lerobot/policies/diffusion/modeling_diffusion.py b/src/lerobot/policies/diffusion/modeling_diffusion.py index 1fdc76f10..7525c9252 100644 --- a/src/lerobot/policies/diffusion/modeling_diffusion.py +++ b/src/lerobot/policies/diffusion/modeling_diffusion.py @@ -182,6 +182,11 @@ class DiffusionModel(nn.Module): self.unet = DiffusionConditionalUnet1d(config, global_cond_dim=global_cond_dim * config.n_obs_steps) + if config.compile_model: + # Compile the U-Net. "reduce-overhead" is preferred for the small-batch repetitive loops + # common in diffusion inference. + self.unet = torch.compile(self.unet, mode=config.compile_mode) + self.noise_scheduler = _make_noise_scheduler( config.noise_scheduler_type, num_train_timesteps=config.num_train_timesteps, From 5095ab08451fc580a6194201f88b7005c19d19f6 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Tue, 24 Feb 2026 19:09:34 +0100 Subject: [PATCH 059/131] fix(ci): permissions triton (#3011) --- .github/workflows/full_tests.yml | 2 ++ docker/Dockerfile.internal | 2 ++ 2 files changed, 4 insertions(+) diff --git a/.github/workflows/full_tests.yml b/.github/workflows/full_tests.yml index fd5e422b3..d23b99de0 100644 --- a/.github/workflows/full_tests.yml +++ b/.github/workflows/full_tests.yml @@ -173,6 +173,8 @@ jobs: shell: bash working-directory: /lerobot steps: + - name: Fix ptxas permissions + run: chmod +x /lerobot/.venv/lib/python3.10/site-packages/triton/backends/nvidia/bin/ptxas - name: Run pytest on GPU run: pytest tests -vv --maxfail=10 - name: Run end-to-end tests diff --git a/docker/Dockerfile.internal b/docker/Dockerfile.internal index c1dfa1dae..ed7d10495 100644 --- a/docker/Dockerfile.internal +++ b/docker/Dockerfile.internal @@ -85,6 +85,8 @@ RUN if [ "$UNBOUND_DEPS" = "true" ]; then \ RUN uv pip install --no-cache ".[all]" +RUN chmod +x /lerobot/.venv/lib/python${PYTHON_VERSION}/site-packages/triton/backends/nvidia/bin/ptxas + # Copy the rest of the application source code # Make sure to have the git-LFS files for testing COPY --chown=user_lerobot:user_lerobot . . From 18d9cb5ac42a29427df7200671f585ebfda2d7b5 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Tue, 24 Feb 2026 19:10:43 +0100 Subject: [PATCH 060/131] feat(scripts): Integrate tqdm for training progress visualization (#3010) --- src/lerobot/scripts/lerobot_train.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index 93b99e245..465cbf531 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -24,6 +24,7 @@ import torch from accelerate import Accelerator from termcolor import colored from torch.optim import Optimizer +from tqdm import tqdm from lerobot.configs import parser from lerobot.configs.train import TrainPipelineConfig @@ -51,6 +52,7 @@ from lerobot.utils.utils import ( format_big_number, has_method, init_logging, + inside_slurm, ) @@ -390,6 +392,14 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): ) if is_main_process: + progbar = tqdm( + total=cfg.steps - step, + desc="Training", + unit="step", + disable=inside_slurm(), + position=0, + leave=True, + ) logging.info( f"Start offline training on a fixed dataset, with effective batch size: {effective_batch_size}" ) @@ -414,6 +424,8 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): # Note: eval and checkpoint happens *after* the `step`th training update has completed, so we # increment `step` here. step += 1 + if is_main_process: + progbar.update(1) train_tracker.step() is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0 and is_main_process is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps @@ -507,6 +519,9 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): accelerator.wait_for_everyone() + if is_main_process: + progbar.close() + if eval_env: close_envs(eval_env) From 8fef4ddab8d0c2cdbd811da780b36569572d7a7a Mon Sep 17 00:00:00 2001 From: Martin Kiefel Date: Wed, 25 Feb 2026 11:57:07 +0100 Subject: [PATCH 061/131] fix(dataset): Fix reindexing bug for videos on splits (#2548) * fix(dataset): Reindex videos based on frame and not on time Sometimes during split operations the frame timestamp floating precision leads to frame ending up in the wrong split. This changes fixes the issues by directly working with frame indices instead. * Fix formatting --- src/lerobot/datasets/dataset_tools.py | 41 +++++++++++++++------------ 1 file changed, 23 insertions(+), 18 deletions(-) diff --git a/src/lerobot/datasets/dataset_tools.py b/src/lerobot/datasets/dataset_tools.py index 123d455c6..b62d7d959 100644 --- a/src/lerobot/datasets/dataset_tools.py +++ b/src/lerobot/datasets/dataset_tools.py @@ -567,20 +567,22 @@ def _copy_and_reindex_data( def _keep_episodes_from_video_with_av( input_path: Path, output_path: Path, - episodes_to_keep: list[tuple[float, float]], + episodes_to_keep: list[tuple[int, int]], fps: float, vcodec: str = "libsvtav1", pix_fmt: str = "yuv420p", ) -> None: """Keep only specified episodes from a video file using PyAV. - This function decodes frames from specified time ranges and re-encodes them with + This function decodes frames from specified frame ranges and re-encodes them with properly reset timestamps to ensure monotonic progression. Args: input_path: Source video file path. output_path: Destination video file path. - episodes_to_keep: List of (start_time, end_time) tuples for episodes to keep. + episodes_to_keep: List of (start_frame, end_frame) tuples for episodes to keep. + Ranges are half-open intervals: [start_frame, end_frame), where start_frame + is inclusive and end_frame is exclusive. fps: Frame rate of the video. vcodec: Video codec to use for encoding. pix_fmt: Pixel format for output video. @@ -622,9 +624,10 @@ def _keep_episodes_from_video_with_av( # Create set of (start, end) ranges for fast lookup. # Convert to a sorted list for efficient checking. - time_ranges = sorted(episodes_to_keep) + frame_ranges = sorted(episodes_to_keep) # Track frame index for setting PTS and current range being processed. + src_frame_count = 0 frame_count = 0 range_idx = 0 @@ -634,21 +637,20 @@ def _keep_episodes_from_video_with_av( if frame is None: continue - # Get frame timestamp. - frame_time = float(frame.pts * frame.time_base) if frame.pts is not None else 0.0 - - # Check if frame is in any of our desired time ranges. + # Check if frame is in any of our desired frame ranges. # Skip ranges that have already passed. - while range_idx < len(time_ranges) and frame_time >= time_ranges[range_idx][1]: + while range_idx < len(frame_ranges) and src_frame_count >= frame_ranges[range_idx][1]: range_idx += 1 # If we've passed all ranges, stop processing. - if range_idx >= len(time_ranges): + if range_idx >= len(frame_ranges): break # Check if frame is in current range. - start_ts, end_ts = time_ranges[range_idx] - if frame_time < start_ts: + start_frame = frame_ranges[range_idx][0] + + if src_frame_count < start_frame: + src_frame_count += 1 continue # Frame is in range - create a new frame with reset timestamps. @@ -661,6 +663,7 @@ def _keep_episodes_from_video_with_av( for pkt in v_out.encode(new_frame): out.mux(pkt) + src_frame_count += 1 frame_count += 1 # Flush encoder. @@ -749,15 +752,17 @@ def _copy_and_reindex_videos( f"videos/{video_key}/to_timestamp" ] else: - # Build list of time ranges to keep, in sorted order. + # Build list of frame ranges to keep, in sorted order. sorted_keep_episodes = sorted(episodes_in_file, key=lambda x: episode_mapping[x]) - episodes_to_keep_ranges: list[tuple[float, float]] = [] - + episodes_to_keep_ranges: list[tuple[int, int]] = [] for old_idx in sorted_keep_episodes: src_ep = src_dataset.meta.episodes[old_idx] - from_ts = src_ep[f"videos/{video_key}/from_timestamp"] - to_ts = src_ep[f"videos/{video_key}/to_timestamp"] - episodes_to_keep_ranges.append((from_ts, to_ts)) + from_frame = round(src_ep[f"videos/{video_key}/from_timestamp"] * src_dataset.meta.fps) + to_frame = round(src_ep[f"videos/{video_key}/to_timestamp"] * src_dataset.meta.fps) + assert src_ep["length"] == to_frame - from_frame, ( + f"Episode length mismatch: {src_ep['length']} vs {to_frame - from_frame}" + ) + episodes_to_keep_ranges.append((from_frame, to_frame)) # Use PyAV filters to efficiently re-encode only the desired segments. assert src_dataset.meta.video_path is not None From f138e5948a076bff69b2188d700c9c5d5a415e30 Mon Sep 17 00:00:00 2001 From: Jash Shah <49280550+jashshah999@users.noreply.github.com> Date: Wed, 25 Feb 2026 03:29:10 -0800 Subject: [PATCH 062/131] Fix metaworld_config.json not bundled in pip installs and AttributeError crash (#3017) 1. Include metaworld_config.json in package distributions by adding it to both MANIFEST.in (for sdist) and pyproject.toml package-data (for wheels). Without this, pip-installed lerobot raises FileNotFoundError when importing the metaworld environment. 2. Fix crash in sanity_check_dataset_name where the error message accesses policy_cfg.type when policy_cfg is None, raising AttributeError instead of the intended ValueError. Fixes #2958 --- MANIFEST.in | 1 + pyproject.toml | 3 +++ src/lerobot/utils/control_utils.py | 2 +- 3 files changed, 5 insertions(+), 1 deletion(-) diff --git a/MANIFEST.in b/MANIFEST.in index c1fb2ea75..c1fce3b5a 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,2 +1,3 @@ include src/lerobot/templates/lerobot_modelcard_template.md include src/lerobot/datasets/card_template.md +include src/lerobot/envs/metaworld_config.json diff --git a/pyproject.toml b/pyproject.toml index ea3df4a6d..b6d85b0f6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -214,6 +214,9 @@ lerobot-edit-dataset="lerobot.scripts.lerobot_edit_dataset:main" lerobot-setup-can="lerobot.scripts.lerobot_setup_can:main" # ---------------- Tool Configurations ---------------- +[tool.setuptools.package-data] +lerobot = ["envs/*.json"] + [tool.setuptools.packages.find] where = ["src"] diff --git a/src/lerobot/utils/control_utils.py b/src/lerobot/utils/control_utils.py index 7cfe177ef..7c605af17 100644 --- a/src/lerobot/utils/control_utils.py +++ b/src/lerobot/utils/control_utils.py @@ -189,7 +189,7 @@ def sanity_check_dataset_name(repo_id, policy_cfg): # Check if dataset_name starts with "eval_" but policy is missing if dataset_name.startswith("eval_") and policy_cfg is None: raise ValueError( - f"Your dataset name begins with 'eval_' ({dataset_name}), but no policy is provided ({policy_cfg.type})." + f"Your dataset name begins with 'eval_' ({dataset_name}), but no policy is provided." ) # Check if dataset_name does not start with "eval_" but policy is provided From 0317a15bf11f655e62ab59f783392cfd8f640100 Mon Sep 17 00:00:00 2001 From: Jash Shah <49280550+jashshah999@users.noreply.github.com> Date: Wed, 25 Feb 2026 03:29:22 -0800 Subject: [PATCH 063/131] fix(video): replace assertions with proper exceptions in video frame decoding (#3016) Replaced assert statements with FrameTimestampError exceptions in decode_video_frames_torchvision and decode_video_frames_torchcodec. Assertions are unsuitable for runtime validation because they can be silently disabled with python -O, and they produce unhelpful AssertionError tracebacks. The codebase already defines FrameTimestampError for this exact purpose but it was only used in one of the three validation sites. Also removed AssertionError from the except clause in LeRobotDataset.__init__, which was masking video timestamp errors by silently triggering a dataset re-download instead of surfacing the actual problem. --- src/lerobot/datasets/lerobot_dataset.py | 2 +- src/lerobot/datasets/video_utils.py | 46 ++++++++++++++----------- 2 files changed, 27 insertions(+), 21 deletions(-) diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index 83d452a44..8fa4f200b 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -747,7 +747,7 @@ class LeRobotDataset(torch.utils.data.Dataset): # Check if cached dataset contains all requested episodes if not self._check_cached_episodes_sufficient(): raise FileNotFoundError("Cached dataset doesn't contain all requested episodes") - except (AssertionError, FileNotFoundError, NotADirectoryError): + except (FileNotFoundError, NotADirectoryError): if is_valid_version(self.revision): self.revision = get_safe_version(self.repo_id, self.revision) self.download(download_videos) diff --git a/src/lerobot/datasets/video_utils.py b/src/lerobot/datasets/video_utils.py index acc24a9e0..8c8494b87 100644 --- a/src/lerobot/datasets/video_utils.py +++ b/src/lerobot/datasets/video_utils.py @@ -227,16 +227,17 @@ def decode_video_frames_torchvision( min_, argmin_ = dist.min(1) is_within_tol = min_ < tolerance_s - assert is_within_tol.all(), ( - f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})." - "It means that the closest frame that can be loaded from the video is too far away in time." - "This might be due to synchronization issues with timestamps during data collection." - "To be safe, we advise to ignore this item during training." - f"\nqueried timestamps: {query_ts}" - f"\nloaded timestamps: {loaded_ts}" - f"\nvideo: {video_path}" - f"\nbackend: {backend}" - ) + if not is_within_tol.all(): + raise FrameTimestampError( + f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})." + " It means that the closest frame that can be loaded from the video is too far away in time." + " This might be due to synchronization issues with timestamps during data collection." + " To be safe, we advise to ignore this item during training." + f"\nqueried timestamps: {query_ts}" + f"\nloaded timestamps: {loaded_ts}" + f"\nvideo: {video_path}" + f"\nbackend: {backend}" + ) # get closest frames to the query timestamps closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_]) @@ -248,7 +249,11 @@ def decode_video_frames_torchvision( # convert to the pytorch format which is float32 in [0,1] range (and channel first) closest_frames = closest_frames.type(torch.float32) / 255 - assert len(timestamps) == len(closest_frames) + if len(timestamps) != len(closest_frames): + raise FrameTimestampError( + f"Number of retrieved frames ({len(closest_frames)}) does not match " + f"number of queried timestamps ({len(timestamps)})" + ) return closest_frames @@ -353,15 +358,16 @@ def decode_video_frames_torchcodec( min_, argmin_ = dist.min(1) is_within_tol = min_ < tolerance_s - assert is_within_tol.all(), ( - f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})." - "It means that the closest frame that can be loaded from the video is too far away in time." - "This might be due to synchronization issues with timestamps during data collection." - "To be safe, we advise to ignore this item during training." - f"\nqueried timestamps: {query_ts}" - f"\nloaded timestamps: {loaded_ts}" - f"\nvideo: {video_path}" - ) + if not is_within_tol.all(): + raise FrameTimestampError( + f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})." + " It means that the closest frame that can be loaded from the video is too far away in time." + " This might be due to synchronization issues with timestamps during data collection." + " To be safe, we advise to ignore this item during training." + f"\nqueried timestamps: {query_ts}" + f"\nloaded timestamps: {loaded_ts}" + f"\nvideo: {video_path}" + ) # get closest frames to the query timestamps closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_]) From 7541d72130c66bd0ba9c24586e7fc773a83c48eb Mon Sep 17 00:00:00 2001 From: Khalil Meftah Date: Wed, 25 Feb 2026 13:28:01 +0100 Subject: [PATCH 064/131] Fix SARM dense_only mode: always load episodes_df for target computation (#3021) * fix annotation mode check * fix: SARM dense_only mode always load episodes_df for target computation --------- Co-authored-by: John Newsom Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com> --- src/lerobot/policies/sarm/processor_sarm.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/lerobot/policies/sarm/processor_sarm.py b/src/lerobot/policies/sarm/processor_sarm.py index 5c617282a..8f2bc23db 100644 --- a/src/lerobot/policies/sarm/processor_sarm.py +++ b/src/lerobot/policies/sarm/processor_sarm.py @@ -277,9 +277,7 @@ class SARMEncodingProcessorStep(ProcessorStep): # When language is perturbed, targets are zero so perturbed samples don't contribute to progress loss if self.dataset_meta is not None: - episodes_df = None - if self.sparse_subtask_names != ["task"]: - episodes_df = self.dataset_meta.episodes.to_pandas() + episodes_df = self.dataset_meta.episodes.to_pandas() # Generate sparse targets if self.sparse_temporal_proportions is not None: From 9a5ab8ffab730efc660d55a2b5213bb24365b8e0 Mon Sep 17 00:00:00 2001 From: Mishig Date: Wed, 25 Feb 2026 15:02:40 +0000 Subject: [PATCH 065/131] feat: add visualization badge to card template and update dataset card creation with repo_id (#3005) * feat: add visualization badge to card template and update dataset card creation with repo_id * Update src/lerobot/datasets/card_template.md * Update src/lerobot/datasets/card_template.md --------- Signed-off-by: Mishig Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- src/lerobot/datasets/card_template.md | 7 +++++++ src/lerobot/datasets/lerobot_dataset.py | 2 +- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/src/lerobot/datasets/card_template.md b/src/lerobot/datasets/card_template.md index ee26a78f5..1eced9f4c 100644 --- a/src/lerobot/datasets/card_template.md +++ b/src/lerobot/datasets/card_template.md @@ -7,6 +7,13 @@ This dataset was created using [LeRobot](https://github.com/huggingface/lerobot). +{% if repo_id is defined and repo_id %} + + + + +{% endif %} + ## Dataset Description {{ dataset_description | default("", true) }} diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index 8fa4f200b..b51f06a04 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -839,7 +839,7 @@ class LeRobotDataset(torch.utils.data.Dataset): hub_api.upload_folder(**upload_kwargs) card = create_lerobot_dataset_card( - tags=tags, dataset_info=self.meta.info, license=license, **card_kwargs + tags=tags, dataset_info=self.meta.info, license=license, repo_id=self.repo_id, **card_kwargs ) card.push_to_hub(repo_id=self.repo_id, repo_type="dataset", revision=branch) From d0b58190dab254ea18f814c03129f102c8b68a66 Mon Sep 17 00:00:00 2001 From: Cotton Hu <1821141394@qq.com> Date: Thu, 26 Feb 2026 00:36:31 +0800 Subject: [PATCH 066/131] fix(policies): support dp train when n_obs_steps=1 (#2430) Co-authored-by: hukongtao Co-authored-by: Steven Palma --- src/lerobot/policies/diffusion/modeling_diffusion.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/lerobot/policies/diffusion/modeling_diffusion.py b/src/lerobot/policies/diffusion/modeling_diffusion.py index 7525c9252..314ca369c 100644 --- a/src/lerobot/policies/diffusion/modeling_diffusion.py +++ b/src/lerobot/policies/diffusion/modeling_diffusion.py @@ -142,6 +142,9 @@ class DiffusionPolicy(PreTrainedPolicy): """Run the batch through the model and compute the loss for training or validation.""" if self.config.image_features: batch = dict(batch) # shallow copy so that adding a key doesn't modify the original + for key in self.config.image_features: + if self.config.n_obs_steps == 1 and batch[key].ndim == 4: + batch[key] = batch[key].unsqueeze(1) batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4) loss = self.diffusion.compute_loss(batch) # no output_dict so returning None From 975dcad9187e9803cfc56545a377a99f61b31969 Mon Sep 17 00:00:00 2001 From: Khalil Meftah Date: Wed, 25 Feb 2026 18:46:55 +0100 Subject: [PATCH 067/131] Feat(teleoperators): add OpenArm Mini teleoperator (#3022) * add OpenArm Mini config and module init * add OpenArm Mini teleoperator implementation * add OpenArm Mini into factory and setup motors --------- Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com> --- src/lerobot/scripts/lerobot_setup_motors.py | 2 + .../teleoperators/openarm_mini/__init__.py | 20 ++ .../openarm_mini/config_openarm_mini.py | 30 ++ .../openarm_mini/openarm_mini.py | 296 ++++++++++++++++++ src/lerobot/teleoperators/utils.py | 4 + 5 files changed, 352 insertions(+) create mode 100644 src/lerobot/teleoperators/openarm_mini/__init__.py create mode 100644 src/lerobot/teleoperators/openarm_mini/config_openarm_mini.py create mode 100644 src/lerobot/teleoperators/openarm_mini/openarm_mini.py diff --git a/src/lerobot/scripts/lerobot_setup_motors.py b/src/lerobot/scripts/lerobot_setup_motors.py index 01af95b61..2c962a6e2 100644 --- a/src/lerobot/scripts/lerobot_setup_motors.py +++ b/src/lerobot/scripts/lerobot_setup_motors.py @@ -43,6 +43,7 @@ from lerobot.teleoperators import ( # noqa: F401 koch_leader, make_teleoperator_from_config, omx_leader, + openarm_mini, so_leader, ) @@ -51,6 +52,7 @@ COMPATIBLE_DEVICES = [ "koch_leader", "omx_follower", "omx_leader", + "openarm_mini", "so100_follower", "so100_leader", "so101_follower", diff --git a/src/lerobot/teleoperators/openarm_mini/__init__.py b/src/lerobot/teleoperators/openarm_mini/__init__.py new file mode 100644 index 000000000..8620af1d7 --- /dev/null +++ b/src/lerobot/teleoperators/openarm_mini/__init__.py @@ -0,0 +1,20 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .config_openarm_mini import OpenArmMiniConfig +from .openarm_mini import OpenArmMini + +__all__ = ["OpenArmMini", "OpenArmMiniConfig"] diff --git a/src/lerobot/teleoperators/openarm_mini/config_openarm_mini.py b/src/lerobot/teleoperators/openarm_mini/config_openarm_mini.py new file mode 100644 index 000000000..7dc3e0212 --- /dev/null +++ b/src/lerobot/teleoperators/openarm_mini/config_openarm_mini.py @@ -0,0 +1,30 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass + +from ..config import TeleoperatorConfig + + +@TeleoperatorConfig.register_subclass("openarm_mini") +@dataclass +class OpenArmMiniConfig(TeleoperatorConfig): + """Configuration for OpenArm Mini teleoperator with Feetech motors (dual arms).""" + + port_right: str = "/dev/ttyUSB0" + port_left: str = "/dev/ttyUSB1" + + use_degrees: bool = True diff --git a/src/lerobot/teleoperators/openarm_mini/openarm_mini.py b/src/lerobot/teleoperators/openarm_mini/openarm_mini.py new file mode 100644 index 000000000..3fbcecf24 --- /dev/null +++ b/src/lerobot/teleoperators/openarm_mini/openarm_mini.py @@ -0,0 +1,296 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import time +from typing import Any + +from lerobot.motors import Motor, MotorCalibration, MotorNormMode +from lerobot.motors.feetech import ( + FeetechMotorsBus, + OperatingMode, +) +from lerobot.processor import RobotAction +from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected + +from ..teleoperator import Teleoperator +from .config_openarm_mini import OpenArmMiniConfig + +logger = logging.getLogger(__name__) + +# Motors whose direction is inverted during readout +RIGHT_MOTORS_TO_FLIP = ["joint_1", "joint_2", "joint_3", "joint_4", "joint_5"] +LEFT_MOTORS_TO_FLIP = ["joint_1", "joint_3", "joint_4", "joint_5", "joint_6", "joint_7"] + + +class OpenArmMini(Teleoperator): + """ + OpenArm Mini Teleoperator with dual Feetech-based arms (8 motors per arm). + + Each arm has 7 joints plus a gripper, using Feetech STS3215 servos. + """ + + config_class = OpenArmMiniConfig + name = "openarm_mini" + + def __init__(self, config: OpenArmMiniConfig): + super().__init__(config) + self.config = config + + norm_mode_body = MotorNormMode.DEGREES + + motors_right = { + "joint_1": Motor(1, "sts3215", norm_mode_body), + "joint_2": Motor(2, "sts3215", norm_mode_body), + "joint_3": Motor(3, "sts3215", norm_mode_body), + "joint_4": Motor(4, "sts3215", norm_mode_body), + "joint_5": Motor(5, "sts3215", norm_mode_body), + "joint_6": Motor(6, "sts3215", norm_mode_body), + "joint_7": Motor(7, "sts3215", norm_mode_body), + "gripper": Motor(8, "sts3215", MotorNormMode.RANGE_0_100), + } + + motors_left = { + "joint_1": Motor(1, "sts3215", norm_mode_body), + "joint_2": Motor(2, "sts3215", norm_mode_body), + "joint_3": Motor(3, "sts3215", norm_mode_body), + "joint_4": Motor(4, "sts3215", norm_mode_body), + "joint_5": Motor(5, "sts3215", norm_mode_body), + "joint_6": Motor(6, "sts3215", norm_mode_body), + "joint_7": Motor(7, "sts3215", norm_mode_body), + "gripper": Motor(8, "sts3215", MotorNormMode.RANGE_0_100), + } + + cal_right = { + k.replace("right_", ""): v for k, v in (self.calibration or {}).items() if k.startswith("right_") + } + cal_left = { + k.replace("left_", ""): v for k, v in (self.calibration or {}).items() if k.startswith("left_") + } + + self.bus_right = FeetechMotorsBus( + port=self.config.port_right, + motors=motors_right, + calibration=cal_right, + ) + + self.bus_left = FeetechMotorsBus( + port=self.config.port_left, + motors=motors_left, + calibration=cal_left, + ) + + @property + def action_features(self) -> dict[str, type]: + features: dict[str, type] = {} + for motor in self.bus_right.motors: + features[f"right_{motor}.pos"] = float + for motor in self.bus_left.motors: + features[f"left_{motor}.pos"] = float + return features + + @property + def feedback_features(self) -> dict[str, type]: + return {} + + @property + def is_connected(self) -> bool: + return self.bus_right.is_connected and self.bus_left.is_connected + + @check_if_already_connected + def connect(self, calibrate: bool = True) -> None: + logger.info(f"Connecting right arm on {self.config.port_right}...") + self.bus_right.connect() + logger.info(f"Connecting left arm on {self.config.port_left}...") + self.bus_left.connect() + + if calibrate: + self.calibrate() + + self.configure() + logger.info(f"{self} connected.") + + @property + def is_calibrated(self) -> bool: + return self.bus_right.is_calibrated and self.bus_left.is_calibrated + + def calibrate(self) -> None: + """ + Run calibration procedure for OpenArm Mini. + + 1. Disable torque + 2. Ask user to position arms in hanging position with grippers closed + 3. Set this as zero position via half-turn homing + 4. Interactive gripper calibration (open/close positions) + 5. Save calibration + """ + if self.calibration: + user_input = input( + f"Press ENTER to use existing calibration for {self.id}, " + f"or type 'c' and press ENTER to run new calibration: " + ) + if user_input.strip().lower() != "c": + logger.info(f"Using existing calibration for {self.id}") + cal_right = { + k.replace("right_", ""): v for k, v in self.calibration.items() if k.startswith("right_") + } + cal_left = { + k.replace("left_", ""): v for k, v in self.calibration.items() if k.startswith("left_") + } + self.bus_right.write_calibration(cal_right) + self.bus_left.write_calibration(cal_left) + return + + logger.info(f"\nRunning calibration for {self}") + + self._calibrate_arm("right", self.bus_right) + self._calibrate_arm("left", self.bus_left) + + self._save_calibration() + print(f"\nCalibration complete and saved to {self.calibration_fpath}") + + def _calibrate_arm(self, arm_name: str, bus: FeetechMotorsBus) -> None: + """Calibrate a single arm with Feetech motors.""" + logger.info(f"\n=== Calibrating {arm_name.upper()} arm ===") + + bus.disable_torque() + + logger.info(f"Setting Phase to 12 for all motors in {arm_name.upper()} arm...") + for motor in bus.motors: + bus.write("Phase", motor, 12) + + for motor in bus.motors: + bus.write("Operating_Mode", motor, OperatingMode.POSITION.value) + + input( + f"\nCalibration: Zero Position ({arm_name.upper()} arm)\n" + "Position the arm in the following configuration:\n" + " - Arm hanging straight down\n" + " - Gripper closed\n" + "Press ENTER when ready..." + ) + + homing_offsets = bus.set_half_turn_homings() + logger.info(f"{arm_name.capitalize()} arm zero position set.") + + print(f"\nSetting motor ranges for {arm_name.upper()} arm\n") + + if self.calibration is None: + self.calibration = {} + + motor_resolution = bus.model_resolution_table[list(bus.motors.values())[0].model] + max_res = motor_resolution - 1 + + for motor_name, motor in bus.motors.items(): + prefixed_name = f"{arm_name}_{motor_name}" + + if motor_name == "gripper": + input( + f"\nGripper Calibration ({arm_name.upper()} arm)\n" + f"Step 1: CLOSE the gripper fully\n" + f"Press ENTER when gripper is closed..." + ) + closed_pos = bus.read("Present_Position", motor_name, normalize=False) + logger.info(f" Gripper closed position recorded: {closed_pos}") + + input("\nStep 2: OPEN the gripper fully\nPress ENTER when gripper is fully open...") + open_pos = bus.read("Present_Position", motor_name, normalize=False) + logger.info(f" Gripper open position recorded: {open_pos}") + + if closed_pos < open_pos: + range_min = int(closed_pos) + range_max = int(open_pos) + drive_mode = 0 + else: + range_min = int(open_pos) + range_max = int(closed_pos) + drive_mode = 1 + + logger.info( + f" {prefixed_name}: range set to [{range_min}, {range_max}] " + f"(0=closed, 100=open, drive_mode={drive_mode})" + ) + else: + range_min = 0 + range_max = max_res + drive_mode = 0 + logger.info(f" {prefixed_name}: range set to [0, {max_res}] (full motor range)") + + self.calibration[prefixed_name] = MotorCalibration( + id=motor.id, + drive_mode=drive_mode, + homing_offset=homing_offsets[motor_name], + range_min=range_min, + range_max=range_max, + ) + + cal_for_bus = { + k.replace(f"{arm_name}_", ""): v + for k, v in self.calibration.items() + if k.startswith(f"{arm_name}_") + } + bus.write_calibration(cal_for_bus) + + def configure(self) -> None: + self.bus_right.disable_torque() + self.bus_right.configure_motors() + for motor in self.bus_right.motors: + self.bus_right.write("Operating_Mode", motor, OperatingMode.POSITION.value) + + self.bus_left.disable_torque() + self.bus_left.configure_motors() + for motor in self.bus_left.motors: + self.bus_left.write("Operating_Mode", motor, OperatingMode.POSITION.value) + + def setup_motors(self) -> None: + print("\nSetting up RIGHT arm motors...") + for motor in reversed(self.bus_right.motors): + input(f"Connect the controller board to the RIGHT '{motor}' motor only and press enter.") + self.bus_right.setup_motor(motor) + print(f"RIGHT '{motor}' motor id set to {self.bus_right.motors[motor].id}") + + print("\nSetting up LEFT arm motors...") + for motor in reversed(self.bus_left.motors): + input(f"Connect the controller board to the LEFT '{motor}' motor only and press enter.") + self.bus_left.setup_motor(motor) + print(f"LEFT '{motor}' motor id set to {self.bus_left.motors[motor].id}") + + @check_if_not_connected + def get_action(self) -> RobotAction: + """Get current action from both arms (read positions from all motors).""" + start = time.perf_counter() + + right_positions = self.bus_right.sync_read("Present_Position") + left_positions = self.bus_left.sync_read("Present_Position") + + action: dict[str, Any] = {} + for motor, val in right_positions.items(): + action[f"right_{motor}.pos"] = -val if motor in RIGHT_MOTORS_TO_FLIP else val + for motor, val in left_positions.items(): + action[f"left_{motor}.pos"] = -val if motor in LEFT_MOTORS_TO_FLIP else val + + dt_ms = (time.perf_counter() - start) * 1e3 + logger.debug(f"{self} read action: {dt_ms:.1f}ms") + return action + + def send_feedback(self, feedback: dict[str, float]) -> None: + raise NotImplementedError("Feedback is not yet implemented for OpenArm Mini.") + + @check_if_not_connected + def disconnect(self) -> None: + self.bus_right.disconnect() + self.bus_left.disconnect() + logger.info(f"{self} disconnected.") diff --git a/src/lerobot/teleoperators/utils.py b/src/lerobot/teleoperators/utils.py index 16454d5ad..db685f396 100644 --- a/src/lerobot/teleoperators/utils.py +++ b/src/lerobot/teleoperators/utils.py @@ -95,6 +95,10 @@ def make_teleoperator_from_config(config: TeleoperatorConfig) -> "Teleoperator": from .bi_openarm_leader import BiOpenArmLeader return BiOpenArmLeader(config) + elif config.type == "openarm_mini": + from .openarm_mini import OpenArmMini + + return OpenArmMini(config) else: try: return cast("Teleoperator", make_device_from_device_class(config)) From 46044fed753f62fe54ced13adcd9af865ed36fb0 Mon Sep 17 00:00:00 2001 From: Khalil Meftah Date: Thu, 26 Feb 2026 13:28:46 +0100 Subject: [PATCH 068/131] Fix: remove device_map from SmolVLA model loading (#3029) * Fix SmolVLA meta tensor error by removing device_map - Remove device_map parameter from VLM model loading - Change torch_dtype from string to torch.bfloat16 - Add explicit .to(device) calls after initialization This resolves NotImplementedError when training SmolVLA policy. Fixes meta tensor copy issue in factory.py:418. * fix: remove manual device movement logic and fix dtype handling --------- Co-authored-by: Highsky7 --- src/lerobot/policies/smolvla/smolvlm_with_expert.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/lerobot/policies/smolvla/smolvlm_with_expert.py b/src/lerobot/policies/smolvla/smolvlm_with_expert.py index 555c40773..caca41dab 100644 --- a/src/lerobot/policies/smolvla/smolvlm_with_expert.py +++ b/src/lerobot/policies/smolvla/smolvlm_with_expert.py @@ -77,7 +77,6 @@ class SmolVLMWithExpertModel(nn.Module): print(f"Loading {model_id} weights ...") self.vlm = AutoModelForImageTextToText.from_pretrained( model_id, - device_map=device, torch_dtype="bfloat16", low_cpu_mem_usage=True, ) From fde9d08281d00641ea13b560cbb731d5e09818cf Mon Sep 17 00:00:00 2001 From: Damien LaRocque Date: Thu, 26 Feb 2026 14:41:32 +0100 Subject: [PATCH 069/131] feat(async_inference) Enable plugins with async inference (#2425) * feat(async-inference) Try using async inference server with plugins * Fix import * Fix import error in Robot Client --------- Signed-off-by: Steven Palma Co-authored-by: Steven Palma --- src/lerobot/async_inference/robot_client.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/src/lerobot/async_inference/robot_client.py b/src/lerobot/async_inference/robot_client.py index e4d21652a..da576eb48 100644 --- a/src/lerobot/async_inference/robot_client.py +++ b/src/lerobot/async_inference/robot_client.py @@ -49,23 +49,18 @@ import torch from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401 from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401 -from lerobot.robots import ( # noqa: F401 - Robot, - RobotConfig, - bi_so_follower, - koch_follower, +from lerobot.robots import ( + RobotConfig, # noqa: F401 make_robot_from_config, - omx_follower, - so_follower, ) from lerobot.transport import ( services_pb2, # type: ignore services_pb2_grpc, # type: ignore ) from lerobot.transport.utils import grpc_channel_options, send_bytes_in_chunks +from lerobot.utils.import_utils import register_third_party_plugins from .configs import RobotClientConfig -from .constants import SUPPORTED_ROBOTS from .helpers import ( Action, FPSTracker, @@ -485,8 +480,9 @@ class RobotClient: def async_client(cfg: RobotClientConfig): logging.info(pformat(asdict(cfg))) - if cfg.robot.type not in SUPPORTED_ROBOTS: - raise ValueError(f"Robot {cfg.robot.type} not yet supported!") + # TODO: Assert if checking robot support is still needed with the plugin system + # if cfg.robot.type not in SUPPORTED_ROBOTS: + # raise ValueError(f"Robot {cfg.robot.type} not yet supported!") client = RobotClient(cfg) @@ -512,4 +508,5 @@ def async_client(cfg: RobotClientConfig): if __name__ == "__main__": + register_third_party_plugins() async_client() # run the client From 4e54be1334db8b3dc32e323b214192a2b4e6a297 Mon Sep 17 00:00:00 2001 From: Michio Sun <47138011+thatmich@users.noreply.github.com> Date: Fri, 27 Feb 2026 01:42:22 +0900 Subject: [PATCH 070/131] fix(datasets): skip warning when MultiLeRobotDataset features are identical (#3019) Co-authored-by: Steven Palma --- src/lerobot/datasets/lerobot_dataset.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index b51f06a04..bb526740e 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -1771,11 +1771,12 @@ class MultiLeRobotDataset(torch.utils.data.Dataset): ) for repo_id, ds in zip(self.repo_ids, self._datasets, strict=True): extra_keys = set(ds.features).difference(intersection_features) - logging.warning( - f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the " - "other datasets." - ) - self.disabled_features.update(extra_keys) + if extra_keys: + logging.warning( + f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the " + "other datasets." + ) + self.disabled_features.update(extra_keys) self.image_transforms = image_transforms self.delta_timestamps = delta_timestamps From c7c620533201535064cd4b84ba8c60624693c79d Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Fri, 27 Feb 2026 15:26:56 +0100 Subject: [PATCH 071/131] chore(scripts): no spam log when no action (#3042) --- src/lerobot/scripts/lerobot_record.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/src/lerobot/scripts/lerobot_record.py b/src/lerobot/scripts/lerobot_record.py index ec04975d4..661d33c51 100644 --- a/src/lerobot/scripts/lerobot_record.py +++ b/src/lerobot/scripts/lerobot_record.py @@ -333,6 +333,7 @@ def record_loop( preprocessor.reset() postprocessor.reset() + no_action_count = 0 timestamp = 0 start_episode_t = time.perf_counter() while timestamp < control_time_s: @@ -380,11 +381,13 @@ def record_loop( act = {**arm_action, **base_action} if len(base_action) > 0 else arm_action act_processed_teleop = teleop_action_processor((act, obs)) else: - logging.info( - "No policy or teleoperator provided, skipping action generation." - "This is likely to happen when resetting the environment without a teleop device." - "The robot won't be at its rest position at the start of the next episode." - ) + no_action_count += 1 + if no_action_count == 1 or no_action_count % 10 == 0: + logging.warning( + "No policy or teleoperator provided, skipping action generation. " + "This is likely to happen when resetting the environment without a teleop device. " + "The robot won't be at its rest position at the start of the next episode." + ) continue # Applies a pipeline to the action, default is IdentityProcessor From c085531b17d914fc9aea8f5c8bef0ad8497df079 Mon Sep 17 00:00:00 2001 From: Khalil Meftah Date: Fri, 27 Feb 2026 15:46:31 +0100 Subject: [PATCH 072/131] fix: add missing openarm_mini import to CLI scripts (#3028) --- src/lerobot/scripts/lerobot_calibrate.py | 1 + src/lerobot/scripts/lerobot_find_joint_limits.py | 1 + src/lerobot/scripts/lerobot_record.py | 1 + src/lerobot/scripts/lerobot_teleoperate.py | 1 + 4 files changed, 4 insertions(+) diff --git a/src/lerobot/scripts/lerobot_calibrate.py b/src/lerobot/scripts/lerobot_calibrate.py index 1b30021dd..242067978 100644 --- a/src/lerobot/scripts/lerobot_calibrate.py +++ b/src/lerobot/scripts/lerobot_calibrate.py @@ -56,6 +56,7 @@ from lerobot.teleoperators import ( # noqa: F401 make_teleoperator_from_config, omx_leader, openarm_leader, + openarm_mini, so_leader, unitree_g1, ) diff --git a/src/lerobot/scripts/lerobot_find_joint_limits.py b/src/lerobot/scripts/lerobot_find_joint_limits.py index 082d11803..bcb93ba12 100644 --- a/src/lerobot/scripts/lerobot_find_joint_limits.py +++ b/src/lerobot/scripts/lerobot_find_joint_limits.py @@ -61,6 +61,7 @@ from lerobot.teleoperators import ( # noqa: F401 make_teleoperator_from_config, omx_leader, openarm_leader, + openarm_mini, so_leader, ) from lerobot.utils.robot_utils import precise_sleep diff --git a/src/lerobot/scripts/lerobot_record.py b/src/lerobot/scripts/lerobot_record.py index 661d33c51..66e2c4228 100644 --- a/src/lerobot/scripts/lerobot_record.py +++ b/src/lerobot/scripts/lerobot_record.py @@ -125,6 +125,7 @@ from lerobot.teleoperators import ( # noqa: F401 make_teleoperator_from_config, omx_leader, openarm_leader, + openarm_mini, reachy2_teleoperator, so_leader, unitree_g1, diff --git a/src/lerobot/scripts/lerobot_teleoperate.py b/src/lerobot/scripts/lerobot_teleoperate.py index b6aa4a750..dad479b2e 100644 --- a/src/lerobot/scripts/lerobot_teleoperate.py +++ b/src/lerobot/scripts/lerobot_teleoperate.py @@ -94,6 +94,7 @@ from lerobot.teleoperators import ( # noqa: F401 make_teleoperator_from_config, omx_leader, openarm_leader, + openarm_mini, reachy2_teleoperator, so_leader, unitree_g1, From a0fdbf037ac918d0f2cdfd540db72199e0b925d0 Mon Sep 17 00:00:00 2001 From: Jade Choghari Date: Fri, 27 Feb 2026 18:58:36 +0300 Subject: [PATCH 073/131] feat(policies): add Smolvla torch compile support (#3043) * Change LIBERO init_state_id when reset. Signed-off-by: Aoqun Jin * Change LIBERO init_state_id when reset. Signed-off-by: Aoqun Jin * pre-commit run * Add torch.compile for smolvla Signed-off-by: Aoqun Jin * Add torch.compile for smolvla Add model compilation option for improved performance. Signed-off-by: Aoqun Jin * first --------- Signed-off-by: Aoqun Jin Co-authored-by: Aoqun Jin Co-authored-by: Steven Palma --- src/lerobot/policies/smolvla/configuration_smolvla.py | 3 +++ src/lerobot/policies/smolvla/modeling_smolvla.py | 6 ++++++ 2 files changed, 9 insertions(+) diff --git a/src/lerobot/policies/smolvla/configuration_smolvla.py b/src/lerobot/policies/smolvla/configuration_smolvla.py index c696265f2..b861b856b 100644 --- a/src/lerobot/policies/smolvla/configuration_smolvla.py +++ b/src/lerobot/policies/smolvla/configuration_smolvla.py @@ -106,6 +106,9 @@ class SmolVLAConfig(PreTrainedConfig): # Real-Time Chunking (RTC) configuration rtc_config: RTCConfig | None = None + compile_model: bool = False # Whether to use torch.compile for model optimization + compile_mode: str = "max-autotune" # Torch compile mode + def __post_init__(self): super().__post_init__() diff --git a/src/lerobot/policies/smolvla/modeling_smolvla.py b/src/lerobot/policies/smolvla/modeling_smolvla.py index 10544a949..e49226d26 100644 --- a/src/lerobot/policies/smolvla/modeling_smolvla.py +++ b/src/lerobot/policies/smolvla/modeling_smolvla.py @@ -593,6 +593,12 @@ class VLAFlowMatching(nn.Module): self.prefix_length = self.config.prefix_length self.rtc_processor = rtc_processor + # Compile model if requested + if config.compile_model: + torch.set_float32_matmul_precision("high") + self.sample_actions = torch.compile(self.sample_actions, mode=config.compile_mode) + self.forward = torch.compile(self.forward, mode=config.compile_mode) + def _rtc_enabled(self): return self.config.rtc_config is not None and self.config.rtc_config.enabled From baf9b5036586f3667c6f5310d30396b4b233a801 Mon Sep 17 00:00:00 2001 From: Khalil Meftah Date: Fri, 27 Feb 2026 17:44:53 +0100 Subject: [PATCH 074/131] Fix(diffusion): enforce no-crop behavior when crop_ratio=1.0 (#3046) * refactor(diffusion): replace crop_shape with resize_shape and crop_ratio * fix(diffusion): address review feedback on resize/crop backward compat * test: regenerate diffusion artifacts for updated default config * fix: disable crop when resize path uses crop_ratio=1.0 --------- Co-authored-by: starlitxiling <1754165401@qq.com> --- .../diffusion/configuration_diffusion.py | 44 +++++++++++++++---- .../policies/diffusion/modeling_diffusion.py | 28 ++++++++---- .../pusht_diffusion_/actions.safetensors | 2 +- .../pusht_diffusion_/grad_stats.safetensors | 2 +- .../pusht_diffusion_/output_dict.safetensors | 2 +- .../pusht_diffusion_/param_stats.safetensors | 2 +- 6 files changed, 59 insertions(+), 21 deletions(-) diff --git a/src/lerobot/policies/diffusion/configuration_diffusion.py b/src/lerobot/policies/diffusion/configuration_diffusion.py index 3d30e0941..91b3df214 100644 --- a/src/lerobot/policies/diffusion/configuration_diffusion.py +++ b/src/lerobot/policies/diffusion/configuration_diffusion.py @@ -55,10 +55,16 @@ class DiffusionConfig(PreTrainedConfig): normalization_mapping: A dictionary that maps from a str value of FeatureType (e.g., "STATE", "VISUAL") to a corresponding NormalizationMode (e.g., NormalizationMode.MIN_MAX) vision_backbone: Name of the torchvision resnet backbone to use for encoding images. - crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit - within the image size. If None, no cropping is done. - crop_is_random: Whether the crop should be random at training time (it's always a center crop in eval - mode). + resize_shape: (H, W) shape to resize images to as a preprocessing step for the vision + backbone. If None, no resizing is done and the original image resolution is used. + crop_ratio: Ratio in (0, 1] used to derive the crop size from resize_shape + (crop_h = int(resize_shape[0] * crop_ratio), likewise for width). + Set to 1.0 to disable cropping. Only takes effect when resize_shape is not None. + crop_shape: (H, W) shape to crop images to. When resize_shape is set and crop_ratio < 1.0, + this is computed automatically. Can also be set directly for legacy configs that use + crop-only (without resize). If None and no derivation applies, no cropping is done. + crop_is_random: Whether the crop should be random at training time (it's always a center + crop in eval mode). pretrained_backbone_weights: Pretrained weights from torchvision to initialize the backbone. `None` means no pretrained weights. use_group_norm: Whether to replace batch normalization with group normalization in the backbone. @@ -114,7 +120,9 @@ class DiffusionConfig(PreTrainedConfig): # Architecture / modeling. # Vision backbone. vision_backbone: str = "resnet18" - crop_shape: tuple[int, int] | None = (84, 84) + resize_shape: tuple[int, int] | None = None + crop_ratio: float = 1.0 + crop_shape: tuple[int, int] | None = None crop_is_random: bool = True pretrained_backbone_weights: str | None = None use_group_norm: bool = True @@ -175,6 +183,25 @@ class DiffusionConfig(PreTrainedConfig): f"Got {self.noise_scheduler_type}." ) + if self.resize_shape is not None and ( + len(self.resize_shape) != 2 or any(d <= 0 for d in self.resize_shape) + ): + raise ValueError(f"`resize_shape` must be a pair of positive integers. Got {self.resize_shape}.") + if not (0 < self.crop_ratio <= 1.0): + raise ValueError(f"`crop_ratio` must be in (0, 1]. Got {self.crop_ratio}.") + + if self.resize_shape is not None: + if self.crop_ratio < 1.0: + self.crop_shape = ( + int(self.resize_shape[0] * self.crop_ratio), + int(self.resize_shape[1] * self.crop_ratio), + ) + else: + # Explicitly disable cropping for resize+ratio path when crop_ratio == 1.0. + self.crop_shape = None + if self.crop_shape is not None and (self.crop_shape[0] <= 0 or self.crop_shape[1] <= 0): + raise ValueError(f"`crop_shape` must have positive dimensions. Got {self.crop_shape}.") + # Check that the horizon size and U-Net downsampling is compatible. # U-Net downsamples by 2 with each stage. downsampling_factor = 2 ** len(self.down_dims) @@ -202,13 +229,12 @@ class DiffusionConfig(PreTrainedConfig): if len(self.image_features) == 0 and self.env_state_feature is None: raise ValueError("You must provide at least one image or the environment state among the inputs.") - if self.crop_shape is not None: + if self.resize_shape is None and self.crop_shape is not None: for key, image_ft in self.image_features.items(): if self.crop_shape[0] > image_ft.shape[1] or self.crop_shape[1] > image_ft.shape[2]: raise ValueError( - f"`crop_shape` should fit within the images shapes. Got {self.crop_shape} " - f"for `crop_shape` and {image_ft.shape} for " - f"`{key}`." + f"`crop_shape` should fit within the image shapes. Got {self.crop_shape} " + f"for `crop_shape` and {image_ft.shape} for `{key}`." ) # Check that all input images have the same shape. diff --git a/src/lerobot/policies/diffusion/modeling_diffusion.py b/src/lerobot/policies/diffusion/modeling_diffusion.py index 314ca369c..aa8d5dd14 100644 --- a/src/lerobot/policies/diffusion/modeling_diffusion.py +++ b/src/lerobot/policies/diffusion/modeling_diffusion.py @@ -454,12 +454,18 @@ class DiffusionRgbEncoder(nn.Module): def __init__(self, config: DiffusionConfig): super().__init__() # Set up optional preprocessing. - if config.crop_shape is not None: + if config.resize_shape is not None: + self.resize = torchvision.transforms.Resize(config.resize_shape) + else: + self.resize = None + + crop_shape = config.crop_shape + if crop_shape is not None: self.do_crop = True # Always use center crop for eval - self.center_crop = torchvision.transforms.CenterCrop(config.crop_shape) + self.center_crop = torchvision.transforms.CenterCrop(crop_shape) if config.crop_is_random: - self.maybe_random_crop = torchvision.transforms.RandomCrop(config.crop_shape) + self.maybe_random_crop = torchvision.transforms.RandomCrop(crop_shape) else: self.maybe_random_crop = self.center_crop else: @@ -485,13 +491,16 @@ class DiffusionRgbEncoder(nn.Module): # Set up pooling and final layers. # Use a dry run to get the feature map shape. - # The dummy input should take the number of image channels from `config.image_features` and it should - # use the height and width from `config.crop_shape` if it is provided, otherwise it should use the - # height and width from `config.image_features`. + # The dummy shape mirrors the runtime preprocessing order: resize -> crop. # Note: we have a check in the config class to make sure all images have the same shape. images_shape = next(iter(config.image_features.values())).shape - dummy_shape_h_w = config.crop_shape if config.crop_shape is not None else images_shape[1:] + if config.crop_shape is not None: + dummy_shape_h_w = config.crop_shape + elif config.resize_shape is not None: + dummy_shape_h_w = config.resize_shape + else: + dummy_shape_h_w = images_shape[1:] dummy_shape = (1, images_shape[0], *dummy_shape_h_w) feature_map_shape = get_output_shape(self.backbone, dummy_shape)[1:] @@ -507,7 +516,10 @@ class DiffusionRgbEncoder(nn.Module): Returns: (B, D) image feature. """ - # Preprocess: maybe crop (if it was set up in the __init__). + # Preprocess: resize if configured, then crop if configured. + + if self.resize is not None: + x = self.resize(x) if self.do_crop: if self.training: # noqa: SIM108 x = self.maybe_random_crop(x) diff --git a/tests/artifacts/policies/pusht_diffusion_/actions.safetensors b/tests/artifacts/policies/pusht_diffusion_/actions.safetensors index ef581727d..70b1411ab 100644 --- a/tests/artifacts/policies/pusht_diffusion_/actions.safetensors +++ b/tests/artifacts/policies/pusht_diffusion_/actions.safetensors @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:19eaaa85f66ba4aa6388dbb83819ffad6ea4363247208f871a8dc385689f6fc8 +oid sha256:54aecbc1af72a4cd5e9261492f5e7601890517516257aacdf2a0ffb3ce281f1b size 992 diff --git a/tests/artifacts/policies/pusht_diffusion_/grad_stats.safetensors b/tests/artifacts/policies/pusht_diffusion_/grad_stats.safetensors index e00ed3238..bea7d4f19 100644 --- a/tests/artifacts/policies/pusht_diffusion_/grad_stats.safetensors +++ b/tests/artifacts/policies/pusht_diffusion_/grad_stats.safetensors @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:227296eaeeb54acdc3dae2eb8af3d4d08fb87e245337624447140b1e91cfd002 +oid sha256:88a9c3775a2aa1e90a08850521970070a4fcf0f6b82aab43cd8ccc5cf77e0013 size 47424 diff --git a/tests/artifacts/policies/pusht_diffusion_/output_dict.safetensors b/tests/artifacts/policies/pusht_diffusion_/output_dict.safetensors index f29303992..20cc4f547 100644 --- a/tests/artifacts/policies/pusht_diffusion_/output_dict.safetensors +++ b/tests/artifacts/policies/pusht_diffusion_/output_dict.safetensors @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:271b00cb2f0cd5fd26b1d53463638e3d1a6e92692ec625fcffb420ca190869e5 +oid sha256:91a2635e05a75fe187a5081504c5f35ce3417378813fa2deaf9ca4e8200e1819 size 68 diff --git a/tests/artifacts/policies/pusht_diffusion_/param_stats.safetensors b/tests/artifacts/policies/pusht_diffusion_/param_stats.safetensors index 614cc754e..365a453dd 100644 --- a/tests/artifacts/policies/pusht_diffusion_/param_stats.safetensors +++ b/tests/artifacts/policies/pusht_diffusion_/param_stats.safetensors @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:778fddbbaa64248cee35cb377c02cc2b6076f7ce5855146de677128900617ddf +oid sha256:645bff922ac7bea63ad018ebf77c303c0e4cd2c1c0dc5ef3192865281bef3dc6 size 47424 From 04de49654718c6584d1e5561506180f6196b7e71 Mon Sep 17 00:00:00 2001 From: Pepijn <138571049+pkooij@users.noreply.github.com> Date: Fri, 27 Feb 2026 17:45:19 +0100 Subject: [PATCH 075/131] fix(logging): avoid double-counting samples across processes (#3045) --- src/lerobot/scripts/lerobot_train.py | 4 ++-- src/lerobot/utils/logging_utils.py | 6 +++-- tests/utils/test_logging_utils.py | 36 ++++++++++++++++++++++++++++ 3 files changed, 42 insertions(+), 4 deletions(-) diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index 465cbf531..04d43d91e 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -380,10 +380,10 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): "dataloading_s": AverageMeter("data_s", ":.3f"), } - # Use effective batch size for proper epoch calculation in distributed training + # Keep global batch size for logging; MetricsTracker handles world size internally. effective_batch_size = cfg.batch_size * accelerator.num_processes train_tracker = MetricsTracker( - effective_batch_size, + cfg.batch_size, dataset.num_frames, dataset.num_episodes, train_metrics, diff --git a/src/lerobot/utils/logging_utils.py b/src/lerobot/utils/logging_utils.py index c4c1f42e0..1497c0585 100644 --- a/src/lerobot/utils/logging_utils.py +++ b/src/lerobot/utils/logging_utils.py @@ -104,9 +104,10 @@ class MetricsTracker: self.metrics = metrics self.steps = initial_step + world_size = accelerator.num_processes if accelerator else 1 # A sample is an (observation,action) pair, where observation and action # can be on multiple timestamps. In a batch, we have `batch_size` number of samples. - self.samples = self.steps * self._batch_size + self.samples = self.steps * self._batch_size * world_size self.episodes = self.samples / self._avg_samples_per_ep self.epochs = self.samples / self._num_frames self.accelerator = accelerator @@ -132,7 +133,8 @@ class MetricsTracker: Updates metrics that depend on 'step' for one step. """ self.steps += 1 - self.samples += self._batch_size * (self.accelerator.num_processes if self.accelerator else 1) + world_size = self.accelerator.num_processes if self.accelerator else 1 + self.samples += self._batch_size * world_size self.episodes = self.samples / self._avg_samples_per_ep self.epochs = self.samples / self._num_frames diff --git a/tests/utils/test_logging_utils.py b/tests/utils/test_logging_utils.py index 560ba5701..1207534c0 100644 --- a/tests/utils/test_logging_utils.py +++ b/tests/utils/test_logging_utils.py @@ -24,6 +24,11 @@ def mock_metrics(): return {"loss": AverageMeter("loss", ":.3f"), "accuracy": AverageMeter("accuracy", ":.2f")} +class MockAccelerator: + def __init__(self, num_processes: int): + self.num_processes = num_processes + + def test_average_meter_initialization(): meter = AverageMeter("loss", ":.2f") assert meter.name == "loss" @@ -82,6 +87,37 @@ def test_metrics_tracker_step(mock_metrics): assert tracker.epochs == tracker.samples / 1000 +def test_metrics_tracker_initialization_with_accelerator(mock_metrics): + tracker = MetricsTracker( + batch_size=32, + num_frames=1000, + num_episodes=50, + metrics=mock_metrics, + initial_step=10, + accelerator=MockAccelerator(num_processes=2), + ) + assert tracker.steps == 10 + assert tracker.samples == 10 * 32 * 2 + assert tracker.episodes == tracker.samples / (1000 / 50) + assert tracker.epochs == tracker.samples / 1000 + + +def test_metrics_tracker_step_with_accelerator(mock_metrics): + tracker = MetricsTracker( + batch_size=32, + num_frames=1000, + num_episodes=50, + metrics=mock_metrics, + initial_step=5, + accelerator=MockAccelerator(num_processes=2), + ) + tracker.step() + assert tracker.steps == 6 + assert tracker.samples == (5 * 32 * 2) + (32 * 2) + assert tracker.episodes == tracker.samples / (1000 / 50) + assert tracker.epochs == tracker.samples / 1000 + + def test_metrics_tracker_getattr(mock_metrics): tracker = MetricsTracker(batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics) assert tracker.loss == mock_metrics["loss"] From 8fff0fde7c79f23a93d845d1a50e985de01f8b8a Mon Sep 17 00:00:00 2001 From: Caroline Pascal Date: Fri, 27 Feb 2026 18:22:44 +0100 Subject: [PATCH 076/131] chore(docstrings): fixing deprecated `root` argument description in LeRobotDataset class (#3035) * chore(docstrings): fixing deprecated `root` argument docstrings in LeRobotDataset class * chore(draccus): updating draccus CLI help * chore(revert): reverting changes in lerobot_dataset_viz.py --------- Co-authored-by: Steven Palma --- examples/backward_compatibility/replay.py | 2 +- src/lerobot/configs/default.py | 2 +- src/lerobot/datasets/lerobot_dataset.py | 10 +++++----- src/lerobot/scripts/lerobot_record.py | 2 +- src/lerobot/scripts/lerobot_replay.py | 2 +- 5 files changed, 9 insertions(+), 9 deletions(-) diff --git a/examples/backward_compatibility/replay.py b/examples/backward_compatibility/replay.py index f7c47bec5..13fdfd5f5 100644 --- a/examples/backward_compatibility/replay.py +++ b/examples/backward_compatibility/replay.py @@ -57,7 +57,7 @@ class DatasetReplayConfig: repo_id: str # Episode to replay. episode: int - # Root directory where the dataset will be stored (e.g. 'dataset/path'). + # Root directory where the dataset will be stored (e.g. 'dataset/path'). If None, defaults to $HF_LEROBOT_HOME/repo_id. root: str | Path | None = None # Limit the frames per second. By default, uses the policy fps. fps: int = 30 diff --git a/src/lerobot/configs/default.py b/src/lerobot/configs/default.py index f613b5251..dcb0cbd54 100644 --- a/src/lerobot/configs/default.py +++ b/src/lerobot/configs/default.py @@ -27,7 +27,7 @@ class DatasetConfig: # "dataset_index" into the returned item. The index mapping is made according to the order in which the # datasets are provided. repo_id: str - # Root directory where the dataset will be stored (e.g. 'dataset/path'). + # Root directory where the dataset will be stored (e.g. 'dataset/path'). If None, defaults to $HF_LEROBOT_HOME/repo_id. root: str | None = None episodes: list[int] | None = None image_transforms: ImageTransformsConfig = field(default_factory=ImageTransformsConfig) diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index bb526740e..76d44de07 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -664,11 +664,11 @@ class LeRobotDataset(torch.utils.data.Dataset): for the README). Args: - repo_id (str): This is the repo id that will be used to fetch the dataset. Locally, the dataset - will be stored under root/repo_id. - root (Path | None, optional): Local directory to use for downloading/writing files. You can also - set the HF_LEROBOT_HOME environment variable to point to a different location. Defaults to - '~/.cache/huggingface/lerobot'. + repo_id (str): This is the repo id that will be used to fetch the dataset. + root (Path | None, optional): Local directory where the dataset will be downloaded and + stored. If set, all dataset files will be stored directly under this path. If not set, the + dataset files will be stored under $HF_LEROBOT_HOME/repo_id (configurable via the + HF_LEROBOT_HOME environment variable). episodes (list[int] | None, optional): If specified, this will only load episodes specified by their episode_index in this list. Defaults to None. image_transforms (Callable | None, optional): You can pass standard v2 image transforms from diff --git a/src/lerobot/scripts/lerobot_record.py b/src/lerobot/scripts/lerobot_record.py index 66e2c4228..72708ba23 100644 --- a/src/lerobot/scripts/lerobot_record.py +++ b/src/lerobot/scripts/lerobot_record.py @@ -155,7 +155,7 @@ class DatasetRecordConfig: repo_id: str # A short but accurate description of the task performed during the recording (e.g. "Pick the Lego block and drop it in the box on the right.") single_task: str - # Root directory where the dataset will be stored (e.g. 'dataset/path'). + # Root directory where the dataset will be stored (e.g. 'dataset/path'). If None, defaults to $HF_LEROBOT_HOME/repo_id. root: str | Path | None = None # Limit the frames per second. fps: int = 30 diff --git a/src/lerobot/scripts/lerobot_replay.py b/src/lerobot/scripts/lerobot_replay.py index 8e2a394b9..7c0b5b96b 100644 --- a/src/lerobot/scripts/lerobot_replay.py +++ b/src/lerobot/scripts/lerobot_replay.py @@ -80,7 +80,7 @@ class DatasetReplayConfig: repo_id: str # Episode to replay. episode: int - # Root directory where the dataset will be stored (e.g. 'dataset/path'). + # Root directory where the dataset will be stored (e.g. 'dataset/path'). If None, defaults to $HF_LEROBOT_HOME/repo_id. root: str | Path | None = None # Limit the frames per second. By default, uses the policy fps. fps: int = 30 From 563f42bdb1db8f8a96d28d4b868c5961eefa4499 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Fri, 27 Feb 2026 19:29:35 +0100 Subject: [PATCH 077/131] chore(dependencies): Bump lerobot to 0.4.5 (#3051) --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index b6d85b0f6..f4fb7d249 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ discord = "https://discord.gg/s3KuuzsPFb" [project] name = "lerobot" -version = "0.4.4" +version = "0.4.5" description = "🤗 LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch" dynamic = ["readme"] license = { text = "Apache-2.0" } From 095856b06af7e4bd6e79f9f741303701d052ba2d Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Sat, 28 Feb 2026 14:41:28 +0100 Subject: [PATCH 078/131] chore: add AI policy (#3055) --- AI_POLICY.md | 25 +++++++++++++++++++++++++ CONTRIBUTING.md | 2 +- 2 files changed, 26 insertions(+), 1 deletion(-) create mode 100644 AI_POLICY.md diff --git a/AI_POLICY.md b/AI_POLICY.md new file mode 100644 index 000000000..272ee8c12 --- /dev/null +++ b/AI_POLICY.md @@ -0,0 +1,25 @@ +# AI Usage Policy + +The LeRobot project welcomes contributions from everyone, and we have a few guidelines regarding AI usage to ensure high code quality, clear communication, and a healthy open-source ecosystem: + +- **Please disclose significant AI assistance.** If you used AI tools (e.g., Copilot, Claude, Cursor, ChatGPT) to generate a substantial portion of your code or text, let us know in your PR description. Transparency helps us review your changes more effectively. +- **Own your code (The Human-in-the-Loop).** You must fully understand all the changes you are proposing. If you cannot explain what your AI-assisted code does or how it interacts with LeRobot's broader architecture, please take the time to learn and test it before submitting. +- **Keep issues and discussions focused.** You are welcome to use AI to help draft issues or PR descriptions, but please review and edit them carefully before posting. AI can often be overly verbose; trimming the noise and getting straight to the point helps our maintainers address your needs faster. + +Our core maintainers also use AI tools to aid their workflows, but they do so while bringing deep contextual knowledge of the LeRobot codebase to validate the output. We ask all contributors to apply that same level of rigor. + +## Remember the Human Maintainers + +Please remember that LeRobot is maintained by a dedicated team of humans. + +Every discussion, issue, and pull request is read and reviewed by real people. While AI tools can generate thousands of lines of code in seconds, reviewing that code still takes human time and energy. Submitting unverified or low-effort AI output puts an unfair burden on our maintainers. + +Today, the quality of the AI output still heavily depends on the developer driving the tool. We ask that you respect our maintainers' time by thoroughly vetting, testing, and refining your submissions. + +## AI is Welcome Here + +LeRobot operates at the cutting edge of AI and robotics, and many of our maintainers actively embrace AI coding assistants as valuable productivity tools. We are a pro-AI project! + +Our reason for having an AI policy is not an anti-AI stance. Rather, it exists to ensure that AI is used to enhance human contributions, not replace them with unverified noise. It's about how the tools are used, not the tools themselves. + +We value the unique human insight you bring to the LeRobot community. Let AI empower your workflow, but always let your own judgment take the wheel. diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index c51a48831..82147d363 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -2,7 +2,7 @@ Everyone is welcome to contribute, and we value everybody's contribution. Code is not the only way to help the community. Answering questions, helping others, reaching out, and improving the documentation are immensely valuable. -Whichever way you choose to contribute, please be mindful to respect our [code of conduct](./CODE_OF_CONDUCT.md). +Whichever way you choose to contribute, please be mindful to respect our [code of conduct](./CODE_OF_CONDUCT.md) and our [AI policy](./AI_POLICY.md). ## Ways to Contribute From 8bb8ed48039e9f4595b105dd99f6bfff4b9aa8e7 Mon Sep 17 00:00:00 2001 From: Bernie Telles Date: Mon, 2 Mar 2026 06:35:15 -0800 Subject: [PATCH 079/131] Improve policy_device documentation for async.mdx (#3060) --- docs/source/async.mdx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/async.mdx b/docs/source/async.mdx index 3244fc2a3..fcc3f1d1e 100644 --- a/docs/source/async.mdx +++ b/docs/source/async.mdx @@ -48,7 +48,7 @@ python -m lerobot.async_inference.robot_client \ --task="dummy" \ # POLICY: The task to run the policy on (`Fold my t-shirt`). Not necessarily defined for all policies, such as `act` --policy_type=your_policy_type \ # POLICY: the type of policy to run (smolvla, act, etc) --pretrained_name_or_path=user/model \ # POLICY: the model name/path on server to the checkpoint to run (e.g., lerobot/smolvla_base) - --policy_device=mps \ # POLICY: the device to run the policy on, on the server + --policy_device=mps \ # POLICY: the device to run the policy on, on the server (cuda, mps, xpu, cpu) --actions_per_chunk=50 \ # POLICY: the number of actions to output at once --chunk_size_threshold=0.5 \ # CLIENT: the threshold for the chunk size before sending a new observation to the server --aggregate_fn_name=weighted_average \ # CLIENT: the function to aggregate actions on overlapping portions From 8a0cc3d6645a30609de1eb085e6943263ed11141 Mon Sep 17 00:00:00 2001 From: Caroline Pascal Date: Tue, 3 Mar 2026 11:55:09 +0100 Subject: [PATCH 080/131] fix(frame_index): making rerun's "frame_index" timeline compatible with behaviour1k datasets (#3068) * fix(frame_index): making rerun's "frame_index" timeline compatible with behaviour1k datasets * fix(segfault risk): removing segfault risk by calling batch["index"] in the dataloader loop --- src/lerobot/scripts/lerobot_dataset_viz.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/lerobot/scripts/lerobot_dataset_viz.py b/src/lerobot/scripts/lerobot_dataset_viz.py index 29d64554f..c4b676c67 100644 --- a/src/lerobot/scripts/lerobot_dataset_viz.py +++ b/src/lerobot/scripts/lerobot_dataset_viz.py @@ -132,10 +132,13 @@ def visualize_dataset( logging.info("Logging to Rerun") + first_index = None for batch in tqdm.tqdm(dataloader, total=len(dataloader)): + if first_index is None: + first_index = batch["index"][0].item() # iterate over the batch for i in range(len(batch["index"])): - rr.set_time("frame_index", sequence=batch["frame_index"][i].item()) + rr.set_time("frame_index", sequence=batch["index"][i].item() - first_index) rr.set_time("timestamp", timestamp=batch["timestamp"][i].item()) # display each camera image From 63dca86df86dbca04378590dd6d8618332dae0bb Mon Sep 17 00:00:00 2001 From: Caroline Pascal Date: Tue, 3 Mar 2026 15:40:46 +0100 Subject: [PATCH 081/131] fix(dataset edit tools): clarifying `root` argument usage + adding related features (#3049) * fix(root): adding proper support for the root and new_root arguments * feat(roots): adding a roots agrument for the merge operation * chore(clean): cleaning up code * chore(doctrings): updating doctrings with new features * fix(repo_id): setting repo_id to None when not needed * fix(roots/repo_ids): making mypy happy by using repo_ids and roots for merge operation * fix(path): fixing path related issues * fix(repo_id): fixing issues related to repo_id * chore(doctrings): updating docstrings + fix typo * chore(clean): cleaning code * fix(split new_repo_id): reverting new_repo_id addition for split operation * docs(dosctrings): completing docstrings * fix(repo_ids/roots): improving checks for repo_ids/roots lengths * fix(repo_ids): making repo_ids optional in MergeConfig but raise if not given * fix(docstrings): fixing docstrings for split operation * fix(hints): updating get_output_path hints to accept paths as strings too * fix(y/N prompts): removing y/N prompts in lerobot_edit_dataset * fix(merge repo_id): fixing merge operation to use new_repo_id instead of repo_id * fix(typo): fixing typo in doctrings --- src/lerobot/datasets/dataset_tools.py | 29 +-- src/lerobot/scripts/lerobot_edit_dataset.py | 201 ++++++++++++++------ tests/scripts/test_edit_dataset_parsing.py | 19 +- 3 files changed, 173 insertions(+), 76 deletions(-) diff --git a/src/lerobot/datasets/dataset_tools.py b/src/lerobot/datasets/dataset_tools.py index b62d7d959..c900d7479 100644 --- a/src/lerobot/datasets/dataset_tools.py +++ b/src/lerobot/datasets/dataset_tools.py @@ -89,8 +89,8 @@ def delete_episodes( Args: dataset: The source LeRobotDataset. episode_indices: List of episode indices to delete. - output_dir: Directory to save the new dataset. If None, uses default location. - repo_id: Repository ID for the new dataset. If None, appends "_modified" to original. + output_dir: Root directory where the edited dataset will be stored. If not specified, defaults to $HF_LEROBOT_HOME/repo_id. Equivalent to new_root in EditDatasetConfig. + repo_id: Edited dataset identifier. Equivalent to new_repo_id in EditDatasetConfig. """ if not episode_indices: raise ValueError("No episodes to delete") @@ -152,7 +152,7 @@ def split_dataset( dataset: The source LeRobotDataset to split. splits: Either a dict mapping split names to episode indices, or a dict mapping split names to fractions (must sum to <= 1.0). - output_dir: Base directory for output datasets. If None, uses default location. + output_dir: Root directory where the split datasets will be stored. If not specified, defaults to $HF_LEROBOT_HOME/repo_id. Examples: Split by specific episodes @@ -243,8 +243,8 @@ def merge_datasets( Args: datasets: List of LeRobotDatasets to merge. - output_repo_id: Repository ID for the merged dataset. - output_dir: Directory to save the merged dataset. If None, uses default location. + output_repo_id: Merged dataset identifier. + output_dir: Root directory where the merged dataset will be stored. If not specified, defaults to $HF_LEROBOT_HOME/output_repo_id. """ if not datasets: raise ValueError("No datasets to merge") @@ -288,8 +288,8 @@ def modify_features( dataset: The source LeRobotDataset. add_features: Optional dict mapping feature names to (feature_values, feature_info) tuples. remove_features: Optional feature name(s) to remove. Can be a single string or list. - output_dir: Directory to save the new dataset. If None, uses default location. - repo_id: Repository ID for the new dataset. If None, appends "_modified" to original. + output_dir: Root directory where the edited dataset will be stored. If not specified, defaults to $HF_LEROBOT_HOME/repo_id. Equivalent to new_root in EditDatasetConfig. + repo_id: Edited dataset identifier. Equivalent to new_repo_id in EditDatasetConfig. Returns: New dataset with features modified. @@ -390,8 +390,8 @@ def add_features( Args: dataset: The source LeRobotDataset. features: Dictionary mapping feature names to (feature_values, feature_info) tuples. - output_dir: Directory to save the new dataset. If None, uses default location. - repo_id: Repository ID for the new dataset. If None, appends "_modified" to original. + output_dir: Root directory where the edited dataset will be stored. If not specified, defaults to $HF_LEROBOT_HOME/repo_id. Equivalent to new_root in EditDatasetConfig. + repo_id: Edited dataset identifier. Equivalent to new_repo_id in EditDatasetConfig. Returns: New dataset with all features added. @@ -427,8 +427,8 @@ def remove_feature( Args: dataset: The source LeRobotDataset. feature_names: Name(s) of features to remove. Can be a single string or list. - output_dir: Directory to save the new dataset. If None, uses default location. - repo_id: Repository ID for the new dataset. If None, appends "_modified" to original. + output_dir: Root directory where the edited dataset will be stored. If not specified, defaults to $HF_LEROBOT_HOME/repo_id. Equivalent to new_root in EditDatasetConfig. + repo_id: Edited dataset identifier. Equivalent to new_repo_id in EditDatasetConfig. Returns: New dataset with features removed. @@ -1529,7 +1529,7 @@ def modify_tasks( def convert_image_to_video_dataset( dataset: LeRobotDataset, - output_dir: Path, + output_dir: Path | None = None, repo_id: str | None = None, vcodec: str = "libsvtav1", pix_fmt: str = "yuv420p", @@ -1548,8 +1548,8 @@ def convert_image_to_video_dataset( Args: dataset: The source LeRobot dataset with images - output_dir: Directory to save the new video dataset - repo_id: Repository ID for the new dataset (default: original_id + "_video") + output_dir: Root directory where the edited dataset will be stored. If not specified, defaults to $HF_LEROBOT_HOME/repo_id. Equivalent to new_root in EditDatasetConfig. + repo_id: Edited dataset identifier. Equivalent to new_repo_id in EditDatasetConfig. vcodec: Video codec (default: libsvtav1) pix_fmt: Pixel format (default: yuv420p) g: Group of pictures size (default: 2) @@ -1600,6 +1600,7 @@ def convert_image_to_video_dataset( # Video info will be updated after episodes are encoded # Create new metadata for video dataset + output_dir = Path(output_dir) if output_dir is not None else HF_LEROBOT_HOME / repo_id new_meta = LeRobotDatasetMetadata.create( repo_id=repo_id, fps=dataset.meta.fps, diff --git a/src/lerobot/scripts/lerobot_edit_dataset.py b/src/lerobot/scripts/lerobot_edit_dataset.py index afdc95efd..49825317d 100644 --- a/src/lerobot/scripts/lerobot_edit_dataset.py +++ b/src/lerobot/scripts/lerobot_edit_dataset.py @@ -21,6 +21,9 @@ This script allows you to delete episodes, split datasets, merge datasets, remove features, modify tasks, and convert image datasets to video format. When new_repo_id is specified, creates a new dataset. +Path semantics (v2): --root and --new_root are exact dataset folders containing +meta/, data/, videos/. When omitted, defaults to $HF_LEROBOT_HOME/{repo_id}. + Usage Examples: Delete episodes 0, 2, and 5 from a dataset: @@ -29,16 +32,31 @@ Delete episodes 0, 2, and 5 from a dataset: --operation.type delete_episodes \ --operation.episode_indices "[0, 2, 5]" -Delete episodes and save to a new dataset: +Delete episodes from a local dataset at a specific path: lerobot-edit-dataset \ --repo_id lerobot/pusht \ - --new_repo_id lerobot/pusht_filtered \ + --root /path/to/pusht \ --operation.type delete_episodes \ --operation.episode_indices "[0, 2, 5]" -Split dataset by fractions: +Delete episodes and save to a new dataset at a specific path and with a new repo_id: lerobot-edit-dataset \ --repo_id lerobot/pusht \ + --new_repo_id lerobot/pusht_filtered \ + --new_root /path/to/pusht_filtered \ + --operation.type delete_episodes \ + --operation.episode_indices "[0, 2, 5]" + +Split dataset by fractions (pusht_train, pusht_val): + lerobot-edit-dataset \ + --repo_id lerobot/pusht \ + --operation.type split \ + --operation.splits '{"train": 0.8, "val": 0.2}' + +Split dataset by fractions and save split datasets to a specific folder (base_folder/train, base_folder/val): + lerobot-edit-dataset \ + --repo_id lerobot/pusht \ + --new_root /path/to/base_folder \ --operation.type split \ --operation.splits '{"train": 0.8, "val": 0.2}' @@ -56,15 +74,29 @@ Split into more than two splits: Merge multiple datasets: lerobot-edit-dataset \ - --repo_id lerobot/pusht_merged \ + --new_repo_id lerobot/pusht_merged \ --operation.type merge \ --operation.repo_ids "['lerobot/pusht_train', 'lerobot/pusht_val']" +Merge multiple datasets to a specific output path: + lerobot-edit-dataset \ + --new_repo_id lerobot/pusht_merged \ + --new_root /path/to/pusht_merged \ + --operation.type merge \ + --operation.repo_ids "['lerobot/pusht_train', 'lerobot/pusht_val']" + +Merge multiple datasets from a list of local dataset paths: + lerobot-edit-dataset \ + --new_repo_id lerobot/pusht_merged \ + --operation.type merge \ + --operation.repo_ids "['pusht_train', 'pusht_val']" \ + --operation.roots "['/path/to/pusht_train', '/path/to/pusht_val']" + Remove camera feature: lerobot-edit-dataset \ --repo_id lerobot/pusht \ --operation.type remove_feature \ - --operation.feature_names "['observation.images.top']" + --operation.feature_names "['observation.image']" Modify tasks - set a single task for all episodes (WARNING: modifies in-place): lerobot-edit-dataset \ @@ -88,8 +120,8 @@ Modify tasks - set default task with overrides for specific episodes (WARNING: m Convert image dataset to video format and save locally: lerobot-edit-dataset \ --repo_id lerobot/pusht_image \ - --operation.type convert_image_to_video \ - --operation.output_dir /path/to/output/pusht_video + --new_root /path/to/output/pusht_video \ + --operation.type convert_image_to_video Convert image dataset to video format and save with new repo_id: lerobot-edit-dataset \ @@ -167,6 +199,7 @@ class SplitConfig(OperationConfig): @dataclass class MergeConfig(OperationConfig): repo_ids: list[str] | None = None + roots: list[str] | None = None @OperationConfig.register_subclass("remove_feature") @@ -200,36 +233,46 @@ class ConvertImageToVideoConfig(OperationConfig): @OperationConfig.register_subclass("info") @dataclass class InfoConfig(OperationConfig): - type: str = "info" show_features: bool = False @dataclass class EditDatasetConfig: - repo_id: str + # Operation configuration. operation: OperationConfig + # Input dataset identifier. Always required unless for Merge operation. + repo_id: str | None = None + # Root directory where the input dataset is stored. If not specified, defaults to $HF_LEROBOT_HOME/repo_id. root: str | None = None + # Edited dataset identifier. When both new_repo_id (resp. new_root) and repo_id (resp. root) are identical, modifications are applied in-place and a backup of the original dataset is created. Required for Merge operation. new_repo_id: str | None = None + # Root directory where the edited dataset will be stored. If not specified, defaults to $HF_LEROBOT_HOME/new_repo_id. For Split operation, this is the base directory for the split datasets. + new_root: str | None = None + # Upload dataset to Hugging Face hub. push_to_hub: bool = False -def get_output_path(repo_id: str, new_repo_id: str | None, root: Path | None) -> tuple[str, Path]: - if new_repo_id: - output_repo_id = new_repo_id - output_dir = root / new_repo_id if root else HF_LEROBOT_HOME / new_repo_id - else: - output_repo_id = repo_id - dataset_path = root / repo_id if root else HF_LEROBOT_HOME / repo_id - old_path = Path(str(dataset_path) + "_old") +def get_output_path( + repo_id: str, + new_repo_id: str | None, + root: Path | str | None, + new_root: Path | str | None, +) -> tuple[str, Path]: + input_path = Path(root) if root else HF_LEROBOT_HOME / repo_id - if dataset_path.exists(): - if old_path.exists(): - shutil.rmtree(old_path) - shutil.move(str(dataset_path), str(old_path)) + output_repo_id = new_repo_id if new_repo_id else repo_id + output_path = Path(new_root) if new_root else HF_LEROBOT_HOME / output_repo_id - output_dir = dataset_path + # In case of in-place modification, create a backup of the original dataset (if it exists) + if output_path == input_path: + backup_path = input_path.with_name(input_path.name + "_old") - return output_repo_id, output_dir + if input_path.exists(): + if backup_path.exists(): + shutil.rmtree(backup_path) + shutil.move(input_path, backup_path) + + return output_repo_id, output_path def handle_delete_episodes(cfg: EditDatasetConfig) -> None: @@ -241,11 +284,15 @@ def handle_delete_episodes(cfg: EditDatasetConfig) -> None: dataset = LeRobotDataset(cfg.repo_id, root=cfg.root) output_repo_id, output_dir = get_output_path( - cfg.repo_id, cfg.new_repo_id, Path(cfg.root) if cfg.root else None + cfg.repo_id, + new_repo_id=cfg.new_repo_id, + root=cfg.root, + new_root=cfg.new_root, ) - if cfg.new_repo_id is None: - dataset.root = Path(str(dataset.root) + "_old") + # In case of in-place modification, make the dataset point to the backup directory + if output_dir == dataset.root: + dataset.root = dataset.root.with_name(dataset.root.name + "_old") logging.info(f"Deleting episodes {cfg.operation.episode_indices} from {cfg.repo_id}") new_dataset = delete_episodes( @@ -272,19 +319,27 @@ def handle_split(cfg: EditDatasetConfig) -> None: "splits dict must be specified with split names as keys and fractions/episode lists as values" ) + if cfg.new_repo_id is not None: + logging.warning( + "split uses the original dataset identifier --repo_id to generate split names. The --new_repo_id parameter is ignored." + ) + dataset = LeRobotDataset(cfg.repo_id, root=cfg.root) logging.info(f"Splitting dataset {cfg.repo_id} with splits: {cfg.operation.splits}") - split_datasets = split_dataset(dataset, splits=cfg.operation.splits) + split_datasets = split_dataset( + dataset, + splits=cfg.operation.splits, + output_dir=cfg.new_root, + ) for split_name, split_ds in split_datasets.items(): - split_repo_id = f"{cfg.repo_id}_{split_name}" logging.info( f"{split_name}: {split_ds.meta.total_episodes} episodes, {split_ds.meta.total_frames} frames" ) if cfg.push_to_hub: - logging.info(f"Pushing {split_name} split to hub as {split_repo_id}") + logging.info(f"Pushing {split_name} split to hub as {split_ds.repo_id}") LeRobotDataset(split_ds.repo_id, root=split_ds.root).push_to_hub() @@ -295,18 +350,29 @@ def handle_merge(cfg: EditDatasetConfig) -> None: if not cfg.operation.repo_ids: raise ValueError("repo_ids must be specified for merge operation") - if not cfg.repo_id: - raise ValueError("repo_id must be specified as the output repository for merged dataset") + if cfg.repo_id is not None or cfg.root is not None: + logging.warning( + "merge uses --new_repo_id and --new_root for the merged dataset. The --repo_id and --root parameters are ignored." + ) - logging.info(f"Loading {len(cfg.operation.repo_ids)} datasets to merge") - datasets = [LeRobotDataset(repo_id, root=cfg.root) for repo_id in cfg.operation.repo_ids] + if cfg.operation.roots: + if len(cfg.operation.roots) != len(cfg.operation.repo_ids): + raise ValueError("repo_ids and roots must have the same length for merge operation") + logging.info(f"Loading {len(cfg.operation.roots)} datasets to merge") + datasets = [ + LeRobotDataset(repo_id=repo_id, root=root) + for repo_id, root in zip(cfg.operation.repo_ids, cfg.operation.roots, strict=True) + ] + else: + logging.info(f"Loading {len(cfg.operation.repo_ids)} datasets to merge") + datasets = [LeRobotDataset(repo_id) for repo_id in cfg.operation.repo_ids] - output_dir = Path(cfg.root) / cfg.repo_id if cfg.root else HF_LEROBOT_HOME / cfg.repo_id + output_dir = Path(cfg.new_root) if cfg.new_root else HF_LEROBOT_HOME / cfg.new_repo_id - logging.info(f"Merging datasets into {cfg.repo_id}") + logging.info(f"Merging datasets into {cfg.new_repo_id}") merged_dataset = merge_datasets( datasets, - output_repo_id=cfg.repo_id, + output_repo_id=cfg.new_repo_id, output_dir=output_dir, ) @@ -316,7 +382,7 @@ def handle_merge(cfg: EditDatasetConfig) -> None: ) if cfg.push_to_hub: - logging.info(f"Pushing to hub as {cfg.repo_id}") + logging.info(f"Pushing to hub as {cfg.new_repo_id}") LeRobotDataset(merged_dataset.repo_id, root=output_dir).push_to_hub() @@ -329,11 +395,15 @@ def handle_remove_feature(cfg: EditDatasetConfig) -> None: dataset = LeRobotDataset(cfg.repo_id, root=cfg.root) output_repo_id, output_dir = get_output_path( - cfg.repo_id, cfg.new_repo_id, Path(cfg.root) if cfg.root else None + cfg.repo_id, + new_repo_id=cfg.new_repo_id, + root=cfg.root, + new_root=cfg.new_root, ) - if cfg.new_repo_id is None: - dataset.root = Path(str(dataset.root) + "_old") + # In case of in-place modification, make the dataset point to the backup directory + if output_dir == dataset.root: + dataset.root = dataset.root.with_name(dataset.root.name + "_old") logging.info(f"Removing features {cfg.operation.feature_names} from {cfg.repo_id}") new_dataset = remove_feature( @@ -361,9 +431,10 @@ def handle_modify_tasks(cfg: EditDatasetConfig) -> None: if new_task is None and episode_tasks_raw is None: raise ValueError("Must specify at least one of new_task or episode_tasks for modify_tasks operation") - # Warn about in-place modification behavior - if cfg.new_repo_id is not None: - logging.warning("modify_tasks modifies datasets in-place. The --new_repo_id parameter is ignored.") + if cfg.new_repo_id is not None or cfg.new_root is not None: + logging.warning( + "modify_tasks modifies datasets in-place. The --new_repo_id and --new_root parameters are ignored." + ) dataset = LeRobotDataset(cfg.repo_id, root=cfg.root) logging.warning(f"Modifying dataset in-place at {dataset.root}. Original data will be overwritten.") @@ -399,32 +470,30 @@ def handle_convert_image_to_video(cfg: EditDatasetConfig) -> None: dataset = LeRobotDataset(cfg.repo_id, root=cfg.root) # Determine output directory and repo_id - # Priority: 1) new_repo_id, 2) operation.output_dir, 3) auto-generated name + # Priority: 1) new_root, 2) new_repo_id, 3) operation.output_dir, 4) auto-generated name output_dir_config = getattr(cfg.operation, "output_dir", None) + if output_dir_config: + logging.warning( + "--operation.output_dir is deprecated and will be removed in future versions. " + "Please use --new_root instead." + ) - if cfg.new_repo_id: - # Use new_repo_id for both local storage and hub push + if cfg.new_root: + output_dir = Path(cfg.new_root) + output_repo_id = cfg.new_repo_id or f"{cfg.repo_id}_video" + logging.info(f"Saving to new_root: {output_dir} as {output_repo_id}") + elif cfg.new_repo_id: output_repo_id = cfg.new_repo_id - # Place new dataset as a sibling to the original dataset - # Get the parent of the actual dataset root (not cfg.root which might be the lerobot cache dir) - # Extract just the dataset name (after last slash) for the local directory - local_dir_name = cfg.new_repo_id.split("/")[-1] - output_dir = dataset.root.parent / local_dir_name + output_dir = HF_LEROBOT_HOME / cfg.new_repo_id logging.info(f"Saving to new dataset: {cfg.new_repo_id} at {output_dir}") elif output_dir_config: - # Use custom output directory for local-only storage output_dir = Path(output_dir_config) - # Extract repo name from output_dir for the dataset output_repo_id = output_dir.name - logging.info(f"Saving to local directory: {output_dir}") + logging.info(f"Saving to local directory: {output_dir} as {output_repo_id}") else: - # Auto-generate name: append "_video" to original repo_id output_repo_id = f"{cfg.repo_id}_video" - # Place new dataset as a sibling to the original dataset - # Extract just the dataset name (after last slash) for the local directory - local_dir_name = output_repo_id.split("/")[-1] - output_dir = dataset.root.parent / local_dir_name - logging.info(f"Saving to auto-generated location: {output_dir}") + output_dir = HF_LEROBOT_HOME / output_repo_id + logging.info(f"Saving to auto-generated location: {output_dir} as {output_repo_id}") logging.info(f"Converting dataset {cfg.repo_id} to video format") @@ -499,8 +568,20 @@ def handle_info(cfg: EditDatasetConfig): sys.stdout.write(f"{feature_dump_str}\n") +def _validate_config(cfg: EditDatasetConfig) -> None: + if isinstance(cfg.operation, MergeConfig): + if not cfg.new_repo_id: + raise ValueError("--new_repo_id is required for merge operation (the merged dataset identifier)") + else: + if not cfg.repo_id: + raise ValueError( + f"--repo_id is required for {cfg.operation.type} operation (the input dataset identifier)" + ) + + @parser.wrap() def edit_dataset(cfg: EditDatasetConfig) -> None: + _validate_config(cfg) operation_type = cfg.operation.type if operation_type == "delete_episodes": diff --git a/tests/scripts/test_edit_dataset_parsing.py b/tests/scripts/test_edit_dataset_parsing.py index 8800b92ee..4d758ae35 100644 --- a/tests/scripts/test_edit_dataset_parsing.py +++ b/tests/scripts/test_edit_dataset_parsing.py @@ -27,6 +27,7 @@ from lerobot.scripts.lerobot_edit_dataset import ( OperationConfig, RemoveFeatureConfig, SplitConfig, + _validate_config, ) @@ -51,11 +52,23 @@ class TestOperationTypeParsing: ], ) def test_operation_type_resolves_correct_class(self, type_name, expected_cls): - cfg = parse_cfg(["--repo_id", "test/repo", "--operation.type", type_name]) + cfg = parse_cfg( + ["--repo_id", "test/repo", "--new_repo_id", "test/merged", "--operation.type", type_name] + ) assert isinstance(cfg.operation, expected_cls), ( f"Expected {expected_cls.__name__}, got {type(cfg.operation).__name__}" ) + def test_merge_requires_new_repo_id(self): + cfg = parse_cfg(["--operation.type", "merge"]) + with pytest.raises(ValueError, match="--new_repo_id is required for merge"): + _validate_config(cfg) + + def test_non_merge_requires_repo_id(self): + cfg = parse_cfg(["--operation.type", "delete_episodes"]) + with pytest.raises(ValueError, match="--repo_id is required for delete_episodes"): + _validate_config(cfg) + @pytest.mark.parametrize( "type_name, expected_cls", [ @@ -69,6 +82,8 @@ class TestOperationTypeParsing: ], ) def test_get_choice_name_roundtrips(self, type_name, expected_cls): - cfg = parse_cfg(["--repo_id", "test/repo", "--operation.type", type_name]) + cfg = parse_cfg( + ["--repo_id", "test/repo", "--new_repo_id", "test/merged", "--operation.type", type_name] + ) resolved_name = OperationConfig.get_choice_name(type(cfg.operation)) assert resolved_name == type_name From 4303b3c9308091032f567ce282491308a8a1ecb1 Mon Sep 17 00:00:00 2001 From: Caroline Pascal Date: Wed, 4 Mar 2026 11:11:21 +0100 Subject: [PATCH 082/131] chore(root): fixing `root` semantics in convert_dataset script (#3073) * fix(root): fixing root semantincs in convert_dataset script * fix(\): fixing command syntax in dataset conversion script Signed-off-by: Caroline Pascal --------- Signed-off-by: Caroline Pascal --- src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py b/src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py index 7be37a1b1..2a69945e1 100644 --- a/src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py +++ b/src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py @@ -36,8 +36,11 @@ Convert a local dataset (works in place): ```bash python src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py \ --repo-id=lerobot/pusht \ - --root=/path/to/local/dataset/directory + --root=/path/to/local/dataset/directory \ --push-to-hub=false + +N.B. Path semantics (v2): --root is the exact dataset folder containing +meta/, data/, videos/. When omitted, defaults to $HF_LEROBOT_HOME/{repo_id}. ``` """ @@ -469,7 +472,7 @@ def convert_dataset( # Set root based on whether local dataset path is provided use_local_dataset = False - root = HF_LEROBOT_HOME / repo_id if root is None else Path(root) / repo_id + root = HF_LEROBOT_HOME / repo_id if root is None else Path(root) if root.exists(): validate_local_dataset_version(root) use_local_dataset = True @@ -553,7 +556,7 @@ if __name__ == "__main__": "--root", type=str, default=None, - help="Local directory to use for downloading/writing the dataset.", + help="Local directory to use for downloading/writing the dataset. Defaults to $HF_LEROBOT_HOME/repo_id.", ) parser.add_argument( "--push-to-hub", From 96b7c212c44ed6c96518e7aa8d759bff98a77e5f Mon Sep 17 00:00:00 2001 From: Maxime Ellerbach Date: Wed, 4 Mar 2026 15:08:49 +0100 Subject: [PATCH 083/131] chore(docs): updating deprecated huggingface-cli to hf (#3071) * chore(docs): updating deprecated huggingface-cli to hf * small typo in my-org --- docs/source/earthrover_mini_plus.mdx | 4 ++-- docs/source/envhub.mdx | 4 ++-- docs/source/il_robots.mdx | 8 ++++---- docs/source/lekiwi.mdx | 4 ++-- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/docs/source/earthrover_mini_plus.mdx b/docs/source/earthrover_mini_plus.mdx index cfc3a2eef..37986a7a2 100644 --- a/docs/source/earthrover_mini_plus.mdx +++ b/docs/source/earthrover_mini_plus.mdx @@ -170,13 +170,13 @@ Once you can drive the robot well, you can start recording data to train AI mode We use Hugging Face to store your data online. First, log in with your token from [Hugging Face settings](https://huggingface.co/settings/tokens): ```bash -huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential +hf auth login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential ``` Store your Hugging Face username: ```bash -HF_USER=$(huggingface-cli whoami | head -n 1) +HF_USER=$(hf auth whoami | awk -F': *' 'NR==1 {print $2}') echo $HF_USER ``` diff --git a/docs/source/envhub.mdx b/docs/source/envhub.mdx index df103d0dd..36c08a8b3 100644 --- a/docs/source/envhub.mdx +++ b/docs/source/envhub.mdx @@ -155,10 +155,10 @@ Upload your repository to Hugging Face: pip install huggingface_hub # Login to Hugging Face -huggingface-cli login +hf auth login # Create a new repository -huggingface-cli repo create my-custom-env --type space --org my-org +hf repo create my-org/my-custom-env # Initialize git and push git init diff --git a/docs/source/il_robots.mdx b/docs/source/il_robots.mdx index bad88f88e..e49132a8e 100644 --- a/docs/source/il_robots.mdx +++ b/docs/source/il_robots.mdx @@ -159,7 +159,7 @@ We use the Hugging Face hub features for uploading your dataset. If you haven't Add your token to the CLI by running this command: ```bash -huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential +hf auth login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential ``` Then store your Hugging Face repository name in a variable: @@ -327,7 +327,7 @@ You can look for other LeRobot datasets on the hub by searching for `LeRobot` [t You can also push your local dataset to the Hub manually, running: ```bash -huggingface-cli upload ${HF_USER}/record-test ~/.cache/huggingface/lerobot/{repo-id} --repo-type dataset +hf upload ${HF_USER}/record-test ~/.cache/huggingface/lerobot/{repo-id} --repo-type dataset ``` #### Record function @@ -491,7 +491,7 @@ If your local computer doesn't have a powerful GPU you could utilize Google Cola Once training is done, upload the latest checkpoint with: ```bash -huggingface-cli upload ${HF_USER}/act_so101_test \ +hf upload ${HF_USER}/act_so101_test \ outputs/train/act_so101_test/checkpoints/last/pretrained_model ``` @@ -499,7 +499,7 @@ You can also upload intermediate checkpoints with: ```bash CKPT=010000 -huggingface-cli upload ${HF_USER}/act_so101_test${CKPT} \ +hf upload ${HF_USER}/act_so101_test${CKPT} \ outputs/train/act_so101_test/checkpoints/${CKPT}/pretrained_model ``` diff --git a/docs/source/lekiwi.mdx b/docs/source/lekiwi.mdx index b339225d8..7e7c1a680 100644 --- a/docs/source/lekiwi.mdx +++ b/docs/source/lekiwi.mdx @@ -279,13 +279,13 @@ We use the Hugging Face hub features for uploading your dataset. If you haven't Add your token to the CLI by running this command: ```bash -huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential +hf auth login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential ``` Then store your Hugging Face repository name in a variable: ```bash -HF_USER=$(huggingface-cli whoami | head -n 1) +HF_USER=$(hf auth whoami | awk -F': *' 'NR==1 {print $2}') echo $HF_USER ``` From 0d1be72dc8309b8841c363cdd322174ed13a7c9f Mon Sep 17 00:00:00 2001 From: Paul Crook <37202747+skiingpacman@users.noreply.github.com> Date: Thu, 5 Mar 2026 00:53:34 +0900 Subject: [PATCH 084/131] Fixing metadata indexing when writing new Parquet file (#2941) * Fixing metadata indexing when writing new Parquet file Summary: - addressing this issue: https://github.com/huggingface/lerobot/issues/2401 - vibe-coded bugfix by Claude Sonnet 4.5 * Backing out changes to convert_videos_of_camera * Addressing Ruff pre-commit complaint Summary: - addressing "SIM113 Use `enumerate()` for index variable `ep_idx` in `for` loop" --------- Co-authored-by: Paul <238953601+pac-robotics@users.noreply.github.com> --- .../v30/convert_dataset_v21_to_v30.py | 33 ++++++++++--------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py b/src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py index 2a69945e1..5362c52f4 100644 --- a/src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py +++ b/src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py @@ -204,7 +204,6 @@ def convert_data(root: Path, new_root: Path, data_file_size_in_mb: int): image_keys = get_image_keys(root) - ep_idx = 0 chunk_idx = 0 file_idx = 0 size_in_mb = 0 @@ -214,9 +213,24 @@ def convert_data(root: Path, new_root: Path, data_file_size_in_mb: int): logging.info(f"Converting data files from {len(ep_paths)} episodes") - for ep_path in tqdm.tqdm(ep_paths, desc="convert data files"): + for ep_idx, ep_path in enumerate(tqdm.tqdm(ep_paths, desc="convert data files")): ep_size_in_mb = get_parquet_file_size_in_mb(ep_path) ep_num_frames = get_parquet_num_frames(ep_path) + + # Check if we need to start a new file BEFORE creating metadata + if size_in_mb + ep_size_in_mb >= data_file_size_in_mb and len(paths_to_cat) > 0: + # Write the accumulated data files + concat_data_files(paths_to_cat, new_root, chunk_idx, file_idx, image_keys) + + # Move to next file + chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, DEFAULT_CHUNK_SIZE) + + # Reset for the next file + size_in_mb = 0 + num_frames += ep_num_frames # Still need to accumulate total frames + paths_to_cat = [] + + # Now create metadata with correct chunk/file indices ep_metadata = { "episode_index": ep_idx, "data/chunk_index": chunk_idx, @@ -227,20 +241,7 @@ def convert_data(root: Path, new_root: Path, data_file_size_in_mb: int): size_in_mb += ep_size_in_mb num_frames += ep_num_frames episodes_metadata.append(ep_metadata) - ep_idx += 1 - - if size_in_mb < data_file_size_in_mb: - paths_to_cat.append(ep_path) - continue - - if paths_to_cat: - concat_data_files(paths_to_cat, new_root, chunk_idx, file_idx, image_keys) - - # Reset for the next file - size_in_mb = ep_size_in_mb - paths_to_cat = [ep_path] - - chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, DEFAULT_CHUNK_SIZE) + paths_to_cat.append(ep_path) # Write remaining data if any if paths_to_cat: From cbc8bfb2e618a16b7d1cb46bdc0f8ac6073c1b29 Mon Sep 17 00:00:00 2001 From: Caroline Pascal Date: Wed, 4 Mar 2026 17:59:03 +0100 Subject: [PATCH 085/131] chore(docstrings): updating v2.1-v3.0 conversion script docstrings to match the new task label (#3077) * chore(docstrings): updating v2.1-v3.0 conversion script docstrings to match the new task label * chore(task): renamming the default index label in the tasks DataFrame to task * Revert "chore(docstrings): updating v2.1-v3.0 conversion script docstrings to match the new task label" This reverts commit f55de3255278f23f18b5d955565f6768d094951d. * chore(docstrings): updating docstrings to match dataset v3.0 architecture * chore(format): formatting code --- src/lerobot/datasets/aggregate.py | 4 +++- src/lerobot/datasets/dataset_tools.py | 4 +++- src/lerobot/datasets/lerobot_dataset.py | 2 +- src/lerobot/datasets/utils.py | 1 + .../datasets/v30/convert_dataset_v21_to_v30.py | 11 ++++++----- tests/fixtures/dataset_factories.py | 2 +- 6 files changed, 15 insertions(+), 9 deletions(-) diff --git a/src/lerobot/datasets/aggregate.py b/src/lerobot/datasets/aggregate.py index 7020545d2..b32116233 100644 --- a/src/lerobot/datasets/aggregate.py +++ b/src/lerobot/datasets/aggregate.py @@ -289,7 +289,9 @@ def aggregate_datasets( logging.info("Find all tasks") unique_tasks = pd.concat([m.tasks for m in all_metadata]).index.unique() - dst_meta.tasks = pd.DataFrame({"task_index": range(len(unique_tasks))}, index=unique_tasks) + dst_meta.tasks = pd.DataFrame( + {"task_index": range(len(unique_tasks))}, index=pd.Index(unique_tasks, name="task") + ) meta_idx = {"chunk": 0, "file": 0} data_idx = {"chunk": 0, "file": 0} diff --git a/src/lerobot/datasets/dataset_tools.py b/src/lerobot/datasets/dataset_tools.py index c900d7479..546b3d67f 100644 --- a/src/lerobot/datasets/dataset_tools.py +++ b/src/lerobot/datasets/dataset_tools.py @@ -1475,7 +1475,9 @@ def modify_tasks( # Collect all unique tasks and create new task mapping unique_tasks = sorted(set(episode_to_task.values())) - new_task_df = pd.DataFrame({"task_index": list(range(len(unique_tasks)))}, index=unique_tasks) + new_task_df = pd.DataFrame( + {"task_index": list(range(len(unique_tasks)))}, index=pd.Index(unique_tasks, name="task") + ) task_to_index = {task: idx for idx, task in enumerate(unique_tasks)} logging.info(f"Modifying tasks in {dataset.repo_id}") diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index 76d44de07..26f0c769c 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -314,7 +314,7 @@ class LeRobotDatasetMetadata: if self.tasks is None: new_tasks = tasks task_indices = range(len(tasks)) - self.tasks = pd.DataFrame({"task_index": task_indices}, index=tasks) + self.tasks = pd.DataFrame({"task_index": task_indices}, index=pd.Index(tasks, name="task")) else: new_tasks = [task for task in tasks if task not in self.tasks.index] new_task_indices = range(len(self.tasks), len(self.tasks) + len(new_tasks)) diff --git a/src/lerobot/datasets/utils.py b/src/lerobot/datasets/utils.py index da186bf30..a56740191 100644 --- a/src/lerobot/datasets/utils.py +++ b/src/lerobot/datasets/utils.py @@ -341,6 +341,7 @@ def write_tasks(tasks: pandas.DataFrame, local_dir: Path) -> None: def load_tasks(local_dir: Path) -> pandas.DataFrame: tasks = pd.read_parquet(local_dir / DEFAULT_TASKS_PATH) + tasks.index.name = "task" return tasks diff --git a/src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py b/src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py index 5362c52f4..3ae9093b9 100644 --- a/src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py +++ b/src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py @@ -108,7 +108,7 @@ episodes.jsonl {"episode_index": 1, "tasks": ["Put the blue block in the green bowl"], "length": 266} NEW -meta/episodes/chunk-000/episodes_000.parquet +meta/episodes/chunk-000/file_000.parquet episode_index | video_chunk_index | video_file_index | data_chunk_index | data_file_index | tasks | length ------------------------- OLD @@ -116,15 +116,16 @@ tasks.jsonl {"task_index": 1, "task": "Put the blue block in the green bowl"} NEW -meta/tasks/chunk-000/file_000.parquet +meta/tasks.parquet task_index | task ------------------------- OLD episodes_stats.jsonl +{"episode_index": 1, "stats": {"feature_name": {"min": ..., "max": ..., "mean": ..., "std": ..., "count": ...}}} NEW -meta/episodes_stats/chunk-000/file_000.parquet -episode_index | mean | std | min | max +meta/episodes/chunk-000/file_000.parquet +episode_index | feature_name/min | feature_name/max | feature_name/mean | feature_name/std | feature_name/count ------------------------- UPDATE meta/info.json @@ -173,7 +174,7 @@ def convert_tasks(root, new_root): tasks, _ = legacy_load_tasks(root) task_indices = tasks.keys() task_strings = tasks.values() - df_tasks = pd.DataFrame({"task_index": task_indices}, index=task_strings) + df_tasks = pd.DataFrame({"task_index": task_indices}, index=pd.Index(task_strings, name="task")) write_tasks(df_tasks, new_root) diff --git a/tests/fixtures/dataset_factories.py b/tests/fixtures/dataset_factories.py index c33fdcb72..f8dd01fec 100644 --- a/tests/fixtures/dataset_factories.py +++ b/tests/fixtures/dataset_factories.py @@ -222,7 +222,7 @@ def tasks_factory(): def _create_tasks(total_tasks: int = 3) -> pd.DataFrame: ids = list(range(total_tasks)) tasks = [f"Perform action {i}." for i in ids] - df = pd.DataFrame({"task_index": ids}, index=tasks) + df = pd.DataFrame({"task_index": ids}, index=pd.Index(tasks, name="task")) return df return _create_tasks From f0d2b37bebddf7e8852cb624712da0ad564601b6 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Thu, 5 Mar 2026 09:25:26 +0100 Subject: [PATCH 086/131] chore(dependencies): bump transformers v5 (#2964) * chore(dependencies): upgrade transformers + hggingface-hub + peft + scipy * chore(dependencies): bump pi0 family to transformers v5 * chore(dependencies): bump wall x to transformers v5 * chore(dependencies): bump gr00t to transformers v5 * chore(style): fix pre-commit * fix(policy): xvla forced_bos_token missing * test(rl): skip ci tests for resnet10 * Fix: full pi models support for transformer v5 (#2967) * fix(pi): remove loss truncation * fix(pi): remove state padding before tokenization * fix(pi): fix image padding value * fix from_pretrain * add transformer v5 changes * remove reference * more fixes * make it work * add support for rest of pi family * add pifast work * more changes * more changes * more cleanup * fix torch params * dtype fix * torch compile * embed mismatch fix * revert groot * more nit fixes * remove unused classes * more fixes * revert * nit * torch dtype warning fix * but back dynamic renaming * add tie embedding --------- Co-authored-by: Yufei Sun * chore: fix XVLA in transformers v5 (#3006) * test(policies): enable wall x CI testing * style(test): pre-commit check * style(test): pre-commit * fix wall x for transformer v5 (#3008) * tv5 fix * various wall x fixes * Delete tests/policies/pi0_pi05/print_pi05_output_logits.py Signed-off-by: Jade Choghari * sync modeling_florence2.py with chore/bump_transformers_v5 * more * more fixes * more * remove comment * more --------- Signed-off-by: Jade Choghari * chore(dependencies): adjust dependencies versioning after transformers v5 (#3034) * chore(dependecies): adjust dependecies versioning after transformers v5 * fix(policies): remove deprecated input_embeds * fix(policies): dict _tied_weights_keys * chore(depedencies): common qwen-vl-utils * chore(dependencies): bump transformers to 5.2 * Fix policy testing for tv5 (#3032) * fix ci logger * other fix * fix mypy * change logits to torch2.10 * skip wallx| * remove logging --------- Co-authored-by: Steven Palma * feat(ci): log into HF to unblock some CI tests (#3007) * feat(ci): log into HF to unblock some CI tests * chore(ci): change hf call + secret name * fix(ci): temp fix for pi0 rtc test * test(policies): require_cuda for unblocked tests * test(policies): require_cuda wall_x * fic(tests): require_cuda outter most for pi0 * fix(test): return instead of yield --------- Signed-off-by: Steven Palma * style(test): fix pre-commit * chore(deps): upgrade transformers (#3050) * chore(test): use lerobot model * fix(policies): change default action tokenizer for wall x * sample on cpu * Revert "Merge branch 'chore/bump_transformers_v5' of https://github.com/huggingface/lerobot into chore/bump_transformers_v5" This reverts commit d9b76755f7ec640cd6d52d29a7a3c09b815ef28c, reversing changes made to 89359cb0b678a6fe4867457f943d8b0b0de935f6. * Reapply "Merge branch 'chore/bump_transformers_v5' of https://github.com/huggingface/lerobot into chore/bump_transformers_v5" This reverts commit c9914db78b05653e885de15b5992b69fc701a0c2. --------- Signed-off-by: Jade Choghari Signed-off-by: Steven Palma Co-authored-by: Jade Choghari Co-authored-by: Yufei Sun Co-authored-by: Pepijn --- .github/workflows/fast_tests.yml | 6 + .github/workflows/full_tests.yml | 11 + docs/source/pi0fast.mdx | 20 +- pyproject.toml | 111 +----- .../image_processing_eagle2_5_vl_fast.py | 4 +- src/lerobot/policies/pi0/modeling_pi0.py | 119 +++--- src/lerobot/policies/pi05/modeling_pi05.py | 128 +++--- src/lerobot/policies/pi05/processor_pi05.py | 4 - .../pi0_fast/configuration_pi0_fast.py | 2 +- .../policies/pi0_fast/modeling_pi0_fast.py | 83 ++-- .../policies/pi0_fast/processor_pi0_fast.py | 4 - src/lerobot/policies/pi_gemma.py | 363 ++++++++++++++++++ .../reward_model/configuration_classifier.py | 2 +- .../policies/wall_x/configuration_wall_x.py | 2 +- .../policies/wall_x/modeling_wall_x.py | 14 +- .../qwen_model/configuration_qwen2_5_vl.py | 2 + .../wall_x/qwen_model/qwen2_5_vl_moe.py | 35 +- src/lerobot/policies/wall_x/utils.py | 4 +- .../policies/xvla/configuration_florence2.py | 2 + .../policies/xvla/modeling_florence2.py | 19 +- src/lerobot/processor/tokenizer_processor.py | 2 +- .../scripts/lerobot_train_tokenizer.py | 2 +- .../hilserl/test_modeling_classifier.py | 13 + .../test_pi0_fast_original_vs_lerobot.py | 23 +- tests/policies/pi0_pi05/test_pi0.py | 9 - tests/policies/pi0_pi05/test_pi05.py | 12 +- tests/policies/pi0_pi05/test_pi05_rtc.py | 3 +- tests/policies/pi0_pi05/test_pi0_rtc.py | 4 +- tests/policies/test_sac_policy.py | 3 + tests/policies/wall_x/test_wallx.py | 16 +- 30 files changed, 694 insertions(+), 328 deletions(-) create mode 100644 src/lerobot/policies/pi_gemma.py diff --git a/.github/workflows/fast_tests.yml b/.github/workflows/fast_tests.yml index 10ec91199..27a4043e7 100644 --- a/.github/workflows/fast_tests.yml +++ b/.github/workflows/fast_tests.yml @@ -61,6 +61,7 @@ jobs: MUJOCO_GL: egl HF_HOME: /mnt/cache/.cache/huggingface HF_LEROBOT_HOME: /mnt/cache/.cache/huggingface/lerobot + HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }} steps: - uses: actions/checkout@v6 with: @@ -89,5 +90,10 @@ jobs: - name: Install lerobot with test extras run: uv sync --extra "test" + - name: Login to Hugging Face + run: | + uv run hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential + uv run hf auth whoami + - name: Run pytest run: uv run pytest tests -vv --maxfail=10 diff --git a/.github/workflows/full_tests.yml b/.github/workflows/full_tests.yml index d23b99de0..8dd1fcb1c 100644 --- a/.github/workflows/full_tests.yml +++ b/.github/workflows/full_tests.yml @@ -60,6 +60,7 @@ jobs: MUJOCO_GL: egl HF_HOME: /mnt/cache/.cache/huggingface HF_LEROBOT_HOME: /mnt/cache/.cache/huggingface/lerobot + HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }} steps: - uses: actions/checkout@v6 with: @@ -87,6 +88,11 @@ jobs: - name: Install lerobot with all extras run: uv sync --extra all # TODO(Steven): Make flash-attn optional + - name: Login to Hugging Face + run: | + uv run hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential + uv run hf auth whoami + - name: Run pytest (all extras) run: uv run pytest tests -vv --maxfail=10 @@ -162,6 +168,7 @@ jobs: HF_LEROBOT_HOME: /home/user_lerobot/.cache/huggingface/lerobot TORCH_HOME: /home/user_lerobot/.cache/torch TRITON_CACHE_DIR: /home/user_lerobot/.cache/triton + HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }} container: image: ${{ needs.build-and-push-docker.outputs.image_tag }} # zizmor: ignore[unpinned-images] options: --gpus all --shm-size "16gb" @@ -173,6 +180,10 @@ jobs: shell: bash working-directory: /lerobot steps: + - name: Login to Hugging Face + run: | + hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential + hf auth whoami - name: Fix ptxas permissions run: chmod +x /lerobot/.venv/lib/python3.10/site-packages/triton/backends/nvidia/bin/ptxas - name: Run pytest on GPU diff --git a/docs/source/pi0fast.mdx b/docs/source/pi0fast.mdx index c4230fa79..85d975924 100644 --- a/docs/source/pi0fast.mdx +++ b/docs/source/pi0fast.mdx @@ -52,7 +52,7 @@ This approach can transform **any existing VLM** into a VLA by training it to pr You have two options for the FAST tokenizer: -1. **Use the pre-trained tokenizer**: The `physical-intelligence/fast` tokenizer was trained on 1M+ real robot action sequences and works as a general-purpose tokenizer. +1. **Use the pre-trained tokenizer**: The `lerobot/fast-action-tokenizer` tokenizer was trained on 1M+ real robot action sequences and works as a general-purpose tokenizer. 2. **Train your own tokenizer**: For maximum performance on your specific dataset, you can finetune the tokenizer on your own data. @@ -114,15 +114,15 @@ lerobot-train \ ### Key Training Parameters -| Parameter | Description | Default | -| -------------------------------------- | -------------------------------------------------- | ---------------------------- | -| `--policy.gradient_checkpointing=true` | Reduces memory usage significantly during training | `false` | -| `--policy.dtype=bfloat16` | Use mixed precision training for efficiency | `float32` | -| `--policy.chunk_size` | Number of action steps to predict (action horizon) | `50` | -| `--policy.n_action_steps` | Number of action steps to execute | `50` | -| `--policy.max_action_tokens` | Maximum number of FAST tokens per action chunk | `256` | -| `--policy.action_tokenizer_name` | FAST tokenizer to use | `physical-intelligence/fast` | -| `--policy.compile_model=true` | Enable torch.compile for faster training | `false` | +| Parameter | Description | Default | +| -------------------------------------- | -------------------------------------------------- | ------------------------------- | +| `--policy.gradient_checkpointing=true` | Reduces memory usage significantly during training | `false` | +| `--policy.dtype=bfloat16` | Use mixed precision training for efficiency | `float32` | +| `--policy.chunk_size` | Number of action steps to predict (action horizon) | `50` | +| `--policy.n_action_steps` | Number of action steps to execute | `50` | +| `--policy.max_action_tokens` | Maximum number of FAST tokens per action chunk | `256` | +| `--policy.action_tokenizer_name` | FAST tokenizer to use | `lerobot/fast-action-tokenizer` | +| `--policy.compile_model=true` | Enable torch.compile for faster training | `false` | ## Inference diff --git a/pyproject.toml b/pyproject.toml index f4fb7d249..f86184900 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,7 +61,7 @@ dependencies = [ # Hugging Face dependencies "datasets>=4.0.0,<5.0.0", "diffusers>=0.27.2,<0.36.0", - "huggingface-hub[hf-transfer,cli]>=0.34.2,<0.36.0", + "huggingface-hub[cli]>=1.0.0,<2.0.0", "accelerate>=1.10.0,<2.0.0", # Core dependencies @@ -96,9 +96,12 @@ dependencies = [ # Common pygame-dep = ["pygame>=2.5.1,<2.7.0"] placo-dep = ["placo>=0.9.6,<0.10.0"] -transformers-dep = ["transformers>=4.57.1,<5.0.0"] +transformers-dep = ["transformers>=5.3.0,<6.0.0"] grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"] can-dep = ["python-can>=4.2.0,<5.0.0"] +peft-dep = ["peft>=0.18.0,<1.0.0"] +scipy-dep = ["scipy>=1.14.0,<2.0.0"] +qwen-vl-utils-dep = ["qwen-vl-utils>=0.0.11,<0.1.0"] # Motors feetech = ["feetech-servo-sdk>=1.0.0,<2.0.0"] @@ -129,17 +132,17 @@ phone = ["hebi-py>=2.8.0,<2.12.0", "teleop>=0.1.0,<0.2.0", "fastapi<1.0"] # Policies wallx = [ - "transformers==4.49.0", - "peft==0.17.1", - "scipy==1.15.3", - "torchdiffeq==0.2.5", - "qwen_vl_utils==0.0.11" + "lerobot[transformers-dep]", + "lerobot[peft]", + "lerobot[scipy-dep]", + "torchdiffeq>=0.2.4,<0.3.0", + "lerobot[qwen-vl-utils-dep]", ] -pi = ["transformers @ git+https://github.com/huggingface/transformers.git@fix/lerobot_openpi", "scipy>=1.10.1,<1.15"] +pi = ["lerobot[transformers-dep]", "lerobot[scipy-dep]"] smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0", "safetensors>=0.4.3,<1.0.0"] groot = [ "lerobot[transformers-dep]", - "peft>=0.13.0,<1.0.0", + "lerobot[peft]", "dm-tree>=0.1.8,<1.0.0", "timm>=1.0.0,<1.1.0", "safetensors>=0.4.3,<1.0.0", @@ -148,13 +151,13 @@ groot = [ "ninja>=1.11.1,<2.0.0", "flash-attn>=2.5.9,<3.0.0 ; sys_platform != 'darwin'" ] -sarm = ["lerobot[transformers-dep]", "faker>=33.0.0,<35.0.0", "matplotlib>=3.10.3,<4.0.0", "qwen-vl-utils>=0.0.14,<0.1.0"] +sarm = ["lerobot[transformers-dep]", "faker>=33.0.0,<35.0.0", "matplotlib>=3.10.3,<4.0.0", "lerobot[qwen-vl-utils-dep]"] xvla = ["lerobot[transformers-dep]"] hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"] # Features async = ["lerobot[grpcio-dep]", "matplotlib>=3.10.3,<4.0.0"] -peft = ["lerobot[transformers-dep]", "peft>=0.18.0,<1.0.0"] +peft = ["lerobot[transformers-dep]", "lerobot[peft-dep]"] # Development dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1", "mypy>=1.19.1"] @@ -176,8 +179,8 @@ all = [ "lerobot[reachy2]", "lerobot[kinematics]", "lerobot[intelrealsense]", - # "lerobot[wallx]", - # "lerobot[pi]", TODO(Pepijn): Update pi to transformers v5 + "lerobot[wallx]", + "lerobot[pi]", "lerobot[smolvla]", # "lerobot[groot]", TODO(Steven): Gr00t requires specific installation instructions for flash-attn "lerobot[xvla]", @@ -397,85 +400,3 @@ ignore_errors = false # [[tool.mypy.overrides]] # module = "lerobot.scripts.*" # ignore_errors = false - -[tool.uv] -# wallx requires transformers==4.49.0 which conflicts with other extras that need >=4.53.0 -conflicts = [ - [ - { extra = "wallx" }, - { extra = "transformers-dep" }, - ], - [ - { extra = "wallx" }, - { extra = "pi" }, - ], - [ - { extra = "wallx" }, - { extra = "smolvla" }, - ], - [ - { extra = "wallx" }, - { extra = "groot" }, - ], - [ - { extra = "wallx" }, - { extra = "xvla" }, - ], - [ - { extra = "wallx" }, - { extra = "sarm" }, - ], - [ - { extra = "wallx" }, - { extra = "hilserl" }, - ], - [ - { extra = "wallx" }, - { extra = "libero" }, - ], - [ - { extra = "wallx" }, - { extra = "peft" }, - ], - [ - { extra = "wallx" }, - { extra = "all" }, - ], - # pi uses custom branch which conflicts with transformers-dep - [ - { extra = "pi" }, - { extra = "transformers-dep" }, - ], - [ - { extra = "pi" }, - { extra = "smolvla" }, - ], - [ - { extra = "pi" }, - { extra = "groot" }, - ], - [ - { extra = "pi" }, - { extra = "xvla" }, - ], - [ - { extra = "pi" }, - { extra = "sarm" }, - ], - [ - { extra = "pi" }, - { extra = "hilserl" }, - ], - [ - { extra = "pi" }, - { extra = "libero" }, - ], - [ - { extra = "pi" }, - { extra = "peft" }, - ], - [ - { extra = "pi" }, - { extra = "all" }, - ], -] diff --git a/src/lerobot/policies/groot/eagle2_hg_model/image_processing_eagle2_5_vl_fast.py b/src/lerobot/policies/groot/eagle2_hg_model/image_processing_eagle2_5_vl_fast.py index 6b4f6d7ac..e01b9b839 100644 --- a/src/lerobot/policies/groot/eagle2_hg_model/image_processing_eagle2_5_vl_fast.py +++ b/src/lerobot/policies/groot/eagle2_hg_model/image_processing_eagle2_5_vl_fast.py @@ -14,7 +14,7 @@ from transformers.image_processing_utils import ( ) from transformers.image_processing_utils_fast import ( BaseImageProcessorFast, - DefaultFastImageProcessorKwargs, + ImagesKwargs, group_images_by_shape, reorder_images, ) @@ -77,7 +77,7 @@ def crop(img: torch.Tensor, left: int, top: int, right: int, bottom: int) -> tor return img[:, top:bottom, left:right] -class Eagle25VLFastImageProcessorKwargs(DefaultFastImageProcessorKwargs): +class Eagle25VLFastImageProcessorKwargs(ImagesKwargs): max_dynamic_tiles: int | None min_dynamic_tiles: int | None use_thumbnail: bool | None diff --git a/src/lerobot/policies/pi0/modeling_pi0.py b/src/lerobot/policies/pi0/modeling_pi0.py index 58b5dc07b..2f77e9517 100644 --- a/src/lerobot/policies/pi0/modeling_pi0.py +++ b/src/lerobot/policies/pi0/modeling_pi0.py @@ -15,6 +15,7 @@ # limitations under the License. import builtins +import copy import logging import math from collections import deque @@ -32,13 +33,21 @@ from lerobot.utils.import_utils import _transformers_available if TYPE_CHECKING or _transformers_available: from transformers.models.auto import CONFIG_MAPPING from transformers.models.gemma import modeling_gemma - from transformers.models.gemma.modeling_gemma import GemmaForCausalLM - from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration + + from lerobot.policies.pi_gemma import ( + PaliGemmaForConditionalGenerationWithPiGemma, + PiGemmaForCausalLM, + _gated_residual, + layernorm_forward, + ) else: CONFIG_MAPPING = None modeling_gemma = None - GemmaForCausalLM = None - PaliGemmaForConditionalGeneration = None + PiGemmaForCausalLM = None + _gated_residual = None + layernorm_forward = None + PaliGemmaForConditionalGenerationWithPiGemma = None + from lerobot.configs.policies import PreTrainedConfig from lerobot.policies.pi0.configuration_pi0 import DEFAULT_IMAGE_SIZE, PI0Config @@ -191,7 +200,7 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy) if images.dtype == torch.uint8: resized_images = torch.round(resized_images).clamp(0, 255).to(torch.uint8) elif images.dtype == torch.float32: - resized_images = resized_images.clamp(-1.0, 1.0) + resized_images = resized_images.clamp(0.0, 1.0) else: raise ValueError(f"Unsupported image dtype: {images.dtype}") @@ -202,7 +211,7 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy) pad_w1 = pad_w0 + remainder_w # Pad - constant_value = 0 if images.dtype == torch.uint8 else -1.0 + constant_value = 0 if images.dtype == torch.uint8 else 0.0 padded_images = F.pad( resized_images, (pad_w0, pad_w1, pad_h0, pad_h1), # left, right, top, bottom @@ -221,14 +230,14 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy) def compute_layer_complete( layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma, gemma_expert ): - models = [paligemma.language_model, gemma_expert.model] + models = [paligemma.model.language_model, gemma_expert.model] query_states = [] key_states = [] value_states = [] gates = [] for i, hidden_states in enumerate(inputs_embeds): layer = models[i].layers[layer_idx] - hidden_states, gate = layer.input_layernorm(hidden_states, cond=adarms_cond[i]) # noqa: PLW2901 + hidden_states, gate = layernorm_forward(layer.input_layernorm, hidden_states, adarms_cond[i]) gates.append(gate) input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, layer.self_attn.head_dim) @@ -254,10 +263,10 @@ def compute_layer_complete( query_states, key_states, cos, sin, unsqueeze_dim=1 ) batch_size = query_states.shape[0] - scaling = paligemma.language_model.layers[layer_idx].self_attn.scaling + scaling = paligemma.model.language_model.layers[layer_idx].self_attn.scaling # Attention computation att_output, _ = modeling_gemma.eager_attention_forward( - paligemma.language_model.layers[layer_idx].self_attn, + paligemma.model.language_model.layers[layer_idx].self_attn, query_states, key_states, value_states, @@ -265,7 +274,7 @@ def compute_layer_complete( scaling, ) # Get head_dim from the current layer, not from the model - head_dim = paligemma.language_model.layers[layer_idx].self_attn.head_dim + head_dim = paligemma.model.language_model.layers[layer_idx].self_attn.head_dim att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim) # Process layer outputs outputs_embeds = [] @@ -277,15 +286,15 @@ def compute_layer_complete( att_output = att_output.to(layer.self_attn.o_proj.weight.dtype) out_emb = layer.self_attn.o_proj(att_output[:, start_pos:end_pos]) # first residual - out_emb = modeling_gemma._gated_residual(hidden_states, out_emb, gates[i]) # noqa: SLF001 + out_emb = _gated_residual(hidden_states, out_emb, gates[i]) after_first_residual = out_emb.clone() - out_emb, gate = layer.post_attention_layernorm(out_emb, cond=adarms_cond[i]) + out_emb, gate = layernorm_forward(layer.post_attention_layernorm, out_emb, adarms_cond[i]) # Convert to bfloat16 if the next layer (mlp) uses bfloat16 if layer.mlp.up_proj.weight.dtype == torch.bfloat16: out_emb = out_emb.to(dtype=torch.bfloat16) out_emb = layer.mlp(out_emb) # second residual - out_emb = modeling_gemma._gated_residual(after_first_residual, out_emb, gate) # noqa: SLF001 + out_emb = _gated_residual(after_first_residual, out_emb, gate) outputs_embeds.append(out_emb) start_pos = end_pos return outputs_embeds @@ -358,7 +367,7 @@ class PaliGemmaWithExpertModel( vlm_config_hf.text_config.num_hidden_layers = vlm_config.depth vlm_config_hf.text_config.num_key_value_heads = vlm_config.num_kv_heads vlm_config_hf.text_config.hidden_activation = "gelu_pytorch_tanh" - vlm_config_hf.text_config.torch_dtype = "float32" + vlm_config_hf.text_config.dtype = "float32" vlm_config_hf.text_config.vocab_size = 257152 vlm_config_hf.text_config.use_adarms = use_adarms[0] vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None @@ -366,7 +375,7 @@ class PaliGemmaWithExpertModel( vlm_config_hf.vision_config.intermediate_size = 4304 vlm_config_hf.vision_config.projection_dim = 2048 vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast" - vlm_config_hf.vision_config.torch_dtype = "float32" + vlm_config_hf.vision_config.dtype = "float32" action_expert_config_hf = CONFIG_MAPPING["gemma"]( head_dim=action_expert_config.head_dim, @@ -377,13 +386,13 @@ class PaliGemmaWithExpertModel( num_key_value_heads=action_expert_config.num_kv_heads, vocab_size=257152, hidden_activation="gelu_pytorch_tanh", - torch_dtype="float32", + dtype="float32", use_adarms=use_adarms[1], adarms_cond_dim=action_expert_config.width if use_adarms[1] else None, ) - self.paligemma = PaliGemmaForConditionalGeneration(config=vlm_config_hf) - self.gemma_expert = GemmaForCausalLM(config=action_expert_config_hf) + self.paligemma = PaliGemmaForConditionalGenerationWithPiGemma(config=vlm_config_hf) + self.gemma_expert = PiGemmaForCausalLM(config=action_expert_config_hf) self.gemma_expert.model.embed_tokens = None self.to_bfloat16_for_selected_params(precision) @@ -398,10 +407,11 @@ class PaliGemmaWithExpertModel( else: raise ValueError(f"Invalid precision: {precision}") + # Keep full vision path in float32 so we never toggle (toggle causes optimizer + # "same dtype" error). Align with PI05. params_to_keep_float32 = [ - "vision_tower.vision_model.embeddings.patch_embedding.weight", - "vision_tower.vision_model.embeddings.patch_embedding.bias", - "vision_tower.vision_model.embeddings.position_embedding.weight", + "vision_tower", + "multi_modal_projector", "input_layernorm", "post_attention_layernorm", "model.norm", @@ -413,8 +423,8 @@ class PaliGemmaWithExpertModel( def _set_requires_grad(self): if self.freeze_vision_encoder: - self.paligemma.vision_tower.eval() - for param in self.paligemma.vision_tower.parameters(): + self.paligemma.model.vision_tower.eval() + for param in self.paligemma.model.vision_tower.parameters(): param.requires_grad = False if self.train_expert_only: self.paligemma.eval() @@ -424,15 +434,23 @@ class PaliGemmaWithExpertModel( def train(self, mode: bool = True): super().train(mode) if self.freeze_vision_encoder: - self.paligemma.vision_tower.eval() + self.paligemma.model.vision_tower.eval() if self.train_expert_only: self.paligemma.eval() def embed_image(self, image: torch.Tensor): - return self.paligemma.model.get_image_features(image) + # Vision tower and multi_modal_projector are kept in float32 (params_to_keep_float32). Align with PI05. + out_dtype = image.dtype + if image.dtype != torch.float32: + image = image.to(torch.float32) + image_outputs = self.paligemma.model.get_image_features(image) + features = image_outputs.pooler_output * self.paligemma.config.text_config.hidden_size**0.5 + if features.dtype != out_dtype: + features = features.to(out_dtype) + return features def embed_language_tokens(self, tokens: torch.Tensor): - return self.paligemma.language_model.embed_tokens(tokens) + return self.paligemma.model.language_model.embed_tokens(tokens) def forward( self, @@ -446,7 +464,7 @@ class PaliGemmaWithExpertModel( if adarms_cond is None: adarms_cond = [None, None] if inputs_embeds[1] is None: - prefix_output = self.paligemma.language_model.forward( + prefix_output = self.paligemma.model.language_model.forward( inputs_embeds=inputs_embeds[0], attention_mask=attention_mask, position_ids=position_ids, @@ -470,7 +488,7 @@ class PaliGemmaWithExpertModel( prefix_output = None prefix_past_key_values = None else: - models = [self.paligemma.language_model, self.gemma_expert.model] + models = [self.paligemma.model.language_model, self.gemma_expert.model] num_layers = self.paligemma.config.text_config.num_hidden_layers # Check if gradient checkpointing is enabled for any of the models @@ -510,7 +528,7 @@ class PaliGemmaWithExpertModel( def compute_final_norms(inputs_embeds, adarms_cond): outputs_embeds = [] for i, hidden_states in enumerate(inputs_embeds): - out_emb, _ = models[i].norm(hidden_states, cond=adarms_cond[i]) + out_emb, _ = layernorm_forward(models[i].norm, hidden_states, adarms_cond[i]) outputs_embeds.append(out_emb) return outputs_embeds @@ -576,29 +594,19 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch` # Also compile the main forward pass used during training self.forward = torch.compile(self.forward, mode=config.compile_mode) - msg = """An incorrect transformer version is used, please create an issue on https://github.com/huggingface/lerobot/issues""" - - try: - from transformers.models.siglip import check - - if not check.check_whether_transformers_replace_is_installed_correctly(): - raise ValueError(msg) - except ImportError: - raise ValueError(msg) from None - def gradient_checkpointing_enable(self): """Enable gradient checkpointing for memory optimization.""" self.gradient_checkpointing_enabled = True - self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = True - self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = True + self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing = True + self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing = True self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = True logging.info("Enabled gradient checkpointing for PI0Pytorch model") def gradient_checkpointing_disable(self): """Disable gradient checkpointing.""" self.gradient_checkpointing_enabled = False - self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = False - self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = False + self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing = False + self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing = False self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False logging.info("Disabled gradient checkpointing for PI0Pytorch model") @@ -760,7 +768,7 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch` suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(state, x_t, time) if ( - self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype + self.paligemma_with_expert.paligemma.model.language_model.layers[0].self_attn.q_proj.weight.dtype == torch.bfloat16 ): suffix_embs = suffix_embs.to(dtype=torch.bfloat16) @@ -834,7 +842,7 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch` prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1 prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_2d_masks) - self.paligemma_with_expert.paligemma.language_model.config._attn_implementation = "eager" # noqa: SLF001 + self.paligemma_with_expert.paligemma.model.language_model.config._attn_implementation = "eager" # noqa: SLF001 _, past_key_values = self.paligemma_with_expert.forward( attention_mask=prefix_att_2d_masks_4d, @@ -908,6 +916,7 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch` full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks) self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001 + past_key_values = copy.deepcopy(past_key_values) outputs_embeds, _ = self.paligemma_with_expert.forward( attention_mask=full_att_2d_masks_4d, position_ids=position_ids, @@ -997,14 +1006,12 @@ class PI0Policy(PreTrainedPolicy): # Check if dataset_stats were provided in kwargs model = cls(config, **kwargs) - # Now manually load and remap the state dict + # Load state dict (expects keys with "model." prefix) try: - # Try to load the pytorch_model.bin or model.safetensors file print(f"Loading model from: {pretrained_name_or_path}") try: from transformers.utils import cached_file - # Try safetensors first resolved_file = cached_file( pretrained_name_or_path, "model.safetensors", @@ -1012,7 +1019,7 @@ class PI0Policy(PreTrainedPolicy): force_download=kwargs.get("force_download", False), resume_download=kwargs.get("resume_download"), proxies=kwargs.get("proxies"), - use_auth_token=kwargs.get("use_auth_token"), + token=kwargs.get("token"), revision=kwargs.get("revision"), local_files_only=kwargs.get("local_files_only", False), ) @@ -1025,7 +1032,7 @@ class PI0Policy(PreTrainedPolicy): print("Returning model without loading pretrained weights") return model - # First, fix any key differences # see openpi `model.py, _fix_pytorch_state_dict_keys` + # First, fix any key differences (see openpi model.py, _fix_pytorch_state_dict_keys) fixed_state_dict = model._fix_pytorch_state_dict_keys(original_state_dict, model.config) # Then add "model." prefix for all keys that don't already have it @@ -1070,7 +1077,7 @@ class PI0Policy(PreTrainedPolicy): print("All keys loaded successfully!") except Exception as e: - print(f"Warning: Could not remap state dict keys: {e}") + print(f"Warning: Could not load state dict: {e}") return model @@ -1120,6 +1127,14 @@ class PI0Policy(PreTrainedPolicy): # Some checkpoints might have this, but current model expects different structure logging.warning(f"Vision embedding key might need handling: {key}") + if ( + key == "model.paligemma_with_expert.paligemma.lm_head.weight" + or key == "paligemma_with_expert.paligemma.lm_head.weight" + ): + fixed_state_dict[ + "model.paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight" + ] = value.clone() + fixed_state_dict[new_key] = value return fixed_state_dict diff --git a/src/lerobot/policies/pi05/modeling_pi05.py b/src/lerobot/policies/pi05/modeling_pi05.py index 104ec63bf..dc5eb20ec 100644 --- a/src/lerobot/policies/pi05/modeling_pi05.py +++ b/src/lerobot/policies/pi05/modeling_pi05.py @@ -15,6 +15,7 @@ # limitations under the License. import builtins +import copy import logging import math from collections import deque @@ -32,14 +33,20 @@ from lerobot.utils.import_utils import _transformers_available if TYPE_CHECKING or _transformers_available: from transformers.models.auto import CONFIG_MAPPING from transformers.models.gemma import modeling_gemma - from transformers.models.gemma.modeling_gemma import GemmaForCausalLM - from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration + + from lerobot.policies.pi_gemma import ( + PaliGemmaForConditionalGenerationWithPiGemma, + PiGemmaForCausalLM, + _gated_residual, + layernorm_forward, + ) else: CONFIG_MAPPING = None modeling_gemma = None - GemmaForCausalLM = None - PaliGemmaForConditionalGeneration = None - + PiGemmaForCausalLM = None + _gated_residual = None + layernorm_forward = None + PaliGemmaForConditionalGenerationWithPiGemma = None from lerobot.configs.policies import PreTrainedConfig from lerobot.policies.pi05.configuration_pi05 import DEFAULT_IMAGE_SIZE, PI05Config from lerobot.policies.pretrained import PreTrainedPolicy, T @@ -92,10 +99,11 @@ def create_sinusoidal_pos_embedding( # see openpi `create_sinusoidal_pos_embedd def sample_beta(alpha, beta, bsize, device): # see openpi `sample_beta` (exact copy) - alpha_t = torch.as_tensor(alpha, dtype=torch.float32, device=device) - beta_t = torch.as_tensor(beta, dtype=torch.float32, device=device) + # Beta sampling uses _sample_dirichlet which isn't implemented for MPS, so sample on CPU + alpha_t = torch.tensor(alpha, dtype=torch.float32) + beta_t = torch.tensor(beta, dtype=torch.float32) dist = torch.distributions.Beta(alpha_t, beta_t) - return dist.sample((bsize,)) + return dist.sample((bsize,)).to(device) def make_att_2d_masks(pad_masks, att_masks): # see openpi `make_att_2d_masks` (exact copy) @@ -189,7 +197,7 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy) if images.dtype == torch.uint8: resized_images = torch.round(resized_images).clamp(0, 255).to(torch.uint8) elif images.dtype == torch.float32: - resized_images = resized_images.clamp(-1.0, 1.0) + resized_images = resized_images.clamp(0.0, 1.0) else: raise ValueError(f"Unsupported image dtype: {images.dtype}") @@ -200,7 +208,7 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy) pad_w1 = pad_w0 + remainder_w # Pad - constant_value = 0 if images.dtype == torch.uint8 else -1.0 + constant_value = 0 if images.dtype == torch.uint8 else 0.0 padded_images = F.pad( resized_images, (pad_w0, pad_w1, pad_h0, pad_h1), # left, right, top, bottom @@ -219,14 +227,14 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy) def compute_layer_complete( layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma, gemma_expert ): - models = [paligemma.language_model, gemma_expert.model] + models = [paligemma.model.language_model, gemma_expert.model] query_states = [] key_states = [] value_states = [] gates = [] for i, hidden_states in enumerate(inputs_embeds): layer = models[i].layers[layer_idx] - hidden_states, gate = layer.input_layernorm(hidden_states, cond=adarms_cond[i]) # noqa: PLW2901 + hidden_states, gate = layernorm_forward(layer.input_layernorm, hidden_states, adarms_cond[i]) gates.append(gate) input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, layer.self_attn.head_dim) @@ -252,10 +260,10 @@ def compute_layer_complete( query_states, key_states, cos, sin, unsqueeze_dim=1 ) batch_size = query_states.shape[0] - scaling = paligemma.language_model.layers[layer_idx].self_attn.scaling + scaling = paligemma.model.language_model.layers[layer_idx].self_attn.scaling # Attention computation att_output, _ = modeling_gemma.eager_attention_forward( - paligemma.language_model.layers[layer_idx].self_attn, + paligemma.model.language_model.layers[layer_idx].self_attn, query_states, key_states, value_states, @@ -263,7 +271,7 @@ def compute_layer_complete( scaling, ) # Get head_dim from the current layer, not from the model - head_dim = paligemma.language_model.layers[layer_idx].self_attn.head_dim + head_dim = paligemma.model.language_model.layers[layer_idx].self_attn.head_dim att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim) # Process layer outputs outputs_embeds = [] @@ -275,15 +283,15 @@ def compute_layer_complete( att_output = att_output.to(layer.self_attn.o_proj.weight.dtype) out_emb = layer.self_attn.o_proj(att_output[:, start_pos:end_pos]) # first residual - out_emb = modeling_gemma._gated_residual(hidden_states, out_emb, gates[i]) # noqa: SLF001 + out_emb = _gated_residual(hidden_states, out_emb, gates[i]) after_first_residual = out_emb.clone() - out_emb, gate = layer.post_attention_layernorm(out_emb, cond=adarms_cond[i]) + out_emb, gate = layernorm_forward(layer.post_attention_layernorm, out_emb, adarms_cond[i]) # Convert to bfloat16 if the next layer (mlp) uses bfloat16 if layer.mlp.up_proj.weight.dtype == torch.bfloat16: out_emb = out_emb.to(dtype=torch.bfloat16) out_emb = layer.mlp(out_emb) # second residual - out_emb = modeling_gemma._gated_residual(after_first_residual, out_emb, gate) # noqa: SLF001 + out_emb = _gated_residual(after_first_residual, out_emb, gate) outputs_embeds.append(out_emb) start_pos = end_pos return outputs_embeds @@ -356,7 +364,7 @@ class PaliGemmaWithExpertModel( vlm_config_hf.text_config.num_hidden_layers = vlm_config.depth vlm_config_hf.text_config.num_key_value_heads = vlm_config.num_kv_heads vlm_config_hf.text_config.hidden_activation = "gelu_pytorch_tanh" - vlm_config_hf.text_config.torch_dtype = "float32" + vlm_config_hf.text_config.dtype = "float32" vlm_config_hf.text_config.vocab_size = 257152 vlm_config_hf.text_config.use_adarms = use_adarms[0] vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None @@ -364,7 +372,7 @@ class PaliGemmaWithExpertModel( vlm_config_hf.vision_config.intermediate_size = 4304 vlm_config_hf.vision_config.projection_dim = 2048 vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast" - vlm_config_hf.vision_config.torch_dtype = "float32" + vlm_config_hf.vision_config.dtype = "float32" action_expert_config_hf = CONFIG_MAPPING["gemma"]( head_dim=action_expert_config.head_dim, @@ -375,13 +383,13 @@ class PaliGemmaWithExpertModel( num_key_value_heads=action_expert_config.num_kv_heads, vocab_size=257152, hidden_activation="gelu_pytorch_tanh", - torch_dtype="float32", + dtype="float32", use_adarms=use_adarms[1], adarms_cond_dim=action_expert_config.width if use_adarms[1] else None, ) - self.paligemma = PaliGemmaForConditionalGeneration(config=vlm_config_hf) - self.gemma_expert = GemmaForCausalLM(config=action_expert_config_hf) + self.paligemma = PaliGemmaForConditionalGenerationWithPiGemma(config=vlm_config_hf) + self.gemma_expert = PiGemmaForCausalLM(config=action_expert_config_hf) self.gemma_expert.model.embed_tokens = None self.to_bfloat16_for_selected_params(precision) @@ -396,10 +404,11 @@ class PaliGemmaWithExpertModel( else: raise ValueError(f"Invalid precision: {precision}") + # Keep full vision path in float32 so we never toggle (toggle causes optimizer + # "same dtype" error). Saves memory vs full float32; more memory than only 3 params. params_to_keep_float32 = [ - "vision_tower.vision_model.embeddings.patch_embedding.weight", - "vision_tower.vision_model.embeddings.patch_embedding.bias", - "vision_tower.vision_model.embeddings.position_embedding.weight", + "vision_tower", + "multi_modal_projector", "input_layernorm", "post_attention_layernorm", "model.norm", @@ -411,8 +420,8 @@ class PaliGemmaWithExpertModel( def _set_requires_grad(self): if self.freeze_vision_encoder: - self.paligemma.vision_tower.eval() - for param in self.paligemma.vision_tower.parameters(): + self.paligemma.model.vision_tower.eval() + for param in self.paligemma.model.vision_tower.parameters(): param.requires_grad = False if self.train_expert_only: self.paligemma.eval() @@ -422,15 +431,23 @@ class PaliGemmaWithExpertModel( def train(self, mode: bool = True): super().train(mode) if self.freeze_vision_encoder: - self.paligemma.vision_tower.eval() + self.paligemma.model.vision_tower.eval() if self.train_expert_only: self.paligemma.eval() def embed_image(self, image: torch.Tensor): - return self.paligemma.model.get_image_features(image) + # Vision tower and multi_modal_projector are kept in float32 (params_to_keep_float32). + out_dtype = image.dtype + if image.dtype != torch.float32: + image = image.to(torch.float32) + image_outputs = self.paligemma.model.get_image_features(image) + features = image_outputs.pooler_output * self.paligemma.config.text_config.hidden_size**0.5 + if features.dtype != out_dtype: + features = features.to(out_dtype) + return features def embed_language_tokens(self, tokens: torch.Tensor): - return self.paligemma.language_model.embed_tokens(tokens) + return self.paligemma.model.language_model.embed_tokens(tokens) def forward( self, @@ -444,7 +461,7 @@ class PaliGemmaWithExpertModel( if adarms_cond is None: adarms_cond = [None, None] if inputs_embeds[1] is None: - prefix_output = self.paligemma.language_model.forward( + prefix_output = self.paligemma.model.language_model.forward( inputs_embeds=inputs_embeds[0], attention_mask=attention_mask, position_ids=position_ids, @@ -468,7 +485,7 @@ class PaliGemmaWithExpertModel( prefix_output = None prefix_past_key_values = None else: - models = [self.paligemma.language_model, self.gemma_expert.model] + models = [self.paligemma.model.language_model, self.gemma_expert.model] num_layers = self.paligemma.config.text_config.num_hidden_layers # Check if gradient checkpointing is enabled for any of the models @@ -508,7 +525,7 @@ class PaliGemmaWithExpertModel( def compute_final_norms(inputs_embeds, adarms_cond): outputs_embeds = [] for i, hidden_states in enumerate(inputs_embeds): - out_emb, _ = models[i].norm(hidden_states, cond=adarms_cond[i]) + out_emb, _ = layernorm_forward(models[i].norm, hidden_states, adarms_cond[i]) outputs_embeds.append(out_emb) return outputs_embeds @@ -573,29 +590,19 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` # Also compile the main forward pass used during training self.forward = torch.compile(self.forward, mode=config.compile_mode) - msg = """An incorrect transformer version is used, please create an issue on https://github.com/huggingface/lerobot/issues""" - - try: - from transformers.models.siglip import check - - if not check.check_whether_transformers_replace_is_installed_correctly(): - raise ValueError(msg) - except ImportError: - raise ValueError(msg) from None - def gradient_checkpointing_enable(self): """Enable gradient checkpointing for memory optimization.""" self.gradient_checkpointing_enabled = True - self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = True - self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = True + self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing = True + self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing = True self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = True logging.info("Enabled gradient checkpointing for PI05Pytorch model") def gradient_checkpointing_disable(self): """Disable gradient checkpointing.""" self.gradient_checkpointing_enabled = False - self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = False - self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = False + self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing = False + self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing = False self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False logging.info("Disabled gradient checkpointing for PI05Pytorch model") @@ -737,7 +744,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(x_t, time) if ( - self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype + self.paligemma_with_expert.paligemma.model.language_model.layers[0].self_attn.q_proj.weight.dtype == torch.bfloat16 ): suffix_embs = suffix_embs.to(dtype=torch.bfloat16) @@ -808,7 +815,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1 prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_2d_masks) - self.paligemma_with_expert.paligemma.language_model.config._attn_implementation = "eager" # noqa: SLF001 + self.paligemma_with_expert.paligemma.model.language_model.config._attn_implementation = "eager" # noqa: SLF001 _, past_key_values = self.paligemma_with_expert.forward( attention_mask=prefix_att_2d_masks_4d, @@ -880,6 +887,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks) self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001 + past_key_values = copy.deepcopy(past_key_values) outputs_embeds, _ = self.paligemma_with_expert.forward( attention_mask=full_att_2d_masks_4d, position_ids=position_ids, @@ -969,14 +977,12 @@ class PI05Policy(PreTrainedPolicy): # Check if dataset_stats were provided in kwargs model = cls(config, **kwargs) - # Now manually load and remap the state dict + # Load state dict (expects keys with "model." prefix) try: - # Try to load the pytorch_model.bin or model.safetensors file print(f"Loading model from: {pretrained_name_or_path}") try: from transformers.utils import cached_file - # Try safetensors first resolved_file = cached_file( pretrained_name_or_path, "model.safetensors", @@ -984,7 +990,7 @@ class PI05Policy(PreTrainedPolicy): force_download=kwargs.get("force_download", False), resume_download=kwargs.get("resume_download"), proxies=kwargs.get("proxies"), - use_auth_token=kwargs.get("use_auth_token"), + token=kwargs.get("token"), revision=kwargs.get("revision"), local_files_only=kwargs.get("local_files_only", False), ) @@ -997,7 +1003,7 @@ class PI05Policy(PreTrainedPolicy): print("Returning model without loading pretrained weights") return model - # First, fix any key differences # see openpi `model.py, _fix_pytorch_state_dict_keys` + # First, fix any key differences (see openpi model.py, _fix_pytorch_state_dict_keys) fixed_state_dict = model._fix_pytorch_state_dict_keys(original_state_dict, model.config) # Then add "model." prefix for all keys that don't already have it @@ -1009,8 +1015,6 @@ class PI05Policy(PreTrainedPolicy): new_key = f"model.{key}" remapped_state_dict[new_key] = value remap_count += 1 - if remap_count <= 10: # Only print first 10 to avoid spam - print(f"Remapped: {key} -> {new_key}") else: remapped_state_dict[key] = value @@ -1044,7 +1048,7 @@ class PI05Policy(PreTrainedPolicy): print("All keys loaded successfully!") except Exception as e: - print(f"Warning: Could not remap state dict keys: {e}") + print(f"Warning: Could not load state dict: {e}") return model @@ -1098,6 +1102,14 @@ class PI05Policy(PreTrainedPolicy): # Some checkpoints might have this, but current model expects different structure logging.warning(f"Vision embedding key might need handling: {key}") + if ( + key == "model.paligemma_with_expert.paligemma.lm_head.weight" + or key == "paligemma_with_expert.paligemma.lm_head.weight" + ): + fixed_state_dict[ + "model.paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight" + ] = value.clone() + fixed_state_dict[new_key] = value return fixed_state_dict diff --git a/src/lerobot/policies/pi05/processor_pi05.py b/src/lerobot/policies/pi05/processor_pi05.py index e29bc4c23..6e01a4e16 100644 --- a/src/lerobot/policies/pi05/processor_pi05.py +++ b/src/lerobot/policies/pi05/processor_pi05.py @@ -23,7 +23,6 @@ import torch from lerobot.configs.types import PipelineFeatureType, PolicyFeature from lerobot.policies.pi05.configuration_pi05 import PI05Config -from lerobot.policies.pi05.modeling_pi05 import pad_vector from lerobot.processor import ( AddBatchDimensionProcessorStep, DeviceProcessorStep, @@ -68,9 +67,6 @@ class Pi05PrepareStateTokenizerProcessorStep(ProcessorStep): # TODO: check if this necessary state = deepcopy(state) - # Prepare state (pad to max_state_dim) - state = pad_vector(state, self.max_state_dim) - # State should already be normalized to [-1, 1] by the NormalizerProcessorStep that runs before this step # Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`) state_np = state.cpu().numpy() diff --git a/src/lerobot/policies/pi0_fast/configuration_pi0_fast.py b/src/lerobot/policies/pi0_fast/configuration_pi0_fast.py index 96137e91f..e12522833 100644 --- a/src/lerobot/policies/pi0_fast/configuration_pi0_fast.py +++ b/src/lerobot/policies/pi0_fast/configuration_pi0_fast.py @@ -54,7 +54,7 @@ class PI0FastConfig(PreTrainedConfig): tokenizer_max_length: int = 200 # see openpi `__post_init__` text_tokenizer_name: str = "google/paligemma-3b-pt-224" - action_tokenizer_name: str = "physical-intelligence/fast" + action_tokenizer_name: str = "lerobot/fast-action-tokenizer" temperature: float = 0.0 max_decoding_steps: int = 256 fast_skip_tokens: int = 128 diff --git a/src/lerobot/policies/pi0_fast/modeling_pi0_fast.py b/src/lerobot/policies/pi0_fast/modeling_pi0_fast.py index b4bc7ba22..52fc2504d 100644 --- a/src/lerobot/policies/pi0_fast/modeling_pi0_fast.py +++ b/src/lerobot/policies/pi0_fast/modeling_pi0_fast.py @@ -38,11 +38,16 @@ else: if TYPE_CHECKING or _transformers_available: from transformers import AutoTokenizer from transformers.models.auto import CONFIG_MAPPING - from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration + + from lerobot.policies.pi_gemma import ( + PaliGemmaForConditionalGenerationWithPiGemma, + PiGemmaModel, + ) else: CONFIG_MAPPING = None - PaliGemmaForConditionalGeneration = None AutoTokenizer = None + PiGemmaModel = None + PaliGemmaForConditionalGenerationWithPiGemma = None from lerobot.configs.policies import PreTrainedConfig from lerobot.policies.pi0_fast.configuration_pi0_fast import PI0FastConfig @@ -121,7 +126,7 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy) if images.dtype == torch.uint8: resized_images = torch.round(resized_images).clamp(0, 255).to(torch.uint8) elif images.dtype == torch.float32: - resized_images = resized_images.clamp(-1.0, 1.0) + resized_images = resized_images.clamp(0.0, 1.0) else: raise ValueError(f"Unsupported image dtype: {images.dtype}") @@ -132,7 +137,7 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy) pad_w1 = pad_w0 + remainder_w # Pad - constant_value = 0 if images.dtype == torch.uint8 else -1.0 + constant_value = 0 if images.dtype == torch.uint8 else 0.0 padded_images = F.pad( resized_images, (pad_w0, pad_w1, pad_h0, pad_h1), # left, right, top, bottom @@ -206,16 +211,22 @@ class PI0FastPaliGemma(nn.Module): vlm_config_hf.text_config.num_hidden_layers = vlm_config.depth vlm_config_hf.text_config.num_key_value_heads = vlm_config.num_kv_heads vlm_config_hf.text_config.hidden_activation = "gelu_pytorch_tanh" - vlm_config_hf.text_config.torch_dtype = "float32" + vlm_config_hf.text_config.dtype = "float32" vlm_config_hf.text_config.vocab_size = 257152 vlm_config_hf.text_config.use_adarms = use_adarms[0] vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None vlm_config_hf.vision_config.intermediate_size = 4304 vlm_config_hf.vision_config.projection_dim = 2048 vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast" - vlm_config_hf.vision_config.torch_dtype = "float32" + vlm_config_hf.vision_config.dtype = "float32" - self.paligemma = PaliGemmaForConditionalGeneration(config=vlm_config_hf) + self.paligemma = PaliGemmaForConditionalGenerationWithPiGemma(config=vlm_config_hf) + + # Use PI Gemma (AdaRMS) as language model when use_adarms[0] is True so that + # forward(..., adarms_cond=...) is supported (same as pi0/pi05). + if use_adarms[0]: + text_config = self.paligemma.config.text_config + self.paligemma.model.language_model = PiGemmaModel(text_config) self.to_bfloat16_for_selected_params(precision) @@ -228,10 +239,11 @@ class PI0FastPaliGemma(nn.Module): else: raise ValueError(f"Invalid precision: {precision}") + # Keep full vision path in float32 so we never toggle (toggle causes optimizer + # "same dtype" error). Align with PI05. params_to_keep_float32 = [ - "vision_tower.vision_model.embeddings.patch_embedding.weight", - "vision_tower.vision_model.embeddings.patch_embedding.bias", - "vision_tower.vision_model.embeddings.position_embedding.weight", + "vision_tower", + "multi_modal_projector", "input_layernorm", "post_attention_layernorm", "model.norm", @@ -242,10 +254,18 @@ class PI0FastPaliGemma(nn.Module): param.data = param.data.to(dtype=torch.float32) def embed_image(self, image: torch.Tensor): - return self.paligemma.model.get_image_features(image) + # Vision tower and multi_modal_projector are kept in float32 (params_to_keep_float32). Align with PI05. + out_dtype = image.dtype + if image.dtype != torch.float32: + image = image.to(torch.float32) + image_outputs = self.paligemma.model.get_image_features(image) + features = image_outputs.pooler_output * self.paligemma.config.text_config.hidden_size**0.5 + if features.dtype != out_dtype: + features = features.to(out_dtype) + return features def embed_language_tokens(self, tokens: torch.Tensor): - return self.paligemma.language_model.embed_tokens(tokens) + return self.paligemma.model.language_model.embed_tokens(tokens) def forward( self, @@ -259,7 +279,7 @@ class PI0FastPaliGemma(nn.Module): if adarms_cond is None: adarms_cond = [None, None] if inputs_embeds[1] is None: - prefix_output = self.paligemma.language_model.forward( + prefix_output = self.paligemma.model.language_model.forward( inputs_embeds=inputs_embeds[0], attention_mask=attention_mask, position_ids=position_ids, @@ -306,24 +326,14 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch` self.sample_actions_fast = torch.compile(self.sample_actions_fast, mode=config.compile_mode) self.forward = torch.compile(self.forward, mode=config.compile_mode) - msg = """An incorrect transformer version is used, please create an issue on https://github.com/huggingface/lerobot/issues""" - - try: - from transformers.models.siglip import check - - if not check.check_whether_transformers_replace_is_installed_correctly(): - raise ValueError(msg) - except ImportError: - raise ValueError(msg) from None - def gradient_checkpointing_enable(self): """Enable gradient checkpointing for memory optimization.""" self.gradient_checkpointing_enabled = True # Call the proper gradient_checkpointing_enable() method with use_reentrant=False for better memory efficiency - self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing_enable( + self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing_enable( gradient_checkpointing_kwargs={"use_reentrant": False} ) - self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing_enable( + self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing_enable( gradient_checkpointing_kwargs={"use_reentrant": False} ) logging.info("Enabled gradient checkpointing for PI0FastPytorch model") @@ -332,8 +342,8 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch` """Disable gradient checkpointing.""" self.gradient_checkpointing_enabled = False # Call the proper gradient_checkpointing_disable() method - self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing_disable() - self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing_disable() + self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing_disable() + self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing_disable() logging.info("Disabled gradient checkpointing for PI0FastPytorch model") def _apply_checkpoint(self, func, *args, **kwargs): @@ -523,7 +533,7 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch` # Convert embeddings to bfloat16 if needed if ( - self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype + self.paligemma_with_expert.paligemma.model.language_model.layers[0].self_attn.q_proj.weight.dtype == torch.bfloat16 ): prefix_embs = prefix_embs.to(dtype=torch.bfloat16) @@ -616,7 +626,7 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch` ) if ( - self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype + self.paligemma_with_expert.paligemma.model.language_model.layers[0].self_attn.q_proj.weight.dtype == torch.bfloat16 ): prefix_embs = prefix_embs.to(dtype=torch.bfloat16) @@ -714,7 +724,7 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch` # Ensure correct precision (bfloat16/float32) if ( - self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype + self.paligemma_with_expert.paligemma.model.language_model.layers[0].self_attn.q_proj.weight.dtype == torch.bfloat16 ): prefix_embs = prefix_embs.to(dtype=torch.bfloat16) @@ -897,14 +907,12 @@ class PI0FastPolicy(PreTrainedPolicy): # Check if dataset_stats were provided in kwargs model = cls(config, **kwargs) - # Now manually load and remap the state dict + # Load state dict (expects keys with "model." prefix) try: - # Try to load the pytorch_model.bin or model.safetensors file print(f"Loading model from: {pretrained_name_or_path}") try: from transformers.utils import cached_file - # Try safetensors first resolved_file = cached_file( pretrained_name_or_path, "model.safetensors", @@ -912,7 +920,7 @@ class PI0FastPolicy(PreTrainedPolicy): force_download=kwargs.get("force_download", False), resume_download=kwargs.get("resume_download"), proxies=kwargs.get("proxies"), - use_auth_token=kwargs.get("use_auth_token"), + token=kwargs.get("token"), revision=kwargs.get("revision"), local_files_only=kwargs.get("local_files_only", False), ) @@ -925,8 +933,9 @@ class PI0FastPolicy(PreTrainedPolicy): print("Returning model without loading pretrained weights") return model - # First, fix any key differences # see openpi `model.py, _fix_pytorch_state_dict_keys` + # First, fix any key differences (see openpi model.py, _fix_pytorch_state_dict_keys) fixed_state_dict = model._fix_pytorch_state_dict_keys(original_state_dict, model.config) + # Then add "model." prefix for all keys that don't already have it remapped_state_dict = {} remap_count = 0 @@ -936,8 +945,6 @@ class PI0FastPolicy(PreTrainedPolicy): new_key = f"model.{key}" remapped_state_dict[new_key] = value remap_count += 1 - if remap_count <= 10: # Only print first 10 to avoid spam - print(f"Remapped: {key} -> {new_key}") else: remapped_state_dict[key] = value @@ -971,7 +978,7 @@ class PI0FastPolicy(PreTrainedPolicy): print("All keys loaded successfully!") except Exception as e: - print(f"Warning: Could not remap state dict keys: {e}") + print(f"Warning: Could not load state dict: {e}") return model diff --git a/src/lerobot/policies/pi0_fast/processor_pi0_fast.py b/src/lerobot/policies/pi0_fast/processor_pi0_fast.py index 0d9dac673..fde7d5c80 100644 --- a/src/lerobot/policies/pi0_fast/processor_pi0_fast.py +++ b/src/lerobot/policies/pi0_fast/processor_pi0_fast.py @@ -23,7 +23,6 @@ import torch from lerobot.configs.types import PipelineFeatureType, PolicyFeature from lerobot.policies.pi0_fast.configuration_pi0_fast import PI0FastConfig -from lerobot.policies.pi0_fast.modeling_pi0_fast import pad_vector from lerobot.processor import ( ActionTokenizerProcessorStep, AddBatchDimensionProcessorStep, @@ -69,9 +68,6 @@ class Pi0FastPrepareStateAndLanguageTokenizerProcessorStep(ProcessorStep): # TODO: check if this necessary state = deepcopy(state) - # Prepare state (pad to max_state_dim) - state = pad_vector(state, self.max_state_dim) - # State should already be normalized to [-1, 1] by the NormalizerProcessorStep that runs before this step # Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`) state_np = state.cpu().numpy() diff --git a/src/lerobot/policies/pi_gemma.py b/src/lerobot/policies/pi_gemma.py new file mode 100644 index 000000000..05f031d08 --- /dev/null +++ b/src/lerobot/policies/pi_gemma.py @@ -0,0 +1,363 @@ +# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch +from torch import nn + +from lerobot.utils.import_utils import _transformers_available + +if TYPE_CHECKING or _transformers_available: + from transformers.cache_utils import DynamicCache + from transformers.masking_utils import create_causal_mask + from transformers.modeling_layers import GradientCheckpointingLayer + from transformers.modeling_outputs import BaseModelOutputWithPast + from transformers.models.gemma.modeling_gemma import ( + GemmaAttention, + GemmaConfig, + GemmaForCausalLM, + GemmaMLP, + GemmaModel, + ) + from transformers.models.paligemma.modeling_paligemma import ( + PaliGemmaForConditionalGeneration, + PaliGemmaModel, + ) +else: + GemmaAttention = None + GemmaConfig = None + GemmaForCausalLM = None + GemmaMLP = None + GemmaModel = None + PaliGemmaModel = None + PaliGemmaForConditionalGeneration = None + DynamicCache = None + GradientCheckpointingLayer = None + BaseModelOutputWithPast = None + create_causal_mask = None + + +def _gated_residual( + x: torch.Tensor | None, + y: torch.Tensor | None, + gate: torch.Tensor | None, +) -> torch.Tensor | None: + """Gated residual: x + y when gate is None, else x + y * gate.""" + if x is None and y is None: + return None + if x is None or y is None: + return x if x is not None else y + if gate is None: + return x + y + return x + y * gate + + +def layernorm_forward( + layernorm: nn.Module, + x: torch.Tensor, + cond: torch.Tensor | None = None, +): + """ + call layernorm and return hidden states and gate + if cond is not None, use conditional norm + otherwise, use normal gemma norm + """ + if cond is not None: + return layernorm(x, cond=cond) + else: + return layernorm(x) + + +class PiGemmaRMSNorm(nn.Module): + """ + Adaptive RMSNorm for PI Gemma (AdaRMS). + When cond_dim is set, uses cond to modulate scale/shift/gate; otherwise behaves like standard GemmaRMSNorm. + forward(x, cond=None) returns (output, gate) for use with _gated_residual. + """ + + def __init__(self, dim: int, eps: float = 1e-6, cond_dim: int | None = None): + super().__init__() + self.eps = eps + self.dim = dim + self.cond_dim = cond_dim + if cond_dim is not None: + self.dense = nn.Linear(cond_dim, dim * 3, bias=True) + nn.init.zeros_(self.dense.weight) + else: + self.weight = nn.Parameter(torch.zeros(dim)) + self.dense = None + + def _norm(self, x): + # Compute variance in float32 (like the source implementation) + var = torch.mean(torch.square(x.float()), dim=-1, keepdim=True) + # Compute normalization in float32 + normed_inputs = x * torch.rsqrt(var + self.eps) + return normed_inputs + + def forward( + self, + x: torch.Tensor, + cond: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + dtype = x.dtype + normed = self._norm(x) + if cond is None or self.dense is None: + normed = normed * (1.0 + self.weight.float()) + return normed.type_as(x), None + if cond.shape[-1] != self.cond_dim: + raise ValueError(f"Expected cond dim {self.cond_dim}, got {cond.shape[-1]}") + modulation = self.dense(cond) + if len(x.shape) == 3: + modulation = modulation.unsqueeze(1) + scale, shift, gate = modulation.chunk(3, dim=-1) + normed = normed * (1 + scale.float()) + shift.float() + return normed.to(dtype), gate.to(dtype) + + def extra_repr(self) -> str: + if self.dense is not None: + return f"dim={self.dim}, eps={self.eps}, adaptive=True, cond_dim={self.cond_dim}" + return f"dim={self.dim}, eps={self.eps}" + + +def _get_pi_gemma_decoder_layer_base(): + """base for PiGemmaDecoderLayer""" + + class _PiGemmaDecoderLayerBase(GradientCheckpointingLayer): + """Decoder layer that uses PiGemmaRMSNorm and _gated_residual, compatible with v5 Gemma.""" + + def __init__(self, config: GemmaConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = GemmaAttention(config=config, layer_idx=layer_idx) + self.mlp = GemmaMLP(config) + cond_dim = ( + getattr(config, "adarms_cond_dim", None) if getattr(config, "use_adarms", False) else None + ) + self.input_layernorm = PiGemmaRMSNorm( + config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim + ) + self.post_attention_layernorm = PiGemmaRMSNorm( + config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values=None, + use_cache: bool = False, + cache_position: torch.LongTensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + adarms_cond: torch.Tensor | None = None, + **kwargs, + ) -> torch.Tensor: + residual = hidden_states + hidden_states, gate = self.input_layernorm(hidden_states, cond=adarms_cond) + hidden_states, _ = self.self_attn( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states = _gated_residual(residual, hidden_states, gate) + + residual = hidden_states + hidden_states, gate = self.post_attention_layernorm(hidden_states, cond=adarms_cond) + hidden_states = self.mlp(hidden_states) + hidden_states = _gated_residual(residual, hidden_states, gate) + return hidden_states + + return _PiGemmaDecoderLayerBase + + +class PiGemmaModel(GemmaModel): # type: ignore[misc] + """ + GemmaModel extended with AdaRMS (adaptive RMSNorm) and gated residuals when config.use_adarms is True. + """ + + def __init__(self, config: GemmaConfig, **kwargs): + super().__init__(config, **kwargs) + # if not getattr(config, "use_adarms", False): + # return + cond_dim = getattr(config, "adarms_cond_dim", None) + pi_gemma_decoder_layer_base = _get_pi_gemma_decoder_layer_base() + self.layers = nn.ModuleList( + [pi_gemma_decoder_layer_base(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = PiGemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim) + + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: DynamicCache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + cache_position: torch.LongTensor | None = None, + adarms_cond: torch.Tensor | None = None, + **kwargs, + ) -> BaseModelOutputWithPast: + """ + adarms_cond (`torch.Tensor` of shape `(batch_size, cond_dim)`, *optional*): + Condition for ADARMS. + """ + output_attentions = ( + output_attentions if output_attentions is not None else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + import logging + + logging.warning( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = create_causal_mask( + config=self.config, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=position_ids, + ) + + # embed positions + hidden_states = inputs_embeds + # Convert to bfloat16 if the first layer uses bfloat16 + if len(self.layers) > 0 and self.layers[0].self_attn.q_proj.weight.dtype == torch.bfloat16: + hidden_states = hidden_states.to(torch.bfloat16) + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # normalized + # Gemma downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5 + # See https://github.com/huggingface/transformers/pull/29402 + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_values=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + adarms_cond=adarms_cond, + **kwargs, + ) + + hidden_states = layer_outputs + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states, _ = self.norm(hidden_states, adarms_cond) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class PiGemmaForCausalLM(GemmaForCausalLM): # type: ignore[misc] + """ + Causal LM wrapper using PiGemmaModel as the backbone, for consistency with GemmaForCausalLM + and the language model used in pi0_fast. Use this for the action expert in pi0/pi05. + """ + + def __init__(self, config: GemmaConfig, **kwargs): + super().__init__(config, **kwargs) + self.model = PiGemmaModel(config) + + +class PaliGemmaModelWithPiGemma(PaliGemmaModel): + """PaliGemmaModel whose language_model is PiGemmaModel (custom decoder with PiGemmaRMSNorm and gated residuals).""" + + def __init__(self, config): + super().__init__(config) + self.language_model = PiGemmaModel(config.text_config) + + +class PaliGemmaForConditionalGenerationWithPiGemma(PaliGemmaForConditionalGeneration): + """PaliGemmaForConditionalGeneration using PiGemma decoder for the language model.""" + + def __init__(self, config): + super().__init__(config) + self.model = PaliGemmaModelWithPiGemma(config) + + # Make modules available through conditional class for BC + @property + def language_model(self): + return self.model.language_model + + +__all__ = [ + "PiGemmaModel", + "PiGemmaForCausalLM", + "PiGemmaRMSNorm", + "_gated_residual", + "layernorm_forward", + "PaliGemmaModelWithPiGemma", + "PaliGemmaForConditionalGenerationWithPiGemma", +] diff --git a/src/lerobot/policies/sac/reward_model/configuration_classifier.py b/src/lerobot/policies/sac/reward_model/configuration_classifier.py index 9b76b8037..879e3c1af 100644 --- a/src/lerobot/policies/sac/reward_model/configuration_classifier.py +++ b/src/lerobot/policies/sac/reward_model/configuration_classifier.py @@ -33,7 +33,7 @@ class RewardClassifierConfig(PreTrainedConfig): latent_dim: int = 256 image_embedding_pooling_dim: int = 8 dropout_rate: float = 0.1 - model_name: str = "helper2424/resnet10" + model_name: str = "helper2424/resnet10" # TODO: This needs to be updated. The model on the Hub doesn't call self.post_init() in its __init__, which is required by transformers v5 to set all_tied_weights_keys. The from_pretrained call fails when it tries to access this attribute during _finalize_model_loading. device: str = "cpu" model_type: str = "cnn" # "transformer" or "cnn" num_cameras: int = 2 diff --git a/src/lerobot/policies/wall_x/configuration_wall_x.py b/src/lerobot/policies/wall_x/configuration_wall_x.py index 3962b56f6..5269c4e10 100644 --- a/src/lerobot/policies/wall_x/configuration_wall_x.py +++ b/src/lerobot/policies/wall_x/configuration_wall_x.py @@ -55,7 +55,7 @@ class WallXConfig(PreTrainedConfig): pretrained_name_or_path: str = "x-square-robot/wall-oss-flow" # Tokenizer settings - action_tokenizer_path: str | None = "physical-intelligence/fast" + action_tokenizer_path: str | None = "lerobot/fast-action-tokenizer" # Action prediction mode: "diffusion" or "fast" prediction_mode: str = "diffusion" diff --git a/src/lerobot/policies/wall_x/modeling_wall_x.py b/src/lerobot/policies/wall_x/modeling_wall_x.py index ef99bad89..84ee05743 100644 --- a/src/lerobot/policies/wall_x/modeling_wall_x.py +++ b/src/lerobot/policies/wall_x/modeling_wall_x.py @@ -261,10 +261,15 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration): and optional LoRA fine-tuning support. """ - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} config_class = Qwen2_5_VLConfig _no_split_modules = ["Qwen2_5_VLDecoderLayer_with_MoE", "Qwen2_5_VLVisionBlock"] + def init_weights(self): + if getattr(self.model, "language_model", None) is not None: + return + super().init_weights() + @classmethod def from_pretrained( cls, @@ -312,6 +317,11 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration): processor.action_processor = action_tokenizer else: action_tokenizer = None + + # add pad_token_id to config + config.pad_token_id = processor.tokenizer.pad_token_id + config.text_config.pad_token_id = processor.tokenizer.pad_token_id + # Initialize model with configuration and processor model = cls(config, processor=processor, action_tokenizer=action_tokenizer, **kwargs) @@ -331,7 +341,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration): force_download=kwargs.get("force_download", False), resume_download=kwargs.get("resume_download"), proxies=kwargs.get("proxies"), - use_auth_token=kwargs.get("use_auth_token"), + token=kwargs.get("token"), revision=kwargs.get("revision"), local_files_only=kwargs.get("local_files_only", False), ) diff --git a/src/lerobot/policies/wall_x/qwen_model/configuration_qwen2_5_vl.py b/src/lerobot/policies/wall_x/qwen_model/configuration_qwen2_5_vl.py index 731ef3b3e..19874b6ff 100644 --- a/src/lerobot/policies/wall_x/qwen_model/configuration_qwen2_5_vl.py +++ b/src/lerobot/policies/wall_x/qwen_model/configuration_qwen2_5_vl.py @@ -21,6 +21,7 @@ class Qwen2_5_VLVisionConfig(PretrainedConfig): window_size=112, out_hidden_size=3584, fullatt_block_indexes=[7, 15, 23, 31], + initializer_range=0.02, **kwargs, ): super().__init__(**kwargs) @@ -38,6 +39,7 @@ class Qwen2_5_VLVisionConfig(PretrainedConfig): self.window_size = window_size self.fullatt_block_indexes = fullatt_block_indexes self.out_hidden_size = out_hidden_size + self.initializer_range = initializer_range class Qwen2_5_VLConfig(PretrainedConfig): diff --git a/src/lerobot/policies/wall_x/qwen_model/qwen2_5_vl_moe.py b/src/lerobot/policies/wall_x/qwen_model/qwen2_5_vl_moe.py index 490e25095..ecf3eb371 100644 --- a/src/lerobot/policies/wall_x/qwen_model/qwen2_5_vl_moe.py +++ b/src/lerobot/policies/wall_x/qwen_model/qwen2_5_vl_moe.py @@ -11,7 +11,6 @@ from transformers.activations import ACT2FN from transformers.cache_utils import ( Cache, DynamicCache, - SlidingWindowCache, StaticCache, ) from transformers.generation import GenerationMixin @@ -31,6 +30,15 @@ from transformers.utils import ( from .configuration_qwen2_5_vl import Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig + +# TODO(Steven): SlidingWindowCache was removed in transformers v5. Define a placeholder so isinstance checks +# always return False (which is the correct behavior when no sliding window cache is in use). +class _SlidingWindowCachePlaceholder: + pass + + +SlidingWindowCache = _SlidingWindowCachePlaceholder + if is_flash_attn_2_available(): from flash_attn import flash_attn_func, flash_attn_varlen_func from flash_attn.layers.rotary import apply_rotary_emb @@ -594,19 +602,40 @@ class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel): return hidden_states +def _compute_default_rope_parameters_qwen2_5_vl(config, device=None): + """ + compute default rope parameters for Qwen2_5_VL + """ + base = config.text_config.rope_parameters["rope_theta"] + dim = config.hidden_size // config.num_attention_heads + inv_freq = 1.0 / ( + base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim) + ) + return inv_freq, 1.0 + + class Qwen2_5_VLRotaryEmbedding(nn.Module): def __init__(self, config: Qwen2_5_VLConfig, device=None): super().__init__() # BC: "rope_type" was originally "type" if hasattr(config, "rope_scaling") and config.rope_scaling is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + elif hasattr(config, "rope_parameters") and config.rope_parameters is not None: + self.rope_type = config.rope_parameters.get("rope_type", "default") else: self.rope_type = "default" self.max_seq_len_cached = config.max_position_embeddings self.original_max_seq_len = config.max_position_embeddings self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + if self.rope_type == "default": + self.rope_init_fn = _compute_default_rope_parameters_qwen2_5_vl + self.rope_kwargs = {} + else: + rope_type_key = "linear" if self.rope_type == "linear" else self.rope_type + self.rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type_key] + self.rope_kwargs = {} inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) @@ -1567,7 +1596,7 @@ QWEN2_5_VL_INPUTS_DOCSTRING = r""" class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} config_class = Qwen2_5_VLConfig _no_split_modules = ["Qwen2_5_VLDecoderLayer", "Qwen2_5_VLVisionBlock"] diff --git a/src/lerobot/policies/wall_x/utils.py b/src/lerobot/policies/wall_x/utils.py index 2ea40b377..e08ef69d5 100644 --- a/src/lerobot/policies/wall_x/utils.py +++ b/src/lerobot/policies/wall_x/utils.py @@ -144,7 +144,7 @@ def preprocesser_call( """ # Process image inputs if images is not None and len(images) > 0: - image_inputs = processor.image_processor(images=images, videos=None, return_tensors=return_tensors) + image_inputs = processor.image_processor(images=images, return_tensors=return_tensors) image_grid_thw = image_inputs["image_grid_thw"] else: image_inputs = {} @@ -152,7 +152,7 @@ def preprocesser_call( # Process video inputs if videos is not None: - videos_inputs = processor.image_processor(images=None, videos=videos, return_tensors=return_tensors) + videos_inputs = processor.image_processor(videos=videos, return_tensors=return_tensors) video_grid_thw = videos_inputs["video_grid_thw"] else: videos_inputs = {} diff --git a/src/lerobot/policies/xvla/configuration_florence2.py b/src/lerobot/policies/xvla/configuration_florence2.py index 35c006ee0..77f1b3a1d 100644 --- a/src/lerobot/policies/xvla/configuration_florence2.py +++ b/src/lerobot/policies/xvla/configuration_florence2.py @@ -276,6 +276,8 @@ class Florence2LanguageConfig(PretrainedConfig): ) # ensure backward compatibility for BART CNN models + if not hasattr(self, "forced_bos_token_id"): + self.forced_bos_token_id = None if self.forced_bos_token_id is None and kwargs.get("force_bos_token_to_be_generated", False): self.forced_bos_token_id = self.bos_token_id warnings.warn( diff --git a/src/lerobot/policies/xvla/modeling_florence2.py b/src/lerobot/policies/xvla/modeling_florence2.py index 2b5316fae..e33efe5c3 100644 --- a/src/lerobot/policies/xvla/modeling_florence2.py +++ b/src/lerobot/policies/xvla/modeling_florence2.py @@ -1951,7 +1951,10 @@ class Florence2Decoder(Florence2LanguagePreTrainedModel): class Florence2LanguageModel(Florence2LanguagePreTrainedModel): - _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] + _tied_weights_keys = { + "encoder.embed_tokens.weight": "shared.weight", + "decoder.embed_tokens.weight": "shared.weight", + } def __init__(self, config: Florence2LanguageConfig): super().__init__(config) @@ -2076,7 +2079,10 @@ class Florence2LanguageModel(Florence2LanguagePreTrainedModel): class Florence2LanguageForConditionalGeneration(Florence2LanguagePreTrainedModel, GenerationMixin): base_model_prefix = "model" - _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"] + _tied_weights_keys = { + "model.encoder.embed_tokens.weight": "model.shared.weight", + "model.decoder.embed_tokens.weight": "model.shared.weight", + } _keys_to_ignore_on_load_missing = ["final_logits_bias"] def __init__(self, config: Florence2LanguageConfig): @@ -2436,11 +2442,10 @@ FLORENCE2_INPUTS_DOCSTRING = r""" FLORENCE2_START_DOCSTRING, ) class Florence2ForConditionalGeneration(Florence2PreTrainedModel): - _tied_weights_keys = [ - "language_model.encoder.embed_tokens.weight", - "language_model.decoder.embed_tokens.weight", - "language_model.lm_head.weight", - ] + _tied_weights_keys = { + "language_model.model.encoder.embed_tokens.weight": "language_model.model.shared.weight", + "language_model.model.decoder.embed_tokens.weight": "language_model.model.shared.weight", + } def __init__(self, config: Florence2Config): super().__init__(config) diff --git a/src/lerobot/processor/tokenizer_processor.py b/src/lerobot/processor/tokenizer_processor.py index df559555a..da6e600af 100644 --- a/src/lerobot/processor/tokenizer_processor.py +++ b/src/lerobot/processor/tokenizer_processor.py @@ -336,7 +336,7 @@ class ActionTokenizerProcessorStep(ActionProcessorStep): Requires the `transformers` library to be installed. Attributes: - tokenizer_name: The name of a pretrained processor from the Hugging Face Hub (e.g., "physical-intelligence/fast"). + tokenizer_name: The name of a pretrained processor from the Hugging Face Hub (e.g., "lerobot/fast-action-tokenizer"). tokenizer: A pre-initialized processor/tokenizer object. If provided, `tokenizer_name` is ignored. trust_remote_code: Whether to trust remote code when loading the tokenizer (required for some tokenizers). action_tokenizer: The internal tokenizer/processor instance, loaded during initialization. diff --git a/src/lerobot/scripts/lerobot_train_tokenizer.py b/src/lerobot/scripts/lerobot_train_tokenizer.py index 1d8f4644b..807d48333 100644 --- a/src/lerobot/scripts/lerobot_train_tokenizer.py +++ b/src/lerobot/scripts/lerobot_train_tokenizer.py @@ -306,7 +306,7 @@ def train_fast_tokenizer( # download the tokenizer source code (not pretrained weights) # we'll train a new tokenizer on our own data - base_tokenizer = AutoProcessor.from_pretrained("physical-intelligence/fast", trust_remote_code=True) + base_tokenizer = AutoProcessor.from_pretrained("lerobot/fast-action-tokenizer", trust_remote_code=True) # convert action_chunks array to list of arrays (expected by .fit()) action_data_list = [action_chunks[i] for i in range(len(action_chunks))] diff --git a/tests/policies/hilserl/test_modeling_classifier.py b/tests/policies/hilserl/test_modeling_classifier.py index a572ea9e1..a62ef3ebb 100644 --- a/tests/policies/hilserl/test_modeling_classifier.py +++ b/tests/policies/hilserl/test_modeling_classifier.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import pytest import torch from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature @@ -37,6 +38,9 @@ def test_classifier_output(): @require_package("transformers") +@pytest.mark.skip( + reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers" +) def test_binary_classifier_with_default_params(): from lerobot.policies.sac.reward_model.modeling_classifier import Classifier @@ -78,6 +82,9 @@ def test_binary_classifier_with_default_params(): @require_package("transformers") +@pytest.mark.skip( + reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers" +) def test_multiclass_classifier(): from lerobot.policies.sac.reward_model.modeling_classifier import Classifier @@ -117,6 +124,9 @@ def test_multiclass_classifier(): @require_package("transformers") +@pytest.mark.skip( + reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers" +) def test_default_device(): from lerobot.policies.sac.reward_model.modeling_classifier import Classifier @@ -129,6 +139,9 @@ def test_default_device(): @require_package("transformers") +@pytest.mark.skip( + reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers" +) def test_explicit_device_setup(): from lerobot.policies.sac.reward_model.modeling_classifier import Classifier diff --git a/tests/policies/pi0_fast/test_pi0_fast_original_vs_lerobot.py b/tests/policies/pi0_fast/test_pi0_fast_original_vs_lerobot.py index 9ebc4ba89..9de781464 100644 --- a/tests/policies/pi0_fast/test_pi0_fast_original_vs_lerobot.py +++ b/tests/policies/pi0_fast/test_pi0_fast_original_vs_lerobot.py @@ -17,7 +17,6 @@ """Test script to verify PI0Fast policy integration with LeRobot vs the original implementation""" # ruff: noqa: E402 -import os import random from copy import deepcopy from typing import Any @@ -28,10 +27,6 @@ import torch pytest.importorskip("transformers") pytest.importorskip("scipy") -pytestmark = pytest.mark.skipif( - os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true", - reason="This test requires accepting the model license", -) from lerobot.policies.pi0_fast.configuration_pi0_fast import PI0FastConfig from lerobot.policies.pi0_fast.modeling_pi0_fast import PI0FastPolicy @@ -53,22 +48,23 @@ DUMMY_STATE_DIM = 20 IMAGE_HEIGHT = 224 IMAGE_WIDTH = 224 NUM_VIEWS = 2 # Number of camera views -DEVICE = "cuda" if torch.cuda.is_available() else "cpu" +DEVICE = "cuda" MODEL_PATH_LEROBOT = "lerobot/pi0fast-base" # Expected action token shape: (batch_size, max_decoding_steps) EXPECTED_ACTION_TOKENS_SHAPE = (1, 2) # Expected first 5 action tokens (for reproducibility check) -EXPECTED_ACTION_TOKENS_FIRST_5 = torch.tensor([255657, 255362]) +EXPECTED_ACTION_TOKENS_FIRST_5 = torch.tensor([255020, 255589]) # Expected actions after detokenization EXPECTED_ACTIONS_SHAPE = (1, 2, 32) # (batch_size, n_action_steps, action_dim) -EXPECTED_ACTIONS_MEAN = 0.04419417306780815 -EXPECTED_ACTIONS_STD = 0.26231569051742554 -EXPECTED_ACTIONS_FIRST_5 = torch.tensor([0.0000, 1.4849, 0.0000, 0.0000, 0.0000]) +EXPECTED_ACTIONS_MEAN = 0.046403881162405014 +EXPECTED_ACTIONS_STD = 0.2607129216194153 +EXPECTED_ACTIONS_FIRST_5 = torch.tensor([0.0000, 0.3536, 0.0707, 0.0000, 0.0000]) +@require_cuda def set_seed_all(seed: int): """Set random seed for all RNG sources to ensure reproducibility.""" random.seed(seed) @@ -85,6 +81,7 @@ def set_seed_all(seed: int): torch.use_deterministic_algorithms(True, warn_only=True) +@require_cuda def instantiate_lerobot_pi0_fast( from_pretrained: bool = False, model_path: str = MODEL_PATH_LEROBOT, @@ -127,6 +124,7 @@ def instantiate_lerobot_pi0_fast( return policy, preprocessor, postprocessor +@require_cuda def create_dummy_data(device=DEVICE): """Create dummy data for testing both implementations.""" batch_size = 1 @@ -158,22 +156,25 @@ def create_dummy_data(device=DEVICE): # Pytest fixtures @pytest.fixture(scope="module") +@require_cuda def pi0_fast_components(): """Fixture to instantiate and provide all PI0Fast components for tests.""" print(f"\nTesting with DEVICE='{DEVICE}'") print("\n[Setup] Instantiating LeRobot PI0Fast policy...") policy_obj, preprocessor_obj, postprocessor_obj = instantiate_lerobot_pi0_fast(from_pretrained=True) print("Model loaded successfully") - yield policy_obj, preprocessor_obj, postprocessor_obj + return policy_obj, preprocessor_obj, postprocessor_obj @pytest.fixture(scope="module") +@require_cuda def policy(pi0_fast_components): """Fixture to provide the PI0Fast policy for tests.""" return pi0_fast_components[0] @pytest.fixture(scope="module") +@require_cuda def preprocessor(pi0_fast_components): """Fixture to provide the PI0Fast preprocessor for tests.""" return pi0_fast_components[1] diff --git a/tests/policies/pi0_pi05/test_pi0.py b/tests/policies/pi0_pi05/test_pi0.py index b580310eb..e83abf57d 100644 --- a/tests/policies/pi0_pi05/test_pi0.py +++ b/tests/policies/pi0_pi05/test_pi0.py @@ -16,17 +16,8 @@ """Test script to verify PI0 policy integration with LeRobot, only meant to be run locally!""" -import os - -import pytest import torch -# Skip this entire module in CI -pytestmark = pytest.mark.skipif( - os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true", - reason="This test requires local OpenPI installation and is not meant for CI", -) - from lerobot.policies.factory import make_policy_config # noqa: E402 from lerobot.policies.pi0 import ( # noqa: E402 PI0Config, diff --git a/tests/policies/pi0_pi05/test_pi05.py b/tests/policies/pi0_pi05/test_pi05.py index 964539446..595191689 100644 --- a/tests/policies/pi0_pi05/test_pi05.py +++ b/tests/policies/pi0_pi05/test_pi05.py @@ -16,25 +16,15 @@ """Test script to verify PI0.5 (pi05) support in PI0 policy, only meant to be run locally!""" -import os - -import pytest import torch -from lerobot.utils.random_utils import set_seed - -# Skip this entire module in CI -pytestmark = pytest.mark.skipif( - os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true", - reason="This test requires local OpenPI installation and is not meant for CI", -) - from lerobot.policies.factory import make_policy_config # noqa: E402 from lerobot.policies.pi05 import ( # noqa: E402 PI05Config, PI05Policy, make_pi05_pre_post_processors, # noqa: E402 ) +from lerobot.utils.random_utils import set_seed from tests.utils import require_cuda # noqa: E402 diff --git a/tests/policies/pi0_pi05/test_pi05_rtc.py b/tests/policies/pi0_pi05/test_pi05_rtc.py index 3a753031f..0dc240638 100644 --- a/tests/policies/pi0_pi05/test_pi05_rtc.py +++ b/tests/policies/pi0_pi05/test_pi05_rtc.py @@ -24,9 +24,10 @@ import torch # Skip this entire module in CI pytestmark = pytest.mark.skipif( os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true", - reason="This test requires local OpenPI installation and is not meant for CI", + reason="TODO: This test seems to hang the CI", ) + from lerobot.configs.types import FeatureType, PolicyFeature, RTCAttentionSchedule # noqa: E402 from lerobot.policies.pi05 import PI05Config, PI05Policy, make_pi05_pre_post_processors # noqa: E402 from lerobot.policies.rtc.configuration_rtc import RTCConfig # noqa: E402 diff --git a/tests/policies/pi0_pi05/test_pi0_rtc.py b/tests/policies/pi0_pi05/test_pi0_rtc.py index 68e94dd94..4105e2068 100644 --- a/tests/policies/pi0_pi05/test_pi0_rtc.py +++ b/tests/policies/pi0_pi05/test_pi0_rtc.py @@ -24,9 +24,10 @@ import torch # Skip this entire module in CI pytestmark = pytest.mark.skipif( os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true", - reason="This test requires local OpenPI installation and is not meant for CI", + reason="TODO: This test seems to hang the CI", ) + from lerobot.configs.types import FeatureType, PolicyFeature, RTCAttentionSchedule # noqa: E402 from lerobot.policies.pi0 import PI0Config, PI0Policy, make_pi0_pre_post_processors # noqa: E402 from lerobot.policies.rtc.configuration_rtc import RTCConfig # noqa: E402 @@ -88,6 +89,7 @@ def test_pi0_rtc_initialization_without_rtc_config(): print("✓ PI0 RTC initialization without RTC config: Test passed") +@require_cuda def test_pi0_rtc_inference_with_prev_chunk(): """Test PI0 policy inference with RTC and previous chunk.""" set_seed(42) diff --git a/tests/policies/test_sac_policy.py b/tests/policies/test_sac_policy.py index 6fad2979e..11499ce30 100644 --- a/tests/policies/test_sac_policy.py +++ b/tests/policies/test_sac_policy.py @@ -305,6 +305,9 @@ def test_sac_policy_with_visual_input(batch_size: int, state_dim: int, action_di [(1, 6, 6, "helper2424/resnet10"), (1, 6, 6, "facebook/convnext-base-224")], ) @pytest.mark.skipif(not TRANSFORMERS_AVAILABLE, reason="Transformers are not installed") +@pytest.mark.skip( + reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers" +) def test_sac_policy_with_pretrained_encoder( batch_size: int, state_dim: int, action_dim: int, vision_encoder_name: str ): diff --git a/tests/policies/wall_x/test_wallx.py b/tests/policies/wall_x/test_wallx.py index e5f124123..3514fccd1 100644 --- a/tests/policies/wall_x/test_wallx.py +++ b/tests/policies/wall_x/test_wallx.py @@ -16,8 +16,6 @@ """Test script to verify Wall-X policy integration with LeRobot, only meant to be run locally!""" -import os - import pytest import torch @@ -26,19 +24,15 @@ pytest.importorskip("peft") pytest.importorskip("transformers") pytest.importorskip("torchdiffeq") -# Skip this entire module in CI -pytestmark = pytest.mark.skipif( - os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true", - reason="This test requires local Wall-X installation and is not meant for CI", -) - from lerobot.policies.factory import make_policy_config # noqa: E402 from lerobot.policies.wall_x import WallXConfig # noqa: E402 from lerobot.policies.wall_x.modeling_wall_x import WallXPolicy # noqa: E402 from lerobot.policies.wall_x.processor_wall_x import make_wall_x_pre_post_processors # noqa: E402 from lerobot.utils.random_utils import set_seed # noqa: E402 +from tests.utils import require_cuda # noqa: E402 +@require_cuda def test_policy_instantiation(): # Create config set_seed(42) @@ -123,6 +117,7 @@ def test_policy_instantiation(): raise +@require_cuda def test_config_creation(): """Test policy config creation through factory.""" try: @@ -134,8 +129,3 @@ def test_config_creation(): except Exception as e: print(f"Config creation failed: {e}") raise - - -if __name__ == "__main__": - test_policy_instantiation() - test_config_creation() From 3e45120272fa30c2ddd02e2134a356f21abe294b Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Thu, 5 Mar 2026 13:22:37 +0100 Subject: [PATCH 087/131] fix(ci): log in HF for gated repo in nightly workflows (#3089) * fix(ci): log in HF for gated repo in nightly workflows * fix(ci): add env var * fix(ci): remove 10 min limit for multi-gpu nightly --- .github/workflows/nightly.yml | 16 +++++++++++++++- .github/workflows/unbound_deps_tests.yml | 11 ++++++++++- 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/.github/workflows/nightly.yml b/.github/workflows/nightly.yml index 45bfb9bd5..563b5957d 100644 --- a/.github/workflows/nightly.yml +++ b/.github/workflows/nightly.yml @@ -119,6 +119,7 @@ jobs: HF_LEROBOT_HOME: /home/user_lerobot/.cache/huggingface/lerobot TORCH_HOME: /home/user_lerobot/.cache/torch TRITON_CACHE_DIR: /home/user_lerobot/.cache/triton + HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }} container: image: ${{ needs.build-docker-cpu-nightly.outputs.image_tag }} # zizmor: ignore[unpinned-images] options: --shm-size "16gb" @@ -130,6 +131,10 @@ jobs: shell: bash working-directory: /lerobot steps: + - name: Login to Hugging Face + run: | + hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential + hf auth whoami - name: Run pytest on CPU run: pytest tests -vv --maxfail=10 - name: Run end-to-end tests @@ -146,6 +151,7 @@ jobs: HF_LEROBOT_HOME: /home/user_lerobot/.cache/huggingface/lerobot TORCH_HOME: /home/user_lerobot/.cache/torch TRITON_CACHE_DIR: /home/user_lerobot/.cache/triton + HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }} container: image: ${{ needs.build-docker-gpu-nightly.outputs.image_tag }} # zizmor: ignore[unpinned-images] options: --gpus all --shm-size "16gb" @@ -157,6 +163,10 @@ jobs: shell: bash working-directory: /lerobot steps: + - name: Login to Hugging Face + run: | + hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential + hf auth whoami - name: Run pytest on GPU run: pytest tests -vv --maxfail=10 - name: Run end-to-end tests @@ -174,6 +184,7 @@ jobs: TORCH_HOME: /home/user_lerobot/.cache/torch TRITON_CACHE_DIR: /home/user_lerobot/.cache/triton CUDA_VISIBLE_DEVICES: "0,1,2,3" + HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }} container: image: ${{ needs.build-docker-gpu-nightly.outputs.image_tag }} # zizmor: ignore[unpinned-images] options: --gpus all --shm-size "16gb" @@ -185,6 +196,10 @@ jobs: shell: bash working-directory: /lerobot steps: + - name: Login to Hugging Face + run: | + hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential + hf auth whoami - name: Verify GPU availability run: | nvidia-smi @@ -193,4 +208,3 @@ jobs: - name: Run multi-GPU training tests # TODO(Steven): Investigate why motors tests are failing in multi-GPU setup run: pytest tests -vv --maxfail=10 --ignore=tests/motors/ - timeout-minutes: 10 diff --git a/.github/workflows/unbound_deps_tests.yml b/.github/workflows/unbound_deps_tests.yml index 3f4ea3316..19de38e3b 100644 --- a/.github/workflows/unbound_deps_tests.yml +++ b/.github/workflows/unbound_deps_tests.yml @@ -48,6 +48,7 @@ jobs: MUJOCO_GL: egl HF_HOME: /mnt/cache/.cache/huggingface HF_LEROBOT_HOME: /mnt/cache/.cache/huggingface/lerobot + HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }} steps: - uses: actions/checkout@v6 with: @@ -79,7 +80,10 @@ jobs: - name: Install lerobot with all extras run: uv sync --extra all # TODO(Steven): Make flash-attn optional - + - name: Login to Hugging Face + run: | + uv run hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential + uv run hf auth whoami - name: Run pytest (all extras) run: uv run pytest tests -vv @@ -137,6 +141,7 @@ jobs: HF_LEROBOT_HOME: /home/user_lerobot/.cache/huggingface/lerobot TORCH_HOME: /home/user_lerobot/.cache/torch TRITON_CACHE_DIR: /home/user_lerobot/.cache/triton + HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }} container: image: ${{ needs.build-and-push-docker.outputs.image_tag }} # zizmor: ignore[unpinned-images] options: --gpus all --shm-size "16gb" @@ -148,6 +153,10 @@ jobs: shell: bash working-directory: /lerobot steps: + - name: Login to Hugging Face + run: | + hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential + hf auth whoami - name: Run pytest on GPU run: pytest tests -vv - name: Run end-to-end tests From 92fba372257dd86c924052b327478f4df84ffbbd Mon Sep 17 00:00:00 2001 From: Caroline Pascal Date: Thu, 5 Mar 2026 15:49:50 +0100 Subject: [PATCH 088/131] fix(num_frames): fixing redundant frames count in conversion script (#3091) --- src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py b/src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py index 3ae9093b9..81de05686 100644 --- a/src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py +++ b/src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py @@ -228,7 +228,6 @@ def convert_data(root: Path, new_root: Path, data_file_size_in_mb: int): # Reset for the next file size_in_mb = 0 - num_frames += ep_num_frames # Still need to accumulate total frames paths_to_cat = [] # Now create metadata with correct chunk/file indices From 1a24f770d310cfca3ac69aeef9871ef235989cbe Mon Sep 17 00:00:00 2001 From: Pepijn <138571049+pkooij@users.noreply.github.com> Date: Thu, 5 Mar 2026 18:27:58 +0100 Subject: [PATCH 089/131] Feat/slurm compute rabc script (#3041) * Add SLURM SARM progress annotation script. Provide a standalone two-stage compute/aggregate pipeline for RA-BC progress generation so large datasets can be processed in parallel and optionally uploaded to the Hub. Made-with: Cursor * fix pr comments * remove comments --- examples/dataset/slurm_compute_rabc.py | 490 +++++++++++++++++++++++++ 1 file changed, 490 insertions(+) create mode 100644 examples/dataset/slurm_compute_rabc.py diff --git a/examples/dataset/slurm_compute_rabc.py b/examples/dataset/slurm_compute_rabc.py new file mode 100644 index 000000000..2ddf84d07 --- /dev/null +++ b/examples/dataset/slurm_compute_rabc.py @@ -0,0 +1,490 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +SLURM-distributed SARM RA-BC annotation pipeline. + +Computes SARM progress values for all frames in a dataset, distributed across +SLURM workers, then merges the shards into a single sarm_progress.parquet. + +Two subcommands, each a separate SLURM submission: + + compute – N workers, each computes progress for a subset of episodes + aggregate – 1 worker, merges N shards into sarm_progress.parquet, pushes to hub + +Usage: + python slurm_compute_rabc.py compute \\ + --repo-id user/dataset --reward-model-path user/sarm_model \\ + --stride 10 --device cpu --workers 50 --partition cpu + + python slurm_compute_rabc.py aggregate \\ + --repo-id user/dataset --reward-model-path user/sarm_model \\ + --partition cpu --push-to-hub +""" + +import argparse +from pathlib import Path + +from datatrove.executor import LocalPipelineExecutor +from datatrove.executor.slurm import SlurmPipelineExecutor +from datatrove.pipeline.base import PipelineStep + + +class ComputeProgressShards(PipelineStep): + """Each worker computes SARM progress for its assigned episodes.""" + + def __init__( + self, repo_id, reward_model_path, stride=1, head_mode="sparse", device="cpu", shard_dir="rabc_shards" + ): + super().__init__() + if stride < 1: + raise ValueError(f"stride must be >= 1, got {stride}") + self.repo_id = repo_id + self.reward_model_path = reward_model_path + self.stride = stride + self.head_mode = head_mode + self.device = device + self.shard_dir = shard_dir + + def run(self, data=None, rank: int = 0, world_size: int = 1): + import logging + from pathlib import Path + + import numpy as np + import pyarrow as pa + import pyarrow.parquet as pq + import torch + from tqdm import tqdm + + from lerobot.policies.sarm.compute_rabc_weights import ( + generate_all_frame_indices, + interpolate_progress, + load_sarm_resources, + ) + from lerobot.utils.utils import init_logging + + init_logging() + + dataset, reward_model, preprocess = load_sarm_resources( + self.repo_id, + self.reward_model_path, + self.device, + ) + + if hasattr(preprocess, "eval"): + preprocess.eval() + for step in preprocess.steps: + if hasattr(step, "eval"): + step.eval() + + image_key = reward_model.config.image_key + state_key = reward_model.config.state_key + frame_gap = reward_model.config.frame_gap + center_idx = reward_model.config.n_obs_steps // 2 + + dual_mode = reward_model.config.uses_dual_heads + compute_sparse = self.head_mode in ("sparse", "both") or not dual_mode + compute_dense = self.head_mode in ("dense", "both") and dual_mode + + my_episodes = list(range(dataset.num_episodes))[rank::world_size] + if not my_episodes: + logging.info(f"Rank {rank}: no episodes assigned") + return + logging.info(f"Rank {rank}: {len(my_episodes)} / {dataset.num_episodes} episodes") + + all_rows = [] + + for ep_idx in tqdm(my_episodes, desc=f"Rank {rank}"): + ep = dataset.meta.episodes[ep_idx] + ep_start, ep_end = ep["dataset_from_index"], ep["dataset_to_index"] + task = dataset[ep_start].get("task", "perform the task") + + all_ep_indices = generate_all_frame_indices(ep_start, ep_end, frame_gap) + if self.stride > 1: + compute_indices = [i for i in all_ep_indices if (i - ep_start) % self.stride == 0] + if (ep_end - 1) not in compute_indices: + compute_indices.append(ep_end - 1) + compute_indices = sorted(set(compute_indices)) + else: + compute_indices = all_ep_indices + + frame_results = {} + for qi in tqdm(compute_indices, desc=f" Ep {ep_idx}", leave=False): + try: + sample = dataset[qi] + batch = { + image_key: sample[image_key], + "task": task, + "index": qi, + "episode_index": ep_idx, + } + if state_key in sample: + batch[state_key] = sample[state_key] + + with torch.no_grad(): + processed = preprocess(batch) + vf = processed["video_features"].to(self.device) + tf = processed["text_features"].to(self.device) + sf = processed.get("state_features") + if sf is not None: + sf = sf.to(self.device) + lengths = processed.get("lengths") + + sparse_val = dense_val = np.nan + if compute_sparse: + r = reward_model.calculate_rewards( + text_embeddings=tf, + video_embeddings=vf, + state_features=sf, + lengths=lengths, + return_all_frames=True, + head_mode="sparse", + ) + sparse_val = float(r[0, center_idx] if r.ndim == 2 else r[center_idx]) + if compute_dense: + r = reward_model.calculate_rewards( + text_embeddings=tf, + video_embeddings=vf, + state_features=sf, + lengths=lengths, + return_all_frames=True, + head_mode="dense", + ) + dense_val = float(r[0, center_idx] if r.ndim == 2 else r[center_idx]) + + frame_results[qi] = (sparse_val, dense_val) + except Exception as e: + logging.warning(f"Failed frame {qi}: {e}") + + if not frame_results: + logging.warning(f"Episode {ep_idx}: all frames failed, skipping") + continue + + # Interpolate to all frames in this episode + computed_idx = np.array(sorted(frame_results.keys())) + all_frame_arr = np.arange(ep_start, ep_end) + + sparse_vals = np.array([frame_results[i][0] for i in computed_idx]) if compute_sparse else None + dense_vals = np.array([frame_results[i][1] for i in computed_idx]) if compute_dense else None + + if self.stride > 1 and len(computed_idx) > 1: + if compute_sparse: + sparse_vals = interpolate_progress(computed_idx, sparse_vals, all_frame_arr) + if compute_dense: + dense_vals = interpolate_progress(computed_idx, dense_vals, all_frame_arr) + output_frames = all_frame_arr + else: + # Use only successfully computed frames to avoid indexing mismatch on failures + output_frames = computed_idx + + for i, fi in enumerate(output_frames): + row = {"index": int(fi), "episode_index": ep_idx, "frame_index": int(fi - ep_start)} + if compute_sparse: + row["progress_sparse"] = float(sparse_vals[i]) + if compute_dense: + row["progress_dense"] = float(dense_vals[i]) + all_rows.append(row) + + if all_rows: + import pandas as pd + + df = pd.DataFrame(all_rows).sort_values("index").reset_index(drop=True) + table = pa.Table.from_pandas(df, preserve_index=False) + table = table.replace_schema_metadata({b"reward_model_path": self.reward_model_path.encode()}) + shard_dir = Path(self.shard_dir) + shard_dir.mkdir(parents=True, exist_ok=True) + out = shard_dir / f"shard_{rank:05d}.parquet" + pq.write_table(table, out) + logging.info(f"Rank {rank}: saved {len(df)} rows to {out}") + + +class AggregateProgress(PipelineStep): + """Merge all shard parquets into final sarm_progress.parquet.""" + + def __init__(self, repo_id, reward_model_path, shard_dir="rabc_shards", push_to_hub=False): + super().__init__() + self.repo_id = repo_id + self.reward_model_path = reward_model_path + self.shard_dir = shard_dir + self.push_to_hub = push_to_hub + + def run(self, data=None, rank: int = 0, world_size: int = 1): + import datetime + import logging + import os + from pathlib import Path + + import pandas as pd + import pyarrow as pa + import pyarrow.parquet as pq + + from lerobot.datasets.lerobot_dataset import LeRobotDataset + from lerobot.utils.utils import init_logging + + init_logging() + if rank != 0: + return + + shard_dir = Path(self.shard_dir) + shards = sorted(shard_dir.glob("shard_*.parquet")) + if not shards: + raise FileNotFoundError(f"No shards found in {shard_dir}") + + # Log shard modification time range to help detect stale files + mtimes = [os.path.getmtime(s) for s in shards] + oldest = datetime.datetime.fromtimestamp(min(mtimes)).isoformat(timespec="seconds") + newest = datetime.datetime.fromtimestamp(max(mtimes)).isoformat(timespec="seconds") + logging.info(f"Aggregating {len(shards)} shards (oldest: {oldest}, newest: {newest})") + + df = pd.concat([pd.read_parquet(s) for s in shards], ignore_index=True) + df = df.sort_values("index").reset_index(drop=True) + + table = pa.Table.from_pandas(df, preserve_index=False) + table = table.replace_schema_metadata({b"reward_model_path": self.reward_model_path.encode()}) + + temp_ds = LeRobotDataset(self.repo_id, download_videos=False) + out_path = Path(temp_ds.root) / "sarm_progress.parquet" + out_path.parent.mkdir(parents=True, exist_ok=True) + pq.write_table(table, out_path) + logging.info(f"Saved {len(df)} rows to {out_path}") + + for col in ["progress_sparse", "progress_dense"]: + if col in df.columns: + v = df[col].dropna() + logging.info( + f"{col}: mean={v.mean():.4f} std={v.std():.4f} min={v.min():.4f} max={v.max():.4f}" + ) + + if self.push_to_hub: + from huggingface_hub import HfApi + + api = HfApi() + hub_path = "sarm_progress.parquet" + logging.info(f"Uploading to {self.repo_id}/{hub_path}") + api.upload_file( + path_or_fileobj=str(out_path), + path_in_repo=hub_path, + repo_id=self.repo_id, + repo_type="dataset", + ) + logging.info(f"Uploaded: https://huggingface.co/datasets/{self.repo_id}/blob/main/{hub_path}") + + +def make_compute_executor( + repo_id, + reward_model_path, + stride, + head_mode, + device, + shard_dir, + logs_dir, + job_name, + slurm, + workers, + partition, + cpus_per_task, + mem_per_cpu, +): + kwargs = { + "pipeline": [ + ComputeProgressShards(repo_id, reward_model_path, stride, head_mode, device, str(shard_dir)), + ], + "logging_dir": str(logs_dir / job_name), + } + + if slurm: + kwargs.update( + { + "job_name": job_name, + "tasks": workers, + "workers": workers, + "time": "24:00:00", + "partition": partition, + "cpus_per_task": cpus_per_task, + "sbatch_args": {"mem-per-cpu": mem_per_cpu}, + } + ) + return SlurmPipelineExecutor(**kwargs) + + kwargs.update({"tasks": workers, "workers": 1}) + return LocalPipelineExecutor(**kwargs) + + +def make_aggregate_executor( + repo_id, + reward_model_path, + shard_dir, + logs_dir, + job_name, + slurm, + partition, + cpus_per_task, + mem_per_cpu, + push_to_hub, +): + kwargs = { + "pipeline": [ + AggregateProgress(repo_id, reward_model_path, str(shard_dir), push_to_hub), + ], + "logging_dir": str(logs_dir / job_name), + } + + if slurm: + kwargs.update( + { + "job_name": job_name, + "tasks": 1, + "workers": 1, + "time": "02:00:00", + "partition": partition, + "cpus_per_task": cpus_per_task, + "sbatch_args": {"mem-per-cpu": mem_per_cpu}, + } + ) + return SlurmPipelineExecutor(**kwargs) + + kwargs.update({"tasks": 1, "workers": 1}) + return LocalPipelineExecutor(**kwargs) + + +def _add_shared_args(p): + p.add_argument( + "--repo-id", + type=str, + required=True, + help="Hugging Face repository identifier, e.g. 'user/dataset'.", + ) + p.add_argument( + "--shard-dir", + type=Path, + default=Path("rabc_shards"), + help="Directory to read/write per-rank parquet shards.", + ) + p.add_argument( + "--logs-dir", + type=Path, + default=Path("logs"), + help="Directory for datatrove logs.", + ) + p.add_argument( + "--job-name", + type=str, + default=None, + help="SLURM job name (defaults to rabc_).", + ) + p.add_argument( + "--slurm", + type=int, + default=1, + help="1 = submit via SLURM; 0 = run locally (useful for debugging).", + ) + p.add_argument( + "--partition", + type=str, + default=None, + help="SLURM partition to submit to.", + ) + p.add_argument( + "--cpus-per-task", + type=int, + default=4, + help="Number of CPUs per SLURM task.", + ) + p.add_argument( + "--mem-per-cpu", + type=str, + default="4G", + help="Memory per CPU, e.g. '4G' or '1950M'.", + ) + + +def main(): + parser = argparse.ArgumentParser( + description="SLURM-distributed SARM RA-BC annotation pipeline", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + sub = parser.add_subparsers(dest="command", required=True) + + # compute subcommand + cp = sub.add_parser( + "compute", + help="Distribute progress computation across SLURM workers.", + ) + _add_shared_args(cp) + cp.add_argument( + "--reward-model-path", + type=str, + required=True, + help="Path or HF repo id of the SARM reward model.", + ) + cp.add_argument( + "--stride", + type=int, + default=1, + help="Compute every Nth frame; intermediate frames are interpolated (must be >= 1).", + ) + cp.add_argument( + "--head-mode", + type=str, + default="sparse", + choices=["sparse", "dense", "both"], + help="Which reward head(s) to compute.", + ) + cp.add_argument( + "--device", + type=str, + default="cpu", + help="Device for reward model inference, e.g. 'cpu' or 'cuda'.", + ) + cp.add_argument( + "--workers", + type=int, + default=50, + help="Number of parallel SLURM tasks (one shard per worker).", + ) + + # aggregate subcommand + ap = sub.add_parser( + "aggregate", + help="Merge per-rank shards into a single sarm_progress.parquet.", + ) + _add_shared_args(ap) + ap.add_argument( + "--reward-model-path", + type=str, + required=True, + help="Path or HF repo id of the SARM reward model (stored in parquet metadata).", + ) + ap.add_argument( + "--push-to-hub", + action="store_true", + help="Upload sarm_progress.parquet to the Hugging Face Hub after aggregation.", + ) + + args = parser.parse_args() + job_name = args.job_name or f"rabc_{args.command}" + kwargs = vars(args) + kwargs["slurm"] = kwargs.pop("slurm") == 1 + kwargs["job_name"] = job_name + command = kwargs.pop("command") + + executor = make_compute_executor(**kwargs) if command == "compute" else make_aggregate_executor(**kwargs) + + executor.run() + + +if __name__ == "__main__": + main() From d324ffe810d17264a0b1e628698aa1fa09aa639c Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Thu, 5 Mar 2026 19:53:40 +0100 Subject: [PATCH 090/131] fix(ci): test only multi-gpu tests in multi-gpu runner (#3092) --- .github/workflows/nightly.yml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/workflows/nightly.yml b/.github/workflows/nightly.yml index 563b5957d..cd1c6f9b7 100644 --- a/.github/workflows/nightly.yml +++ b/.github/workflows/nightly.yml @@ -206,5 +206,4 @@ jobs: python -c "import torch; print(f'PyTorch CUDA available: {torch.cuda.is_available()}'); print(f'Number of GPUs: {torch.cuda.device_count()}')" - name: Run multi-GPU training tests - # TODO(Steven): Investigate why motors tests are failing in multi-GPU setup - run: pytest tests -vv --maxfail=10 --ignore=tests/motors/ + run: pytest -vv tests/training/ From e489ba24fc212eb24c45fca2056328e5ea5bf8a5 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Fri, 6 Mar 2026 10:15:13 +0100 Subject: [PATCH 091/131] feat(dependencies): require Python 3.12+ as minimum version (#3023) * feat(dependecies): upgrade to python3.12 * fix(test): processor regex message * fix(test): processor regex message * fix(dependecies): resolve all tags in python 3.12 * fix(dependecies): add more hints to faster resolve * chore(dependecies): remove cli tag huggingface-hub dep * refactor(policy): update eagle for python3.12 * chore(docs): update policy creation for python 3.12 * chore(test): skip failing tests in macos --- .github/workflows/fast_tests.yml | 2 +- .github/workflows/full_tests.yml | 4 +- .github/workflows/nightly.yml | 2 +- .github/workflows/quality.yml | 2 +- .github/workflows/release.yml | 4 +- .github/workflows/unbound_deps_tests.yml | 2 +- .pre-commit-config.yaml | 4 +- docker/Dockerfile.internal | 2 +- docker/Dockerfile.user | 2 +- docs/source/bring_your_own_policies.mdx | 8 +-- docs/source/earthrover_mini_plus.mdx | 2 +- docs/source/installation.mdx | 6 +-- docs/source/unitree_g1.mdx | 4 +- pyproject.toml | 53 +++++++++++-------- src/lerobot/datasets/utils.py | 6 +-- src/lerobot/motors/motors_bus.py | 8 +-- src/lerobot/policies/factory.py | 3 +- .../image_processing_eagle2_5_vl_fast.py | 27 +++++----- src/lerobot/policies/pi0/modeling_pi0.py | 3 +- src/lerobot/policies/pi05/modeling_pi05.py | 3 +- .../policies/pi0_fast/modeling_pi0_fast.py | 3 +- src/lerobot/policies/pretrained.py | 3 +- .../policies/smolvla/modeling_smolvla.py | 3 +- src/lerobot/processor/core.py | 10 ++-- src/lerobot/processor/pipeline.py | 8 +-- .../robots/so_follower/config_so_follower.py | 5 +- src/lerobot/robots/so_follower/so_follower.py | 5 +- .../so_leader/config_so_leader.py | 5 +- .../teleoperators/so_leader/so_leader.py | 5 +- src/lerobot/utils/io_utils.py | 4 +- tests/policies/test_policies.py | 7 +++ tests/processor/test_pipeline.py | 4 +- tests/utils.py | 3 +- 33 files changed, 106 insertions(+), 106 deletions(-) diff --git a/.github/workflows/fast_tests.yml b/.github/workflows/fast_tests.yml index 27a4043e7..7715823ff 100644 --- a/.github/workflows/fast_tests.yml +++ b/.github/workflows/fast_tests.yml @@ -44,7 +44,7 @@ permissions: # Sets up the environment variables env: UV_VERSION: "0.8.0" - PYTHON_VERSION: "3.10" + PYTHON_VERSION: "3.12" # Ensures that only the latest commit for a PR or branch is built, canceling older runs. concurrency: diff --git a/.github/workflows/full_tests.yml b/.github/workflows/full_tests.yml index 8dd1fcb1c..0e50de879 100644 --- a/.github/workflows/full_tests.yml +++ b/.github/workflows/full_tests.yml @@ -37,7 +37,7 @@ permissions: # Sets up the environment variables env: UV_VERSION: "0.8.0" - PYTHON_VERSION: "3.10" + PYTHON_VERSION: "3.12" DOCKER_IMAGE_NAME: huggingface/lerobot-gpu # Ensures that only the latest action is built, canceling older runs. @@ -185,7 +185,7 @@ jobs: hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential hf auth whoami - name: Fix ptxas permissions - run: chmod +x /lerobot/.venv/lib/python3.10/site-packages/triton/backends/nvidia/bin/ptxas + run: chmod +x /lerobot/.venv/lib/python3.12/site-packages/triton/backends/nvidia/bin/ptxas - name: Run pytest on GPU run: pytest tests -vv --maxfail=10 - name: Run end-to-end tests diff --git a/.github/workflows/nightly.yml b/.github/workflows/nightly.yml index cd1c6f9b7..95c6702cd 100644 --- a/.github/workflows/nightly.yml +++ b/.github/workflows/nightly.yml @@ -28,7 +28,7 @@ on: # Sets up the environment variables env: UV_VERSION: "0.8.0" - PYTHON_VERSION: "3.10" + PYTHON_VERSION: "3.12" DOCKER_IMAGE_NAME_CPU: huggingface/lerobot-cpu:latest DOCKER_IMAGE_NAME_GPU: huggingface/lerobot-gpu:latest diff --git a/.github/workflows/quality.yml b/.github/workflows/quality.yml index 0dc94cdd4..a84e9c17e 100644 --- a/.github/workflows/quality.yml +++ b/.github/workflows/quality.yml @@ -50,7 +50,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v6 with: - python-version: '3.10' + python-version: '3.12' - name: Run pre-commit hooks uses: pre-commit/action@v3.0.1 # zizmor: ignore[unpinned-uses] diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index bcab4c262..e95d6cef6 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -22,7 +22,7 @@ on: # Sets up the environment variables env: UV_VERSION: "0.8.0" - PYTHON_VERSION: "3.10" + PYTHON_VERSION: "3.12" jobs: # This job builds the Python package and publishes it to PyPI @@ -45,7 +45,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v6 with: - python-version: '3.10' + python-version: '3.12' - name: Extract Version id: extract_info diff --git a/.github/workflows/unbound_deps_tests.yml b/.github/workflows/unbound_deps_tests.yml index 19de38e3b..9ce44152a 100644 --- a/.github/workflows/unbound_deps_tests.yml +++ b/.github/workflows/unbound_deps_tests.yml @@ -29,7 +29,7 @@ permissions: # Sets up the environment variables env: UV_VERSION: "0.8.0" - PYTHON_VERSION: "3.10" + PYTHON_VERSION: "3.12" DOCKER_IMAGE_NAME: huggingface/lerobot-gpu:unbound # Ensures that only the latest action is built, canceling older runs. diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index bfa3340d4..dff7416f4 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -13,7 +13,7 @@ # limitations under the License. default_language_version: - python: python3.10 + python: python3.12 exclude: "tests/artifacts/.*\\.safetensors$" @@ -55,7 +55,7 @@ repos: rev: v3.21.0 hooks: - id: pyupgrade - args: [--py310-plus] + args: [--py312-plus] ##### Markdown Quality ##### - repo: https://github.com/rbubley/mirrors-prettier diff --git a/docker/Dockerfile.internal b/docker/Dockerfile.internal index ed7d10495..b385fc51c 100644 --- a/docker/Dockerfile.internal +++ b/docker/Dockerfile.internal @@ -24,7 +24,7 @@ ARG OS_VERSION=22.04 FROM nvidia/cuda:${CUDA_VERSION}-base-ubuntu${OS_VERSION} # Define Python version argument -ARG PYTHON_VERSION=3.10 +ARG PYTHON_VERSION=3.12 # Configure environment variables ENV DEBIAN_FRONTEND=noninteractive \ diff --git a/docker/Dockerfile.user b/docker/Dockerfile.user index 031165930..d43d12816 100644 --- a/docker/Dockerfile.user +++ b/docker/Dockerfile.user @@ -19,7 +19,7 @@ # docker run -it --rm lerobot-user # Configure the base image -ARG PYTHON_VERSION=3.10 +ARG PYTHON_VERSION=3.12 FROM python:${PYTHON_VERSION}-slim # Configure environment variables diff --git a/docs/source/bring_your_own_policies.mdx b/docs/source/bring_your_own_policies.mdx index 0ff098708..9266c9e5b 100644 --- a/docs/source/bring_your_own_policies.mdx +++ b/docs/source/bring_your_own_policies.mdx @@ -32,7 +32,7 @@ version = "0.1.0" dependencies = [ # your policy-specific dependencies ] -requires-python = ">= 3.11" +requires-python = ">= 3.12" [build-system] build-backend = # your-build-backend @@ -82,7 +82,7 @@ Create your policy implementation by inheriting from LeRobot's base `PreTrainedP # modeling_my_custom_policy.py import torch import torch.nn as nn -from typing import Dict, Any +from typing import Any from lerobot.policies.pretrained import PreTrainedPolicy from .configuration_my_custom_policy import MyCustomPolicyConfig @@ -91,7 +91,7 @@ class MyCustomPolicy(PreTrainedPolicy): config_class = MyCustomPolicyConfig name = "my_custom_policy" - def __init__(self, config: MyCustomPolicyConfig, dataset_stats: Dict[str, Any] = None): + def __init__(self, config: MyCustomPolicyConfig, dataset_stats: dict[str, Any] = None): super().__init__(config, dataset_stats) ... ``` @@ -102,7 +102,7 @@ Create processor functions: ```python # processor_my_custom_policy.py -from typing import Dict, Any +from typing import Any import torch diff --git a/docs/source/earthrover_mini_plus.mdx b/docs/source/earthrover_mini_plus.mdx index 37986a7a2..7b739ecc1 100644 --- a/docs/source/earthrover_mini_plus.mdx +++ b/docs/source/earthrover_mini_plus.mdx @@ -13,7 +13,7 @@ The EarthRover Mini Plus is a fully open source mobile robot that connects throu ### Hardware - EarthRover Mini robot -- Computer with Python 3.10 or newer +- Computer with Python 3.12 or newer - Internet connection ### Setting Up the Frodobots SDK diff --git a/docs/source/installation.mdx b/docs/source/installation.mdx index a112377c1..26c88a0a2 100644 --- a/docs/source/installation.mdx +++ b/docs/source/installation.mdx @@ -1,6 +1,6 @@ # Installation -This guide uses conda (via miniforge) to manage environments. If you prefer another environment manager (e.g. `uv`, `venv`), ensure you have Python >=3.10 and ffmpeg installed with the `libsvtav1` encoder, then skip ahead to [Install LeRobot](#step-3-install-lerobot-). +This guide uses conda (via miniforge) to manage environments. If you prefer another environment manager (e.g. `uv`, `venv`), ensure you have Python >=3.12 and ffmpeg installed with the `libsvtav1` encoder, then skip ahead to [Install LeRobot](#step-3-install-lerobot-). ## Step 1: Install [`miniforge`](https://conda-forge.org/download/) @@ -11,10 +11,10 @@ bash Miniforge3-$(uname)-$(uname -m).sh ## Step 2: Environment Setup -Create a virtual environment with Python 3.10, using conda: +Create a virtual environment with Python 3.12, using conda: ```bash -conda create -y -n lerobot python=3.10 +conda create -y -n lerobot python=3.12 ``` Then activate your conda environment, you have to do this each time you open a shell to use lerobot: diff --git a/docs/source/unitree_g1.mdx b/docs/source/unitree_g1.mdx index 76e972dca..fa7159154 100644 --- a/docs/source/unitree_g1.mdx +++ b/docs/source/unitree_g1.mdx @@ -123,7 +123,7 @@ SSH into the robot and install LeRobot: ```bash ssh unitree@ -conda create -y -n lerobot python=3.10 +conda create -y -n lerobot python=3.12 conda activate lerobot git clone https://github.com/huggingface/lerobot.git cd lerobot @@ -153,7 +153,7 @@ With the robot server running, you can now control the robot remotely. Let's lau ### Step 1: Install LeRobot on your machine ```bash -conda create -y -n lerobot python=3.10 +conda create -y -n lerobot python=3.12 conda activate lerobot git clone https://github.com/huggingface/lerobot.git cd lerobot diff --git a/pyproject.toml b/pyproject.toml index f86184900..7cd83591f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,7 +29,7 @@ version = "0.4.5" description = "🤗 LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch" dynamic = ["readme"] license = { text = "Apache-2.0" } -requires-python = ">=3.10" +requires-python = ">=3.12" authors = [ { name = "Rémi Cadène", email = "re.cadene@gmail.com" }, { name = "Simon Alibert", email = "alibert.sim@gmail.com" }, @@ -50,7 +50,8 @@ classifiers = [ "Intended Audience :: Education", "Intended Audience :: Science/Research", "License :: OSI Approved :: Apache Software License", - "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", "Topic :: Software Development :: Build Tools", "Topic :: Scientific/Engineering :: Artificial Intelligence", ] @@ -61,26 +62,28 @@ dependencies = [ # Hugging Face dependencies "datasets>=4.0.0,<5.0.0", "diffusers>=0.27.2,<0.36.0", - "huggingface-hub[cli]>=1.0.0,<2.0.0", + "huggingface-hub>=1.0.0,<2.0.0", "accelerate>=1.10.0,<2.0.0", # Core dependencies + "numpy>=2.0.0,<2.3.0", # TODO: upper bound imposed by opencv-python-headless "setuptools>=71.0.0,<81.0.0", "cmake>=3.29.0.1,<4.2.0", + "packaging>=24.2,<26.0", + + "torch>=2.2.1,<2.11.0", + "torchcodec>=0.2.1,<0.11.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", + "torchvision>=0.21.0,<0.26.0", + "einops>=0.8.0,<0.9.0", "opencv-python-headless>=4.9.0,<4.13.0", "av>=15.0.0,<16.0.0", "jsonlines>=4.0.0,<5.0.0", - "packaging>=24.2,<26.0", - "pynput>=1.7.7,<1.9.0", + "pynput>=1.7.8,<1.9.0", "pyserial>=3.5,<4.0", + "wandb>=0.24.0,<0.25.0", - - "torch>=2.2.1,<2.11.0", # TODO: Bump dependency - "torchcodec>=0.2.1,<0.11.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # TODO: Bump dependency - "torchvision>=0.21.0,<0.26.0", # TODO: Bump dependency - - "draccus==0.10.0", # TODO: Remove == + "draccus==0.10.0", # TODO: Relax version constraint "gymnasium>=1.1.1,<2.0.0", "rerun-sdk>=0.24.0,<0.27.0", @@ -95,13 +98,14 @@ dependencies = [ # Common pygame-dep = ["pygame>=2.5.1,<2.7.0"] -placo-dep = ["placo>=0.9.6,<0.10.0"] +placo-dep = ["placo>=0.9.6,<0.9.17"] transformers-dep = ["transformers>=5.3.0,<6.0.0"] grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"] can-dep = ["python-can>=4.2.0,<5.0.0"] peft-dep = ["peft>=0.18.0,<1.0.0"] scipy-dep = ["scipy>=1.14.0,<2.0.0"] qwen-vl-utils-dep = ["qwen-vl-utils>=0.0.11,<0.1.0"] +matplotlib-dep = ["matplotlib>=3.10.3,<4.0.0", "contourpy>=1.3.0,<2.0.0"] # Motors feetech = ["feetech-servo-sdk>=1.0.0,<2.0.0"] @@ -119,7 +123,7 @@ unitree_g1 = [ "onnxruntime>=1.16.0,<2.0.0", "pin>=3.0.0,<4.0.0", "meshcat>=0.3.0,<0.4.0", - "matplotlib>=3.9.0,<4.0.0", + "lerobot[matplotlib-dep]", "casadi>=3.6.0,<4.0.0", ] reachy2 = ["reachy2_sdk>=1.0.15,<1.1.0"] @@ -128,7 +132,7 @@ intelrealsense = [ "pyrealsense2>=2.55.1.6486,<2.57.0 ; sys_platform != 'darwin'", "pyrealsense2-macosx>=2.54,<2.55.0 ; sys_platform == 'darwin'", ] -phone = ["hebi-py>=2.8.0,<2.12.0", "teleop>=0.1.0,<0.2.0", "fastapi<1.0"] +phone = ["hebi-py>=2.8.0,<2.12.0", "teleop>=0.1.0,<0.2.0", "fastapi<1.0", "lerobot[scipy-dep]"] # Policies wallx = [ @@ -151,12 +155,12 @@ groot = [ "ninja>=1.11.1,<2.0.0", "flash-attn>=2.5.9,<3.0.0 ; sys_platform != 'darwin'" ] -sarm = ["lerobot[transformers-dep]", "faker>=33.0.0,<35.0.0", "matplotlib>=3.10.3,<4.0.0", "lerobot[qwen-vl-utils-dep]"] +sarm = ["lerobot[transformers-dep]", "faker>=33.0.0,<35.0.0", "lerobot[matplotlib-dep]", "lerobot[qwen-vl-utils-dep]"] xvla = ["lerobot[transformers-dep]"] hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"] # Features -async = ["lerobot[grpcio-dep]", "matplotlib>=3.10.3,<4.0.0"] +async = ["lerobot[grpcio-dep]", "lerobot[matplotlib-dep]"] peft = ["lerobot[transformers-dep]", "lerobot[peft-dep]"] # Development @@ -165,13 +169,18 @@ test = ["pytest>=8.1.0,<9.0.0", "pytest-timeout>=2.4.0,<3.0.0", "pytest-cov>=5.0 video_benchmark = ["scikit-image>=0.23.2,<0.26.0", "pandas>=2.2.2,<2.4.0"] # Simulation -aloha = ["gym-aloha>=0.1.2,<0.2.0"] +aloha = ["gym-aloha>=0.1.2,<0.2.0", "lerobot[scipy-dep]"] pusht = ["gym-pusht>=0.1.5,<0.2.0", "pymunk>=6.6.0,<7.0.0"] # TODO: Fix pymunk version in gym-pusht instead -libero = ["lerobot[transformers-dep]", "hf-libero>=0.1.3,<0.2.0"] -metaworld = ["metaworld==3.0.0"] +libero = ["lerobot[transformers-dep]", "hf-libero>=0.1.3,<0.2.0; sys_platform == 'linux'", "lerobot[scipy-dep]"] +metaworld = ["metaworld==3.0.0", "lerobot[scipy-dep]"] # All all = [ + # Resolver hint: scipy is pulled in transitively via lerobot[scipy-dep] through + # multiple extras below (aloha, metaworld, pi, wallx, phone). Listing it explicitly + # helps pip's resolver converge by constraining scipy early, before it encounters + # the loose scipy requirements from transitive deps like dm-control and metaworld. + "scipy>=1.14.0,<2.0.0", "lerobot[dynamixel]", "lerobot[gamepad]", "lerobot[hopejr]", @@ -192,7 +201,7 @@ all = [ "lerobot[aloha]", "lerobot[pusht]", "lerobot[phone]", - "lerobot[libero]", + "lerobot[libero]; sys_platform == 'linux'", "lerobot[metaworld]", "lerobot[sarm]", "lerobot[peft]", @@ -224,7 +233,7 @@ lerobot = ["envs/*.json"] where = ["src"] [tool.ruff] -target-version = "py310" +target-version = "py312" line-length = 110 exclude = ["tests/artifacts/**/*.safetensors", "*_pb2.py", "*_pb2_grpc.py"] @@ -316,7 +325,7 @@ default.extend-ignore-identifiers-re = [ # Uncomment [tool.mypy] first, then uncomment individual module overrides as they get proper type annotations [tool.mypy] -python_version = "3.10" +python_version = "3.12" ignore_missing_imports = true follow_imports = "skip" # warn_return_any = true diff --git a/src/lerobot/datasets/utils.py b/src/lerobot/datasets/utils.py index a56740191..8bc56a1bd 100644 --- a/src/lerobot/datasets/utils.py +++ b/src/lerobot/datasets/utils.py @@ -21,7 +21,7 @@ from collections import deque from collections.abc import Iterable, Iterator from pathlib import Path from pprint import pformat -from typing import Any, Generic, TypeVar +from typing import Any import datasets import numpy as np @@ -78,8 +78,6 @@ DEFAULT_FEATURES = { "task_index": {"dtype": "int64", "shape": (1,), "names": None}, } -T = TypeVar("T") - def get_parquet_file_size_in_mb(parquet_path: str | Path) -> float: metadata = pq.read_metadata(parquet_path) @@ -1234,7 +1232,7 @@ class LookAheadError(Exception): pass -class Backtrackable(Generic[T]): +class Backtrackable[T]: """ Wrap any iterator/iterable so you can step back up to `history` items and look ahead up to `lookahead` items. diff --git a/src/lerobot/motors/motors_bus.py b/src/lerobot/motors/motors_bus.py index bc3ffb7e2..509f5e95f 100644 --- a/src/lerobot/motors/motors_bus.py +++ b/src/lerobot/motors/motors_bus.py @@ -29,7 +29,7 @@ from dataclasses import dataclass from enum import Enum from functools import cached_property from pprint import pformat -from typing import Protocol, TypeAlias +from typing import Protocol import serial from deepdiff import DeepDiff @@ -38,8 +38,8 @@ from tqdm import tqdm from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from lerobot.utils.utils import enter_pressed, move_cursor_up -NameOrID: TypeAlias = str | int -Value: TypeAlias = int | float +type NameOrID = str | int +type Value = int | float logger = logging.getLogger(__name__) @@ -1277,4 +1277,4 @@ class SerialMotorsBus(MotorsBusBase): # Backward compatibility alias -MotorsBus: TypeAlias = SerialMotorsBus +MotorsBus = SerialMotorsBus diff --git a/src/lerobot/policies/factory.py b/src/lerobot/policies/factory.py index a593e5bcb..d50d8652a 100644 --- a/src/lerobot/policies/factory.py +++ b/src/lerobot/policies/factory.py @@ -18,10 +18,9 @@ from __future__ import annotations import importlib import logging -from typing import Any, TypedDict +from typing import Any, TypedDict, Unpack import torch -from typing_extensions import Unpack from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.types import FeatureType diff --git a/src/lerobot/policies/groot/eagle2_hg_model/image_processing_eagle2_5_vl_fast.py b/src/lerobot/policies/groot/eagle2_hg_model/image_processing_eagle2_5_vl_fast.py index e01b9b839..90e9dcecc 100644 --- a/src/lerobot/policies/groot/eagle2_hg_model/image_processing_eagle2_5_vl_fast.py +++ b/src/lerobot/policies/groot/eagle2_hg_model/image_processing_eagle2_5_vl_fast.py @@ -4,10 +4,9 @@ # Licensed under The MIT License [see LICENSE for details] # -------------------------------------------------------- +from __future__ import annotations # copy from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava_onevision/image_processing_llava_onevision_fast.py -from typing import Optional - from transformers.image_processing_utils import ( BatchFeature, get_patch_output_size, @@ -165,11 +164,11 @@ class Eagle25VLImageProcessorFast(BaseImageProcessorFast): def _resize_for_patching( self, - image: "torch.Tensor", + image: torch.Tensor, target_resolution: tuple, - interpolation: "F.InterpolationMode", + interpolation: F.InterpolationMode, input_data_format: ChannelDimension, - ) -> "torch.Tensor": + ) -> torch.Tensor: """ Resizes an image to a target resolution while maintaining aspect ratio. @@ -219,8 +218,8 @@ class Eagle25VLImageProcessorFast(BaseImageProcessorFast): return best_ratio def _pad_for_patching( - self, image: "torch.Tensor", target_resolution: tuple, input_data_format: ChannelDimension - ) -> "torch.Tensor": + self, image: torch.Tensor, target_resolution: tuple, input_data_format: ChannelDimension + ) -> torch.Tensor: """ Pad an image to a target resolution while maintaining aspect ratio. """ @@ -236,15 +235,15 @@ class Eagle25VLImageProcessorFast(BaseImageProcessorFast): def _get_image_patches( self, - image: "torch.Tensor", + image: torch.Tensor, min_num: int, max_num: int, size: tuple, tile_size: int, use_thumbnail: bool, - interpolation: "F.InterpolationMode", + interpolation: F.InterpolationMode, pad_during_tiling: bool, - ) -> list["torch.Tensor"]: + ) -> list[torch.Tensor]: image_size = get_image_size(image, channel_dim=ChannelDimension.FIRST) orig_height, orig_width = image_size aspect_ratio = orig_width / orig_height @@ -305,8 +304,8 @@ class Eagle25VLImageProcessorFast(BaseImageProcessorFast): def _pad_for_batching( self, - pixel_values: list["torch.Tensor"], - ) -> list["torch.Tensor"]: + pixel_values: list[torch.Tensor], + ) -> list[torch.Tensor]: """ Pads images on the `num_of_patches` dimension with zeros to form a batch of same number of patches. @@ -327,14 +326,14 @@ class Eagle25VLImageProcessorFast(BaseImageProcessorFast): def _preprocess( self, - images: list["torch.Tensor"], + images: list[torch.Tensor], do_resize: bool, size: SizeDict, max_dynamic_tiles: int, min_dynamic_tiles: int, use_thumbnail: bool, pad_during_tiling: bool, - interpolation: Optional["F.InterpolationMode"], + interpolation: F.InterpolationMode | None, do_center_crop: bool, crop_size: SizeDict, do_rescale: bool, diff --git a/src/lerobot/policies/pi0/modeling_pi0.py b/src/lerobot/policies/pi0/modeling_pi0.py index 2f77e9517..aebf32964 100644 --- a/src/lerobot/policies/pi0/modeling_pi0.py +++ b/src/lerobot/policies/pi0/modeling_pi0.py @@ -20,12 +20,11 @@ import logging import math from collections import deque from pathlib import Path -from typing import TYPE_CHECKING, Literal, TypedDict +from typing import TYPE_CHECKING, Literal, TypedDict, Unpack import torch import torch.nn.functional as F # noqa: N812 from torch import Tensor, nn -from typing_extensions import Unpack from lerobot.utils.import_utils import _transformers_available diff --git a/src/lerobot/policies/pi05/modeling_pi05.py b/src/lerobot/policies/pi05/modeling_pi05.py index dc5eb20ec..96c4002f2 100644 --- a/src/lerobot/policies/pi05/modeling_pi05.py +++ b/src/lerobot/policies/pi05/modeling_pi05.py @@ -20,12 +20,11 @@ import logging import math from collections import deque from pathlib import Path -from typing import TYPE_CHECKING, Literal, TypedDict +from typing import TYPE_CHECKING, Literal, TypedDict, Unpack import torch import torch.nn.functional as F # noqa: N812 from torch import Tensor, nn -from typing_extensions import Unpack from lerobot.utils.import_utils import _transformers_available diff --git a/src/lerobot/policies/pi0_fast/modeling_pi0_fast.py b/src/lerobot/policies/pi0_fast/modeling_pi0_fast.py index 52fc2504d..1bcf9794c 100644 --- a/src/lerobot/policies/pi0_fast/modeling_pi0_fast.py +++ b/src/lerobot/policies/pi0_fast/modeling_pi0_fast.py @@ -19,13 +19,12 @@ import logging import math from collections import deque from pathlib import Path -from typing import TYPE_CHECKING, Literal, TypedDict +from typing import TYPE_CHECKING, Literal, TypedDict, Unpack import numpy as np import torch import torch.nn.functional as F # noqa: N812 from torch import Tensor, nn -from typing_extensions import Unpack from lerobot.utils.import_utils import _scipy_available, _transformers_available diff --git a/src/lerobot/policies/pretrained.py b/src/lerobot/policies/pretrained.py index e730b78a7..70efeba6f 100644 --- a/src/lerobot/policies/pretrained.py +++ b/src/lerobot/policies/pretrained.py @@ -19,7 +19,7 @@ import os from importlib.resources import files from pathlib import Path from tempfile import TemporaryDirectory -from typing import TypedDict, TypeVar +from typing import TypedDict, TypeVar, Unpack import packaging import safetensors @@ -28,7 +28,6 @@ from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE from huggingface_hub.errors import HfHubHTTPError from safetensors.torch import load_model as load_model_as_safetensor, save_model as save_model_as_safetensor from torch import Tensor, nn -from typing_extensions import Unpack from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.train import TrainPipelineConfig diff --git a/src/lerobot/policies/smolvla/modeling_smolvla.py b/src/lerobot/policies/smolvla/modeling_smolvla.py index e49226d26..430c85481 100644 --- a/src/lerobot/policies/smolvla/modeling_smolvla.py +++ b/src/lerobot/policies/smolvla/modeling_smolvla.py @@ -54,12 +54,11 @@ policy = SmolVLAPolicy.from_pretrained("lerobot/smolvla_base") import math from collections import deque -from typing import TypedDict +from typing import TypedDict, Unpack import torch import torch.nn.functional as F # noqa: N812 from torch import Tensor, nn -from typing_extensions import Unpack from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.rtc.modeling_rtc import RTCProcessor diff --git a/src/lerobot/processor/core.py b/src/lerobot/processor/core.py index 0b293c9b0..d9b8166c5 100644 --- a/src/lerobot/processor/core.py +++ b/src/lerobot/processor/core.py @@ -17,7 +17,7 @@ from __future__ import annotations from enum import Enum -from typing import Any, TypeAlias, TypedDict +from typing import Any, TypedDict import numpy as np import torch @@ -36,10 +36,10 @@ class TransitionKey(str, Enum): COMPLEMENTARY_DATA = "complementary_data" -PolicyAction: TypeAlias = torch.Tensor -RobotAction: TypeAlias = dict[str, Any] -EnvAction: TypeAlias = np.ndarray -RobotObservation: TypeAlias = dict[str, Any] +PolicyAction = torch.Tensor +RobotAction = dict[str, Any] +EnvAction = np.ndarray +RobotObservation = dict[str, Any] EnvTransition = TypedDict( diff --git a/src/lerobot/processor/pipeline.py b/src/lerobot/processor/pipeline.py index 8de376928..db1c3015c 100644 --- a/src/lerobot/processor/pipeline.py +++ b/src/lerobot/processor/pipeline.py @@ -39,7 +39,7 @@ from collections.abc import Callable, Iterable, Sequence from copy import deepcopy from dataclasses import dataclass, field from pathlib import Path -from typing import Any, Generic, TypeAlias, TypedDict, TypeVar, cast +from typing import Any, TypedDict, TypeVar, cast import torch from huggingface_hub import hf_hub_download @@ -251,7 +251,7 @@ class ProcessorMigrationError(Exception): @dataclass -class DataProcessorPipeline(HubMixin, Generic[TInput, TOutput]): +class DataProcessorPipeline[TInput, TOutput](HubMixin): """A sequential pipeline for processing data, integrated with the Hugging Face Hub. This class chains together multiple `ProcessorStep` instances to form a complete @@ -1432,8 +1432,8 @@ class DataProcessorPipeline(HubMixin, Generic[TInput, TOutput]): # Type aliases for semantic clarity. -RobotProcessorPipeline: TypeAlias = DataProcessorPipeline[TInput, TOutput] -PolicyProcessorPipeline: TypeAlias = DataProcessorPipeline[TInput, TOutput] +RobotProcessorPipeline = DataProcessorPipeline[TInput, TOutput] +PolicyProcessorPipeline = DataProcessorPipeline[TInput, TOutput] class ObservationProcessorStep(ProcessorStep, ABC): diff --git a/src/lerobot/robots/so_follower/config_so_follower.py b/src/lerobot/robots/so_follower/config_so_follower.py index 1ee589bda..52f7953de 100644 --- a/src/lerobot/robots/so_follower/config_so_follower.py +++ b/src/lerobot/robots/so_follower/config_so_follower.py @@ -15,7 +15,6 @@ # limitations under the License. from dataclasses import dataclass, field -from typing import TypeAlias from lerobot.cameras import CameraConfig @@ -50,5 +49,5 @@ class SOFollowerRobotConfig(RobotConfig, SOFollowerConfig): pass -SO100FollowerConfig: TypeAlias = SOFollowerRobotConfig -SO101FollowerConfig: TypeAlias = SOFollowerRobotConfig +SO100FollowerConfig = SOFollowerRobotConfig +SO101FollowerConfig = SOFollowerRobotConfig diff --git a/src/lerobot/robots/so_follower/so_follower.py b/src/lerobot/robots/so_follower/so_follower.py index bc72a2b6a..c898e9137 100644 --- a/src/lerobot/robots/so_follower/so_follower.py +++ b/src/lerobot/robots/so_follower/so_follower.py @@ -17,7 +17,6 @@ import logging import time from functools import cached_property -from typing import TypeAlias from lerobot.cameras.utils import make_cameras_from_configs from lerobot.motors import Motor, MotorCalibration, MotorNormMode @@ -230,5 +229,5 @@ class SOFollower(Robot): logger.info(f"{self} disconnected.") -SO100Follower: TypeAlias = SOFollower -SO101Follower: TypeAlias = SOFollower +SO100Follower = SOFollower +SO101Follower = SOFollower diff --git a/src/lerobot/teleoperators/so_leader/config_so_leader.py b/src/lerobot/teleoperators/so_leader/config_so_leader.py index 2b4f782a7..189303088 100644 --- a/src/lerobot/teleoperators/so_leader/config_so_leader.py +++ b/src/lerobot/teleoperators/so_leader/config_so_leader.py @@ -15,7 +15,6 @@ # limitations under the License. from dataclasses import dataclass -from typing import TypeAlias from ..config import TeleoperatorConfig @@ -38,5 +37,5 @@ class SOLeaderTeleopConfig(TeleoperatorConfig, SOLeaderConfig): pass -SO100LeaderConfig: TypeAlias = SOLeaderTeleopConfig -SO101LeaderConfig: TypeAlias = SOLeaderTeleopConfig +SO100LeaderConfig = SOLeaderTeleopConfig +SO101LeaderConfig = SOLeaderTeleopConfig diff --git a/src/lerobot/teleoperators/so_leader/so_leader.py b/src/lerobot/teleoperators/so_leader/so_leader.py index a10e3a61f..04ce0f21f 100644 --- a/src/lerobot/teleoperators/so_leader/so_leader.py +++ b/src/lerobot/teleoperators/so_leader/so_leader.py @@ -16,7 +16,6 @@ import logging import time -from typing import TypeAlias from lerobot.motors import Motor, MotorCalibration, MotorNormMode from lerobot.motors.feetech import ( @@ -156,5 +155,5 @@ class SOLeader(Teleoperator): logger.info(f"{self} disconnected.") -SO100Leader: TypeAlias = SOLeader -SO101Leader: TypeAlias = SOLeader +SO100Leader = SOLeader +SO101Leader = SOLeader diff --git a/src/lerobot/utils/io_utils.py b/src/lerobot/utils/io_utils.py index da0be1c77..d70ea8b6a 100644 --- a/src/lerobot/utils/io_utils.py +++ b/src/lerobot/utils/io_utils.py @@ -16,12 +16,10 @@ import json import warnings from pathlib import Path -from typing import TypeVar import imageio JsonLike = str | int | float | bool | None | list["JsonLike"] | dict[str, "JsonLike"] | tuple["JsonLike", ...] -T = TypeVar("T", bound=JsonLike) def write_video(video_path, stacked_frames, fps): @@ -33,7 +31,7 @@ def write_video(video_path, stacked_frames, fps): imageio.mimsave(video_path, stacked_frames, fps=fps) -def deserialize_json_into_object(fpath: Path, obj: T) -> T: +def deserialize_json_into_object[T: JsonLike](fpath: Path, obj: T) -> T: """ Loads the JSON data from `fpath` and recursively fills `obj` with the corresponding values (strictly matching structure and types). diff --git a/tests/policies/test_policies.py b/tests/policies/test_policies.py index 345526d90..1ba82ffd0 100644 --- a/tests/policies/test_policies.py +++ b/tests/policies/test_policies.py @@ -143,12 +143,18 @@ def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs): Note: We test various combinations of policy and dataset. The combinations are by no means exhaustive, and for now we add tests as we see fit. """ + if policy_name == "vqbet" and DEVICE == "mps": + pytest.skip("VQBet does not support MPS backend") + if policy_name == "act" and "aloha" in ds_repo_id and DEVICE == "mps": + pytest.skip("ACT with aloha has batch mutation issues on MPS") + train_cfg = TrainPipelineConfig( # TODO(rcadene, aliberts): remove dataset download dataset=DatasetConfig(repo_id=ds_repo_id, episodes=[0]), policy=make_policy_config(policy_name, push_to_hub=False, **policy_kwargs), env=make_env_config(env_name, **env_kwargs), ) + train_cfg.policy.device = DEVICE train_cfg.validate() # Check that we can make the policy object. @@ -227,6 +233,7 @@ def test_act_backbone_lr(): dataset=DatasetConfig(repo_id="lerobot/aloha_sim_insertion_scripted", episodes=[0]), policy=make_policy_config("act", optimizer_lr=0.01, optimizer_lr_backbone=0.001, push_to_hub=False), ) + cfg.policy.device = DEVICE cfg.validate() # Needed for auto-setting some parameters assert cfg.policy.optimizer_lr == 0.01 diff --git a/tests/processor/test_pipeline.py b/tests/processor/test_pipeline.py index 58a83fe69..a335c2b4b 100644 --- a/tests/processor/test_pipeline.py +++ b/tests/processor/test_pipeline.py @@ -1870,9 +1870,7 @@ class NonCallableStep(ProcessorStep): def test_construction_rejects_step_without_call(): """Test that DataProcessorPipeline rejects steps that don't inherit from ProcessorStep.""" - with pytest.raises( - TypeError, match=r"Can't instantiate abstract class NonCallableStep with abstract method __call_" - ): + with pytest.raises(TypeError, match=r"Can't instantiate abstract class NonCallableStep"): DataProcessorPipeline([NonCallableStep()]) with pytest.raises(TypeError, match=r"must inherit from ProcessorStep"): diff --git a/tests/utils.py b/tests/utils.py index 38841db02..d9bdc7f93 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -22,8 +22,9 @@ import torch from lerobot import available_cameras, available_motors, available_robots from lerobot.utils.import_utils import is_package_available +from lerobot.utils.utils import auto_select_torch_device -DEVICE = os.environ.get("LEROBOT_TEST_DEVICE", "cuda") if torch.cuda.is_available() else "cpu" +DEVICE = os.environ.get("LEROBOT_TEST_DEVICE", str(auto_select_torch_device())) TEST_ROBOT_TYPES = [] for robot_type in available_robots: From a225127527becc4fe299919e1c4b240c43cdc37a Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Fri, 6 Mar 2026 10:50:46 +0100 Subject: [PATCH 092/131] chore(dependencies): sync intelrealsense + added notes (#3094) --- pyproject.toml | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 7cd83591f..d75d6b788 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -66,7 +66,7 @@ dependencies = [ "accelerate>=1.10.0,<2.0.0", # Core dependencies - "numpy>=2.0.0,<2.3.0", # TODO: upper bound imposed by opencv-python-headless + "numpy>=2.0.0,<2.3.0", # NOTE: Explicitly listing numpy helps the resolver converge faster. Upper bound imposed by opencv-python-headless. "setuptools>=71.0.0,<81.0.0", "cmake>=3.29.0.1,<4.2.0", "packaging>=24.2,<26.0", @@ -105,7 +105,7 @@ can-dep = ["python-can>=4.2.0,<5.0.0"] peft-dep = ["peft>=0.18.0,<1.0.0"] scipy-dep = ["scipy>=1.14.0,<2.0.0"] qwen-vl-utils-dep = ["qwen-vl-utils>=0.0.11,<0.1.0"] -matplotlib-dep = ["matplotlib>=3.10.3,<4.0.0", "contourpy>=1.3.0,<2.0.0"] +matplotlib-dep = ["matplotlib>=3.10.3,<4.0.0", "contourpy>=1.3.0,<2.0.0"] # NOTE: Explicitly listing contourpy helps the resolver converge faster. # Motors feetech = ["feetech-servo-sdk>=1.0.0,<2.0.0"] @@ -130,7 +130,7 @@ reachy2 = ["reachy2_sdk>=1.0.15,<1.1.0"] kinematics = ["lerobot[placo-dep]"] intelrealsense = [ "pyrealsense2>=2.55.1.6486,<2.57.0 ; sys_platform != 'darwin'", - "pyrealsense2-macosx>=2.54,<2.55.0 ; sys_platform == 'darwin'", + "pyrealsense2-macosx>=2.54,<2.57.0 ; sys_platform == 'darwin'", ] phone = ["hebi-py>=2.8.0,<2.12.0", "teleop>=0.1.0,<0.2.0", "fastapi<1.0", "lerobot[scipy-dep]"] @@ -169,6 +169,7 @@ test = ["pytest>=8.1.0,<9.0.0", "pytest-timeout>=2.4.0,<3.0.0", "pytest-cov>=5.0 video_benchmark = ["scikit-image>=0.23.2,<0.26.0", "pandas>=2.2.2,<2.4.0"] # Simulation +# NOTE: Explicitly listing scipy helps flatten the dependecy tree. aloha = ["gym-aloha>=0.1.2,<0.2.0", "lerobot[scipy-dep]"] pusht = ["gym-pusht>=0.1.5,<0.2.0", "pymunk>=6.6.0,<7.0.0"] # TODO: Fix pymunk version in gym-pusht instead libero = ["lerobot[transformers-dep]", "hf-libero>=0.1.3,<0.2.0; sys_platform == 'linux'", "lerobot[scipy-dep]"] @@ -176,8 +177,8 @@ metaworld = ["metaworld==3.0.0", "lerobot[scipy-dep]"] # All all = [ - # Resolver hint: scipy is pulled in transitively via lerobot[scipy-dep] through - # multiple extras below (aloha, metaworld, pi, wallx, phone). Listing it explicitly + # NOTE(resolver hint): scipy is pulled in transitively via lerobot[scipy-dep] through + # multiple extras (aloha, metaworld, pi, wallx, phone). Listing it explicitly # helps pip's resolver converge by constraining scipy early, before it encounters # the loose scipy requirements from transitive deps like dm-control and metaworld. "scipy>=1.14.0,<2.0.0", From a4c66e530bd1edc805db5ca69640f0ee88e7ea3c Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Fri, 6 Mar 2026 15:52:54 +0100 Subject: [PATCH 093/131] chore(docs): remove pi installation note (#3095) --- .github/workflows/release.yml | 8 -------- docs/source/installation.mdx | 3 --- docs/source/pi0.mdx | 5 ----- docs/source/pi05.mdx | 5 ----- docs/source/pi0fast.mdx | 5 ----- 5 files changed, 26 deletions(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index e95d6cef6..f7bd2be6c 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -83,14 +83,6 @@ jobs: exit 1 fi - - name: Remove Tags with Git dependencies - # TODO(Steven): Temporary patch to remove pi from PyPi 0.4.0 release due to its reliance on git dependencies. - run: | - echo "::info:: Checking for Git dependencies to remove from pyproject.toml..." - grep -E '@ git\+https|lerobot\[pi\]' pyproject.toml | sed 's/^/::warning:: Removing line: /' || true - sed -E -i '/@ git\+https|lerobot\[pi\]/d' pyproject.toml - echo "::info:: Git dependencies removed. Proceeding with build." - - name: Install build dependencies run: python -m pip install build diff --git a/docs/source/installation.mdx b/docs/source/installation.mdx index 26c88a0a2..6d29215a0 100644 --- a/docs/source/installation.mdx +++ b/docs/source/installation.mdx @@ -90,9 +90,6 @@ _Replace `[...]` with your desired features._ For a full list of optional dependencies, see: https://pypi.org/project/lerobot/ -> [!NOTE] -> For lerobot 0.4.0, if you want to install pi, you will have to do: `pip install "lerobot[pi]@git+https://github.com/huggingface/lerobot.git"` - ### Troubleshooting If you encounter build errors, you may need to install additional dependencies: `cmake`, `build-essential`, and `ffmpeg libs`. diff --git a/docs/source/pi0.mdx b/docs/source/pi0.mdx index 879bbd16d..be7792b28 100644 --- a/docs/source/pi0.mdx +++ b/docs/source/pi0.mdx @@ -34,11 +34,6 @@ As described by Physical Intelligence, while AI has achieved remarkable success pip install -e ".[pi]" ``` - > [!NOTE] - > For lerobot 0.4.0, if you want to install pi tag, you will have to do: `pip install "lerobot[pi]@git+https://github.com/huggingface/lerobot.git"`. - > - > This will be solved in the next patch release - ## Training Data and Capabilities π₀ is trained on the largest robot interaction dataset to date, combining three key data sources: diff --git a/docs/source/pi05.mdx b/docs/source/pi05.mdx index 8abaca989..f586f0dc1 100644 --- a/docs/source/pi05.mdx +++ b/docs/source/pi05.mdx @@ -36,11 +36,6 @@ This diverse training mixture creates a "curriculum" that enables generalization pip install -e ".[pi]" ``` - > [!NOTE] - > For lerobot 0.4.0, if you want to install pi tag, you will have to do: `pip install "lerobot[pi]@git+https://github.com/huggingface/lerobot.git"`. - > - > This will be solved in the next patch release - ## Usage To use π₀.₅ in your LeRobot configuration, specify the policy type as: diff --git a/docs/source/pi0fast.mdx b/docs/source/pi0fast.mdx index 85d975924..f7272acc5 100644 --- a/docs/source/pi0fast.mdx +++ b/docs/source/pi0fast.mdx @@ -43,11 +43,6 @@ This approach can transform **any existing VLM** into a VLA by training it to pr pip install -e ".[pi]" ``` - > [!NOTE] - > For lerobot 0.4.0, if you want to install the pi tag, you will have to do: `pip install "lerobot[pi]@git+https://github.com/huggingface/lerobot.git"`. - > - > This will be solved in the next patch release - ## Training a Custom FAST Tokenizer You have two options for the FAST tokenizer: From 85de893fa73fa939f1d57d209b2c04be11a0a71e Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Fri, 6 Mar 2026 16:33:43 +0100 Subject: [PATCH 094/131] fix(ci): skip HF log in (and tests) in forks and community PRs (#3097) * fix(ci): skip HF log in (and tests) in forks and community PRs * chore(test): remove comment about test meant to be only run locally * fix(tests): no hf log in decorator for xvla * fix(test): no decorator in yield --- .github/workflows/fast_tests.yml | 1 + .github/workflows/full_tests.yml | 2 ++ .github/workflows/nightly.yml | 3 +++ .github/workflows/unbound_deps_tests.yml | 2 ++ .../test_pi0_fast_original_vs_lerobot.py | 14 +++++++++++++- tests/policies/pi0_pi05/test_pi0.py | 9 +++++++-- tests/policies/pi0_pi05/test_pi05.py | 9 +++++++-- .../pi0_pi05/test_pi05_original_vs_lerobot.py | 2 +- .../pi0_pi05/test_pi0_original_vs_lerobot.py | 2 +- tests/policies/wall_x/test_wallx.py | 6 ++++-- .../xvla/test_xvla_original_vs_lerobot.py | 2 +- tests/utils.py | 16 ++++++++++++++++ 12 files changed, 58 insertions(+), 10 deletions(-) diff --git a/.github/workflows/fast_tests.yml b/.github/workflows/fast_tests.yml index 7715823ff..fc169e253 100644 --- a/.github/workflows/fast_tests.yml +++ b/.github/workflows/fast_tests.yml @@ -91,6 +91,7 @@ jobs: run: uv sync --extra "test" - name: Login to Hugging Face + if: env.HF_USER_TOKEN != '' run: | uv run hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential uv run hf auth whoami diff --git a/.github/workflows/full_tests.yml b/.github/workflows/full_tests.yml index 0e50de879..8b7d28123 100644 --- a/.github/workflows/full_tests.yml +++ b/.github/workflows/full_tests.yml @@ -89,6 +89,7 @@ jobs: run: uv sync --extra all # TODO(Steven): Make flash-attn optional - name: Login to Hugging Face + if: env.HF_USER_TOKEN != '' run: | uv run hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential uv run hf auth whoami @@ -181,6 +182,7 @@ jobs: working-directory: /lerobot steps: - name: Login to Hugging Face + if: env.HF_USER_TOKEN != '' run: | hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential hf auth whoami diff --git a/.github/workflows/nightly.yml b/.github/workflows/nightly.yml index 95c6702cd..5bc86857a 100644 --- a/.github/workflows/nightly.yml +++ b/.github/workflows/nightly.yml @@ -132,6 +132,7 @@ jobs: working-directory: /lerobot steps: - name: Login to Hugging Face + if: env.HF_USER_TOKEN != '' run: | hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential hf auth whoami @@ -164,6 +165,7 @@ jobs: working-directory: /lerobot steps: - name: Login to Hugging Face + if: env.HF_USER_TOKEN != '' run: | hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential hf auth whoami @@ -197,6 +199,7 @@ jobs: working-directory: /lerobot steps: - name: Login to Hugging Face + if: env.HF_USER_TOKEN != '' run: | hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential hf auth whoami diff --git a/.github/workflows/unbound_deps_tests.yml b/.github/workflows/unbound_deps_tests.yml index 9ce44152a..404816c52 100644 --- a/.github/workflows/unbound_deps_tests.yml +++ b/.github/workflows/unbound_deps_tests.yml @@ -81,6 +81,7 @@ jobs: - name: Install lerobot with all extras run: uv sync --extra all # TODO(Steven): Make flash-attn optional - name: Login to Hugging Face + if: env.HF_USER_TOKEN != '' run: | uv run hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential uv run hf auth whoami @@ -154,6 +155,7 @@ jobs: working-directory: /lerobot steps: - name: Login to Hugging Face + if: env.HF_USER_TOKEN != '' run: | hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential hf auth whoami diff --git a/tests/policies/pi0_fast/test_pi0_fast_original_vs_lerobot.py b/tests/policies/pi0_fast/test_pi0_fast_original_vs_lerobot.py index 9de781464..d24bb11d7 100644 --- a/tests/policies/pi0_fast/test_pi0_fast_original_vs_lerobot.py +++ b/tests/policies/pi0_fast/test_pi0_fast_original_vs_lerobot.py @@ -40,7 +40,7 @@ from lerobot.utils.constants import ( OBS_LANGUAGE_TOKENS, OBS_STATE, ) # noqa: E402 -from tests.utils import require_cuda # noqa: E402 +from tests.utils import require_cuda, require_hf_token # noqa: E402 # Constants DUMMY_ACTION_DIM = 7 @@ -65,6 +65,7 @@ EXPECTED_ACTIONS_FIRST_5 = torch.tensor([0.0000, 0.3536, 0.0707, 0.0000, 0.0000] @require_cuda +@require_hf_token def set_seed_all(seed: int): """Set random seed for all RNG sources to ensure reproducibility.""" random.seed(seed) @@ -82,6 +83,7 @@ def set_seed_all(seed: int): @require_cuda +@require_hf_token def instantiate_lerobot_pi0_fast( from_pretrained: bool = False, model_path: str = MODEL_PATH_LEROBOT, @@ -125,6 +127,7 @@ def instantiate_lerobot_pi0_fast( @require_cuda +@require_hf_token def create_dummy_data(device=DEVICE): """Create dummy data for testing both implementations.""" batch_size = 1 @@ -157,6 +160,7 @@ def create_dummy_data(device=DEVICE): # Pytest fixtures @pytest.fixture(scope="module") @require_cuda +@require_hf_token def pi0_fast_components(): """Fixture to instantiate and provide all PI0Fast components for tests.""" print(f"\nTesting with DEVICE='{DEVICE}'") @@ -168,6 +172,7 @@ def pi0_fast_components(): @pytest.fixture(scope="module") @require_cuda +@require_hf_token def policy(pi0_fast_components): """Fixture to provide the PI0Fast policy for tests.""" return pi0_fast_components[0] @@ -175,12 +180,14 @@ def policy(pi0_fast_components): @pytest.fixture(scope="module") @require_cuda +@require_hf_token def preprocessor(pi0_fast_components): """Fixture to provide the PI0Fast preprocessor for tests.""" return pi0_fast_components[1] @require_cuda +@require_hf_token def test_pi0_fast_preprocessor_alignment(policy, preprocessor): """Test that LeRobot PI0Fast preprocessor produces expected outputs.""" print("\n" + "=" * 80) @@ -228,6 +235,7 @@ def test_pi0_fast_preprocessor_alignment(policy, preprocessor): @require_cuda +@require_hf_token def test_pi0_fast_action_generation(policy, preprocessor): """Test PI0Fast LeRobot implementation generates expected actions.""" print("\n" + "=" * 80) @@ -306,6 +314,7 @@ def test_pi0_fast_action_generation(policy, preprocessor): @require_cuda +@require_hf_token def test_pi0_fast_inference_reproducibility(policy, preprocessor): """Test that PI0Fast inference is reproducible with the same seed.""" print("\n" + "=" * 80) @@ -347,6 +356,7 @@ def test_pi0_fast_inference_reproducibility(policy, preprocessor): @require_cuda +@require_hf_token def test_pi0_fast_forward_pass_logits(policy, preprocessor): """Test PI0Fast forward pass and compare logits against expected values.""" print("\n" + "=" * 80) @@ -396,6 +406,7 @@ def test_pi0_fast_forward_pass_logits(policy, preprocessor): @require_cuda +@require_hf_token def test_pi0_fast_action_token_sampling(policy, preprocessor): """Test PI0Fast action token sampling (autoregressive decoding).""" print("\n" + "=" * 80) @@ -452,6 +463,7 @@ def test_pi0_fast_action_token_sampling(policy, preprocessor): @require_cuda +@require_hf_token def test_pi0_fast_detokenization(policy, preprocessor): """Test PI0Fast action detokenization (FAST decoding).""" print("\n" + "=" * 80) diff --git a/tests/policies/pi0_pi05/test_pi0.py b/tests/policies/pi0_pi05/test_pi0.py index e83abf57d..5a985e03c 100644 --- a/tests/policies/pi0_pi05/test_pi0.py +++ b/tests/policies/pi0_pi05/test_pi0.py @@ -14,10 +14,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Test script to verify PI0 policy integration with LeRobot, only meant to be run locally!""" +"""Test script to verify PI0 policy integration with LeRobot""" +import pytest import torch +pytest.importorskip("transformers") + from lerobot.policies.factory import make_policy_config # noqa: E402 from lerobot.policies.pi0 import ( # noqa: E402 PI0Config, @@ -25,10 +28,11 @@ from lerobot.policies.pi0 import ( # noqa: E402 make_pi0_pre_post_processors, # noqa: E402 ) from lerobot.utils.random_utils import set_seed # noqa: E402 -from tests.utils import require_cuda # noqa: E402 +from tests.utils import require_cuda, require_hf_token # noqa: E402 @require_cuda +@require_hf_token def test_policy_instantiation(): # Create config set_seed(42) @@ -105,6 +109,7 @@ def test_policy_instantiation(): @require_cuda +@require_hf_token def test_config_creation(): """Test policy config creation through factory.""" try: diff --git a/tests/policies/pi0_pi05/test_pi05.py b/tests/policies/pi0_pi05/test_pi05.py index 595191689..f0da2971b 100644 --- a/tests/policies/pi0_pi05/test_pi05.py +++ b/tests/policies/pi0_pi05/test_pi05.py @@ -14,10 +14,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Test script to verify PI0.5 (pi05) support in PI0 policy, only meant to be run locally!""" +"""Test script to verify PI0.5 (pi05) support in PI0 policy""" +import pytest import torch +pytest.importorskip("transformers") + from lerobot.policies.factory import make_policy_config # noqa: E402 from lerobot.policies.pi05 import ( # noqa: E402 PI05Config, @@ -25,10 +28,11 @@ from lerobot.policies.pi05 import ( # noqa: E402 make_pi05_pre_post_processors, # noqa: E402 ) from lerobot.utils.random_utils import set_seed -from tests.utils import require_cuda # noqa: E402 +from tests.utils import require_cuda, require_hf_token # noqa: E402 @require_cuda +@require_hf_token def test_policy_instantiation(): # Create config set_seed(42) @@ -141,6 +145,7 @@ def test_policy_instantiation(): @require_cuda +@require_hf_token def test_config_creation(): """Test policy config creation through factory.""" try: diff --git a/tests/policies/pi0_pi05/test_pi05_original_vs_lerobot.py b/tests/policies/pi0_pi05/test_pi05_original_vs_lerobot.py index 0d5244e1c..f70707262 100644 --- a/tests/policies/pi0_pi05/test_pi05_original_vs_lerobot.py +++ b/tests/policies/pi0_pi05/test_pi05_original_vs_lerobot.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Test script to verify PI0OpenPI policy integration with LeRobot vs the original implementation, only meant to be run locally!""" +"""Test script to verify PI0OpenPI policy integration with LeRobot vs the original implementation""" import os from copy import deepcopy diff --git a/tests/policies/pi0_pi05/test_pi0_original_vs_lerobot.py b/tests/policies/pi0_pi05/test_pi0_original_vs_lerobot.py index 41db2dceb..d3d1c1908 100644 --- a/tests/policies/pi0_pi05/test_pi0_original_vs_lerobot.py +++ b/tests/policies/pi0_pi05/test_pi0_original_vs_lerobot.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Test script to verify PI0 policy integration with LeRobot vs the original implementation, only meant to be run locally!""" +"""Test script to verify PI0 policy integration with LeRobot vs the original implementation""" import os from copy import deepcopy diff --git a/tests/policies/wall_x/test_wallx.py b/tests/policies/wall_x/test_wallx.py index 3514fccd1..85656eca2 100644 --- a/tests/policies/wall_x/test_wallx.py +++ b/tests/policies/wall_x/test_wallx.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Test script to verify Wall-X policy integration with LeRobot, only meant to be run locally!""" +"""Test script to verify Wall-X policy integration with LeRobot""" import pytest import torch @@ -29,10 +29,11 @@ from lerobot.policies.wall_x import WallXConfig # noqa: E402 from lerobot.policies.wall_x.modeling_wall_x import WallXPolicy # noqa: E402 from lerobot.policies.wall_x.processor_wall_x import make_wall_x_pre_post_processors # noqa: E402 from lerobot.utils.random_utils import set_seed # noqa: E402 -from tests.utils import require_cuda # noqa: E402 +from tests.utils import require_cuda, require_hf_token # noqa: E402 @require_cuda +@require_hf_token def test_policy_instantiation(): # Create config set_seed(42) @@ -118,6 +119,7 @@ def test_policy_instantiation(): @require_cuda +@require_hf_token def test_config_creation(): """Test policy config creation through factory.""" try: diff --git a/tests/policies/xvla/test_xvla_original_vs_lerobot.py b/tests/policies/xvla/test_xvla_original_vs_lerobot.py index a9603fdb0..e36d14d01 100644 --- a/tests/policies/xvla/test_xvla_original_vs_lerobot.py +++ b/tests/policies/xvla/test_xvla_original_vs_lerobot.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Test script to verify XVLA policy integration with LeRobot vs the original implementation, only meant to be run locally!""" +"""Test script to verify XVLA policy integration with LeRobot vs the original implementation""" # ruff: noqa: E402 import random diff --git a/tests/utils.py b/tests/utils.py index d9bdc7f93..a77082ea9 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -108,6 +108,22 @@ def require_cuda(func): return wrapper +def require_hf_token(func): + """ + Decorator that skips the test if no Hugging Face Hub token is available. + """ + + @wraps(func) + def wrapper(*args, **kwargs): + from huggingface_hub import get_token + + if get_token() is None: + pytest.skip("requires HF token for gated model access") + return func(*args, **kwargs) + + return wrapper + + def require_env(func): """ Decorator that skips the test if the required environment package is not installed. From 6139b133caa1fc3116fb7508464dff57228d7d0e Mon Sep 17 00:00:00 2001 From: "Shun.Sasaki" Date: Sat, 7 Mar 2026 01:14:14 +0900 Subject: [PATCH 095/131] fix(async_inference): restore robot module imports in robot_client.py (#3081) --- src/lerobot/async_inference/robot_client.py | 9 ++++-- tests/async_inference/test_robot_client.py | 36 +++++++++++++++++++++ 2 files changed, 43 insertions(+), 2 deletions(-) diff --git a/src/lerobot/async_inference/robot_client.py b/src/lerobot/async_inference/robot_client.py index da576eb48..0ee70a0e6 100644 --- a/src/lerobot/async_inference/robot_client.py +++ b/src/lerobot/async_inference/robot_client.py @@ -49,9 +49,14 @@ import torch from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401 from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401 -from lerobot.robots import ( - RobotConfig, # noqa: F401 +from lerobot.robots import ( # noqa: F401 + Robot, + RobotConfig, + bi_so_follower, + koch_follower, make_robot_from_config, + omx_follower, + so_follower, ) from lerobot.transport import ( services_pb2, # type: ignore diff --git a/tests/async_inference/test_robot_client.py b/tests/async_inference/test_robot_client.py index 5b138d91b..d7ef5b350 100644 --- a/tests/async_inference/test_robot_client.py +++ b/tests/async_inference/test_robot_client.py @@ -231,3 +231,39 @@ def test_ready_to_send_observation_with_varying_threshold(robot_client, g_thresh robot_client.action_queue.put(act) assert robot_client._ready_to_send_observation() is expected + + +# ----------------------------------------------------------------------------- +# Regression test: robot type registry populated by robot_client imports +# ----------------------------------------------------------------------------- + + +def test_robot_client_registers_builtin_robot_types(): + """Importing robot_client must populate RobotConfig's ChoiceRegistry. + + This is a regression test for a bug introduced in #2425, where removing + robot module imports from robot_client.py caused RobotConfig's registry to + be empty, breaking CLI argument parsing with: + error: argument --robot.type: invalid choice: 'so101_follower' (choose from ) + + Robot types are registered via @RobotConfig.register_subclass() decorators + at import time, so all supported modules must be explicitly imported. + """ + import lerobot.async_inference.robot_client # noqa: F401 + from lerobot.robots.config import RobotConfig + + known_choices = RobotConfig.get_known_choices() + + expected_robot_types = [ + "so100_follower", + "so101_follower", + "koch_follower", + "omx_follower", + "bi_so_follower", + ] + for robot_type in expected_robot_types: + assert robot_type in known_choices, ( + f"Robot type '{robot_type}' is not registered in RobotConfig's ChoiceRegistry. " + f"Ensure the corresponding module is imported in robot_client.py. " + f"Known choices: {sorted(known_choices)}" + ) From 4f2ef024d847695c61b552981ff942235dca5bea Mon Sep 17 00:00:00 2001 From: Martino Russi <77496684+nepyope@users.noreply.github.com> Date: Sun, 8 Mar 2026 11:33:24 +0100 Subject: [PATCH 096/131] feat(robots): Unitree G1 WBC implementation (#2876) * move locomotion from examples to robot, move controller to teleoperator class * modify teleoperate to send back actions to robot * whole body controller * add holosoma to locomotros * various updates * update joint zeroing etc * ensure safefail with locomotion * add unitree locomotion * launch camera from g1 server * publish at varying framerates * fix async read in camera * attempting to fix camera lag * test camera speedup * training * inference works * remove logging from pi0 * remove logging * push local changes * testing * final changes * revert control_utils * revert utils * revert * revert g1 * revert again: * revert utils * push recents * remove examples * remove junk * remove mjlog * revergt edit_dataset * Update lerobot_edit_dataset.py Signed-off-by: Martino Russi <77496684+nepyope@users.noreply.github.com> * undo teleop changes * revert logging * remove loggings * remove loogs * revert dataset tools * Update dataset_tools.py Signed-off-by: Martino Russi <77496684+nepyope@users.noreply.github.com> * move gravity to utils * revert changes * remove matplotlib viewer (rerun works fine) * factory revert * send policy action directly * recent changes * implement flexible action space * send empty command if arms are missing * rename locomotion to controller * add init * implement feedback * add feedback for teleoperator * fix ruff * fix ruff * use read_latest * fix zmq camera * revert exo_serial * simplify PR * revert exo_changes * revert camera_zmq * Update camera_zmq.py Signed-off-by: Martino Russi <77496684+nepyope@users.noreply.github.com> * remove frame duplication from zmq server * revert channerfactoryinitialize * keep channelfactoryinitialize * remove zeroing out logic * fix typo * refactor teleop class * simplify teleop further * import armindex at the top * fix visualizer again * revert ik helper * push stuff * simplify image_server * update image_server * asd * add threading logic * simplify ik helper stuff * simplify holosoma * fix names * fix docs * revert leg override * clean connect * fix controller * fix ruff * clean teleoperator * set_from_wireless * avoid double initializations * refactor robot class * fix pre-commit * update docs * update docs format * add teleop instructions * unitree_g1 specific exception in record/teleoperate * add thumbnail to docs * add thumbnail to doc * refactor(unitree): multiple improvements (#3103) * refactor(unitree): multiple improvements * test(unitree): added tests + improved installation instructions * refactor(robots): minor changes unitree robot kinematic * chore(robots): rename g1 kinematics file --------- Signed-off-by: Martino Russi <77496684+nepyope@users.noreply.github.com> Signed-off-by: Steven Palma Co-authored-by: Steven Palma Co-authored-by: Steven Palma --- docs/source/unitree_g1.mdx | 326 +++++++-------- examples/rtc/eval_with_real_robot.py | 2 + pyproject.toml | 3 + src/lerobot/cameras/zmq/camera_zmq.py | 2 +- src/lerobot/cameras/zmq/image_server.py | 76 +++- src/lerobot/robots/unitree_g1/__init__.py | 2 + .../robots/unitree_g1/config_unitree_g1.py | 9 +- ...inematic_processor.py => g1_kinematics.py} | 66 +-- src/lerobot/robots/unitree_g1/g1_utils.py | 41 +- .../robots}/unitree_g1/gr00t_locomotion.py | 157 +++---- .../robots}/unitree_g1/holosoma_locomotion.py | 174 +++----- .../robots/unitree_g1/run_g1_server.py | 31 ++ src/lerobot/robots/unitree_g1/unitree_g1.py | 384 ++++++++++++------ .../robots/unitree_g1/unitree_sdk2_socket.py | 11 +- src/lerobot/scripts/lerobot_record.py | 6 +- src/lerobot/scripts/lerobot_teleoperate.py | 5 +- .../teleoperators/unitree_g1/__init__.py | 10 + .../teleoperators/unitree_g1/exo_calib.py | 24 +- .../teleoperators/unitree_g1/exo_ik.py | 2 +- .../teleoperators/unitree_g1/exo_serial.py | 37 +- .../teleoperators/unitree_g1/unitree_g1.py | 195 ++++++++- src/lerobot/utils/import_utils.py | 2 + tests/robots/test_unitree_g1.py | 267 ++++++++++++ .../test_unitree_g1_teleoperator.py | 309 ++++++++++++++ 24 files changed, 1504 insertions(+), 637 deletions(-) rename src/lerobot/robots/unitree_g1/{robot_kinematic_processor.py => g1_kinematics.py} (87%) rename {examples => src/lerobot/robots}/unitree_g1/gr00t_locomotion.py (59%) rename {examples => src/lerobot/robots}/unitree_g1/holosoma_locomotion.py (53%) create mode 100644 tests/robots/test_unitree_g1.py create mode 100644 tests/teleoperators/test_unitree_g1_teleoperator.py diff --git a/docs/source/unitree_g1.mdx b/docs/source/unitree_g1.mdx index fa7159154..39bd7832b 100644 --- a/docs/source/unitree_g1.mdx +++ b/docs/source/unitree_g1.mdx @@ -1,23 +1,49 @@ # Unitree G1 -This guide covers the complete setup process for the Unitree G1 humanoid, from initial connection to running gr00t_wbc locomotion. +Unitree G1 locomanipulation demo -## About - -We support both 29 and 23 DOF G1 EDU version. We introduce: - -- **`unitree g1` robot class, handling low level read/write from/to the humanoid** -- **ZMQ socket bridge** for remote communication and camera streaming, allowing for remote policy deployment over wlan, eth or directly on the robot -- **Locomotion policies** from NVIDIA gr00t and Amazon FAR Holosoma -- **Simulation mode** for testing policies without the physical robot in mujoco +The Unitree G1 humanoid is now supported in LeRobot! You can teleoperate, train locomanipulation policies, test in sim, and more. Both 29 and 23 DoF variants are supported. --- -## Connection guide +## Part 1: Getting Started -### Step 1: Configure Ethernet Interface +### Install LeRobot on Your Machine -Set a static IP on the same subnet as the robot: +```bash +conda create -y -n lerobot python=3.12 +conda activate lerobot +git clone https://github.com/unitreerobotics/unitree_sdk2_python.git +cd unitree_sdk2_python && pip install -e . +git clone https://github.com/huggingface/lerobot.git +cd lerobot +pip install -e '.[unitree_g1]' +``` + +### Test the Installation (Simulation) + +```bash +lerobot-teleoperate \ + --robot.type=unitree_g1 \ + --robot.is_simulation=true \ + --teleop.type=unitree_g1 \ + --teleop.id=wbc_unitree \ + --robot.cameras='{"global_view": {"type": "zmq", "server_address": "localhost", "port": 5555, "camera_name": "head_camera", "width": 640, "height": 480, "fps": 30}}' \ + --display_data=true +``` + +This will launch a [MuJoCo sim instance](https://huggingface.co/lerobot/unitree-g1-mujoco/tree/main) for the G1. + +- Press `9` to release the robot +- Press `7` / `8` to increase / decrease waist height + +### Connect to the Robot + +The G1's Ethernet IP is fixed at `192.168.123.164`. Your machine must have a static IP on the same subnet: `192.168.123.x` where `x ≠ 164`. ```bash # Replace 'enp131s0' with your ethernet interface name (check with `ip a`) @@ -26,272 +52,200 @@ sudo ip addr add 192.168.123.200/24 dev enp131s0 sudo ip link set enp131s0 up ``` -**Note**: The G1's Ethernet IP is fixed at `192.168.123.164`. Your computer must use `192.168.123.x` with x ≠ 164. - -### Step 2: SSH into the Robot +### SSH into the Robot ```bash ssh unitree@192.168.123.164 # Password: 123 ``` -You should now be connected to the G1's Orin. +### Install LeRobot on the G1 + +From the robot: + +```bash +conda create -y -n lerobot python=3.12 +conda activate lerobot +git clone https://github.com/unitreerobotics/unitree_sdk2_python.git +cd unitree_sdk2_python && pip install -e . +git clone https://github.com/huggingface/lerobot.git +cd lerobot +pip install -e '.[unitree_g1]' +``` + +> **Note:** The Unitree SDK requires CycloneDDS v0.10.2. See the [Unitree SDK docs](https://github.com/unitreerobotics/unitree_sdk2_python) for details. --- ## Part 2: Enable WiFi on the Robot -Wlan0 is disabled by default on the G1. To enable it: - -### Step 1: Enable WiFi Hardware +Wi-Fi connectivity is blocked by default on the G1. To activate: ```bash -sudo rfkill unblock wifi sudo rfkill unblock all - -# Bring up wlan0 sudo ip link set wlan0 up - -# Enable NetworkManager control of wlan0 sudo nmcli radio wifi on sudo nmcli device set wlan0 managed yes sudo systemctl restart NetworkManager ``` -### Step 2: Enable Internet Forwarding - -**On your laptop:** +**On your laptop** (share internet via Ethernet): ```bash -# Enable IP forwarding sudo sysctl -w net.ipv4.ip_forward=1 -# Set up NAT (replace wlp132s0f0 with your WiFi interface) +# Replace wlp132s0f0 with your WiFi interface name sudo iptables -t nat -A POSTROUTING -o wlp132s0f0 -s 192.168.123.0/24 -j MASQUERADE sudo iptables -A FORWARD -i wlp132s0f0 -o enp131s0 -m state --state RELATED,ESTABLISHED -j ACCEPT sudo iptables -A FORWARD -i enp131s0 -o wlp132s0f0 -j ACCEPT ``` -**On the G1:** +**On the G1** (set default route through your laptop): ```bash -# Add laptop as default gateway sudo ip route del default 2>/dev/null || true sudo ip route add default via 192.168.123.200 dev eth0 echo "nameserver 8.8.8.8" | sudo tee /etc/resolv.conf -# Test connection +# Verify ping -c 3 8.8.8.8 ``` -### Step 3: Connect to WiFi Network +**Connect to a WiFi network:** ```bash -# List available networks nmcli device wifi list -# Connect to your WiFi (example) sudo nmcli connection add type wifi ifname wlan0 con-name "YourNetwork" ssid "YourNetwork" sudo nmcli connection modify "YourNetwork" wifi-sec.key-mgmt wpa-psk sudo nmcli connection modify "YourNetwork" wifi-sec.psk "YourPassword" sudo nmcli connection modify "YourNetwork" connection.autoconnect yes sudo nmcli connection up "YourNetwork" -# Check WiFi IP address ip a show wlan0 ``` -### Step 4: SSH Over WiFi - -Once connected to WiFi, note the robot's IP address and disconnect the Ethernet cable. You can now SSH over WiFi: +You can now SSH over WiFi: ```bash -ssh unitree@ +ssh unitree@ # Password: 123 ``` -Replace `` with your robot's actual WiFi IP address. - --- -## Part 3: Robot Server Setup +## Part 3: Teleoperation & Locomotion -### Step 1: Install LeRobot on the Orin - -SSH into the robot and install LeRobot: - -```bash -ssh unitree@ - -conda create -y -n lerobot python=3.12 -conda activate lerobot -git clone https://github.com/huggingface/lerobot.git -cd lerobot -pip install -e '.[unitree_g1]' -git clone https://github.com/unitreerobotics/unitree_sdk2_python.git -cd unitree_sdk2_python && pip install -e . -``` - -**Note**: The Unitree SDK requires CycloneDDS v0.10.2 to be installed. See the [Unitree SDK documentation](https://github.com/unitreerobotics/unitree_sdk2_python) for details. - -### Step 2: Run the Robot Server +### Run the Robot Server On the robot: ```bash -python src/lerobot/robots/unitree_g1/run_g1_server.py +python src/lerobot/robots/unitree_g1/run_g1_server.py --camera ``` -**Important**: Keep this terminal running. The server must be active for remote control. +### Run the Locomotion Policy + +```bash +lerobot-teleoperate \ + --robot.type=unitree_g1 \ + --robot.is_simulation=false \ + --robot.robot_ip= \ + --teleop.type=unitree_g1 \ + --teleop.id=wbc_unitree \ + --robot.cameras='{"global_view": {"type": "zmq", "server_address": "", "port": 5555, "camera_name": "head_camera", "width": 640, "height": 480, "fps": 30}}' \ + --display_data=true \ + --robot.controller=HolosomaLocomotionController +``` + +We support both [HolosomaLocomotionController](https://github.com/amazon-far/holosoma) and [GrootLocomotionController](https://github.com/NVlabs/GR00T-WholeBodyControl). --- -## Part 4: Controlling the robot +## Part 4: Loco-Manipulation with the Homunculus Exoskeleton -With the robot server running, you can now control the robot remotely. Let's launch a locomotion policy +We provide a loco-manipulation solution via the Homunculus Exoskeleton — an open-source 7 DoF exoskeleton for whole-body control. Assembly instructions [here](https://github.com/nepyope/hmc_exo). -### Step 1: Install LeRobot on your machine - -```bash -conda create -y -n lerobot python=3.12 -conda activate lerobot -git clone https://github.com/huggingface/lerobot.git -cd lerobot -pip install -e '.[unitree_g1]' -git clone https://github.com/unitreerobotics/unitree_sdk2_python.git -cd unitree_sdk2_python && pip install -e . -``` - -### Step 2: Update Robot IP in Config - -Edit the config file to match your robot's WiFi IP: - -```python -# In src/lerobot/robots/unitree_g1/config_unitree_g1.py -robot_ip: str = "" # Replace with your robot's WiFi IP. -``` - -### Step 3: Run the Locomotion Policy - -```bash -# Run GR00T locomotion controller -python examples/unitree_g1/gr00t_locomotion.py --repo-id "nepyope/GR00T-WholeBodyControl_g1" - -# Run Holosoma locomotion controller -python examples/unitree_g1/holosoma_locomotion.py - -``` - -Press `Ctrl+C` to stop the policy. - ---- - -## Running in Simulation Mode (MuJoCo) - -You can test policies before deploying on the physical robot using MuJoCo simulation. Set `is_simulation=True` in config or pass `--robot.is_simulation=true` via CLI. - -### Calibrate Exoskeleton Teleoperator +### Calibrate ```bash lerobot-calibrate \ - --teleop.type=unitree_g1 \ - --teleop.left_arm_config.port=/dev/ttyACM1 \ - --teleop.right_arm_config.port=/dev/ttyACM0 \ - --teleop.id=exo + --teleop.type=unitree_g1 \ + --teleop.left_arm_config.port=/dev/ttyACM1 \ + --teleop.right_arm_config.port=/dev/ttyACM0 \ + --teleop.id=exo ``` -### Teleoperate in Simulation +During calibration move each joint through its entire range. After fitting, move the joint in a neutral position and press `n` to advance. -```bash -lerobot-teleoperate \ - --robot.type=unitree_g1 \ - --robot.is_simulation=true \ - --teleop.type=unitree_g1 \ - --teleop.left_arm_config.port=/dev/ttyACM1 \ - --teleop.right_arm_config.port=/dev/ttyACM0 \ - --teleop.id=exo \ - --fps=100 -``` - -### Record Dataset in Simulation +### Record a Dataset ```bash lerobot-record \ - --robot.type=unitree_g1 \ - --robot.is_simulation=true \ - --robot.cameras='{"global_view": {"type": "zmq", "server_address": "localhost", "port": 5555, "camera_name": "head_camera", "width": 640, "height": 480, "fps": 30}}' \ - --teleop.type=unitree_g1 \ - --teleop.left_arm_config.port=/dev/ttyACM1 \ - --teleop.right_arm_config.port=/dev/ttyACM0 \ - --teleop.id=exo \ - --dataset.repo_id=your-username/dataset-name \ - --dataset.single_task="Test" \ - --dataset.num_episodes=2 \ - --dataset.episode_time_s=5 \ - --dataset.reset_time_s=5 \ - --dataset.push_to_hub=true \ - --dataset.streaming_encoding=true \ - # --dataset.vcodec=auto \ - --dataset.encoder_threads=2 + --robot.type=unitree_g1 \ + --robot.is_simulation=true \ + --robot.cameras='{"global_view": {"type": "zmq", "server_address": "localhost", "port": 5555, "camera_name": "head_camera", "width": 640, "height": 480, "fps": 30}}' \ + --teleop.type=unitree_g1 \ + --teleop.left_arm_config.port=/dev/ttyACM1 \ + --teleop.right_arm_config.port=/dev/ttyACM0 \ + --teleop.id=exo \ + --dataset.repo_id=your-username/dataset-name \ + --dataset.single_task="Test" \ + --dataset.num_episodes=2 \ + --dataset.episode_time_s=5 \ + --dataset.reset_time_s=5 \ + --dataset.push_to_hub=true \ + --dataset.streaming_encoding=true \ + --dataset.encoder_threads=2 ``` -Example simulation dataset: [nepyope/teleop_test_sim](https://huggingface.co/datasets/nepyope/teleop_test_sim) +> **Note:** Omit `--teleop.left_arm_config.port` and `--teleop.right_arm_config.port` if you're only using the joystick. + +Example dataset: [nepyope/unitree_box_move_blue_full](https://huggingface.co/datasets/nepyope/unitree_box_move_blue_full) --- -## Running on Real Robot +## Part 5: Training & Inference -Once the robot server is running on the G1 (see Part 3), you can teleoperate and record on the real robot. - -### Start the Camera Server - -On the robot, start the ZMQ image server: +### Train ```bash -python src/lerobot/cameras/zmq/image_server.py +python src/lerobot/scripts/lerobot_train.py \ + --dataset.repo_id=your-username/dataset-name \ + --policy.type=pi05 \ + --output_dir=./outputs/pi05_training \ + --job_name=pi05_training \ + --policy.repo_id=your-username/your-repo-id \ + --policy.pretrained_path=lerobot/pi05_base \ + --policy.compile_model=true \ + --policy.gradient_checkpointing=true \ + --wandb.enable=true \ + --policy.dtype=bfloat16 \ + --policy.freeze_vision_encoder=false \ + --policy.train_expert_only=false \ + --steps=3000 \ + --policy.device=cuda \ + --batch_size=32 ``` -Keep this running in a separate terminal for camera streaming during recording. +### Inference with RTC -### Teleoperate Real Robot +Once trained, we recommend deploying policies using inference-time RTC: ```bash -lerobot-teleoperate \ - --robot.type=unitree_g1 \ - --robot.is_simulation=false \ - --teleop.type=unitree_g1 \ - --teleop.left_arm_config.port=/dev/ttyACM1 \ - --teleop.right_arm_config.port=/dev/ttyACM0 \ - --teleop.id=exo \ - --fps=100 +python examples/rtc/eval_with_real_robot.py \ + --policy.path=your-username/your-repo-id \ + --policy.device=cuda \ + --robot.type=unitree_g1 \ + --robot.is_simulation=false \ + --robot.controller=HolosomaLocomotionController \ + --robot.cameras='{"global_view": {"type": "zmq", "server_address": "", "port": 5555, "camera_name": "head_camera", "width": 640, "height": 480, "fps": 30}}' \ + --task="task_description" \ + --duration=1000 \ + --fps=30 \ + --rtc.enabled=true ``` -### Record Dataset on Real Robot - -```bash -lerobot-record \ - --robot.type=unitree_g1 \ - --robot.is_simulation=false \ - --robot.cameras='{"global_view": {"type": "zmq", "server_address": "172.18.129.215", "port": 5555, "camera_name": "head_camera", "width": 640, "height": 480, "fps": 30}}' \ - --teleop.type=unitree_g1 \ - --teleop.left_arm_config.port=/dev/ttyACM1 \ - --teleop.right_arm_config.port=/dev/ttyACM0 \ - --teleop.id=exo \ - --dataset.repo_id=your-username/dataset-name \ - --dataset.single_task="Test" \ - --dataset.num_episodes=2 \ - --dataset.episode_time_s=5 \ - --dataset.reset_time_s=5 \ - --dataset.push_to_hub=true \ - --dataset.streaming_encoding=true \ - # --dataset.vcodec=auto \ - --dataset.encoder_threads=2 -``` - -**Note**: Update `server_address` to match your robot's camera server IP. - -Example real robot dataset: [nepyope/teleop_test_real](https://huggingface.co/datasets/nepyope/teleop_test_real) - --- ## Additional Resources @@ -300,8 +254,8 @@ Example real robot dataset: [nepyope/teleop_test_real](https://huggingface.co/da - [GR00T-WholeBodyControl](https://github.com/NVlabs/GR00T-WholeBodyControl) - [Holosoma](https://github.com/amazon-far/holosoma) - [LeRobot Documentation](https://github.com/huggingface/lerobot) -- [Unitree_IL_Lerobot](https://github.com/unitreerobotics/unitree_IL_lerobot) +- [Unitree IL LeRobot](https://github.com/unitreerobotics/unitree_IL_lerobot) --- -_Last updated: December 2025_ +_Last updated: March 2026_ diff --git a/examples/rtc/eval_with_real_robot.py b/examples/rtc/eval_with_real_robot.py index 4c803eb7e..9d9e1364a 100644 --- a/examples/rtc/eval_with_real_robot.py +++ b/examples/rtc/eval_with_real_robot.py @@ -78,6 +78,7 @@ from torch import Tensor from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401 from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401 +from lerobot.cameras.zmq.configuration_zmq import ZMQCameraConfig # noqa: F401 from lerobot.configs import parser from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.types import RTCAttentionSchedule @@ -97,6 +98,7 @@ from lerobot.robots import ( # noqa: F401 bi_so_follower, koch_follower, so_follower, + unitree_g1, ) from lerobot.robots.utils import make_robot_from_config from lerobot.utils.constants import OBS_IMAGES diff --git a/pyproject.toml b/pyproject.toml index d75d6b788..696d8597d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -119,11 +119,13 @@ gamepad = ["lerobot[pygame-dep]", "hidapi>=0.14.0,<0.15.0"] hopejr = ["lerobot[feetech]", "lerobot[pygame-dep]"] lekiwi = ["lerobot[feetech]", "pyzmq>=26.2.1,<28.0.0"] unitree_g1 = [ + "unitree-sdk2==1.0.1", "pyzmq>=26.2.1,<28.0.0", "onnxruntime>=1.16.0,<2.0.0", "pin>=3.0.0,<4.0.0", "meshcat>=0.3.0,<0.4.0", "lerobot[matplotlib-dep]", + "lerobot[pygame-dep]", "casadi>=3.6.0,<4.0.0", ] reachy2 = ["reachy2_sdk>=1.0.15,<1.1.0"] @@ -206,6 +208,7 @@ all = [ "lerobot[metaworld]", "lerobot[sarm]", "lerobot[peft]", + # "lerobot[unitree_g1]", TODO: Unitree requires specific installation instructions for unitree_sdk2 ] [project.scripts] diff --git a/src/lerobot/cameras/zmq/camera_zmq.py b/src/lerobot/cameras/zmq/camera_zmq.py index 16523b50a..2fbe50d8b 100644 --- a/src/lerobot/cameras/zmq/camera_zmq.py +++ b/src/lerobot/cameras/zmq/camera_zmq.py @@ -181,7 +181,7 @@ class ZMQCamera(Camera): try: message = self.socket.recv_string() except Exception as e: - # Check for ZMQ timeout (EAGAIN/Again) without requiring global zmq import + # zmq is lazy-imported in connect(), so check by name to avoid a top-level import if type(e).__name__ == "Again": raise TimeoutError(f"{self} timeout after {self.timeout_ms}ms") from e raise diff --git a/src/lerobot/cameras/zmq/image_server.py b/src/lerobot/cameras/zmq/image_server.py index 2da366cef..8222b9fee 100644 --- a/src/lerobot/cameras/zmq/image_server.py +++ b/src/lerobot/cameras/zmq/image_server.py @@ -23,6 +23,7 @@ import base64 import contextlib import json import logging +import threading import time from collections import deque @@ -42,10 +43,57 @@ def encode_image(image: np.ndarray, quality: int = 80) -> str: return base64.b64encode(buffer).decode("utf-8") +class CameraCaptureThread: + """Background thread that continuously captures and encodes frames from a camera.""" + + def __init__(self, camera: OpenCVCamera, name: str): + self.camera = camera + self.name = name + self.latest_encoded: str | None = None # Pre-encoded JPEG as base64 + self.latest_timestamp: float = 0.0 + self.frame_lock = threading.Lock() + self.running = False + self.thread: threading.Thread | None = None + + def start(self): + """Start the capture thread.""" + self.running = True + self.thread = threading.Thread(target=self._capture_loop, daemon=True) + self.thread.start() + + def stop(self): + """Stop the capture thread.""" + self.running = False + if self.thread: + self.thread.join(timeout=1.0) + + def _capture_loop(self): + """Continuously capture and encode frames at the camera's native rate.""" + while self.running: + try: + frame = self.camera.read() # Blocks at camera's native rate + timestamp = time.time() + # Encode immediately in capture thread (this is the slow part) + encoded = encode_image(frame) + with self.frame_lock: + self.latest_encoded = encoded + self.latest_timestamp = timestamp + except Exception as e: + logger.warning(f"Camera {self.name} capture error: {e}") + time.sleep(0.01) + + def get_latest(self) -> tuple[str | None, float]: + """Get the latest encoded frame and its timestamp.""" + with self.frame_lock: + return self.latest_encoded, self.latest_timestamp + + class ImageServer: def __init__(self, config: dict, port: int = 5555): + # fps controls the publish loop rate (how often frames are sent over ZMQ), not the camera capture rate self.fps = config.get("fps", 30) self.cameras: dict[str, OpenCVCamera] = {} + self.capture_threads: dict[str, CameraCaptureThread] = {} for name, cfg in config.get("cameras", {}).items(): shape = cfg.get("shape", [480, 640]) @@ -61,6 +109,10 @@ class ImageServer: self.cameras[name] = camera logger.info(f"Camera {name}: {shape[1]}x{shape[0]}") + # Create capture thread for this camera + capture_thread = CameraCaptureThread(camera, name) + self.capture_threads[name] = capture_thread + # ZMQ PUB socket self.context = zmq.Context() self.socket = self.context.socket(zmq.PUB) @@ -73,6 +125,18 @@ class ImageServer: def run(self): frame_count = 0 frame_times = deque(maxlen=60) + last_published_ts: dict[str, float] = {} + + # Start all capture threads + for capture_thread in self.capture_threads.values(): + capture_thread.start() + + # Wait for first frames to be captured and encoded + logger.info("Waiting for cameras to start capturing...") + for name, capture_thread in self.capture_threads.items(): + while capture_thread.get_latest()[0] is None: + time.sleep(0.01) + logger.info(f"Camera {name} ready (capture + encode in background)") try: while True: @@ -80,10 +144,12 @@ class ImageServer: # Build message message = {"timestamps": {}, "images": {}} - for name, cam in self.cameras.items(): - frame = cam.read() # Returns RGB - message["timestamps"][name] = time.time() - message["images"][name] = encode_image(frame) + for name, capture_thread in self.capture_threads.items(): + encoded, timestamp = capture_thread.get_latest() + if encoded is not None and timestamp > last_published_ts.get(name, 0.0): + message["timestamps"][name] = timestamp + message["images"][name] = encoded + last_published_ts[name] = timestamp # Send as JSON string (suppress if buffer full) with contextlib.suppress(zmq.Again): @@ -102,6 +168,8 @@ class ImageServer: except KeyboardInterrupt: pass finally: + for capture_thread in self.capture_threads.values(): + capture_thread.stop() for cam in self.cameras.values(): cam.disconnect() self.socket.close() diff --git a/src/lerobot/robots/unitree_g1/__init__.py b/src/lerobot/robots/unitree_g1/__init__.py index d91be150f..ef3a9d05e 100644 --- a/src/lerobot/robots/unitree_g1/__init__.py +++ b/src/lerobot/robots/unitree_g1/__init__.py @@ -16,3 +16,5 @@ from .config_unitree_g1 import UnitreeG1Config from .unitree_g1 import UnitreeG1 + +__all__ = ["UnitreeG1", "UnitreeG1Config"] diff --git a/src/lerobot/robots/unitree_g1/config_unitree_g1.py b/src/lerobot/robots/unitree_g1/config_unitree_g1.py index 1b81214a6..b786c2a33 100644 --- a/src/lerobot/robots/unitree_g1/config_unitree_g1.py +++ b/src/lerobot/robots/unitree_g1/config_unitree_g1.py @@ -27,11 +27,10 @@ _GAINS: dict[str, dict[str, list[float]]] = { }, # pitch, roll, yaw, knee, ankle_pitch, ankle_roll "right_leg": {"kp": [150, 150, 150, 300, 40, 40], "kd": [2, 2, 2, 4, 2, 2]}, "waist": {"kp": [250, 250, 250], "kd": [5, 5, 5]}, # yaw, roll, pitch - "left_arm": {"kp": [80, 80, 80, 80], "kd": [3, 3, 3, 3]}, # shoulder_pitch/roll/yaw, elbow + "left_arm": {"kp": [50, 50, 80, 80], "kd": [3, 3, 3, 3]}, # shoulder_pitch/roll/yaw, elbow "left_wrist": {"kp": [40, 40, 40], "kd": [1.5, 1.5, 1.5]}, # roll, pitch, yaw - "right_arm": {"kp": [80, 80, 80, 80], "kd": [3, 3, 3, 3]}, + "right_arm": {"kp": [50, 50, 80, 80], "kd": [3, 3, 3, 3]}, "right_wrist": {"kp": [40, 40, 40], "kd": [1.5, 1.5, 1.5]}, - "other": {"kp": [80, 80, 80, 80, 80, 80], "kd": [3, 3, 3, 3, 3, 3]}, } @@ -68,3 +67,7 @@ class UnitreeG1Config(RobotConfig): # Compensates for gravity on the unitree's arms using the arm ik solver gravity_compensation: bool = False + + # Lower-body controller class name, e.g. "GrootLocomotionController" or + # "HolosomaLocomotionController". None disables it. + controller: str | None = None diff --git a/src/lerobot/robots/unitree_g1/robot_kinematic_processor.py b/src/lerobot/robots/unitree_g1/g1_kinematics.py similarity index 87% rename from src/lerobot/robots/unitree_g1/robot_kinematic_processor.py rename to src/lerobot/robots/unitree_g1/g1_kinematics.py index d086a9986..f57320a11 100644 --- a/src/lerobot/robots/unitree_g1/robot_kinematic_processor.py +++ b/src/lerobot/robots/unitree_g1/g1_kinematics.py @@ -16,13 +16,11 @@ import logging import os -import sys +from collections import deque import numpy as np logger = logging.getLogger(__name__) -parent2_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -sys.path.append(parent2_dir) class WeightedMovingFilter: @@ -31,18 +29,14 @@ class WeightedMovingFilter: self._weights = np.array(weights) self._data_size = data_size self._filtered_data = np.zeros(self._data_size) - self._data_queue = [] + self._data_queue = deque(maxlen=self._window_size) def _apply_filter(self): if len(self._data_queue) < self._window_size: return self._data_queue[-1] data_array = np.array(self._data_queue) - temp_filtered_data = np.zeros(self._data_size) - for i in range(self._data_size): - temp_filtered_data[i] = np.convolve(data_array[:, i], self._weights, mode="valid")[-1] - - return temp_filtered_data + return data_array.T @ self._weights def add_data(self, new_data): assert len(new_data) == self._data_size @@ -52,9 +46,6 @@ class WeightedMovingFilter: ): # skip duplicate data return - if len(self._data_queue) >= self._window_size: - self._data_queue.pop(0) - self._data_queue.append(new_data) self._filtered_data = self._apply_filter() @@ -71,8 +62,6 @@ class G1_29_ArmIK: # noqa: N801 from pinocchio import casadi as cpin self._pin = pin - np.set_printoptions(precision=5, suppress=True, linewidth=200) - self.unit_test = unit_test self.repo_path = snapshot_download("lerobot/unitree-g1-mujoco") @@ -249,50 +238,35 @@ class G1_29_ArmIK: # noqa: N801 self.opti.set_value(self.param_tf_r, right_wrist) self.opti.set_value(self.var_q_last, self.init_data) # for smooth + converged = True try: self.opti.solve() - sol_q = self.opti.value(self.var_q) - self.smooth_filter.add_data(sol_q) - sol_q = self.smooth_filter.filtered_data - - if current_lr_arm_motor_dq is not None: - v = current_lr_arm_motor_dq * 0.0 - else: - v = (sol_q - self.init_data) * 0.0 - - self.init_data = sol_q - - sol_tauff = self._pin.rnea( - self.reduced_robot.model, - self.reduced_robot.data, - sol_q, - v, - np.zeros(self.reduced_robot.model.nv), - ) - - return sol_q, sol_tauff - except Exception as e: - logger.error(f"ERROR in convergence, plotting debug info.{e}") - + converged = False + logger.error(f"IK convergence error: {e}") sol_q = self.opti.debug.value(self.var_q) - self.smooth_filter.add_data(sol_q) - sol_q = self.smooth_filter.filtered_data - if current_lr_arm_motor_dq is not None: - v = current_lr_arm_motor_dq * 0.0 - else: - v = (sol_q - self.init_data) * 0.0 - - self.init_data = sol_q + self.smooth_filter.add_data(sol_q) + sol_q = self.smooth_filter.filtered_data + self.init_data = sol_q + if not converged: logger.error( f"sol_q:{sol_q} \nmotorstate: \n{current_lr_arm_motor_q} \nleft_pose: \n{left_wrist} \nright_pose: \n{right_wrist}" ) - return current_lr_arm_motor_q, np.zeros(self.reduced_robot.model.nv) + sol_tauff = self._pin.rnea( + self.reduced_robot.model, + self.reduced_robot.data, + sol_q, + np.zeros(self.reduced_robot.model.nv), + np.zeros(self.reduced_robot.model.nv), + ) + + return sol_q, sol_tauff + def solve_tau(self, current_lr_arm_motor_q=None, current_lr_arm_motor_dq=None): try: q_g1 = np.array(current_lr_arm_motor_q, dtype=float) diff --git a/src/lerobot/robots/unitree_g1/g1_utils.py b/src/lerobot/robots/unitree_g1/g1_utils.py index 4e37bdcef..91f009b26 100644 --- a/src/lerobot/robots/unitree_g1/g1_utils.py +++ b/src/lerobot/robots/unitree_g1/g1_utils.py @@ -14,12 +14,34 @@ # See the License for the specific language governing permissions and # limitations under the License. +import importlib from enum import IntEnum +import numpy as np + # ruff: noqa: N801, N815 NUM_MOTORS = 29 +REMOTE_AXES = ("remote.lx", "remote.ly", "remote.rx", "remote.ry") +REMOTE_BUTTONS = tuple(f"remote.button.{i}" for i in range(16)) +REMOTE_KEYS = REMOTE_AXES + REMOTE_BUTTONS + + +def default_remote_input() -> dict[str, float]: + """Return a zeroed-out remote input dict (axes + buttons).""" + return dict.fromkeys(REMOTE_KEYS, 0.0) + + +def get_gravity_orientation(quaternion: list[float] | np.ndarray) -> np.ndarray: + """Get gravity orientation from quaternion [w, x, y, z].""" + qw, qx, qy, qz = quaternion + gravity_orientation = np.zeros(3, dtype=np.float32) + gravity_orientation[0] = 2 * (-qz * qx + qw * qy) + gravity_orientation[1] = -2 * (qz * qy + qw * qx) + gravity_orientation[2] = 1 - 2 * (qw * qw + qz * qz) + return gravity_orientation + class G1_29_JointArmIndex(IntEnum): # Left arm @@ -29,7 +51,7 @@ class G1_29_JointArmIndex(IntEnum): kLeftElbow = 18 kLeftWristRoll = 19 kLeftWristPitch = 20 - kLeftWristyaw = 21 + kLeftWristYaw = 21 # Right arm kRightShoulderPitch = 22 @@ -41,6 +63,21 @@ class G1_29_JointArmIndex(IntEnum): kRightWristYaw = 28 +def make_locomotion_controller(name: str | None): + """Instantiate a locomotion controller by class name. Returns None if name is None.""" + if name is None: + return None + controllers = { + "GrootLocomotionController": "lerobot.robots.unitree_g1.gr00t_locomotion", + "HolosomaLocomotionController": "lerobot.robots.unitree_g1.holosoma_locomotion", + } + module_path = controllers.get(name) + if module_path is None: + raise ValueError(f"Unknown controller: {name!r}. Available: {list(controllers)}") + module = importlib.import_module(module_path) + return getattr(module, name)() + + class G1_29_JointIndex(IntEnum): # Left leg kLeftHipPitch = 0 @@ -69,7 +106,7 @@ class G1_29_JointIndex(IntEnum): kLeftElbow = 18 kLeftWristRoll = 19 kLeftWristPitch = 20 - kLeftWristyaw = 21 + kLeftWristYaw = 21 # Right arm kRightShoulderPitch = 22 diff --git a/examples/unitree_g1/gr00t_locomotion.py b/src/lerobot/robots/unitree_g1/gr00t_locomotion.py similarity index 59% rename from examples/unitree_g1/gr00t_locomotion.py rename to src/lerobot/robots/unitree_g1/gr00t_locomotion.py index 0123b5206..31166e123 100644 --- a/examples/unitree_g1/gr00t_locomotion.py +++ b/src/lerobot/robots/unitree_g1/gr00t_locomotion.py @@ -14,20 +14,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -import argparse import logging -import time from collections import deque import numpy as np import onnxruntime as ort from huggingface_hub import hf_hub_download -from lerobot.robots.unitree_g1.config_unitree_g1 import UnitreeG1Config -from lerobot.robots.unitree_g1.g1_utils import G1_29_JointIndex -from lerobot.robots.unitree_g1.unitree_g1 import UnitreeG1 +from lerobot.robots.unitree_g1.g1_utils import ( + REMOTE_AXES, + REMOTE_BUTTONS, + G1_29_JointIndex, + get_gravity_orientation, +) -logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -36,18 +36,13 @@ GROOT_DEFAULT_ANGLES[[0, 6]] = -0.1 # Hip pitch GROOT_DEFAULT_ANGLES[[3, 9]] = 0.3 # Knee GROOT_DEFAULT_ANGLES[[4, 10]] = -0.2 # Ankle pitch -MISSING_JOINTS = [] -G1_MODEL = "g1_23" # Or "g1_29" -if G1_MODEL == "g1_23": - MISSING_JOINTS = [12, 14, 20, 21, 27, 28] # Waist yaw/pitch, wrist pitch/yaw - # Control parameters ACTION_SCALE = 0.25 CONTROL_DT = 0.02 # 50Hz ANG_VEL_SCALE: float = 0.25 DOF_POS_SCALE: float = 1.0 DOF_VEL_SCALE: float = 0.05 -CMD_SCALE: list = [2.0, 2.0, 0.25] +CMD_SCALE: list[float] = [2.0, 2.0, 0.25] DEFAULT_GROOT_REPO_ID = "nepyope/GR00T-WholeBodyControl_g1" @@ -85,11 +80,11 @@ def load_groot_policies( class GrootLocomotionController: """GR00T lower-body locomotion controller for the Unitree G1.""" - def __init__(self, policy_balance, policy_walk, robot, config): - self.policy_balance = policy_balance - self.policy_walk = policy_walk - self.robot = robot - self.config = config + control_dt = CONTROL_DT # Expose for unitree_g1.py + + def __init__(self): + # Load policies + self.policy_balance, self.policy_walk = load_groot_policies() self.cmd = np.array([0.0, 0.0, 0.0], dtype=np.float32) # vx, vy, theta_dot @@ -109,45 +104,60 @@ class GrootLocomotionController: logger.info("GrootLocomotionController initialized") - def run_step(self): - # Get current observation - obs = self.robot.get_observation() + def reset(self) -> None: + """Reset internal state for a new episode.""" + self.cmd[:] = 0.0 + self.groot_qj_all[:] = 0.0 + self.groot_dqj_all[:] = 0.0 + self.groot_action[:] = 0.0 + self.groot_obs_single[:] = 0.0 + self.groot_obs_stacked[:] = 0.0 + self.groot_height_cmd = 0.74 + self.groot_orientation_cmd[:] = 0.0 + self.groot_obs_history.clear() + for _ in range(6): + self.groot_obs_history.append(np.zeros(86, dtype=np.float32)) - if not obs: - return + def run_step(self, action: dict, lowstate) -> dict: + """Run one step of the locomotion controller. - # Get command from remote controller - if obs["remote.buttons"][0]: # R1 - raise waist + Args: + action: Action dict containing remote.lx/ly/rx/ry and buttons + lowstate: Robot lowstate containing motor positions/velocities and IMU + + Returns: + Action dict for lower body joints (0-14) + """ + if lowstate is None: + return {} + + buttons = [int(action.get(k, 0)) for k in REMOTE_BUTTONS] + if buttons[0]: # R1 - raise waist self.groot_height_cmd += 0.001 self.groot_height_cmd = np.clip(self.groot_height_cmd, 0.50, 1.00) - if obs["remote.buttons"][4]: # R2 - lower waist + if buttons[4]: # R2 - lower waist self.groot_height_cmd -= 0.001 self.groot_height_cmd = np.clip(self.groot_height_cmd, 0.50, 1.00) - self.cmd[0] = obs["remote.ly"] # Forward/backward - self.cmd[1] = obs["remote.lx"] * -1 # Left/right - self.cmd[2] = obs["remote.rx"] * -1 # Rotation rate + lx, ly, rx, _ry = (action.get(k, 0.0) for k in REMOTE_AXES) + self.cmd[0] = ly # Forward/backward + self.cmd[1] = -lx # Left/right (negated) + self.cmd[2] = -rx # Rotation rate (negated) - # Get joint positions and velocities from flat dict + # Get joint positions and velocities from lowstate for motor in G1_29_JointIndex: - name = motor.name idx = motor.value - self.groot_qj_all[idx] = obs[f"{name}.q"] - self.groot_dqj_all[idx] = obs[f"{name}.dq"] - - # Adapt observation for g1_23dof - for idx in MISSING_JOINTS: - self.groot_qj_all[idx] = 0.0 - self.groot_dqj_all[idx] = 0.0 + self.groot_qj_all[idx] = lowstate.motor_state[idx].q + self.groot_dqj_all[idx] = lowstate.motor_state[idx].dq # Scale joint positions and velocities qj_obs = self.groot_qj_all.copy() dqj_obs = self.groot_dqj_all.copy() # Express IMU data in gravity frame of reference - quat = [obs["imu.quat.w"], obs["imu.quat.x"], obs["imu.quat.y"], obs["imu.quat.z"]] - ang_vel = np.array([obs["imu.gyro.x"], obs["imu.gyro.y"], obs["imu.gyro.z"]], dtype=np.float32) - gravity_orientation = self.robot.get_gravity_orientation(quat) + quat = lowstate.imu_state.quaternion + ang_vel = np.array(lowstate.imu_state.gyroscope, dtype=np.float32) + gravity_orientation = get_gravity_orientation(quat) # Scale joint positions and velocities before policy inference qj_obs = (qj_obs - GROOT_DEFAULT_ANGLES) * DOF_POS_SCALE @@ -186,73 +196,10 @@ class GrootLocomotionController: # Transform action back to target joint positions target_dof_pos_15 = GROOT_DEFAULT_ANGLES[:15] + self.groot_action * ACTION_SCALE - # Build action dict (only first 15 joints for GR00T) + # Build action dict action_dict = {} for i in range(15): motor_name = G1_29_JointIndex(i).name action_dict[f"{motor_name}.q"] = float(target_dof_pos_15[i]) - # Zero out missing joints for g1_23dof - for joint_idx in MISSING_JOINTS: - motor_name = G1_29_JointIndex(joint_idx).name - action_dict[f"{motor_name}.q"] = 0.0 - - # Send action to robot - self.robot.send_action(action_dict) - - -def run(repo_id: str = DEFAULT_GROOT_REPO_ID) -> None: - """Main function to run the GR00T locomotion controller. - - Args: - repo_id: Hugging Face Hub repository ID for GR00T policies. - """ - # Load policies - policy_balance, policy_walk = load_groot_policies(repo_id=repo_id) - - # Initialize robot - config = UnitreeG1Config() - robot = UnitreeG1(config) - - robot.connect() - - # Initialize gr00T locomotion controller - groot_controller = GrootLocomotionController( - policy_balance=policy_balance, - policy_walk=policy_walk, - robot=robot, - config=config, - ) - - try: - robot.reset(CONTROL_DT, GROOT_DEFAULT_ANGLES) - - logger.info("Use joystick: LY=fwd/back, LX=left/right, RX=rotate, R1=raise waist, R2=lower waist") - logger.info("Press Ctrl+C to stop") - - # Run step - while not robot._shutdown_event.is_set(): - start_time = time.time() - groot_controller.run_step() - elapsed = time.time() - start_time - sleep_time = max(0, CONTROL_DT - elapsed) - time.sleep(sleep_time) - except KeyboardInterrupt: - logger.info("Stopping locomotion...") - finally: - if robot.is_connected: - robot.disconnect() - logger.info("Done!") - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="GR00T Locomotion Controller for Unitree G1") - parser.add_argument( - "--repo-id", - type=str, - default=DEFAULT_GROOT_REPO_ID, - help=f"Hugging Face Hub repo ID for GR00T policies (default: {DEFAULT_GROOT_REPO_ID})", - ) - args = parser.parse_args() - - run(repo_id=args.repo_id) + return action_dict diff --git a/examples/unitree_g1/holosoma_locomotion.py b/src/lerobot/robots/unitree_g1/holosoma_locomotion.py similarity index 53% rename from examples/unitree_g1/holosoma_locomotion.py rename to src/lerobot/robots/unitree_g1/holosoma_locomotion.py index 3a07023de..857bb97bc 100644 --- a/examples/unitree_g1/holosoma_locomotion.py +++ b/src/lerobot/robots/unitree_g1/holosoma_locomotion.py @@ -14,21 +14,21 @@ # See the License for the specific language governing permissions and # limitations under the License. -import argparse import json import logging -import time import numpy as np import onnx import onnxruntime as ort from huggingface_hub import hf_hub_download -from lerobot.robots.unitree_g1.config_unitree_g1 import UnitreeG1Config -from lerobot.robots.unitree_g1.g1_utils import G1_29_JointIndex -from lerobot.robots.unitree_g1.unitree_g1 import UnitreeG1 +from lerobot.robots.unitree_g1.g1_utils import ( + REMOTE_AXES, + G1_29_JointArmIndex, + G1_29_JointIndex, + get_gravity_orientation, +) -logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) DEFAULT_ANGLES = np.zeros(29, dtype=np.float32) @@ -40,18 +40,13 @@ DEFAULT_ANGLES[16] = 0.2 # Left shoulder roll DEFAULT_ANGLES[23] = -0.2 # Right shoulder roll DEFAULT_ANGLES[[18, 25]] = 0.6 # Elbow -MISSING_JOINTS = [] -G1_MODEL = "g1_23" # Or "g1_29" -if G1_MODEL == "g1_23": - MISSING_JOINTS = [12, 14, 20, 21, 27, 28] # Waist yaw/pitch, wrist pitch/yaw - # Control parameters ACTION_SCALE = 0.25 -CONTROL_DT = 0.02 # 50Hz +CONTROL_DT = 0.005 # 200Hz ANG_VEL_SCALE = 0.25 DOF_POS_SCALE = 1.0 DOF_VEL_SCALE = 0.05 -GAIT_PERIOD = 1.0 +GAIT_PERIOD = 0.5 DEFAULT_HOLOSOMA_REPO_ID = "nepyope/holosoma_locomotion" @@ -87,7 +82,7 @@ def load_policy( logger.info(f"Policy loaded: {policy.get_inputs()[0].shape} → {policy.get_outputs()[0].shape}") # Extract KP/KD from ONNX metadata - model = onnx.load(policy_path) + model = onnx.load(policy_path, load_external_data=False) metadata = {prop.key: prop.value for prop in model.metadata_props} if "kp" not in metadata or "kd" not in metadata: @@ -101,15 +96,13 @@ def load_policy( class HolosomaLocomotionController: - """Holosoma whole-body locomotion controller for Unitree G1.""" + """Holosoma lower-body locomotion controller for Unitree G1.""" - def __init__(self, policy, robot, kp: np.ndarray, kd: np.ndarray): - self.policy = policy - self.robot = robot + control_dt = CONTROL_DT # Expose for unitree_g1.py - # Override robot's PD gains with policy gains - self.robot.kp = kp - self.robot.kd = kd + def __init__(self): + # Load policy and gains + self.policy, self.kp, self.kd = load_policy() self.cmd = np.zeros(3, dtype=np.float32) @@ -124,35 +117,55 @@ class HolosomaLocomotionController: self.phase_dt = 2 * np.pi / ((1.0 / CONTROL_DT) * GAIT_PERIOD) self.is_standing = True - def run_step(self): - # Get current observation - obs = self.robot.get_observation() + logger.info("HolosomaLocomotionController initialized") - if not obs: - return + def reset(self) -> None: + """Reset internal state for a new episode.""" + self.cmd[:] = 0.0 + self.qj[:] = 0.0 + self.dqj[:] = 0.0 + self.obs[:] = 0.0 + self.last_action[:] = 0.0 + self.phase = np.array([[0.0, np.pi]], dtype=np.float32) + self.is_standing = True - # Get command from remote controller - ly = obs["remote.ly"] if abs(obs["remote.ly"]) > 0.1 else 0.0 - lx = obs["remote.lx"] if abs(obs["remote.lx"]) > 0.1 else 0.0 - rx = obs["remote.rx"] if abs(obs["remote.rx"]) > 0.1 else 0.0 + def run_step(self, action: dict, lowstate) -> dict: + """Run one step of the locomotion controller. + + Args: + action: Action dict containing remote.lx/ly/rx/ry + lowstate: Robot lowstate containing motor positions/velocities and IMU + + Returns: + Action dict for lower body joints (0-14) + """ + if lowstate is None: + return {} + + lx, ly, rx, _ry = (action.get(k, 0.0) for k in REMOTE_AXES) + ly = ly if abs(ly) > 0.1 else 0.0 + lx = lx if abs(lx) > 0.1 else 0.0 + rx = rx if abs(rx) > 0.1 else 0.0 + ly = np.clip(ly, -0.3, 0.3) + lx = np.clip(lx, -0.3, 0.3) self.cmd[:] = [ly, -lx, -rx] - # Get joint positions and velocities + # Get joint positions and velocities from lowstate for motor in G1_29_JointIndex: - name = motor.name idx = motor.value - self.qj[idx] = obs[f"{name}.q"] - self.dqj[idx] = obs[f"{name}.dq"] + self.qj[idx] = lowstate.motor_state[idx].q + self.dqj[idx] = lowstate.motor_state[idx].dq - # Adapt observation for g1_23dof - for idx in MISSING_JOINTS: - self.qj[idx] = 0.0 - self.dqj[idx] = 0.0 + # Hide arm positions from policy (show DEFAULT_ANGLES instead) + # This prevents policy from reacting to teleop arm movements + for arm_joint in G1_29_JointArmIndex: + self.qj[arm_joint.value] = DEFAULT_ANGLES[arm_joint.value] + self.dqj[arm_joint.value] = 0.0 # Express IMU data in gravity frame of reference - quat = [obs["imu.quat.w"], obs["imu.quat.x"], obs["imu.quat.y"], obs["imu.quat.z"]] - ang_vel = np.array([obs["imu.gyro.x"], obs["imu.gyro.y"], obs["imu.gyro.z"]], dtype=np.float32) - gravity = self.robot.get_gravity_orientation(quat) + quat = lowstate.imu_state.quaternion + ang_vel = np.array(lowstate.imu_state.gyroscope, dtype=np.float32) + gravity = get_gravity_orientation(quat) # Scale joint positions and velocities before policy inference qj_obs = (self.qj - DEFAULT_ANGLES) * DOF_POS_SCALE @@ -186,79 +199,16 @@ class HolosomaLocomotionController: # Run policy inference ort_in = {self.policy.get_inputs()[0].name: self.obs.reshape(1, -1).astype(np.float32)} raw_action = self.policy.run(None, ort_in)[0].squeeze() - action = np.clip(raw_action, -100.0, 100.0) - self.last_action = action.copy() + policy_action = np.clip(raw_action, -100.0, 100.0) + self.last_action = policy_action.copy() # Transform action back to target joint positions - target = DEFAULT_ANGLES + action * ACTION_SCALE + target = DEFAULT_ANGLES + policy_action * ACTION_SCALE - # Build action dict + # Build action dict (first 15 joints only) action_dict = {} - for motor in G1_29_JointIndex: - action_dict[f"{motor.name}.q"] = float(target[motor.value]) + for i in range(15): + motor_name = G1_29_JointIndex(i).name + action_dict[f"{motor_name}.q"] = float(target[i]) - # Zero out missing joints for g1_23dof - for joint_idx in MISSING_JOINTS: - motor_name = G1_29_JointIndex(joint_idx).name - action_dict[f"{motor_name}.q"] = 0.0 - - # Send action to robot - self.robot.send_action(action_dict) - - -def run(repo_id: str = DEFAULT_HOLOSOMA_REPO_ID, policy_type: str = "fastsac") -> None: - """Main function to run the Holosoma locomotion controller. - - Args: - repo_id: Hugging Face Hub repository ID for Holosoma policies. - policy_type: Policy type to use ('fastsac' or 'ppo'). - """ - # Load policy and gains - policy, kp, kd = load_policy(repo_id=repo_id, policy_type=policy_type) - - # Initialize robot - config = UnitreeG1Config() - robot = UnitreeG1(config) - robot.connect() - - holosoma_controller = HolosomaLocomotionController(policy, robot, kp, kd) - - try: - robot.reset(CONTROL_DT, DEFAULT_ANGLES) - - logger.info("Use joystick: LY=fwd/back, LX=left/right, RX=rotate") - logger.info("Press Ctrl+C to stop") - - # Run step - while not robot._shutdown_event.is_set(): - start_time = time.time() - holosoma_controller.run_step() - elapsed = time.time() - start_time - sleep_time = max(0, CONTROL_DT - elapsed) - time.sleep(sleep_time) - except KeyboardInterrupt: - logger.info("Stopping locomotion...") - finally: - if robot.is_connected: - robot.disconnect() - logger.info("Done!") - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Holosoma Locomotion Controller for Unitree G1") - parser.add_argument( - "--repo-id", - type=str, - default=DEFAULT_HOLOSOMA_REPO_ID, - help=f"Hugging Face Hub repo ID for Holosoma policies (default: {DEFAULT_HOLOSOMA_REPO_ID})", - ) - parser.add_argument( - "--policy", - type=str, - choices=["fastsac", "ppo"], - default="fastsac", - help="Policy type to use: 'fastsac' (default) or 'ppo'", - ) - args = parser.parse_args() - - run(repo_id=args.repo_id, policy_type=args.policy) + return action_dict diff --git a/src/lerobot/robots/unitree_g1/run_g1_server.py b/src/lerobot/robots/unitree_g1/run_g1_server.py index 70166b590..b5bd0baf8 100644 --- a/src/lerobot/robots/unitree_g1/run_g1_server.py +++ b/src/lerobot/robots/unitree_g1/run_g1_server.py @@ -24,6 +24,7 @@ This server runs on the robot and forwards: Uses JSON for secure serialization instead of pickle. """ +import argparse import base64 import contextlib import json @@ -38,6 +39,8 @@ from unitree_sdk2py.idl.default import unitree_hg_msg_dds__LowCmd_ from unitree_sdk2py.idl.unitree_hg.msg.dds_ import LowCmd_ as hg_LowCmd, LowState_ as hg_LowState from unitree_sdk2py.utils.crc import CRC +from lerobot.cameras.zmq.image_server import ImageServer + # DDS topic names follow Unitree SDK naming conventions # ruff: noqa: N816 kTopicLowCommand_Debug = "rt/lowcmd" # action to robot @@ -150,6 +153,32 @@ def cmd_forward_loop( def main() -> None: """Main entry point for the robot server bridge.""" + parser = argparse.ArgumentParser(description="DDS-to-ZMQ bridge server for Unitree G1") + parser.add_argument("--camera", action="store_true", help="Also launch camera server") + parser.add_argument("--camera-device", type=int, default=4, help="Camera device ID (default: 4)") + parser.add_argument("--camera-fps", type=int, default=30, help="Camera FPS (default: 30)") + parser.add_argument("--camera-width", type=int, default=640, help="Camera width (default: 640)") + parser.add_argument("--camera-height", type=int, default=480, help="Camera height (default: 480)") + parser.add_argument("--camera-port", type=int, default=5555, help="Camera ZMQ port (default: 5555)") + args = parser.parse_args() + + # Optionally start camera server in background thread + camera_thread = None + if args.camera: + camera_config = { + "fps": args.camera_fps, + "cameras": { + "head_camera": { + "device_id": args.camera_device, + "shape": [args.camera_height, args.camera_width], + } + }, + } + camera_server = ImageServer(camera_config, port=args.camera_port) + camera_thread = threading.Thread(target=camera_server.run, daemon=True) + camera_thread.start() + print(f"Camera server started on port {args.camera_port} (device {args.camera_device})") + # initialize DDS ChannelFactoryInitialize(0) @@ -206,6 +235,8 @@ def main() -> None: shutdown_event.set() ctx.term() # terminates blocking zmq.recv() calls t_state.join(timeout=2.0) + if camera_thread is not None: + camera_thread.join(timeout=2.0) if __name__ == "__main__": diff --git a/src/lerobot/robots/unitree_g1/unitree_g1.py b/src/lerobot/robots/unitree_g1/unitree_g1.py index df0de8f19..41146ebe6 100644 --- a/src/lerobot/robots/unitree_g1/unitree_g1.py +++ b/src/lerobot/robots/unitree_g1/unitree_g1.py @@ -14,27 +14,67 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import logging -import struct import threading import time from dataclasses import dataclass, field from functools import cached_property -from typing import Any +from typing import TYPE_CHECKING, Protocol, runtime_checkable import numpy as np from lerobot.cameras.utils import make_cameras_from_configs from lerobot.envs.factory import make_env from lerobot.processor import RobotAction, RobotObservation -from lerobot.robots.unitree_g1.g1_utils import G1_29_JointArmIndex, G1_29_JointIndex -from lerobot.robots.unitree_g1.robot_kinematic_processor import G1_29_ArmIK +from lerobot.robots.unitree_g1.g1_kinematics import G1_29_ArmIK +from lerobot.robots.unitree_g1.g1_utils import ( + REMOTE_AXES, + REMOTE_KEYS, + G1_29_JointArmIndex, + G1_29_JointIndex, + default_remote_input, + make_locomotion_controller, +) +from lerobot.utils.import_utils import _unitree_sdk_available from ..robot import Robot from .config_unitree_g1 import UnitreeG1Config +if TYPE_CHECKING or _unitree_sdk_available: + from unitree_sdk2py.core.channel import ( + ChannelFactoryInitialize as _SDKChannelFactoryInitialize, + ChannelPublisher as _SDKChannelPublisher, + ChannelSubscriber as _SDKChannelSubscriber, + ) + from unitree_sdk2py.idl.default import unitree_hg_msg_dds__LowCmd_ + from unitree_sdk2py.idl.unitree_hg.msg.dds_ import ( + LowCmd_ as hg_LowCmd, + LowState_ as hg_LowState, + ) + from unitree_sdk2py.utils.crc import CRC +else: + _SDKChannelFactoryInitialize = None + _SDKChannelPublisher = None + _SDKChannelSubscriber = None + unitree_hg_msg_dds__LowCmd_ = None + hg_LowCmd = None + hg_LowState = None + CRC = None + logger = logging.getLogger(__name__) + +@runtime_checkable +class LocomotionController(Protocol): + control_dt: float + + def run_step(self, action: dict, lowstate) -> dict: ... + + def reset(self) -> None: ... + + # DDS topic names follow Unitree SDK naming conventions # ruff: noqa: N816 kTopicLowCommand_Debug = "rt/lowcmd" @@ -63,7 +103,7 @@ class IMUState: class G1_29_LowState: # noqa: N801 motor_state: list[MotorState] = field(default_factory=lambda: [MotorState() for _ in G1_29_JointIndex]) imu_state: IMUState = field(default_factory=IMUState) - wireless_remote: Any = None # Raw wireless remote data + wireless_remote: bytes | None = None # Raw wireless remote data mode_machine: int = 0 # Robot mode @@ -71,25 +111,6 @@ class UnitreeG1(Robot): config_class = UnitreeG1Config name = "unitree_g1" - # unitree remote controller - class RemoteController: - def __init__(self): - self.lx = 0 - self.ly = 0 - self.rx = 0 - self.ry = 0 - self.button = [0] * 16 - - def set(self, data): - # wireless_remote - keys = struct.unpack("H", data[2:4])[0] - for i in range(16): - self.button[i] = (keys & (1 << i)) >> i - self.lx = struct.unpack("f", data[4:8])[0] - self.rx = struct.unpack("f", data[8:12])[0] - self.ry = struct.unpack("f", data[12:16])[0] - self.ly = struct.unpack("f", data[20:24])[0] - def __init__(self, config: UnitreeG1Config): super().__init__(config) @@ -103,11 +124,9 @@ class UnitreeG1(Robot): # Import channel classes based on mode if config.is_simulation: - from unitree_sdk2py.core.channel import ( - ChannelFactoryInitialize, - ChannelPublisher, - ChannelSubscriber, - ) + self._ChannelFactoryInitialize = _SDKChannelFactoryInitialize + self._ChannelPublisher = _SDKChannelPublisher + self._ChannelSubscriber = _SDKChannelSubscriber else: from lerobot.robots.unitree_g1.unitree_sdk2_socket import ( ChannelFactoryInitialize, @@ -115,22 +134,30 @@ class UnitreeG1(Robot): ChannelSubscriber, ) - # Store for use in connect() - self._ChannelFactoryInitialize = ChannelFactoryInitialize - self._ChannelPublisher = ChannelPublisher - self._ChannelSubscriber = ChannelSubscriber + self._ChannelFactoryInitialize = ChannelFactoryInitialize + self._ChannelPublisher = ChannelPublisher + self._ChannelSubscriber = ChannelSubscriber # Initialize state variables self.sim_env = None self._env_wrapper = None self._lowstate = None + self._lowstate_lock = threading.Lock() self._shutdown_event = threading.Event() self.subscribe_thread = None - self.remote_controller = self.RemoteController() - self.arm_ik = G1_29_ArmIK() + self.arm_ik = G1_29_ArmIK() if config.gravity_compensation else None - def _subscribe_motor_state(self): # polls robot state @ 250Hz + # Lower-body controller loaded dynamically + self.controller: LocomotionController | None = make_locomotion_controller(config.controller) + + # Controller thread state + self._controller_thread = None + self._controller_action_lock = threading.Lock() + self.controller_input = default_remote_input() + self.controller_output = {} + + def _subscribe_lowstate(self): # polls robot state @ 250Hz while not self._shutdown_event.is_set(): start_time = time.time() @@ -143,11 +170,11 @@ class UnitreeG1(Robot): lowstate = G1_29_LowState() # Capture motor states using jointindex - for id in G1_29_JointIndex: - lowstate.motor_state[id].q = msg.motor_state[id].q - lowstate.motor_state[id].dq = msg.motor_state[id].dq - lowstate.motor_state[id].tau_est = msg.motor_state[id].tau_est - lowstate.motor_state[id].temperature = msg.motor_state[id].temperature + for joint in G1_29_JointIndex: + lowstate.motor_state[joint].q = msg.motor_state[joint].q + lowstate.motor_state[joint].dq = msg.motor_state[joint].dq + lowstate.motor_state[joint].tau_est = msg.motor_state[joint].tau_est + lowstate.motor_state[joint].temperature = msg.motor_state[joint].temperature # Capture IMU state lowstate.imu_state.quaternion = list(msg.imu_state.quaternion) @@ -162,31 +189,106 @@ class UnitreeG1(Robot): # Capture mode_machine lowstate.mode_machine = msg.mode_machine - self._lowstate = lowstate + with self._lowstate_lock: + self._lowstate = lowstate current_time = time.time() all_t_elapsed = current_time - start_time sleep_time = max(0, (self.control_dt - all_t_elapsed)) # maintain constant control dt time.sleep(sleep_time) + def publish_lowcmd( + self, + action: RobotAction, + kp: np.ndarray | list[float] | None = None, + kd: np.ndarray | list[float] | None = None, + tau: np.ndarray | list[float] | None = None, + ) -> None: # writes robot command whenever requested + for motor in G1_29_JointIndex: + key = f"{motor.name}.q" + if key in action: + self.msg.motor_cmd[motor.value].q = action[key] + self.msg.motor_cmd[motor.value].qd = 0 + self.msg.motor_cmd[motor.value].kp = ( + kp[motor.value] if kp is not None else self.kp[motor.value] + ) + self.msg.motor_cmd[motor.value].kd = ( + kd[motor.value] if kd is not None else self.kd[motor.value] + ) + self.msg.motor_cmd[motor.value].tau = tau[motor.value] if tau is not None else 0.0 + + self.msg.crc = self.crc.Crc(self.msg) + self.lowcmd_publisher.Write(self.msg) + + @property + def _cameras_ft(self) -> dict[str, tuple]: + return { + cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras + } + + @cached_property + def observation_features(self) -> dict[str, type | tuple]: + return {**self._motors_ft, **self._cameras_ft} + @cached_property def action_features(self) -> dict[str, type]: - return {f"{G1_29_JointIndex(motor).name}.q": float for motor in G1_29_JointIndex} + if self.controller is None: + return {f"{G1_29_JointIndex(motor).name}.q": float for motor in G1_29_JointIndex} - def calibrate(self) -> None: # robot is already calibrated + arm_features = {f"{G1_29_JointArmIndex(motor).name}.q": float for motor in G1_29_JointArmIndex} + remote_features = dict.fromkeys(REMOTE_AXES, float) + return {**arm_features, **remote_features} + + def _controller_loop(self): + """Background thread that runs controller at policy's control_dt.""" + control_dt = self.controller.control_dt + logger.info(f"Controller loop starting with control_dt={control_dt} ({1.0 / control_dt:.1f}Hz)") + + loop_count = 0 + last_log_time = time.time() + + while not self._shutdown_event.is_set(): + start_time = time.time() + + with self._lowstate_lock: + lowstate = self._lowstate + + if lowstate is not None and self.controller is not None: + loop_count += 1 + if time.time() - last_log_time >= 5.0: # Log every 5 seconds + actual_hz = loop_count / (time.time() - last_log_time) + logger.info( + f"Controller actual rate: {actual_hz:.1f}Hz (target: {1.0 / control_dt:.1f}Hz)" + ) + loop_count = 0 + last_log_time = time.time() + # Read controller input snapshot + with self._controller_action_lock: + controller_input = dict(self.controller_input) + + # Run controller step + controller_action = self.controller.run_step(controller_input, lowstate) + + # Write controller output snapshot + with self._controller_action_lock: + self.controller_output = dict(controller_action) + + ctrl_kp = self.controller.kp if hasattr(self.controller, "kp") else None + ctrl_kd = self.controller.kd if hasattr(self.controller, "kd") else None + self.publish_lowcmd(controller_action, kp=ctrl_kp, kd=ctrl_kd) + + elapsed = time.time() - start_time + sleep_time = max(0, control_dt - elapsed) + time.sleep(sleep_time) + + def calibrate(self) -> None: + # TODO: implement g1_29 calibration pass def configure(self) -> None: pass def connect(self, calibrate: bool = True) -> None: # connect to DDS - from unitree_sdk2py.idl.default import unitree_hg_msg_dds__LowCmd_ - from unitree_sdk2py.idl.unitree_hg.msg.dds_ import ( - LowCmd_ as hg_LowCmd, - LowState_ as hg_LowState, - ) - from unitree_sdk2py.utils.crc import CRC - # Initialize DDS channel and simulation environment if self.config.is_simulation: self._ChannelFactoryInitialize(0, "lo") @@ -194,7 +296,7 @@ class UnitreeG1(Robot): # Extract the actual gym env from the dict structure self.sim_env = self._env_wrapper["hub_env"][0].envs[0] else: - self._ChannelFactoryInitialize(0) + self._ChannelFactoryInitialize(0, config=self.config) # Initialize direct motor control interface self.lowcmd_publisher = self._ChannelPublisher(kTopicLowCommand_Debug, hg_LowCmd) @@ -203,7 +305,7 @@ class UnitreeG1(Robot): self.lowstate_subscriber.Init() # Start subscribe thread to read robot state - self.subscribe_thread = threading.Thread(target=self._subscribe_motor_state) + self.subscribe_thread = threading.Thread(target=self._subscribe_lowstate) self.subscribe_thread.start() # Connect cameras @@ -220,25 +322,53 @@ class UnitreeG1(Robot): # Wait for first state message to arrive lowstate = None + deadline = time.time() + 10.0 while lowstate is None: - lowstate = self._lowstate + with self._lowstate_lock: + lowstate = self._lowstate if lowstate is None: + if time.time() > deadline: + raise TimeoutError("Timed out waiting for robot state (10s)") + logger.warning("[UnitreeG1] Waiting for robot state...") time.sleep(0.01) - logger.warning("[UnitreeG1] Waiting for robot state...") - logger.warning("[UnitreeG1] Connected to robot.") + logger.info("[UnitreeG1] Connected to robot.") self.msg.mode_machine = lowstate.mode_machine - # Initialize all motors with unified kp/kd from config self.kp = np.array(self.config.kp, dtype=np.float32) self.kd = np.array(self.config.kd, dtype=np.float32) - for id in G1_29_JointIndex: - self.msg.motor_cmd[id].mode = 1 - self.msg.motor_cmd[id].kp = self.kp[id.value] - self.msg.motor_cmd[id].kd = self.kd[id.value] - self.msg.motor_cmd[id].q = lowstate.motor_state[id.value].q + for joint in G1_29_JointIndex: + self.msg.motor_cmd[joint].mode = 1 + self.msg.motor_cmd[joint].kp = self.kp[joint.value] + self.msg.motor_cmd[joint].kd = self.kd[joint.value] + self.msg.motor_cmd[joint].q = lowstate.motor_state[joint.value].q + + # Start controller thread if enabled + if self.controller is not None: + self._controller_thread = threading.Thread(target=self._controller_loop, daemon=True) + self._controller_thread.start() + fps = int(1.0 / self.controller.control_dt) + logger.info(f"Controller thread started ({fps}Hz)") + + def _send_zero_torque(self) -> None: + """Send a zero-gain command to make joints passive before shutting down.""" + try: + with self._lowstate_lock: + lowstate = self._lowstate + if lowstate is None: + return + action = {f"{motor.name}.q": lowstate.motor_state[motor.value].q for motor in G1_29_JointIndex} + zero_gains = np.zeros(29, dtype=np.float32) + self.publish_lowcmd(action, kp=zero_gains, kd=zero_gains, tau=zero_gains) + logger.info("Sent zero-torque command for safe shutdown") + except Exception as e: + logger.warning(f"Failed to send zero-torque on disconnect: {e}") def disconnect(self): + # Put robot in passive mode before stopping threads + if not self.config.is_simulation: + self._send_zero_torque() + # Signal thread to stop and unblock any waits self._shutdown_event.set() @@ -248,6 +378,12 @@ class UnitreeG1(Robot): if self.subscribe_thread.is_alive(): logger.warning("Subscribe thread did not stop cleanly") + # Wait for controller thread to finish + if self._controller_thread is not None: + self._controller_thread.join(timeout=2.0) + if self._controller_thread.is_alive(): + logger.warning("Controller thread did not stop cleanly") + # Close simulation environment if self.config.is_simulation and self.sim_env is not None: try: @@ -274,7 +410,8 @@ class UnitreeG1(Robot): cam.disconnect() def get_observation(self) -> RobotObservation: - lowstate = self._lowstate + with self._lowstate_lock: + lowstate = self._lowstate if lowstate is None: return {} @@ -313,14 +450,9 @@ class UnitreeG1(Robot): obs["imu.rpy.pitch"] = lowstate.imu_state.rpy[1] obs["imu.rpy.yaw"] = lowstate.imu_state.rpy[2] - # Controller - parse wireless_remote and add to obs - if lowstate.wireless_remote and len(lowstate.wireless_remote) >= 24: - self.remote_controller.set(lowstate.wireless_remote) - obs["remote.buttons"] = self.remote_controller.button.copy() - obs["remote.lx"] = self.remote_controller.lx - obs["remote.ly"] = self.remote_controller.ly - obs["remote.rx"] = self.remote_controller.rx - obs["remote.ry"] = self.remote_controller.ry + # Wireless remote (raw bytes for teleoperator) + if lowstate.wireless_remote: + obs["wireless_remote"] = lowstate.wireless_remote # Cameras - read images from ZMQ cameras for cam_name, cam in self._cameras.items(): @@ -328,73 +460,63 @@ class UnitreeG1(Robot): return obs + def send_action(self, action: RobotAction) -> RobotAction: + action_to_publish = action + if self.controller is not None: + # Controller thread owns legs/waist. Here we only update joystick inputs + # and publish arm targets from the teleoperator. + self._update_controller_action(action) + arm_prefixes = tuple(j.name for j in G1_29_JointArmIndex) + action_to_publish = { + key: value + for key, value in action.items() + if key.endswith(".q") and key.startswith(arm_prefixes) + } + + tau = None + if self.config.gravity_compensation and self.arm_ik is not None: + tau = np.zeros(29, dtype=np.float32) + action_np = np.array( + [ + action_to_publish.get(f"{joint.name}.q", self.msg.motor_cmd[joint.value].q) + for joint in G1_29_JointArmIndex + ], + dtype=np.float32, + ) + arm_tau = self.arm_ik.solve_tau(action_np) + arm_start_idx = G1_29_JointArmIndex.kLeftShoulderPitch.value + for joint in G1_29_JointArmIndex: + local_idx = joint.value - arm_start_idx + tau[joint.value] = arm_tau[local_idx] + + self.publish_lowcmd(action_to_publish, tau=tau) + return action + + def _update_controller_action(self, action: RobotAction) -> None: + """Update controller input state from incoming teleop action.""" + with self._controller_action_lock: + for key in REMOTE_KEYS: + if key in action: + self.controller_input[key] = action[key] + @property def is_calibrated(self) -> bool: return True @property def is_connected(self) -> bool: - return self._lowstate is not None + with self._lowstate_lock: + return self._lowstate is not None @property def _motors_ft(self) -> dict[str, type]: + """Joint positions for all 29 joints.""" return {f"{G1_29_JointIndex(motor).name}.q": float for motor in G1_29_JointIndex} @property def cameras(self) -> dict: return self._cameras - @property - def _cameras_ft(self) -> dict[str, tuple]: - return { - cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras - } - - @cached_property - def observation_features(self) -> dict[str, type | tuple]: - return {**self._motors_ft, **self._cameras_ft} - - def send_action(self, action: RobotAction) -> RobotAction: - for motor in G1_29_JointIndex: - key = f"{motor.name}.q" - if key in action: - self.msg.motor_cmd[motor.value].q = action[key] - self.msg.motor_cmd[motor.value].qd = 0 - self.msg.motor_cmd[motor.value].kp = self.kp[motor.value] - self.msg.motor_cmd[motor.value].kd = self.kd[motor.value] - self.msg.motor_cmd[motor.value].tau = 0 - - if self.config.gravity_compensation: - # Build action_np from motor commands (arm joints are indices 15-28, local indices 0-13) - action_np = np.zeros(14) - arm_start_idx = G1_29_JointArmIndex.kLeftShoulderPitch.value # 15 - for joint in G1_29_JointArmIndex: - local_idx = joint.value - arm_start_idx - action_np[local_idx] = self.msg.motor_cmd[joint.value].q - tau = self.arm_ik.solve_tau(action_np) - - # Apply tau back to motor commands - for joint in G1_29_JointArmIndex: - local_idx = joint.value - arm_start_idx - self.msg.motor_cmd[joint.value].tau = tau[local_idx] - - self.msg.crc = self.crc.Crc(self.msg) - self.lowcmd_publisher.Write(self.msg) - return action - - def get_gravity_orientation(self, quaternion): # get gravity orientation from quaternion - """Get gravity orientation from quaternion.""" - qw = quaternion[0] - qx = quaternion[1] - qy = quaternion[2] - qz = quaternion[3] - - gravity_orientation = np.zeros(3) - gravity_orientation[0] = 2 * (-qz * qx + qw * qy) - gravity_orientation[1] = -2 * (qz * qy + qw * qx) - gravity_orientation[2] = 1 - 2 * (qw * qw + qz * qz) - return gravity_orientation - def reset( self, control_dt: float | None = None, @@ -407,15 +529,9 @@ class UnitreeG1(Robot): if self.config.is_simulation and self.sim_env is not None: self.sim_env.reset() - - for motor in G1_29_JointIndex: - self.msg.motor_cmd[motor.value].q = default_positions[motor.value] - self.msg.motor_cmd[motor.value].qd = 0 - self.msg.motor_cmd[motor.value].kp = self.kp[motor.value] - self.msg.motor_cmd[motor.value].kd = self.kd[motor.value] - self.msg.motor_cmd[motor.value].tau = 0 - self.msg.crc = self.crc.Crc(self.msg) - self.lowcmd_publisher.Write(self.msg) + self.publish_lowcmd( + {f"{motor.name}.q": float(default_positions[motor.value]) for motor in G1_29_JointIndex} + ) else: total_time = 3.0 num_steps = int(total_time / control_dt) @@ -446,4 +562,8 @@ class UnitreeG1(Robot): sleep_time = max(0, control_dt - elapsed) time.sleep(sleep_time) + # Reset controller internal state (gait phase, obs history, etc.) + if self.controller is not None and hasattr(self.controller, "reset"): + self.controller.reset() + logger.info("Reached default position") diff --git a/src/lerobot/robots/unitree_g1/unitree_sdk2_socket.py b/src/lerobot/robots/unitree_g1/unitree_sdk2_socket.py index ad96df965..0f1f8f8d6 100644 --- a/src/lerobot/robots/unitree_g1/unitree_sdk2_socket.py +++ b/src/lerobot/robots/unitree_g1/unitree_sdk2_socket.py @@ -22,6 +22,8 @@ import zmq from lerobot.robots.unitree_g1.config_unitree_g1 import UnitreeG1Config +# Module-level ZMQ state mirrors the Unitree SDK's global ChannelFactory Singleton. +# Only one robot connection per process is supported. _ctx: zmq.Context | None = None _lowcmd_sock: zmq.Socket | None = None _lowstate_sock: zmq.Socket | None = None @@ -97,17 +99,22 @@ def lowcmd_to_dict(topic: str, msg: Any) -> dict[str, Any]: } -def ChannelFactoryInitialize(*args: Any, **kwargs: Any) -> None: # noqa: N802 +def ChannelFactoryInitialize(domain_id: int = 0, config: Any = None) -> None: # noqa: N802 """ Initialize ZMQ sockets for robot communication. This function mimics the Unitree SDK's ChannelFactoryInitialize but uses ZMQ sockets to connect to the robot server bridge instead of DDS. + + Args: + domain_id: Ignored (for API compatibility with Unitree SDK) + config: UnitreeG1Config instance with robot_ip """ global _ctx, _lowcmd_sock, _lowstate_sock # read socket config - config = UnitreeG1Config() + if config is None: + config = UnitreeG1Config() robot_ip = config.robot_ip ctx = zmq.Context.instance() diff --git a/src/lerobot/scripts/lerobot_record.py b/src/lerobot/scripts/lerobot_record.py index 72708ba23..dc682fe6f 100644 --- a/src/lerobot/scripts/lerobot_record.py +++ b/src/lerobot/scripts/lerobot_record.py @@ -369,6 +369,8 @@ def record_loop( act_processed_policy: RobotAction = make_robot_action(action_values, dataset.features) elif policy is None and isinstance(teleop, Teleoperator): + if robot.name == "unitree_g1": + teleop.send_feedback(obs) act = teleop.get_action() # Applies a pipeline to the raw teleop action, default is IdentityProcessor @@ -556,10 +558,6 @@ def record(cfg: RecordConfig) -> LeRobotDataset: ): log_say("Reset the environment", cfg.play_sounds) - # reset g1 robot - if robot.name == "unitree_g1": - robot.reset() - record_loop( robot=robot, events=events, diff --git a/src/lerobot/scripts/lerobot_teleoperate.py b/src/lerobot/scripts/lerobot_teleoperate.py index dad479b2e..f050d572a 100644 --- a/src/lerobot/scripts/lerobot_teleoperate.py +++ b/src/lerobot/scripts/lerobot_teleoperate.py @@ -60,6 +60,7 @@ import rerun as rr from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401 from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401 +from lerobot.cameras.zmq.configuration_zmq import ZMQCameraConfig # noqa: F401 from lerobot.configs import parser from lerobot.processor import ( RobotAction, @@ -153,7 +154,6 @@ def teleop_loop( display_len = max(len(key) for key in robot.action_features) start = time.perf_counter() - while True: loop_start = time.perf_counter() @@ -163,6 +163,9 @@ def teleop_loop( # given that it is the identity processor as default obs = robot.get_observation() + if robot.name == "unitree_g1": + teleop.send_feedback(obs) + # Get teleop action raw_action = teleop.get_action() diff --git a/src/lerobot/teleoperators/unitree_g1/__init__.py b/src/lerobot/teleoperators/unitree_g1/__init__.py index 45955a0e2..5e67538b8 100644 --- a/src/lerobot/teleoperators/unitree_g1/__init__.py +++ b/src/lerobot/teleoperators/unitree_g1/__init__.py @@ -19,3 +19,13 @@ from .exo_calib import ExoskeletonCalibration, ExoskeletonJointCalibration from .exo_ik import ExoskeletonIKHelper from .exo_serial import ExoskeletonArm from .unitree_g1 import UnitreeG1Teleoperator + +__all__ = [ + "ExoskeletonArmPortConfig", + "ExoskeletonCalibration", + "ExoskeletonIKHelper", + "ExoskeletonJointCalibration", + "ExoskeletonArm", + "UnitreeG1Teleoperator", + "UnitreeG1TeleoperatorConfig", +] diff --git a/src/lerobot/teleoperators/unitree_g1/exo_calib.py b/src/lerobot/teleoperators/unitree_g1/exo_calib.py index 2927a1b55..b90e8fd7e 100644 --- a/src/lerobot/teleoperators/unitree_g1/exo_calib.py +++ b/src/lerobot/teleoperators/unitree_g1/exo_calib.py @@ -35,6 +35,9 @@ import serial logger = logging.getLogger(__name__) +ADC_MAX = 2**12 - 1 +ADC_HALF = ADC_MAX / 2 + # exoskeleton joint names -> ADC channel pairs. TODO: add wrist pitch and wrist yaw JOINTS = { "shoulder_pitch": (0, 1), @@ -59,7 +62,7 @@ class ExoskeletonCalibration: version: int = 2 side: str = "" - adc_max: int = 2**12 - 1 + adc_max: int = ADC_MAX joints: list[ExoskeletonJointCalibration] = field(default_factory=list) def to_dict(self) -> dict: @@ -92,7 +95,7 @@ class ExoskeletonCalibration: return cls( version=data.get("version", 2), side=data.get("side", ""), - adc_max=data.get("adc_max", 2**12 - 1), + adc_max=data.get("adc_max", ADC_MAX), joints=joints, ) @@ -112,11 +115,8 @@ class CalibParams: def normalize_angle(angle: float) -> float: - while angle > np.pi: - angle -= 2 * np.pi - while angle < -np.pi: - angle += 2 * np.pi - return angle + """Normalize angle to [-pi, pi].""" + return float(np.arctan2(np.sin(angle), np.cos(angle))) def joint_z_and_angle(raw16: list[int], j: ExoskeletonJointCalibration) -> tuple[np.ndarray, float]: @@ -125,7 +125,7 @@ def joint_z_and_angle(raw16: list[int], j: ExoskeletonJointCalibration) -> tuple """ pair = JOINTS[j.name] s, c = raw16[pair[0]], raw16[pair[1]] # get sin and cos - p = np.array([float(c) - (2**12 - 1) / 2, float(s) - (2**12 - 1) / 2]) # center the raw values + p = np.array([float(c) - ADC_HALF, float(s) - ADC_HALF]) # center the raw values z = np.asarray(j.T) @ ( p - np.asarray(j.center_fit) ) # center the ellipse and invert the transformation matrix to get unit circle coords @@ -167,7 +167,7 @@ def run_exo_calibration( def read_joint_point(raw16: list[int], pair: tuple[int, int]): s, c = raw16[pair[0]], raw16[pair[1]] - return float(c) - (2**12 - 1) / 2, float(s) - (2**12 - 1) / 2, float(s), float(c) + return float(c) - ADC_HALF, float(s) - ADC_HALF, float(s), float(c) def select_fit_subset(xs, ys): """Select and filter points for ellipse fitting. Trims outliers by radius and downsamples.""" @@ -317,7 +317,7 @@ def run_exo_calibration( calib = ExoskeletonCalibration( version=2, side=side, - adc_max=2**12 - 1, + adc_max=ADC_MAX, joints=[ ExoskeletonJointCalibration( name=j["name"], @@ -367,8 +367,8 @@ def run_exo_calibration( state["win_s"].append(s_raw) state["win_c"].append(c_raw) if len(state["win_s"]) >= max(3, params.median_window): - state["ys"].append(running_median(state["win_s"]) - (2**12 - 1) / 2) - state["xs"].append(running_median(state["win_c"]) - (2**12 - 1) / 2) + state["ys"].append(running_median(state["win_s"]) - ADC_HALF) + state["xs"].append(running_median(state["win_c"]) - ADC_HALF) else: jdata = joints_out[-1] z = np.array(jdata["T"]) @ (np.array([x_raw, y_raw]) - np.array(jdata["center_fit"])) diff --git a/src/lerobot/teleoperators/unitree_g1/exo_ik.py b/src/lerobot/teleoperators/unitree_g1/exo_ik.py index 92519540f..3fd18d2f8 100644 --- a/src/lerobot/teleoperators/unitree_g1/exo_ik.py +++ b/src/lerobot/teleoperators/unitree_g1/exo_ik.py @@ -25,8 +25,8 @@ from dataclasses import dataclass import numpy as np +from lerobot.robots.unitree_g1.g1_kinematics import G1_29_ArmIK from lerobot.robots.unitree_g1.g1_utils import G1_29_JointArmIndex -from lerobot.robots.unitree_g1.robot_kinematic_processor import G1_29_ArmIK from .exo_calib import JOINTS diff --git a/src/lerobot/teleoperators/unitree_g1/exo_serial.py b/src/lerobot/teleoperators/unitree_g1/exo_serial.py index 1211c57cc..4f45997c0 100644 --- a/src/lerobot/teleoperators/unitree_g1/exo_serial.py +++ b/src/lerobot/teleoperators/unitree_g1/exo_serial.py @@ -32,25 +32,29 @@ def parse_raw16(line: bytes) -> list[int] | None: if len(parts) < 16: return None return [int(x) for x in parts[:16]] - except Exception: + except (ValueError, IndexError): return None def read_raw_from_serial(ser) -> list[int] | None: """Read latest sample from serial; if buffer is backed up, keep only the newest.""" - last = None - while ser.in_waiting > 0: - b = ser.readline() - if not b: - break - raw16 = parse_raw16(b) - if raw16 is not None: - last = raw16 - if last is None: - b = ser.readline() - if b: - last = parse_raw16(b) - return last + try: + last = None + while ser.in_waiting > 0: + b = ser.readline() + if not b: + break + raw16 = parse_raw16(b) + if raw16 is not None: + last = raw16 + if last is None: + b = ser.readline() + if b: + last = parse_raw16(b) + return last + except serial.SerialException as e: + logger.warning(f"Serial read error: {e}") + return None @dataclass @@ -115,5 +119,6 @@ class ExoskeletonArm: return {} if raw is None else exo_raw_to_angles(raw, self.calibration) def calibrate(self) -> None: - ser = self._ser - self.calibration = run_exo_calibration(ser, self.side, self.calibration_fpath) + if not self.is_connected: + raise RuntimeError("Cannot calibrate: exoskeleton not connected") + self.calibration = run_exo_calibration(self._ser, self.side, self.calibration_fpath) diff --git a/src/lerobot/teleoperators/unitree_g1/unitree_g1.py b/src/lerobot/teleoperators/unitree_g1/unitree_g1.py index 3779d83ec..242613e7e 100644 --- a/src/lerobot/teleoperators/unitree_g1/unitree_g1.py +++ b/src/lerobot/teleoperators/unitree_g1/unitree_g1.py @@ -17,9 +17,22 @@ import logging import time from functools import cached_property +from typing import TYPE_CHECKING, Any -from lerobot.robots.unitree_g1.g1_utils import G1_29_JointIndex +from lerobot.robots.unitree_g1.g1_utils import REMOTE_AXES, G1_29_JointArmIndex from lerobot.utils.constants import HF_LEROBOT_CALIBRATION, TELEOPERATORS +from lerobot.utils.import_utils import _unitree_sdk_available + +if TYPE_CHECKING or _unitree_sdk_available: + from unitree_sdk2py.utils.joystick import Joystick +else: + + class Joystick: + def __init__(self): + raise ImportError( + "unitree_sdk2py is required for RemoteController. Install with: pip install unitree_sdk2py" + ) + from ..teleoperator import Teleoperator from .config_unitree_g1 import UnitreeG1TeleoperatorConfig @@ -29,6 +42,120 @@ from .exo_serial import ExoskeletonArm logger = logging.getLogger(__name__) +class RemoteController: + """Unitree remote controller data parser for joystick and button state.""" + + # ADC parameters for exoskeleton joystick (12-bit ADC) + ADC_MAX = 4095 + ADC_HALF = ADC_MAX / 2 + JOYSTICK_X_IDX = 11 # X axis in raw ADC array + JOYSTICK_BTN_IDX = 12 # Button in raw ADC array + JOYSTICK_Y_IDX = 13 # Y axis in raw ADC array + + # Map SDK named buttons to positional indices matching the wireless_remote + # byte layout (little-endian uint16 from bytes 2-3). + _BUTTON_MAP: list[str] = [ + "RB", + "LB", + "start", + "back", + "RT", + "LT", + "", + "", + "A", + "B", + "X", + "Y", + "up", + "right", + "down", + "left", + ] + + def __init__(self): + self.lx = 0.0 + self.ly = 0.0 + self.rx = 0.0 + self.ry = 0.0 + self.button = [0] * 16 + self.remote_action = dict.fromkeys(REMOTE_AXES, 0.0) + + # SDK joystick parser for wireless remote bytes + self._joystick = Joystick() + # Disable axis smoothing and deadzone to preserve raw values + for axis in (self._joystick.lx, self._joystick.ly, self._joystick.rx, self._joystick.ry): + axis.smooth = 1.0 + axis.deadzone = 0.0 + + # Joystick center calibration (read at connect time) + self.left_center_x = self.ADC_HALF + self.left_center_y = self.ADC_HALF + self.right_center_x = self.ADC_HALF + self.right_center_y = self.ADC_HALF + + # Whether to use exo joystick (detected at connect time) + self.use_left_exo_joystick = False + self.use_right_exo_joystick = False + + def _sync_remote_action(self) -> None: + self.remote_action.update(zip(REMOTE_AXES, (self.lx, self.ly, self.rx, self.ry), strict=True)) + + def calibrate_center(self, raw16: list[int] | None, side: str) -> None: + if raw16 is None or len(raw16) < 16: + logger.info(f"{side.capitalize()} exo joystick: no data available") + return + + btn_val = raw16[self.JOYSTICK_BTN_IDX] + logger.info(f"{side.capitalize()} exo joystick button ADC: {btn_val} (threshold: {self.ADC_HALF})") + if btn_val <= self.ADC_HALF: + logger.info(f"{side.capitalize()} exo joystick not detected (button below threshold)") + return + + x = raw16[self.JOYSTICK_X_IDX] + y = raw16[self.JOYSTICK_Y_IDX] + if side == "left": + self.use_left_exo_joystick = True + self.left_center_x, self.left_center_y = x, y + else: + self.use_right_exo_joystick = True + self.right_center_x, self.right_center_y = x, y + logger.info(f"{side.capitalize()} exo joystick enabled, center: x={x}, y={y}") + + def set_from_exo(self, raw16: list[int] | None, side: str) -> None: + if raw16 is None or len(raw16) < 16: + return + + if side == "left": + if not self.use_left_exo_joystick: + return + self.lx = (raw16[self.JOYSTICK_X_IDX] - self.left_center_x) / self.ADC_HALF + self.ly = (raw16[self.JOYSTICK_Y_IDX] - self.left_center_y) / self.ADC_HALF + self.button[4] = 1 if raw16[self.JOYSTICK_BTN_IDX] < self.ADC_HALF else 0 + return + + if not self.use_right_exo_joystick: + return + self.rx = (raw16[self.JOYSTICK_X_IDX] - self.right_center_x) / self.ADC_HALF + self.ry = (raw16[self.JOYSTICK_Y_IDX] - self.right_center_y) / self.ADC_HALF + self.button[0] = 1 if raw16[self.JOYSTICK_BTN_IDX] < self.ADC_HALF else 0 + + def set_from_wireless(self, wireless_remote: bytes) -> None: + """Parse Unitree wireless remote raw bytes into joystick + button state.""" + if len(wireless_remote) < 24: + return + self._joystick.extract(wireless_remote) + + self.lx = self._joystick.lx.data + self.ly = self._joystick.ly.data + self.rx = self._joystick.rx.data + self.ry = self._joystick.ry.data + + for i, name in enumerate(self._BUTTON_MAP): + if name: + self.button[i] = getattr(self._joystick, name).data + + class UnitreeG1Teleoperator(Teleoperator): """ Bimanual exoskeleton arms teleoperator for Unitree G1 arms. @@ -43,6 +170,13 @@ class UnitreeG1Teleoperator(Teleoperator): def __init__(self, config: UnitreeG1TeleoperatorConfig): super().__init__(config) self.config = config + left_exo_enabled = bool(config.left_arm_config.port.strip()) + right_exo_enabled = bool(config.right_arm_config.port.strip()) + if left_exo_enabled != right_exo_enabled: + raise ValueError( + "Invalid exo config: set both left/right exo ports, or leave both empty for remote-only mode." + ) + self._arm_control_enabled = left_exo_enabled and right_exo_enabled # Setup calibration directory self.calibration_dir = ( @@ -70,24 +204,37 @@ class UnitreeG1Teleoperator(Teleoperator): ) self.ik_helper: ExoskeletonIKHelper | None = None + self.remote_controller = RemoteController() @cached_property def action_features(self) -> dict[str, type]: - return {f"{name}.q": float for name in self._g1_joint_names} + remote_features = dict.fromkeys(self.remote_controller.remote_action, float) + if not self._arm_control_enabled: + return remote_features + joint_features = {f"{name}.q": float for name in self._g1_arm_joint_names} + return {**joint_features, **remote_features} @cached_property def feedback_features(self) -> dict[str, type]: - return {} + return {"wireless_remote": bytes} @property def is_connected(self) -> bool: + if not self._arm_control_enabled: + return True return self.left_arm.is_connected and self.right_arm.is_connected @property def is_calibrated(self) -> bool: + if not self._arm_control_enabled: + return True return self.left_arm.is_calibrated and self.right_arm.is_calibrated def connect(self, calibrate: bool = True) -> None: + if not self._arm_control_enabled: + logger.warning("Exo ports not fully configured; teleop will send joystick only (no arm actions)") + return + self.left_arm.connect(calibrate) self.right_arm.connect(calibrate) @@ -95,6 +242,13 @@ class UnitreeG1Teleoperator(Teleoperator): self.ik_helper = ExoskeletonIKHelper(frozen_joints=frozen_joints) logger.info("IK helper initialized") + time.sleep(0.1) # Give serial time to populate buffer + + left_raw = self.left_arm.read_raw() + right_raw = self.right_arm.read_raw() + self.remote_controller.calibrate_center(left_raw, "left") + self.remote_controller.calibrate_center(right_raw, "right") + def calibrate(self) -> None: if not self.left_arm.is_calibrated: logger.info("Starting calibration for left arm...") @@ -115,12 +269,33 @@ class UnitreeG1Teleoperator(Teleoperator): pass def get_action(self) -> dict[str, float]: - left_angles = self.left_arm.get_angles() - right_angles = self.right_arm.get_angles() - return self.ik_helper.compute_g1_joints_from_exo(left_angles, right_angles) + joint_action = {} + left_raw = None + right_raw = None + if self._arm_control_enabled: + left_raw = self.left_arm.read_raw() + right_raw = self.right_arm.read_raw() - def send_feedback(self, feedback: dict[str, float]) -> None: - raise NotImplementedError("Exoskeleton arms do not support feedback") + left_angles = self.left_arm.get_angles() + right_angles = self.right_arm.get_angles() + joint_action = self.ik_helper.compute_g1_joints_from_exo(left_angles, right_angles) + + # Wireless remote has priority when non-zero; otherwise, use exo joystick. + rc = self.remote_controller + wireless_active = ( + abs(rc.lx) > 1e-3 or abs(rc.ly) > 1e-3 or abs(rc.rx) > 1e-3 or abs(rc.ry) > 1e-3 + ) or any(rc.button) + if self._arm_control_enabled and not wireless_active: + rc.set_from_exo(left_raw, "left") + rc.set_from_exo(right_raw, "right") + + rc._sync_remote_action() + return {**joint_action, **rc.remote_action} + + def send_feedback(self, feedback: dict[str, Any]) -> None: + wireless_remote = feedback.get("wireless_remote") + if wireless_remote is not None: + self.remote_controller.set_from_wireless(wireless_remote) def disconnect(self) -> None: self.left_arm.disconnect() @@ -153,5 +328,5 @@ class UnitreeG1Teleoperator(Teleoperator): print("\n\nVisualization stopped.") @cached_property - def _g1_joint_names(self) -> list[str]: - return [joint.name for joint in G1_29_JointIndex] + def _g1_arm_joint_names(self) -> list[str]: + return [joint.name for joint in G1_29_JointArmIndex] diff --git a/src/lerobot/utils/import_utils.py b/src/lerobot/utils/import_utils.py index c33a73589..cae445e06 100644 --- a/src/lerobot/utils/import_utils.py +++ b/src/lerobot/utils/import_utils.py @@ -74,6 +74,8 @@ _peft_available = is_package_available("peft") _scipy_available = is_package_available("scipy") _reachy2_sdk_available = is_package_available("reachy2_sdk") _can_available = is_package_available("python-can", "can") +_unitree_sdk_available = is_package_available("unitree-sdk2", "unitree_sdk2py") +_pygame_available = is_package_available("pygame") def make_device_from_device_class(config: ChoiceRegistry) -> Any: diff --git a/tests/robots/test_unitree_g1.py b/tests/robots/test_unitree_g1.py new file mode 100644 index 000000000..8cc85b572 --- /dev/null +++ b/tests/robots/test_unitree_g1.py @@ -0,0 +1,267 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for Unitree G1 robot. Meant to be run in an environment where the Unitree SDK is installed.""" + +from unittest.mock import MagicMock, patch + +import numpy as np +import pytest + +from lerobot.utils.import_utils import _unitree_sdk_available + +if not _unitree_sdk_available: + pytest.skip("Unitree SDK not available", allow_module_level=True) + +from lerobot.robots.unitree_g1.config_unitree_g1 import UnitreeG1Config +from lerobot.robots.unitree_g1.g1_utils import ( + NUM_MOTORS, + REMOTE_AXES, + REMOTE_BUTTONS, + REMOTE_KEYS, + G1_29_JointArmIndex, + G1_29_JointIndex, + default_remote_input, + get_gravity_orientation, +) + +# --------------------------------------------------------------------------- +# Unit tests for g1_utils (no SDK needed) +# --------------------------------------------------------------------------- + + +class TestG1Utils: + def test_num_motors(self): + assert NUM_MOTORS == 29 + + def test_joint_index_count(self): + assert len(G1_29_JointIndex) == 29 + + def test_joint_arm_index_count(self): + assert len(G1_29_JointArmIndex) == 14 + + def test_arm_indices_are_subset_of_full(self): + full_values = {j.value for j in G1_29_JointIndex} + arm_values = {j.value for j in G1_29_JointArmIndex} + assert arm_values.issubset(full_values) + + def test_arm_indices_start_at_15(self): + assert min(j.value for j in G1_29_JointArmIndex) == 15 + assert max(j.value for j in G1_29_JointArmIndex) == 28 + + def test_enum_naming_consistency(self): + """Verify all wrist joints use consistent PascalCase naming.""" + wrist_joints = [j for j in G1_29_JointIndex if "Wrist" in j.name] + for j in wrist_joints: + # Should be "WristYaw", "WristPitch", "WristRoll" — no lowercase after "Wrist" + after_wrist = j.name.split("Wrist")[1] + assert after_wrist[0].isupper(), f"{j.name} has inconsistent casing after 'Wrist'" + + def test_remote_keys_structure(self): + assert len(REMOTE_AXES) == 4 + assert len(REMOTE_BUTTONS) == 16 + assert len(REMOTE_KEYS) == 20 + assert REMOTE_KEYS == REMOTE_AXES + REMOTE_BUTTONS + + def test_default_remote_input(self): + d = default_remote_input() + assert len(d) == 20 + assert all(v == 0.0 for v in d.values()) + assert set(d.keys()) == set(REMOTE_KEYS) + + def test_gravity_orientation_identity(self): + """Quaternion [1, 0, 0, 0] (no rotation) should give gravity along -z.""" + g = get_gravity_orientation([1.0, 0.0, 0.0, 0.0]) + assert g.shape == (3,) + assert g.dtype == np.float32 + np.testing.assert_allclose(g, [0.0, 0.0, -1.0], atol=1e-6) + + def test_gravity_orientation_dtype(self): + g = get_gravity_orientation(np.array([1.0, 0.0, 0.0, 0.0])) + assert g.dtype == np.float32 + + +# --------------------------------------------------------------------------- +# Unit tests for UnitreeG1Config (no SDK needed) +# --------------------------------------------------------------------------- + + +class TestUnitreeG1Config: + def test_default_config(self): + cfg = UnitreeG1Config() + assert len(cfg.kp) == 29 + assert len(cfg.kd) == 29 + assert len(cfg.default_positions) == 29 + assert cfg.is_simulation is True + assert cfg.controller is None + assert cfg.gravity_compensation is False + + def test_gains_are_positive(self): + cfg = UnitreeG1Config() + assert all(v > 0 for v in cfg.kp) + assert all(v > 0 for v in cfg.kd) + + def test_config_copies_gains(self): + """Each config instance should have its own copy of gains.""" + cfg1 = UnitreeG1Config() + cfg2 = UnitreeG1Config() + cfg1.kp[0] = 999.0 + assert cfg2.kp[0] != 999.0 + + +# --------------------------------------------------------------------------- +# Robot mock and integration tests +# --------------------------------------------------------------------------- + + +def _make_lowstate_msg_mock(): + """Create a mock that mimics the SDK LowState_ message.""" + msg = MagicMock() + for i in range(29): + motor = MagicMock() + motor.q = float(i) * 0.1 + motor.dq = float(i) * 0.01 + motor.tau_est = float(i) * 0.001 + motor.temperature = 30.0 + i + msg.motor_state.__getitem__ = lambda self, idx, _motors={}: _motors.setdefault( + idx, MagicMock(q=idx * 0.1, dq=idx * 0.01, tau_est=idx * 0.001, temperature=30.0 + idx) + ) + + msg.imu_state.quaternion = [1.0, 0.0, 0.0, 0.0] + msg.imu_state.gyroscope = [0.1, 0.2, 0.3] + msg.imu_state.accelerometer = [0.0, 0.0, 9.81] + msg.imu_state.rpy = [0.0, 0.0, 0.0] + msg.imu_state.temperature = 25.0 + msg.wireless_remote = b"\x00" * 40 + msg.mode_machine = 0 + return msg + + +def _make_sdk_mocks(): + """Create mocks for the Unitree SDK modules used by UnitreeG1.""" + lowcmd_default = MagicMock() + lowcmd_default.mode_pr = 0 + lowcmd_default.motor_cmd = [MagicMock() for _ in range(35)] + + crc_mock = MagicMock() + crc_mock.Crc.return_value = 0 + + lowstate_msg = _make_lowstate_msg_mock() + + subscriber_mock = MagicMock() + subscriber_mock.Read.return_value = lowstate_msg + + publisher_mock = MagicMock() + + return { + "lowcmd_default": lowcmd_default, + "crc_mock": crc_mock, + "subscriber_mock": subscriber_mock, + "publisher_mock": publisher_mock, + "lowstate_msg": lowstate_msg, + } + + +@pytest.fixture +def unitree_g1(): + """Create a UnitreeG1 robot with all SDK dependencies mocked.""" + mocks = _make_sdk_mocks() + + mock_channel_init = MagicMock() + mock_channel_pub = MagicMock(return_value=mocks["publisher_mock"]) + mock_channel_sub = MagicMock(return_value=mocks["subscriber_mock"]) + + with ( + patch( + "lerobot.robots.unitree_g1.unitree_g1.make_cameras_from_configs", + return_value={}, + ), + patch( + "lerobot.robots.unitree_g1.unitree_g1.G1_29_ArmIK", + return_value=MagicMock(), + ), + patch( + "lerobot.robots.unitree_g1.unitree_g1._SDKChannelFactoryInitialize", + mock_channel_init, + ), + patch( + "lerobot.robots.unitree_g1.unitree_g1._SDKChannelPublisher", + mock_channel_pub, + ), + patch( + "lerobot.robots.unitree_g1.unitree_g1._SDKChannelSubscriber", + mock_channel_sub, + ), + patch( + "lerobot.robots.unitree_g1.unitree_g1.unitree_hg_msg_dds__LowCmd_", + MagicMock(return_value=mocks["lowcmd_default"]), + ), + patch( + "lerobot.robots.unitree_g1.unitree_g1.hg_LowCmd", + MagicMock, + ), + patch( + "lerobot.robots.unitree_g1.unitree_g1.hg_LowState", + MagicMock, + ), + patch( + "lerobot.robots.unitree_g1.unitree_g1.CRC", + MagicMock(return_value=mocks["crc_mock"]), + ), + ): + from lerobot.robots.unitree_g1.unitree_g1 import UnitreeG1 + + cfg = UnitreeG1Config(is_simulation=True, gravity_compensation=False) + robot = UnitreeG1(cfg) + yield robot, mocks + if robot.is_connected: + robot.disconnect() + + +def test_init_state(unitree_g1): + robot, _ = unitree_g1 + assert not robot.is_connected + assert robot.controller is None + + +def test_observation_features(unitree_g1): + robot, _ = unitree_g1 + features = robot.observation_features + # Should have .q for all 29 joints (no cameras configured) + assert len(features) == 29 + for joint in G1_29_JointIndex: + assert f"{joint.name}.q" in features + + +def test_action_features_no_controller(unitree_g1): + robot, _ = unitree_g1 + features = robot.action_features + # Without controller: all 29 joints + assert len(features) == 29 + for joint in G1_29_JointIndex: + assert f"{joint.name}.q" in features + + +def test_get_observation_before_connect(unitree_g1): + robot, _ = unitree_g1 + obs = robot.get_observation() + assert obs == {} + + +def test_disconnect_idempotent(unitree_g1): + robot, _ = unitree_g1 + # Should not raise even when not connected + robot.disconnect() diff --git a/tests/teleoperators/test_unitree_g1_teleoperator.py b/tests/teleoperators/test_unitree_g1_teleoperator.py new file mode 100644 index 000000000..52f4a8482 --- /dev/null +++ b/tests/teleoperators/test_unitree_g1_teleoperator.py @@ -0,0 +1,309 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for Unitree G1 teleoperator. Meant to be run in an environment where the Unitree SDK is installed.""" + +from unittest.mock import MagicMock + +import pytest + +from lerobot.utils.import_utils import _unitree_sdk_available + +if not _unitree_sdk_available: + pytest.skip("Unitree SDK not available", allow_module_level=True) + +from lerobot.robots.unitree_g1.g1_utils import REMOTE_AXES +from lerobot.teleoperators.unitree_g1.config_unitree_g1 import ( + ExoskeletonArmPortConfig, + UnitreeG1TeleoperatorConfig, +) +from lerobot.teleoperators.unitree_g1.unitree_g1 import RemoteController, UnitreeG1Teleoperator + +# --------------------------------------------------------------------------- +# Tests for RemoteController +# --------------------------------------------------------------------------- + + +def _make_joystick_mock(): + """Create a mock Joystick class matching the SDK interface.""" + joystick = MagicMock() + # Axes are Axis objects with .data attribute + joystick.lx = MagicMock(data=0.0, smooth=0.03, deadzone=0.01) + joystick.ly = MagicMock(data=0.0, smooth=0.03, deadzone=0.01) + joystick.rx = MagicMock(data=0.0, smooth=0.03, deadzone=0.01) + joystick.ry = MagicMock(data=0.0, smooth=0.03, deadzone=0.01) + # Buttons are Button objects with .data attribute + for name in ["RB", "LB", "start", "back", "RT", "LT", "A", "B", "X", "Y", "up", "right", "down", "left"]: + setattr(joystick, name, MagicMock(data=0)) + return joystick + + +@pytest.fixture +def remote_controller(): + """Create a RemoteController with a mocked Joystick.""" + mock_joystick = _make_joystick_mock() + + rc = RemoteController() + rc._joystick = mock_joystick + yield rc, mock_joystick + + +def test_remote_controller_init(remote_controller): + rc, _ = remote_controller + assert rc.lx == 0.0 + assert rc.ly == 0.0 + assert rc.rx == 0.0 + assert rc.ry == 0.0 + assert len(rc.button) == 16 + assert all(b == 0 for b in rc.button) + + +def test_sync_remote_action(remote_controller): + rc, _ = remote_controller + rc.lx = 0.5 + rc.ly = -0.3 + rc.rx = 0.1 + rc.ry = 0.0 + rc._sync_remote_action() + + assert rc.remote_action["remote.lx"] == 0.5 + assert rc.remote_action["remote.ly"] == -0.3 + assert rc.remote_action["remote.rx"] == 0.1 + assert rc.remote_action["remote.ry"] == 0.0 + + +def test_set_from_wireless_calls_extract(remote_controller): + rc, mock_joystick = remote_controller + # Set up the mock to populate data after extract + mock_joystick.lx.data = 0.5 + mock_joystick.ly.data = -0.3 + mock_joystick.rx.data = 0.1 + mock_joystick.ry.data = 0.0 + + wireless_data = b"\x00" * 40 + rc.set_from_wireless(wireless_data) + + mock_joystick.extract.assert_called_once_with(wireless_data) + assert rc.lx == 0.5 + assert rc.ly == -0.3 + + +def test_set_from_wireless_short_data(remote_controller): + rc, mock_joystick = remote_controller + rc.set_from_wireless(b"\x00" * 10) # Too short + mock_joystick.extract.assert_not_called() + + +def test_set_from_wireless_buttons(remote_controller): + rc, mock_joystick = remote_controller + # Simulate RB pressed + mock_joystick.RB.data = 1 + mock_joystick.lx.data = 0.0 + mock_joystick.ly.data = 0.0 + mock_joystick.rx.data = 0.0 + mock_joystick.ry.data = 0.0 + + rc.set_from_wireless(b"\x00" * 40) + assert rc.button[0] == 1 # RB maps to button[0] + + +def test_set_from_exo_left(remote_controller): + rc, _ = remote_controller + rc.use_left_exo_joystick = True + rc.left_center_x = 2048 + rc.left_center_y = 2048 + + raw16 = [0] * 16 + raw16[11] = 3048 # X axis: (3048 - 2048) / 2047.5 ≈ 0.488 + raw16[13] = 1048 # Y axis: (1048 - 2048) / 2047.5 ≈ -0.488 + raw16[12] = 0 # Button pressed (below ADC_HALF) + + rc.set_from_exo(raw16, "left") + assert rc.lx == pytest.approx((3048 - 2048) / 2047.5, abs=1e-3) + assert rc.ly == pytest.approx((1048 - 2048) / 2047.5, abs=1e-3) + assert rc.button[4] == 1 # Left button maps to button[4] + + +def test_set_from_exo_clears_button(remote_controller): + rc, _ = remote_controller + rc.use_left_exo_joystick = True + rc.button[4] = 1 # Pre-set + + raw16 = [0] * 16 + raw16[12] = 4000 # Button NOT pressed (above ADC_HALF) + + rc.set_from_exo(raw16, "left") + assert rc.button[4] == 0 # Should be cleared + + +def test_set_from_exo_ignored_when_not_enabled(remote_controller): + rc, _ = remote_controller + rc.use_left_exo_joystick = False + raw16 = [0] * 16 + raw16[11] = 3000 + + rc.set_from_exo(raw16, "left") + assert rc.lx == 0.0 # Unchanged + + +# --------------------------------------------------------------------------- +# Tests for UnitreeG1TeleoperatorConfig (no SDK needed) +# --------------------------------------------------------------------------- + + +class TestTeleoperatorConfig: + def test_default_config(self): + cfg = UnitreeG1TeleoperatorConfig() + assert cfg.left_arm_config.port == "" + assert cfg.right_arm_config.port == "" + assert cfg.frozen_joints == "" + + def test_config_with_ports(self): + cfg = UnitreeG1TeleoperatorConfig( + left_arm_config=ExoskeletonArmPortConfig(port="/dev/ttyACM0"), + right_arm_config=ExoskeletonArmPortConfig(port="/dev/ttyACM1"), + ) + assert cfg.left_arm_config.port == "/dev/ttyACM0" + assert cfg.right_arm_config.port == "/dev/ttyACM1" + + +# --------------------------------------------------------------------------- +# Tests for UnitreeG1Teleoperator +# --------------------------------------------------------------------------- + + +@pytest.fixture +def teleop_remote_only(): + """Create a UnitreeG1Teleoperator in remote-only mode (no exo arms).""" + cfg = UnitreeG1TeleoperatorConfig() # No ports = remote-only mode + teleop = UnitreeG1Teleoperator(cfg) + yield teleop + + +def test_remote_only_connect(teleop_remote_only): + """Remote-only mode should connect immediately without serial ports.""" + teleop = teleop_remote_only + teleop.connect() + assert teleop.is_connected + assert not teleop._arm_control_enabled + + +def test_remote_only_action_features(teleop_remote_only): + teleop = teleop_remote_only + features = teleop.action_features + # Remote-only: just the 4 remote axes + assert set(features.keys()) == set(REMOTE_AXES) + + +def test_feedback_features(teleop_remote_only): + teleop = teleop_remote_only + features = teleop.feedback_features + assert "wireless_remote" in features + assert features["wireless_remote"] is bytes + + +def test_remote_only_get_action(teleop_remote_only): + teleop = teleop_remote_only + teleop.connect() + action = teleop.get_action() + assert set(action.keys()) == set(REMOTE_AXES) + assert all(isinstance(v, float) for v in action.values()) + + +def test_send_feedback(teleop_remote_only): + teleop = teleop_remote_only + teleop.connect() + # Should not raise + teleop.send_feedback({"wireless_remote": b"\x00" * 40}) + + +def test_send_feedback_missing_key(teleop_remote_only): + teleop = teleop_remote_only + teleop.connect() + # Should not raise even with missing key + teleop.send_feedback({"other_key": 42}) + + +def test_asymmetric_exo_ports_raises(): + """Configuring only one exo port should raise ValueError.""" + cfg = UnitreeG1TeleoperatorConfig( + left_arm_config=ExoskeletonArmPortConfig(port="/dev/ttyACM0"), + # right_arm_config left empty + ) + with pytest.raises(ValueError, match="set both left/right"): + UnitreeG1Teleoperator(cfg) + + +# --------------------------------------------------------------------------- +# Tests for ExoskeletonArm (needs serial mock) +# --------------------------------------------------------------------------- + + +class TestExoskeletonArm: + def test_parse_raw16_valid(self): + from lerobot.teleoperators.unitree_g1.exo_serial import parse_raw16 + + line = b"100 200 300 400 500 600 700 800 900 1000 1100 1200 1300 1400 1500 1600\n" + result = parse_raw16(line) + assert result is not None + assert len(result) == 16 + assert result[0] == 100 + assert result[15] == 1600 + + def test_parse_raw16_too_short(self): + from lerobot.teleoperators.unitree_g1.exo_serial import parse_raw16 + + line = b"100 200 300\n" + assert parse_raw16(line) is None + + def test_parse_raw16_garbage(self): + from lerobot.teleoperators.unitree_g1.exo_serial import parse_raw16 + + assert parse_raw16(b"not numbers at all\n") is None + assert parse_raw16(b"\xff\xfe\xfd\n") is None + assert parse_raw16(b"") is None + + def test_calibrate_requires_connection(self): + from lerobot.teleoperators.unitree_g1.exo_serial import ExoskeletonArm + + arm = ExoskeletonArm( + port="/dev/null", + calibration_fpath=MagicMock(is_file=MagicMock(return_value=False)), + side="left", + ) + with pytest.raises(RuntimeError, match="not connected"): + arm.calibrate() + + def test_is_connected_false_by_default(self): + from lerobot.teleoperators.unitree_g1.exo_serial import ExoskeletonArm + + arm = ExoskeletonArm( + port="/dev/null", + calibration_fpath=MagicMock(is_file=MagicMock(return_value=False)), + side="left", + ) + assert not arm.is_connected + assert not arm.is_calibrated + + def test_read_raw_when_disconnected(self): + from lerobot.teleoperators.unitree_g1.exo_serial import ExoskeletonArm + + arm = ExoskeletonArm( + port="/dev/null", + calibration_fpath=MagicMock(is_file=MagicMock(return_value=False)), + side="left", + ) + assert arm.read_raw() is None From 2fb5c7add07e0afe19f020399481f736f7d64d24 Mon Sep 17 00:00:00 2001 From: Ignat Georgiev Date: Sun, 8 Mar 2026 13:29:33 +0200 Subject: [PATCH 097/131] feat(train): add cudnn_deterministic option for reproducible training (#3102) Add a `cudnn_deterministic` flag to `TrainPipelineConfig` (default: False) that sets `torch.backends.cudnn.deterministic = True` and disables benchmark mode, eliminating CUDA floating-point non-determinism at the cost of ~10-20% training speed. When False (default) the existing benchmark=True behaviour is preserved. --- src/lerobot/configs/train.py | 3 +++ src/lerobot/scripts/lerobot_train.py | 6 +++++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/src/lerobot/configs/train.py b/src/lerobot/configs/train.py index 7a5eee77d..9d20afc68 100644 --- a/src/lerobot/configs/train.py +++ b/src/lerobot/configs/train.py @@ -50,6 +50,9 @@ class TrainPipelineConfig(HubMixin): # `seed` is used for training (eg: model initialization, dataset shuffling) # AND for the evaluation environments. seed: int | None = 1000 + # Set to True to use deterministic cuDNN algorithms for reproducibility. + # This disables cudnn.benchmark and may reduce training speed by ~10-20%. + cudnn_deterministic: bool = False # Number of workers for the dataloader. num_workers: int = 4 batch_size: int = 8 diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index 04d43d91e..1fed3bee4 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -209,7 +209,11 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): # Use accelerator's device device = accelerator.device - torch.backends.cudnn.benchmark = True + if cfg.cudnn_deterministic: + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + else: + torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True # Dataset loading synchronization: main process downloads first to avoid race conditions From 1e131f93f8260b3d32f035bcc253057ca25c9bb8 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Sun, 8 Mar 2026 13:00:06 +0100 Subject: [PATCH 098/131] chore(docs): add uv installation instructions (#3105) * chore(docs): add uv installation instructions * fix(docs): format tabs * chore(docs): small details * chore(docs): last details uv installation instructions * chore(docs): last detail --- Co-authored-by: sahilmaniyar888 <156301258+sahilmaniyar888@users.noreply.github.com> --- docs/source/installation.mdx | 70 ++++++++++++++++++++++++++++++------ 1 file changed, 60 insertions(+), 10 deletions(-) diff --git a/docs/source/installation.mdx b/docs/source/installation.mdx index 6d29215a0..80f705e88 100644 --- a/docs/source/installation.mdx +++ b/docs/source/installation.mdx @@ -1,8 +1,8 @@ # Installation -This guide uses conda (via miniforge) to manage environments. If you prefer another environment manager (e.g. `uv`, `venv`), ensure you have Python >=3.12 and ffmpeg installed with the `libsvtav1` encoder, then skip ahead to [Install LeRobot](#step-3-install-lerobot-). +This guide uses `conda` (via miniforge) to manage environments (recommended). If you prefer another environment manager (e.g. `uv`, `venv`), ensure you have Python >=3.12 and `ffmpeg` installed with the `libsvtav1` encoder, then skip ahead to [Environment Setup](#step-2-environment-setup). -## Step 1: Install [`miniforge`](https://conda-forge.org/download/) +## Step 1 (`conda` only): Install [`miniforge`](https://conda-forge.org/download/) ```bash wget "https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-$(uname)-$(uname -m).sh" @@ -11,22 +11,47 @@ bash Miniforge3-$(uname)-$(uname -m).sh ## Step 2: Environment Setup -Create a virtual environment with Python 3.12, using conda: +Create a virtual environment with Python 3.12: + + + ```bash conda create -y -n lerobot python=3.12 ``` - -Then activate your conda environment, you have to do this each time you open a shell to use lerobot: - + + ```bash -conda activate lerobot +uv python install 3.12 +uv venv --python 3.12 ``` + + + + +Then activate your virtual environment, you have to do this each time you open a shell to use lerobot: + + + +```bash +conda activate lerobot +``` + +```bash +# Linux/macOSsource +source .venv/bin/activate +# Windows PowerShell +source .venv\Scripts\Activate.ps1 +``` + + + When using `conda`, install `ffmpeg` in your environment: ```bash conda install ffmpeg -c conda-forge +ffmpeg -version # ffmpeg 8.X is not yet supported ! ``` > [!TIP] @@ -47,6 +72,9 @@ conda install ffmpeg -c conda-forge > conda install evdev -c conda-forge > ``` +> [!IMPORTANT] +> If you are using `uv` you will have to install `ffmpeg` system-wide (outside of the virtual environment). You rely on `uv` and `torchcodec` ability to dynamically link to the system `ffmpeg`. + ## Step 3: Install LeRobot 🤗 ### From Source @@ -60,23 +88,45 @@ cd lerobot Then, install the library in editable mode. This is useful if you plan to contribute to the code. + + + ```bash pip install -e . ``` + + +```bash +uv pip install -e . +``` + + + ### Installation from PyPI **Core Library:** Install the base package with: + + + ```bash pip install lerobot ``` + + +```bash +uv pip install lerobot +``` + + + _This installs only the default dependencies._ **Extra Features:** -To install additional functionality, use one of the following: +To install additional functionality, use one of the following (If you are using `uv`, replace `pip install` with `uv pip install` in the commands below.): ```bash pip install 'lerobot[all]' # All available features @@ -93,7 +143,7 @@ https://pypi.org/project/lerobot/ ### Troubleshooting If you encounter build errors, you may need to install additional dependencies: `cmake`, `build-essential`, and `ffmpeg libs`. -To install these for linux run: +To install these for Linux run: ```bash sudo apt-get install cmake build-essential python3-dev pkg-config libavformat-dev libavcodec-dev libavdevice-dev libavutil-dev libswscale-dev libswresample-dev libavfilter-dev @@ -103,7 +153,7 @@ For other systems, see: [Compiling PyAV](https://pyav.org/docs/develop/overview/ ## Optional dependencies -LeRobot provides optional extras for specific functionalities. Multiple extras can be combined (e.g., `.[aloha,feetech]`). For all available extras, refer to `pyproject.toml`. +LeRobot provides optional extras for specific functionalities. Multiple extras can be combined (e.g., `.[aloha,feetech]`). For all available extras, refer to `pyproject.toml`. If you are using `uv`, replace `pip install` with `uv pip install` in the commands below. ### Simulations From c17d9495310fa792d1ac8e3bed24c53d829224a3 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Sun, 8 Mar 2026 14:01:43 +0100 Subject: [PATCH 099/131] chore(readme): update citation with ICLR26 paper (#3107) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * peer reviewed citation 🎉 Signed-off-by: Francesco Capuano <74058581+fracapuano@users.noreply.github.com> * add iclr year Signed-off-by: Francesco Capuano <74058581+fracapuano@users.noreply.github.com> * fix quentin's spelling name Signed-off-by: Francesco Capuano <74058581+fracapuano@users.noreply.github.com> * docs(readme): update citation --------- Signed-off-by: Francesco Capuano <74058581+fracapuano@users.noreply.github.com> Co-authored-by: Francesco Capuano <74058581+fracapuano@users.noreply.github.com> --- README.md | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index d60cd35a9..e273a4de8 100644 --- a/README.md +++ b/README.md @@ -135,7 +135,7 @@ Learn how to implement your own simulation environment or benchmark and distribu ## Citation -If you use LeRobot in your research, please cite: +If you use LeRobot in your project, please cite the GitHub repository to acknowledge the ongoing development and contributors: ```bibtex @misc{cadene2024lerobot, @@ -146,6 +146,23 @@ If you use LeRobot in your research, please cite: } ``` +If you are referencing our research or the academic paper, please also cite our ICLR publication: + +
+ICLR 2026 Paper + +```bibtex +@inproceedings{cadenelerobot, + title={LeRobot: An Open-Source Library for End-to-End Robot Learning}, + author={Cadene, Remi and Alibert, Simon and Capuano, Francesco and Aractingi, Michel and Zouitine, Adil and Kooijmans, Pepijn and Choghari, Jade and Russi, Martino and Pascal, Caroline and Palma, Steven and Shukor, Mustafa and Moss, Jess and Soare, Alexander and Aubakirova, Dana and Lhoest, Quentin and Gallou\'edec, Quentin and Wolf, Thomas}, + booktitle={The Fourteenth International Conference on Learning Representations}, + year={2026}, + url={https://arxiv.org/abs/2602.22818} +} +``` + +
+ ## Contribute We welcome contributions from everyone in the community! To get started, please read our [CONTRIBUTING.md](./CONTRIBUTING.md) guide. Whether you're adding a new feature, improving documentation, or fixing a bug, your help and feedback are invaluable. We're incredibly excited about the future of open-source robotics and can't wait to work with you on what's next—thank you for your support! From db8547e35df2e47928f51bbdead70b771587d2fe Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Sun, 8 Mar 2026 14:02:33 +0100 Subject: [PATCH 100/131] test(cameras): skip flaky async_read test (#3106) --- tests/cameras/test_opencv.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/cameras/test_opencv.py b/tests/cameras/test_opencv.py index feb700631..720d0c9b3 100644 --- a/tests/cameras/test_opencv.py +++ b/tests/cameras/test_opencv.py @@ -170,6 +170,7 @@ def test_async_read(index_or_path): assert isinstance(img, np.ndarray) +@pytest.mark.skip("Skipping test: async_read 0 timeout behavior may be flaky/non-deterministic.") def test_async_read_timeout(): config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH, warmup_s=0) From 5c51a744840b524d83e474ee71b0fceec49ae66f Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Mon, 9 Mar 2026 11:18:05 +0100 Subject: [PATCH 101/131] chore(deps): update requirements file (#3114) --- requirements-macos.txt | 441 ++++++++++++++++------------------------ requirements-ubuntu.txt | 397 +++++++++++++++++++----------------- requirements.in | 8 +- 3 files changed, 383 insertions(+), 463 deletions(-) diff --git a/requirements-macos.txt b/requirements-macos.txt index dc90416a3..c5bbe1c8a 100644 --- a/requirements-macos.txt +++ b/requirements-macos.txt @@ -1,76 +1,73 @@ # -# This file is autogenerated by pip-compile with Python 3.10 +# This file is autogenerated by pip-compile with Python 3.12 # by the following command: # # pip-compile --output-file=requirements-macos.txt requirements.in # -e .[all] # via -[all] -absl-py==2.3.1 +absl-py==2.4.0 # via # dm-control # dm-env # dm-tree # labmaze # mujoco - # tensorboard -accelerate==1.11.0 +accelerate==1.13.0 # via # lerobot # peft aiohappyeyeballs==2.6.1 # via aiohttp -aiohttp==3.13.1 +aiohttp==3.13.3 # via fsspec aiosignal==1.4.0 # via aiohttp +annotated-doc==0.0.4 + # via + # fastapi + # typer annotated-types==0.7.0 # via pydantic -antlr4-python3-runtime==4.9.3 - # via - # hydra-core - # omegaconf -anyio==4.11.0 +anyio==4.12.1 # via + # httpx # starlette # watchfiles -asttokens==3.0.0 +asttokens==3.0.1 # via stack-data -async-timeout==5.0.1 - # via aiohttp attrs==25.4.0 # via # aiohttp # dm-tree # jsonlines - # jsonschema - # referencing # rerun-sdk av==15.1.0 - # via lerobot -bddl==1.0.1 - # via libero -certifi==2025.10.5 # via + # lerobot + # qwen-vl-utils +certifi==2026.2.25 + # via + # httpcore + # httpx # requests # sentry-sdk cffi==2.0.0 # via pymunk -cfgv==3.4.0 +cfgv==3.5.0 # via pre-commit -charset-normalizer==3.4.4 +charset-normalizer==3.4.5 # via requests -click==8.3.0 +click==8.3.1 # via + # typer # uvicorn # wandb -cloudpickle==3.1.1 - # via - # gymnasium - # libero -cmake==4.1.0 +cloudpickle==3.1.2 + # via gymnasium +cmake==4.1.3 # via lerobot -cmeel==0.57.3 +cmeel==0.59.0 # via # cmeel-assimp # cmeel-boost @@ -108,15 +105,17 @@ cmeel-zlib==1.3.1 # via cmeel-assimp coal-library==3.0.1 # via pin -contourpy==1.3.2 - # via matplotlib -coverage[toml]==7.11.0 +contourpy==1.3.3 + # via + # lerobot + # matplotlib +coverage[toml]==7.13.4 # via pytest-cov cycler==0.12.1 # via matplotlib -datasets==4.1.1 +datasets==4.6.1 # via lerobot -debugpy==1.8.17 +debugpy==1.8.20 # via lerobot decorator==5.2.1 # via ipython @@ -130,7 +129,7 @@ dill==0.4.0 # multiprocess distlib==0.4.0 # via virtualenv -dm-control==1.0.34 +dm-control==1.0.37 # via gym-aloha dm-env==1.6 # via dm-control @@ -138,69 +137,55 @@ dm-tree==0.1.9 # via # dm-control # dm-env - # lerobot docopt==0.6.2 # via num2words draccus==0.10.0 # via lerobot dynamixel-sdk==3.8.4 # via lerobot -easydict==1.13 - # via libero -egl-probe @ git+https://github.com/huggingface/egl_probe.git - # via - # libero - # robomimic eigenpy==3.10.3 # via coal-library -einops==0.8.1 - # via - # lerobot - # libero +einops==0.8.2 + # via lerobot eiquadprog==1.2.9 # via placo -etils[epath,epy]==1.13.0 +etils[epath,epy]==1.14.0 # via mujoco -exceptiongroup==1.3.0 - # via - # anyio - # ipython - # pytest executing==2.2.1 # via stack-data +faker==34.0.2 + # via lerobot farama-notifications==0.0.4 # via gymnasium -fastapi==0.119.1 - # via teleop -fastjsonschema==2.21.2 - # via nbformat +fastapi==0.135.1 + # via + # lerobot + # teleop feetech-servo-sdk==1.0.0 # via lerobot -filelock==3.20.0 +filelock==3.25.0 # via # datasets # diffusers # huggingface-hub + # python-discovery # torch - # transformers # virtualenv -fonttools==4.60.1 +fonttools==4.61.1 # via matplotlib frozenlist==1.8.0 # via # aiohttp # aiosignal -fsspec[http]==2025.9.0 +fsspec[http]==2026.2.0 # via # datasets # etils # huggingface-hub # torch -future==1.0.0 - # via libero gitdb==4.0.12 # via gitpython -gitpython==3.1.45 +gitpython==3.1.46 # via wandb glfw==2.10.0 # via @@ -212,7 +197,6 @@ grpcio==1.73.1 # lerobot # reachy2-sdk # reachy2-sdk-api - # tensorboard grpcio-tools==1.73.1 # via # lerobot @@ -223,71 +207,67 @@ gym-hil==0.1.13 # via lerobot gym-pusht==0.1.6 # via lerobot -gymnasium==1.2.1 +gymnasium==1.2.3 # via # gym-aloha # gym-hil # gym-pusht # lerobot - # libero # metaworld h11==0.16.0 - # via uvicorn -h5py==3.15.1 - # via robomimic + # via + # httpcore + # uvicorn hebi-py==2.11.0 # via lerobot -hf-transfer==0.1.9 - # via huggingface-hub -hf-xet==1.1.10 +hf-xet==1.3.2 # via huggingface-hub hidapi==0.14.0.post4 # via # gym-hil # lerobot +httpcore==1.0.9 + # via httpx httptools==0.7.1 # via uvicorn -huggingface-hub[cli,hf-transfer]==0.35.3 +httpx==0.28.1 + # via + # datasets + # huggingface-hub +huggingface-hub==1.6.0 # via # accelerate # datasets # diffusers # lerobot # peft - # timm # tokenizers # transformers -hydra-core==1.3.2 - # via libero -identify==2.6.15 +identify==2.6.17 # via pre-commit idna==3.11 # via # anyio + # httpx # requests # yarl -imageio[ffmpeg]==2.37.0 +imageio[ffmpeg]==2.37.2 # via # gym-aloha # gym-hil # lerobot # metaworld - # robomimic # scikit-image imageio-ffmpeg==0.6.0 - # via - # imageio - # robomimic -importlib-metadata==8.7.0 + # via imageio +importlib-metadata==8.7.1 # via diffusers -importlib-resources==6.5.2 - # via etils iniconfig==2.3.0 # via pytest -inquirerpy==0.3.4 - # via huggingface-hub -ipython==8.37.0 +ipython==9.11.0 # via meshcat +ipython-pygments-lexers==1.1.1 + # via ipython ischedule==1.2.7 # via placo jedi==0.19.2 @@ -296,44 +276,24 @@ jinja2==3.1.6 # via torch jsonlines==4.0.0 # via lerobot -jsonschema==4.25.1 - # via nbformat -jsonschema-specifications==2025.9.1 - # via jsonschema -jupyter-core==5.9.1 - # via nbformat -jupytext==1.18.1 - # via bddl kiwisolver==1.4.9 # via matplotlib labmaze==1.0.6 # via dm-control -lazy-loader==0.4 +lazy-loader==0.5 # via scikit-image -libero @ git+https://github.com/huggingface/lerobot-libero.git@main - # via lerobot -llvmlite==0.45.1 - # via numba +librt==0.8.1 + # via mypy lxml==6.0.2 # via dm-control -markdown==3.9 - # via tensorboard markdown-it-py==4.0.0 - # via - # jupytext - # mdit-py-plugins + # via rich markupsafe==3.0.3 - # via - # jinja2 - # werkzeug -matplotlib==3.10.7 - # via - # lerobot - # libero + # via jinja2 +matplotlib==3.10.8 + # via lerobot matplotlib-inline==0.2.1 # via ipython -mdit-py-plugins==0.5.0 - # via jupytext mdurl==0.1.2 # via markdown-it-py mergedeep==1.3.4 @@ -346,41 +306,35 @@ mock-serial==0.0.1 # via lerobot mpmath==1.3.0 # via sympy -mujoco==3.3.7 +mujoco==3.5.0 # via # dm-control # gym-aloha # gym-hil - # libero # metaworld - # robosuite -multidict==6.7.0 +multidict==6.7.1 # via # aiohttp # yarl -multiprocess==0.70.16 +multiprocess==0.70.18 # via datasets +mypy==1.19.1 + # via lerobot mypy-extensions==1.1.0 - # via typing-inspect -nbformat==5.10.4 - # via jupytext -networkx==3.4.2 # via - # bddl + # mypy + # typing-inspect +networkx==3.6.1 + # via # scikit-image # torch -ninja==1.13.0 - # via lerobot -nodeenv==1.9.1 +nodeenv==1.10.0 # via pre-commit num2words==0.5.14 # via lerobot -numba==0.62.1 - # via robosuite numpy==2.2.6 # via # accelerate - # bddl # cmeel-boost # contourpy # datasets @@ -389,16 +343,14 @@ numpy==2.2.6 # dm-env # dm-tree # gymnasium - # h5py # hebi-py # imageio # labmaze - # libero + # lerobot # matplotlib # meshcat # metaworld # mujoco - # numba # opencv-python # opencv-python-headless # pandas @@ -406,26 +358,18 @@ numpy==2.2.6 # pyquaternion # reachy2-sdk # rerun-sdk - # robomimic - # robosuite # scikit-image # scipy # shapely # teleop - # tensorboard - # tensorboardx # tifffile # torchvision # transformers # transforms3d -omegaconf==2.3.0 - # via hydra-core -opencv-python==4.12.0.88 +opencv-python==4.13.0.92 # via # gym-pusht - # libero # reachy2-sdk - # robosuite opencv-python-headless==4.12.0.88 # via lerobot orderly-set==5.5.0 @@ -435,97 +379,87 @@ packaging==25.0 # accelerate # datasets # huggingface-hub - # hydra-core - # jupytext # lazy-loader # lerobot # matplotlib # peft # pytest + # qwen-vl-utils # reachy2-sdk # scikit-image - # tensorboard - # tensorboardx # transformers # wandb pandas==2.3.3 # via # datasets # lerobot -parso==0.8.5 +parso==0.8.6 # via jedi -peft==0.17.1 +pathspec==1.0.4 + # via mypy +peft==0.18.1 # via lerobot pexpect==4.9.0 # via ipython -pfzy==0.3.4 - # via inquirerpy -pillow==12.0.0 +pillow==12.1.1 # via # diffusers # imageio - # lerobot # matplotlib # meshcat + # qwen-vl-utils # rerun-sdk - # robosuite # scikit-image - # tensorboard # torchvision pin==3.4.0 # via placo -placo==0.9.14 +placo==0.9.16 # via lerobot -platformdirs==4.5.0 +platformdirs==4.9.4 # via - # jupyter-core + # python-discovery # virtualenv # wandb pluggy==1.6.0 # via # pytest # pytest-cov -pre-commit==4.3.0 +pre-commit==4.5.1 # via lerobot prompt-toolkit==3.0.52 - # via - # inquirerpy - # ipython + # via ipython propcache==0.4.1 # via # aiohttp # yarl -protobuf==6.31.0 +protobuf==6.31.1 # via # dm-control # grpcio-tools # lerobot # reachy2-sdk # reachy2-sdk-api - # tensorboard - # tensorboardx # wandb -psutil==7.1.1 +psutil==7.2.2 # via # accelerate # imageio # peft - # robomimic ptyprocess==0.7.0 # via pexpect pure-eval==0.2.3 # via stack-data -pyarrow==21.0.0 +pyarrow==23.0.1 # via # datasets # rerun-sdk -pycparser==2.23 +pycparser==3.0 # via cffi -pydantic==2.12.3 +pydantic==2.12.5 # via # fastapi # wandb -pydantic-core==2.41.4 +pydantic-core==2.41.5 # via pydantic pygame==2.6.1 # via @@ -535,33 +469,35 @@ pygame==2.6.1 pygments==2.19.2 # via # ipython + # ipython-pygments-lexers # pytest + # rich pymunk==6.11.1 # via # gym-pusht # lerobot -pyngrok==7.4.1 +pyngrok==7.5.1 # via meshcat pynput==1.8.1 # via # gym-hil # lerobot -pyobjc-core==12.0 +pyobjc-core==12.1 # via # pyobjc-framework-applicationservices # pyobjc-framework-cocoa # pyobjc-framework-coretext # pyobjc-framework-quartz -pyobjc-framework-applicationservices==12.0 +pyobjc-framework-applicationservices==12.1 # via pynput -pyobjc-framework-cocoa==12.0 +pyobjc-framework-cocoa==12.1 # via # pyobjc-framework-applicationservices # pyobjc-framework-coretext # pyobjc-framework-quartz -pyobjc-framework-coretext==12.0 +pyobjc-framework-coretext==12.1 # via pyobjc-framework-applicationservices -pyobjc-framework-quartz==12.0 +pyobjc-framework-quartz==12.1 # via # pynput # pyobjc-framework-applicationservices @@ -570,13 +506,13 @@ pyopengl==3.1.10 # via # dm-control # mujoco -pyparsing==3.2.5 +pyparsing==3.3.2 # via # dm-control # matplotlib pyquaternion==0.9.9 # via reachy2-sdk -pyrealsense2-macosx==2.54.2 +pyrealsense2-macosx==2.56.5 # via lerobot pyserial==3.5 # via @@ -585,7 +521,6 @@ pyserial==3.5 # lerobot pytest==8.4.2 # via - # bddl # lerobot # pytest-cov # pytest-timeout @@ -596,11 +531,14 @@ pytest-timeout==2.4.0 # via lerobot python-dateutil==2.9.0.post0 # via + # faker # matplotlib # pandas -python-dotenv==1.1.1 +python-discovery==1.1.1 + # via virtualenv +python-dotenv==1.2.2 # via uvicorn -pytz==2025.2 +pytz==2026.1.post1 # via pandas pyyaml==6.0.3 # via @@ -609,13 +547,10 @@ pyyaml==6.0.3 # draccus # hebi-py # huggingface-hub - # jupytext - # omegaconf # peft # pre-commit # pyngrok # pyyaml-include - # timm # transformers # uvicorn # wandb @@ -625,15 +560,13 @@ pyzmq==27.1.0 # via # lerobot # meshcat -reachy2-sdk==1.0.14 +qwen-vl-utils==0.0.14 + # via lerobot +reachy2-sdk==1.0.15 # via lerobot reachy2-sdk-api==1.0.21 # via reachy2-sdk -referencing==0.37.0 - # via - # jsonschema - # jsonschema-specifications -regex==2025.10.23 +regex==2026.2.28 # via # diffusers # transformers @@ -642,184 +575,150 @@ requests==2.32.5 # datasets # diffusers # dm-control - # huggingface-hub + # qwen-vl-utils # teleop - # transformers # wandb -rerun-sdk==0.26.1 +rerun-sdk==0.26.2 # via lerobot rhoban-cmeel-jsoncpp==1.9.4.9 # via placo -robomimic==0.2.0 - # via libero -robosuite==1.4.0 - # via libero -rpds-py==0.28.0 - # via - # jsonschema - # referencing -safetensors==0.6.2 +rich==14.3.3 + # via typer +safetensors==0.7.0 # via # accelerate # diffusers # lerobot # peft - # timm # transformers scikit-image==0.25.2 # via # gym-pusht # lerobot -scipy==1.15.3 +scipy==1.17.1 # via # dm-control + # lerobot # metaworld - # robosuite # scikit-image -sentry-sdk==2.42.1 + # torchdiffeq +sentry-sdk==2.54.0 # via wandb shapely==2.1.2 # via gym-pusht +shellingham==1.5.4 + # via typer six==1.17.0 # via # pynput # python-dateutil -smmap==5.0.2 +smmap==5.0.3 # via gitdb -sniffio==1.3.1 - # via anyio stack-data==0.6.3 # via ipython -starlette==0.48.0 +starlette==0.52.1 # via fastapi sympy==1.14.0 # via torch -teleop==0.1.2 +teleop==0.1.4 # via lerobot -tensorboard==2.20.0 - # via robomimic -tensorboard-data-server==0.7.2 - # via tensorboard -tensorboardx==2.6.4 - # via robomimic -termcolor==3.1.0 - # via - # lerobot - # robomimic -thop==0.1.1.post2209072238 - # via libero -tifffile==2025.5.10 +termcolor==3.3.0 + # via lerobot +tifffile==2026.3.3 # via scikit-image -timm==1.0.20 - # via lerobot -tokenizers==0.22.1 +tokenizers==0.22.2 # via transformers toml==0.10.2 # via draccus -tomli==2.3.0 - # via - # cmeel - # coverage - # jupytext - # pytest -torch==2.7.1 +torch==2.10.0 # via # accelerate # lerobot # peft - # robomimic - # thop - # timm + # torchdiffeq # torchvision -torchcodec==0.5 +torchcodec==0.10.0 # via lerobot -torchvision==0.22.1 - # via - # lerobot - # robomimic - # timm -tornado==6.5.2 +torchdiffeq==0.2.5 + # via lerobot +torchvision==0.25.0 + # via lerobot +tornado==6.5.4 # via meshcat -tqdm==4.67.1 +tqdm==4.67.3 # via # datasets # dm-control # huggingface-hub # peft - # robomimic # transformers traitlets==5.14.3 # via # ipython - # jupyter-core # matplotlib-inline - # nbformat -transformers==4.57.1 +transformers==5.3.0 # via # lerobot - # libero # peft transforms3d==0.4.2 # via teleop +typer==0.24.1 + # via + # huggingface-hub + # transformers typing-extensions==4.15.0 # via # aiosignal # anyio # etils - # exceptiongroup + # faker # fastapi # gymnasium # huggingface-hub - # ipython - # multidict + # mypy # pydantic # pydantic-core - # referencing # rerun-sdk # starlette # torch # typing-inspect # typing-inspection - # uvicorn - # virtualenv # wandb typing-inspect==0.9.0 # via draccus typing-inspection==0.4.2 - # via pydantic -tzdata==2025.2 + # via + # fastapi + # pydantic +tzdata==2025.3 # via pandas u-msgpack-python==2.8.0 # via meshcat -urllib3==2.5.0 +urllib3==2.6.3 # via # requests # sentry-sdk -uvicorn[standard]==0.38.0 +uvicorn[standard]==0.41.0 # via teleop uvloop==0.22.1 # via uvicorn -virtualenv==20.35.3 +virtualenv==21.1.0 # via pre-commit -wandb==0.21.4 - # via - # lerobot - # libero +wandb==0.24.2 + # via lerobot watchfiles==1.1.1 # via uvicorn -wcwidth==0.2.14 +wcwidth==0.6.0 # via prompt-toolkit websocket-client==1.9.0 # via teleop -websockets==15.0.1 +websockets==16.0 # via uvicorn -werkzeug==3.1.3 - # via tensorboard -wrapt==2.0.0 +wrapt==2.1.2 # via dm-tree xxhash==3.6.0 # via datasets -yarl==1.22.0 +yarl==1.23.0 # via aiohttp zipp==3.23.0 # via diff --git a/requirements-ubuntu.txt b/requirements-ubuntu.txt index 8413feac3..0cdc54190 100644 --- a/requirements-ubuntu.txt +++ b/requirements-ubuntu.txt @@ -1,12 +1,12 @@ # -# This file is autogenerated by pip-compile with Python 3.10 +# This file is autogenerated by pip-compile with Python 3.12 # by the following command: # # pip-compile --output-file=requirements-ubuntu.txt requirements.in # -e .[all] # via -[all] -absl-py==2.3.1 +absl-py==2.4.0 # via # dm-control # dm-env @@ -14,30 +14,33 @@ absl-py==2.3.1 # labmaze # mujoco # tensorboard -accelerate==1.11.0 +accelerate==1.13.0 # via # lerobot # peft aiohappyeyeballs==2.6.1 # via aiohttp -aiohttp==3.13.1 +aiohttp==3.13.3 # via fsspec aiosignal==1.4.0 # via aiohttp +annotated-doc==0.0.4 + # via + # fastapi + # typer annotated-types==0.7.0 # via pydantic antlr4-python3-runtime==4.9.3 # via # hydra-core # omegaconf -anyio==4.11.0 +anyio==4.12.1 # via + # httpx # starlette # watchfiles -asttokens==3.0.0 +asttokens==3.0.1 # via stack-data -async-timeout==5.0.1 - # via aiohttp attrs==25.4.0 # via # aiohttp @@ -47,30 +50,35 @@ attrs==25.4.0 # referencing # rerun-sdk av==15.1.0 - # via lerobot -bddl==1.0.1 - # via libero -certifi==2025.10.5 # via + # lerobot + # qwen-vl-utils +bddl==1.0.1 + # via hf-libero +certifi==2026.2.25 + # via + # httpcore + # httpx # requests # sentry-sdk cffi==2.0.0 # via pymunk -cfgv==3.4.0 +cfgv==3.5.0 # via pre-commit -charset-normalizer==3.4.4 +charset-normalizer==3.4.5 # via requests -click==8.3.0 +click==8.3.1 # via + # typer # uvicorn # wandb -cloudpickle==3.1.1 +cloudpickle==3.1.2 # via # gymnasium - # libero -cmake==4.1.0 + # hf-libero +cmake==4.1.3 # via lerobot -cmeel==0.57.3 +cmeel==0.59.0 # via # cmeel-assimp # cmeel-boost @@ -108,20 +116,24 @@ cmeel-zlib==1.3.1 # via cmeel-assimp coal-library==3.0.1 # via pin -contourpy==1.3.2 - # via matplotlib -coverage[toml]==7.11.0 +contourpy==1.3.3 + # via + # lerobot + # matplotlib +coverage[toml]==7.13.4 # via pytest-cov +cuda-bindings==12.9.4 + # via torch +cuda-pathfinder==1.4.1 + # via cuda-bindings cycler==0.12.1 # via matplotlib -datasets==4.1.1 +datasets==4.6.1 # via lerobot -debugpy==1.8.17 +debugpy==1.8.20 # via lerobot decorator==5.2.1 # via ipython -decord==0.6.0 - # via lerobot deepdiff==8.6.1 # via lerobot diffusers==0.35.2 @@ -132,7 +144,7 @@ dill==0.4.0 # multiprocess distlib==0.4.0 # via virtualenv -dm-control==1.0.34 +dm-control==1.0.37 # via gym-aloha dm-env==1.6 # via dm-control @@ -140,7 +152,6 @@ dm-tree==0.1.9 # via # dm-control # dm-env - # lerobot docopt==0.6.2 # via num2words draccus==0.10.0 @@ -148,66 +159,60 @@ draccus==0.10.0 dynamixel-sdk==3.8.4 # via lerobot easydict==1.13 - # via libero -egl-probe @ git+https://github.com/huggingface/egl_probe.git - # via - # libero - # robomimic + # via hf-libero +egl-probe==1.0.2 + # via robomimic eigenpy==3.10.3 # via coal-library -einops==0.8.1 +einops==0.8.2 # via - # flash-attn + # hf-libero # lerobot - # libero eiquadprog==1.2.9 # via placo -etils[epath,epy]==1.13.0 +etils[epath,epy]==1.14.0 # via mujoco -evdev==1.9.2 +evdev==1.9.3 # via pynput -exceptiongroup==1.3.0 - # via - # anyio - # ipython - # pytest executing==2.2.1 # via stack-data +faker==34.0.2 + # via lerobot farama-notifications==0.0.4 # via gymnasium -fastapi==0.119.1 - # via teleop +fastapi==0.135.1 + # via + # lerobot + # teleop fastjsonschema==2.21.2 # via nbformat feetech-servo-sdk==1.0.0 # via lerobot -filelock==3.20.0 +filelock==3.25.0 # via # datasets # diffusers # huggingface-hub + # python-discovery # torch - # transformers # virtualenv -flash-attn==2.8.3 - # via lerobot -fonttools==4.60.1 +fonttools==4.61.1 # via matplotlib frozenlist==1.8.0 # via # aiohttp # aiosignal -fsspec[http]==2025.9.0 +fsspec[http]==2026.2.0 # via # datasets # etils # huggingface-hub # torch future==1.0.0 - # via libero + # via hf-libero gitdb==4.0.12 # via gitpython -gitpython==3.1.45 +gitpython==3.1.46 # via wandb glfw==2.10.0 # via @@ -230,50 +235,60 @@ gym-hil==0.1.13 # via lerobot gym-pusht==0.1.6 # via lerobot -gymnasium==1.2.1 +gymnasium==1.2.3 # via # gym-aloha # gym-hil # gym-pusht + # hf-libero # lerobot - # libero # metaworld h11==0.16.0 - # via uvicorn -h5py==3.15.1 + # via + # httpcore + # uvicorn +h5py==3.16.0 # via robomimic hebi-py==2.11.0 # via lerobot -hf-transfer==0.1.9 - # via huggingface-hub -hf-xet==1.1.10 +hf-egl-probe==1.0.2 + # via hf-libero +hf-libero==0.1.3 + # via lerobot +hf-xet==1.3.2 # via huggingface-hub hidapi==0.14.0.post4 # via # gym-hil # lerobot +httpcore==1.0.9 + # via httpx httptools==0.7.1 # via uvicorn -huggingface-hub[cli,hf-transfer]==0.35.3 +httpx==0.28.1 + # via + # datasets + # huggingface-hub +huggingface-hub==1.6.0 # via # accelerate # datasets # diffusers # lerobot # peft - # timm # tokenizers # transformers hydra-core==1.3.2 - # via libero -identify==2.6.15 + # via hf-libero +identify==2.6.17 # via pre-commit idna==3.11 # via # anyio + # httpx # requests # yarl -imageio[ffmpeg]==2.37.0 +imageio[ffmpeg]==2.37.2 # via # gym-aloha # gym-hil @@ -285,16 +300,14 @@ imageio-ffmpeg==0.6.0 # via # imageio # robomimic -importlib-metadata==8.7.0 +importlib-metadata==8.7.1 # via diffusers -importlib-resources==6.5.2 - # via etils iniconfig==2.3.0 # via pytest -inquirerpy==0.3.4 - # via huggingface-hub -ipython==8.37.0 +ipython==9.11.0 # via meshcat +ipython-pygments-lexers==1.1.1 + # via ipython ischedule==1.2.7 # via placo jedi==0.19.2 @@ -303,40 +316,41 @@ jinja2==3.1.6 # via torch jsonlines==4.0.0 # via lerobot -jsonschema==4.25.1 +jsonschema==4.26.0 # via nbformat jsonschema-specifications==2025.9.1 # via jsonschema jupyter-core==5.9.1 # via nbformat -jupytext==1.18.1 +jupytext==1.19.1 # via bddl kiwisolver==1.4.9 # via matplotlib labmaze==1.0.6 # via dm-control -lazy-loader==0.4 +lazy-loader==0.5 # via scikit-image -libero @ git+https://github.com/huggingface/lerobot-libero.git@main - # via lerobot -llvmlite==0.45.1 +librt==0.8.1 + # via mypy +llvmlite==0.46.0 # via numba lxml==6.0.2 # via dm-control -markdown==3.9 +markdown==3.10.2 # via tensorboard markdown-it-py==4.0.0 # via # jupytext # mdit-py-plugins + # rich markupsafe==3.0.3 # via # jinja2 # werkzeug -matplotlib==3.10.7 +matplotlib==3.10.8 # via + # hf-libero # lerobot - # libero matplotlib-inline==0.2.1 # via ipython mdit-py-plugins==0.5.0 @@ -353,36 +367,38 @@ mock-serial==0.0.1 # via lerobot mpmath==1.3.0 # via sympy -mujoco==3.3.7 +mujoco==3.5.0 # via # dm-control # gym-aloha # gym-hil - # libero + # hf-libero # metaworld # robosuite -multidict==6.7.0 +multidict==6.7.1 # via # aiohttp # yarl -multiprocess==0.70.16 +multiprocess==0.70.18 # via datasets +mypy==1.19.1 + # via lerobot mypy-extensions==1.1.0 - # via typing-inspect + # via + # mypy + # typing-inspect nbformat==5.10.4 # via jupytext -networkx==3.4.2 +networkx==3.6.1 # via # bddl # scikit-image # torch -ninja==1.13.0 - # via lerobot -nodeenv==1.9.1 +nodeenv==1.10.0 # via pre-commit num2words==0.5.14 # via lerobot -numba==0.62.1 +numba==0.64.0 # via robosuite numpy==2.2.6 # via @@ -391,7 +407,6 @@ numpy==2.2.6 # cmeel-boost # contourpy # datasets - # decord # diffusers # dm-control # dm-env @@ -399,9 +414,10 @@ numpy==2.2.6 # gymnasium # h5py # hebi-py + # hf-libero # imageio # labmaze - # libero + # lerobot # matplotlib # meshcat # metaworld @@ -426,49 +442,51 @@ numpy==2.2.6 # torchvision # transformers # transforms3d -nvidia-cublas-cu12==12.6.4.1 +nvidia-cublas-cu12==12.8.4.1 # via # nvidia-cudnn-cu12 # nvidia-cusolver-cu12 # torch -nvidia-cuda-cupti-cu12==12.6.80 +nvidia-cuda-cupti-cu12==12.8.90 # via torch -nvidia-cuda-nvrtc-cu12==12.6.77 +nvidia-cuda-nvrtc-cu12==12.8.93 # via torch -nvidia-cuda-runtime-cu12==12.6.77 +nvidia-cuda-runtime-cu12==12.8.90 # via torch -nvidia-cudnn-cu12==9.5.1.17 +nvidia-cudnn-cu12==9.10.2.21 # via torch -nvidia-cufft-cu12==11.3.0.4 +nvidia-cufft-cu12==11.3.3.83 # via torch -nvidia-cufile-cu12==1.11.1.6 +nvidia-cufile-cu12==1.13.1.3 # via torch -nvidia-curand-cu12==10.3.7.77 +nvidia-curand-cu12==10.3.9.90 # via torch -nvidia-cusolver-cu12==11.7.1.2 +nvidia-cusolver-cu12==11.7.3.90 # via torch -nvidia-cusparse-cu12==12.5.4.2 +nvidia-cusparse-cu12==12.5.8.93 # via # nvidia-cusolver-cu12 # torch -nvidia-cusparselt-cu12==0.6.3 +nvidia-cusparselt-cu12==0.7.1 # via torch -nvidia-nccl-cu12==2.26.2 +nvidia-nccl-cu12==2.27.5 # via torch -nvidia-nvjitlink-cu12==12.6.85 +nvidia-nvjitlink-cu12==12.8.93 # via # nvidia-cufft-cu12 # nvidia-cusolver-cu12 # nvidia-cusparse-cu12 # torch -nvidia-nvtx-cu12==12.6.77 +nvidia-nvshmem-cu12==3.4.5 + # via torch +nvidia-nvtx-cu12==12.8.90 # via torch omegaconf==2.3.0 # via hydra-core -opencv-python==4.12.0.88 +opencv-python==4.13.0.92 # via # gym-pusht - # libero + # hf-libero # reachy2-sdk # robosuite opencv-python-headless==4.12.0.88 @@ -487,6 +505,7 @@ packaging==25.0 # matplotlib # peft # pytest + # qwen-vl-utils # reachy2-sdk # scikit-image # tensorboard @@ -497,21 +516,21 @@ pandas==2.3.3 # via # datasets # lerobot -parso==0.8.5 +parso==0.8.6 # via jedi -peft==0.17.1 +pathspec==1.0.4 + # via mypy +peft==0.18.1 # via lerobot pexpect==4.9.0 # via ipython -pfzy==0.3.4 - # via inquirerpy -pillow==12.0.0 +pillow==12.1.1 # via # diffusers # imageio - # lerobot # matplotlib # meshcat + # qwen-vl-utils # rerun-sdk # robosuite # scikit-image @@ -519,28 +538,27 @@ pillow==12.0.0 # torchvision pin==3.4.0 # via placo -placo==0.9.14 +placo==0.9.16 # via lerobot -platformdirs==4.5.0 +platformdirs==4.9.4 # via # jupyter-core + # python-discovery # virtualenv # wandb pluggy==1.6.0 # via # pytest # pytest-cov -pre-commit==4.3.0 +pre-commit==4.5.1 # via lerobot prompt-toolkit==3.0.52 - # via - # inquirerpy - # ipython + # via ipython propcache==0.4.1 # via # aiohttp # yarl -protobuf==6.31.0 +protobuf==6.31.1 # via # dm-control # grpcio-tools @@ -550,7 +568,7 @@ protobuf==6.31.0 # tensorboard # tensorboardx # wandb -psutil==7.1.1 +psutil==7.2.2 # via # accelerate # imageio @@ -560,17 +578,17 @@ ptyprocess==0.7.0 # via pexpect pure-eval==0.2.3 # via stack-data -pyarrow==21.0.0 +pyarrow==23.0.1 # via # datasets # rerun-sdk -pycparser==2.23 +pycparser==3.0 # via cffi -pydantic==2.12.3 +pydantic==2.12.5 # via # fastapi # wandb -pydantic-core==2.41.4 +pydantic-core==2.41.5 # via pydantic pygame==2.6.1 # via @@ -580,12 +598,14 @@ pygame==2.6.1 pygments==2.19.2 # via # ipython + # ipython-pygments-lexers # pytest + # rich pymunk==6.11.1 # via # gym-pusht # lerobot -pyngrok==7.4.1 +pyngrok==7.5.1 # via meshcat pynput==1.8.1 # via @@ -595,7 +615,7 @@ pyopengl==3.1.10 # via # dm-control # mujoco -pyparsing==3.2.5 +pyparsing==3.3.2 # via # dm-control # matplotlib @@ -621,13 +641,16 @@ pytest-timeout==2.4.0 # via lerobot python-dateutil==2.9.0.post0 # via + # faker # matplotlib # pandas -python-dotenv==1.1.1 +python-discovery==1.1.1 + # via virtualenv +python-dotenv==1.2.2 # via uvicorn python-xlib==0.33 # via pynput -pytz==2025.2 +pytz==2026.1.post1 # via pandas pyyaml==6.0.3 # via @@ -642,7 +665,6 @@ pyyaml==6.0.3 # pre-commit # pyngrok # pyyaml-include - # timm # transformers # uvicorn # wandb @@ -652,7 +674,9 @@ pyzmq==27.1.0 # via # lerobot # meshcat -reachy2-sdk==1.0.14 +qwen-vl-utils==0.0.14 + # via lerobot +reachy2-sdk==1.0.15 # via lerobot reachy2-sdk-api==1.0.21 # via reachy2-sdk @@ -660,7 +684,7 @@ referencing==0.37.0 # via # jsonschema # jsonschema-specifications -regex==2025.10.23 +regex==2026.2.28 # via # diffusers # transformers @@ -669,60 +693,62 @@ requests==2.32.5 # datasets # diffusers # dm-control - # huggingface-hub + # qwen-vl-utils # teleop - # transformers # wandb -rerun-sdk==0.26.1 +rerun-sdk==0.26.2 # via lerobot rhoban-cmeel-jsoncpp==1.9.4.9 # via placo +rich==14.3.3 + # via typer robomimic==0.2.0 - # via libero + # via hf-libero robosuite==1.4.0 - # via libero -rpds-py==0.28.0 + # via hf-libero +rpds-py==0.30.0 # via # jsonschema # referencing -safetensors==0.6.2 +safetensors==0.7.0 # via # accelerate # diffusers # lerobot # peft - # timm # transformers scikit-image==0.25.2 # via # gym-pusht # lerobot -scipy==1.15.3 +scipy==1.17.1 # via # dm-control + # lerobot # metaworld # robosuite # scikit-image -sentry-sdk==2.42.1 + # torchdiffeq +sentry-sdk==2.54.0 # via wandb shapely==2.1.2 # via gym-pusht +shellingham==1.5.4 + # via typer six==1.17.0 # via # pynput # python-dateutil # python-xlib -smmap==5.0.2 +smmap==5.0.3 # via gitdb -sniffio==1.3.1 - # via anyio stack-data==0.6.3 # via ipython -starlette==0.48.0 +starlette==0.52.1 # via fastapi sympy==1.14.0 # via torch -teleop==0.1.2 +teleop==0.1.4 # via lerobot tensorboard==2.20.0 # via robomimic @@ -730,46 +756,38 @@ tensorboard-data-server==0.7.2 # via tensorboard tensorboardx==2.6.4 # via robomimic -termcolor==3.1.0 +termcolor==3.3.0 # via # lerobot # robomimic thop==0.1.1.post2209072238 - # via libero -tifffile==2025.5.10 + # via hf-libero +tifffile==2026.3.3 # via scikit-image -timm==1.0.20 - # via lerobot -tokenizers==0.22.1 +tokenizers==0.22.2 # via transformers toml==0.10.2 # via draccus -tomli==2.3.0 - # via - # cmeel - # coverage - # jupytext - # pytest -torch==2.7.1 +torch==2.10.0 # via # accelerate - # flash-attn # lerobot # peft # robomimic # thop - # timm + # torchdiffeq # torchvision -torchcodec==0.5 +torchcodec==0.10.0 # via lerobot -torchvision==0.22.1 +torchdiffeq==0.2.5 + # via lerobot +torchvision==0.25.0 # via # lerobot # robomimic - # timm -tornado==6.5.2 +tornado==6.5.4 # via meshcat -tqdm==4.67.1 +tqdm==4.67.3 # via # datasets # dm-control @@ -783,26 +801,29 @@ traitlets==5.14.3 # jupyter-core # matplotlib-inline # nbformat -transformers==4.57.1 +transformers==5.3.0 # via + # hf-libero # lerobot - # libero # peft transforms3d==0.4.2 # via teleop -triton==3.3.1 +triton==3.6.0 # via torch +typer==0.24.1 + # via + # huggingface-hub + # transformers typing-extensions==4.15.0 # via # aiosignal # anyio # etils - # exceptiongroup + # faker # fastapi # gymnasium # huggingface-hub - # ipython - # multidict + # mypy # pydantic # pydantic-core # referencing @@ -811,46 +832,46 @@ typing-extensions==4.15.0 # torch # typing-inspect # typing-inspection - # uvicorn - # virtualenv # wandb typing-inspect==0.9.0 # via draccus typing-inspection==0.4.2 - # via pydantic -tzdata==2025.2 + # via + # fastapi + # pydantic +tzdata==2025.3 # via pandas u-msgpack-python==2.8.0 # via meshcat -urllib3==2.5.0 +urllib3==2.6.3 # via # requests # sentry-sdk -uvicorn[standard]==0.38.0 +uvicorn[standard]==0.41.0 # via teleop uvloop==0.22.1 # via uvicorn -virtualenv==20.35.3 +virtualenv==21.1.0 # via pre-commit -wandb==0.21.4 +wandb==0.24.2 # via + # hf-libero # lerobot - # libero watchfiles==1.1.1 # via uvicorn -wcwidth==0.2.14 +wcwidth==0.6.0 # via prompt-toolkit websocket-client==1.9.0 # via teleop -websockets==15.0.1 +websockets==16.0 # via uvicorn -werkzeug==3.1.3 +werkzeug==3.1.6 # via tensorboard -wrapt==2.0.0 +wrapt==2.1.2 # via dm-tree xxhash==3.6.0 # via datasets -yarl==1.22.0 +yarl==1.23.0 # via aiohttp zipp==3.23.0 # via diff --git a/requirements.in b/requirements.in index df2a07d67..b39632f71 100644 --- a/requirements.in +++ b/requirements.in @@ -1,9 +1,9 @@ # requirements.in -# requirements-macos.txt was generated on macOS and is platform-specific (macOS 26.0.1 25A362 arm64). -# Darwin MacBook-Pro.local 25.0.0 Darwin Kernel Version 25.0.0: Wed Sep 17 21:42:08 PDT 2025; root:xnu-12377.1.9~141/RELEASE_ARM64_T8132 arm64 +# requirements-macos.txt was generated on macOS and is platform-specific (macOS 26.3.1 25D2128 arm64). +# Darwin MacBook-Pro.local 25.3.0 Darwin Kernel Version 25.3.0: Wed Jan 28 20:54:55 PST 2026; root:xnu-12377.91.3~2/RELEASE_ARM64_T8132 arm64 -# requirements-ubuntu.txt was generated on Linux and is platform-specific (Ubuntu 24.04.3 LTS x86_64). -# Linux mlerobot-linux 6.14.0-33-generic #33~24.04.1-Ubuntu SMP PREEMPT_DYNAMIC Fri Sep 19 17:02:30 UTC 2 x86_64 x86_64 x86_64 GNU/Linux +# requirements-ubuntu.txt was generated on Linux and is platform-specific (Ubuntu 24.04.4 LTS x86_64). +# Linux lerobot-linux 6.17.0-14-generic #14~24.04.1-Ubuntu SMP PREEMPT_DYNAMIC Thu Jan 15 15:52:10 UTC 2 x86_64 x86_64 x86_64 GNU/Linux -e .[all] From 00b662de02734a6972ec674b8792696ecd1cb28e Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Mon, 9 Mar 2026 11:34:52 +0100 Subject: [PATCH 102/131] chore(dependencies): Bump lerobot to 0.5.0 (#3117) --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 696d8597d..b59d169c7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ discord = "https://discord.gg/s3KuuzsPFb" [project] name = "lerobot" -version = "0.4.5" +version = "0.5.0" description = "🤗 LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch" dynamic = ["readme"] license = { text = "Apache-2.0" } From b0efa73520845dbe507eab0049f8304bedf69e96 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Mon, 9 Mar 2026 12:43:32 +0100 Subject: [PATCH 103/131] chore(dependencies): Bump lerobot to 0.5.1 (#3118) --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index b59d169c7..aed846f43 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ discord = "https://discord.gg/s3KuuzsPFb" [project] name = "lerobot" -version = "0.5.0" +version = "0.5.1" description = "🤗 LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch" dynamic = ["readme"] license = { text = "Apache-2.0" } From 885ef91892f1d190a145a11966869ccc52f7964a Mon Sep 17 00:00:00 2001 From: Martino Russi <77496684+nepyope@users.noreply.github.com> Date: Mon, 9 Mar 2026 18:47:12 +0100 Subject: [PATCH 104/131] fix(unitree_g1): correct SDK detection and update installation docs (#3115) * update docs * update toml / docs * update docs * fix joystick * Update pyproject.toml Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Martino Russi <77496684+nepyope@users.noreply.github.com> * update toml and docs * update docs * clarify robot * update docs * update docs * update pinocchio deps * final touches * Update docs/source/unitree_g1.mdx Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Martino Russi <77496684+nepyope@users.noreply.github.com> * move envhub dependencies to docs * point to unitree_sdk docs * upper bound on onnx * chore(docs): small details unitree docs * chore(deps): add version pin and unitree_sdk hint --------- Signed-off-by: Martino Russi <77496684+nepyope@users.noreply.github.com> Co-authored-by: Steven Palma --- docs/source/unitree_g1.mdx | 127 ++++++++++++++++++++---------- pyproject.toml | 5 +- src/lerobot/utils/import_utils.py | 2 +- 3 files changed, 87 insertions(+), 47 deletions(-) diff --git a/docs/source/unitree_g1.mdx b/docs/source/unitree_g1.mdx index 39bd7832b..2e615085e 100644 --- a/docs/source/unitree_g1.mdx +++ b/docs/source/unitree_g1.mdx @@ -12,36 +12,59 @@ The Unitree G1 humanoid is now supported in LeRobot! You can teleoperate, train ## Part 1: Getting Started -### Install LeRobot on Your Machine +### Install the Unitree SDK + +Follow the [unitree_sdk2_python installation guide](https://github.com/unitreerobotics/unitree_sdk2_python#installation). Tested with `unitree_sdk2py==1.0.1` and `cyclonedds==0.10.2`: ```bash conda create -y -n lerobot python=3.12 conda activate lerobot git clone https://github.com/unitreerobotics/unitree_sdk2_python.git -cd unitree_sdk2_python && pip install -e . +cd unitree_sdk2_python +pip install -e . +cd .. +``` + +### Install LeRobot + +```bash +conda install ffmpeg -c conda-forge +conda install -c conda-forge "pinocchio>=3.0.0,<4.0.0" git clone https://github.com/huggingface/lerobot.git cd lerobot pip install -e '.[unitree_g1]' ``` + + For now, pinocchio must be installed from conda-forge (not pip) to include the + CasADi bindings needed for arm IK. + + ### Test the Installation (Simulation) +The simulation environment has its own dependencies. Check the Simulation environment dependencies: [Unitree G1 Mujoco EnvHub](https://huggingface.co/lerobot/unitree-g1-mujoco/tree/main). + +```bash +pip install mujoco loguru msgpack msgpack-numpy +``` + ```bash lerobot-teleoperate \ --robot.type=unitree_g1 \ --robot.is_simulation=true \ --teleop.type=unitree_g1 \ --teleop.id=wbc_unitree \ - --robot.cameras='{"global_view": {"type": "zmq", "server_address": "localhost", "port": 5555, "camera_name": "head_camera", "width": 640, "height": 480, "fps": 30}}' \ - --display_data=true + --robot.cameras='{"global_view": {"type": "zmq", "server_address": "localhost", "port": 5555, "camera_name": "head_camera", "width": 640, "height": 480, "fps": 30, "warmup_s": 5}}' \ + --display_data=true \ + --robot.controller=GrootLocomotionController ``` -This will launch a [MuJoCo sim instance](https://huggingface.co/lerobot/unitree-g1-mujoco/tree/main) for the G1. +This will launch a [MuJoCo sim instance](https://huggingface.co/lerobot/unitree-g1-mujoco/tree/main) for the G1. You can connect a gamepad to your machine before launching in order to control the robot's locomotion in sim. We support both [HolosomaLocomotionController](https://github.com/amazon-far/holosoma) and [GrootLocomotionController](https://github.com/NVlabs/GR00T-WholeBodyControl) via `--robot.controller`. - Press `9` to release the robot - Press `7` / `8` to increase / decrease waist height -### Connect to the Robot +### Connect to the Physical Robot The G1's Ethernet IP is fixed at `192.168.123.164`. Your machine must have a static IP on the same subnet: `192.168.123.x` where `x ≠ 164`. @@ -59,37 +82,11 @@ ssh unitree@192.168.123.164 # Password: 123 ``` -### Install LeRobot on the G1 +### Share Internet via Ethernet -From the robot: +The G1 needs internet access to clone repos and install packages. Share your laptop's connection over Ethernet: -```bash -conda create -y -n lerobot python=3.12 -conda activate lerobot -git clone https://github.com/unitreerobotics/unitree_sdk2_python.git -cd unitree_sdk2_python && pip install -e . -git clone https://github.com/huggingface/lerobot.git -cd lerobot -pip install -e '.[unitree_g1]' -``` - -> **Note:** The Unitree SDK requires CycloneDDS v0.10.2. See the [Unitree SDK docs](https://github.com/unitreerobotics/unitree_sdk2_python) for details. - ---- - -## Part 2: Enable WiFi on the Robot - -Wi-Fi connectivity is blocked by default on the G1. To activate: - -```bash -sudo rfkill unblock all -sudo ip link set wlan0 up -sudo nmcli radio wifi on -sudo nmcli device set wlan0 managed yes -sudo systemctl restart NetworkManager -``` - -**On your laptop** (share internet via Ethernet): +**On your laptop:** ```bash sudo sysctl -w net.ipv4.ip_forward=1 @@ -100,7 +97,7 @@ sudo iptables -A FORWARD -i wlp132s0f0 -o enp131s0 -m state --state RELATED,ESTA sudo iptables -A FORWARD -i enp131s0 -o wlp132s0f0 -j ACCEPT ``` -**On the G1** (set default route through your laptop): +**On the G1:** ```bash sudo ip route del default 2>/dev/null || true @@ -111,6 +108,45 @@ echo "nameserver 8.8.8.8" | sudo tee /etc/resolv.conf ping -c 3 8.8.8.8 ``` +### Install the Unitree SDK on the G1 + +Follow the [unitree_sdk2_python installation guide](https://github.com/unitreerobotics/unitree_sdk2_python#installation): + +```bash +conda create -y -n lerobot python=3.12 +conda activate lerobot +git clone https://github.com/unitreerobotics/unitree_sdk2_python.git +cd unitree_sdk2_python +python -m pip install -e . +cd .. +``` + +### Install LeRobot on the G1 + +```bash +git clone https://github.com/huggingface/lerobot.git +cd lerobot +conda install -c conda-forge "pinocchio>=3.0.0,<4.0.0" +python -m pip install -e '.[unitree_g1]' +``` + + + For now, pinocchio must be installed from conda-forge (not pip) to include the + CasADi bindings needed for arm IK. + + +### (Optional) Enable WiFi on the Robot + +For wireless SSH access, you can enable WiFi on the G1 (it's blocked by default): + +```bash +sudo rfkill unblock all +sudo ip link set wlan0 up +sudo nmcli radio wifi on +sudo nmcli device set wlan0 managed yes +sudo systemctl restart NetworkManager +``` + **Connect to a WiFi network:** ```bash @@ -125,7 +161,7 @@ sudo nmcli connection up "YourNetwork" ip a show wlan0 ``` -You can now SSH over WiFi: +You can then SSH over WiFi instead of Ethernet: ```bash ssh unitree@ @@ -134,18 +170,23 @@ ssh unitree@ --- -## Part 3: Teleoperation & Locomotion +## Part 2: Teleoperation & Locomotion ### Run the Robot Server -On the robot: +On the robot (from `~/lerobot`): ```bash +cd ~/lerobot python src/lerobot/robots/unitree_g1/run_g1_server.py --camera ``` ### Run the Locomotion Policy +You can run the teleoperation client from your laptop over Ethernet, over WiFi (experimental), or directly on the robot itself. Mind potential latency introduced by your network. + +**From your laptop:** + ```bash lerobot-teleoperate \ --robot.type=unitree_g1 \ @@ -158,13 +199,13 @@ lerobot-teleoperate \ --robot.controller=HolosomaLocomotionController ``` -We support both [HolosomaLocomotionController](https://github.com/amazon-far/holosoma) and [GrootLocomotionController](https://github.com/NVlabs/GR00T-WholeBodyControl). +We support both [GrootLocomotionController](https://github.com/NVlabs/GR00T-WholeBodyControl) and [HolosomaLocomotionController](https://github.com/amazon-far/holosoma) via `--robot.controller`. --- -## Part 4: Loco-Manipulation with the Homunculus Exoskeleton +## Part 3: Loco-Manipulation with the Homunculus Exoskeleton -We provide a loco-manipulation solution via the Homunculus Exoskeleton — an open-source 7 DoF exoskeleton for whole-body control. Assembly instructions [here](https://github.com/nepyope/hmc_exo). +We provide a loco-manipulation solution via the Homunculus Exoskeleton — an open-source 7 DoF exoskeleton for whole-body control. Check it out [here](https://github.com/nepyope/hmc_exo). ### Calibrate @@ -205,7 +246,7 @@ Example dataset: [nepyope/unitree_box_move_blue_full](https://huggingface.co/dat --- -## Part 5: Training & Inference +## Part 4: Training & Inference ### Train diff --git a/pyproject.toml b/pyproject.toml index aed846f43..e85d695df 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -119,14 +119,13 @@ gamepad = ["lerobot[pygame-dep]", "hidapi>=0.14.0,<0.15.0"] hopejr = ["lerobot[feetech]", "lerobot[pygame-dep]"] lekiwi = ["lerobot[feetech]", "pyzmq>=26.2.1,<28.0.0"] unitree_g1 = [ - "unitree-sdk2==1.0.1", + # "unitree-sdk2==1.0.1", "pyzmq>=26.2.1,<28.0.0", "onnxruntime>=1.16.0,<2.0.0", - "pin>=3.0.0,<4.0.0", + "onnx>=1.16.0,<2.0.0", "meshcat>=0.3.0,<0.4.0", "lerobot[matplotlib-dep]", "lerobot[pygame-dep]", - "casadi>=3.6.0,<4.0.0", ] reachy2 = ["reachy2_sdk>=1.0.15,<1.1.0"] kinematics = ["lerobot[placo-dep]"] diff --git a/src/lerobot/utils/import_utils.py b/src/lerobot/utils/import_utils.py index cae445e06..2b26b2302 100644 --- a/src/lerobot/utils/import_utils.py +++ b/src/lerobot/utils/import_utils.py @@ -74,7 +74,7 @@ _peft_available = is_package_available("peft") _scipy_available = is_package_available("scipy") _reachy2_sdk_available = is_package_available("reachy2_sdk") _can_available = is_package_available("python-can", "can") -_unitree_sdk_available = is_package_available("unitree-sdk2", "unitree_sdk2py") +_unitree_sdk_available = is_package_available("unitree-sdk2py", "unitree_sdk2py") _pygame_available = is_package_available("pygame") From 96b7f3dae0e79029add57aa8cc99cc854cf33b5a Mon Sep 17 00:00:00 2001 From: Johnson Sun <20457146+j3soon@users.noreply.github.com> Date: Tue, 10 Mar 2026 01:47:58 +0800 Subject: [PATCH 105/131] Parse HF_USER with NO_COLOR to avoid incorrectly capturing bash ANSI codes (#3119) --- docs/source/il_robots.mdx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/il_robots.mdx b/docs/source/il_robots.mdx index e49132a8e..245634382 100644 --- a/docs/source/il_robots.mdx +++ b/docs/source/il_robots.mdx @@ -165,7 +165,7 @@ hf auth login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential Then store your Hugging Face repository name in a variable: ```bash -HF_USER=$(hf auth whoami | awk -F': *' 'NR==1 {print $2}') +HF_USER=$(NO_COLOR=1 hf auth whoami | awk -F': *' 'NR==1 {print $2}') echo $HF_USER ``` From 19c6adef8521d27561590904abc8e5e1d99f24af Mon Sep 17 00:00:00 2001 From: Silvio Traversaro Date: Mon, 9 Mar 2026 23:27:18 +0100 Subject: [PATCH 106/131] chore(dependencies): Increase opencv-python-headless upper bound (#3120) Signed-off-by: Silvio Traversaro --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index e85d695df..5f45626c0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -76,7 +76,7 @@ dependencies = [ "torchvision>=0.21.0,<0.26.0", "einops>=0.8.0,<0.9.0", - "opencv-python-headless>=4.9.0,<4.13.0", + "opencv-python-headless>=4.9.0,<4.14.0", "av>=15.0.0,<16.0.0", "jsonlines>=4.0.0,<5.0.0", "pynput>=1.7.8,<1.9.0", From f311ca3dcee0969966cd3ff174f33becf4e0d850 Mon Sep 17 00:00:00 2001 From: "H.Yamada" <3405876+ymd-h@users.noreply.github.com> Date: Thu, 12 Mar 2026 04:12:21 +0900 Subject: [PATCH 107/131] Fix action padding key at SmolVLA (#1717) Issue https://github.com/huggingface/lerobot/issues/1707 Action padding mask is set at LeRobotDataset as f"{key}_is_pad". Wrong key doesn't raise any errors, however, padding mask is ignored, resulting wrong attention at around the edges of an episode when multi step actions is enabled (aka. action horizon is greater than 1). Co-authored-by: Steven Palma --- src/lerobot/policies/smolvla/modeling_smolvla.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lerobot/policies/smolvla/modeling_smolvla.py b/src/lerobot/policies/smolvla/modeling_smolvla.py index 430c85481..048d974af 100644 --- a/src/lerobot/policies/smolvla/modeling_smolvla.py +++ b/src/lerobot/policies/smolvla/modeling_smolvla.py @@ -374,7 +374,7 @@ class SmolVLAPolicy(PreTrainedPolicy): lang_tokens = batch[f"{OBS_LANGUAGE_TOKENS}"] lang_masks = batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"] actions = self.prepare_action(batch) - actions_is_pad = batch.get("actions_id_pad") + actions_is_pad = batch.get("action_is_pad") loss_dict = {} losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions, noise, time) loss_dict["losses_after_forward"] = losses.clone().mean().item() From c15b75e3dad73d9993e58453baf06d61544c4c10 Mon Sep 17 00:00:00 2001 From: Heuzef Date: Thu, 12 Mar 2026 00:45:43 +0100 Subject: [PATCH 108/131] Update Dockerfile.user (#1633) Instruction for USB ports access with container Signed-off-by: Heuzef Co-authored-by: Steven Palma --- docker/Dockerfile.user | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docker/Dockerfile.user b/docker/Dockerfile.user index d43d12816..f267be7f2 100644 --- a/docker/Dockerfile.user +++ b/docker/Dockerfile.user @@ -18,6 +18,8 @@ # docker build -f docker/Dockerfile.user -t lerobot-user . # docker run -it --rm lerobot-user +# With USB physical access : docker run -it --device=/dev/ -v /dev/:/dev/ --rm lerobot-user + # Configure the base image ARG PYTHON_VERSION=3.12 FROM python:${PYTHON_VERSION}-slim From efee611403e6d2a3b500900faf651d7937199282 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Thu, 12 Mar 2026 00:51:31 +0100 Subject: [PATCH 109/131] fix(policies): crop losses based on the action dof (#3133) Co-authored-by: Chenning Yu --- src/lerobot/policies/smolvla/modeling_smolvla.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/lerobot/policies/smolvla/modeling_smolvla.py b/src/lerobot/policies/smolvla/modeling_smolvla.py index 048d974af..32165eba8 100644 --- a/src/lerobot/policies/smolvla/modeling_smolvla.py +++ b/src/lerobot/policies/smolvla/modeling_smolvla.py @@ -377,6 +377,8 @@ class SmolVLAPolicy(PreTrainedPolicy): actions_is_pad = batch.get("action_is_pad") loss_dict = {} losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions, noise, time) + original_action_dim = self.config.action_feature.shape[0] + losses = losses[:, :, :original_action_dim] loss_dict["losses_after_forward"] = losses.clone().mean().item() if actions_is_pad is not None: From 0db5f66ddae6afd62454c418771286fa64dbea20 Mon Sep 17 00:00:00 2001 From: Bruno Machado <72039033+brunomachado37@users.noreply.github.com> Date: Thu, 12 Mar 2026 00:54:08 +0100 Subject: [PATCH 110/131] Add option to disable tags on WandB (#1339) Signed-off-by: Steven Palma Co-authored-by: Steven Palma --- src/lerobot/configs/default.py | 1 + src/lerobot/rl/wandb_utils.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/lerobot/configs/default.py b/src/lerobot/configs/default.py index dcb0cbd54..3fb0c6c4e 100644 --- a/src/lerobot/configs/default.py +++ b/src/lerobot/configs/default.py @@ -47,6 +47,7 @@ class WandBConfig: notes: str | None = None run_id: str | None = None mode: str | None = None # Allowed values: 'online', 'offline' 'disabled'. Defaults to 'online' + add_tags: bool = True # If True, save configuration as tags in the WandB run. @dataclass diff --git a/src/lerobot/rl/wandb_utils.py b/src/lerobot/rl/wandb_utils.py index ee30b75df..e3190b6ce 100644 --- a/src/lerobot/rl/wandb_utils.py +++ b/src/lerobot/rl/wandb_utils.py @@ -98,7 +98,7 @@ class WandBLogger: entity=self.cfg.entity, name=self.job_name, notes=self.cfg.notes, - tags=cfg_to_group(cfg, return_list=True, truncate_tags=True), + tags=cfg_to_group(cfg, return_list=True, truncate_tags=True) if self.cfg.add_tags else None, dir=self.log_dir, config=cfg.to_dict(), # TODO(rcadene): try set to True From 2d6259156bd8702c328bd9d91bfb324ade293c85 Mon Sep 17 00:00:00 2001 From: Caroline Pascal Date: Fri, 13 Mar 2026 04:46:05 +0100 Subject: [PATCH 111/131] fix(links): replacing relative links with absolute links in the contribution guide (#3141) * fix(links): replacing relative links with absolute links in the contribution guide * fix(links): replacing relative link in the README --- CONTRIBUTING.md | 8 ++++---- README.md | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 82147d363..60df93b27 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -2,7 +2,7 @@ Everyone is welcome to contribute, and we value everybody's contribution. Code is not the only way to help the community. Answering questions, helping others, reaching out, and improving the documentation are immensely valuable. -Whichever way you choose to contribute, please be mindful to respect our [code of conduct](./CODE_OF_CONDUCT.md) and our [AI policy](./AI_POLICY.md). +Whichever way you choose to contribute, please be mindful to respect our [code of conduct](https://github.com/huggingface/lerobot/blob/main/CODE_OF_CONDUCT.md) and our [AI policy](https://github.com/huggingface/lerobot/blob/main/AI_POLICY.md). ## Ways to Contribute @@ -32,7 +32,7 @@ git remote add upstream https://github.com/huggingface/lerobot.git ### 2. Environment Installation -Please follow our [Installation Guide](./docs/source/installation.mdx) for the environment setup & installation from source. +Please follow our [Installation Guide](https://huggingface.co/docs/lerobot/installation) for the environment setup & installation from source. ## Running Tests & Quality Checks @@ -75,8 +75,8 @@ pytest -sv tests/test_specific_feature.py Use the templates for required fields and examples. -- **Issues:** Follow the [ticket template](./.github/ISSUE_TEMPLATE/bug-report.yml). -- **Pull requests:** Rebase on `upstream/main`, use a descriptive branch (don't work on `main`), run `pre-commit` and tests locally, and follow the [PR template](./.github/PULL_REQUEST_TEMPLATE.md). +- **Issues:** Follow the [ticket template](https://github.com/huggingface/lerobot/blob/main/.github/ISSUE_TEMPLATE/bug-report.yml). +- **Pull requests:** Rebase on `upstream/main`, use a descriptive branch (don't work on `main`), run `pre-commit` and tests locally, and follow the [PR template](https://github.com/huggingface/lerobot/blob/main/.github/PULL_REQUEST_TEMPLATE.md). One member of the LeRobot team will then review your contribution. diff --git a/README.md b/README.md index e273a4de8..f58b337b3 100644 --- a/README.md +++ b/README.md @@ -165,7 +165,7 @@ If you are referencing our research or the academic paper, please also cite our ## Contribute -We welcome contributions from everyone in the community! To get started, please read our [CONTRIBUTING.md](./CONTRIBUTING.md) guide. Whether you're adding a new feature, improving documentation, or fixing a bug, your help and feedback are invaluable. We're incredibly excited about the future of open-source robotics and can't wait to work with you on what's next—thank you for your support! +We welcome contributions from everyone in the community! To get started, please read our [CONTRIBUTING.md](https://github.com/huggingface/lerobot/blob/main/CONTRIBUTING.md) guide. Whether you're adding a new feature, improving documentation, or fixing a bug, your help and feedback are invaluable. We're incredibly excited about the future of open-source robotics and can't wait to work with you on what's next—thank you for your support!

SO101 Video From 2ec1dafcc28b8737d8c606e29f05b44f5cf99ddf Mon Sep 17 00:00:00 2001 From: Caroline Pascal Date: Sat, 14 Mar 2026 18:49:53 +0100 Subject: [PATCH 112/131] fix(lerobot-train): fixing lerobot-train --help by removing % in the docstrings (draccus does not support the character) (#3161) --- src/lerobot/configs/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lerobot/configs/train.py b/src/lerobot/configs/train.py index 9d20afc68..8b8aedb26 100644 --- a/src/lerobot/configs/train.py +++ b/src/lerobot/configs/train.py @@ -51,7 +51,7 @@ class TrainPipelineConfig(HubMixin): # AND for the evaluation environments. seed: int | None = 1000 # Set to True to use deterministic cuDNN algorithms for reproducibility. - # This disables cudnn.benchmark and may reduce training speed by ~10-20%. + # This disables cudnn.benchmark and may reduce training speed by ~10-20 percent. cudnn_deterministic: bool = False # Number of workers for the dataloader. num_workers: int = 4 From a07b1d76f17b359f55ffcfceed1a3b532448ab25 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Sun, 15 Mar 2026 20:26:06 -0700 Subject: [PATCH 113/131] chore(dependecies): untangle dependecies across internal modules (#3149) --- examples/phone_to_so100/evaluate.py | 3 +- examples/phone_to_so100/record.py | 3 +- examples/phone_to_so100/replay.py | 3 +- examples/phone_to_so100/teleoperate.py | 3 +- examples/so100_to_so100_EE/evaluate.py | 3 +- examples/so100_to_so100_EE/record.py | 3 +- examples/so100_to_so100_EE/replay.py | 3 +- examples/so100_to_so100_EE/teleoperate.py | 3 +- src/lerobot/async_inference/policy_server.py | 6 +- src/lerobot/configs/policies.py | 2 +- src/lerobot/datasets/pipeline_features.py | 3 +- src/lerobot/envs/libero.py | 2 +- src/lerobot/envs/metaworld.py | 2 +- src/lerobot/envs/utils.py | 2 +- src/lerobot/policies/factory.py | 3 +- src/lerobot/policies/groot/processor_groot.py | 2 +- src/lerobot/policies/pi05/processor_pi05.py | 2 +- .../policies/pi0_fast/processor_pi0_fast.py | 2 +- src/lerobot/policies/sarm/processor_sarm.py | 2 +- .../policies/smolvla/modeling_smolvla.py | 2 +- src/lerobot/policies/utils.py | 2 +- src/lerobot/policies/xvla/processor_xvla.py | 2 +- src/lerobot/processor/__init__.py | 15 +-- src/lerobot/processor/batch_processor.py | 2 +- src/lerobot/processor/converters.py | 3 +- .../processor/delta_action_processor.py | 2 +- src/lerobot/processor/device_processor.py | 4 +- src/lerobot/processor/factory.py | 3 +- src/lerobot/processor/gym_action_processor.py | 4 +- src/lerobot/processor/hil_processor.py | 3 +- src/lerobot/processor/normalize_processor.py | 2 +- src/lerobot/processor/pipeline.py | 2 +- src/lerobot/processor/tokenizer_processor.py | 2 +- src/lerobot/rl/actor.py | 4 +- src/lerobot/rl/learner.py | 2 +- .../bi_openarm_follower.py | 2 +- .../robots/bi_so_follower/bi_so_follower.py | 2 +- .../robot_earthrover_mini_plus.py | 2 +- src/lerobot/robots/hope_jr/hope_jr_arm.py | 2 +- src/lerobot/robots/hope_jr/hope_jr_hand.py | 2 +- .../robots/koch_follower/koch_follower.py | 2 +- src/lerobot/robots/lekiwi/lekiwi.py | 2 +- src/lerobot/robots/lekiwi/lekiwi_client.py | 2 +- .../robots/omx_follower/omx_follower.py | 2 +- .../openarm_follower/openarm_follower.py | 2 +- src/lerobot/robots/reachy2/robot_reachy2.py | 2 +- src/lerobot/robots/robot.py | 2 +- src/lerobot/robots/so_follower/so_follower.py | 2 +- src/lerobot/robots/unitree_g1/unitree_g1.py | 5 +- src/lerobot/scripts/lerobot_eval.py | 5 +- src/lerobot/scripts/lerobot_record.py | 2 +- .../bi_openarm_leader/bi_openarm_leader.py | 2 +- .../teleoperators/gamepad/teleop_gamepad.py | 2 +- .../teleoperators/keyboard/teleop_keyboard.py | 2 +- .../openarm_leader/openarm_leader.py | 2 +- .../openarm_mini/openarm_mini.py | 2 +- .../teleoperators/phone/phone_processor.py | 3 +- src/lerobot/teleoperators/teleoperator.py | 2 +- src/lerobot/{processor/core.py => types.py} | 0 src/lerobot/utils/control_utils.py | 3 +- src/lerobot/utils/device_utils.py | 109 ++++++++++++++++++ src/lerobot/utils/utils.py | 103 ++--------------- src/lerobot/utils/visualization_utils.py | 2 +- tests/mocks/mock_robot.py | 2 +- tests/mocks/mock_teleop.py | 2 +- tests/policies/groot/test_groot_lerobot.py | 5 +- .../policies/groot/test_groot_vs_original.py | 3 +- .../test_pi0_fast_original_vs_lerobot.py | 3 +- .../pi0_pi05/test_pi05_original_vs_lerobot.py | 3 +- .../pi0_pi05/test_pi0_original_vs_lerobot.py | 3 +- tests/policies/test_sarm_processor.py | 2 +- .../xvla/test_xvla_original_vs_lerobot.py | 3 +- tests/processor/test_batch_conversion.py | 3 +- tests/processor/test_converters.py | 2 +- tests/processor/test_device_processor.py | 3 +- tests/processor/test_normalize_processor.py | 2 +- tests/processor/test_observation_processor.py | 3 +- tests/processor/test_tokenizer_processor.py | 3 +- tests/training/test_visual_validation.py | 2 +- tests/utils.py | 2 +- tests/utils/test_visualization_utils.py | 2 +- 81 files changed, 235 insertions(+), 189 deletions(-) rename src/lerobot/{processor/core.py => types.py} (100%) create mode 100644 src/lerobot/utils/device_utils.py diff --git a/examples/phone_to_so100/evaluate.py b/examples/phone_to_so100/evaluate.py index 837217eda..c1291d101 100644 --- a/examples/phone_to_so100/evaluate.py +++ b/examples/phone_to_so100/evaluate.py @@ -23,8 +23,6 @@ from lerobot.model.kinematics import RobotKinematics from lerobot.policies.act.modeling_act import ACTPolicy from lerobot.policies.factory import make_pre_post_processors from lerobot.processor import ( - RobotAction, - RobotObservation, RobotProcessorPipeline, make_default_teleop_action_processor, ) @@ -40,6 +38,7 @@ from lerobot.robots.so_follower.robot_kinematic_processor import ( InverseKinematicsEEToJoints, ) from lerobot.scripts.lerobot_record import record_loop +from lerobot.types import RobotAction, RobotObservation from lerobot.utils.control_utils import init_keyboard_listener from lerobot.utils.utils import log_say from lerobot.utils.visualization_utils import init_rerun diff --git a/examples/phone_to_so100/record.py b/examples/phone_to_so100/record.py index 1f5005db9..756c6f42d 100644 --- a/examples/phone_to_so100/record.py +++ b/examples/phone_to_so100/record.py @@ -19,7 +19,7 @@ from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features from lerobot.datasets.utils import combine_feature_dicts from lerobot.model.kinematics import RobotKinematics -from lerobot.processor import RobotAction, RobotObservation, RobotProcessorPipeline +from lerobot.processor import RobotProcessorPipeline from lerobot.processor.converters import ( observation_to_transition, robot_action_observation_to_transition, @@ -38,6 +38,7 @@ from lerobot.scripts.lerobot_record import record_loop from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS from lerobot.teleoperators.phone.phone_processor import MapPhoneActionToRobotAction from lerobot.teleoperators.phone.teleop_phone import Phone +from lerobot.types import RobotAction, RobotObservation from lerobot.utils.control_utils import init_keyboard_listener from lerobot.utils.utils import log_say from lerobot.utils.visualization_utils import init_rerun diff --git a/examples/phone_to_so100/replay.py b/examples/phone_to_so100/replay.py index 9d7806cf4..7b955cdb7 100644 --- a/examples/phone_to_so100/replay.py +++ b/examples/phone_to_so100/replay.py @@ -18,7 +18,7 @@ import time from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.model.kinematics import RobotKinematics -from lerobot.processor import RobotAction, RobotObservation, RobotProcessorPipeline +from lerobot.processor import RobotProcessorPipeline from lerobot.processor.converters import ( robot_action_observation_to_transition, transition_to_robot_action, @@ -27,6 +27,7 @@ from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig from lerobot.robots.so_follower.robot_kinematic_processor import ( InverseKinematicsEEToJoints, ) +from lerobot.types import RobotAction, RobotObservation from lerobot.utils.constants import ACTION from lerobot.utils.robot_utils import precise_sleep from lerobot.utils.utils import log_say diff --git a/examples/phone_to_so100/teleoperate.py b/examples/phone_to_so100/teleoperate.py index 6eaaec806..7242c39ce 100644 --- a/examples/phone_to_so100/teleoperate.py +++ b/examples/phone_to_so100/teleoperate.py @@ -16,7 +16,7 @@ import time from lerobot.model.kinematics import RobotKinematics -from lerobot.processor import RobotAction, RobotObservation, RobotProcessorPipeline +from lerobot.processor import RobotProcessorPipeline from lerobot.processor.converters import ( robot_action_observation_to_transition, transition_to_robot_action, @@ -31,6 +31,7 @@ from lerobot.robots.so_follower.robot_kinematic_processor import ( from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS from lerobot.teleoperators.phone.phone_processor import MapPhoneActionToRobotAction from lerobot.teleoperators.phone.teleop_phone import Phone +from lerobot.types import RobotAction, RobotObservation from lerobot.utils.robot_utils import precise_sleep from lerobot.utils.visualization_utils import init_rerun, log_rerun_data diff --git a/examples/so100_to_so100_EE/evaluate.py b/examples/so100_to_so100_EE/evaluate.py index b614b89f2..45a87ebad 100644 --- a/examples/so100_to_so100_EE/evaluate.py +++ b/examples/so100_to_so100_EE/evaluate.py @@ -23,8 +23,6 @@ from lerobot.model.kinematics import RobotKinematics from lerobot.policies.act.modeling_act import ACTPolicy from lerobot.policies.factory import make_pre_post_processors from lerobot.processor import ( - RobotAction, - RobotObservation, RobotProcessorPipeline, make_default_teleop_action_processor, ) @@ -40,6 +38,7 @@ from lerobot.robots.so_follower.robot_kinematic_processor import ( InverseKinematicsEEToJoints, ) from lerobot.scripts.lerobot_record import record_loop +from lerobot.types import RobotAction, RobotObservation from lerobot.utils.control_utils import init_keyboard_listener from lerobot.utils.utils import log_say from lerobot.utils.visualization_utils import init_rerun diff --git a/examples/so100_to_so100_EE/record.py b/examples/so100_to_so100_EE/record.py index d85a1c5cc..8fa862d6e 100644 --- a/examples/so100_to_so100_EE/record.py +++ b/examples/so100_to_so100_EE/record.py @@ -20,7 +20,7 @@ from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features from lerobot.datasets.utils import combine_feature_dicts from lerobot.model.kinematics import RobotKinematics -from lerobot.processor import RobotAction, RobotObservation, RobotProcessorPipeline +from lerobot.processor import RobotProcessorPipeline from lerobot.processor.converters import ( observation_to_transition, robot_action_observation_to_transition, @@ -35,6 +35,7 @@ from lerobot.robots.so_follower.robot_kinematic_processor import ( ) from lerobot.scripts.lerobot_record import record_loop from lerobot.teleoperators.so_leader import SO100Leader, SO100LeaderConfig +from lerobot.types import RobotAction, RobotObservation from lerobot.utils.control_utils import init_keyboard_listener from lerobot.utils.utils import log_say from lerobot.utils.visualization_utils import init_rerun diff --git a/examples/so100_to_so100_EE/replay.py b/examples/so100_to_so100_EE/replay.py index 47a2f6635..b042e02dd 100644 --- a/examples/so100_to_so100_EE/replay.py +++ b/examples/so100_to_so100_EE/replay.py @@ -19,7 +19,7 @@ import time from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.model.kinematics import RobotKinematics -from lerobot.processor import RobotAction, RobotObservation, RobotProcessorPipeline +from lerobot.processor import RobotProcessorPipeline from lerobot.processor.converters import ( robot_action_observation_to_transition, transition_to_robot_action, @@ -28,6 +28,7 @@ from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig from lerobot.robots.so_follower.robot_kinematic_processor import ( InverseKinematicsEEToJoints, ) +from lerobot.types import RobotAction, RobotObservation from lerobot.utils.constants import ACTION from lerobot.utils.robot_utils import precise_sleep from lerobot.utils.utils import log_say diff --git a/examples/so100_to_so100_EE/teleoperate.py b/examples/so100_to_so100_EE/teleoperate.py index 71d2899de..af21f079b 100644 --- a/examples/so100_to_so100_EE/teleoperate.py +++ b/examples/so100_to_so100_EE/teleoperate.py @@ -17,7 +17,7 @@ import time from lerobot.model.kinematics import RobotKinematics -from lerobot.processor import RobotAction, RobotObservation, RobotProcessorPipeline +from lerobot.processor import RobotProcessorPipeline from lerobot.processor.converters import ( robot_action_observation_to_transition, robot_action_to_transition, @@ -30,6 +30,7 @@ from lerobot.robots.so_follower.robot_kinematic_processor import ( InverseKinematicsEEToJoints, ) from lerobot.teleoperators.so_leader import SO100Leader, SO100LeaderConfig +from lerobot.types import RobotAction, RobotObservation from lerobot.utils.robot_utils import precise_sleep from lerobot.utils.visualization_utils import init_rerun, log_rerun_data diff --git a/src/lerobot/async_inference/policy_server.py b/src/lerobot/async_inference/policy_server.py index aedce2a74..3f63929df 100644 --- a/src/lerobot/async_inference/policy_server.py +++ b/src/lerobot/async_inference/policy_server.py @@ -39,15 +39,13 @@ import grpc import torch from lerobot.policies.factory import get_policy_class, make_pre_post_processors -from lerobot.processor import ( - PolicyAction, - PolicyProcessorPipeline, -) +from lerobot.processor import PolicyProcessorPipeline from lerobot.transport import ( services_pb2, # type: ignore services_pb2_grpc, # type: ignore ) from lerobot.transport.utils import receive_bytes_in_chunks +from lerobot.types import PolicyAction from .configs import PolicyServerConfig from .constants import SUPPORTED_POLICIES diff --git a/src/lerobot/configs/policies.py b/src/lerobot/configs/policies.py index 44b013c29..ce567b8f5 100644 --- a/src/lerobot/configs/policies.py +++ b/src/lerobot/configs/policies.py @@ -30,8 +30,8 @@ from lerobot.configs.types import FeatureType, PolicyFeature from lerobot.optim.optimizers import OptimizerConfig from lerobot.optim.schedulers import LRSchedulerConfig from lerobot.utils.constants import ACTION, OBS_STATE +from lerobot.utils.device_utils import auto_select_torch_device, is_amp_available, is_torch_device_available from lerobot.utils.hub import HubMixin -from lerobot.utils.utils import auto_select_torch_device, is_amp_available, is_torch_device_available T = TypeVar("T", bound="PreTrainedConfig") logger = getLogger(__name__) diff --git a/src/lerobot/datasets/pipeline_features.py b/src/lerobot/datasets/pipeline_features.py index 161633f26..f824eb9bc 100644 --- a/src/lerobot/datasets/pipeline_features.py +++ b/src/lerobot/datasets/pipeline_features.py @@ -18,7 +18,8 @@ from typing import Any from lerobot.configs.types import PipelineFeatureType from lerobot.datasets.utils import hw_to_dataset_features -from lerobot.processor import DataProcessorPipeline, RobotAction, RobotObservation +from lerobot.processor import DataProcessorPipeline +from lerobot.types import RobotAction, RobotObservation from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE, OBS_STR diff --git a/src/lerobot/envs/libero.py b/src/lerobot/envs/libero.py index d20dae8ea..6d3589fed 100644 --- a/src/lerobot/envs/libero.py +++ b/src/lerobot/envs/libero.py @@ -29,7 +29,7 @@ from gymnasium import spaces from libero.libero import benchmark, get_libero_path from libero.libero.envs import OffScreenRenderEnv -from lerobot.processor import RobotObservation +from lerobot.types import RobotObservation def _parse_camera_names(camera_name: str | Sequence[str]) -> list[str]: diff --git a/src/lerobot/envs/metaworld.py b/src/lerobot/envs/metaworld.py index 4d91e002d..e9e29f304 100644 --- a/src/lerobot/envs/metaworld.py +++ b/src/lerobot/envs/metaworld.py @@ -25,7 +25,7 @@ import metaworld.policies as policies import numpy as np from gymnasium import spaces -from lerobot.processor import RobotObservation +from lerobot.types import RobotObservation # ---- Load configuration data from the external JSON file ---- CONFIG_PATH = Path(__file__).parent / "metaworld_config.json" diff --git a/src/lerobot/envs/utils.py b/src/lerobot/envs/utils.py index 09431a18d..fd17a6762 100644 --- a/src/lerobot/envs/utils.py +++ b/src/lerobot/envs/utils.py @@ -29,7 +29,7 @@ from torch import Tensor from lerobot.configs.types import FeatureType, PolicyFeature from lerobot.envs.configs import EnvConfig -from lerobot.processor import RobotObservation +from lerobot.types import RobotObservation from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE, OBS_STR from lerobot.utils.utils import get_channel_first_image_shape diff --git a/src/lerobot/policies/factory.py b/src/lerobot/policies/factory.py index d50d8652a..9515d5b82 100644 --- a/src/lerobot/policies/factory.py +++ b/src/lerobot/policies/factory.py @@ -43,13 +43,14 @@ from lerobot.policies.utils import validate_visual_features_consistency from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig from lerobot.policies.wall_x.configuration_wall_x import WallXConfig from lerobot.policies.xvla.configuration_xvla import XVLAConfig -from lerobot.processor import PolicyAction, PolicyProcessorPipeline +from lerobot.processor import PolicyProcessorPipeline from lerobot.processor.converters import ( batch_to_transition, policy_action_to_transition, transition_to_batch, transition_to_policy_action, ) +from lerobot.types import PolicyAction from lerobot.utils.constants import ( ACTION, POLICY_POSTPROCESSOR_DEFAULT_NAME, diff --git a/src/lerobot/policies/groot/processor_groot.py b/src/lerobot/policies/groot/processor_groot.py index 14149cf2f..8bf9dabca 100644 --- a/src/lerobot/policies/groot/processor_groot.py +++ b/src/lerobot/policies/groot/processor_groot.py @@ -49,7 +49,7 @@ from lerobot.processor.converters import ( policy_action_to_transition, transition_to_policy_action, ) -from lerobot.processor.core import EnvTransition, TransitionKey +from lerobot.types import EnvTransition, TransitionKey from lerobot.utils.constants import ( ACTION, HF_LEROBOT_HOME, diff --git a/src/lerobot/policies/pi05/processor_pi05.py b/src/lerobot/policies/pi05/processor_pi05.py index 6e01a4e16..425a85577 100644 --- a/src/lerobot/policies/pi05/processor_pi05.py +++ b/src/lerobot/policies/pi05/processor_pi05.py @@ -36,7 +36,7 @@ from lerobot.processor import ( UnnormalizerProcessorStep, ) from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action -from lerobot.processor.core import EnvTransition, TransitionKey +from lerobot.types import EnvTransition, TransitionKey from lerobot.utils.constants import ( OBS_STATE, POLICY_POSTPROCESSOR_DEFAULT_NAME, diff --git a/src/lerobot/policies/pi0_fast/processor_pi0_fast.py b/src/lerobot/policies/pi0_fast/processor_pi0_fast.py index fde7d5c80..46e54432a 100644 --- a/src/lerobot/policies/pi0_fast/processor_pi0_fast.py +++ b/src/lerobot/policies/pi0_fast/processor_pi0_fast.py @@ -37,7 +37,7 @@ from lerobot.processor import ( UnnormalizerProcessorStep, ) from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action -from lerobot.processor.core import EnvTransition, TransitionKey +from lerobot.types import EnvTransition, TransitionKey from lerobot.utils.constants import ( OBS_STATE, POLICY_POSTPROCESSOR_DEFAULT_NAME, diff --git a/src/lerobot/policies/sarm/processor_sarm.py b/src/lerobot/policies/sarm/processor_sarm.py index 8f2bc23db..f377a7ffa 100644 --- a/src/lerobot/policies/sarm/processor_sarm.py +++ b/src/lerobot/policies/sarm/processor_sarm.py @@ -48,8 +48,8 @@ from lerobot.processor.converters import ( policy_action_to_transition, transition_to_policy_action, ) -from lerobot.processor.core import EnvTransition, TransitionKey from lerobot.processor.pipeline import PipelineFeatureType +from lerobot.types import EnvTransition, TransitionKey from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME diff --git a/src/lerobot/policies/smolvla/modeling_smolvla.py b/src/lerobot/policies/smolvla/modeling_smolvla.py index 32165eba8..7110ba7d2 100644 --- a/src/lerobot/policies/smolvla/modeling_smolvla.py +++ b/src/lerobot/policies/smolvla/modeling_smolvla.py @@ -68,7 +68,7 @@ from lerobot.policies.utils import ( populate_queues, ) from lerobot.utils.constants import ACTION, OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS, OBS_STATE -from lerobot.utils.utils import get_safe_dtype +from lerobot.utils.device_utils import get_safe_dtype class ActionSelectKwargs(TypedDict, total=False): diff --git a/src/lerobot/policies/utils.py b/src/lerobot/policies/utils.py index 1a14b2925..9ad5dac4a 100644 --- a/src/lerobot/policies/utils.py +++ b/src/lerobot/policies/utils.py @@ -24,7 +24,7 @@ from torch import nn from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.types import FeatureType, PolicyFeature from lerobot.datasets.utils import build_dataset_frame -from lerobot.processor import PolicyAction, RobotAction, RobotObservation +from lerobot.types import PolicyAction, RobotAction, RobotObservation from lerobot.utils.constants import ACTION, OBS_STR diff --git a/src/lerobot/policies/xvla/processor_xvla.py b/src/lerobot/policies/xvla/processor_xvla.py index c4e3f2d6f..0fa9ffe3f 100644 --- a/src/lerobot/policies/xvla/processor_xvla.py +++ b/src/lerobot/policies/xvla/processor_xvla.py @@ -38,7 +38,7 @@ from lerobot.processor import ( UnnormalizerProcessorStep, ) from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action -from lerobot.processor.core import EnvTransition, TransitionKey +from lerobot.types import EnvTransition, TransitionKey from lerobot.utils.constants import ( OBS_IMAGES, OBS_PREFIX, diff --git a/src/lerobot/processor/__init__.py b/src/lerobot/processor/__init__.py index 0b63e1606..12dcf0c6d 100644 --- a/src/lerobot/processor/__init__.py +++ b/src/lerobot/processor/__init__.py @@ -14,13 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .batch_processor import AddBatchDimensionProcessorStep -from .converters import ( - batch_to_transition, - create_transition, - transition_to_batch, -) -from .core import ( +from lerobot.types import ( EnvAction, EnvTransition, PolicyAction, @@ -28,6 +22,13 @@ from .core import ( RobotObservation, TransitionKey, ) + +from .batch_processor import AddBatchDimensionProcessorStep +from .converters import ( + batch_to_transition, + create_transition, + transition_to_batch, +) from .delta_action_processor import MapDeltaActionToRobotActionStep, MapTensorToDeltaActionDictStep from .device_processor import DeviceProcessorStep from .factory import ( diff --git a/src/lerobot/processor/batch_processor.py b/src/lerobot/processor/batch_processor.py index e1a90421f..c904acf84 100644 --- a/src/lerobot/processor/batch_processor.py +++ b/src/lerobot/processor/batch_processor.py @@ -25,9 +25,9 @@ from dataclasses import dataclass, field from torch import Tensor from lerobot.configs.types import PipelineFeatureType, PolicyFeature +from lerobot.types import EnvTransition, PolicyAction from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE -from .core import EnvTransition, PolicyAction from .pipeline import ( ComplementaryDataProcessorStep, ObservationProcessorStep, diff --git a/src/lerobot/processor/converters.py b/src/lerobot/processor/converters.py index 18c7b0220..ffdf0098c 100644 --- a/src/lerobot/processor/converters.py +++ b/src/lerobot/processor/converters.py @@ -23,10 +23,9 @@ from typing import Any import numpy as np import torch +from lerobot.types import EnvTransition, PolicyAction, RobotAction, RobotObservation, TransitionKey from lerobot.utils.constants import ACTION, DONE, INFO, OBS_PREFIX, REWARD, TRUNCATED -from .core import EnvTransition, PolicyAction, RobotAction, RobotObservation, TransitionKey - @singledispatch def to_tensor( diff --git a/src/lerobot/processor/delta_action_processor.py b/src/lerobot/processor/delta_action_processor.py index a8395637c..f7f5676ac 100644 --- a/src/lerobot/processor/delta_action_processor.py +++ b/src/lerobot/processor/delta_action_processor.py @@ -17,8 +17,8 @@ from dataclasses import dataclass from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature +from lerobot.types import PolicyAction, RobotAction -from .core import PolicyAction, RobotAction from .pipeline import ActionProcessorStep, ProcessorStepRegistry, RobotActionProcessorStep diff --git a/src/lerobot/processor/device_processor.py b/src/lerobot/processor/device_processor.py index 2d0dd0880..36c80e58e 100644 --- a/src/lerobot/processor/device_processor.py +++ b/src/lerobot/processor/device_processor.py @@ -25,9 +25,9 @@ from typing import Any import torch from lerobot.configs.types import PipelineFeatureType, PolicyFeature -from lerobot.utils.utils import get_safe_torch_device +from lerobot.types import EnvTransition, PolicyAction, TransitionKey +from lerobot.utils.device_utils import get_safe_torch_device -from .core import EnvTransition, PolicyAction, TransitionKey from .pipeline import ProcessorStep, ProcessorStepRegistry diff --git a/src/lerobot/processor/factory.py b/src/lerobot/processor/factory.py index 5a0c41072..5028122f1 100644 --- a/src/lerobot/processor/factory.py +++ b/src/lerobot/processor/factory.py @@ -14,13 +14,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +from lerobot.types import RobotAction, RobotObservation + from .converters import ( observation_to_transition, robot_action_observation_to_transition, transition_to_observation, transition_to_robot_action, ) -from .core import RobotAction, RobotObservation from .pipeline import IdentityProcessorStep, RobotProcessorPipeline diff --git a/src/lerobot/processor/gym_action_processor.py b/src/lerobot/processor/gym_action_processor.py index 4f225af92..e756ded7f 100644 --- a/src/lerobot/processor/gym_action_processor.py +++ b/src/lerobot/processor/gym_action_processor.py @@ -17,9 +17,9 @@ from dataclasses import dataclass from lerobot.configs.types import PipelineFeatureType, PolicyFeature +from lerobot.types import EnvAction, EnvTransition, PolicyAction from .converters import to_tensor -from .core import EnvAction, EnvTransition, PolicyAction from .hil_processor import TELEOP_ACTION_KEY from .pipeline import ActionProcessorStep, ProcessorStep, ProcessorStepRegistry @@ -75,7 +75,7 @@ class Numpy2TorchActionProcessorStep(ProcessorStep): def __call__(self, transition: EnvTransition) -> EnvTransition: """Converts numpy action to torch tensor if action exists, otherwise passes through.""" - from .core import TransitionKey + from lerobot.types import TransitionKey self._current_transition = transition.copy() new_transition = self._current_transition diff --git a/src/lerobot/processor/hil_processor.py b/src/lerobot/processor/hil_processor.py index 34eaeed51..0b8521c2b 100644 --- a/src/lerobot/processor/hil_processor.py +++ b/src/lerobot/processor/hil_processor.py @@ -30,7 +30,8 @@ from lerobot.teleoperators.utils import TeleopEvents if TYPE_CHECKING: from lerobot.teleoperators.teleoperator import Teleoperator -from .core import EnvTransition, PolicyAction, TransitionKey +from lerobot.types import EnvTransition, PolicyAction, TransitionKey + from .pipeline import ( ComplementaryDataProcessorStep, InfoProcessorStep, diff --git a/src/lerobot/processor/normalize_processor.py b/src/lerobot/processor/normalize_processor.py index 4769b91ac..8a7a1176a 100644 --- a/src/lerobot/processor/normalize_processor.py +++ b/src/lerobot/processor/normalize_processor.py @@ -26,10 +26,10 @@ from torch import Tensor from lerobot.configs.types import FeatureType, NormalizationMode, PipelineFeatureType, PolicyFeature from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.types import EnvTransition, PolicyAction, TransitionKey from lerobot.utils.constants import ACTION from .converters import from_tensor_to_numpy, to_tensor -from .core import EnvTransition, PolicyAction, TransitionKey from .pipeline import PolicyProcessorPipeline, ProcessorStep, ProcessorStepRegistry, RobotObservation diff --git a/src/lerobot/processor/pipeline.py b/src/lerobot/processor/pipeline.py index db1c3015c..abfb31421 100644 --- a/src/lerobot/processor/pipeline.py +++ b/src/lerobot/processor/pipeline.py @@ -46,10 +46,10 @@ from huggingface_hub import hf_hub_download from safetensors.torch import load_file, save_file from lerobot.configs.types import PipelineFeatureType, PolicyFeature +from lerobot.types import EnvAction, EnvTransition, PolicyAction, RobotAction, RobotObservation, TransitionKey from lerobot.utils.hub import HubMixin from .converters import batch_to_transition, create_transition, transition_to_batch -from .core import EnvAction, EnvTransition, PolicyAction, RobotAction, RobotObservation, TransitionKey # Generic type variables for pipeline input and output. TInput = TypeVar("TInput") diff --git a/src/lerobot/processor/tokenizer_processor.py b/src/lerobot/processor/tokenizer_processor.py index da6e600af..2a972ecc8 100644 --- a/src/lerobot/processor/tokenizer_processor.py +++ b/src/lerobot/processor/tokenizer_processor.py @@ -30,6 +30,7 @@ from typing import TYPE_CHECKING, Any import torch from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature +from lerobot.types import EnvTransition, RobotObservation, TransitionKey from lerobot.utils.constants import ( ACTION_TOKEN_MASK, ACTION_TOKENS, @@ -40,7 +41,6 @@ from lerobot.utils.constants import ( ) from lerobot.utils.import_utils import _transformers_available -from .core import EnvTransition, RobotObservation, TransitionKey from .pipeline import ActionProcessorStep, ObservationProcessorStep, ProcessorStepRegistry # Conditional import for type checking and lazy loading diff --git a/src/lerobot/rl/actor.py b/src/lerobot/rl/actor.py index 7427633d2..18c0ca1ea 100644 --- a/src/lerobot/rl/actor.py +++ b/src/lerobot/rl/actor.py @@ -62,7 +62,6 @@ from lerobot.configs import parser from lerobot.configs.train import TrainRLServerPipelineConfig from lerobot.policies.factory import make_policy from lerobot.policies.sac.modeling_sac import SACPolicy -from lerobot.processor import TransitionKey from lerobot.rl.process import ProcessSignalHandler from lerobot.rl.queue import get_last_item_from_queue from lerobot.robots import so_follower # noqa: F401 @@ -77,6 +76,8 @@ from lerobot.transport.utils import ( send_bytes_in_chunks, transitions_to_bytes, ) +from lerobot.types import TransitionKey +from lerobot.utils.device_utils import get_safe_torch_device from lerobot.utils.random_utils import set_seed from lerobot.utils.robot_utils import precise_sleep from lerobot.utils.transition import ( @@ -86,7 +87,6 @@ from lerobot.utils.transition import ( ) from lerobot.utils.utils import ( TimerManager, - get_safe_torch_device, init_logging, ) diff --git a/src/lerobot/rl/learner.py b/src/lerobot/rl/learner.py index ee09ac9ac..2853fbcb3 100644 --- a/src/lerobot/rl/learner.py +++ b/src/lerobot/rl/learner.py @@ -86,6 +86,7 @@ from lerobot.utils.constants import ( PRETRAINED_MODEL_DIR, TRAINING_STATE_DIR, ) +from lerobot.utils.device_utils import get_safe_torch_device from lerobot.utils.random_utils import set_seed from lerobot.utils.train_utils import ( get_step_checkpoint_dir, @@ -96,7 +97,6 @@ from lerobot.utils.train_utils import ( from lerobot.utils.transition import move_state_dict_to_device, move_transition_to_device from lerobot.utils.utils import ( format_big_number, - get_safe_torch_device, init_logging, ) diff --git a/src/lerobot/robots/bi_openarm_follower/bi_openarm_follower.py b/src/lerobot/robots/bi_openarm_follower/bi_openarm_follower.py index 2e3885e67..7f5e92271 100644 --- a/src/lerobot/robots/bi_openarm_follower/bi_openarm_follower.py +++ b/src/lerobot/robots/bi_openarm_follower/bi_openarm_follower.py @@ -17,8 +17,8 @@ import logging from functools import cached_property -from lerobot.processor import RobotAction, RobotObservation from lerobot.robots.openarm_follower import OpenArmFollower, OpenArmFollowerConfig +from lerobot.types import RobotAction, RobotObservation from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from ..robot import Robot diff --git a/src/lerobot/robots/bi_so_follower/bi_so_follower.py b/src/lerobot/robots/bi_so_follower/bi_so_follower.py index 28c58b898..ba1826e29 100644 --- a/src/lerobot/robots/bi_so_follower/bi_so_follower.py +++ b/src/lerobot/robots/bi_so_follower/bi_so_follower.py @@ -17,8 +17,8 @@ import logging from functools import cached_property -from lerobot.processor import RobotAction, RobotObservation from lerobot.robots.so_follower import SOFollower, SOFollowerRobotConfig +from lerobot.types import RobotAction, RobotObservation from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from ..robot import Robot diff --git a/src/lerobot/robots/earthrover_mini_plus/robot_earthrover_mini_plus.py b/src/lerobot/robots/earthrover_mini_plus/robot_earthrover_mini_plus.py index cdf6efde1..299206a1e 100644 --- a/src/lerobot/robots/earthrover_mini_plus/robot_earthrover_mini_plus.py +++ b/src/lerobot/robots/earthrover_mini_plus/robot_earthrover_mini_plus.py @@ -23,7 +23,7 @@ import cv2 import numpy as np import requests -from lerobot.processor import RobotAction, RobotObservation +from lerobot.types import RobotAction, RobotObservation from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from lerobot.utils.errors import DeviceNotConnectedError diff --git a/src/lerobot/robots/hope_jr/hope_jr_arm.py b/src/lerobot/robots/hope_jr/hope_jr_arm.py index e8269ae46..7f6492ef0 100644 --- a/src/lerobot/robots/hope_jr/hope_jr_arm.py +++ b/src/lerobot/robots/hope_jr/hope_jr_arm.py @@ -24,7 +24,7 @@ from lerobot.motors.calibration_gui import RangeFinderGUI from lerobot.motors.feetech import ( FeetechMotorsBus, ) -from lerobot.processor import RobotAction, RobotObservation +from lerobot.types import RobotAction, RobotObservation from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from ..robot import Robot diff --git a/src/lerobot/robots/hope_jr/hope_jr_hand.py b/src/lerobot/robots/hope_jr/hope_jr_hand.py index a05c4bbcb..784804836 100644 --- a/src/lerobot/robots/hope_jr/hope_jr_hand.py +++ b/src/lerobot/robots/hope_jr/hope_jr_hand.py @@ -24,7 +24,7 @@ from lerobot.motors.calibration_gui import RangeFinderGUI from lerobot.motors.feetech import ( FeetechMotorsBus, ) -from lerobot.processor import RobotAction, RobotObservation +from lerobot.types import RobotAction, RobotObservation from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from ..robot import Robot diff --git a/src/lerobot/robots/koch_follower/koch_follower.py b/src/lerobot/robots/koch_follower/koch_follower.py index 53a32beed..44e83f6a3 100644 --- a/src/lerobot/robots/koch_follower/koch_follower.py +++ b/src/lerobot/robots/koch_follower/koch_follower.py @@ -24,7 +24,7 @@ from lerobot.motors.dynamixel import ( DynamixelMotorsBus, OperatingMode, ) -from lerobot.processor import RobotAction, RobotObservation +from lerobot.types import RobotAction, RobotObservation from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from ..robot import Robot diff --git a/src/lerobot/robots/lekiwi/lekiwi.py b/src/lerobot/robots/lekiwi/lekiwi.py index 9d11a000f..60fac89e5 100644 --- a/src/lerobot/robots/lekiwi/lekiwi.py +++ b/src/lerobot/robots/lekiwi/lekiwi.py @@ -28,7 +28,7 @@ from lerobot.motors.feetech import ( FeetechMotorsBus, OperatingMode, ) -from lerobot.processor import RobotAction, RobotObservation +from lerobot.types import RobotAction, RobotObservation from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from ..robot import Robot diff --git a/src/lerobot/robots/lekiwi/lekiwi_client.py b/src/lerobot/robots/lekiwi/lekiwi_client.py index 1d5ea64a6..fd43e84fe 100644 --- a/src/lerobot/robots/lekiwi/lekiwi_client.py +++ b/src/lerobot/robots/lekiwi/lekiwi_client.py @@ -22,7 +22,7 @@ from functools import cached_property import cv2 import numpy as np -from lerobot.processor import RobotAction, RobotObservation +from lerobot.types import RobotAction, RobotObservation from lerobot.utils.constants import ACTION, OBS_STATE from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from lerobot.utils.errors import DeviceNotConnectedError diff --git a/src/lerobot/robots/omx_follower/omx_follower.py b/src/lerobot/robots/omx_follower/omx_follower.py index e0b612c60..5d161daa2 100644 --- a/src/lerobot/robots/omx_follower/omx_follower.py +++ b/src/lerobot/robots/omx_follower/omx_follower.py @@ -25,7 +25,7 @@ from lerobot.motors.dynamixel import ( DynamixelMotorsBus, OperatingMode, ) -from lerobot.processor import RobotAction, RobotObservation +from lerobot.types import RobotAction, RobotObservation from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from ..robot import Robot diff --git a/src/lerobot/robots/openarm_follower/openarm_follower.py b/src/lerobot/robots/openarm_follower/openarm_follower.py index c865f1ec1..99e8b920b 100644 --- a/src/lerobot/robots/openarm_follower/openarm_follower.py +++ b/src/lerobot/robots/openarm_follower/openarm_follower.py @@ -22,7 +22,7 @@ from typing import Any from lerobot.cameras.utils import make_cameras_from_configs from lerobot.motors import Motor, MotorCalibration, MotorNormMode from lerobot.motors.damiao import DamiaoMotorsBus -from lerobot.processor import RobotAction, RobotObservation +from lerobot.types import RobotAction, RobotObservation from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from ..robot import Robot diff --git a/src/lerobot/robots/reachy2/robot_reachy2.py b/src/lerobot/robots/reachy2/robot_reachy2.py index fb466f85b..5227a096a 100644 --- a/src/lerobot/robots/reachy2/robot_reachy2.py +++ b/src/lerobot/robots/reachy2/robot_reachy2.py @@ -19,7 +19,7 @@ import time from typing import TYPE_CHECKING, Any from lerobot.cameras.utils import make_cameras_from_configs -from lerobot.processor import RobotAction, RobotObservation +from lerobot.types import RobotAction, RobotObservation from lerobot.utils.import_utils import _reachy2_sdk_available from ..robot import Robot diff --git a/src/lerobot/robots/robot.py b/src/lerobot/robots/robot.py index d165886b9..1b556f963 100644 --- a/src/lerobot/robots/robot.py +++ b/src/lerobot/robots/robot.py @@ -19,7 +19,7 @@ from pathlib import Path import draccus from lerobot.motors import MotorCalibration -from lerobot.processor import RobotAction, RobotObservation +from lerobot.types import RobotAction, RobotObservation from lerobot.utils.constants import HF_LEROBOT_CALIBRATION, ROBOTS from .config import RobotConfig diff --git a/src/lerobot/robots/so_follower/so_follower.py b/src/lerobot/robots/so_follower/so_follower.py index c898e9137..ca132d102 100644 --- a/src/lerobot/robots/so_follower/so_follower.py +++ b/src/lerobot/robots/so_follower/so_follower.py @@ -24,7 +24,7 @@ from lerobot.motors.feetech import ( FeetechMotorsBus, OperatingMode, ) -from lerobot.processor import RobotAction, RobotObservation +from lerobot.types import RobotAction, RobotObservation from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from ..robot import Robot diff --git a/src/lerobot/robots/unitree_g1/unitree_g1.py b/src/lerobot/robots/unitree_g1/unitree_g1.py index 41146ebe6..9e373c05f 100644 --- a/src/lerobot/robots/unitree_g1/unitree_g1.py +++ b/src/lerobot/robots/unitree_g1/unitree_g1.py @@ -26,8 +26,6 @@ from typing import TYPE_CHECKING, Protocol, runtime_checkable import numpy as np from lerobot.cameras.utils import make_cameras_from_configs -from lerobot.envs.factory import make_env -from lerobot.processor import RobotAction, RobotObservation from lerobot.robots.unitree_g1.g1_kinematics import G1_29_ArmIK from lerobot.robots.unitree_g1.g1_utils import ( REMOTE_AXES, @@ -37,6 +35,7 @@ from lerobot.robots.unitree_g1.g1_utils import ( default_remote_input, make_locomotion_controller, ) +from lerobot.types import RobotAction, RobotObservation from lerobot.utils.import_utils import _unitree_sdk_available from ..robot import Robot @@ -291,6 +290,8 @@ class UnitreeG1(Robot): def connect(self, calibrate: bool = True) -> None: # connect to DDS # Initialize DDS channel and simulation environment if self.config.is_simulation: + from lerobot.envs.factory import make_env + self._ChannelFactoryInitialize(0, "lo") self._env_wrapper = make_env("lerobot/unitree-g1-mujoco", trust_remote_code=True) # Extract the actual gym env from the dict structure diff --git a/src/lerobot/scripts/lerobot_eval.py b/src/lerobot/scripts/lerobot_eval.py index e32b80404..6d814f498 100644 --- a/src/lerobot/scripts/lerobot_eval.py +++ b/src/lerobot/scripts/lerobot_eval.py @@ -80,13 +80,14 @@ from lerobot.envs.utils import ( ) from lerobot.policies.factory import make_policy, make_pre_post_processors from lerobot.policies.pretrained import PreTrainedPolicy -from lerobot.processor import PolicyAction, PolicyProcessorPipeline +from lerobot.processor import PolicyProcessorPipeline +from lerobot.types import PolicyAction from lerobot.utils.constants import ACTION, DONE, OBS_STR, REWARD +from lerobot.utils.device_utils import get_safe_torch_device from lerobot.utils.import_utils import register_third_party_plugins from lerobot.utils.io_utils import write_video from lerobot.utils.random_utils import set_seed from lerobot.utils.utils import ( - get_safe_torch_device, init_logging, inside_slurm, ) diff --git a/src/lerobot/scripts/lerobot_record.py b/src/lerobot/scripts/lerobot_record.py index dc682fe6f..345d18f23 100644 --- a/src/lerobot/scripts/lerobot_record.py +++ b/src/lerobot/scripts/lerobot_record.py @@ -139,10 +139,10 @@ from lerobot.utils.control_utils import ( sanity_check_dataset_name, sanity_check_dataset_robot_compatibility, ) +from lerobot.utils.device_utils import get_safe_torch_device from lerobot.utils.import_utils import register_third_party_plugins from lerobot.utils.robot_utils import precise_sleep from lerobot.utils.utils import ( - get_safe_torch_device, init_logging, log_say, ) diff --git a/src/lerobot/teleoperators/bi_openarm_leader/bi_openarm_leader.py b/src/lerobot/teleoperators/bi_openarm_leader/bi_openarm_leader.py index 74b0c9b83..b44f1fbea 100644 --- a/src/lerobot/teleoperators/bi_openarm_leader/bi_openarm_leader.py +++ b/src/lerobot/teleoperators/bi_openarm_leader/bi_openarm_leader.py @@ -17,8 +17,8 @@ import logging from functools import cached_property -from lerobot.processor import RobotAction from lerobot.teleoperators.openarm_leader import OpenArmLeaderConfig +from lerobot.types import RobotAction from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from ..openarm_leader import OpenArmLeader diff --git a/src/lerobot/teleoperators/gamepad/teleop_gamepad.py b/src/lerobot/teleoperators/gamepad/teleop_gamepad.py index 69cb0f971..8c1796e45 100644 --- a/src/lerobot/teleoperators/gamepad/teleop_gamepad.py +++ b/src/lerobot/teleoperators/gamepad/teleop_gamepad.py @@ -20,7 +20,7 @@ from typing import Any import numpy as np -from lerobot.processor import RobotAction +from lerobot.types import RobotAction from lerobot.utils.decorators import check_if_not_connected from ..teleoperator import Teleoperator diff --git a/src/lerobot/teleoperators/keyboard/teleop_keyboard.py b/src/lerobot/teleoperators/keyboard/teleop_keyboard.py index 919f463d3..6c1ef7492 100644 --- a/src/lerobot/teleoperators/keyboard/teleop_keyboard.py +++ b/src/lerobot/teleoperators/keyboard/teleop_keyboard.py @@ -21,7 +21,7 @@ import time from queue import Queue from typing import Any -from lerobot.processor import RobotAction +from lerobot.types import RobotAction from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from ..teleoperator import Teleoperator diff --git a/src/lerobot/teleoperators/openarm_leader/openarm_leader.py b/src/lerobot/teleoperators/openarm_leader/openarm_leader.py index d9eaabe0f..65da7416a 100644 --- a/src/lerobot/teleoperators/openarm_leader/openarm_leader.py +++ b/src/lerobot/teleoperators/openarm_leader/openarm_leader.py @@ -20,7 +20,7 @@ from typing import Any from lerobot.motors import Motor, MotorCalibration, MotorNormMode from lerobot.motors.damiao import DamiaoMotorsBus -from lerobot.processor import RobotAction +from lerobot.types import RobotAction from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from ..teleoperator import Teleoperator diff --git a/src/lerobot/teleoperators/openarm_mini/openarm_mini.py b/src/lerobot/teleoperators/openarm_mini/openarm_mini.py index 3fbcecf24..23594caa9 100644 --- a/src/lerobot/teleoperators/openarm_mini/openarm_mini.py +++ b/src/lerobot/teleoperators/openarm_mini/openarm_mini.py @@ -23,7 +23,7 @@ from lerobot.motors.feetech import ( FeetechMotorsBus, OperatingMode, ) -from lerobot.processor import RobotAction +from lerobot.types import RobotAction from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from ..teleoperator import Teleoperator diff --git a/src/lerobot/teleoperators/phone/phone_processor.py b/src/lerobot/teleoperators/phone/phone_processor.py index 67e64c7d5..c498bed7d 100644 --- a/src/lerobot/teleoperators/phone/phone_processor.py +++ b/src/lerobot/teleoperators/phone/phone_processor.py @@ -17,8 +17,9 @@ from dataclasses import dataclass, field from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature -from lerobot.processor import ProcessorStepRegistry, RobotAction, RobotActionProcessorStep +from lerobot.processor import ProcessorStepRegistry, RobotActionProcessorStep from lerobot.teleoperators.phone.config_phone import PhoneOS +from lerobot.types import RobotAction @ProcessorStepRegistry.register("map_phone_action_to_robot_action") diff --git a/src/lerobot/teleoperators/teleoperator.py b/src/lerobot/teleoperators/teleoperator.py index 847b88b7f..f47904423 100644 --- a/src/lerobot/teleoperators/teleoperator.py +++ b/src/lerobot/teleoperators/teleoperator.py @@ -20,7 +20,7 @@ from typing import Any import draccus from lerobot.motors.motors_bus import MotorCalibration -from lerobot.processor import RobotAction +from lerobot.types import RobotAction from lerobot.utils.constants import HF_LEROBOT_CALIBRATION, TELEOPERATORS from .config import TeleoperatorConfig diff --git a/src/lerobot/processor/core.py b/src/lerobot/types.py similarity index 100% rename from src/lerobot/processor/core.py rename to src/lerobot/types.py diff --git a/src/lerobot/utils/control_utils.py b/src/lerobot/utils/control_utils.py index 7c605af17..94cd82fa1 100644 --- a/src/lerobot/utils/control_utils.py +++ b/src/lerobot/utils/control_utils.py @@ -32,8 +32,9 @@ from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.utils import DEFAULT_FEATURES from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.utils import prepare_observation_for_inference -from lerobot.processor import PolicyAction, PolicyProcessorPipeline +from lerobot.processor import PolicyProcessorPipeline from lerobot.robots import Robot +from lerobot.types import PolicyAction @cache diff --git a/src/lerobot/utils/device_utils.py b/src/lerobot/utils/device_utils.py new file mode 100644 index 000000000..37981f07f --- /dev/null +++ b/src/lerobot/utils/device_utils.py @@ -0,0 +1,109 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +import torch + + +def auto_select_torch_device() -> torch.device: + """Tries to select automatically a torch device.""" + if torch.cuda.is_available(): + logging.info("Cuda backend detected, using cuda.") + return torch.device("cuda") + elif torch.backends.mps.is_available(): + logging.info("Metal backend detected, using mps.") + return torch.device("mps") + elif torch.xpu.is_available(): + logging.info("Intel XPU backend detected, using xpu.") + return torch.device("xpu") + else: + logging.warning("No accelerated backend detected. Using default cpu, this will be slow.") + return torch.device("cpu") + + +# TODO(Steven): Remove log. log shouldn't be an argument, this should be handled by the logger level +def get_safe_torch_device(try_device: str, log: bool = False) -> torch.device: + """Given a string, return a torch.device with checks on whether the device is available.""" + try_device = str(try_device) + if try_device.startswith("cuda"): + assert torch.cuda.is_available() + device = torch.device(try_device) + elif try_device == "mps": + assert torch.backends.mps.is_available() + device = torch.device("mps") + elif try_device == "xpu": + assert torch.xpu.is_available() + device = torch.device("xpu") + elif try_device == "cpu": + device = torch.device("cpu") + if log: + logging.warning("Using CPU, this will be slow.") + else: + device = torch.device(try_device) + if log: + logging.warning(f"Using custom {try_device} device.") + return device + + +def get_safe_dtype(dtype: torch.dtype, device: str | torch.device): + """ + mps is currently not compatible with float64 + """ + if isinstance(device, torch.device): + device = device.type + if device == "mps" and dtype == torch.float64: + return torch.float32 + if device == "xpu" and dtype == torch.float64: + if hasattr(torch.xpu, "get_device_capability"): + device_capability = torch.xpu.get_device_capability() + # NOTE: Some Intel XPU devices do not support double precision (FP64). + # The `has_fp64` flag is returned by `torch.xpu.get_device_capability()` + # when available; if False, we fall back to float32 for compatibility. + if not device_capability.get("has_fp64", False): + logging.warning(f"Device {device} does not support float64, using float32 instead.") + return torch.float32 + else: + logging.warning( + f"Device {device} capability check failed. Assuming no support for float64, using float32 instead." + ) + return torch.float32 + return dtype + else: + return dtype + + +def is_torch_device_available(try_device: str) -> bool: + try_device = str(try_device) # Ensure try_device is a string + if try_device.startswith("cuda"): + return torch.cuda.is_available() + elif try_device == "mps": + return torch.backends.mps.is_available() + elif try_device == "xpu": + return torch.xpu.is_available() + elif try_device == "cpu": + return True + else: + raise ValueError(f"Unknown device {try_device}. Supported devices are: cuda, mps, xpu or cpu.") + + +def is_amp_available(device: str): + if device in ["cuda", "xpu", "cpu"]: + return True + elif device == "mps": + return False + else: + raise ValueError(f"Unknown device '{device}.") diff --git a/src/lerobot/utils/utils.py b/src/lerobot/utils/utils.py index c7ad2bbdb..b9f8441d6 100644 --- a/src/lerobot/utils/utils.py +++ b/src/lerobot/utils/utils.py @@ -13,6 +13,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import logging import os import platform @@ -24,11 +26,12 @@ from copy import copy, deepcopy from datetime import datetime from pathlib import Path from statistics import mean +from typing import TYPE_CHECKING import numpy as np -import torch -from accelerate import Accelerator -from datasets.utils.logging import disable_progress_bar, enable_progress_bar + +if TYPE_CHECKING: + from accelerate import Accelerator def inside_slurm(): @@ -37,96 +40,6 @@ def inside_slurm(): return "SLURM_JOB_ID" in os.environ -def auto_select_torch_device() -> torch.device: - """Tries to select automatically a torch device.""" - if torch.cuda.is_available(): - logging.info("Cuda backend detected, using cuda.") - return torch.device("cuda") - elif torch.backends.mps.is_available(): - logging.info("Metal backend detected, using mps.") - return torch.device("mps") - elif torch.xpu.is_available(): - logging.info("Intel XPU backend detected, using xpu.") - return torch.device("xpu") - else: - logging.warning("No accelerated backend detected. Using default cpu, this will be slow.") - return torch.device("cpu") - - -# TODO(Steven): Remove log. log shouldn't be an argument, this should be handled by the logger level -def get_safe_torch_device(try_device: str, log: bool = False) -> torch.device: - """Given a string, return a torch.device with checks on whether the device is available.""" - try_device = str(try_device) - if try_device.startswith("cuda"): - assert torch.cuda.is_available() - device = torch.device(try_device) - elif try_device == "mps": - assert torch.backends.mps.is_available() - device = torch.device("mps") - elif try_device == "xpu": - assert torch.xpu.is_available() - device = torch.device("xpu") - elif try_device == "cpu": - device = torch.device("cpu") - if log: - logging.warning("Using CPU, this will be slow.") - else: - device = torch.device(try_device) - if log: - logging.warning(f"Using custom {try_device} device.") - return device - - -def get_safe_dtype(dtype: torch.dtype, device: str | torch.device): - """ - mps is currently not compatible with float64 - """ - if isinstance(device, torch.device): - device = device.type - if device == "mps" and dtype == torch.float64: - return torch.float32 - if device == "xpu" and dtype == torch.float64: - if hasattr(torch.xpu, "get_device_capability"): - device_capability = torch.xpu.get_device_capability() - # NOTE: Some Intel XPU devices do not support double precision (FP64). - # The `has_fp64` flag is returned by `torch.xpu.get_device_capability()` - # when available; if False, we fall back to float32 for compatibility. - if not device_capability.get("has_fp64", False): - logging.warning(f"Device {device} does not support float64, using float32 instead.") - return torch.float32 - else: - logging.warning( - f"Device {device} capability check failed. Assuming no support for float64, using float32 instead." - ) - return torch.float32 - return dtype - else: - return dtype - - -def is_torch_device_available(try_device: str) -> bool: - try_device = str(try_device) # Ensure try_device is a string - if try_device.startswith("cuda"): - return torch.cuda.is_available() - elif try_device == "mps": - return torch.backends.mps.is_available() - elif try_device == "xpu": - return torch.xpu.is_available() - elif try_device == "cpu": - return True - else: - raise ValueError(f"Unknown device {try_device}. Supported devices are: cuda, mps, xpu or cpu.") - - -def is_amp_available(device: str): - if device in ["cuda", "xpu", "cpu"]: - return True - elif device == "mps": - return False - else: - raise ValueError(f"Unknown device '{device}.") - - def init_logging( log_file: Path | None = None, display_pid: bool = False, @@ -297,9 +210,13 @@ class SuppressProgressBars: """ def __enter__(self): + from datasets.utils.logging import disable_progress_bar + disable_progress_bar() def __exit__(self, exc_type, exc_val, exc_tb): + from datasets.utils.logging import enable_progress_bar + enable_progress_bar() diff --git a/src/lerobot/utils/visualization_utils.py b/src/lerobot/utils/visualization_utils.py index 31ca8d247..782358c9e 100644 --- a/src/lerobot/utils/visualization_utils.py +++ b/src/lerobot/utils/visualization_utils.py @@ -18,7 +18,7 @@ import os import numpy as np import rerun as rr -from lerobot.processor import RobotAction, RobotObservation +from lerobot.types import RobotAction, RobotObservation from .constants import ACTION, ACTION_PREFIX, OBS_PREFIX, OBS_STR diff --git a/tests/mocks/mock_robot.py b/tests/mocks/mock_robot.py index f69a2c02a..5504b30bf 100644 --- a/tests/mocks/mock_robot.py +++ b/tests/mocks/mock_robot.py @@ -20,8 +20,8 @@ from functools import cached_property from lerobot.cameras import CameraConfig, make_cameras_from_configs from lerobot.motors.motors_bus import Motor, MotorNormMode -from lerobot.processor import RobotAction, RobotObservation from lerobot.robots import Robot, RobotConfig +from lerobot.types import RobotAction, RobotObservation from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from tests.mocks.mock_motors_bus import MockMotorsBus diff --git a/tests/mocks/mock_teleop.py b/tests/mocks/mock_teleop.py index 89174dadf..b84b2b891 100644 --- a/tests/mocks/mock_teleop.py +++ b/tests/mocks/mock_teleop.py @@ -19,8 +19,8 @@ from dataclasses import dataclass from functools import cached_property from typing import Any -from lerobot.processor import RobotAction from lerobot.teleoperators import Teleoperator, TeleoperatorConfig +from lerobot.types import RobotAction from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected diff --git a/tests/policies/groot/test_groot_lerobot.py b/tests/policies/groot/test_groot_lerobot.py index 760f13a5f..e299a34e2 100644 --- a/tests/policies/groot/test_groot_lerobot.py +++ b/tests/policies/groot/test_groot_lerobot.py @@ -28,8 +28,9 @@ import torch from lerobot.policies.groot.configuration_groot import GrootConfig from lerobot.policies.groot.modeling_groot import GrootPolicy from lerobot.policies.groot.processor_groot import make_groot_pre_post_processors -from lerobot.processor import PolicyAction, PolicyProcessorPipeline -from lerobot.utils.utils import auto_select_torch_device +from lerobot.processor import PolicyProcessorPipeline +from lerobot.types import PolicyAction +from lerobot.utils.device_utils import auto_select_torch_device from tests.utils import require_cuda # noqa: E402 pytest.importorskip("transformers") diff --git a/tests/policies/groot/test_groot_vs_original.py b/tests/policies/groot/test_groot_vs_original.py index e9dd1df00..0adad96ca 100644 --- a/tests/policies/groot/test_groot_vs_original.py +++ b/tests/policies/groot/test_groot_vs_original.py @@ -28,7 +28,8 @@ import torch from lerobot.policies.groot.configuration_groot import GrootConfig from lerobot.policies.groot.modeling_groot import GrootPolicy from lerobot.policies.groot.processor_groot import make_groot_pre_post_processors -from lerobot.processor import PolicyAction, PolicyProcessorPipeline +from lerobot.processor import PolicyProcessorPipeline +from lerobot.types import PolicyAction pytest.importorskip("gr00t") pytest.importorskip("transformers") diff --git a/tests/policies/pi0_fast/test_pi0_fast_original_vs_lerobot.py b/tests/policies/pi0_fast/test_pi0_fast_original_vs_lerobot.py index d24bb11d7..b757d5a94 100644 --- a/tests/policies/pi0_fast/test_pi0_fast_original_vs_lerobot.py +++ b/tests/policies/pi0_fast/test_pi0_fast_original_vs_lerobot.py @@ -31,7 +31,8 @@ pytest.importorskip("scipy") from lerobot.policies.pi0_fast.configuration_pi0_fast import PI0FastConfig from lerobot.policies.pi0_fast.modeling_pi0_fast import PI0FastPolicy from lerobot.policies.pi0_fast.processor_pi0_fast import make_pi0_fast_pre_post_processors -from lerobot.processor import PolicyAction, PolicyProcessorPipeline # noqa: E402 +from lerobot.processor import PolicyProcessorPipeline # noqa: E402 +from lerobot.types import PolicyAction # noqa: E402 from lerobot.utils.constants import ( ACTION_TOKEN_MASK, ACTION_TOKENS, diff --git a/tests/policies/pi0_pi05/test_pi05_original_vs_lerobot.py b/tests/policies/pi0_pi05/test_pi05_original_vs_lerobot.py index f70707262..a965132b0 100644 --- a/tests/policies/pi0_pi05/test_pi05_original_vs_lerobot.py +++ b/tests/policies/pi0_pi05/test_pi05_original_vs_lerobot.py @@ -42,7 +42,8 @@ from transformers import AutoTokenizer # noqa: E402 from lerobot.policies.pi05 import PI05Config, PI05Policy # noqa: E402 from lerobot.policies.pi05.processor_pi05 import make_pi05_pre_post_processors # noqa: E402 -from lerobot.processor import PolicyAction, PolicyProcessorPipeline # noqa: E402 +from lerobot.processor import PolicyProcessorPipeline # noqa: E402 +from lerobot.types import PolicyAction # noqa: E402 # TODO: ADDING DEFAULT IMAGES_FEATURES TO CONFIG DUMMY_ACTION_DIM = 32 diff --git a/tests/policies/pi0_pi05/test_pi0_original_vs_lerobot.py b/tests/policies/pi0_pi05/test_pi0_original_vs_lerobot.py index d3d1c1908..62e34b70d 100644 --- a/tests/policies/pi0_pi05/test_pi0_original_vs_lerobot.py +++ b/tests/policies/pi0_pi05/test_pi0_original_vs_lerobot.py @@ -41,7 +41,8 @@ from transformers import AutoTokenizer # noqa: E402 from lerobot.policies.pi0 import PI0Config, PI0Policy # noqa: E402 from lerobot.policies.pi0.processor_pi0 import make_pi0_pre_post_processors # noqa: E402 -from lerobot.processor import PolicyAction, PolicyProcessorPipeline # noqa: E402 +from lerobot.processor import PolicyProcessorPipeline # noqa: E402 +from lerobot.types import PolicyAction # noqa: E402 # TODO: ADDING DEFAULT IMAGES_FEATURES TO CONFIG DUMMY_ACTION_DIM = 32 diff --git a/tests/policies/test_sarm_processor.py b/tests/policies/test_sarm_processor.py index 66404f663..5b90784a6 100644 --- a/tests/policies/test_sarm_processor.py +++ b/tests/policies/test_sarm_processor.py @@ -25,7 +25,7 @@ import pandas as pd import pytest import torch -from lerobot.processor.core import TransitionKey +from lerobot.types import TransitionKey class MockDatasetMeta: diff --git a/tests/policies/xvla/test_xvla_original_vs_lerobot.py b/tests/policies/xvla/test_xvla_original_vs_lerobot.py index e36d14d01..3cea11329 100644 --- a/tests/policies/xvla/test_xvla_original_vs_lerobot.py +++ b/tests/policies/xvla/test_xvla_original_vs_lerobot.py @@ -30,7 +30,8 @@ pytest.importorskip("transformers") from lerobot.policies.xvla.configuration_xvla import XVLAConfig from lerobot.policies.xvla.modeling_xvla import XVLAPolicy from lerobot.policies.xvla.processor_xvla import make_xvla_pre_post_processors -from lerobot.processor import PolicyAction, PolicyProcessorPipeline # noqa: E402 +from lerobot.processor import PolicyProcessorPipeline # noqa: E402 +from lerobot.types import PolicyAction # noqa: E402 from lerobot.utils.constants import OBS_IMAGES, OBS_STATE # noqa: E402 from tests.utils import require_cuda # noqa: E402 diff --git a/tests/processor/test_batch_conversion.py b/tests/processor/test_batch_conversion.py index 477381618..d589b6c5e 100644 --- a/tests/processor/test_batch_conversion.py +++ b/tests/processor/test_batch_conversion.py @@ -16,8 +16,9 @@ import torch -from lerobot.processor import DataProcessorPipeline, TransitionKey +from lerobot.processor import DataProcessorPipeline from lerobot.processor.converters import batch_to_transition, transition_to_batch +from lerobot.types import TransitionKey from lerobot.utils.constants import ACTION, DONE, OBS_IMAGE, OBS_PREFIX, OBS_STATE, REWARD, TRUNCATED diff --git a/tests/processor/test_converters.py b/tests/processor/test_converters.py index 47a6eea18..91afdd0e5 100644 --- a/tests/processor/test_converters.py +++ b/tests/processor/test_converters.py @@ -18,13 +18,13 @@ import numpy as np import pytest import torch -from lerobot.processor import TransitionKey from lerobot.processor.converters import ( batch_to_transition, create_transition, to_tensor, transition_to_batch, ) +from lerobot.types import TransitionKey from lerobot.utils.constants import ACTION, DONE, OBS_STATE, OBS_STR, REWARD diff --git a/tests/processor/test_device_processor.py b/tests/processor/test_device_processor.py index bb7d467bf..57b923076 100644 --- a/tests/processor/test_device_processor.py +++ b/tests/processor/test_device_processor.py @@ -19,8 +19,9 @@ import pytest import torch from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature -from lerobot.processor import DataProcessorPipeline, DeviceProcessorStep, TransitionKey +from lerobot.processor import DataProcessorPipeline, DeviceProcessorStep from lerobot.processor.converters import create_transition, identity_transition +from lerobot.types import TransitionKey from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE diff --git a/tests/processor/test_normalize_processor.py b/tests/processor/test_normalize_processor.py index 208a6b5c5..cd5c75005 100644 --- a/tests/processor/test_normalize_processor.py +++ b/tests/processor/test_normalize_processor.py @@ -30,7 +30,7 @@ from lerobot.processor import ( ) from lerobot.processor.converters import create_transition, identity_transition, to_tensor from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE, OBS_STR -from lerobot.utils.utils import auto_select_torch_device +from lerobot.utils.device_utils import auto_select_torch_device def test_numpy_conversion(): diff --git a/tests/processor/test_observation_processor.py b/tests/processor/test_observation_processor.py index 11b58a66c..923059210 100644 --- a/tests/processor/test_observation_processor.py +++ b/tests/processor/test_observation_processor.py @@ -19,8 +19,9 @@ import pytest import torch from lerobot.configs.types import FeatureType, PipelineFeatureType -from lerobot.processor import TransitionKey, VanillaObservationProcessorStep +from lerobot.processor import VanillaObservationProcessorStep from lerobot.processor.converters import create_transition +from lerobot.types import TransitionKey from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE from tests.conftest import assert_contract_is_typed diff --git a/tests/processor/test_tokenizer_processor.py b/tests/processor/test_tokenizer_processor.py index 64cc8aac8..2f1c4cc9c 100644 --- a/tests/processor/test_tokenizer_processor.py +++ b/tests/processor/test_tokenizer_processor.py @@ -25,8 +25,9 @@ import pytest import torch from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature -from lerobot.processor import DataProcessorPipeline, TokenizerProcessorStep, TransitionKey +from lerobot.processor import DataProcessorPipeline, TokenizerProcessorStep from lerobot.processor.converters import create_transition, identity_transition +from lerobot.types import TransitionKey from lerobot.utils.constants import ( ACTION, OBS_IMAGE, diff --git a/tests/training/test_visual_validation.py b/tests/training/test_visual_validation.py index af693fe5e..89351e3c2 100644 --- a/tests/training/test_visual_validation.py +++ b/tests/training/test_visual_validation.py @@ -37,7 +37,7 @@ from lerobot.configs.train import TrainPipelineConfig from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.policies.factory import make_policy_config from lerobot.scripts.lerobot_train import train -from lerobot.utils.utils import auto_select_torch_device +from lerobot.utils.device_utils import auto_select_torch_device pytest.importorskip("transformers") diff --git a/tests/utils.py b/tests/utils.py index a77082ea9..33c554804 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -21,8 +21,8 @@ import pytest import torch from lerobot import available_cameras, available_motors, available_robots +from lerobot.utils.device_utils import auto_select_torch_device from lerobot.utils.import_utils import is_package_available -from lerobot.utils.utils import auto_select_torch_device DEVICE = os.environ.get("LEROBOT_TEST_DEVICE", str(auto_select_torch_device())) diff --git a/tests/utils/test_visualization_utils.py b/tests/utils/test_visualization_utils.py index 408f636cb..c8e5a92a8 100644 --- a/tests/utils/test_visualization_utils.py +++ b/tests/utils/test_visualization_utils.py @@ -21,7 +21,7 @@ from types import SimpleNamespace import numpy as np import pytest -from lerobot.processor import TransitionKey +from lerobot.types import TransitionKey from lerobot.utils.constants import OBS_STATE From 7c2ec31793da193e03853699cb5040db6ce1caa5 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Sun, 15 Mar 2026 20:42:15 -0700 Subject: [PATCH 114/131] refactor(datasets): module cleanup (#3169) --- .../datasets/backward_compatibility.py | 2 +- src/lerobot/datasets/lerobot_dataset.py | 3 +- src/lerobot/datasets/online_buffer.py | 382 ------------------ .../datasets/push_dataset_to_hub/utils.py | 73 ---- .../augment_dataset_quantile_stats.py | 2 +- .../convert_dataset_v21_to_v30.py | 4 +- tests/datasets/test_dataset_utils.py | 17 +- tests/datasets/test_online_buffer.py | 282 ------------- tests/datasets/test_sampler.py | 18 +- 9 files changed, 38 insertions(+), 745 deletions(-) delete mode 100644 src/lerobot/datasets/online_buffer.py delete mode 100644 src/lerobot/datasets/push_dataset_to_hub/utils.py rename src/lerobot/{datasets/v30 => scripts}/augment_dataset_quantile_stats.py (99%) rename src/lerobot/{datasets/v30 => scripts}/convert_dataset_v21_to_v30.py (99%) delete mode 100644 tests/datasets/test_online_buffer.py diff --git a/src/lerobot/datasets/backward_compatibility.py b/src/lerobot/datasets/backward_compatibility.py index ae95c5f7b..aefbfd55b 100644 --- a/src/lerobot/datasets/backward_compatibility.py +++ b/src/lerobot/datasets/backward_compatibility.py @@ -20,7 +20,7 @@ The dataset you requested ({repo_id}) is in {version} format. We introduced a new format since v3.0 which is not backward compatible with v2.1. Please, update your dataset to the new format using this command: ``` -python -m lerobot.datasets.v30.convert_dataset_v21_to_v30 --repo-id={repo_id} +python -m lerobot.scripts.convert_dataset_v21_to_v30 --repo-id={repo_id} ``` If you already have a converted version uploaded to the hub, then this error might be because of diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index 26f0c769c..11c10f493 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -596,7 +596,7 @@ class LeRobotDataset(torch.utils.data.Dataset): the dataset from that address and load it, pending your dataset is compliant with codebase_version v3.0. If your dataset has been created before this new format, you will be prompted to convert it using our conversion script from v2.1 to v3.0, which you can find at - lerobot/datasets/v30/convert_dataset_v21_to_v30.py. + lerobot/scripts/convert_dataset_v21_to_v30.py. 2. Your dataset doesn't already exists (either on local disk or on the Hub): you can create an empty @@ -1683,7 +1683,6 @@ class LeRobotDataset(torch.utils.data.Dataset): if image_writer_processes or image_writer_threads: obj.start_image_writer(image_writer_processes, image_writer_threads) - # TODO(aliberts, rcadene, alexander-soare): Merge this with OnlineBuffer/DataBuffer obj.episode_buffer = obj.create_episode_buffer() obj.episodes = None diff --git a/src/lerobot/datasets/online_buffer.py b/src/lerobot/datasets/online_buffer.py deleted file mode 100644 index 563d800b9..000000000 --- a/src/lerobot/datasets/online_buffer.py +++ /dev/null @@ -1,382 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""An online buffer for the online training loop in train.py - -Note to maintainers: This duplicates some logic from LeRobotDataset and EpisodeAwareSampler. We should -consider converging to one approach. Here we have opted to use numpy.memmap to back the data buffer. It's much -faster than using HuggingFace Datasets as there's no conversion to an intermediate non-python object. Also it -supports in-place slicing and mutation which is very handy for a dynamic buffer. -""" - -import os -from pathlib import Path -from typing import Any - -import numpy as np -import torch - -from lerobot.datasets.lerobot_dataset import LeRobotDataset - - -def _make_memmap_safe(**kwargs) -> np.memmap: - """Make a numpy memmap with checks on available disk space first. - - Expected kwargs are: "filename", "dtype" (must by np.dtype), "mode" and "shape" - - For information on dtypes: - https://numpy.org/doc/stable/reference/arrays.dtypes.html#arrays-dtypes-constructing - """ - if kwargs["mode"].startswith("w"): - required_space = kwargs["dtype"].itemsize * np.prod(kwargs["shape"]) # bytes - stats = os.statvfs(Path(kwargs["filename"]).parent) - available_space = stats.f_bavail * stats.f_frsize # bytes - if required_space >= available_space * 0.8: - raise RuntimeError( - f"You're about to take up {required_space} of {available_space} bytes available." - ) - return np.memmap(**kwargs) - - -class OnlineBuffer(torch.utils.data.Dataset): - """FIFO data buffer for the online training loop in train.py. - - Follows the protocol of LeRobotDataset as much as is required to have it be used by the online training - loop in the same way that a LeRobotDataset would be used. - - The underlying data structure will have data inserted in a circular fashion. Always insert after the - last index, and when you reach the end, wrap around to the start. - - The data is stored in a numpy memmap. - """ - - NEXT_INDEX_KEY = "_next_index" - OCCUPANCY_MASK_KEY = "_occupancy_mask" - INDEX_KEY = "index" - FRAME_INDEX_KEY = "frame_index" - EPISODE_INDEX_KEY = "episode_index" - TIMESTAMP_KEY = "timestamp" - IS_PAD_POSTFIX = "_is_pad" - - def __init__( - self, - write_dir: str | Path, - data_spec: dict[str, Any] | None, - buffer_capacity: int | None, - fps: float | None = None, - delta_timestamps: dict[str, list[float]] | dict[str, np.ndarray] | None = None, - ): - """ - The online buffer can be provided from scratch or you can load an existing online buffer by passing - a `write_dir` associated with an existing buffer. - - Args: - write_dir: Where to keep the numpy memmap files. One memmap file will be stored for each data key. - Note that if the files already exist, they are opened in read-write mode (used for training - resumption.) - data_spec: A mapping from data key to data specification, like {data_key: {"shape": tuple[int], - "dtype": np.dtype}}. This should include all the data that you wish to record into the buffer, - but note that "index", "frame_index" and "episode_index" are already accounted for by this - class, so you don't need to include them. - buffer_capacity: How many frames should be stored in the buffer as a maximum. Be aware of your - system's available disk space when choosing this. - fps: Same as the fps concept in LeRobot dataset. Here it needs to be provided for the - delta_timestamps logic. You can pass None if you are not using delta_timestamps. - delta_timestamps: Same as the delta_timestamps concept in LeRobotDataset. This is internally - converted to dict[str, np.ndarray] for optimization purposes. - - """ - self.set_delta_timestamps(delta_timestamps) - self._fps = fps - # Tolerance in seconds used to discard loaded frames when their timestamps are not close enough from - # the requested frames. It is only used when `delta_timestamps` is provided. - # minus 1e-4 to account for possible numerical error - self.tolerance_s = 1 / self.fps - 1e-4 if fps is not None else None - self._buffer_capacity = buffer_capacity - data_spec = self._make_data_spec(data_spec, buffer_capacity) - Path(write_dir).mkdir(parents=True, exist_ok=True) - self._data = {} - for k, v in data_spec.items(): - self._data[k] = _make_memmap_safe( - filename=Path(write_dir) / k, - dtype=v["dtype"] if v is not None else None, - mode="r+" if (Path(write_dir) / k).exists() else "w+", - shape=tuple(v["shape"]) if v is not None else None, - ) - - @property - def delta_timestamps(self) -> dict[str, np.ndarray] | None: - return self._delta_timestamps - - def set_delta_timestamps(self, value: dict[str, list[float]] | None): - """Set delta_timestamps converting the values to numpy arrays. - - The conversion is for an optimization in the __getitem__. The loop is much slower if the arrays - need to be converted into numpy arrays. - """ - if value is not None: - self._delta_timestamps = {k: np.array(v) for k, v in value.items()} - else: - self._delta_timestamps = None - - def _make_data_spec(self, data_spec: dict[str, Any], buffer_capacity: int) -> dict[str, dict[str, Any]]: - """Makes the data spec for np.memmap.""" - if any(k.startswith("_") for k in data_spec): - raise ValueError( - "data_spec keys should not start with '_'. This prefix is reserved for internal logic." - ) - preset_keys = { - OnlineBuffer.INDEX_KEY, - OnlineBuffer.FRAME_INDEX_KEY, - OnlineBuffer.EPISODE_INDEX_KEY, - OnlineBuffer.TIMESTAMP_KEY, - } - if len(intersection := set(data_spec).intersection(preset_keys)) > 0: - raise ValueError( - f"data_spec should not contain any of {preset_keys} as these are handled internally. " - f"The provided data_spec has {intersection}." - ) - complete_data_spec = { - # _next_index will be a pointer to the next index that we should start filling from when we add - # more data. - OnlineBuffer.NEXT_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": ()}, - # Since the memmap is initialized with all-zeros, this keeps track of which indices are occupied - # with real data rather than the dummy initialization. - OnlineBuffer.OCCUPANCY_MASK_KEY: {"dtype": np.dtype("?"), "shape": (buffer_capacity,)}, - OnlineBuffer.INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)}, - OnlineBuffer.FRAME_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)}, - OnlineBuffer.EPISODE_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)}, - OnlineBuffer.TIMESTAMP_KEY: {"dtype": np.dtype("float64"), "shape": (buffer_capacity,)}, - } - for k, v in data_spec.items(): - complete_data_spec[k] = {"dtype": v["dtype"], "shape": (buffer_capacity, *v["shape"])} - return complete_data_spec - - def add_data(self, data: dict[str, np.ndarray]): - """Add new data to the buffer, which could potentially mean shifting old data out. - - The new data should contain all the frames (in order) of any number of episodes. The indices should - start from 0 (note to the developer: this can easily be generalized). See the `rollout` and - `eval_policy` functions in `eval.py` for more information on how the data is constructed. - - Shift the incoming data index and episode_index to continue on from the last frame. Note that this - will be done in place! - """ - if len(missing_keys := (set(self.data_keys).difference(set(data)))) > 0: - raise ValueError(f"Missing data keys: {missing_keys}") - new_data_length = len(data[self.data_keys[0]]) - if not all(len(data[k]) == new_data_length for k in self.data_keys): - raise ValueError("All data items should have the same length") - - next_index = self._data[OnlineBuffer.NEXT_INDEX_KEY] - - # Sanity check to make sure that the new data indices start from 0. - assert data[OnlineBuffer.EPISODE_INDEX_KEY][0].item() == 0 - assert data[OnlineBuffer.INDEX_KEY][0].item() == 0 - - # Shift the incoming indices if necessary. - if self.num_frames > 0: - last_episode_index = self._data[OnlineBuffer.EPISODE_INDEX_KEY][next_index - 1] - last_data_index = self._data[OnlineBuffer.INDEX_KEY][next_index - 1] - data[OnlineBuffer.EPISODE_INDEX_KEY] += last_episode_index + 1 - data[OnlineBuffer.INDEX_KEY] += last_data_index + 1 - - # Insert the new data starting from next_index. It may be necessary to wrap around to the start. - n_surplus = max(0, new_data_length - (self._buffer_capacity - next_index)) - for k in self.data_keys: - if n_surplus == 0: - slc = slice(next_index, next_index + new_data_length) - self._data[k][slc] = data[k] - self._data[OnlineBuffer.OCCUPANCY_MASK_KEY][slc] = True - else: - self._data[k][next_index:] = data[k][:-n_surplus] - self._data[OnlineBuffer.OCCUPANCY_MASK_KEY][next_index:] = True - self._data[k][:n_surplus] = data[k][-n_surplus:] - if n_surplus == 0: - self._data[OnlineBuffer.NEXT_INDEX_KEY] = next_index + new_data_length - else: - self._data[OnlineBuffer.NEXT_INDEX_KEY] = n_surplus - - @property - def data_keys(self) -> list[str]: - keys = set(self._data) - keys.remove(OnlineBuffer.OCCUPANCY_MASK_KEY) - keys.remove(OnlineBuffer.NEXT_INDEX_KEY) - return sorted(keys) - - @property - def fps(self) -> float | None: - return self._fps - - @property - def num_episodes(self) -> int: - return len( - np.unique(self._data[OnlineBuffer.EPISODE_INDEX_KEY][self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]]) - ) - - @property - def num_frames(self) -> int: - return np.count_nonzero(self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]) - - def __len__(self): - return self.num_frames - - def _item_to_tensors(self, item: dict) -> dict: - item_ = {} - for k, v in item.items(): - if isinstance(v, torch.Tensor): - item_[k] = v - elif isinstance(v, np.ndarray): - item_[k] = torch.from_numpy(v) - else: - item_[k] = torch.tensor(v) - return item_ - - def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: - if idx >= len(self) or idx < -len(self): - raise IndexError - - item = {k: v[idx] for k, v in self._data.items() if not k.startswith("_")} - - if self.delta_timestamps is None: - return self._item_to_tensors(item) - - episode_index = item[OnlineBuffer.EPISODE_INDEX_KEY] - current_ts = item[OnlineBuffer.TIMESTAMP_KEY] - episode_data_indices = np.where( - np.bitwise_and( - self._data[OnlineBuffer.EPISODE_INDEX_KEY] == episode_index, - self._data[OnlineBuffer.OCCUPANCY_MASK_KEY], - ) - )[0] - episode_timestamps = self._data[OnlineBuffer.TIMESTAMP_KEY][episode_data_indices] - - for data_key in self.delta_timestamps: - # Note: The logic in this loop is copied from `load_previous_and_future_frames`. - # Get timestamps used as query to retrieve data of previous/future frames. - query_ts = current_ts + self.delta_timestamps[data_key] - - # Compute distances between each query timestamp and all timestamps of all the frames belonging to - # the episode. - dist = np.abs(query_ts[:, None] - episode_timestamps[None, :]) - argmin_ = np.argmin(dist, axis=1) - min_ = dist[np.arange(dist.shape[0]), argmin_] - - is_pad = min_ > self.tolerance_s - - # Check violated query timestamps are all outside the episode range. - assert ( - (query_ts[is_pad] < episode_timestamps[0]) | (episode_timestamps[-1] < query_ts[is_pad]) - ).all(), ( - f"One or several timestamps unexpectedly violate the tolerance ({min_} > {self.tolerance_s=}" - ") inside the episode range." - ) - - # Load frames for this data key. - item[data_key] = self._data[data_key][episode_data_indices[argmin_]] - - item[f"{data_key}{OnlineBuffer.IS_PAD_POSTFIX}"] = is_pad - - return self._item_to_tensors(item) - - def get_data_by_key(self, key: str) -> torch.Tensor: - """Returns all data for a given data key as a Tensor.""" - return torch.from_numpy(self._data[key][self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]]) - - -def compute_sampler_weights( - offline_dataset: LeRobotDataset, - offline_drop_n_last_frames: int = 0, - online_dataset: OnlineBuffer | None = None, - online_sampling_ratio: float | None = None, - online_drop_n_last_frames: int = 0, -) -> torch.Tensor: - """Compute the sampling weights for the online training dataloader in train.py. - - Args: - offline_dataset: The LeRobotDataset used for offline pre-training. - online_drop_n_last_frames: Number of frames to drop from the end of each offline dataset episode. - online_dataset: The OnlineBuffer used in online training. - online_sampling_ratio: The proportion of data that should be sampled from the online dataset. If an - online dataset is provided, this value must also be provided. - online_drop_n_first_frames: See `offline_drop_n_last_frames`. This is the same, but for the online - dataset. - Returns: - Tensor of weights for [offline_dataset; online_dataset], normalized to 1. - - Notes to maintainers: - - This duplicates some logic from EpisodeAwareSampler. We should consider converging to one approach. - - When used with `torch.utils.data.WeightedRandomSampler`, it could completely replace - `EpisodeAwareSampler` as the online dataset related arguments are optional. The only missing feature - is the ability to turn shuffling off. - - Options `drop_first_n_frames` and `episode_indices_to_use` can be added easily. They were not - included here to avoid adding complexity. - """ - if len(offline_dataset) == 0 and (online_dataset is None or len(online_dataset) == 0): - raise ValueError("At least one of `offline_dataset` or `online_dataset` should be contain data.") - if (online_dataset is None) ^ (online_sampling_ratio is None): - raise ValueError( - "`online_dataset` and `online_sampling_ratio` must be provided together or not at all." - ) - offline_sampling_ratio = 0 if online_sampling_ratio is None else 1 - online_sampling_ratio - - weights = [] - - if len(offline_dataset) > 0: - offline_data_mask_indices = [] - for start_index, end_index in zip( - offline_dataset.meta.episodes["dataset_from_index"], - offline_dataset.meta.episodes["dataset_to_index"], - strict=True, - ): - offline_data_mask_indices.extend(range(start_index, end_index - offline_drop_n_last_frames)) - offline_data_mask = torch.zeros(len(offline_dataset), dtype=torch.bool) - offline_data_mask[torch.tensor(offline_data_mask_indices)] = True - weights.append( - torch.full( - size=(len(offline_dataset),), - fill_value=offline_sampling_ratio / offline_data_mask.sum(), - ) - * offline_data_mask - ) - - if online_dataset is not None and len(online_dataset) > 0: - online_data_mask_indices = [] - episode_indices = online_dataset.get_data_by_key("episode_index") - for episode_idx in torch.unique(episode_indices): - where_episode = torch.where(episode_indices == episode_idx) - start_index = where_episode[0][0] - end_index = where_episode[0][-1] + 1 - online_data_mask_indices.extend( - range(start_index.item(), end_index.item() - online_drop_n_last_frames) - ) - online_data_mask = torch.zeros(len(online_dataset), dtype=torch.bool) - online_data_mask[torch.tensor(online_data_mask_indices)] = True - weights.append( - torch.full( - size=(len(online_dataset),), - fill_value=online_sampling_ratio / online_data_mask.sum(), - ) - * online_data_mask - ) - - weights = torch.cat(weights) - - if weights.sum() == 0: - weights += 1 / len(weights) - else: - weights /= weights.sum() - - return weights diff --git a/src/lerobot/datasets/push_dataset_to_hub/utils.py b/src/lerobot/datasets/push_dataset_to_hub/utils.py deleted file mode 100644 index 48214e1bf..000000000 --- a/src/lerobot/datasets/push_dataset_to_hub/utils.py +++ /dev/null @@ -1,73 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import datasets -import torch - - -# TODO(aliberts): remove -def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> dict[str, torch.Tensor]: - """ - Calculate episode data index for the provided HuggingFace Dataset. Relies on episode_index column of hf_dataset. - - Parameters: - - hf_dataset (datasets.Dataset): A HuggingFace dataset containing the episode index. - - Returns: - - episode_data_index: A dictionary containing the data index for each episode. The dictionary has two keys: - - "from": A tensor containing the starting index of each episode. - - "to": A tensor containing the ending index of each episode. - """ - episode_data_index = {"from": [], "to": []} - - current_episode = None - """ - The episode_index is a list of integers, each representing the episode index of the corresponding example. - For instance, the following is a valid episode_index: - [0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2] - - Below, we iterate through the episode_index and populate the episode_data_index dictionary with the starting and - ending index of each episode. For the episode_index above, the episode_data_index dictionary will look like this: - { - "from": [0, 3, 7], - "to": [3, 7, 12] - } - """ - if len(hf_dataset) == 0: - episode_data_index = { - "from": torch.tensor([]), - "to": torch.tensor([]), - } - return episode_data_index - for idx, episode_idx in enumerate(hf_dataset["episode_index"]): - if episode_idx != current_episode: - # We encountered a new episode, so we append its starting location to the "from" list - episode_data_index["from"].append(idx) - # If this is not the first episode, we append the ending location of the previous episode to the "to" list - if current_episode is not None: - episode_data_index["to"].append(idx) - # Let's keep track of the current episode index - current_episode = episode_idx - else: - # We are still in the same episode, so there is nothing for us to do here - pass - # We have reached the end of the dataset, so we append the ending location of the last episode to the "to" list - episode_data_index["to"].append(idx + 1) - - for k in ["from", "to"]: - episode_data_index[k] = torch.tensor(episode_data_index[k]) - - return episode_data_index diff --git a/src/lerobot/datasets/v30/augment_dataset_quantile_stats.py b/src/lerobot/scripts/augment_dataset_quantile_stats.py similarity index 99% rename from src/lerobot/datasets/v30/augment_dataset_quantile_stats.py rename to src/lerobot/scripts/augment_dataset_quantile_stats.py index 900a43a4f..e6ab6867e 100644 --- a/src/lerobot/datasets/v30/augment_dataset_quantile_stats.py +++ b/src/lerobot/scripts/augment_dataset_quantile_stats.py @@ -28,7 +28,7 @@ quantile statistics (q01, q10, q50, q90, q99) in their metadata. This script: Usage: ```bash -python src/lerobot/datasets/v30/augment_dataset_quantile_stats.py \ +python src/lerobot/scripts/augment_dataset_quantile_stats.py \ --repo-id=lerobot/pusht \ ``` """ diff --git a/src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py b/src/lerobot/scripts/convert_dataset_v21_to_v30.py similarity index 99% rename from src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py rename to src/lerobot/scripts/convert_dataset_v21_to_v30.py index 81de05686..dc81cc51c 100644 --- a/src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py +++ b/src/lerobot/scripts/convert_dataset_v21_to_v30.py @@ -28,13 +28,13 @@ Usage: Convert a dataset from the hub: ```bash -python src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py \ +python src/lerobot/scripts/convert_dataset_v21_to_v30.py \ --repo-id=lerobot/pusht ``` Convert a local dataset (works in place): ```bash -python src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py \ +python src/lerobot/scripts/convert_dataset_v21_to_v30.py \ --repo-id=lerobot/pusht \ --root=/path/to/local/dataset/directory \ --push-to-hub=false diff --git a/tests/datasets/test_dataset_utils.py b/tests/datasets/test_dataset_utils.py index 99b832e55..d40ee238f 100644 --- a/tests/datasets/test_dataset_utils.py +++ b/tests/datasets/test_dataset_utils.py @@ -19,11 +19,26 @@ import torch from datasets import Dataset from huggingface_hub import DatasetCard -from lerobot.datasets.push_dataset_to_hub.utils import calculate_episode_data_index from lerobot.datasets.utils import combine_feature_dicts, create_lerobot_dataset_card, hf_transform_to_torch from lerobot.utils.constants import ACTION, OBS_IMAGES +def calculate_episode_data_index(hf_dataset: Dataset) -> dict[str, torch.Tensor]: + """Calculate episode data index for testing. Returns {"from": Tensor, "to": Tensor}.""" + episode_data_index: dict[str, list[int]] = {"from": [], "to": []} + current_episode = None + if len(hf_dataset) == 0: + return {"from": torch.tensor([]), "to": torch.tensor([])} + for idx, episode_idx in enumerate(hf_dataset["episode_index"]): + if episode_idx != current_episode: + episode_data_index["from"].append(idx) + if current_episode is not None: + episode_data_index["to"].append(idx) + current_episode = episode_idx + episode_data_index["to"].append(idx + 1) + return {k: torch.tensor(v) for k, v in episode_data_index.items()} + + def test_default_parameters(): card = create_lerobot_dataset_card() assert isinstance(card, DatasetCard) diff --git a/tests/datasets/test_online_buffer.py b/tests/datasets/test_online_buffer.py deleted file mode 100644 index 887da6041..000000000 --- a/tests/datasets/test_online_buffer.py +++ /dev/null @@ -1,282 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License.d -from copy import deepcopy -from uuid import uuid4 - -import numpy as np -import pytest -import torch - -from lerobot.datasets.online_buffer import OnlineBuffer, compute_sampler_weights - -# Some constants for OnlineBuffer tests. -data_key = "data" -data_shape = (2, 3) # just some arbitrary > 1D shape -buffer_capacity = 100 -fps = 10 - - -def make_new_buffer( - write_dir: str | None = None, delta_timestamps: dict[str, list[float]] | None = None -) -> tuple[OnlineBuffer, str]: - if write_dir is None: - write_dir = f"/tmp/online_buffer_{uuid4().hex}" - buffer = OnlineBuffer( - write_dir, - data_spec={data_key: {"shape": data_shape, "dtype": np.dtype("float32")}}, - buffer_capacity=buffer_capacity, - fps=fps, - delta_timestamps=delta_timestamps, - ) - return buffer, write_dir - - -def make_spoof_data_frames(n_episodes: int, n_frames_per_episode: int) -> dict[str, np.ndarray]: - new_data = { - data_key: np.arange(n_frames_per_episode * n_episodes * np.prod(data_shape)).reshape(-1, *data_shape), - OnlineBuffer.INDEX_KEY: np.arange(n_frames_per_episode * n_episodes), - OnlineBuffer.EPISODE_INDEX_KEY: np.repeat(np.arange(n_episodes), n_frames_per_episode), - OnlineBuffer.FRAME_INDEX_KEY: np.tile(np.arange(n_frames_per_episode), n_episodes), - OnlineBuffer.TIMESTAMP_KEY: np.tile(np.arange(n_frames_per_episode) / fps, n_episodes), - } - return new_data - - -def test_non_mutate(): - """Checks that the data provided to the add_data method is copied rather than passed by reference. - - This means that mutating the data in the buffer does not mutate the original data. - - NOTE: If this test fails, it means some of the other tests may be compromised. For example, we can't trust - a success case for `test_write_read`. - """ - buffer, _ = make_new_buffer() - new_data = make_spoof_data_frames(2, buffer_capacity // 4) - new_data_copy = deepcopy(new_data) - buffer.add_data(new_data) - buffer._data[data_key][:] += 1 - assert all(np.array_equal(new_data[k], new_data_copy[k]) for k in new_data) - - -def test_index_error_no_data(): - buffer, _ = make_new_buffer() - with pytest.raises(IndexError): - buffer[0] - - -def test_index_error_with_data(): - buffer, _ = make_new_buffer() - n_frames = buffer_capacity // 2 - new_data = make_spoof_data_frames(1, n_frames) - buffer.add_data(new_data) - with pytest.raises(IndexError): - buffer[n_frames] - with pytest.raises(IndexError): - buffer[-n_frames - 1] - - -@pytest.mark.parametrize("do_reload", [False, True]) -def test_write_read(do_reload: bool): - """Checks that data can be added to the buffer and read back. - - If do_reload we delete the buffer object and load the buffer back from disk before reading. - """ - buffer, write_dir = make_new_buffer() - n_episodes = 2 - n_frames_per_episode = buffer_capacity // 4 - new_data = make_spoof_data_frames(n_episodes, n_frames_per_episode) - buffer.add_data(new_data) - - if do_reload: - del buffer - buffer, _ = make_new_buffer(write_dir) - - assert len(buffer) == n_frames_per_episode * n_episodes - for i, item in enumerate(buffer): - assert all(isinstance(item[k], torch.Tensor) for k in item) - assert np.array_equal(item[data_key].numpy(), new_data[data_key][i]) - - -def test_read_data_key(): - """Tests that data can be added to a buffer and all data for a. specific key can be read back.""" - buffer, _ = make_new_buffer() - n_episodes = 2 - n_frames_per_episode = buffer_capacity // 4 - new_data = make_spoof_data_frames(n_episodes, n_frames_per_episode) - buffer.add_data(new_data) - - data_from_buffer = buffer.get_data_by_key(data_key) - assert isinstance(data_from_buffer, torch.Tensor) - assert np.array_equal(data_from_buffer.numpy(), new_data[data_key]) - - -def test_fifo(): - """Checks that if data is added beyond the buffer capacity, we discard the oldest data first.""" - buffer, _ = make_new_buffer() - n_frames_per_episode = buffer_capacity // 4 - n_episodes = 3 - new_data = make_spoof_data_frames(n_episodes, n_frames_per_episode) - buffer.add_data(new_data) - n_more_episodes = 2 - # Developer sanity check (in case someone changes the global `buffer_capacity`). - assert (n_episodes + n_more_episodes) * n_frames_per_episode > buffer_capacity, ( - "Something went wrong with the test code." - ) - more_new_data = make_spoof_data_frames(n_more_episodes, n_frames_per_episode) - buffer.add_data(more_new_data) - assert len(buffer) == buffer_capacity, "The buffer should be full." - - expected_data = {} - for k in new_data: - # Concatenate, left-truncate, then roll, to imitate the cyclical FIFO pattern in OnlineBuffer. - expected_data[k] = np.roll( - np.concatenate([new_data[k], more_new_data[k]])[-buffer_capacity:], - shift=len(new_data[k]) + len(more_new_data[k]) - buffer_capacity, - axis=0, - ) - - for i, item in enumerate(buffer): - assert all(isinstance(item[k], torch.Tensor) for k in item) - assert np.array_equal(item[data_key].numpy(), expected_data[data_key][i]) - - -def test_delta_timestamps_within_tolerance(): - """Check that getting an item with delta_timestamps within tolerance succeeds. - - Note: Copied from `test_datasets.py::test_load_previous_and_future_frames_within_tolerance`. - """ - # Sanity check on global fps as we are assuming it is 10 here. - assert fps == 10, "This test assumes fps==10" - buffer, _ = make_new_buffer(delta_timestamps={"index": [-0.2, 0, 0.139]}) - new_data = make_spoof_data_frames(n_episodes=1, n_frames_per_episode=5) - buffer.add_data(new_data) - buffer.tolerance_s = 0.04 - item = buffer[2] - data, is_pad = item["index"], item[f"index{OnlineBuffer.IS_PAD_POSTFIX}"] - torch.testing.assert_close(data, torch.tensor([0, 2, 3]), msg="Data does not match expected values") - assert not is_pad.any(), "Unexpected padding detected" - - -def test_delta_timestamps_outside_tolerance_inside_episode_range(): - """Check that getting an item with delta_timestamps outside of tolerance fails. - - We expect it to fail if and only if the requested timestamps are within the episode range. - - Note: Copied from - `test_datasets.py::test_load_previous_and_future_frames_outside_tolerance_inside_episode_range` - """ - # Sanity check on global fps as we are assuming it is 10 here. - assert fps == 10, "This test assumes fps==10" - buffer, _ = make_new_buffer(delta_timestamps={"index": [-0.2, 0, 0.141]}) - new_data = make_spoof_data_frames(n_episodes=1, n_frames_per_episode=5) - buffer.add_data(new_data) - buffer.tolerance_s = 0.04 - with pytest.raises(AssertionError): - buffer[2] - - -def test_delta_timestamps_outside_tolerance_outside_episode_range(): - """Check that copy-padding of timestamps outside of the episode range works. - - Note: Copied from - `test_datasets.py::test_load_previous_and_future_frames_outside_tolerance_outside_episode_range` - """ - # Sanity check on global fps as we are assuming it is 10 here. - assert fps == 10, "This test assumes fps==10" - buffer, _ = make_new_buffer(delta_timestamps={"index": [-0.3, -0.24, 0, 0.26, 0.3]}) - new_data = make_spoof_data_frames(n_episodes=1, n_frames_per_episode=5) - buffer.add_data(new_data) - buffer.tolerance_s = 0.04 - item = buffer[2] - data, is_pad = item["index"], item["index_is_pad"] - assert torch.equal(data, torch.tensor([0, 0, 2, 4, 4])), "Data does not match expected values" - assert torch.equal(is_pad, torch.tensor([True, False, False, True, True])), ( - "Padding does not match expected values" - ) - - -# Arbitrarily set small dataset sizes, making sure to have uneven sizes. -@pytest.mark.parametrize("offline_dataset_size", [1, 6]) -@pytest.mark.parametrize("online_dataset_size", [0, 4]) -@pytest.mark.parametrize("online_sampling_ratio", [0.0, 1.0]) -def test_compute_sampler_weights_trivial( - lerobot_dataset_factory, - tmp_path, - offline_dataset_size: int, - online_dataset_size: int, - online_sampling_ratio: float, -): - offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=offline_dataset_size) - online_dataset, _ = make_new_buffer() - if online_dataset_size > 0: - online_dataset.add_data( - make_spoof_data_frames(n_episodes=2, n_frames_per_episode=online_dataset_size // 2) - ) - - weights = compute_sampler_weights( - offline_dataset, online_dataset=online_dataset, online_sampling_ratio=online_sampling_ratio - ) - if offline_dataset_size == 0 or online_dataset_size == 0: - expected_weights = torch.ones(offline_dataset_size + online_dataset_size) - elif online_sampling_ratio == 0: - expected_weights = torch.cat([torch.ones(offline_dataset_size), torch.zeros(online_dataset_size)]) - elif online_sampling_ratio == 1: - expected_weights = torch.cat([torch.zeros(offline_dataset_size), torch.ones(online_dataset_size)]) - expected_weights /= expected_weights.sum() - torch.testing.assert_close(weights, expected_weights) - - -def test_compute_sampler_weights_nontrivial_ratio(lerobot_dataset_factory, tmp_path): - # Arbitrarily set small dataset sizes, making sure to have uneven sizes. - offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=4) - online_dataset, _ = make_new_buffer() - online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2)) - online_sampling_ratio = 0.8 - weights = compute_sampler_weights( - offline_dataset, online_dataset=online_dataset, online_sampling_ratio=online_sampling_ratio - ) - torch.testing.assert_close( - weights, torch.tensor([0.05, 0.05, 0.05, 0.05, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]) - ) - - -def test_compute_sampler_weights_nontrivial_ratio_and_drop_last_n(lerobot_dataset_factory, tmp_path): - # Arbitrarily set small dataset sizes, making sure to have uneven sizes. - offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=4) - online_dataset, _ = make_new_buffer() - online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2)) - weights = compute_sampler_weights( - offline_dataset, online_dataset=online_dataset, online_sampling_ratio=0.8, online_drop_n_last_frames=1 - ) - torch.testing.assert_close( - weights, torch.tensor([0.05, 0.05, 0.05, 0.05, 0.2, 0.0, 0.2, 0.0, 0.2, 0.0, 0.2, 0.0]) - ) - - -def test_compute_sampler_weights_drop_n_last_frames(lerobot_dataset_factory, tmp_path): - """Note: test copied from test_sampler.""" - offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=2) - online_dataset, _ = make_new_buffer() - online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2)) - - weights = compute_sampler_weights( - offline_dataset, - offline_drop_n_last_frames=1, - online_dataset=online_dataset, - online_sampling_ratio=0.5, - online_drop_n_last_frames=1, - ) - torch.testing.assert_close(weights, torch.tensor([0.5, 0, 0.125, 0, 0.125, 0, 0.125, 0, 0.125, 0])) diff --git a/tests/datasets/test_sampler.py b/tests/datasets/test_sampler.py index fd7a6e380..e5b35e426 100644 --- a/tests/datasets/test_sampler.py +++ b/tests/datasets/test_sampler.py @@ -13,15 +13,31 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import torch from datasets import Dataset -from lerobot.datasets.push_dataset_to_hub.utils import calculate_episode_data_index from lerobot.datasets.sampler import EpisodeAwareSampler from lerobot.datasets.utils import ( hf_transform_to_torch, ) +def calculate_episode_data_index(hf_dataset: Dataset) -> dict[str, torch.Tensor]: + """Calculate episode data index for testing. Returns {"from": Tensor, "to": Tensor}.""" + episode_data_index: dict[str, list[int]] = {"from": [], "to": []} + current_episode = None + if len(hf_dataset) == 0: + return {"from": torch.tensor([]), "to": torch.tensor([])} + for idx, episode_idx in enumerate(hf_dataset["episode_index"]): + if episode_idx != current_episode: + episode_data_index["from"].append(idx) + if current_episode is not None: + episode_data_index["to"].append(idx) + current_episode = episode_idx + episode_data_index["to"].append(idx + 1) + return {k: torch.tensor(v) for k, v in episode_data_index.items()} + + def test_drop_n_first_frames(): dataset = Dataset.from_dict( { From 9d3b62aa613ec82dfb17839ed1ef50b948a5ea80 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Sun, 15 Mar 2026 22:12:09 -0700 Subject: [PATCH 115/131] chore(dataset): basic house-keeping (#3170) --- src/lerobot/configs/default.py | 10 +++++ src/lerobot/datasets/image_writer.py | 10 +++-- src/lerobot/datasets/lerobot_dataset.py | 17 +++++--- src/lerobot/datasets/pipeline_features.py | 10 +++-- src/lerobot/datasets/sampler.py | 25 ++++++++++++ src/lerobot/datasets/video_utils.py | 48 ++++++++++++----------- tests/configs/test_default.py | 38 ++++++++++++++++++ tests/datasets/test_image_writer.py | 8 ++-- tests/datasets/test_sampler.py | 28 +++++++++++++ 9 files changed, 153 insertions(+), 41 deletions(-) create mode 100644 tests/configs/test_default.py diff --git a/src/lerobot/configs/default.py b/src/lerobot/configs/default.py index 3fb0c6c4e..7f481b9ca 100644 --- a/src/lerobot/configs/default.py +++ b/src/lerobot/configs/default.py @@ -36,6 +36,16 @@ class DatasetConfig: video_backend: str = field(default_factory=get_safe_default_codec) streaming: bool = False + def __post_init__(self) -> None: + if self.episodes is not None: + if any(ep < 0 for ep in self.episodes): + raise ValueError( + f"Episode indices must be non-negative, got: {[ep for ep in self.episodes if ep < 0]}" + ) + if len(self.episodes) != len(set(self.episodes)): + duplicates = sorted({ep for ep in self.episodes if self.episodes.count(ep) > 1}) + raise ValueError(f"Episode indices contain duplicates: {duplicates}") + @dataclass class WandBConfig: diff --git a/src/lerobot/datasets/image_writer.py b/src/lerobot/datasets/image_writer.py index 23bc2efb8..9f40394de 100644 --- a/src/lerobot/datasets/image_writer.py +++ b/src/lerobot/datasets/image_writer.py @@ -13,6 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging import multiprocessing import queue import threading @@ -22,6 +23,8 @@ import numpy as np import PIL.Image import torch +logger = logging.getLogger(__name__) + def safe_stop_image_writer(func): def wrapper(*args, **kwargs): @@ -31,7 +34,7 @@ def safe_stop_image_writer(func): dataset = kwargs.get("dataset") image_writer = getattr(dataset, "image_writer", None) if dataset else None if image_writer is not None: - print("Waiting for image writer to terminate...") + logger.warning("Waiting for image writer to terminate...") image_writer.stop() raise e @@ -89,8 +92,7 @@ def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path, compress_level PIL.Image.Image object. Side Effects: - Prints an error message to the console if the image writing process - fails for any reason. + Logs an error message if the image writing process fails for any reason. """ try: if isinstance(image, np.ndarray): @@ -101,7 +103,7 @@ def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path, compress_level raise TypeError(f"Unsupported image type: {type(image)}") img.save(fpath, compress_level=compress_level) except Exception as e: - print(f"Error writing image {fpath}: {e}") + logger.error("Error writing image %s: %s", fpath, e) def worker_thread_loop(queue: queue.Queue): diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index 11c10f493..5d1b5d042 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -80,6 +80,8 @@ from lerobot.datasets.video_utils import ( ) from lerobot.utils.constants import HF_LEROBOT_HOME +logger = logging.getLogger(__name__) + CODEBASE_VERSION = "v3.0" @@ -535,7 +537,10 @@ class LeRobotDatasetMetadata: video_files_size_in_mb, ) if len(obj.video_keys) > 0 and not use_videos: - raise ValueError() + raise ValueError( + f"Features contain video keys {obj.video_keys}, but 'use_videos' is set to False. " + "Either remove video features from the features dict, or set 'use_videos=True'." + ) write_json(obj.info, obj.root / INFO_PATH) obj.revision = None obj.writer = None @@ -1326,7 +1331,7 @@ class LeRobotDataset(torch.utils.data.Dataset): temp_path = future.result() results[video_key] = temp_path except Exception as exc: - logging.error(f"Video encoding failed for {video_key}: {exc}") + logger.error(f"Video encoding failed for {video_key}: {exc}") raise exc for video_key in self.meta.video_keys: @@ -1365,7 +1370,7 @@ class LeRobotDataset(torch.utils.data.Dataset): if end_episode is None: end_episode = self.num_episodes - logging.info( + logger.info( f"Batch encoding {self.batch_encoding_size} videos for episodes {start_episode} to {end_episode - 1}" ) @@ -1375,7 +1380,7 @@ class LeRobotDataset(torch.utils.data.Dataset): episode_df = pd.read_parquet(episode_df_path) for ep_idx in range(start_episode, end_episode): - logging.info(f"Encoding videos for episode {ep_idx}") + logger.info(f"Encoding videos for episode {ep_idx}") if ( self.meta.episodes[ep_idx]["data/chunk_index"] != chunk_idx @@ -1605,7 +1610,7 @@ class LeRobotDataset(torch.utils.data.Dataset): def start_image_writer(self, num_processes: int = 0, num_threads: int = 4) -> None: if isinstance(self.image_writer, AsyncImageWriter): - logging.warning( + logger.warning( "You are starting a new AsyncImageWriter that is replacing an already existing one in the dataset." ) @@ -1771,7 +1776,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset): for repo_id, ds in zip(self.repo_ids, self._datasets, strict=True): extra_keys = set(ds.features).difference(intersection_features) if extra_keys: - logging.warning( + logger.warning( f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the " "other datasets." ) diff --git a/src/lerobot/datasets/pipeline_features.py b/src/lerobot/datasets/pipeline_features.py index f824eb9bc..fe8cabbeb 100644 --- a/src/lerobot/datasets/pipeline_features.py +++ b/src/lerobot/datasets/pipeline_features.py @@ -44,11 +44,11 @@ def create_initial_features( return features -# Helper to filter state/action keys based on regex patterns. -def should_keep(key: str, patterns: tuple[str]) -> bool: +# Helper to filter state/action keys based on compiled regex patterns. +def should_keep(key: str, patterns: tuple[re.Pattern] | None) -> bool: if patterns is None: return True - return any(re.search(pat, key) for pat in patterns) + return any(pat.search(key) for pat in patterns) def strip_prefix(key: str, prefixes_to_strip: tuple[str]) -> str: @@ -89,6 +89,8 @@ def aggregate_pipeline_dataset_features( Returns: A dictionary of features formatted for a Hugging Face LeRobot Dataset. """ + compiled_patterns = tuple(re.compile(p) for p in patterns) if patterns is not None else None + all_features = pipeline.transform_features(initial_features) # Intermediate storage for categorized and filtered features. @@ -120,7 +122,7 @@ def aggregate_pipeline_dataset_features( # 2. Apply filtering rules. if is_image and not use_videos: continue - if not is_image and not should_keep(key, patterns): + if not is_image and not should_keep(key, compiled_patterns): continue # 3. Add the feature to the appropriate group with a clean name. diff --git a/src/lerobot/datasets/sampler.py b/src/lerobot/datasets/sampler.py index d0bb20c27..2bf7ab922 100644 --- a/src/lerobot/datasets/sampler.py +++ b/src/lerobot/datasets/sampler.py @@ -13,10 +13,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging from collections.abc import Iterator import torch +logger = logging.getLogger(__name__) + class EpisodeAwareSampler: def __init__( @@ -39,13 +42,35 @@ class EpisodeAwareSampler: drop_n_last_frames: Number of frames to drop from the end of each episode. shuffle: Whether to shuffle the indices. """ + if drop_n_first_frames < 0: + raise ValueError(f"drop_n_first_frames must be >= 0, got {drop_n_first_frames}") + if drop_n_last_frames < 0: + raise ValueError(f"drop_n_last_frames must be >= 0, got {drop_n_last_frames}") + indices = [] for episode_idx, (start_index, end_index) in enumerate( zip(dataset_from_indices, dataset_to_indices, strict=True) ): if episode_indices_to_use is None or episode_idx in episode_indices_to_use: + ep_length = end_index - start_index + if drop_n_first_frames + drop_n_last_frames >= ep_length: + logger.warning( + "Episode %d has %d frames but drop_n_first_frames=%d and " + "drop_n_last_frames=%d removes all frames. Skipping.", + episode_idx, + ep_length, + drop_n_first_frames, + drop_n_last_frames, + ) + continue indices.extend(range(start_index + drop_n_first_frames, end_index - drop_n_last_frames)) + if not indices: + raise ValueError( + "No valid frames remain after applying drop_n_first_frames and drop_n_last_frames. " + "All episodes were either filtered out or had too few frames." + ) + self.indices = indices self.shuffle = shuffle diff --git a/src/lerobot/datasets/video_utils.py b/src/lerobot/datasets/video_utils.py index 8c8494b87..e465b79b4 100644 --- a/src/lerobot/datasets/video_utils.py +++ b/src/lerobot/datasets/video_utils.py @@ -37,6 +37,8 @@ import torchvision from datasets.features.features import register_feature from PIL import Image +logger = logging.getLogger(__name__) + # List of hardware encoders to probe for auto-selection. Availability depends on the platform and FFmpeg build. # Determines the order of preference for auto-selection when vcodec="auto" is used. HW_ENCODERS = [ @@ -94,7 +96,7 @@ def detect_available_hw_encoders() -> list[str]: av.codec.Codec(codec_name, "w") available.append(codec_name) except Exception: # nosec B110 - pass # nosec B110 + logger.debug("HW encoder '%s' not available", codec_name) # nosec B110 return available @@ -103,14 +105,14 @@ def resolve_vcodec(vcodec: str) -> str: if vcodec not in VALID_VIDEO_CODECS: raise ValueError(f"Invalid vcodec '{vcodec}'. Must be one of: {sorted(VALID_VIDEO_CODECS)}") if vcodec != "auto": - logging.info(f"Using video codec: {vcodec}") + logger.info(f"Using video codec: {vcodec}") return vcodec available = detect_available_hw_encoders() for encoder in HW_ENCODERS: if encoder in available: - logging.info(f"Auto-selected video codec: {encoder}") + logger.info(f"Auto-selected video codec: {encoder}") return encoder - logging.info("No hardware encoder available, falling back to software encoder 'libsvtav1'") + logger.info("No hardware encoder available, falling back to software encoder 'libsvtav1'") return "libsvtav1" @@ -118,7 +120,7 @@ def get_safe_default_codec(): if importlib.util.find_spec("torchcodec"): return "torchcodec" else: - logging.warning( + logger.warning( "'torchcodec' is not available in your platform, falling back to 'pyav' as a default decoder" ) return "pyav" @@ -208,7 +210,7 @@ def decode_video_frames_torchvision( for frame in reader: current_ts = frame["pts"] if log_loaded_timestamps: - logging.info(f"frame loaded at timestamp={current_ts:.4f}") + logger.info(f"frame loaded at timestamp={current_ts:.4f}") loaded_frames.append(frame["data"]) loaded_ts.append(current_ts) if current_ts >= last_ts: @@ -244,7 +246,7 @@ def decode_video_frames_torchvision( closest_ts = loaded_ts[argmin_] if log_loaded_timestamps: - logging.info(f"{closest_ts=}") + logger.info(f"{closest_ts=}") # convert to the pytorch format which is float32 in [0,1] range (and channel first) closest_frames = closest_frames.type(torch.float32) / 255 @@ -348,7 +350,7 @@ def decode_video_frames_torchcodec( loaded_frames.append(frame) loaded_ts.append(pts.item()) if log_loaded_timestamps: - logging.info(f"Frame loaded at timestamp={pts:.4f}") + logger.info(f"Frame loaded at timestamp={pts:.4f}") query_ts = torch.tensor(timestamps) loaded_ts = torch.tensor(loaded_ts) @@ -374,7 +376,7 @@ def decode_video_frames_torchcodec( closest_ts = loaded_ts[argmin_] if log_loaded_timestamps: - logging.info(f"{closest_ts=}") + logger.info(f"{closest_ts=}") # convert to float32 in [0,1] range closest_frames = (closest_frames / 255.0).type(torch.float32) @@ -408,14 +410,14 @@ def encode_video_frames( imgs_dir = Path(imgs_dir) if video_path.exists() and not overwrite: - logging.warning(f"Video file already exists: {video_path}. Skipping encoding.") + logger.warning(f"Video file already exists: {video_path}. Skipping encoding.") return video_path.parent.mkdir(parents=True, exist_ok=True) # Encoders/pixel formats incompatibility check if (vcodec == "libsvtav1" or vcodec == "hevc") and pix_fmt == "yuv444p": - logging.warning( + logger.warning( f"Incompatible pixel format 'yuv444p' for codec {vcodec}, auto-selecting format 'yuv420p'" ) pix_fmt = "yuv420p" @@ -508,7 +510,7 @@ def concatenate_video_files( output_video_path = Path(output_video_path) if output_video_path.exists() and not overwrite: - logging.warning(f"Video file already exists: {output_video_path}. Skipping concatenation.") + logger.warning(f"Video file already exists: {output_video_path}. Skipping concatenation.") return output_video_path.parent.mkdir(parents=True, exist_ok=True) @@ -693,7 +695,7 @@ class _CameraEncoderThread(threading.Thread): self.result_queue.put(("ok", None)) except Exception as e: - logging.error(f"Encoder thread error: {e}") + logger.error(f"Encoder thread error: {e}") if container is not None: with contextlib.suppress(Exception): container.close() @@ -819,7 +821,7 @@ class StreamingVideoEncoder: count = self._dropped_frames[video_key] # Log periodically to avoid spam (1st, then every 10th) if count == 1 or count % 10 == 0: - logging.warning( + logger.warning( f"Encoder queue full for {video_key}, dropped {count} frame(s). " f"Consider using vcodec='auto' for hardware encoding or increasing encoder_queue_maxsize." ) @@ -841,7 +843,7 @@ class StreamingVideoEncoder: # Report dropped frames for video_key, count in self._dropped_frames.items(): if count > 0: - logging.warning(f"Episode finished with {count} dropped frame(s) for {video_key}.") + logger.warning(f"Episode finished with {count} dropped frame(s) for {video_key}.") # Send sentinel to all queues for video_key in self._frame_queues: @@ -851,7 +853,7 @@ class StreamingVideoEncoder: for video_key in self._threads: self._threads[video_key].join(timeout=120) if self._threads[video_key].is_alive(): - logging.error(f"Encoder thread for {video_key} did not finish in time") + logger.error(f"Encoder thread for {video_key} did not finish in time") self._stop_events[video_key].set() self._threads[video_key].join(timeout=5) results[video_key] = (self._video_paths[video_key], None) @@ -863,7 +865,7 @@ class StreamingVideoEncoder: raise RuntimeError(f"Encoder thread for {video_key} failed: {data}") results[video_key] = (self._video_paths[video_key], data) except queue.Empty: - logging.error(f"No result from encoder thread for {video_key}") + logger.error(f"No result from encoder thread for {video_key}") results[video_key] = (self._video_paths[video_key], None) self._cleanup() @@ -1071,13 +1073,13 @@ class VideoEncodingManager: elif self.dataset.episodes_since_last_encoding > 0: # Handle any remaining episodes that haven't been batch encoded if exc_type is not None: - logging.info("Exception occurred. Encoding remaining episodes before exit...") + logger.info("Exception occurred. Encoding remaining episodes before exit...") else: - logging.info("Recording stopped. Encoding remaining episodes...") + logger.info("Recording stopped. Encoding remaining episodes...") start_ep = self.dataset.num_episodes - self.dataset.episodes_since_last_encoding end_ep = self.dataset.num_episodes - logging.info( + logger.info( f"Encoding remaining {self.dataset.episodes_since_last_encoding} episodes, " f"from episode {start_ep} to {end_ep - 1}" ) @@ -1094,7 +1096,7 @@ class VideoEncodingManager: episode_index=interrupted_episode_index, image_key=key, frame_index=0 ).parent if img_dir.exists(): - logging.debug( + logger.debug( f"Cleaning up interrupted episode images for episode {interrupted_episode_index}, camera {key}" ) shutil.rmtree(img_dir) @@ -1105,8 +1107,8 @@ class VideoEncodingManager: png_files = list(img_dir.rglob("*.png")) if len(png_files) == 0: shutil.rmtree(img_dir) - logging.debug("Cleaned up empty images directory") + logger.debug("Cleaned up empty images directory") else: - logging.debug(f"Images directory is not empty, containing {len(png_files)} PNG files") + logger.debug(f"Images directory is not empty, containing {len(png_files)} PNG files") return False # Don't suppress the original exception diff --git a/tests/configs/test_default.py b/tests/configs/test_default.py new file mode 100644 index 000000000..238b8bacd --- /dev/null +++ b/tests/configs/test_default.py @@ -0,0 +1,38 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest + +from lerobot.configs.default import DatasetConfig + + +def test_dataset_config_valid(): + DatasetConfig(repo_id="user/repo", episodes=[0, 1, 2]) + + +def test_dataset_config_negative_episodes(): + with pytest.raises(ValueError, match="non-negative"): + DatasetConfig(repo_id="user/repo", episodes=[0, -1, 2]) + + +def test_dataset_config_duplicate_episodes(): + with pytest.raises(ValueError, match="duplicates"): + DatasetConfig(repo_id="user/repo", episodes=[0, 1, 1, 2]) + + +def test_dataset_config_none_episodes_ok(): + DatasetConfig(repo_id="user/repo", episodes=None) + + +def test_dataset_config_empty_episodes_ok(): + DatasetConfig(repo_id="user/repo", episodes=[]) diff --git a/tests/datasets/test_image_writer.py b/tests/datasets/test_image_writer.py index 99c8b24fc..e02755171 100644 --- a/tests/datasets/test_image_writer.py +++ b/tests/datasets/test_image_writer.py @@ -142,9 +142,9 @@ def test_write_image_image(tmp_path, img_factory): def test_write_image_exception(tmp_path): image_array = "invalid data" fpath = tmp_path / DUMMY_IMAGE - with patch("builtins.print") as mock_print: + with patch("lerobot.datasets.image_writer.logger") as mock_logger: write_image(image_array, fpath) - mock_print.assert_called() + mock_logger.error.assert_called() assert not fpath.exists() @@ -243,10 +243,10 @@ def test_save_image_invalid_data(tmp_path): image_array = "invalid data" fpath = tmp_path / DUMMY_IMAGE fpath.parent.mkdir(parents=True, exist_ok=True) - with patch("builtins.print") as mock_print: + with patch("lerobot.datasets.image_writer.logger") as mock_logger: writer.save_image(image_array, fpath) writer.wait_until_done() - mock_print.assert_called() + mock_logger.error.assert_called() assert not fpath.exists() finally: writer.stop() diff --git a/tests/datasets/test_sampler.py b/tests/datasets/test_sampler.py index e5b35e426..a5d463349 100644 --- a/tests/datasets/test_sampler.py +++ b/tests/datasets/test_sampler.py @@ -13,6 +13,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging + +import pytest import torch from datasets import Dataset @@ -106,3 +109,28 @@ def test_shuffle(): assert sampler.indices == [0, 1, 2, 3, 4, 5] assert len(sampler) == 6 assert set(sampler) == {0, 1, 2, 3, 4, 5} + + +def test_negative_drop_first_frames_raises(): + with pytest.raises(ValueError, match="drop_n_first_frames must be >= 0"): + EpisodeAwareSampler([0], [10], drop_n_first_frames=-1) + + +def test_negative_drop_last_frames_raises(): + with pytest.raises(ValueError, match="drop_n_last_frames must be >= 0"): + EpisodeAwareSampler([0], [10], drop_n_last_frames=-1) + + +def test_all_episodes_dropped_raises(): + # All episodes have 1 frame, drop_n_first_frames=1 removes all + with pytest.raises(ValueError, match="No valid frames remain"): + EpisodeAwareSampler([0, 1, 2], [1, 2, 3], drop_n_first_frames=1) + + +def test_partial_episode_drop_warns(caplog): + # Episode 0: 1 frame (dropped), Episode 1: 5 frames (kept) + with caplog.at_level(logging.WARNING, logger="lerobot.datasets.sampler"): + sampler = EpisodeAwareSampler([0, 1], [1, 6], drop_n_first_frames=1) + # Episode 0 is skipped (1 frame, drop 1), Episode 1 keeps frames 2-5 + assert sampler.indices == [2, 3, 4, 5] + assert "Episode 0" in caplog.text From d90e4bcfd33ecf89f1db45ad7315923c21da57bd Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Sun, 15 Mar 2026 23:58:09 -0700 Subject: [PATCH 116/131] refactor(dataset): modular files (#3171) * refactor(dataset): modular files * refactor(dataset): update imports across the codebase --- examples/dataset/load_lerobot_dataset.py | 3 +- examples/lekiwi/evaluate.py | 2 +- examples/lekiwi/record.py | 2 +- examples/phone_to_so100/evaluate.py | 2 +- examples/phone_to_so100/record.py | 2 +- examples/port_datasets/port_droid.py | 3 +- examples/port_datasets/slurm_upload.py | 4 +- examples/rtc/eval_dataset.py | 3 +- examples/rtc/eval_with_real_robot.py | 2 +- examples/so100_to_so100_EE/evaluate.py | 2 +- examples/so100_to_so100_EE/record.py | 2 +- examples/training/train_policy.py | 5 +- examples/training/train_with_streaming.py | 4 +- examples/tutorial/act/act_training_example.py | 5 +- examples/tutorial/act/act_using_example.py | 2 +- .../diffusion/diffusion_training_example.py | 5 +- .../diffusion/diffusion_using_example.py | 2 +- examples/tutorial/pi0/using_pi0_example.py | 2 +- examples/tutorial/rl/hilserl_example.py | 2 +- .../tutorial/smolvla/using_smolvla_example.py | 2 +- src/lerobot/async_inference/helpers.py | 2 +- .../sarm_annotations/subtask_annotation.py | 5 +- src/lerobot/datasets/aggregate.py | 18 +- .../datasets/backward_compatibility.py | 56 - src/lerobot/datasets/compute_stats.py | 2 +- src/lerobot/datasets/dataset_metadata.py | 517 ++++++++ src/lerobot/datasets/dataset_tools.py | 18 +- src/lerobot/datasets/factory.py | 8 +- src/lerobot/datasets/feature_utils.py | 552 +++++++++ src/lerobot/datasets/io_utils.py | 342 ++++++ src/lerobot/datasets/lerobot_dataset.py | 698 +---------- src/lerobot/datasets/multi_dataset.py | 210 ++++ src/lerobot/datasets/pipeline_features.py | 2 +- src/lerobot/datasets/streaming_dataset.py | 170 ++- src/lerobot/datasets/utils.py | 1037 +---------------- src/lerobot/optim/optimizers.py | 3 +- src/lerobot/optim/schedulers.py | 2 +- src/lerobot/policies/factory.py | 4 +- src/lerobot/policies/utils.py | 2 +- .../scripts/augment_dataset_quantile_stats.py | 5 +- .../scripts/convert_dataset_v21_to_v30.py | 23 +- src/lerobot/scripts/lerobot_record.py | 2 +- src/lerobot/utils/train_utils.py | 2 +- tests/datasets/test_aggregate.py | 24 +- tests/datasets/test_dataset_tools.py | 118 +- tests/datasets/test_dataset_utils.py | 4 +- tests/datasets/test_datasets.py | 9 +- tests/datasets/test_delta_timestamps.py | 2 +- tests/datasets/test_sampler.py | 4 +- tests/fixtures/dataset_factories.py | 11 +- tests/fixtures/files.py | 12 +- tests/policies/test_policies.py | 3 +- tests/test_control_robot.py | 8 +- 53 files changed, 2030 insertions(+), 1901 deletions(-) delete mode 100644 src/lerobot/datasets/backward_compatibility.py create mode 100644 src/lerobot/datasets/dataset_metadata.py create mode 100644 src/lerobot/datasets/feature_utils.py create mode 100644 src/lerobot/datasets/io_utils.py create mode 100644 src/lerobot/datasets/multi_dataset.py diff --git a/examples/dataset/load_lerobot_dataset.py b/examples/dataset/load_lerobot_dataset.py index 4fda25884..ea3516710 100644 --- a/examples/dataset/load_lerobot_dataset.py +++ b/examples/dataset/load_lerobot_dataset.py @@ -32,7 +32,8 @@ import torch from huggingface_hub import HfApi import lerobot -from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata +from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata +from lerobot.datasets.lerobot_dataset import LeRobotDataset def main(): diff --git a/examples/lekiwi/evaluate.py b/examples/lekiwi/evaluate.py index a3144a442..ef98640aa 100644 --- a/examples/lekiwi/evaluate.py +++ b/examples/lekiwi/evaluate.py @@ -14,8 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from lerobot.datasets.feature_utils import hw_to_dataset_features from lerobot.datasets.lerobot_dataset import LeRobotDataset -from lerobot.datasets.utils import hw_to_dataset_features from lerobot.policies.act.modeling_act import ACTPolicy from lerobot.policies.factory import make_pre_post_processors from lerobot.processor import make_default_processors diff --git a/examples/lekiwi/record.py b/examples/lekiwi/record.py index 9292157f7..ace2e35b8 100644 --- a/examples/lekiwi/record.py +++ b/examples/lekiwi/record.py @@ -14,8 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from lerobot.datasets.feature_utils import hw_to_dataset_features from lerobot.datasets.lerobot_dataset import LeRobotDataset -from lerobot.datasets.utils import hw_to_dataset_features from lerobot.processor import make_default_processors from lerobot.robots.lekiwi.config_lekiwi import LeKiwiClientConfig from lerobot.robots.lekiwi.lekiwi_client import LeKiwiClient diff --git a/examples/phone_to_so100/evaluate.py b/examples/phone_to_so100/evaluate.py index c1291d101..9cd7a98c2 100644 --- a/examples/phone_to_so100/evaluate.py +++ b/examples/phone_to_so100/evaluate.py @@ -16,9 +16,9 @@ from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig from lerobot.configs.types import FeatureType, PolicyFeature +from lerobot.datasets.feature_utils import combine_feature_dicts from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features -from lerobot.datasets.utils import combine_feature_dicts from lerobot.model.kinematics import RobotKinematics from lerobot.policies.act.modeling_act import ACTPolicy from lerobot.policies.factory import make_pre_post_processors diff --git a/examples/phone_to_so100/record.py b/examples/phone_to_so100/record.py index 756c6f42d..f2a17cd33 100644 --- a/examples/phone_to_so100/record.py +++ b/examples/phone_to_so100/record.py @@ -15,9 +15,9 @@ # limitations under the License. from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig +from lerobot.datasets.feature_utils import combine_feature_dicts from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features -from lerobot.datasets.utils import combine_feature_dicts from lerobot.model.kinematics import RobotKinematics from lerobot.processor import RobotProcessorPipeline from lerobot.processor.converters import ( diff --git a/examples/port_datasets/port_droid.py b/examples/port_datasets/port_droid.py index a1fb50914..f58bacbe0 100644 --- a/examples/port_datasets/port_droid.py +++ b/examples/port_datasets/port_droid.py @@ -22,7 +22,8 @@ from pathlib import Path import numpy as np import tensorflow_datasets as tfds -from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata +from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata +from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.utils.utils import get_elapsed_time_in_days_hours_minutes_seconds DROID_SHARDS = 2048 diff --git a/examples/port_datasets/slurm_upload.py b/examples/port_datasets/slurm_upload.py index 55002c0be..7fb01c11b 100644 --- a/examples/port_datasets/slurm_upload.py +++ b/examples/port_datasets/slurm_upload.py @@ -26,7 +26,7 @@ from huggingface_hub import HfApi from huggingface_hub.constants import REPOCARD_NAME from port_droid import DROID_SHARDS -from lerobot.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDatasetMetadata +from lerobot.datasets.dataset_metadata import CODEBASE_VERSION, LeRobotDatasetMetadata from lerobot.datasets.utils import create_lerobot_dataset_card from lerobot.utils.utils import init_logging @@ -155,7 +155,7 @@ class UploadDataset(PipelineStep): from datasets.utils.tqdm import disable_progress_bars from huggingface_hub import CommitOperationAdd, preupload_lfs_files - from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata + from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata from lerobot.utils.utils import init_logging init_logging() diff --git a/examples/rtc/eval_dataset.py b/examples/rtc/eval_dataset.py index 613fd67d7..a94d4da48 100644 --- a/examples/rtc/eval_dataset.py +++ b/examples/rtc/eval_dataset.py @@ -113,8 +113,9 @@ from lerobot.configs import parser from lerobot.configs.default import DatasetConfig from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.types import RTCAttentionSchedule +from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata from lerobot.datasets.factory import resolve_delta_timestamps -from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata +from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.policies.factory import get_policy_class, make_pre_post_processors from lerobot.policies.rtc.configuration_rtc import RTCConfig from lerobot.policies.rtc.debug_visualizer import RTCDebugVisualizer diff --git a/examples/rtc/eval_with_real_robot.py b/examples/rtc/eval_with_real_robot.py index 9d9e1364a..36da88e1b 100644 --- a/examples/rtc/eval_with_real_robot.py +++ b/examples/rtc/eval_with_real_robot.py @@ -82,7 +82,7 @@ from lerobot.cameras.zmq.configuration_zmq import ZMQCameraConfig # noqa: F401 from lerobot.configs import parser from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.types import RTCAttentionSchedule -from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features +from lerobot.datasets.feature_utils import build_dataset_frame, hw_to_dataset_features from lerobot.policies.factory import get_policy_class, make_pre_post_processors from lerobot.policies.rtc.action_queue import ActionQueue from lerobot.policies.rtc.configuration_rtc import RTCConfig diff --git a/examples/so100_to_so100_EE/evaluate.py b/examples/so100_to_so100_EE/evaluate.py index 45a87ebad..638591021 100644 --- a/examples/so100_to_so100_EE/evaluate.py +++ b/examples/so100_to_so100_EE/evaluate.py @@ -16,9 +16,9 @@ from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig from lerobot.configs.types import FeatureType, PolicyFeature +from lerobot.datasets.feature_utils import combine_feature_dicts from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features -from lerobot.datasets.utils import combine_feature_dicts from lerobot.model.kinematics import RobotKinematics from lerobot.policies.act.modeling_act import ACTPolicy from lerobot.policies.factory import make_pre_post_processors diff --git a/examples/so100_to_so100_EE/record.py b/examples/so100_to_so100_EE/record.py index 8fa862d6e..634bd891a 100644 --- a/examples/so100_to_so100_EE/record.py +++ b/examples/so100_to_so100_EE/record.py @@ -16,9 +16,9 @@ from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig +from lerobot.datasets.feature_utils import combine_feature_dicts from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features -from lerobot.datasets.utils import combine_feature_dicts from lerobot.model.kinematics import RobotKinematics from lerobot.processor import RobotProcessorPipeline from lerobot.processor.converters import ( diff --git a/examples/training/train_policy.py b/examples/training/train_policy.py index 16f2a4d87..07ec10c92 100644 --- a/examples/training/train_policy.py +++ b/examples/training/train_policy.py @@ -19,8 +19,9 @@ from pathlib import Path import torch from lerobot.configs.types import FeatureType -from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata -from lerobot.datasets.utils import dataset_to_policy_features +from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata +from lerobot.datasets.feature_utils import dataset_to_policy_features +from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig from lerobot.policies.diffusion.modeling_diffusion import DiffusionPolicy from lerobot.policies.factory import make_pre_post_processors diff --git a/examples/training/train_with_streaming.py b/examples/training/train_with_streaming.py index 185be5b13..973698e74 100644 --- a/examples/training/train_with_streaming.py +++ b/examples/training/train_with_streaming.py @@ -20,9 +20,9 @@ from pathlib import Path import torch from lerobot.configs.types import FeatureType -from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata +from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata +from lerobot.datasets.feature_utils import dataset_to_policy_features from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset -from lerobot.datasets.utils import dataset_to_policy_features from lerobot.policies.act.configuration_act import ACTConfig from lerobot.policies.act.modeling_act import ACTPolicy from lerobot.policies.factory import make_pre_post_processors diff --git a/examples/tutorial/act/act_training_example.py b/examples/tutorial/act/act_training_example.py index fe70f3023..b62c49cac 100644 --- a/examples/tutorial/act/act_training_example.py +++ b/examples/tutorial/act/act_training_example.py @@ -5,8 +5,9 @@ from pathlib import Path import torch from lerobot.configs.types import FeatureType -from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata -from lerobot.datasets.utils import dataset_to_policy_features +from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata +from lerobot.datasets.feature_utils import dataset_to_policy_features +from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.policies.act.configuration_act import ACTConfig from lerobot.policies.act.modeling_act import ACTPolicy from lerobot.policies.factory import make_pre_post_processors diff --git a/examples/tutorial/act/act_using_example.py b/examples/tutorial/act/act_using_example.py index 60bc802d8..15254d8eb 100644 --- a/examples/tutorial/act/act_using_example.py +++ b/examples/tutorial/act/act_using_example.py @@ -1,7 +1,7 @@ import torch from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig -from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata +from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata from lerobot.policies.act.modeling_act import ACTPolicy from lerobot.policies.factory import make_pre_post_processors from lerobot.policies.utils import build_inference_frame, make_robot_action diff --git a/examples/tutorial/diffusion/diffusion_training_example.py b/examples/tutorial/diffusion/diffusion_training_example.py index 6db081450..dc6ca68a3 100644 --- a/examples/tutorial/diffusion/diffusion_training_example.py +++ b/examples/tutorial/diffusion/diffusion_training_example.py @@ -5,8 +5,9 @@ from pathlib import Path import torch from lerobot.configs.types import FeatureType -from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata -from lerobot.datasets.utils import dataset_to_policy_features +from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata +from lerobot.datasets.feature_utils import dataset_to_policy_features +from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig from lerobot.policies.diffusion.modeling_diffusion import DiffusionPolicy from lerobot.policies.factory import make_pre_post_processors diff --git a/examples/tutorial/diffusion/diffusion_using_example.py b/examples/tutorial/diffusion/diffusion_using_example.py index d8ac75cfe..9b31cf359 100644 --- a/examples/tutorial/diffusion/diffusion_using_example.py +++ b/examples/tutorial/diffusion/diffusion_using_example.py @@ -1,7 +1,7 @@ import torch from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig -from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata +from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata from lerobot.policies.diffusion.modeling_diffusion import DiffusionPolicy from lerobot.policies.factory import make_pre_post_processors from lerobot.policies.utils import build_inference_frame, make_robot_action diff --git a/examples/tutorial/pi0/using_pi0_example.py b/examples/tutorial/pi0/using_pi0_example.py index 056c3d81a..d8cf9dbff 100644 --- a/examples/tutorial/pi0/using_pi0_example.py +++ b/examples/tutorial/pi0/using_pi0_example.py @@ -1,7 +1,7 @@ import torch from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig -from lerobot.datasets.utils import hw_to_dataset_features +from lerobot.datasets.feature_utils import hw_to_dataset_features from lerobot.policies.factory import make_pre_post_processors from lerobot.policies.pi0.modeling_pi0 import PI0Policy from lerobot.policies.utils import build_inference_frame, make_robot_action diff --git a/examples/tutorial/rl/hilserl_example.py b/examples/tutorial/rl/hilserl_example.py index 980ac7985..d367a01ce 100644 --- a/examples/tutorial/rl/hilserl_example.py +++ b/examples/tutorial/rl/hilserl_example.py @@ -6,8 +6,8 @@ from queue import Empty, Full import torch import torch.optim as optim +from lerobot.datasets.feature_utils import hw_to_dataset_features from lerobot.datasets.lerobot_dataset import LeRobotDataset -from lerobot.datasets.utils import hw_to_dataset_features from lerobot.envs.configs import HILSerlProcessorConfig, HILSerlRobotEnvConfig from lerobot.policies.sac.configuration_sac import SACConfig from lerobot.policies.sac.modeling_sac import SACPolicy diff --git a/examples/tutorial/smolvla/using_smolvla_example.py b/examples/tutorial/smolvla/using_smolvla_example.py index ce3aa7bca..b99126efa 100644 --- a/examples/tutorial/smolvla/using_smolvla_example.py +++ b/examples/tutorial/smolvla/using_smolvla_example.py @@ -1,7 +1,7 @@ import torch from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig -from lerobot.datasets.utils import hw_to_dataset_features +from lerobot.datasets.feature_utils import hw_to_dataset_features from lerobot.policies.factory import make_pre_post_processors from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy from lerobot.policies.utils import build_inference_frame, make_robot_action diff --git a/src/lerobot/async_inference/helpers.py b/src/lerobot/async_inference/helpers.py index 8b12920d9..9dd44eb44 100644 --- a/src/lerobot/async_inference/helpers.py +++ b/src/lerobot/async_inference/helpers.py @@ -23,7 +23,7 @@ from typing import Any import torch from lerobot.configs.types import PolicyFeature -from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features +from lerobot.datasets.feature_utils import build_dataset_frame, hw_to_dataset_features # NOTE: Configs need to be loaded for the client to be able to instantiate the policy config from lerobot.policies import ( # noqa: F401 diff --git a/src/lerobot/data_processing/sarm_annotations/subtask_annotation.py b/src/lerobot/data_processing/sarm_annotations/subtask_annotation.py index 67e37bab8..8f3a65e39 100644 --- a/src/lerobot/data_processing/sarm_annotations/subtask_annotation.py +++ b/src/lerobot/data_processing/sarm_annotations/subtask_annotation.py @@ -746,7 +746,8 @@ def save_annotations_to_dataset( dataset_path: Path, annotations: dict[int, SubtaskAnnotation], fps: int, prefix: str = "sparse" ): """Save annotations to LeRobot dataset parquet format.""" - from lerobot.datasets.utils import DEFAULT_EPISODES_PATH, load_episodes + from lerobot.datasets.io_utils import load_episodes + from lerobot.datasets.utils import DEFAULT_EPISODES_PATH episodes_dataset = load_episodes(dataset_path) if not episodes_dataset or len(episodes_dataset) == 0: @@ -840,7 +841,7 @@ def generate_auto_sparse_annotations( def load_annotations_from_dataset(dataset_path: Path, prefix: str = "sparse") -> dict[int, SubtaskAnnotation]: """Load annotations from LeRobot dataset parquet files.""" - from lerobot.datasets.utils import load_episodes + from lerobot.datasets.io_utils import load_episodes episodes_dataset = load_episodes(dataset_path) if not episodes_dataset or len(episodes_dataset) == 0: diff --git a/src/lerobot/datasets/aggregate.py b/src/lerobot/datasets/aggregate.py index b32116233..66f055f04 100644 --- a/src/lerobot/datasets/aggregate.py +++ b/src/lerobot/datasets/aggregate.py @@ -24,7 +24,16 @@ import pandas as pd import tqdm from lerobot.datasets.compute_stats import aggregate_stats -from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata +from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata +from lerobot.datasets.feature_utils import get_hf_features_from_features +from lerobot.datasets.io_utils import ( + get_file_size_in_mb, + get_parquet_file_size_in_mb, + to_parquet_with_hf_images, + write_info, + write_stats, + write_tasks, +) from lerobot.datasets.utils import ( DEFAULT_CHUNK_SIZE, DEFAULT_DATA_FILE_SIZE_IN_MB, @@ -32,14 +41,7 @@ from lerobot.datasets.utils import ( DEFAULT_EPISODES_PATH, DEFAULT_VIDEO_FILE_SIZE_IN_MB, DEFAULT_VIDEO_PATH, - get_file_size_in_mb, - get_hf_features_from_features, - get_parquet_file_size_in_mb, - to_parquet_with_hf_images, update_chunk_file_indices, - write_info, - write_stats, - write_tasks, ) from lerobot.datasets.video_utils import concatenate_video_files, get_video_duration_in_s diff --git a/src/lerobot/datasets/backward_compatibility.py b/src/lerobot/datasets/backward_compatibility.py deleted file mode 100644 index aefbfd55b..000000000 --- a/src/lerobot/datasets/backward_compatibility.py +++ /dev/null @@ -1,56 +0,0 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import packaging.version - -V30_MESSAGE = """ -The dataset you requested ({repo_id}) is in {version} format. - -We introduced a new format since v3.0 which is not backward compatible with v2.1. -Please, update your dataset to the new format using this command: -``` -python -m lerobot.scripts.convert_dataset_v21_to_v30 --repo-id={repo_id} -``` - -If you already have a converted version uploaded to the hub, then this error might be because of -an older version in your local cache. Consider deleting the cached version and retrying. - -If you encounter a problem, contact LeRobot maintainers on [Discord](https://discord.com/invite/s3KuuzsPFb) -or open an [issue on GitHub](https://github.com/huggingface/lerobot/issues/new/choose). -""" - -FUTURE_MESSAGE = """ -The dataset you requested ({repo_id}) is only available in {version} format. -As we cannot ensure forward compatibility with it, please update your current version of lerobot. -""" - - -class CompatibilityError(Exception): ... - - -class BackwardCompatibilityError(CompatibilityError): - def __init__(self, repo_id: str, version: packaging.version.Version): - if version.major == 2 and version.minor == 1: - message = V30_MESSAGE.format(repo_id=repo_id, version=version) - else: - raise NotImplementedError( - "Contact the maintainer on [Discord](https://discord.com/invite/s3KuuzsPFb)." - ) - super().__init__(message) - - -class ForwardCompatibilityError(CompatibilityError): - def __init__(self, repo_id: str, version: packaging.version.Version): - message = FUTURE_MESSAGE.format(repo_id=repo_id, version=version) - super().__init__(message) diff --git a/src/lerobot/datasets/compute_stats.py b/src/lerobot/datasets/compute_stats.py index 61e174d5c..5bd95810b 100644 --- a/src/lerobot/datasets/compute_stats.py +++ b/src/lerobot/datasets/compute_stats.py @@ -15,7 +15,7 @@ # limitations under the License. import numpy as np -from lerobot.datasets.utils import load_image_as_numpy +from lerobot.datasets.io_utils import load_image_as_numpy DEFAULT_QUANTILES = [0.01, 0.10, 0.50, 0.90, 0.99] diff --git a/src/lerobot/datasets/dataset_metadata.py b/src/lerobot/datasets/dataset_metadata.py new file mode 100644 index 000000000..560a90a6e --- /dev/null +++ b/src/lerobot/datasets/dataset_metadata.py @@ -0,0 +1,517 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pathlib import Path + +import numpy as np +import packaging.version +import pandas as pd +import pyarrow as pa +import pyarrow.parquet as pq +from huggingface_hub import snapshot_download + +from lerobot.datasets.compute_stats import aggregate_stats +from lerobot.datasets.feature_utils import _validate_feature_names, create_empty_dataset_info +from lerobot.datasets.io_utils import ( + get_file_size_in_mb, + load_episodes, + load_info, + load_stats, + load_subtasks, + load_tasks, + write_info, + write_json, + write_stats, + write_tasks, +) +from lerobot.datasets.utils import ( + DEFAULT_EPISODES_PATH, + DEFAULT_FEATURES, + INFO_PATH, + check_version_compatibility, + flatten_dict, + get_safe_version, + is_valid_version, + update_chunk_file_indices, +) +from lerobot.datasets.video_utils import get_video_info +from lerobot.utils.constants import HF_LEROBOT_HOME + +CODEBASE_VERSION = "v3.0" + + +class LeRobotDatasetMetadata: + def __init__( + self, + repo_id: str, + root: str | Path | None = None, + revision: str | None = None, + force_cache_sync: bool = False, + metadata_buffer_size: int = 10, + ): + self.repo_id = repo_id + self.revision = revision if revision else CODEBASE_VERSION + self.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id + self.writer = None + self.latest_episode = None + self.metadata_buffer: list[dict] = [] + self.metadata_buffer_size = metadata_buffer_size + + try: + if force_cache_sync: + raise FileNotFoundError + self.load_metadata() + except (FileNotFoundError, NotADirectoryError): + if is_valid_version(self.revision): + self.revision = get_safe_version(self.repo_id, self.revision) + + (self.root / "meta").mkdir(exist_ok=True, parents=True) + self.pull_from_repo(allow_patterns="meta/") + self.load_metadata() + + def _flush_metadata_buffer(self) -> None: + """Write all buffered episode metadata to parquet file.""" + if not hasattr(self, "metadata_buffer") or len(self.metadata_buffer) == 0: + return + + combined_dict = {} + for episode_dict in self.metadata_buffer: + for key, value in episode_dict.items(): + if key not in combined_dict: + combined_dict[key] = [] + # Extract value and serialize numpy arrays + # because PyArrow's from_pydict function doesn't support numpy arrays + val = value[0] if isinstance(value, list) else value + combined_dict[key].append(val.tolist() if isinstance(val, np.ndarray) else val) + + first_ep = self.metadata_buffer[0] + chunk_idx = first_ep["meta/episodes/chunk_index"][0] + file_idx = first_ep["meta/episodes/file_index"][0] + + table = pa.Table.from_pydict(combined_dict) + + if not self.writer: + path = Path(self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx)) + path.parent.mkdir(parents=True, exist_ok=True) + self.writer = pq.ParquetWriter( + path, schema=table.schema, compression="snappy", use_dictionary=True + ) + + self.writer.write_table(table) + + self.latest_episode = self.metadata_buffer[-1] + self.metadata_buffer.clear() + + def _close_writer(self) -> None: + """Close and cleanup the parquet writer if it exists.""" + self._flush_metadata_buffer() + + writer = getattr(self, "writer", None) + if writer is not None: + writer.close() + self.writer = None + + def __del__(self): + """ + Trust the user to call .finalize() but as an added safety check call the parquet writer to stop when calling the destructor + """ + self._close_writer() + + def load_metadata(self): + self.info = load_info(self.root) + check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION) + self.tasks = load_tasks(self.root) + self.subtasks = load_subtasks(self.root) + self.episodes = load_episodes(self.root) + self.stats = load_stats(self.root) + + def pull_from_repo( + self, + allow_patterns: list[str] | str | None = None, + ignore_patterns: list[str] | str | None = None, + ) -> None: + snapshot_download( + self.repo_id, + repo_type="dataset", + revision=self.revision, + local_dir=self.root, + allow_patterns=allow_patterns, + ignore_patterns=ignore_patterns, + ) + + @property + def url_root(self) -> str: + return f"hf://datasets/{self.repo_id}" + + @property + def _version(self) -> packaging.version.Version: + """Codebase version used to create this dataset.""" + return packaging.version.parse(self.info["codebase_version"]) + + def get_data_file_path(self, ep_index: int) -> Path: + if self.episodes is None: + self.episodes = load_episodes(self.root) + if ep_index >= len(self.episodes): + raise IndexError( + f"Episode index {ep_index} out of range. Episodes: {len(self.episodes) if self.episodes else 0}" + ) + ep = self.episodes[ep_index] + chunk_idx = ep["data/chunk_index"] + file_idx = ep["data/file_index"] + fpath = self.data_path.format(chunk_index=chunk_idx, file_index=file_idx) + return Path(fpath) + + def get_video_file_path(self, ep_index: int, vid_key: str) -> Path: + if self.episodes is None: + self.episodes = load_episodes(self.root) + if ep_index >= len(self.episodes): + raise IndexError( + f"Episode index {ep_index} out of range. Episodes: {len(self.episodes) if self.episodes else 0}" + ) + ep = self.episodes[ep_index] + chunk_idx = ep[f"videos/{vid_key}/chunk_index"] + file_idx = ep[f"videos/{vid_key}/file_index"] + fpath = self.video_path.format(video_key=vid_key, chunk_index=chunk_idx, file_index=file_idx) + return Path(fpath) + + @property + def data_path(self) -> str: + """Formattable string for the parquet files.""" + return self.info["data_path"] + + @property + def video_path(self) -> str | None: + """Formattable string for the video files.""" + return self.info["video_path"] + + @property + def robot_type(self) -> str | None: + """Robot type used in recording this dataset.""" + return self.info["robot_type"] + + @property + def fps(self) -> int: + """Frames per second used during data collection.""" + return self.info["fps"] + + @property + def features(self) -> dict[str, dict]: + """All features contained in the dataset.""" + return self.info["features"] + + @property + def image_keys(self) -> list[str]: + """Keys to access visual modalities stored as images.""" + return [key for key, ft in self.features.items() if ft["dtype"] == "image"] + + @property + def video_keys(self) -> list[str]: + """Keys to access visual modalities stored as videos.""" + return [key for key, ft in self.features.items() if ft["dtype"] == "video"] + + @property + def camera_keys(self) -> list[str]: + """Keys to access visual modalities (regardless of their storage method).""" + return [key for key, ft in self.features.items() if ft["dtype"] in ["video", "image"]] + + @property + def names(self) -> dict[str, list | dict]: + """Names of the various dimensions of vector modalities.""" + return {key: ft["names"] for key, ft in self.features.items()} + + @property + def shapes(self) -> dict: + """Shapes for the different features.""" + return {key: tuple(ft["shape"]) for key, ft in self.features.items()} + + @property + def total_episodes(self) -> int: + """Total number of episodes available.""" + return self.info["total_episodes"] + + @property + def total_frames(self) -> int: + """Total number of frames saved in this dataset.""" + return self.info["total_frames"] + + @property + def total_tasks(self) -> int: + """Total number of different tasks performed in this dataset.""" + return self.info["total_tasks"] + + @property + def chunks_size(self) -> int: + """Max number of files per chunk.""" + return self.info["chunks_size"] + + @property + def data_files_size_in_mb(self) -> int: + """Max size of data file in mega bytes.""" + return self.info["data_files_size_in_mb"] + + @property + def video_files_size_in_mb(self) -> int: + """Max size of video file in mega bytes.""" + return self.info["video_files_size_in_mb"] + + def get_task_index(self, task: str) -> int | None: + """ + Given a task in natural language, returns its task_index if the task already exists in the dataset, + otherwise return None. + """ + if task in self.tasks.index: + return int(self.tasks.loc[task].task_index) + else: + return None + + def save_episode_tasks(self, tasks: list[str]): + if len(set(tasks)) != len(tasks): + raise ValueError(f"Tasks are not unique: {tasks}") + + if self.tasks is None: + new_tasks = tasks + task_indices = range(len(tasks)) + self.tasks = pd.DataFrame({"task_index": task_indices}, index=pd.Index(tasks, name="task")) + else: + new_tasks = [task for task in tasks if task not in self.tasks.index] + new_task_indices = range(len(self.tasks), len(self.tasks) + len(new_tasks)) + for task_idx, task in zip(new_task_indices, new_tasks, strict=False): + self.tasks.loc[task] = task_idx + + if len(new_tasks) > 0: + # Update on disk + write_tasks(self.tasks, self.root) + + def _save_episode_metadata(self, episode_dict: dict) -> None: + """Buffer episode metadata and write to parquet in batches for efficiency. + + This function accumulates episode metadata in a buffer and flushes it when the buffer + reaches the configured size. This reduces I/O overhead by writing multiple episodes + at once instead of one row at a time. + + Notes: We both need to update parquet files and HF dataset: + - `pandas` loads parquet file in RAM + - `datasets` relies on a memory mapping from pyarrow (no RAM). It either converts parquet files to a pyarrow cache on disk, + or loads directly from pyarrow cache. + """ + # Convert to list format for each value + episode_dict = {key: [value] for key, value in episode_dict.items()} + num_frames = episode_dict["length"][0] + + if self.latest_episode is None: + # Initialize indices and frame count for a new dataset made of the first episode data + chunk_idx, file_idx = 0, 0 + if self.episodes is not None and len(self.episodes) > 0: + # It means we are resuming recording, so we need to load the latest episode + # Update the indices to avoid overwriting the latest episode + chunk_idx = self.episodes[-1]["meta/episodes/chunk_index"] + file_idx = self.episodes[-1]["meta/episodes/file_index"] + latest_num_frames = self.episodes[-1]["dataset_to_index"] + episode_dict["dataset_from_index"] = [latest_num_frames] + episode_dict["dataset_to_index"] = [latest_num_frames + num_frames] + + # When resuming, move to the next file + chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.chunks_size) + else: + episode_dict["dataset_from_index"] = [0] + episode_dict["dataset_to_index"] = [num_frames] + + episode_dict["meta/episodes/chunk_index"] = [chunk_idx] + episode_dict["meta/episodes/file_index"] = [file_idx] + else: + chunk_idx = self.latest_episode["meta/episodes/chunk_index"][0] + file_idx = self.latest_episode["meta/episodes/file_index"][0] + + latest_path = ( + self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx) + if self.writer is None + else self.writer.where + ) + + if Path(latest_path).exists(): + latest_size_in_mb = get_file_size_in_mb(Path(latest_path)) + latest_num_frames = self.latest_episode["episode_index"][0] + + av_size_per_frame = latest_size_in_mb / latest_num_frames if latest_num_frames > 0 else 0.0 + + if latest_size_in_mb + av_size_per_frame * num_frames >= self.data_files_size_in_mb: + # Size limit is reached, flush buffer and prepare new parquet file + self._flush_metadata_buffer() + chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.chunks_size) + self._close_writer() + + # Update the existing pandas dataframe with new row + episode_dict["meta/episodes/chunk_index"] = [chunk_idx] + episode_dict["meta/episodes/file_index"] = [file_idx] + episode_dict["dataset_from_index"] = [self.latest_episode["dataset_to_index"][0]] + episode_dict["dataset_to_index"] = [self.latest_episode["dataset_to_index"][0] + num_frames] + + # Add to buffer + self.metadata_buffer.append(episode_dict) + self.latest_episode = episode_dict + + if len(self.metadata_buffer) >= self.metadata_buffer_size: + self._flush_metadata_buffer() + + def save_episode( + self, + episode_index: int, + episode_length: int, + episode_tasks: list[str], + episode_stats: dict[str, dict], + episode_metadata: dict, + ) -> None: + episode_dict = { + "episode_index": episode_index, + "tasks": episode_tasks, + "length": episode_length, + } + episode_dict.update(episode_metadata) + episode_dict.update(flatten_dict({"stats": episode_stats})) + self._save_episode_metadata(episode_dict) + + # Update info + self.info["total_episodes"] += 1 + self.info["total_frames"] += episode_length + self.info["total_tasks"] = len(self.tasks) + self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"} + + write_info(self.info, self.root) + + self.stats = aggregate_stats([self.stats, episode_stats]) if self.stats is not None else episode_stats + write_stats(self.stats, self.root) + + def update_video_info(self, video_key: str | None = None) -> None: + """ + Warning: this function writes info from first episode videos, implicitly assuming that all videos have + been encoded the same way. Also, this means it assumes the first episode exists. + """ + if video_key is not None and video_key not in self.video_keys: + raise ValueError(f"Video key {video_key} not found in dataset") + + video_keys = [video_key] if video_key is not None else self.video_keys + for key in video_keys: + if not self.features[key].get("info", None): + video_path = self.root / self.video_path.format(video_key=key, chunk_index=0, file_index=0) + self.info["features"][key]["info"] = get_video_info(video_path) + + def update_chunk_settings( + self, + chunks_size: int | None = None, + data_files_size_in_mb: int | None = None, + video_files_size_in_mb: int | None = None, + ) -> None: + """Update chunk and file size settings after dataset creation. + + This allows users to customize storage organization without modifying the constructor. + These settings control how episodes are chunked and how large files can grow before + creating new ones. + + Args: + chunks_size: Maximum number of files per chunk directory. If None, keeps current value. + data_files_size_in_mb: Maximum size for data parquet files in MB. If None, keeps current value. + video_files_size_in_mb: Maximum size for video files in MB. If None, keeps current value. + """ + if chunks_size is not None: + if chunks_size <= 0: + raise ValueError(f"chunks_size must be positive, got {chunks_size}") + self.info["chunks_size"] = chunks_size + + if data_files_size_in_mb is not None: + if data_files_size_in_mb <= 0: + raise ValueError(f"data_files_size_in_mb must be positive, got {data_files_size_in_mb}") + self.info["data_files_size_in_mb"] = data_files_size_in_mb + + if video_files_size_in_mb is not None: + if video_files_size_in_mb <= 0: + raise ValueError(f"video_files_size_in_mb must be positive, got {video_files_size_in_mb}") + self.info["video_files_size_in_mb"] = video_files_size_in_mb + + # Update the info file on disk + write_info(self.info, self.root) + + def get_chunk_settings(self) -> dict[str, int]: + """Get current chunk and file size settings. + + Returns: + Dict containing chunks_size, data_files_size_in_mb, and video_files_size_in_mb. + """ + return { + "chunks_size": self.chunks_size, + "data_files_size_in_mb": self.data_files_size_in_mb, + "video_files_size_in_mb": self.video_files_size_in_mb, + } + + def __repr__(self): + feature_keys = list(self.features) + return ( + f"{self.__class__.__name__}({{\n" + f" Repository ID: '{self.repo_id}',\n" + f" Total episodes: '{self.total_episodes}',\n" + f" Total frames: '{self.total_frames}',\n" + f" Features: '{feature_keys}',\n" + "})',\n" + ) + + @classmethod + def create( + cls, + repo_id: str, + fps: int, + features: dict, + robot_type: str | None = None, + root: str | Path | None = None, + use_videos: bool = True, + metadata_buffer_size: int = 10, + chunks_size: int | None = None, + data_files_size_in_mb: int | None = None, + video_files_size_in_mb: int | None = None, + ) -> "LeRobotDatasetMetadata": + """Creates metadata for a LeRobotDataset.""" + obj = cls.__new__(cls) + obj.repo_id = repo_id + obj.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id + + obj.root.mkdir(parents=True, exist_ok=False) + + features = {**features, **DEFAULT_FEATURES} + _validate_feature_names(features) + + obj.tasks = None + obj.subtasks = None + obj.episodes = None + obj.stats = None + obj.info = create_empty_dataset_info( + CODEBASE_VERSION, + fps, + features, + use_videos, + robot_type, + chunks_size, + data_files_size_in_mb, + video_files_size_in_mb, + ) + if len(obj.video_keys) > 0 and not use_videos: + raise ValueError( + f"Features contain video keys {obj.video_keys}, but 'use_videos' is set to False. " + "Either remove video features from the features dict, or set 'use_videos=True'." + ) + write_json(obj.info, obj.root / INFO_PATH) + obj.revision = None + obj.writer = None + obj.latest_episode = None + obj.metadata_buffer = [] + obj.metadata_buffer_size = metadata_buffer_size + return obj diff --git a/src/lerobot/datasets/dataset_tools.py b/src/lerobot/datasets/dataset_tools.py index 546b3d67f..87cdc18e5 100644 --- a/src/lerobot/datasets/dataset_tools.py +++ b/src/lerobot/datasets/dataset_tools.py @@ -38,19 +38,22 @@ from tqdm import tqdm from lerobot.datasets.aggregate import aggregate_datasets from lerobot.datasets.compute_stats import aggregate_stats -from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata +from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata +from lerobot.datasets.io_utils import ( + get_parquet_file_size_in_mb, + load_episodes, + write_info, + write_stats, + write_tasks, +) +from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.utils import ( DATA_DIR, DEFAULT_CHUNK_SIZE, DEFAULT_DATA_FILE_SIZE_IN_MB, DEFAULT_DATA_PATH, DEFAULT_EPISODES_PATH, - get_parquet_file_size_in_mb, - load_episodes, update_chunk_file_indices, - write_info, - write_stats, - write_tasks, ) from lerobot.datasets.video_utils import encode_video_frames, get_video_info from lerobot.utils.constants import HF_LEROBOT_HOME, OBS_IMAGE @@ -915,7 +918,8 @@ def _write_parquet(df: pd.DataFrame, path: Path, meta: LeRobotDatasetMetadata) - This ensures images are properly embedded and the file can be loaded correctly by HF datasets. """ - from lerobot.datasets.utils import embed_images, get_hf_features_from_features + from lerobot.datasets.feature_utils import get_hf_features_from_features + from lerobot.datasets.io_utils import embed_images hf_features = get_hf_features_from_features(meta.features) ep_dataset = datasets.Dataset.from_dict(df.to_dict(orient="list"), features=hf_features, split="train") diff --git a/src/lerobot/datasets/factory.py b/src/lerobot/datasets/factory.py index 31e939809..76ece8961 100644 --- a/src/lerobot/datasets/factory.py +++ b/src/lerobot/datasets/factory.py @@ -20,11 +20,9 @@ import torch from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.train import TrainPipelineConfig -from lerobot.datasets.lerobot_dataset import ( - LeRobotDataset, - LeRobotDatasetMetadata, - MultiLeRobotDataset, -) +from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata +from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.datasets.multi_dataset import MultiLeRobotDataset from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset from lerobot.datasets.transforms import ImageTransforms from lerobot.utils.constants import ACTION, OBS_PREFIX, REWARD diff --git a/src/lerobot/datasets/feature_utils.py b/src/lerobot/datasets/feature_utils.py new file mode 100644 index 000000000..d9a3c6301 --- /dev/null +++ b/src/lerobot/datasets/feature_utils.py @@ -0,0 +1,552 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pprint import pformat +from typing import Any + +import datasets +import numpy as np +from PIL import Image as PILImage + +from lerobot.configs.types import FeatureType, PolicyFeature +from lerobot.datasets.utils import ( + DEFAULT_CHUNK_SIZE, + DEFAULT_DATA_FILE_SIZE_IN_MB, + DEFAULT_DATA_PATH, + DEFAULT_FEATURES, + DEFAULT_VIDEO_FILE_SIZE_IN_MB, + DEFAULT_VIDEO_PATH, +) +from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_STR +from lerobot.utils.utils import is_valid_numpy_dtype_string + + +def get_hf_features_from_features(features: dict) -> datasets.Features: + """Convert a LeRobot features dictionary to a `datasets.Features` object. + + Args: + features (dict): A LeRobot-style feature dictionary. + + Returns: + datasets.Features: The corresponding Hugging Face `datasets.Features` object. + + Raises: + ValueError: If a feature has an unsupported shape. + """ + hf_features = {} + for key, ft in features.items(): + if ft["dtype"] == "video": + continue + elif ft["dtype"] == "image": + hf_features[key] = datasets.Image() + elif ft["shape"] == (1,): + hf_features[key] = datasets.Value(dtype=ft["dtype"]) + elif len(ft["shape"]) == 1: + hf_features[key] = datasets.Sequence( + length=ft["shape"][0], feature=datasets.Value(dtype=ft["dtype"]) + ) + elif len(ft["shape"]) == 2: + hf_features[key] = datasets.Array2D(shape=ft["shape"], dtype=ft["dtype"]) + elif len(ft["shape"]) == 3: + hf_features[key] = datasets.Array3D(shape=ft["shape"], dtype=ft["dtype"]) + elif len(ft["shape"]) == 4: + hf_features[key] = datasets.Array4D(shape=ft["shape"], dtype=ft["dtype"]) + elif len(ft["shape"]) == 5: + hf_features[key] = datasets.Array5D(shape=ft["shape"], dtype=ft["dtype"]) + else: + raise ValueError(f"Corresponding feature is not valid: {ft}") + + return datasets.Features(hf_features) + + +def _validate_feature_names(features: dict[str, dict]) -> None: + """Validate that feature names do not contain invalid characters. + + Args: + features (dict): The LeRobot features dictionary. + + Raises: + ValueError: If any feature name contains '/'. + """ + invalid_features = {name: ft for name, ft in features.items() if "/" in name} + if invalid_features: + raise ValueError(f"Feature names should not contain '/'. Found '/' in '{invalid_features}'.") + + +def hw_to_dataset_features( + hw_features: dict[str, type | tuple], prefix: str, use_video: bool = True +) -> dict[str, dict]: + """Convert hardware-specific features to a LeRobot dataset feature dictionary. + + This function takes a dictionary describing hardware outputs (like joint states + or camera image shapes) and formats it into the standard LeRobot feature + specification. + + Args: + hw_features (dict): Dictionary mapping feature names to their type (float for + joints) or shape (tuple for images). + prefix (str): The prefix to add to the feature keys (e.g., "observation" + or "action"). + use_video (bool): If True, image features are marked as "video", otherwise "image". + + Returns: + dict: A LeRobot features dictionary. + """ + features = {} + joint_fts = { + key: ftype + for key, ftype in hw_features.items() + if ftype is float or (isinstance(ftype, PolicyFeature) and ftype.type != FeatureType.VISUAL) + } + cam_fts = {key: shape for key, shape in hw_features.items() if isinstance(shape, tuple)} + + if joint_fts and prefix == ACTION: + features[prefix] = { + "dtype": "float32", + "shape": (len(joint_fts),), + "names": list(joint_fts), + } + + if joint_fts and prefix == OBS_STR: + features[f"{prefix}.state"] = { + "dtype": "float32", + "shape": (len(joint_fts),), + "names": list(joint_fts), + } + + for key, shape in cam_fts.items(): + features[f"{prefix}.images.{key}"] = { + "dtype": "video" if use_video else "image", + "shape": shape, + "names": ["height", "width", "channels"], + } + + _validate_feature_names(features) + return features + + +def build_dataset_frame( + ds_features: dict[str, dict], values: dict[str, Any], prefix: str +) -> dict[str, np.ndarray]: + """Construct a single data frame from raw values based on dataset features. + + A "frame" is a dictionary containing all the data for a single timestep, + formatted as numpy arrays according to the feature specification. + + Args: + ds_features (dict): The LeRobot dataset features dictionary. + values (dict): A dictionary of raw values from the hardware/environment. + prefix (str): The prefix to filter features by (e.g., "observation" + or "action"). + + Returns: + dict: A dictionary representing a single frame of data. + """ + frame = {} + for key, ft in ds_features.items(): + if key in DEFAULT_FEATURES or not key.startswith(prefix): + continue + elif ft["dtype"] == "float32" and len(ft["shape"]) == 1: + frame[key] = np.array([values[name] for name in ft["names"]], dtype=np.float32) + elif ft["dtype"] in ["image", "video"]: + frame[key] = values[key.removeprefix(f"{prefix}.images.")] + + return frame + + +def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFeature]: + """Convert dataset features to policy features. + + This function transforms the dataset's feature specification into a format + that a policy can use, classifying features by type (e.g., visual, state, + action) and ensuring correct shapes (e.g., channel-first for images). + + Args: + features (dict): The LeRobot dataset features dictionary. + + Returns: + dict: A dictionary mapping feature keys to `PolicyFeature` objects. + + Raises: + ValueError: If an image feature does not have a 3D shape. + """ + # TODO(aliberts): Implement "type" in dataset features and simplify this + policy_features = {} + for key, ft in features.items(): + shape = ft["shape"] + if ft["dtype"] in ["image", "video"]: + type = FeatureType.VISUAL + if len(shape) != 3: + raise ValueError(f"Number of dimensions of {key} != 3 (shape={shape})") + + names = ft["names"] + # Backward compatibility for "channel" which is an error introduced in LeRobotDataset v2.0 for ported datasets. + if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w) + shape = (shape[2], shape[0], shape[1]) + elif key == OBS_ENV_STATE: + type = FeatureType.ENV + elif key.startswith(OBS_STR): + type = FeatureType.STATE + elif key.startswith(ACTION): + type = FeatureType.ACTION + else: + continue + + policy_features[key] = PolicyFeature( + type=type, + shape=shape, + ) + + return policy_features + + +def combine_feature_dicts(*dicts: dict) -> dict: + """Merge LeRobot grouped feature dicts. + + - For 1D numeric specs (dtype not image/video/string) with "names": we merge the names and recompute the shape. + - For others (e.g. `observation.images.*`), the last one wins (if they are identical). + + Args: + *dicts: A variable number of LeRobot feature dictionaries to merge. + + Returns: + dict: A single merged feature dictionary. + + Raises: + ValueError: If there's a dtype mismatch for a feature being merged. + """ + out: dict = {} + for d in dicts: + for key, value in d.items(): + if not isinstance(value, dict): + out[key] = value + continue + + dtype = value.get("dtype") + shape = value.get("shape") + is_vector = ( + dtype not in ("image", "video", "string") + and isinstance(shape, tuple) + and len(shape) == 1 + and "names" in value + ) + + if is_vector: + # Initialize or retrieve the accumulating dict for this feature key + target = out.setdefault(key, {"dtype": dtype, "names": [], "shape": (0,)}) + # Ensure consistent data types across merged entries + if "dtype" in target and dtype != target["dtype"]: + raise ValueError(f"dtype mismatch for '{key}': {target['dtype']} vs {dtype}") + + # Merge feature names: append only new ones to preserve order without duplicates + seen = set(target["names"]) + for n in value["names"]: + if n not in seen: + target["names"].append(n) + seen.add(n) + # Recompute the shape to reflect the updated number of features + target["shape"] = (len(target["names"]),) + else: + # For images/videos and non-1D entries: override with the latest definition + out[key] = value + return out + + +def create_empty_dataset_info( + codebase_version: str, + fps: int, + features: dict, + use_videos: bool, + robot_type: str | None = None, + chunks_size: int | None = None, + data_files_size_in_mb: int | None = None, + video_files_size_in_mb: int | None = None, +) -> dict: + """Create a template dictionary for a new dataset's `info.json`. + + Args: + codebase_version (str): The version of the LeRobot codebase. + fps (int): The frames per second of the data. + features (dict): The LeRobot features dictionary for the dataset. + use_videos (bool): Whether the dataset will store videos. + robot_type (str | None): The type of robot used, if any. + + Returns: + dict: A dictionary with the initial dataset metadata. + """ + return { + "codebase_version": codebase_version, + "robot_type": robot_type, + "total_episodes": 0, + "total_frames": 0, + "total_tasks": 0, + "chunks_size": chunks_size or DEFAULT_CHUNK_SIZE, + "data_files_size_in_mb": data_files_size_in_mb or DEFAULT_DATA_FILE_SIZE_IN_MB, + "video_files_size_in_mb": video_files_size_in_mb or DEFAULT_VIDEO_FILE_SIZE_IN_MB, + "fps": fps, + "splits": {}, + "data_path": DEFAULT_DATA_PATH, + "video_path": DEFAULT_VIDEO_PATH if use_videos else None, + "features": features, + } + + +def check_delta_timestamps( + delta_timestamps: dict[str, list[float]], fps: int, tolerance_s: float, raise_value_error: bool = True +) -> bool: + """Check if delta timestamps are multiples of 1/fps +/- tolerance. + + This ensures that adding these delta timestamps to any existing timestamp in + the dataset will result in a value that aligns with the dataset's frame rate. + + Args: + delta_timestamps (dict): A dictionary where values are lists of time + deltas in seconds. + fps (int): The frames per second of the dataset. + tolerance_s (float): The allowed tolerance in seconds. + raise_value_error (bool): If True, raises an error on failure. + + Returns: + bool: True if all deltas are valid, False otherwise. + + Raises: + ValueError: If any delta is outside the tolerance and `raise_value_error` is True. + """ + outside_tolerance = {} + for key, delta_ts in delta_timestamps.items(): + within_tolerance = [abs(ts * fps - round(ts * fps)) / fps <= tolerance_s for ts in delta_ts] + if not all(within_tolerance): + outside_tolerance[key] = [ + ts for ts, is_within in zip(delta_ts, within_tolerance, strict=True) if not is_within + ] + + if len(outside_tolerance) > 0: + if raise_value_error: + raise ValueError( + f""" + The following delta_timestamps are found outside of tolerance range. + Please make sure they are multiples of 1/{fps} +/- tolerance and adjust + their values accordingly. + \n{pformat(outside_tolerance)} + """ + ) + return False + + return True + + +def get_delta_indices(delta_timestamps: dict[str, list[float]], fps: int) -> dict[str, list[int]]: + """Convert delta timestamps in seconds to delta indices in frames. + + Args: + delta_timestamps (dict): A dictionary of time deltas in seconds. + fps (int): The frames per second of the dataset. + + Returns: + dict: A dictionary of frame delta indices. + """ + delta_indices = {} + for key, delta_ts in delta_timestamps.items(): + delta_indices[key] = [round(d * fps) for d in delta_ts] + + return delta_indices + + +def validate_frame(frame: dict, features: dict) -> None: + expected_features = set(features) - set(DEFAULT_FEATURES) + actual_features = set(frame) + + # task is a special required field that's not part of regular features + if "task" not in actual_features: + raise ValueError("Feature mismatch in `frame` dictionary:\nMissing features: {'task'}\n") + + # Remove task from actual_features for regular feature validation + actual_features_for_validation = actual_features - {"task"} + + error_message = validate_features_presence(actual_features_for_validation, expected_features) + + common_features = actual_features_for_validation & expected_features + for name in common_features: + error_message += validate_feature_dtype_and_shape(name, features[name], frame[name]) + + if error_message: + raise ValueError(error_message) + + +def validate_features_presence(actual_features: set[str], expected_features: set[str]) -> str: + """Check for missing or extra features in a frame. + + Args: + actual_features (set[str]): The set of feature names present in the frame. + expected_features (set[str]): The set of feature names expected in the frame. + + Returns: + str: An error message string if there's a mismatch, otherwise an empty string. + """ + error_message = "" + missing_features = expected_features - actual_features + extra_features = actual_features - expected_features + + if missing_features or extra_features: + error_message += "Feature mismatch in `frame` dictionary:\n" + if missing_features: + error_message += f"Missing features: {missing_features}\n" + if extra_features: + error_message += f"Extra features: {extra_features}\n" + + return error_message + + +def validate_feature_dtype_and_shape( + name: str, feature: dict, value: np.ndarray | PILImage.Image | str +) -> str: + """Validate the dtype and shape of a single feature's value. + + Args: + name (str): The name of the feature. + feature (dict): The feature specification from the LeRobot features dictionary. + value: The value of the feature to validate. + + Returns: + str: An error message if validation fails, otherwise an empty string. + + Raises: + NotImplementedError: If the feature dtype is not supported for validation. + """ + expected_dtype = feature["dtype"] + expected_shape = feature["shape"] + if is_valid_numpy_dtype_string(expected_dtype): + return validate_feature_numpy_array(name, expected_dtype, expected_shape, value) + elif expected_dtype in ["image", "video"]: + return validate_feature_image_or_video(name, expected_shape, value) + elif expected_dtype == "string": + return validate_feature_string(name, value) + else: + raise NotImplementedError(f"The feature dtype '{expected_dtype}' is not implemented yet.") + + +def validate_feature_numpy_array( + name: str, expected_dtype: str, expected_shape: list[int], value: np.ndarray +) -> str: + """Validate a feature that is expected to be a numpy array. + + Args: + name (str): The name of the feature. + expected_dtype (str): The expected numpy dtype as a string. + expected_shape (list[int]): The expected shape. + value (np.ndarray): The numpy array to validate. + + Returns: + str: An error message if validation fails, otherwise an empty string. + """ + error_message = "" + if isinstance(value, np.ndarray): + actual_dtype = value.dtype + actual_shape = value.shape + + if actual_dtype != np.dtype(expected_dtype): + error_message += f"The feature '{name}' of dtype '{actual_dtype}' is not of the expected dtype '{expected_dtype}'.\n" + + if actual_shape != expected_shape: + error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{expected_shape}'.\n" + else: + error_message += f"The feature '{name}' is not a 'np.ndarray'. Expected type is '{expected_dtype}', but type '{type(value)}' provided instead.\n" + + return error_message + + +def validate_feature_image_or_video( + name: str, expected_shape: list[str], value: np.ndarray | PILImage.Image +) -> str: + """Validate a feature that is expected to be an image or video frame. + + Accepts `np.ndarray` (channel-first or channel-last) or `PIL.Image.Image`. + + Args: + name (str): The name of the feature. + expected_shape (list[str]): The expected shape (C, H, W). + value: The image data to validate. + + Returns: + str: An error message if validation fails, otherwise an empty string. + """ + # Note: The check of pixels range ([0,1] for float and [0,255] for uint8) is done by the image writer threads. + error_message = "" + if isinstance(value, np.ndarray): + actual_shape = value.shape + c, h, w = expected_shape + if len(actual_shape) != 3 or (actual_shape != (c, h, w) and actual_shape != (h, w, c)): + error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(c, h, w)}' or '{(h, w, c)}'.\n" + elif isinstance(value, PILImage.Image): + pass + else: + error_message += f"The feature '{name}' is expected to be of type 'PIL.Image' or 'np.ndarray' channel first or channel last, but type '{type(value)}' provided instead.\n" + + return error_message + + +def validate_feature_string(name: str, value: str) -> str: + """Validate a feature that is expected to be a string. + + Args: + name (str): The name of the feature. + value (str): The value to validate. + + Returns: + str: An error message if validation fails, otherwise an empty string. + """ + if not isinstance(value, str): + return f"The feature '{name}' is expected to be of type 'str', but type '{type(value)}' provided instead.\n" + return "" + + +def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features: dict) -> None: + """Validate the episode buffer before it's written to disk. + + Ensures the buffer has the required keys, contains at least one frame, and + has features consistent with the dataset's specification. + + Args: + episode_buffer (dict): The buffer containing data for a single episode. + total_episodes (int): The current total number of episodes in the dataset. + features (dict): The LeRobot features dictionary for the dataset. + + Raises: + ValueError: If the buffer is invalid. + NotImplementedError: If the episode index is manually set and doesn't match. + """ + if "size" not in episode_buffer: + raise ValueError("size key not found in episode_buffer") + + if "task" not in episode_buffer: + raise ValueError("task key not found in episode_buffer") + + if episode_buffer["episode_index"] != total_episodes: + # TODO(aliberts): Add option to use existing episode_index + raise NotImplementedError( + "You might have manually provided the episode_buffer with an episode_index that doesn't " + "match the total number of episodes already in the dataset. This is not supported for now." + ) + + if episode_buffer["size"] == 0: + raise ValueError("You must add one or several frames with `add_frame` before calling `add_episode`.") + + buffer_keys = set(episode_buffer.keys()) - {"task", "size"} + if not buffer_keys == set(features): + raise ValueError( + f"Features from `episode_buffer` don't match the ones in `features`." + f"In episode_buffer not in features: {buffer_keys - set(features)}" + f"In features not in episode_buffer: {set(features) - buffer_keys}" + ) diff --git a/src/lerobot/datasets/io_utils.py b/src/lerobot/datasets/io_utils.py new file mode 100644 index 000000000..cee6cfba8 --- /dev/null +++ b/src/lerobot/datasets/io_utils.py @@ -0,0 +1,342 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +from pathlib import Path +from typing import Any + +import datasets +import numpy as np +import pandas +import pandas as pd +import pyarrow.dataset as pa_ds +import pyarrow.parquet as pq +import torch +from datasets import Dataset +from datasets.table import embed_table_storage +from PIL import Image as PILImage +from torchvision import transforms + +from lerobot.datasets.utils import ( + DEFAULT_DATA_FILE_SIZE_IN_MB, + DEFAULT_EPISODES_PATH, + DEFAULT_SUBTASKS_PATH, + DEFAULT_TASKS_PATH, + EPISODES_DIR, + INFO_PATH, + STATS_PATH, + flatten_dict, + serialize_dict, + unflatten_dict, +) +from lerobot.utils.utils import SuppressProgressBars + + +def get_parquet_file_size_in_mb(parquet_path: str | Path) -> float: + metadata = pq.read_metadata(parquet_path) + total_uncompressed_size = 0 + for row_group in range(metadata.num_row_groups): + rg_metadata = metadata.row_group(row_group) + for column in range(rg_metadata.num_columns): + col_metadata = rg_metadata.column(column) + total_uncompressed_size += col_metadata.total_uncompressed_size + return total_uncompressed_size / (1024**2) + + +def get_hf_dataset_size_in_mb(hf_ds: Dataset) -> int: + return hf_ds.data.nbytes // (1024**2) + + +def load_nested_dataset( + pq_dir: Path, features: datasets.Features | None = None, episodes: list[int] | None = None +) -> Dataset: + """Find parquet files in provided directory {pq_dir}/chunk-xxx/file-xxx.parquet + Convert parquet files to pyarrow memory mapped in a cache folder for efficient RAM usage + Concatenate all pyarrow references to return HF Dataset format + + Args: + pq_dir: Directory containing parquet files + features: Optional features schema to ensure consistent loading of complex types like images + episodes: Optional list of episode indices to filter. Uses PyArrow predicate pushdown for efficiency. + """ + paths = sorted(pq_dir.glob("*/*.parquet")) + if len(paths) == 0: + raise FileNotFoundError(f"Provided directory does not contain any parquet file: {pq_dir}") + + with SuppressProgressBars(): + # We use .from_parquet() memory-mapped loading for efficiency + filters = pa_ds.field("episode_index").isin(episodes) if episodes is not None else None + return Dataset.from_parquet([str(path) for path in paths], filters=filters, features=features) + + +def get_parquet_num_frames(parquet_path: str | Path) -> int: + metadata = pq.read_metadata(parquet_path) + return metadata.num_rows + + +def get_file_size_in_mb(file_path: Path) -> float: + """Get file size on disk in megabytes. + + Args: + file_path (Path): Path to the file. + """ + file_size_bytes = file_path.stat().st_size + return file_size_bytes / (1024**2) + + +def embed_images(dataset: datasets.Dataset) -> datasets.Dataset: + """Embed image bytes into the dataset table before saving to Parquet. + + This function prepares a Hugging Face dataset for serialization by converting + image objects into an embedded format that can be stored in Arrow/Parquet. + + Args: + dataset (datasets.Dataset): The input dataset, possibly containing image features. + + Returns: + datasets.Dataset: The dataset with images embedded in the table storage. + """ + # Embed image bytes into the table before saving to parquet + format = dataset.format + dataset = dataset.with_format("arrow") + dataset = dataset.map(embed_table_storage, batched=False) + dataset = dataset.with_format(**format) + return dataset + + +def load_json(fpath: Path) -> Any: + """Load data from a JSON file. + + Args: + fpath (Path): Path to the JSON file. + + Returns: + Any: The data loaded from the JSON file. + """ + with open(fpath) as f: + return json.load(f) + + +def write_json(data: dict, fpath: Path) -> None: + """Write data to a JSON file. + + Creates parent directories if they don't exist. + + Args: + data (dict): The dictionary to write. + fpath (Path): The path to the output JSON file. + """ + fpath.parent.mkdir(exist_ok=True, parents=True) + with open(fpath, "w") as f: + json.dump(data, f, indent=4, ensure_ascii=False) + + +def write_info(info: dict, local_dir: Path) -> None: + write_json(info, local_dir / INFO_PATH) + + +def load_info(local_dir: Path) -> dict: + """Load dataset info metadata from its standard file path. + + Also converts shape lists to tuples for consistency. + + Args: + local_dir (Path): The root directory of the dataset. + + Returns: + dict: The dataset information dictionary. + """ + info = load_json(local_dir / INFO_PATH) + for ft in info["features"].values(): + ft["shape"] = tuple(ft["shape"]) + return info + + +def write_stats(stats: dict, local_dir: Path) -> None: + """Serialize and write dataset statistics to their standard file path. + + Args: + stats (dict): The statistics dictionary (can contain tensors/numpy arrays). + local_dir (Path): The root directory of the dataset. + """ + serialized_stats = serialize_dict(stats) + write_json(serialized_stats, local_dir / STATS_PATH) + + +def cast_stats_to_numpy(stats: dict) -> dict[str, dict[str, np.ndarray]]: + """Recursively cast numerical values in a stats dictionary to numpy arrays. + + Args: + stats (dict): The statistics dictionary. + + Returns: + dict: The statistics dictionary with values cast to numpy arrays. + """ + stats = {key: np.array(value) for key, value in flatten_dict(stats).items()} + return unflatten_dict(stats) + + +def load_stats(local_dir: Path) -> dict[str, dict[str, np.ndarray]] | None: + """Load dataset statistics and cast numerical values to numpy arrays. + + Returns None if the stats file doesn't exist. + + Args: + local_dir (Path): The root directory of the dataset. + + Returns: + A dictionary of statistics or None if the file is not found. + """ + if not (local_dir / STATS_PATH).exists(): + return None + stats = load_json(local_dir / STATS_PATH) + return cast_stats_to_numpy(stats) + + +def write_tasks(tasks: pandas.DataFrame, local_dir: Path) -> None: + path = local_dir / DEFAULT_TASKS_PATH + path.parent.mkdir(parents=True, exist_ok=True) + tasks.to_parquet(path) + + +def load_tasks(local_dir: Path) -> pandas.DataFrame: + tasks = pd.read_parquet(local_dir / DEFAULT_TASKS_PATH) + tasks.index.name = "task" + return tasks + + +def load_subtasks(local_dir: Path) -> pandas.DataFrame | None: + """Load subtasks from subtasks.parquet if it exists.""" + subtasks_path = local_dir / DEFAULT_SUBTASKS_PATH + if subtasks_path.exists(): + return pd.read_parquet(subtasks_path) + return None + + +def write_episodes(episodes: Dataset, local_dir: Path) -> None: + """Write episode metadata to a parquet file in the LeRobot v3.0 format. + This function writes episode-level metadata to a single parquet file. + Used primarily during dataset conversion (v2.1 → v3.0) and in test fixtures. + + Args: + episodes: HuggingFace Dataset containing episode metadata + local_dir: Root directory where the dataset will be stored + """ + episode_size_mb = get_hf_dataset_size_in_mb(episodes) + if episode_size_mb > DEFAULT_DATA_FILE_SIZE_IN_MB: + raise NotImplementedError( + f"Episodes dataset is too large ({episode_size_mb} MB) to write to a single file. " + f"The current limit is {DEFAULT_DATA_FILE_SIZE_IN_MB} MB. " + "This function only supports single-file episode metadata. " + ) + + fpath = local_dir / DEFAULT_EPISODES_PATH.format(chunk_index=0, file_index=0) + fpath.parent.mkdir(parents=True, exist_ok=True) + episodes.to_parquet(fpath) + + +def load_episodes(local_dir: Path) -> datasets.Dataset: + episodes = load_nested_dataset(local_dir / EPISODES_DIR) + # Select episode features/columns containing references to episode data and videos + # (e.g. tasks, dataset_from_index, dataset_to_index, data/chunk_index, data/file_index, etc.) + # This is to speedup access to these data, instead of having to load episode stats. + episodes = episodes.select_columns([key for key in episodes.features if not key.startswith("stats/")]) + return episodes + + +def load_image_as_numpy( + fpath: str | Path, dtype: np.dtype = np.float32, channel_first: bool = True +) -> np.ndarray: + """Load an image from a file into a numpy array. + + Args: + fpath (str | Path): Path to the image file. + dtype (np.dtype): The desired data type of the output array. If floating, + pixels are scaled to [0, 1]. + channel_first (bool): If True, converts the image to (C, H, W) format. + Otherwise, it remains in (H, W, C) format. + + Returns: + np.ndarray: The image as a numpy array. + """ + img = PILImage.open(fpath).convert("RGB") + img_array = np.array(img, dtype=dtype) + if channel_first: # (H, W, C) -> (C, H, W) + img_array = np.transpose(img_array, (2, 0, 1)) + if np.issubdtype(dtype, np.floating): + img_array /= 255.0 + return img_array + + +def hf_transform_to_torch(items_dict: dict[str, list[Any]]) -> dict[str, list[torch.Tensor | str]]: + """Convert a batch from a Hugging Face dataset to torch tensors. + + This transform function converts items from Hugging Face dataset format (pyarrow) + to torch tensors. Importantly, images are converted from PIL objects (H, W, C, uint8) + to a torch image representation (C, H, W, float32) in the range [0, 1]. Other + types are converted to torch.tensor. + + Args: + items_dict (dict): A dictionary representing a batch of data from a + Hugging Face dataset. + + Returns: + dict: The batch with items converted to torch tensors. + """ + for key in items_dict: + first_item = items_dict[key][0] + if isinstance(first_item, PILImage.Image): + to_tensor = transforms.ToTensor() + items_dict[key] = [to_tensor(img) for img in items_dict[key]] + elif first_item is None: + pass + else: + items_dict[key] = [x if isinstance(x, str) else torch.tensor(x) for x in items_dict[key]] + return items_dict + + +def to_parquet_with_hf_images( + df: pandas.DataFrame, path: Path, features: datasets.Features | None = None +) -> None: + """This function correctly writes to parquet a panda DataFrame that contains images encoded by HF dataset. + This way, it can be loaded by HF dataset and correctly formatted images are returned. + + Args: + df: DataFrame to write to parquet. + path: Path to write the parquet file. + features: Optional HuggingFace Features schema. If provided, ensures image columns + are properly typed as Image() in the parquet schema. + """ + # TODO(qlhoest): replace this weird synthax by `df.to_parquet(path)` only + ds = datasets.Dataset.from_dict(df.to_dict(orient="list"), features=features) + ds.to_parquet(path) + + +def item_to_torch(item: dict) -> dict: + """Convert all items in a dictionary to PyTorch tensors where appropriate. + + This function is used to convert an item from a streaming dataset to PyTorch tensors. + + Args: + item (dict): Dictionary of items from a dataset. + + Returns: + dict: Dictionary with all tensor-like items converted to torch.Tensor. + """ + for key, val in item.items(): + if isinstance(val, (np.ndarray | list)) and key not in ["task"]: + # Convert numpy arrays and lists to torch tensors + item[key] = torch.tensor(val) + return item diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index 5d1b5d042..8f0600ba8 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -23,532 +23,53 @@ from pathlib import Path import datasets import numpy as np -import packaging.version import pandas as pd import PIL.Image -import pyarrow as pa import pyarrow.parquet as pq import torch import torch.utils from huggingface_hub import HfApi, snapshot_download from huggingface_hub.errors import RevisionNotFoundError -from lerobot.datasets.compute_stats import aggregate_stats, compute_episode_stats -from lerobot.datasets.image_writer import AsyncImageWriter, write_image -from lerobot.datasets.utils import ( - DEFAULT_EPISODES_PATH, - DEFAULT_FEATURES, - DEFAULT_IMAGE_PATH, - INFO_PATH, - _validate_feature_names, +from lerobot.datasets.compute_stats import compute_episode_stats +from lerobot.datasets.dataset_metadata import CODEBASE_VERSION, LeRobotDatasetMetadata +from lerobot.datasets.feature_utils import ( check_delta_timestamps, - check_version_compatibility, - create_empty_dataset_info, - create_lerobot_dataset_card, - embed_images, - flatten_dict, get_delta_indices, - get_file_size_in_mb, get_hf_features_from_features, - get_safe_version, - hf_transform_to_torch, - is_valid_version, - load_episodes, - load_info, - load_nested_dataset, - load_stats, - load_subtasks, - load_tasks, - update_chunk_file_indices, validate_episode_buffer, validate_frame, +) +from lerobot.datasets.image_writer import AsyncImageWriter, write_image +from lerobot.datasets.io_utils import ( + embed_images, + get_file_size_in_mb, + hf_transform_to_torch, + load_episodes, + load_nested_dataset, write_info, - write_json, - write_stats, - write_tasks, +) +from lerobot.datasets.utils import ( + DEFAULT_EPISODES_PATH, + DEFAULT_IMAGE_PATH, + create_lerobot_dataset_card, + get_safe_version, + is_valid_version, + update_chunk_file_indices, ) from lerobot.datasets.video_utils import ( StreamingVideoEncoder, - VideoFrame, concatenate_video_files, decode_video_frames, encode_video_frames, get_safe_default_codec, get_video_duration_in_s, - get_video_info, resolve_vcodec, ) from lerobot.utils.constants import HF_LEROBOT_HOME logger = logging.getLogger(__name__) -CODEBASE_VERSION = "v3.0" - - -class LeRobotDatasetMetadata: - def __init__( - self, - repo_id: str, - root: str | Path | None = None, - revision: str | None = None, - force_cache_sync: bool = False, - metadata_buffer_size: int = 10, - ): - self.repo_id = repo_id - self.revision = revision if revision else CODEBASE_VERSION - self.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id - self.writer = None - self.latest_episode = None - self.metadata_buffer: list[dict] = [] - self.metadata_buffer_size = metadata_buffer_size - - try: - if force_cache_sync: - raise FileNotFoundError - self.load_metadata() - except (FileNotFoundError, NotADirectoryError): - if is_valid_version(self.revision): - self.revision = get_safe_version(self.repo_id, self.revision) - - (self.root / "meta").mkdir(exist_ok=True, parents=True) - self.pull_from_repo(allow_patterns="meta/") - self.load_metadata() - - def _flush_metadata_buffer(self) -> None: - """Write all buffered episode metadata to parquet file.""" - if not hasattr(self, "metadata_buffer") or len(self.metadata_buffer) == 0: - return - - combined_dict = {} - for episode_dict in self.metadata_buffer: - for key, value in episode_dict.items(): - if key not in combined_dict: - combined_dict[key] = [] - # Extract value and serialize numpy arrays - # because PyArrow's from_pydict function doesn't support numpy arrays - val = value[0] if isinstance(value, list) else value - combined_dict[key].append(val.tolist() if isinstance(val, np.ndarray) else val) - - first_ep = self.metadata_buffer[0] - chunk_idx = first_ep["meta/episodes/chunk_index"][0] - file_idx = first_ep["meta/episodes/file_index"][0] - - table = pa.Table.from_pydict(combined_dict) - - if not self.writer: - path = Path(self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx)) - path.parent.mkdir(parents=True, exist_ok=True) - self.writer = pq.ParquetWriter( - path, schema=table.schema, compression="snappy", use_dictionary=True - ) - - self.writer.write_table(table) - - self.latest_episode = self.metadata_buffer[-1] - self.metadata_buffer.clear() - - def _close_writer(self) -> None: - """Close and cleanup the parquet writer if it exists.""" - self._flush_metadata_buffer() - - writer = getattr(self, "writer", None) - if writer is not None: - writer.close() - self.writer = None - - def __del__(self): - """ - Trust the user to call .finalize() but as an added safety check call the parquet writer to stop when calling the destructor - """ - self._close_writer() - - def load_metadata(self): - self.info = load_info(self.root) - check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION) - self.tasks = load_tasks(self.root) - self.subtasks = load_subtasks(self.root) - self.episodes = load_episodes(self.root) - self.stats = load_stats(self.root) - - def pull_from_repo( - self, - allow_patterns: list[str] | str | None = None, - ignore_patterns: list[str] | str | None = None, - ) -> None: - snapshot_download( - self.repo_id, - repo_type="dataset", - revision=self.revision, - local_dir=self.root, - allow_patterns=allow_patterns, - ignore_patterns=ignore_patterns, - ) - - @property - def url_root(self) -> str: - return f"hf://datasets/{self.repo_id}" - - @property - def _version(self) -> packaging.version.Version: - """Codebase version used to create this dataset.""" - return packaging.version.parse(self.info["codebase_version"]) - - def get_data_file_path(self, ep_index: int) -> Path: - if self.episodes is None: - self.episodes = load_episodes(self.root) - if ep_index >= len(self.episodes): - raise IndexError( - f"Episode index {ep_index} out of range. Episodes: {len(self.episodes) if self.episodes else 0}" - ) - ep = self.episodes[ep_index] - chunk_idx = ep["data/chunk_index"] - file_idx = ep["data/file_index"] - fpath = self.data_path.format(chunk_index=chunk_idx, file_index=file_idx) - return Path(fpath) - - def get_video_file_path(self, ep_index: int, vid_key: str) -> Path: - if self.episodes is None: - self.episodes = load_episodes(self.root) - if ep_index >= len(self.episodes): - raise IndexError( - f"Episode index {ep_index} out of range. Episodes: {len(self.episodes) if self.episodes else 0}" - ) - ep = self.episodes[ep_index] - chunk_idx = ep[f"videos/{vid_key}/chunk_index"] - file_idx = ep[f"videos/{vid_key}/file_index"] - fpath = self.video_path.format(video_key=vid_key, chunk_index=chunk_idx, file_index=file_idx) - return Path(fpath) - - @property - def data_path(self) -> str: - """Formattable string for the parquet files.""" - return self.info["data_path"] - - @property - def video_path(self) -> str | None: - """Formattable string for the video files.""" - return self.info["video_path"] - - @property - def robot_type(self) -> str | None: - """Robot type used in recording this dataset.""" - return self.info["robot_type"] - - @property - def fps(self) -> int: - """Frames per second used during data collection.""" - return self.info["fps"] - - @property - def features(self) -> dict[str, dict]: - """All features contained in the dataset.""" - return self.info["features"] - - @property - def image_keys(self) -> list[str]: - """Keys to access visual modalities stored as images.""" - return [key for key, ft in self.features.items() if ft["dtype"] == "image"] - - @property - def video_keys(self) -> list[str]: - """Keys to access visual modalities stored as videos.""" - return [key for key, ft in self.features.items() if ft["dtype"] == "video"] - - @property - def camera_keys(self) -> list[str]: - """Keys to access visual modalities (regardless of their storage method).""" - return [key for key, ft in self.features.items() if ft["dtype"] in ["video", "image"]] - - @property - def names(self) -> dict[str, list | dict]: - """Names of the various dimensions of vector modalities.""" - return {key: ft["names"] for key, ft in self.features.items()} - - @property - def shapes(self) -> dict: - """Shapes for the different features.""" - return {key: tuple(ft["shape"]) for key, ft in self.features.items()} - - @property - def total_episodes(self) -> int: - """Total number of episodes available.""" - return self.info["total_episodes"] - - @property - def total_frames(self) -> int: - """Total number of frames saved in this dataset.""" - return self.info["total_frames"] - - @property - def total_tasks(self) -> int: - """Total number of different tasks performed in this dataset.""" - return self.info["total_tasks"] - - @property - def chunks_size(self) -> int: - """Max number of files per chunk.""" - return self.info["chunks_size"] - - @property - def data_files_size_in_mb(self) -> int: - """Max size of data file in mega bytes.""" - return self.info["data_files_size_in_mb"] - - @property - def video_files_size_in_mb(self) -> int: - """Max size of video file in mega bytes.""" - return self.info["video_files_size_in_mb"] - - def get_task_index(self, task: str) -> int | None: - """ - Given a task in natural language, returns its task_index if the task already exists in the dataset, - otherwise return None. - """ - if task in self.tasks.index: - return int(self.tasks.loc[task].task_index) - else: - return None - - def save_episode_tasks(self, tasks: list[str]): - if len(set(tasks)) != len(tasks): - raise ValueError(f"Tasks are not unique: {tasks}") - - if self.tasks is None: - new_tasks = tasks - task_indices = range(len(tasks)) - self.tasks = pd.DataFrame({"task_index": task_indices}, index=pd.Index(tasks, name="task")) - else: - new_tasks = [task for task in tasks if task not in self.tasks.index] - new_task_indices = range(len(self.tasks), len(self.tasks) + len(new_tasks)) - for task_idx, task in zip(new_task_indices, new_tasks, strict=False): - self.tasks.loc[task] = task_idx - - if len(new_tasks) > 0: - # Update on disk - write_tasks(self.tasks, self.root) - - def _save_episode_metadata(self, episode_dict: dict) -> None: - """Buffer episode metadata and write to parquet in batches for efficiency. - - This function accumulates episode metadata in a buffer and flushes it when the buffer - reaches the configured size. This reduces I/O overhead by writing multiple episodes - at once instead of one row at a time. - - Notes: We both need to update parquet files and HF dataset: - - `pandas` loads parquet file in RAM - - `datasets` relies on a memory mapping from pyarrow (no RAM). It either converts parquet files to a pyarrow cache on disk, - or loads directly from pyarrow cache. - """ - # Convert to list format for each value - episode_dict = {key: [value] for key, value in episode_dict.items()} - num_frames = episode_dict["length"][0] - - if self.latest_episode is None: - # Initialize indices and frame count for a new dataset made of the first episode data - chunk_idx, file_idx = 0, 0 - if self.episodes is not None and len(self.episodes) > 0: - # It means we are resuming recording, so we need to load the latest episode - # Update the indices to avoid overwriting the latest episode - chunk_idx = self.episodes[-1]["meta/episodes/chunk_index"] - file_idx = self.episodes[-1]["meta/episodes/file_index"] - latest_num_frames = self.episodes[-1]["dataset_to_index"] - episode_dict["dataset_from_index"] = [latest_num_frames] - episode_dict["dataset_to_index"] = [latest_num_frames + num_frames] - - # When resuming, move to the next file - chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.chunks_size) - else: - episode_dict["dataset_from_index"] = [0] - episode_dict["dataset_to_index"] = [num_frames] - - episode_dict["meta/episodes/chunk_index"] = [chunk_idx] - episode_dict["meta/episodes/file_index"] = [file_idx] - else: - chunk_idx = self.latest_episode["meta/episodes/chunk_index"][0] - file_idx = self.latest_episode["meta/episodes/file_index"][0] - - latest_path = ( - self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx) - if self.writer is None - else self.writer.where - ) - - if Path(latest_path).exists(): - latest_size_in_mb = get_file_size_in_mb(Path(latest_path)) - latest_num_frames = self.latest_episode["episode_index"][0] - - av_size_per_frame = latest_size_in_mb / latest_num_frames if latest_num_frames > 0 else 0.0 - - if latest_size_in_mb + av_size_per_frame * num_frames >= self.data_files_size_in_mb: - # Size limit is reached, flush buffer and prepare new parquet file - self._flush_metadata_buffer() - chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.chunks_size) - self._close_writer() - - # Update the existing pandas dataframe with new row - episode_dict["meta/episodes/chunk_index"] = [chunk_idx] - episode_dict["meta/episodes/file_index"] = [file_idx] - episode_dict["dataset_from_index"] = [self.latest_episode["dataset_to_index"][0]] - episode_dict["dataset_to_index"] = [self.latest_episode["dataset_to_index"][0] + num_frames] - - # Add to buffer - self.metadata_buffer.append(episode_dict) - self.latest_episode = episode_dict - - if len(self.metadata_buffer) >= self.metadata_buffer_size: - self._flush_metadata_buffer() - - def save_episode( - self, - episode_index: int, - episode_length: int, - episode_tasks: list[str], - episode_stats: dict[str, dict], - episode_metadata: dict, - ) -> None: - episode_dict = { - "episode_index": episode_index, - "tasks": episode_tasks, - "length": episode_length, - } - episode_dict.update(episode_metadata) - episode_dict.update(flatten_dict({"stats": episode_stats})) - self._save_episode_metadata(episode_dict) - - # Update info - self.info["total_episodes"] += 1 - self.info["total_frames"] += episode_length - self.info["total_tasks"] = len(self.tasks) - self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"} - - write_info(self.info, self.root) - - self.stats = aggregate_stats([self.stats, episode_stats]) if self.stats is not None else episode_stats - write_stats(self.stats, self.root) - - def update_video_info(self, video_key: str | None = None) -> None: - """ - Warning: this function writes info from first episode videos, implicitly assuming that all videos have - been encoded the same way. Also, this means it assumes the first episode exists. - """ - if video_key is not None and video_key not in self.video_keys: - raise ValueError(f"Video key {video_key} not found in dataset") - - video_keys = [video_key] if video_key is not None else self.video_keys - for key in video_keys: - if not self.features[key].get("info", None): - video_path = self.root / self.video_path.format(video_key=key, chunk_index=0, file_index=0) - self.info["features"][key]["info"] = get_video_info(video_path) - - def update_chunk_settings( - self, - chunks_size: int | None = None, - data_files_size_in_mb: int | None = None, - video_files_size_in_mb: int | None = None, - ) -> None: - """Update chunk and file size settings after dataset creation. - - This allows users to customize storage organization without modifying the constructor. - These settings control how episodes are chunked and how large files can grow before - creating new ones. - - Args: - chunks_size: Maximum number of files per chunk directory. If None, keeps current value. - data_files_size_in_mb: Maximum size for data parquet files in MB. If None, keeps current value. - video_files_size_in_mb: Maximum size for video files in MB. If None, keeps current value. - """ - if chunks_size is not None: - if chunks_size <= 0: - raise ValueError(f"chunks_size must be positive, got {chunks_size}") - self.info["chunks_size"] = chunks_size - - if data_files_size_in_mb is not None: - if data_files_size_in_mb <= 0: - raise ValueError(f"data_files_size_in_mb must be positive, got {data_files_size_in_mb}") - self.info["data_files_size_in_mb"] = data_files_size_in_mb - - if video_files_size_in_mb is not None: - if video_files_size_in_mb <= 0: - raise ValueError(f"video_files_size_in_mb must be positive, got {video_files_size_in_mb}") - self.info["video_files_size_in_mb"] = video_files_size_in_mb - - # Update the info file on disk - write_info(self.info, self.root) - - def get_chunk_settings(self) -> dict[str, int]: - """Get current chunk and file size settings. - - Returns: - Dict containing chunks_size, data_files_size_in_mb, and video_files_size_in_mb. - """ - return { - "chunks_size": self.chunks_size, - "data_files_size_in_mb": self.data_files_size_in_mb, - "video_files_size_in_mb": self.video_files_size_in_mb, - } - - def __repr__(self): - feature_keys = list(self.features) - return ( - f"{self.__class__.__name__}({{\n" - f" Repository ID: '{self.repo_id}',\n" - f" Total episodes: '{self.total_episodes}',\n" - f" Total frames: '{self.total_frames}',\n" - f" Features: '{feature_keys}',\n" - "})',\n" - ) - - @classmethod - def create( - cls, - repo_id: str, - fps: int, - features: dict, - robot_type: str | None = None, - root: str | Path | None = None, - use_videos: bool = True, - metadata_buffer_size: int = 10, - chunks_size: int | None = None, - data_files_size_in_mb: int | None = None, - video_files_size_in_mb: int | None = None, - ) -> "LeRobotDatasetMetadata": - """Creates metadata for a LeRobotDataset.""" - obj = cls.__new__(cls) - obj.repo_id = repo_id - obj.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id - - obj.root.mkdir(parents=True, exist_ok=False) - - features = {**features, **DEFAULT_FEATURES} - _validate_feature_names(features) - - obj.tasks = None - obj.subtasks = None - obj.episodes = None - obj.stats = None - obj.info = create_empty_dataset_info( - CODEBASE_VERSION, - fps, - features, - use_videos, - robot_type, - chunks_size, - data_files_size_in_mb, - video_files_size_in_mb, - ) - if len(obj.video_keys) > 0 and not use_videos: - raise ValueError( - f"Features contain video keys {obj.video_keys}, but 'use_videos' is set to False. " - "Either remove video features from the features dict, or set 'use_videos=True'." - ) - write_json(obj.info, obj.root / INFO_PATH) - obj.revision = None - obj.writer = None - obj.latest_episode = None - obj.metadata_buffer = [] - obj.metadata_buffer_size = metadata_buffer_size - return obj - def _encode_video_worker( video_key: str, @@ -1721,184 +1242,3 @@ class LeRobotDataset(torch.utils.data.Dataset): obj._streaming_encoder = None return obj - - -class MultiLeRobotDataset(torch.utils.data.Dataset): - """A dataset consisting of multiple underlying `LeRobotDataset`s. - - The underlying `LeRobotDataset`s are effectively concatenated, and this class adopts much of the API - structure of `LeRobotDataset`. - """ - - def __init__( - self, - repo_ids: list[str], - root: str | Path | None = None, - episodes: dict | None = None, - image_transforms: Callable | None = None, - delta_timestamps: dict[str, list[float]] | None = None, - tolerances_s: dict | None = None, - download_videos: bool = True, - video_backend: str | None = None, - ): - super().__init__() - self.repo_ids = repo_ids - self.root = Path(root) if root else HF_LEROBOT_HOME - self.tolerances_s = tolerances_s if tolerances_s else dict.fromkeys(repo_ids, 0.0001) - # Construct the underlying datasets passing everything but `transform` and `delta_timestamps` which - # are handled by this class. - self._datasets = [ - LeRobotDataset( - repo_id, - root=self.root / repo_id, - episodes=episodes[repo_id] if episodes else None, - image_transforms=image_transforms, - delta_timestamps=delta_timestamps, - tolerance_s=self.tolerances_s[repo_id], - download_videos=download_videos, - video_backend=video_backend, - ) - for repo_id in repo_ids - ] - - # Disable any data keys that are not common across all of the datasets. Note: we may relax this - # restriction in future iterations of this class. For now, this is necessary at least for being able - # to use PyTorch's default DataLoader collate function. - self.disabled_features = set() - intersection_features = set(self._datasets[0].features) - for ds in self._datasets: - intersection_features.intersection_update(ds.features) - if len(intersection_features) == 0: - raise RuntimeError( - "Multiple datasets were provided but they had no keys common to all of them. " - "The multi-dataset functionality currently only keeps common keys." - ) - for repo_id, ds in zip(self.repo_ids, self._datasets, strict=True): - extra_keys = set(ds.features).difference(intersection_features) - if extra_keys: - logger.warning( - f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the " - "other datasets." - ) - self.disabled_features.update(extra_keys) - - self.image_transforms = image_transforms - self.delta_timestamps = delta_timestamps - # TODO(rcadene, aliberts): We should not perform this aggregation for datasets - # with multiple robots of different ranges. Instead we should have one normalization - # per robot. - self.stats = aggregate_stats([dataset.meta.stats for dataset in self._datasets]) - - @property - def repo_id_to_index(self): - """Return a mapping from dataset repo_id to a dataset index automatically created by this class. - - This index is incorporated as a data key in the dictionary returned by `__getitem__`. - """ - return {repo_id: i for i, repo_id in enumerate(self.repo_ids)} - - @property - def fps(self) -> int: - """Frames per second used during data collection. - - NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info. - """ - return self._datasets[0].meta.info["fps"] - - @property - def video(self) -> bool: - """Returns True if this dataset loads video frames from mp4 files. - - Returns False if it only loads images from png files. - - NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info. - """ - return self._datasets[0].meta.info.get("video", False) - - @property - def features(self) -> datasets.Features: - features = {} - for dataset in self._datasets: - features.update({k: v for k, v in dataset.hf_features.items() if k not in self.disabled_features}) - return features - - @property - def camera_keys(self) -> list[str]: - """Keys to access image and video stream from cameras.""" - keys = [] - for key, feats in self.features.items(): - if isinstance(feats, (datasets.Image | VideoFrame)): - keys.append(key) - return keys - - @property - def video_frame_keys(self) -> list[str]: - """Keys to access video frames that requires to be decoded into images. - - Note: It is empty if the dataset contains images only, - or equal to `self.cameras` if the dataset contains videos only, - or can even be a subset of `self.cameras` in a case of a mixed image/video dataset. - """ - video_frame_keys = [] - for key, feats in self.features.items(): - if isinstance(feats, VideoFrame): - video_frame_keys.append(key) - return video_frame_keys - - @property - def num_frames(self) -> int: - """Number of samples/frames.""" - return sum(d.num_frames for d in self._datasets) - - @property - def num_episodes(self) -> int: - """Number of episodes.""" - return sum(d.num_episodes for d in self._datasets) - - @property - def tolerance_s(self) -> float: - """Tolerance in seconds used to discard loaded frames when their timestamps - are not close enough from the requested frames. It is only used when `delta_timestamps` - is provided or when loading video frames from mp4 files. - """ - # 1e-4 to account for possible numerical error - return 1 / self.fps - 1e-4 - - def __len__(self): - return self.num_frames - - def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: - if idx >= len(self): - raise IndexError(f"Index {idx} out of bounds.") - # Determine which dataset to get an item from based on the index. - start_idx = 0 - dataset_idx = 0 - for dataset in self._datasets: - if idx >= start_idx + dataset.num_frames: - start_idx += dataset.num_frames - dataset_idx += 1 - continue - break - else: - raise AssertionError("We expect the loop to break out as long as the index is within bounds.") - item = self._datasets[dataset_idx][idx - start_idx] - item["dataset_index"] = torch.tensor(dataset_idx) - for data_key in self.disabled_features: - if data_key in item: - del item[data_key] - - return item - - def __repr__(self): - return ( - f"{self.__class__.__name__}(\n" - f" Repository IDs: '{self.repo_ids}',\n" - f" Number of Samples: {self.num_frames},\n" - f" Number of Episodes: {self.num_episodes},\n" - f" Type: {'video (.mp4)' if self.video else 'image (.png)'},\n" - f" Recorded Frames per Second: {self.fps},\n" - f" Camera Keys: {self.camera_keys},\n" - f" Video Frame Keys: {self.video_frame_keys if self.video else 'N/A'},\n" - f" Transformations: {self.image_transforms},\n" - f")" - ) diff --git a/src/lerobot/datasets/multi_dataset.py b/src/lerobot/datasets/multi_dataset.py new file mode 100644 index 000000000..917d5c5eb --- /dev/null +++ b/src/lerobot/datasets/multi_dataset.py @@ -0,0 +1,210 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from collections.abc import Callable +from pathlib import Path + +import datasets +import torch +import torch.utils + +from lerobot.datasets.compute_stats import aggregate_stats +from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.datasets.video_utils import VideoFrame +from lerobot.utils.constants import HF_LEROBOT_HOME + +logger = logging.getLogger(__name__) + + +class MultiLeRobotDataset(torch.utils.data.Dataset): + """A dataset consisting of multiple underlying `LeRobotDataset`s. + + The underlying `LeRobotDataset`s are effectively concatenated, and this class adopts much of the API + structure of `LeRobotDataset`. + """ + + def __init__( + self, + repo_ids: list[str], + root: str | Path | None = None, + episodes: dict | None = None, + image_transforms: Callable | None = None, + delta_timestamps: dict[str, list[float]] | None = None, + tolerances_s: dict | None = None, + download_videos: bool = True, + video_backend: str | None = None, + ): + super().__init__() + self.repo_ids = repo_ids + self.root = Path(root) if root else HF_LEROBOT_HOME + self.tolerances_s = tolerances_s if tolerances_s else dict.fromkeys(repo_ids, 0.0001) + # Construct the underlying datasets passing everything but `transform` and `delta_timestamps` which + # are handled by this class. + self._datasets = [ + LeRobotDataset( + repo_id, + root=self.root / repo_id, + episodes=episodes[repo_id] if episodes else None, + image_transforms=image_transforms, + delta_timestamps=delta_timestamps, + tolerance_s=self.tolerances_s[repo_id], + download_videos=download_videos, + video_backend=video_backend, + ) + for repo_id in repo_ids + ] + + # Disable any data keys that are not common across all of the datasets. Note: we may relax this + # restriction in future iterations of this class. For now, this is necessary at least for being able + # to use PyTorch's default DataLoader collate function. + self.disabled_features = set() + intersection_features = set(self._datasets[0].features) + for ds in self._datasets: + intersection_features.intersection_update(ds.features) + if len(intersection_features) == 0: + raise RuntimeError( + "Multiple datasets were provided but they had no keys common to all of them. " + "The multi-dataset functionality currently only keeps common keys." + ) + for repo_id, ds in zip(self.repo_ids, self._datasets, strict=True): + extra_keys = set(ds.features).difference(intersection_features) + if extra_keys: + logger.warning( + f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the " + "other datasets." + ) + self.disabled_features.update(extra_keys) + + self.image_transforms = image_transforms + self.delta_timestamps = delta_timestamps + # TODO(rcadene, aliberts): We should not perform this aggregation for datasets + # with multiple robots of different ranges. Instead we should have one normalization + # per robot. + self.stats = aggregate_stats([dataset.meta.stats for dataset in self._datasets]) + + @property + def repo_id_to_index(self): + """Return a mapping from dataset repo_id to a dataset index automatically created by this class. + + This index is incorporated as a data key in the dictionary returned by `__getitem__`. + """ + return {repo_id: i for i, repo_id in enumerate(self.repo_ids)} + + @property + def fps(self) -> int: + """Frames per second used during data collection. + + NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info. + """ + return self._datasets[0].meta.info["fps"] + + @property + def video(self) -> bool: + """Returns True if this dataset loads video frames from mp4 files. + + Returns False if it only loads images from png files. + + NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info. + """ + return self._datasets[0].meta.info.get("video", False) + + @property + def features(self) -> datasets.Features: + features = {} + for dataset in self._datasets: + features.update({k: v for k, v in dataset.hf_features.items() if k not in self.disabled_features}) + return features + + @property + def camera_keys(self) -> list[str]: + """Keys to access image and video stream from cameras.""" + keys = [] + for key, feats in self.features.items(): + if isinstance(feats, (datasets.Image | VideoFrame)): + keys.append(key) + return keys + + @property + def video_frame_keys(self) -> list[str]: + """Keys to access video frames that requires to be decoded into images. + + Note: It is empty if the dataset contains images only, + or equal to `self.cameras` if the dataset contains videos only, + or can even be a subset of `self.cameras` in a case of a mixed image/video dataset. + """ + video_frame_keys = [] + for key, feats in self.features.items(): + if isinstance(feats, VideoFrame): + video_frame_keys.append(key) + return video_frame_keys + + @property + def num_frames(self) -> int: + """Number of samples/frames.""" + return sum(d.num_frames for d in self._datasets) + + @property + def num_episodes(self) -> int: + """Number of episodes.""" + return sum(d.num_episodes for d in self._datasets) + + @property + def tolerance_s(self) -> float: + """Tolerance in seconds used to discard loaded frames when their timestamps + are not close enough from the requested frames. It is only used when `delta_timestamps` + is provided or when loading video frames from mp4 files. + """ + # 1e-4 to account for possible numerical error + return 1 / self.fps - 1e-4 + + def __len__(self): + return self.num_frames + + def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: + if idx >= len(self): + raise IndexError(f"Index {idx} out of bounds.") + # Determine which dataset to get an item from based on the index. + start_idx = 0 + dataset_idx = 0 + for dataset in self._datasets: + if idx >= start_idx + dataset.num_frames: + start_idx += dataset.num_frames + dataset_idx += 1 + continue + break + else: + raise AssertionError("We expect the loop to break out as long as the index is within bounds.") + item = self._datasets[dataset_idx][idx - start_idx] + item["dataset_index"] = torch.tensor(dataset_idx) + for data_key in self.disabled_features: + if data_key in item: + del item[data_key] + + return item + + def __repr__(self): + return ( + f"{self.__class__.__name__}(\n" + f" Repository IDs: '{self.repo_ids}',\n" + f" Number of Samples: {self.num_frames},\n" + f" Number of Episodes: {self.num_episodes},\n" + f" Type: {'video (.mp4)' if self.video else 'image (.png)'},\n" + f" Recorded Frames per Second: {self.fps},\n" + f" Camera Keys: {self.camera_keys},\n" + f" Video Frame Keys: {self.video_frame_keys if self.video else 'N/A'},\n" + f" Transformations: {self.image_transforms},\n" + f")" + ) diff --git a/src/lerobot/datasets/pipeline_features.py b/src/lerobot/datasets/pipeline_features.py index fe8cabbeb..96779fdc6 100644 --- a/src/lerobot/datasets/pipeline_features.py +++ b/src/lerobot/datasets/pipeline_features.py @@ -17,7 +17,7 @@ from collections.abc import Sequence from typing import Any from lerobot.configs.types import PipelineFeatureType -from lerobot.datasets.utils import hw_to_dataset_features +from lerobot.datasets.feature_utils import hw_to_dataset_features from lerobot.processor import DataProcessorPipeline from lerobot.types import RobotAction, RobotObservation from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE, OBS_STR diff --git a/src/lerobot/datasets/streaming_dataset.py b/src/lerobot/datasets/streaming_dataset.py index 454389d46..62e00558a 100644 --- a/src/lerobot/datasets/streaming_dataset.py +++ b/src/lerobot/datasets/streaming_dataset.py @@ -13,7 +13,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections.abc import Callable, Generator, Iterator +from collections import deque +from collections.abc import Callable, Generator, Iterable, Iterator from pathlib import Path import datasets @@ -21,16 +22,13 @@ import numpy as np import torch from datasets import load_dataset -from lerobot.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDatasetMetadata +from lerobot.datasets.dataset_metadata import CODEBASE_VERSION, LeRobotDatasetMetadata +from lerobot.datasets.feature_utils import get_delta_indices +from lerobot.datasets.io_utils import item_to_torch from lerobot.datasets.utils import ( - Backtrackable, - LookAheadError, - LookBackError, check_version_compatibility, find_float_index, - get_delta_indices, is_float_in_list, - item_to_torch, safe_shard, ) from lerobot.datasets.video_utils import ( @@ -40,6 +38,164 @@ from lerobot.datasets.video_utils import ( from lerobot.utils.constants import HF_LEROBOT_HOME, LOOKAHEAD_BACKTRACKTABLE, LOOKBACK_BACKTRACKTABLE +class LookBackError(Exception): + """ + Exception raised when trying to look back in the history of a Backtrackable object. + """ + + pass + + +class LookAheadError(Exception): + """ + Exception raised when trying to look ahead in the future of a Backtrackable object. + """ + + pass + + +class Backtrackable[T]: + """ + Wrap any iterator/iterable so you can step back up to `history` items + and look ahead up to `lookahead` items. + + This is useful for streaming datasets where you need to access previous and future items + but can't load the entire dataset into memory. + + Example: + ------- + ```python + ds = load_dataset("c4", "en", streaming=True, split="train") + rev = Backtrackable(ds, history=3, lookahead=2) + + x0 = next(rev) # forward + x1 = next(rev) + x2 = next(rev) + + # Look ahead + x3_peek = rev.peek_ahead(1) # next item without moving cursor + x4_peek = rev.peek_ahead(2) # two items ahead + + # Look back + x1_again = rev.peek_back(1) # previous item without moving cursor + x0_again = rev.peek_back(2) # two items back + + # Move backward + x1_back = rev.prev() # back one step + next(rev) # returns x2, continues forward from where we were + ``` + """ + + __slots__ = ("_source", "_back_buf", "_ahead_buf", "_cursor", "_history", "_lookahead") + + def __init__(self, iterable: Iterable[T], *, history: int = 1, lookahead: int = 0): + if history < 1: + raise ValueError("history must be >= 1") + if lookahead <= 0: + raise ValueError("lookahead must be > 0") + + self._source: Iterator[T] = iter(iterable) + self._back_buf: deque[T] = deque(maxlen=history) + self._ahead_buf: deque[T] = deque(maxlen=lookahead) if lookahead > 0 else deque() + self._cursor: int = 0 + self._history = history + self._lookahead = lookahead + + def __iter__(self) -> "Backtrackable[T]": + return self + + def __next__(self) -> T: + # If we've stepped back, consume from back buffer first + if self._cursor < 0: # -1 means "last item", etc. + self._cursor += 1 + return self._back_buf[self._cursor] + + # If we have items in the ahead buffer, use them first + item = self._ahead_buf.popleft() if self._ahead_buf else next(self._source) + + # Add current item to back buffer and reset cursor + self._back_buf.append(item) + self._cursor = 0 + return item + + def prev(self) -> T: + """ + Step one item back in history and return it. + Raises IndexError if already at the oldest buffered item. + """ + if len(self._back_buf) + self._cursor <= 1: + raise LookBackError("At start of history") + + self._cursor -= 1 + return self._back_buf[self._cursor] + + def peek_back(self, n: int = 1) -> T: + """ + Look `n` items back (n=1 == previous item) without moving the cursor. + """ + if n < 0 or n + 1 > len(self._back_buf) + self._cursor: + raise LookBackError("peek_back distance out of range") + + return self._back_buf[self._cursor - (n + 1)] + + def peek_ahead(self, n: int = 1) -> T: + """ + Look `n` items ahead (n=1 == next item) without moving the cursor. + Fills the ahead buffer if necessary. + """ + if n < 1: + raise LookAheadError("peek_ahead distance must be 1 or more") + elif n > self._lookahead: + raise LookAheadError("peek_ahead distance exceeds lookahead limit") + + # Fill ahead buffer if we don't have enough items + while len(self._ahead_buf) < n: + try: + item = next(self._source) + self._ahead_buf.append(item) + + except StopIteration as err: + raise LookAheadError("peek_ahead: not enough items in source") from err + + return self._ahead_buf[n - 1] + + def history(self) -> list[T]: + """ + Return a copy of the buffered history (most recent last). + The list length ≤ `history` argument passed at construction. + """ + if self._cursor == 0: + return list(self._back_buf) + + # When cursor<0, slice so the order remains chronological + return list(self._back_buf)[: self._cursor or None] + + def can_peek_back(self, steps: int = 1) -> bool: + """ + Check if we can go back `steps` items without raising an IndexError. + """ + return steps <= len(self._back_buf) + self._cursor + + def can_peek_ahead(self, steps: int = 1) -> bool: + """ + Check if we can peek ahead `steps` items. + This may involve trying to fill the ahead buffer. + """ + if self._lookahead > 0 and steps > self._lookahead: + return False + + # Try to fill ahead buffer to check if we can peek that far + try: + while len(self._ahead_buf) < steps: + if self._lookahead > 0 and len(self._ahead_buf) >= self._lookahead: + return False + item = next(self._source) + self._ahead_buf.append(item) + return True + except StopIteration: + return False + + class StreamingLeRobotDataset(torch.utils.data.IterableDataset): """LeRobotDataset with streaming capabilities. diff --git a/src/lerobot/datasets/utils.py b/src/lerobot/datasets/utils.py index 8bc56a1bd..2e1d360f9 100644 --- a/src/lerobot/datasets/utils.py +++ b/src/lerobot/datasets/utils.py @@ -17,35 +17,57 @@ import contextlib import importlib.resources import json import logging -from collections import deque -from collections.abc import Iterable, Iterator -from pathlib import Path -from pprint import pformat +from collections.abc import Iterator from typing import Any import datasets import numpy as np import packaging.version -import pandas -import pandas as pd -import pyarrow.dataset as pa_ds -import pyarrow.parquet as pq import torch -from datasets import Dataset -from datasets.table import embed_table_storage from huggingface_hub import DatasetCard, DatasetCardData, HfApi from huggingface_hub.errors import RevisionNotFoundError -from PIL import Image as PILImage -from torchvision import transforms -from lerobot.configs.types import FeatureType, PolicyFeature -from lerobot.datasets.backward_compatibility import ( - FUTURE_MESSAGE, - BackwardCompatibilityError, - ForwardCompatibilityError, -) -from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_STR -from lerobot.utils.utils import SuppressProgressBars, is_valid_numpy_dtype_string +V30_MESSAGE = """ +The dataset you requested ({repo_id}) is in {version} format. + +We introduced a new format since v3.0 which is not backward compatible with v2.1. +Please, update your dataset to the new format using this command: +``` +python -m lerobot.scripts.convert_dataset_v21_to_v30 --repo-id={repo_id} +``` + +If you already have a converted version uploaded to the hub, then this error might be because of +an older version in your local cache. Consider deleting the cached version and retrying. + +If you encounter a problem, contact LeRobot maintainers on [Discord](https://discord.com/invite/s3KuuzsPFb) +or open an [issue on GitHub](https://github.com/huggingface/lerobot/issues/new/choose). +""" + +FUTURE_MESSAGE = """ +The dataset you requested ({repo_id}) is only available in {version} format. +As we cannot ensure forward compatibility with it, please update your current version of lerobot. +""" + + +class CompatibilityError(Exception): ... + + +class BackwardCompatibilityError(CompatibilityError): + def __init__(self, repo_id: str, version: packaging.version.Version): + if version.major == 2 and version.minor == 1: + message = V30_MESSAGE.format(repo_id=repo_id, version=version) + else: + raise NotImplementedError( + "Contact the maintainer on [Discord](https://discord.com/invite/s3KuuzsPFb)." + ) + super().__init__(message) + + +class ForwardCompatibilityError(CompatibilityError): + def __init__(self, repo_id: str, version: packaging.version.Version): + message = FUTURE_MESSAGE.format(repo_id=repo_id, version=version) + super().__init__(message) + DEFAULT_CHUNK_SIZE = 1000 # Max number of files per chunk DEFAULT_DATA_FILE_SIZE_IN_MB = 100 # Max size per file @@ -79,21 +101,6 @@ DEFAULT_FEATURES = { } -def get_parquet_file_size_in_mb(parquet_path: str | Path) -> float: - metadata = pq.read_metadata(parquet_path) - total_uncompressed_size = 0 - for row_group in range(metadata.num_row_groups): - rg_metadata = metadata.row_group(row_group) - for column in range(rg_metadata.num_columns): - col_metadata = rg_metadata.column(column) - total_uncompressed_size += col_metadata.total_uncompressed_size - return total_uncompressed_size / (1024**2) - - -def get_hf_dataset_size_in_mb(hf_ds: Dataset) -> int: - return hf_ds.data.nbytes // (1024**2) - - def update_chunk_file_indices(chunk_idx: int, file_idx: int, chunks_size: int) -> tuple[int, int]: if file_idx == chunks_size - 1: file_idx = 0 @@ -103,43 +110,6 @@ def update_chunk_file_indices(chunk_idx: int, file_idx: int, chunks_size: int) - return chunk_idx, file_idx -def load_nested_dataset( - pq_dir: Path, features: datasets.Features | None = None, episodes: list[int] | None = None -) -> Dataset: - """Find parquet files in provided directory {pq_dir}/chunk-xxx/file-xxx.parquet - Convert parquet files to pyarrow memory mapped in a cache folder for efficient RAM usage - Concatenate all pyarrow references to return HF Dataset format - - Args: - pq_dir: Directory containing parquet files - features: Optional features schema to ensure consistent loading of complex types like images - episodes: Optional list of episode indices to filter. Uses PyArrow predicate pushdown for efficiency. - """ - paths = sorted(pq_dir.glob("*/*.parquet")) - if len(paths) == 0: - raise FileNotFoundError(f"Provided directory does not contain any parquet file: {pq_dir}") - - with SuppressProgressBars(): - # We use .from_parquet() memory-mapped loading for efficiency - filters = pa_ds.field("episode_index").isin(episodes) if episodes is not None else None - return Dataset.from_parquet([str(path) for path in paths], filters=filters, features=features) - - -def get_parquet_num_frames(parquet_path: str | Path) -> int: - metadata = pq.read_metadata(parquet_path) - return metadata.num_rows - - -def get_file_size_in_mb(file_path: Path) -> float: - """Get file size on disk in megabytes. - - Args: - file_path (Path): Path to the file. - """ - file_size_bytes = file_path.stat().st_size - return file_size_bytes / (1024**2) - - def flatten_dict(d: dict, parent_key: str = "", sep: str = "/") -> dict: """Flatten a nested dictionary by joining keys with a separator. @@ -222,217 +192,6 @@ def serialize_dict(stats: dict[str, torch.Tensor | np.ndarray | dict]) -> dict: return unflatten_dict(serialized_dict) -def embed_images(dataset: datasets.Dataset) -> datasets.Dataset: - """Embed image bytes into the dataset table before saving to Parquet. - - This function prepares a Hugging Face dataset for serialization by converting - image objects into an embedded format that can be stored in Arrow/Parquet. - - Args: - dataset (datasets.Dataset): The input dataset, possibly containing image features. - - Returns: - datasets.Dataset: The dataset with images embedded in the table storage. - """ - # Embed image bytes into the table before saving to parquet - format = dataset.format - dataset = dataset.with_format("arrow") - dataset = dataset.map(embed_table_storage, batched=False) - dataset = dataset.with_format(**format) - return dataset - - -def load_json(fpath: Path) -> Any: - """Load data from a JSON file. - - Args: - fpath (Path): Path to the JSON file. - - Returns: - Any: The data loaded from the JSON file. - """ - with open(fpath) as f: - return json.load(f) - - -def write_json(data: dict, fpath: Path) -> None: - """Write data to a JSON file. - - Creates parent directories if they don't exist. - - Args: - data (dict): The dictionary to write. - fpath (Path): The path to the output JSON file. - """ - fpath.parent.mkdir(exist_ok=True, parents=True) - with open(fpath, "w") as f: - json.dump(data, f, indent=4, ensure_ascii=False) - - -def write_info(info: dict, local_dir: Path) -> None: - write_json(info, local_dir / INFO_PATH) - - -def load_info(local_dir: Path) -> dict: - """Load dataset info metadata from its standard file path. - - Also converts shape lists to tuples for consistency. - - Args: - local_dir (Path): The root directory of the dataset. - - Returns: - dict: The dataset information dictionary. - """ - info = load_json(local_dir / INFO_PATH) - for ft in info["features"].values(): - ft["shape"] = tuple(ft["shape"]) - return info - - -def write_stats(stats: dict, local_dir: Path) -> None: - """Serialize and write dataset statistics to their standard file path. - - Args: - stats (dict): The statistics dictionary (can contain tensors/numpy arrays). - local_dir (Path): The root directory of the dataset. - """ - serialized_stats = serialize_dict(stats) - write_json(serialized_stats, local_dir / STATS_PATH) - - -def cast_stats_to_numpy(stats: dict) -> dict[str, dict[str, np.ndarray]]: - """Recursively cast numerical values in a stats dictionary to numpy arrays. - - Args: - stats (dict): The statistics dictionary. - - Returns: - dict: The statistics dictionary with values cast to numpy arrays. - """ - stats = {key: np.array(value) for key, value in flatten_dict(stats).items()} - return unflatten_dict(stats) - - -def load_stats(local_dir: Path) -> dict[str, dict[str, np.ndarray]] | None: - """Load dataset statistics and cast numerical values to numpy arrays. - - Returns None if the stats file doesn't exist. - - Args: - local_dir (Path): The root directory of the dataset. - - Returns: - A dictionary of statistics or None if the file is not found. - """ - if not (local_dir / STATS_PATH).exists(): - return None - stats = load_json(local_dir / STATS_PATH) - return cast_stats_to_numpy(stats) - - -def write_tasks(tasks: pandas.DataFrame, local_dir: Path) -> None: - path = local_dir / DEFAULT_TASKS_PATH - path.parent.mkdir(parents=True, exist_ok=True) - tasks.to_parquet(path) - - -def load_tasks(local_dir: Path) -> pandas.DataFrame: - tasks = pd.read_parquet(local_dir / DEFAULT_TASKS_PATH) - tasks.index.name = "task" - return tasks - - -def load_subtasks(local_dir: Path) -> pandas.DataFrame | None: - """Load subtasks from subtasks.parquet if it exists.""" - subtasks_path = local_dir / DEFAULT_SUBTASKS_PATH - if subtasks_path.exists(): - return pd.read_parquet(subtasks_path) - return None - - -def write_episodes(episodes: Dataset, local_dir: Path) -> None: - """Write episode metadata to a parquet file in the LeRobot v3.0 format. - This function writes episode-level metadata to a single parquet file. - Used primarily during dataset conversion (v2.1 → v3.0) and in test fixtures. - - Args: - episodes: HuggingFace Dataset containing episode metadata - local_dir: Root directory where the dataset will be stored - """ - episode_size_mb = get_hf_dataset_size_in_mb(episodes) - if episode_size_mb > DEFAULT_DATA_FILE_SIZE_IN_MB: - raise NotImplementedError( - f"Episodes dataset is too large ({episode_size_mb} MB) to write to a single file. " - f"The current limit is {DEFAULT_DATA_FILE_SIZE_IN_MB} MB. " - "This function only supports single-file episode metadata. " - ) - - fpath = local_dir / DEFAULT_EPISODES_PATH.format(chunk_index=0, file_index=0) - fpath.parent.mkdir(parents=True, exist_ok=True) - episodes.to_parquet(fpath) - - -def load_episodes(local_dir: Path) -> datasets.Dataset: - episodes = load_nested_dataset(local_dir / EPISODES_DIR) - # Select episode features/columns containing references to episode data and videos - # (e.g. tasks, dataset_from_index, dataset_to_index, data/chunk_index, data/file_index, etc.) - # This is to speedup access to these data, instead of having to load episode stats. - episodes = episodes.select_columns([key for key in episodes.features if not key.startswith("stats/")]) - return episodes - - -def load_image_as_numpy( - fpath: str | Path, dtype: np.dtype = np.float32, channel_first: bool = True -) -> np.ndarray: - """Load an image from a file into a numpy array. - - Args: - fpath (str | Path): Path to the image file. - dtype (np.dtype): The desired data type of the output array. If floating, - pixels are scaled to [0, 1]. - channel_first (bool): If True, converts the image to (C, H, W) format. - Otherwise, it remains in (H, W, C) format. - - Returns: - np.ndarray: The image as a numpy array. - """ - img = PILImage.open(fpath).convert("RGB") - img_array = np.array(img, dtype=dtype) - if channel_first: # (H, W, C) -> (C, H, W) - img_array = np.transpose(img_array, (2, 0, 1)) - if np.issubdtype(dtype, np.floating): - img_array /= 255.0 - return img_array - - -def hf_transform_to_torch(items_dict: dict[str, list[Any]]) -> dict[str, list[torch.Tensor | str]]: - """Convert a batch from a Hugging Face dataset to torch tensors. - - This transform function converts items from Hugging Face dataset format (pyarrow) - to torch tensors. Importantly, images are converted from PIL objects (H, W, C, uint8) - to a torch image representation (C, H, W, float32) in the range [0, 1]. Other - types are converted to torch.tensor. - - Args: - items_dict (dict): A dictionary representing a batch of data from a - Hugging Face dataset. - - Returns: - dict: The batch with items converted to torch tensors. - """ - for key in items_dict: - first_item = items_dict[key][0] - if isinstance(first_item, PILImage.Image): - to_tensor = transforms.ToTensor() - items_dict[key] = [to_tensor(img) for img in items_dict[key]] - elif first_item is None: - pass - else: - items_dict[key] = [x if isinstance(x, str) else torch.tensor(x) for x in items_dict[key]] - return items_dict - - def is_valid_version(version: str) -> bool: """Check if a string is a valid PEP 440 version. @@ -560,337 +319,6 @@ def get_safe_version(repo_id: str, version: str | packaging.version.Version) -> raise ForwardCompatibilityError(repo_id, min(upper_versions)) -def get_hf_features_from_features(features: dict) -> datasets.Features: - """Convert a LeRobot features dictionary to a `datasets.Features` object. - - Args: - features (dict): A LeRobot-style feature dictionary. - - Returns: - datasets.Features: The corresponding Hugging Face `datasets.Features` object. - - Raises: - ValueError: If a feature has an unsupported shape. - """ - hf_features = {} - for key, ft in features.items(): - if ft["dtype"] == "video": - continue - elif ft["dtype"] == "image": - hf_features[key] = datasets.Image() - elif ft["shape"] == (1,): - hf_features[key] = datasets.Value(dtype=ft["dtype"]) - elif len(ft["shape"]) == 1: - hf_features[key] = datasets.Sequence( - length=ft["shape"][0], feature=datasets.Value(dtype=ft["dtype"]) - ) - elif len(ft["shape"]) == 2: - hf_features[key] = datasets.Array2D(shape=ft["shape"], dtype=ft["dtype"]) - elif len(ft["shape"]) == 3: - hf_features[key] = datasets.Array3D(shape=ft["shape"], dtype=ft["dtype"]) - elif len(ft["shape"]) == 4: - hf_features[key] = datasets.Array4D(shape=ft["shape"], dtype=ft["dtype"]) - elif len(ft["shape"]) == 5: - hf_features[key] = datasets.Array5D(shape=ft["shape"], dtype=ft["dtype"]) - else: - raise ValueError(f"Corresponding feature is not valid: {ft}") - - return datasets.Features(hf_features) - - -def _validate_feature_names(features: dict[str, dict]) -> None: - """Validate that feature names do not contain invalid characters. - - Args: - features (dict): The LeRobot features dictionary. - - Raises: - ValueError: If any feature name contains '/'. - """ - invalid_features = {name: ft for name, ft in features.items() if "/" in name} - if invalid_features: - raise ValueError(f"Feature names should not contain '/'. Found '/' in '{invalid_features}'.") - - -def hw_to_dataset_features( - hw_features: dict[str, type | tuple], prefix: str, use_video: bool = True -) -> dict[str, dict]: - """Convert hardware-specific features to a LeRobot dataset feature dictionary. - - This function takes a dictionary describing hardware outputs (like joint states - or camera image shapes) and formats it into the standard LeRobot feature - specification. - - Args: - hw_features (dict): Dictionary mapping feature names to their type (float for - joints) or shape (tuple for images). - prefix (str): The prefix to add to the feature keys (e.g., "observation" - or "action"). - use_video (bool): If True, image features are marked as "video", otherwise "image". - - Returns: - dict: A LeRobot features dictionary. - """ - features = {} - joint_fts = { - key: ftype - for key, ftype in hw_features.items() - if ftype is float or (isinstance(ftype, PolicyFeature) and ftype.type != FeatureType.VISUAL) - } - cam_fts = {key: shape for key, shape in hw_features.items() if isinstance(shape, tuple)} - - if joint_fts and prefix == ACTION: - features[prefix] = { - "dtype": "float32", - "shape": (len(joint_fts),), - "names": list(joint_fts), - } - - if joint_fts and prefix == OBS_STR: - features[f"{prefix}.state"] = { - "dtype": "float32", - "shape": (len(joint_fts),), - "names": list(joint_fts), - } - - for key, shape in cam_fts.items(): - features[f"{prefix}.images.{key}"] = { - "dtype": "video" if use_video else "image", - "shape": shape, - "names": ["height", "width", "channels"], - } - - _validate_feature_names(features) - return features - - -def build_dataset_frame( - ds_features: dict[str, dict], values: dict[str, Any], prefix: str -) -> dict[str, np.ndarray]: - """Construct a single data frame from raw values based on dataset features. - - A "frame" is a dictionary containing all the data for a single timestep, - formatted as numpy arrays according to the feature specification. - - Args: - ds_features (dict): The LeRobot dataset features dictionary. - values (dict): A dictionary of raw values from the hardware/environment. - prefix (str): The prefix to filter features by (e.g., "observation" - or "action"). - - Returns: - dict: A dictionary representing a single frame of data. - """ - frame = {} - for key, ft in ds_features.items(): - if key in DEFAULT_FEATURES or not key.startswith(prefix): - continue - elif ft["dtype"] == "float32" and len(ft["shape"]) == 1: - frame[key] = np.array([values[name] for name in ft["names"]], dtype=np.float32) - elif ft["dtype"] in ["image", "video"]: - frame[key] = values[key.removeprefix(f"{prefix}.images.")] - - return frame - - -def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFeature]: - """Convert dataset features to policy features. - - This function transforms the dataset's feature specification into a format - that a policy can use, classifying features by type (e.g., visual, state, - action) and ensuring correct shapes (e.g., channel-first for images). - - Args: - features (dict): The LeRobot dataset features dictionary. - - Returns: - dict: A dictionary mapping feature keys to `PolicyFeature` objects. - - Raises: - ValueError: If an image feature does not have a 3D shape. - """ - # TODO(aliberts): Implement "type" in dataset features and simplify this - policy_features = {} - for key, ft in features.items(): - shape = ft["shape"] - if ft["dtype"] in ["image", "video"]: - type = FeatureType.VISUAL - if len(shape) != 3: - raise ValueError(f"Number of dimensions of {key} != 3 (shape={shape})") - - names = ft["names"] - # Backward compatibility for "channel" which is an error introduced in LeRobotDataset v2.0 for ported datasets. - if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w) - shape = (shape[2], shape[0], shape[1]) - elif key == OBS_ENV_STATE: - type = FeatureType.ENV - elif key.startswith(OBS_STR): - type = FeatureType.STATE - elif key.startswith(ACTION): - type = FeatureType.ACTION - else: - continue - - policy_features[key] = PolicyFeature( - type=type, - shape=shape, - ) - - return policy_features - - -def combine_feature_dicts(*dicts: dict) -> dict: - """Merge LeRobot grouped feature dicts. - - - For 1D numeric specs (dtype not image/video/string) with "names": we merge the names and recompute the shape. - - For others (e.g. `observation.images.*`), the last one wins (if they are identical). - - Args: - *dicts: A variable number of LeRobot feature dictionaries to merge. - - Returns: - dict: A single merged feature dictionary. - - Raises: - ValueError: If there's a dtype mismatch for a feature being merged. - """ - out: dict = {} - for d in dicts: - for key, value in d.items(): - if not isinstance(value, dict): - out[key] = value - continue - - dtype = value.get("dtype") - shape = value.get("shape") - is_vector = ( - dtype not in ("image", "video", "string") - and isinstance(shape, tuple) - and len(shape) == 1 - and "names" in value - ) - - if is_vector: - # Initialize or retrieve the accumulating dict for this feature key - target = out.setdefault(key, {"dtype": dtype, "names": [], "shape": (0,)}) - # Ensure consistent data types across merged entries - if "dtype" in target and dtype != target["dtype"]: - raise ValueError(f"dtype mismatch for '{key}': {target['dtype']} vs {dtype}") - - # Merge feature names: append only new ones to preserve order without duplicates - seen = set(target["names"]) - for n in value["names"]: - if n not in seen: - target["names"].append(n) - seen.add(n) - # Recompute the shape to reflect the updated number of features - target["shape"] = (len(target["names"]),) - else: - # For images/videos and non-1D entries: override with the latest definition - out[key] = value - return out - - -def create_empty_dataset_info( - codebase_version: str, - fps: int, - features: dict, - use_videos: bool, - robot_type: str | None = None, - chunks_size: int | None = None, - data_files_size_in_mb: int | None = None, - video_files_size_in_mb: int | None = None, -) -> dict: - """Create a template dictionary for a new dataset's `info.json`. - - Args: - codebase_version (str): The version of the LeRobot codebase. - fps (int): The frames per second of the data. - features (dict): The LeRobot features dictionary for the dataset. - use_videos (bool): Whether the dataset will store videos. - robot_type (str | None): The type of robot used, if any. - - Returns: - dict: A dictionary with the initial dataset metadata. - """ - return { - "codebase_version": codebase_version, - "robot_type": robot_type, - "total_episodes": 0, - "total_frames": 0, - "total_tasks": 0, - "chunks_size": chunks_size or DEFAULT_CHUNK_SIZE, - "data_files_size_in_mb": data_files_size_in_mb or DEFAULT_DATA_FILE_SIZE_IN_MB, - "video_files_size_in_mb": video_files_size_in_mb or DEFAULT_VIDEO_FILE_SIZE_IN_MB, - "fps": fps, - "splits": {}, - "data_path": DEFAULT_DATA_PATH, - "video_path": DEFAULT_VIDEO_PATH if use_videos else None, - "features": features, - } - - -def check_delta_timestamps( - delta_timestamps: dict[str, list[float]], fps: int, tolerance_s: float, raise_value_error: bool = True -) -> bool: - """Check if delta timestamps are multiples of 1/fps +/- tolerance. - - This ensures that adding these delta timestamps to any existing timestamp in - the dataset will result in a value that aligns with the dataset's frame rate. - - Args: - delta_timestamps (dict): A dictionary where values are lists of time - deltas in seconds. - fps (int): The frames per second of the dataset. - tolerance_s (float): The allowed tolerance in seconds. - raise_value_error (bool): If True, raises an error on failure. - - Returns: - bool: True if all deltas are valid, False otherwise. - - Raises: - ValueError: If any delta is outside the tolerance and `raise_value_error` is True. - """ - outside_tolerance = {} - for key, delta_ts in delta_timestamps.items(): - within_tolerance = [abs(ts * fps - round(ts * fps)) / fps <= tolerance_s for ts in delta_ts] - if not all(within_tolerance): - outside_tolerance[key] = [ - ts for ts, is_within in zip(delta_ts, within_tolerance, strict=True) if not is_within - ] - - if len(outside_tolerance) > 0: - if raise_value_error: - raise ValueError( - f""" - The following delta_timestamps are found outside of tolerance range. - Please make sure they are multiples of 1/{fps} +/- tolerance and adjust - their values accordingly. - \n{pformat(outside_tolerance)} - """ - ) - return False - - return True - - -def get_delta_indices(delta_timestamps: dict[str, list[float]], fps: int) -> dict[str, list[int]]: - """Convert delta timestamps in seconds to delta indices in frames. - - Args: - delta_timestamps (dict): A dictionary of time deltas in seconds. - fps (int): The frames per second of the dataset. - - Returns: - dict: A dictionary of frame delta indices. - """ - delta_indices = {} - for key, delta_ts in delta_timestamps.items(): - delta_indices[key] = [round(d * fps) for d in delta_ts] - - return delta_indices - - def cycle(iterable: Any) -> Iterator[Any]: """Create a dataloader-safe cyclical iterator. @@ -982,229 +410,6 @@ def create_lerobot_dataset_card( ) -def validate_frame(frame: dict, features: dict) -> None: - expected_features = set(features) - set(DEFAULT_FEATURES) - actual_features = set(frame) - - # task is a special required field that's not part of regular features - if "task" not in actual_features: - raise ValueError("Feature mismatch in `frame` dictionary:\nMissing features: {'task'}\n") - - # Remove task from actual_features for regular feature validation - actual_features_for_validation = actual_features - {"task"} - - error_message = validate_features_presence(actual_features_for_validation, expected_features) - - common_features = actual_features_for_validation & expected_features - for name in common_features: - error_message += validate_feature_dtype_and_shape(name, features[name], frame[name]) - - if error_message: - raise ValueError(error_message) - - -def validate_features_presence(actual_features: set[str], expected_features: set[str]) -> str: - """Check for missing or extra features in a frame. - - Args: - actual_features (set[str]): The set of feature names present in the frame. - expected_features (set[str]): The set of feature names expected in the frame. - - Returns: - str: An error message string if there's a mismatch, otherwise an empty string. - """ - error_message = "" - missing_features = expected_features - actual_features - extra_features = actual_features - expected_features - - if missing_features or extra_features: - error_message += "Feature mismatch in `frame` dictionary:\n" - if missing_features: - error_message += f"Missing features: {missing_features}\n" - if extra_features: - error_message += f"Extra features: {extra_features}\n" - - return error_message - - -def validate_feature_dtype_and_shape( - name: str, feature: dict, value: np.ndarray | PILImage.Image | str -) -> str: - """Validate the dtype and shape of a single feature's value. - - Args: - name (str): The name of the feature. - feature (dict): The feature specification from the LeRobot features dictionary. - value: The value of the feature to validate. - - Returns: - str: An error message if validation fails, otherwise an empty string. - - Raises: - NotImplementedError: If the feature dtype is not supported for validation. - """ - expected_dtype = feature["dtype"] - expected_shape = feature["shape"] - if is_valid_numpy_dtype_string(expected_dtype): - return validate_feature_numpy_array(name, expected_dtype, expected_shape, value) - elif expected_dtype in ["image", "video"]: - return validate_feature_image_or_video(name, expected_shape, value) - elif expected_dtype == "string": - return validate_feature_string(name, value) - else: - raise NotImplementedError(f"The feature dtype '{expected_dtype}' is not implemented yet.") - - -def validate_feature_numpy_array( - name: str, expected_dtype: str, expected_shape: list[int], value: np.ndarray -) -> str: - """Validate a feature that is expected to be a numpy array. - - Args: - name (str): The name of the feature. - expected_dtype (str): The expected numpy dtype as a string. - expected_shape (list[int]): The expected shape. - value (np.ndarray): The numpy array to validate. - - Returns: - str: An error message if validation fails, otherwise an empty string. - """ - error_message = "" - if isinstance(value, np.ndarray): - actual_dtype = value.dtype - actual_shape = value.shape - - if actual_dtype != np.dtype(expected_dtype): - error_message += f"The feature '{name}' of dtype '{actual_dtype}' is not of the expected dtype '{expected_dtype}'.\n" - - if actual_shape != expected_shape: - error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{expected_shape}'.\n" - else: - error_message += f"The feature '{name}' is not a 'np.ndarray'. Expected type is '{expected_dtype}', but type '{type(value)}' provided instead.\n" - - return error_message - - -def validate_feature_image_or_video( - name: str, expected_shape: list[str], value: np.ndarray | PILImage.Image -) -> str: - """Validate a feature that is expected to be an image or video frame. - - Accepts `np.ndarray` (channel-first or channel-last) or `PIL.Image.Image`. - - Args: - name (str): The name of the feature. - expected_shape (list[str]): The expected shape (C, H, W). - value: The image data to validate. - - Returns: - str: An error message if validation fails, otherwise an empty string. - """ - # Note: The check of pixels range ([0,1] for float and [0,255] for uint8) is done by the image writer threads. - error_message = "" - if isinstance(value, np.ndarray): - actual_shape = value.shape - c, h, w = expected_shape - if len(actual_shape) != 3 or (actual_shape != (c, h, w) and actual_shape != (h, w, c)): - error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(c, h, w)}' or '{(h, w, c)}'.\n" - elif isinstance(value, PILImage.Image): - pass - else: - error_message += f"The feature '{name}' is expected to be of type 'PIL.Image' or 'np.ndarray' channel first or channel last, but type '{type(value)}' provided instead.\n" - - return error_message - - -def validate_feature_string(name: str, value: str) -> str: - """Validate a feature that is expected to be a string. - - Args: - name (str): The name of the feature. - value (str): The value to validate. - - Returns: - str: An error message if validation fails, otherwise an empty string. - """ - if not isinstance(value, str): - return f"The feature '{name}' is expected to be of type 'str', but type '{type(value)}' provided instead.\n" - return "" - - -def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features: dict) -> None: - """Validate the episode buffer before it's written to disk. - - Ensures the buffer has the required keys, contains at least one frame, and - has features consistent with the dataset's specification. - - Args: - episode_buffer (dict): The buffer containing data for a single episode. - total_episodes (int): The current total number of episodes in the dataset. - features (dict): The LeRobot features dictionary for the dataset. - - Raises: - ValueError: If the buffer is invalid. - NotImplementedError: If the episode index is manually set and doesn't match. - """ - if "size" not in episode_buffer: - raise ValueError("size key not found in episode_buffer") - - if "task" not in episode_buffer: - raise ValueError("task key not found in episode_buffer") - - if episode_buffer["episode_index"] != total_episodes: - # TODO(aliberts): Add option to use existing episode_index - raise NotImplementedError( - "You might have manually provided the episode_buffer with an episode_index that doesn't " - "match the total number of episodes already in the dataset. This is not supported for now." - ) - - if episode_buffer["size"] == 0: - raise ValueError("You must add one or several frames with `add_frame` before calling `add_episode`.") - - buffer_keys = set(episode_buffer.keys()) - {"task", "size"} - if not buffer_keys == set(features): - raise ValueError( - f"Features from `episode_buffer` don't match the ones in `features`." - f"In episode_buffer not in features: {buffer_keys - set(features)}" - f"In features not in episode_buffer: {set(features) - buffer_keys}" - ) - - -def to_parquet_with_hf_images( - df: pandas.DataFrame, path: Path, features: datasets.Features | None = None -) -> None: - """This function correctly writes to parquet a panda DataFrame that contains images encoded by HF dataset. - This way, it can be loaded by HF dataset and correctly formatted images are returned. - - Args: - df: DataFrame to write to parquet. - path: Path to write the parquet file. - features: Optional HuggingFace Features schema. If provided, ensures image columns - are properly typed as Image() in the parquet schema. - """ - # TODO(qlhoest): replace this weird synthax by `df.to_parquet(path)` only - ds = datasets.Dataset.from_dict(df.to_dict(orient="list"), features=features) - ds.to_parquet(path) - - -def item_to_torch(item: dict) -> dict: - """Convert all items in a dictionary to PyTorch tensors where appropriate. - - This function is used to convert an item from a streaming dataset to PyTorch tensors. - - Args: - item (dict): Dictionary of items from a dataset. - - Returns: - dict: Dictionary with all tensor-like items converted to torch.Tensor. - """ - for key, val in item.items(): - if isinstance(val, (np.ndarray | list)) and key not in ["task"]: - # Convert numpy arrays and lists to torch tensors - item[key] = torch.tensor(val) - return item - - def is_float_in_list(target, float_list, threshold=1e-6): return any(abs(target - x) <= threshold for x in float_list) @@ -1216,164 +421,6 @@ def find_float_index(target, float_list, threshold=1e-6): return -1 -class LookBackError(Exception): - """ - Exception raised when trying to look back in the history of a Backtrackable object. - """ - - pass - - -class LookAheadError(Exception): - """ - Exception raised when trying to look ahead in the future of a Backtrackable object. - """ - - pass - - -class Backtrackable[T]: - """ - Wrap any iterator/iterable so you can step back up to `history` items - and look ahead up to `lookahead` items. - - This is useful for streaming datasets where you need to access previous and future items - but can't load the entire dataset into memory. - - Example: - ------- - ```python - ds = load_dataset("c4", "en", streaming=True, split="train") - rev = Backtrackable(ds, history=3, lookahead=2) - - x0 = next(rev) # forward - x1 = next(rev) - x2 = next(rev) - - # Look ahead - x3_peek = rev.peek_ahead(1) # next item without moving cursor - x4_peek = rev.peek_ahead(2) # two items ahead - - # Look back - x1_again = rev.peek_back(1) # previous item without moving cursor - x0_again = rev.peek_back(2) # two items back - - # Move backward - x1_back = rev.prev() # back one step - next(rev) # returns x2, continues forward from where we were - ``` - """ - - __slots__ = ("_source", "_back_buf", "_ahead_buf", "_cursor", "_history", "_lookahead") - - def __init__(self, iterable: Iterable[T], *, history: int = 1, lookahead: int = 0): - if history < 1: - raise ValueError("history must be >= 1") - if lookahead <= 0: - raise ValueError("lookahead must be > 0") - - self._source: Iterator[T] = iter(iterable) - self._back_buf: deque[T] = deque(maxlen=history) - self._ahead_buf: deque[T] = deque(maxlen=lookahead) if lookahead > 0 else deque() - self._cursor: int = 0 - self._history = history - self._lookahead = lookahead - - def __iter__(self) -> "Backtrackable[T]": - return self - - def __next__(self) -> T: - # If we've stepped back, consume from back buffer first - if self._cursor < 0: # -1 means "last item", etc. - self._cursor += 1 - return self._back_buf[self._cursor] - - # If we have items in the ahead buffer, use them first - item = self._ahead_buf.popleft() if self._ahead_buf else next(self._source) - - # Add current item to back buffer and reset cursor - self._back_buf.append(item) - self._cursor = 0 - return item - - def prev(self) -> T: - """ - Step one item back in history and return it. - Raises IndexError if already at the oldest buffered item. - """ - if len(self._back_buf) + self._cursor <= 1: - raise LookBackError("At start of history") - - self._cursor -= 1 - return self._back_buf[self._cursor] - - def peek_back(self, n: int = 1) -> T: - """ - Look `n` items back (n=1 == previous item) without moving the cursor. - """ - if n < 0 or n + 1 > len(self._back_buf) + self._cursor: - raise LookBackError("peek_back distance out of range") - - return self._back_buf[self._cursor - (n + 1)] - - def peek_ahead(self, n: int = 1) -> T: - """ - Look `n` items ahead (n=1 == next item) without moving the cursor. - Fills the ahead buffer if necessary. - """ - if n < 1: - raise LookAheadError("peek_ahead distance must be 1 or more") - elif n > self._lookahead: - raise LookAheadError("peek_ahead distance exceeds lookahead limit") - - # Fill ahead buffer if we don't have enough items - while len(self._ahead_buf) < n: - try: - item = next(self._source) - self._ahead_buf.append(item) - - except StopIteration as err: - raise LookAheadError("peek_ahead: not enough items in source") from err - - return self._ahead_buf[n - 1] - - def history(self) -> list[T]: - """ - Return a copy of the buffered history (most recent last). - The list length ≤ `history` argument passed at construction. - """ - if self._cursor == 0: - return list(self._back_buf) - - # When cursor<0, slice so the order remains chronological - return list(self._back_buf)[: self._cursor or None] - - def can_peek_back(self, steps: int = 1) -> bool: - """ - Check if we can go back `steps` items without raising an IndexError. - """ - return steps <= len(self._back_buf) + self._cursor - - def can_peek_ahead(self, steps: int = 1) -> bool: - """ - Check if we can peek ahead `steps` items. - This may involve trying to fill the ahead buffer. - """ - if self._lookahead > 0 and steps > self._lookahead: - return False - - # Try to fill ahead buffer to check if we can peek that far - try: - while len(self._ahead_buf) < steps: - if self._lookahead > 0 and len(self._ahead_buf) >= self._lookahead: - return False - item = next(self._source) - self._ahead_buf.append(item) - return True - except StopIteration: - return False - - def safe_shard(dataset: datasets.IterableDataset, index: int, num_shards: int) -> datasets.Dataset: """ Safe shards the dataset. diff --git a/src/lerobot/optim/optimizers.py b/src/lerobot/optim/optimizers.py index 2b75353d9..e2e3d8937 100644 --- a/src/lerobot/optim/optimizers.py +++ b/src/lerobot/optim/optimizers.py @@ -23,7 +23,8 @@ import draccus import torch from safetensors.torch import load_file, save_file -from lerobot.datasets.utils import flatten_dict, unflatten_dict, write_json +from lerobot.datasets.io_utils import write_json +from lerobot.datasets.utils import flatten_dict, unflatten_dict from lerobot.utils.constants import ( OPTIMIZER_PARAM_GROUPS, OPTIMIZER_STATE, diff --git a/src/lerobot/optim/schedulers.py b/src/lerobot/optim/schedulers.py index 4af7f0802..19c3fd7bd 100644 --- a/src/lerobot/optim/schedulers.py +++ b/src/lerobot/optim/schedulers.py @@ -23,7 +23,7 @@ import draccus from torch.optim import Optimizer from torch.optim.lr_scheduler import LambdaLR, LRScheduler -from lerobot.datasets.utils import write_json +from lerobot.datasets.io_utils import write_json from lerobot.utils.constants import SCHEDULER_STATE from lerobot.utils.io_utils import deserialize_json_into_object diff --git a/src/lerobot/policies/factory.py b/src/lerobot/policies/factory.py index 9515d5b82..2320cd624 100644 --- a/src/lerobot/policies/factory.py +++ b/src/lerobot/policies/factory.py @@ -24,8 +24,8 @@ import torch from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.types import FeatureType -from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata -from lerobot.datasets.utils import dataset_to_policy_features +from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata +from lerobot.datasets.feature_utils import dataset_to_policy_features from lerobot.envs.configs import EnvConfig from lerobot.envs.utils import env_to_policy_features from lerobot.policies.act.configuration_act import ACTConfig diff --git a/src/lerobot/policies/utils.py b/src/lerobot/policies/utils.py index 9ad5dac4a..82ab51005 100644 --- a/src/lerobot/policies/utils.py +++ b/src/lerobot/policies/utils.py @@ -23,7 +23,7 @@ from torch import nn from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.types import FeatureType, PolicyFeature -from lerobot.datasets.utils import build_dataset_frame +from lerobot.datasets.feature_utils import build_dataset_frame from lerobot.types import PolicyAction, RobotAction, RobotObservation from lerobot.utils.constants import ACTION, OBS_STR diff --git a/src/lerobot/scripts/augment_dataset_quantile_stats.py b/src/lerobot/scripts/augment_dataset_quantile_stats.py index e6ab6867e..4d80c9332 100644 --- a/src/lerobot/scripts/augment_dataset_quantile_stats.py +++ b/src/lerobot/scripts/augment_dataset_quantile_stats.py @@ -45,8 +45,9 @@ from requests import HTTPError from tqdm import tqdm from lerobot.datasets.compute_stats import DEFAULT_QUANTILES, aggregate_stats, get_feature_stats -from lerobot.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset -from lerobot.datasets.utils import write_stats +from lerobot.datasets.dataset_metadata import CODEBASE_VERSION +from lerobot.datasets.io_utils import write_stats +from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.utils.utils import init_logging diff --git a/src/lerobot/scripts/convert_dataset_v21_to_v30.py b/src/lerobot/scripts/convert_dataset_v21_to_v30.py index dc81cc51c..2b6dcf732 100644 --- a/src/lerobot/scripts/convert_dataset_v21_to_v30.py +++ b/src/lerobot/scripts/convert_dataset_v21_to_v30.py @@ -60,7 +60,19 @@ from huggingface_hub import HfApi, snapshot_download from requests import HTTPError from lerobot.datasets.compute_stats import aggregate_stats -from lerobot.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset +from lerobot.datasets.dataset_metadata import CODEBASE_VERSION +from lerobot.datasets.io_utils import ( + cast_stats_to_numpy, + get_file_size_in_mb, + get_parquet_file_size_in_mb, + get_parquet_num_frames, + load_info, + write_episodes, + write_info, + write_stats, + write_tasks, +) +from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.utils import ( DEFAULT_CHUNK_SIZE, DEFAULT_DATA_FILE_SIZE_IN_MB, @@ -70,17 +82,8 @@ from lerobot.datasets.utils import ( LEGACY_EPISODES_PATH, LEGACY_EPISODES_STATS_PATH, LEGACY_TASKS_PATH, - cast_stats_to_numpy, flatten_dict, - get_file_size_in_mb, - get_parquet_file_size_in_mb, - get_parquet_num_frames, - load_info, update_chunk_file_indices, - write_episodes, - write_info, - write_stats, - write_tasks, ) from lerobot.datasets.video_utils import concatenate_video_files, get_video_duration_in_s from lerobot.utils.constants import HF_LEROBOT_HOME diff --git a/src/lerobot/scripts/lerobot_record.py b/src/lerobot/scripts/lerobot_record.py index 345d18f23..819634ba2 100644 --- a/src/lerobot/scripts/lerobot_record.py +++ b/src/lerobot/scripts/lerobot_record.py @@ -83,10 +83,10 @@ from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraCon from lerobot.cameras.zmq.configuration_zmq import ZMQCameraConfig # noqa: F401 from lerobot.configs import parser from lerobot.configs.policies import PreTrainedConfig +from lerobot.datasets.feature_utils import build_dataset_frame, combine_feature_dicts from lerobot.datasets.image_writer import safe_stop_image_writer from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features -from lerobot.datasets.utils import build_dataset_frame, combine_feature_dicts from lerobot.datasets.video_utils import VideoEncodingManager from lerobot.policies.factory import make_policy, make_pre_post_processors from lerobot.policies.pretrained import PreTrainedPolicy diff --git a/src/lerobot/utils/train_utils.py b/src/lerobot/utils/train_utils.py index d8481f4b9..02f6aebb3 100644 --- a/src/lerobot/utils/train_utils.py +++ b/src/lerobot/utils/train_utils.py @@ -19,7 +19,7 @@ from torch.optim import Optimizer from torch.optim.lr_scheduler import LRScheduler from lerobot.configs.train import TrainPipelineConfig -from lerobot.datasets.utils import load_json, write_json +from lerobot.datasets.io_utils import load_json, write_json from lerobot.optim.optimizers import load_optimizer_state, save_optimizer_state from lerobot.optim.schedulers import load_scheduler_state, save_scheduler_state from lerobot.policies.pretrained import PreTrainedPolicy diff --git a/tests/datasets/test_aggregate.py b/tests/datasets/test_aggregate.py index 3609bac24..4ac7e001a 100644 --- a/tests/datasets/test_aggregate.py +++ b/tests/datasets/test_aggregate.py @@ -260,8 +260,8 @@ def test_aggregate_datasets(tmp_path, lerobot_dataset_factory): # Mock the revision to prevent Hub calls during dataset loading with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.return_value = str(tmp_path / "test_aggr") @@ -311,8 +311,8 @@ def test_aggregate_with_low_threshold(tmp_path, lerobot_dataset_factory): # Mock the revision to prevent Hub calls during dataset loading with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.return_value = str(tmp_path / "small_aggr") @@ -367,8 +367,8 @@ def test_video_timestamps_regression(tmp_path, lerobot_dataset_factory): ) with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.return_value = str(tmp_path / "regression_aggr") @@ -492,8 +492,8 @@ def test_aggregate_image_datasets(tmp_path, lerobot_dataset_factory): # Load the aggregated dataset with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.return_value = str(tmp_path / "image_aggr") @@ -562,8 +562,8 @@ def test_aggregate_already_merged_dataset(tmp_path, lerobot_dataset_factory): ) with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.return_value = str(tmp_path / "ds_ab") @@ -590,8 +590,8 @@ def test_aggregate_already_merged_dataset(tmp_path, lerobot_dataset_factory): ) with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.return_value = str(tmp_path / "ds_abc") diff --git a/tests/datasets/test_dataset_tools.py b/tests/datasets/test_dataset_tools.py index 1de199630..5ed7aa1a3 100644 --- a/tests/datasets/test_dataset_tools.py +++ b/tests/datasets/test_dataset_tools.py @@ -67,8 +67,8 @@ def test_delete_single_episode(sample_dataset, tmp_path): output_dir = tmp_path / "filtered" with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.return_value = str(output_dir) @@ -93,8 +93,8 @@ def test_delete_multiple_episodes(sample_dataset, tmp_path): output_dir = tmp_path / "filtered" with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.return_value = str(output_dir) @@ -150,8 +150,8 @@ def test_split_by_episodes(sample_dataset, tmp_path): } with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" @@ -193,8 +193,8 @@ def test_split_by_fractions(sample_dataset, tmp_path): } with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" @@ -270,8 +270,8 @@ def test_merge_two_datasets(sample_dataset, tmp_path, empty_lerobot_dataset_fact dataset2.finalize() with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.return_value = str(tmp_path / "merged_dataset") @@ -310,8 +310,8 @@ def test_add_features_with_values(sample_dataset, tmp_path): } with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.return_value = str(tmp_path / "with_reward") @@ -346,8 +346,8 @@ def test_add_features_with_callable(sample_dataset, tmp_path): "reward": (compute_reward, feature_info), } with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.return_value = str(tmp_path / "with_reward") @@ -401,8 +401,8 @@ def test_modify_features_add_and_remove(sample_dataset, tmp_path): feature_info = {"dtype": "float32", "shape": (1,), "names": None} with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.return_value = str(tmp_path / "modified") @@ -434,8 +434,8 @@ def test_modify_features_only_add(sample_dataset, tmp_path): feature_info = {"dtype": "float32", "shape": (1,), "names": None} with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.return_value = str(tmp_path / "modified") @@ -457,8 +457,8 @@ def test_modify_features_only_remove(sample_dataset, tmp_path): feature_info = {"dtype": "float32", "shape": (1,), "names": None} with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(kwargs.get("local_dir", tmp_path)) @@ -494,8 +494,8 @@ def test_remove_single_feature(sample_dataset, tmp_path): "reward": (np.random.randn(50, 1).astype(np.float32), feature_info), } with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(kwargs.get("local_dir", tmp_path)) @@ -521,8 +521,8 @@ def test_remove_single_feature(sample_dataset, tmp_path): def test_remove_multiple_features(sample_dataset, tmp_path): """Test removing multiple features at once.""" with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(kwargs.get("local_dir", tmp_path)) @@ -576,8 +576,8 @@ def test_remove_camera_feature(sample_dataset, tmp_path): camera_to_remove = camera_keys[0] with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.return_value = str(tmp_path / "without_camera") @@ -598,8 +598,8 @@ def test_remove_camera_feature(sample_dataset, tmp_path): def test_complex_workflow_integration(sample_dataset, tmp_path): """Test a complex workflow combining multiple operations.""" with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(kwargs.get("local_dir", tmp_path)) @@ -647,8 +647,8 @@ def test_delete_episodes_preserves_stats(sample_dataset, tmp_path): output_dir = tmp_path / "filtered" with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.return_value = str(output_dir) @@ -671,8 +671,8 @@ def test_delete_episodes_preserves_tasks(sample_dataset, tmp_path): output_dir = tmp_path / "filtered" with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.return_value = str(output_dir) @@ -699,8 +699,8 @@ def test_split_three_ways(sample_dataset, tmp_path): } with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" @@ -732,8 +732,8 @@ def test_split_preserves_stats(sample_dataset, tmp_path): splits = {"train": [0, 1, 2], "val": [3, 4]} with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" @@ -790,8 +790,8 @@ def test_merge_three_datasets(sample_dataset, tmp_path, empty_lerobot_dataset_fa datasets.append(dataset) with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.return_value = str(tmp_path / "merged_dataset") @@ -832,8 +832,8 @@ def test_merge_preserves_stats(sample_dataset, tmp_path, empty_lerobot_dataset_f dataset2.finalize() with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.return_value = str(tmp_path / "merged_dataset") @@ -866,8 +866,8 @@ def test_add_features_preserves_existing_stats(sample_dataset, tmp_path): } with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.return_value = str(tmp_path / "with_reward") @@ -890,8 +890,8 @@ def test_remove_feature_updates_stats(sample_dataset, tmp_path): feature_info = {"dtype": "float32", "shape": (1,), "names": None} with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(kwargs.get("local_dir", tmp_path)) @@ -919,8 +919,8 @@ def test_delete_consecutive_episodes(sample_dataset, tmp_path): output_dir = tmp_path / "filtered" with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.return_value = str(output_dir) @@ -943,8 +943,8 @@ def test_delete_first_and_last_episodes(sample_dataset, tmp_path): output_dir = tmp_path / "filtered" with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.return_value = str(output_dir) @@ -971,8 +971,8 @@ def test_split_all_episodes_assigned(sample_dataset, tmp_path): } with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" @@ -999,8 +999,8 @@ def test_modify_features_preserves_file_structure(sample_dataset, tmp_path): feature_info = {"dtype": "float32", "shape": (1,), "names": None} with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" @@ -1020,7 +1020,7 @@ def test_modify_features_preserves_file_structure(sample_dataset, tmp_path): # Get original chunk/file indices from first episode if train_dataset.meta.episodes is None: - from lerobot.datasets.utils import load_episodes + from lerobot.datasets.io_utils import load_episodes train_dataset.meta.episodes = load_episodes(train_dataset.meta.root) original_chunk_indices = [ep["data/chunk_index"] for ep in train_dataset.meta.episodes] @@ -1040,7 +1040,7 @@ def test_modify_features_preserves_file_structure(sample_dataset, tmp_path): # Check that chunk/file indices are preserved if modified_dataset.meta.episodes is None: - from lerobot.datasets.utils import load_episodes + from lerobot.datasets.io_utils import load_episodes modified_dataset.meta.episodes = load_episodes(modified_dataset.meta.root) new_chunk_indices = [ep["data/chunk_index"] for ep in modified_dataset.meta.episodes] @@ -1194,7 +1194,7 @@ def test_modify_tasks_in_place(sample_dataset): def test_modify_tasks_keeps_original_when_not_overridden(sample_dataset): """Test that original tasks are kept when using episode_tasks without new_task.""" - from lerobot.datasets.utils import load_episodes + from lerobot.datasets.io_utils import load_episodes # Ensure episodes metadata is loaded if sample_dataset.meta.episodes is None: @@ -1229,8 +1229,8 @@ def test_convert_image_to_video_dataset(tmp_path): output_dir = tmp_path / "pusht_video" with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.return_value = str(output_dir) @@ -1292,8 +1292,8 @@ def test_convert_image_to_video_dataset_subset_episodes(tmp_path): output_dir = tmp_path / "pusht_video_subset" with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.return_value = str(output_dir) diff --git a/tests/datasets/test_dataset_utils.py b/tests/datasets/test_dataset_utils.py index d40ee238f..874099e2b 100644 --- a/tests/datasets/test_dataset_utils.py +++ b/tests/datasets/test_dataset_utils.py @@ -19,7 +19,9 @@ import torch from datasets import Dataset from huggingface_hub import DatasetCard -from lerobot.datasets.utils import combine_feature_dicts, create_lerobot_dataset_card, hf_transform_to_torch +from lerobot.datasets.feature_utils import combine_feature_dicts +from lerobot.datasets.io_utils import hf_transform_to_torch +from lerobot.datasets.utils import create_lerobot_dataset_card from lerobot.utils.constants import ACTION, OBS_IMAGES diff --git a/tests/datasets/test_datasets.py b/tests/datasets/test_datasets.py index 6f99eb301..67878d8f6 100644 --- a/tests/datasets/test_datasets.py +++ b/tests/datasets/test_datasets.py @@ -29,20 +29,19 @@ import lerobot from lerobot.configs.default import DatasetConfig from lerobot.configs.train import TrainPipelineConfig from lerobot.datasets.factory import make_dataset +from lerobot.datasets.feature_utils import get_hf_features_from_features, hw_to_dataset_features from lerobot.datasets.image_writer import image_array_to_pil_image +from lerobot.datasets.io_utils import hf_transform_to_torch from lerobot.datasets.lerobot_dataset import ( LeRobotDataset, - MultiLeRobotDataset, _encode_video_worker, ) +from lerobot.datasets.multi_dataset import MultiLeRobotDataset from lerobot.datasets.utils import ( DEFAULT_CHUNK_SIZE, DEFAULT_DATA_FILE_SIZE_IN_MB, DEFAULT_VIDEO_FILE_SIZE_IN_MB, create_branch, - get_hf_features_from_features, - hf_transform_to_torch, - hw_to_dataset_features, ) from lerobot.datasets.video_utils import VALID_VIDEO_CODECS from lerobot.envs.factory import make_env_config @@ -1329,7 +1328,7 @@ def test_frames_in_current_file_calculation(tmp_path, empty_lerobot_dataset_fact dataset.finalize() - from lerobot.datasets.utils import load_episodes + from lerobot.datasets.io_utils import load_episodes dataset.meta.episodes = load_episodes(dataset.root) assert dataset.meta.episodes is not None diff --git a/tests/datasets/test_delta_timestamps.py b/tests/datasets/test_delta_timestamps.py index 72f69bc72..8d9529f68 100644 --- a/tests/datasets/test_delta_timestamps.py +++ b/tests/datasets/test_delta_timestamps.py @@ -13,7 +13,7 @@ # limitations under the License. import pytest -from lerobot.datasets.utils import ( +from lerobot.datasets.feature_utils import ( check_delta_timestamps, get_delta_indices, ) diff --git a/tests/datasets/test_sampler.py b/tests/datasets/test_sampler.py index a5d463349..18fb1c8ac 100644 --- a/tests/datasets/test_sampler.py +++ b/tests/datasets/test_sampler.py @@ -19,10 +19,10 @@ import pytest import torch from datasets import Dataset -from lerobot.datasets.sampler import EpisodeAwareSampler -from lerobot.datasets.utils import ( +from lerobot.datasets.io_utils import ( hf_transform_to_torch, ) +from lerobot.datasets.sampler import EpisodeAwareSampler def calculate_episode_data_index(hf_dataset: Dataset) -> dict[str, torch.Tensor]: diff --git a/tests/fixtures/dataset_factories.py b/tests/fixtures/dataset_factories.py index f8dd01fec..5ecb52145 100644 --- a/tests/fixtures/dataset_factories.py +++ b/tests/fixtures/dataset_factories.py @@ -26,7 +26,10 @@ import pytest import torch from datasets import Dataset -from lerobot.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset, LeRobotDatasetMetadata +from lerobot.datasets.dataset_metadata import CODEBASE_VERSION, LeRobotDatasetMetadata +from lerobot.datasets.feature_utils import get_hf_features_from_features +from lerobot.datasets.io_utils import hf_transform_to_torch +from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.utils import ( DEFAULT_CHUNK_SIZE, DEFAULT_DATA_FILE_SIZE_IN_MB, @@ -35,8 +38,6 @@ from lerobot.datasets.utils import ( DEFAULT_VIDEO_FILE_SIZE_IN_MB, DEFAULT_VIDEO_PATH, flatten_dict, - get_hf_features_from_features, - hf_transform_to_torch, ) from lerobot.datasets.video_utils import encode_video_frames from tests.fixtures.constants import ( @@ -453,8 +454,8 @@ def lerobot_dataset_metadata_factory( episodes=episodes, ) with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version_patch, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download_patch, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version_patch, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download_patch, ): mock_get_safe_version_patch.side_effect = lambda repo_id, version: version mock_snapshot_download_patch.side_effect = mock_snapshot_download diff --git a/tests/fixtures/files.py b/tests/fixtures/files.py index 11f3fa94a..92d9ca1e2 100644 --- a/tests/fixtures/files.py +++ b/tests/fixtures/files.py @@ -20,17 +20,19 @@ import pandas as pd import pytest from datasets import Dataset -from lerobot.datasets.utils import ( - DEFAULT_CHUNK_SIZE, - DEFAULT_DATA_FILE_SIZE_IN_MB, - DEFAULT_DATA_PATH, +from lerobot.datasets.io_utils import ( get_hf_dataset_size_in_mb, - update_chunk_file_indices, write_episodes, write_info, write_stats, write_tasks, ) +from lerobot.datasets.utils import ( + DEFAULT_CHUNK_SIZE, + DEFAULT_DATA_FILE_SIZE_IN_MB, + DEFAULT_DATA_PATH, + update_chunk_file_indices, +) def write_hf_dataset( diff --git a/tests/policies/test_policies.py b/tests/policies/test_policies.py index 1ba82ffd0..1aae3fcc8 100644 --- a/tests/policies/test_policies.py +++ b/tests/policies/test_policies.py @@ -28,7 +28,8 @@ from lerobot.configs.default import DatasetConfig from lerobot.configs.train import TrainPipelineConfig from lerobot.configs.types import FeatureType, PolicyFeature from lerobot.datasets.factory import make_dataset -from lerobot.datasets.utils import cycle, dataset_to_policy_features +from lerobot.datasets.feature_utils import dataset_to_policy_features +from lerobot.datasets.utils import cycle from lerobot.envs.factory import make_env, make_env_config from lerobot.envs.utils import preprocess_observation from lerobot.optim.factory import make_optimizer_and_scheduler diff --git a/tests/test_control_robot.py b/tests/test_control_robot.py index ace0aea49..772588467 100644 --- a/tests/test_control_robot.py +++ b/tests/test_control_robot.py @@ -71,8 +71,8 @@ def test_record_and_resume(tmp_path): cfg.resume = True # Mock the revision to prevent Hub calls during resume with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.return_value = str(tmp_path / "record") @@ -115,8 +115,8 @@ def test_record_and_replay(tmp_path): # Mock the revision to prevent Hub calls during replay with ( - patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, - patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download, ): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.return_value = str(tmp_path / "record_and_replay") From d9ec3a6fa266546b41326c989cca8c1314ecfc2e Mon Sep 17 00:00:00 2001 From: Khalil Meftah Date: Tue, 17 Mar 2026 18:33:53 +0100 Subject: [PATCH 117/131] Fix/earth rover dataset features (#3088) * docs(earthrover): update EarthRover Mini Plus dataset features and descriptions * refactor(teleop): rename rover action keys to linear_velocity/angular_velocity * fix(earthrover): align observation and action features with frodobots/berkeley-frodobots-lerobot-7k * chore: address PR review comments * ci: retrigger checks --- docs/source/earthrover_mini_plus.mdx | 18 +- .../robot_earthrover_mini_plus.py | 232 ++++++++++++------ .../teleoperators/keyboard/teleop_keyboard.py | 10 +- 3 files changed, 178 insertions(+), 82 deletions(-) diff --git a/docs/source/earthrover_mini_plus.mdx b/docs/source/earthrover_mini_plus.mdx index 7b739ecc1..884e84d8c 100644 --- a/docs/source/earthrover_mini_plus.mdx +++ b/docs/source/earthrover_mini_plus.mdx @@ -204,22 +204,26 @@ Replace `your_username/dataset_name` with your Hugging Face username and a name Your dataset includes: -**Your Actions (2 things)**: +**Your Actions (2 features)**: -- How much you moved forward/backward -- How much you turned left/right +- `linear_velocity`: How much you moved forward/backward +- `angular_velocity`: How much you turned left/right -**Robot Observations (12 things)**: +**Robot Observations (24 features)**: - Front camera video - Rear camera video - Current speed - Battery level -- Which way the robot is facing -- GPS location (latitude, longitude, signal strength) +- Orientation +- GPS (latitude, longitude, signal strength) - Network signal strength - Vibration level -- Lamp status (on/off) +- Lamp state (on/off) +- Accelerometer (x, y, z) +- Gyroscope (x, y, z) +- Magnetometer (x, y, z) +- Wheel RPMs (4 wheels) ### Where Your Data Goes diff --git a/src/lerobot/robots/earthrover_mini_plus/robot_earthrover_mini_plus.py b/src/lerobot/robots/earthrover_mini_plus/robot_earthrover_mini_plus.py index 299206a1e..76707a80c 100644 --- a/src/lerobot/robots/earthrover_mini_plus/robot_earthrover_mini_plus.py +++ b/src/lerobot/robots/earthrover_mini_plus/robot_earthrover_mini_plus.py @@ -33,21 +33,40 @@ from .config_earthrover_mini_plus import EarthRoverMiniPlusConfig logger = logging.getLogger(__name__) # Action feature keys -ACTION_LINEAR_VEL = "linear.vel" -ACTION_ANGULAR_VEL = "angular.vel" +ACTION_LINEAR_VEL = "linear_velocity" +ACTION_ANGULAR_VEL = "angular_velocity" -# Observation feature keys +# Observation feature keys — cameras OBS_FRONT = "front" OBS_REAR = "rear" -OBS_LINEAR_VEL = "linear.vel" -OBS_BATTERY_LEVEL = "battery.level" -OBS_ORIENTATION_DEG = "orientation.deg" -OBS_GPS_LATITUDE = "gps.latitude" -OBS_GPS_LONGITUDE = "gps.longitude" -OBS_GPS_SIGNAL = "gps.signal" -OBS_SIGNAL_LEVEL = "signal.level" + +# Observation feature keys — telemetry +OBS_SPEED = "speed" +OBS_BATTERY_LEVEL = "battery_level" +OBS_ORIENTATION = "orientation" +OBS_GPS_LATITUDE = "gps_latitude" +OBS_GPS_LONGITUDE = "gps_longitude" +OBS_GPS_SIGNAL = "gps_signal" +OBS_SIGNAL_LEVEL = "signal_level" OBS_VIBRATION = "vibration" -OBS_LAMP_STATE = "lamp.state" +OBS_LAMP = "lamp" + +# Observation feature keys — IMU sensors +OBS_ACCELEROMETER_X = "accelerometer_x" +OBS_ACCELEROMETER_Y = "accelerometer_y" +OBS_ACCELEROMETER_Z = "accelerometer_z" +OBS_GYROSCOPE_X = "gyroscope_x" +OBS_GYROSCOPE_Y = "gyroscope_y" +OBS_GYROSCOPE_Z = "gyroscope_z" +OBS_MAGNETOMETER_X = "magnetometer_filtered_x" +OBS_MAGNETOMETER_Y = "magnetometer_filtered_y" +OBS_MAGNETOMETER_Z = "magnetometer_filtered_z" + +# Observation feature keys — wheel RPMs +OBS_WHEEL_RPM_0 = "wheel_rpm_0" +OBS_WHEEL_RPM_1 = "wheel_rpm_1" +OBS_WHEEL_RPM_2 = "wheel_rpm_2" +OBS_WHEEL_RPM_3 = "wheel_rpm_3" class EarthRoverMiniPlus(Robot): @@ -154,33 +173,60 @@ class EarthRoverMiniPlus(Robot): dict: Observation features with types/shapes: - front: (480, 640, 3) - Front camera RGB image - rear: (480, 640, 3) - Rear camera RGB image - - linear.vel: float - Current speed (0-1, SDK reports only positive speeds) - - battery.level: float - Battery level (0-1, normalized from 0-100) - - orientation.deg: float - Robot orientation (0-1, normalized from raw value) - - gps.latitude: float - GPS latitude coordinate - - gps.longitude: float - GPS longitude coordinate - - gps.signal: float - GPS signal strength (0-1, normalized from percentage) - - signal.level: float - Network signal level (0-1, normalized from 0-5) + - speed: float - Current speed (raw SDK value) + - battery_level: float - Battery level (0-100) + - orientation: float - Robot orientation in degrees + - gps_latitude: float - GPS latitude coordinate + - gps_longitude: float - GPS longitude coordinate + - gps_signal: float - GPS signal strength (percentage) + - signal_level: float - Network signal level (0-5) - vibration: float - Vibration sensor reading - - lamp.state: float - Lamp state (0=off, 1=on) + - lamp: float - Lamp state (0=off, 1=on) + - accelerometer_x: float - Accelerometer X axis (raw SDK value) + - accelerometer_y: float - Accelerometer Y axis (raw SDK value) + - accelerometer_z: float - Accelerometer Z axis (raw SDK value) + - gyroscope_x: float - Gyroscope X axis (raw SDK value) + - gyroscope_y: float - Gyroscope Y axis (raw SDK value) + - gyroscope_z: float - Gyroscope Z axis (raw SDK value) + - magnetometer_filtered_x: float - Magnetometer X axis (raw SDK value) + - magnetometer_filtered_y: float - Magnetometer Y axis (raw SDK value) + - magnetometer_filtered_z: float - Magnetometer Z axis (raw SDK value) + - wheel_rpm_0: float - Wheel 0 RPM + - wheel_rpm_1: float - Wheel 1 RPM + - wheel_rpm_2: float - Wheel 2 RPM + - wheel_rpm_3: float - Wheel 3 RPM """ return { # Cameras (height, width, channels) OBS_FRONT: (480, 640, 3), OBS_REAR: (480, 640, 3), - # Motion state - OBS_LINEAR_VEL: float, - # Robot state + # Telemetry + OBS_SPEED: float, OBS_BATTERY_LEVEL: float, - OBS_ORIENTATION_DEG: float, - # GPS + OBS_ORIENTATION: float, OBS_GPS_LATITUDE: float, OBS_GPS_LONGITUDE: float, OBS_GPS_SIGNAL: float, - # Sensors OBS_SIGNAL_LEVEL: float, OBS_VIBRATION: float, - OBS_LAMP_STATE: float, + OBS_LAMP: float, + # IMU — accelerometer + OBS_ACCELEROMETER_X: float, + OBS_ACCELEROMETER_Y: float, + OBS_ACCELEROMETER_Z: float, + # IMU — gyroscope + OBS_GYROSCOPE_X: float, + OBS_GYROSCOPE_Y: float, + OBS_GYROSCOPE_Z: float, + # IMU — magnetometer + OBS_MAGNETOMETER_X: float, + OBS_MAGNETOMETER_Y: float, + OBS_MAGNETOMETER_Z: float, + # Wheel RPMs + OBS_WHEEL_RPM_0: float, + OBS_WHEEL_RPM_1: float, + OBS_WHEEL_RPM_2: float, + OBS_WHEEL_RPM_3: float, } @cached_property @@ -189,8 +235,8 @@ class EarthRoverMiniPlus(Robot): Returns: dict: Action features with types: - - linear.vel: float - Target linear velocity - - angular.vel: float - Target angular velocity + - linear_velocity: float - Target linear velocity (-1 to 1) + - angular_velocity: float - Target angular velocity (-1 to 1) """ return { ACTION_LINEAR_VEL: float, @@ -201,19 +247,29 @@ class EarthRoverMiniPlus(Robot): def get_observation(self) -> RobotObservation: """Get current robot observation from SDK. + Camera frames are retrieved from SDK endpoints /v2/front and /v2/rear. + Frames are decoded from base64 and converted from BGR to RGB format. + Robot telemetry is retrieved from /data endpoint. + Sensor arrays (accels, gyros, mags, rpms) each contain entries of + [values..., timestamp]; the latest reading from each array is used. + Returns: RobotObservation: Observation containing: - front: Front camera image (480, 640, 3) in RGB format - rear: Rear camera image (480, 640, 3) in RGB format - - linear.vel: Current speed (0-1, SDK reports only positive speeds) - - battery.level: Battery level (0-1, normalized from 0-100) - - orientation.deg: Robot orientation (0-1, normalized from raw value) - - gps.latitude: GPS latitude coordinate - - gps.longitude: GPS longitude coordinate - - gps.signal: GPS signal strength (0-1, normalized from percentage) - - signal.level: Network signal level (0-1, normalized from 0-5) - - vibration: Vibration sensor reading - - lamp.state: Lamp state (0=off, 1=on) + - speed: float - Current speed (raw SDK value) + - battery_level: float - Battery level (0-100) + - orientation: float - Robot orientation in degrees + - gps_latitude: float - GPS latitude coordinate + - gps_longitude: float - GPS longitude coordinate + - gps_signal: float - GPS signal strength (percentage) + - signal_level: float - Network signal level (0-5) + - vibration: float - Vibration sensor reading + - lamp: float - Lamp state (0=off, 1=on) + - accelerometer_x/y/z: float - Accelerometer axes (raw SDK value) + - gyroscope_x/y/z: float - Gyroscope axes (raw SDK value) + - magnetometer_filtered_x/y/z: float - Magnetometer axes (raw SDK value) + - wheel_rpm_0/1/2/3: float - Wheel RPMs Raises: DeviceNotConnectedError: If robot is not connected @@ -235,22 +291,41 @@ class EarthRoverMiniPlus(Robot): # Get robot state from SDK robot_data = self._get_robot_data() - # Motion state - observation[OBS_LINEAR_VEL] = robot_data["speed"] / 100.0 # Normalize 0-100 to 0-1 + # Telemetry + observation[OBS_SPEED] = float(robot_data["speed"]) + observation[OBS_BATTERY_LEVEL] = float(robot_data["battery"]) + observation[OBS_ORIENTATION] = float(robot_data["orientation"]) + observation[OBS_GPS_LATITUDE] = float(robot_data["latitude"]) + observation[OBS_GPS_LONGITUDE] = float(robot_data["longitude"]) + observation[OBS_GPS_SIGNAL] = float(robot_data["gps_signal"]) + observation[OBS_SIGNAL_LEVEL] = float(robot_data["signal_level"]) + observation[OBS_VIBRATION] = float(robot_data["vibration"]) + observation[OBS_LAMP] = float(robot_data["lamp"]) - # Robot state - observation[OBS_BATTERY_LEVEL] = robot_data["battery"] / 100.0 # Normalize 0-100 to 0-1 - observation[OBS_ORIENTATION_DEG] = robot_data["orientation"] / 360.0 # Normalize to 0-1 + # Accelerometer — latest reading from accels array [x, y, z, ts] + accel = self._latest_sensor_reading(robot_data, "accels", n_values=3) + observation[OBS_ACCELEROMETER_X] = accel[0] + observation[OBS_ACCELEROMETER_Y] = accel[1] + observation[OBS_ACCELEROMETER_Z] = accel[2] - # GPS data - observation[OBS_GPS_LATITUDE] = robot_data["latitude"] - observation[OBS_GPS_LONGITUDE] = robot_data["longitude"] - observation[OBS_GPS_SIGNAL] = robot_data["gps_signal"] / 100.0 # Normalize percentage to 0-1 + # Gyroscope — latest reading from gyros array [x, y, z, ts] + gyro = self._latest_sensor_reading(robot_data, "gyros", n_values=3) + observation[OBS_GYROSCOPE_X] = gyro[0] + observation[OBS_GYROSCOPE_Y] = gyro[1] + observation[OBS_GYROSCOPE_Z] = gyro[2] - # Sensors - observation[OBS_SIGNAL_LEVEL] = robot_data["signal_level"] / 5.0 # Normalize 0-5 to 0-1 - observation[OBS_VIBRATION] = robot_data["vibration"] - observation[OBS_LAMP_STATE] = float(robot_data["lamp"]) # 0 or 1 + # Magnetometer — latest reading from mags array [x, y, z, ts] + mag = self._latest_sensor_reading(robot_data, "mags", n_values=3) + observation[OBS_MAGNETOMETER_X] = mag[0] + observation[OBS_MAGNETOMETER_Y] = mag[1] + observation[OBS_MAGNETOMETER_Z] = mag[2] + + # Wheel RPMs — latest reading from rpms array [w0, w1, w2, w3, ts] + rpm = self._latest_sensor_reading(robot_data, "rpms", n_values=4) + observation[OBS_WHEEL_RPM_0] = rpm[0] + observation[OBS_WHEEL_RPM_1] = rpm[1] + observation[OBS_WHEEL_RPM_2] = rpm[2] + observation[OBS_WHEEL_RPM_3] = rpm[3] return observation @@ -260,11 +335,12 @@ class EarthRoverMiniPlus(Robot): Args: action: Action dict with keys: - - linear.vel: Target linear velocity (-1 to 1) - - angular.vel: Target angular velocity (-1 to 1) + - linear_velocity: Target linear velocity (-1 to 1) + - angular_velocity: Target angular velocity (-1 to 1) Returns: RobotAction: The action that was sent (matches action_features keys) + Raises: DeviceNotConnectedError: If robot is not connected @@ -272,18 +348,14 @@ class EarthRoverMiniPlus(Robot): Actions are sent to SDK via POST /control endpoint. SDK expects commands in range [-1, 1]. """ - - # Extract action values and convert to float linear = float(action.get(ACTION_LINEAR_VEL, 0.0)) angular = float(action.get(ACTION_ANGULAR_VEL, 0.0)) - # Send command to SDK try: self._send_command_to_sdk(linear, angular) except Exception as e: logger.error(f"Error sending action: {e}") - # Return action in format matching action_features return { ACTION_LINEAR_VEL: linear, ACTION_ANGULAR_VEL: angular, @@ -394,11 +466,27 @@ class EarthRoverMiniPlus(Robot): logger.error(f"Error decoding image: {e}") return None + @staticmethod + def _latest_sensor_reading(robot_data: dict, key: str, n_values: int) -> list[float]: + """Extract the latest sensor reading from an SDK sensor array. + + The SDK returns sensor arrays like ``accels``, ``gyros``, ``mags``, + ``rpms`` where each entry is ``[value_0, ..., value_n, timestamp]``. + This helper returns the *n_values* leading floats from the last entry, + falling back to zeros when the key is missing or the array is empty. + """ + readings = robot_data.get(key) + if readings and len(readings) > 0: + latest = readings[-1] + return [float(v) for v in latest[:n_values]] + return [0.0] * n_values + def _get_robot_data(self) -> dict: """Get robot telemetry data from SDK. Returns: - dict: Robot telemetry data including battery, speed, orientation, GPS, etc: + dict: Robot telemetry data including battery, speed, orientation, GPS, + and sensor arrays (accels, gyros, mags, rpms): - Current data (if request succeeds) - Cached data (if request fails but cache exists) - Default values (if request fails and no cache exists yet) @@ -420,19 +508,23 @@ class EarthRoverMiniPlus(Robot): # Fallback: use cache or default values if self._last_robot_data is not None: return self._last_robot_data - else: - # Return dict with default values (used only on first failure before any cache exists) - return { - "speed": 0, - "battery": 0, - "orientation": 0, - "latitude": 0.0, - "longitude": 0.0, - "gps_signal": 0, - "signal_level": 0, - "vibration": 0.0, - "lamp": 0, - } + + # Return dict with default values (used only on first failure before any cache exists) + return { + "speed": 0, + "battery": 0, + "orientation": 0, + "latitude": 0.0, + "longitude": 0.0, + "gps_signal": 0, + "signal_level": 0, + "vibration": 0.0, + "lamp": 0, + "accels": [], + "gyros": [], + "mags": [], + "rpms": [], + } def _send_command_to_sdk(self, linear: float, angular: float, lamp: int = 0) -> bool: """Send control command to SDK. diff --git a/src/lerobot/teleoperators/keyboard/teleop_keyboard.py b/src/lerobot/teleoperators/keyboard/teleop_keyboard.py index 6c1ef7492..090aa7fae 100644 --- a/src/lerobot/teleoperators/keyboard/teleop_keyboard.py +++ b/src/lerobot/teleoperators/keyboard/teleop_keyboard.py @@ -341,8 +341,8 @@ class KeyboardRoverTeleop(KeyboardTeleop): def action_features(self) -> dict: """Return action format for rover (linear and angular velocities).""" return { - "linear.vel": float, - "angular.vel": float, + "linear_velocity": float, + "angular_velocity": float, } @property @@ -366,7 +366,7 @@ class KeyboardRoverTeleop(KeyboardTeleop): Get the current action based on pressed keys. Returns: - RobotAction with 'linear.vel' and 'angular.vel' keys + RobotAction with 'linear_velocity' and 'angular_velocity' keys. """ before_read_t = time.perf_counter() @@ -427,6 +427,6 @@ class KeyboardRoverTeleop(KeyboardTeleop): self.logs["read_pos_dt_s"] = time.perf_counter() - before_read_t return { - "linear.vel": linear_velocity, - "angular.vel": angular_velocity, + "linear_velocity": linear_velocity, + "angular_velocity": angular_velocity, } From e64fa667c3e3b8db8ad0bd16853d293d6de9c50a Mon Sep 17 00:00:00 2001 From: Altman <64389901+Altman-conquer@users.noreply.github.com> Date: Wed, 18 Mar 2026 20:24:07 +0800 Subject: [PATCH 118/131] fix(vqbet): use in-place fill_ to avoid overwriting DDP GPU buffers with CPU tensors (#3128) * fix(vqbet): use in-place fill_ to avoid overwriting DDP GPU buffers with CPU tensors When VQ discretization phase completes, the code was overwriting register_buffer('discretized') and register_buffer('freeze_codebook') with torch.tensor(True), which is created on CPU. DDP then fails in _sync_buffers() with: RuntimeError: No backend type associated with device type cpu. Fix by updating the buffers in-place with .fill_(True) so device and registration are preserved. Made-with: Cursor * test(vqbet): add regression test for in-place buffer update during discretization Verifies that discretize() updates the 'discretized' and 'freeze_codebook' registered buffers in-place (via fill_()) rather than replacing them with new CPU tensors. The test checks data_ptr() identity and that the tensors remain registered buffers after the call. This prevents regressions of the DDP fix. Made-with: Cursor * test(vqbet): add GPU regression test to verify buffers stay on CUDA after discretize() Directly catches the original DDP failure mode: when buffers are replaced with torch.tensor(True) they land on CPU, causing NCCL to raise 'No backend type associated with device type cpu' in _sync_buffers(). The GPU test places the model on cuda:0 and asserts both buffers remain on CUDA after discretization. Made-with: Cursor * test(vqbet): simplify to single device-check test in test_policies.py Per reviewer feedback: remove the separate test file and replace the two CPU/GPU tests (with data_ptr checks) with a single focused test in tests/policies/test_policies.py that only asserts the registered buffers remain on the model device after discretize(). Uses DEVICE from tests/utils.py so it runs on whatever device the CI/user selects (cpu, cuda, mps). Made-with: Cursor * style: fix import order in test_policies.py to pass ruff/pre-commit checks Made-with: Cursor --------- Co-authored-by: Zhan DiJia <2476100824@example.com> Co-authored-by: Khalil Meftah --- src/lerobot/policies/vqbet/modeling_vqbet.py | 4 +- tests/policies/test_policies.py | 44 ++++++++++++++++++++ 2 files changed, 46 insertions(+), 2 deletions(-) diff --git a/src/lerobot/policies/vqbet/modeling_vqbet.py b/src/lerobot/policies/vqbet/modeling_vqbet.py index 359b4fdb1..6d3976b79 100644 --- a/src/lerobot/policies/vqbet/modeling_vqbet.py +++ b/src/lerobot/policies/vqbet/modeling_vqbet.py @@ -467,8 +467,8 @@ class VQBeTHead(nn.Module): self.vqvae_model.optimized_steps += 1 # if we updated RVQ more than `n_vqvae_training_steps` steps, we freeze the RVQ part. if self.vqvae_model.optimized_steps >= n_vqvae_training_steps: - self.vqvae_model.discretized = torch.tensor(True) - self.vqvae_model.vq_layer.freeze_codebook = torch.tensor(True) + self.vqvae_model.discretized.fill_(True) + self.vqvae_model.vq_layer.freeze_codebook.fill_(True) print("Finished discretizing action data!") self.vqvae_model.eval() for param in self.vqvae_model.vq_layer.parameters(): diff --git a/tests/policies/test_policies.py b/tests/policies/test_policies.py index 1aae3fcc8..77a74d60e 100644 --- a/tests/policies/test_policies.py +++ b/tests/policies/test_policies.py @@ -42,6 +42,8 @@ from lerobot.policies.factory import ( make_pre_post_processors, ) from lerobot.policies.pretrained import PreTrainedPolicy +from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig +from lerobot.policies.vqbet.modeling_vqbet import VQBeTHead from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE from lerobot.utils.random_utils import seeded_context from tests.artifacts.policies.save_policy_to_safetensors import get_policy_stats @@ -460,3 +462,45 @@ def test_act_temporal_ensembler(): assert torch.all(offline_avg <= einops.reduce(seq_slice, "b s 1 -> b 1", "max")) # Selected atol=1e-4 keeping in mind actions in [-1, 1] and excepting 0.01% error. torch.testing.assert_close(online_avg, offline_avg, rtol=1e-4, atol=1e-4) + + +def test_vqbet_discretize_keeps_buffers_on_device(): + """Regression test: VQBeTHead.discretize() must not move registered buffers off the model device. + + Previously, `self.vqvae_model.discretized = torch.tensor(True)` replaced the + registered buffer with a new CPU tensor, causing DDP to crash with: + RuntimeError: No backend type associated with device type cpu + The fix uses `.fill_(True)` to update in-place, preserving device placement. + """ + config = VQBeTConfig() + config.input_features = { + OBS_IMAGES: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 96, 96)), + OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(6,)), + } + config.output_features = { + ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(6,)), + } + # Tiny sizes for fast CPU/GPU execution. + config.n_vqvae_training_steps = 3 + config.vqvae_n_embed = 8 + config.vqvae_embedding_dim = 32 + config.vqvae_enc_hidden_dim = 32 + config.action_chunk_size = 2 + config.crop_shape = (84, 84) + + head = VQBeTHead(config).to(DEVICE) + vqvae = head.vqvae_model + + dummy_actions = torch.randn(4, config.action_chunk_size, config.action_feature.shape[0], device=DEVICE) + n_steps = config.n_vqvae_training_steps + for _ in range(n_steps): + head.discretize(n_steps, dummy_actions) + + assert vqvae.discretized.device.type == torch.device(DEVICE).type, ( + "vqvae_model.discretized was moved off the model device after discretize(). " + "Use .fill_(True) instead of = torch.tensor(True) to keep the buffer on device." + ) + assert vqvae.vq_layer.freeze_codebook.device.type == torch.device(DEVICE).type, ( + "vq_layer.freeze_codebook was moved off the model device after discretize(). " + "Use .fill_(True) instead of = torch.tensor(True) to keep the buffer on device." + ) From f90db58c15a998a8ae37780ccd255bbde0130e00 Mon Sep 17 00:00:00 2001 From: Praedico Date: Fri, 20 Mar 2026 06:32:07 +0100 Subject: [PATCH 119/131] docs(async): fix GitHub issues link (#3186) --- docs/source/async.mdx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/async.mdx b/docs/source/async.mdx index fcc3f1d1e..a46408a0d 100644 --- a/docs/source/async.mdx +++ b/docs/source/async.mdx @@ -310,4 +310,4 @@ Asynchronous inference represents a significant advancement in real-time robotic - **Universal Compatibility**: Works with all LeRobot-supported policies, from lightweight ACT models to vision-language models like SmolVLA Start experimenting with the default parameters, monitor your action queue sizes, and iteratively refine your setup to achieve optimal performance for your specific use case. -If you want to discuss this further, hop into our [Discord community](https://discord.gg/s3KuuzsPFb), or open an issue on our [GitHub repository](https://github.com/lerobot/lerobot/issues). +If you want to discuss this further, hop into our [Discord community](https://discord.gg/s3KuuzsPFb), or open an issue on our [GitHub repository](https://github.com/huggingface/lerobot/issues). From 017ff73fbfe46bf9a673cd9b402988dcb79151f7 Mon Sep 17 00:00:00 2001 From: Jade Choghari Date: Mon, 23 Mar 2026 13:57:53 -0700 Subject: [PATCH 120/131] chore(docs): add rename map and empty cam guide (#3065) * add blog/guide * add to tree * chore(docs): rephrase rename_map docs for clarity and simplicity --------- Co-authored-by: Steven Palma Co-authored-by: Steven Palma --- docs/source/_toctree.yml | 2 + docs/source/rename_map.mdx | 114 +++++++++++++++++++++++++++++++++++++ 2 files changed, 116 insertions(+) create mode 100644 docs/source/rename_map.mdx diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 1055975d7..09d94d28c 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -19,6 +19,8 @@ title: Multi GPU training - local: peft_training title: Training with PEFT (e.g., LoRA) + - local: rename_map + title: Using Rename Map and Empty Cameras title: "Tutorials" - sections: - local: lerobot-dataset-v3 diff --git a/docs/source/rename_map.mdx b/docs/source/rename_map.mdx new file mode 100644 index 000000000..6249faaca --- /dev/null +++ b/docs/source/rename_map.mdx @@ -0,0 +1,114 @@ +# Rename Map and Empty Cameras + +When you train, evaluate, or record with a robot policy, your **dataset** or **environment** provides observations under one set of keys (e.g. `observation.images.front`, `observation.images.eagle`), while your **policy** expects another (e.g. `observation.images.image`, `observation.images.image2`). The **rename map** bridges that gap without changing the policy or data source. + +> **Scope:** The rename map only renames **observation** keys (images and state). Action keys are not affected. + +## Why observation keys don't always match + +Policies have a fixed set of **input feature names** baked into their pretrained config. For example: + +- [pi0fast-libero](https://huggingface.co/lerobot/pi0fast-libero) expects `observation.images.base_0_rgb` and `observation.images.left_wrist_0_rgb`. +- [xvla-base](https://huggingface.co/lerobot/xvla-base) expects `observation.images.image`, `observation.images.image2`, and `observation.images.image3`. + +Your dataset might use different names entirely (e.g. `observation.images.front`, `observation.images.eagle`, `observation.images.glove`), and your eval environment might use yet another set. Rather than editing the policy config or renaming columns in the dataset, you pass a **rename map**: a JSON dictionary that maps source keys to the keys the policy expects. Renaming happens inside the preprocessor pipeline, so the policy always sees its expected keys. + +## Using the rename map + +Pass the mapping as a JSON string on the command line. The convention is always: + +``` +--rename_map='{"source_key": "policy_key", ...}' +``` + +where **source_key** is what the dataset or environment provides, and **policy_key** is what the policy expects. + +Only listed keys are renamed; everything else passes through unchanged. Order of entries doesn't matter. + +Supported policies: **PI0**, **PI05**, **PI0Fast**, **SmolVLA**, and **XVLA**. + +### Training + +Suppose you fine-tune [lerobot/xvla-base](https://huggingface.co/lerobot/xvla-base) on a dataset with images under `observation.images.front`, `observation.images.eagle`, and `observation.images.glove`. XVLA expects `observation.images.image`, `observation.images.image2`, and `observation.images.image3`: + +```bash +lerobot-train \ + --dataset.repo_id=YOUR_DATASET \ + --output_dir=./outputs/xvla_training \ + --job_name=xvla_training \ + --policy.path="lerobot/xvla-base" \ + --policy.repo_id="HF_USER/xvla-your-robot" \ + --policy.dtype=bfloat16 \ + --policy.action_mode=auto \ + --steps=20000 \ + --policy.device=cuda \ + --policy.freeze_vision_encoder=false \ + --policy.freeze_language_encoder=false \ + --policy.train_policy_transformer=true \ + --policy.train_soft_prompts=true \ + --rename_map='{"observation.images.front": "observation.images.image", "observation.images.eagle": "observation.images.image2", "observation.images.glove": "observation.images.image3"}' +``` + +### Evaluation + +A policy that expects `observation.images.base_0_rgb` and `observation.images.left_wrist_0_rgb` (e.g. [pi0fast-libero](https://huggingface.co/lerobot/pi0fast-libero)), but the LIBERO environment returns `observation.images.image` and `observation.images.image2`: + +```bash +lerobot-eval \ + --policy.path=lerobot/pi0fast-libero \ + --env.type=libero \ + ... \ + --rename_map='{"observation.images.image": "observation.images.base_0_rgb", "observation.images.image2": "observation.images.left_wrist_0_rgb"}' +``` + +### Recording + +`lerobot-record` also supports rename maps, nested under the dataset config: + +```bash +lerobot-record \ # When running inference + --policy.path="/smolVLA_finetuned" \ + ... \ + --dataset.rename_map='{"observation.images.glove2": "observation.images.image"}' +``` + +## Alternative: edit the policy config directly + +If you always use the same dataset or environment, you can **edit the policy's `config.json`** so its observation keys match your data source. Then no rename map is needed. + +The tradeoff: modifying the policy config ties it to one data source. A rename map keeps one policy usable across many datasets and environments. + +## Empty cameras: fewer views than the policy expects + +Some policies are built for a fixed number of image inputs. If your dataset has fewer cameras, you can set **`empty_cameras`** in the policy config instead of modifying the model architecture. + +### How it works + +Setting `empty_cameras=N` adds N placeholder image features to the policy config, named: + +``` +observation.images.empty_camera_0 +observation.images.empty_camera_1 +... +``` + +At runtime, these keys have no corresponding data in the batch. The policy fills them with masked dummy tensors (padded with `-1` for SigLIP-based vision encoders, with a zero attention mask), so the extra image slots are effectively ignored during training and inference. + +### Example + +XVLA-base has three visual inputs and `empty_cameras=0` by default. Your dataset only has two cameras: + +1. Set `--policy.empty_cameras=1`. +2. The config adds a third key: `observation.images.empty_camera_0`. +3. Use the rename map for your two real cameras as usual. +4. The third slot is masked out — no fake images needed in your dataset. + +## Quick reference + +| Goal | What to do | +| ----------------------------------------- | --------------------------------------------------------------------------- | +| Dataset keys ≠ policy keys | `--rename_map='{"dataset_key": "policy_key", ...}'` | +| Env keys ≠ policy keys (eval) | `--rename_map='{"env_key": "policy_key", ...}'` | +| Recording with different keys (inference) | `--dataset.rename_map='{"source_key": "policy_key", ...}'`. | +| Fewer cameras than policy expects | `--policy.empty_cameras=N` (supported by PI0, PI05, PI0Fast, SmolVLA, XVLA) | +| Avoid passing a rename map | Edit the policy's `config.json` so its keys match your data source | From 123495250b029f5f4bc4d8c91f8ac705a7e18426 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Thu, 26 Mar 2026 19:09:25 +0100 Subject: [PATCH 121/131] refactor(dataset): split LeRobotDataset into DatasetReader & DatasetWriter (+ API cleanup) (#3180) * refactor(dataset): split reader and writer * chore(dataset): remove proxys * refactor(dataset): better reader & writer encapsulation * refactor(datasets): clean API + reduce leaky implementations * refactor(dataset): API cleaning for writer, reader and meta * refactor(dataset): expose writer & reader + other minor improvements * refactor(dataset): improve teardown routine * refactor(dataset): add hf_dataset property at the facade level * chore(dataset): add init for datasset module * docs(dataset): add docstrings for public API of the dataset classes * tests(dataset): add tests for new classes * fix(dataset): remove circular dependecy --- docs/source/il_robots.mdx | 2 +- examples/backward_compatibility/replay.py | 2 +- examples/dataset/load_lerobot_dataset.py | 5 +- examples/lekiwi/replay.py | 6 +- examples/phone_to_so100/replay.py | 6 +- examples/so100_to_so100_EE/replay.py | 6 +- src/lerobot/datasets/__init__.py | 33 + src/lerobot/datasets/dataset_metadata.py | 172 ++- src/lerobot/datasets/dataset_reader.py | 288 ++++ src/lerobot/datasets/dataset_tools.py | 2 +- src/lerobot/datasets/dataset_writer.py | 625 ++++++++ src/lerobot/datasets/image_writer.py | 6 +- src/lerobot/datasets/lerobot_dataset.py | 1375 ++++++----------- src/lerobot/datasets/multi_dataset.py | 9 +- src/lerobot/datasets/video_utils.py | 50 +- src/lerobot/rl/buffer.py | 6 +- src/lerobot/rl/gym_manipulator.py | 3 +- src/lerobot/scripts/lerobot_record.py | 13 +- src/lerobot/scripts/lerobot_replay.py | 6 +- .../scripts/lerobot_train_tokenizer.py | 8 +- .../policies/save_policy_to_safetensors.py | 2 +- tests/datasets/test_dataset_metadata.py | 385 +++++ tests/datasets/test_dataset_reader.py | 168 ++ tests/datasets/test_dataset_writer.py | 226 +++ tests/datasets/test_datasets.py | 170 +- tests/datasets/test_image_writer.py | 8 +- tests/datasets/test_lerobot_dataset.py | 314 ++++ .../datasets/test_streaming_video_encoder.py | 4 +- 28 files changed, 2742 insertions(+), 1158 deletions(-) create mode 100644 src/lerobot/datasets/__init__.py create mode 100644 src/lerobot/datasets/dataset_reader.py create mode 100644 src/lerobot/datasets/dataset_writer.py create mode 100644 tests/datasets/test_dataset_metadata.py create mode 100644 tests/datasets/test_dataset_reader.py create mode 100644 tests/datasets/test_dataset_writer.py create mode 100644 tests/datasets/test_lerobot_dataset.py diff --git a/docs/source/il_robots.mdx b/docs/source/il_robots.mdx index 245634382..8e50a2aec 100644 --- a/docs/source/il_robots.mdx +++ b/docs/source/il_robots.mdx @@ -424,7 +424,7 @@ robot = SO100Follower(robot_config) robot.connect() dataset = LeRobotDataset("/", episodes=[episode_idx]) -actions = dataset.hf_dataset.select_columns("action") +actions = dataset.select_columns("action") log_say(f"Replaying episode {episode_idx}") for idx in range(dataset.num_frames): diff --git a/examples/backward_compatibility/replay.py b/examples/backward_compatibility/replay.py index 13fdfd5f5..e999b5913 100644 --- a/examples/backward_compatibility/replay.py +++ b/examples/backward_compatibility/replay.py @@ -78,7 +78,7 @@ def replay(cfg: ReplayConfig): robot = make_robot_from_config(cfg.robot) dataset = LeRobotDataset(cfg.dataset.repo_id, root=cfg.dataset.root, episodes=[cfg.dataset.episode]) - actions = dataset.hf_dataset.select_columns(ACTION) + actions = dataset.select_columns(ACTION) robot.connect() try: diff --git a/examples/dataset/load_lerobot_dataset.py b/examples/dataset/load_lerobot_dataset.py index ea3516710..44ae21a11 100644 --- a/examples/dataset/load_lerobot_dataset.py +++ b/examples/dataset/load_lerobot_dataset.py @@ -88,9 +88,8 @@ def main(): # The previous metadata class is contained in the 'meta' attribute of the dataset: print(dataset.meta) - # LeRobotDataset actually wraps an underlying Hugging Face dataset - # (see https://huggingface.co/docs/datasets for more information). - print(dataset.hf_dataset) + # You can inspect the dataset using its repr: + print(dataset) # LeRobot datasets also subclasses PyTorch datasets so you can do everything you know and love from working # with the latter, like iterating through the dataset. diff --git a/examples/lekiwi/replay.py b/examples/lekiwi/replay.py index cf89aea16..0cfd4811c 100644 --- a/examples/lekiwi/replay.py +++ b/examples/lekiwi/replay.py @@ -35,9 +35,7 @@ def main(): # Fetch the dataset to replay dataset = LeRobotDataset("/", episodes=[EPISODE_IDX]) - # Filter dataset to only include frames from the specified episode since episodes are chunked in dataset V3.0 - episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == EPISODE_IDX) - actions = episode_frames.select_columns(ACTION) + actions = dataset.select_columns(ACTION) # Connect to the robot robot.connect() @@ -48,7 +46,7 @@ def main(): print("Starting replay loop...") log_say(f"Replaying episode {EPISODE_IDX}") - for idx in range(len(episode_frames)): + for idx in range(dataset.num_frames): t0 = time.perf_counter() # Get recorded action from dataset diff --git a/examples/phone_to_so100/replay.py b/examples/phone_to_so100/replay.py index 7b955cdb7..c544614a7 100644 --- a/examples/phone_to_so100/replay.py +++ b/examples/phone_to_so100/replay.py @@ -67,9 +67,7 @@ def main(): # Fetch the dataset to replay dataset = LeRobotDataset(HF_REPO_ID, episodes=[EPISODE_IDX]) - # Filter dataset to only include frames from the specified episode since episodes are chunked in dataset V3.0 - episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == EPISODE_IDX) - actions = episode_frames.select_columns(ACTION) + actions = dataset.select_columns(ACTION) # Connect to the robot robot.connect() @@ -80,7 +78,7 @@ def main(): print("Starting replay loop...") log_say(f"Replaying episode {EPISODE_IDX}") - for idx in range(len(episode_frames)): + for idx in range(dataset.num_frames): t0 = time.perf_counter() # Get recorded action from dataset diff --git a/examples/so100_to_so100_EE/replay.py b/examples/so100_to_so100_EE/replay.py index b042e02dd..7caa560f0 100644 --- a/examples/so100_to_so100_EE/replay.py +++ b/examples/so100_to_so100_EE/replay.py @@ -68,9 +68,7 @@ def main(): # Fetch the dataset to replay dataset = LeRobotDataset(HF_REPO_ID, episodes=[EPISODE_IDX]) - # Filter dataset to only include frames from the specified episode since episodes are chunked in dataset V3.0 - episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == EPISODE_IDX) - actions = episode_frames.select_columns(ACTION) + actions = dataset.select_columns(ACTION) # Connect to the robot robot.connect() @@ -81,7 +79,7 @@ def main(): print("Starting replay loop...") log_say(f"Replaying episode {EPISODE_IDX}") - for idx in range(len(episode_frames)): + for idx in range(dataset.num_frames): t0 = time.perf_counter() # Get recorded action from dataset diff --git a/src/lerobot/datasets/__init__.py b/src/lerobot/datasets/__init__.py new file mode 100644 index 000000000..42c4ab810 --- /dev/null +++ b/src/lerobot/datasets/__init__.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata +from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.datasets.multi_dataset import MultiLeRobotDataset +from lerobot.datasets.sampler import EpisodeAwareSampler +from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset +from lerobot.datasets.transforms import ImageTransforms, ImageTransformsConfig + +__all__ = [ + "EpisodeAwareSampler", + "ImageTransforms", + "ImageTransformsConfig", + "LeRobotDataset", + "LeRobotDatasetMetadata", + "MultiLeRobotDataset", + "StreamingLeRobotDataset", +] diff --git a/src/lerobot/datasets/dataset_metadata.py b/src/lerobot/datasets/dataset_metadata.py index 560a90a6e..a43ba07b4 100644 --- a/src/lerobot/datasets/dataset_metadata.py +++ b/src/lerobot/datasets/dataset_metadata.py @@ -13,6 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import contextlib from pathlib import Path import numpy as np @@ -53,6 +54,13 @@ CODEBASE_VERSION = "v3.0" class LeRobotDatasetMetadata: + """Metadata container for a LeRobot dataset. + + Manages the ``info.json``, ``stats.json``, ``tasks.parquet``, and + ``episodes/`` parquet files that describe a dataset's structure, content, + and statistics. + """ + def __init__( self, repo_id: str, @@ -61,33 +69,51 @@ class LeRobotDatasetMetadata: force_cache_sync: bool = False, metadata_buffer_size: int = 10, ): + """Load or download metadata for an existing LeRobot dataset. + + Attempts to load metadata from local disk. If files are missing or + ``force_cache_sync`` is ``True``, downloads the ``meta/`` directory from + the Hub. + + Args: + repo_id: Repository identifier (e.g. ``'lerobot/aloha_sim'``). + root: Local directory for the dataset. Defaults to + ``$HF_LEROBOT_HOME/{repo_id}``. + revision: Git revision (branch, tag, or commit hash). Defaults to + the current codebase version. + force_cache_sync: If ``True``, re-download metadata from the Hub + even when local files exist. + metadata_buffer_size: Number of episode metadata records to buffer + in memory before flushing to parquet. + """ self.repo_id = repo_id self.revision = revision if revision else CODEBASE_VERSION self.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id - self.writer = None + self._pq_writer = None self.latest_episode = None - self.metadata_buffer: list[dict] = [] - self.metadata_buffer_size = metadata_buffer_size + self._metadata_buffer: list[dict] = [] + self._metadata_buffer_size = metadata_buffer_size + self._finalized = False try: if force_cache_sync: raise FileNotFoundError - self.load_metadata() + self._load_metadata() except (FileNotFoundError, NotADirectoryError): if is_valid_version(self.revision): self.revision = get_safe_version(self.repo_id, self.revision) (self.root / "meta").mkdir(exist_ok=True, parents=True) - self.pull_from_repo(allow_patterns="meta/") - self.load_metadata() + self._pull_from_repo(allow_patterns="meta/") + self._load_metadata() def _flush_metadata_buffer(self) -> None: """Write all buffered episode metadata to parquet file.""" - if not hasattr(self, "metadata_buffer") or len(self.metadata_buffer) == 0: + if not hasattr(self, "_metadata_buffer") or len(self._metadata_buffer) == 0: return combined_dict = {} - for episode_dict in self.metadata_buffer: + for episode_dict in self._metadata_buffer: for key, value in episode_dict.items(): if key not in combined_dict: combined_dict[key] = [] @@ -96,40 +122,50 @@ class LeRobotDatasetMetadata: val = value[0] if isinstance(value, list) else value combined_dict[key].append(val.tolist() if isinstance(val, np.ndarray) else val) - first_ep = self.metadata_buffer[0] + first_ep = self._metadata_buffer[0] chunk_idx = first_ep["meta/episodes/chunk_index"][0] file_idx = first_ep["meta/episodes/file_index"][0] table = pa.Table.from_pydict(combined_dict) - if not self.writer: + if not self._pq_writer: path = Path(self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx)) path.parent.mkdir(parents=True, exist_ok=True) - self.writer = pq.ParquetWriter( + self._pq_writer = pq.ParquetWriter( path, schema=table.schema, compression="snappy", use_dictionary=True ) - self.writer.write_table(table) + self._pq_writer.write_table(table) - self.latest_episode = self.metadata_buffer[-1] - self.metadata_buffer.clear() + self.latest_episode = self._metadata_buffer[-1] + self._metadata_buffer.clear() def _close_writer(self) -> None: """Close and cleanup the parquet writer if it exists.""" self._flush_metadata_buffer() - writer = getattr(self, "writer", None) + writer = getattr(self, "_pq_writer", None) if writer is not None: writer.close() - self.writer = None + self._pq_writer = None + + def finalize(self) -> None: + """Flush metadata buffer and close the parquet writer. + + Idempotent — safe to call multiple times. + """ + if getattr(self, "_finalized", False): + return + self._close_writer() + self._finalized = True def __del__(self): - """ - Trust the user to call .finalize() but as an added safety check call the parquet writer to stop when calling the destructor - """ - self._close_writer() + """Safety net: flush and close parquet writer on garbage collection.""" + # During interpreter shutdown, referenced objects may already be collected. + with contextlib.suppress(Exception): + self.finalize() - def load_metadata(self): + def _load_metadata(self): self.info = load_info(self.root) check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION) self.tasks = load_tasks(self.root) @@ -137,7 +173,7 @@ class LeRobotDatasetMetadata: self.episodes = load_episodes(self.root) self.stats = load_stats(self.root) - def pull_from_repo( + def _pull_from_repo( self, allow_patterns: list[str] | str | None = None, ignore_patterns: list[str] | str | None = None, @@ -153,6 +189,7 @@ class LeRobotDatasetMetadata: @property def url_root(self) -> str: + """Hugging Face Hub URL root for this dataset.""" return f"hf://datasets/{self.repo_id}" @property @@ -161,6 +198,17 @@ class LeRobotDatasetMetadata: return packaging.version.parse(self.info["codebase_version"]) def get_data_file_path(self, ep_index: int) -> Path: + """Return the relative parquet file path for the given episode index. + + Args: + ep_index: Zero-based episode index. + + Returns: + Path to the parquet file containing this episode's data. + + Raises: + IndexError: If ``ep_index`` is out of range. + """ if self.episodes is None: self.episodes = load_episodes(self.root) if ep_index >= len(self.episodes): @@ -174,6 +222,19 @@ class LeRobotDatasetMetadata: return Path(fpath) def get_video_file_path(self, ep_index: int, vid_key: str) -> Path: + """Return the relative video file path for the given episode and video key. + + Args: + ep_index: Zero-based episode index. + vid_key: Feature key identifying the video stream + (e.g. ``'observation.images.laptop'``). + + Returns: + Path to the video file containing this episode's frames. + + Raises: + IndexError: If ``ep_index`` is out of range. + """ if self.episodes is None: self.episodes = load_episodes(self.root) if ep_index >= len(self.episodes): @@ -277,6 +338,17 @@ class LeRobotDatasetMetadata: return None def save_episode_tasks(self, tasks: list[str]): + """Register tasks for the current episode and persist to disk. + + New tasks that do not already exist in the dataset are assigned + sequential task indices and appended to the tasks parquet file. + + Args: + tasks: List of unique task descriptions in natural language. + + Raises: + ValueError: If ``tasks`` contains duplicates. + """ if len(set(tasks)) != len(tasks): raise ValueError(f"Tasks are not unique: {tasks}") @@ -336,8 +408,8 @@ class LeRobotDatasetMetadata: latest_path = ( self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx) - if self.writer is None - else self.writer.where + if self._pq_writer is None + else self._pq_writer.where ) if Path(latest_path).exists(): @@ -359,10 +431,10 @@ class LeRobotDatasetMetadata: episode_dict["dataset_to_index"] = [self.latest_episode["dataset_to_index"][0] + num_frames] # Add to buffer - self.metadata_buffer.append(episode_dict) + self._metadata_buffer.append(episode_dict) self.latest_episode = episode_dict - if len(self.metadata_buffer) >= self.metadata_buffer_size: + if len(self._metadata_buffer) >= self._metadata_buffer_size: self._flush_metadata_buffer() def save_episode( @@ -373,6 +445,20 @@ class LeRobotDatasetMetadata: episode_stats: dict[str, dict], episode_metadata: dict, ) -> None: + """Persist episode metadata, update dataset info, and aggregate stats. + + Writes the episode's metadata to the buffered parquet writer, increments + the total episode/frame counters in ``info.json``, and merges the + episode's statistics into the running dataset statistics. + + Args: + episode_index: Zero-based index of the episode being saved. + episode_length: Number of frames in this episode. + episode_tasks: List of task descriptions for this episode. + episode_stats: Per-feature statistics for this episode. + episode_metadata: Additional metadata (chunk/file indices, frame + ranges, video timestamps, etc.). + """ episode_dict = { "episode_index": episode_index, "tasks": episode_tasks, @@ -479,7 +565,32 @@ class LeRobotDatasetMetadata: data_files_size_in_mb: int | None = None, video_files_size_in_mb: int | None = None, ) -> "LeRobotDatasetMetadata": - """Creates metadata for a LeRobotDataset.""" + """Create metadata for a new LeRobot dataset from scratch. + + Initializes the ``info.json`` file on disk with the provided feature + schema and dataset settings. No episode data is written yet. + + Args: + repo_id: Repository identifier (e.g. ``'user/my_dataset'``). + fps: Frames per second used during data collection. + features: Feature specification dict mapping feature names to their + type/shape metadata. + robot_type: Optional robot type string stored in metadata. + root: Local directory for the dataset. Defaults to + ``$HF_LEROBOT_HOME/{repo_id}``. Must not already exist. + use_videos: If ``True``, visual modalities are encoded as MP4 videos. + metadata_buffer_size: Number of episode metadata records to buffer + before flushing to parquet. + chunks_size: Max number of files per chunk directory. ``None`` uses + the default. + data_files_size_in_mb: Max parquet file size in MB. ``None`` uses the + default. + video_files_size_in_mb: Max video file size in MB. ``None`` uses the + default. + + Returns: + A new :class:`LeRobotDatasetMetadata` instance. + """ obj = cls.__new__(cls) obj.repo_id = repo_id obj.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id @@ -510,8 +621,9 @@ class LeRobotDatasetMetadata: ) write_json(obj.info, obj.root / INFO_PATH) obj.revision = None - obj.writer = None + obj._pq_writer = None obj.latest_episode = None - obj.metadata_buffer = [] - obj.metadata_buffer_size = metadata_buffer_size + obj._metadata_buffer = [] + obj._metadata_buffer_size = metadata_buffer_size + obj._finalized = False return obj diff --git a/src/lerobot/datasets/dataset_reader.py b/src/lerobot/datasets/dataset_reader.py new file mode 100644 index 000000000..0233a3cf6 --- /dev/null +++ b/src/lerobot/datasets/dataset_reader.py @@ -0,0 +1,288 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Private reader component for LeRobotDataset. Handles random-access reading (HF dataset, delta indices, video decoding).""" + +from collections.abc import Callable +from pathlib import Path + +import datasets +import torch + +from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata +from lerobot.datasets.feature_utils import ( + check_delta_timestamps, + get_delta_indices, + get_hf_features_from_features, +) +from lerobot.datasets.io_utils import ( + hf_transform_to_torch, + load_nested_dataset, +) +from lerobot.datasets.video_utils import decode_video_frames + + +class DatasetReader: + """Encapsulates read-side state and methods for LeRobotDataset. + + Owns: hf_dataset, _absolute_to_relative_idx, delta_indices. + """ + + def __init__( + self, + meta: LeRobotDatasetMetadata, + root: Path, + episodes: list[int] | None, + tolerance_s: float, + video_backend: str, + delta_timestamps: dict[str, list[float]] | None, + image_transforms: Callable | None, + ): + """Initialize the reader with metadata, filtering, and transform config. + + The HF dataset is not loaded here — call :meth:`try_load` or + :meth:`load_and_activate` afterward. + + Args: + meta: Dataset metadata instance. + root: Local dataset root directory. + episodes: Optional list of episode indices to select. ``None`` + means all episodes. + tolerance_s: Timestamp synchronization tolerance in seconds. + video_backend: Video decoding backend identifier. + delta_timestamps: Optional dict mapping feature keys to lists of + relative timestamp offsets for temporal context windows. + image_transforms: Optional torchvision v2 transform applied to + visual features. + """ + self._meta = meta + self._root = root + self.episodes = episodes + self._tolerance_s = tolerance_s + self._video_backend = video_backend + self._image_transforms = image_transforms + + self.hf_dataset: datasets.Dataset | None = None + self._absolute_to_relative_idx: dict[int, int] | None = None + + # Setup delta_indices (doesn't depend on hf_dataset) + self.delta_indices = None + if delta_timestamps is not None: + check_delta_timestamps(delta_timestamps, meta.fps, tolerance_s) + self.delta_indices = get_delta_indices(delta_timestamps, meta.fps) + + def try_load(self) -> bool: + """Attempt to load from local cache. Returns True if data is sufficient.""" + try: + self.hf_dataset = self._load_hf_dataset() + except (FileNotFoundError, NotADirectoryError): + self.hf_dataset = None + return False + if not self._check_cached_episodes_sufficient(): + self.hf_dataset = None + return False + self._build_index_mapping() + return True + + def load_and_activate(self) -> None: + """Load HF dataset from disk and build index mapping. Call after data is on disk.""" + self.hf_dataset = self._load_hf_dataset() + self._build_index_mapping() + + def _build_index_mapping(self) -> None: + """Build absolute-to-relative index mapping from loaded hf_dataset.""" + self._absolute_to_relative_idx = None + if self.episodes is not None and self.hf_dataset is not None: + self._absolute_to_relative_idx = { + abs_idx.item() if isinstance(abs_idx, torch.Tensor) else abs_idx: rel_idx + for rel_idx, abs_idx in enumerate(self.hf_dataset["index"]) + } + + @property + def num_frames(self) -> int: + """Number of frames in selected episodes.""" + if self.episodes is not None and self.hf_dataset is not None: + return len(self.hf_dataset) + return self._meta.total_frames + + @property + def num_episodes(self) -> int: + """Number of episodes selected.""" + return len(self.episodes) if self.episodes is not None else self._meta.total_episodes + + def _load_hf_dataset(self) -> datasets.Dataset: + """hf_dataset contains all the observations, states, actions, rewards, etc.""" + features = get_hf_features_from_features(self._meta.features) + hf_dataset = load_nested_dataset(self._root / "data", features=features, episodes=self.episodes) + hf_dataset.set_transform(hf_transform_to_torch) + return hf_dataset + + def _check_cached_episodes_sufficient(self) -> bool: + """Check if the cached dataset contains all requested episodes and their video files.""" + if self.hf_dataset is None or len(self.hf_dataset) == 0: + return False + + available_episodes = { + ep_idx.item() if isinstance(ep_idx, torch.Tensor) else ep_idx + for ep_idx in self.hf_dataset.unique("episode_index") + } + + if self.episodes is None: + requested_episodes = set(range(self._meta.total_episodes)) + else: + requested_episodes = set(self.episodes) + + if not requested_episodes.issubset(available_episodes): + return False + + if len(self._meta.video_keys) > 0: + for ep_idx in requested_episodes: + for vid_key in self._meta.video_keys: + video_path = self._root / self._meta.get_video_file_path(ep_idx, vid_key) + if not video_path.exists(): + return False + + return True + + def get_episodes_file_paths(self) -> list[Path]: + """Return deduplicated file paths (data + video) for selected episodes. + + Used to build the ``allow_patterns`` list for ``snapshot_download``. + """ + episodes = self.episodes if self.episodes is not None else list(range(self._meta.total_episodes)) + fpaths = [str(self._meta.get_data_file_path(ep_idx)) for ep_idx in episodes] + if len(self._meta.video_keys) > 0: + video_files = [ + str(self._meta.get_video_file_path(ep_idx, vid_key)) + for vid_key in self._meta.video_keys + for ep_idx in episodes + ] + fpaths += video_files + # episodes are stored in the same files, so we return unique paths only + fpaths = list(set(fpaths)) + return fpaths + + def _get_query_indices( + self, abs_idx: int, ep_idx: int + ) -> tuple[dict[str, list[int]], dict[str, torch.Tensor]]: + """Compute query indices for delta timestamps.""" + ep = self._meta.episodes[ep_idx] + ep_start = ep["dataset_from_index"] + ep_end = ep["dataset_to_index"] + query_indices = { + key: [max(ep_start, min(ep_end - 1, abs_idx + delta)) for delta in delta_idx] + for key, delta_idx in self.delta_indices.items() + } + padding = { + f"{key}_is_pad": torch.BoolTensor( + [(abs_idx + delta < ep_start) | (abs_idx + delta >= ep_end) for delta in delta_idx] + ) + for key, delta_idx in self.delta_indices.items() + } + return query_indices, padding + + def _get_query_timestamps( + self, + current_ts: float, + query_indices: dict[str, list[int]] | None = None, + ) -> dict[str, list[float]]: + query_timestamps = {} + for key in self._meta.video_keys: + if query_indices is not None and key in query_indices: + if self._absolute_to_relative_idx is not None: + relative_indices = [self._absolute_to_relative_idx[idx] for idx in query_indices[key]] + timestamps = self.hf_dataset[relative_indices]["timestamp"] + else: + timestamps = self.hf_dataset[query_indices[key]]["timestamp"] + query_timestamps[key] = torch.stack(timestamps).tolist() + else: + query_timestamps[key] = [current_ts] + + return query_timestamps + + def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict: + """Query dataset for indices across keys, skipping video keys.""" + result: dict = {} + for key, q_idx in query_indices.items(): + if key in self._meta.video_keys: + continue + relative_indices = ( + q_idx + if self._absolute_to_relative_idx is None + else [self._absolute_to_relative_idx[idx] for idx in q_idx] + ) + try: + result[key] = torch.stack(self.hf_dataset[key][relative_indices]) + except (KeyError, TypeError, IndexError): + result[key] = torch.stack(self.hf_dataset[relative_indices][key]) + return result + + def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict[str, torch.Tensor]: + """Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function + in the main process (e.g. by using a second Dataloader with num_workers=0). It will result in a + Segmentation Fault. + """ + ep = self._meta.episodes[ep_idx] + item = {} + for vid_key, query_ts in query_timestamps.items(): + from_timestamp = ep[f"videos/{vid_key}/from_timestamp"] + shifted_query_ts = [from_timestamp + ts for ts in query_ts] + + video_path = self._root / self._meta.get_video_file_path(ep_idx, vid_key) + frames = decode_video_frames(video_path, shifted_query_ts, self._tolerance_s, self._video_backend) + item[vid_key] = frames.squeeze(0) + + return item + + def get_item(self, idx) -> dict: + """Core __getitem__ logic. Assumes hf_dataset is loaded. + + ``idx`` is a *relative* index into the (possibly episode-filtered) + HF dataset, **not** the absolute frame index stored in the ``index`` + column. The absolute index is retrieved from the row itself. + """ + item = self.hf_dataset[idx] + ep_idx = item["episode_index"].item() + abs_idx = item["index"].item() + + query_indices = None + if self.delta_indices is not None: + query_indices, padding = self._get_query_indices(abs_idx, ep_idx) + query_result = self._query_hf_dataset(query_indices) + item = {**item, **padding} + for key, val in query_result.items(): + item[key] = val + + if len(self._meta.video_keys) > 0: + current_ts = item["timestamp"].item() + query_timestamps = self._get_query_timestamps(current_ts, query_indices) + video_frames = self._query_videos(query_timestamps, ep_idx) + item = {**video_frames, **item} + + if self._image_transforms is not None: + image_keys = self._meta.camera_keys + for cam in image_keys: + item[cam] = self._image_transforms(item[cam]) + + # Add task as a string + task_idx = item["task_index"].item() + item["task"] = self._meta.tasks.iloc[task_idx].name + + # add subtask information if available + if "subtask_index" in self._meta.features and self._meta.subtasks is not None: + subtask_idx = item["subtask_index"].item() + item["subtask"] = self._meta.subtasks.iloc[subtask_idx].name + + return item diff --git a/src/lerobot/datasets/dataset_tools.py b/src/lerobot/datasets/dataset_tools.py index 87cdc18e5..cd2b9fc7c 100644 --- a/src/lerobot/datasets/dataset_tools.py +++ b/src/lerobot/datasets/dataset_tools.py @@ -891,7 +891,7 @@ def _copy_and_reindex_episodes_metadata( total_frames += src_episode["length"] - dst_meta._close_writer() + dst_meta.finalize() dst_meta.info.update( { diff --git a/src/lerobot/datasets/dataset_writer.py b/src/lerobot/datasets/dataset_writer.py new file mode 100644 index 000000000..b74b18e0c --- /dev/null +++ b/src/lerobot/datasets/dataset_writer.py @@ -0,0 +1,625 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Private writer component for LeRobotDataset. Handles sequential recording (episode buffer, ParquetWriter, image writer, video encoding).""" + +from __future__ import annotations + +import concurrent.futures +import contextlib +import logging +import shutil +import tempfile +from pathlib import Path + +import datasets +import numpy as np +import pandas as pd +import PIL.Image +import pyarrow.parquet as pq +import torch + +from lerobot.datasets.compute_stats import compute_episode_stats +from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata +from lerobot.datasets.feature_utils import ( + get_hf_features_from_features, + validate_episode_buffer, + validate_frame, +) +from lerobot.datasets.image_writer import AsyncImageWriter, write_image +from lerobot.datasets.io_utils import ( + embed_images, + get_file_size_in_mb, + load_episodes, + write_info, +) +from lerobot.datasets.utils import ( + DEFAULT_EPISODES_PATH, + DEFAULT_IMAGE_PATH, + update_chunk_file_indices, +) +from lerobot.datasets.video_utils import ( + StreamingVideoEncoder, + concatenate_video_files, + encode_video_frames, + get_video_duration_in_s, +) + +logger = logging.getLogger(__name__) + + +def _encode_video_worker( + video_key: str, + episode_index: int, + root: Path, + fps: int, + vcodec: str = "libsvtav1", + encoder_threads: int | None = None, +) -> Path: + temp_path = Path(tempfile.mkdtemp(dir=root)) / f"{video_key}_{episode_index:03d}.mp4" + fpath = DEFAULT_IMAGE_PATH.format(image_key=video_key, episode_index=episode_index, frame_index=0) + img_dir = (root / fpath).parent + encode_video_frames( + img_dir, temp_path, fps, vcodec=vcodec, overwrite=True, encoder_threads=encoder_threads + ) + shutil.rmtree(img_dir) + return temp_path + + +class DatasetWriter: + """Encapsulates write-side state and methods for LeRobotDataset. + + Owns: episode_buffer, image_writer, _pq_writer (ParquetWriter), _latest_episode, + _current_file_start_frame, _streaming_encoder, _episodes_since_last_encoding, _recorded_frames. + """ + + def __init__( + self, + meta: LeRobotDatasetMetadata, + root: Path, + vcodec: str, + encoder_threads: int | None, + batch_encoding_size: int, + streaming_encoder: StreamingVideoEncoder | None = None, + initial_frames: int = 0, + ): + """Initialize the writer with metadata, codec, and encoding config. + + Args: + meta: Dataset metadata instance (used for feature schema, chunk + settings, and episode persistence). + root: Local dataset root directory. + vcodec: Video codec for encoding (e.g. ``'libsvtav1'``, ``'h264'``). + encoder_threads: Threads per encoder instance. ``None`` for auto. + batch_encoding_size: Number of episodes to accumulate before + batch-encoding videos. + streaming_encoder: Optional pre-built :class:`StreamingVideoEncoder` + for real-time encoding. ``None`` disables streaming mode. + initial_frames: Starting frame count (non-zero when resuming). + """ + self._meta = meta + self._root = root + self._vcodec = vcodec + self._encoder_threads = encoder_threads + self._batch_encoding_size = batch_encoding_size + self._streaming_encoder = streaming_encoder + + # Writer state + self.image_writer: AsyncImageWriter | None = None + self.episode_buffer: dict = self._create_episode_buffer() + self._pq_writer: pq.ParquetWriter | None = None + self._latest_episode: dict | None = None + self._current_file_start_frame: int | None = None + self._episodes_since_last_encoding: int = 0 + self._recorded_frames: int = initial_frames + self._finalized = False + + def _create_episode_buffer(self, episode_index: int | None = None) -> dict: + current_ep_idx = self._meta.total_episodes if episode_index is None else episode_index + ep_buffer = {} + ep_buffer["size"] = 0 + ep_buffer["task"] = [] + for key in self._meta.features: + ep_buffer[key] = current_ep_idx if key == "episode_index" else [] + return ep_buffer + + def _get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path: + fpath = DEFAULT_IMAGE_PATH.format( + image_key=image_key, episode_index=episode_index, frame_index=frame_index + ) + return self._root / fpath + + def _get_image_file_dir(self, episode_index: int, image_key: str) -> Path: + return self._get_image_file_path(episode_index, image_key, frame_index=0).parent + + def _save_image( + self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path, compress_level: int = 1 + ) -> None: + if self.image_writer is None: + if isinstance(image, torch.Tensor): + image = image.cpu().numpy() + write_image(image, fpath, compress_level=compress_level) + else: + self.image_writer.save_image(image=image, fpath=fpath, compress_level=compress_level) + + def add_frame(self, frame: dict) -> None: + """Add a frame to the episode_buffer. Images are written to a temporary directory.""" + # Convert torch to numpy if needed + for name in frame: + if isinstance(frame[name], torch.Tensor): + frame[name] = frame[name].numpy() + + validate_frame(frame, self._meta.features) + + if self.episode_buffer is None: + self.episode_buffer = self._create_episode_buffer() + + # Automatically add frame_index and timestamp to episode buffer + frame_index = self.episode_buffer["size"] + timestamp = frame.pop("timestamp") if "timestamp" in frame else frame_index / self._meta.fps + self.episode_buffer["frame_index"].append(frame_index) + self.episode_buffer["timestamp"].append(timestamp) + self.episode_buffer["task"].append(frame.pop("task")) + + # Start streaming encoder on first frame of episode + if frame_index == 0 and self._streaming_encoder is not None: + self._streaming_encoder.start_episode( + video_keys=list(self._meta.video_keys), + temp_dir=self._root, + ) + + # Add frame features to episode_buffer + for key in frame: + if key not in self._meta.features: + raise ValueError( + f"An element of the frame is not in the features. '{key}' not in '{self._meta.features.keys()}'." + ) + + if self._meta.features[key]["dtype"] == "video" and self._streaming_encoder is not None: + self._streaming_encoder.feed_frame(key, frame[key]) + self.episode_buffer[key].append(None) + elif self._meta.features[key]["dtype"] in ["image", "video"]: + img_path = self._get_image_file_path( + episode_index=self.episode_buffer["episode_index"], image_key=key, frame_index=frame_index + ) + if frame_index == 0: + img_path.parent.mkdir(parents=True, exist_ok=True) + compress_level = 1 if self._meta.features[key]["dtype"] == "video" else 6 + self._save_image(frame[key], img_path, compress_level) + self.episode_buffer[key].append(str(img_path)) + else: + self.episode_buffer[key].append(frame[key]) + + self.episode_buffer["size"] += 1 + + def save_episode( + self, + episode_data: dict | None = None, + parallel_encoding: bool = True, + ) -> None: + """Save the current episode in self.episode_buffer to disk.""" + episode_buffer = episode_data if episode_data is not None else self.episode_buffer + + validate_episode_buffer(episode_buffer, self._meta.total_episodes, self._meta.features) + + # size and task are special cases that won't be added to hf_dataset + episode_length = episode_buffer.pop("size") + tasks = episode_buffer.pop("task") + episode_tasks = list(set(tasks)) + episode_index = episode_buffer["episode_index"] + + episode_buffer["index"] = np.arange(self._meta.total_frames, self._meta.total_frames + episode_length) + episode_buffer["episode_index"] = np.full((episode_length,), episode_index) + + # Update tasks and task indices with new tasks if any + self._meta.save_episode_tasks(episode_tasks) + + # Given tasks in natural language, find their corresponding task indices + episode_buffer["task_index"] = np.array([self._meta.get_task_index(task) for task in tasks]) + + for key, ft in self._meta.features.items(): + if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video"]: + continue + episode_buffer[key] = np.stack(episode_buffer[key]) + + # Wait for image writer to end, so that episode stats over images can be computed + self._wait_image_writer() + + has_video_keys = len(self._meta.video_keys) > 0 + use_streaming = self._streaming_encoder is not None and has_video_keys + use_batched_encoding = self._batch_encoding_size > 1 + + if use_streaming: + non_video_buffer = { + k: v + for k, v in episode_buffer.items() + if self._meta.features.get(k, {}).get("dtype") not in ("video",) + } + non_video_features = {k: v for k, v in self._meta.features.items() if v["dtype"] != "video"} + ep_stats = compute_episode_stats(non_video_buffer, non_video_features) + else: + ep_stats = compute_episode_stats(episode_buffer, self._meta.features) + + ep_metadata = self._save_episode_data(episode_buffer) + + if use_streaming: + streaming_results = self._streaming_encoder.finish_episode() + for video_key in self._meta.video_keys: + temp_path, video_stats = streaming_results[video_key] + if video_stats is not None: + ep_stats[video_key] = { + k: v if k == "count" else np.squeeze(v.reshape(1, -1, 1, 1) / 255.0, axis=0) + for k, v in video_stats.items() + } + ep_metadata.update(self._save_episode_video(video_key, episode_index, temp_path=temp_path)) + elif has_video_keys and not use_batched_encoding: + num_cameras = len(self._meta.video_keys) + if parallel_encoding and num_cameras > 1: + with concurrent.futures.ProcessPoolExecutor(max_workers=num_cameras) as executor: + future_to_key = { + executor.submit( + _encode_video_worker, + video_key, + episode_index, + self._root, + self._meta.fps, + self._vcodec, + self._encoder_threads, + ): video_key + for video_key in self._meta.video_keys + } + + results = {} + for future in concurrent.futures.as_completed(future_to_key): + video_key = future_to_key[future] + try: + temp_path = future.result() + results[video_key] = temp_path + except Exception as exc: + logger.error(f"Video encoding failed for {video_key}: {exc}") + raise exc + + for video_key in self._meta.video_keys: + temp_path = results[video_key] + ep_metadata.update( + self._save_episode_video(video_key, episode_index, temp_path=temp_path) + ) + else: + for video_key in self._meta.video_keys: + ep_metadata.update(self._save_episode_video(video_key, episode_index)) + + # `meta.save_episode` need to be executed after encoding the videos + self._meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats, ep_metadata) + + if has_video_keys and use_batched_encoding: + self._episodes_since_last_encoding += 1 + if self._episodes_since_last_encoding == self._batch_encoding_size: + start_ep = self._meta.total_episodes - self._batch_encoding_size + end_ep = self._meta.total_episodes + self._batch_save_episode_video(start_ep, end_ep) + self._episodes_since_last_encoding = 0 + + if episode_data is None: + self.clear_episode_buffer(delete_images=len(self._meta.image_keys) > 0) + + def _batch_save_episode_video(self, start_episode: int, end_episode: int | None = None) -> None: + """Batch save videos for multiple episodes.""" + if end_episode is None: + end_episode = self._meta.total_episodes + + logger.info( + f"Batch encoding {self._batch_encoding_size} videos for episodes {start_episode} to {end_episode - 1}" + ) + + chunk_idx = self._meta.episodes[start_episode]["data/chunk_index"] + file_idx = self._meta.episodes[start_episode]["data/file_index"] + episode_df_path = self._root / DEFAULT_EPISODES_PATH.format( + chunk_index=chunk_idx, file_index=file_idx + ) + episode_df = pd.read_parquet(episode_df_path) + + for ep_idx in range(start_episode, end_episode): + logger.info(f"Encoding videos for episode {ep_idx}") + + if ( + self._meta.episodes[ep_idx]["data/chunk_index"] != chunk_idx + or self._meta.episodes[ep_idx]["data/file_index"] != file_idx + ): + episode_df.to_parquet(episode_df_path) + self._meta.episodes = load_episodes(self._root) + + chunk_idx = self._meta.episodes[ep_idx]["data/chunk_index"] + file_idx = self._meta.episodes[ep_idx]["data/file_index"] + episode_df_path = self._root / DEFAULT_EPISODES_PATH.format( + chunk_index=chunk_idx, file_index=file_idx + ) + episode_df = pd.read_parquet(episode_df_path) + + video_ep_metadata = {} + for video_key in self._meta.video_keys: + video_ep_metadata.update(self._save_episode_video(video_key, ep_idx)) + video_ep_metadata.pop("episode_index") + video_ep_df = pd.DataFrame(video_ep_metadata, index=[ep_idx]).convert_dtypes( + dtype_backend="pyarrow" + ) + + episode_df = episode_df.combine_first(video_ep_df) + episode_df.to_parquet(episode_df_path) + self._meta.episodes = load_episodes(self._root) + + def _save_episode_data(self, episode_buffer: dict) -> dict: + """Save episode data to a parquet file.""" + # Use metadata features as the authoritative schema + hf_features = get_hf_features_from_features(self._meta.features) + ep_dict = {key: episode_buffer[key] for key in hf_features} + ep_dataset = datasets.Dataset.from_dict(ep_dict, features=hf_features, split="train") + ep_dataset = embed_images(ep_dataset) + ep_num_frames = len(ep_dataset) + + if self._latest_episode is None: + chunk_idx, file_idx = 0, 0 + global_frame_index = 0 + self._current_file_start_frame = 0 + if self._meta.episodes is not None and len(self._meta.episodes) > 0: + latest_ep = self._meta.episodes[-1] + global_frame_index = latest_ep["dataset_to_index"] + chunk_idx = latest_ep["data/chunk_index"] + file_idx = latest_ep["data/file_index"] + + chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self._meta.chunks_size) + self._current_file_start_frame = global_frame_index + else: + latest_ep = self._latest_episode + chunk_idx = latest_ep["data/chunk_index"] + file_idx = latest_ep["data/file_index"] + global_frame_index = latest_ep["index"][-1] + 1 + + latest_path = self._root / self._meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx) + latest_size_in_mb = get_file_size_in_mb(latest_path) + + frames_in_current_file = global_frame_index - self._current_file_start_frame + av_size_per_frame = ( + latest_size_in_mb / frames_in_current_file if frames_in_current_file > 0 else 0 + ) + + if latest_size_in_mb + av_size_per_frame * ep_num_frames >= self._meta.data_files_size_in_mb: + chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self._meta.chunks_size) + self.close_writer() + self._current_file_start_frame = global_frame_index + + ep_dict["data/chunk_index"] = chunk_idx + ep_dict["data/file_index"] = file_idx + + path = self._root / self._meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx) + path.parent.mkdir(parents=True, exist_ok=True) + + table = ep_dataset.with_format("arrow")[:] + if not self._pq_writer: + self._pq_writer = pq.ParquetWriter( + path, schema=table.schema, compression="snappy", use_dictionary=True + ) + self._pq_writer.write_table(table) + + metadata = { + "data/chunk_index": chunk_idx, + "data/file_index": file_idx, + "dataset_from_index": global_frame_index, + "dataset_to_index": global_frame_index + ep_num_frames, + } + + self._latest_episode = {**ep_dict, **metadata} + self._recorded_frames += ep_num_frames + + return metadata + + def _save_episode_video( + self, + video_key: str, + episode_index: int, + temp_path: Path | None = None, + ) -> dict: + if temp_path is None: + ep_path = self._encode_temporary_episode_video(video_key, episode_index) + else: + ep_path = temp_path + + ep_size_in_mb = get_file_size_in_mb(ep_path) + ep_duration_in_s = get_video_duration_in_s(ep_path) + + if ( + episode_index == 0 + or self._meta.latest_episode is None + or f"videos/{video_key}/chunk_index" not in self._meta.latest_episode + ): + chunk_idx, file_idx = 0, 0 + if self._meta.episodes is not None and len(self._meta.episodes) > 0: + old_chunk_idx = self._meta.episodes[-1][f"videos/{video_key}/chunk_index"] + old_file_idx = self._meta.episodes[-1][f"videos/{video_key}/file_index"] + chunk_idx, file_idx = update_chunk_file_indices( + old_chunk_idx, old_file_idx, self._meta.chunks_size + ) + latest_duration_in_s = 0.0 + new_path = self._root / self._meta.video_path.format( + video_key=video_key, chunk_index=chunk_idx, file_index=file_idx + ) + new_path.parent.mkdir(parents=True, exist_ok=True) + shutil.move(str(ep_path), str(new_path)) + else: + latest_ep = self._meta.latest_episode + chunk_idx = latest_ep[f"videos/{video_key}/chunk_index"][0] + file_idx = latest_ep[f"videos/{video_key}/file_index"][0] + + latest_path = self._root / self._meta.video_path.format( + video_key=video_key, chunk_index=chunk_idx, file_index=file_idx + ) + latest_size_in_mb = get_file_size_in_mb(latest_path) + latest_duration_in_s = latest_ep[f"videos/{video_key}/to_timestamp"][0] + + if latest_size_in_mb + ep_size_in_mb >= self._meta.video_files_size_in_mb: + chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self._meta.chunks_size) + new_path = self._root / self._meta.video_path.format( + video_key=video_key, chunk_index=chunk_idx, file_index=file_idx + ) + new_path.parent.mkdir(parents=True, exist_ok=True) + shutil.move(str(ep_path), str(new_path)) + latest_duration_in_s = 0.0 + else: + concatenate_video_files( + [latest_path, ep_path], + latest_path, + ) + + # Remove temporary directory + shutil.rmtree(str(ep_path.parent)) + + # Update video info (only needed when first episode is encoded) + if episode_index == 0: + self._meta.update_video_info(video_key) + write_info(self._meta.info, self._meta.root) + + metadata = { + "episode_index": episode_index, + f"videos/{video_key}/chunk_index": chunk_idx, + f"videos/{video_key}/file_index": file_idx, + f"videos/{video_key}/from_timestamp": latest_duration_in_s, + f"videos/{video_key}/to_timestamp": latest_duration_in_s + ep_duration_in_s, + } + return metadata + + def clear_episode_buffer(self, delete_images: bool = True) -> None: + """Discard the current episode buffer and optionally delete temp images. + + Args: + delete_images: If ``True``, remove temporary image directories + written for the current episode. + """ + # Cancel streaming encoder if active + if self._streaming_encoder is not None: + self._streaming_encoder.cancel_episode() + + if delete_images: + if self.image_writer is not None: + self._wait_image_writer() + episode_index = self.episode_buffer["episode_index"] + # episode_index is `int` when freshly created, but becomes `np.ndarray` after + # save_episode() mutates the buffer. Handle both types here. + if isinstance(episode_index, np.ndarray): + episode_index = episode_index.item() if episode_index.size == 1 else episode_index[0] + for cam_key in self._meta.image_keys: + img_dir = self._get_image_file_dir(episode_index, cam_key) + if img_dir.is_dir(): + shutil.rmtree(img_dir) + + self.episode_buffer = self._create_episode_buffer() + + def start_image_writer(self, num_processes: int = 0, num_threads: int = 4) -> None: + """Start an :class:`AsyncImageWriter` for background image persistence. + + Args: + num_processes: Number of subprocesses. ``0`` means threads only. + num_threads: Number of threads per process. + """ + if isinstance(self.image_writer, AsyncImageWriter): + logger.warning( + "You are starting a new AsyncImageWriter that is replacing an already existing one in the dataset." + ) + + self.image_writer = AsyncImageWriter( + num_processes=num_processes, + num_threads=num_threads, + ) + + def stop_image_writer(self) -> None: + """Stop the image writer (needed before pickling the dataset for DataLoader).""" + if self.image_writer is not None: + self.image_writer.stop() + self.image_writer = None + + def _wait_image_writer(self) -> None: + """Wait for asynchronous image writer to finish.""" + if self.image_writer is not None: + self.image_writer.wait_until_done() + + def _encode_temporary_episode_video(self, video_key: str, episode_index: int) -> Path: + """Use ffmpeg to convert frames stored as png into mp4 videos.""" + return _encode_video_worker( + video_key, episode_index, self._root, self._meta.fps, self._vcodec, self._encoder_threads + ) + + def close_writer(self) -> None: + """Close and cleanup the parquet writer if it exists.""" + if self._pq_writer is not None: + self._pq_writer.close() + self._pq_writer = None + + def flush_pending_videos(self) -> None: + """Flush any pending video encoding (streaming or batch). + + For streaming encoding: closes the encoder. + For batch encoding: encodes any remaining episodes that haven't been batch-encoded yet. + """ + if self._streaming_encoder is not None: + self._streaming_encoder.close() + elif self._episodes_since_last_encoding > 0: + start_ep = self._meta.total_episodes - self._episodes_since_last_encoding + end_ep = self._meta.total_episodes + logger.info( + f"Encoding remaining {self._episodes_since_last_encoding} episodes, " + f"from episode {start_ep} to {end_ep - 1}" + ) + self._batch_save_episode_video(start_ep, end_ep) + + def cancel_pending_videos(self) -> None: + """Cancel any in-progress streaming encoding without flushing.""" + if self._streaming_encoder is not None: + self._streaming_encoder.cancel_episode() + + def cleanup_interrupted_episode(self, episode_index: int) -> None: + """Remove temporary image directories for an interrupted episode.""" + for key in self._meta.video_keys: + img_dir = self._get_image_file_path( + episode_index=episode_index, image_key=key, frame_index=0 + ).parent + if img_dir.exists(): + logger.debug( + f"Cleaning up interrupted episode images for episode {episode_index}, camera {key}" + ) + shutil.rmtree(img_dir) + + def finalize(self) -> None: + """Flush all pending work and release all resources. + + Idempotent — safe to call multiple times. + """ + if getattr(self, "_finalized", False): + return + # 1. Wait for async image writes to complete, then stop + if self.image_writer is not None: + self.image_writer.wait_until_done() + self.image_writer.stop() + self.image_writer = None + # 2. Flush pending video encoding (streaming or batch) + self.flush_pending_videos() + # 3. Close own parquet writer + self.close_writer() + # 4. Finalize metadata (idempotent) + self._meta.finalize() + self._finalized = True + + def __del__(self): + """Safety net: release resources on garbage collection.""" + # During interpreter shutdown, referenced objects may already be collected. + with contextlib.suppress(Exception): + self.finalize() diff --git a/src/lerobot/datasets/image_writer.py b/src/lerobot/datasets/image_writer.py index 9f40394de..603067757 100644 --- a/src/lerobot/datasets/image_writer.py +++ b/src/lerobot/datasets/image_writer.py @@ -32,10 +32,10 @@ def safe_stop_image_writer(func): return func(*args, **kwargs) except Exception as e: dataset = kwargs.get("dataset") - image_writer = getattr(dataset, "image_writer", None) if dataset else None - if image_writer is not None: + writer = getattr(dataset, "writer", None) if dataset else None + if writer is not None and writer.image_writer is not None: logger.warning("Waiting for image writer to terminate...") - image_writer.stop() + writer.image_writer.stop() raise e return wrapper diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index 8f0600ba8..cba0c1cba 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -13,57 +13,28 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import concurrent.futures import contextlib import logging -import shutil -import tempfile from collections.abc import Callable from pathlib import Path import datasets -import numpy as np -import pandas as pd -import PIL.Image -import pyarrow.parquet as pq import torch import torch.utils from huggingface_hub import HfApi, snapshot_download from huggingface_hub.errors import RevisionNotFoundError -from lerobot.datasets.compute_stats import compute_episode_stats from lerobot.datasets.dataset_metadata import CODEBASE_VERSION, LeRobotDatasetMetadata -from lerobot.datasets.feature_utils import ( - check_delta_timestamps, - get_delta_indices, - get_hf_features_from_features, - validate_episode_buffer, - validate_frame, -) -from lerobot.datasets.image_writer import AsyncImageWriter, write_image -from lerobot.datasets.io_utils import ( - embed_images, - get_file_size_in_mb, - hf_transform_to_torch, - load_episodes, - load_nested_dataset, - write_info, -) +from lerobot.datasets.dataset_reader import DatasetReader +from lerobot.datasets.dataset_writer import DatasetWriter from lerobot.datasets.utils import ( - DEFAULT_EPISODES_PATH, - DEFAULT_IMAGE_PATH, create_lerobot_dataset_card, get_safe_version, is_valid_version, - update_chunk_file_indices, ) from lerobot.datasets.video_utils import ( StreamingVideoEncoder, - concatenate_video_files, - decode_video_frames, - encode_video_frames, get_safe_default_codec, - get_video_duration_in_s, resolve_vcodec, ) from lerobot.utils.constants import HF_LEROBOT_HOME @@ -71,24 +42,6 @@ from lerobot.utils.constants import HF_LEROBOT_HOME logger = logging.getLogger(__name__) -def _encode_video_worker( - video_key: str, - episode_index: int, - root: Path, - fps: int, - vcodec: str = "libsvtav1", - encoder_threads: int | None = None, -) -> Path: - temp_path = Path(tempfile.mkdtemp(dir=root)) / f"{video_key}_{episode_index:03d}.mp4" - fpath = DEFAULT_IMAGE_PATH.format(image_key=video_key, episode_index=episode_index, frame_index=0) - img_dir = (root / fpath).parent - encode_video_frames( - img_dir, temp_path, fps, vcodec=vcodec, overwrite=True, encoder_threads=encoder_threads - ) - shutil.rmtree(img_dir) - return temp_path - - class LeRobotDataset(torch.utils.data.Dataset): def __init__( self, @@ -136,7 +89,7 @@ class LeRobotDataset(torch.utils.data.Dataset): - stats stores the dataset statistics of the different modalities for normalization - tasks contains the prompts for each task of the dataset, which can be used for task-conditioned training. - - hf_dataset (from datasets.Dataset), which will read any values from parquet files. + - data (backed by datasets.Dataset), which reads values from parquet files. - videos (optional) from which frames are loaded to be synchronous with data from parquet files. A typical LeRobotDataset looks like this from its root path: @@ -229,6 +182,11 @@ class LeRobotDataset(torch.utils.data.Dataset): encoder_threads (int | None, optional): Number of threads per encoder instance. None lets the codec auto-detect (default). Lower values reduce CPU usage per encoder. Maps to 'lp' (via svtav1-params) for libsvtav1 and 'threads' for h264/hevc. + + Note: + Write-mode parameters (``streaming_encoding``, ``batch_encoding_size``) passed to + ``__init__`` are deprecated. Use :meth:`create` for new datasets or :meth:`resume` + to append to existing ones. """ super().__init__() self.repo_id = repo_id @@ -238,21 +196,11 @@ class LeRobotDataset(torch.utils.data.Dataset): self.episodes = episodes self.tolerance_s = tolerance_s self.revision = revision if revision else CODEBASE_VERSION - self.video_backend = video_backend if video_backend else get_safe_default_codec() - self.delta_indices = None - self.batch_encoding_size = batch_encoding_size - self.episodes_since_last_encoding = 0 - self.vcodec = resolve_vcodec(vcodec) + self._video_backend = video_backend if video_backend else get_safe_default_codec() + self._batch_encoding_size = batch_encoding_size + self._vcodec = resolve_vcodec(vcodec) self._encoder_threads = encoder_threads - # Unused attributes - self.image_writer = None - self.episode_buffer = None - self.writer = None - self.latest_episode = None - self._current_file_start_frame = None # Track the starting frame index of the current parquet file - self._streaming_encoder = None - self.root.mkdir(exist_ok=True, parents=True) # Load metadata @@ -260,64 +208,270 @@ class LeRobotDataset(torch.utils.data.Dataset): self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync ) - # Track dataset state for efficient incremental writing - self._lazy_loading = False - self._recorded_frames = self.meta.total_frames - self._writer_closed_for_reading = False + # Create reader (hf_dataset loaded below) + self.reader = DatasetReader( + meta=self.meta, + root=self.root, + episodes=episodes, + tolerance_s=tolerance_s, + video_backend=self._video_backend, + delta_timestamps=delta_timestamps, + image_transforms=image_transforms, + ) # Load actual data - try: - if force_cache_sync: - raise FileNotFoundError - self.hf_dataset = self.load_hf_dataset() - # Check if cached dataset contains all requested episodes - if not self._check_cached_episodes_sufficient(): - raise FileNotFoundError("Cached dataset doesn't contain all requested episodes") - except (FileNotFoundError, NotADirectoryError): + if force_cache_sync or not self.reader.try_load(): if is_valid_version(self.revision): self.revision = get_safe_version(self.repo_id, self.revision) - self.download(download_videos) - self.hf_dataset = self.load_hf_dataset() + self._download(download_videos) + self.reader.load_and_activate() - # Create mapping from absolute indices to relative indices when only a subset of the episodes are loaded - # Build a mapping: absolute_index -> relative_index_in_filtered_dataset - self._absolute_to_relative_idx = None - if self.episodes is not None: - self._absolute_to_relative_idx = { - abs_idx.item() if isinstance(abs_idx, torch.Tensor) else abs_idx: rel_idx - for rel_idx, abs_idx in enumerate(self.hf_dataset["index"]) - } + # Detect write-mode params for backward compatibility + _has_write_params = streaming_encoding or batch_encoding_size != 1 + if _has_write_params: + import warnings - # Setup delta_indices - if self.delta_timestamps is not None: - check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s) - self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps) - - # Initialize streaming encoder for resumed recording - if streaming_encoding and len(self.meta.video_keys) > 0: - self._streaming_encoder = StreamingVideoEncoder( - fps=self.meta.fps, - vcodec=self.vcodec, - pix_fmt="yuv420p", - g=2, - crf=30, - preset=None, - queue_maxsize=encoder_queue_maxsize, - encoder_threads=encoder_threads, + warnings.warn( + "Passing write-mode parameters (streaming_encoding, batch_encoding_size) to " + "LeRobotDataset.__init__() is deprecated. Use LeRobotDataset.resume() instead.", + DeprecationWarning, + stacklevel=2, ) - - def _close_writer(self) -> None: - """Close and cleanup the parquet writer if it exists.""" - writer = getattr(self, "writer", None) - if writer is not None: - writer.close() + streaming_enc = None + if streaming_encoding and len(self.meta.video_keys) > 0: + streaming_enc = self._build_streaming_encoder( + self.meta.fps, self._vcodec, encoder_queue_maxsize, encoder_threads + ) + self.writer = DatasetWriter( + meta=self.meta, + root=self.root, + vcodec=self._vcodec, + encoder_threads=encoder_threads, + batch_encoding_size=batch_encoding_size, + streaming_encoder=streaming_enc, + initial_frames=self.meta.total_frames, + ) + else: self.writer = None - def __del__(self): + self._is_finalized = False + + # ── Writer guard ────────────────────────────────────────────────── + + def _require_writer(self, method_name: str) -> None: + if self.writer is None: + raise RuntimeError( + f"Cannot call '{method_name}()' on a read-only dataset. " + f"Use LeRobotDataset.create() for new recording or " + f"LeRobotDataset.resume() for resume recording." + ) + + # ── Reader guard ────────────────────────────────────────────────── + + def _ensure_reader(self) -> DatasetReader: + """Lazily create the reader on first access.""" + if self.reader is None: + self.reader = DatasetReader( + meta=self.meta, + root=self.root, + episodes=self.episodes, + tolerance_s=self.tolerance_s, + video_backend=self._video_backend, + delta_timestamps=self.delta_timestamps, + image_transforms=self.image_transforms, + ) + return self.reader + + @staticmethod + def _build_streaming_encoder( + fps: int, + vcodec: str, + encoder_queue_maxsize: int, + encoder_threads: int | None, + ) -> StreamingVideoEncoder: + return StreamingVideoEncoder( + fps=fps, + vcodec=vcodec, + pix_fmt="yuv420p", + g=2, + crf=30, + preset=None, + queue_maxsize=encoder_queue_maxsize, + encoder_threads=encoder_threads, + ) + + # ── Metadata properties ─────────────────────────────────────────── + + @property + def fps(self) -> int: + """Frames per second used during data collection.""" + return self.meta.fps + + @property + def num_frames(self) -> int: + """Number of frames in selected episodes.""" + # Check directly instead of using _ensure_reader(): in write-only mode + # (create/resume) we rely on metadata rather than initializing a reader. + if self.reader is None: + return self.meta.total_frames + return self.reader.num_frames + + @property + def num_episodes(self) -> int: + """Number of episodes selected.""" + # Check directly instead of using _ensure_reader(): in write-only mode + # (create/resume) we rely on metadata rather than initializing a reader. + if self.reader is None: + return self.meta.total_episodes + return self.reader.num_episodes + + @property + def features(self) -> dict[str, dict]: + """Feature specification dict mapping feature names to their type/shape metadata.""" + return self.meta.features + + @property + def hf_dataset(self) -> datasets.Dataset: + """The underlying Hugging Face Dataset object""" + self.reader = self._ensure_reader() + if self.reader.hf_dataset is None: + self.reader.load_and_activate() + return self.reader.hf_dataset + + # ── Writer-delegated methods ────────────────────────────────────── + + def add_frame(self, frame: dict) -> None: + """Add a single frame to the current episode buffer. + + Delegates to :meth:`DatasetWriter.add_frame`. The dataset must be in + write mode (created via :meth:`create` or :meth:`resume`). + + Args: + frame: Dict mapping feature names to their values for this frame. + Must include a ``'task'`` key. Torch tensors are converted to numpy. + + Raises: + RuntimeError: If the dataset is read-only (no writer). """ - Trust the user to call .finalize() but as an added safety check call the parquet writer to stop when calling the destructor + self._require_writer("add_frame") + self.writer.add_frame(frame) + + def save_episode(self, episode_data: dict | None = None, parallel_encoding: bool = True) -> None: + """Save the current episode buffer to disk. + + Delegates to :meth:`DatasetWriter.save_episode`. Encodes videos, writes + parquet data, and updates metadata. The episode buffer is reset afterward. + + Args: + episode_data: Optional pre-built episode dict. If ``None``, uses the + internal episode buffer populated by :meth:`add_frame`. + parallel_encoding: If ``True`` and multiple cameras exist, encode + videos in parallel using a process pool. + + Raises: + RuntimeError: If the dataset is read-only (no writer). """ - self._close_writer() + self._require_writer("save_episode") + self.writer.save_episode(episode_data, parallel_encoding) + + def clear_episode_buffer(self, delete_images: bool = True) -> None: + """Discard the current episode buffer without saving. + + Delegates to :meth:`DatasetWriter.clear_episode_buffer`. Useful for + discarding a failed or interrupted recording episode. + + Args: + delete_images: If ``True``, also remove temporary image files written + to disk for the current episode. + + Raises: + RuntimeError: If the dataset is read-only (no writer). + """ + self._require_writer("clear_episode_buffer") + self.writer.clear_episode_buffer(delete_images) + + def has_pending_frames(self) -> bool: + """Check if there are unsaved frames in the episode buffer.""" + if self.writer is None: + return False + return self.writer.episode_buffer is not None and self.writer.episode_buffer["size"] > 0 + + def finalize(self): + """Flush all pending work and close writers. + + Must be called after data collection/conversion, otherwise footer metadata + won't be written to the parquet files and the dataset will be invalid. + + Idempotent — safe to call multiple times. DatasetWriter.__del__ acts as a + safety net if this is never called explicitly. + """ + if self._is_finalized: + return + if self.writer is not None: + self.writer.finalize() + self._is_finalized = True + + # ── Core Dataset methods ────────────────────────────────────────── + + def __len__(self): + """Return the number of frames in the selected episodes.""" + return self.num_frames + + def __getitem__(self, idx) -> dict: + """Return a single frame by index, with all transforms applied. + + Loads the frame from the underlying HF dataset, expands delta-timestamp + windows, decodes video frames, and applies image transforms. Delegates + the core logic to :meth:`DatasetReader.get_item`. + + Args: + idx: Index into the (possibly episode-filtered) dataset. + + Returns: + Dict mapping feature names to their tensor values for this frame. + + Raises: + RuntimeError: If the dataset is currently being recorded and + :meth:`finalize` has not been called yet. + """ + if self.writer is not None and not self._is_finalized: + raise RuntimeError( + "Cannot read from a dataset that is being recorded. Call finalize() first, then access items." + ) + reader = self._ensure_reader() + if reader.hf_dataset is None: + # One-shot load after finalize() + reader.load_and_activate() + return reader.get_item(idx) + + def select_columns(self, column_names: str | list[str]): + """Select specific columns from the underlying dataset. + + Useful for extracting action sequences during replay without loading all features. + Returns a ``datasets.Dataset`` containing only the requested columns. + """ + return self.hf_dataset.select_columns(column_names) + + def get_raw_item(self, idx) -> dict: + """Get a raw frame without image transforms applied. + + Unlike ``__getitem__``, this returns the raw HF dataset row at the given + index with no delta-timestamp expansion, video decoding, or image transforms. + """ + return self.hf_dataset[idx] + + def __repr__(self): + feature_keys = list(self.features) + return ( + f"{self.__class__.__name__}({{\n" + f" Repository ID: '{self.repo_id}',\n" + f" Number of selected episodes: '{self.num_episodes}',\n" + f" Number of selected samples: '{self.num_frames}',\n" + f" Features: '{feature_keys}',\n" + f"}})" + ) + + # ── Hub methods (stay on facade) ────────────────────────────────── def push_to_hub( self, @@ -331,6 +485,27 @@ class LeRobotDataset(torch.utils.data.Dataset): upload_large_folder: bool = False, **card_kwargs, ) -> None: + """Upload the dataset to the Hugging Face Hub. + + Creates the repository if it does not exist, uploads all dataset files + (optionally excluding videos), generates a dataset card, and tags the + revision with the current codebase version. + + Args: + branch: Optional branch to push to. Created from the current + revision if it does not exist. + tags: Optional list of tags for the dataset card. + license: License identifier for the dataset card. + tag_version: If ``True``, create a Git tag for the current codebase + version. + push_videos: If ``False``, skip uploading the ``videos/`` directory. + private: If ``True``, create a private repository. + allow_patterns: Glob pattern(s) restricting which files to upload. + upload_large_folder: If ``True``, use ``upload_large_folder`` instead + of ``upload_folder`` for very large datasets. + **card_kwargs: Additional keyword arguments forwarded to dataset card + creation. + """ ignore_patterns = ["images/"] if not push_videos: ignore_patterns.append("videos/") @@ -374,795 +549,23 @@ class LeRobotDataset(torch.utils.data.Dataset): hub_api.delete_tag(self.repo_id, tag=CODEBASE_VERSION, repo_type="dataset") hub_api.create_tag(self.repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset") - def pull_from_repo( - self, - allow_patterns: list[str] | str | None = None, - ignore_patterns: list[str] | str | None = None, - ) -> None: + def _download(self, download_videos: bool = True) -> None: + """Downloads the dataset from the given 'repo_id' at the provided version.""" + ignore_patterns = None if download_videos else "videos/" + files = None + if self.episodes is not None: + # Reader is guaranteed to exist here (created in __init__ before _download) + files = self.reader.get_episodes_file_paths() snapshot_download( self.repo_id, repo_type="dataset", revision=self.revision, local_dir=self.root, - allow_patterns=allow_patterns, + allow_patterns=files, ignore_patterns=ignore_patterns, ) - def download(self, download_videos: bool = True) -> None: - """Downloads the dataset from the given 'repo_id' at the provided version. If 'episodes' is given, this - will only download those episodes (selected by their episode_index). If 'episodes' is None, the whole - dataset will be downloaded. Thanks to the behavior of snapshot_download, if the files are already present - in 'local_dir', they won't be downloaded again. - """ - # TODO(rcadene, aliberts): implement faster transfer - # https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads - ignore_patterns = None if download_videos else "videos/" - files = None - if self.episodes is not None: - files = self.get_episodes_file_paths() - self.pull_from_repo(allow_patterns=files, ignore_patterns=ignore_patterns) - - def get_episodes_file_paths(self) -> list[Path]: - episodes = self.episodes if self.episodes is not None else list(range(self.meta.total_episodes)) - fpaths = [str(self.meta.get_data_file_path(ep_idx)) for ep_idx in episodes] - if len(self.meta.video_keys) > 0: - video_files = [ - str(self.meta.get_video_file_path(ep_idx, vid_key)) - for vid_key in self.meta.video_keys - for ep_idx in episodes - ] - fpaths += video_files - # episodes are stored in the same files, so we return unique paths only - fpaths = list(set(fpaths)) - return fpaths - - def load_hf_dataset(self) -> datasets.Dataset: - """hf_dataset contains all the observations, states, actions, rewards, etc.""" - features = get_hf_features_from_features(self.features) - hf_dataset = load_nested_dataset(self.root / "data", features=features, episodes=self.episodes) - hf_dataset.set_transform(hf_transform_to_torch) - return hf_dataset - - def _check_cached_episodes_sufficient(self) -> bool: - """Check if the cached dataset contains all requested episodes and their video files.""" - if self.hf_dataset is None or len(self.hf_dataset) == 0: - return False - - # Get available episode indices from cached dataset - available_episodes = { - ep_idx.item() if isinstance(ep_idx, torch.Tensor) else ep_idx - for ep_idx in self.hf_dataset.unique("episode_index") - } - - # Determine requested episodes - if self.episodes is None: - requested_episodes = set(range(self.meta.total_episodes)) - else: - requested_episodes = set(self.episodes) - - # Check if all requested episodes are available in cached data - if not requested_episodes.issubset(available_episodes): - return False - - # Check if all required video files exist - if len(self.meta.video_keys) > 0: - for ep_idx in requested_episodes: - for vid_key in self.meta.video_keys: - video_path = self.root / self.meta.get_video_file_path(ep_idx, vid_key) - if not video_path.exists(): - return False - - return True - - def create_hf_dataset(self) -> datasets.Dataset: - features = get_hf_features_from_features(self.features) - ft_dict = {col: [] for col in features} - hf_dataset = datasets.Dataset.from_dict(ft_dict, features=features, split="train") - hf_dataset.set_transform(hf_transform_to_torch) - return hf_dataset - - @property - def fps(self) -> int: - """Frames per second used during data collection.""" - return self.meta.fps - - @property - def num_frames(self) -> int: - """Number of frames in selected episodes. - - Note: When episodes a subset of the full dataset is requested, we must return the - actual loaded data length (len(self.hf_dataset)) rather than metadata total_frames. - self.meta.total_frames is the total number of frames in the full dataset. - """ - if self.episodes is not None and self.hf_dataset is not None: - return len(self.hf_dataset) - return self.meta.total_frames - - @property - def num_episodes(self) -> int: - """Number of episodes selected.""" - return len(self.episodes) if self.episodes is not None else self.meta.total_episodes - - @property - def features(self) -> dict[str, dict]: - return self.meta.features - - @property - def hf_features(self) -> datasets.Features: - """Features of the hf_dataset.""" - if self.hf_dataset is not None: - return self.hf_dataset.features - else: - return get_hf_features_from_features(self.features) - - def _get_query_indices( - self, abs_idx: int, ep_idx: int - ) -> tuple[dict[str, list[int]], dict[str, torch.Tensor]]: - """Compute query indices for delta timestamps. - - Args: - abs_idx: The absolute index in the full dataset (not the relative index in filtered episodes). - ep_idx: The episode index. - - Returns: - A tuple of (query_indices, padding) where: - - query_indices: Dict mapping keys to lists of absolute indices to query - - padding: Dict mapping "{key}_is_pad" to boolean tensors indicating padded positions - """ - ep = self.meta.episodes[ep_idx] - ep_start = ep["dataset_from_index"] - ep_end = ep["dataset_to_index"] - query_indices = { - key: [max(ep_start, min(ep_end - 1, abs_idx + delta)) for delta in delta_idx] - for key, delta_idx in self.delta_indices.items() - } - padding = { # Pad values outside of current episode range - f"{key}_is_pad": torch.BoolTensor( - [(abs_idx + delta < ep_start) | (abs_idx + delta >= ep_end) for delta in delta_idx] - ) - for key, delta_idx in self.delta_indices.items() - } - return query_indices, padding - - def _get_query_timestamps( - self, - current_ts: float, - query_indices: dict[str, list[int]] | None = None, - ) -> dict[str, list[float]]: - query_timestamps = {} - for key in self.meta.video_keys: - if query_indices is not None and key in query_indices: - if self._absolute_to_relative_idx is not None: - relative_indices = [self._absolute_to_relative_idx[idx] for idx in query_indices[key]] - timestamps = self.hf_dataset[relative_indices]["timestamp"] - else: - timestamps = self.hf_dataset[query_indices[key]]["timestamp"] - query_timestamps[key] = torch.stack(timestamps).tolist() - else: - query_timestamps[key] = [current_ts] - - return query_timestamps - - def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict: - """ - Query dataset for indices across keys, skipping video keys. - - Tries column-first [key][indices] for speed, falls back to row-first. - - Args: - query_indices: Dict mapping keys to index lists to retrieve - - Returns: - Dict with stacked tensors of queried data (video keys excluded) - """ - result: dict = {} - for key, q_idx in query_indices.items(): - if key in self.meta.video_keys: - continue - # Map absolute indices to relative indices if needed - relative_indices = ( - q_idx - if self._absolute_to_relative_idx is None - else [self._absolute_to_relative_idx[idx] for idx in q_idx] - ) - try: - result[key] = torch.stack(self.hf_dataset[key][relative_indices]) - except (KeyError, TypeError, IndexError): - result[key] = torch.stack(self.hf_dataset[relative_indices][key]) - return result - - def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict[str, torch.Tensor]: - """Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function - in the main process (e.g. by using a second Dataloader with num_workers=0). It will result in a - Segmentation Fault. This probably happens because a memory reference to the video loader is created in - the main process and a subprocess fails to access it. - """ - ep = self.meta.episodes[ep_idx] - item = {} - for vid_key, query_ts in query_timestamps.items(): - # Episodes are stored sequentially on a single mp4 to reduce the number of files. - # Thus we load the start timestamp of the episode on this mp4 and, - # shift the query timestamp accordingly. - from_timestamp = ep[f"videos/{vid_key}/from_timestamp"] - shifted_query_ts = [from_timestamp + ts for ts in query_ts] - - video_path = self.root / self.meta.get_video_file_path(ep_idx, vid_key) - frames = decode_video_frames(video_path, shifted_query_ts, self.tolerance_s, self.video_backend) - item[vid_key] = frames.squeeze(0) - - return item - - def _ensure_hf_dataset_loaded(self): - """Lazy load the HF dataset only when needed for reading.""" - if self._lazy_loading or self.hf_dataset is None: - # Close the writer before loading to ensure parquet file is properly finalized - if self.writer is not None: - self._close_writer() - self._writer_closed_for_reading = True - self.hf_dataset = self.load_hf_dataset() - self._lazy_loading = False - - def __len__(self): - return self.num_frames - - def __getitem__(self, idx) -> dict: - # Ensure dataset is loaded when we actually need to read from it - self._ensure_hf_dataset_loaded() - item = self.hf_dataset[idx] - ep_idx = item["episode_index"].item() - # Use the absolute index from the dataset for delta timestamp calculations - abs_idx = item["index"].item() - - query_indices = None - if self.delta_indices is not None: - query_indices, padding = self._get_query_indices(abs_idx, ep_idx) - query_result = self._query_hf_dataset(query_indices) - item = {**item, **padding} - for key, val in query_result.items(): - item[key] = val - - if len(self.meta.video_keys) > 0: - current_ts = item["timestamp"].item() - query_timestamps = self._get_query_timestamps(current_ts, query_indices) - video_frames = self._query_videos(query_timestamps, ep_idx) - item = {**video_frames, **item} - - if self.image_transforms is not None: - image_keys = self.meta.camera_keys - for cam in image_keys: - item[cam] = self.image_transforms(item[cam]) - - # Add task as a string - task_idx = item["task_index"].item() - item["task"] = self.meta.tasks.iloc[task_idx].name - - # add subtask information if available - if "subtask_index" in self.features and self.meta.subtasks is not None: - subtask_idx = item["subtask_index"].item() - item["subtask"] = self.meta.subtasks.iloc[subtask_idx].name - - return item - - def __repr__(self): - feature_keys = list(self.features) - return ( - f"{self.__class__.__name__}({{\n" - f" Repository ID: '{self.repo_id}',\n" - f" Number of selected episodes: '{self.num_episodes}',\n" - f" Number of selected samples: '{self.num_frames}',\n" - f" Features: '{feature_keys}',\n" - "})',\n" - ) - - def finalize(self): - """ - Close the parquet writers. This function needs to be called after data collection/conversion, else footer metadata won't be written to the parquet files. - The dataset won't be valid and can't be loaded as ds = LeRobotDataset(repo_id=repo, root=HF_LEROBOT_HOME.joinpath(repo)) - """ - self._close_writer() - self.meta._close_writer() - if self._streaming_encoder is not None: - self._streaming_encoder.close() - - def create_episode_buffer(self, episode_index: int | None = None) -> dict: - current_ep_idx = self.meta.total_episodes if episode_index is None else episode_index - ep_buffer = {} - # size and task are special cases that are not in self.features - ep_buffer["size"] = 0 - ep_buffer["task"] = [] - for key in self.features: - ep_buffer[key] = current_ep_idx if key == "episode_index" else [] - return ep_buffer - - # TODO(Steven): consider move this to utils - def _get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path: - fpath = DEFAULT_IMAGE_PATH.format( - image_key=image_key, episode_index=episode_index, frame_index=frame_index - ) - return self.root / fpath - - def _get_image_file_dir(self, episode_index: int, image_key: str) -> Path: - return self._get_image_file_path(episode_index, image_key, frame_index=0).parent - - def _save_image( - self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path, compress_level: int = 1 - ) -> None: - if self.image_writer is None: - if isinstance(image, torch.Tensor): - image = image.cpu().numpy() - write_image(image, fpath, compress_level=compress_level) - else: - self.image_writer.save_image(image=image, fpath=fpath, compress_level=compress_level) - - def add_frame(self, frame: dict) -> None: - """ - This function only adds the frame to the episode_buffer. Apart from images — which are written in a - temporary directory — nothing is written to disk. To save those frames, the 'save_episode()' method - then needs to be called. - """ - # Convert torch to numpy if needed - for name in frame: - if isinstance(frame[name], torch.Tensor): - frame[name] = frame[name].numpy() - - validate_frame(frame, self.features) - - if self.episode_buffer is None: - self.episode_buffer = self.create_episode_buffer() - - # Automatically add frame_index and timestamp to episode buffer - frame_index = self.episode_buffer["size"] - timestamp = frame.pop("timestamp") if "timestamp" in frame else frame_index / self.fps - self.episode_buffer["frame_index"].append(frame_index) - self.episode_buffer["timestamp"].append(timestamp) - self.episode_buffer["task"].append(frame.pop("task")) # Remove task from frame after processing - - # Start streaming encoder on first frame of episode (once, before iterating keys) - if frame_index == 0 and self._streaming_encoder is not None: - self._streaming_encoder.start_episode( - video_keys=list(self.meta.video_keys), - temp_dir=self.root, - ) - - # Add frame features to episode_buffer - for key in frame: - if key not in self.features: - raise ValueError( - f"An element of the frame is not in the features. '{key}' not in '{self.features.keys()}'." - ) - - if self.features[key]["dtype"] == "video" and self._streaming_encoder is not None: - self._streaming_encoder.feed_frame(key, frame[key]) - self.episode_buffer[key].append(None) # Placeholder (video keys are skipped in parquet) - elif self.features[key]["dtype"] in ["image", "video"]: - img_path = self._get_image_file_path( - episode_index=self.episode_buffer["episode_index"], image_key=key, frame_index=frame_index - ) - if frame_index == 0: - img_path.parent.mkdir(parents=True, exist_ok=True) - compress_level = 1 if self.features[key]["dtype"] == "video" else 6 - self._save_image(frame[key], img_path, compress_level) - self.episode_buffer[key].append(str(img_path)) - else: - self.episode_buffer[key].append(frame[key]) - - self.episode_buffer["size"] += 1 - - def save_episode( - self, - episode_data: dict | None = None, - parallel_encoding: bool = True, - ) -> None: - """ - This will save to disk the current episode in self.episode_buffer. - - Video encoding is handled automatically based on batch_encoding_size: - - If batch_encoding_size == 1: Videos are encoded immediately after each episode - - If batch_encoding_size > 1: Videos are encoded in batches. - - Args: - episode_data (dict | None, optional): Dict containing the episode data to save. If None, this will - save the current episode in self.episode_buffer, which is filled with 'add_frame'. Defaults to - None. - parallel_encoding (bool, optional): If True, encode videos in parallel using ProcessPoolExecutor. - Defaults to True on Linux, False on macOS as it tends to use all the CPU available already. - """ - episode_buffer = episode_data if episode_data is not None else self.episode_buffer - - validate_episode_buffer(episode_buffer, self.meta.total_episodes, self.features) - - # size and task are special cases that won't be added to hf_dataset - episode_length = episode_buffer.pop("size") - tasks = episode_buffer.pop("task") - episode_tasks = list(set(tasks)) - episode_index = episode_buffer["episode_index"] - - episode_buffer["index"] = np.arange(self.meta.total_frames, self.meta.total_frames + episode_length) - episode_buffer["episode_index"] = np.full((episode_length,), episode_index) - - # Update tasks and task indices with new tasks if any - self.meta.save_episode_tasks(episode_tasks) - - # Given tasks in natural language, find their corresponding task indices - episode_buffer["task_index"] = np.array([self.meta.get_task_index(task) for task in tasks]) - - for key, ft in self.features.items(): - # index, episode_index, task_index are already processed above, and image and video - # are processed separately by storing image path and frame info as meta data - if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video"]: - continue - episode_buffer[key] = np.stack(episode_buffer[key]) - - # Wait for image writer to end, so that episode stats over images can be computed - self._wait_image_writer() - - has_video_keys = len(self.meta.video_keys) > 0 - use_streaming = self._streaming_encoder is not None and has_video_keys - use_batched_encoding = self.batch_encoding_size > 1 - - if use_streaming: - # Compute stats for non-video features only (video stats come from encoder) - non_video_buffer = { - k: v - for k, v in episode_buffer.items() - if self.features.get(k, {}).get("dtype") not in ("video",) - } - non_video_features = {k: v for k, v in self.features.items() if v["dtype"] != "video"} - ep_stats = compute_episode_stats(non_video_buffer, non_video_features) - else: - ep_stats = compute_episode_stats(episode_buffer, self.features) - - ep_metadata = self._save_episode_data(episode_buffer) - - if use_streaming: - # Finish streaming encoding and collect results - streaming_results = self._streaming_encoder.finish_episode() - for video_key in self.meta.video_keys: - temp_path, video_stats = streaming_results[video_key] - if video_stats is not None: - # Format stats same as compute_episode_stats: normalize to [0,1], reshape to (C,1,1) - ep_stats[video_key] = { - k: v if k == "count" else np.squeeze(v.reshape(1, -1, 1, 1) / 255.0, axis=0) - for k, v in video_stats.items() - } - ep_metadata.update(self._save_episode_video(video_key, episode_index, temp_path=temp_path)) - elif has_video_keys and not use_batched_encoding: - num_cameras = len(self.meta.video_keys) - if parallel_encoding and num_cameras > 1: - # TODO(Steven): Ideally we would like to control the number of threads per encoding such that: - # num_cameras * num_threads = (total_cpu -1) - with concurrent.futures.ProcessPoolExecutor(max_workers=num_cameras) as executor: - future_to_key = { - executor.submit( - _encode_video_worker, - video_key, - episode_index, - self.root, - self.fps, - self.vcodec, - self._encoder_threads, - ): video_key - for video_key in self.meta.video_keys - } - - results = {} - for future in concurrent.futures.as_completed(future_to_key): - video_key = future_to_key[future] - try: - temp_path = future.result() - results[video_key] = temp_path - except Exception as exc: - logger.error(f"Video encoding failed for {video_key}: {exc}") - raise exc - - for video_key in self.meta.video_keys: - temp_path = results[video_key] - ep_metadata.update( - self._save_episode_video(video_key, episode_index, temp_path=temp_path) - ) - else: - for video_key in self.meta.video_keys: - ep_metadata.update(self._save_episode_video(video_key, episode_index)) - - # `meta.save_episode` need to be executed after encoding the videos - self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats, ep_metadata) - - if has_video_keys and use_batched_encoding: - # Check if we should trigger batch encoding - self.episodes_since_last_encoding += 1 - if self.episodes_since_last_encoding == self.batch_encoding_size: - start_ep = self.num_episodes - self.batch_encoding_size - end_ep = self.num_episodes - self._batch_save_episode_video(start_ep, end_ep) - self.episodes_since_last_encoding = 0 - - if not episode_data: - # Reset episode buffer and clean up temporary images (if not already deleted during video encoding) - self.clear_episode_buffer(delete_images=len(self.meta.image_keys) > 0) - - def _batch_save_episode_video(self, start_episode: int, end_episode: int | None = None) -> None: - """ - Batch save videos for multiple episodes. - - Args: - start_episode: Starting episode index (inclusive) - end_episode: Ending episode index (exclusive). If None, encodes all episodes from start_episode to the current episode. - """ - if end_episode is None: - end_episode = self.num_episodes - - logger.info( - f"Batch encoding {self.batch_encoding_size} videos for episodes {start_episode} to {end_episode - 1}" - ) - - chunk_idx = self.meta.episodes[start_episode]["data/chunk_index"] - file_idx = self.meta.episodes[start_episode]["data/file_index"] - episode_df_path = self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx) - episode_df = pd.read_parquet(episode_df_path) - - for ep_idx in range(start_episode, end_episode): - logger.info(f"Encoding videos for episode {ep_idx}") - - if ( - self.meta.episodes[ep_idx]["data/chunk_index"] != chunk_idx - or self.meta.episodes[ep_idx]["data/file_index"] != file_idx - ): - # The current episode is in a new chunk or file. - # Save previous episode dataframe and update the Hugging Face dataset by reloading it. - episode_df.to_parquet(episode_df_path) - self.meta.episodes = load_episodes(self.root) - - # Load new episode dataframe - chunk_idx = self.meta.episodes[ep_idx]["data/chunk_index"] - file_idx = self.meta.episodes[ep_idx]["data/file_index"] - episode_df_path = self.root / DEFAULT_EPISODES_PATH.format( - chunk_index=chunk_idx, file_index=file_idx - ) - episode_df = pd.read_parquet(episode_df_path) - - # Save the current episode's video metadata to the dataframe - video_ep_metadata = {} - for video_key in self.meta.video_keys: - video_ep_metadata.update(self._save_episode_video(video_key, ep_idx)) - video_ep_metadata.pop("episode_index") - video_ep_df = pd.DataFrame(video_ep_metadata, index=[ep_idx]).convert_dtypes( - dtype_backend="pyarrow" - ) # allows NaN values along with integers - - episode_df = episode_df.combine_first(video_ep_df) - episode_df.to_parquet(episode_df_path) - self.meta.episodes = load_episodes(self.root) - - def _save_episode_data(self, episode_buffer: dict) -> dict: - """Save episode data to a parquet file and update the Hugging Face dataset of frames data. - - This function processes episodes data from a buffer, converts it into a Hugging Face dataset, - and saves it as a parquet file. It handles both the creation of new parquet files and the - updating of existing ones based on size constraints. After saving the data, it reloads - the Hugging Face dataset to ensure it is up-to-date. - - Notes: We both need to update parquet files and HF dataset: - - `pandas` loads parquet file in RAM - - `datasets` relies on a memory mapping from pyarrow (no RAM). It either converts parquet files to a pyarrow cache on disk, - or loads directly from pyarrow cache. - """ - # Convert buffer into HF Dataset - ep_dict = {key: episode_buffer[key] for key in self.hf_features} - ep_dataset = datasets.Dataset.from_dict(ep_dict, features=self.hf_features, split="train") - ep_dataset = embed_images(ep_dataset) - ep_num_frames = len(ep_dataset) - - if self.latest_episode is None: - # Initialize indices and frame count for a new dataset made of the first episode data - chunk_idx, file_idx = 0, 0 - global_frame_index = 0 - self._current_file_start_frame = 0 - # However, if the episodes already exists - # It means we are resuming recording, so we need to load the latest episode - # Update the indices to avoid overwriting the latest episode - if self.meta.episodes is not None and len(self.meta.episodes) > 0: - latest_ep = self.meta.episodes[-1] - global_frame_index = latest_ep["dataset_to_index"] - chunk_idx = latest_ep["data/chunk_index"] - file_idx = latest_ep["data/file_index"] - - # When resuming, move to the next file - chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.meta.chunks_size) - self._current_file_start_frame = global_frame_index - else: - # Retrieve information from the latest parquet file - latest_ep = self.latest_episode - chunk_idx = latest_ep["data/chunk_index"] - file_idx = latest_ep["data/file_index"] - global_frame_index = latest_ep["index"][-1] + 1 - - latest_path = self.root / self.meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx) - latest_size_in_mb = get_file_size_in_mb(latest_path) - - frames_in_current_file = global_frame_index - self._current_file_start_frame - av_size_per_frame = ( - latest_size_in_mb / frames_in_current_file if frames_in_current_file > 0 else 0 - ) - - # Determine if a new parquet file is needed - if ( - latest_size_in_mb + av_size_per_frame * ep_num_frames >= self.meta.data_files_size_in_mb - or self._writer_closed_for_reading - ): - # Size limit is reached or writer was closed for reading, prepare new parquet file - chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.meta.chunks_size) - self._close_writer() - self._writer_closed_for_reading = False - self._current_file_start_frame = global_frame_index - - ep_dict["data/chunk_index"] = chunk_idx - ep_dict["data/file_index"] = file_idx - - # Write the resulting dataframe from RAM to disk - path = self.root / self.meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx) - path.parent.mkdir(parents=True, exist_ok=True) - - table = ep_dataset.with_format("arrow")[:] - if not self.writer: - self.writer = pq.ParquetWriter( - path, schema=table.schema, compression="snappy", use_dictionary=True - ) - self.writer.write_table(table) - - metadata = { - "data/chunk_index": chunk_idx, - "data/file_index": file_idx, - "dataset_from_index": global_frame_index, - "dataset_to_index": global_frame_index + ep_num_frames, - } - - # Store metadata with episode data for next episode - self.latest_episode = {**ep_dict, **metadata} - - # Mark that the HF dataset needs reloading (lazy loading approach) - # This avoids expensive reloading during sequential recording - self._lazy_loading = True - # Update recorded frames count for efficient length tracking - self._recorded_frames += ep_num_frames - - return metadata - - def _save_episode_video( - self, - video_key: str, - episode_index: int, - temp_path: Path | None = None, - ) -> dict: - # Encode episode frames into a temporary video - if temp_path is None: - ep_path = self._encode_temporary_episode_video(video_key, episode_index) - else: - ep_path = temp_path - - ep_size_in_mb = get_file_size_in_mb(ep_path) - ep_duration_in_s = get_video_duration_in_s(ep_path) - - if ( - episode_index == 0 - or self.meta.latest_episode is None - or f"videos/{video_key}/chunk_index" not in self.meta.latest_episode - ): - # Initialize indices for a new dataset made of the first episode data - chunk_idx, file_idx = 0, 0 - if self.meta.episodes is not None and len(self.meta.episodes) > 0: - # It means we are resuming recording, so we need to load the latest episode - # Update the indices to avoid overwriting the latest episode - old_chunk_idx = self.meta.episodes[-1][f"videos/{video_key}/chunk_index"] - old_file_idx = self.meta.episodes[-1][f"videos/{video_key}/file_index"] - chunk_idx, file_idx = update_chunk_file_indices( - old_chunk_idx, old_file_idx, self.meta.chunks_size - ) - latest_duration_in_s = 0.0 - new_path = self.root / self.meta.video_path.format( - video_key=video_key, chunk_index=chunk_idx, file_index=file_idx - ) - new_path.parent.mkdir(parents=True, exist_ok=True) - shutil.move(str(ep_path), str(new_path)) - else: - # Retrieve information from the latest updated video file using latest_episode - latest_ep = self.meta.latest_episode - chunk_idx = latest_ep[f"videos/{video_key}/chunk_index"][0] - file_idx = latest_ep[f"videos/{video_key}/file_index"][0] - - latest_path = self.root / self.meta.video_path.format( - video_key=video_key, chunk_index=chunk_idx, file_index=file_idx - ) - latest_size_in_mb = get_file_size_in_mb(latest_path) - latest_duration_in_s = latest_ep[f"videos/{video_key}/to_timestamp"][0] - - if latest_size_in_mb + ep_size_in_mb >= self.meta.video_files_size_in_mb: - # Move temporary episode video to a new video file in the dataset - chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.meta.chunks_size) - new_path = self.root / self.meta.video_path.format( - video_key=video_key, chunk_index=chunk_idx, file_index=file_idx - ) - new_path.parent.mkdir(parents=True, exist_ok=True) - shutil.move(str(ep_path), str(new_path)) - latest_duration_in_s = 0.0 - else: - # Update latest video file - concatenate_video_files( - [latest_path, ep_path], - latest_path, - ) - - # Remove temporary directory - shutil.rmtree(str(ep_path.parent)) - - # Update video info (only needed when first episode is encoded since it reads from episode 0) - if episode_index == 0: - self.meta.update_video_info(video_key) - write_info(self.meta.info, self.meta.root) # ensure video info always written properly - - metadata = { - "episode_index": episode_index, - f"videos/{video_key}/chunk_index": chunk_idx, - f"videos/{video_key}/file_index": file_idx, - f"videos/{video_key}/from_timestamp": latest_duration_in_s, - f"videos/{video_key}/to_timestamp": latest_duration_in_s + ep_duration_in_s, - } - return metadata - - def clear_episode_buffer(self, delete_images: bool = True) -> None: - # Cancel streaming encoder if active - if self._streaming_encoder is not None: - self._streaming_encoder.cancel_episode() - - # Clean up image files for the current episode buffer - if delete_images: - # Wait for the async image writer to finish - if self.image_writer is not None: - self._wait_image_writer() - episode_index = self.episode_buffer["episode_index"] - if isinstance(episode_index, np.ndarray): - episode_index = episode_index.item() if episode_index.size == 1 else episode_index[0] - for cam_key in self.meta.image_keys: - img_dir = self._get_image_file_dir(episode_index, cam_key) - if img_dir.is_dir(): - shutil.rmtree(img_dir) - - # Reset the buffer - self.episode_buffer = self.create_episode_buffer() - - def start_image_writer(self, num_processes: int = 0, num_threads: int = 4) -> None: - if isinstance(self.image_writer, AsyncImageWriter): - logger.warning( - "You are starting a new AsyncImageWriter that is replacing an already existing one in the dataset." - ) - - self.image_writer = AsyncImageWriter( - num_processes=num_processes, - num_threads=num_threads, - ) - - def stop_image_writer(self) -> None: - """ - Whenever wrapping this dataset inside a parallelized DataLoader, this needs to be called first to - remove the image_writer in order for the LeRobotDataset object to be pickleable and parallelized. - """ - if self.image_writer is not None: - self.image_writer.stop() - self.image_writer = None - - def _wait_image_writer(self) -> None: - """Wait for asynchronous image writer to finish.""" - if self.image_writer is not None: - self.image_writer.wait_until_done() - - def _encode_temporary_episode_video(self, video_key: str, episode_index: int) -> Path: - """ - Use ffmpeg to convert frames stored as png into mp4 videos. - Note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding, - since video encoding with ffmpeg is already using multithreading. - """ - return _encode_video_worker( - video_key, episode_index, self.root, self.fps, self.vcodec, self._encoder_threads - ) + # ── Class constructors ──────────────────────────────────────────── @classmethod def create( @@ -1184,7 +587,42 @@ class LeRobotDataset(torch.utils.data.Dataset): encoder_queue_maxsize: int = 30, encoder_threads: int | None = None, ) -> "LeRobotDataset": - """Create a LeRobot Dataset from scratch in order to record data.""" + """Create a new LeRobotDataset from scratch for recording data. + + Returns a write-mode dataset with an active :class:`DatasetWriter`. Use + :meth:`add_frame` / :meth:`save_episode` to populate it, then + :meth:`finalize` when done. + + Args: + repo_id: Repository identifier, typically ``'{hf_user}/{dataset_name}'``. + fps: Frames per second used during data collection. + features: Feature specification dict mapping feature names to their + type/shape metadata. + root: Local directory for dataset storage. Defaults to + ``$HF_LEROBOT_HOME/{repo_id}``. + robot_type: Optional robot type string stored in metadata. + use_videos: If ``True``, visual modalities are stored as MP4 videos. + If ``False``, they are stored as images. + tolerance_s: Timestamp synchronization tolerance in seconds. + image_writer_processes: Number of subprocesses for async image + writing. ``0`` means use threads only. + image_writer_threads: Number of threads for async image writing. + video_backend: Video decoding backend (used when reading back). + batch_encoding_size: Number of episodes to accumulate before + batch-encoding videos. ``1`` means encode immediately. + vcodec: Video codec for encoding. Options include ``'libsvtav1'``, + ``'h264'``, ``'hevc'``, ``'auto'``. + metadata_buffer_size: Number of episode metadata records to buffer + before flushing to parquet. + streaming_encoding: If ``True``, encode video frames in real-time + during capture instead of writing images first. + encoder_queue_maxsize: Max buffered frames per camera when using + streaming encoding. + encoder_threads: Threads per encoder instance. ``None`` for auto. + + Returns: + A new :class:`LeRobotDataset` in write mode. + """ vcodec = resolve_vcodec(vcodec) obj = cls.__new__(cls) obj.meta = LeRobotDatasetMetadata.create( @@ -1200,45 +638,126 @@ class LeRobotDataset(torch.utils.data.Dataset): obj.root = obj.meta.root obj.revision = None obj.tolerance_s = tolerance_s - obj.image_writer = None - obj.batch_encoding_size = batch_encoding_size - obj.episodes_since_last_encoding = 0 - obj.vcodec = vcodec - obj._encoder_threads = encoder_threads - - if image_writer_processes or image_writer_threads: - obj.start_image_writer(image_writer_processes, image_writer_threads) - - obj.episode_buffer = obj.create_episode_buffer() - - obj.episodes = None - obj.hf_dataset = obj.create_hf_dataset() obj.image_transforms = None obj.delta_timestamps = None - obj.delta_indices = None - obj._absolute_to_relative_idx = None - obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec() - obj.writer = None - obj.latest_episode = None - obj._current_file_start_frame = None - # Initialize tracking for incremental recording - obj._lazy_loading = False - obj._recorded_frames = 0 - obj._writer_closed_for_reading = False + obj.episodes = None + obj._video_backend = video_backend if video_backend is not None else get_safe_default_codec() + obj._batch_encoding_size = batch_encoding_size + obj._vcodec = vcodec + obj._encoder_threads = encoder_threads - # Initialize streaming encoder + # Reader is lazily created on first access (write-only mode) + obj.reader = None + + # Create writer + streaming_enc = None if streaming_encoding and len(obj.meta.video_keys) > 0: - obj._streaming_encoder = StreamingVideoEncoder( - fps=fps, - vcodec=vcodec, - pix_fmt="yuv420p", - g=2, - crf=30, - preset=None, - queue_maxsize=encoder_queue_maxsize, - encoder_threads=encoder_threads, - ) - else: - obj._streaming_encoder = None + streaming_enc = cls._build_streaming_encoder(fps, vcodec, encoder_queue_maxsize, encoder_threads) + obj.writer = DatasetWriter( + meta=obj.meta, + root=obj.root, + vcodec=vcodec, + encoder_threads=encoder_threads, + batch_encoding_size=batch_encoding_size, + streaming_encoder=streaming_enc, + ) + + if image_writer_processes or image_writer_threads: + obj.writer.start_image_writer(image_writer_processes, image_writer_threads) + + obj._is_finalized = False + + return obj + + @classmethod + def resume( + cls, + repo_id: str, + root: str | Path | None = None, + tolerance_s: float = 1e-4, + revision: str | None = None, + force_cache_sync: bool = False, + video_backend: str | None = None, + batch_encoding_size: int = 1, + vcodec: str = "libsvtav1", + image_writer_processes: int = 0, + image_writer_threads: int = 0, + streaming_encoding: bool = False, + encoder_queue_maxsize: int = 30, + encoder_threads: int | None = None, + ) -> "LeRobotDataset": + """Resume recording on an existing dataset. + + Loads metadata from an existing dataset (local or Hub) and creates a + :class:`DatasetWriter` for appending new episodes. The underlying HF + dataset is not loaded until :meth:`finalize` is called and data is + subsequently read. + + Args: + repo_id: Repository identifier of the existing dataset. + root: Local directory of the dataset. Defaults to + ``$HF_LEROBOT_HOME/{repo_id}``. + tolerance_s: Timestamp synchronization tolerance in seconds. + revision: Git revision (branch, tag, or commit hash). Defaults to + current codebase version tag. + force_cache_sync: If ``True``, re-download metadata from the Hub even + if a local cache exists. + video_backend: Video decoding backend for reading back data. + batch_encoding_size: Number of episodes to accumulate before + batch-encoding videos. + vcodec: Video codec for encoding. + image_writer_processes: Subprocesses for async image writing. + image_writer_threads: Threads for async image writing. + streaming_encoding: If ``True``, encode video in real-time during + capture. + encoder_queue_maxsize: Max buffered frames per camera for streaming. + encoder_threads: Threads per encoder instance. ``None`` for auto. + + Returns: + A :class:`LeRobotDataset` in write mode, ready to append episodes. + """ + vcodec = resolve_vcodec(vcodec) + obj = cls.__new__(cls) + obj.repo_id = repo_id + obj.root = Path(root) if root else HF_LEROBOT_HOME / repo_id + obj.root.mkdir(exist_ok=True, parents=True) + obj.revision = revision if revision else CODEBASE_VERSION + obj.tolerance_s = tolerance_s + obj.image_transforms = None + obj.delta_timestamps = None + obj.episodes = None + obj._video_backend = video_backend if video_backend else get_safe_default_codec() + obj._batch_encoding_size = batch_encoding_size + obj._vcodec = vcodec + obj._encoder_threads = encoder_threads + + # Load metadata + obj.meta = LeRobotDatasetMetadata( + obj.repo_id, obj.root, obj.revision, force_cache_sync=force_cache_sync + ) + + # Reader is lazily created on first access (write-only mode) + obj.reader = None + + # Create writer for appending + streaming_enc = None + if streaming_encoding and len(obj.meta.video_keys) > 0: + streaming_enc = cls._build_streaming_encoder( + obj.meta.fps, vcodec, encoder_queue_maxsize, encoder_threads + ) + obj.writer = DatasetWriter( + meta=obj.meta, + root=obj.root, + vcodec=vcodec, + encoder_threads=encoder_threads, + batch_encoding_size=batch_encoding_size, + streaming_encoder=streaming_enc, + initial_frames=obj.meta.total_frames, + ) + + if image_writer_processes or image_writer_threads: + obj.writer.start_image_writer(image_writer_processes, image_writer_threads) + + obj._is_finalized = False return obj diff --git a/src/lerobot/datasets/multi_dataset.py b/src/lerobot/datasets/multi_dataset.py index 917d5c5eb..d16c5bb07 100644 --- a/src/lerobot/datasets/multi_dataset.py +++ b/src/lerobot/datasets/multi_dataset.py @@ -22,6 +22,7 @@ import torch import torch.utils from lerobot.datasets.compute_stats import aggregate_stats +from lerobot.datasets.feature_utils import get_hf_features_from_features from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.video_utils import VideoFrame from lerobot.utils.constants import HF_LEROBOT_HOME @@ -125,7 +126,13 @@ class MultiLeRobotDataset(torch.utils.data.Dataset): def features(self) -> datasets.Features: features = {} for dataset in self._datasets: - features.update({k: v for k, v in dataset.hf_features.items() if k not in self.disabled_features}) + features.update( + { + k: v + for k, v in get_hf_features_from_features(dataset.features).items() + if k not in self.disabled_features + } + ) return features @property diff --git a/src/lerobot/datasets/video_utils.py b/src/lerobot/datasets/video_utils.py index e465b79b4..59c8c7d3e 100644 --- a/src/lerobot/datasets/video_utils.py +++ b/src/lerobot/datasets/video_utils.py @@ -741,6 +741,7 @@ class StreamingVideoEncoder: self._video_paths: dict[str, Path] = {} self._dropped_frames: dict[str, int] = {} self._episode_active = False + self._closed = False def start_episode(self, video_keys: list[str], temp_dir: Path) -> None: """Start encoder threads for a new episode. @@ -895,8 +896,11 @@ class StreamingVideoEncoder: def close(self) -> None: """Close the encoder, canceling any in-progress episode.""" + if self._closed: + return if self._episode_active: self.cancel_episode() + self._closed = True def _cleanup(self) -> None: """Clean up queues and thread tracking dicts.""" @@ -1063,43 +1067,19 @@ class VideoEncodingManager: return self def __exit__(self, exc_type, exc_val, exc_tb): - streaming_encoder = getattr(self.dataset, "_streaming_encoder", None) + writer = self.dataset.writer + if writer is not None: + if exc_type is not None and writer._streaming_encoder is not None: + writer.cancel_pending_videos() - if streaming_encoder is not None: - # Handle streaming encoder cleanup - if exc_type is not None: - streaming_encoder.cancel_episode() - streaming_encoder.close() - elif self.dataset.episodes_since_last_encoding > 0: - # Handle any remaining episodes that haven't been batch encoded - if exc_type is not None: - logger.info("Exception occurred. Encoding remaining episodes before exit...") - else: - logger.info("Recording stopped. Encoding remaining episodes...") + # finalize() handles flush_pending_videos + parquet + metadata + self.dataset.finalize() - start_ep = self.dataset.num_episodes - self.dataset.episodes_since_last_encoding - end_ep = self.dataset.num_episodes - logger.info( - f"Encoding remaining {self.dataset.episodes_since_last_encoding} episodes, " - f"from episode {start_ep} to {end_ep - 1}" - ) - self.dataset._batch_save_episode_video(start_ep, end_ep) - - # Finalize the dataset to properly close all writers - self.dataset.finalize() - - # Clean up episode images if recording was interrupted (only for non-streaming mode) - if exc_type is not None and streaming_encoder is None: - interrupted_episode_index = self.dataset.num_episodes - for key in self.dataset.meta.video_keys: - img_dir = self.dataset._get_image_file_path( - episode_index=interrupted_episode_index, image_key=key, frame_index=0 - ).parent - if img_dir.exists(): - logger.debug( - f"Cleaning up interrupted episode images for episode {interrupted_episode_index}, camera {key}" - ) - shutil.rmtree(img_dir) + # Clean up episode images if recording was interrupted (only for non-streaming mode) + if exc_type is not None and writer._streaming_encoder is None: + writer.cleanup_interrupted_episode(self.dataset.num_episodes) + else: + self.dataset.finalize() # Clean up any remaining images directory if it's empty img_dir = self.dataset.root / "images" diff --git a/src/lerobot/rl/buffer.py b/src/lerobot/rl/buffer.py index 81aa29c48..68954162d 100644 --- a/src/lerobot/rl/buffer.py +++ b/src/lerobot/rl/buffer.py @@ -563,7 +563,7 @@ class ReplayBuffer: ) # Start writing images if needed - lerobot_dataset.start_image_writer(num_processes=0, num_threads=3) + lerobot_dataset.writer.start_image_writer(num_processes=0, num_threads=3) # Convert transitions into episodes and frames @@ -603,10 +603,10 @@ class ReplayBuffer: lerobot_dataset.save_episode() # Save any remaining frames in the buffer - if lerobot_dataset.episode_buffer["size"] > 0: + if lerobot_dataset.has_pending_frames(): lerobot_dataset.save_episode() - lerobot_dataset.stop_image_writer() + lerobot_dataset.writer.stop_image_writer() lerobot_dataset.finalize() return lerobot_dataset diff --git a/src/lerobot/rl/gym_manipulator.py b/src/lerobot/rl/gym_manipulator.py index f5fcb7437..bd64d205f 100644 --- a/src/lerobot/rl/gym_manipulator.py +++ b/src/lerobot/rl/gym_manipulator.py @@ -752,8 +752,7 @@ def replay_trajectory( episodes=[cfg.dataset.replay_episode], download_videos=False, ) - episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == cfg.dataset.replay_episode) - actions = episode_frames.select_columns(ACTION) + actions = dataset.select_columns(ACTION) _, info = env.reset() diff --git a/src/lerobot/scripts/lerobot_record.py b/src/lerobot/scripts/lerobot_record.py index 819634ba2..ac01c9319 100644 --- a/src/lerobot/scripts/lerobot_record.py +++ b/src/lerobot/scripts/lerobot_record.py @@ -468,7 +468,8 @@ def record(cfg: RecordConfig) -> LeRobotDataset: try: if cfg.resume: - dataset = LeRobotDataset( + num_cameras = len(robot.cameras) if hasattr(robot, "cameras") else 0 + dataset = LeRobotDataset.resume( cfg.dataset.repo_id, root=cfg.dataset.root, batch_encoding_size=cfg.dataset.video_encoding_batch_size, @@ -476,13 +477,11 @@ def record(cfg: RecordConfig) -> LeRobotDataset: streaming_encoding=cfg.dataset.streaming_encoding, encoder_queue_maxsize=cfg.dataset.encoder_queue_maxsize, encoder_threads=cfg.dataset.encoder_threads, + image_writer_processes=cfg.dataset.num_image_writer_processes if num_cameras > 0 else 0, + image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera * num_cameras + if num_cameras > 0 + else 0, ) - - if hasattr(robot, "cameras") and len(robot.cameras) > 0: - dataset.start_image_writer( - num_processes=cfg.dataset.num_image_writer_processes, - num_threads=cfg.dataset.num_image_writer_threads_per_camera * len(robot.cameras), - ) sanity_check_dataset_robot_compatibility(dataset, robot, cfg.dataset.fps, dataset_features) else: # Create empty dataset or load existing saved episodes diff --git a/src/lerobot/scripts/lerobot_replay.py b/src/lerobot/scripts/lerobot_replay.py index 7c0b5b96b..09e7d4e8b 100644 --- a/src/lerobot/scripts/lerobot_replay.py +++ b/src/lerobot/scripts/lerobot_replay.py @@ -104,15 +104,13 @@ def replay(cfg: ReplayConfig): robot = make_robot_from_config(cfg.robot) dataset = LeRobotDataset(cfg.dataset.repo_id, root=cfg.dataset.root, episodes=[cfg.dataset.episode]) - # Filter dataset to only include frames from the specified episode since episodes are chunked in dataset V3.0 - episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == cfg.dataset.episode) - actions = episode_frames.select_columns(ACTION) + actions = dataset.select_columns(ACTION) robot.connect() try: log_say("Replaying episode", cfg.play_sounds, blocking=True) - for idx in range(len(episode_frames)): + for idx in range(dataset.num_frames): start_episode_t = time.perf_counter() action_array = actions[idx][ACTION] diff --git a/src/lerobot/scripts/lerobot_train_tokenizer.py b/src/lerobot/scripts/lerobot_train_tokenizer.py index 807d48333..70185fc51 100644 --- a/src/lerobot/scripts/lerobot_train_tokenizer.py +++ b/src/lerobot/scripts/lerobot_train_tokenizer.py @@ -204,15 +204,15 @@ def process_episode(args): for abs_idx in range(from_idx, to_idx): # map absolute index to relative index if needed - if dataset._absolute_to_relative_idx is not None: - if abs_idx not in dataset._absolute_to_relative_idx: + if dataset.reader._absolute_to_relative_idx is not None: + if abs_idx not in dataset.reader._absolute_to_relative_idx: # this episode's frames aren't in the filtered dataset return None - rel_idx = dataset._absolute_to_relative_idx[abs_idx] + rel_idx = dataset.reader._absolute_to_relative_idx[abs_idx] else: rel_idx = abs_idx - frame = dataset.hf_dataset[rel_idx] + frame = dataset.get_raw_item(rel_idx) # get state (could be from observation.state or other state key) if state_key in frame: diff --git a/tests/artifacts/policies/save_policy_to_safetensors.py b/tests/artifacts/policies/save_policy_to_safetensors.py index 64b125cc9..7359f6169 100644 --- a/tests/artifacts/policies/save_policy_to_safetensors.py +++ b/tests/artifacts/policies/save_policy_to_safetensors.py @@ -80,7 +80,7 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict): # HACK: We reload a batch with no delta_indices as `select_action` won't expect a timestamps dimension # We simulate having an environment using a dataset by setting delta_indices to None and dropping tensors # indicating padding (those ending with "_is_pad") - dataset.delta_indices = None + dataset.reader.delta_indices = None batch = next(iter(dataloader)) obs = {} for k in batch: diff --git a/tests/datasets/test_dataset_metadata.py b/tests/datasets/test_dataset_metadata.py new file mode 100644 index 000000000..3f3971e15 --- /dev/null +++ b/tests/datasets/test_dataset_metadata.py @@ -0,0 +1,385 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contract tests for LeRobotDatasetMetadata.""" + +import json + +import numpy as np +import pytest + +from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata +from lerobot.datasets.utils import INFO_PATH +from tests.fixtures.constants import DEFAULT_FPS, DUMMY_ROBOT_TYPE + +# ── helpers ────────────────────────────────────────────────────────── + +SIMPLE_FEATURES = { + "state": {"dtype": "float32", "shape": (6,), "names": None}, + "action": {"dtype": "float32", "shape": (6,), "names": None}, +} + +VIDEO_FEATURES = { + **SIMPLE_FEATURES, + "observation.images.laptop": { + "dtype": "video", + "shape": (64, 96, 3), + "names": ["height", "width", "channels"], + "info": None, + }, +} + +IMAGE_FEATURES = { + **SIMPLE_FEATURES, + "observation.images.laptop": { + "dtype": "image", + "shape": (64, 96, 3), + "names": ["height", "width", "channels"], + "info": None, + }, +} + + +def _make_dummy_stats(features: dict) -> dict: + """Create minimal episode stats matching the given features.""" + stats = {} + for key, ft in features.items(): + if ft["dtype"] in ("image", "video"): + stats[key] = { + "max": np.ones((3, 1, 1), dtype=np.float32), + "mean": np.full((3, 1, 1), 0.5, dtype=np.float32), + "min": np.zeros((3, 1, 1), dtype=np.float32), + "std": np.full((3, 1, 1), 0.25, dtype=np.float32), + "count": np.array([5]), + } + elif ft["dtype"] in ("float32", "float64", "int64"): + stats[key] = { + "max": np.ones(ft["shape"], dtype=np.float32), + "mean": np.full(ft["shape"], 0.5, dtype=np.float32), + "min": np.zeros(ft["shape"], dtype=np.float32), + "std": np.full(ft["shape"], 0.25, dtype=np.float32), + "count": np.array([5]), + } + return stats + + +# ── Construction contracts ─────────────────────────────────────────── + + +def test_create_produces_valid_info_on_disk(tmp_path): + """create() writes info.json and the returned object reflects the provided settings.""" + root = tmp_path / "new_ds" + meta = LeRobotDatasetMetadata.create( + repo_id="test/meta", + fps=DEFAULT_FPS, + features=SIMPLE_FEATURES, + robot_type=DUMMY_ROBOT_TYPE, + root=root, + use_videos=False, + ) + + # info.json was written to disk + assert (root / INFO_PATH).exists() + with open(root / INFO_PATH) as f: + info_on_disk = json.load(f) + + assert meta.fps == DEFAULT_FPS + assert meta.robot_type == DUMMY_ROBOT_TYPE + assert "state" in meta.features + assert "action" in meta.features + assert info_on_disk["fps"] == DEFAULT_FPS + + +def test_create_starts_with_zero_counts(tmp_path): + """A freshly created metadata has zero episode/frame/task counts.""" + root = tmp_path / "empty_ds" + meta = LeRobotDatasetMetadata.create( + repo_id="test/empty", fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=root, use_videos=False + ) + + assert meta.total_episodes == 0 + assert meta.total_frames == 0 + assert meta.total_tasks == 0 + assert meta.tasks is None + assert meta.episodes is None + assert meta.stats is None + + +def test_create_with_videos_sets_video_path(tmp_path): + """When features include video-dtype keys, create() produces a non-None video_path.""" + root = tmp_path / "video_ds" + meta = LeRobotDatasetMetadata.create( + repo_id="test/video", fps=DEFAULT_FPS, features=VIDEO_FEATURES, root=root, use_videos=True + ) + + assert meta.video_path is not None + assert len(meta.video_keys) == 1 + assert "observation.images.laptop" in meta.video_keys + + +def test_create_without_videos_has_no_video_path(tmp_path): + """When use_videos=False and no video features, video_path is None.""" + root = tmp_path / "no_video" + meta = LeRobotDatasetMetadata.create( + repo_id="test/novid", fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=root, use_videos=False + ) + + assert meta.video_path is None + assert meta.video_keys == [] + + +def test_create_raises_on_existing_directory(tmp_path): + """create() raises if root directory already exists.""" + root = tmp_path / "existing" + root.mkdir() + + with pytest.raises(FileExistsError): + LeRobotDatasetMetadata.create( + repo_id="test/exists", fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=root, use_videos=False + ) + + +def test_init_loads_existing_metadata(tmp_path, lerobot_dataset_metadata_factory, info_factory): + """When metadata files exist on disk, __init__ loads them correctly.""" + root = tmp_path / "load_test" + info = info_factory(total_episodes=3, total_frames=150, total_tasks=1, use_videos=False) + meta = lerobot_dataset_metadata_factory(root=root, info=info) + + assert meta.total_episodes == 3 + assert meta.total_frames == 150 + assert meta.fps == info["fps"] + + +# ── Property accessors ─────────────────────────────────────────────── + + +def test_property_accessors_reflect_info(tmp_path): + """Properties return values consistent with the info dict.""" + root = tmp_path / "props_ds" + meta = LeRobotDatasetMetadata.create( + repo_id="test/props", + fps=DEFAULT_FPS, + features=IMAGE_FEATURES, + robot_type=DUMMY_ROBOT_TYPE, + root=root, + use_videos=False, + ) + + assert meta.fps == DEFAULT_FPS + assert meta.robot_type == DUMMY_ROBOT_TYPE + # shapes should be tuples + for _key, shape in meta.shapes.items(): + assert isinstance(shape, tuple) + # image_keys should contain the image feature + assert "observation.images.laptop" in meta.image_keys + # camera_keys is a superset of image_keys and video_keys + assert set(meta.image_keys + meta.video_keys) == set(meta.camera_keys) + + +def test_data_path_is_formattable(tmp_path): + """data_path contains format placeholders that can be .format()-ed.""" + root = tmp_path / "fmt_ds" + meta = LeRobotDatasetMetadata.create( + repo_id="test/fmt", fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=root, use_videos=False + ) + + formatted = meta.data_path.format(chunk_index=0, file_index=0) + assert "chunk" in formatted.lower() or "0" in formatted + + +# ── Task management ────────────────────────────────────────────────── + + +def test_save_episode_tasks_creates_tasks_dataframe(tmp_path): + """On a fresh metadata, save_episode_tasks() creates the tasks DataFrame.""" + root = tmp_path / "task_ds" + meta = LeRobotDatasetMetadata.create( + repo_id="test/task", fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=root, use_videos=False + ) + assert meta.tasks is None + + meta.save_episode_tasks(["Pick up the cube"]) + + assert meta.tasks is not None + assert len(meta.tasks) == 1 + assert "Pick up the cube" in meta.tasks.index + + +def test_save_episode_tasks_is_additive(tmp_path): + """New tasks are added; existing tasks keep their original index.""" + root = tmp_path / "additive_ds" + meta = LeRobotDatasetMetadata.create( + repo_id="test/add", fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=root, use_videos=False + ) + + meta.save_episode_tasks(["Task A"]) + idx_a = meta.get_task_index("Task A") + + meta.save_episode_tasks(["Task A", "Task B"]) + assert meta.get_task_index("Task A") == idx_a # unchanged + assert meta.get_task_index("Task B") is not None + assert len(meta.tasks) == 2 + + +def test_get_task_index_returns_none_for_unknown(tmp_path): + """get_task_index() returns None for an unknown task.""" + root = tmp_path / "unknown_ds" + meta = LeRobotDatasetMetadata.create( + repo_id="test/unknown", fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=root, use_videos=False + ) + meta.save_episode_tasks(["Known task"]) + + assert meta.get_task_index("Known task") == 0 + assert meta.get_task_index("Unknown task") is None + + +def test_save_episode_tasks_rejects_duplicates(tmp_path): + """save_episode_tasks() raises ValueError on duplicate task strings.""" + root = tmp_path / "dup_ds" + meta = LeRobotDatasetMetadata.create( + repo_id="test/dup", fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=root, use_videos=False + ) + + with pytest.raises(ValueError): + meta.save_episode_tasks(["Same task", "Same task"]) + + +# ── Episode saving ─────────────────────────────────────────────────── + + +def test_save_episode_increments_counters(tmp_path): + """After save_episode(), total_episodes and total_frames increase.""" + root = tmp_path / "ep_ds" + meta = LeRobotDatasetMetadata.create( + repo_id="test/ep", fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=root, use_videos=False + ) + meta.save_episode_tasks(["Task 1"]) + stats = _make_dummy_stats(meta.features) + + meta.save_episode( + episode_index=0, + episode_length=10, + episode_tasks=["Task 1"], + episode_stats=stats, + episode_metadata={}, + ) + + assert meta.total_episodes == 1 + assert meta.total_frames == 10 + + +def test_save_episode_updates_stats(tmp_path): + """After save_episode(), .stats is non-None and has feature keys.""" + root = tmp_path / "stats_ds" + meta = LeRobotDatasetMetadata.create( + repo_id="test/stats", fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=root, use_videos=False + ) + meta.save_episode_tasks(["Task 1"]) + stats = _make_dummy_stats(meta.features) + + meta.save_episode( + episode_index=0, + episode_length=5, + episode_tasks=["Task 1"], + episode_stats=stats, + episode_metadata={}, + ) + + assert meta.stats is not None + # Stats should contain at least the user-defined feature keys + for key in SIMPLE_FEATURES: + assert key in meta.stats + + +# ── Chunk settings ─────────────────────────────────────────────────── + + +def test_update_chunk_settings_persists(tmp_path): + """update_chunk_settings() changes values and writes info.json.""" + root = tmp_path / "chunk_ds" + meta = LeRobotDatasetMetadata.create( + repo_id="test/chunk", fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=root, use_videos=False + ) + original = meta.get_chunk_settings() + + meta.update_chunk_settings(chunks_size=500) + assert meta.chunks_size == 500 + assert meta.chunks_size != original["chunks_size"] or original["chunks_size"] == 500 + + # Verify persisted + with open(root / INFO_PATH) as f: + info_on_disk = json.load(f) + assert info_on_disk["chunks_size"] == 500 + + +def test_update_chunk_settings_rejects_non_positive(tmp_path): + """update_chunk_settings() raises ValueError for <= 0 values.""" + root = tmp_path / "bad_chunk" + meta = LeRobotDatasetMetadata.create( + repo_id="test/bad", fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=root, use_videos=False + ) + + with pytest.raises(ValueError): + meta.update_chunk_settings(chunks_size=0) + with pytest.raises(ValueError): + meta.update_chunk_settings(data_files_size_in_mb=-1) + + +# ── Finalization ───────────────────────────────────────────────────── + + +def test_finalize_is_idempotent(tmp_path): + """Calling finalize() multiple times does not raise.""" + root = tmp_path / "fin_ds" + meta = LeRobotDatasetMetadata.create( + repo_id="test/fin", fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=root, use_videos=False + ) + + meta.finalize() + meta.finalize() # second call should not raise + + +def test_finalize_flushes_buffered_metadata(tmp_path): + """Episodes saved before finalize() are written to parquet.""" + root = tmp_path / "flush_ds" + meta = LeRobotDatasetMetadata.create( + repo_id="test/flush", + fps=DEFAULT_FPS, + features=SIMPLE_FEATURES, + root=root, + use_videos=False, + metadata_buffer_size=100, # large buffer so nothing auto-flushes + ) + meta.save_episode_tasks(["Task 1"]) + stats = _make_dummy_stats(meta.features) + + # Save a few episodes (won't auto-flush since buffer_size=100) + for i in range(3): + meta.save_episode( + episode_index=i, + episode_length=5, + episode_tasks=["Task 1"], + episode_stats=stats, + episode_metadata={}, + ) + + # Before finalize, the parquet might not exist yet + meta.finalize() + + # After finalize, episodes parquet should exist + episodes_dir = root / "meta" / "episodes" + assert episodes_dir.exists() + parquet_files = list(episodes_dir.rglob("*.parquet")) + assert len(parquet_files) > 0 diff --git a/tests/datasets/test_dataset_reader.py b/tests/datasets/test_dataset_reader.py new file mode 100644 index 000000000..4c8a8b23f --- /dev/null +++ b/tests/datasets/test_dataset_reader.py @@ -0,0 +1,168 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contract tests for DatasetReader.""" + +from lerobot.datasets.dataset_reader import DatasetReader +from lerobot.datasets.video_utils import get_safe_default_codec + +# ── Loading ────────────────────────────────────────────────────────── + + +def test_try_load_returns_true_when_data_exists(tmp_path, lerobot_dataset_factory): + """Given a fully written dataset, try_load() returns True.""" + dataset = lerobot_dataset_factory( + root=tmp_path / "ds", total_episodes=2, total_frames=20, use_videos=False + ) + reader = DatasetReader( + meta=dataset.meta, + root=dataset.root, + episodes=None, + tolerance_s=1e-4, + video_backend=get_safe_default_codec(), + delta_timestamps=None, + image_transforms=None, + ) + assert reader.try_load() is True + assert reader.hf_dataset is not None + + +def test_try_load_returns_false_when_no_data(tmp_path): + """When only metadata exists (no data/ parquets), try_load() returns False.""" + from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata + + root = tmp_path / "meta_only" + features = {"state": {"dtype": "float32", "shape": (2,), "names": None}} + meta = LeRobotDatasetMetadata.create( + repo_id="test/meta_only", fps=30, features=features, root=root, use_videos=False + ) + + reader = DatasetReader( + meta=meta, + root=meta.root, + episodes=None, + tolerance_s=1e-4, + video_backend=get_safe_default_codec(), + delta_timestamps=None, + image_transforms=None, + ) + assert reader.try_load() is False + assert reader.hf_dataset is None + + +# ── Counts ─────────────────────────────────────────────────────────── + + +def test_num_frames_without_filter(tmp_path, lerobot_dataset_factory): + """With episodes=None, num_frames equals total_frames.""" + dataset = lerobot_dataset_factory( + root=tmp_path / "ds", total_episodes=3, total_frames=60, use_videos=False + ) + assert dataset.reader.num_frames == dataset.meta.total_frames + + +def test_num_episodes_without_filter(tmp_path, lerobot_dataset_factory): + """With episodes=None, num_episodes equals total_episodes.""" + dataset = lerobot_dataset_factory( + root=tmp_path / "ds", total_episodes=3, total_frames=60, use_videos=False + ) + assert dataset.reader.num_episodes == dataset.meta.total_episodes + + +def test_num_frames_with_episode_filter(tmp_path, lerobot_dataset_factory): + """When filtering to a subset, only those episodes' frames are counted.""" + dataset = lerobot_dataset_factory( + root=tmp_path / "ds", total_episodes=5, total_frames=100, episodes=[0, 2], use_videos=False + ) + # Filtered frames should be less than total + assert dataset.reader.num_frames <= dataset.meta.total_frames + assert dataset.reader.num_episodes == 2 + + +# ── get_item ───────────────────────────────────────────────────────── + + +def test_get_item_returns_expected_keys(tmp_path, lerobot_dataset_factory): + """get_item(0) returns a dict with expected keys.""" + dataset = lerobot_dataset_factory( + root=tmp_path / "ds", total_episodes=1, total_frames=10, use_videos=False + ) + item = dataset.reader.get_item(0) + + # Standard keys that must always be present + for key in ["index", "episode_index", "frame_index", "timestamp", "task_index", "task"]: + assert key in item, f"Missing key: {key}" + + +def test_get_item_values_are_correct(tmp_path, lerobot_dataset_factory): + """get_item() returns correct index and episode_index.""" + dataset = lerobot_dataset_factory( + root=tmp_path / "ds", total_episodes=2, total_frames=20, use_videos=False + ) + item_0 = dataset.reader.get_item(0) + + assert item_0["index"].item() == 0 + assert item_0["episode_index"].item() == 0 + + +# ── Transforms ─────────────────────────────────────────────────────── + + +def test_image_transforms_are_applied(tmp_path, lerobot_dataset_factory): + """When image_transforms is provided, get_item() applies it to camera keys.""" + transform_called = {"count": 0} + + def sentinel_transform(img): + transform_called["count"] += 1 + return img + + dataset = lerobot_dataset_factory( + root=tmp_path / "ds", + total_episodes=1, + total_frames=5, + use_videos=False, + image_transforms=sentinel_transform, + ) + item = dataset[0] # noqa: F841 + + # Should have been called once per camera key per frame + num_cameras = len(dataset.meta.camera_keys) + if num_cameras > 0: + assert transform_called["count"] >= 1 + + +# ── File paths ─────────────────────────────────────────────────────── + + +def test_get_episodes_file_paths_returns_data_paths(tmp_path, lerobot_dataset_factory): + """get_episodes_file_paths() returns paths including data/ paths.""" + dataset = lerobot_dataset_factory( + root=tmp_path / "ds", total_episodes=2, total_frames=20, use_videos=False + ) + paths = dataset.reader.get_episodes_file_paths() + + assert len(paths) > 0 + assert any("data/" in str(p) for p in paths) + + +def test_get_episodes_file_paths_includes_video_paths(tmp_path, lerobot_dataset_factory): + """When dataset has video keys, file paths include video/ paths.""" + dataset = lerobot_dataset_factory( + root=tmp_path / "ds", total_episodes=2, total_frames=20, use_videos=True + ) + + if len(dataset.meta.video_keys) > 0: + paths = dataset.reader.get_episodes_file_paths() + assert any("video" in str(p).lower() for p in paths) diff --git a/tests/datasets/test_dataset_writer.py b/tests/datasets/test_dataset_writer.py new file mode 100644 index 000000000..8c6ee68bd --- /dev/null +++ b/tests/datasets/test_dataset_writer.py @@ -0,0 +1,226 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contract tests for DatasetWriter.""" + +from pathlib import Path +from unittest.mock import patch + +import numpy as np +import pytest +import torch +from PIL import Image + +from lerobot.datasets.dataset_writer import _encode_video_worker +from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.datasets.utils import DEFAULT_IMAGE_PATH +from tests.fixtures.constants import DEFAULT_FPS, DUMMY_REPO_ID + +SIMPLE_FEATURES = { + "state": {"dtype": "float32", "shape": (6,), "names": None}, + "action": {"dtype": "float32", "shape": (6,), "names": None}, +} + + +def _make_frame(features: dict, task: str = "Dummy task") -> dict: + """Build a valid frame dict for the given features.""" + frame = {"task": task} + for key, ft in features.items(): + if ft["dtype"] in ("image", "video"): + frame[key] = np.random.randint(0, 256, size=ft["shape"], dtype=np.uint8) + elif ft["dtype"] in ("float32", "float64"): + frame[key] = torch.randn(ft["shape"]) + elif ft["dtype"] == "int64": + frame[key] = torch.zeros(ft["shape"], dtype=torch.int64) + return frame + + +# ── Existing encode_video_worker tests ─────────────────────────────── + + +def test_encode_video_worker_forwards_vcodec(tmp_path): + """_encode_video_worker correctly forwards the vcodec parameter.""" + video_key = "observation.images.laptop" + fpath = DEFAULT_IMAGE_PATH.format(image_key=video_key, episode_index=0, frame_index=0) + img_dir = tmp_path / Path(fpath).parent + img_dir.mkdir(parents=True, exist_ok=True) + Image.new("RGB", (64, 64), color="red").save(img_dir / "frame-000000.png") + + captured_kwargs = {} + + def mock_encode(imgs_dir, video_path, fps, **kwargs): + captured_kwargs.update(kwargs) + Path(video_path).parent.mkdir(parents=True, exist_ok=True) + Path(video_path).touch() + + with patch("lerobot.datasets.dataset_writer.encode_video_frames", side_effect=mock_encode): + _encode_video_worker(video_key, 0, tmp_path, fps=30, vcodec="h264") + + assert captured_kwargs["vcodec"] == "h264" + + +def test_encode_video_worker_default_vcodec(tmp_path): + """_encode_video_worker uses libsvtav1 as the default codec.""" + video_key = "observation.images.laptop" + fpath = DEFAULT_IMAGE_PATH.format(image_key=video_key, episode_index=0, frame_index=0) + img_dir = tmp_path / Path(fpath).parent + img_dir.mkdir(parents=True, exist_ok=True) + Image.new("RGB", (64, 64), color="red").save(img_dir / "frame-000000.png") + + captured_kwargs = {} + + def mock_encode(imgs_dir, video_path, fps, **kwargs): + captured_kwargs.update(kwargs) + Path(video_path).parent.mkdir(parents=True, exist_ok=True) + Path(video_path).touch() + + with patch("lerobot.datasets.dataset_writer.encode_video_frames", side_effect=mock_encode): + _encode_video_worker(video_key, 0, tmp_path, fps=30) + + assert captured_kwargs["vcodec"] == "libsvtav1" + + +# ── add_frame contracts ────────────────────────────────────────────── + + +def test_add_frame_increments_buffer_size(tmp_path): + """Each add_frame() call increases episode_buffer['size'] by 1.""" + dataset = LeRobotDataset.create( + repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=tmp_path / "ds" + ) + assert dataset.writer.episode_buffer["size"] == 0 + + dataset.add_frame(_make_frame(SIMPLE_FEATURES)) + assert dataset.writer.episode_buffer["size"] == 1 + + dataset.add_frame(_make_frame(SIMPLE_FEATURES)) + assert dataset.writer.episode_buffer["size"] == 2 + + +def test_add_frame_rejects_missing_feature(tmp_path): + """add_frame() raises ValueError when a required feature is missing.""" + dataset = LeRobotDataset.create( + repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=tmp_path / "ds" + ) + with pytest.raises(ValueError, match="Missing features"): + dataset.add_frame({"task": "Dummy task", "state": torch.randn(6)}) + # missing 'action' + + +# ── save_episode contracts ─────────────────────────────────────────── + + +def test_save_episode_writes_parquet(tmp_path): + """After save_episode(), at least one .parquet file exists under data/.""" + dataset = LeRobotDataset.create( + repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=tmp_path / "ds" + ) + for _ in range(3): + dataset.add_frame(_make_frame(SIMPLE_FEATURES)) + dataset.save_episode() + + parquet_files = list((tmp_path / "ds" / "data").rglob("*.parquet")) + assert len(parquet_files) > 0 + + +def test_save_episode_updates_counters(tmp_path): + """After save_episode(), metadata counters are updated.""" + dataset = LeRobotDataset.create( + repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=tmp_path / "ds" + ) + for _ in range(5): + dataset.add_frame(_make_frame(SIMPLE_FEATURES)) + dataset.save_episode() + + assert dataset.meta.total_episodes == 1 + assert dataset.meta.total_frames == 5 + + +def test_save_episode_resets_buffer(tmp_path): + """After save_episode(), the episode buffer is reset.""" + dataset = LeRobotDataset.create( + repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=tmp_path / "ds" + ) + for _ in range(3): + dataset.add_frame(_make_frame(SIMPLE_FEATURES)) + dataset.save_episode() + + assert dataset.writer.episode_buffer["size"] == 0 + + +def test_save_multiple_episodes(tmp_path): + """Recording 3 episodes results in correct total counts.""" + dataset = LeRobotDataset.create( + repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=tmp_path / "ds" + ) + total_frames = 0 + for ep in range(3): + n_frames = ep + 2 # 2, 3, 4 + for _ in range(n_frames): + dataset.add_frame(_make_frame(SIMPLE_FEATURES)) + dataset.save_episode() + total_frames += n_frames + + assert dataset.meta.total_episodes == 3 + assert dataset.meta.total_frames == total_frames + + +# ── clear / lifecycle ──────────────────────────────────────────────── + + +def test_clear_resets_buffer(tmp_path): + """clear_episode_buffer() resets the buffer size to 0.""" + dataset = LeRobotDataset.create( + repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=tmp_path / "ds" + ) + dataset.add_frame(_make_frame(SIMPLE_FEATURES)) + assert dataset.writer.episode_buffer["size"] == 1 + + dataset.clear_episode_buffer() + assert dataset.writer.episode_buffer["size"] == 0 + + +def test_finalize_is_idempotent(tmp_path): + """Calling finalize() twice does not raise.""" + dataset = LeRobotDataset.create( + repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=tmp_path / "ds" + ) + for _ in range(3): + dataset.add_frame(_make_frame(SIMPLE_FEATURES)) + dataset.save_episode() + + dataset.finalize() + dataset.finalize() # second call should not raise + + +def test_finalize_then_read_roundtrip(tmp_path): + """Write data, finalize, re-open, and verify data matches.""" + root = tmp_path / "roundtrip" + features = {"state": {"dtype": "float32", "shape": (2,), "names": None}} + dataset = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=features, root=root) + + # Record known values + known_states = [] + for i in range(5): + state = torch.tensor([float(i), float(i * 10)]) + known_states.append(state) + dataset.add_frame({"task": "Test task", "state": state}) + dataset.save_episode() + dataset.finalize() + + # Read back + for i in range(5): + item = dataset[i] + assert torch.allclose(item["state"], known_states[i], atol=1e-5) diff --git a/tests/datasets/test_datasets.py b/tests/datasets/test_datasets.py index 67878d8f6..b2518149f 100644 --- a/tests/datasets/test_datasets.py +++ b/tests/datasets/test_datasets.py @@ -32,10 +32,7 @@ from lerobot.datasets.factory import make_dataset from lerobot.datasets.feature_utils import get_hf_features_from_features, hw_to_dataset_features from lerobot.datasets.image_writer import image_array_to_pil_image from lerobot.datasets.io_utils import hf_transform_to_torch -from lerobot.datasets.lerobot_dataset import ( - LeRobotDataset, - _encode_video_worker, -) +from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.multi_dataset import MultiLeRobotDataset from lerobot.datasets.utils import ( DEFAULT_CHUNK_SIZE, @@ -72,7 +69,7 @@ def image_dataset(tmp_path, empty_lerobot_dataset_factory): def test_same_attributes_defined(tmp_path, lerobot_dataset_factory): """ Instantiate a LeRobotDataset both ways with '__init__()' and 'create()' and verify that instantiated - objects have the same sets of attributes defined. + objects have the same sets of facade-level attributes defined. """ # Instantiate both ways robot = make_robot_from_config(MockRobotConfig()) @@ -87,6 +84,7 @@ def test_same_attributes_defined(tmp_path, lerobot_dataset_factory): root_init = tmp_path / "init" dataset_init = lerobot_dataset_factory(root=root_init, total_episodes=1, total_frames=1) + # Facade-level attributes should match between __init__ and create() init_attr = set(vars(dataset_init).keys()) create_attr = set(vars(dataset_create).keys()) @@ -214,6 +212,7 @@ def test_add_frame(tmp_path, empty_lerobot_dataset_factory): dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) dataset.add_frame({"state": torch.randn(1), "task": "Dummy task"}) dataset.save_episode() + dataset.finalize() assert len(dataset) == 1 assert dataset[0]["task"] == "Dummy task" @@ -226,6 +225,7 @@ def test_add_frame_state_1d(tmp_path, empty_lerobot_dataset_factory): dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) dataset.add_frame({"state": torch.randn(2), "task": "Dummy task"}) dataset.save_episode() + dataset.finalize() assert dataset[0]["state"].shape == torch.Size([2]) @@ -235,6 +235,7 @@ def test_add_frame_state_2d(tmp_path, empty_lerobot_dataset_factory): dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) dataset.add_frame({"state": torch.randn(2, 4), "task": "Dummy task"}) dataset.save_episode() + dataset.finalize() assert dataset[0]["state"].shape == torch.Size([2, 4]) @@ -244,6 +245,7 @@ def test_add_frame_state_3d(tmp_path, empty_lerobot_dataset_factory): dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) dataset.add_frame({"state": torch.randn(2, 4, 3), "task": "Dummy task"}) dataset.save_episode() + dataset.finalize() assert dataset[0]["state"].shape == torch.Size([2, 4, 3]) @@ -253,6 +255,7 @@ def test_add_frame_state_4d(tmp_path, empty_lerobot_dataset_factory): dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) dataset.add_frame({"state": torch.randn(2, 4, 3, 5), "task": "Dummy task"}) dataset.save_episode() + dataset.finalize() assert dataset[0]["state"].shape == torch.Size([2, 4, 3, 5]) @@ -262,6 +265,7 @@ def test_add_frame_state_5d(tmp_path, empty_lerobot_dataset_factory): dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) dataset.add_frame({"state": torch.randn(2, 4, 3, 5, 1), "task": "Dummy task"}) dataset.save_episode() + dataset.finalize() assert dataset[0]["state"].shape == torch.Size([2, 4, 3, 5, 1]) @@ -271,6 +275,7 @@ def test_add_frame_state_numpy(tmp_path, empty_lerobot_dataset_factory): dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) dataset.add_frame({"state": np.array([1], dtype=np.float32), "task": "Dummy task"}) dataset.save_episode() + dataset.finalize() assert dataset[0]["state"].ndim == 0 @@ -280,6 +285,7 @@ def test_add_frame_string(tmp_path, empty_lerobot_dataset_factory): dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) dataset.add_frame({"caption": "Dummy caption", "task": "Dummy task"}) dataset.save_episode() + dataset.finalize() assert dataset[0]["caption"] == "Dummy caption" @@ -315,6 +321,7 @@ def test_add_frame_image(image_dataset): dataset = image_dataset dataset.add_frame({"image": np.random.rand(*DUMMY_CHW), "task": "Dummy task"}) dataset.save_episode() + dataset.finalize() assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW) @@ -323,6 +330,7 @@ def test_add_frame_image_h_w_c(image_dataset): dataset = image_dataset dataset.add_frame({"image": np.random.rand(*DUMMY_HWC), "task": "Dummy task"}) dataset.save_episode() + dataset.finalize() assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW) @@ -332,6 +340,7 @@ def test_add_frame_image_uint8(image_dataset): image = np.random.randint(0, 256, DUMMY_HWC, dtype=np.uint8) dataset.add_frame({"image": image, "task": "Dummy task"}) dataset.save_episode() + dataset.finalize() assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW) @@ -341,6 +350,7 @@ def test_add_frame_image_pil(image_dataset): image = np.random.randint(0, 256, DUMMY_HWC, dtype=np.uint8) dataset.add_frame({"image": Image.fromarray(image), "task": "Dummy task"}) dataset.save_episode() + dataset.finalize() assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW) @@ -361,7 +371,7 @@ def test_tmp_image_deletion(tmp_path, empty_lerobot_dataset_factory): ds_img = empty_lerobot_dataset_factory(root=tmp_path / "img", features=features_image) ds_img.add_frame({"image": np.random.rand(*DUMMY_CHW), "task": "Dummy task"}) ds_img.save_episode() - img_dir = ds_img._get_image_file_dir(0, image_key) + img_dir = ds_img.writer._get_image_file_dir(0, image_key) assert not img_dir.exists(), "Temporary image directory should be removed for image features" @@ -374,10 +384,10 @@ def test_tmp_video_deletion(tmp_path, empty_lerobot_dataset_factory): } ds_vid = empty_lerobot_dataset_factory(root=tmp_path / "vid", features=features_video) - ds_vid.batch_encoding_size = 1 + ds_vid.writer._batch_encoding_size = 1 ds_vid.add_frame({vid_key: np.random.rand(*DUMMY_CHW), "task": "Dummy task"}) ds_vid.save_episode() - vid_img_dir = ds_vid._get_image_file_dir(0, vid_key) + vid_img_dir = ds_vid.writer._get_image_file_dir(0, vid_key) assert not vid_img_dir.exists(), ( "Temporary image directory should be removed when batch_encoding_size == 1" ) @@ -402,8 +412,8 @@ def test_tmp_mixed_deletion(tmp_path, empty_lerobot_dataset_factory): } ) ds_mixed.save_episode() - img_dir = ds_mixed._get_image_file_dir(0, image_key) - vid_img_dir = ds_mixed._get_image_file_dir(0, vid_key) + img_dir = ds_mixed.writer._get_image_file_dir(0, image_key) + vid_img_dir = ds_mixed.writer._get_image_file_dir(0, vid_key) assert not img_dir.exists(), "Temporary image directory should be removed for image features" assert vid_img_dir.exists(), ( "Temporary image directory should not be removed for video features when batch_encoding_size == 2" @@ -631,29 +641,29 @@ def test_check_cached_episodes_sufficient(tmp_path, lerobot_dataset_factory): ) # Test hf_dataset is None - dataset.hf_dataset = None - assert dataset._check_cached_episodes_sufficient() is False + dataset.reader.hf_dataset = None + assert dataset.reader._check_cached_episodes_sufficient() is False # Test hf_dataset is empty import datasets empty_features = get_hf_features_from_features(dataset.features) - dataset.hf_dataset = datasets.Dataset.from_dict( + dataset.reader.hf_dataset = datasets.Dataset.from_dict( {key: [] for key in empty_features}, features=empty_features ) - dataset.hf_dataset.set_transform(hf_transform_to_torch) - assert dataset._check_cached_episodes_sufficient() is False + dataset.reader.hf_dataset.set_transform(hf_transform_to_torch) + assert dataset.reader._check_cached_episodes_sufficient() is False # Restore the original dataset for remaining tests - dataset.hf_dataset = dataset.load_hf_dataset() + dataset.reader.hf_dataset = dataset.reader._load_hf_dataset() # Test all episodes requested (self.episodes = None) and all are available - dataset.episodes = None - assert dataset._check_cached_episodes_sufficient() is True + dataset.reader.episodes = None + assert dataset.reader._check_cached_episodes_sufficient() is True # Test specific episodes requested that are all available - dataset.episodes = [0, 2, 4] - assert dataset._check_cached_episodes_sufficient() is True + dataset.reader.episodes = [0, 2, 4] + assert dataset.reader._check_cached_episodes_sufficient() is True # Test request episodes that don't exist in the cached dataset # Create a dataset with only episodes 0, 1, 2 @@ -665,8 +675,8 @@ def test_check_cached_episodes_sufficient(tmp_path, lerobot_dataset_factory): ) # Request episodes that include non-existent ones - limited_dataset.episodes = [0, 1, 2, 3, 4] - assert limited_dataset._check_cached_episodes_sufficient() is False + limited_dataset.reader.episodes = [0, 1, 2, 3, 4] + assert limited_dataset.reader._check_cached_episodes_sufficient() is False # Test create a dataset with sparse episodes (e.g., only episodes 0, 2, 4) # First create the full dataset structure @@ -702,22 +712,22 @@ def test_check_cached_episodes_sufficient(tmp_path, lerobot_dataset_factory): filtered_data[key] = filtered_values - sparse_dataset.hf_dataset = datasets.Dataset.from_dict( + sparse_dataset.reader.hf_dataset = datasets.Dataset.from_dict( filtered_data, features=get_hf_features_from_features(sparse_dataset.features) ) - sparse_dataset.hf_dataset.set_transform(hf_transform_to_torch) + sparse_dataset.reader.hf_dataset.set_transform(hf_transform_to_torch) # Test requesting all episodes when only some are cached - sparse_dataset.episodes = None - assert sparse_dataset._check_cached_episodes_sufficient() is False + sparse_dataset.reader.episodes = None + assert sparse_dataset.reader._check_cached_episodes_sufficient() is False # Test requesting only the available episodes - sparse_dataset.episodes = [0, 2, 4] - assert sparse_dataset._check_cached_episodes_sufficient() is True + sparse_dataset.reader.episodes = [0, 2, 4] + assert sparse_dataset.reader._check_cached_episodes_sufficient() is True # Test requesting a mix of available and unavailable episodes - sparse_dataset.episodes = [0, 1, 2] - assert sparse_dataset._check_cached_episodes_sufficient() is False + sparse_dataset.reader.episodes = [0, 1, 2] + assert sparse_dataset.reader._check_cached_episodes_sufficient() is False def test_update_chunk_settings(tmp_path, empty_lerobot_dataset_factory): @@ -1189,13 +1199,13 @@ def test_dataset_resume_recording(tmp_path, empty_lerobot_dataset_factory): del dataset_verify # Phase 3: Resume recording - add more episodes - dataset_resumed = LeRobotDataset(initial_repo_id, root=initial_root, revision="v3.0") + dataset_resumed = LeRobotDataset.resume(initial_repo_id, root=initial_root, revision="v3.0") assert dataset_resumed.meta.total_episodes == initial_episodes assert dataset_resumed.meta.total_frames == initial_episodes * frames_per_episode - assert dataset_resumed.latest_episode is None # Not recording yet - assert dataset_resumed.writer is None - assert dataset_resumed.meta.writer is None + assert dataset_resumed.writer._latest_episode is None # Not recording yet + assert dataset_resumed.writer._pq_writer is None + assert dataset_resumed.meta._pq_writer is None additional_episodes = 2 for ep_idx in range(initial_episodes, initial_episodes + additional_episodes): @@ -1271,7 +1281,7 @@ def test_frames_in_current_file_calculation(tmp_path, empty_lerobot_dataset_fact dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features, use_videos=False) dataset.meta.update_chunk_settings(data_files_size_in_mb=100) - assert dataset._current_file_start_frame is None + assert dataset.writer._current_file_start_frame is None frames_per_episode = 10 for _ in range(frames_per_episode): @@ -1284,7 +1294,7 @@ def test_frames_in_current_file_calculation(tmp_path, empty_lerobot_dataset_fact ) dataset.save_episode() - assert dataset._current_file_start_frame == 0 + assert dataset.writer._current_file_start_frame == 0 assert dataset.meta.total_episodes == 1 assert dataset.meta.total_frames == frames_per_episode @@ -1298,12 +1308,12 @@ def test_frames_in_current_file_calculation(tmp_path, empty_lerobot_dataset_fact ) dataset.save_episode() - assert dataset._current_file_start_frame == 0 + assert dataset.writer._current_file_start_frame == 0 assert dataset.meta.total_episodes == 2 assert dataset.meta.total_frames == 2 * frames_per_episode - ep1_chunk = dataset.latest_episode["data/chunk_index"] - ep1_file = dataset.latest_episode["data/file_index"] + ep1_chunk = dataset.writer._latest_episode["data/chunk_index"] + ep1_file = dataset.writer._latest_episode["data/file_index"] assert ep1_chunk == 0 assert ep1_file == 0 @@ -1317,12 +1327,12 @@ def test_frames_in_current_file_calculation(tmp_path, empty_lerobot_dataset_fact ) dataset.save_episode() - assert dataset._current_file_start_frame == 0 + assert dataset.writer._current_file_start_frame == 0 assert dataset.meta.total_episodes == 3 assert dataset.meta.total_frames == 3 * frames_per_episode - ep2_chunk = dataset.latest_episode["data/chunk_index"] - ep2_file = dataset.latest_episode["data/file_index"] + ep2_chunk = dataset.writer._latest_episode["data/chunk_index"] + ep2_file = dataset.writer._latest_episode["data/file_index"] assert ep2_chunk == 0 assert ep2_file == 0 @@ -1354,82 +1364,6 @@ def test_frames_in_current_file_calculation(tmp_path, empty_lerobot_dataset_fact assert frame["episode_index"].item() == expected_ep -def test_encode_video_worker_forwards_vcodec(tmp_path): - """Test that _encode_video_worker correctly forwards the vcodec parameter to encode_video_frames.""" - from unittest.mock import patch - - from lerobot.datasets.utils import DEFAULT_IMAGE_PATH - - # Create the expected directory structure - video_key = "observation.images.laptop" - episode_index = 0 - frame_index = 0 - - fpath = DEFAULT_IMAGE_PATH.format( - image_key=video_key, episode_index=episode_index, frame_index=frame_index - ) - img_dir = tmp_path / Path(fpath).parent - img_dir.mkdir(parents=True, exist_ok=True) - - # Create a dummy image file - dummy_img = Image.new("RGB", (64, 64), color="red") - dummy_img.save(img_dir / "frame-000000.png") - - # Track what vcodec was passed to encode_video_frames - captured_kwargs = {} - - def mock_encode_video_frames(imgs_dir, video_path, fps, **kwargs): - captured_kwargs.update(kwargs) - # Create a dummy output file so the worker doesn't fail - Path(video_path).parent.mkdir(parents=True, exist_ok=True) - Path(video_path).touch() - - with patch("lerobot.datasets.lerobot_dataset.encode_video_frames", side_effect=mock_encode_video_frames): - # Test with h264 codec - _encode_video_worker(video_key, episode_index, tmp_path, fps=30, vcodec="h264") - - assert "vcodec" in captured_kwargs - assert captured_kwargs["vcodec"] == "h264" - - -def test_encode_video_worker_default_vcodec(tmp_path): - """Test that _encode_video_worker uses libsvtav1 as the default codec.""" - from unittest.mock import patch - - from lerobot.datasets.utils import DEFAULT_IMAGE_PATH - - # Create the expected directory structure - video_key = "observation.images.laptop" - episode_index = 0 - frame_index = 0 - - fpath = DEFAULT_IMAGE_PATH.format( - image_key=video_key, episode_index=episode_index, frame_index=frame_index - ) - img_dir = tmp_path / Path(fpath).parent - img_dir.mkdir(parents=True, exist_ok=True) - - # Create a dummy image file - dummy_img = Image.new("RGB", (64, 64), color="red") - dummy_img.save(img_dir / "frame-000000.png") - - # Track what vcodec was passed to encode_video_frames - captured_kwargs = {} - - def mock_encode_video_frames(imgs_dir, video_path, fps, **kwargs): - captured_kwargs.update(kwargs) - # Create a dummy output file so the worker doesn't fail - Path(video_path).parent.mkdir(parents=True, exist_ok=True) - Path(video_path).touch() - - with patch("lerobot.datasets.lerobot_dataset.encode_video_frames", side_effect=mock_encode_video_frames): - # Test with default codec (no vcodec specified) - _encode_video_worker(video_key, episode_index, tmp_path, fps=30) - - assert "vcodec" in captured_kwargs - assert captured_kwargs["vcodec"] == "libsvtav1" - - def test_lerobot_dataset_vcodec_validation(): """Test that LeRobotDataset validates the vcodec parameter.""" # Test that invalid vcodec raises ValueError diff --git a/tests/datasets/test_image_writer.py b/tests/datasets/test_image_writer.py index e02755171..55419473f 100644 --- a/tests/datasets/test_image_writer.py +++ b/tests/datasets/test_image_writer.py @@ -352,10 +352,14 @@ def test_with_different_image_formats(tmp_path, img_array_factory): def test_safe_stop_image_writer_decorator(): - class MockDataset: + class MockWriter: def __init__(self): self.image_writer = MagicMock(spec=AsyncImageWriter) + class MockDataset: + def __init__(self): + self.writer = MockWriter() + @safe_stop_image_writer def function_that_raises_exception(dataset=None): raise Exception("Test exception") @@ -366,7 +370,7 @@ def test_safe_stop_image_writer_decorator(): function_that_raises_exception(dataset=dataset) assert str(exc_info.value) == "Test exception" - dataset.image_writer.stop.assert_called_once() + dataset.writer.image_writer.stop.assert_called_once() def test_main_process_time(tmp_path, img_tensor_factory): diff --git a/tests/datasets/test_lerobot_dataset.py b/tests/datasets/test_lerobot_dataset.py new file mode 100644 index 000000000..d7ce54a15 --- /dev/null +++ b/tests/datasets/test_lerobot_dataset.py @@ -0,0 +1,314 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contract tests for the LeRobotDataset facade. + +Tests focus on mode contracts (read-only, write-only, resume), guards, +property delegation, and the full create-record-finalize-read lifecycle. +""" + +import pytest +import torch + +from lerobot.datasets.dataset_reader import DatasetReader +from lerobot.datasets.dataset_writer import DatasetWriter +from lerobot.datasets.lerobot_dataset import LeRobotDataset +from tests.fixtures.constants import DEFAULT_FPS, DUMMY_REPO_ID + +SIMPLE_FEATURES = { + "state": {"dtype": "float32", "shape": (2,), "names": None}, +} + + +def _make_frame(task: str = "Dummy task") -> dict: + return {"task": task, "state": torch.randn(2)} + + +# ── Read-only mode (via __init__) ──────────────────────────────────── + + +def test_init_creates_reader_no_writer(tmp_path, lerobot_dataset_factory): + """__init__() sets reader to a DatasetReader and writer to None.""" + dataset = lerobot_dataset_factory( + root=tmp_path / "ds", total_episodes=1, total_frames=10, use_videos=False + ) + assert isinstance(dataset.reader, DatasetReader) + assert dataset.writer is None + + +def test_init_loads_data(tmp_path, lerobot_dataset_factory): + """After __init__(), the dataset has data and len > 0.""" + dataset = lerobot_dataset_factory( + root=tmp_path / "ds", total_episodes=1, total_frames=10, use_videos=False + ) + assert len(dataset) > 0 + + +def test_getitem_works_in_read_mode(tmp_path, lerobot_dataset_factory): + """dataset[0] returns a dict with expected keys in read-only mode.""" + dataset = lerobot_dataset_factory( + root=tmp_path / "ds", total_episodes=1, total_frames=10, use_videos=False + ) + item = dataset[0] + assert isinstance(item, dict) + assert "index" in item + assert "task" in item + + +def test_len_matches_num_frames(tmp_path, lerobot_dataset_factory): + """len(dataset) equals dataset.num_frames.""" + dataset = lerobot_dataset_factory( + root=tmp_path / "ds", total_episodes=2, total_frames=30, use_videos=False + ) + assert len(dataset) == dataset.num_frames + + +# ── Write-only mode (via create()) ────────────────────────────────── + + +def test_create_sets_writer_no_reader(tmp_path): + """create() sets writer to a DatasetWriter and reader to None.""" + dataset = LeRobotDataset.create( + repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=tmp_path / "ds" + ) + assert isinstance(dataset.writer, DatasetWriter) + assert dataset.reader is None + + +def test_create_initial_counts_zero(tmp_path): + """After create(), num_episodes == 0 and num_frames == 0.""" + dataset = LeRobotDataset.create( + repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=tmp_path / "ds" + ) + assert dataset.num_episodes == 0 + assert dataset.num_frames == 0 + + +def test_add_frame_works_in_write_mode(tmp_path): + """add_frame() succeeds on a dataset created via create().""" + dataset = LeRobotDataset.create( + repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=tmp_path / "ds" + ) + dataset.add_frame(_make_frame()) # should not raise + + +# ── Resume mode ────────────────────────────────────────────────────── + + +def test_resume_creates_writer(tmp_path): + """After resume(), writer is a DatasetWriter.""" + root = tmp_path / "resume_ds" + dataset = LeRobotDataset.create( + repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=root + ) + for _ in range(3): + dataset.add_frame(_make_frame()) + dataset.save_episode() + dataset.finalize() + + resumed = LeRobotDataset.resume(repo_id=DUMMY_REPO_ID, root=root) + assert isinstance(resumed.writer, DatasetWriter) + + +def test_resume_preserves_episode_count(tmp_path): + """After resume(), existing episodes are counted.""" + root = tmp_path / "resume_ds" + dataset = LeRobotDataset.create( + repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=root + ) + for _ in range(3): + dataset.add_frame(_make_frame()) + dataset.save_episode() + dataset.finalize() + + resumed = LeRobotDataset.resume(repo_id=DUMMY_REPO_ID, root=root) + assert resumed.meta.total_episodes == 1 + + +def test_resume_can_add_more_episodes(tmp_path): + """After resume(), new episodes can be added.""" + root = tmp_path / "resume_ds" + dataset = LeRobotDataset.create( + repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=root + ) + for _ in range(3): + dataset.add_frame(_make_frame()) + dataset.save_episode() + dataset.finalize() + + resumed = LeRobotDataset.resume(repo_id=DUMMY_REPO_ID, root=root) + for _ in range(2): + resumed.add_frame(_make_frame()) + resumed.save_episode() + + assert resumed.meta.total_episodes == 2 + + +# ── Writer guard ───────────────────────────────────────────────────── + + +def test_add_frame_raises_without_writer(tmp_path, lerobot_dataset_factory): + """add_frame() raises RuntimeError on a read-only dataset.""" + dataset = lerobot_dataset_factory( + root=tmp_path / "ds", total_episodes=1, total_frames=5, use_videos=False + ) + with pytest.raises(RuntimeError, match="read-only"): + dataset.add_frame(_make_frame()) + + +def test_save_episode_raises_without_writer(tmp_path, lerobot_dataset_factory): + """save_episode() raises RuntimeError on a read-only dataset.""" + dataset = lerobot_dataset_factory( + root=tmp_path / "ds", total_episodes=1, total_frames=5, use_videos=False + ) + with pytest.raises(RuntimeError, match="read-only"): + dataset.save_episode() + + +def test_clear_episode_buffer_raises_without_writer(tmp_path, lerobot_dataset_factory): + """clear_episode_buffer() raises RuntimeError on a read-only dataset.""" + dataset = lerobot_dataset_factory( + root=tmp_path / "ds", total_episodes=1, total_frames=5, use_videos=False + ) + with pytest.raises(RuntimeError, match="read-only"): + dataset.clear_episode_buffer() + + +# ── Reader guard ───────────────────────────────────────────────────── + + +def test_getitem_raises_before_finalize(tmp_path): + """dataset[0] raises RuntimeError while recording (before finalize).""" + dataset = LeRobotDataset.create( + repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=tmp_path / "ds" + ) + for _ in range(3): + dataset.add_frame(_make_frame()) + dataset.save_episode() + + with pytest.raises(RuntimeError, match="finalize"): + dataset[0] + + +def test_getitem_works_after_finalize(tmp_path): + """After finalize(), dataset[0] returns data.""" + dataset = LeRobotDataset.create( + repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=tmp_path / "ds" + ) + for _ in range(3): + dataset.add_frame(_make_frame()) + dataset.save_episode() + dataset.finalize() + + item = dataset[0] + assert "state" in item + assert "task" in item + + +# ── Property delegation ────────────────────────────────────────────── + + +def test_fps_delegates_to_meta(tmp_path, lerobot_dataset_factory): + """dataset.fps == dataset.meta.fps.""" + dataset = lerobot_dataset_factory( + root=tmp_path / "ds", total_episodes=1, total_frames=5, use_videos=False + ) + assert dataset.fps == dataset.meta.fps + + +def test_features_delegates_to_meta(tmp_path, lerobot_dataset_factory): + """dataset.features is dataset.meta.features.""" + dataset = lerobot_dataset_factory( + root=tmp_path / "ds", total_episodes=1, total_frames=5, use_videos=False + ) + assert dataset.features is dataset.meta.features + + +def test_num_frames_uses_meta_in_write_mode(tmp_path): + """In write-only mode (reader=None), num_frames comes from metadata.""" + dataset = LeRobotDataset.create( + repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=tmp_path / "ds" + ) + assert dataset.reader is None + assert dataset.num_frames == dataset.meta.total_frames + + +# ── Lifecycle ──────────────────────────────────────────────────────── + + +def test_finalize_is_idempotent(tmp_path): + """Calling finalize() twice does not raise.""" + dataset = LeRobotDataset.create( + repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=tmp_path / "ds" + ) + dataset.finalize() + dataset.finalize() + + +def test_has_pending_frames_lifecycle(tmp_path): + """has_pending_frames: False -> True (add_frame) -> False (save_episode).""" + dataset = LeRobotDataset.create( + repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=tmp_path / "ds" + ) + assert dataset.has_pending_frames() is False + + dataset.add_frame(_make_frame()) + assert dataset.has_pending_frames() is True + + dataset.save_episode() + assert dataset.has_pending_frames() is False + + +def test_create_record_finalize_read_roundtrip(tmp_path): + """End-to-end: create, record 2 episodes, finalize, re-open, verify data.""" + root = tmp_path / "roundtrip" + dataset = LeRobotDataset.create( + repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=root + ) + + # Episode 0: 3 frames with known values + ep0_states = [] + for i in range(3): + state = torch.tensor([float(i), float(i * 2)]) + ep0_states.append(state) + dataset.add_frame({"task": "Task A", "state": state}) + dataset.save_episode() + + # Episode 1: 2 frames + ep1_states = [] + for i in range(2): + state = torch.tensor([float(i + 100), float(i + 200)]) + ep1_states.append(state) + dataset.add_frame({"task": "Task B", "state": state}) + dataset.save_episode() + + dataset.finalize() + + # Re-open as read-only + reopened = LeRobotDataset(repo_id=DUMMY_REPO_ID, root=root) + assert len(reopened) == 5 + assert reopened.num_episodes == 2 + + # Verify episode 0 + for i in range(3): + item = reopened[i] + assert torch.allclose(item["state"], ep0_states[i], atol=1e-5) + assert item["episode_index"].item() == 0 + + # Verify episode 1 + for i in range(2): + item = reopened[3 + i] + assert torch.allclose(item["state"], ep1_states[i], atol=1e-5) + assert item["episode_index"].item() == 1 diff --git a/tests/datasets/test_streaming_video_encoder.py b/tests/datasets/test_streaming_video_encoder.py index a85db6a8d..f7e63b06f 100644 --- a/tests/datasets/test_streaming_video_encoder.py +++ b/tests/datasets/test_streaming_video_encoder.py @@ -534,7 +534,7 @@ class TestStreamingEncoderIntegration: streaming_encoding=True, ) - assert dataset._streaming_encoder is not None + assert dataset.writer._streaming_encoder is not None num_frames = 20 for _ in range(num_frames): @@ -580,7 +580,7 @@ class TestStreamingEncoderIntegration: streaming_encoding=False, ) - assert dataset._streaming_encoder is None + assert dataset.writer._streaming_encoder is None num_frames = 5 for _ in range(num_frames): From aa9cc9bd43e9eba92a32aaadde2a09c90b5836cf Mon Sep 17 00:00:00 2001 From: Reece O'Mahoney <66252930+reeceomahoney@users.noreply.github.com> Date: Thu, 26 Mar 2026 20:05:15 +0000 Subject: [PATCH 122/131] fix(logging): suppress noisy httpx INFO logs (#3173) Set httpx logger level to WARNING in init_logging to prevent HTTP request traces from flooding the terminal during train and eval scripts. Co-authored-by: Steven Palma --- src/lerobot/utils/utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/lerobot/utils/utils.py b/src/lerobot/utils/utils.py index b9f8441d6..f6aa93bea 100644 --- a/src/lerobot/utils/utils.py +++ b/src/lerobot/utils/utils.py @@ -95,6 +95,8 @@ def init_logging( file_handler.setLevel(file_level.upper()) logger.addHandler(file_handler) + logging.getLogger("httpx").setLevel(logging.WARNING) + def format_big_number(num, precision=0): suffixes = ["", "K", "M", "B", "T", "Q"] From 07502868e58095b437e5dd5a480fecc65a6f29dc Mon Sep 17 00:00:00 2001 From: Maxime Ellerbach Date: Fri, 27 Mar 2026 21:25:12 +0100 Subject: [PATCH 123/131] fix(deps): breaking change from transformers 5.4.0 (#3231) * fix(deps): breaking change from transformers 5.4.0 * Update src/lerobot/policies/xvla/modeling_florence2.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Maxime Ellerbach * Update src/lerobot/policies/wall_x/qwen_model/qwen2_5_vl_moe.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Maxime Ellerbach * removing dataclass * bumping transformers 5.4.0 --------- Signed-off-by: Maxime Ellerbach Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- pyproject.toml | 2 +- .../policies/groot/action_head/flow_matching_action_head.py | 3 +-- src/lerobot/policies/groot/groot_n1.py | 3 +-- src/lerobot/policies/wall_x/qwen_model/qwen2_5_vl_moe.py | 4 ++-- src/lerobot/policies/xvla/modeling_florence2.py | 4 ++-- 5 files changed, 7 insertions(+), 9 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 5f45626c0..7e4f24eb6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -99,7 +99,7 @@ dependencies = [ # Common pygame-dep = ["pygame>=2.5.1,<2.7.0"] placo-dep = ["placo>=0.9.6,<0.9.17"] -transformers-dep = ["transformers>=5.3.0,<6.0.0"] +transformers-dep = ["transformers>=5.4.0,<6.0.0"] grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"] can-dep = ["python-can>=4.2.0,<5.0.0"] peft-dep = ["peft>=0.18.0,<1.0.0"] diff --git a/src/lerobot/policies/groot/action_head/flow_matching_action_head.py b/src/lerobot/policies/groot/action_head/flow_matching_action_head.py index bfc456ba0..74d922988 100644 --- a/src/lerobot/policies/groot/action_head/flow_matching_action_head.py +++ b/src/lerobot/policies/groot/action_head/flow_matching_action_head.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass, field +from dataclasses import field from typing import TYPE_CHECKING import torch @@ -110,7 +110,6 @@ class MultiEmbodimentActionEncoder(nn.Module): return x -@dataclass class FlowmatchingActionHeadConfig(PretrainedConfig): """NOTE: N1.5 uses XEmbFlowmatchingPolicyHeadConfig as action head""" diff --git a/src/lerobot/policies/groot/groot_n1.py b/src/lerobot/policies/groot/groot_n1.py index 06ff5a04d..38512b8a8 100644 --- a/src/lerobot/policies/groot/groot_n1.py +++ b/src/lerobot/policies/groot/groot_n1.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass, field +from dataclasses import field from pathlib import Path from typing import TYPE_CHECKING @@ -173,7 +173,6 @@ N_COLOR_CHANNELS = 3 # config -@dataclass class GR00TN15Config(PretrainedConfig): model_type = "gr00t_n1_5" backbone_cfg: dict = field(init=False, metadata={"help": "Backbone configuration."}) diff --git a/src/lerobot/policies/wall_x/qwen_model/qwen2_5_vl_moe.py b/src/lerobot/policies/wall_x/qwen_model/qwen2_5_vl_moe.py index ecf3eb371..a80096514 100644 --- a/src/lerobot/policies/wall_x/qwen_model/qwen2_5_vl_moe.py +++ b/src/lerobot/policies/wall_x/qwen_model/qwen2_5_vl_moe.py @@ -22,7 +22,7 @@ from transformers.utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, - is_flash_attn_greater_or_equal_2_10, + is_flash_attn_greater_or_equal, is_torchdynamo_compiling, logging, replace_return_docstrings, @@ -890,7 +890,7 @@ class Qwen2_5_VLFlashAttention2(Qwen2_5_VLAttention): # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal("2.1.0") def forward( self, diff --git a/src/lerobot/policies/xvla/modeling_florence2.py b/src/lerobot/policies/xvla/modeling_florence2.py index e33efe5c3..81f9c8234 100644 --- a/src/lerobot/policies/xvla/modeling_florence2.py +++ b/src/lerobot/policies/xvla/modeling_florence2.py @@ -45,7 +45,7 @@ from transformers.utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, - is_flash_attn_greater_or_equal_2_10, + is_flash_attn_greater_or_equal, logging, replace_return_docstrings, ) @@ -909,7 +909,7 @@ class Florence2FlashAttention2(Florence2Attention): # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal("2.1.0") def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) From 975d89b38d1091589f4232c0449c32e75e27b2e5 Mon Sep 17 00:00:00 2001 From: Maxime Ellerbach Date: Fri, 27 Mar 2026 21:25:37 +0100 Subject: [PATCH 124/131] chore(docs): add more guidance to bring your own policies tutorial (#3230) * chore(docs): add more guidance to bring your own policies tutorial * removing normalization to avoid confusion with processors * trailing whitespace * Update docs/source/bring_your_own_policies.mdx Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Maxime Ellerbach * Update docs/source/bring_your_own_policies.mdx Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Maxime Ellerbach * adding get optim params and predict_action chunk * removing extra quotes --------- Signed-off-by: Maxime Ellerbach --- docs/source/bring_your_own_policies.mdx | 98 +++++++++++++++++++++---- 1 file changed, 85 insertions(+), 13 deletions(-) diff --git a/docs/source/bring_your_own_policies.mdx b/docs/source/bring_your_own_policies.mdx index 9266c9e5b..38c32aa71 100644 --- a/docs/source/bring_your_own_policies.mdx +++ b/docs/source/bring_your_own_policies.mdx @@ -41,13 +41,15 @@ requires = # your-build-system ## Step 2: Define the Policy Configuration -Create a configuration class that inherits from `PreTrainedConfig` and registers your policy type: +Create a configuration class that inherits from [`PreTrainedConfig`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/configs/policies.py) and registers your policy type: +Here is a template to get you started, customize the parameters and methods as needed for your policy's architecture and training requirements. ```python # configuration_my_custom_policy.py from dataclasses import dataclass, field from lerobot.configs.policies import PreTrainedConfig -from lerobot.configs.types import NormalizationMode +from lerobot.optim.optimizers import AdamWConfig +from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig @PreTrainedConfig.register_subclass("my_custom_policy") @dataclass @@ -61,22 +63,56 @@ class MyCustomPolicyConfig(PreTrainedConfig): hidden_dim: Hidden dimension for the policy network # Add your policy-specific parameters here """ - # ...PreTrainedConfig fields... - pass + + horizon: int = 50 + n_action_steps: int = 50 + hidden_dim: int = 256 + + optimizer_lr: float = 1e-4 + optimizer_weight_decay: float = 1e-4 def __post_init__(self): super().__post_init__() - # Add any validation logic here + if self.n_action_steps > self.horizon: + raise ValueError("n_action_steps cannot exceed horizon") def validate_features(self) -> None: """Validate input/output feature compatibility.""" - # Implement validation logic for your policy's requirements - pass + if not self.image_features: + raise ValueError("MyCustomPolicy requires at least one image feature.") + if self.action_feature is None: + raise ValueError("MyCustomPolicy requires 'action' in output_features.") + + def get_optimizer_preset(self) -> AdamWConfig: + return AdamWConfig(lr=self.optimizer_lr, weight_decay=self.optimizer_weight_decay) + + def get_scheduler_preset(self): + return None + + @property + def observation_delta_indices(self) -> list[int] | None: + """Relative timestep offsets the dataset loader provides per observation. + + Return `None` for single-frame policies. For temporal policies that consume + multiple past or future frames, return a list of offsets, e.g. `[-20, -10, 0, 10]` for + 3 past frames at stride 10 and 1 future frame at stride 10. + """ + return None + + @property + def action_delta_indices(self) -> list[int]: + """Relative timestep offsets for the action chunk the dataset loader returns. + """ + return list(range(self.horizon)) + + @property + def reward_delta_indices(self) -> None: + return None ``` ## Step 3: Implement the Policy Class -Create your policy implementation by inheriting from LeRobot's base `PreTrainedPolicy` class: +Create your policy implementation by inheriting from [`PreTrainedPolicy`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/pretrained.py): ```python # modeling_my_custom_policy.py @@ -85,38 +121,74 @@ import torch.nn as nn from typing import Any from lerobot.policies.pretrained import PreTrainedPolicy +from lerobot.utils.constants import ACTION from .configuration_my_custom_policy import MyCustomPolicyConfig class MyCustomPolicy(PreTrainedPolicy): - config_class = MyCustomPolicyConfig + config_class = MyCustomPolicyConfig # must match the string in @register_subclass name = "my_custom_policy" def __init__(self, config: MyCustomPolicyConfig, dataset_stats: dict[str, Any] = None): super().__init__(config, dataset_stats) + config.validate_features() # not called automatically by the base class + self.config = config + self.model = ... # your nn.Module here + + def reset(self): + """Reset episode state.""" ... + + def get_optim_params(self) -> dict: + """Return parameters to pass to the optimizer (e.g. with per-group lr/wd).""" + return {"params": self.parameters()} + + def predict_action_chunk(self, batch: dict[str, torch.Tensor], **kwargs) -> torch.Tensor: + """Return the full action chunk (B, chunk_size, action_dim) for the current observation.""" + ... + + def select_action(self, batch: dict[str, torch.Tensor], **kwargs) -> torch.Tensor: + """Return a single action for the current timestep (called at inference).""" + ... + + def forward(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + """Compute the training loss. + + `batch["action_is_pad"]` is a bool mask of shape (B, horizon) that marks + timesteps padded because the episode ended before `horizon` steps, you + can exclude those from your loss. + """ + actions = batch[ACTION] + action_is_pad = batch.get("action_is_pad") + ... + return {"loss": ...} ``` ## Step 4: Add Data Processors -Create processor functions: +Create processor functions. For a concrete reference, see [processor_act.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/act/processor_act.py) or [processor_diffusion.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/diffusion/processor_diffusion.py). ```python # processor_my_custom_policy.py from typing import Any import torch +from lerobot.processor import PolicyAction, PolicyProcessorPipeline + def make_my_custom_policy_pre_post_processors( config, + dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None, ) -> tuple[ PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], PolicyProcessorPipeline[PolicyAction, PolicyAction], ]: - """Create preprocessing and postprocessing functions for your policy.""" - pass # Define your preprocessing and postprocessing logic here - + preprocessor = ... # build your PolicyProcessorPipeline for inputs + postprocessor = ... # build your PolicyProcessorPipeline for outputs + return preprocessor, postprocessor ``` +**Important - function naming:** LeRobot discovers your processor by name. The function **must** be called `make_{policy_name}_pre_post_processors` (matching the string you passed to `@PreTrainedConfig.register_subclass`). + ## Step 5: Package Initialization Expose your classes in the package's `__init__.py`: From 4e45acca52679745f9c7d4b80984ef4c59fe9a57 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Fri, 27 Mar 2026 22:21:55 +0100 Subject: [PATCH 125/131] fix(dataset): use revision-safe Hub cache for downloaded datasets (#3233) * refactor(dataset): enhance dataset root directory handling and introduce hub cache support - Updated DatasetConfig and LeRobotDatasetMetadata to clarify root directory behavior and introduce a dedicated hub cache for downloads. - Refactored LeRobotDataset and StreamingLeRobotDataset to utilize the new hub cache and improve directory management. - Added tests to ensure correct behavior when using the hub cache and handling different revisions without a specified root directory. * refactor(dataset): improve root directory handling in LeRobotDataset - Updated LeRobotDataset to store the requested root path separately from the actual root path. - Adjusted metadata loading to use the requested root, enhancing clarity and consistency in directory management. * refactor(dataset): minor improvements for hub cache support * chore(datasets): guard in resume + assertion test --------- Co-authored-by: AdilZouitine Co-authored-by: mickaelChen --- src/lerobot/configs/default.py | 3 +- src/lerobot/datasets/dataset_metadata.py | 39 ++- src/lerobot/datasets/dataset_reader.py | 8 +- src/lerobot/datasets/lerobot_dataset.py | 81 ++++-- src/lerobot/datasets/streaming_dataset.py | 14 +- src/lerobot/datasets/utils.py | 13 + src/lerobot/utils/constants.py | 4 + tests/datasets/test_lerobot_dataset.py | 318 ++++++++++++++++++++++ 8 files changed, 440 insertions(+), 40 deletions(-) diff --git a/src/lerobot/configs/default.py b/src/lerobot/configs/default.py index 7f481b9ca..58ed64420 100644 --- a/src/lerobot/configs/default.py +++ b/src/lerobot/configs/default.py @@ -27,7 +27,8 @@ class DatasetConfig: # "dataset_index" into the returned item. The index mapping is made according to the order in which the # datasets are provided. repo_id: str - # Root directory where the dataset will be stored (e.g. 'dataset/path'). If None, defaults to $HF_LEROBOT_HOME/repo_id. + # Root directory for a concrete local dataset tree (e.g. 'dataset/path'). If None, local datasets are + # looked up under $HF_LEROBOT_HOME/repo_id and Hub downloads use a revision-safe cache under $HF_LEROBOT_HOME/hub. root: str | None = None episodes: list[int] | None = None image_transforms: ImageTransformsConfig = field(default_factory=ImageTransformsConfig) diff --git a/src/lerobot/datasets/dataset_metadata.py b/src/lerobot/datasets/dataset_metadata.py index a43ba07b4..65dbc9c4a 100644 --- a/src/lerobot/datasets/dataset_metadata.py +++ b/src/lerobot/datasets/dataset_metadata.py @@ -44,11 +44,12 @@ from lerobot.datasets.utils import ( check_version_compatibility, flatten_dict, get_safe_version, + has_legacy_hub_download_metadata, is_valid_version, update_chunk_file_indices, ) from lerobot.datasets.video_utils import get_video_info -from lerobot.utils.constants import HF_LEROBOT_HOME +from lerobot.utils.constants import HF_LEROBOT_HOME, HF_LEROBOT_HUB_CACHE CODEBASE_VERSION = "v3.0" @@ -77,8 +78,12 @@ class LeRobotDatasetMetadata: Args: repo_id: Repository identifier (e.g. ``'lerobot/aloha_sim'``). - root: Local directory for the dataset. Defaults to - ``$HF_LEROBOT_HOME/{repo_id}``. + root: Local directory for the dataset. When provided, Hub downloads + are materialized directly into this directory. When omitted, + existing local datasets are still looked up under + ``$HF_LEROBOT_HOME/{repo_id}``, but Hub downloads use a + revision-safe snapshot cache under + ``$HF_LEROBOT_HOME/hub``. revision: Git revision (branch, tag, or commit hash). Defaults to the current codebase version. force_cache_sync: If ``True``, re-download metadata from the Hub @@ -88,7 +93,8 @@ class LeRobotDatasetMetadata: """ self.repo_id = repo_id self.revision = revision if revision else CODEBASE_VERSION - self.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id + self._requested_root = Path(root) if root is not None else None + self.root = self._requested_root if self._requested_root is not None else HF_LEROBOT_HOME / repo_id self._pq_writer = None self.latest_episode = None self._metadata_buffer: list[dict] = [] @@ -96,14 +102,15 @@ class LeRobotDatasetMetadata: self._finalized = False try: - if force_cache_sync: + if force_cache_sync or ( + self._requested_root is None and has_legacy_hub_download_metadata(self.root) + ): raise FileNotFoundError self._load_metadata() except (FileNotFoundError, NotADirectoryError): if is_valid_version(self.revision): self.revision = get_safe_version(self.repo_id, self.revision) - (self.root / "meta").mkdir(exist_ok=True, parents=True) self._pull_from_repo(allow_patterns="meta/") self._load_metadata() @@ -178,14 +185,29 @@ class LeRobotDatasetMetadata: allow_patterns: list[str] | str | None = None, ignore_patterns: list[str] | str | None = None, ) -> None: + if self._requested_root is None: + self.root = Path( + snapshot_download( + self.repo_id, + repo_type="dataset", + revision=self.revision, + cache_dir=HF_LEROBOT_HUB_CACHE, + allow_patterns=allow_patterns, + ignore_patterns=ignore_patterns, + ) + ) + return + + self._requested_root.mkdir(exist_ok=True, parents=True) snapshot_download( self.repo_id, repo_type="dataset", revision=self.revision, - local_dir=self.root, + local_dir=self._requested_root, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns, ) + self.root = self._requested_root @property def url_root(self) -> str: @@ -593,7 +615,8 @@ class LeRobotDatasetMetadata: """ obj = cls.__new__(cls) obj.repo_id = repo_id - obj.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id + obj._requested_root = Path(root) if root is not None else None + obj.root = obj._requested_root if obj._requested_root is not None else HF_LEROBOT_HOME / repo_id obj.root.mkdir(parents=True, exist_ok=False) diff --git a/src/lerobot/datasets/dataset_reader.py b/src/lerobot/datasets/dataset_reader.py index 0233a3cf6..3720a5084 100644 --- a/src/lerobot/datasets/dataset_reader.py +++ b/src/lerobot/datasets/dataset_reader.py @@ -68,7 +68,7 @@ class DatasetReader: visual features. """ self._meta = meta - self._root = root + self.root = root self.episodes = episodes self._tolerance_s = tolerance_s self._video_backend = video_backend @@ -125,7 +125,7 @@ class DatasetReader: def _load_hf_dataset(self) -> datasets.Dataset: """hf_dataset contains all the observations, states, actions, rewards, etc.""" features = get_hf_features_from_features(self._meta.features) - hf_dataset = load_nested_dataset(self._root / "data", features=features, episodes=self.episodes) + hf_dataset = load_nested_dataset(self.root / "data", features=features, episodes=self.episodes) hf_dataset.set_transform(hf_transform_to_torch) return hf_dataset @@ -150,7 +150,7 @@ class DatasetReader: if len(self._meta.video_keys) > 0: for ep_idx in requested_episodes: for vid_key in self._meta.video_keys: - video_path = self._root / self._meta.get_video_file_path(ep_idx, vid_key) + video_path = self.root / self._meta.get_video_file_path(ep_idx, vid_key) if not video_path.exists(): return False @@ -240,7 +240,7 @@ class DatasetReader: from_timestamp = ep[f"videos/{vid_key}/from_timestamp"] shifted_query_ts = [from_timestamp + ts for ts in query_ts] - video_path = self._root / self._meta.get_video_file_path(ep_idx, vid_key) + video_path = self.root / self._meta.get_video_file_path(ep_idx, vid_key) frames = decode_video_frames(video_path, shifted_query_ts, self._tolerance_s, self._video_backend) item[vid_key] = frames.squeeze(0) diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index cba0c1cba..f719222fd 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -37,7 +37,7 @@ from lerobot.datasets.video_utils import ( get_safe_default_codec, resolve_vcodec, ) -from lerobot.utils.constants import HF_LEROBOT_HOME +from lerobot.utils.constants import HF_LEROBOT_HUB_CACHE logger = logging.getLogger(__name__) @@ -144,10 +144,11 @@ class LeRobotDataset(torch.utils.data.Dataset): Args: repo_id (str): This is the repo id that will be used to fetch the dataset. - root (Path | None, optional): Local directory where the dataset will be downloaded and - stored. If set, all dataset files will be stored directly under this path. If not set, the - dataset files will be stored under $HF_LEROBOT_HOME/repo_id (configurable via the - HF_LEROBOT_HOME environment variable). + root (Path | None, optional): Local directory where the dataset will be read from or downloaded + into. If set, all dataset files are materialized directly under this path. If not set, + existing local datasets are still looked up under ``$HF_LEROBOT_HOME/{repo_id}``, but Hub + downloads use a revision-safe snapshot cache under + ``$HF_LEROBOT_HOME/hub``. episodes (list[int] | None, optional): If specified, this will only load episodes specified by their episode_index in this list. Defaults to None. image_transforms (Callable | None, optional): You can pass standard v2 image transforms from @@ -190,7 +191,7 @@ class LeRobotDataset(torch.utils.data.Dataset): """ super().__init__() self.repo_id = repo_id - self.root = Path(root) if root else HF_LEROBOT_HOME / repo_id + self._requested_root = Path(root) if root else None self.image_transforms = image_transforms self.delta_timestamps = delta_timestamps self.episodes = episodes @@ -201,12 +202,15 @@ class LeRobotDataset(torch.utils.data.Dataset): self._vcodec = resolve_vcodec(vcodec) self._encoder_threads = encoder_threads - self.root.mkdir(exist_ok=True, parents=True) + if self._requested_root is not None: + self._requested_root.mkdir(exist_ok=True, parents=True) - # Load metadata + # Load metadata (sets self.root once from the resolved metadata root) self.meta = LeRobotDatasetMetadata( - self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync + self.repo_id, self._requested_root, self.revision, force_cache_sync=force_cache_sync ) + self.root = self.meta.root + self.revision = self.meta.revision # Create reader (hf_dataset loaded below) self.reader = DatasetReader( @@ -556,14 +560,33 @@ class LeRobotDataset(torch.utils.data.Dataset): if self.episodes is not None: # Reader is guaranteed to exist here (created in __init__ before _download) files = self.reader.get_episodes_file_paths() - snapshot_download( - self.repo_id, - repo_type="dataset", - revision=self.revision, - local_dir=self.root, - allow_patterns=files, - ignore_patterns=ignore_patterns, - ) + + if self._requested_root is None: + self.meta.root = Path( + snapshot_download( + self.repo_id, + repo_type="dataset", + revision=self.revision, + cache_dir=HF_LEROBOT_HUB_CACHE, + allow_patterns=files, + ignore_patterns=ignore_patterns, + ) + ) + else: + self._requested_root.mkdir(exist_ok=True, parents=True) + snapshot_download( + self.repo_id, + repo_type="dataset", + revision=self.revision, + local_dir=self._requested_root, + allow_patterns=files, + ignore_patterns=ignore_patterns, + ) + self.meta.root = self._requested_root + + # Propagate resolved root from metadata (single source of truth) + self.root = self.meta.root + self.reader.root = self.meta.root # ── Class constructors ──────────────────────────────────────────── @@ -635,6 +658,7 @@ class LeRobotDataset(torch.utils.data.Dataset): metadata_buffer_size=metadata_buffer_size, ) obj.repo_id = obj.meta.repo_id + obj._requested_root = obj.meta.root obj.root = obj.meta.root obj.revision = None obj.tolerance_s = tolerance_s @@ -695,8 +719,10 @@ class LeRobotDataset(torch.utils.data.Dataset): Args: repo_id: Repository identifier of the existing dataset. - root: Local directory of the dataset. Defaults to - ``$HF_LEROBOT_HOME/{repo_id}``. + root: Local directory of the dataset. When provided, Hub downloads + are materialized directly into this directory. When omitted, + Hub downloads use a revision-safe snapshot cache under + ``$HF_LEROBOT_HOME/hub``. tolerance_s: Timestamp synchronization tolerance in seconds. revision: Git revision (branch, tag, or commit hash). Defaults to current codebase version tag. @@ -716,11 +742,16 @@ class LeRobotDataset(torch.utils.data.Dataset): Returns: A :class:`LeRobotDataset` in write mode, ready to append episodes. """ + if not root: + raise ValueError( + "resume() requires an explicit 'root' directory because it creates a DatasetWriter. " + "Writing into the revision-safe Hub snapshot cache (used when root=None) would corrupt " + "the shared cache. Please provide a local directory path." + ) vcodec = resolve_vcodec(vcodec) obj = cls.__new__(cls) obj.repo_id = repo_id - obj.root = Path(root) if root else HF_LEROBOT_HOME / repo_id - obj.root.mkdir(exist_ok=True, parents=True) + obj._requested_root = Path(root) obj.revision = revision if revision else CODEBASE_VERSION obj.tolerance_s = tolerance_s obj.image_transforms = None @@ -731,10 +762,14 @@ class LeRobotDataset(torch.utils.data.Dataset): obj._vcodec = vcodec obj._encoder_threads = encoder_threads - # Load metadata + if obj._requested_root is not None: + obj._requested_root.mkdir(exist_ok=True, parents=True) + + # Load metadata (revision-safe when root is not provided) obj.meta = LeRobotDatasetMetadata( - obj.repo_id, obj.root, obj.revision, force_cache_sync=force_cache_sync + obj.repo_id, obj._requested_root, obj.revision, force_cache_sync=force_cache_sync ) + obj.root = obj.meta.root # Reader is lazily created on first access (write-only mode) obj.reader = None diff --git a/src/lerobot/datasets/streaming_dataset.py b/src/lerobot/datasets/streaming_dataset.py index 62e00558a..1767cc79d 100644 --- a/src/lerobot/datasets/streaming_dataset.py +++ b/src/lerobot/datasets/streaming_dataset.py @@ -255,7 +255,9 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): Args: repo_id (str): This is the repo id that will be used to fetch the dataset. - root (Path | None, optional): Local directory to use for downloading/writing files. + root (Path | None, optional): Local directory to use for local datasets. When omitted, Hub + metadata is resolved through a revision-safe snapshot cache under + ``$HF_LEROBOT_HOME/hub``. episodes (list[int] | None, optional): If specified, this will only load episodes specified by their episode_index in this list. image_transforms (Callable | None, optional): Transform to apply to image data. @@ -271,7 +273,8 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): """ super().__init__() self.repo_id = repo_id - self.root = Path(root) if root else HF_LEROBOT_HOME / repo_id + self._requested_root = Path(root) if root else None + self.root = self._requested_root if self._requested_root is not None else HF_LEROBOT_HOME / repo_id self.streaming_from_local = root is not None self.image_transforms = image_transforms @@ -288,12 +291,15 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): # We cache the video decoders to avoid re-initializing them at each frame (avoiding a ~10x slowdown) self.video_decoder_cache = None - self.root.mkdir(exist_ok=True, parents=True) + if self._requested_root is not None: + self.root.mkdir(exist_ok=True, parents=True) # Load metadata self.meta = LeRobotDatasetMetadata( - self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync + self.repo_id, self._requested_root, self.revision, force_cache_sync=force_cache_sync ) + self.root = self.meta.root + self.revision = self.meta.revision # Check version check_version_compatibility(self.repo_id, self.meta._version, CODEBASE_VERSION) diff --git a/src/lerobot/datasets/utils.py b/src/lerobot/datasets/utils.py index 2e1d360f9..36e7934ed 100644 --- a/src/lerobot/datasets/utils.py +++ b/src/lerobot/datasets/utils.py @@ -18,6 +18,7 @@ import importlib.resources import json import logging from collections.abc import Iterator +from pathlib import Path from typing import Any import datasets @@ -101,6 +102,18 @@ DEFAULT_FEATURES = { } +def has_legacy_hub_download_metadata(root: Path) -> bool: + """Return ``True`` when *root* looks like a legacy Hub ``local_dir`` mirror. + + ``snapshot_download(local_dir=...)`` stores lightweight metadata under + ``/.cache/huggingface/download/``. The presence of this + directory is a reliable indicator that the dataset was downloaded with + the old non-revision-safe ``local_dir`` mode and should be re-fetched + through the snapshot cache instead. + """ + return (root / ".cache" / "huggingface" / "download").exists() + + def update_chunk_file_indices(chunk_idx: int, file_idx: int, chunks_size: int) -> tuple[int, int]: if file_idx == chunks_size - 1: file_idx = 0 diff --git a/src/lerobot/utils/constants.py b/src/lerobot/utils/constants.py index ecd54844c..fd10cab35 100644 --- a/src/lerobot/utils/constants.py +++ b/src/lerobot/utils/constants.py @@ -65,6 +65,10 @@ if "LEROBOT_HOME" in os.environ: # cache dir default_cache_path = Path(HF_HOME) / "lerobot" HF_LEROBOT_HOME = Path(os.getenv("HF_LEROBOT_HOME", default_cache_path)).expanduser() +# LeRobot's own revision-safe Hub cache (NOT the system-wide ~/.cache/huggingface/hub/). +# Used as the ``cache_dir`` argument to ``snapshot_download`` so that different +# dataset revisions are stored in isolated snapshot directories. +HF_LEROBOT_HUB_CACHE = HF_LEROBOT_HOME / "hub" # calibration dir default_calibration_path = HF_LEROBOT_HOME / "calibration" diff --git a/tests/datasets/test_lerobot_dataset.py b/tests/datasets/test_lerobot_dataset.py index d7ce54a15..a8aa47ed2 100644 --- a/tests/datasets/test_lerobot_dataset.py +++ b/tests/datasets/test_lerobot_dataset.py @@ -19,9 +19,15 @@ Tests focus on mode contracts (read-only, write-only, resume), guards, property delegation, and the full create-record-finalize-read lifecycle. """ +from pathlib import Path +from unittest.mock import Mock + import pytest import torch +import lerobot.datasets.dataset_metadata as dataset_metadata_module +import lerobot.datasets.lerobot_dataset as lerobot_dataset_module +from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata from lerobot.datasets.dataset_reader import DatasetReader from lerobot.datasets.dataset_writer import DatasetWriter from lerobot.datasets.lerobot_dataset import LeRobotDataset @@ -30,12 +36,69 @@ from tests.fixtures.constants import DEFAULT_FPS, DUMMY_REPO_ID SIMPLE_FEATURES = { "state": {"dtype": "float32", "shape": (2,), "names": None}, } +SNAPSHOT_MAIN_FEATURES = { + **SIMPLE_FEATURES, + "test": {"dtype": "float32", "shape": (2,), "names": None}, +} def _make_frame(task: str = "Dummy task") -> dict: return {"task": task, "state": torch.randn(2)} +def _set_default_cache_root(monkeypatch: pytest.MonkeyPatch, cache_root: Path) -> None: + monkeypatch.setattr(dataset_metadata_module, "HF_LEROBOT_HOME", cache_root) + monkeypatch.setattr(dataset_metadata_module, "HF_LEROBOT_HUB_CACHE", cache_root / "hub") + monkeypatch.setattr(lerobot_dataset_module, "HF_LEROBOT_HUB_CACHE", cache_root / "hub") + + +def _write_dataset_tree( + root: Path, + *, + motor_features: dict[str, dict], + info_factory, + stats_factory, + tasks_factory, + episodes_factory, + hf_dataset_factory, + create_info, + create_stats, + create_tasks, + create_episodes, + create_hf_dataset, +) -> None: + root.mkdir(parents=True, exist_ok=True) + info = info_factory( + total_episodes=1, + total_frames=3, + total_tasks=1, + use_videos=False, + motor_features=motor_features, + camera_features={}, + ) + tasks = tasks_factory(total_tasks=1) + episodes = episodes_factory( + features=info["features"], + fps=info["fps"], + total_episodes=1, + total_frames=3, + tasks=tasks, + ) + stats = stats_factory(features=info["features"]) + hf_dataset = hf_dataset_factory( + features=info["features"], + tasks=tasks, + episodes=episodes, + fps=info["fps"], + ) + + create_info(root, info) + create_stats(root, stats) + create_tasks(root, tasks) + create_episodes(root, episodes) + create_hf_dataset(root, hf_dataset) + + # ── Read-only mode (via __init__) ──────────────────────────────────── @@ -75,6 +138,261 @@ def test_len_matches_num_frames(tmp_path, lerobot_dataset_factory): assert len(dataset) == dataset.num_frames +def test_metadata_without_root_uses_hub_cache_snapshot_download( + tmp_path, + info_factory, + stats_factory, + tasks_factory, + episodes_factory, + hf_dataset_factory, + create_info, + create_stats, + create_tasks, + create_episodes, + create_hf_dataset, + monkeypatch, +): + """Metadata refresh uses the dedicated Hub cache instead of a shared local_dir mirror.""" + repo_id = DUMMY_REPO_ID + cache_root = tmp_path / "lerobot_cache" + snapshot_root = cache_root / "hub" / "datasets--dummy--repo" / "snapshots" / "commit-main" + _write_dataset_tree( + snapshot_root, + motor_features=SNAPSHOT_MAIN_FEATURES, + info_factory=info_factory, + stats_factory=stats_factory, + tasks_factory=tasks_factory, + episodes_factory=episodes_factory, + hf_dataset_factory=hf_dataset_factory, + create_info=create_info, + create_stats=create_stats, + create_tasks=create_tasks, + create_episodes=create_episodes, + create_hf_dataset=create_hf_dataset, + ) + + _set_default_cache_root(monkeypatch, cache_root) + snapshot_download = Mock(return_value=str(snapshot_root)) + monkeypatch.setattr(dataset_metadata_module, "snapshot_download", snapshot_download) + + meta = LeRobotDatasetMetadata(repo_id=repo_id, revision="main", force_cache_sync=True) + + assert meta.root == snapshot_root + assert snapshot_download.call_count == 1 + assert snapshot_download.call_args.args == (repo_id,) + assert snapshot_download.call_args.kwargs == { + "repo_type": "dataset", + "revision": "main", + "cache_dir": cache_root / "hub", + "allow_patterns": "meta/", + "ignore_patterns": None, + } + + +def test_without_root_reads_different_revisions_from_distinct_snapshot_roots( + tmp_path, + info_factory, + stats_factory, + tasks_factory, + episodes_factory, + hf_dataset_factory, + create_info, + create_stats, + create_tasks, + create_episodes, + create_hf_dataset, + monkeypatch, +): + """Different revisions resolve to different on-disk snapshot roots.""" + repo_id = DUMMY_REPO_ID + old_revision = "b59010db93eb6cc3cf06ef2f7cae1bbe62b726d9" + cache_root = tmp_path / "lerobot_cache" + main_root = cache_root / "hub" / "datasets--dummy--repo" / "snapshots" / "commit-main" + old_root = cache_root / "hub" / "datasets--dummy--repo" / "snapshots" / "commit-old" + + _write_dataset_tree( + main_root, + motor_features=SNAPSHOT_MAIN_FEATURES, + info_factory=info_factory, + stats_factory=stats_factory, + tasks_factory=tasks_factory, + episodes_factory=episodes_factory, + hf_dataset_factory=hf_dataset_factory, + create_info=create_info, + create_stats=create_stats, + create_tasks=create_tasks, + create_episodes=create_episodes, + create_hf_dataset=create_hf_dataset, + ) + _write_dataset_tree( + old_root, + motor_features=SIMPLE_FEATURES, + info_factory=info_factory, + stats_factory=stats_factory, + tasks_factory=tasks_factory, + episodes_factory=episodes_factory, + hf_dataset_factory=hf_dataset_factory, + create_info=create_info, + create_stats=create_stats, + create_tasks=create_tasks, + create_episodes=create_episodes, + create_hf_dataset=create_hf_dataset, + ) + + _set_default_cache_root(monkeypatch, cache_root) + snapshot_roots = { + "main": main_root, + old_revision: old_root, + } + meta_snapshot_download = Mock( + side_effect=lambda repo_id, **kwargs: str(snapshot_roots[kwargs["revision"]]) + ) + data_snapshot_download = Mock( + side_effect=lambda repo_id, **kwargs: str(snapshot_roots[kwargs["revision"]]) + ) + monkeypatch.setattr(dataset_metadata_module, "snapshot_download", meta_snapshot_download) + monkeypatch.setattr(lerobot_dataset_module, "snapshot_download", data_snapshot_download) + + main_dataset = LeRobotDataset( + repo_id=repo_id, revision="main", download_videos=False, force_cache_sync=True + ) + old_dataset = LeRobotDataset( + repo_id=repo_id, revision=old_revision, download_videos=False, force_cache_sync=True + ) + + assert main_dataset.root == main_root + assert old_dataset.root == old_root + assert "test" in main_dataset.hf_dataset.column_names + assert "test" not in old_dataset.hf_dataset.column_names + + # Metadata downloads use cache_dir, not local_dir + assert meta_snapshot_download.call_count == 2 + for download_call in meta_snapshot_download.call_args_list: + assert download_call.kwargs["cache_dir"] == cache_root / "hub" + assert "local_dir" not in download_call.kwargs + + # Data downloads also use cache_dir, not local_dir + assert data_snapshot_download.call_count == 2 + for download_call in data_snapshot_download.call_args_list: + assert download_call.kwargs["cache_dir"] == cache_root / "hub" + assert "local_dir" not in download_call.kwargs + + +def test_metadata_without_root_ignores_legacy_local_dir_cache( + tmp_path, + info_factory, + stats_factory, + tasks_factory, + episodes_factory, + hf_dataset_factory, + create_info, + create_stats, + create_tasks, + create_episodes, + create_hf_dataset, + monkeypatch, +): + """Legacy local-dir mirrors are bypassed in favor of revision-safe snapshots.""" + repo_id = DUMMY_REPO_ID + cache_root = tmp_path / "lerobot_cache" + legacy_root = cache_root / repo_id + snapshot_root = cache_root / "hub" / "datasets--dummy--repo" / "snapshots" / "commit-main" + + _write_dataset_tree( + legacy_root, + motor_features=SIMPLE_FEATURES, + info_factory=info_factory, + stats_factory=stats_factory, + tasks_factory=tasks_factory, + episodes_factory=episodes_factory, + hf_dataset_factory=hf_dataset_factory, + create_info=create_info, + create_stats=create_stats, + create_tasks=create_tasks, + create_episodes=create_episodes, + create_hf_dataset=create_hf_dataset, + ) + (legacy_root / ".cache" / "huggingface" / "download").mkdir(parents=True, exist_ok=True) + _write_dataset_tree( + snapshot_root, + motor_features=SNAPSHOT_MAIN_FEATURES, + info_factory=info_factory, + stats_factory=stats_factory, + tasks_factory=tasks_factory, + episodes_factory=episodes_factory, + hf_dataset_factory=hf_dataset_factory, + create_info=create_info, + create_stats=create_stats, + create_tasks=create_tasks, + create_episodes=create_episodes, + create_hf_dataset=create_hf_dataset, + ) + + _set_default_cache_root(monkeypatch, cache_root) + snapshot_download = Mock(return_value=str(snapshot_root)) + monkeypatch.setattr(dataset_metadata_module, "snapshot_download", snapshot_download) + + meta = LeRobotDatasetMetadata(repo_id=repo_id, revision="main") + + assert meta.root == snapshot_root + assert "test" in meta.features + assert snapshot_download.call_count == 1 + + +def test_download_without_root_uses_hub_cache( + tmp_path, + info_factory, + stats_factory, + tasks_factory, + episodes_factory, + hf_dataset_factory, + create_info, + create_stats, + create_tasks, + create_episodes, + create_hf_dataset, + monkeypatch, +): + """LeRobotDataset._download() uses cache_dir (not local_dir) when root is not provided.""" + repo_id = DUMMY_REPO_ID + cache_root = tmp_path / "lerobot_cache" + snapshot_root = cache_root / "hub" / "datasets--dummy--repo" / "snapshots" / "commit-main" + + # Pre-populate snapshot directory so metadata loads succeed, but leave + # data absent so that _download() is triggered. + _write_dataset_tree( + snapshot_root, + motor_features=SIMPLE_FEATURES, + info_factory=info_factory, + stats_factory=stats_factory, + tasks_factory=tasks_factory, + episodes_factory=episodes_factory, + hf_dataset_factory=hf_dataset_factory, + create_info=create_info, + create_stats=create_stats, + create_tasks=create_tasks, + create_episodes=create_episodes, + create_hf_dataset=create_hf_dataset, + ) + + _set_default_cache_root(monkeypatch, cache_root) + meta_snapshot_download = Mock(return_value=str(snapshot_root)) + monkeypatch.setattr(dataset_metadata_module, "snapshot_download", meta_snapshot_download) + + # Mock the data snapshot_download to return the same root (data already + # exists there from _write_dataset_tree). + data_snapshot_download = Mock(return_value=str(snapshot_root)) + monkeypatch.setattr(lerobot_dataset_module, "snapshot_download", data_snapshot_download) + + LeRobotDataset(repo_id=repo_id, revision="main", force_cache_sync=True) + + # _download() should have called snapshot_download with cache_dir + assert data_snapshot_download.call_count == 1 + call_kwargs = data_snapshot_download.call_args.kwargs + assert call_kwargs["cache_dir"] == cache_root / "hub" + assert "local_dir" not in call_kwargs + + # ── Write-only mode (via create()) ────────────────────────────────── From 2e069b1c4769e40371d226763191d68123282d4d Mon Sep 17 00:00:00 2001 From: Bryson Jones <63133702+brysonjones@users.noreply.github.com> Date: Fri, 27 Mar 2026 16:41:26 -0700 Subject: [PATCH 126/131] Feature/add multitask diffusion transformer policy implementation (#2545) * Add multitask diffusion transformer policy Add multitask diffusion transformer policy * expand the observation encoder to support differnt size encoders for vision and text * add RoPE attention module as this is shown to help training dynamics and generation quality for DiTs * update readme and citations for multitask dit policy * remove dino vision encoder and simplify text and vision encoders by removing inheritance structure * adjust factory comment * update docstring for multitask dit policy processor file * simplify config for multitask dit by merging and flattening everything, then adding comments to denote where some parameters are only used for specific objectives * add references to the modeling file comments * merge all modules files into the main modeling file * add torch.no_grad decorators * split up select action return statement * remove redundant asserts * add tutorial to training with multi_task_dit * fix bugs when testing on hardware * remove environment state conditioning * update typo in test instruction comment * add processor tests to multitask dit tests * move policy to top of file * use constants for indexing into batches and remove env state references * remove the base classes since we don't need to be able to extend * fix nit formatting in generate actions fcn * reformat and clean up tutorial for multitask dit policy * add more descriptions and depth to multitask dit tutorial * note origins of each training objective * rename config param for multiple vision encoders * refactor code to perform task tokenization in the processor instead of in the modeling code for multitask dit * add multitask dit to toc for docs * add conditional transformers import to match all other policies that use transformers lib * add test handling for multitask dit when transformers isnt available * skip tests without transformers * remove cropping of images smaller than the crop size * add kwargs arg to multitask dit constructor * add wallx dep conflict management for multitask dit policy * use hyphens for cleanliness in pyproject.toml * add conflict management to pyproject toml for pi conflict for mtdp as well * update tests script to not use unnecessary uv sync call which resolves dependencies that do not need to run. This drastically reduces CI run time * revert fast tests edits * update docs and readme files, fixing some typos and adding multitask dit to readme * chore(dependencies): upgrade transformers + hggingface-hub + peft + scipy * chore(dependencies): bump pi0 family to transformers v5 * chore(dependencies): bump wall x to transformers v5 * chore(dependencies): bump gr00t to transformers v5 * chore(style): fix pre-commit * fix(policy): xvla forced_bos_token missing * test(rl): skip ci tests for resnet10 * Fix: full pi models support for transformer v5 (#2967) * fix(pi): remove loss truncation * fix(pi): remove state padding before tokenization * fix(pi): fix image padding value * fix from_pretrain * add transformer v5 changes * remove reference * more fixes * make it work * add support for rest of pi family * add pifast work * more changes * more changes * more cleanup * fix torch params * dtype fix * torch compile * embed mismatch fix * revert groot * more nit fixes * remove unused classes * more fixes * revert * nit * torch dtype warning fix * but back dynamic renaming * add tie embedding --------- Co-authored-by: Yufei Sun * chore: fix XVLA in transformers v5 (#3006) * test(policies): enable wall x CI testing * style(test): pre-commit check * style(test): pre-commit --------- Signed-off-by: Bryson Jones <63133702+brysonjones@users.noreply.github.com> Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com> Co-authored-by: Steven Palma Co-authored-by: Jade Choghari Co-authored-by: Yufei Sun Co-authored-by: Steven Palma --- README.md | 10 +- docs/source/_toctree.yml | 2 + docs/source/multi_task_dit.mdx | 340 ++++++++ docs/source/policy_multi_task_dit_README.md | 37 + pyproject.toml | 1 + src/lerobot/policies/__init__.py | 2 + src/lerobot/policies/factory.py | 24 +- src/lerobot/policies/multi_task_dit/README.md | 37 + .../policies/multi_task_dit/__init__.py | 21 + .../configuration_multi_task_dit.py | 256 ++++++ .../multi_task_dit/modeling_multi_task_dit.py | 803 ++++++++++++++++++ .../processor_multi_task_dit.py | 105 +++ .../multi_task_dit/test_multi_task_dit.py | 624 ++++++++++++++ 13 files changed, 2253 insertions(+), 9 deletions(-) create mode 100644 docs/source/multi_task_dit.mdx create mode 100644 docs/source/policy_multi_task_dit_README.md create mode 100644 src/lerobot/policies/multi_task_dit/README.md create mode 100644 src/lerobot/policies/multi_task_dit/__init__.py create mode 100644 src/lerobot/policies/multi_task_dit/configuration_multi_task_dit.py create mode 100644 src/lerobot/policies/multi_task_dit/modeling_multi_task_dit.py create mode 100644 src/lerobot/policies/multi_task_dit/processor_multi_task_dit.py create mode 100644 tests/policies/multi_task_dit/test_multi_task_dit.py diff --git a/README.md b/README.md index f58b337b3..f67d9103c 100644 --- a/README.md +++ b/README.md @@ -100,11 +100,11 @@ lerobot-train \ --dataset.repo_id=lerobot/aloha_mobile_cabinet ``` -| Category | Models | -| -------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | -| **Imitation Learning** | [ACT](./docs/source/policy_act_README.md), [Diffusion](./docs/source/policy_diffusion_README.md), [VQ-BeT](./docs/source/policy_vqbet_README.md) | -| **Reinforcement Learning** | [HIL-SERL](./docs/source/hilserl.mdx), [TDMPC](./docs/source/policy_tdmpc_README.md) & QC-FQL (coming soon) | -| **VLAs Models** | [Pi0Fast](./docs/source/pi0fast.mdx), [Pi0.5](./docs/source/pi05.mdx), [GR00T N1.5](./docs/source/policy_groot_README.md), [SmolVLA](./docs/source/policy_smolvla_README.md), [XVLA](./docs/source/xvla.mdx) | +| Category | Models | +| -------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| **Imitation Learning** | [ACT](./docs/source/policy_act_README.md), [Diffusion](./docs/source/policy_diffusion_README.md), [VQ-BeT](./docs/source/policy_vqbet_README.md), [Multitask DiT Policy](./docs/source/policy_multi_task_dit_README.md) | +| **Reinforcement Learning** | [HIL-SERL](./docs/source/hilserl.mdx), [TDMPC](./docs/source/policy_tdmpc_README.md) & QC-FQL (coming soon) | +| **VLAs Models** | [Pi0Fast](./docs/source/pi0fast.mdx), [Pi0.5](./docs/source/pi05.mdx), [GR00T N1.5](./docs/source/policy_groot_README.md), [SmolVLA](./docs/source/policy_smolvla_README.md), [XVLA](./docs/source/xvla.mdx) | Similarly to the hardware, you can easily implement your own policy & leverage LeRobot's data collection, training, and visualization tools, and share your model to the HF Hub diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 09d94d28c..650a21184 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -49,6 +49,8 @@ title: NVIDIA GR00T N1.5 - local: xvla title: X-VLA + - local: multi_task_dit + title: Multitask DiT Policy - local: walloss title: WALL-OSS title: "Policies" diff --git a/docs/source/multi_task_dit.mdx b/docs/source/multi_task_dit.mdx new file mode 100644 index 000000000..c3cced708 --- /dev/null +++ b/docs/source/multi_task_dit.mdx @@ -0,0 +1,340 @@ +# Multitask DiT Policy + +Multitask Diffusion Transformer (DiT) Policy is an evolution of the original Diffusion Policy architecture, which leverages a large DiT with text and vision conditioning for multitask robot learning. This implementation supports both diffusion and flow matching objectives for action generation, enabling robots to perform diverse manipulation tasks conditioned on language instructions. + +## Model Overview + +The model uses: + +- **CLIP Vision Encoder**: Processes RGB images from multiple camera views +- **CLIP Text Encoder**: Encodes language task instructions (frozen weights with learnable projection) +- **Diffusion Transformer**: Predicts action sequences conditioned on observations and language +- **Two Objectives**: Supports both diffusion (DDPM/DDIM) and flow matching for action generation + +This model is exciting because you can achieve extremely high dexterity, competitive with multi-billion parameter +VLAs, with only ~450M parameters and significantly less training. + +## Installation Requirements + +Multitask DiT Policy has additional dependencies. Install it with: + +```bash +pip install lerobot[multi_task_dit] +``` + +This will install all necessary dependencies including the HuggingFace Transformers library for CLIP models. + +## Usage + +To use Multitask DiT in your LeRobot configuration, specify the policy type as: + +```python +policy.type=multi_task_dit +``` + +## Training + +### Basic Training Command + +Here's a complete training command for training Multitask DiT on your dataset: + +```bash +lerobot-train \ + --dataset.repo_id=YOUR_DATASET \ + --output_dir=./outputs/multitask_dit_training \ + --batch_size=32 \ + --steps=5000 \ + --save_freq=500 \ + --log_freq=100 \ + --policy.type=multi_task_dit \ + --policy.device=cuda \ + --policy.repo_id="HF_USER/multitask-dit-your-robot" \ + --wandb.enable=true +``` + +### Recommended Hyperparameters and Dataset Details (30Hz Control Frequency) + +For reliable performance, start with these suggested default hyperparameters: + +```bash +lerobot-train \ + --dataset.repo_id=YOUR_DATASET \ + --output_dir=./outputs/mutitask_dit_training \ + --batch_size=320 \ + --steps=30000 \ + --policy.type=multi_task_dit \ + --policy.device=cuda \ + --policy.horizon=32 \ + --policy.n_action_steps=24 \ + --policy.objective=diffusion \ + --policy.noise_scheduler_type=DDPM \ + --policy.num_train_timesteps=100 \ + --policy.repo_id="HF_USER/multitask-dit-your-robot" \ + --wandb.enable=true +``` + +**Key Parameters:** + +- **Batch Size**: 192-320 - If you have access to a GPU that can support this, you will get the best training dynamics +- **Horizon**: 32 - number of action steps to predict, ~1.0 sec at 30Hz +- **n_action_steps**: 24 - ~0.8 seconds at 30Hz +- **Objective**: `diffusion` - start with diffusion and experiment with flow matching if generation quality is poor +- **Training Steps**: >30k steps recommended for a single task + +### Training Configuration Parameters + +#### Objective Selection + +Choose between diffusion and flow matching: + +```bash +# Diffusion objective (default) +--policy.objective=diffusion \ +--policy.noise_scheduler_type=DDPM \ # or "DDIM" +--policy.num_train_timesteps=100 \ +--policy.num_inference_steps=10 \ # For faster inference +--policy.beta_schedule=squaredcos_cap_v2 \ # Noise schedule type +--policy.prediction_type=epsilon \ # "epsilon" (predict noise) or "sample" (predict clean) +--policy.clip_sample=true \ # Clip samples during denoising +--policy.clip_sample_range=1.0 # Clipping range [-x, x] + +# Flow matching objective +--policy.objective=flow_matching \ +--policy.timestep_sampling_strategy=beta \ # or "uniform" | the beta sampling strategy performance appears much better in practice +--policy.num_integration_steps=100 \ +--policy.integration_method=euler \ # or "rk4" +--policy.sigma_min=0.0 # Minimum noise in flow interpolation path +``` + +#### Transformer Architecture + +Adjust model capacity based on dataset size: + +```bash +# Small datasets (< 100 examples) +--policy.num_layers=4 \ +--policy.hidden_dim=512 \ +--policy.num_heads=8 # should ideally be hidden_dim // 64 + +# Medium datasets (100-5k examples) - default +--policy.num_layers=6 \ +--policy.hidden_dim=512 \ +--policy.num_heads=8 # should ideally be hidden_dim // 64 + +# Large datasets (> 5k examples) +--policy.num_layers=8 \ +--policy.hidden_dim=512 \ +--policy.num_heads=8 # should ideally be hidden_dim // 64 +``` + +**Positional Encoding Options:** + +The model supports two positional encoding methods for action sequences: + +```bash +# Rotary Position Embedding (RoPE) - default, recommended +--policy.use_rope=true \ +--policy.rope_base=10000.0 # Base frequency for RoPE + +# Absolute positional encoding +--policy.use_positional_encoding=true # Disables RoPE when true +``` + +**Other Transformer Parameters:** + +```bash +--policy.dropout=0.1 # Dropout rate for DiT blocks (0.0-1.0) +--policy.timestep_embed_dim=256 # Timestep embedding dimension +``` + +#### Vision Encoder Configuration + +```bash +# Use different CLIP model for more expressivity at the cost of inference time +# experiment with larger or smaller models depending on the complexity of your tasks and size of dataset +--policy.vision_encoder_name=openai/clip-vit-large-patch14 + +# Use separate vision encoder per camera +# This may be useful when cameras have significantly different characteristics, but +# be wary of increased VRAM footprint. +--policy.use_separate_rgb_encoder_per_camera=true + +# Image preprocessing +--policy.image_resize_shape=[XXX,YYY] \ # you may need to resize your images for inference speed ups +--policy.image_crop_shape=[224,224] \ +--policy.image_crop_is_random=true # Random during training, center at inference +``` + +#### Text Encoder Configuration + +```bash +# Use different CLIP text encoder model +# same as vision: experiment with larger or smaller models depending on the +# complexity of your tasks and size of dataset +--policy.text_encoder_name=openai/clip-vit-large-patch14 +``` + +#### Learning Rate Configuration + +The vision encoder uses a separate learning rate multiplier, where 1/10th is suggested to be the ideal staritng point: + +```bash +--policy.optimizer_lr=2e-5 \ +--policy.vision_encoder_lr_multiplier=0.1 # Vision encoder LR = 0.1 * optimizer_lr +``` + +### Training Tuning Guidelines + +#### 1. Flow Matching with Beta Sampling + +The original diffusion implementation here is based on the work described in [TRI's LBM paper](https://arxiv.org/abs/2507.05331) + +Additionally, we have implemented a flow-matching objective, which is described at a high-level in [Boston Dynamics blog post](https://bostondynamics.com/blog/large-behavior-models-atlas-find-new-footing/). + +Consider testing the flow-matching objective and evaluating performance differences for your task: + +```bash +--policy.objective=flow_matching \ +--policy.timestep_sampling_strategy=beta \ +--policy.timestep_sampling_alpha=1.5 \ +--policy.timestep_sampling_beta=1.0 \ +--policy.timestep_sampling_s=0.999 +``` + +This hasn't been shown to be a silver bullet across every user case, but it occasionally results in smoother and more consistent actions. + +#### 2. Number of Transformer Layers + +Match model capacity to your dataset size: + +- **Small datasets** (< 100 examples): Reduce to 4 layers +- **Large datasets** (> 5k examples): Increase to 8 layers + +#### 3. `horizon` Tuning + +The model can be sensitive to the horizon you choose. Start with around a 1 second horizon based on your control frequency: + +- **30 Hz frequency**: `horizon=30` +- **10 Hz frequency**: `horizon=10` + +Then experiment with increasing from there. The horizon determines how far into the future the model predicts actions. + +#### 4. `n_action_steps` Sensitivity + +The model can also be very sensitive to `n_action_steps`. Start with it being around 0.8 seconds based on your control frequency and tune from there: + +- **Lower values**: More reactive but potentially less stable for long-horizon tasks +- **Higher values**: Better for long-horizon execution but open-loop failures are limited in their recovery + +### Inference Tuning + +For faster inference, use DDIM with fewer sampling steps: + +```bash +--policy.noise_scheduler_type=DDIM \ +--policy.num_inference_steps=10 +``` + +### Resuming Training + +To resume training from a checkpoint: + +```bash +lerobot-train \ + --config_path=./outputs/mutitask_dit_training/checkpoints/last/pretrained_model/train_config.json \ + --resume=true +``` + +The checkpoint directory should contain `model.safetensors` and `config.json` files (saved automatically during training). When resuming, the configuration is loaded from the checkpoint, so you don't need to specify other parameters. + +## Common Failure Modes and Debugging + +Training these models can be finicky. Here are common failure modes and debugging approaches: + +### Idling / No Motion + +The model may "collapse" during inference, resulting in static or no motion. This can occur when: + +1. **Insufficient training data**: If you only have 20-50 examples, try to roughly double your dataset size. Once you have above 300 examples, if you're still seeing this, the task may be too complex. + +2. **Multiple similar tasks**: When your dataset contains multiple similar tasks (e.g., picking up 2 different objects), the model may rely too heavily on language conditioning which might not be rich enough. + +**Debugging tips:** + +- Increase dataset size (double until you get to over 300 examples) +- Train for longer, up to 100k steps, even when the loss flatlines +- Check if the model is receiving proper language instructions or increase diversity of instruction + +### Executing the Wrong Task + +Sometimes the robot will completely ignore your instruction and perform some other task. This generally only happens if you have trained on multiple tasks. + +**Potential causes:** + +- Language instruction ambiguity +- Insufficient task-specific training data +- Model confusion between similar tasks in the multitask dataset + +**Debugging tips:** + +- Verify language instruction specificity, especially if descriptions are similar between multiple tasks +- Check task distribution in your training dataset and add weighting to the failing/ignored task +- Consider task-specific fine-tuning + +### Training Instability + +If training loss is unstable or diverging: + +- Try adjusting learning rate between `1e-5` and `3e-4` +- Increase batch size if possible +- Check that your dataset normalization is correct +- Verify image preprocessing is working correctly + +## Performance Considerations + +### GPU Requirements + +- **Inference**: At least an RTX 5070 Ti (or equivalent GPU) is recommended for reasonable speed performance +- **Training**: A GPU with enough VRAM to load batch sizes of >64 is ideal, which will vary depending on the number of image observations, etc + +### Batch Size Recommendations + +- **Minimum**: 64 (less than this may result in unstable training) +- **Recommended**: 256-320 (best performance, requires larger GPU) + +## Example: Training on Custom Dataset + +Here's a complete example training on a custom dataset: + +```bash +lerobot-train \ + --dataset.repo_id=YOUR_DATASET \ + --output_dir=./outputs/mutitask_dit_training \ + --batch_size=320 \ + --steps=30000 \ + --save_freq=1000 \ + --log_freq=100 \ + --eval_freq=1000 \ + --policy.type=multi_task_dit \ + --policy.device=cuda \ + --policy.horizon=32 \ + --policy.n_action_steps=24 \ + --policy.objective=diffusion \ + --policy.noise_scheduler_type=DDPM \ + --policy.num_layers=6 \ + --policy.hidden_dim=512 \ + --policy.vision_encoder_name=openai/clip-vit-base-patch16 \ + --policy.image_resize_shape=[320,240] \ + --policy.image_crop_shape=[224,224] \ + --policy.repo_id="HF_USER/multitask-dit-your-robot" \ + --wandb.enable=true \ + --wandb.project=multitask_dit +``` + +## References + +For more details on the technical implementation and architecture, see: + +- [A Careful Examination of Large Behavior Models for Multitask Dexterous Manipulation](https://arxiv.org/abs/2507.05331) +- [Large Behavior Models and Atlas Find New Footing](https://bostondynamics.com/blog/large-behavior-models-atlas-find-new-footing/) +- [Dissecting and Open-Sourcing Multitask Diffusion Transformer Policy](https://brysonkjones.substack.com/p/dissecting-and-open-sourcing-multitask-diffusion-transformer-policy) diff --git a/docs/source/policy_multi_task_dit_README.md b/docs/source/policy_multi_task_dit_README.md new file mode 100644 index 000000000..f24fa927e --- /dev/null +++ b/docs/source/policy_multi_task_dit_README.md @@ -0,0 +1,37 @@ +# Multitask DiT Policy + +## Citation + +If you use this work, please cite the following works: + +```bibtex +@misc{jones2025multitaskditpolicy, + author = {Bryson Jones}, + title = {Dissecting and Open-Sourcing Multitask Diffusion Transformer Policy}, + year = {2025}, + url = {https://brysonkjones.substack.com/p/dissecting-and-open-sourcing-multitask-diffusion-transformer-policy}, + note = {Blog post} +} +``` + +```bibtex +@misc{trilbmteam2025carefulexaminationlargebehaviormodels, + author = {TRI LBM Team}, + title = {A Careful Examination of Large Behavior Models for Multitask Dexterous Manipulation}, + year = {2025}, + eprint = {arXiv:2507.05331}, + archivePrefix = {arXiv}, + primaryClass = {cs.RO}, + url = {https://arxiv.org/abs/2507.05331} +} +``` + +```bibtex +@misc{bostondynamics2025largebehaviormodelsatlas, + author = {Boston Dynamics and TRI Research Team}, + title = {Large Behavior Models and Atlas Find New Footing}, + year = {2025}, + url = {https://bostondynamics.com/blog/large-behavior-models-atlas-find-new-footing/}, + note = {Blog post} +} +``` diff --git a/pyproject.toml b/pyproject.toml index 7e4f24eb6..bed22a507 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -145,6 +145,7 @@ wallx = [ ] pi = ["lerobot[transformers-dep]", "lerobot[scipy-dep]"] smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0", "safetensors>=0.4.3,<1.0.0"] +multi_task_dit = ["lerobot[transformers-dep]"] groot = [ "lerobot[transformers-dep]", "lerobot[peft]", diff --git a/src/lerobot/policies/__init__.py b/src/lerobot/policies/__init__.py index c7951f028..55ce09cf9 100644 --- a/src/lerobot/policies/__init__.py +++ b/src/lerobot/policies/__init__.py @@ -15,6 +15,7 @@ from .act.configuration_act import ACTConfig as ACTConfig from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig from .groot.configuration_groot import GrootConfig as GrootConfig +from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig as MultiTaskDiTConfig from .pi0.configuration_pi0 import PI0Config as PI0Config from .pi0_fast.configuration_pi0_fast import PI0FastConfig as PI0FastConfig from .pi05.configuration_pi05 import PI05Config as PI05Config @@ -28,6 +29,7 @@ from .xvla.configuration_xvla import XVLAConfig as XVLAConfig __all__ = [ "ACTConfig", "DiffusionConfig", + "MultiTaskDiTConfig", "PI0Config", "PI05Config", "PI0FastConfig", diff --git a/src/lerobot/policies/factory.py b/src/lerobot/policies/factory.py index 2320cd624..146924502 100644 --- a/src/lerobot/policies/factory.py +++ b/src/lerobot/policies/factory.py @@ -31,6 +31,7 @@ from lerobot.envs.utils import env_to_policy_features from lerobot.policies.act.configuration_act import ACTConfig from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig from lerobot.policies.groot.configuration_groot import GrootConfig +from lerobot.policies.multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig from lerobot.policies.pi0.configuration_pi0 import PI0Config from lerobot.policies.pi05.configuration_pi05 import PI05Config from lerobot.policies.pretrained import PreTrainedPolicy @@ -67,8 +68,7 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]: Args: name: The name of the policy. Supported names are "tdmpc", "diffusion", "act", - "vqbet", "pi0", "pi05", "sac", "reward_classifier", "smolvla", "wall_x". - + "multi_task_dit", "vqbet", "pi0", "pi05", "sac", "reward_classifier", "smolvla", "wall_x". Returns: The policy class corresponding to the given name. @@ -87,6 +87,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]: from lerobot.policies.act.modeling_act import ACTPolicy return ACTPolicy + elif name == "multi_task_dit": + from lerobot.policies.multi_task_dit.modeling_multi_task_dit import MultiTaskDiTPolicy + + return MultiTaskDiTPolicy elif name == "vqbet": from lerobot.policies.vqbet.modeling_vqbet import VQBeTPolicy @@ -147,8 +151,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig: Args: policy_type: The type of the policy. Supported types include "tdmpc", - "diffusion", "act", "vqbet", "pi0", "pi05", "sac", "smolvla", - "reward_classifier", "wall_x". + "multi_task_dit", "diffusion", "act", "vqbet", "pi0", "pi05", "sac", + "smolvla", "reward_classifier", "wall_x". **kwargs: Keyword arguments to be passed to the configuration class constructor. Returns: @@ -163,6 +167,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig: return DiffusionConfig(**kwargs) elif policy_type == "act": return ACTConfig(**kwargs) + elif policy_type == "multi_task_dit": + return MultiTaskDiTConfig(**kwargs) elif policy_type == "vqbet": return VQBeTConfig(**kwargs) elif policy_type == "pi0": @@ -309,6 +315,16 @@ def make_pre_post_processors( dataset_stats=kwargs.get("dataset_stats"), ) + elif isinstance(policy_cfg, MultiTaskDiTConfig): + from lerobot.policies.multi_task_dit.processor_multi_task_dit import ( + make_multi_task_dit_pre_post_processors, + ) + + processors = make_multi_task_dit_pre_post_processors( + config=policy_cfg, + dataset_stats=kwargs.get("dataset_stats"), + ) + elif isinstance(policy_cfg, VQBeTConfig): from lerobot.policies.vqbet.processor_vqbet import make_vqbet_pre_post_processors diff --git a/src/lerobot/policies/multi_task_dit/README.md b/src/lerobot/policies/multi_task_dit/README.md new file mode 100644 index 000000000..f24fa927e --- /dev/null +++ b/src/lerobot/policies/multi_task_dit/README.md @@ -0,0 +1,37 @@ +# Multitask DiT Policy + +## Citation + +If you use this work, please cite the following works: + +```bibtex +@misc{jones2025multitaskditpolicy, + author = {Bryson Jones}, + title = {Dissecting and Open-Sourcing Multitask Diffusion Transformer Policy}, + year = {2025}, + url = {https://brysonkjones.substack.com/p/dissecting-and-open-sourcing-multitask-diffusion-transformer-policy}, + note = {Blog post} +} +``` + +```bibtex +@misc{trilbmteam2025carefulexaminationlargebehaviormodels, + author = {TRI LBM Team}, + title = {A Careful Examination of Large Behavior Models for Multitask Dexterous Manipulation}, + year = {2025}, + eprint = {arXiv:2507.05331}, + archivePrefix = {arXiv}, + primaryClass = {cs.RO}, + url = {https://arxiv.org/abs/2507.05331} +} +``` + +```bibtex +@misc{bostondynamics2025largebehaviormodelsatlas, + author = {Boston Dynamics and TRI Research Team}, + title = {Large Behavior Models and Atlas Find New Footing}, + year = {2025}, + url = {https://bostondynamics.com/blog/large-behavior-models-atlas-find-new-footing/}, + note = {Blog post} +} +``` diff --git a/src/lerobot/policies/multi_task_dit/__init__.py b/src/lerobot/policies/multi_task_dit/__init__.py new file mode 100644 index 000000000..52a209d47 --- /dev/null +++ b/src/lerobot/policies/multi_task_dit/__init__.py @@ -0,0 +1,21 @@ +#!/usr/bin/env python + +# Copyright 2025 Bryson Jones and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .configuration_multi_task_dit import MultiTaskDiTConfig +from .modeling_multi_task_dit import MultiTaskDiTPolicy +from .processor_multi_task_dit import make_multi_task_dit_pre_post_processors + +__all__ = ["MultiTaskDiTConfig", "MultiTaskDiTPolicy", "make_multi_task_dit_pre_post_processors"] diff --git a/src/lerobot/policies/multi_task_dit/configuration_multi_task_dit.py b/src/lerobot/policies/multi_task_dit/configuration_multi_task_dit.py new file mode 100644 index 000000000..061230687 --- /dev/null +++ b/src/lerobot/policies/multi_task_dit/configuration_multi_task_dit.py @@ -0,0 +1,256 @@ +#!/usr/bin/env python + +# Copyright 2025 Bryson Jones and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from dataclasses import dataclass, field + +from lerobot.configs.policies import PreTrainedConfig +from lerobot.configs.types import NormalizationMode +from lerobot.optim.optimizers import AdamConfig +from lerobot.optim.schedulers import DiffuserSchedulerConfig + + +@PreTrainedConfig.register_subclass("multi_task_dit") +@dataclass +class MultiTaskDiTConfig(PreTrainedConfig): + """Configuration for the Multi-Task Diffusion Transformer (DiT) policy. + + A transformer-based policy that supports both diffusion and flow matching objectives + for multi-task robot learning with text and vision conditioning. + """ + + n_obs_steps: int = 2 # Number of observation steps for temporal context + horizon: int = 32 # Number of action steps to predict + n_action_steps: int = 24 # Actions executed per policy call (~0.8s at 30Hz) + + # Objective Selection + objective: str = "diffusion" # "diffusion" or "flow_matching" + + # --- Diffusion-specific (used when objective="diffusion") --- + noise_scheduler_type: str = "DDPM" # "DDPM" or "DDIM" + num_train_timesteps: int = 100 # Number of diffusion timesteps + beta_schedule: str = "squaredcos_cap_v2" # Noise schedule type + beta_start: float = 0.0001 # Starting noise level + beta_end: float = 0.02 # Ending noise level + prediction_type: str = "epsilon" # "epsilon" (predict noise) or "sample" (predict clean) + clip_sample: bool = True # Clip samples during denoising + clip_sample_range: float = 1.0 # Clipping range [-x, x] + num_inference_steps: int | None = None # Denoising steps at inference (defaults to num_train_timesteps) + + # --- Flow Matching-specific (used when objective="flow_matching") --- + sigma_min: float = 0.0 # Minimum noise in flow interpolation path + num_integration_steps: int = 100 # ODE integration steps at inference + integration_method: str = "euler" # ODE solver: "euler" or "rk4" + timestep_sampling_strategy: str = "beta" # "uniform" or "beta" + + timestep_sampling_s: float = 0.999 # (beta only) Max timestep threshold + timestep_sampling_alpha: float = 1.5 # (beta only) Beta distribution alpha + timestep_sampling_beta: float = 1.0 # (beta only) Beta distribution beta + + # Transformer Architecture + hidden_dim: int = 512 # Transformer hidden dimension + num_layers: int = 6 # Number of transformer layers + num_heads: int = 8 # Number of attention heads + dropout: float = 0.1 # Dropout rate + use_positional_encoding: bool = False # Use absolute positional encoding + timestep_embed_dim: int = 256 # Timestep embedding dimension + use_rope: bool = True # Use Rotary Position Embedding + rope_base: float = 10000.0 # RoPE base frequency + + # Vision Encoder (CLIP) + vision_encoder_name: str = "openai/clip-vit-base-patch16" # HuggingFace CLIP model + use_separate_rgb_encoder_per_camera: bool = False # Separate encoder per camera view + vision_encoder_lr_multiplier: float = 0.1 # LR multiplier for vision encoder + image_resize_shape: tuple[int, int] | None = None # Resize images before crop + image_crop_shape: tuple[int, int] | None = (224, 224) # Crop shape (CLIP default) + image_crop_is_random: bool = True # Random crop during training, center at inference + + # Text Encoder (CLIP) + text_encoder_name: str = "openai/clip-vit-base-patch16" # HuggingFace CLIP model + tokenizer_max_length: int = 77 # Max length for tokenized text (CLIP default is 77) + tokenizer_padding: str = "max_length" # Padding strategy: "max_length" or "longest" + tokenizer_padding_side: str = "right" # Padding side: "left" or "right" + tokenizer_truncation: bool = True # Whether to truncate sequences longer than max_length + + # Normalization + normalization_mapping: dict[str, NormalizationMode] = field( + default_factory=lambda: { + "VISUAL": NormalizationMode.MEAN_STD, + "STATE": NormalizationMode.MIN_MAX, + "ACTION": NormalizationMode.MIN_MAX, + } + ) + + # Training/Optimizer + optimizer_lr: float = 2e-5 + optimizer_betas: tuple = (0.95, 0.999) + optimizer_eps: float = 1e-8 + optimizer_weight_decay: float = 0.0 + scheduler_name: str = "cosine" + scheduler_warmup_steps: int = 0 + do_mask_loss_for_padding: bool = False + + # Auto-calculated + drop_n_last_frames: int | None = None + + def __post_init__(self): + super().__post_init__() + + if self.drop_n_last_frames is None: + self.drop_n_last_frames = self.horizon - self.n_action_steps - self.n_obs_steps + 1 + + self._validate() + + def _validate(self): + """Validate configuration parameters.""" + # Objective validation + if self.objective not in ["diffusion", "flow_matching"]: + raise ValueError(f"objective must be 'diffusion' or 'flow_matching', got '{self.objective}'") + + # Transformer validation + if self.hidden_dim <= 0: + raise ValueError("hidden_dim must be positive") + if self.num_layers <= 0: + raise ValueError("num_layers must be positive") + if self.num_heads <= 0: + raise ValueError("num_heads must be positive") + if self.hidden_dim % self.num_heads != 0: + raise ValueError("hidden_dim must be divisible by num_heads") + if not (0.0 <= self.dropout <= 1.0): + raise ValueError("dropout must be between 0.0 and 1.0") + + # Vision encoder validation + if "clip" not in self.vision_encoder_name.lower(): + raise ValueError( + f"vision_encoder_name must be a CLIP model (contain 'clip'), got '{self.vision_encoder_name}'" + ) + if ( + self.image_resize_shape + and self.image_crop_shape + and ( + self.image_crop_shape[0] > self.image_resize_shape[0] + or self.image_crop_shape[1] > self.image_resize_shape[1] + ) + ): + logging.warning( + "image_crop_shape %s must be <= image_resize_shape %s; disabling cropping.", + self.image_crop_shape, + self.image_resize_shape, + ) + self.image_crop_shape = None + + # Text encoder validation + if "clip" not in self.text_encoder_name.lower(): + raise ValueError( + f"text_encoder_name must be a CLIP model (contain 'clip'), got '{self.text_encoder_name}'" + ) + + # Objective-specific validation + if self.objective == "diffusion": + if self.noise_scheduler_type not in ["DDPM", "DDIM"]: + raise ValueError( + f"noise_scheduler_type must be 'DDPM' or 'DDIM', got {self.noise_scheduler_type}" + ) + if self.prediction_type not in ["epsilon", "sample"]: + raise ValueError(f"prediction_type must be 'epsilon' or 'sample', got {self.prediction_type}") + if self.num_train_timesteps <= 0: + raise ValueError(f"num_train_timesteps must be positive, got {self.num_train_timesteps}") + if not (0.0 <= self.beta_start <= self.beta_end <= 1.0): + raise ValueError(f"Invalid beta values: {self.beta_start}, {self.beta_end}") + + elif self.objective == "flow_matching": + if not (0.0 <= self.sigma_min <= 1.0): + raise ValueError(f"sigma_min must be in [0, 1], got {self.sigma_min}") + if self.num_integration_steps <= 0: + raise ValueError(f"num_integration_steps must be positive, got {self.num_integration_steps}") + if self.integration_method not in ["euler", "rk4"]: + raise ValueError( + f"integration_method must be 'euler' or 'rk4', got {self.integration_method}" + ) + if self.timestep_sampling_strategy not in ["uniform", "beta"]: + raise ValueError("timestep_sampling_strategy must be 'uniform' or 'beta'") + if self.timestep_sampling_strategy == "beta": + if not (0.0 < self.timestep_sampling_s <= 1.0): + raise ValueError(f"timestep_sampling_s must be in (0, 1], got {self.timestep_sampling_s}") + if self.timestep_sampling_alpha <= 0: + raise ValueError("timestep_sampling_alpha must be positive") + if self.timestep_sampling_beta <= 0: + raise ValueError("timestep_sampling_beta must be positive") + + def get_optimizer_preset(self) -> AdamConfig: + return AdamConfig( + lr=self.optimizer_lr, + betas=self.optimizer_betas, + eps=self.optimizer_eps, + weight_decay=self.optimizer_weight_decay, + ) + + def get_scheduler_preset(self) -> DiffuserSchedulerConfig: + return DiffuserSchedulerConfig( + name=self.scheduler_name, + num_warmup_steps=self.scheduler_warmup_steps, + ) + + def validate_features(self) -> None: + """Validate that required input features are present and properly configured.""" + # If the configured crop doesn't fit, disable cropping instead of erroring. + # Note: if image_resize_shape is set, cropping is applied *after* resizing. + if self.image_crop_shape is not None: + for key, image_ft in self.image_features.items(): + # image_ft.shape is (C, H, W) + effective_h, effective_w = ( + self.image_resize_shape + if self.image_resize_shape is not None + else (image_ft.shape[1], image_ft.shape[2]) + ) + if self.image_crop_shape[0] > effective_h or self.image_crop_shape[1] > effective_w: + logging.warning( + "image_crop_shape %s doesn't fit within effective image shape (%s, %s) for '%s'; disabling cropping.", + self.image_crop_shape, + effective_h, + effective_w, + key, + ) + self.image_crop_shape = None + break + + if len(self.image_features) > 0: + first_key, first_ft = next(iter(self.image_features.items())) + for key, image_ft in self.image_features.items(): + if image_ft.shape != first_ft.shape: + raise ValueError( + f"Image '{key}' shape {image_ft.shape} != '{first_key}' shape {first_ft.shape}" + ) + + @property + def is_diffusion(self) -> bool: + return self.objective == "diffusion" + + @property + def is_flow_matching(self) -> bool: + return self.objective == "flow_matching" + + @property + def observation_delta_indices(self) -> list: + return list(range(1 - self.n_obs_steps, 1)) + + @property + def action_delta_indices(self) -> list: + return list(range(1 - self.n_obs_steps, 1 - self.n_obs_steps + self.horizon)) + + @property + def reward_delta_indices(self) -> None: + return None diff --git a/src/lerobot/policies/multi_task_dit/modeling_multi_task_dit.py b/src/lerobot/policies/multi_task_dit/modeling_multi_task_dit.py new file mode 100644 index 000000000..4fee851e0 --- /dev/null +++ b/src/lerobot/policies/multi_task_dit/modeling_multi_task_dit.py @@ -0,0 +1,803 @@ +#!/usr/bin/env python + +# Copyright 2025 Bryson Jones and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Multi-Task Diffusion Transformer (DiT) Policy + +Transformer-based diffusion policy for multi-task robot learning with text and vision conditioning. +Supports both diffusion and flow matching objectives for action generation. + +References: +- https://arxiv.org/abs/2507.05331 +- https://bostondynamics.com/blog/large-behavior-models-atlas-find-new-footing/ +- https://brysonkjones.substack.com/p/dissecting-and-open-sourcing-multitask-diffusion-transformer-policy +""" + +import math +from collections import deque +from typing import TYPE_CHECKING + +import einops +import torch +import torch.nn as nn +import torch.nn.functional as F # noqa: N812 +import torchvision +from diffusers.schedulers.scheduling_ddim import DDIMScheduler +from diffusers.schedulers.scheduling_ddpm import DDPMScheduler +from torch import Tensor + +from lerobot.policies.multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig +from lerobot.utils.import_utils import _transformers_available + +# Conditional import for type checking and lazy loading +if TYPE_CHECKING or _transformers_available: + from transformers import CLIPTextModel, CLIPVisionModel +else: + CLIPTextModel = None + CLIPVisionModel = None +from lerobot.policies.pretrained import PreTrainedPolicy +from lerobot.policies.utils import populate_queues +from lerobot.utils.constants import ( + ACTION, + OBS_IMAGES, + OBS_LANGUAGE_ATTENTION_MASK, + OBS_LANGUAGE_TOKENS, + OBS_STATE, +) + +# -- Policy -- + + +class MultiTaskDiTPolicy(PreTrainedPolicy): + config_class = MultiTaskDiTConfig + name = "multi_task_dit" + + def __init__(self, config: MultiTaskDiTConfig, **kwargs): + super().__init__(config) + config.validate_features() + self.config = config + + self._queues = None + + self.observation_encoder = ObservationEncoder(config) + conditioning_dim = self.observation_encoder.conditioning_dim + self.noise_predictor = DiffusionTransformer(config, conditioning_dim=conditioning_dim) + + action_dim = config.action_feature.shape[0] + horizon = config.horizon + + if config.is_diffusion: + self.objective = DiffusionObjective( + config, + action_dim=action_dim, + horizon=horizon, + do_mask_loss_for_padding=config.do_mask_loss_for_padding, + ) + elif config.is_flow_matching: + self.objective = FlowMatchingObjective( + config, + action_dim=action_dim, + horizon=horizon, + do_mask_loss_for_padding=config.do_mask_loss_for_padding, + ) + else: + raise ValueError(f"Unsupported objective: {config.objective}") + + self.reset() + + def get_optim_params(self) -> list: + """Returns parameter groups with different learning rates for vision vs non-vision parameters""" + non_vision_params = [] + vision_encoder_params = [] + + for name, param in self.named_parameters(): + if not param.requires_grad: + continue + + if "observation_encoder.vision_encoder" in name: + vision_encoder_params.append(param) + else: + non_vision_params.append(param) + + return [ + {"params": non_vision_params}, + { + "params": vision_encoder_params, + "lr": self.config.optimizer_lr * self.config.vision_encoder_lr_multiplier, + }, + ] + + def _generate_actions(self, batch: dict[str, Tensor]) -> Tensor: + batch_size, n_obs_steps = batch[OBS_STATE].shape[:2] + assert n_obs_steps == self.config.n_obs_steps + + conditioning_vec = self.observation_encoder.encode(batch) + actions = self.objective.conditional_sample(self.noise_predictor, batch_size, conditioning_vec) + + start = n_obs_steps - 1 + end = start + self.config.n_action_steps + actions = actions[:, start:end] + return actions + + def reset(self): + """Clear observation and action queues. Should be called on `env.reset()`""" + self._queues = { + OBS_STATE: deque(maxlen=self.config.n_obs_steps), + ACTION: deque(maxlen=self.config.n_action_steps), + } + + if self.config.image_features: + self._queues[OBS_IMAGES] = deque(maxlen=self.config.n_obs_steps) + + @torch.no_grad() + def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: + """Predict a chunk of actions given environment observations""" + self.eval() + + for k in batch: + if k in self._queues: + batch[k] = torch.stack(list(self._queues[k]), dim=1) + + actions = self._generate_actions(batch) + return actions + + def _prepare_batch(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: + """Prepare batch by stacking image features if needed.""" + if self.config.image_features: + batch = dict(batch) # shallow copy to avoid modifying original + batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4) + + return batch + + @torch.no_grad() + def select_action(self, batch: dict[str, Tensor]) -> Tensor: + """Select a single action given environment observations""" + if ACTION in batch: + batch = dict(batch) # shallow copy to avoid modifying original + batch.pop(ACTION) + + batch = self._prepare_batch(batch) + + self._queues = populate_queues(self._queues, batch) + + if len(self._queues[ACTION]) == 0: + actions = self.predict_action_chunk(batch) + self._queues[ACTION].extend(actions.transpose(0, 1)) + + action = self._queues[ACTION].popleft() + return action + + def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict | None]: + """Run the batch through the model and compute the loss for training""" + batch = self._prepare_batch(batch) + + conditioning_vec = self.observation_encoder.encode(batch) + loss = self.objective.compute_loss(self.noise_predictor, batch, conditioning_vec) + + return loss, None + + +# -- Observation Encoders -- + + +class CLIPVisionEncoder(nn.Module): + """CLIP vision encoder using the CLS token for global image representation.""" + + def __init__(self, model_name: str): + super().__init__() + self.model_name = model_name + self.model = CLIPVisionModel.from_pretrained(self.model_name) + self.num_non_spatial_tokens = 1 + self.embed_dim = self.model.config.hidden_size + + def forward(self, x: Tensor) -> Tensor: + """Encode RGB image to CLS token.""" + outputs = self.model(pixel_values=x, output_hidden_states=False) + cls_token = outputs.last_hidden_state[:, 0] + b, embed_dim = cls_token.shape + return cls_token.reshape(b, embed_dim, 1, 1) + + def get_output_shape(self) -> tuple: + return (self.embed_dim, 1, 1) + + +class CLIPTextEncoder(nn.Module): + """CLIP text encoder with frozen weights and a learnable projection layer. + + Accepts pre-tokenized inputs (input_ids and attention_mask) from the processor pipeline. See the processor + pipeline to see how the tokenization is handled. + """ + + def __init__(self, model_name: str = "openai/clip-vit-base-patch16", projection_dim: int = 512): + super().__init__() + self.model_name = model_name + self.projection_dim = projection_dim + self.text_encoder = CLIPTextModel.from_pretrained(model_name) + + for param in self.text_encoder.parameters(): + param.requires_grad = False + + self.text_embed_dim = self.text_encoder.config.hidden_size + self.projection = nn.Linear(self.text_embed_dim, projection_dim) + + def forward(self, input_ids: Tensor, attention_mask: Tensor) -> Tensor: + """Encode pre-tokenized text to feature vectors.""" + # Ensure inputs are on the same device as the model + device = next(self.parameters()).device + input_ids = input_ids.to(device) + attention_mask = attention_mask.to(device) + + with torch.no_grad(): + outputs = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask) + clip_features = outputs.pooler_output + + return self.projection(clip_features) + + +class ObservationEncoder(nn.Module): + """Handles all observation processing for the conditioning vector.""" + + def __init__(self, config): + super().__init__() + self.config = config + self._setup_preprocessing(config) + + if config.image_features: + self.num_cameras = len(config.image_features) + self.camera_names = list(config.image_features.keys()) + + if config.use_separate_rgb_encoder_per_camera: + self.vision_encoders = nn.ModuleList( + [CLIPVisionEncoder(model_name=config.vision_encoder_name) for _ in self.camera_names] + ) + self.vision_encoder = None + else: + self.vision_encoder = CLIPVisionEncoder(model_name=config.vision_encoder_name) + self.vision_encoders = None + else: + self.vision_encoder = None + self.vision_encoders = None + self.camera_names = [] + self.num_cameras = 0 + + if hasattr(config, "robot_state_feature") and config.robot_state_feature: + self.robot_state_dim = config.robot_state_feature.shape[0] + else: + self.robot_state_dim = 0 + + self.text_dim = config.hidden_dim + self.text_encoder = CLIPTextEncoder(model_name=config.text_encoder_name, projection_dim=self.text_dim) + + self._setup_vector_output() + + def _apply_preprocessing(self, images: Tensor) -> Tensor: + if self.do_resize: + images = self.resize(images) + if self.do_crop: + images = self.maybe_random_crop(images) if self.training else self.center_crop(images) + return images + + def _setup_preprocessing(self, config): + if config.image_resize_shape is not None: + self.do_resize = True + self.resize = torchvision.transforms.Resize( + size=config.image_resize_shape, + interpolation=torchvision.transforms.InterpolationMode.BILINEAR, + antialias=True, + ) + else: + self.do_resize = False + + if config.image_crop_shape is not None: + self.do_crop = True + self.center_crop = torchvision.transforms.CenterCrop(config.image_crop_shape) + if config.image_crop_is_random: + self.maybe_random_crop = torchvision.transforms.RandomCrop(config.image_crop_shape) + else: + self.maybe_random_crop = self.center_crop + else: + self.do_crop = False + + def _setup_vector_output(self): + total_dim = 0 + + if self.vision_encoder is not None or self.vision_encoders is not None: + encoder_to_check = self.vision_encoder or next(iter(self.vision_encoders)) + feature_map_shape = encoder_to_check.get_output_shape() + c, h, w = feature_map_shape + spatial_feature_dim = c * h * w + total_dim += spatial_feature_dim * self.num_cameras + + total_dim += self.robot_state_dim + total_dim += self.text_dim + + self.conditioning_dim = total_dim * self.config.n_obs_steps + + def encode(self, batch: dict) -> Tensor: + """Encode observations to vector format.""" + batch_size, n_obs_steps = batch[OBS_STATE].shape[:2] + conditioning_feats = [] + + conditioning_feats.append(batch[OBS_STATE]) + + if self.vision_encoder is not None or self.vision_encoders is not None: + images = batch[OBS_IMAGES] + + if len(images.shape) == 5: + images = images.unsqueeze(1) + + if self.config.use_separate_rgb_encoder_per_camera: + camera_features = [] + for cam_idx in range(self.num_cameras): + cam_images = images[:, :, cam_idx] + cam_images_flat = einops.rearrange(cam_images, "b s c h w -> (b s) c h w") + cam_images_flat = self._apply_preprocessing(cam_images_flat) + cam_features = self.vision_encoders[cam_idx](cam_images_flat) + cam_visual_features = cam_features.flatten(start_dim=1) + cam_features_reshaped = einops.rearrange( + cam_visual_features, "(b s) f -> b s f", b=batch_size, s=n_obs_steps + ) + camera_features.append(cam_features_reshaped) + img_features = torch.cat(camera_features, dim=-1) + conditioning_feats.append(img_features) + else: + images_flat = einops.rearrange(images, "b s n c h w -> (b s n) c h w") + images_flat = self._apply_preprocessing(images_flat) + visual_features = self.vision_encoder(images_flat).flatten(start_dim=1) + img_features = einops.rearrange( + visual_features, "(b s n) f -> b s (n f)", b=batch_size, s=n_obs_steps, n=self.num_cameras + ) + conditioning_feats.append(img_features) + + if self.text_encoder is not None and OBS_LANGUAGE_TOKENS in batch: + input_ids = batch[OBS_LANGUAGE_TOKENS] # [batch_size, seq_length] + attention_mask = batch[OBS_LANGUAGE_ATTENTION_MASK] # [batch_size, seq_length] + + text_features = self.text_encoder(input_ids, attention_mask) + + text_features = text_features.unsqueeze(1).expand(-1, n_obs_steps, -1) + conditioning_feats.append(text_features) + + combined_features = torch.cat(conditioning_feats, dim=-1) + return combined_features.flatten(start_dim=1) + + +# -- Transformer Components -- + + +def modulate(x: Tensor, shift: Tensor, scale: Tensor) -> Tensor: + """Modulate input with shift and scale for AdaLN-Zero.""" + return x * (1 + scale) + shift + + +class SinusoidalPosEmb(nn.Module): + """Sinusoidal positional embeddings for timesteps.""" + + def __init__(self, dim: int): + super().__init__() + self.dim = dim + + def forward(self, x: Tensor) -> Tensor: + device = x.device + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, device=device) * -emb) + emb = x[:, None] * emb[None, :] + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb + + +class RotaryPositionalEmbedding(nn.Module): + """Rotary Position Embedding (RoPE) for transformers.""" + + def __init__(self, head_dim: int, max_seq_len: int = 512, base: float = 10000.0): + super().__init__() + assert head_dim % 2 == 0, "head_dim must be even for RoPE" + + self.head_dim = head_dim + self.max_seq_len = max_seq_len + self.base = base + + inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._precompute_cache(max_seq_len) + + def _precompute_cache(self, seq_len: int): + t = torch.arange(seq_len, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("_cos_cached", emb.cos()[None, None, :, :], persistent=False) + self.register_buffer("_sin_cached", emb.sin()[None, None, :, :], persistent=False) + + def _rotate_half(self, x: Tensor) -> Tensor: + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + def forward(self, q: Tensor, k: Tensor) -> tuple[Tensor, Tensor]: + seq_len = q.shape[2] + if seq_len > self.max_seq_len: + raise ValueError(f"Sequence length {seq_len} exceeds max_seq_len {self.max_seq_len}.") + + cos = self._cos_cached[:, :, :seq_len, :].to(q.dtype) + sin = self._sin_cached[:, :, :seq_len, :].to(q.dtype) + + q_rotated = (q * cos) + (self._rotate_half(q) * sin) + k_rotated = (k * cos) + (self._rotate_half(k) * sin) + return q_rotated, k_rotated + + +class RoPEAttention(nn.Module): + """Multi-head self-attention with Rotary Position Embedding (RoPE).""" + + def __init__( + self, + hidden_size: int, + num_heads: int, + dropout: float = 0.0, + max_seq_len: int = 512, + rope_base: float = 10000.0, + ): + super().__init__() + assert hidden_size % num_heads == 0, "hidden_size must be divisible by num_heads" + + self.hidden_size = hidden_size + self.num_heads = num_heads + self.head_dim = hidden_size // num_heads + self.scale = self.head_dim**-0.5 + + self.qkv_proj = nn.Linear(hidden_size, 3 * hidden_size, bias=True) + self.out_proj = nn.Linear(hidden_size, hidden_size, bias=True) + self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity() + self.rope = RotaryPositionalEmbedding(head_dim=self.head_dim, max_seq_len=max_seq_len, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + B, T, _ = x.shape # noqa: N806 + + qkv = self.qkv_proj(x) + qkv = qkv.reshape(B, T, 3, self.num_heads, self.head_dim) + qkv = qkv.permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + + q, k = self.rope(q, k) + + attn_out = torch.nn.functional.scaled_dot_product_attention( + q, + k, + v, + dropout_p=self.dropout.p if isinstance(self.dropout, nn.Dropout) and self.training else 0.0, + ) + + attn_out = attn_out.transpose(1, 2).reshape(B, T, self.hidden_size) + return self.out_proj(attn_out) + + +class TransformerBlock(nn.Module): + """DiT-style transformer block with AdaLN-Zero.""" + + def __init__( + self, + hidden_size: int = 128, + num_heads: int = 4, + num_features: int = 128, + dropout: float = 0.0, + use_rope: bool = False, + max_seq_len: int = 512, + rope_base: float = 10000.0, + ): + super().__init__() + self.use_rope = use_rope + + if use_rope: + self.attn = RoPEAttention( + hidden_size=hidden_size, + num_heads=num_heads, + dropout=dropout, + max_seq_len=max_seq_len, + rope_base=rope_base, + ) + else: + self.multihead_attn = nn.MultiheadAttention( + hidden_size, num_heads=num_heads, batch_first=True, dropout=dropout + ) + + self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + + self.mlp = nn.Sequential( + nn.Linear(hidden_size, hidden_size * 4), + nn.GELU(approximate="tanh"), + nn.Linear(hidden_size * 4, hidden_size), + ) + + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(num_features, 6 * hidden_size, bias=True)) + + def forward(self, x: Tensor, features: Tensor) -> Tensor: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation( + features + ).chunk(6, dim=1) + + attn_input = modulate(self.norm1(x), shift_msa.unsqueeze(1), scale_msa.unsqueeze(1)) + + if self.use_rope: + attn_out = self.attn(attn_input) + else: + attn_out, _ = self.multihead_attn(attn_input, attn_input, attn_input) + + x = x + gate_msa.unsqueeze(1) * attn_out + + mlp_input = modulate(self.norm2(x), shift_mlp.unsqueeze(1), scale_mlp.unsqueeze(1)) + mlp_out = self.mlp(mlp_input) + x = x + gate_mlp.unsqueeze(1) * mlp_out + + return x + + +class DiffusionTransformer(nn.Module): + """Transformer-based diffusion noise prediction model.""" + + def __init__(self, config, conditioning_dim: int): + super().__init__() + self.config = config + self.conditioning_dim = conditioning_dim + + self.action_dim = config.action_feature.shape[0] + self.horizon = config.horizon + self.hidden_size = config.hidden_dim + self.num_layers = config.num_layers + self.num_heads = config.num_heads + self.dropout = config.dropout + self.use_rope = config.use_rope + + self.timestep_embed_dim = config.timestep_embed_dim + self.time_mlp = nn.Sequential( + SinusoidalPosEmb(self.timestep_embed_dim), + nn.Linear(self.timestep_embed_dim, 2 * self.timestep_embed_dim), + nn.GELU(), + nn.Linear(2 * self.timestep_embed_dim, self.timestep_embed_dim), + nn.GELU(), + ) + + self.cond_dim = self.timestep_embed_dim + conditioning_dim + self.input_proj = nn.Linear(self.action_dim, self.hidden_size) + + if config.use_positional_encoding: + self.pos_embedding = nn.Parameter( + torch.empty(1, self.horizon, self.hidden_size).normal_(std=0.02) + ) + else: + self.pos_embedding = None + + self.transformer_blocks = nn.ModuleList( + [ + TransformerBlock( + hidden_size=self.hidden_size, + num_heads=self.num_heads, + num_features=self.cond_dim, + dropout=self.dropout, + use_rope=self.use_rope, + max_seq_len=self.horizon, + rope_base=config.rope_base, + ) + for _ in range(self.num_layers) + ] + ) + + self.output_proj = nn.Linear(self.hidden_size, self.action_dim) + self._initialize_weights() + + def _initialize_weights(self): + for block in self.transformer_blocks: + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + def forward(self, x: Tensor, timestep: Tensor, conditioning_vec: Tensor) -> Tensor: + _, seq_len, _ = x.shape + + timestep_features = self.time_mlp(timestep) + cond_features = torch.cat([timestep_features, conditioning_vec], dim=-1) + + hidden_seq = self.input_proj(x) + + if self.pos_embedding is not None: + hidden_seq = hidden_seq + self.pos_embedding[:, :seq_len, :] + + for block in self.transformer_blocks: + hidden_seq = block(hidden_seq, cond_features) + + return self.output_proj(hidden_seq) + + +# -- Objectives -- + + +class DiffusionObjective(nn.Module): + """Standard diffusion (DDPM/DDIM) objective implementation.""" + + def __init__(self, config, action_dim: int, horizon: int, do_mask_loss_for_padding: bool = False): + super().__init__() + self.config = config + self.action_dim = action_dim + self.horizon = horizon + self.do_mask_loss_for_padding = do_mask_loss_for_padding + + scheduler_kwargs = { + "num_train_timesteps": config.num_train_timesteps, + "beta_start": config.beta_start, + "beta_end": config.beta_end, + "beta_schedule": config.beta_schedule, + "clip_sample": config.clip_sample, + "clip_sample_range": config.clip_sample_range, + "prediction_type": config.prediction_type, + } + + if config.noise_scheduler_type == "DDPM": + self.noise_scheduler: DDPMScheduler | DDIMScheduler = DDPMScheduler(**scheduler_kwargs) + elif config.noise_scheduler_type == "DDIM": + self.noise_scheduler = DDIMScheduler(**scheduler_kwargs) + else: + raise ValueError(f"Unsupported noise scheduler type {config.noise_scheduler_type}") + + self.num_inference_steps = ( + config.num_inference_steps + if config.num_inference_steps is not None + else self.noise_scheduler.config.num_train_timesteps + ) + + def compute_loss(self, model: nn.Module, batch: dict[str, Tensor], conditioning_vec: Tensor) -> Tensor: + clean_actions = batch[ACTION] + noise = torch.randn_like(clean_actions) + timesteps = torch.randint( + low=0, + high=self.noise_scheduler.config.num_train_timesteps, + size=(clean_actions.shape[0],), + device=clean_actions.device, + ).long() + noisy_actions = self.noise_scheduler.add_noise(clean_actions, noise, timesteps) + + prediction_type = self.noise_scheduler.config.prediction_type + if prediction_type == "epsilon": + target = noise + elif prediction_type == "sample": + target = clean_actions + else: + raise ValueError(f"Unsupported prediction type: {prediction_type}") + + predicted = model(noisy_actions, timesteps, conditioning_vec=conditioning_vec) + loss = F.mse_loss(predicted, target, reduction="none") + + if self.do_mask_loss_for_padding and "action_is_pad" in batch: + valid_actions = ~batch["action_is_pad"] + loss = loss * valid_actions.unsqueeze(-1) + + return loss.mean() + + def conditional_sample(self, model: nn.Module, batch_size: int, conditioning_vec: Tensor) -> Tensor: + device = next(model.parameters()).device + dtype = next(model.parameters()).dtype + + sample = torch.randn( + size=(batch_size, self.horizon, self.action_dim), + dtype=dtype, + device=device, + ) + + self.noise_scheduler.set_timesteps(self.num_inference_steps) + for t in self.noise_scheduler.timesteps: + model_output = model( + sample, + torch.full(sample.shape[:1], t, dtype=torch.long, device=sample.device), + conditioning_vec=conditioning_vec, + ) + sample = self.noise_scheduler.step(model_output, t, sample).prev_sample + + return sample + + +class FlowMatchingObjective(nn.Module): + """Flow matching objective: trains a model to predict velocity fields.""" + + def __init__(self, config, action_dim: int, horizon: int, do_mask_loss_for_padding: bool = False): + super().__init__() + self.config = config + self.action_dim = action_dim + self.horizon = horizon + self.do_mask_loss_for_padding = do_mask_loss_for_padding + + def _sample_timesteps(self, batch_size: int, device: torch.device) -> Tensor: + if self.config.timestep_sampling_strategy == "uniform": + return torch.rand(batch_size, device=device) + elif self.config.timestep_sampling_strategy == "beta": + beta_dist = torch.distributions.Beta( + self.config.timestep_sampling_alpha, self.config.timestep_sampling_beta + ) + u = beta_dist.sample((batch_size,)).to(device) + return self.config.timestep_sampling_s * (1.0 - u) + else: + raise ValueError(f"Unknown timestep strategy: {self.config.timestep_sampling_strategy}") + + def compute_loss(self, model: nn.Module, batch: dict[str, Tensor], conditioning_vec: Tensor) -> Tensor: + data = batch[ACTION] + batch_size = data.shape[0] + device = data.device + + noise = torch.randn_like(data) + t = self._sample_timesteps(batch_size, device) + t_expanded = t.view(-1, 1, 1) + x_t = t_expanded * data + (1 - (1 - self.config.sigma_min) * t_expanded) * noise + + target_velocity = data - (1 - self.config.sigma_min) * noise + predicted_velocity = model(x_t, t, conditioning_vec=conditioning_vec) + loss = F.mse_loss(predicted_velocity, target_velocity, reduction="none") + + if self.do_mask_loss_for_padding and "action_is_pad" in batch: + valid_mask = ~batch["action_is_pad"] + loss = loss * valid_mask.unsqueeze(-1) + + return loss.mean() + + def conditional_sample(self, model: nn.Module, batch_size: int, conditioning_vec: Tensor) -> Tensor: + device = next(model.parameters()).device + dtype = next(model.parameters()).dtype + + x = torch.randn((batch_size, self.horizon, self.action_dim), dtype=dtype, device=device) + + num_steps = self.config.num_integration_steps + time_grid = torch.linspace(0, 1, num_steps + 1, device=device) + + if self.config.integration_method == "euler": + x = self._euler_integrate(model, x, time_grid, conditioning_vec) + elif self.config.integration_method == "rk4": + x = self._rk4_integrate(model, x, time_grid, conditioning_vec) + else: + raise ValueError(f"Unknown integration method: {self.config.integration_method}") + + return x + + def _euler_integrate( + self, model: nn.Module, x_init: Tensor, time_grid: Tensor, conditioning_vec: Tensor + ) -> Tensor: + x = x_init + for i in range(len(time_grid) - 1): + t_scalar = time_grid[i].item() + dt = (time_grid[i + 1] - time_grid[i]).item() + t_batch = torch.full((x.shape[0],), t_scalar, dtype=x.dtype, device=x.device) + with torch.no_grad(): + velocity = model(x, t_batch, conditioning_vec=conditioning_vec) + x = x + dt * velocity + return x + + def _rk4_integrate( + self, model: nn.Module, x_init: Tensor, time_grid: Tensor, conditioning_vec: Tensor + ) -> Tensor: + x = x_init + + def dynamics(x_val: Tensor, t_scalar: float) -> Tensor: + t_batch = torch.full((x_val.shape[0],), t_scalar, dtype=x_val.dtype, device=x_val.device) + with torch.no_grad(): + return model(x_val, t_batch, conditioning_vec=conditioning_vec) + + for i in range(len(time_grid) - 1): + t = time_grid[i].item() + dt = (time_grid[i + 1] - time_grid[i]).item() + + k1 = dynamics(x, t) + k2 = dynamics(x + dt * k1 / 2, t + dt / 2) + k3 = dynamics(x + dt * k2 / 2, t + dt / 2) + k4 = dynamics(x + dt * k3, t + dt) + + x = x + dt / 6 * (k1 + 2 * k2 + 2 * k3 + k4) + + return x diff --git a/src/lerobot/policies/multi_task_dit/processor_multi_task_dit.py b/src/lerobot/policies/multi_task_dit/processor_multi_task_dit.py new file mode 100644 index 000000000..fc94599c2 --- /dev/null +++ b/src/lerobot/policies/multi_task_dit/processor_multi_task_dit.py @@ -0,0 +1,105 @@ +#!/usr/bin/env python + +# Copyright 2025 Bryson Jones and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +import torch + +from lerobot.policies.multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig +from lerobot.processor import ( + AddBatchDimensionProcessorStep, + DeviceProcessorStep, + NormalizerProcessorStep, + PolicyAction, + PolicyProcessorPipeline, + RenameObservationsProcessorStep, + TokenizerProcessorStep, + UnnormalizerProcessorStep, +) +from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action +from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME + + +def make_multi_task_dit_pre_post_processors( + config: MultiTaskDiTConfig, + dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None, +) -> tuple[ + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], + PolicyProcessorPipeline[PolicyAction, PolicyAction], +]: + """ + Constructs pre-processor and post-processor pipelines for a Multi-Task DiT policy. + + The pre-processing pipeline prepares the input data for the model by: + 1. Renaming features. + 2. Adding a batch dimension. + 3. Tokenizing the language task description (if present). + 4. Moving the data to the specified device. + 5. Normalizing the input and output features based on dataset statistics. + + The post-processing pipeline handles the model's output by: + 1. Unnormalizing the output features to their original scale. + 2. Moving the data to the CPU. + + Args: + config: The configuration object for the Multi-Task DiT policy, + containing feature definitions, normalization mappings, and device information. + dataset_stats: A dictionary of statistics used for normalization. + Defaults to None. + + Returns: + A tuple containing the configured pre-processor and post-processor pipelines. + """ + + input_steps = [ + RenameObservationsProcessorStep(rename_map={}), + AddBatchDimensionProcessorStep(), + TokenizerProcessorStep( + tokenizer_name=config.text_encoder_name, + padding=config.tokenizer_padding, + padding_side=config.tokenizer_padding_side, + max_length=config.tokenizer_max_length, + truncation=config.tokenizer_truncation, + ), + DeviceProcessorStep(device=config.device), + NormalizerProcessorStep( + features={**config.input_features, **config.output_features}, + norm_map=config.normalization_mapping, + stats=dataset_stats, + device=config.device, + ), + ] + output_steps = [ + UnnormalizerProcessorStep( + features=config.output_features, + norm_map=config.normalization_mapping, + stats=dataset_stats, + ), + DeviceProcessorStep(device="cpu"), + ] + + return ( + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]]( + steps=input_steps, + name=POLICY_PREPROCESSOR_DEFAULT_NAME, + ), + PolicyProcessorPipeline[PolicyAction, PolicyAction]( + steps=output_steps, + name=POLICY_POSTPROCESSOR_DEFAULT_NAME, + to_transition=policy_action_to_transition, + to_output=transition_to_policy_action, + ), + ) diff --git a/tests/policies/multi_task_dit/test_multi_task_dit.py b/tests/policies/multi_task_dit/test_multi_task_dit.py new file mode 100644 index 000000000..5b70422d4 --- /dev/null +++ b/tests/policies/multi_task_dit/test_multi_task_dit.py @@ -0,0 +1,624 @@ +#!/usr/bin/env python + +# Copyright 2025 Bryson Jones and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ruff: noqa: E402 + +"""Test script for Multi-Task DiT policy. + +To run tests locally: + python -m pytest tests/policies/multi_task_dit/test_multi_task_dit.py -v +""" + +import os + +import pytest +import torch +from torch import Tensor + +pytest.importorskip("transformers") + +pytestmark = pytest.mark.skipif( + os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true", + reason="This test requires local transformers installation and is not meant for CI", +) + +from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature +from lerobot.policies.multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig +from lerobot.policies.multi_task_dit.modeling_multi_task_dit import MultiTaskDiTPolicy +from lerobot.policies.multi_task_dit.processor_multi_task_dit import ( + make_multi_task_dit_pre_post_processors, +) +from lerobot.utils.constants import ( + ACTION, + OBS_IMAGES, + OBS_LANGUAGE_ATTENTION_MASK, + OBS_LANGUAGE_TOKENS, + OBS_STATE, +) +from lerobot.utils.random_utils import seeded_context, set_seed + + +@pytest.fixture(autouse=True) +def set_random_seed(): + seed = 17 + set_seed(seed) + + +def create_train_batch( + batch_size: int = 2, + n_obs_steps: int = 2, + horizon: int = 16, + state_dim: int = 10, + action_dim: int = 10, + height: int = 224, + width: int = 224, +) -> dict[str, Tensor]: + """Create a training batch with visual input and text.""" + return { + "observation.state": torch.randn(batch_size, n_obs_steps, state_dim), + f"{OBS_IMAGES}.laptop": torch.rand(batch_size, n_obs_steps, 3, height, width), + ACTION: torch.randn(batch_size, horizon, action_dim), + "task": ["pick up the cube"] * batch_size, + } + + +def create_observation_batch( + batch_size: int = 2, state_dim: int = 10, height: int = 224, width: int = 224 +) -> dict: + """Create observation batch for inference for a single timestep.""" + return { + "observation.state": torch.randn(batch_size, state_dim), + f"{OBS_IMAGES}.laptop": torch.rand(batch_size, 3, height, width), + "task": ["pick up the red cube"] * batch_size, + } + + +def create_config( + state_dim: int = 10, + action_dim: int = 10, + n_obs_steps: int = 2, + horizon: int = 16, + n_action_steps: int = 8, + with_visual: bool = True, + height: int = 224, + width: int = 224, +) -> MultiTaskDiTConfig: + """Create a MultiTaskDiT config for testing. + + Args: + state_dim: Dimension of state observations + action_dim: Dimension of actions + n_obs_steps: Number of observation steps + horizon: Action prediction horizon + n_action_steps: Number of action steps to execute + with_visual: Whether to include visual input (default: True) + height: Image height (only used if with_visual=True) + width: Image width (only used if with_visual=True) + """ + input_features = {OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))} + + if with_visual: + input_features[f"{OBS_IMAGES}.laptop"] = PolicyFeature( + type=FeatureType.VISUAL, shape=(3, height, width) + ) + + config = MultiTaskDiTConfig( + input_features=input_features, + output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))}, + n_obs_steps=n_obs_steps, + horizon=horizon, + n_action_steps=n_action_steps, + # Use smaller model for faster tests + hidden_dim=128, + num_layers=2, + num_heads=4, + ) + + config.validate_features() + return config + + +@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 10, 10), (1, 6, 6)]) +def test_multi_task_dit_policy_forward(batch_size: int, state_dim: int, action_dim: int): + """Test forward pass (training mode).""" + n_obs_steps = 2 + horizon = 16 + n_action_steps = 8 + + config = create_config( + state_dim=state_dim, + action_dim=action_dim, + n_obs_steps=n_obs_steps, + horizon=horizon, + n_action_steps=n_action_steps, + ) + + policy = MultiTaskDiTPolicy(config=config) + policy.train() + + # Use preprocessor to handle tokenization + config.normalization_mapping = { + "VISUAL": NormalizationMode.IDENTITY, + "STATE": NormalizationMode.IDENTITY, + "ACTION": NormalizationMode.IDENTITY, + } + preprocessor, _ = make_multi_task_dit_pre_post_processors(config=config, dataset_stats=None) + + batch = create_train_batch( + batch_size=batch_size, + n_obs_steps=n_obs_steps, + horizon=horizon, + state_dim=state_dim, + action_dim=action_dim, + ) + + # Process batch through preprocessor to tokenize task text + processed_batch = preprocessor(batch) + + # Test forward pass + loss, _ = policy.forward(processed_batch) + assert loss is not None + assert loss.item() is not None + assert loss.shape == () + + # Test backward pass + loss.backward() + + +def test_multi_task_dit_pre_post_processors(): + """Test pre and post processors for Multi-Task DiT policy.""" + state_dim = 10 + action_dim = 8 + n_obs_steps = 2 + horizon = 16 + + config = create_config( + state_dim=state_dim, + action_dim=action_dim, + n_obs_steps=n_obs_steps, + horizon=horizon, + n_action_steps=8, + ) + config.device = "cpu" + + # Set normalization mode to match the stats we're providing + config.normalization_mapping = { + "VISUAL": NormalizationMode.IDENTITY, + "STATE": NormalizationMode.MEAN_STD, # Use MEAN_STD since we provide mean/std stats + "ACTION": NormalizationMode.MIN_MAX, + } + + # Create dataset stats for normalization + dataset_stats = { + "observation.state": { + "mean": torch.zeros(state_dim), + "std": torch.ones(state_dim), + }, + "action": { + "min": torch.full((action_dim,), -1.0), + "max": torch.ones(action_dim), + }, + } + + # Create processors + preprocessor, postprocessor = make_multi_task_dit_pre_post_processors( + config=config, dataset_stats=dataset_stats + ) + + # Test preprocessor with sample data + batch = { + "observation.state": torch.randn(state_dim), + f"{OBS_IMAGES}.laptop": torch.rand(3, 224, 224), + ACTION: torch.randn(action_dim), + "task": "pick up the cube", + } + + processed_batch = preprocessor(batch) + + # Check that data is batched + assert processed_batch["observation.state"].shape == (1, state_dim) + assert processed_batch[f"{OBS_IMAGES}.laptop"].shape == (1, 3, 224, 224) + assert processed_batch[ACTION].shape == (1, action_dim) + # Check that task text was tokenized + assert OBS_LANGUAGE_TOKENS in processed_batch + assert OBS_LANGUAGE_ATTENTION_MASK in processed_batch + assert processed_batch[OBS_LANGUAGE_TOKENS].shape[0] == 1 # batch dimension + assert processed_batch[OBS_LANGUAGE_ATTENTION_MASK].shape[0] == 1 # batch dimension + + # Check that data is on correct device + assert processed_batch["observation.state"].device.type == "cpu" + assert processed_batch[f"{OBS_IMAGES}.laptop"].device.type == "cpu" + assert processed_batch[ACTION].device.type == "cpu" + + # Test postprocessor with sample action (PolicyAction is just a torch.Tensor) + action = torch.randn(1, action_dim) + processed_action = postprocessor(action) + + # Check that action is unnormalized and on CPU + assert processed_action.shape == (1, action_dim) + assert processed_action.device.type == "cpu" + + +def test_multi_task_dit_pre_post_processors_normalization(): + """Test that normalization and unnormalization work correctly with simple sanity check numbers.""" + state_dim = 3 + action_dim = 2 + + config = create_config( + state_dim=state_dim, + action_dim=action_dim, + n_obs_steps=2, + horizon=16, + n_action_steps=8, + ) + config.device = "cpu" + + # Set normalization mode to match the stats we're providing + config.normalization_mapping = { + "VISUAL": NormalizationMode.IDENTITY, + "STATE": NormalizationMode.MEAN_STD, # Use MEAN_STD since we provide mean/std stats + "ACTION": NormalizationMode.MIN_MAX, + } + + # Use simple stats that will actually transform the values + dataset_stats = { + "observation.state": { + "mean": torch.full((state_dim,), 5.0), + "std": torch.full((state_dim,), 2.0), + }, + "action": { + "min": torch.zeros(action_dim), + "max": torch.full((action_dim,), 2.0), + }, + } + + # Create processors + preprocessor, postprocessor = make_multi_task_dit_pre_post_processors( + config=config, dataset_stats=dataset_stats + ) + + # Use simple input values + input_state = torch.tensor([7.0, 5.0, 3.0]) # Will normalize to [1.0, 0.0, -1.0] + input_action = torch.tensor([1.0, 2.0]) # Will normalize to [0.0, 1.0] + + batch = { + "observation.state": input_state, + f"{OBS_IMAGES}.laptop": torch.rand(3, 224, 224), + ACTION: input_action, + "task": "test task", + } + + # Process through preprocessor + processed_batch = preprocessor(batch) + + # State normalization: (x - mean) / std + expected_normalized_state = torch.tensor([1.0, 0.0, -1.0]) + assert torch.allclose(processed_batch["observation.state"][0], expected_normalized_state, atol=1e-5) + + # Action normalization: (x - min) / (max - min) * 2 - 1 + expected_normalized_action = torch.tensor([0.0, 1.0]) + assert torch.allclose(processed_batch[ACTION][0], expected_normalized_action, atol=1e-5) + + # Test unnormalization: should recover original values + normalized_action_tensor = processed_batch[ACTION][0:1] # Keep batch dimension + unnormalized_action = postprocessor(normalized_action_tensor) + + # Should recover original action values + assert torch.allclose(unnormalized_action[0], input_action, atol=1e-4) + + +@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 10, 10), (1, 6, 6)]) +def test_multi_task_dit_policy_select_action(batch_size: int, state_dim: int, action_dim: int): + """Test select_action (inference mode).""" + n_obs_steps = 2 + horizon = 16 + n_action_steps = 8 + + config = create_config( + state_dim=state_dim, + action_dim=action_dim, + n_obs_steps=n_obs_steps, + horizon=horizon, + n_action_steps=n_action_steps, + ) + + policy = MultiTaskDiTPolicy(config=config) + policy.eval() + policy.reset() # Reset queues before inference + + # Create processors - use IDENTITY normalization when no stats provided + config.normalization_mapping = { + "VISUAL": NormalizationMode.IDENTITY, + "STATE": NormalizationMode.IDENTITY, + "ACTION": NormalizationMode.IDENTITY, + } + preprocessor, postprocessor = make_multi_task_dit_pre_post_processors(config=config, dataset_stats=None) + + with torch.no_grad(): + observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim) + # Process observation through preprocessor + processed_obs = preprocessor(observation_batch) + selected_action = policy.select_action(processed_obs) + # Process action through postprocessor (PolicyAction is just a torch.Tensor) + processed_action = postprocessor(selected_action) + assert processed_action.shape == (batch_size, action_dim) + + +def test_multi_task_dit_policy_diffusion_objective(): + """Test policy with diffusion objective.""" + batch_size = 2 + state_dim = 10 + action_dim = 10 + n_obs_steps = 2 + horizon = 16 + n_action_steps = 8 + + input_features = { + OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,)), + f"{OBS_IMAGES}.laptop": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)), + } + + config = MultiTaskDiTConfig( + input_features=input_features, + output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))}, + n_obs_steps=n_obs_steps, + horizon=horizon, + n_action_steps=n_action_steps, + # Use diffusion objective + objective="diffusion", + noise_scheduler_type="DDPM", + num_train_timesteps=100, + num_inference_steps=10, + # Smaller model for tests + hidden_dim=128, + num_layers=2, + num_heads=4, + ) + config.validate_features() + + policy = MultiTaskDiTPolicy(config=config) + policy.train() + + # Use preprocessor to handle tokenization + config.normalization_mapping = { + "VISUAL": NormalizationMode.IDENTITY, + "STATE": NormalizationMode.IDENTITY, + "ACTION": NormalizationMode.IDENTITY, + } + preprocessor, _ = make_multi_task_dit_pre_post_processors(config=config, dataset_stats=None) + + batch = create_train_batch( + batch_size=batch_size, + n_obs_steps=n_obs_steps, + horizon=horizon, + state_dim=state_dim, + action_dim=action_dim, + ) + + # Process batch through preprocessor to tokenize task text + processed_batch = preprocessor(batch) + + # Test forward pass + loss, _ = policy.forward(processed_batch) + assert loss is not None + assert loss.item() is not None + + # Test inference + policy.eval() + # Use IDENTITY normalization when no stats provided + config.normalization_mapping = { + "VISUAL": NormalizationMode.IDENTITY, + "STATE": NormalizationMode.IDENTITY, + "ACTION": NormalizationMode.IDENTITY, + } + preprocessor, postprocessor = make_multi_task_dit_pre_post_processors(config=config, dataset_stats=None) + with torch.no_grad(): + observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim) + # Process observation through preprocessor + processed_obs = preprocessor(observation_batch) + selected_action = policy.select_action(processed_obs) + # Process action through postprocessor (PolicyAction is just a torch.Tensor) + processed_action = postprocessor(selected_action) + assert processed_action.shape == (batch_size, action_dim) + + +def test_multi_task_dit_policy_flow_matching_objective(): + """Test policy with flow matching objective.""" + batch_size = 2 + state_dim = 10 + action_dim = 10 + n_obs_steps = 2 + horizon = 16 + n_action_steps = 8 + + input_features = { + OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,)), + f"{OBS_IMAGES}.laptop": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)), + } + + config = MultiTaskDiTConfig( + input_features=input_features, + output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))}, + n_obs_steps=n_obs_steps, + horizon=horizon, + n_action_steps=n_action_steps, + # Use flow matching objective + objective="flow_matching", + sigma_min=0.0, + num_integration_steps=10, # Fewer steps for faster tests + integration_method="euler", + # Smaller model for tests + hidden_dim=128, + num_layers=2, + num_heads=4, + ) + config.validate_features() + + policy = MultiTaskDiTPolicy(config=config) + policy.train() + + # Use preprocessor to handle tokenization + config.normalization_mapping = { + "VISUAL": NormalizationMode.IDENTITY, + "STATE": NormalizationMode.IDENTITY, + "ACTION": NormalizationMode.IDENTITY, + } + preprocessor, _ = make_multi_task_dit_pre_post_processors(config=config, dataset_stats=None) + + batch = create_train_batch( + batch_size=batch_size, + n_obs_steps=n_obs_steps, + horizon=horizon, + state_dim=state_dim, + action_dim=action_dim, + ) + + # Process batch through preprocessor to tokenize task text + processed_batch = preprocessor(batch) + + # Test forward pass + loss, _ = policy.forward(processed_batch) + assert loss is not None + assert loss.item() is not None + + # Test inference + policy.eval() + # Use IDENTITY normalization when no stats provided + config.normalization_mapping = { + "VISUAL": NormalizationMode.IDENTITY, + "STATE": NormalizationMode.IDENTITY, + "ACTION": NormalizationMode.IDENTITY, + } + preprocessor, postprocessor = make_multi_task_dit_pre_post_processors(config=config, dataset_stats=None) + with torch.no_grad(): + observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim) + # Process observation through preprocessor + processed_obs = preprocessor(observation_batch) + selected_action = policy.select_action(processed_obs) + # Process action through postprocessor (PolicyAction is just a torch.Tensor) + processed_action = postprocessor(selected_action) + assert processed_action.shape == (batch_size, action_dim) + + +def test_multi_task_dit_policy_save_and_load(tmp_path): + """Test that the policy can be saved and loaded correctly.""" + root = tmp_path / "test_multi_task_dit_save_and_load" + + state_dim = 10 + action_dim = 10 + batch_size = 2 + n_obs_steps = 2 + horizon = 16 + n_action_steps = 8 + + config = create_config( + state_dim=state_dim, + action_dim=action_dim, + n_obs_steps=n_obs_steps, + horizon=horizon, + n_action_steps=n_action_steps, + ) + + policy = MultiTaskDiTPolicy(config=config) + policy.eval() + + # Get device before saving + device = next(policy.parameters()).device + + policy.save_pretrained(root) + loaded_policy = MultiTaskDiTPolicy.from_pretrained(root, config=config) + + # Explicitly move loaded_policy to the same device + loaded_policy.to(device) + loaded_policy.eval() + + batch = create_train_batch( + batch_size=batch_size, + n_obs_steps=n_obs_steps, + horizon=horizon, + state_dim=state_dim, + action_dim=action_dim, + ) + + # Use preprocessor to handle tokenization + config.normalization_mapping = { + "VISUAL": NormalizationMode.IDENTITY, + "STATE": NormalizationMode.IDENTITY, + "ACTION": NormalizationMode.IDENTITY, + } + preprocessor, postprocessor = make_multi_task_dit_pre_post_processors(config=config, dataset_stats=None) + + with torch.no_grad(): + with seeded_context(12): + # Process batch through preprocessor + processed_batch = preprocessor(batch) + # Move batch to the same device as the policy + for key in processed_batch: + if isinstance(processed_batch[key], torch.Tensor): + processed_batch[key] = processed_batch[key].to(device) + # Collect policy values before saving + loss, _ = policy.forward(processed_batch) + + observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim) + # Process observation through preprocessor + processed_obs = preprocessor(observation_batch) + actions = policy.select_action(processed_obs) + + with seeded_context(12): + # Process batch through preprocessor + processed_batch = preprocessor(batch) + # Collect policy values after loading + loaded_loss, _ = loaded_policy.forward(processed_batch) + + loaded_observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim) + processed_obs = preprocessor(loaded_observation_batch) + loaded_actions = loaded_policy.select_action(processed_obs) + + # Compare state dicts + assert policy.state_dict().keys() == loaded_policy.state_dict().keys() + for k in policy.state_dict(): + assert torch.allclose(policy.state_dict()[k], loaded_policy.state_dict()[k], atol=1e-6) + + # Compare values before and after saving and loading + assert torch.allclose(loss, loaded_loss) + assert torch.allclose(actions, loaded_actions) + + +def test_multi_task_dit_policy_get_optim_params(): + """Test that the policy returns correct optimizer parameter groups.""" + config = create_config( + state_dim=10, + action_dim=10, + n_obs_steps=2, + horizon=16, + n_action_steps=8, + ) + + policy = MultiTaskDiTPolicy(config=config) + param_groups = policy.get_optim_params() + + # Should have 2 parameter groups: non-vision and vision encoder + assert len(param_groups) == 2 + + # First group is non-vision params (no lr specified, will use default) + assert "params" in param_groups[0] + assert len(param_groups[0]["params"]) > 0 + + # Second group is vision encoder params with different lr + assert "params" in param_groups[1] + assert "lr" in param_groups[1] + expected_lr = config.optimizer_lr * config.vision_encoder_lr_multiplier + assert param_groups[1]["lr"] == expected_lr From 3b185f7f9d8faf16bbdf3a833e670feac08f881e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=9B=9B=E4=B8=83?= <41624527+SevenFo@users.noreply.github.com> Date: Sat, 28 Mar 2026 18:37:57 +0800 Subject: [PATCH 127/131] fix(datasets): remove unreachable timestamp branch in add_frame (#3163) * fix(datasets): remove unreachable timestamp branch in add_frame and document caller contract - Remove dead code: frame.pop("timestamp") branch in add_frame() could never execute because validate_frame() raises ValueError for any DEFAULT_FEATURES key (including timestamp) before we reach that line. - Expand add_frame() docstring: explicitly document that timestamp and frame_index must NOT be passed by the caller. - Add explanatory comment in validate_frame(): clarifies why DEFAULT_FEATURES are excluded from expected_features, preventing future re-introduction of the dead branch. The dead branch originated in #1200, which fixed a shape-(1,) mismatch for a code path that was subsequently made unreachable by a refactor of validate_frame. * chore(datasets): narrow PR scope * fix(datasets): move add_frame timestamp cleanup to dataset_writer --- src/lerobot/datasets/dataset_writer.py | 13 +++++++++++-- src/lerobot/datasets/feature_utils.py | 4 ++++ 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/src/lerobot/datasets/dataset_writer.py b/src/lerobot/datasets/dataset_writer.py index b74b18e0c..787ecd337 100644 --- a/src/lerobot/datasets/dataset_writer.py +++ b/src/lerobot/datasets/dataset_writer.py @@ -155,7 +155,16 @@ class DatasetWriter: self.image_writer.save_image(image=image, fpath=fpath, compress_level=compress_level) def add_frame(self, frame: dict) -> None: - """Add a frame to the episode_buffer. Images are written to a temporary directory.""" + """ + Add a single frame to the current episode buffer. + + Apart from images written to a temporary directory, nothing is written to disk + until ``save_episode()`` is called. + + The caller must provide all user-defined features plus ``"task"``, and must + not provide ``"timestamp"`` or ``"frame_index"``; those are computed + automatically. + """ # Convert torch to numpy if needed for name in frame: if isinstance(frame[name], torch.Tensor): @@ -168,7 +177,7 @@ class DatasetWriter: # Automatically add frame_index and timestamp to episode buffer frame_index = self.episode_buffer["size"] - timestamp = frame.pop("timestamp") if "timestamp" in frame else frame_index / self._meta.fps + timestamp = frame_index / self._meta.fps self.episode_buffer["frame_index"].append(frame_index) self.episode_buffer["timestamp"].append(timestamp) self.episode_buffer["task"].append(frame.pop("task")) diff --git a/src/lerobot/datasets/feature_utils.py b/src/lerobot/datasets/feature_utils.py index d9a3c6301..46154d92a 100644 --- a/src/lerobot/datasets/feature_utils.py +++ b/src/lerobot/datasets/feature_utils.py @@ -365,6 +365,10 @@ def get_delta_indices(delta_timestamps: dict[str, list[float]], fps: int) -> dic def validate_frame(frame: dict, features: dict) -> None: + # DEFAULT_FEATURES (timestamp, frame_index, episode_index, index, task_index) are + # auto-populated by the recording pipeline (add_frame / save_episode) and must not + # be supplied by the caller. Excluding them here means any frame dict that contains + # these keys will be rejected as extra features. expected_features = set(features) - set(DEFAULT_FEATURES) actual_features = set(frame) From 5d4fdf5088ed86aa6d0d85c426525a4b1d2e213d Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Mon, 30 Mar 2026 16:33:17 +0200 Subject: [PATCH 128/131] feat(scripts): add transformers version (#3248) * feat(scripts): add transformers and torch version * chore(scripts): remove pytorch --- src/lerobot/scripts/lerobot_info.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/lerobot/scripts/lerobot_info.py b/src/lerobot/scripts/lerobot_info.py index 879d392be..2092db48b 100644 --- a/src/lerobot/scripts/lerobot_info.py +++ b/src/lerobot/scripts/lerobot_info.py @@ -65,6 +65,7 @@ def get_sys_info() -> dict[str, str]: "Platform": platform.platform(), "Python version": platform.python_version(), "Huggingface Hub version": get_package_version("huggingface_hub"), + "Transformers version": get_package_version("transformers"), "Datasets version": get_package_version("datasets"), "Numpy version": get_package_version("numpy"), "FFmpeg version": get_ffmpeg_version(), From 720cf8e3a09f62fa95260cc49a7a30e5d0f7473a Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Mon, 30 Mar 2026 19:11:41 +0200 Subject: [PATCH 129/131] Revert "fix(deps): breaking change from transformers 5.4.0" (#3249) * Revert "fix(deps): breaking change from transformers 5.4.0 (#3231)" This reverts commit 07502868e58095b437e5dd5a480fecc65a6f29dc. * chore(dependecies): pin transformers to 5.3.0 temporarily --- pyproject.toml | 2 +- .../policies/groot/action_head/flow_matching_action_head.py | 3 ++- src/lerobot/policies/groot/groot_n1.py | 3 ++- src/lerobot/policies/wall_x/qwen_model/qwen2_5_vl_moe.py | 4 ++-- src/lerobot/policies/xvla/modeling_florence2.py | 4 ++-- 5 files changed, 9 insertions(+), 7 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index bed22a507..4a1efab30 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -99,7 +99,7 @@ dependencies = [ # Common pygame-dep = ["pygame>=2.5.1,<2.7.0"] placo-dep = ["placo>=0.9.6,<0.9.17"] -transformers-dep = ["transformers>=5.4.0,<6.0.0"] +transformers-dep = ["transformers==5.3.0"] # TODO(Steven): https://github.com/huggingface/lerobot/pull/3249 grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"] can-dep = ["python-can>=4.2.0,<5.0.0"] peft-dep = ["peft>=0.18.0,<1.0.0"] diff --git a/src/lerobot/policies/groot/action_head/flow_matching_action_head.py b/src/lerobot/policies/groot/action_head/flow_matching_action_head.py index 74d922988..bfc456ba0 100644 --- a/src/lerobot/policies/groot/action_head/flow_matching_action_head.py +++ b/src/lerobot/policies/groot/action_head/flow_matching_action_head.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import field +from dataclasses import dataclass, field from typing import TYPE_CHECKING import torch @@ -110,6 +110,7 @@ class MultiEmbodimentActionEncoder(nn.Module): return x +@dataclass class FlowmatchingActionHeadConfig(PretrainedConfig): """NOTE: N1.5 uses XEmbFlowmatchingPolicyHeadConfig as action head""" diff --git a/src/lerobot/policies/groot/groot_n1.py b/src/lerobot/policies/groot/groot_n1.py index 38512b8a8..06ff5a04d 100644 --- a/src/lerobot/policies/groot/groot_n1.py +++ b/src/lerobot/policies/groot/groot_n1.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import field +from dataclasses import dataclass, field from pathlib import Path from typing import TYPE_CHECKING @@ -173,6 +173,7 @@ N_COLOR_CHANNELS = 3 # config +@dataclass class GR00TN15Config(PretrainedConfig): model_type = "gr00t_n1_5" backbone_cfg: dict = field(init=False, metadata={"help": "Backbone configuration."}) diff --git a/src/lerobot/policies/wall_x/qwen_model/qwen2_5_vl_moe.py b/src/lerobot/policies/wall_x/qwen_model/qwen2_5_vl_moe.py index a80096514..ecf3eb371 100644 --- a/src/lerobot/policies/wall_x/qwen_model/qwen2_5_vl_moe.py +++ b/src/lerobot/policies/wall_x/qwen_model/qwen2_5_vl_moe.py @@ -22,7 +22,7 @@ from transformers.utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, - is_flash_attn_greater_or_equal, + is_flash_attn_greater_or_equal_2_10, is_torchdynamo_compiling, logging, replace_return_docstrings, @@ -890,7 +890,7 @@ class Qwen2_5_VLFlashAttention2(Qwen2_5_VLAttention): # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal("2.1.0") + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() def forward( self, diff --git a/src/lerobot/policies/xvla/modeling_florence2.py b/src/lerobot/policies/xvla/modeling_florence2.py index 81f9c8234..e33efe5c3 100644 --- a/src/lerobot/policies/xvla/modeling_florence2.py +++ b/src/lerobot/policies/xvla/modeling_florence2.py @@ -45,7 +45,7 @@ from transformers.utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, - is_flash_attn_greater_or_equal, + is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, ) @@ -909,7 +909,7 @@ class Florence2FlashAttention2(Florence2Attention): # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal("2.1.0") + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) From 9300352876f68ed7c12726d6fb8ff45773023b7c Mon Sep 17 00:00:00 2001 From: Jai Kumaar Ratadia Date: Tue, 31 Mar 2026 11:16:34 +0100 Subject: [PATCH 130/131] Fix SO-101 assembly instruction order to match videos (#3242) * Fix SO-101 assembly instruction order to match videos Motor horn installation steps were listed after placing motors into the housing, but the assembly videos show installing horns first. Reordered steps to match the videos, which is also the easier approach since horns are harder to attach once the motor is seated. Added missing detail that bottom horns don't require screws. * Update docs/source/so101.mdx Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Jai Kumaar Ratadia --------- Signed-off-by: Jai Kumaar Ratadia Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com> --- docs/source/so101.mdx | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/docs/source/so101.mdx b/docs/source/so101.mdx index 7c9df588a..1274b8282 100644 --- a/docs/source/so101.mdx +++ b/docs/source/so101.mdx @@ -236,10 +236,10 @@ It is advisable to install one 3-pin cable in the motor after placing them befor ### Joint 1 +- Install both motor horns. Secure the top horn with a M3x6mm screw. No screws are required for the bottom horn. - Place the first motor into the base. - Fasten the motor with 4 M2x6mm screws (smallest screws). Two from the top and two from the bottom. - Slide over the first motor holder and fasten it using two M2x6mm screws (one on each side). -- Install both motor horns, securing the top horn with a M3x6mm screw. - Attach the shoulder part. - Tighten the shoulder part with 4 M3x6mm screws on top and 4 M3x6mm screws on the bottom - Add the shoulder motor holder. @@ -255,9 +255,9 @@ It is advisable to install one 3-pin cable in the motor after placing them befor ### Joint 2 +- Install both motor horns. Secure the top horn with a M3x6mm screw. No screws are required for the bottom horn. - Slide the second motor in from the top. - Fasten the second motor with 4 M2x6mm screws. -- Attach both motor horns to motor 2, again use the M3x6mm horn screw. - Attach the upper arm with 4 M3x6mm screws on each side.

@@ -271,8 +271,8 @@ It is advisable to install one 3-pin cable in the motor after placing them befor ### Joint 3 -- Insert motor 3 and fasten using 4 M2x6mm screws -- Attach both motor horns to motor 3 and secure one again with a M3x6mm horn screw. +- Install both motor horns. Secure the top horn with a M3x6mm screw. No screws are required for the bottom horn. +- Insert motor 3 and fasten using 4 M2x6mm screws. - Connect the forearm to motor 3 using 4 M3x6mm screws on each side.
@@ -286,9 +286,10 @@ It is advisable to install one 3-pin cable in the motor after placing them befor ### Joint 4 +- Install both motor horns. Secure the top horn with a M3x6mm screw. No screws are required for the bottom horn. - Slide over motor holder 4. - Slide in motor 4. -- Fasten motor 4 with 4 M2x6mm screws and attach its motor horns, use a M3x6mm horn screw. +- Fasten motor 4 with 4 M2x6mm screws.