Merge branch 'main' into user/michel-aractingi/2025_06_30_dataset_v3

This commit is contained in:
Michel Aractingi
2025-07-30 00:43:53 +02:00
committed by GitHub
36 changed files with 455 additions and 604 deletions
+2
View File
@@ -206,3 +206,5 @@ jobs:
echo "::error::Failed to delete Docker image. HTTP status: $HTTP_RESPONSE"
exit 1
fi
# TODO(Steven): Check dockerimages pull in ubuntu
+61 -16
View File
@@ -20,12 +20,12 @@ on:
- 'v*.*.*' # Trigger on tags like v0.1.0, v1.0.0
jobs:
# TODO(Steven): Publish draft/pre-release and to test pypi
# TODO(Steven): Tag documentation with the same version as the package
# TODO(Steven): Define entry points for main CLI scripts
# This job builds the Python package and publishes it to PyPI
build-and-publish:
name: Build and publish Python distributions
runs-on: ubuntu-latest
outputs:
version: ${{ steps.extract_info.outputs.tag_version }}
permissions:
contents: write
id-token: write
@@ -41,26 +41,34 @@ jobs:
with:
python-version: '3.10'
- name: Extract Version and Package Name
- name: Extract Version
id: extract_info
# Extract version from tag (e.g., v0.1.0 -> 0.1.0)
# zizmor: ignore[template-injection]
run: |
# Extract version from tag (e.g., v0.1.0 -> 0.1.0)
VERSION=${{ github.ref_name }}
VERSION_NUMBER=${VERSION#v}
echo "tag_version=$VERSION_NUMBER" >> $GITHUB_OUTPUT
- name: Check if version matches pyproject.toml
# zizmor: ignore[template-injection]
run: |
TAG_VERSION=${{ steps.extract_info.outputs.tag_version }}
# Extract package name from pyproject.toml
PACKAGE_NAME=$(grep -oP 'name = "\K[^"]+' pyproject.toml)
echo "package_name=$PACKAGE_NAME" >> $GITHUB_OUTPUT
PYPROJECT_VERSION=$(grep '^version = ' pyproject.toml | awk -F' = ' '{print $2}' | tr -d '"')
if [[ "$TAG_VERSION" != "$PYPROJECT_VERSION" ]]; then
echo "Error: Tag version ($TAG_VERSION) does not match pyproject.toml version ($PYPROJECT_VERSION)." >&2
exit 1
else
echo "Tag version matches pyproject.toml version: $TAG_VERSION. Proceeding with release."
fi
- name: Check if version exists on PyPI
# zizmor: ignore[template-injection]
run: |
PACKAGE_NAME=${{ steps.extract_info.outputs.package_name }}
NEW_VERSION=${{ steps.extract_info.outputs.tag_version }}
response=$(curl -s "https://pypi.org/pypi/$PACKAGE_NAME/$NEW_VERSION/json")
response=$(curl -s "https://pypi.org/pypi/lerobot/$NEW_VERSION/json")
if echo "$response" | grep -q "message"; then
echo "Version $NEW_VERSION is available on PyPI. Proceeding with release."
else
@@ -80,9 +88,46 @@ jobs:
# zizmor: ignore[template-injection]
run: gh release create ${{ github.ref_name }} --release-name "Release ${{ github.ref_name }}" --generate-notes ./dist/*
# TODO(Steven): Uncomment when ready to publish to PyPI
# - name: Publish to PyPI
# if: startsWith(github.ref, 'refs/tags/v')
# uses: pypa/gh-action-pypi-publish@v1.12.4
# with:
# password: ${{ secrets.PYPI_API_TOKEN }}
- name: Publish to PyPI
if: startsWith(github.ref, 'refs/tags/v')
uses: pypa/gh-action-pypi-publish@v1.12.4 # zizmor: ignore[unpinned-uses, use-trusted-publishing]
with:
password: ${{ secrets.PYPI_API_TOKEN }}
# This job runs end-to-end tests on the release
test-release:
name: Test Release
needs: [build-and-publish]
runs-on: ubuntu-latest
permissions:
contents: read
env:
MUJOCO_GL: egl
steps:
- uses: actions/checkout@v4
with:
lfs: true
persist-credentials: false
- name: Install apt dependencies
run: |
sudo apt-get update && sudo apt-get install -y build-essential \
git curl libglib2.0-0 libegl1-mesa-dev ffmpeg libusb-1.0-0-dev \
speech-dispatcher libgeos-dev portaudio19-dev
- name: Setup uv and Python
uses: astral-sh/setup-uv@v6 # zizmor: ignore[unpinned-uses]
with:
enable-cache: true
version: ${{ env.UV_VERSION }}
python-version: ${{ env.PYTHON_VERSION }}
- name: Install lerobot release
run: uv run pip install lerobot==${{ needs.build-and-publish.outputs.version }} # zizmor: ignore[template-injection]
- name: Check lerobot version
run: uv run lerobot --version
- name: Run end-to-end tests
run: uv run make test-end-to-end
# TODO(Steven): Publish draft/pre-release and to test pypi
# TODO(Steven): Tag documentation with the same version as the package
+4 -4
View File
@@ -39,13 +39,13 @@ repos:
- id: trailing-whitespace
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.11.13
rev: v0.12.4
hooks:
- id: ruff-format
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
- repo: https://github.com/crate-ci/typos
- repo: https://github.com/adhtruong/mirrors-typos
rev: v1.34.0
hooks:
- id: typos
@@ -58,8 +58,8 @@ repos:
args: [--py310-plus]
##### Markdown Quality #####
- repo: https://github.com/pre-commit/mirrors-prettier
rev: v4.0.0-alpha.8
- repo: https://github.com/rbubley/mirrors-prettier
rev: v3.6.2
hooks:
- id: prettier
name: Format Markdown with Prettier
+2 -2
View File
@@ -28,7 +28,7 @@ This guide provides step-by-step instructions for training a robot policy using
- A gamepad (recommended) or keyboard to control the robot
- A Nvidia GPU
- A real robot with a follower and leader arm (optional if you use the keyboard or the gamepad)
- A URDF file for the robot for the kinematics package (check `lerobot/common/model/kinematics.py`)
- A URDF file for the robot for the kinematics package (check `lerobot/model/kinematics.py`)
## What kind of tasks can I train?
@@ -477,7 +477,7 @@ Create a training configuration file (example available [here](https://huggingfa
1. Configure the policy settings (`type="sac"`, `device`, etc.)
2. Set `dataset` to your cropped dataset
3. Configure environment settings with crop parameters
4. Check the other parameters related to SAC in [configuration_sac.py](https://github.com/huggingface/lerobot/blob/19bb621a7d0a31c20cd3cc08b1dbab68d3031454/lerobot/policies/sac/configuration_sac.py#L79).
4. Check the other parameters related to SAC in [configuration_sac.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/sac/configuration_sac.py#L79).
5. Verify that the `policy` config is correct with the right `input_features` and `output_features` for your task.
**Starting the Learner**
+2 -2
View File
@@ -323,7 +323,7 @@ The `record` function provides a suite of tools for capturing and managing data
##### 2. Checkpointing and Resuming
- Checkpoints are automatically created during recording.
- If an issue occurs, you can resume by re-running the same command with `--resume=true`.
- If an issue occurs, you can resume by re-running the same command with `--resume=true`. When resuming a recording, `--dataset.num_episodes` must be set to the **number of additional episodes to be recorded**, and not to the targeted total number of episodes in the dataset !
- To start recording from scratch, **manually delete** the dataset directory.
##### 3. Recording Parameters
@@ -485,7 +485,7 @@ huggingface-cli upload ${HF_USER}/act_so101_test${CKPT} \
## Run inference and evaluate your policy
You can use the `record` script from [`lerobot/record.py`](https://github.com/huggingface/lerobot/blob/main/lerobot/record.py) with a policy checkpoint as input, to run inference and evaluate your policy. For instance, run this command or API example to run inference and record 10 evaluation episodes:
You can use the `record` script from [`lerobot/record.py`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/record.py) with a policy checkpoint as input, to run inference and evaluate your policy. For instance, run this command or API example to run inference and record 10 evaluation episodes:
<hfoptions id="eval">
<hfoption id="Command">
+9 -9
View File
@@ -2,23 +2,23 @@
This tutorial will explain how to integrate your own robot design into the LeRobot ecosystem and have it access all of our tools (data collection, control pipelines, policy training and inference).
To that end, we provide the [`Robot`](https://github.com/huggingface/lerobot/blob/main/lerobot/robots/robot.py) base class in the LeRobot which specifies a standard interface for physical robot integration. Let's see how to implement it.
To that end, we provide the [`Robot`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/robots/robot.py) base class in the LeRobot which specifies a standard interface for physical robot integration. Let's see how to implement it.
## Prerequisites
- Your own robot which exposes a communication interface (e.g. serial, CAN, TCP)
- A way to read sensor data and send motor commands programmatically, e.g. manufacturer's SDK or API, or your own protocol implementation.
- LeRobot installed in your environment. Follow our [Installation Guide](./installation).
- LeRobot installed in your environment. Follow our [Installation Guide](./installation.mdx).
## Choose your motors
If you're using Feetech or Dynamixel motors, LeRobot provides built-in bus interfaces:
- [`FeetechMotorsBus`](https://github.com/huggingface/lerobot/blob/main/lerobot/motors/feetech/feetech.py) for controlling Feetech servos
- [`DynamixelMotorsBus`](https://github.com/huggingface/lerobot/blob/main/lerobot/motors/dynamixel/dynamixel.py) for controlling Dynamixel servos
- [`FeetechMotorsBus`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/motors/feetech/feetech.py) for controlling Feetech servos
- [`DynamixelMotorsBus`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/motors/dynamixel/dynamixel.py) for controlling Dynamixel servos
Please refer to the [`MotorsBus`](https://github.com/huggingface/lerobot/blob/main/lerobot/motors/motors_bus.py) abstract class to learn about its API.
For a good example of how it can be used, you can have a look at our own [SO101 follower implementation](https://github.com/huggingface/lerobot/blob/main/lerobot/robots/so101_follower/so101_follower.py)
Please refer to the [`MotorsBus`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/motors/motors_bus.py) abstract class to learn about its API.
For a good example of how it can be used, you can have a look at our own [SO101 follower implementation](https://github.com/huggingface/lerobot/blob/main/src/lerobot/robots/so101_follower/so101_follower.py)
Use these if compatible. Otherwise, you'll need to find or write a Python interface (not covered in this tutorial):
@@ -29,7 +29,7 @@ You're not alone—many community contributions use custom boards or firmware!
For Feetech and Dynamixel, we currently support these servos: - Feetech: - STS & SMS series (protocol 0): `sts3215`, `sts3250`, `sm8512bl` - SCS series (protocol 1): `scs0009` - Dynamixel (protocol 2.0 only): `xl330-m077`, `xl330-m288`, `xl430-w250`, `xm430-w350`, `xm540-w270`, `xc430-w150`
If you are using Feetech or Dynamixel servos that are not in this list, you can add those in the [Feetech table](https://github.com/huggingface/lerobot/blob/main/lerobot/motors/feetech/tables.py) or [Dynamixel table](https://github.com/huggingface/lerobot/blob/main/lerobot/motors/dynamixel/tables.py). Depending on the model, this will require you to add model-specific information. In most cases though, there shouldn't be a lot of additions to do.
If you are using Feetech or Dynamixel servos that are not in this list, you can add those in the [Feetech table](https://github.com/huggingface/lerobot/blob/main/src/lerobot/motors/feetech/tables.py) or [Dynamixel table](https://github.com/huggingface/lerobot/blob/main/src/lerobot/motors/dynamixel/tables.py). Depending on the model, this will require you to add model-specific information. In most cases though, there shouldn't be a lot of additions to do.
In the next sections, we'll use a `FeetechMotorsBus` as the motors interface for the examples. Replace it and adapt to your motors if necessary.
@@ -65,7 +65,7 @@ class MyCoolRobotConfig(RobotConfig):
```
<!-- prettier-ignore-end -->
Have a look at our [Cameras tutorial](./cameras) to understand how to detect and add your camera.
[Cameras tutorial](./cameras.mdx) to understand how to detect and add your camera.
Next, we'll create our actual robot class which inherits from `Robot`. This abstract class defines a contract you must follow for your robot to be usable with the rest of the LeRobot tools.
@@ -331,7 +331,7 @@ def send_action(self, action: dict[str, Any]) -> dict[str, Any]:
## Adding a Teleoperator
For implementing teleoperation devices, we also provide a [`Teleoperator`](https://github.com/huggingface/lerobot/blob/main/lerobot/teleoperators/teleoperator.py) base class. This class is very similar to the `Robot` base class and also doesn't assume anything on form factor.
For implementing teleoperation devices, we also provide a [`Teleoperator`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/teleoperators/teleoperator.py) base class. This class is very similar to the `Robot` base class and also doesn't assume anything on form factor.
The main differences are in the I/O functions: a teleoperator allows you to produce action via `get_action` and can receive feedback actions via `send_feedback`. Feedback could be anything controllable on the teleoperation device that could help the person controlling it understand the consequences of the actions sent. Think motion/force feedback on a leader arm, vibrations on a gamepad controller for example. To implement a teleoperator, you can follow this same tutorial and adapt it for these two methods.
+1 -1
View File
@@ -258,7 +258,7 @@ You should see on your laptop something like this: `[INFO] Connected to remote r
| F | Decrease speed |
> [!TIP]
> If you use a different keyboard, you can change the keys for each command in the [`LeKiwiConfig`](../src/lerobot/robot_devices/robots/configs.py).
> If you use a different keyboard, you can change the keys for each command in the [`LeKiwiClientConfig`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/robots/lekiwi/config_lekiwi.py).
### Wired version
+16 -5
View File
@@ -61,7 +61,7 @@ dependencies = [
# Hugging Face dependencies
"datasets>=2.19.0,<=3.6.0", # TODO: Bumb dependency
"diffusers>=0.27.2",
"huggingface-hub[hf-transfer,cli]>=0.27.1",
"huggingface-hub[hf-transfer,cli]>=0.34.2",
# Core dependencies
"cmake>=3.29.0.1",
@@ -75,7 +75,7 @@ dependencies = [
"packaging>=24.2",
"pynput>=1.7.7",
"pyserial>=3.5",
"wandb>=0.16.3",
"wandb>=0.20.0",
"draccus==0.10.0", # TODO: Remove ==
"gymnasium>=0.29.1,<1.0.0", # TODO: Bumb dependency
@@ -95,7 +95,7 @@ dependencies = [
pygame-dep = ["pygame>=2.5.1"]
placo-dep = ["placo>=0.9.6"]
transformers-dep = ["transformers>=4.50.3,<4.52.0"] # TODO: Bumb dependency
grpcio-dep = ["grpcio==1.71.0"]
grpcio-dep = ["grpcio==1.73.1", "protobuf==6.31.0"]
# Motors
feetech = ["feetech-servo-sdk>=1.0.0"]
@@ -119,14 +119,14 @@ intelrealsense = [
# Policies
pi0 = ["lerobot[transformers-dep]"]
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14", "accelerate>=1.7.0", "safetensors>=0.4.3"]
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.9", "protobuf>=5.29.3", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.9", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
# Features
async = ["lerobot[grpcio-dep]", "matplotlib>=3.10.3"]
# Development
docs = ["hf-doc-builder @ git+https://github.com/huggingface/doc-builder.git@main", "watchdog >= 6.0.0"]
dev = ["pre-commit>=3.7.0", "debugpy>=1.8.1", "grpcio-tools==1.71.0"]
dev = ["pre-commit>=3.7.0", "debugpy>=1.8.1", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1"]
test = ["pytest>=8.1.0", "pytest-timeout>=2.4.0", "pytest-cov>=5.0.0", "mock-serial>=0.0.1 ; sys_platform != 'win32'"]
video_benchmark = ["scikit-image>=0.23.2", "pandas>=2.2.2"]
@@ -156,6 +156,17 @@ all = [
"lerobot[xarm]"
]
[project.scripts]
lerobot-calibrate="lerobot.calibrate:main"
lerobot-find-cameras="lerobot.find_cameras:main"
lerobot-find-port="lerobot.find_port:main"
lerobot-record="lerobot.record:main"
lerobot-replay="lerobot.replay:main"
lerobot-setup-motors="lerobot.setup_motors:main"
lerobot-teleoperate="lerobot.teleoperate:main"
lerobot-eval="lerobot.scripts.eval:main"
lerobot-train="lerobot.scripts.train:main"
# ---------------- Tool Configurations ----------------
[tool.setuptools.packages.find]
where = ["src"]
+3 -3
View File
@@ -170,7 +170,7 @@ available_datasets = sorted(
# lists all available policies from `lerobot/policies`
available_policies = ["act", "diffusion", "tdmpc", "vqbet"]
# lists all available robots from `lerobot/robot_devices/robots`
# lists all available robots from `lerobot/robots`
available_robots = [
"koch",
"koch_bimanual",
@@ -179,13 +179,13 @@ available_robots = [
"so101",
]
# lists all available cameras from `lerobot/robot_devices/cameras`
# lists all available cameras from `lerobot/cameras`
available_cameras = [
"opencv",
"intelrealsense",
]
# lists all available motors from `lerobot/robot_devices/motors`
# lists all available motors from `lerobot/motors`
available_motors = [
"dynamixel",
"feetech",
+1 -1
View File
@@ -368,7 +368,7 @@ class OpenCVCamera(Camera):
if requested_color_mode == ColorMode.RGB:
processed_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
if self.rotation in [cv2.ROTATE_90_CLOCKWISE, cv2.ROTATE_90_COUNTERCLOCKWISE]:
if self.rotation in [cv2.ROTATE_90_CLOCKWISE, cv2.ROTATE_90_COUNTERCLOCKWISE, cv2.ROTATE_180]:
processed_image = cv2.rotate(processed_image, self.rotation)
return processed_image
@@ -434,7 +434,7 @@ class RealSenseCamera(Camera):
if self.color_mode == ColorMode.BGR:
processed_image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
if self.rotation in [cv2.ROTATE_90_CLOCKWISE, cv2.ROTATE_90_COUNTERCLOCKWISE]:
if self.rotation in [cv2.ROTATE_90_CLOCKWISE, cv2.ROTATE_90_COUNTERCLOCKWISE, cv2.ROTATE_180]:
processed_image = cv2.rotate(processed_image, self.rotation)
return processed_image
+1 -1
View File
@@ -514,7 +514,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
multiples of 1/fps. Defaults to 1e-4.
revision (str, optional): An optional Git revision id which can be a branch name, a tag, or a
commit hash. Defaults to current codebase version tag.
sync_cache_first (bool, optional): Flag to sync and refresh local files first. If True and files
force_cache_sync (bool, optional): Flag to sync and refresh local files first. If True and files
are already present in the local cache, this will be faster. However, files loaded might not
be in sync with the version on the hub, especially if you specified 'revision'. Defaults to
False.
+8 -8
View File
@@ -44,7 +44,7 @@ class EnvConfig(draccus.ChoiceRegistry, abc.ABC):
@EnvConfig.register_subclass("aloha")
@dataclass
class AlohaEnv(EnvConfig):
task: str = "AlohaInsertion-v0"
task: str | None = "AlohaInsertion-v0"
fps: int = 50
episode_length: int = 400
obs_type: str = "pixels_agent_pos"
@@ -82,7 +82,7 @@ class AlohaEnv(EnvConfig):
@EnvConfig.register_subclass("pusht")
@dataclass
class PushtEnv(EnvConfig):
task: str = "PushT-v0"
task: str | None = "PushT-v0"
fps: int = 10
episode_length: int = 300
obs_type: str = "pixels_agent_pos"
@@ -124,7 +124,7 @@ class PushtEnv(EnvConfig):
@EnvConfig.register_subclass("xarm")
@dataclass
class XarmEnv(EnvConfig):
task: str = "XarmLift-v0"
task: str | None = "XarmLift-v0"
fps: int = 15
episode_length: int = 200
obs_type: str = "pixels_agent_pos"
@@ -200,10 +200,10 @@ class HILSerlRobotEnvConfig(EnvConfig):
wrapper: EnvTransformConfig | None = None
fps: int = 10
name: str = "real_robot"
mode: str = None # Either "record", "replay", None
mode: str | None = None # Either "record", "replay", None
repo_id: str | None = None
dataset_root: str | None = None
task: str = ""
task: str | None = ""
num_episodes: int = 10 # only for record mode
episode: int = 0
device: str = "cuda"
@@ -213,6 +213,7 @@ class HILSerlRobotEnvConfig(EnvConfig):
# For the reward classifier, to record more positive examples after a success
number_of_steps_after_success: int = 0
@property
def gym_kwargs(self) -> dict:
return {}
@@ -222,9 +223,8 @@ class HILSerlRobotEnvConfig(EnvConfig):
class HILEnvConfig(EnvConfig):
"""Configuration for the HIL environment."""
type: str = "hil"
name: str = "PandaPickCube"
task: str = "PandaPickCubeKeyboard-v0"
task: str | None = "PandaPickCubeKeyboard-v0"
use_viewer: bool = True
gripper_penalty: float = 0.0
use_gamepad: bool = True
@@ -252,7 +252,7 @@ class HILEnvConfig(EnvConfig):
robot_config: RobotConfig | None = None
teleop_config: TeleoperatorConfig | None = None
wrapper: EnvTransformConfig | None = None
mode: str = None # Either "record", "replay", None
mode: str | None = None # Either "record", "replay", None
repo_id: str | None = None
dataset_root: str | None = None
num_episodes: int = 10 # only for record mode
+1 -1
View File
@@ -420,7 +420,7 @@ class ACT(nn.Module):
batch_size = batch["observation.environment_state"].shape[0]
# Prepare the latent for input to the transformer encoder.
if self.config.use_vae and "action" in batch:
if self.config.use_vae and "action" in batch and self.training:
# Prepare the input to the VAE encoder: [cls, *joint_space_configuration, *action_sequence].
cls_embed = einops.repeat(
self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size
@@ -217,6 +217,7 @@ class DiffusionConfig(PreTrainedConfig):
)
# Check that all input images have the same shape.
if len(self.image_features) > 0:
first_image_key, first_image_ft = next(iter(self.image_features.items()))
for key, image_ft in self.image_features.items():
if image_ft.shape != first_image_ft.shape:
@@ -133,11 +133,15 @@ class DiffusionPolicy(PreTrainedPolicy):
"horizon" may not the best name to describe what the variable actually means, because this period is
actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past.
"""
# NOTE: for offline evaluation, we have action in the batch, so we need to pop it out
if ACTION in batch:
batch.pop(ACTION)
batch = self.normalize_inputs(batch)
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
# Note: It's important that this happens after stacking the images into a single key.
# NOTE: It's important that this happens after stacking the images into a single key.
self._queues = populate_queues(self._queues, batch)
if len(self._queues[ACTION]) == 0:
@@ -284,7 +288,7 @@ class DiffusionModel(nn.Module):
"observation.images": (B, n_obs_steps, num_cameras, C, H, W)
AND/OR
"observation.environment_state": (B, environment_dim)
"observation.environment_state": (B, n_obs_steps, environment_dim)
}
"""
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
@@ -311,7 +315,7 @@ class DiffusionModel(nn.Module):
"observation.images": (B, n_obs_steps, num_cameras, C, H, W)
AND/OR
"observation.environment_state": (B, environment_dim)
"observation.environment_state": (B, n_obs_steps, environment_dim)
"action": (B, horizon, action_dim)
"action_is_pad": (B, horizon)
+3 -8
View File
@@ -90,12 +90,6 @@ def create_sinusoidal_pos_embedding(
return pos_emb
def sample_beta(alpha, beta, bsize, device):
gamma1 = torch.empty((bsize,), device=device).uniform_(0, 1).pow(1 / alpha)
gamma2 = torch.empty((bsize,), device=device).uniform_(0, 1).pow(1 / beta)
return gamma1 / (gamma1 + gamma2)
def make_att_2d_masks(pad_masks, att_masks):
"""Copied from big_vision.
@@ -515,9 +509,10 @@ class PI0FlowMatching(nn.Module):
return noise
def sample_time(self, bsize, device):
time_beta = sample_beta(1.5, 1.0, bsize, device)
beta_dist = torch.distributions.Beta(concentration1=1.5, concentration0=1.0)
time_beta = beta_dist.sample((bsize,)).to(device=device, dtype=torch.float32)
time = time_beta * 0.999 + 0.001
return time.to(dtype=torch.float32, device=device)
return time
def embed_prefix(
self, images, img_masks, lang_tokens, lang_masks
@@ -488,6 +488,8 @@ class PI0FAST(nn.Module):
param.data = param.data.to(dtype=torch_precision)
self.set_requires_grad()
self.image_keys = self.config.image_features.keys()
# TODO: Remove this once we bump transformers to >4.52.0 because the attribute will be removed
# AttributeError: 'PaliGemmaConfig' object has no attribute 'ignore_index'
self.ignore_index = self.pi0_paligemma.config.ignore_index
self.padding_side = self.config.padding_side
@@ -194,12 +194,6 @@ def create_sinusoidal_pos_embedding(
return pos_emb
def sample_beta(alpha, beta, bsize, device):
gamma1 = torch.empty((bsize,), device=device).uniform_(0, 1).pow(1 / alpha)
gamma2 = torch.empty((bsize,), device=device).uniform_(0, 1).pow(1 / beta)
return gamma1 / (gamma1 + gamma2)
def make_att_2d_masks(pad_masks, att_masks):
"""Copied from big_vision.
@@ -384,8 +378,13 @@ class SmolVLAPolicy(PreTrainedPolicy):
return self.parameters()
def _get_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
# TODO: Check if this for loop is needed.
# Context: In fact, self.queues contains only ACTION field, and in inference, we don't have action in the batch
# In the case of offline inference, we have the action in the batch
# that why without the k != ACTION check, it will raise an error because we are trying to stack
# on an empty container.
for k in batch:
if k in self._queues:
if k in self._queues and k != ACTION:
batch[k] = torch.stack(list(self._queues[k]), dim=1)
images, img_masks = self.prepare_images(batch)
@@ -631,7 +630,7 @@ class VLAFlowMatching(nn.Module):
"""
def __init__(self, config):
def __init__(self, config: SmolVLAConfig):
super().__init__()
self.config = config
@@ -685,9 +684,10 @@ class VLAFlowMatching(nn.Module):
return noise
def sample_time(self, bsize, device):
time_beta = sample_beta(1.5, 1.0, bsize, device)
beta_dist = torch.distributions.Beta(concentration1=1.5, concentration0=1.0)
time_beta = beta_dist.sample((bsize,)).to(device=device, dtype=torch.float32)
time = time_beta * 0.999 + 0.001
return time.to(dtype=torch.float32, device=device)
return time
def embed_prefix(
self, images, img_masks, lang_tokens, lang_masks, state: torch.Tensor = None
@@ -143,7 +143,12 @@ class TDMPCPolicy(PreTrainedPolicy):
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations."""
# NOTE: for offline evaluation, we have action in the batch, so we need to pop it out
if ACTION in batch:
batch.pop(ACTION)
batch = self.normalize_inputs(batch)
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch[OBS_IMAGE] = batch[next(iter(self.config.image_features))]
+5 -2
View File
@@ -139,11 +139,14 @@ class VQBeTPolicy(PreTrainedPolicy):
environment. It works by managing the actions in a queue and only calling `select_actions` when the
queue is empty.
"""
# NOTE: for offline evaluation, we have action in the batch, so we need to pop it out
if ACTION in batch:
batch.pop(ACTION)
batch = self.normalize_inputs(batch)
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
# NOTE: It's important that this happens after stacking the images into a single key.
batch["observation.images"] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
# Note: It's important that this happens after stacking the images into a single key.
self._queues = populate_queues(self._queues, batch)
if not self.vqbet.action_head.vqvae_model.discretized.item():
+1 -1
View File
@@ -63,7 +63,7 @@ python lerobot/scripts/control_robot.py \
--control.type=teleoperate
```
By adding `--robot.max_relative_target=5`, we override the default value for `max_relative_target` defined in [`AlohaRobotConfig`](lerobot/robot_devices/robots/configs.py). It is expected to be `5` to limit the magnitude of the movement for more safety, but the teleoperation won't be smooth. When you feel confident, you can disable this limit by adding `--robot.max_relative_target=null` to the command line:
By adding `--robot.max_relative_target=5`, we override the default value for `max_relative_target` defined in [`ViperXConfig`](./config_viperx.py). It is expected to be `5` to limit the magnitude of the movement for more safety, but the teleoperation won't be smooth. When you feel confident, you can disable this limit by adding `--robot.max_relative_target=null` to the command line:
```bash
python lerobot/scripts/control_robot.py \
-86
View File
@@ -12,15 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import logging
import logging.handlers
import os
import time
from dataclasses import dataclass
from pathlib import Path
from threading import Event
from typing import Any
import torch
@@ -31,8 +28,6 @@ from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features
# NOTE: Configs need to be loaded for the client to be able to instantiate the policy config
from lerobot.policies import ACTConfig, DiffusionConfig, PI0Config, SmolVLAConfig, VQBeTConfig # noqa: F401
from lerobot.robots.robot import Robot
from lerobot.transport import async_inference_pb2
from lerobot.transport.utils import bytes_buffer_size
from lerobot.utils.utils import init_logging
Action = torch.Tensor
@@ -303,84 +298,3 @@ def observations_similar(
)
return _compare_observation_states(obs1_state, obs2_state, atol=atol)
def send_bytes_in_chunks(
buffer: bytes,
message_class: Any,
log_prefix: str = "",
silent: bool = True,
chunk_size: int = 3 * 1024 * 1024,
):
# NOTE(fracapuano): Partially copied from lerobot.common.transport.utils.send_bytes_in_chunks. Duplication can't be avoided if we
# don't use a unique class for messages sent (due to the different transfer states sent). Also, I'd want more control over the
# chunk size as I am using it to send image observations.
buffer = io.BytesIO(buffer)
size_in_bytes = bytes_buffer_size(buffer)
sent_bytes = 0
logging_method = logging.info if not silent else logging.debug
logging_method(f"{log_prefix} Buffer size {size_in_bytes / 1024 / 1024} MB with")
while sent_bytes < size_in_bytes:
transfer_state = async_inference_pb2.TransferState.TRANSFER_MIDDLE
if sent_bytes + chunk_size >= size_in_bytes:
transfer_state = async_inference_pb2.TransferState.TRANSFER_END
elif sent_bytes == 0:
transfer_state = async_inference_pb2.TransferState.TRANSFER_BEGIN
size_to_read = min(chunk_size, size_in_bytes - sent_bytes)
chunk = buffer.read(size_to_read)
yield message_class(transfer_state=transfer_state, data=chunk)
sent_bytes += size_to_read
logging_method(f"{log_prefix} Sent {sent_bytes}/{size_in_bytes} bytes with state {transfer_state}")
logging_method(f"{log_prefix} Published {sent_bytes / 1024 / 1024} MB")
def receive_bytes_in_chunks(
iterator, continue_receiving: Event, logger: logging.Logger, log_prefix: str = ""
): # type: ignore
# NOTE(fracapuano): Partially copied from lerobot.common.transport.utils.receive_bytes_in_chunks. Duplication can't be avoided if we
# don't use a unique class for messages sent (due to the different transfer states sent). Also, on the server side the logic for receiving
# is opposite then the HIL-SERL design (my event showcases keeping on running instead of shutdown)
bytes_buffer = io.BytesIO()
step = 0
logger.info(f"{log_prefix} Starting receiver")
for item in iterator:
logger.debug(f"{log_prefix} Received item")
if not continue_receiving.is_set():
logger.info(f"{log_prefix} Shutting down receiver")
return
if item.transfer_state == async_inference_pb2.TransferState.TRANSFER_BEGIN:
bytes_buffer.seek(0)
bytes_buffer.truncate(0)
bytes_buffer.write(item.data)
logger.debug(f"{log_prefix} Received data at step 0")
elif item.transfer_state == async_inference_pb2.TransferState.TRANSFER_MIDDLE:
bytes_buffer.write(item.data)
step += 1
logger.debug(f"{log_prefix} Received data at step {step}")
elif item.transfer_state == async_inference_pb2.TransferState.TRANSFER_END:
bytes_buffer.write(item.data)
logger.debug(f"{log_prefix} Received data at step end size {bytes_buffer_size(bytes_buffer)}")
complete_bytes = bytes_buffer.getvalue()
bytes_buffer.seek(0)
bytes_buffer.truncate(0)
logger.debug(f"{log_prefix} Queue updated")
return complete_bytes
else:
logger.warning(f"{log_prefix} Received unknown transfer state {item.transfer_state}")
raise ValueError(f"Received unknown transfer state {item.transfer_state}")
+17 -17
View File
@@ -49,21 +49,21 @@ from lerobot.scripts.server.helpers import (
get_logger,
observations_similar,
raw_observation_to_observation,
receive_bytes_in_chunks,
)
from lerobot.transport import (
async_inference_pb2, # type: ignore
async_inference_pb2_grpc, # type: ignore
services_pb2, # type: ignore
services_pb2_grpc, # type: ignore
)
from lerobot.transport.utils import receive_bytes_in_chunks
class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer):
class PolicyServer(services_pb2_grpc.AsyncInferenceServicer):
prefix = "policy_server"
logger = get_logger(prefix)
def __init__(self, config: PolicyServerConfig):
self.config = config
self._running_event = threading.Event()
self.shutdown_event = threading.Event()
# FPS measurement
self.fps_tracker = FPSTracker(target_fps=config.fps)
@@ -84,7 +84,7 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer):
@property
def running(self):
return self._running_event.is_set()
return not self.shutdown_event.is_set()
@property
def policy_image_features(self):
@@ -93,7 +93,7 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer):
def _reset_server(self) -> None:
"""Flushes server state when new client connects."""
# only running inference on the latest observation received by the server
self._running_event.clear()
self.shutdown_event.set()
self.observation_queue = Queue(maxsize=1)
with self._predicted_timesteps_lock:
@@ -103,16 +103,16 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer):
client_id = context.peer()
self.logger.info(f"Client {client_id} connected and ready")
self._reset_server()
self._running_event.set()
self.shutdown_event.clear()
return async_inference_pb2.Empty()
return services_pb2.Empty()
def SendPolicyInstructions(self, request, context): # noqa: N802
"""Receive policy instructions from the robot client"""
if not self.running:
self.logger.warning("Server is not running. Ignoring policy instructions.")
return async_inference_pb2.Empty()
return services_pb2.Empty()
client_id = context.peer()
@@ -149,7 +149,7 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer):
self.logger.info(f"Time taken to put policy on {self.device}: {end - start:.4f} seconds")
return async_inference_pb2.Empty()
return services_pb2.Empty()
def SendObservations(self, request_iterator, context): # noqa: N802
"""Receive observations from the robot client"""
@@ -159,7 +159,7 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer):
receive_time = time.time() # comparing timestamps so need time.time()
start_deserialize = time.perf_counter()
received_bytes = receive_bytes_in_chunks(
request_iterator, self._running_event, self.logger
request_iterator, None, self.shutdown_event, self.logger
) # blocking call while looping over request_iterator
timed_observation = pickle.loads(received_bytes) # nosec
deserialize_time = time.perf_counter() - start_deserialize
@@ -190,7 +190,7 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer):
):
self.logger.info(f"Observation #{obs_timestep} has been filtered out")
return async_inference_pb2.Empty()
return services_pb2.Empty()
def GetActions(self, request, context): # noqa: N802
"""Returns actions to the robot client. Actions are sent as a single
@@ -218,7 +218,7 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer):
serialize_time = time.perf_counter() - start_time
# Create and return the action chunk
actions = async_inference_pb2.Actions(data=actions_bytes)
actions = services_pb2.Actions(data=actions_bytes)
self.logger.info(
f"Action chunk #{obs.get_timestep()} generated | "
@@ -239,12 +239,12 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer):
return actions
except Empty: # no observation added to queue in obs_queue_timeout
return async_inference_pb2.Empty()
return services_pb2.Empty()
except Exception as e:
self.logger.error(f"Error in StreamActions: {e}")
return async_inference_pb2.Empty()
return services_pb2.Empty()
def _obs_sanity_checks(self, obs: TimedObservation, previous_obs: TimedObservation) -> bool:
"""Check if the observation is valid to be processed by the policy"""
@@ -388,7 +388,7 @@ def serve(cfg: PolicyServerConfig):
# Setup and start gRPC server
server = grpc.server(futures.ThreadPoolExecutor(max_workers=4))
async_inference_pb2_grpc.add_AsyncInferenceServicer_to_server(policy_server, server)
services_pb2_grpc.add_AsyncInferenceServicer_to_server(policy_server, server)
server.add_insecure_port(f"{cfg.host}:{cfg.port}")
policy_server.logger.info(f"PolicyServer started on {cfg.host}:{cfg.port}")
+12 -13
View File
@@ -69,15 +69,14 @@ from lerobot.scripts.server.helpers import (
TimedObservation,
get_logger,
map_robot_keys_to_lerobot_features,
send_bytes_in_chunks,
validate_robot_cameras_for_policy,
visualize_action_queue_size,
)
from lerobot.transport import (
async_inference_pb2, # type: ignore
async_inference_pb2_grpc, # type: ignore
services_pb2, # type: ignore
services_pb2_grpc, # type: ignore
)
from lerobot.transport.utils import grpc_channel_options
from lerobot.transport.utils import grpc_channel_options, send_bytes_in_chunks
class RobotClient:
@@ -118,10 +117,10 @@ class RobotClient:
self.channel = grpc.insecure_channel(
self.server_address, grpc_channel_options(initial_backoff=f"{config.environment_dt:.4f}s")
)
self.stub = async_inference_pb2_grpc.AsyncInferenceStub(self.channel)
self.stub = services_pb2_grpc.AsyncInferenceStub(self.channel)
self.logger.info(f"Initializing client to connect to server at {self.server_address}")
self._running_event = threading.Event()
self.shutdown_event = threading.Event()
# Initialize client side variables
self.latest_action_lock = threading.Lock()
@@ -146,20 +145,20 @@ class RobotClient:
@property
def running(self):
return self._running_event.is_set()
return not self.shutdown_event.is_set()
def start(self):
"""Start the robot client and connect to the policy server"""
try:
# client-server handshake
start_time = time.perf_counter()
self.stub.Ready(async_inference_pb2.Empty())
self.stub.Ready(services_pb2.Empty())
end_time = time.perf_counter()
self.logger.debug(f"Connected to policy server in {end_time - start_time:.4f}s")
# send policy instructions
policy_config_bytes = pickle.dumps(self.policy_config)
policy_setup = async_inference_pb2.PolicySetup(data=policy_config_bytes)
policy_setup = services_pb2.PolicySetup(data=policy_config_bytes)
self.logger.info("Sending policy instructions to policy server")
self.logger.debug(
@@ -170,7 +169,7 @@ class RobotClient:
self.stub.SendPolicyInstructions(policy_setup)
self._running_event.set()
self.shutdown_event.clear()
return True
@@ -180,7 +179,7 @@ class RobotClient:
def stop(self):
"""Stop the robot client"""
self._running_event.clear()
self.shutdown_event.set()
self.robot.disconnect()
self.logger.debug("Robot disconnected")
@@ -208,7 +207,7 @@ class RobotClient:
try:
observation_iterator = send_bytes_in_chunks(
observation_bytes,
async_inference_pb2.Observation,
services_pb2.Observation,
log_prefix="[CLIENT] Observation",
silent=True,
)
@@ -283,7 +282,7 @@ class RobotClient:
while self.running:
try:
# Use StreamActions to get a stream of actions from the server
actions_chunk = self.stub.GetActions(async_inference_pb2.Empty())
actions_chunk = self.stub.GetActions(services_pb2.Empty())
if len(actions_chunk.data) == 0:
continue # received `Empty` from server, wait for next call
@@ -295,8 +295,8 @@ class GamepadController(InputController):
try:
# Read joystick axes
# Left stick X and Y (typically axes 0 and 1)
y_input = self.joystick.get_axis(0) # Left/Right
x_input = self.joystick.get_axis(1) # Up/Down (often inverted)
x_input = self.joystick.get_axis(0) # Left/Right
y_input = self.joystick.get_axis(1) # Up/Down (often inverted)
# Right stick Y (typically axis 3 or 4)
z_input = self.joystick.get_axis(3) # Up/Down for Z
@@ -308,7 +308,7 @@ class GamepadController(InputController):
# Calculate deltas (note: may need to invert axes depending on controller)
delta_x = -x_input * self.x_step_size # Forward/backward
delta_y = -y_input * self.y_step_size # Left/right
delta_y = y_input * self.y_step_size # Left/right
delta_z = -z_input * self.z_step_size # Up/down
return delta_x, delta_y, delta_z
@@ -1,59 +0,0 @@
// fmt: off
// flake8: noqa
// !/usr/bin/env python
// Copyright 2024 The HuggingFace Inc. team.
// All rights reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto3";
package async_inference;
// AsyncInference: from Robot perspective
// Robot send observations to & executes action received from a remote Policy server
service AsyncInference {
// Robot -> Policy to share observations with a remote inference server
// Policy -> Robot to share actions predicted for given observations
rpc SendObservations(stream Observation) returns (Empty);
rpc GetActions(Empty) returns (Actions);
rpc SendPolicyInstructions(PolicySetup) returns (Empty);
rpc Ready(Empty) returns (Empty);
rpc Stop(Empty) returns (Empty);
}
enum TransferState {
TRANSFER_UNKNOWN = 0;
TRANSFER_BEGIN = 1;
TRANSFER_MIDDLE = 2;
TRANSFER_END = 3;
}
// Messages
message Observation {
// sent by Robot, to remote Policy
TransferState transfer_state = 1; // Observations can be streamed exceeding 4MB of size
bytes data = 2;
}
message Actions {
// sent by remote Policy, to Robot
bytes data = 1;
}
message PolicySetup {
// sent by Robot to remote server, to init Policy
bytes data = 1;
}
message Empty {}
@@ -1,45 +0,0 @@
# Generated by the protocol buffer compiler. DO NOT EDIT!
# NO CHECKED-IN PROTOBUF GENCODE
# source: async_inference.proto
# Protobuf Python Version: 5.29.0
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import runtime_version as _runtime_version
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
_runtime_version.ValidateProtobufRuntimeVersion(
_runtime_version.Domain.PUBLIC,
5,
29,
0,
'',
'async_inference.proto'
)
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x61sync_inference.proto\x12\x0f\x61sync_inference\"S\n\x0bObservation\x12\x36\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x1e.async_inference.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"\x17\n\x07\x41\x63tions\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\"\x1b\n\x0bPolicySetup\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\"\x07\n\x05\x45mpty*`\n\rTransferState\x12\x14\n\x10TRANSFER_UNKNOWN\x10\x00\x12\x12\n\x0eTRANSFER_BEGIN\x10\x01\x12\x13\n\x0fTRANSFER_MIDDLE\x10\x02\x12\x10\n\x0cTRANSFER_END\x10\x03\x32\xdd\x02\n\x0e\x41syncInference\x12J\n\x10SendObservations\x12\x1c.async_inference.Observation\x1a\x16.async_inference.Empty(\x01\x12>\n\nGetActions\x12\x16.async_inference.Empty\x1a\x18.async_inference.Actions\x12N\n\x16SendPolicyInstructions\x12\x1c.async_inference.PolicySetup\x1a\x16.async_inference.Empty\x12\x37\n\x05Ready\x12\x16.async_inference.Empty\x1a\x16.async_inference.Empty\x12\x36\n\x04Stop\x12\x16.async_inference.Empty\x1a\x16.async_inference.Emptyb\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'async_inference_pb2', _globals)
if not _descriptor._USE_C_DESCRIPTORS:
DESCRIPTOR._loaded_options = None
_globals['_TRANSFERSTATE']._serialized_start=190
_globals['_TRANSFERSTATE']._serialized_end=286
_globals['_OBSERVATION']._serialized_start=42
_globals['_OBSERVATION']._serialized_end=125
_globals['_ACTIONS']._serialized_start=127
_globals['_ACTIONS']._serialized_end=150
_globals['_POLICYSETUP']._serialized_start=152
_globals['_POLICYSETUP']._serialized_end=179
_globals['_EMPTY']._serialized_start=181
_globals['_EMPTY']._serialized_end=188
_globals['_ASYNCINFERENCE']._serialized_start=289
_globals['_ASYNCINFERENCE']._serialized_end=638
# @@protoc_insertion_point(module_scope)
@@ -1,277 +0,0 @@
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
"""Client and server classes corresponding to protobuf-defined services."""
import grpc
import warnings
from lerobot.transport import async_inference_pb2 as async__inference__pb2
GRPC_GENERATED_VERSION = '1.71.0'
GRPC_VERSION = grpc.__version__
_version_not_supported = False
try:
from grpc._utilities import first_version_is_lower
_version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
except ImportError:
_version_not_supported = True
if _version_not_supported:
raise RuntimeError(
f'The grpc package installed is at version {GRPC_VERSION},'
+ f' but the generated code in async_inference_pb2_grpc.py depends on'
+ f' grpcio>={GRPC_GENERATED_VERSION}.'
+ f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
+ f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
)
class AsyncInferenceStub:
"""AsyncInference: from Robot perspective
Robot send observations to & executes action received from a remote Policy server
"""
def __init__(self, channel):
"""Constructor.
Args:
channel: A grpc.Channel.
"""
self.SendObservations = channel.stream_unary(
'/async_inference.AsyncInference/SendObservations',
request_serializer=async__inference__pb2.Observation.SerializeToString,
response_deserializer=async__inference__pb2.Empty.FromString,
_registered_method=True)
self.GetActions = channel.unary_unary(
'/async_inference.AsyncInference/GetActions',
request_serializer=async__inference__pb2.Empty.SerializeToString,
response_deserializer=async__inference__pb2.Actions.FromString,
_registered_method=True)
self.SendPolicyInstructions = channel.unary_unary(
'/async_inference.AsyncInference/SendPolicyInstructions',
request_serializer=async__inference__pb2.PolicySetup.SerializeToString,
response_deserializer=async__inference__pb2.Empty.FromString,
_registered_method=True)
self.Ready = channel.unary_unary(
'/async_inference.AsyncInference/Ready',
request_serializer=async__inference__pb2.Empty.SerializeToString,
response_deserializer=async__inference__pb2.Empty.FromString,
_registered_method=True)
self.Stop = channel.unary_unary(
'/async_inference.AsyncInference/Stop',
request_serializer=async__inference__pb2.Empty.SerializeToString,
response_deserializer=async__inference__pb2.Empty.FromString,
_registered_method=True)
class AsyncInferenceServicer:
"""AsyncInference: from Robot perspective
Robot send observations to & executes action received from a remote Policy server
"""
def SendObservations(self, request_iterator, context):
"""Robot -> Policy to share observations with a remote inference server
Policy -> Robot to share actions predicted for given observations
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def GetActions(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def SendPolicyInstructions(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def Ready(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def Stop(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def add_AsyncInferenceServicer_to_server(servicer, server):
rpc_method_handlers = {
'SendObservations': grpc.stream_unary_rpc_method_handler(
servicer.SendObservations,
request_deserializer=async__inference__pb2.Observation.FromString,
response_serializer=async__inference__pb2.Empty.SerializeToString,
),
'GetActions': grpc.unary_unary_rpc_method_handler(
servicer.GetActions,
request_deserializer=async__inference__pb2.Empty.FromString,
response_serializer=async__inference__pb2.Actions.SerializeToString,
),
'SendPolicyInstructions': grpc.unary_unary_rpc_method_handler(
servicer.SendPolicyInstructions,
request_deserializer=async__inference__pb2.PolicySetup.FromString,
response_serializer=async__inference__pb2.Empty.SerializeToString,
),
'Ready': grpc.unary_unary_rpc_method_handler(
servicer.Ready,
request_deserializer=async__inference__pb2.Empty.FromString,
response_serializer=async__inference__pb2.Empty.SerializeToString,
),
'Stop': grpc.unary_unary_rpc_method_handler(
servicer.Stop,
request_deserializer=async__inference__pb2.Empty.FromString,
response_serializer=async__inference__pb2.Empty.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'async_inference.AsyncInference', rpc_method_handlers)
server.add_generic_rpc_handlers((generic_handler,))
server.add_registered_method_handlers('async_inference.AsyncInference', rpc_method_handlers)
# This class is part of an EXPERIMENTAL API.
class AsyncInference:
"""AsyncInference: from Robot perspective
Robot send observations to & executes action received from a remote Policy server
"""
@staticmethod
def SendObservations(request_iterator,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.stream_unary(
request_iterator,
target,
'/async_inference.AsyncInference/SendObservations',
async__inference__pb2.Observation.SerializeToString,
async__inference__pb2.Empty.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def GetActions(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/async_inference.AsyncInference/GetActions',
async__inference__pb2.Empty.SerializeToString,
async__inference__pb2.Actions.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def SendPolicyInstructions(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/async_inference.AsyncInference/SendPolicyInstructions',
async__inference__pb2.PolicySetup.SerializeToString,
async__inference__pb2.Empty.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def Ready(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/async_inference.AsyncInference/Ready',
async__inference__pb2.Empty.SerializeToString,
async__inference__pb2.Empty.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def Stop(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/async_inference.AsyncInference/Stop',
async__inference__pb2.Empty.SerializeToString,
async__inference__pb2.Empty.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
+28
View File
@@ -33,6 +33,17 @@ service LearnerService {
rpc Ready(Empty) returns (Empty);
}
// AsyncInference: from Robot perspective
// Robot send observations to & executes action received from a remote Policy server
service AsyncInference {
// Robot -> Policy to share observations with a remote inference server
// Policy -> Robot to share actions predicted for given observations
rpc SendObservations(stream Observation) returns (Empty);
rpc GetActions(Empty) returns (Actions);
rpc SendPolicyInstructions(PolicySetup) returns (Empty);
rpc Ready(Empty) returns (Empty);
}
enum TransferState {
TRANSFER_UNKNOWN = 0;
TRANSFER_BEGIN = 1;
@@ -56,4 +67,21 @@ message InteractionMessage {
bytes data = 2;
}
// Messages
message Observation {
// sent by Robot, to remote Policy
TransferState transfer_state = 1; // Observations can be streamed exceeding 4MB of size
bytes data = 2;
}
message Actions {
// sent by remote Policy, to Robot
bytes data = 1;
}
message PolicySetup {
// sent by Robot to remote server, to init Policy
bytes data = 1;
}
message Empty {}
+18 -10
View File
@@ -1,7 +1,7 @@
# Generated by the protocol buffer compiler. DO NOT EDIT!
# NO CHECKED-IN PROTOBUF GENCODE
# source: lerobot/transport/services.proto
# Protobuf Python Version: 5.29.0
# Protobuf Python Version: 6.31.0
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
@@ -10,8 +10,8 @@ from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
_runtime_version.ValidateProtobufRuntimeVersion(
_runtime_version.Domain.PUBLIC,
5,
29,
6,
31,
0,
'',
'lerobot/transport/services.proto'
@@ -23,23 +23,31 @@ _sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n lerobot/transport/services.proto\x12\ttransport\"L\n\nTransition\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"L\n\nParameters\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"T\n\x12InteractionMessage\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"\x07\n\x05\x45mpty*`\n\rTransferState\x12\x14\n\x10TRANSFER_UNKNOWN\x10\x00\x12\x12\n\x0eTRANSFER_BEGIN\x10\x01\x12\x13\n\x0fTRANSFER_MIDDLE\x10\x02\x12\x10\n\x0cTRANSFER_END\x10\x03\x32\x81\x02\n\x0eLearnerService\x12=\n\x10StreamParameters\x12\x10.transport.Empty\x1a\x15.transport.Parameters0\x01\x12<\n\x0fSendTransitions\x12\x15.transport.Transition\x1a\x10.transport.Empty(\x01\x12\x45\n\x10SendInteractions\x12\x1d.transport.InteractionMessage\x1a\x10.transport.Empty(\x01\x12+\n\x05Ready\x12\x10.transport.Empty\x1a\x10.transport.Emptyb\x06proto3')
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n lerobot/transport/services.proto\x12\ttransport\"L\n\nTransition\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"L\n\nParameters\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"T\n\x12InteractionMessage\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"M\n\x0bObservation\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"\x17\n\x07\x41\x63tions\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\"\x1b\n\x0bPolicySetup\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\"\x07\n\x05\x45mpty*`\n\rTransferState\x12\x14\n\x10TRANSFER_UNKNOWN\x10\x00\x12\x12\n\x0eTRANSFER_BEGIN\x10\x01\x12\x13\n\x0fTRANSFER_MIDDLE\x10\x02\x12\x10\n\x0cTRANSFER_END\x10\x03\x32\x81\x02\n\x0eLearnerService\x12=\n\x10StreamParameters\x12\x10.transport.Empty\x1a\x15.transport.Parameters0\x01\x12<\n\x0fSendTransitions\x12\x15.transport.Transition\x1a\x10.transport.Empty(\x01\x12\x45\n\x10SendInteractions\x12\x1d.transport.InteractionMessage\x1a\x10.transport.Empty(\x01\x12+\n\x05Ready\x12\x10.transport.Empty\x1a\x10.transport.Empty2\xf5\x01\n\x0e\x41syncInference\x12>\n\x10SendObservations\x12\x16.transport.Observation\x1a\x10.transport.Empty(\x01\x12\x32\n\nGetActions\x12\x10.transport.Empty\x1a\x12.transport.Actions\x12\x42\n\x16SendPolicyInstructions\x12\x16.transport.PolicySetup\x1a\x10.transport.Empty\x12+\n\x05Ready\x12\x10.transport.Empty\x1a\x10.transport.Emptyb\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'lerobot.transport.services_pb2', _globals)
if not _descriptor._USE_C_DESCRIPTORS:
DESCRIPTOR._loaded_options = None
_globals['_TRANSFERSTATE']._serialized_start=298
_globals['_TRANSFERSTATE']._serialized_end=394
_globals['_TRANSFERSTATE']._serialized_start=431
_globals['_TRANSFERSTATE']._serialized_end=527
_globals['_TRANSITION']._serialized_start=47
_globals['_TRANSITION']._serialized_end=123
_globals['_PARAMETERS']._serialized_start=125
_globals['_PARAMETERS']._serialized_end=201
_globals['_INTERACTIONMESSAGE']._serialized_start=203
_globals['_INTERACTIONMESSAGE']._serialized_end=287
_globals['_EMPTY']._serialized_start=289
_globals['_EMPTY']._serialized_end=296
_globals['_LEARNERSERVICE']._serialized_start=397
_globals['_LEARNERSERVICE']._serialized_end=654
_globals['_OBSERVATION']._serialized_start=289
_globals['_OBSERVATION']._serialized_end=366
_globals['_ACTIONS']._serialized_start=368
_globals['_ACTIONS']._serialized_end=391
_globals['_POLICYSETUP']._serialized_start=393
_globals['_POLICYSETUP']._serialized_end=420
_globals['_EMPTY']._serialized_start=422
_globals['_EMPTY']._serialized_end=429
_globals['_LEARNERSERVICE']._serialized_start=530
_globals['_LEARNERSERVICE']._serialized_end=787
_globals['_ASYNCINFERENCE']._serialized_start=790
_globals['_ASYNCINFERENCE']._serialized_end=1035
# @@protoc_insertion_point(module_scope)
+210 -1
View File
@@ -5,7 +5,7 @@ import warnings
from lerobot.transport import services_pb2 as lerobot_dot_transport_dot_services__pb2
GRPC_GENERATED_VERSION = '1.71.0'
GRPC_GENERATED_VERSION = '1.73.1'
GRPC_VERSION = grpc.__version__
_version_not_supported = False
@@ -231,3 +231,212 @@ class LearnerService:
timeout,
metadata,
_registered_method=True)
class AsyncInferenceStub:
"""AsyncInference: from Robot perspective
Robot send observations to & executes action received from a remote Policy server
"""
def __init__(self, channel):
"""Constructor.
Args:
channel: A grpc.Channel.
"""
self.SendObservations = channel.stream_unary(
'/transport.AsyncInference/SendObservations',
request_serializer=lerobot_dot_transport_dot_services__pb2.Observation.SerializeToString,
response_deserializer=lerobot_dot_transport_dot_services__pb2.Empty.FromString,
_registered_method=True)
self.GetActions = channel.unary_unary(
'/transport.AsyncInference/GetActions',
request_serializer=lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString,
response_deserializer=lerobot_dot_transport_dot_services__pb2.Actions.FromString,
_registered_method=True)
self.SendPolicyInstructions = channel.unary_unary(
'/transport.AsyncInference/SendPolicyInstructions',
request_serializer=lerobot_dot_transport_dot_services__pb2.PolicySetup.SerializeToString,
response_deserializer=lerobot_dot_transport_dot_services__pb2.Empty.FromString,
_registered_method=True)
self.Ready = channel.unary_unary(
'/transport.AsyncInference/Ready',
request_serializer=lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString,
response_deserializer=lerobot_dot_transport_dot_services__pb2.Empty.FromString,
_registered_method=True)
class AsyncInferenceServicer:
"""AsyncInference: from Robot perspective
Robot send observations to & executes action received from a remote Policy server
"""
def SendObservations(self, request_iterator, context):
"""Robot -> Policy to share observations with a remote inference server
Policy -> Robot to share actions predicted for given observations
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def GetActions(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def SendPolicyInstructions(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def Ready(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def add_AsyncInferenceServicer_to_server(servicer, server):
rpc_method_handlers = {
'SendObservations': grpc.stream_unary_rpc_method_handler(
servicer.SendObservations,
request_deserializer=lerobot_dot_transport_dot_services__pb2.Observation.FromString,
response_serializer=lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString,
),
'GetActions': grpc.unary_unary_rpc_method_handler(
servicer.GetActions,
request_deserializer=lerobot_dot_transport_dot_services__pb2.Empty.FromString,
response_serializer=lerobot_dot_transport_dot_services__pb2.Actions.SerializeToString,
),
'SendPolicyInstructions': grpc.unary_unary_rpc_method_handler(
servicer.SendPolicyInstructions,
request_deserializer=lerobot_dot_transport_dot_services__pb2.PolicySetup.FromString,
response_serializer=lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString,
),
'Ready': grpc.unary_unary_rpc_method_handler(
servicer.Ready,
request_deserializer=lerobot_dot_transport_dot_services__pb2.Empty.FromString,
response_serializer=lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'transport.AsyncInference', rpc_method_handlers)
server.add_generic_rpc_handlers((generic_handler,))
server.add_registered_method_handlers('transport.AsyncInference', rpc_method_handlers)
# This class is part of an EXPERIMENTAL API.
class AsyncInference:
"""AsyncInference: from Robot perspective
Robot send observations to & executes action received from a remote Policy server
"""
@staticmethod
def SendObservations(request_iterator,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.stream_unary(
request_iterator,
target,
'/transport.AsyncInference/SendObservations',
lerobot_dot_transport_dot_services__pb2.Observation.SerializeToString,
lerobot_dot_transport_dot_services__pb2.Empty.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def GetActions(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/transport.AsyncInference/GetActions',
lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString,
lerobot_dot_transport_dot_services__pb2.Actions.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def SendPolicyInstructions(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/transport.AsyncInference/SendPolicyInstructions',
lerobot_dot_transport_dot_services__pb2.PolicySetup.SerializeToString,
lerobot_dot_transport_dot_services__pb2.Empty.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def Ready(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/transport.AsyncInference/Ready',
lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString,
lerobot_dot_transport_dot_services__pb2.Empty.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
+6 -2
View File
@@ -19,7 +19,8 @@ import io
import json
import logging
import pickle # nosec B403: Safe usage for internal serialization only
from multiprocessing import Event, Queue
from multiprocessing import Event
from queue import Queue
from typing import Any
import torch
@@ -66,7 +67,7 @@ def send_bytes_in_chunks(buffer: bytes, message_class: Any, log_prefix: str = ""
logging_method(f"{log_prefix} Published {sent_bytes / 1024 / 1024} MB")
def receive_bytes_in_chunks(iterator, queue: Queue, shutdown_event: Event, log_prefix: str = ""): # type: ignore
def receive_bytes_in_chunks(iterator, queue: Queue | None, shutdown_event: Event, log_prefix: str = ""):
bytes_buffer = io.BytesIO()
step = 0
@@ -91,7 +92,10 @@ def receive_bytes_in_chunks(iterator, queue: Queue, shutdown_event: Event, log_p
bytes_buffer.write(item.data)
logging.debug(f"{log_prefix} Received data at step end size {bytes_buffer_size(bytes_buffer)}")
if queue is not None:
queue.put(bytes_buffer.getvalue())
else:
return bytes_buffer.getvalue()
bytes_buffer.seek(0)
bytes_buffer.truncate(0)
+4 -4
View File
@@ -54,8 +54,8 @@ def test_async_inference_e2e(monkeypatch):
from lerobot.scripts.server.policy_server import PolicyServer
from lerobot.scripts.server.robot_client import RobotClient
from lerobot.transport import (
async_inference_pb2, # type: ignore
async_inference_pb2_grpc, # type: ignore
services_pb2, # type: ignore
services_pb2_grpc, # type: ignore
)
from tests.mocks.mock_robot import MockRobotConfig
@@ -113,13 +113,13 @@ def test_async_inference_e2e(monkeypatch):
# Bypass potentially heavy model loading inside SendPolicyInstructions
def _fake_send_policy_instructions(self, request, context): # noqa: N802
return async_inference_pb2.Empty()
return services_pb2.Empty()
monkeypatch.setattr(PolicyServer, "SendPolicyInstructions", _fake_send_policy_instructions, raising=True)
# Build gRPC server running a PolicyServer
server = grpc.server(futures.ThreadPoolExecutor(max_workers=1, thread_name_prefix="policy_server"))
async_inference_pb2_grpc.add_AsyncInferenceServicer_to_server(policy_server, server)
services_pb2_grpc.add_AsyncInferenceServicer_to_server(policy_server, server)
# Use the host/port specified in the fixture's config
server_address = f"{policy_server.config.host}:{policy_server.config.port}"
+1 -1
View File
@@ -13,7 +13,7 @@
# limitations under the License.
"""Unit-tests for the `RobotClient` action-queue logic (pure Python, no gRPC).
We monkey-patch `lerobot.common.robot_devices.robots.utils.make_robot` so that
We monkey-patch `lerobot.robots.utils.make_robot_from_config` so that
no real hardware is accessed. Only the queue-update mechanism is verified.
"""
+3 -1
View File
@@ -104,12 +104,14 @@ def test_read():
assert isinstance(img, np.ndarray)
# TODO(Steven): Fix this test for the latest version of pyrealsense2.
@pytest.mark.skip("Skipping test: pyrealsense2 version > 2.55.1.6486")
def test_read_depth():
config = RealSenseCameraConfig(serial_number_or_name="042", width=640, height=480, fps=30, use_depth=True)
camera = RealSenseCamera(config)
camera.connect(warmup=False)
img = camera.read_depth(timeout_ms=1000) # NOTE(Steven): Reading depth takes longer
img = camera.read_depth(timeout_ms=2000) # NOTE(Steven): Reading depth takes longer in CI environments.
assert isinstance(img, np.ndarray)