fix: ensure action tensors are moved to client_device in async training (#2792)

* feat(async_inference): server always sends CPU tensors, client handles device conversion

* fix:fix the type annotation of RawObservation in src/lerobot/async_inference/helpers.py

* update the import of robot_client

---------

Co-authored-by: Sato shinji <wwwsatoshinji@gmail.com>
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
Co-authored-by: KB <kevin-brian.n-diaye@epita.fr>
This commit is contained in:
sato_shinji
2026-01-20 14:17:38 +00:00
committed by GitHub
parent d36dfcdf71
commit 9919b16b36
7 changed files with 43 additions and 6 deletions
+1
View File
@@ -195,6 +195,7 @@ client_cfg = RobotClientConfig(
robot=robot_cfg, robot=robot_cfg,
server_address="localhost:8080", server_address="localhost:8080",
policy_device="mps", policy_device="mps",
client_device="cpu",
policy_type="smolvla", policy_type="smolvla",
pretrained_name_or_path="<user>/smolvla_async", pretrained_name_or_path="<user>/smolvla_async",
chunk_size_threshold=0.5, chunk_size_threshold=0.5,
@@ -30,6 +30,7 @@ def main():
robot=robot_cfg, robot=robot_cfg,
server_address=server_address, server_address=server_address,
policy_device="mps", policy_device="mps",
client_device="cpu",
policy_type="act", policy_type="act",
pretrained_name_or_path="<user>/robot_learning_tutorial_act", pretrained_name_or_path="<user>/robot_learning_tutorial_act",
chunk_size_threshold=0.5, # g chunk_size_threshold=0.5, # g
+10
View File
@@ -126,6 +126,12 @@ class RobotClientConfig:
# Device configuration # Device configuration
policy_device: str = field(default="cpu", metadata={"help": "Device for policy inference"}) policy_device: str = field(default="cpu", metadata={"help": "Device for policy inference"})
client_device: str = field(
default="cpu",
metadata={
"help": "Device to move actions to after receiving from server (e.g., for downstream planners)"
},
)
# Control behavior configuration # Control behavior configuration
chunk_size_threshold: float = field(default=0.5, metadata={"help": "Threshold for chunk size control"}) chunk_size_threshold: float = field(default=0.5, metadata={"help": "Threshold for chunk size control"})
@@ -161,6 +167,9 @@ class RobotClientConfig:
if not self.policy_device: if not self.policy_device:
raise ValueError("policy_device cannot be empty") raise ValueError("policy_device cannot be empty")
if not self.client_device:
raise ValueError("client_device cannot be empty")
if self.chunk_size_threshold < 0 or self.chunk_size_threshold > 1: if self.chunk_size_threshold < 0 or self.chunk_size_threshold > 1:
raise ValueError(f"chunk_size_threshold must be between 0 and 1, got {self.chunk_size_threshold}") raise ValueError(f"chunk_size_threshold must be between 0 and 1, got {self.chunk_size_threshold}")
@@ -184,6 +193,7 @@ class RobotClientConfig:
"policy_type": self.policy_type, "policy_type": self.policy_type,
"pretrained_name_or_path": self.pretrained_name_or_path, "pretrained_name_or_path": self.pretrained_name_or_path,
"policy_device": self.policy_device, "policy_device": self.policy_device,
"client_device": self.client_device,
"chunk_size_threshold": self.chunk_size_threshold, "chunk_size_threshold": self.chunk_size_threshold,
"fps": self.fps, "fps": self.fps,
"actions_per_chunk": self.actions_per_chunk, "actions_per_chunk": self.actions_per_chunk,
+3 -2
View File
@@ -18,6 +18,7 @@ import os
import time import time
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from typing import Any
import torch import torch
@@ -39,8 +40,8 @@ from lerobot.utils.utils import init_logging
Action = torch.Tensor Action = torch.Tensor
# observation as received from the robot # observation as received from the robot (can be numpy arrays, floats, etc.)
RawObservation = dict[str, torch.Tensor] RawObservation = dict[str, Any]
# observation as those recorded in LeRobot dataset (keys are different) # observation as those recorded in LeRobot dataset (keys are different)
LeRobotObservation = dict[str, torch.Tensor] LeRobotObservation = dict[str, torch.Tensor]
@@ -381,6 +381,8 @@ class PolicyServer(services_pb2_grpc.AsyncInferenceServicer):
action_tensor = torch.stack(processed_actions, dim=1).squeeze(0) action_tensor = torch.stack(processed_actions, dim=1).squeeze(0)
self.logger.debug(f"Postprocessed action shape: {action_tensor.shape}") self.logger.debug(f"Postprocessed action shape: {action_tensor.shape}")
action_tensor = action_tensor.detach().cpu()
"""5. Convert to TimedAction list""" """5. Convert to TimedAction list"""
action_chunk = self._time_action_chunk( action_chunk = self._time_action_chunk(
observation_t.get_timestamp(), list(action_tensor), observation_t.get_timestep() observation_t.get_timestamp(), list(action_tensor), observation_t.get_timestep()
+18 -2
View File
@@ -25,6 +25,7 @@ python src/lerobot/async_inference/robot_client.py \
--policy_type=act \ --policy_type=act \
--pretrained_name_or_path=user/model \ --pretrained_name_or_path=user/model \
--policy_device=mps \ --policy_device=mps \
--client_device=cpu \
--actions_per_chunk=50 \ --actions_per_chunk=50 \
--chunk_size_threshold=0.5 \ --chunk_size_threshold=0.5 \
--aggregate_fn_name=weighted_average \ --aggregate_fn_name=weighted_average \
@@ -40,6 +41,7 @@ from collections.abc import Callable
from dataclasses import asdict from dataclasses import asdict
from pprint import pformat from pprint import pformat
from queue import Queue from queue import Queue
from typing import Any
import draccus import draccus
import grpc import grpc
@@ -47,7 +49,6 @@ import torch
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401 from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401 from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
from lerobot.processor import RobotAction
from lerobot.robots import ( # noqa: F401 from lerobot.robots import ( # noqa: F401
Robot, Robot,
RobotConfig, RobotConfig,
@@ -285,6 +286,21 @@ class RobotClient:
timed_actions = pickle.loads(actions_chunk.data) # nosec timed_actions = pickle.loads(actions_chunk.data) # nosec
deserialize_time = time.perf_counter() - deserialize_start deserialize_time = time.perf_counter() - deserialize_start
# Log device type of received actions
if len(timed_actions) > 0:
received_device = timed_actions[0].get_action().device.type
self.logger.debug(f"Received actions on device: {received_device}")
# Move actions to client_device (e.g., for downstream planners that need GPU)
client_device = self.config.client_device
if client_device != "cpu":
for timed_action in timed_actions:
if timed_action.get_action().device.type != client_device:
timed_action.action = timed_action.get_action().to(client_device)
self.logger.debug(f"Converted actions to device: {client_device}")
else:
self.logger.debug(f"Actions kept on device: {client_device}")
self.action_chunk_size = max(self.action_chunk_size, len(timed_actions)) self.action_chunk_size = max(self.action_chunk_size, len(timed_actions))
# Calculate network latency if we have matching observations # Calculate network latency if we have matching observations
@@ -351,7 +367,7 @@ class RobotClient:
action = {key: action_tensor[i].item() for i, key in enumerate(self.robot.action_features)} action = {key: action_tensor[i].item() for i, key in enumerate(self.robot.action_features)}
return action return action
def control_loop_action(self, verbose: bool = False) -> RobotAction: def control_loop_action(self, verbose: bool = False) -> dict[str, Any]:
"""Reading and performing actions in local queue""" """Reading and performing actions in local queue"""
# Lock only for queue operations # Lock only for queue operations
+8 -2
View File
@@ -144,12 +144,18 @@ def test_async_inference_e2e(monkeypatch):
client = RobotClient(client_config) client = RobotClient(client_config)
assert client.start(), "Client failed initial handshake with the server" assert client.start(), "Client failed initial handshake with the server"
# Track action chunks received without modifying RobotClient # Track action chunks received and verify device type
action_chunks_received = {"count": 0} action_chunks_received = {"count": 0, "actions_on_cpu": True}
original_aggregate = client._aggregate_action_queues original_aggregate = client._aggregate_action_queues
def counting_aggregate(*args, **kwargs): def counting_aggregate(*args, **kwargs):
action_chunks_received["count"] += 1 action_chunks_received["count"] += 1
# Check that all received actions are on CPU
if args:
for timed_action in args[0]: # args[0] is the list of TimedAction
action_tensor = timed_action.get_action()
if action_tensor.device.type != "cpu":
action_chunks_received["actions_on_cpu"] = False
return original_aggregate(*args, **kwargs) return original_aggregate(*args, **kwargs)
monkeypatch.setattr(client, "_aggregate_action_queues", counting_aggregate) monkeypatch.setattr(client, "_aggregate_action_queues", counting_aggregate)