mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 20:19:43 +00:00
fix: ensure action tensors are moved to client_device in async training (#2792)
* feat(async_inference): server always sends CPU tensors, client handles device conversion * fix:fix the type annotation of RawObservation in src/lerobot/async_inference/helpers.py * update the import of robot_client --------- Co-authored-by: Sato shinji <wwwsatoshinji@gmail.com> Co-authored-by: Steven Palma <imstevenpmwork@ieee.org> Co-authored-by: KB <kevin-brian.n-diaye@epita.fr>
This commit is contained in:
@@ -195,6 +195,7 @@ client_cfg = RobotClientConfig(
|
|||||||
robot=robot_cfg,
|
robot=robot_cfg,
|
||||||
server_address="localhost:8080",
|
server_address="localhost:8080",
|
||||||
policy_device="mps",
|
policy_device="mps",
|
||||||
|
client_device="cpu",
|
||||||
policy_type="smolvla",
|
policy_type="smolvla",
|
||||||
pretrained_name_or_path="<user>/smolvla_async",
|
pretrained_name_or_path="<user>/smolvla_async",
|
||||||
chunk_size_threshold=0.5,
|
chunk_size_threshold=0.5,
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ def main():
|
|||||||
robot=robot_cfg,
|
robot=robot_cfg,
|
||||||
server_address=server_address,
|
server_address=server_address,
|
||||||
policy_device="mps",
|
policy_device="mps",
|
||||||
|
client_device="cpu",
|
||||||
policy_type="act",
|
policy_type="act",
|
||||||
pretrained_name_or_path="<user>/robot_learning_tutorial_act",
|
pretrained_name_or_path="<user>/robot_learning_tutorial_act",
|
||||||
chunk_size_threshold=0.5, # g
|
chunk_size_threshold=0.5, # g
|
||||||
|
|||||||
@@ -126,6 +126,12 @@ class RobotClientConfig:
|
|||||||
|
|
||||||
# Device configuration
|
# Device configuration
|
||||||
policy_device: str = field(default="cpu", metadata={"help": "Device for policy inference"})
|
policy_device: str = field(default="cpu", metadata={"help": "Device for policy inference"})
|
||||||
|
client_device: str = field(
|
||||||
|
default="cpu",
|
||||||
|
metadata={
|
||||||
|
"help": "Device to move actions to after receiving from server (e.g., for downstream planners)"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
# Control behavior configuration
|
# Control behavior configuration
|
||||||
chunk_size_threshold: float = field(default=0.5, metadata={"help": "Threshold for chunk size control"})
|
chunk_size_threshold: float = field(default=0.5, metadata={"help": "Threshold for chunk size control"})
|
||||||
@@ -161,6 +167,9 @@ class RobotClientConfig:
|
|||||||
if not self.policy_device:
|
if not self.policy_device:
|
||||||
raise ValueError("policy_device cannot be empty")
|
raise ValueError("policy_device cannot be empty")
|
||||||
|
|
||||||
|
if not self.client_device:
|
||||||
|
raise ValueError("client_device cannot be empty")
|
||||||
|
|
||||||
if self.chunk_size_threshold < 0 or self.chunk_size_threshold > 1:
|
if self.chunk_size_threshold < 0 or self.chunk_size_threshold > 1:
|
||||||
raise ValueError(f"chunk_size_threshold must be between 0 and 1, got {self.chunk_size_threshold}")
|
raise ValueError(f"chunk_size_threshold must be between 0 and 1, got {self.chunk_size_threshold}")
|
||||||
|
|
||||||
@@ -184,6 +193,7 @@ class RobotClientConfig:
|
|||||||
"policy_type": self.policy_type,
|
"policy_type": self.policy_type,
|
||||||
"pretrained_name_or_path": self.pretrained_name_or_path,
|
"pretrained_name_or_path": self.pretrained_name_or_path,
|
||||||
"policy_device": self.policy_device,
|
"policy_device": self.policy_device,
|
||||||
|
"client_device": self.client_device,
|
||||||
"chunk_size_threshold": self.chunk_size_threshold,
|
"chunk_size_threshold": self.chunk_size_threshold,
|
||||||
"fps": self.fps,
|
"fps": self.fps,
|
||||||
"actions_per_chunk": self.actions_per_chunk,
|
"actions_per_chunk": self.actions_per_chunk,
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ import os
|
|||||||
import time
|
import time
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -39,8 +40,8 @@ from lerobot.utils.utils import init_logging
|
|||||||
|
|
||||||
Action = torch.Tensor
|
Action = torch.Tensor
|
||||||
|
|
||||||
# observation as received from the robot
|
# observation as received from the robot (can be numpy arrays, floats, etc.)
|
||||||
RawObservation = dict[str, torch.Tensor]
|
RawObservation = dict[str, Any]
|
||||||
|
|
||||||
# observation as those recorded in LeRobot dataset (keys are different)
|
# observation as those recorded in LeRobot dataset (keys are different)
|
||||||
LeRobotObservation = dict[str, torch.Tensor]
|
LeRobotObservation = dict[str, torch.Tensor]
|
||||||
|
|||||||
@@ -381,6 +381,8 @@ class PolicyServer(services_pb2_grpc.AsyncInferenceServicer):
|
|||||||
action_tensor = torch.stack(processed_actions, dim=1).squeeze(0)
|
action_tensor = torch.stack(processed_actions, dim=1).squeeze(0)
|
||||||
self.logger.debug(f"Postprocessed action shape: {action_tensor.shape}")
|
self.logger.debug(f"Postprocessed action shape: {action_tensor.shape}")
|
||||||
|
|
||||||
|
action_tensor = action_tensor.detach().cpu()
|
||||||
|
|
||||||
"""5. Convert to TimedAction list"""
|
"""5. Convert to TimedAction list"""
|
||||||
action_chunk = self._time_action_chunk(
|
action_chunk = self._time_action_chunk(
|
||||||
observation_t.get_timestamp(), list(action_tensor), observation_t.get_timestep()
|
observation_t.get_timestamp(), list(action_tensor), observation_t.get_timestep()
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ python src/lerobot/async_inference/robot_client.py \
|
|||||||
--policy_type=act \
|
--policy_type=act \
|
||||||
--pretrained_name_or_path=user/model \
|
--pretrained_name_or_path=user/model \
|
||||||
--policy_device=mps \
|
--policy_device=mps \
|
||||||
|
--client_device=cpu \
|
||||||
--actions_per_chunk=50 \
|
--actions_per_chunk=50 \
|
||||||
--chunk_size_threshold=0.5 \
|
--chunk_size_threshold=0.5 \
|
||||||
--aggregate_fn_name=weighted_average \
|
--aggregate_fn_name=weighted_average \
|
||||||
@@ -40,6 +41,7 @@ from collections.abc import Callable
|
|||||||
from dataclasses import asdict
|
from dataclasses import asdict
|
||||||
from pprint import pformat
|
from pprint import pformat
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import draccus
|
import draccus
|
||||||
import grpc
|
import grpc
|
||||||
@@ -47,7 +49,6 @@ import torch
|
|||||||
|
|
||||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
|
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
|
||||||
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
|
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
|
||||||
from lerobot.processor import RobotAction
|
|
||||||
from lerobot.robots import ( # noqa: F401
|
from lerobot.robots import ( # noqa: F401
|
||||||
Robot,
|
Robot,
|
||||||
RobotConfig,
|
RobotConfig,
|
||||||
@@ -285,6 +286,21 @@ class RobotClient:
|
|||||||
timed_actions = pickle.loads(actions_chunk.data) # nosec
|
timed_actions = pickle.loads(actions_chunk.data) # nosec
|
||||||
deserialize_time = time.perf_counter() - deserialize_start
|
deserialize_time = time.perf_counter() - deserialize_start
|
||||||
|
|
||||||
|
# Log device type of received actions
|
||||||
|
if len(timed_actions) > 0:
|
||||||
|
received_device = timed_actions[0].get_action().device.type
|
||||||
|
self.logger.debug(f"Received actions on device: {received_device}")
|
||||||
|
|
||||||
|
# Move actions to client_device (e.g., for downstream planners that need GPU)
|
||||||
|
client_device = self.config.client_device
|
||||||
|
if client_device != "cpu":
|
||||||
|
for timed_action in timed_actions:
|
||||||
|
if timed_action.get_action().device.type != client_device:
|
||||||
|
timed_action.action = timed_action.get_action().to(client_device)
|
||||||
|
self.logger.debug(f"Converted actions to device: {client_device}")
|
||||||
|
else:
|
||||||
|
self.logger.debug(f"Actions kept on device: {client_device}")
|
||||||
|
|
||||||
self.action_chunk_size = max(self.action_chunk_size, len(timed_actions))
|
self.action_chunk_size = max(self.action_chunk_size, len(timed_actions))
|
||||||
|
|
||||||
# Calculate network latency if we have matching observations
|
# Calculate network latency if we have matching observations
|
||||||
@@ -351,7 +367,7 @@ class RobotClient:
|
|||||||
action = {key: action_tensor[i].item() for i, key in enumerate(self.robot.action_features)}
|
action = {key: action_tensor[i].item() for i, key in enumerate(self.robot.action_features)}
|
||||||
return action
|
return action
|
||||||
|
|
||||||
def control_loop_action(self, verbose: bool = False) -> RobotAction:
|
def control_loop_action(self, verbose: bool = False) -> dict[str, Any]:
|
||||||
"""Reading and performing actions in local queue"""
|
"""Reading and performing actions in local queue"""
|
||||||
|
|
||||||
# Lock only for queue operations
|
# Lock only for queue operations
|
||||||
|
|||||||
@@ -144,12 +144,18 @@ def test_async_inference_e2e(monkeypatch):
|
|||||||
client = RobotClient(client_config)
|
client = RobotClient(client_config)
|
||||||
assert client.start(), "Client failed initial handshake with the server"
|
assert client.start(), "Client failed initial handshake with the server"
|
||||||
|
|
||||||
# Track action chunks received without modifying RobotClient
|
# Track action chunks received and verify device type
|
||||||
action_chunks_received = {"count": 0}
|
action_chunks_received = {"count": 0, "actions_on_cpu": True}
|
||||||
original_aggregate = client._aggregate_action_queues
|
original_aggregate = client._aggregate_action_queues
|
||||||
|
|
||||||
def counting_aggregate(*args, **kwargs):
|
def counting_aggregate(*args, **kwargs):
|
||||||
action_chunks_received["count"] += 1
|
action_chunks_received["count"] += 1
|
||||||
|
# Check that all received actions are on CPU
|
||||||
|
if args:
|
||||||
|
for timed_action in args[0]: # args[0] is the list of TimedAction
|
||||||
|
action_tensor = timed_action.get_action()
|
||||||
|
if action_tensor.device.type != "cpu":
|
||||||
|
action_chunks_received["actions_on_cpu"] = False
|
||||||
return original_aggregate(*args, **kwargs)
|
return original_aggregate(*args, **kwargs)
|
||||||
|
|
||||||
monkeypatch.setattr(client, "_aggregate_action_queues", counting_aggregate)
|
monkeypatch.setattr(client, "_aggregate_action_queues", counting_aggregate)
|
||||||
|
|||||||
Reference in New Issue
Block a user