fix: ensure action tensors are moved to client_device in async training (#2792)

* feat(async_inference): server always sends CPU tensors, client handles device conversion

* fix:fix the type annotation of RawObservation in src/lerobot/async_inference/helpers.py

* update the import of robot_client

---------

Co-authored-by: Sato shinji <wwwsatoshinji@gmail.com>
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
Co-authored-by: KB <kevin-brian.n-diaye@epita.fr>
This commit is contained in:
sato_shinji
2026-01-20 14:17:38 +00:00
committed by GitHub
parent d36dfcdf71
commit 9919b16b36
7 changed files with 43 additions and 6 deletions
+10
View File
@@ -126,6 +126,12 @@ class RobotClientConfig:
# Device configuration
policy_device: str = field(default="cpu", metadata={"help": "Device for policy inference"})
client_device: str = field(
default="cpu",
metadata={
"help": "Device to move actions to after receiving from server (e.g., for downstream planners)"
},
)
# Control behavior configuration
chunk_size_threshold: float = field(default=0.5, metadata={"help": "Threshold for chunk size control"})
@@ -161,6 +167,9 @@ class RobotClientConfig:
if not self.policy_device:
raise ValueError("policy_device cannot be empty")
if not self.client_device:
raise ValueError("client_device cannot be empty")
if self.chunk_size_threshold < 0 or self.chunk_size_threshold > 1:
raise ValueError(f"chunk_size_threshold must be between 0 and 1, got {self.chunk_size_threshold}")
@@ -184,6 +193,7 @@ class RobotClientConfig:
"policy_type": self.policy_type,
"pretrained_name_or_path": self.pretrained_name_or_path,
"policy_device": self.policy_device,
"client_device": self.client_device,
"chunk_size_threshold": self.chunk_size_threshold,
"fps": self.fps,
"actions_per_chunk": self.actions_per_chunk,
+3 -2
View File
@@ -18,6 +18,7 @@ import os
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
import torch
@@ -39,8 +40,8 @@ from lerobot.utils.utils import init_logging
Action = torch.Tensor
# observation as received from the robot
RawObservation = dict[str, torch.Tensor]
# observation as received from the robot (can be numpy arrays, floats, etc.)
RawObservation = dict[str, Any]
# observation as those recorded in LeRobot dataset (keys are different)
LeRobotObservation = dict[str, torch.Tensor]
@@ -381,6 +381,8 @@ class PolicyServer(services_pb2_grpc.AsyncInferenceServicer):
action_tensor = torch.stack(processed_actions, dim=1).squeeze(0)
self.logger.debug(f"Postprocessed action shape: {action_tensor.shape}")
action_tensor = action_tensor.detach().cpu()
"""5. Convert to TimedAction list"""
action_chunk = self._time_action_chunk(
observation_t.get_timestamp(), list(action_tensor), observation_t.get_timestep()
+18 -2
View File
@@ -25,6 +25,7 @@ python src/lerobot/async_inference/robot_client.py \
--policy_type=act \
--pretrained_name_or_path=user/model \
--policy_device=mps \
--client_device=cpu \
--actions_per_chunk=50 \
--chunk_size_threshold=0.5 \
--aggregate_fn_name=weighted_average \
@@ -40,6 +41,7 @@ from collections.abc import Callable
from dataclasses import asdict
from pprint import pformat
from queue import Queue
from typing import Any
import draccus
import grpc
@@ -47,7 +49,6 @@ import torch
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
from lerobot.processor import RobotAction
from lerobot.robots import ( # noqa: F401
Robot,
RobotConfig,
@@ -285,6 +286,21 @@ class RobotClient:
timed_actions = pickle.loads(actions_chunk.data) # nosec
deserialize_time = time.perf_counter() - deserialize_start
# Log device type of received actions
if len(timed_actions) > 0:
received_device = timed_actions[0].get_action().device.type
self.logger.debug(f"Received actions on device: {received_device}")
# Move actions to client_device (e.g., for downstream planners that need GPU)
client_device = self.config.client_device
if client_device != "cpu":
for timed_action in timed_actions:
if timed_action.get_action().device.type != client_device:
timed_action.action = timed_action.get_action().to(client_device)
self.logger.debug(f"Converted actions to device: {client_device}")
else:
self.logger.debug(f"Actions kept on device: {client_device}")
self.action_chunk_size = max(self.action_chunk_size, len(timed_actions))
# Calculate network latency if we have matching observations
@@ -351,7 +367,7 @@ class RobotClient:
action = {key: action_tensor[i].item() for i, key in enumerate(self.robot.action_features)}
return action
def control_loop_action(self, verbose: bool = False) -> RobotAction:
def control_loop_action(self, verbose: bool = False) -> dict[str, Any]:
"""Reading and performing actions in local queue"""
# Lock only for queue operations