From 9919b16b3693a756388dd2cccdd04396308fd0ae Mon Sep 17 00:00:00 2001 From: sato_shinji Date: Tue, 20 Jan 2026 14:17:38 +0000 Subject: [PATCH] fix: ensure action tensors are moved to client_device in async training (#2792) * feat(async_inference): server always sends CPU tensors, client handles device conversion * fix:fix the type annotation of RawObservation in src/lerobot/async_inference/helpers.py * update the import of robot_client --------- Co-authored-by: Sato shinji Co-authored-by: Steven Palma Co-authored-by: KB --- docs/source/async.mdx | 1 + examples/tutorial/async-inf/robot_client.py | 1 + src/lerobot/async_inference/configs.py | 10 ++++++++++ src/lerobot/async_inference/helpers.py | 5 +++-- src/lerobot/async_inference/policy_server.py | 2 ++ src/lerobot/async_inference/robot_client.py | 20 ++++++++++++++++++-- tests/async_inference/test_e2e.py | 10 ++++++++-- 7 files changed, 43 insertions(+), 6 deletions(-) diff --git a/docs/source/async.mdx b/docs/source/async.mdx index 1d3e0edbf..3244fc2a3 100644 --- a/docs/source/async.mdx +++ b/docs/source/async.mdx @@ -195,6 +195,7 @@ client_cfg = RobotClientConfig( robot=robot_cfg, server_address="localhost:8080", policy_device="mps", + client_device="cpu", policy_type="smolvla", pretrained_name_or_path="/smolvla_async", chunk_size_threshold=0.5, diff --git a/examples/tutorial/async-inf/robot_client.py b/examples/tutorial/async-inf/robot_client.py index eb3751169..db6ead3fe 100644 --- a/examples/tutorial/async-inf/robot_client.py +++ b/examples/tutorial/async-inf/robot_client.py @@ -30,6 +30,7 @@ def main(): robot=robot_cfg, server_address=server_address, policy_device="mps", + client_device="cpu", policy_type="act", pretrained_name_or_path="/robot_learning_tutorial_act", chunk_size_threshold=0.5, # g diff --git a/src/lerobot/async_inference/configs.py b/src/lerobot/async_inference/configs.py index d1768a323..2e3fe576d 100644 --- a/src/lerobot/async_inference/configs.py +++ b/src/lerobot/async_inference/configs.py @@ -126,6 +126,12 @@ class RobotClientConfig: # Device configuration policy_device: str = field(default="cpu", metadata={"help": "Device for policy inference"}) + client_device: str = field( + default="cpu", + metadata={ + "help": "Device to move actions to after receiving from server (e.g., for downstream planners)" + }, + ) # Control behavior configuration chunk_size_threshold: float = field(default=0.5, metadata={"help": "Threshold for chunk size control"}) @@ -161,6 +167,9 @@ class RobotClientConfig: if not self.policy_device: raise ValueError("policy_device cannot be empty") + if not self.client_device: + raise ValueError("client_device cannot be empty") + if self.chunk_size_threshold < 0 or self.chunk_size_threshold > 1: raise ValueError(f"chunk_size_threshold must be between 0 and 1, got {self.chunk_size_threshold}") @@ -184,6 +193,7 @@ class RobotClientConfig: "policy_type": self.policy_type, "pretrained_name_or_path": self.pretrained_name_or_path, "policy_device": self.policy_device, + "client_device": self.client_device, "chunk_size_threshold": self.chunk_size_threshold, "fps": self.fps, "actions_per_chunk": self.actions_per_chunk, diff --git a/src/lerobot/async_inference/helpers.py b/src/lerobot/async_inference/helpers.py index 2158f51ac..8b12920d9 100644 --- a/src/lerobot/async_inference/helpers.py +++ b/src/lerobot/async_inference/helpers.py @@ -18,6 +18,7 @@ import os import time from dataclasses import dataclass, field from pathlib import Path +from typing import Any import torch @@ -39,8 +40,8 @@ from lerobot.utils.utils import init_logging Action = torch.Tensor -# observation as received from the robot -RawObservation = dict[str, torch.Tensor] +# observation as received from the robot (can be numpy arrays, floats, etc.) +RawObservation = dict[str, Any] # observation as those recorded in LeRobot dataset (keys are different) LeRobotObservation = dict[str, torch.Tensor] diff --git a/src/lerobot/async_inference/policy_server.py b/src/lerobot/async_inference/policy_server.py index ab2e6bcd8..aedce2a74 100644 --- a/src/lerobot/async_inference/policy_server.py +++ b/src/lerobot/async_inference/policy_server.py @@ -381,6 +381,8 @@ class PolicyServer(services_pb2_grpc.AsyncInferenceServicer): action_tensor = torch.stack(processed_actions, dim=1).squeeze(0) self.logger.debug(f"Postprocessed action shape: {action_tensor.shape}") + action_tensor = action_tensor.detach().cpu() + """5. Convert to TimedAction list""" action_chunk = self._time_action_chunk( observation_t.get_timestamp(), list(action_tensor), observation_t.get_timestep() diff --git a/src/lerobot/async_inference/robot_client.py b/src/lerobot/async_inference/robot_client.py index f26639dc1..e4d21652a 100644 --- a/src/lerobot/async_inference/robot_client.py +++ b/src/lerobot/async_inference/robot_client.py @@ -25,6 +25,7 @@ python src/lerobot/async_inference/robot_client.py \ --policy_type=act \ --pretrained_name_or_path=user/model \ --policy_device=mps \ + --client_device=cpu \ --actions_per_chunk=50 \ --chunk_size_threshold=0.5 \ --aggregate_fn_name=weighted_average \ @@ -40,6 +41,7 @@ from collections.abc import Callable from dataclasses import asdict from pprint import pformat from queue import Queue +from typing import Any import draccus import grpc @@ -47,7 +49,6 @@ import torch from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401 from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401 -from lerobot.processor import RobotAction from lerobot.robots import ( # noqa: F401 Robot, RobotConfig, @@ -285,6 +286,21 @@ class RobotClient: timed_actions = pickle.loads(actions_chunk.data) # nosec deserialize_time = time.perf_counter() - deserialize_start + # Log device type of received actions + if len(timed_actions) > 0: + received_device = timed_actions[0].get_action().device.type + self.logger.debug(f"Received actions on device: {received_device}") + + # Move actions to client_device (e.g., for downstream planners that need GPU) + client_device = self.config.client_device + if client_device != "cpu": + for timed_action in timed_actions: + if timed_action.get_action().device.type != client_device: + timed_action.action = timed_action.get_action().to(client_device) + self.logger.debug(f"Converted actions to device: {client_device}") + else: + self.logger.debug(f"Actions kept on device: {client_device}") + self.action_chunk_size = max(self.action_chunk_size, len(timed_actions)) # Calculate network latency if we have matching observations @@ -351,7 +367,7 @@ class RobotClient: action = {key: action_tensor[i].item() for i, key in enumerate(self.robot.action_features)} return action - def control_loop_action(self, verbose: bool = False) -> RobotAction: + def control_loop_action(self, verbose: bool = False) -> dict[str, Any]: """Reading and performing actions in local queue""" # Lock only for queue operations diff --git a/tests/async_inference/test_e2e.py b/tests/async_inference/test_e2e.py index 11941ce32..54ca29b48 100644 --- a/tests/async_inference/test_e2e.py +++ b/tests/async_inference/test_e2e.py @@ -144,12 +144,18 @@ def test_async_inference_e2e(monkeypatch): client = RobotClient(client_config) assert client.start(), "Client failed initial handshake with the server" - # Track action chunks received without modifying RobotClient - action_chunks_received = {"count": 0} + # Track action chunks received and verify device type + action_chunks_received = {"count": 0, "actions_on_cpu": True} original_aggregate = client._aggregate_action_queues def counting_aggregate(*args, **kwargs): action_chunks_received["count"] += 1 + # Check that all received actions are on CPU + if args: + for timed_action in args[0]: # args[0] is the list of TimedAction + action_tensor = timed_action.get_action() + if action_tensor.device.type != "cpu": + action_chunks_received["actions_on_cpu"] = False return original_aggregate(*args, **kwargs) monkeypatch.setattr(client, "_aggregate_action_queues", counting_aggregate)