Compare commits

...

4 Commits

Author SHA1 Message Date
Steven Palma 0b067df57d feat(robots): add context managers (#2828) 2026-01-20 18:02:38 +01:00
Tommy in Tongji 9ca680dce2 Update README.md (#2827)
Add Chinese doc link.

Signed-off-by: Tommy in Tongji <36354458+TommyZihao@users.noreply.github.com>
2026-01-20 17:54:24 +01:00
sato_shinji 9919b16b36 fix: ensure action tensors are moved to client_device in async training (#2792)
* feat(async_inference): server always sends CPU tensors, client handles device conversion

* fix:fix the type annotation of RawObservation in src/lerobot/async_inference/helpers.py

* update the import of robot_client

---------

Co-authored-by: Sato shinji <wwwsatoshinji@gmail.com>
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
Co-authored-by: KB <kevin-brian.n-diaye@epita.fr>
2026-01-20 15:17:38 +01:00
Caroline Pascal d36dfcdf71 fix(discord link): fixing discord link in CONTRIBUTING.md (#2826)
Signed-off-by: Caroline Pascal <caroline8.pascal@gmail.com>
2026-01-20 15:00:45 +01:00
11 changed files with 97 additions and 7 deletions
+1 -1
View File
@@ -14,7 +14,7 @@ You can contribute in many ways:
- **Documentation:** Improve examples, guides, and docstrings.
- **Feedback:** Submit tickets related to bugs or desired new features.
If you are unsure where to start, join our [Discord Channel](https://discord.gg/JkrYNdmw).
If you are unsure where to start, join our [Discord Channel](https://discord.gg/q8Dzzpym3f).
## Development Setup
+1
View File
@@ -128,6 +128,7 @@ Learn how to implement your own simulation environment or benchmark and distribu
## Resources
- **[Documentation](https://huggingface.co/docs/lerobot/index):** The complete guide to tutorials & API.
- **[Chinese Tutorials: LeRobot+SO-ARM101中文教程-同济子豪兄](https://zihao-ai.feishu.cn/wiki/space/7589642043471924447)** Detailed doc for assembling, teleoperate, dataset, train, deploy. Verified by Seed Studio and 5 global hackathon players.
- **[Discord](https://discord.gg/q8Dzzpym3f):** Join the `LeRobot` server to discuss with the community.
- **[X](https://x.com/LeRobotHF):** Follow us on X to stay up-to-date with the latest developments.
- **[Robot Learning Tutorial](https://huggingface.co/spaces/lerobot/robot-learning-tutorial):** A free, hands-on course to learn robot learning using LeRobot.
+1
View File
@@ -195,6 +195,7 @@ client_cfg = RobotClientConfig(
robot=robot_cfg,
server_address="localhost:8080",
policy_device="mps",
client_device="cpu",
policy_type="smolvla",
pretrained_name_or_path="<user>/smolvla_async",
chunk_size_threshold=0.5,
@@ -30,6 +30,7 @@ def main():
robot=robot_cfg,
server_address=server_address,
policy_device="mps",
client_device="cpu",
policy_type="act",
pretrained_name_or_path="<user>/robot_learning_tutorial_act",
chunk_size_threshold=0.5, # g
+10
View File
@@ -126,6 +126,12 @@ class RobotClientConfig:
# Device configuration
policy_device: str = field(default="cpu", metadata={"help": "Device for policy inference"})
client_device: str = field(
default="cpu",
metadata={
"help": "Device to move actions to after receiving from server (e.g., for downstream planners)"
},
)
# Control behavior configuration
chunk_size_threshold: float = field(default=0.5, metadata={"help": "Threshold for chunk size control"})
@@ -161,6 +167,9 @@ class RobotClientConfig:
if not self.policy_device:
raise ValueError("policy_device cannot be empty")
if not self.client_device:
raise ValueError("client_device cannot be empty")
if self.chunk_size_threshold < 0 or self.chunk_size_threshold > 1:
raise ValueError(f"chunk_size_threshold must be between 0 and 1, got {self.chunk_size_threshold}")
@@ -184,6 +193,7 @@ class RobotClientConfig:
"policy_type": self.policy_type,
"pretrained_name_or_path": self.pretrained_name_or_path,
"policy_device": self.policy_device,
"client_device": self.client_device,
"chunk_size_threshold": self.chunk_size_threshold,
"fps": self.fps,
"actions_per_chunk": self.actions_per_chunk,
+3 -2
View File
@@ -18,6 +18,7 @@ import os
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
import torch
@@ -39,8 +40,8 @@ from lerobot.utils.utils import init_logging
Action = torch.Tensor
# observation as received from the robot
RawObservation = dict[str, torch.Tensor]
# observation as received from the robot (can be numpy arrays, floats, etc.)
RawObservation = dict[str, Any]
# observation as those recorded in LeRobot dataset (keys are different)
LeRobotObservation = dict[str, torch.Tensor]
@@ -381,6 +381,8 @@ class PolicyServer(services_pb2_grpc.AsyncInferenceServicer):
action_tensor = torch.stack(processed_actions, dim=1).squeeze(0)
self.logger.debug(f"Postprocessed action shape: {action_tensor.shape}")
action_tensor = action_tensor.detach().cpu()
"""5. Convert to TimedAction list"""
action_chunk = self._time_action_chunk(
observation_t.get_timestamp(), list(action_tensor), observation_t.get_timestep()
+18 -2
View File
@@ -25,6 +25,7 @@ python src/lerobot/async_inference/robot_client.py \
--policy_type=act \
--pretrained_name_or_path=user/model \
--policy_device=mps \
--client_device=cpu \
--actions_per_chunk=50 \
--chunk_size_threshold=0.5 \
--aggregate_fn_name=weighted_average \
@@ -40,6 +41,7 @@ from collections.abc import Callable
from dataclasses import asdict
from pprint import pformat
from queue import Queue
from typing import Any
import draccus
import grpc
@@ -47,7 +49,6 @@ import torch
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
from lerobot.processor import RobotAction
from lerobot.robots import ( # noqa: F401
Robot,
RobotConfig,
@@ -285,6 +286,21 @@ class RobotClient:
timed_actions = pickle.loads(actions_chunk.data) # nosec
deserialize_time = time.perf_counter() - deserialize_start
# Log device type of received actions
if len(timed_actions) > 0:
received_device = timed_actions[0].get_action().device.type
self.logger.debug(f"Received actions on device: {received_device}")
# Move actions to client_device (e.g., for downstream planners that need GPU)
client_device = self.config.client_device
if client_device != "cpu":
for timed_action in timed_actions:
if timed_action.get_action().device.type != client_device:
timed_action.action = timed_action.get_action().to(client_device)
self.logger.debug(f"Converted actions to device: {client_device}")
else:
self.logger.debug(f"Actions kept on device: {client_device}")
self.action_chunk_size = max(self.action_chunk_size, len(timed_actions))
# Calculate network latency if we have matching observations
@@ -351,7 +367,7 @@ class RobotClient:
action = {key: action_tensor[i].item() for i, key in enumerate(self.robot.action_features)}
return action
def control_loop_action(self, verbose: bool = False) -> RobotAction:
def control_loop_action(self, verbose: bool = False) -> dict[str, Any]:
"""Reading and performing actions in local queue"""
# Lock only for queue operations
+26
View File
@@ -58,6 +58,32 @@ class Robot(abc.ABC):
def __str__(self) -> str:
return f"{self.id} {self.__class__.__name__}"
def __enter__(self):
"""
Context manager entry.
Automatically connects to the camera.
"""
self.connect()
return self
def __exit__(self, exc_type, exc_value, traceback) -> None:
"""
Context manager exit.
Automatically disconnects, ensuring resources are released even on error.
"""
self.disconnect()
def __del__(self) -> None:
"""
Destructor safety net.
Attempts to disconnect if the object is garbage collected without cleanup.
"""
try:
if self.is_connected:
self.disconnect()
except Exception: # nosec B110
pass
# TODO(aliberts): create a proper Feature class for this that links with datasets
@property
@abc.abstractmethod
+26
View File
@@ -58,6 +58,32 @@ class Teleoperator(abc.ABC):
def __str__(self) -> str:
return f"{self.id} {self.__class__.__name__}"
def __enter__(self):
"""
Context manager entry.
Automatically connects to the camera.
"""
self.connect()
return self
def __exit__(self, exc_type, exc_value, traceback) -> None:
"""
Context manager exit.
Automatically disconnects, ensuring resources are released even on error.
"""
self.disconnect()
def __del__(self) -> None:
"""
Destructor safety net.
Attempts to disconnect if the object is garbage collected without cleanup.
"""
try:
if self.is_connected:
self.disconnect()
except Exception: # nosec B110
pass
@property
@abc.abstractmethod
def action_features(self) -> dict:
+8 -2
View File
@@ -144,12 +144,18 @@ def test_async_inference_e2e(monkeypatch):
client = RobotClient(client_config)
assert client.start(), "Client failed initial handshake with the server"
# Track action chunks received without modifying RobotClient
action_chunks_received = {"count": 0}
# Track action chunks received and verify device type
action_chunks_received = {"count": 0, "actions_on_cpu": True}
original_aggregate = client._aggregate_action_queues
def counting_aggregate(*args, **kwargs):
action_chunks_received["count"] += 1
# Check that all received actions are on CPU
if args:
for timed_action in args[0]: # args[0] is the list of TimedAction
action_tensor = timed_action.get_action()
if action_tensor.device.type != "cpu":
action_chunks_received["actions_on_cpu"] = False
return original_aggregate(*args, **kwargs)
monkeypatch.setattr(client, "_aggregate_action_queues", counting_aggregate)