mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-16 07:49:48 +00:00
Compare commits
11 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| b968020ec4 | |||
| fc019d3902 | |||
| 87242cfced | |||
| 1edc83a0ef | |||
| 6fbcf67249 | |||
| 41166b39fb | |||
| 79c6821407 | |||
| 507083249f | |||
| bd22407d93 | |||
| 49755a3d9e | |||
| 09808183ca |
Binary file not shown.
|
After Width: | Height: | Size: 445 KiB |
@@ -58,7 +58,7 @@ action = model.select_action(obs)
|
||||
robot.send_action(action)
|
||||
```
|
||||
|
||||
**Supported Hardware:** SO100, LeKiwi, Koch, HopeJR, OMX, EarthRover, Reachy2, Gamepads, Keyboards, Phones, OpenARM, Unitree G1.
|
||||
**Supported Hardware:** SO100, LeKiwi, Koch, HopeJR, OMX, EarthRover, Reachy2, Gamepads, Keyboards, Phones, OpenARM, Unitree G1, reBot B601.
|
||||
|
||||
While these devices are natively integrated into the LeRobot codebase, the library is designed to be extensible. You can easily implement the Robot interface to utilize LeRobot's data collection, training, and visualization tools for your own custom robot.
|
||||
|
||||
@@ -101,11 +101,13 @@ lerobot-train \
|
||||
--dataset.repo_id=lerobot/aloha_mobile_cabinet
|
||||
```
|
||||
|
||||
| Category | Models |
|
||||
| -------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| **Imitation Learning** | [ACT](./docs/source/policy_act_README.md), [Diffusion](./docs/source/policy_diffusion_README.md), [VQ-BeT](./docs/source/policy_vqbet_README.md), [Multitask DiT Policy](./docs/source/policy_multi_task_dit_README.md) |
|
||||
| **Reinforcement Learning** | [HIL-SERL](./docs/source/hilserl.mdx), [TDMPC](./docs/source/policy_tdmpc_README.md) & QC-FQL (coming soon) |
|
||||
| **VLAs Models** | [Pi0Fast](./docs/source/pi0fast.mdx), [Pi0.5](./docs/source/pi05.mdx), [GR00T N1.7](./docs/source/policy_groot_README.md), [SmolVLA](./docs/source/policy_smolvla_README.md), [XVLA](./docs/source/xvla.mdx) |
|
||||
| Category | Models |
|
||||
| -------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| **Imitation Learning** | [ACT](./docs/source/policy_act_README.md), [Diffusion](./docs/source/policy_diffusion_README.md), [VQ-BeT](./docs/source/policy_vqbet_README.md), [Multitask DiT Policy](./docs/source/policy_multi_task_dit_README.md) |
|
||||
| **Reinforcement Learning** | [HIL-SERL](./docs/source/hilserl.mdx), [TDMPC](./docs/source/policy_tdmpc_README.md) & QC-FQL (coming soon) |
|
||||
| **VLAs Models** | [Pi0](./docs/source/pi0.mdx), [Pi0Fast](./docs/source/pi0fast.mdx), [Pi0.5](./docs/source/pi05.mdx), [GR00T N1.5](./docs/source/policy_groot_README.md), [SmolVLA](./docs/source/policy_smolvla_README.md), [XVLA](./docs/source/xvla.mdx), [EO-1](./docs/source/eo1.mdx), [MolmoAct2](./docs/source/molmoact2.mdx), [WALL-OSS](./docs/source/walloss.mdx) |
|
||||
| **World Models** | [VLA-JEPA](./docs/source/vla_jepa.mdx) (more coming soon) |
|
||||
| **Reward Models** | [SARM](./docs/source/sarm.mdx), [TOPReward](./docs/source/topreward.mdx), [Robometer](./docs/source/robometer.mdx) |
|
||||
|
||||
Similarly to the hardware, you can easily implement your own policy & leverage LeRobot's data collection, training, and visualization tools, and share your model to the HF Hub
|
||||
|
||||
@@ -133,6 +135,7 @@ Learn how to implement your own simulation environment or benchmark and distribu
|
||||
- **[Discord](https://discord.gg/q8Dzzpym3f):** Join the `LeRobot` server to discuss with the community.
|
||||
- **[X](https://x.com/LeRobotHF):** Follow us on X to stay up-to-date with the latest developments.
|
||||
- **[Robot Learning Tutorial](https://huggingface.co/spaces/lerobot/robot-learning-tutorial):** A free, hands-on course to learn robot learning using LeRobot.
|
||||
- **[T-Shirt Folding Experiment](https://huggingface.co/spaces/lerobot/robot-folding):** An end-to-end demonstration of folding t-shirts with LeRobot.
|
||||
|
||||
## Citation
|
||||
|
||||
@@ -140,7 +143,7 @@ If you use LeRobot in your project, please cite the GitHub repository to acknowl
|
||||
|
||||
```bibtex
|
||||
@misc{cadene2024lerobot,
|
||||
author = {Cadene, Remi and Alibert, Simon and Soare, Alexander and Gallouedec, Quentin and Zouitine, Adil and Palma, Steven and Kooijmans, Pepijn and Aractingi, Michel and Shukor, Mustafa and Aubakirova, Dana and Russi, Martino and Capuano, Francesco and Pascal, Caroline and Choghari, Jade and Moss, Jess and Wolf, Thomas},
|
||||
author = {Cadene, Remi and Alibert, Simon and Soare, Alexander and Gallouedec, Quentin and Zouitine, Adil and Palma, Steven and Kooijmans, Pepijn and Aractingi, Michel and Shukor, Mustafa and Aubakirova, Dana and Russi, Martino and Capuano, Francesco and Pascal, Caroline and Choghari, Jade and Meftah, Khalil and Ellerbach, Maxime and Moss, Jess and Wolf, Thomas},
|
||||
title = {LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch},
|
||||
howpublished = "\url{https://github.com/huggingface/lerobot}",
|
||||
year = {2024}
|
||||
|
||||
@@ -0,0 +1,417 @@
|
||||
# Decoupled VLA Inference & Edge Control: System Design Proposal
|
||||
|
||||
## 1. Executive Summary
|
||||
|
||||
This document proposes a production-grade system for decoupling GPU-bound VLA (Vision-Language-Action) policy inference from high-frequency, CPU-bound robot control in LeRobot. The system adopts a **Model-as-a-Service (MaaS)** paradigm using **Zenoh** as the sole transport protocol, enabling multiple edge devices to be served by centralized GPU servers with minimal latency and high reliability.
|
||||
|
||||
An initial prototype exists in `src/lerobot/async_inference/` (gRPC-based, single-client). This proposal defines the target architecture, identifies gaps between the prototype and production requirements, documents known bugs, and establishes the design for the new system.
|
||||
|
||||
---
|
||||
|
||||
## 2. Motivation
|
||||
|
||||
LeRobot's standard control loop runs policy inference and robot I/O in the same process. This works for lightweight policies on local GPUs, but breaks down when:
|
||||
|
||||
- **The policy is too large for edge hardware** (e.g., Pi0 at ~3B parameters requires a dedicated GPU).
|
||||
- **Multiple robots need the same policy** (redundant GPU allocation per robot).
|
||||
- **Inference latency exceeds the control deadline** (e.g., 200ms inference on a 33ms control loop at 30 FPS).
|
||||
|
||||
Decoupling inference from control solves all three: the edge device runs a tight I/O loop on a CPU, while a GPU server handles inference for one or more clients.
|
||||
|
||||
---
|
||||
|
||||
## 3. Core Architectural Principles
|
||||
|
||||
### 3.1 Model-as-a-Service (MaaS)
|
||||
|
||||
Servers initialize models **once at startup** from a configuration manifest. Edge devices do **not** trigger dynamic model loading — they route to pre-warmed servers and validate compatibility via a status endpoint.
|
||||
|
||||
### 3.2 Multi-Tenant & Stateless Inference
|
||||
|
||||
A single GPU server handles multiple edge devices executing the same task. The server is stateless per inference call — `predict_action_chunk()` is a pure function with no side effects on the model. Client isolation is achieved through per-client observation slots and Zenoh key-expression routing.
|
||||
|
||||
> **Invariant**: `predict_action_chunk()` must remain a pure function (no mutation of `self`) for all supported policies. This is what enables safe multi-tenant sharing of a single model instance. This invariant must be documented and tested.
|
||||
|
||||
### 3.3 Zenoh as primary Transport
|
||||
|
||||
The system uses Zenoh's pub/sub model, replacing the current gRPC implementation. Zenoh provides:
|
||||
|
||||
- **Hierarchical key expressions** for routing (natural fit for the cluster/experiment/model/task topology).
|
||||
- **Built-in discovery** (no external service discovery needed).
|
||||
- **Non-blocking publish** for observations (fire-and-forget with best-effort QoS).
|
||||
- **Reliable delivery** configurable per-topic (required for action chunks).
|
||||
- **Shared-memory transport** for same-machine deployments (zero-copy) (if available).
|
||||
|
||||
### 3.4 Local Edge CPU
|
||||
|
||||
Edge devices rely on standard CPUs for sensor polling, image compression, payload serialization, motor control, and data logging. No edge-GPU dependency.
|
||||
|
||||
---
|
||||
|
||||
## 4. System Topology
|
||||
|
||||

|
||||
|
||||
- **Cluster**: A set of GPU machines. Identified by `cluster_uuid`.
|
||||
- **Experiment**: A logical grouping of servers and clients. Identified by `experiment_tag`.
|
||||
- **Server**: One model + one task, pre-warmed. Serves N clients for that model/task combination.
|
||||
- **Client**: One robot, one task. Publishes observations, subscribes to actions.
|
||||
|
||||
The number of clients a single server can handle is a **user decision** based on model inference time and acceptable latency.
|
||||
|
||||
---
|
||||
|
||||
## 5. Component Specifications
|
||||
|
||||
### 5.1 The Edge Device (Client)
|
||||
|
||||
**Responsibilities:**
|
||||
|
||||
1. **Observation capture**: Read sensors (cameras, motors) at the control loop frequency.
|
||||
2. **Image compression**: JPEG-encode RGB images before transmission.
|
||||
3. **Observation publishing**: Non-blocking Zenoh put to the observation topic.
|
||||
4. **Action subscription**: Zenoh callback receives action chunks, deposits into local buffer.
|
||||
5. **Action execution**: Pop actions from buffer, send to robot at control frequency.
|
||||
6. **Action blending**: When a new action chunk overlaps with the current buffer, blend via configurable aggregation function (weighted average, latest-only, etc.).
|
||||
7. **Latency compensation**: Calculate one-way latency from RTT, discard expired initial steps of incoming action chunks.
|
||||
8. **Fail-safe**: If action buffer empties, logs a warning.
|
||||
9. **Data logging**: Record raw observations and executed actions to local `LeRobotDataset` storage for deferred upload.
|
||||
|
||||
**Threading model:**
|
||||
|
||||
- **Control loop thread** (main): Capture observation → deposit in outbox → pop action from buffer → send to robot → sleep to maintain frequency.
|
||||
- **Zenoh action callback** (Zenoh-managed): Receives action chunks, processes RTT, trims stale steps, deposits into action buffer.
|
||||
- **Observation publisher thread**: Drains the outbox, compresses images, serializes, publishes via Zenoh.
|
||||
|
||||
> **Design note**: The current prototype blocks on `send_observation` inside the control loop (BUG-1, see Section 9). The new design decouples observation publishing from the control loop entirely, using a separate thread and Zenoh's non-blocking put.
|
||||
|
||||
### 5.2 The Inference Server (GPU Pod)
|
||||
|
||||
**Responsibilities:**
|
||||
|
||||
1. **Model pre-warming**: Load model and processor pipelines at startup from config manifest (including expected clients & policy parameters).
|
||||
2. **Status publishing**: Expose model capabilities (policy type, expected camera names, resolutions, action dimensions) via Zenoh queryable.
|
||||
3. **Observation subscription**: Subscribe to observation topics for all clients of this model/task. Maintain per-client observation slots (newest-only semantics).
|
||||
4. **Inference**: Single inference thread processes observations sequentially (round-robin across clients). Calls `policy.predict_action_chunk()`.
|
||||
5. **Action publishing**: Publish action chunks to per-client action topics with reliable QoS.
|
||||
|
||||
> **Thread safety**: PyTorch's `model.forward()` is not guaranteed thread-safe. Inference will be sequential, latency is mostly about the capabilities of the server to serve multiple requests.
|
||||
|
||||
---
|
||||
|
||||
## 6. Zenoh Routing & Key Expressions
|
||||
|
||||
### 6.1 Key Expression Schema
|
||||
|
||||
```
|
||||
[cluster_uuid] / [experiment_tag] / [model_id] / [model_version] / [application_tag] / [client_uuid] / [topic]
|
||||
```
|
||||
|
||||
**Example key expressions:**
|
||||
|
||||
| Key Expression | Direction | Purpose |
|
||||
| ------------------------------------------------ | ----------------- | ---------------------------------- |
|
||||
| `jupiter/fabio2/pi0/v1/cookie/robot_a4b9/obs` | Client → Server | Observation payload |
|
||||
| `jupiter/fabio2/pi0/v1/cookie/robot_a4b9/action` | Server → Client | Action chunk |
|
||||
| `jupiter/fabio2/pi0/v1/cookie/*/obs` | Server subscribes | All observations for pi0/v1/cookie |
|
||||
| `jupiter/fabio2/pi0/v1/cookie/status` | Server publishes | Model capabilities (queryable) |
|
||||
|
||||
### 6.2 QoS Configuration
|
||||
|
||||
| Topic | Reliability | Rationale |
|
||||
| -------- | ----------- | -------------------------------------------------------------------- |
|
||||
| `obs` | Best-effort | Dropping stale observations is expected behavior. |
|
||||
| `action` | Reliable | Every action chunk must be delivered; loss causes action starvation. |
|
||||
| `status` | Reliable | Client needs accurate capability info before starting. |
|
||||
|
||||
### 6.3 Discovery Flow
|
||||
|
||||
0. Server goes up with the static configuration.
|
||||
1. Client constructs its target key prefix: `cluster/experiment/model/version/task/`.
|
||||
2. Client queries `cluster/experiment/model/version/task/status` (Zenoh queryable).
|
||||
3. Server responds with its capabilities (expected camera names, image resolutions, action dimensions, model metadata).
|
||||
4. Client validates its own configuration against server capabilities.
|
||||
5. On match: client starts publishing observations and subscribing to actions.
|
||||
6. On mismatch: client logs an error and refuses to start.
|
||||
|
||||
No dynamic client discovery for now.
|
||||
|
||||
---
|
||||
|
||||
## 7. Message Schema
|
||||
|
||||
### 7.1 Observation Payload (Client → Server)
|
||||
|
||||
| Field | Type | Purpose |
|
||||
| ------------- | ------------------ | ----------------------------------------------------------- |
|
||||
| `seq_id` | `uint64` | Incrementing ID for causality tracking and RTT computation. |
|
||||
| `client_uuid` | `string` | Identifies the sending client. |
|
||||
| `state` | `bytes` | Proprioceptive state vector (`numpy.tobytes()`). |
|
||||
| `images` | `dict[str, bytes]` | JPEG-compressed camera images, keyed by camera name. |
|
||||
| `task` | `string` | Natural-language task instruction (for VLA conditioning). |
|
||||
|
||||
### 7.2 Action Payload (Server → Client)
|
||||
|
||||
| Field | Type | Purpose |
|
||||
| -------------------- | --------- | --------------------------------------------------------------- |
|
||||
| `response_to_seq_id` | `uint64` | Echoes the observation `seq_id` this action corresponds to. |
|
||||
| `inference_time_ms` | `float32` | Server-side compute duration (for edge RTT math). |
|
||||
| `actions` | `bytes` | Action chunk as numpy array bytes (`(chunk_size, action_dim)`). |
|
||||
|
||||
### 7.3 Status Payload (Server, Queryable)
|
||||
|
||||
| Field | Type | Purpose |
|
||||
| ----------------------- | ------------------- | ------------------------------------------ |
|
||||
| `model_id` | `string` | Policy identifier (e.g., `pi0`). |
|
||||
| `model_version` | `string` | Model version or checkpoint path. |
|
||||
| `expected_cameras` | `dict[str, (H, W)]` | Expected camera names and shapes. |
|
||||
| `action_dim` | `int` | Dimensionality of the action space. |
|
||||
| `max_actions_per_chunk` | `int` | Maximum chunk size the model supports. |
|
||||
| `observation_features` | `dict` | Full feature specification for validation. |
|
||||
|
||||
### 7.4 Serialization Format
|
||||
|
||||
**MessagePack** for all structured metadata (compact, fast, cross-language). Image payloads are raw JPEG bytes embedded in the MessagePack structure. State vectors use `numpy.tobytes()` with shape/dtype metadata for zero-copy reconstruction.
|
||||
|
||||
**No pickle.** The current prototype uses `pickle.dumps`/`pickle.loads` throughout, which allows arbitrary code execution. This is replaced entirely.
|
||||
|
||||
---
|
||||
|
||||
## 8. Latency Compensation
|
||||
|
||||
### 8.1 RTT Calculation
|
||||
|
||||
The edge device tracks in-flight observations:
|
||||
|
||||
```python
|
||||
in_flight: dict[int, float] = {} # seq_id -> time.perf_counter() at send
|
||||
|
||||
# On send:
|
||||
in_flight[seq_id] = time.perf_counter()
|
||||
|
||||
# On receive action chunk:
|
||||
rtt = time.perf_counter() - in_flight[response_to_seq_id]
|
||||
# delete older keys than the one received
|
||||
```
|
||||
|
||||
> **Important**: Delete only the exact `response_to_seq_id` key from `in_flight`, not all keys `<= response_to_seq_id`. With Zenoh's best-effort transport, messages can arrive out of order. Clearing earlier keys would make their RTT unmeasurable.
|
||||
|
||||
### 8.2 Stale Action Trimming
|
||||
|
||||
When an action chunk arrives, the edge calculates how many initial steps have already expired:
|
||||
|
||||
```python
|
||||
expired_steps = int(rtt / environment_dt)
|
||||
valid_actions = action_chunk[expired_steps:]
|
||||
```
|
||||
|
||||
The valid actions are then blended into the action buffer using the configured aggregation function.
|
||||
|
||||
### 8.3 Edge Cases
|
||||
|
||||
| Scenario | Behavior |
|
||||
| -------------------------------------- | -------------------------------------------------------------------------------------- |
|
||||
| **First observation** (no RTT history) | Apply all action steps without trimming. |
|
||||
| **Dropped observations** | Server infers on next received observation. No special handling needed. |
|
||||
| **Dropped action chunks** | Edge continues executing current buffer. If buffer empties, warn & hold last position. |
|
||||
| **Server crash** | Edge exhausts buffer, holds position, warns & re-validates via status query. |
|
||||
|
||||
> **Assumption**: All currently supported robots are position-controlled (SO100, SO101, OMX). For velocity-controlled robots, the fail-safe must send zero-velocity instead of holding position. This should be configurable per-robot.
|
||||
|
||||
---
|
||||
|
||||
## 9. Known Bugs in Current Prototype
|
||||
|
||||
These issues exist in `src/lerobot/async_inference/` and must be addressed in the new implementation.
|
||||
|
||||
### BUG-1: `send_observation` Blocks the Control Loop (Critical)
|
||||
|
||||
**Location**: `robot_client.py:207`
|
||||
|
||||
`self.stub.SendObservations(observation_iterator)` is a synchronous gRPC call inside the 33ms control loop. For multi-camera observations (several MB after pickle), this consumes 10-20ms on the network, leaving no headroom for sensor capture and motor commands. The robot stutters.
|
||||
|
||||
**Resolution in new design**: Observation publishing is moved to a dedicated thread. Zenoh's `session.put()` is non-blocking by default. The control loop only deposits observations into a local outbox.
|
||||
|
||||
### BUG-2: Race Condition in Action Queue Aggregation (Correctness)
|
||||
|
||||
**Location**: `robot_client.py:236-267`
|
||||
|
||||
The lock on `self.action_queue` is acquired to read `internal_queue = self.action_queue.queue` (a reference to the internal deque), then **released** at line 238. The aggregation logic iterates over this reference outside the lock. Meanwhile, the control loop thread can `get_nowait()` from the same queue, mutating the deque during iteration. At line 267, the entire queue is replaced, but actions popped between 238-267 are silently lost.
|
||||
|
||||
**Fix**: Either hold the lock for the entire aggregation, or `list(self.action_queue.queue)` to copy contents before releasing.
|
||||
|
||||
### BUG-3: No RPC Deadlines (Reliability)
|
||||
|
||||
**Location**: `robot_client.py:278`
|
||||
|
||||
`GetActions` blocks indefinitely if the server hangs (GPU OOM, deadlock). The retry policy handles `UNAVAILABLE` but not a hung connection.
|
||||
|
||||
**Resolution in new design**: The polling `GetActions` pattern is replaced by Zenoh subscription callbacks. The client needs a watchdog timer or check when action queue is empty: if no actions are received for `T` seconds, trigger re-validation via the status service.
|
||||
|
||||
### BUG-4: Similarity Check Ignores Images (Correctness for VLAs)
|
||||
|
||||
**Location**: `helpers.py:280-297`
|
||||
|
||||
`observations_similar()` + `must_go` is a workaround for current architecure limitations to avoid filling up the server queue the first seconds of the task & the robot remaining idle.
|
||||
|
||||
**Resolution in new design**: the server always processes the latest observation per client in its inference loop, and doesn't need similarity gating at all. The client can always push.
|
||||
|
||||
---
|
||||
|
||||
## 10. Gaps Between Prototype and Target Architecture
|
||||
|
||||
### 10.1 Critical (Must Address)
|
||||
|
||||
| # | Gap | Current State | Target State |
|
||||
| --- | ------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| G1 | **Single-client server** | One `observation_queue(maxsize=1)`, one `last_processed_obs`, one `_predicted_timesteps`. `_reset_server()` flushes all state on any new connection. | Per-client state (`ClientState` dataclass) keyed by `client_uuid`. Zenoh key-expression routing provides client isolation. |
|
||||
| G2 | **Dynamic model loading** | Client sends `RemotePolicyConfig` → server calls `from_pretrained()` on demand. | Server loads models at startup from config manifest. `SendPolicyInstructions` RPC eliminated. Client validates via status query. |
|
||||
| G3 | **gRPC transport** | Entire `transport/` directory: proto definitions, generated stubs, chunking utils. 4 RPCs: `Ready`, `SendPolicyInstructions`, `SendObservations`, `GetActions`. | Zenoh pub/sub. Client publishes obs, subscribes to actions. Server subscribes to obs, publishes actions. Dispatching via key expressions. |
|
||||
| G4 | **Pickle serialization** | `pickle.dumps`/`pickle.loads` throughout (arbitrary code execution risk, `# nosec` suppression). | MessagePack for structured metadata + raw JPEG bytes for images + `numpy.tobytes()` for state vectors. |
|
||||
|
||||
### 10.2 Important
|
||||
|
||||
| # | Gap | Current State | Target State |
|
||||
| --- | -------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | -------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| G5 | **No RTT/latency compensation** | No `seq_id`, no `response_to_seq_id`, no `inference_time_ms`. Timestamps use `time.time()` (unreliable across machines). | Edge-local `perf_counter` + echoed `seq_id` + server inference duration. Stale action step trimming. |
|
||||
| G6 | **No hierarchical routing** | Direct gRPC channel to `host:port`. | Zenoh key expressions: `cluster/experiment/model/version/task/client/topic`. |
|
||||
| G7 | **No data logging** | `control_loop` has access to obs and actions but doesn't persist them. | Edge records via `LeRobotDataset` (`build_dataset_frame` + `dataset.add_frame`). |
|
||||
| G8 | **No authentication** | `grpc.insecure_channel`. | Zenoh TLS + access control lists on key expressions. |
|
||||
| G9 | **ProcessorPipeline divergence** | Server reimplements observation prep in `helpers.py` (custom `resize_robot_observation_image` with `F.interpolate` bilinear). Diverges from standard `RobotProcessorPipeline`. | Use the standard `RobotProcessorPipeline` + `build_dataset_frame` to ensure behavioral equivalence between record and async inference. |
|
||||
|
||||
### 10.3 Nice-to-Have
|
||||
|
||||
| # | Gap | Current State | Target State |
|
||||
| --- | ------------------------------------- | --------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| G11 | **No status/discovery service** | Bare `Ready()` ping. | Zenoh queryable at `cluster/exp/model/version/task/status`. |
|
||||
| G12 | **No monitoring** | `FPSTracker` + `logging.debug`. | Structured metrics via Zenoh telemetry topics. Wildcard subscriptions for centralized monitoring. |
|
||||
| G13 | **No entry points** | Module-level `__main__`. | `lerobot-policy-server` and `lerobot-robot-client` console scripts in `pyproject.toml`. |
|
||||
| G14 | **Ratio-based observation threshold** | `chunk_size_threshold` (0-1 ratio of queue fill). Scales oddly with different `actions_per_chunk` values. | Absolute time threshold: `buffer_time_s` calibrated to observed RTT. Send observation when `queue_size * environment_dt < buffer_time_s`. |
|
||||
|
||||
---
|
||||
|
||||
## 11. Design Decisions & Rationale
|
||||
|
||||
### 11.1 Why Zenoh Over gRPC
|
||||
|
||||
| Aspect | Zenoh | gRPC |
|
||||
| ------------------------- | -------------------------------------------------------------------------- | ---------------------------------------------------------------------------------- |
|
||||
| Communication model | Pub/sub — natural fit for "client publishes obs, server publishes actions" | Request/response — requires polling (`GetActions` loop) or bidirectional streaming |
|
||||
| Multi-tenant routing | Hierarchical key expressions provide built-in per-client topic isolation | Requires manual per-client channel/stream management |
|
||||
| Discovery | Built-in discovery | Requires external service (mDNS, Consul, etc.) |
|
||||
| Observation publishing | Non-blocking put (fire-and-forget) — resolves BUG-1 automatically | Synchronous stream-unary call — blocks the control loop |
|
||||
| Same-machine optimization | Shared-memory transport (zero-copy) | Loopback TCP |
|
||||
| Telemetry | Wildcard subscriptions (`+/+/+/+/+/metrics`) | Requires separate monitoring infrastructure |
|
||||
|
||||
**Tradeoffs of going Zenoh-only:**
|
||||
|
||||
- Smaller community, less tooling for monitoring/tracing vs. gRPC's mature ecosystem.
|
||||
- No built-in schema enforcement (Zenoh sends raw bytes) — serialization correctness is entirely on us.
|
||||
- Default QoS is best-effort (like UDP). Must explicitly configure reliable delivery for action chunks.
|
||||
- `zenoh-python` bindings are less battle-tested than `grpcio`. Needs integration testing under network stress.
|
||||
|
||||
### 11.2 Why Single Inference Thread (Not Batching)
|
||||
|
||||
True GPU batching across clients requires collecting observations from multiple clients and running a single forward pass. This is difficult because:
|
||||
|
||||
- Clients send observations at different times — waiting to batch adds latency.
|
||||
- Different clients may have slightly different image resolutions.
|
||||
- Error in one client's observation shouldn't affect others.
|
||||
|
||||
**Decision**: Start with sequential processing (single inference thread, round-robin across clients). Profile GPU utilization.
|
||||
|
||||
### 11.4 Why MessagePack (Not Protobuf, Not FlatBuffers)
|
||||
|
||||
- **Protobuf**: Strong schema enforcement but heavier toolchain (proto compilation, generated code). Since we're dropping gRPC, the protobuf dependency becomes unnecessary overhead.
|
||||
- **MessagePack**: Fast, compact, schema-less (enforced by application), excellent Python support (`msgpack` package), good for nested dicts with mixed types. Natural fit for observation/action payloads.
|
||||
|
||||
Images are embedded as raw JPEG bytes within the MessagePack structure. State vectors use `numpy.tobytes()` with shape/dtype metadata for zero-copy reconstruction.
|
||||
|
||||
### 11.5 Action Aggregation Strategy
|
||||
|
||||
When a new action chunk overlaps with the existing buffer, the overlapping timesteps must be blended. The current prototype supports configurable aggregation functions:
|
||||
|
||||
| Function | Formula | Character |
|
||||
| ------------------ | ----------------------- | ------------------------------------------ |
|
||||
| `weighted_average` | `0.3 * old + 0.7 * new` | Smooth transitions, favors new predictions |
|
||||
| `latest_only` | `new` | Most responsive, can cause discontinuities |
|
||||
| `average` | `0.5 * old + 0.5 * new` | Equal weight |
|
||||
| `conservative` | `0.7 * old + 0.3 * new` | Smooth, slow to adapt |
|
||||
|
||||
Ultimately, this should be the user's decision. Default to `weighted_average`. The goal of async is not to do temporal ensembling, but to provide a solution when we want to decouple inference and execution.
|
||||
|
||||
---
|
||||
|
||||
## 12. Configuration
|
||||
|
||||
### 12.1 Server Configuration (Manifest)
|
||||
|
||||
Servers are configured via a YAML manifest that declares which models to pre-warm & clients to serve:
|
||||
|
||||
```yaml
|
||||
cluster_uuid: jupiter
|
||||
experiment_tag: fabio2
|
||||
server:
|
||||
- model_id: pi0
|
||||
model_version: v1
|
||||
pretrained_path: lerobot/pi0-cookie-v1
|
||||
application_tag: cookie
|
||||
device: cuda:0
|
||||
fps: 30
|
||||
endpoint: tcp/192.168.1.50:7447
|
||||
clients:
|
||||
- client_uuid: cookie-worker-4269
|
||||
```
|
||||
|
||||
### 12.2 Client Configuration
|
||||
|
||||
Clients are configured via draccus dataclass (CLI-compatible):
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class AsyncClientConfig:
|
||||
# Zenoh routing
|
||||
cluster_uuid: str
|
||||
experiment_tag: str
|
||||
model_id: str
|
||||
model_version: str
|
||||
application_tag: str
|
||||
client_uuid: str
|
||||
endpoint: str
|
||||
|
||||
# Robot
|
||||
robot: RobotConfig
|
||||
|
||||
# Control
|
||||
fps: int = 30
|
||||
actions_per_chunk: int = 50
|
||||
aggregate_fn_name: str = "weighted_average"
|
||||
jpeg_quality: int = 90
|
||||
|
||||
# Fail-safe
|
||||
max_empty_cycles_before_warning: int = 10
|
||||
|
||||
# Datset recording
|
||||
dataset_repo_id: str | None = None # None = no logging
|
||||
|
||||
# Task
|
||||
task: str = ""
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 14. Data Logging Integration
|
||||
|
||||
The client records observations and executed actions into a local `LeRobotDataset` for deferred upload to the training dataset:
|
||||
|
||||
```python
|
||||
# In control_loop, after executing an action:
|
||||
if self.dataset is not None:
|
||||
frame = build_dataset_frame(
|
||||
self.dataset.features,
|
||||
processed_observation,
|
||||
prefix=OBS_STR,
|
||||
)
|
||||
frame["action"] = executed_action_tensor
|
||||
self.dataset.add_frame(frame)
|
||||
```
|
||||
@@ -0,0 +1,498 @@
|
||||
# Decoupled VLA Inference & Edge Control v2: Async Network Inference for `lerobot-rollout`
|
||||
|
||||
> **Status**: supersedes the v1 proposal in full. v1 was written against the standalone `src/lerobot/async_inference/` prototype, before `lerobot-rollout` existed. This revision re-grounds the design in the current codebase, keeps v1's decisions that survived contact with it (marked **KEPT** throughout), reverses the ones that didn't, and adds the safety, multi-tenancy, and operations specifications v1 lacked.
|
||||
|
||||
## 1. Executive Summary
|
||||
|
||||
This document specifies a production-grade system for decoupling GPU-bound policy inference from high-frequency robot control, targeting power users running **hundreds of robots** against centralized GPU clusters. The system keeps v1's **Model-as-a-Service (MaaS)** paradigm and **Zenoh** transport, but changes the integration architecture fundamentally:
|
||||
|
||||
- **The client is not a standalone CLI.** It is `--inference.type=remote`, a new `InferenceEngine` backend inside `lerobot-rollout` (`src/lerobot/rollout/inference/`). Every rollout strategy (base, sentry, highlight, dagger, episodic) gets network inference for free — including dataset recording, DAgger pause/resume, Rerun visualization, and safe teardown.
|
||||
- **The client is weightless.** No policy weights, no policy processors on the edge. `--policy.path` resolves to a config-only `PreTrainedConfig` (no weight download) used for pre-flight validation and action ordering.
|
||||
- **The server is stateless per request.** All RTC chunk state (leftover prefixes, latency tracking, delay computation) lives client-side in the existing `ActionQueue`/`LatencyTracker` machinery — the client ships prefixes + a delay hint with each observation. A server crash loses zero control state; reconnects and horizontal scaling are trivial.
|
||||
- **Multi-tenancy is engineered, not assumed.** The real hazards are stateful processor pipelines and episode-scoped policy state — not `predict_action_chunk` purity (which holds for ACT/Pi0/Pi0.5/SmolVLA but _not_ diffusion). The server uses per-session processor instances, a chunk-stateless allowlist, and an exclusive serving mode for policies that need it.
|
||||
- **The legacy module dies.** `src/lerobot/async_inference/` (~1,900 lines, pickle-over-gRPC, single-client, four confirmed bugs) is deleted in the same PR that lands the new backend. No deprecation cycle: the module is experimental, its CLI undocumented in the main flow, and every config field has a mapped successor (§13.4).
|
||||
|
||||
---
|
||||
|
||||
## 2. Motivation (unchanged from v1) — **KEPT**
|
||||
|
||||
LeRobot's standard control loop runs policy inference and robot I/O in the same process. This breaks down when:
|
||||
|
||||
- **The policy is too large for edge hardware** (Pi0-class models need a dedicated GPU).
|
||||
- **Multiple robots need the same policy** (redundant GPU allocation per robot).
|
||||
- **Inference latency exceeds the control deadline** (e.g. 150 ms inference on a 33 ms control tick).
|
||||
|
||||
Decoupling solves all three: the edge runs a tight CPU loop; a GPU server performs inference for N clients.
|
||||
|
||||
What changed since v1: the _local_ version of this decoupling already shipped. `RTCInferenceEngine` (`src/lerobot/rollout/inference/rtc.py`) runs inference in a background thread against a thread-safe `ActionQueue` with latency-aware chunk merging. **The network system is that same architecture with the thread boundary replaced by a network boundary.** This is the design's central simplification: reuse, don't reinvent.
|
||||
|
||||
---
|
||||
|
||||
## 3. Gap Analysis: v1 Proposal vs. Modern Codebase
|
||||
|
||||
| Topic | v1 assumed | Modern reality | Verdict |
|
||||
| ----------------------------------------- | --------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------- | --------------------------------------- |
|
||||
| Client architecture | Standalone robot-client CLI (§5.1 of v1) | `InferenceEngine` ABC seam in `lerobot-rollout` (`rollout/inference/base.py`); strategies are backend-agnostic | **Superseded** — backend, not CLI |
|
||||
| Chunk blending | Configurable aggregation zoo (`weighted_average`, …) | `ActionQueue` replace-with-delay-trim (RTC) / append (non-RTC) (`policies/rtc/action_queue.py:147-217`) | **Superseded** — drop blending entirely |
|
||||
| Latency compensation | Hand-rolled RTT trim (`expired_steps = int(rtt/dt)`, v1 §8.2) | `ActionQueue.merge(..., real_delay, idx_before)` + `LatencyTracker` already do this, validated | **Superseded** |
|
||||
| Multi-tenancy invariant | "`predict_action_chunk()` pure ⇒ safe to share" | Processor state + episode-scoped policy state are the real hazards (§7) | **Incomplete** — fixed in §8.3 |
|
||||
| Data logging | Client-side `build_dataset_frame` + `add_frame` sketch (v1 §14) | Recording strategies (sentry/episodic/dagger) already log obs + executed actions | **Superseded** — free via rollout |
|
||||
| MaaS pre-warm, no dynamic loading | ✓ | Still right; legacy `SendPolicyInstructions` is a pickle/RCE + capacity-planning disaster | **KEPT** |
|
||||
| JPEG observation compression | ✓ | Still right (§10.1) | **KEPT** |
|
||||
| Status/capability validation before start | ✓ (Zenoh queryable) | Still right; extended into a hard sync-safety contract (§8.4) | **KEPT, extended** |
|
||||
| Time-based send threshold (v1 G14) | ✓ | Adopted as `buffer_time_s` | **KEPT** |
|
||||
| Zenoh pub/sub data plane | ✓ | Confirmed; QoS corrected (§6.3), control plane moved to queryables, liveliness added | **KEPT, hardened** |
|
||||
| MessagePack serialization | ✓ | Endorsed (zenoh's `ext` serializer cannot encode numpy); must be version-gated (§10.4) | **KEPT, with schema discipline** |
|
||||
| QoS table (v1 §6.2) | "obs best-effort, actions reliable" | Conflates transport reliability with congestion control; BLOCK on actions is dangerous | **Revised** (§6.3) |
|
||||
| Bugs BUG-1…BUG-4, gaps G1…G14 | Listed as work items | Every one resolved _structurally_ by this design (§13.5 mapping) | **Resolved by design** |
|
||||
|
||||
---
|
||||
|
||||
## 4. Critical Pushbacks on v1
|
||||
|
||||
Each pushback: claim → evidence → consequence for this design.
|
||||
|
||||
**P1 — A standalone client duplicates `lerobot-rollout`.**
|
||||
v1 §5.1 assigns the client: observation capture, action execution at frequency, fail-safe, data logging. Every one of those is already owned by rollout strategies and `send_next_action` (`rollout/strategies/core.py:269-304`), which tolerates `None` actions, runs the interpolator, and routes through the canonical robot processors. A standalone client re-implements loop timing, recording, DAgger UX, Rerun, and teardown safety — and then drifts. _Consequence_: the client is `RemoteInferenceEngine`, registered as `--inference.type=remote` next to `sync` and `rtc`.
|
||||
|
||||
**P2 — The aggregation-function zoo fabricates actions no policy predicted.**
|
||||
`0.3*old + 0.7*new` produces hybrid actions that exist in no policy's output distribution; the logged action becomes unexplainable (bad for the reproducibility story) and the implementation hosted a real lock-release race (BUG-2, `async_inference/robot_client.py:236-267`). RTC's prefix-conditioned chunk generation is the principled mechanism for smooth chunk transitions; plain append covers non-RTC chunking. _Consequence_: `ActionQueue` replace/append are the only two merge semantics. The zoo is deleted.
|
||||
|
||||
**P3 — "predict_action_chunk pure ⇒ multi-tenant safe" is incomplete.**
|
||||
Verified in-tree: (a) `RelativeActionsProcessorStep` caches `_last_state` at preprocess (`processor/relative_action_processor.py:131`) and the postprocessor reads it back (`:189`) — a shared pipeline across clients is a race; (b) `DiffusionPolicy.predict_action_chunk` reads `self._queues`, which only `select_action` populates (`policies/diffusion/modeling_diffusion.py:90-108`) — it is **not** chunk-stateless; (c) SAC/SARM have no `predict_action_chunk` at all. _Consequence_: per-session processor instances (mandatory), a chunk-stateless allowlist, `serving_mode: exclusive` for diffusion-family, refusal at startup for SAC/SARM, and `policy.reset()` is **never** called in shared mode (§8.3).
|
||||
|
||||
**P4 — v1 re-derives latency compensation that already exists, on top of broken clocks.**
|
||||
v1 §8 specifies an in-flight RTT dict and manual stale-step trimming. `ActionQueue.merge(original, processed, real_delay, idx_before)` already trims `real_delay` stale steps and cross-validates against actions consumed in flight (`action_queue.py:219-246`). Worse, the legacy code compares wall clocks across machines (`robot_client.py:420` stamps `time.time()` "to compare timestamps across client and server"; `policy_server.py:178` compares it) — NTP skew is the same order as the latencies being measured. _Consequence_: the **monotonic iron rule** (§11): instants never cross machines; client timestamps are opaque echoed tokens; servers report only durations. `delay_steps = ceil((rtt + inference)/dt)` is computed client-side from client-local `perf_counter` samples and shipped per request.
|
||||
|
||||
**P5 — One-in-flight per client is a correctness requirement, not a tuning choice.**
|
||||
At send time the client snapshots `idx_before = queue.get_action_index()` and the leftover prefixes; `merge` validates against them. Two in-flight requests carry conflicting snapshots — the second merge corrupts both RTC replace mode and append mode. The local RTC thread is also strictly one-inference-at-a-time; one-in-flight preserves exact parity. _Consequence_: the worker publishes one observation, waits for its chunk (or timeout), then sends the next. v1 §8.1's out-of-order in-flight dict is dead weight; a late chunk is accepted only if it answers the _latest_ outstanding `seq_id`, otherwise dropped.
|
||||
|
||||
**P6 — v1's QoS table conflates transport reliability with congestion behavior.**
|
||||
"Reliable delivery for actions" sounds right but the dangerous knob is congestion control: a publisher configured `BLOCK` on the action topic can stall the **server's** publish path on one robot's dead uplink (Zenoh blocks up to `wait_before_close`, then may close the transport). A dropped action chunk is _recoverable by design_ — the client's queue keeps the robot moving and the next chunk replaces it. _Consequence_ (§6.3): actions = `reliability=RELIABLE` (hop-level) + `congestion_control=DROP` + `express=True` + `priority=INTERACTIVE_HIGH`; observations = `DROP` + `DATA`. If WAN loss proves material, upgrade the action topic to Zenoh Advanced Pub/Sub (cache + recovery, zenoh ≥ 1.5) rather than BLOCK.
|
||||
|
||||
**P7 — Schema-less MessagePack invites silent version drift across a 300-robot fleet.**
|
||||
msgpack stays (zenoh's `ext` serializer cannot encode numpy/dataclasses, and the team's choice stands), but naked msgpack dicts across heterogeneous fleet versions fail at runtime, on the robot. _Consequence_ (§10.4): a packed little-endian **attachment header** (`schema_version`, `seq_id`, `episode_id`, `client_mono_ns` — the rmw_zenoh pattern) so routing/correlation never deserializes the body; `schema_version` negotiated at the session handshake; additive-only evolution; golden codec tests. Protobuf-over-ZBytes is the documented fallback if drift bites in practice.
|
||||
|
||||
**P8 — "Deterministic rollout reproducibility" is unattainable on real robots.**
|
||||
No seed controls hardware, sensor noise, or network jitter; RTC's latency-driven trimming is inherently timing-dependent. _Consequence_: the contract is **fully logged + replayable** (§12): recording strategies already persist observations and executed actions; the remote engine adds `(session_id, seq_id, episode_id)` provenance so client datasets join server audit logs mechanically.
|
||||
|
||||
**P9 — v1 has no safety specification.**
|
||||
"Log a warning when the buffer empties" is not a fail-safe for a 300-robot fleet. _Consequence_ (§9): a staleness bound (`max_action_age_s` — never execute an action older than X relative to its source observation), an explicit fallback ladder (`hold` / `repeat_last` / `zero` — zero-command required for future velocity-controlled robots), and a DEAD state that triggers the existing strategy shutdown path (return-to-initial-pose, disconnect) via the same `shutdown_event` mechanism RTC uses (`rtc.py:359-360`).
|
||||
|
||||
**P10 — Capacity must be formula-driven, not "a user decision".**
|
||||
v1 §4 says clients-per-server "is a user decision". With `t` = server time per request, `r` = per-client request rate, `H` = RTC execution horizon, `dt` = control period:
|
||||
`N_max = min( 0.8 / (r·t), (H·dt/2 − RTT_net) / t )`
|
||||
→ ACT @ 20 ms, 1 Hz: ~40 clients/GPU. Pi0 @ 150 ms, 1 Hz: ~5 clients/GPU. 300 robots on Pi0 ≈ 60 GPU pods. _Consequence_: the manifest carries `max_sessions`; the server rejects session opens beyond it (with current load in the reply) so clients retry another replica. Micro-batching is deferred — blocked on a real API issue (`predict_action_chunk` takes a _scalar_ `inference_delay`; batched clients have different delays) — behind a `Scheduler` seam so it can land later without redesign (§8.5).
|
||||
|
||||
**P11 — Discovery ≠ multicast.**
|
||||
Zenoh's multicast scouting does not cross WAN, NAT, or most k8s CNIs. _Consequence_: multicast scouting disabled; clients use static `connect.endpoints` (DNS name of the router) + gossip; presence and liveness come from Zenoh **liveliness tokens** (§6.4), not discovery. "Discovery" for a robot fleet is configuration.
|
||||
|
||||
---
|
||||
|
||||
## 5. System Topology
|
||||
|
||||

|
||||
_(Diagram unchanged from v1 — the topology survives; transport/QoS/session details in it are superseded by §6.)_
|
||||
|
||||
- **Router tier**: one or more `zenohd` routers (k8s Deployment + Service, TLS on 7447). Robots **dial out** to the router (NAT-friendly: labs only need outbound 7447/443). GPU servers join as peers via cluster DNS.
|
||||
- **Server**: one process = one `(model_repo, revision, dtype, device)` on one GPU, pre-warmed from a YAML manifest (**KEPT** from v1, amended: `pin_task: bool` — VLA prompts may vary per session unless pinned).
|
||||
- **Client**: one robot running `lerobot-rollout --inference.type=remote`. Weightless: config-only policy metadata.
|
||||
- **Identity**: `client_uuid` per robot; `session_id` per connection epoch; both in every log line on both sides.
|
||||
|
||||
---
|
||||
|
||||
## 6. Zenoh Design
|
||||
|
||||
All Zenoh claims below were verified against zenoh / zenoh-python 1.x (eclipse-zenoh 1.9.0). Pin: `eclipse-zenoh>=1.9,<2.0`; keep `zenohd` on the same minor as the Python binding. Wheels cover manylinux x86_64/aarch64/armv7l/armv6l + macOS — Raspberry Pi edge clients are covered.
|
||||
|
||||
### 6.1 Key-expression schema
|
||||
|
||||
```
|
||||
@lerobot/<model_id>/<revision>/<task_slug>/<client_uuid>/obs client → server
|
||||
@lerobot/<model_id>/<revision>/<task_slug>/<client_uuid>/action server → client
|
||||
@lerobot/<model_id>/<revision>/<task_slug>/status queryable (capabilities)
|
||||
@lerobot/<model_id>/<revision>/<task_slug>/session queryable (open/validate)
|
||||
@lerobot/<model_id>/<revision>/<task_slug>/<client_uuid>/reset queryable (episode boundary)
|
||||
@lerobot/<model_id>/<revision>/<task_slug>/<client_uuid>/alive liveliness token (client)
|
||||
@lerobot/<model_id>/<revision>/<task_slug>/server/alive liveliness token (server)
|
||||
```
|
||||
|
||||
Rules (hard, enforced by a `sanitize_keyexpr()` helper):
|
||||
|
||||
- Root at the **verbatim chunk** `@lerobot` — verbatim chunks are only matched by identical chunks, so third-party `**` subscribers on a shared router can never scrape the tree.
|
||||
- Sanitize every user-supplied segment (model ids, task strings, uuids): non-empty, no `* $ ? # /`, no leading/trailing/double `/`. A task string containing `/` must be slugified before it becomes a key chunk.
|
||||
- Server subscribes with a **single-depth** wildcard (`.../*/obs`) — never `**` (it would also match `status`, `alive`, …).
|
||||
- v1's `cluster/experiment` prefix segments are dropped from the key schema; they return as free-form `tags` metadata in the session handshake (telemetry/labeling, not routing). Routing topology belongs to deployment (which router you dial), not to key depth.
|
||||
|
||||
### 6.2 Data plane vs. control plane (the rmw_zenoh split)
|
||||
|
||||
- **Data plane = pub/sub** (KEPT from v1): observations up, action chunks down, correlated by `seq_id` in **attachments** (§10.4). Pub/sub rather than query-per-inference because: a timed-out query's late reply is _dropped by the transport_ (wasted inference), whereas a late pub/sub chunk is still mergeable if it answers the latest outstanding seq; and pub/sub leaves room for server-initiated messages (drain notices). The one-in-flight discipline (P5) is enforced in the client worker, not by the transport.
|
||||
- **Control plane = queryables** (request/reply with explicit timeouts; the pattern rmw*zenoh uses for ROS 2 services): `status` (pre-flight capability fetch, 2 s timeout), `session` (open/validate → ack with capabilities + `session_id`), `reset` (episode boundary — \_acknowledged*, so episodic strategies know the server-side episode state is clean). Always pass an explicit `timeout` to `session.get()` — the config default is 10 s, far too long for our watchdogs.
|
||||
- **Episode ordering**: under one-in-flight there is no obs/reset race window in the data plane, but as belt-and-braces the first observation of each episode also carries `episode_start=True` + the new `episode_id` in its header.
|
||||
|
||||
### 6.3 QoS (revised from v1 §6.2 — see P6)
|
||||
|
||||
| Topic | reliability | congestion_control | express | priority | Why |
|
||||
| ------------------ | ----------- | ---------------------- | -------- | ---------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
|
||||
| `obs` | default | **DROP** | false | DATA | Intentional drop already happened at the client's one-slot holder; if the uplink stalls, dropping a frame protects the control loop. |
|
||||
| `action` | RELIABLE | **DROP** (never BLOCK) | **true** | INTERACTIVE_HIGH | Hop-level reliability over TCP; express skips batching for the small (4–50 KB) latency-critical payload; DROP so one dead robot uplink can never stall the server's publish path. Chunk loss is recoverable: the client buffer rides through it. |
|
||||
| control queryables | RELIABLE | default | — | — | Correctness over latency; explicit timeouts bound them. |
|
||||
|
||||
Upgrade path if WAN chunk loss proves material: `AdvancedPublisher`/`AdvancedSubscriber` (zenoh ≥ 1.5) with a small cache + heartbeat-based recovery **on the action topic only**. Hop-by-hop RELIABLE is not end-to-end reliability — Zenoh has no broker persistence; a disconnected subscriber's data is gone. The design assumes this (client state machine, §9).
|
||||
|
||||
### 6.4 Liveliness (presence + watchdogs)
|
||||
|
||||
- Client declares a liveliness token on `.../<client_uuid>/alive`. The server liveliness-subscribes with `history=True`: token appear → ensure session state; token drop → GC the session (mailbox, processor instances) after a grace period.
|
||||
- Server declares `.../server/alive`. The client liveliness-subscribes: on drop → treat as RECONNECTING (§9), hold/fallback per config, re-run the `status`/`session` handshake when the token reappears.
|
||||
- Tune the transport lease down from its default so ungraceful-death detection is seconds, not tens of seconds (verify the default in the pinned version; it is config `transport/link/tx/lease`).
|
||||
- Liveliness cannot detect a _hung-but-connected_ server. The client's per-request timeout (`request_timeout_s`) is the authoritative watchdog — this is the structural fix for legacy BUG-3 (no deadlines on `GetActions`).
|
||||
|
||||
### 6.5 Threading constraints (zenoh-python facts that shape both processes)
|
||||
|
||||
- **No asyncio API** in zenoh-python — both client and server are thread-based. This matches the existing RTC engine pattern exactly.
|
||||
- Each callback-based subscriber spawns a dedicated Python thread; **blocking Zenoh calls inside callbacks are disallowed**. Callbacks must be deposit-only (write a slot, set an event, return).
|
||||
- Channel handlers (`FifoChannel`, `RingChannel`) are Rust-side; `try_recv()` polls without spawning Python threads. `RingChannel(1)` is native latest-only semantics.
|
||||
- No zero-copy path for our payloads (SHM API is `@_unstable` and same-host-only; `ZBytes` copy behavior undocumented). At ~200 KB × a few Hz per robot, one memcpy is irrelevant.
|
||||
|
||||
### 6.6 Router deployment
|
||||
|
||||
- `zenohd` official image as a k8s Deployment (1–N replicas; routers mesh and reroute around failures) behind a `LoadBalancer`/`NodePort` Service exposing TLS 7447. No official Helm chart exists — roll-your-own manifests.
|
||||
- `scouting.multicast.enabled: false`; `scouting.gossip.enabled: true`; clients/servers use static `connect.endpoints`.
|
||||
- **Auth**: mTLS per robot (`transport.link.tls` with `enable_mtls`) + router **ACL** keyed on `cert_common_names`: a robot's cert may only `put` to `@lerobot/**/<its-uuid>/obs` and receive on `.../<its-uuid>/action`. Caveat (flagged): ACL config reloads require a router restart — plan cert/ACL changes as rolling router restarts.
|
||||
- Security review input: the third-party Zenoh protocol security analysis (Census Labs, 2025) should be read before exposing 7447 publicly.
|
||||
|
||||
---
|
||||
|
||||
## 7. The Statelessness Boundary (the load-bearing section)
|
||||
|
||||
**Where the network cut goes.** The local RTC pipeline is:
|
||||
|
||||
```
|
||||
obs (robot-processed dict)
|
||||
→ build_dataset_frame(hw_features, obs, "observation") CLIENT (cheap, hardware-coupled)
|
||||
─────────────────────────── network ───────────────────────────
|
||||
→ prepare_observation_for_inference(...) SERVER (policy-coupled, heavy)
|
||||
→ per-session preprocessor(...) SERVER (stateful within the request)
|
||||
→ policy.predict_action_chunk(obs, inference_delay, prefix) SERVER (pure for allowlisted policies)
|
||||
→ per-session postprocessor(...) SERVER (reads state cached at preprocess)
|
||||
─────────────────────────── network ───────────────────────────
|
||||
→ ActionQueue.merge(original, processed, real_delay, idx_before) CLIENT
|
||||
```
|
||||
|
||||
Three consequences:
|
||||
|
||||
1. **The server needs no cross-request state.** `RelativeActionsProcessorStep` writes `_last_state` at preprocess and the postprocessor reads it back _within the same request_. Per-session pipeline instances + one-request-at-a-time-per-session give correctness with zero persistent state.
|
||||
2. **RTC state stays client-side**, exactly where `RTCInferenceEngine` already keeps it. Each request ships: `inference_delay_steps = ceil(L_max/dt)` (from the client `LatencyTracker`, whose samples are full network-inclusive cycle times — RTT compensation falls out for free), `prefix_model = queue.get_left_over()[:H]`, and `prefix_robot = queue.get_processed_left_over()[:H]` (needed for server-side relative-prefix re-anchoring, mirroring `rtc.py:287-305`). The response returns **both** the model-space and robot-space chunks because `merge` needs both. ≤ `execution_horizon × action_dim` float32 each — a few hundred bytes.
|
||||
3. **G9 dies structurally.** No bespoke client resize (`F.interpolate` in legacy `helpers.py`), no client-side normalization. Clients ship native camera resolution; the server's canonical processor path does everything — serve-time preprocessing is byte-identical to train-time.
|
||||
|
||||
**What the server _does_ hold** (and what it means):
|
||||
|
||||
- Per-session processor instances (cheap; normalization stat tensors shared read-only).
|
||||
- Per-session episode counter + stats. Episode reset = reset the session's pipelines, clear its mailbox. **`policy.reset()` is never called in shared mode** — it is global to the shared policy instance and unnecessary for chunk-pure policies (ACT's ensembler and Pi0/SmolVLA's queues live in `select_action`, not `predict_action_chunk` — verified).
|
||||
- Policies that are _not_ chunk-pure get `serving_mode: exclusive` (§8.3).
|
||||
|
||||
---
|
||||
|
||||
## 8. The Inference Server: `lerobot-policy-server`
|
||||
|
||||
New package `src/lerobot/policy_server/`; console script `lerobot-policy-server --manifest manifest.yaml`.
|
||||
|
||||
### 8.1 Process model — **KEPT** from v1, amended
|
||||
|
||||
One process = one model+task on one GPU, loaded and warmed at startup (`warmup_inferences` dummy forwards; covers torch.compile). Multi-GPU nodes run N processes (`CUDA_VISIBLE_DEVICES` pinning). Dynamic model loading (`SendPolicyInstructions`) is **rejected**: pickle/RCE surface, arbitrary-download surface, and it destroys capacity planning. Amendment: `pin_task: false` (default) lets VLA clients set the task per session; `pin_task: true` rejects mismatched tasks at session open.
|
||||
|
||||
### 8.2 Concurrency (pure threads — no asyncio in zenoh-python)
|
||||
|
||||
```
|
||||
zenoh subscriber (.../*/obs) inference worker (1 thread, owns GPU)
|
||||
deposit-only callback: loop:
|
||||
slots[client_uuid] = sample ──► pick next session with pending obs (RR ring)
|
||||
(per-client latest-only) decode JPEG → per-session preprocess
|
||||
predict_action_chunk(delay, prefix)
|
||||
control queryables (status/session/ per-session postprocess → encode
|
||||
reset): validate, mutate session publisher.put(.../<uuid>/action)
|
||||
registry, reply (publishing from the worker thread is fine)
|
||||
```
|
||||
|
||||
- **Per-client latest-only mailbox**: a wildcard subscriber with a deposit-only callback writing per-client slots (scales to dynamic fleets), or — when the manifest enumerates clients — one `RingChannel(1)` subscriber per client polled via `try_recv()`. Either way: newest observation wins; a superseded request is counted (`superseded_seqs` in the next response) so drops are visible. This deletes legacy BUG-4 (`observations_similar` + `must_go`) by construction — the **client** decides when to request; the server never second-guesses observation content.
|
||||
- **Single inference worker**: torch releases the GIL inside `forward`, callbacks stay responsive. Strict round-robin over sessions with pending observations: each gets exactly one inference per cycle; starvation is structurally impossible. Overload degrades into longer cycle times → larger (but correct) client `delay_steps` → eventually the client staleness bound trips and the robot holds — safe by construction.
|
||||
|
||||
### 8.3 Chunk-stateless allowlist and serving modes
|
||||
|
||||
At startup the server classifies the loaded policy:
|
||||
|
||||
| Class | Policies (verified) | Mode |
|
||||
| --------------- | ------------------------------------------------------------------------------------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| chunk-stateless | ACT, Pi0, Pi0.5, SmolVLA (and any policy whose `predict_action_chunk` touches no instance state) | `shared`: N sessions, per-session pipelines, `policy.reset()` never called |
|
||||
| chunk-stateful | Diffusion family (`predict_action_chunk` reads `select_action`-fed `self._queues`) | `exclusive`: `max_sessions=1` enforced; episode reset additionally calls `policy.reset()`; second session open → rejected with a self-explanatory error |
|
||||
| no chunk API | SAC, SARM | refused at startup |
|
||||
|
||||
Implemented as a registry in `policy_server/validation.py`; the cleaner follow-up is a `supports_stateless_chunking` class attribute on `PreTrainedPolicy` (needs a pass over policy families — roadmap §14).
|
||||
|
||||
### 8.4 Session open & capability validation (fail fast, fail loud)
|
||||
|
||||
`session` queryable payload: `client_uuid`, `policy_type`, `fps`, feature summary (post-rename observation feature names + shapes, ordered action keys), `schema_version`, RTC intent, `tags`. Checks:
|
||||
|
||||
| Check | Rule | On mismatch |
|
||||
| -------------------------- | --------------------------------------------------------------- | ---------------------------------------------------------------------------------- |
|
||||
| Action names **and order** | must equal server's `action_feature_names` exactly | **hard reject** — this is the sync-safety contract mapping chunk columns to motors |
|
||||
| Camera names | client set must cover `policy.config.input_features` image keys | hard reject |
|
||||
| Resolution | any H×W accepted (server resizes canonically) | warn if aspect ratio differs from training |
|
||||
| State dim | flattened dim must match | hard reject |
|
||||
| `schema_version` | client within server's supported range | hard reject |
|
||||
| fps | vs. manifest `trained_fps` | warn (reject only when `strict_fps: true`) |
|
||||
| Task | when `pin_task: true`, must equal `default_task` | reject |
|
||||
| RTC | client RTC requires policy RTC kwargs support | downgrade to append mode + warning |
|
||||
| Capacity | `active_sessions < max_sessions` | reject with current load → client retries another replica |
|
||||
|
||||
Reply: `session_id`, model info (repo, revision — consider a checkpoint hash, §15), `action_feature_names`, `chunk_size`, `trained_fps`, `supports_rtc`, `serving_mode`, `warmed_up`, `schema_version`, warnings. **rename_map is applied client-side** so the wire format is canonical policy-feature keys across heterogeneous robots (also a prerequisite for future batching).
|
||||
|
||||
### 8.5 Scheduler seam (micro-batching later, not in v1)
|
||||
|
||||
The worker calls a `Scheduler.select(ready: list[Session]) -> list[Session]`; v1 ships `RoundRobin` (`return ready[:1]`). Cross-session batching is blocked on the policy API (`inference_delay` is scalar; batched clients have different delays/prefixes) — when that lands, a `MicroBatch` scheduler groups same-shape sessions. The seam costs nothing now and prevents a redesign later.
|
||||
|
||||
### 8.6 Manifest
|
||||
|
||||
```yaml
|
||||
model:
|
||||
{
|
||||
repo_or_path: lerobot/pi0_towels,
|
||||
revision: main,
|
||||
dtype: bfloat16,
|
||||
device: cuda,
|
||||
}
|
||||
default_task: "fold the towel"
|
||||
pin_task: false
|
||||
serving_mode: shared # forced to exclusive for chunk-stateful policies
|
||||
max_sessions: 5 # from the §P10 formula: Pi0 @150ms, 1 Hz refresh
|
||||
warmup_inferences: 2
|
||||
strict_fps: false
|
||||
zenoh:
|
||||
connect_endpoints: ["tls/router.gpu-cluster.internal:7447"]
|
||||
tls:
|
||||
{
|
||||
connect_certificate: ...,
|
||||
connect_private_key: ...,
|
||||
root_ca_certificate: ...,
|
||||
}
|
||||
health_port: 9100 # HTTP health + Prometheus metrics
|
||||
debug: { capture_dir: null, capture_max: 256 }
|
||||
```
|
||||
|
||||
Draccus dataclass in `policy_server/manifest.py`; YAML via `--manifest`, individual overrides via CLI.
|
||||
|
||||
---
|
||||
|
||||
## 9. The Edge Client: `RemoteInferenceEngine`
|
||||
|
||||
New file `src/lerobot/rollout/inference/remote.py`, registered `@InferenceEngineConfig.register_subclass("remote")`.
|
||||
|
||||
### 9.1 Threading model
|
||||
|
||||
| Thread | Role |
|
||||
| -------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| Main (strategy loop) | `notify_observation(obs)` → lock-protected latest-only slot (identical to `rtc.py` `_obs_holder`). `get_action()` → `ActionQueue.get()` + staleness check. **Never any I/O.** Structurally fixes legacy BUG-1 (blocking send inside the 33 ms loop). |
|
||||
| Network worker (1 daemon thread) | Cycle: wait until `queue_remaining·dt ≤ buffer_time_s` and active → snapshot `idx_before`, prefixes, `delay_steps = ceil(L_max/dt)` → encode (JPEG q=`jpeg_quality`) → `publisher.put(obs, attachment=header)` → await chunk on the action subscriber channel (timeout `request_timeout_s`) → `merge(original, processed, ceil(L/dt), idx_before)` → `latency_tracker.add(L)`. Owns the state machine, reconnects, and control queries. One-in-flight (P5). |
|
||||
| Zenoh action subscriber | `FifoChannel(2)` handler drained by the worker (no Python callback thread on the hot path); liveliness subscriber callback is deposit-only (sets an event). |
|
||||
|
||||
Reused unchanged: `ActionQueue` (`policies/rtc/action_queue.py`), `LatencyTracker`, `ActionInterpolator` (lives in strategies — `interpolation_multiplier` works with remote for free). Deleted concepts: aggregation zoo, `observations_similar`, `must_go`, `TimedObservation`/`TimedAction` pickles.
|
||||
|
||||
### 9.2 Fail-safe state machine
|
||||
|
||||
```
|
||||
ok no chunk for degraded_after_s
|
||||
CONNECTING ─────► STREAMING ───────────────────────────────► DEGRADED
|
||||
│ ▲ ▲ │ queue empty OR max_action_age_s hit │
|
||||
│ │ backoff, │ └───────────────────────────────────► STALLED ◄──┘
|
||||
│ │ re-handshake │ first successful merge │
|
||||
│ └─ RECONNECTING ◄── timeout streak / server liveliness drop ◄─┘
|
||||
│ │ offline > max_offline_s, capability/schema mismatch, auth failure
|
||||
└──────► DEAD (failed=True → shutdown_event → strategy teardown: return-to-initial-pose)
|
||||
```
|
||||
|
||||
- **DEGRADED**: requests failing but the queue still holds actions — the robot keeps executing; chunks _are_ the fault-tolerance buffer (1–3 s of coverage makes blips and clean server drains invisible).
|
||||
- **STALLED**: queue empty or staleness bound hit → apply `fallback`: `hold` (`get_action` → `None`; `send_next_action` already tolerates it), `repeat_last`, or `zero` (required for velocity-controlled robots, where "send nothing" means "keep last velocity").
|
||||
- **Staleness bound** (sync safety): every merge records `(chunk_start_index, t_send)`; `get_action` refuses any action whose source observation is older than `max_action_age_s` (default 3.0 s ≈ 90 steps @ 30 fps). Bounds open-loop execution after a network stall.
|
||||
- **DEAD**: only after `max_offline_s` (default 60 s) or a hard contract violation (capability/schema mismatch on reconnect — e.g. the server restarted with a different model; never execute wrong-model chunks). Uses the exact mechanism RTC uses (`failed=True` + global `shutdown_event`) so existing teardown runs unchanged.
|
||||
- **Watchdog layering**: per-request timeout (hung server — the BUG-3 fix) → server liveliness token (dead server/router) → staleness bound (the robot-side invariant that holds regardless of why data stopped).
|
||||
- **Pause/resume (DAgger)**: `pause()` stops the worker publishing (slot keeps refreshing, ignored); queue intact — parity with `RTCInferenceEngine.pause`. DAgger's existing `interpolator.reset(); engine.reset(); engine.resume()` sequence works unchanged.
|
||||
- **`reset()` (episode boundary)**: clear `ActionQueue` + staleness bookkeeping, bump `episode_id`, fire the acked `reset` query (1 s timeout, failure logged — the server has nothing it _must_ do thanks to per-request statelessness), flag `episode_start` on the next observation. `LatencyTracker` intentionally survives reset (latency is episode-invariant; parity with local RTC).
|
||||
- **`ready`** = session opened ∧ capabilities validated ∧ server `warmed_up`. First-chunk gating is implicit (`get_action` → `None` until the first merge).
|
||||
|
||||
### 9.3 Weightless client — exact integration changes
|
||||
|
||||
- `rollout/context.py`: `PolicyContext.{policy, preprocessor, postprocessor}` become `| None`. For remote configs, skip step 1 (weight load / PEFT / `.to(device)` / torch.compile / `init_rtc_processor`) and step 6 (`make_pre_post_processors`). Verified safe: strategies only consume `ctx.policy.inference`. Keep steps 2–5 (robot processors, hardware, features, dataset) — they are robot-derived. Keep the visual pre-flight check (`context.py:309-324`): `--policy.path` already loads config-only (`rollout/configs.py:324-328`, no weight download) and failing before dialing the server is free. `use_torch_compile` / explicit `--device` → warn-and-ignore for remote.
|
||||
- `rollout/inference/factory.py`: signature loosens to `policy: PreTrainedPolicy | None` (+ `policy_config: PreTrainedConfig`); `sync`/`rtc` branches guard `policy is None`; the `remote` branch lazy-imports (`eclipse-zenoh` stays an optional extra).
|
||||
- The authoritative validation moves to session open (§8.4); the local check becomes a fast-fail convenience.
|
||||
|
||||
### 9.4 Config
|
||||
|
||||
```python
|
||||
@InferenceEngineConfig.register_subclass("remote")
|
||||
@dataclass
|
||||
class RemoteInferenceConfig(InferenceEngineConfig):
|
||||
connect_endpoint: str = "tls/localhost:7447" # zenoh router endpoint
|
||||
tls_cert: str | None = None; tls_key: str | None = None; tls_ca: str | None = None
|
||||
client_uuid: str = "" # "" → uuid4 at start()
|
||||
jpeg_quality: int = 90 # 0 = raw (LAN/debug)
|
||||
buffer_time_s: float = 0.5 # send next obs when queue playback ≤ this (v1 G14) — KEPT
|
||||
max_action_age_s: float = 3.0 # staleness bound (safety)
|
||||
degraded_after_s: float = 1.0
|
||||
request_timeout_s: float = 5.0
|
||||
reconnect_initial_backoff_s: float = 0.5
|
||||
reconnect_max_backoff_s: float = 10.0
|
||||
max_offline_s: float = 60.0
|
||||
fallback: FallbackBehavior = FallbackBehavior.HOLD # hold | repeat_last | zero
|
||||
rtc: RTCConfig = field(default_factory=RTCConfig) # enabled → replace mode; horizon caps prefix
|
||||
tags: dict[str, str] = field(default_factory=dict) # ex-cluster/experiment labels
|
||||
```
|
||||
|
||||
```bash
|
||||
# Remote RTC + sentry recording (the reproducibility path)
|
||||
lerobot-rollout \
|
||||
--strategy.type=sentry \
|
||||
--policy.path=lerobot/pi0_towels \ # config-only: no weights downloaded
|
||||
--inference.type=remote \
|
||||
--inference.connect_endpoint=tls/router.gpu-cluster.internal:7447 \
|
||||
--inference.rtc.execution_horizon=10 \
|
||||
--robot.type=so100_follower --robot.port=/dev/ttyACM0 \
|
||||
--robot.cameras="{front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||
--dataset.repo_id=user/rollout_fleet_a --dataset.single_task="fold the towel"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 10. Wire Schema
|
||||
|
||||
### 10.1 Payload anatomy & rates — **KEPT** (JPEG) with numbers
|
||||
|
||||
Upstream per request: joints (24–128 B) + JPEG frames (480p q90 ≈ 40–90 KB each; 720p ≈ 110–230 KB) + RTC prefixes (≤ a few KB) → 60–450 KB depending on cameras. Downstream: `2 × chunk_size × action_dim × 4 B` + metadata → 3–50 KB. Effective request rate is self-clocked by `buffer_time_s` to ~1–4 Hz per robot (not the 30 Hz control rate). 300 robots ≈ 0.3–10 Mbps each — the wire is never the bottleneck; bandwidth budgeting is about camera count/resolution, and each GPU pod only ever sees its own ≤ `max_sessions` clients. Zenoh fragments >64 KiB payloads transparently; multi-MB messages are fine.
|
||||
|
||||
### 10.2 Attachment header (fixed-layout, packed little-endian — parsed without touching the body)
|
||||
|
||||
| Field | Type | Notes |
|
||||
| ---------------- | ---- | -------------------------------------------------------------- |
|
||||
| `schema_version` | u16 | negotiated at session open |
|
||||
| `msg_type` | u8 | OBS / CHUNK / EVENT |
|
||||
| `seq_id` | u64 | per-session monotonic; echoed in the chunk |
|
||||
| `episode_id` | u32 | bumped by `reset()` |
|
||||
| `client_mono_ns` | i64 | client `monotonic_ns()`; **opaque to the server, echoed back** |
|
||||
| `session_epoch` | u32 | bumped per (re)connect; stale-epoch chunks dropped |
|
||||
|
||||
### 10.3 msgpack bodies
|
||||
|
||||
**ObservationMsg** (client → server): `state: {names_ref, data: f32 LE bytes}`, `images: {name: {codec: jpeg|raw, bytes, (h,w,c) if raw}}`, `task: str`, `inference_delay_steps: int`, `prefix_model: tensor?`, `prefix_robot: tensor?` (tensors = raw LE bytes + dtype + shape), `episode_start: bool`.
|
||||
**ActionChunkMsg** (server → client): `seq_id_echo`, `client_mono_ns_echo`, `chunk_model: tensor`, `chunk_robot: tensor`, `queue_wait_ms: f32`, `inference_ms: f32`, `superseded_seqs: u32`, `server_load: f32`.
|
||||
**Status / SessionOpen / SessionAck / ResetMsg**: as specified in §8.4.
|
||||
|
||||
### 10.4 Schema discipline (P7)
|
||||
|
||||
`schema_version` gates at handshake; evolution is additive-only (new optional msgpack keys; unknown keys ignored); attachment layout changes require a version bump; golden codec round-trip tests (tensor exactness, JPEG RGB-channel-order regression — a silent BGR swap poisons every VLA in the fleet) are part of the test suite. **No pickle anywhere** — KEPT from v1 and now structural: nothing in the schema can carry code.
|
||||
|
||||
---
|
||||
|
||||
## 11. Latency Budget & the Clock Iron Rule
|
||||
|
||||
| Stage | LAN | WAN (50 ms RTT) |
|
||||
| ------------------------------ | --------------- | --------------- |
|
||||
| JPEG encode ×3 (edge CPU) | 2–9 ms | 2–9 ms |
|
||||
| Serialize | <1 ms | <1 ms |
|
||||
| Uplink (tx + ½RTT) | ~2 ms | ~54 ms |
|
||||
| Server queue wait | 0 → 1×inference | 0 → 1×inference |
|
||||
| Decode + canonical preprocess | 4–10 ms | 4–10 ms |
|
||||
| **Inference** | **15–150 ms** | **15–150 ms** |
|
||||
| Postprocess + downlink + merge | ~2 ms | ~27 ms |
|
||||
| **Total (Pi0-class)** | **~110–175 ms** | **~190–250 ms** |
|
||||
|
||||
Inference is 60–85 % of end-to-end on LAN; the entire transport+serialization stack is <10 ms. WAN adds propagation + uplink bandwidth — identical under any transport. At 30 fps this lands `delay_steps` ≈ 4–8, comfortably inside RTC execution horizons: WAN degrades smoothness parameters, never correctness. _This table is the standing answer to transport-performance bikeshedding._
|
||||
|
||||
**Clock iron rule** (P4): wall-clock instants never cross machines. Client stamps `monotonic_ns`, the server echoes it opaquely; `RTT = now − echo`. The server reports only **durations** (`queue_wait_ms`, `inference_ms`) measured on its own monotonic clock; `network_time = RTT − queue_wait − inference` for diagnostics. The schema has no field in which a foreign wall-clock instant can be compared — the legacy `time.time()` bug is unrepresentable.
|
||||
|
||||
---
|
||||
|
||||
## 12. Reproducibility & Audit (P8)
|
||||
|
||||
The contract is **fully logged + replayable**, not "deterministic":
|
||||
|
||||
- **Client = source of truth.** Recording strategies already persist observations + executed actions to `LeRobotDataset`. The remote engine logs, per executed action, the `(session_id, seq_id, episode_id)` of its source chunk plus the echoed `queue_wait_ms`/`inference_ms` (dataset-extras columns are a follow-up; client logs in v1).
|
||||
- **Server audit line per request** (structured JSON): `{ts, session_id, client_uuid, seq_id, episode_id, queue_wait_ms, inference_ms, chunk_range, superseded_seqs, outcome}`.
|
||||
- **Optional bounded capture**: `debug.capture_dir` writes a ring of request/response pairs (safetensors) for byte-exact offline replay through the same server pipeline.
|
||||
- **Runbook — "robot #217 stuttered at 14:03"**: (1) Grafana `session_staleness{client="217"}` — spike ⇒ server side, flat ⇒ client/network. (2) Server side: audit lines — `queue_wait_ms` rising across _all_ sessions ⇒ overloaded replica (check `active_sessions` vs `max_sessions`); `superseded_seqs` streak on 217 only ⇒ that client over-requesting; `outcome=error` ⇒ adjacent stack trace. (3) Client side: state-machine transitions + reconnects in the client log; dataset rows show which seq's chunk was executing and where `None` ticks occurred. Every hop shares `(session_id, seq_id)` — the join is mechanical.
|
||||
|
||||
---
|
||||
|
||||
## 13. Integration & Migration Plan
|
||||
|
||||
### 13.1 New
|
||||
|
||||
| Path | Content |
|
||||
| --------------------------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `src/lerobot/policy_server/{__init__,schema,codec,manifest,session,scheduler,validation,server}.py` | wire schema constants, msgpack/attachment codecs, manifest dataclasses, `Session` + mailbox, `Scheduler` seam, capability rules + chunk-stateless registry, zenoh servicer + inference worker + drain + HTTP health/metrics |
|
||||
| `src/lerobot/rollout/inference/remote.py` | `RemoteInferenceEngine` (~600 lines; mirrors `rtc.py` structure) |
|
||||
| `src/lerobot/scripts/lerobot_policy_server.py` + `[project.scripts]` entry | thin `main()` |
|
||||
| `docker/Dockerfile.policy-server` | CUDA runtime base + uv; manifest via ConfigMap |
|
||||
| `docs/source/remote_inference.mdx` (+ `_toctree.yml`) | replaces `async.mdx` |
|
||||
|
||||
### 13.2 Modified
|
||||
|
||||
`rollout/inference/factory.py` (config + Optional-typed signature + lazy import) · `rollout/context.py` (weightless branch) · `rollout/inference/__init__.py` · `scripts/lerobot_rollout.py` docstring · `pyproject.toml`: `[async]` extra becomes `eclipse-zenoh>=1.9,<2.0` + `msgpack` (grpcio/matplotlib leave it; grpcio remains under `[hilserl]`/`dev` for the RL stack).
|
||||
|
||||
### 13.3 Removed — same landing PR
|
||||
|
||||
`src/lerobot/async_inference/` · `tests/async_inference/` · `docs/source/async.mdx` + its `_toctree.yml` entry · the `AsyncInference` service + `Observation`/`Actions`/`PolicySetup` messages from `src/lerobot/transport/services.proto` (regenerate pb2; **`LearnerService` untouched** — `transport/` is shared with HIL-SERL (`src/lerobot/rl/`); the RL test suite gates this change).
|
||||
|
||||
### 13.4 Legacy config → successor mapping
|
||||
|
||||
| Legacy (`RobotClientConfig`/`PolicyServerConfig`) | Successor |
|
||||
| ------------------------------------------------- | ---------------------------------------------------------- |
|
||||
| `server_address` | `--inference.connect_endpoint` (zenoh router) |
|
||||
| `policy_type`, `pretrained_name_or_path` | `--policy.path` (config-only) + server manifest |
|
||||
| `chunk_size_threshold` (0–1 ratio) | `--inference.buffer_time_s` (seconds) |
|
||||
| `actions_per_chunk` | server manifest (validated at session open) |
|
||||
| `aggregate_fn_name` + `AGGREGATE_FUNCTIONS` | **dropped** — `ActionQueue` replace/append |
|
||||
| `policy_device`, `client_device` | **dropped** — server concern / chunks arrive CPU f32 |
|
||||
| `debug_visualize_queue_size` | **dropped** — Rerun (`--display_data`) + engine stats |
|
||||
| `PolicyServerConfig.{host,port}` | manifest `zenoh.connect_endpoints` |
|
||||
| `inference_latency`, `obs_queue_timeout` | **dropped** — latency client-measured; no server obs queue |
|
||||
| `SendPolicyInstructions` | **dropped** — MaaS manifest + session validation |
|
||||
| `observations_similar` / `must_go` | **dropped** — latest-only slots + client send gate |
|
||||
| pickle envelopes | **dropped** — msgpack + attachment headers |
|
||||
|
||||
### 13.5 Legacy bugs/gaps → structural resolution
|
||||
|
||||
BUG-1 → worker thread owns all I/O. BUG-2 → aggregation deleted; `ActionQueue` is internally locked. BUG-3 → per-request timeout + liveliness. BUG-4 → client-side send gating; server newest-wins. G1 → per-session registry. G2 → manifest. G4 → msgpack+attachments. G5 → monotonic echo + `delay_steps`. G7 → recording strategies. G8 → mTLS + ACL. G9 → server-side canonical processors. G11 → `status` queryable. G12 → Prometheus + audit logs. G13 → `lerobot-policy-server` console script. G14 → `buffer_time_s`.
|
||||
|
||||
### 13.6 Tests
|
||||
|
||||
- **Unit**: codec round-trips (tensor exact; JPEG RGB-order regression), capability-validation matrix (§8.4 as parametrized cases), scheduler fairness + newest-wins supersession (mock policy with configurable sleep), manifest parsing, key-expr sanitization.
|
||||
- **Loopback integration** (CPU, fast CI): client+server in one process over zenoh peer-to-peer (or a localhost `zenohd` started by the fixture), tiny-ACT, fake 2-camera robot, N=8 concurrent sessions. The headline regression: two sessions with different joint states must not cross-contaminate `RelativeActionsProcessorStep` postprocessing — the test that proves the multi-tenancy claim.
|
||||
- **Chaos**: kill the server mid-episode → client returns `None`, never raises into the control loop, `failed` stays False within `max_offline_s`, resumes on restart; `docker kill zenohd` → liveliness flap → safe state → re-handshake (explicitly tests re-declaration behavior, flagged unverified upstream); SIGTERM drain → in-flight chunk completes, clients reconnect invisibly.
|
||||
- **Golden parity**: remote RTC vs local `RTCInferenceEngine` on identical observation sequences → byte-identical merged queues (the re-anchoring contract test). Gate for any real-robot remote-RTC use.
|
||||
|
||||
---
|
||||
|
||||
## 14. Roadmap
|
||||
|
||||
1. **PR1 — schema & codecs** (no torch deps): `policy_server/{schema,codec,manifest}.py`, key-expr sanitizer, golden codec tests.
|
||||
2. **PR2 — server core**: session registry, scheduler, validation/allowlist, inference worker with mock policy, loopback harness.
|
||||
3. **PR3 — client engine**: `RemoteInferenceEngine`, factory/context weightless integration, loopback integration + chaos + golden-parity tests.
|
||||
4. **PR4 — ops & docs**: Dockerfile, health/metrics, drain, ACL examples, `remote_inference.mdx`, rollout docstring.
|
||||
5. **Landing PR — legacy deletion**: remove `async_inference/` + tests + docs + proto service (RL suite gates), `[async]` extra swap.
|
||||
6. **Pre-release field validation**: one real robot on a lossy network (watchdog default tuning); JPEG q90 vs raw A/B on one policy (train/serve shift).
|
||||
7. **Future**: micro-batching (needs per-sample `inference_delay` across policy families), client-side downscale-to-policy-resolution (config-only shapes make it possible), Advanced Pub/Sub on the action topic, per-robot quotas, dataset provenance columns, `supports_stateless_chunking` attribute upstreamed to policy classes.
|
||||
|
||||
---
|
||||
|
||||
## 15. Open Risks
|
||||
|
||||
| Risk | Mitigation / decision needed |
|
||||
| ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| Re-anchoring parity (server-side relative-prefix re-anchor vs `rtc.py`) | Golden parity test (§13.6) is a hard gate before robot use; likely failure mode is normalizer dtype/device drift |
|
||||
| First-chunk over-trim when idle: `merge` trims `ceil(L/dt)` even when nothing was consumed (queue empty at episode start) — wasteful at network latencies (600 ms ⇒ 18 steps) | Proposed clamp `real_delay = min(real_delay, last_index - idx_before)` touches the shared `ActionQueue` used by local RTC — needs sign-off + regression tests |
|
||||
| JPEG train/serve distribution shift | Unmeasured; A/B before locking q90 default (roadmap §14.6) |
|
||||
| Watchdog defaults untuned (`request_timeout_s=5`, `degraded_after_s=1`, `max_action_age_s=3`) | Field validation on wired and Wi-Fi; consider named profiles |
|
||||
| Capability check can pass while semantics differ (different finetune, different normalization stats, identical feature names) | Add checkpoint hash/revision pinning to SessionAck — decide in PR2 |
|
||||
| zenoh-python long-session maturity: re-declaration after router restart partially verified; SHM unstable; no asyncio | Chaos tests own this; thread-based design avoids the asyncio gap entirely |
|
||||
| Router ACL reload requires restart | Operational runbook: cert/ACL changes = rolling router restart |
|
||||
| `fallback=zero` has no consumer until velocity actions land in rollout (only `.pos` features routed today) | Validate the enum against robot capabilities when velocity support lands |
|
||||
| Per-client mailbox memory under fleet-scale wildcard subscription | One decoded-obs slot per client is small; add an LRU GC tied to liveliness drops |
|
||||
@@ -0,0 +1,82 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# This Dockerfile builds a GPU inference pod for `lerobot-policy-server`
|
||||
# (remote inference over Zenoh). It starts from an NVIDIA CUDA base image;
|
||||
# the cu128 PyTorch wheels bundle their own CUDA runtime (driver floor 570.86,
|
||||
# see pyproject.toml [tool.uv]).
|
||||
|
||||
# docker build -f docker/Dockerfile.policy-server -t lerobot-policy-server .
|
||||
# docker run --gpus all -v ./server.yaml:/etc/lerobot/server.yaml lerobot-policy-server
|
||||
#
|
||||
# Extra policy-family dependencies (e.g. pi0/smolvla need transformers) can be
|
||||
# added at build time:
|
||||
# docker build -f docker/Dockerfile.policy-server \
|
||||
# --build-arg LEROBOT_EXTRAS="async pi0" -t lerobot-policy-server .
|
||||
|
||||
# Configure the base image (same CUDA family as Dockerfile.internal)
|
||||
ARG CUDA_VERSION=12.8.1
|
||||
ARG OS_VERSION=24.04
|
||||
FROM nvidia/cuda:${CUDA_VERSION}-base-ubuntu${OS_VERSION}
|
||||
|
||||
# Define Python version and lerobot extras arguments
|
||||
ARG PYTHON_VERSION=3.12
|
||||
ARG LEROBOT_EXTRAS="async"
|
||||
|
||||
# Configure environment variables
|
||||
ENV DEBIAN_FRONTEND=noninteractive \
|
||||
PATH=/lerobot/.venv/bin:$PATH
|
||||
|
||||
# Install system dependencies and uv (as root).
|
||||
# Kept lean: no hardware/teleop libraries — this image only serves policies.
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
git curl ca-certificates libglib2.0-0 ffmpeg \
|
||||
&& curl -LsSf https://astral.sh/uv/install.sh | sh \
|
||||
&& mv /root/.local/bin/uv /usr/local/bin/uv \
|
||||
&& useradd --create-home --shell /bin/bash user_lerobot \
|
||||
&& apt-get clean && rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Create application directory and set permissions
|
||||
WORKDIR /lerobot
|
||||
RUN chown -R user_lerobot:user_lerobot /lerobot
|
||||
|
||||
# Switch to the non-root user
|
||||
USER user_lerobot
|
||||
|
||||
# Model checkpoints are cached under HF_HOME — mount it as a volume
|
||||
# (or a PVC in Kubernetes) so warm restarts skip the Hub download.
|
||||
ENV HOME=/home/user_lerobot \
|
||||
HF_HOME=/home/user_lerobot/.cache/huggingface \
|
||||
HF_LEROBOT_HOME=/home/user_lerobot/.cache/huggingface/lerobot \
|
||||
TORCH_HOME=/home/user_lerobot/.cache/torch \
|
||||
TRITON_CACHE_DIR=/home/user_lerobot/.cache/triton
|
||||
|
||||
# Create the virtual environment (Python provisioned by uv)
|
||||
RUN uv venv --python ${PYTHON_VERSION}
|
||||
|
||||
# Install lerobot from the build context with the async extra
|
||||
# (eclipse-zenoh + msgpack — see pyproject.toml [project.optional-dependencies])
|
||||
COPY --chown=user_lerobot:user_lerobot setup.py pyproject.toml uv.lock README.md MANIFEST.in ./
|
||||
COPY --chown=user_lerobot:user_lerobot src/ src/
|
||||
|
||||
RUN uv sync --locked --no-cache $(printf -- '--extra %s ' ${LEROBOT_EXTRAS})
|
||||
|
||||
# HTTP health + Prometheus metrics (manifest `health_port`, 0 disables)
|
||||
EXPOSE 9100
|
||||
|
||||
# The manifest is typically mounted as a ConfigMap (Kubernetes) or a bind
|
||||
# mount (docker run -v) at /etc/lerobot/server.yaml; any field can also be
|
||||
# overridden on the command line, e.g. --model.repo_or_path=lerobot/pi0_towels
|
||||
ENTRYPOINT ["lerobot-policy-server"]
|
||||
CMD ["--manifest", "/etc/lerobot/server.yaml"]
|
||||
@@ -68,7 +68,7 @@
|
||||
- local: eo1
|
||||
title: EO-1
|
||||
- local: groot
|
||||
title: NVIDIA GR00T
|
||||
title: NVIDIA GR00T N1.5
|
||||
- local: xvla
|
||||
title: X-VLA
|
||||
- local: multi_task_dit
|
||||
@@ -87,8 +87,8 @@
|
||||
- sections:
|
||||
- local: inference
|
||||
title: Policy Deployment (lerobot-rollout)
|
||||
- local: async
|
||||
title: Use Async Inference
|
||||
- local: remote_inference
|
||||
title: Remote Inference (lerobot-policy-server)
|
||||
- local: rtc
|
||||
title: Real-Time Chunking (RTC)
|
||||
title: "Inference"
|
||||
|
||||
@@ -1,313 +0,0 @@
|
||||
# Asynchronous Inference
|
||||
|
||||
With our [SmolVLA](https://huggingface.co/papers/2506.01844) we introduced a new way to run inference on real-world robots, **decoupling action prediction from action execution**.
|
||||
In this tutorial, we'll show how to use asynchronous inference (_async inference_) using a finetuned version of SmolVLA, and all the policies supported by LeRobot.
|
||||
**Try async inference with all the policies** supported by LeRobot!
|
||||
|
||||
**What you'll learn:**
|
||||
|
||||
1. Why asynchronous inference matters and how it compares to, more traditional, sequential inference.
|
||||
2. How to spin-up a `PolicyServer` and connect a `RobotClient` from the same machine, and even over the network.
|
||||
3. How to tune key parameters (`actions_per_chunk`, `chunk_size_threshold`) for your robot and policy.
|
||||
|
||||
If you get stuck, hop into our [Discord community](https://discord.gg/s3KuuzsPFb)!
|
||||
|
||||
In a nutshell: with _async inference_, your robot keeps acting while the policy server is already busy computing the next chunk of actions---eliminating "wait-for-inference" lags and unlocking smoother, more reactive behaviours.
|
||||
This is fundamentally different from synchronous inference (sync), where the robot stays idle while the policy computes the next chunk of actions.
|
||||
|
||||
---
|
||||
|
||||
## Getting started with async inference
|
||||
|
||||
You can read more information on asynchronous inference in our [blogpost](https://huggingface.co/blog/async-robot-inference). This guide is designed to help you quickly set up and run asynchronous inference in your environment.
|
||||
|
||||
First, install `lerobot` with the `async` tag, to install the extra dependencies required to run async inference.
|
||||
|
||||
```shell
|
||||
pip install -e ".[async]"
|
||||
```
|
||||
|
||||
Then, spin up a policy server (in one terminal, or in a separate machine) specifying the host address and port for the client to connect to.
|
||||
You can spin up a policy server running:
|
||||
|
||||
```shell
|
||||
python -m lerobot.async_inference.policy_server \
|
||||
--host=127.0.0.1 \
|
||||
--port=8080
|
||||
```
|
||||
|
||||
This will start a policy server listening on `127.0.0.1:8080` (`localhost`, port 8080). At this stage, the policy server is empty, as all information related to which policy to run and with which parameters are specified during the first handshake with the client. Spin up a client with:
|
||||
|
||||
```shell
|
||||
python -m lerobot.async_inference.robot_client \
|
||||
--server_address=127.0.0.1:8080 \ # SERVER: the host address and port of the policy server
|
||||
--robot.type=so100_follower \ # ROBOT: your robot type
|
||||
--robot.port=/dev/tty.usbmodem585A0076841 \ # ROBOT: your robot port
|
||||
--robot.id=follower_so100 \ # ROBOT: your robot id, to load calibration file
|
||||
--robot.cameras="{ laptop: {type: opencv, index_or_path: 0, width: 1920, height: 1080, fps: 30}, phone: {type: opencv, index_or_path: 0, width: 1920, height: 1080, fps: 30}}" \ # POLICY: the cameras used to acquire frames, with keys matching the keys expected by the policy
|
||||
--task="dummy" \ # POLICY: The task to run the policy on (`Fold my t-shirt`). Not necessarily defined for all policies, such as `act`
|
||||
--policy_type=your_policy_type \ # POLICY: the type of policy to run (smolvla, act, etc)
|
||||
--pretrained_name_or_path=user/model \ # POLICY: the model name/path on server to the checkpoint to run (e.g., lerobot/smolvla_base)
|
||||
--policy_device=mps \ # POLICY: the device to run the policy on, on the server (cuda, mps, xpu, cpu)
|
||||
--actions_per_chunk=50 \ # POLICY: the number of actions to output at once
|
||||
--chunk_size_threshold=0.5 \ # CLIENT: the threshold for the chunk size before sending a new observation to the server
|
||||
--aggregate_fn_name=weighted_average \ # CLIENT: the function to aggregate actions on overlapping portions
|
||||
--debug_visualize_queue_size=True # CLIENT: whether to visualize the queue size at runtime
|
||||
```
|
||||
|
||||
In summary, you need to specify instructions for:
|
||||
|
||||
- `SERVER`: the address and port of the policy server
|
||||
- `ROBOT`: the type of robot to connect to, the port to connect to, and the local `id` of the robot
|
||||
- `POLICY`: the type of policy to run, and the model name/path on server to the checkpoint to run. You also need to specify which device should the sever be using, and how many actions to output at once (capped at the policy max actions value).
|
||||
- `CLIENT`: the threshold for the chunk size before sending a new observation to the server, and the function to aggregate actions on overlapping portions. Optionally, you can also visualize the queue size at runtime, to help you tune the `CLIENT` parameters.
|
||||
|
||||
Importantly,
|
||||
|
||||
- `actions_per_chunk` and `chunk_size_threshold` are key parameters to tune for your setup.
|
||||
- `aggregate_fn_name` is the function to aggregate actions on overlapping portions. You can either add a new one to a registry of functions, or add your own in `robot_client.py` (see [here](NOTE:addlinktoLOC))
|
||||
- `debug_visualize_queue_size` is a useful tool to tune the `CLIENT` parameters.
|
||||
|
||||
## Done! You should see your robot moving around by now 😉
|
||||
|
||||
## Async vs. synchronous inference
|
||||
|
||||
Synchronous inference relies on interleaving action chunk prediction and action execution. This inherently results in _idle frames_, frames where the robot awaits idle the policy's output: a new action chunk.
|
||||
In turn, inference is plagued by evident real-time lags, where the robot simply stops acting due to the lack of available actions.
|
||||
With robotics models increasing in size, this problem risks becoming only more severe.
|
||||
|
||||
<p align="center">
|
||||
<img
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/async-inference/sync.png"
|
||||
width="80%"
|
||||
></img>
|
||||
</p>
|
||||
<p align="center">
|
||||
<i>Synchronous inference</i> makes the robot idle while the policy is
|
||||
computing the next chunk of actions.
|
||||
</p>
|
||||
|
||||
To overcome this, we design async inference, a paradigm where action planning and execution are decoupled, resulting in (1) higher adaptability and, most importantly, (2) no idle frames.
|
||||
Crucially, with async inference, the next action chunk is computed _before_ the current one is exhausted, resulting in no idleness.
|
||||
Higher adaptability is ensured by aggregating the different action chunks on overlapping portions, obtaining an up-to-date plan and a tighter control loop.
|
||||
|
||||
<p align="center">
|
||||
<img
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/async-inference/async.png"
|
||||
width="80%"
|
||||
></img>
|
||||
</p>
|
||||
<p align="center">
|
||||
<i>Asynchronous inference</i> results in no idleness because the next chunk is
|
||||
computed before the current chunk is exhausted.
|
||||
</p>
|
||||
|
||||
---
|
||||
|
||||
## Start the Policy Server
|
||||
|
||||
Policy servers are wrappers around a `PreTrainedPolicy` interfacing them with observations coming from a robot client.
|
||||
Policy servers are initialized as empty containers which are populated with the requested policy specified in the initial handshake between the robot client and the policy server.
|
||||
As such, spinning up a policy server is as easy as specifying the host address and port. If you're running the policy server on the same machine as the robot client, you can use `localhost` as the host address.
|
||||
|
||||
<hfoptions id="start_policy_server">
|
||||
<hfoption id="Command">
|
||||
```bash
|
||||
python -m lerobot.async_inference.policy_server \
|
||||
--host=127.0.0.1 \
|
||||
--port=8080
|
||||
```
|
||||
</hfoption>
|
||||
<hfoption id="API example">
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
from lerobot.async_inference.configs import PolicyServerConfig
|
||||
from lerobot.async_inference.policy_server import serve
|
||||
|
||||
config = PolicyServerConfig(
|
||||
host="localhost",
|
||||
port=8080,
|
||||
)
|
||||
serve(config)
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
This listens on `localhost:8080` for an incoming connection from the associated`RobotClient`, which will communicate which policy to run during the first client-server handshake.
|
||||
|
||||
---
|
||||
|
||||
## Launch the Robot Client
|
||||
|
||||
`RobotClient` is a wrapper around a `Robot` instance, which `RobotClient` connects to the (possibly remote) `PolicyServer`.
|
||||
The `RobotClient` streams observations to the `PolicyServer`, and receives action chunks obtained running inference on the server (which we assume to have better computational resources than the robot controller).
|
||||
|
||||
<hfoptions id="start_robot_client">
|
||||
<hfoption id="Command">
|
||||
```bash
|
||||
python -m lerobot.async_inference.robot_client \
|
||||
--server_address=127.0.0.1:8080 \ # SERVER: the host address and port of the policy server
|
||||
--robot.type=so100_follower \ # ROBOT: your robot type
|
||||
--robot.port=/dev/tty.usbmodem585A0076841 \ # ROBOT: your robot port
|
||||
--robot.id=follower_so100 \ # ROBOT: your robot id, to load calibration file
|
||||
--robot.cameras="{ laptop: {type: opencv, index_or_path: 0, width: 1920, height: 1080, fps: 30}, phone: {type: opencv, index_or_path: 0, width: 1920, height: 1080, fps: 30}}" \ # POLICY: the cameras used to acquire frames, with keys matching the keys expected by the policy
|
||||
--task="dummy" \ # POLICY: The task to run the policy on (`Fold my t-shirt`). Not necessarily defined for all policies, such as `act`
|
||||
--policy_type=your_policy_type \ # POLICY: the type of policy to run (smolvla, act, etc)
|
||||
--pretrained_name_or_path=user/model \ # POLICY: the model name/path on server to the checkpoint to run (e.g., lerobot/smolvla_base)
|
||||
--policy_device=mps \ # POLICY: the device to run the policy on, on the server
|
||||
--actions_per_chunk=50 \ # POLICY: the number of actions to output at once
|
||||
--chunk_size_threshold=0.5 \ # CLIENT: the threshold for the chunk size before sending a new observation to the server
|
||||
--aggregate_fn_name=weighted_average \ # CLIENT: the function to aggregate actions on overlapping portions
|
||||
--debug_visualize_queue_size=True # CLIENT: whether to visualize the queue size at runtime
|
||||
```
|
||||
</hfoption>
|
||||
<hfoption id="API example">
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
import threading
|
||||
from lerobot.robots.so_follower import SO100FollowerConfig
|
||||
from lerobot.cameras.opencv import OpenCVCameraConfig
|
||||
from lerobot.async_inference.configs import RobotClientConfig
|
||||
from lerobot.async_inference.robot_client import RobotClient
|
||||
from lerobot.async_inference.helpers import visualize_action_queue_size
|
||||
|
||||
# 1. Create the robot instance
|
||||
"""Check out the cameras available in your setup by running `python lerobot/find_cameras.py`"""
|
||||
# these cameras must match the ones expected by the policy
|
||||
# check the config.json on the Hub for the policy you are using
|
||||
camera_cfg = {
|
||||
"top": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
|
||||
"side": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30)
|
||||
}
|
||||
|
||||
robot_cfg = SO100FollowerConfig(
|
||||
port="/dev/tty.usbmodem585A0076841",
|
||||
id="follower_so100",
|
||||
cameras=camera_cfg
|
||||
)
|
||||
|
||||
# 3. Create client configuration
|
||||
client_cfg = RobotClientConfig(
|
||||
robot=robot_cfg,
|
||||
server_address="localhost:8080",
|
||||
policy_device="mps",
|
||||
client_device="cpu",
|
||||
policy_type="smolvla",
|
||||
pretrained_name_or_path="<user>/smolvla_async",
|
||||
chunk_size_threshold=0.5,
|
||||
actions_per_chunk=50, # make sure this is less than the max actions of the policy
|
||||
)
|
||||
|
||||
# 4. Create and start client
|
||||
client = RobotClient(client_cfg)
|
||||
|
||||
# 5. Specify the task
|
||||
task = "Don't do anything, stay still"
|
||||
|
||||
if client.start():
|
||||
# Start action receiver thread
|
||||
action_receiver_thread = threading.Thread(target=client.receive_actions, daemon=True)
|
||||
action_receiver_thread.start()
|
||||
|
||||
try:
|
||||
# Run the control loop
|
||||
client.control_loop(task)
|
||||
except KeyboardInterrupt:
|
||||
client.stop()
|
||||
action_receiver_thread.join()
|
||||
# (Optionally) plot the action queue size
|
||||
visualize_action_queue_size(client.action_queue_size)
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
The following two parameters are key in every setup:
|
||||
|
||||
<table>
|
||||
<thead>
|
||||
<tr>
|
||||
<th>Hyperparameter</th>
|
||||
<th>Default</th>
|
||||
<th>What it does</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
<tr>
|
||||
<td>
|
||||
<code>actions_per_chunk</code>
|
||||
</td>
|
||||
<td>50</td>
|
||||
<td>
|
||||
How many actions the policy outputs at once. Typical values: 10-50.
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>
|
||||
<code>chunk_size_threshold</code>
|
||||
</td>
|
||||
<td>0.7</td>
|
||||
<td>
|
||||
When the queue is ≤ 50% full, the client sends a fresh observation.
|
||||
Value in [0, 1].
|
||||
</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
|
||||
<Tip>
|
||||
Different values of `actions_per_chunk` and `chunk_size_threshold` do result
|
||||
in different behaviours.
|
||||
</Tip>
|
||||
|
||||
On the one hand, increasing the value of `actions_per_chunk` will result in reducing the likelihood of ending up with no actions to execute, as more actions will be available when the new chunk is computed.
|
||||
However, larger values of `actions_per_chunk` might also result in less precise actions, due to the compounding errors consequent to predicting actions over longer timespans.
|
||||
|
||||
On the other hand, increasing the value of `chunk_size_threshold` will result in sending out to the `PolicyServer` observations for inference more often, resulting in a larger number of updates action chunks, overlapping on significant portions. This results in high adaptability, in the limit predicting one action chunk for each observation, which is in turn only marginally consumed while a new one is produced.
|
||||
This option does also put more pressure on the inference pipeline, as a consequence of the many requests. Conversely, values of `chunk_size_threshold` close to 0.0 collapse to the synchronous edge case, whereby new observations are only sent out whenever the current chunk is exhausted.
|
||||
|
||||
We found the default values of `actions_per_chunk` and `chunk_size_threshold` to work well in the experiments we developed for the [SmolVLA paper](https://huggingface.co/papers/2506.01844), but recommend experimenting with different values to find the best fit for your setup.
|
||||
|
||||
### Tuning async inference for your setup
|
||||
|
||||
1. **Choose your computational resources carefully.** [PI0](https://huggingface.co/lerobot/pi0) occupies 14GB of memory at inference time, while [SmolVLA](https://huggingface.co/lerobot/smolvla_base) requires only ~2GB. You should identify the best computational resource for your use case keeping in mind smaller policies require less computational resources. The combination of policy and device used (CPU-intensive, using MPS, or the number of CUDA cores on a given NVIDIA GPU) directly impacts the average inference latency you should expect.
|
||||
2. **Adjust your `fps` based on inference latency.** While the server generates a new action chunk, the client is not idle and is stepping through its current action queue. If the two processes happen at fundamentally different speeds, the client might end up with an empty queue. As such, you should reduce your fps if you consistently run out of actions in queue.
|
||||
3. **Adjust `chunk_size_threshold`**.
|
||||
- Values closer to `0.0` result in almost sequential behavior. Values closer to `1.0` → send observation every step (more bandwidth, relies on good world-model).
|
||||
- We found values around 0.5-0.6 to work well. If you want to tweak this, spin up a `RobotClient` setting the `--debug_visualize_queue_size` to `True`. This will plot the action queue size evolution at runtime, and you can use it to find the value of `chunk_size_threshold` that works best for your setup.
|
||||
|
||||
<p align="center">
|
||||
<img
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/async-inference/queues.png"
|
||||
width="80%"
|
||||
></img>
|
||||
</p>
|
||||
<p align="center">
|
||||
<i>
|
||||
The action queue size is plotted at runtime when the
|
||||
`--debug_visualize_queue_size` flag is passed, for various levels of
|
||||
`chunk_size_threshold` (`g` in the SmolVLA paper).
|
||||
</i>
|
||||
</p>
|
||||
|
||||
---
|
||||
|
||||
## Conclusion
|
||||
|
||||
Asynchronous inference represents a significant advancement in real-time robotics control, addressing the fundamental challenge of inference latency that has long plagued robotics applications. Through this tutorial, you've learned how to implement a complete async inference pipeline that eliminates idle frames and enables smoother, more reactive robot behaviors.
|
||||
|
||||
**Key Takeaways:**
|
||||
|
||||
- **Paradigm Shift**: Async inference decouples action prediction from execution, allowing robots to continue acting while new action chunks are computed in parallel
|
||||
- **Performance Benefits**: Eliminates "wait-for-inference" lags that are inherent in synchronous approaches, becoming increasingly important as policy models grow larger
|
||||
- **Flexible Architecture**: The server-client design enables distributed computing, where inference can run on powerful remote hardware while maintaining real-time robot control
|
||||
- **Tunable Parameters**: Success depends on properly configuring `actions_per_chunk` and `chunk_size_threshold` for your specific hardware, policy, and task requirements
|
||||
- **Universal Compatibility**: Works with all LeRobot-supported policies, from lightweight ACT models to vision-language models like SmolVLA
|
||||
|
||||
Start experimenting with the default parameters, monitor your action queue sizes, and iteratively refine your setup to achieve optimal performance for your specific use case.
|
||||
If you want to discuss this further, hop into our [Discord community](https://discord.gg/s3KuuzsPFb), or open an issue on our [GitHub repository](https://github.com/huggingface/lerobot/issues).
|
||||
@@ -193,7 +193,7 @@ To learn more about training policies with LeRobot, please refer to the training
|
||||
|
||||
- [SmolVLA](./smolvla)
|
||||
- [Pi0.5](./pi05)
|
||||
- [GR00T N1.7](./groot)
|
||||
- [GR00T N1.5](./groot)
|
||||
|
||||
Sample IsaacLab Arena datasets are available on HuggingFace Hub for experimentation:
|
||||
|
||||
|
||||
+30
-79
@@ -1,19 +1,16 @@
|
||||
# GR00T Policy
|
||||
# GR00T N1.5 Policy
|
||||
|
||||
GR00T is an NVIDIA foundation model family for generalized humanoid robot reasoning and skills. It is a cross-embodiment policy that accepts multimodal input, including language, images, and proprioception, to perform manipulation tasks in diverse environments.
|
||||
GR00T N1.5 is an open foundation model from NVIDIA designed for generalized humanoid robot reasoning and skills. It is a cross-embodiment model that accepts multimodal input, including language and images, to perform manipulation tasks in diverse environments.
|
||||
|
||||
LeRobot integrates GR00T N1.7 through the `groot` policy type.
|
||||
|
||||
> [!WARNING]
|
||||
> **Breaking change:** GR00T N1.5 support was removed from LeRobot, and current releases support GR00T N1.7 only. N1.5 checkpoints, configs, and `--policy.model_version=n1.5` are rejected with a clear error. To keep using an N1.5 checkpoint, pin the last release that supports it: `pip install 'lerobot==0.5.1'`. To use the current release, migrate to GR00T N1.7 (`model_version='n1.7'`, base model [`nvidia/GR00T-N1.7-3B`](https://huggingface.co/nvidia/GR00T-N1.7-3B)).
|
||||
This document outlines the specifics of its integration and usage within the LeRobot framework.
|
||||
|
||||
## Model Overview
|
||||
|
||||
GR00T N1.7 uses a Cosmos-Reason2/Qwen3-VL backbone and provides checkpoints for SimplerEnv, DROID, and LIBERO.
|
||||
NVIDIA Isaac GR00T N1.5 is an upgraded version of the GR00T N1 foundation model. It is built to improve generalization and language-following abilities for humanoid robots.
|
||||
|
||||
Developers and researchers can post-train GR00T with their own real or synthetic data to adapt it for specific humanoid robots or tasks.
|
||||
Developers and researchers can post-train GR00T N1.5 with their own real or synthetic data to adapt it for specific humanoid robots or tasks.
|
||||
|
||||
GR00T uses pre-trained vision and language encoders with a flow matching action transformer to model a chunk of actions conditioned on vision, language, and proprioception.
|
||||
GR00T N1.5 (specifically the GR00T-N1.5-3B model) is built using pre-trained vision and language encoders. It utilizes a flow matching action transformer to model a chunk of actions, conditioned on vision, language, and proprioception.
|
||||
|
||||
<img
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/lerobot-groot-paper1%20(1).png"
|
||||
@@ -31,46 +28,33 @@ This approach allows the model to be highly adaptable through post-training for
|
||||
|
||||
## Installation Requirements
|
||||
|
||||
GR00T is intended for NVIDIA GPU-accelerated systems. The `groot` extra still includes Flash Attention on non-macOS platforms, and Flash Attention needs a compatible PyTorch/CUDA environment before it is installed. Install the dependencies in this order:
|
||||
As of today, GR00T N1.5 requires flash attention for it's internal working.
|
||||
|
||||
1. Follow the Environment Setup in the [Installation Guide](./installation). Do not install `lerobot` yet.
|
||||
2. Install PyTorch, TorchVision, and the build dependencies used by Flash Attention:
|
||||
|
||||
```bash
|
||||
# Check https://pytorch.org/get-started/locally/ for the right CUDA wheel index for your system.
|
||||
pip install "torch>=2.7,<2.12.0" "torchvision>=0.22.0,<0.27.0" \
|
||||
--index-url https://download.pytorch.org/whl/cu128
|
||||
pip install "ninja>=1.11.1,<2.0.0" "packaging>=24.2,<26.0"
|
||||
```
|
||||
|
||||
3. Install and verify Flash Attention:
|
||||
We are working on making this optional, but in the meantime that means that we require an extra installation step and it can only be used in CUDA enabled devices.
|
||||
|
||||
1. Following the Environment Setup of our [Installation Guide](./installation). **Attention** don't install `lerobot` in this step.
|
||||
2. Install [Flash Attention](https://github.com/Dao-AILab/flash-attention) by running:
|
||||
|
||||
```bash
|
||||
# Check https://pytorch.org/get-started/locally/ for your system
|
||||
pip install "torch>=2.2.1,<2.8.0" "torchvision>=0.21.0,<0.23.0" # --index-url https://download.pytorch.org/whl/cu1XX
|
||||
pip install ninja "packaging>=24.2,<26.0" # flash attention dependencies
|
||||
pip install "flash-attn>=2.5.9,<3.0.0" --no-build-isolation
|
||||
python -c "import flash_attn; print(f'Flash Attention {flash_attn.__version__} imported successfully')"
|
||||
```
|
||||
|
||||
4. Install LeRobot with the GR00T extra:
|
||||
3. Install LeRobot by running:
|
||||
|
||||
```bash
|
||||
pip install "lerobot[groot]"
|
||||
pip install lerobot[groot]
|
||||
```
|
||||
|
||||
For a source checkout, use the same order, then install the local package with:
|
||||
|
||||
```bash
|
||||
pip install -e ".[groot]"
|
||||
```
|
||||
|
||||
If your CUDA/PyTorch build needs a different Flash Attention wheel or source build, follow the [Flash Attention project](https://github.com/Dao-AILab/flash-attention) instructions, but keep the same ordering: PyTorch first, Flash Attention next, then `lerobot[groot]`.
|
||||
|
||||
## Usage
|
||||
|
||||
To use GR00T N1.7:
|
||||
To use GR00T in your LeRobot configuration, specify the policy type as:
|
||||
|
||||
```bash
|
||||
--policy.type=groot \
|
||||
--policy.model_version=n1.7
|
||||
```python
|
||||
policy.type=groot
|
||||
```
|
||||
|
||||
## Training
|
||||
@@ -103,54 +87,21 @@ accelerate launch \
|
||||
|
||||
## Performance Results
|
||||
|
||||
### LIBERO Benchmark Results
|
||||
### Libero Benchmark Results
|
||||
|
||||
> [!NOTE]
|
||||
> Follow the [LIBERO](./libero) setup instructions before running `lerobot-eval`.
|
||||
> Follow our instructions for Libero usage: [Libero](./libero)
|
||||
|
||||
GR00T N1.7 has demonstrated strong performance on the LIBERO benchmark suite. To reproduce LeRobot results, follow the instructions in the [LIBERO](./libero) section.
|
||||
GR00T has demonstrated strong performance on the Libero benchmark suite. To compare and test its LeRobot implementation, we finetuned the GR00T N1.5 model for 30k steps on the Libero dataset and compared the results to the GR00T reference results.
|
||||
|
||||
### GR00T N1.7 LIBERO Checkpoints
|
||||
| Benchmark | LeRobot Implementation | GR00T Reference |
|
||||
| ------------------ | ---------------------- | --------------- |
|
||||
| **Libero Spatial** | 82.0% | 92.0% |
|
||||
| **Libero Object** | 99.0% | 92.0% |
|
||||
| **Libero Long** | 82.0% | 76.0% |
|
||||
| **Average** | 87.0% | 87.0% |
|
||||
|
||||
NVIDIA publishes GR00T N1.7 LIBERO checkpoints at [`nvidia/GR00T-N1.7-LIBERO`](https://huggingface.co/nvidia/GR00T-N1.7-LIBERO), with one subdirectory per LIBERO suite:
|
||||
|
||||
| Suite | Checkpoint subdirectory |
|
||||
| -------------- | ----------------------- |
|
||||
| LIBERO Spatial | `libero_spatial` |
|
||||
| LIBERO Object | `libero_object` |
|
||||
| LIBERO Goal | `libero_goal` |
|
||||
| LIBERO 10 | `libero_10` |
|
||||
|
||||
Preliminary LeRobot integration results:
|
||||
|
||||
| Suite | Status | Success rate | n_episodes |
|
||||
| -------------- | ------ | -----------: | ---------: |
|
||||
| LIBERO Spatial | ✓ | ~95% | XX |
|
||||
| LIBERO Object | ✓ | XX% | XX |
|
||||
| LIBERO Goal | ✓ | XX% | XX |
|
||||
| LIBERO 10 | ✓ | XX% | XX |
|
||||
| **Average** | ✓ | **XX%** | **XX** |
|
||||
|
||||
Replace the `XX` placeholders with final eval artifacts before merge.
|
||||
|
||||
Download the suite checkpoint locally, then point `--policy.base_model_path` at the downloaded subdirectory. `--policy.path` is reserved for LeRobot checkpoints that contain a LeRobot `config.json` with a `type` field.
|
||||
|
||||
```bash
|
||||
hf download nvidia/GR00T-N1.7-LIBERO \
|
||||
--include "libero_spatial/*" \
|
||||
--local-dir ./GR00T-N1.7-LIBERO
|
||||
|
||||
lerobot-eval \
|
||||
--policy.type=groot \
|
||||
--policy.model_version=n1.7 \
|
||||
--policy.base_model_path=./GR00T-N1.7-LIBERO/libero_spatial \
|
||||
--policy.embodiment_tag=libero_sim \
|
||||
--env.type=libero \
|
||||
--env.task=libero_spatial \
|
||||
--eval.n_episodes=50
|
||||
```
|
||||
|
||||
Use `eval.n_episodes >= 50` per suite when reporting success rates.
|
||||
These results demonstrate GR00T's strong generalization capabilities across diverse robotic manipulation tasks. To reproduce these results, you can follow the instructions in the [Libero](https://huggingface.co/docs/lerobot/libero) section.
|
||||
|
||||
### Evaluate in your hardware setup
|
||||
|
||||
@@ -180,4 +131,4 @@ lerobot-rollout\
|
||||
|
||||
## License
|
||||
|
||||
GR00T N1.7 is released under the [NVIDIA Open Model License Agreement](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license/).
|
||||
This model follows NVIDIA's proprietary license, consistent with the original [GR00T repository](https://github.com/NVIDIA/Isaac-GR00T). Future versions (starting from N1.7) will follow **Apache 2.0 License**.
|
||||
|
||||
@@ -647,5 +647,6 @@ The `--strategy.type` flag selects the execution mode:
|
||||
- `sentry`: Continuous recording with auto-upload (useful for large-scale evaluation)
|
||||
- `highlight`: Ring buffer recording with keystroke save (useful for capturing interesting events)
|
||||
- `dagger`: Human-in-the-loop data collection (see [HIL Data Collection](./hil_data_collection))
|
||||
- `episodic`: Episode-oriented policy recording with reset phases between episodes
|
||||
|
||||
All strategies support `--inference.type=rtc` for smooth execution with slow VLA models (Pi0, Pi0.5, SmolVLA).
|
||||
|
||||
@@ -157,6 +157,44 @@ Foot pedal input is also supported via `--strategy.input_device=pedal`. Configur
|
||||
| `--strategy.input_device` | Input device: `keyboard` or `pedal` (default: keyboard) |
|
||||
| `--teleop.type` | **Required.** Teleoperator type |
|
||||
|
||||
### Episodic (`--strategy.type=episodic`)
|
||||
|
||||
Episode-oriented recording that mirrors the behavior of `lerobot-record`. The policy drives the robot for each episode; an optional teleoperator can drive the robot during the reset phase between episodes.
|
||||
|
||||
```bash
|
||||
lerobot-rollout \
|
||||
--strategy.type=episodic \
|
||||
--policy.path=${HF_USER}/my_policy \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/ttyACM0 \
|
||||
--teleop.type=so100_leader \
|
||||
--teleop.port=/dev/ttyACM1 \
|
||||
--dataset.repo_id=${HF_USER}/my_eval_data \
|
||||
--dataset.num_episodes=20 \
|
||||
--dataset.episode_time_s=30 \
|
||||
--dataset.reset_time_s=10 \
|
||||
--dataset.single_task="Pick up the red cube"
|
||||
```
|
||||
|
||||
Teleop is optional — if omitted the robot holds its position during the reset phase.
|
||||
|
||||
**Keyboard controls:**
|
||||
|
||||
| Key | Action |
|
||||
| ----------- | -------------------------------- |
|
||||
| `→` (right) | End the current episode early |
|
||||
| `←` (left) | Discard episode and re-record it |
|
||||
| `ESC` | Stop the recording session |
|
||||
|
||||
| Flag | Description |
|
||||
| ----------------------------------------------- | -------------------------------------------------------------------------- |
|
||||
| `--dataset.num_episodes` | Number of episodes to record |
|
||||
| `--dataset.episode_time_s` | Duration of each recording episode in seconds |
|
||||
| `--dataset.reset_time_s` | Duration of the reset phase between episodes in seconds |
|
||||
| `--teleop.type` | Optional. Teleoperator to drive the robot during resets |
|
||||
| `--strategy.reset_to_initial_position` | Whether to reset the robot to its initial position between episodes |
|
||||
| `--strategy.smooth_leader_to_follower_handover` | Whether to turn on or off the leader -> follower smooth handover behavior. |
|
||||
|
||||
---
|
||||
|
||||
## Inference Backends
|
||||
|
||||
@@ -1,13 +1,6 @@
|
||||
## Research Paper
|
||||
|
||||
GR00T N1 technical report (covers the GR00T N1.x family, including N1.7): https://arxiv.org/abs/2503.14734
|
||||
|
||||
GR00T N1.7 model card: https://huggingface.co/nvidia/GR00T-N1.7-3B
|
||||
|
||||
GR00T N1.5 research page (earlier version): https://research.nvidia.com/labs/gear/gr00t-n1_5/
|
||||
|
||||
> GR00T N1.5 support was removed from LeRobot; the last release supporting it is `lerobot==0.5.1`.
|
||||
> Current releases support GR00T N1.7 only.
|
||||
Paper: https://research.nvidia.com/labs/gear/gr00t-n1_5/
|
||||
|
||||
## Repository
|
||||
|
||||
@@ -31,103 +24,4 @@ Code: https://github.com/NVIDIA/Isaac-GR00T
|
||||
|
||||
Blog: https://developer.nvidia.com/isaac/gr00t
|
||||
|
||||
Hugging Face Models:
|
||||
|
||||
- GR00T N1.7: https://huggingface.co/nvidia/GR00T-N1.7-3B
|
||||
- GR00T N1.7 LIBERO checkpoints: https://huggingface.co/nvidia/GR00T-N1.7-LIBERO
|
||||
|
||||
## Original-vs-LeRobot parity test
|
||||
|
||||
`tests/policies/groot/test_groot_vs_original.py` verifies this LeRobot
|
||||
reimplementation of GR00T N1.7 (Qwen3-VL backbone + flow-matching action head)
|
||||
against NVIDIA's original `gr00t` package with two comparisons, each parametrized
|
||||
over every embodiment tag present in the checkpoint:
|
||||
|
||||
1. **Model parity** — given byte-identical pre-processed inputs and the same
|
||||
flow-matching seed (recorded in each artifact), both implementations must produce
|
||||
the **same raw model output** (`get_action(...)["action_pred"]`, the normalized
|
||||
flow-matching prediction). Output shapes must match exactly; any action-horizon
|
||||
or action-dim mismatch fails the test.
|
||||
2. **Preprocessor parity** — given the identical raw observations (per-camera
|
||||
frames, state vectors, language instruction), LeRobot's own preprocessor pipeline
|
||||
(real Qwen3-VL chat template / tokenizer / image packing + checkpoint-driven
|
||||
state normalization, no mocks) must produce the **same collated model inputs**
|
||||
(`input_ids`, `attention_mask`, `pixel_values`, `image_grid_thw`, `state`,
|
||||
`embodiment_id`) as the original package's processor.
|
||||
|
||||
### Why two environments
|
||||
|
||||
The original `gr00t` package pins `transformers==4.57.3` (Python 3.10); this
|
||||
integration requires `transformers>=5.x` (Qwen3-VL). Under 5.x, `PretrainedConfig`
|
||||
is itself a defaulted dataclass, so the original config dataclasses fail to import
|
||||
(`non-default argument follows default argument`). The two implementations therefore
|
||||
**cannot be imported in the same Python process**.
|
||||
|
||||
So the test uses a **producer / consumer** split across two venvs:
|
||||
|
||||
1. **Producer** — `tests/policies/groot/utils/dump_original_n1_7.py`, run in the _original_
|
||||
gr00t venv. For each embodiment it builds dummy inputs generically from the
|
||||
checkpoint metadata (state dims from `statistics.json`; camera/language keys from
|
||||
the processor modality configs), runs the original model, and saves to one `.npz`
|
||||
per tag: the raw observations (`raw::` keys), the exact collated inputs
|
||||
(`in::` keys), the seed, and the raw `action_pred`.
|
||||
2. **Consumer** — the pytest above, run in the _LeRobot_ venv. It discovers every
|
||||
`.npz`; the model-parity case replays the byte-identical collated inputs through
|
||||
the LeRobot model with the recorded seed and asserts the outputs match, and the
|
||||
preprocessor-parity case replays the raw observations through LeRobot's full
|
||||
preprocessor pipeline and asserts the collated tensors match.
|
||||
|
||||
> Artifacts generated by older versions of the dump script contain no `raw::`
|
||||
> fields; the preprocessor-parity case then **skips** with a regeneration hint.
|
||||
> Re-run the producer to refresh them.
|
||||
|
||||
### Fairness controls
|
||||
|
||||
- **Same pre-processed inputs (model parity)** — the original processor's `input_ids`,
|
||||
`pixel_values`, `image_grid_thw`, `attention_mask`, `state`, `embodiment_id` are
|
||||
fed verbatim to the LeRobot model (no re-tokenization / re-normalization), so the
|
||||
model comparison isolates the model. LeRobot's own tokenization / image packing is
|
||||
covered separately by the preprocessor-parity case, which compares its output
|
||||
against those same collated tensors from identical raw observations.
|
||||
- **Same precision + attention kernel** — both sides run **fp32 + SDPA**. The
|
||||
original defaults to `use_flash_attention=True` (flash_attention_2 + bf16); the
|
||||
producer forces SDPA + fp32. (With the defaults the gap is ~3e-2 — pure
|
||||
kernel/rounding noise, not an implementation difference.)
|
||||
- **Same flow-matching seed** — fixed right before sampling on both sides; the
|
||||
producer records it in each artifact (`--seed`, default 42) and the consumer
|
||||
replays the recorded value.
|
||||
|
||||
### How to run
|
||||
|
||||
```bash
|
||||
# Resolve a local checkpoint (GR00T-N1.7-LIBERO / libero_10)
|
||||
CKPT=$(python - <<'PY'
|
||||
import os
|
||||
from huggingface_hub import snapshot_download
|
||||
print(os.path.join(snapshot_download("nvidia/GR00T-N1.7-LIBERO",
|
||||
allow_patterns=["libero_10/*"]), "libero_10"))
|
||||
PY
|
||||
)
|
||||
|
||||
# 1) Produce the original-side artifacts for all embodiments (original gr00t venv, CUDA)
|
||||
CUDA_VISIBLE_DEVICES=0 /path/to/Isaac-GR00T/.venv-original/bin/python \
|
||||
tests/policies/groot/utils/dump_original_n1_7.py \
|
||||
--ckpt "$CKPT" --out-dir tests/policies/groot/artifacts --device cuda --seed 42
|
||||
|
||||
# 2) Run the parity test (LeRobot venv) — one parametrized case per embodiment
|
||||
CUDA_VISIBLE_DEVICES=0 GROOT_PARITY_DEVICE=cuda \
|
||||
uv run pytest tests/policies/groot/test_groot_vs_original.py -v -s
|
||||
```
|
||||
|
||||
The `.npz` artifacts are local-only (gitignored, ~6–10 MB each) and are regenerated by
|
||||
the producer; they are never committed. The tests **skip** (do not fail) on CI or
|
||||
when the checkpoint / artifacts are absent.
|
||||
|
||||
#### Env knobs (all optional)
|
||||
|
||||
| Var | Default | Purpose |
|
||||
| ----------------------------------------- | -------------------------------- | ------------------------------------- |
|
||||
| `GROOT_N1_7_PARITY_DIR` | `tests/policies/groot/artifacts` | directory of per-tag `.npz` artifacts |
|
||||
| `GROOT_N1_7_LIBERO_CKPT` | auto (HF cache) | override checkpoint dir |
|
||||
| `GROOT_PARITY_DEVICE` | `cuda` if available | `cpu` or `cuda` |
|
||||
| `GROOT_PARITY_ATOL` / `GROOT_PARITY_RTOL` | `1e-3` | comparison tolerance |
|
||||
Hugging Face Model: https://huggingface.co/nvidia/GR00T-N1.5-3B
|
||||
|
||||
@@ -0,0 +1,250 @@
|
||||
# Remote Inference (lerobot-policy-server)
|
||||
|
||||
Remote inference decouples GPU policy inference from robot control. A `lerobot-policy-server` process runs the policy on a GPU machine; the robot runs `lerobot-rollout --inference.type=remote` as a **weightless edge client** — no policy weights, no GPU, no policy processors on the robot. One GPU server can serve several robots at once, and the remote backend works with every rollout strategy (`base`, `sentry`, `highlight`, `dagger`, `episodic`).
|
||||
|
||||
Use remote inference when:
|
||||
|
||||
- The policy is too large or too slow for the machine attached to the robot (e.g. Pi0/Pi0.5 on a Raspberry Pi or laptop edge).
|
||||
- You want one GPU to serve a fleet of robots running the same policy.
|
||||
- You want to update or restart the inference side without touching the robots.
|
||||
|
||||
<Tip>
|
||||
|
||||
Remote inference requires the `async` extra on **both** sides: `pip install 'lerobot[async]'` (installs `eclipse-zenoh` and `msgpack`). The server additionally needs the extras of the policy it serves (e.g. `lerobot[pi]`, `lerobot[smolvla]`).
|
||||
|
||||
</Tip>
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
robot (edge, weightless) GPU machine
|
||||
┌───────────────────────────┐ ┌────────────────────────────┐
|
||||
│ lerobot-rollout │ │ lerobot-policy-server │
|
||||
│ --inference.type=remote │ zenoh │ one process = one │
|
||||
│ │ router │ (model, revision, GPU) │
|
||||
│ control loop @ fps │ ┌────────┐ │ │
|
||||
│ └─ pops local action ◄──┼───┤ zenohd ├─────┼─► inference worker thread │
|
||||
│ buffer (chunks) │ └────────┘ │ (round-robin over │
|
||||
│ │ observations ► │ client sessions) │
|
||||
│ network worker thread ───┼──► ◄ action │ │
|
||||
│ (publishes obs, merges │ chunks │ stateless per request │
|
||||
│ chunks into buffer) │ │ │
|
||||
└───────────────────────────┘ └────────────────────────────┘
|
||||
```
|
||||
|
||||
The client keeps a local **action buffer** filled with chunks of future actions, so the control loop never blocks on the network: short network blips are absorbed by the buffer and the robot keeps moving. The client self-clocks — it requests a new chunk whenever the buffer holds less than `--inference.buffer_time_s` seconds of playback.
|
||||
|
||||
The server is **stateless per request**: clients ship their RTC prefixes and a delay hint with every observation, so a server crash or restart loses zero control state and reconnects are trivial. In production both robots and servers _dial out_ to a `zenohd` router (NAT-friendly: nothing on the robot network needs an open inbound port).
|
||||
|
||||
## Quickstart on a LAN (peer mode, no router)
|
||||
|
||||
For a quick test on one network you can skip the router: the server listens directly and the robot connects to it.
|
||||
|
||||
On the GPU machine:
|
||||
|
||||
```bash
|
||||
lerobot-policy-server \
|
||||
--model.repo_or_path=${HF_USER}/my_pi0_policy \
|
||||
--default_task="pick up the cube" \
|
||||
--zenoh.mode=peer \
|
||||
--zenoh.listen_endpoints='["tcp/0.0.0.0:7447"]'
|
||||
```
|
||||
|
||||
Wait for `Policy server up: ...` (the model is downloaded, loaded, and warmed up first).
|
||||
|
||||
On the robot machine (replace `192.168.1.42` with the GPU machine's IP):
|
||||
|
||||
```bash
|
||||
lerobot-rollout \
|
||||
--strategy.type=base \
|
||||
--policy.path=${HF_USER}/my_pi0_policy \
|
||||
--inference.type=remote \
|
||||
--inference.zenoh_mode=peer \
|
||||
--inference.connect_endpoint=tcp/192.168.1.42:7447 \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/ttyACM0 \
|
||||
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||
--task="pick up the cube" \
|
||||
--duration=60
|
||||
```
|
||||
|
||||
`--policy.path` on the client resolves to a config-only download (no weights): it is used for pre-flight validation and action ordering, and doubles as the default service address. The client's `--policy.path` and `--task` must match the server's `--model.repo_or_path` and `--default_task` — that pair is the namespace the service is published under (see [Troubleshooting](#troubleshooting)).
|
||||
|
||||
## Production deployment (router)
|
||||
|
||||
In production, run a [zenoh router](https://zenoh.io/docs/getting-started/installation/) (`zenohd`) somewhere both sides can reach, and have robots and servers dial out to it:
|
||||
|
||||
```bash
|
||||
zenohd # listens on tcp/0.0.0.0:7447 by default
|
||||
```
|
||||
|
||||
Configure the server with a YAML manifest:
|
||||
|
||||
```yaml
|
||||
# server.yaml
|
||||
model:
|
||||
repo_or_path: lerobot/pi0_towels
|
||||
revision: main
|
||||
dtype: bfloat16 # optional cast after load
|
||||
device: cuda
|
||||
default_task: "fold the towel"
|
||||
serving_mode: auto # shared for verified chunk-stateless policies, exclusive otherwise
|
||||
max_sessions: 5
|
||||
warmup_inferences: 2
|
||||
trained_fps: 30.0
|
||||
rtc:
|
||||
enabled: true
|
||||
execution_horizon: 10
|
||||
max_guidance_weight: 10.0
|
||||
health_port: 9100 # /healthz + /metrics; 0 disables
|
||||
zenoh:
|
||||
mode: client
|
||||
connect_endpoints: ["tcp/router.gpu-cluster.internal:7447"]
|
||||
```
|
||||
|
||||
```bash
|
||||
lerobot-policy-server --manifest server.yaml
|
||||
```
|
||||
|
||||
Everything in the manifest can also be set directly on the CLI (`--model.repo_or_path=...`, `--max_sessions=...`, etc.). One process serves exactly one `(model, revision, dtype, device)` — to serve two models, or one model on two GPUs, run two processes. Dynamic model loading is deliberately unsupported: pre-warmed processes keep capacity planning honest.
|
||||
|
||||
On the robot, only the endpoint changes (the default `--inference.zenoh_mode=client` is already router mode):
|
||||
|
||||
```bash
|
||||
lerobot-rollout \
|
||||
--strategy.type=base \
|
||||
--policy.path=lerobot/pi0_towels \
|
||||
--inference.type=remote \
|
||||
--inference.connect_endpoint=tcp/router.gpu-cluster.internal:7447 \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/ttyACM0 \
|
||||
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||
--task="fold the towel" \
|
||||
--duration=600
|
||||
```
|
||||
|
||||
### TLS / mTLS
|
||||
|
||||
For traffic that leaves a trusted network, terminate TLS at the router and give both sides client certificates (all three PEM paths are required together):
|
||||
|
||||
```yaml
|
||||
# server.yaml (zenoh section)
|
||||
zenoh:
|
||||
mode: client
|
||||
connect_endpoints: ["tls/router.gpu-cluster.internal:7447"]
|
||||
tls_root_ca_certificate: /etc/lerobot/ca.pem
|
||||
tls_connect_certificate: /etc/lerobot/server.pem
|
||||
tls_connect_private_key: /etc/lerobot/server.key
|
||||
```
|
||||
|
||||
On the robot the equivalent flags are `--inference.tls_ca`, `--inference.tls_cert`, and `--inference.tls_key`, with `--inference.connect_endpoint=tls/...`.
|
||||
|
||||
<Tip>
|
||||
|
||||
Multicast scouting is always disabled: discovery is configuration, not protocol magic. If nothing connects, check the endpoints — there is no fallback discovery mechanism.
|
||||
|
||||
</Tip>
|
||||
|
||||
## RTC over the network
|
||||
|
||||
The remote engine reuses the [Real-Time Chunking](./rtc) machinery: the client keeps the chunk leftover and latency tracking locally and ships an action prefix plus a delay hint with every observation; the server runs prefix-conditioned chunk generation. This gives the same smooth chunk-to-chunk transitions as local RTC, with network latency folded into the delay computation.
|
||||
|
||||
RTC is enabled by default on both sides (`rtc.enabled: true`). Tune it from the client:
|
||||
|
||||
```bash
|
||||
lerobot-rollout \
|
||||
... \
|
||||
--inference.type=remote \
|
||||
--inference.rtc.execution_horizon=10 \
|
||||
--inference.rtc.max_guidance_weight=10.0
|
||||
```
|
||||
|
||||
If the server or its policy does not support RTC (only `pi0`, `pi05`, and `smolvla` are RTC-capable, and the server manifest must have `rtc.enabled: true`), the session is **downgraded to plain chunk-append** and the client logs:
|
||||
|
||||
```
|
||||
RTC downgraded to chunk-append (server does not support RTC)
|
||||
```
|
||||
|
||||
The robot still runs — chunks are simply appended to the buffer without prefix blending, which can produce visible seams between chunks on slow policies.
|
||||
|
||||
## Fail-safe behavior
|
||||
|
||||
The client runs a fail-safe state machine (`CONNECTING → STREAMING → DEGRADED → STALLED → RECONNECTING → DEAD`). A bad initial deployment fails fast: `lerobot-rollout` aborts before the robot moves if the handshake or validation fails. Once streaming, faults degrade in stages:
|
||||
|
||||
| Condition | Behavior |
|
||||
| -------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| Short network blip / late chunk | The robot rides its action buffer; state goes `DEGRADED` after `--inference.degraded_after_s` (default 1.0 s) without a fresh chunk |
|
||||
| Buffered actions older than `max_action_age_s` | Stale actions are dropped (never executed); default `--inference.max_action_age_s=3.0` |
|
||||
| Buffer runs dry (`STALLED`) | Fallback per `--inference.fallback`: `hold` (default — robot holds its last commanded position), `repeat_last`, or `zero` |
|
||||
| Server liveliness lost / repeated request timeouts | `RECONNECTING`: re-handshake with exponential backoff (`reconnect_initial_backoff_s=0.5` doubling up to `reconnect_max_backoff_s=10.0`) |
|
||||
| Reconnected server runs a different model/revision | Hard refusal (`DEAD`) — the client never executes wrong-model chunks |
|
||||
| Offline longer than `max_offline_s` (default 60 s) | `DEAD`: the engine signals the rollout's shutdown event for a clean stop |
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
`--inference.fallback=zero` is required for velocity-controlled robots: for them "send nothing" means "keep the last velocity", so an explicit zero command is the only safe stop. For position-controlled arms the default `hold` is safe.
|
||||
|
||||
</Tip>
|
||||
|
||||
Server restarts are equally graceful: on SIGTERM the server drops its liveliness token first (clients ride their buffers through the drain), finishes the in-flight inference, and exits. Clients reconnect when the replacement comes up.
|
||||
|
||||
## Serving multiple robots
|
||||
|
||||
`max_sessions` caps concurrent clients per server process. A single inference worker thread serializes GPU access and round-robins over sessions with a pending observation; per-client newest-wins mailboxes mean overload degrades into longer cycle times (larger but correct client-side delays), never into queue buildup.
|
||||
|
||||
A rough capacity estimate, keeping ~20% headroom:
|
||||
|
||||
```
|
||||
N_robots ≈ 0.8 / (rate × inference_time)
|
||||
```
|
||||
|
||||
where `rate` is each robot's chunk-request rate in Hz (how often the client's buffer dips below `buffer_time_s`) and `inference_time` is the server's seconds per chunk. For example, at 100 ms per chunk and ~2 chunk requests per second per robot: `N ≈ 0.8 / (2 × 0.1) = 4` robots.
|
||||
|
||||
The actual serving mode is classified per policy family, never inferred:
|
||||
|
||||
- **shared** — verified chunk-stateless policies (`act`, `pi0`, `pi05`, and `smolvla` with `n_obs_steps=1`) serve up to `max_sessions` clients from one policy instance.
|
||||
- **exclusive** — stateful families (diffusion-family policies, `smolvla` with observation history, and any unverified policy) are forced to `max_sessions=1`. Run one server process per robot for these.
|
||||
|
||||
`serving_mode: auto` (the default) resolves this automatically; you may force `exclusive`, but `shared` can never override a stateful classification.
|
||||
|
||||
## Observability
|
||||
|
||||
With `health_port` set (default 9100), the server exposes:
|
||||
|
||||
- `GET /healthz` — `200 ok` while the inference worker is alive, `503` otherwise. Wire this to your orchestrator's liveness probe.
|
||||
- `GET /metrics` — Prometheus text format: `lerobot_policy_server_requests_total`, `errors_total`, `superseded_total`, `dropped_unknown_client_total`, `sessions_opened_total`, `sessions_closed_total`, `active_sessions`, `server_load`.
|
||||
|
||||
Every inference request also emits one structured audit line on the `lerobot.policy_server.audit` logger:
|
||||
|
||||
```json
|
||||
{
|
||||
"session_id": "9f2c...",
|
||||
"client_uuid": "robot-07",
|
||||
"seq_id": 412,
|
||||
"episode_id": 3,
|
||||
"queue_wait_ms": 1.8,
|
||||
"inference_ms": 93.2,
|
||||
"superseded": 0,
|
||||
"outcome": "ok"
|
||||
}
|
||||
```
|
||||
|
||||
`(session_id, seq_id)` correlates a server-side audit line with the client's request. Set a stable `--inference.client_uuid` per robot (instead of the default fresh UUID per run) for fleet-wide log correlation, and use `--inference.tags` to forward free-form labels in the handshake.
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
**`No policy server answered status query at '@lerobot/...'`**
|
||||
|
||||
The client found no server under the key it dialed. Either the endpoint is wrong (check `--inference.connect_endpoint`, the router, and firewalls), or the **service namespace** does not match. The namespace is the `(model_id, revision, task)` triple: on the client it comes from `--inference.service_model_id` (default: `--policy.path`), `--inference.service_revision` (default: `main`), and `--inference.service_task` (default: the rollout `--task`); on the server from `model.repo_or_path`, `model.revision`, and `service_name` (default: a slug of `default_task`). A robot task string that differs from the server's `default_task` is the most common cause — fix the task, or pin the namespace explicitly with `--inference.service_task` on the client / `service_name` in the manifest.
|
||||
|
||||
**`Action name/order mismatch between server policy and this robot`**
|
||||
|
||||
The hard sync-safety contract: chunk columns map to motors **by order**, so the robot's ordered action keys must exactly equal the policy's `action_feature_names`. This fires when the robot type, motor naming, or rename map differs from the training setup. Use the same robot type (and rename map) the policy was trained with.
|
||||
|
||||
**`RTC requested but this server/policy does not support it — downgrading to chunk-append`**
|
||||
|
||||
Informational, not fatal. Enable RTC in the server manifest (`rtc.enabled: true`) and make sure the policy family is RTC-capable (`pi0`, `pi05`, `smolvla`). Otherwise, expect chunk-append behavior (see [RTC over the network](#rtc-over-the-network)).
|
||||
|
||||
**`server full: N/N sessions active`**
|
||||
|
||||
The session-open was rejected at capacity. Raise `max_sessions` (shared mode only), or point the robot at another server replica — the rejection includes the current load so orchestration can retry elsewhere.
|
||||
+9
-9
@@ -151,18 +151,18 @@ lerobot-rollout \
|
||||
--device=cuda
|
||||
```
|
||||
|
||||
## How It Differs from the Async Inference in LeRobot
|
||||
## How It Relates to Remote Inference
|
||||
|
||||
Both RTC and [async inference](./async) improve real-time robot control, but they solve different problems.
|
||||
Both RTC and [remote inference](./remote_inference) improve real-time robot control, but they solve different problems.
|
||||
|
||||
| Aspect | Async Inference | RTC |
|
||||
| ------------- | -------------------------------------------------------------------------- | --------------------------------------------------- |
|
||||
| **Problem** | Idle frames while waiting for inference | Discontinuities between action chunks |
|
||||
| **Solution** | Decouple prediction from execution | Guide new chunks to continue smoothly from previous |
|
||||
| **Benefit** | No waiting, continuous action | Smooth transitions, natural motion |
|
||||
| **Best Used** | Async inference is best used with large models with high inference latency | Flow-matching based policies |
|
||||
| Aspect | Remote Inference | RTC |
|
||||
| ------------- | ------------------------------------------------------------------------ | --------------------------------------------------- |
|
||||
| **Problem** | The policy is too large (or too slow) for the edge machine | Discontinuities between action chunks |
|
||||
| **Solution** | Run inference on a GPU server; the robot executes buffered action chunks | Guide new chunks to continue smoothly from previous |
|
||||
| **Benefit** | Weightless edge clients, one GPU serves many robots | Smooth transitions, natural motion |
|
||||
| **Best Used** | Large models with high inference latency, robot fleets | Flow-matching based policies |
|
||||
|
||||
**Use both together** for maximum smoothness and reactivity!
|
||||
**Use both together** (`--inference.type=remote` with `--inference.rtc.execution_horizon=...`) for maximum smoothness and reactivity: the remote engine reuses RTC's chunk-merging machinery client-side while the server runs prefix-conditioned chunk generation.
|
||||
|
||||
## Advanced: Debug Tracking
|
||||
|
||||
|
||||
@@ -0,0 +1,115 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Example manifest for `lerobot-policy-server --manifest server.yaml`.
|
||||
#
|
||||
# One process = one (model, revision, dtype, device) on one GPU. Dynamic
|
||||
# model loading is deliberately unsupported: pre-warmed processes keep
|
||||
# capacity planning honest. Every field below can also be overridden on
|
||||
# the command line via draccus, e.g. --model.repo_or_path=... or
|
||||
# --zenoh.connect_endpoints='["tcp/other-router:7447"]'.
|
||||
#
|
||||
# Field names mirror the dataclasses in src/lerobot/policy_server/manifest.py.
|
||||
|
||||
# --- Which policy this process serves, and where it runs ------------------
|
||||
model:
|
||||
# Hub repo id (org/name) or a local checkpoint directory. Required.
|
||||
repo_or_path: lerobot/pi0_towels
|
||||
# Hub revision: branch, tag, or commit sha.
|
||||
revision: main
|
||||
# Optional torch dtype cast applied after load (e.g. "bfloat16",
|
||||
# "float16"). null keeps the checkpoint's native dtype.
|
||||
dtype: bfloat16
|
||||
# Inference device, e.g. "cuda", "cuda:1", "cpu".
|
||||
device: cuda
|
||||
|
||||
# --- Task namespace --------------------------------------------------------
|
||||
# The task this service is published under. VLA clients may override the
|
||||
# task per session unless `pin_task` is true, in which case session opens
|
||||
# with a different task string are rejected.
|
||||
default_task: "fold the towel"
|
||||
pin_task: false
|
||||
# Optional override for the <task_slug> key segment of the Zenoh prefix
|
||||
# (defaults to a slug of `default_task`).
|
||||
service_name: ""
|
||||
|
||||
# --- Serving mode & capacity ------------------------------------------------
|
||||
# "auto" resolves from the policy classification: shared for verified
|
||||
# chunk-stateless policies (act/pi0/pi05, smolvla with n_obs_steps=1),
|
||||
# exclusive otherwise. Chunk-stateful policies — e.g. diffusion, whose
|
||||
# predict_action_chunk reads select_action-fed queues — are always forced
|
||||
# to "exclusive" (max_sessions=1); "shared" cannot override that.
|
||||
serving_mode: auto
|
||||
|
||||
# Capacity rule-of-thumb: with t = server seconds per inference, r = each
|
||||
# client's request rate (self-clocked to ~1-4 Hz, not the control rate),
|
||||
# H = RTC execution horizon, and dt = control period:
|
||||
# max_sessions ~= min( 0.8 / (r*t), (H*dt/2 - network RTT) / t )
|
||||
# e.g. ACT @ 20 ms, 1 Hz refresh -> ~40 clients/GPU; Pi0 @ 150 ms -> ~5.
|
||||
# Session opens beyond this are rejected with the current load in the
|
||||
# reply, so clients retry another replica.
|
||||
max_sessions: 5
|
||||
|
||||
# Dummy inferences run at startup so the first real request does not pay
|
||||
# for CUDA graph/kernel warmup.
|
||||
warmup_inferences: 2
|
||||
|
||||
# --- FPS contract -----------------------------------------------------------
|
||||
# Control rate the policy was trained at. Clients reporting a different
|
||||
# fps get a warning — or a hard reject when `strict_fps` is true.
|
||||
trained_fps: 30.0
|
||||
strict_fps: false
|
||||
|
||||
# --- Real Time Chunking (RTC) -----------------------------------------------
|
||||
# Global to this process: init_rtc_processor mutates the policy instance,
|
||||
# so RTC is a per-process decision, not per-session. Only rtc-capable
|
||||
# families (pi0/pi05/smolvla) honor it; others are downgraded to plain
|
||||
# chunk-append at session open.
|
||||
rtc:
|
||||
enabled: true
|
||||
# Number of actions executed from each chunk before the next chunk is
|
||||
# blended in (the H in the capacity formula above).
|
||||
execution_horizon: 10
|
||||
|
||||
# --- Housekeeping ------------------------------------------------------------
|
||||
# Sessions with no liveliness token and no traffic for this long are
|
||||
# garbage-collected (belt-and-braces behind liveliness GC).
|
||||
session_idle_timeout_s: 300.0
|
||||
|
||||
# --- Transport ----------------------------------------------------------------
|
||||
# Robots and servers both *dial out* to a zenohd router in production
|
||||
# (mode: client). mode: peer + listen_endpoints supports router-less LAN
|
||||
# and loopback test deployments. Multicast scouting is always disabled:
|
||||
# fleet discovery is configuration, not protocol magic.
|
||||
zenoh:
|
||||
mode: client
|
||||
connect_endpoints:
|
||||
- tcp/router.gpu-cluster.internal:7447
|
||||
listen_endpoints: []
|
||||
# mTLS material (PEM paths). All three are required for tls/ endpoints;
|
||||
# leave them null for plain tcp/ inside a trusted network.
|
||||
# tls_root_ca_certificate: /etc/lerobot/tls/ca.pem
|
||||
# tls_connect_certificate: /etc/lerobot/tls/server.pem
|
||||
# tls_connect_private_key: /etc/lerobot/tls/server.key
|
||||
# Escape hatch: raw JSON5 merged into the zenoh config last.
|
||||
# extra_config_json5: '{transport: {link: {tx: {queue: {size: {data: 4}}}}}}'
|
||||
|
||||
# --- Observability -------------------------------------------------------------
|
||||
# HTTP health + Prometheus metrics port; 0 disables the endpoint.
|
||||
health_port: 9100
|
||||
|
||||
# Optional bounded request/response capture for offline replay.
|
||||
debug:
|
||||
capture_dir: null
|
||||
capture_max: 256
|
||||
@@ -1,17 +0,0 @@
|
||||
from lerobot.async_inference.configs import PolicyServerConfig
|
||||
from lerobot.async_inference.policy_server import serve
|
||||
|
||||
|
||||
def main():
|
||||
host = ... # something like "127.0.0.1" if you're exposing to localhost
|
||||
port = ... # something like 8080
|
||||
|
||||
config = PolicyServerConfig(
|
||||
host=host,
|
||||
port=port,
|
||||
)
|
||||
serve(config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,62 +0,0 @@
|
||||
import threading
|
||||
|
||||
from lerobot.async_inference.configs import RobotClientConfig
|
||||
from lerobot.async_inference.helpers import visualize_action_queue_size
|
||||
from lerobot.async_inference.robot_client import RobotClient
|
||||
from lerobot.cameras.opencv import OpenCVCameraConfig
|
||||
from lerobot.robots.so_follower import SO100FollowerConfig
|
||||
|
||||
|
||||
def main():
|
||||
# these cameras must match the ones expected by the policy - find your cameras with lerobot-find-cameras
|
||||
# check the config.json on the Hub for the policy you are using to see the expected camera specs
|
||||
camera_cfg = {
|
||||
"up": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
|
||||
"side": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30),
|
||||
}
|
||||
|
||||
# # find ports using lerobot-find-port
|
||||
follower_port = ... # something like "/dev/tty.usbmodem58760431631"
|
||||
|
||||
# # the robot ids are used the load the right calibration files
|
||||
follower_id = ... # something like "follower_so100"
|
||||
|
||||
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id, cameras=camera_cfg)
|
||||
|
||||
server_address = ... # something like "127.0.0.1:8080" if using localhost
|
||||
|
||||
# 3. Create client configuration
|
||||
client_cfg = RobotClientConfig(
|
||||
robot=robot_cfg,
|
||||
server_address=server_address,
|
||||
policy_device="mps",
|
||||
client_device="cpu",
|
||||
policy_type="act",
|
||||
pretrained_name_or_path="<user>/robot_learning_tutorial_act",
|
||||
chunk_size_threshold=0.5, # g
|
||||
actions_per_chunk=50, # make sure this is less than the max actions of the policy
|
||||
)
|
||||
|
||||
# 4. Create and start client
|
||||
client = RobotClient(client_cfg)
|
||||
|
||||
# 5. Provide a textual description of the task
|
||||
task = ...
|
||||
|
||||
if client.start():
|
||||
# Start action receiver thread
|
||||
action_receiver_thread = threading.Thread(target=client.receive_actions, daemon=True)
|
||||
action_receiver_thread.start()
|
||||
|
||||
try:
|
||||
# Run the control loop
|
||||
client.control_loop(task)
|
||||
except KeyboardInterrupt:
|
||||
client.stop()
|
||||
action_receiver_thread.join()
|
||||
# (Optionally) plot the action queue size
|
||||
visualize_action_queue_size(client.action_queue_size)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
+19
-15
@@ -115,8 +115,8 @@ dataset = [
|
||||
]
|
||||
training = [
|
||||
"lerobot[dataset]",
|
||||
"accelerate>=1.10.0,<2.0.0",
|
||||
"wandb>=0.24.0,<0.25.0",
|
||||
"wandb>=0.24.0,<0.28.0",
|
||||
"lerobot[accelerate-dep]",
|
||||
]
|
||||
hardware = [
|
||||
"lerobot[pynput-dep]",
|
||||
@@ -142,7 +142,8 @@ pygame-dep = ["pygame>=2.5.1,<2.7.0"]
|
||||
# (noble ships urdfdom 3.x). Cap below 0.9.16 until system urdfdom 4.x is broadly available.
|
||||
placo-dep = ["placo>=0.9.6,<0.9.16"]
|
||||
transformers-dep = ["transformers>=5.4.0,<5.6.0"]
|
||||
grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"]
|
||||
grpcio-dep = ["grpcio>=1.73.1,<2.0.0", "protobuf>=6.31.1,<8.0.0"]
|
||||
accelerate-dep = ["accelerate>=1.14.0,<2.0.0"]
|
||||
can-dep = ["python-can>=4.2.0,<5.0.0"]
|
||||
peft-dep = ["peft>=0.18.0,<1.0.0"]
|
||||
scipy-dep = ["scipy>=1.14.0,<2.0.0"]
|
||||
@@ -177,7 +178,12 @@ unitree_g1 = [
|
||||
"lerobot[matplotlib-dep]",
|
||||
"lerobot[pygame-dep]",
|
||||
]
|
||||
reachy2 = ["reachy2_sdk>=1.0.15,<1.1.0"]
|
||||
# reachy2-sdk caps grpcio<=1.73.1 and protobuf<=6.32.0; quarantined here so downstream users aren't held back. reachy2-sdk is unlikely to release new versions.
|
||||
reachy2 = [
|
||||
"reachy2_sdk>=1.0.15,<1.1.0",
|
||||
"grpcio<=1.73.1",
|
||||
"protobuf<=6.32.0",
|
||||
]
|
||||
# Seeed Studio reBot B601-DM follower (motorbridge / CAN) + StarArm102 / reBot Arm 102
|
||||
# leader (motorbridge-smart-servo / FashionStar UART servos).
|
||||
rebot = ["lerobot[motorbridge-dep]", "lerobot[motorbridge-smart-servo-dep]"]
|
||||
@@ -199,7 +205,7 @@ wallx = [
|
||||
]
|
||||
pi = ["lerobot[transformers-dep]", "lerobot[scipy-dep]"]
|
||||
molmoact2 = ["lerobot[transformers-dep]", "lerobot[peft-dep]", "lerobot[scipy-dep]"]
|
||||
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0"]
|
||||
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "lerobot[accelerate-dep]"]
|
||||
multi_task_dit = ["lerobot[transformers-dep]", "lerobot[diffusers-dep]"]
|
||||
groot = [
|
||||
"lerobot[transformers-dep]",
|
||||
@@ -216,24 +222,26 @@ robometer = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]", "lerobot
|
||||
topreward = ["lerobot[transformers-dep]"]
|
||||
xvla = ["lerobot[transformers-dep]"]
|
||||
eo1 = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||
hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
||||
hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.14,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
||||
vla_jepa = ["lerobot[transformers-dep]", "lerobot[diffusers-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||
|
||||
# Features
|
||||
async = ["lerobot[grpcio-dep]", "lerobot[matplotlib-dep]"]
|
||||
# Remote inference over Zenoh: lerobot-policy-server + lerobot-rollout --inference.type=remote.
|
||||
# Keep zenohd routers on the same minor version as the Python binding.
|
||||
async = ["eclipse-zenoh>=1.9,<2.0", "msgpack>=1.0.0,<2.0.0"]
|
||||
peft = ["lerobot[transformers-dep]", "lerobot[peft-dep]"]
|
||||
|
||||
# Development
|
||||
dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1", "mypy>=1.19.1", "ruff>=0.14.1", "lerobot[notebook]"]
|
||||
dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools>=1.73.1,<2.0.0", "mypy>=1.19.1", "ruff>=0.14.1", "lerobot[notebook]"]
|
||||
notebook = ["jupyter>=1.0.0,<2.0.0", "ipykernel>=6.0.0,<7.0.0"]
|
||||
test = ["pytest>=8.1.0,<9.0.0", "pytest-timeout>=2.4.0,<3.0.0", "pytest-cov>=5.0.0,<8.0.0", "mock-serial>=0.0.1,<0.1.0 ; sys_platform != 'win32'"]
|
||||
video_benchmark = ["scikit-image>=0.23.2,<0.26.0", "pandas>=2.2.2,<2.4.0"]
|
||||
|
||||
# Simulation
|
||||
# NOTE: Explicitly listing scipy helps flatten the dependecy tree.
|
||||
aloha = ["lerobot[dataset]", "gym-aloha>=0.1.2,<0.2.0", "lerobot[scipy-dep]"]
|
||||
aloha = ["lerobot[dataset]", "gym-aloha>=0.1.4,<0.2.0", "lerobot[scipy-dep]"]
|
||||
pusht = ["lerobot[dataset]", "gym-pusht>=0.1.5,<0.2.0", "pymunk>=6.6.0,<7.0.0"] # TODO: Fix pymunk version in gym-pusht instead
|
||||
libero = ["lerobot[dataset]", "lerobot[transformers-dep]", "hf-libero>=0.1.3,<0.2.0; sys_platform == 'linux'", "lerobot[scipy-dep]"]
|
||||
libero = ["lerobot[dataset]", "lerobot[transformers-dep]", "hf-libero>=0.1.4,<0.2.0; sys_platform == 'linux'", "lerobot[scipy-dep]"]
|
||||
metaworld = ["lerobot[dataset]", "metaworld==3.0.0", "lerobot[scipy-dep]"]
|
||||
# NOTE: vlabench is NOT exposed as a `lerobot` extra. Its only distribution
|
||||
# is the OpenMOSS/VLABench GitHub repo (package name `VLABench`, no PyPI
|
||||
@@ -318,6 +326,7 @@ lerobot-imgtransform-viz="lerobot.scripts.lerobot_imgtransform_viz:main"
|
||||
lerobot-edit-dataset="lerobot.scripts.lerobot_edit_dataset:main"
|
||||
lerobot-setup-can="lerobot.scripts.lerobot_setup_can:main"
|
||||
lerobot-rollout="lerobot.scripts.lerobot_rollout:main"
|
||||
lerobot-policy-server="lerobot.scripts.lerobot_policy_server:main"
|
||||
|
||||
# ---------------- Tool Configurations ----------------
|
||||
|
||||
@@ -510,11 +519,6 @@ ignore_errors = false
|
||||
# module = "lerobot.rl.*"
|
||||
# ignore_errors = false
|
||||
|
||||
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.async_inference.*"
|
||||
# ignore_errors = false
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "lerobot.transport.*"
|
||||
ignore_errors = false
|
||||
|
||||
@@ -1,30 +0,0 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Async inference server/client.
|
||||
|
||||
Requires: ``pip install 'lerobot[async]'``
|
||||
|
||||
Available modules (import directly)::
|
||||
|
||||
from lerobot.async_inference.policy_server import ...
|
||||
from lerobot.async_inference.robot_client import ...
|
||||
"""
|
||||
|
||||
from lerobot.utils.import_utils import require_package
|
||||
|
||||
require_package("grpcio", extra="async", import_name="grpc")
|
||||
|
||||
__all__: list[str] = []
|
||||
@@ -1,203 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.robots.config import RobotConfig
|
||||
|
||||
from .constants import (
|
||||
DEFAULT_FPS,
|
||||
DEFAULT_INFERENCE_LATENCY,
|
||||
DEFAULT_OBS_QUEUE_TIMEOUT,
|
||||
)
|
||||
|
||||
# Aggregate function registry for CLI usage
|
||||
AGGREGATE_FUNCTIONS = {
|
||||
"weighted_average": lambda old, new: 0.3 * old + 0.7 * new,
|
||||
"latest_only": lambda old, new: new,
|
||||
"average": lambda old, new: 0.5 * old + 0.5 * new,
|
||||
"conservative": lambda old, new: 0.7 * old + 0.3 * new,
|
||||
}
|
||||
|
||||
|
||||
def get_aggregate_function(name: str) -> Callable[[torch.Tensor, torch.Tensor], torch.Tensor]:
|
||||
"""Get aggregate function by name from registry."""
|
||||
if name not in AGGREGATE_FUNCTIONS:
|
||||
available = list(AGGREGATE_FUNCTIONS.keys())
|
||||
raise ValueError(f"Unknown aggregate function '{name}'. Available: {available}")
|
||||
return AGGREGATE_FUNCTIONS[name]
|
||||
|
||||
|
||||
@dataclass
|
||||
class PolicyServerConfig:
|
||||
"""Configuration for PolicyServer.
|
||||
|
||||
This class defines all configurable parameters for the PolicyServer,
|
||||
including networking settings and action chunking specifications.
|
||||
"""
|
||||
|
||||
# Networking configuration
|
||||
host: str = field(default="localhost", metadata={"help": "Host address to bind the server to"})
|
||||
port: int = field(default=8080, metadata={"help": "Port number to bind the server to"})
|
||||
|
||||
# Timing configuration
|
||||
fps: int = field(default=DEFAULT_FPS, metadata={"help": "Frames per second"})
|
||||
inference_latency: float = field(
|
||||
default=DEFAULT_INFERENCE_LATENCY, metadata={"help": "Target inference latency in seconds"}
|
||||
)
|
||||
|
||||
obs_queue_timeout: float = field(
|
||||
default=DEFAULT_OBS_QUEUE_TIMEOUT, metadata={"help": "Timeout for observation queue in seconds"}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
"""Validate configuration after initialization."""
|
||||
if self.port < 1 or self.port > 65535:
|
||||
raise ValueError(f"Port must be between 1 and 65535, got {self.port}")
|
||||
|
||||
if self.environment_dt <= 0:
|
||||
raise ValueError(f"environment_dt must be positive, got {self.environment_dt}")
|
||||
|
||||
if self.inference_latency < 0:
|
||||
raise ValueError(f"inference_latency must be non-negative, got {self.inference_latency}")
|
||||
|
||||
if self.obs_queue_timeout < 0:
|
||||
raise ValueError(f"obs_queue_timeout must be non-negative, got {self.obs_queue_timeout}")
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, config_dict: dict) -> "PolicyServerConfig":
|
||||
"""Create a PolicyServerConfig from a dictionary."""
|
||||
return cls(**config_dict)
|
||||
|
||||
@property
|
||||
def environment_dt(self) -> float:
|
||||
"""Environment time step, in seconds"""
|
||||
return 1 / self.fps
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert the configuration to a dictionary."""
|
||||
return {
|
||||
"host": self.host,
|
||||
"port": self.port,
|
||||
"fps": self.fps,
|
||||
"environment_dt": self.environment_dt,
|
||||
"inference_latency": self.inference_latency,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class RobotClientConfig:
|
||||
"""Configuration for RobotClient.
|
||||
|
||||
This class defines all configurable parameters for the RobotClient,
|
||||
including network connection, policy settings, and control behavior.
|
||||
"""
|
||||
|
||||
# Policy configuration
|
||||
policy_type: str = field(metadata={"help": "Type of policy to use"})
|
||||
pretrained_name_or_path: str = field(metadata={"help": "Pretrained model name or path"})
|
||||
|
||||
# Robot configuration (for CLI usage - robot instance will be created from this)
|
||||
robot: RobotConfig = field(metadata={"help": "Robot configuration"})
|
||||
|
||||
# Policies typically output K actions at max, but we can use less to avoid wasting bandwidth (as actions
|
||||
# would be aggregated on the client side anyway, depending on the value of `chunk_size_threshold`)
|
||||
actions_per_chunk: int = field(metadata={"help": "Number of actions per chunk"})
|
||||
|
||||
# Task instruction for the robot to execute (e.g., 'fold my tshirt')
|
||||
task: str = field(default="", metadata={"help": "Task instruction for the robot to execute"})
|
||||
|
||||
# Network configuration
|
||||
server_address: str = field(default="localhost:8080", metadata={"help": "Server address to connect to"})
|
||||
|
||||
# Device configuration
|
||||
policy_device: str = field(default="cpu", metadata={"help": "Device for policy inference"})
|
||||
client_device: str = field(
|
||||
default="cpu",
|
||||
metadata={
|
||||
"help": "Device to move actions to after receiving from server (e.g., for downstream planners)"
|
||||
},
|
||||
)
|
||||
|
||||
# Control behavior configuration
|
||||
chunk_size_threshold: float = field(default=0.5, metadata={"help": "Threshold for chunk size control"})
|
||||
fps: int = field(default=DEFAULT_FPS, metadata={"help": "Frames per second"})
|
||||
|
||||
# Aggregate function configuration (CLI-compatible)
|
||||
aggregate_fn_name: str = field(
|
||||
default="weighted_average",
|
||||
metadata={"help": f"Name of aggregate function to use. Options: {list(AGGREGATE_FUNCTIONS.keys())}"},
|
||||
)
|
||||
|
||||
# Debug configuration
|
||||
debug_visualize_queue_size: bool = field(
|
||||
default=False, metadata={"help": "Visualize the action queue size"}
|
||||
)
|
||||
|
||||
@property
|
||||
def environment_dt(self) -> float:
|
||||
"""Environment time step, in seconds"""
|
||||
return 1 / self.fps
|
||||
|
||||
def __post_init__(self):
|
||||
"""Validate configuration after initialization."""
|
||||
if not self.server_address:
|
||||
raise ValueError("server_address cannot be empty")
|
||||
|
||||
if not self.policy_type:
|
||||
raise ValueError("policy_type cannot be empty")
|
||||
|
||||
if not self.pretrained_name_or_path:
|
||||
raise ValueError("pretrained_name_or_path cannot be empty")
|
||||
|
||||
if not self.policy_device:
|
||||
raise ValueError("policy_device cannot be empty")
|
||||
|
||||
if not self.client_device:
|
||||
raise ValueError("client_device cannot be empty")
|
||||
|
||||
if self.chunk_size_threshold < 0 or self.chunk_size_threshold > 1:
|
||||
raise ValueError(f"chunk_size_threshold must be between 0 and 1, got {self.chunk_size_threshold}")
|
||||
|
||||
if self.fps <= 0:
|
||||
raise ValueError(f"fps must be positive, got {self.fps}")
|
||||
|
||||
if self.actions_per_chunk <= 0:
|
||||
raise ValueError(f"actions_per_chunk must be positive, got {self.actions_per_chunk}")
|
||||
|
||||
self.aggregate_fn = get_aggregate_function(self.aggregate_fn_name)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, config_dict: dict) -> "RobotClientConfig":
|
||||
"""Create a RobotClientConfig from a dictionary."""
|
||||
return cls(**config_dict)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert the configuration to a dictionary."""
|
||||
return {
|
||||
"server_address": self.server_address,
|
||||
"policy_type": self.policy_type,
|
||||
"pretrained_name_or_path": self.pretrained_name_or_path,
|
||||
"policy_device": self.policy_device,
|
||||
"client_device": self.client_device,
|
||||
"chunk_size_threshold": self.chunk_size_threshold,
|
||||
"fps": self.fps,
|
||||
"actions_per_chunk": self.actions_per_chunk,
|
||||
"task": self.task,
|
||||
"debug_visualize_queue_size": self.debug_visualize_queue_size,
|
||||
"aggregate_fn_name": self.aggregate_fn_name,
|
||||
}
|
||||
@@ -1,29 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Client side: The environment evolves with a time resolution equal to 1/fps"""
|
||||
|
||||
DEFAULT_FPS = 30
|
||||
|
||||
"""Server side: Running inference on (at most) 1/fps"""
|
||||
DEFAULT_INFERENCE_LATENCY = 1 / DEFAULT_FPS
|
||||
|
||||
"""Server side: Timeout for observation queue in seconds"""
|
||||
DEFAULT_OBS_QUEUE_TIMEOUT = 2
|
||||
|
||||
# All action chunking policies
|
||||
SUPPORTED_POLICIES = ["act", "smolvla", "diffusion", "tdmpc", "vqbet", "pi0", "pi05", "groot"]
|
||||
|
||||
# TODO: Add all other robots
|
||||
SUPPORTED_ROBOTS = ["so100_follower", "so101_follower", "bi_so_follower", "omx_follower"]
|
||||
@@ -1,297 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import logging.handlers
|
||||
import os
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs import PolicyFeature
|
||||
|
||||
# NOTE: Configs need to be loaded for the client to be able to instantiate the policy config
|
||||
from lerobot.policies import ( # noqa: F401
|
||||
ACTConfig,
|
||||
DiffusionConfig,
|
||||
PI0Config,
|
||||
PI05Config,
|
||||
SmolVLAConfig,
|
||||
VQBeTConfig,
|
||||
)
|
||||
from lerobot.robots.robot import Robot
|
||||
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE, OBS_STR
|
||||
from lerobot.utils.feature_utils import build_dataset_frame, hw_to_dataset_features
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
Action = torch.Tensor
|
||||
|
||||
# observation as received from the robot (can be numpy arrays, floats, etc.)
|
||||
RawObservation = dict[str, Any]
|
||||
|
||||
# observation as those recorded in LeRobot dataset (keys are different)
|
||||
LeRobotObservation = dict[str, torch.Tensor]
|
||||
|
||||
# observation, ready for policy inference (image keys resized)
|
||||
Observation = dict[str, torch.Tensor]
|
||||
|
||||
|
||||
def visualize_action_queue_size(action_queue_size: list[int]) -> None:
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
_, ax = plt.subplots()
|
||||
ax.set_title("Action Queue Size Over Time")
|
||||
ax.set_xlabel("Environment steps")
|
||||
ax.set_ylabel("Action Queue Size")
|
||||
ax.set_ylim(0, max(action_queue_size) * 1.1)
|
||||
ax.grid(True, alpha=0.3)
|
||||
ax.plot(range(len(action_queue_size)), action_queue_size)
|
||||
plt.show()
|
||||
|
||||
|
||||
def map_robot_keys_to_lerobot_features(robot: Robot) -> dict[str, dict]:
|
||||
return hw_to_dataset_features(robot.observation_features, OBS_STR, use_video=False)
|
||||
|
||||
|
||||
def is_image_key(k: str) -> bool:
|
||||
return k.startswith(OBS_IMAGES)
|
||||
|
||||
|
||||
def resize_robot_observation_image(image: torch.tensor, resize_dims: tuple[int, int, int]) -> torch.tensor:
|
||||
assert image.ndim == 3, f"Image must be (C, H, W)! Received {image.shape}"
|
||||
# (H, W, C) -> (C, H, W) for resizing from robot obsevation resolution to policy image resolution
|
||||
image = image.permute(2, 0, 1)
|
||||
dims = (resize_dims[1], resize_dims[2])
|
||||
# Add batch dimension for interpolate: (C, H, W) -> (1, C, H, W)
|
||||
image_batched = image.unsqueeze(0)
|
||||
# Interpolate and remove batch dimension: (1, C, H, W) -> (C, H, W)
|
||||
resized = torch.nn.functional.interpolate(image_batched, size=dims, mode="bilinear", align_corners=False)
|
||||
|
||||
return resized.squeeze(0)
|
||||
|
||||
|
||||
# TODO(Steven): Consider implementing a pipeline step for this
|
||||
def raw_observation_to_observation(
|
||||
raw_observation: RawObservation,
|
||||
lerobot_features: dict[str, dict],
|
||||
policy_image_features: dict[str, PolicyFeature],
|
||||
) -> Observation:
|
||||
observation = {}
|
||||
|
||||
observation = prepare_raw_observation(raw_observation, lerobot_features, policy_image_features)
|
||||
for k, v in observation.items():
|
||||
if isinstance(v, torch.Tensor): # VLAs present natural-language instructions in observations
|
||||
if "image" in k:
|
||||
# Policy expects images in shape (B, C, H, W)
|
||||
observation[k] = prepare_image(v).unsqueeze(0)
|
||||
else:
|
||||
observation[k] = v
|
||||
|
||||
return observation
|
||||
|
||||
|
||||
def prepare_image(image: torch.Tensor) -> torch.Tensor:
|
||||
"""Minimal preprocessing to turn int8 images to float32 in [0, 1], and create a memory-contiguous tensor"""
|
||||
image = image.type(torch.float32) / 255
|
||||
image = image.contiguous()
|
||||
|
||||
return image
|
||||
|
||||
|
||||
def extract_state_from_raw_observation(
|
||||
lerobot_obs: RawObservation,
|
||||
) -> torch.Tensor:
|
||||
"""Extract the state from a raw observation."""
|
||||
state = torch.tensor(lerobot_obs[OBS_STATE])
|
||||
|
||||
if state.ndim == 1:
|
||||
state = state.unsqueeze(0)
|
||||
|
||||
return state
|
||||
|
||||
|
||||
def extract_images_from_raw_observation(
|
||||
lerobot_obs: RawObservation,
|
||||
camera_key: str,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""Extract the images from a raw observation."""
|
||||
return torch.tensor(lerobot_obs[camera_key])
|
||||
|
||||
|
||||
def make_lerobot_observation(
|
||||
robot_obs: RawObservation,
|
||||
lerobot_features: dict[str, dict],
|
||||
) -> LeRobotObservation:
|
||||
"""Make a lerobot observation from a raw observation."""
|
||||
return build_dataset_frame(lerobot_features, robot_obs, prefix=OBS_STR)
|
||||
|
||||
|
||||
def prepare_raw_observation(
|
||||
robot_obs: RawObservation,
|
||||
lerobot_features: dict[str, dict],
|
||||
policy_image_features: dict[str, PolicyFeature],
|
||||
) -> Observation:
|
||||
"""Matches keys from the raw robot_obs dict to the keys expected by a given policy (passed as
|
||||
policy_image_features)."""
|
||||
# 1. {motor.pos1:value1, motor.pos2:value2, ..., laptop:np.ndarray} ->
|
||||
# -> {observation.state:[value1,value2,...], observation.images.laptop:np.ndarray}
|
||||
lerobot_obs = make_lerobot_observation(robot_obs, lerobot_features)
|
||||
|
||||
# 2. Greps all observation.images.<> keys
|
||||
image_keys = list(filter(is_image_key, lerobot_obs))
|
||||
# state's shape is expected as (B, state_dim)
|
||||
state_dict = {OBS_STATE: extract_state_from_raw_observation(lerobot_obs)}
|
||||
image_dict = {
|
||||
image_k: extract_images_from_raw_observation(lerobot_obs, image_k) for image_k in image_keys
|
||||
}
|
||||
|
||||
# Turns the image features to (C, H, W) with H, W matching the policy image features.
|
||||
# This reduces the resolution of the images
|
||||
image_dict = {
|
||||
key: resize_robot_observation_image(torch.tensor(lerobot_obs[key]), policy_image_features[key].shape)
|
||||
for key in image_keys
|
||||
}
|
||||
|
||||
if "task" in robot_obs:
|
||||
state_dict["task"] = robot_obs["task"]
|
||||
|
||||
return {**state_dict, **image_dict}
|
||||
|
||||
|
||||
def get_logger(name: str, log_to_file: bool = True) -> logging.Logger:
|
||||
"""
|
||||
Get a logger using the standardized logging setup from utils.py.
|
||||
|
||||
Args:
|
||||
name: Logger name (e.g., 'policy_server', 'robot_client')
|
||||
log_to_file: Whether to also log to a file
|
||||
|
||||
Returns:
|
||||
Configured logger instance
|
||||
"""
|
||||
# Create logs directory if logging to file
|
||||
if log_to_file:
|
||||
os.makedirs("logs", exist_ok=True)
|
||||
log_file = Path(f"logs/{name}_{int(time.time())}.log")
|
||||
else:
|
||||
log_file = None
|
||||
|
||||
# Initialize the standardized logging
|
||||
init_logging(log_file=log_file, display_pid=False)
|
||||
|
||||
# Return a named logger
|
||||
return logging.getLogger(name)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TimedData:
|
||||
"""A data object with timestamp and timestep information.
|
||||
|
||||
Args:
|
||||
timestamp: Unix timestamp relative to data's creation.
|
||||
data: The actual data to wrap a timestamp around.
|
||||
timestep: The timestep of the data.
|
||||
"""
|
||||
|
||||
timestamp: float
|
||||
timestep: int
|
||||
|
||||
def get_timestamp(self):
|
||||
return self.timestamp
|
||||
|
||||
def get_timestep(self):
|
||||
return self.timestep
|
||||
|
||||
|
||||
@dataclass
|
||||
class TimedAction(TimedData):
|
||||
action: Action
|
||||
|
||||
def get_action(self):
|
||||
return self.action
|
||||
|
||||
|
||||
@dataclass
|
||||
class TimedObservation(TimedData):
|
||||
observation: RawObservation
|
||||
must_go: bool = False
|
||||
|
||||
def get_observation(self):
|
||||
return self.observation
|
||||
|
||||
|
||||
@dataclass
|
||||
class FPSTracker:
|
||||
"""Utility class to track FPS metrics over time."""
|
||||
|
||||
target_fps: float
|
||||
first_timestamp: float = None
|
||||
total_obs_count: int = 0
|
||||
|
||||
def calculate_fps_metrics(self, current_timestamp: float) -> dict[str, float]:
|
||||
"""Calculate average FPS vs target"""
|
||||
self.total_obs_count += 1
|
||||
|
||||
# Initialize first observation time
|
||||
if self.first_timestamp is None:
|
||||
self.first_timestamp = current_timestamp
|
||||
|
||||
# Calculate overall average FPS (since start)
|
||||
total_duration = current_timestamp - self.first_timestamp
|
||||
avg_fps = (self.total_obs_count - 1) / total_duration if total_duration > 1e-6 else 0.0
|
||||
|
||||
return {"avg_fps": avg_fps, "target_fps": self.target_fps}
|
||||
|
||||
def reset(self):
|
||||
"""Reset the FPS tracker state"""
|
||||
self.first_timestamp = None
|
||||
self.total_obs_count = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class RemotePolicyConfig:
|
||||
policy_type: str
|
||||
pretrained_name_or_path: str
|
||||
lerobot_features: dict[str, PolicyFeature]
|
||||
actions_per_chunk: int
|
||||
device: str = "cpu"
|
||||
rename_map: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
|
||||
def _compare_observation_states(obs1_state: torch.Tensor, obs2_state: torch.Tensor, atol: float) -> bool:
|
||||
"""Check if two observation states are similar, under a tolerance threshold"""
|
||||
return bool(torch.linalg.norm(obs1_state - obs2_state) < atol)
|
||||
|
||||
|
||||
def observations_similar(
|
||||
obs1: TimedObservation, obs2: TimedObservation, lerobot_features: dict[str, dict], atol: float = 1
|
||||
) -> bool:
|
||||
"""Check if two observations are similar, under a tolerance threshold. Measures distance between
|
||||
observations as the difference in joint-space between the two observations.
|
||||
|
||||
NOTE(fracapuano): This is a very simple check, and it is enough for the current use case.
|
||||
An immediate next step is to use (fast) perceptual difference metrics comparing some camera views,
|
||||
to surpass this joint-space similarity check.
|
||||
"""
|
||||
obs1_state = extract_state_from_raw_observation(
|
||||
make_lerobot_observation(obs1.get_observation(), lerobot_features)
|
||||
)
|
||||
obs2_state = extract_state_from_raw_observation(
|
||||
make_lerobot_observation(obs2.get_observation(), lerobot_features)
|
||||
)
|
||||
|
||||
return _compare_observation_states(obs1_state, obs2_state, atol=atol)
|
||||
@@ -1,439 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Example:
|
||||
```shell
|
||||
python -m lerobot.async_inference.policy_server \
|
||||
--host=127.0.0.1 \
|
||||
--port=8080 \
|
||||
--fps=30 \
|
||||
--inference_latency=0.033 \
|
||||
--obs_queue_timeout=1
|
||||
```
|
||||
"""
|
||||
|
||||
import logging
|
||||
import pickle # nosec
|
||||
import threading
|
||||
import time
|
||||
from concurrent import futures
|
||||
from dataclasses import asdict
|
||||
from pprint import pformat
|
||||
from queue import Empty, Queue
|
||||
from typing import Any
|
||||
|
||||
import draccus
|
||||
import grpc
|
||||
import torch
|
||||
|
||||
from lerobot.policies import get_policy_class, make_pre_post_processors
|
||||
from lerobot.processor import PolicyProcessorPipeline
|
||||
from lerobot.transport import (
|
||||
services_pb2, # type: ignore
|
||||
services_pb2_grpc, # type: ignore
|
||||
)
|
||||
from lerobot.transport.utils import receive_bytes_in_chunks
|
||||
from lerobot.types import PolicyAction
|
||||
|
||||
from .configs import PolicyServerConfig
|
||||
from .constants import SUPPORTED_POLICIES
|
||||
from .helpers import (
|
||||
FPSTracker,
|
||||
Observation,
|
||||
RemotePolicyConfig,
|
||||
TimedAction,
|
||||
TimedObservation,
|
||||
get_logger,
|
||||
observations_similar,
|
||||
raw_observation_to_observation,
|
||||
)
|
||||
|
||||
|
||||
class PolicyServer(services_pb2_grpc.AsyncInferenceServicer):
|
||||
prefix = "policy_server"
|
||||
logger = get_logger(prefix)
|
||||
|
||||
def __init__(self, config: PolicyServerConfig):
|
||||
self.config = config
|
||||
self.shutdown_event = threading.Event()
|
||||
|
||||
# FPS measurement
|
||||
self.fps_tracker = FPSTracker(target_fps=config.fps)
|
||||
|
||||
self.observation_queue = Queue(maxsize=1)
|
||||
|
||||
self._predicted_timesteps_lock = threading.Lock()
|
||||
self._predicted_timesteps = set()
|
||||
|
||||
self.last_processed_obs = None
|
||||
|
||||
# Attributes will be set by SendPolicyInstructions
|
||||
self.device = None
|
||||
self.policy_type = None
|
||||
self.lerobot_features = None
|
||||
self.actions_per_chunk = None
|
||||
self.policy = None
|
||||
self.preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]] | None = None
|
||||
self.postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction] | None = None
|
||||
|
||||
@property
|
||||
def running(self):
|
||||
return not self.shutdown_event.is_set()
|
||||
|
||||
@property
|
||||
def policy_image_features(self):
|
||||
return self.policy.config.image_features
|
||||
|
||||
def _reset_server(self) -> None:
|
||||
"""Flushes server state when new client connects."""
|
||||
# only running inference on the latest observation received by the server
|
||||
self.shutdown_event.set()
|
||||
self.observation_queue = Queue(maxsize=1)
|
||||
|
||||
with self._predicted_timesteps_lock:
|
||||
self._predicted_timesteps = set()
|
||||
|
||||
def Ready(self, request, context): # noqa: N802
|
||||
client_id = context.peer()
|
||||
self.logger.info(f"Client {client_id} connected and ready")
|
||||
self._reset_server()
|
||||
self.shutdown_event.clear()
|
||||
|
||||
return services_pb2.Empty()
|
||||
|
||||
def SendPolicyInstructions(self, request, context): # noqa: N802
|
||||
"""Receive policy instructions from the robot client"""
|
||||
|
||||
if not self.running:
|
||||
self.logger.warning("Server is not running. Ignoring policy instructions.")
|
||||
return services_pb2.Empty()
|
||||
|
||||
client_id = context.peer()
|
||||
|
||||
policy_specs = pickle.loads(request.data) # nosec
|
||||
|
||||
if not isinstance(policy_specs, RemotePolicyConfig):
|
||||
raise TypeError(f"Policy specs must be a RemotePolicyConfig. Got {type(policy_specs)}")
|
||||
|
||||
if policy_specs.policy_type not in SUPPORTED_POLICIES:
|
||||
raise ValueError(
|
||||
f"Policy type {policy_specs.policy_type} not supported. "
|
||||
f"Supported policies: {SUPPORTED_POLICIES}"
|
||||
)
|
||||
|
||||
self.logger.info(
|
||||
f"Receiving policy instructions from {client_id} | "
|
||||
f"Policy type: {policy_specs.policy_type} | "
|
||||
f"Pretrained name or path: {policy_specs.pretrained_name_or_path} | "
|
||||
f"Actions per chunk: {policy_specs.actions_per_chunk} | "
|
||||
f"Device: {policy_specs.device}"
|
||||
)
|
||||
|
||||
self.device = policy_specs.device
|
||||
self.policy_type = policy_specs.policy_type # act, pi0, etc.
|
||||
self.lerobot_features = policy_specs.lerobot_features
|
||||
self.actions_per_chunk = policy_specs.actions_per_chunk
|
||||
|
||||
policy_class = get_policy_class(self.policy_type)
|
||||
|
||||
start = time.perf_counter()
|
||||
self.policy = policy_class.from_pretrained(policy_specs.pretrained_name_or_path)
|
||||
self.policy.to(self.device)
|
||||
|
||||
# Load preprocessor and postprocessor, overriding device to match requested device
|
||||
device_override = {"device": self.device}
|
||||
self.preprocessor, self.postprocessor = make_pre_post_processors(
|
||||
self.policy.config,
|
||||
pretrained_path=policy_specs.pretrained_name_or_path,
|
||||
preprocessor_overrides={
|
||||
"device_processor": device_override,
|
||||
"rename_observations_processor": {"rename_map": policy_specs.rename_map},
|
||||
},
|
||||
postprocessor_overrides={"device_processor": device_override},
|
||||
)
|
||||
|
||||
end = time.perf_counter()
|
||||
|
||||
self.logger.info(f"Time taken to put policy on {self.device}: {end - start:.4f} seconds")
|
||||
|
||||
return services_pb2.Empty()
|
||||
|
||||
def SendObservations(self, request_iterator, context): # noqa: N802
|
||||
"""Receive observations from the robot client"""
|
||||
client_id = context.peer()
|
||||
self.logger.debug(f"Receiving observations from {client_id}")
|
||||
|
||||
receive_time = time.time() # comparing timestamps so need time.time()
|
||||
start_deserialize = time.perf_counter()
|
||||
received_bytes = receive_bytes_in_chunks(
|
||||
request_iterator, None, self.shutdown_event, self.logger
|
||||
) # blocking call while looping over request_iterator
|
||||
timed_observation = pickle.loads(received_bytes) # nosec
|
||||
deserialize_time = time.perf_counter() - start_deserialize
|
||||
|
||||
self.logger.debug(f"Received observation #{timed_observation.get_timestep()}")
|
||||
|
||||
obs_timestep = timed_observation.get_timestep()
|
||||
obs_timestamp = timed_observation.get_timestamp()
|
||||
|
||||
# Calculate FPS metrics
|
||||
fps_metrics = self.fps_tracker.calculate_fps_metrics(obs_timestamp)
|
||||
|
||||
self.logger.debug(
|
||||
f"Received observation #{obs_timestep} | "
|
||||
f"Avg FPS: {fps_metrics['avg_fps']:.2f} | " # fps at which observations are received from client
|
||||
f"Target: {fps_metrics['target_fps']:.2f} | "
|
||||
f"One-way latency: {(receive_time - obs_timestamp) * 1000:.2f}ms"
|
||||
)
|
||||
|
||||
self.logger.debug(
|
||||
f"Server timestamp: {receive_time:.6f} | "
|
||||
f"Client timestamp: {obs_timestamp:.6f} | "
|
||||
f"Deserialization time: {deserialize_time:.6f}s"
|
||||
)
|
||||
|
||||
if not self._enqueue_observation(
|
||||
timed_observation # wrapping a RawObservation
|
||||
):
|
||||
self.logger.debug(f"Observation #{obs_timestep} has been filtered out")
|
||||
|
||||
return services_pb2.Empty()
|
||||
|
||||
def GetActions(self, request, context): # noqa: N802
|
||||
"""Returns actions to the robot client. Actions are sent as a single
|
||||
chunk, containing multiple actions."""
|
||||
client_id = context.peer()
|
||||
self.logger.debug(f"Client {client_id} connected for action streaming")
|
||||
|
||||
# Generate action based on the most recent observation and its timestep
|
||||
try:
|
||||
getactions_starts = time.perf_counter()
|
||||
obs = self.observation_queue.get(timeout=self.config.obs_queue_timeout)
|
||||
self.logger.info(
|
||||
f"Running inference for observation #{obs.get_timestep()} (must_go: {obs.must_go})"
|
||||
)
|
||||
|
||||
with self._predicted_timesteps_lock:
|
||||
self._predicted_timesteps.add(obs.get_timestep())
|
||||
|
||||
start_time = time.perf_counter()
|
||||
action_chunk = self._predict_action_chunk(obs)
|
||||
inference_time = time.perf_counter() - start_time
|
||||
|
||||
start_time = time.perf_counter()
|
||||
actions_bytes = pickle.dumps(action_chunk) # nosec
|
||||
serialize_time = time.perf_counter() - start_time
|
||||
|
||||
# Create and return the action chunk
|
||||
actions = services_pb2.Actions(data=actions_bytes)
|
||||
|
||||
self.logger.info(
|
||||
f"Action chunk #{obs.get_timestep()} generated | "
|
||||
f"Total time: {(inference_time + serialize_time) * 1000:.2f}ms"
|
||||
)
|
||||
|
||||
self.logger.debug(
|
||||
f"Action chunk #{obs.get_timestep()} generated | "
|
||||
f"Inference time: {inference_time:.2f}s |"
|
||||
f"Serialize time: {serialize_time:.2f}s |"
|
||||
f"Total time: {inference_time + serialize_time:.2f}s"
|
||||
)
|
||||
|
||||
time.sleep(
|
||||
max(0, self.config.inference_latency - max(0, time.perf_counter() - getactions_starts))
|
||||
) # sleep controls inference latency
|
||||
|
||||
return actions
|
||||
|
||||
except Empty: # no observation added to queue in obs_queue_timeout
|
||||
return services_pb2.Empty()
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in StreamActions: {e}")
|
||||
|
||||
return services_pb2.Empty()
|
||||
|
||||
def _obs_sanity_checks(self, obs: TimedObservation, previous_obs: TimedObservation) -> bool:
|
||||
"""Check if the observation is valid to be processed by the policy"""
|
||||
with self._predicted_timesteps_lock:
|
||||
predicted_timesteps = self._predicted_timesteps
|
||||
|
||||
if obs.get_timestep() in predicted_timesteps:
|
||||
self.logger.debug(f"Skipping observation #{obs.get_timestep()} - Timestep predicted already!")
|
||||
return False
|
||||
|
||||
elif observations_similar(obs, previous_obs, lerobot_features=self.lerobot_features):
|
||||
self.logger.debug(
|
||||
f"Skipping observation #{obs.get_timestep()} - Observation too similar to last obs predicted!"
|
||||
)
|
||||
return False
|
||||
|
||||
else:
|
||||
return True
|
||||
|
||||
def _enqueue_observation(self, obs: TimedObservation) -> bool:
|
||||
"""Enqueue an observation if it must go through processing, otherwise skip it.
|
||||
Observations not in queue are never run through the policy network"""
|
||||
|
||||
if (
|
||||
obs.must_go
|
||||
or self.last_processed_obs is None
|
||||
or self._obs_sanity_checks(obs, self.last_processed_obs)
|
||||
):
|
||||
last_obs = self.last_processed_obs.get_timestep() if self.last_processed_obs else "None"
|
||||
self.logger.debug(
|
||||
f"Enqueuing observation. Must go: {obs.must_go} | Last processed obs: {last_obs}"
|
||||
)
|
||||
|
||||
# If queue is full, get the old observation to make room
|
||||
if self.observation_queue.full():
|
||||
# pops from queue
|
||||
_ = self.observation_queue.get_nowait()
|
||||
self.logger.debug("Observation queue was full, removed oldest observation")
|
||||
|
||||
# Now put the new observation (never blocks as queue is non-full here)
|
||||
self.observation_queue.put(obs)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _time_action_chunk(self, t_0: float, action_chunk: list[torch.Tensor], i_0: int) -> list[TimedAction]:
|
||||
"""Turn a chunk of actions into a list of TimedAction instances,
|
||||
with the first action corresponding to t_0 and the rest corresponding to
|
||||
t_0 + i*environment_dt for i in range(len(action_chunk))
|
||||
"""
|
||||
return [
|
||||
TimedAction(timestamp=t_0 + i * self.config.environment_dt, timestep=i_0 + i, action=action)
|
||||
for i, action in enumerate(action_chunk)
|
||||
]
|
||||
|
||||
def _get_action_chunk(self, observation: dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
"""Get an action chunk from the policy. The chunk contains only"""
|
||||
chunk = self.policy.predict_action_chunk(observation)
|
||||
if chunk.ndim != 3:
|
||||
chunk = chunk.unsqueeze(0) # adding batch dimension, now shape is (B, chunk_size, action_dim)
|
||||
|
||||
return chunk[:, : self.actions_per_chunk, :]
|
||||
|
||||
def _predict_action_chunk(self, observation_t: TimedObservation) -> list[TimedAction]:
|
||||
"""Predict an action chunk based on an observation.
|
||||
|
||||
Pipeline:
|
||||
1. Convert raw observation to LeRobot format
|
||||
2. Apply preprocessor (tokenization, normalization, batching, device placement)
|
||||
3. Run policy inference to get action chunk
|
||||
4. Apply postprocessor (unnormalization, device movement)
|
||||
5. Convert to TimedAction list
|
||||
"""
|
||||
"""1. Prepare observation"""
|
||||
start_prepare = time.perf_counter()
|
||||
observation: Observation = raw_observation_to_observation(
|
||||
observation_t.get_observation(),
|
||||
self.lerobot_features,
|
||||
self.policy_image_features,
|
||||
)
|
||||
prepare_time = time.perf_counter() - start_prepare
|
||||
|
||||
"""2. Apply preprocessor"""
|
||||
start_preprocess = time.perf_counter()
|
||||
observation = self.preprocessor(observation)
|
||||
self.last_processed_obs: TimedObservation = observation_t
|
||||
preprocessing_time = time.perf_counter() - start_preprocess
|
||||
|
||||
"""3. Get action chunk"""
|
||||
start_inference = time.perf_counter()
|
||||
action_tensor = self._get_action_chunk(observation)
|
||||
inference_time = time.perf_counter() - start_inference
|
||||
self.logger.info(
|
||||
f"Preprocessing and inference took {inference_time:.4f}s, action shape: {action_tensor.shape}"
|
||||
)
|
||||
|
||||
"""4. Apply postprocessor"""
|
||||
# Apply postprocessor (handles unnormalization and device movement)
|
||||
# Postprocessor expects (B, action_dim) per action, but we have (B, chunk_size, action_dim)
|
||||
# So we process each action in the chunk individually
|
||||
start_postprocess = time.perf_counter()
|
||||
_, chunk_size, _ = action_tensor.shape
|
||||
|
||||
# Process each action in the chunk
|
||||
processed_actions = []
|
||||
for i in range(chunk_size):
|
||||
# Extract action at timestep i: (B, action_dim)
|
||||
single_action = action_tensor[:, i, :]
|
||||
processed_action = self.postprocessor(single_action)
|
||||
processed_actions.append(processed_action)
|
||||
|
||||
# Stack back to (B, chunk_size, action_dim), then remove batch dim
|
||||
action_tensor = torch.stack(processed_actions, dim=1).squeeze(0)
|
||||
self.logger.debug(f"Postprocessed action shape: {action_tensor.shape}")
|
||||
|
||||
action_tensor = action_tensor.detach().cpu()
|
||||
|
||||
"""5. Convert to TimedAction list"""
|
||||
action_chunk = self._time_action_chunk(
|
||||
observation_t.get_timestamp(), list(action_tensor), observation_t.get_timestep()
|
||||
)
|
||||
postprocess_stops = time.perf_counter()
|
||||
postprocessing_time = postprocess_stops - start_postprocess
|
||||
|
||||
self.logger.info(
|
||||
f"Observation {observation_t.get_timestep()} | "
|
||||
f"Total time: {1000 * (postprocess_stops - start_prepare):.2f}ms"
|
||||
)
|
||||
|
||||
self.logger.debug(
|
||||
f"Observation {observation_t.get_timestep()} | "
|
||||
f"Prepare time: {1000 * prepare_time:.2f}ms | "
|
||||
f"Preprocessing time: {1000 * preprocessing_time:.2f}ms | "
|
||||
f"Inference time: {1000 * inference_time:.2f}ms | "
|
||||
f"Postprocessing time: {1000 * postprocessing_time:.2f}ms | "
|
||||
f"Total time: {1000 * (postprocess_stops - start_prepare):.2f}ms"
|
||||
)
|
||||
|
||||
return action_chunk
|
||||
|
||||
def stop(self):
|
||||
"""Stop the server"""
|
||||
self._reset_server()
|
||||
self.logger.info("Server stopping...")
|
||||
|
||||
|
||||
@draccus.wrap()
|
||||
def serve(cfg: PolicyServerConfig):
|
||||
"""Start the PolicyServer with the given configuration.
|
||||
|
||||
Args:
|
||||
config: PolicyServerConfig instance. If None, uses default configuration.
|
||||
"""
|
||||
logging.info(pformat(asdict(cfg)))
|
||||
|
||||
# Create the server instance first
|
||||
policy_server = PolicyServer(cfg)
|
||||
|
||||
# Setup and start gRPC server
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=4))
|
||||
services_pb2_grpc.add_AsyncInferenceServicer_to_server(policy_server, server)
|
||||
server.add_insecure_port(f"{cfg.host}:{cfg.port}")
|
||||
|
||||
policy_server.logger.info(f"PolicyServer started on {cfg.host}:{cfg.port}")
|
||||
server.start()
|
||||
|
||||
server.wait_for_termination()
|
||||
|
||||
policy_server.logger.info("Server terminated")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
serve()
|
||||
@@ -1,517 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Example command:
|
||||
```shell
|
||||
python src/lerobot/async_inference/robot_client.py \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58760431541 \
|
||||
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 1920, height: 1080, fps: 30}}" \
|
||||
--robot.id=black \
|
||||
--task="dummy" \
|
||||
--server_address=127.0.0.1:8080 \
|
||||
--policy_type=act \
|
||||
--pretrained_name_or_path=user/model \
|
||||
--policy_device=mps \
|
||||
--client_device=cpu \
|
||||
--actions_per_chunk=50 \
|
||||
--chunk_size_threshold=0.5 \
|
||||
--aggregate_fn_name=weighted_average \
|
||||
--debug_visualize_queue_size=True
|
||||
```
|
||||
"""
|
||||
|
||||
import logging
|
||||
import pickle # nosec
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from dataclasses import asdict
|
||||
from pprint import pformat
|
||||
from queue import Queue
|
||||
from typing import Any
|
||||
|
||||
import draccus
|
||||
import grpc
|
||||
import torch
|
||||
|
||||
from lerobot.cameras.opencv import OpenCVCameraConfig # noqa: F401
|
||||
from lerobot.cameras.realsense import RealSenseCameraConfig # noqa: F401
|
||||
from lerobot.robots import ( # noqa: F401
|
||||
Robot,
|
||||
RobotConfig,
|
||||
bi_so_follower,
|
||||
koch_follower,
|
||||
make_robot_from_config,
|
||||
omx_follower,
|
||||
so_follower,
|
||||
)
|
||||
from lerobot.transport import (
|
||||
services_pb2, # type: ignore
|
||||
services_pb2_grpc, # type: ignore
|
||||
)
|
||||
from lerobot.transport.utils import grpc_channel_options, send_bytes_in_chunks
|
||||
from lerobot.utils.import_utils import register_third_party_plugins
|
||||
|
||||
from .configs import RobotClientConfig
|
||||
from .helpers import (
|
||||
Action,
|
||||
FPSTracker,
|
||||
Observation,
|
||||
RawObservation,
|
||||
RemotePolicyConfig,
|
||||
TimedAction,
|
||||
TimedObservation,
|
||||
get_logger,
|
||||
map_robot_keys_to_lerobot_features,
|
||||
visualize_action_queue_size,
|
||||
)
|
||||
|
||||
|
||||
class RobotClient:
|
||||
prefix = "robot_client"
|
||||
logger = get_logger(prefix)
|
||||
|
||||
def __init__(self, config: RobotClientConfig):
|
||||
"""Initialize RobotClient with unified configuration.
|
||||
|
||||
Args:
|
||||
config: RobotClientConfig containing all configuration parameters
|
||||
"""
|
||||
# Store configuration
|
||||
self.config = config
|
||||
self.robot = make_robot_from_config(config.robot)
|
||||
self.robot.connect()
|
||||
|
||||
lerobot_features = map_robot_keys_to_lerobot_features(self.robot)
|
||||
|
||||
# Use environment variable if server_address is not provided in config
|
||||
self.server_address = config.server_address
|
||||
|
||||
self.policy_config = RemotePolicyConfig(
|
||||
config.policy_type,
|
||||
config.pretrained_name_or_path,
|
||||
lerobot_features,
|
||||
config.actions_per_chunk,
|
||||
config.policy_device,
|
||||
)
|
||||
self.channel = grpc.insecure_channel(
|
||||
self.server_address, grpc_channel_options(initial_backoff=f"{config.environment_dt:.4f}s")
|
||||
)
|
||||
self.stub = services_pb2_grpc.AsyncInferenceStub(self.channel)
|
||||
self.logger.info(f"Initializing client to connect to server at {self.server_address}")
|
||||
|
||||
self.shutdown_event = threading.Event()
|
||||
|
||||
# Initialize client side variables
|
||||
self.latest_action_lock = threading.Lock()
|
||||
self.latest_action = -1
|
||||
self.action_chunk_size = -1
|
||||
|
||||
self._chunk_size_threshold = config.chunk_size_threshold
|
||||
|
||||
self.action_queue = Queue()
|
||||
self.action_queue_lock = threading.Lock() # Protect queue operations
|
||||
self.action_queue_size = []
|
||||
self.start_barrier = threading.Barrier(2) # 2 threads: action receiver, control loop
|
||||
|
||||
# FPS measurement
|
||||
self.fps_tracker = FPSTracker(target_fps=self.config.fps)
|
||||
|
||||
self.logger.info("Robot connected and ready")
|
||||
|
||||
# Use an event for thread-safe coordination
|
||||
self.must_go = threading.Event()
|
||||
self.must_go.set() # Initially set - observations qualify for direct processing
|
||||
|
||||
@property
|
||||
def running(self):
|
||||
return not self.shutdown_event.is_set()
|
||||
|
||||
def start(self):
|
||||
"""Start the robot client and connect to the policy server"""
|
||||
try:
|
||||
# client-server handshake
|
||||
start_time = time.perf_counter()
|
||||
self.stub.Ready(services_pb2.Empty())
|
||||
end_time = time.perf_counter()
|
||||
self.logger.debug(f"Connected to policy server in {end_time - start_time:.4f}s")
|
||||
|
||||
# send policy instructions
|
||||
policy_config_bytes = pickle.dumps(self.policy_config)
|
||||
policy_setup = services_pb2.PolicySetup(data=policy_config_bytes)
|
||||
|
||||
self.logger.info("Sending policy instructions to policy server")
|
||||
self.logger.debug(
|
||||
f"Policy type: {self.policy_config.policy_type} | "
|
||||
f"Pretrained name or path: {self.policy_config.pretrained_name_or_path} | "
|
||||
f"Device: {self.policy_config.device}"
|
||||
)
|
||||
|
||||
self.stub.SendPolicyInstructions(policy_setup)
|
||||
|
||||
self.shutdown_event.clear()
|
||||
|
||||
return True
|
||||
|
||||
except grpc.RpcError as e:
|
||||
self.logger.error(f"Failed to connect to policy server: {e}")
|
||||
return False
|
||||
|
||||
def stop(self):
|
||||
"""Stop the robot client"""
|
||||
self.shutdown_event.set()
|
||||
|
||||
self.robot.disconnect()
|
||||
self.logger.debug("Robot disconnected")
|
||||
|
||||
self.channel.close()
|
||||
self.logger.debug("Client stopped, channel closed")
|
||||
|
||||
def send_observation(
|
||||
self,
|
||||
obs: TimedObservation,
|
||||
) -> bool:
|
||||
"""Send observation to the policy server.
|
||||
Returns True if the observation was sent successfully, False otherwise."""
|
||||
if not self.running:
|
||||
raise RuntimeError("Client not running. Run RobotClient.start() before sending observations.")
|
||||
|
||||
if not isinstance(obs, TimedObservation):
|
||||
raise ValueError("Input observation needs to be a TimedObservation!")
|
||||
|
||||
start_time = time.perf_counter()
|
||||
observation_bytes = pickle.dumps(obs)
|
||||
serialize_time = time.perf_counter() - start_time
|
||||
self.logger.debug(f"Observation serialization time: {serialize_time:.6f}s")
|
||||
|
||||
try:
|
||||
observation_iterator = send_bytes_in_chunks(
|
||||
observation_bytes,
|
||||
services_pb2.Observation,
|
||||
log_prefix="[CLIENT] Observation",
|
||||
silent=True,
|
||||
)
|
||||
_ = self.stub.SendObservations(observation_iterator)
|
||||
obs_timestep = obs.get_timestep()
|
||||
self.logger.debug(f"Sent observation #{obs_timestep} | ")
|
||||
|
||||
return True
|
||||
|
||||
except grpc.RpcError as e:
|
||||
self.logger.error(f"Error sending observation #{obs.get_timestep()}: {e}")
|
||||
return False
|
||||
|
||||
def _inspect_action_queue(self):
|
||||
with self.action_queue_lock:
|
||||
queue_size = self.action_queue.qsize()
|
||||
timestamps = sorted([action.get_timestep() for action in self.action_queue.queue])
|
||||
self.logger.debug(f"Queue size: {queue_size}, Queue contents: {timestamps}")
|
||||
return queue_size, timestamps
|
||||
|
||||
def _aggregate_action_queues(
|
||||
self,
|
||||
incoming_actions: list[TimedAction],
|
||||
aggregate_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
|
||||
):
|
||||
"""Finds the same timestep actions in the queue and aggregates them using the aggregate_fn"""
|
||||
if aggregate_fn is None:
|
||||
# default aggregate function: take the latest action
|
||||
def aggregate_fn(x1, x2):
|
||||
return x2
|
||||
|
||||
future_action_queue = Queue()
|
||||
with self.action_queue_lock:
|
||||
internal_queue = self.action_queue.queue
|
||||
|
||||
current_action_queue = {action.get_timestep(): action.get_action() for action in internal_queue}
|
||||
|
||||
for new_action in incoming_actions:
|
||||
with self.latest_action_lock:
|
||||
latest_action = self.latest_action
|
||||
|
||||
# New action is older than the latest action in the queue, skip it
|
||||
if new_action.get_timestep() <= latest_action:
|
||||
continue
|
||||
|
||||
# If the new action's timestep is not in the current action queue, add it directly
|
||||
elif new_action.get_timestep() not in current_action_queue:
|
||||
future_action_queue.put(new_action)
|
||||
continue
|
||||
|
||||
# If the new action's timestep is in the current action queue, aggregate it
|
||||
# TODO: There is probably a way to do this with broadcasting of the two action tensors
|
||||
future_action_queue.put(
|
||||
TimedAction(
|
||||
timestamp=new_action.get_timestamp(),
|
||||
timestep=new_action.get_timestep(),
|
||||
action=aggregate_fn(
|
||||
current_action_queue[new_action.get_timestep()], new_action.get_action()
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
with self.action_queue_lock:
|
||||
self.action_queue = future_action_queue
|
||||
|
||||
def receive_actions(self, verbose: bool = False):
|
||||
"""Receive actions from the policy server"""
|
||||
# Wait at barrier for synchronized start
|
||||
self.start_barrier.wait()
|
||||
self.logger.info("Action receiving thread starting")
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
# Use StreamActions to get a stream of actions from the server
|
||||
actions_chunk = self.stub.GetActions(services_pb2.Empty())
|
||||
if len(actions_chunk.data) == 0:
|
||||
continue # received `Empty` from server, wait for next call
|
||||
|
||||
receive_time = time.time()
|
||||
|
||||
# Deserialize bytes back into list[TimedAction]
|
||||
deserialize_start = time.perf_counter()
|
||||
timed_actions = pickle.loads(actions_chunk.data) # nosec
|
||||
deserialize_time = time.perf_counter() - deserialize_start
|
||||
|
||||
# Log device type of received actions
|
||||
if len(timed_actions) > 0:
|
||||
received_device = timed_actions[0].get_action().device.type
|
||||
self.logger.debug(f"Received actions on device: {received_device}")
|
||||
|
||||
# Move actions to client_device (e.g., for downstream planners that need GPU)
|
||||
client_device = self.config.client_device
|
||||
if client_device != "cpu":
|
||||
for timed_action in timed_actions:
|
||||
if timed_action.get_action().device.type != client_device:
|
||||
timed_action.action = timed_action.get_action().to(client_device)
|
||||
self.logger.debug(f"Converted actions to device: {client_device}")
|
||||
else:
|
||||
self.logger.debug(f"Actions kept on device: {client_device}")
|
||||
|
||||
self.action_chunk_size = max(self.action_chunk_size, len(timed_actions))
|
||||
|
||||
# Calculate network latency if we have matching observations
|
||||
if len(timed_actions) > 0 and verbose:
|
||||
with self.latest_action_lock:
|
||||
latest_action = self.latest_action
|
||||
|
||||
self.logger.debug(f"Current latest action: {latest_action}")
|
||||
|
||||
# Get queue state before changes
|
||||
old_size, old_timesteps = self._inspect_action_queue()
|
||||
if not old_timesteps:
|
||||
old_timesteps = [latest_action] # queue was empty
|
||||
|
||||
# Log incoming actions
|
||||
incoming_timesteps = [a.get_timestep() for a in timed_actions]
|
||||
|
||||
first_action_timestep = timed_actions[0].get_timestep()
|
||||
server_to_client_latency = (receive_time - timed_actions[0].get_timestamp()) * 1000
|
||||
|
||||
self.logger.info(
|
||||
f"Received action chunk for step #{first_action_timestep} | "
|
||||
f"Latest action: #{latest_action} | "
|
||||
f"Incoming actions: {incoming_timesteps[0]}:{incoming_timesteps[-1]} | "
|
||||
f"Network latency (server->client): {server_to_client_latency:.2f}ms | "
|
||||
f"Deserialization time: {deserialize_time * 1000:.2f}ms"
|
||||
)
|
||||
|
||||
# Update action queue
|
||||
start_time = time.perf_counter()
|
||||
self._aggregate_action_queues(timed_actions, self.config.aggregate_fn)
|
||||
queue_update_time = time.perf_counter() - start_time
|
||||
|
||||
self.must_go.set() # after receiving actions, next empty queue triggers must-go processing!
|
||||
|
||||
if verbose:
|
||||
# Get queue state after changes
|
||||
new_size, new_timesteps = self._inspect_action_queue()
|
||||
|
||||
with self.latest_action_lock:
|
||||
latest_action = self.latest_action
|
||||
|
||||
self.logger.info(
|
||||
f"Latest action: {latest_action} | "
|
||||
f"Old action steps: {old_timesteps[0]}:{old_timesteps[-1]} | "
|
||||
f"Incoming action steps: {incoming_timesteps[0]}:{incoming_timesteps[-1]} | "
|
||||
f"Updated action steps: {new_timesteps[0]}:{new_timesteps[-1]}"
|
||||
)
|
||||
self.logger.debug(
|
||||
f"Queue update complete ({queue_update_time:.6f}s) | "
|
||||
f"Before: {old_size} items | "
|
||||
f"After: {new_size} items | "
|
||||
)
|
||||
|
||||
except grpc.RpcError as e:
|
||||
self.logger.error(f"Error receiving actions: {e}")
|
||||
|
||||
def actions_available(self):
|
||||
"""Check if there are actions available in the queue"""
|
||||
with self.action_queue_lock:
|
||||
return not self.action_queue.empty()
|
||||
|
||||
def _action_tensor_to_action_dict(self, action_tensor: torch.Tensor) -> dict[str, float]:
|
||||
action = {key: action_tensor[i].item() for i, key in enumerate(self.robot.action_features)}
|
||||
return action
|
||||
|
||||
def control_loop_action(self, verbose: bool = False) -> dict[str, Any]:
|
||||
"""Reading and performing actions in local queue"""
|
||||
|
||||
# Lock only for queue operations
|
||||
get_start = time.perf_counter()
|
||||
with self.action_queue_lock:
|
||||
self.action_queue_size.append(self.action_queue.qsize())
|
||||
# Get action from queue
|
||||
timed_action = self.action_queue.get_nowait()
|
||||
get_end = time.perf_counter() - get_start
|
||||
|
||||
_performed_action = self.robot.send_action(
|
||||
self._action_tensor_to_action_dict(timed_action.get_action())
|
||||
)
|
||||
with self.latest_action_lock:
|
||||
self.latest_action = timed_action.get_timestep()
|
||||
|
||||
if verbose:
|
||||
with self.action_queue_lock:
|
||||
current_queue_size = self.action_queue.qsize()
|
||||
|
||||
self.logger.debug(
|
||||
f"Ts={timed_action.get_timestamp()} | "
|
||||
f"Action #{timed_action.get_timestep()} performed | "
|
||||
f"Queue size: {current_queue_size}"
|
||||
)
|
||||
|
||||
self.logger.debug(
|
||||
f"Popping action from queue to perform took {get_end:.6f}s | Queue size: {current_queue_size}"
|
||||
)
|
||||
|
||||
return _performed_action
|
||||
|
||||
def _ready_to_send_observation(self):
|
||||
"""Flags when the client is ready to send an observation"""
|
||||
with self.action_queue_lock:
|
||||
return self.action_queue.qsize() / self.action_chunk_size <= self._chunk_size_threshold
|
||||
|
||||
def control_loop_observation(self, task: str, verbose: bool = False) -> RawObservation:
|
||||
try:
|
||||
# Get serialized observation bytes from the function
|
||||
start_time = time.perf_counter()
|
||||
|
||||
raw_observation: RawObservation = self.robot.get_observation()
|
||||
raw_observation["task"] = task
|
||||
|
||||
with self.latest_action_lock:
|
||||
latest_action = self.latest_action
|
||||
|
||||
observation = TimedObservation(
|
||||
timestamp=time.time(), # need time.time() to compare timestamps across client and server
|
||||
observation=raw_observation,
|
||||
timestep=max(latest_action, 0),
|
||||
)
|
||||
|
||||
obs_capture_time = time.perf_counter() - start_time
|
||||
|
||||
# If there are no actions left in the queue, the observation must go through processing!
|
||||
with self.action_queue_lock:
|
||||
observation.must_go = self.must_go.is_set() and self.action_queue.empty()
|
||||
current_queue_size = self.action_queue.qsize()
|
||||
|
||||
_ = self.send_observation(observation)
|
||||
|
||||
self.logger.debug(f"QUEUE SIZE: {current_queue_size} (Must go: {observation.must_go})")
|
||||
if observation.must_go:
|
||||
# must-go event will be set again after receiving actions
|
||||
self.must_go.clear()
|
||||
|
||||
if verbose:
|
||||
# Calculate comprehensive FPS metrics
|
||||
fps_metrics = self.fps_tracker.calculate_fps_metrics(observation.get_timestamp())
|
||||
|
||||
self.logger.info(
|
||||
f"Obs #{observation.get_timestep()} | "
|
||||
f"Avg FPS: {fps_metrics['avg_fps']:.2f} | "
|
||||
f"Target: {fps_metrics['target_fps']:.2f}"
|
||||
)
|
||||
|
||||
self.logger.debug(
|
||||
f"Ts={observation.get_timestamp():.6f} | Capturing observation took {obs_capture_time:.6f}s"
|
||||
)
|
||||
|
||||
return raw_observation
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in observation sender: {e}")
|
||||
|
||||
def control_loop(self, task: str, verbose: bool = False) -> tuple[Observation, Action]:
|
||||
"""Combined function for executing actions and streaming observations"""
|
||||
# Wait at barrier for synchronized start
|
||||
self.start_barrier.wait()
|
||||
self.logger.info("Control loop thread starting")
|
||||
|
||||
_performed_action = None
|
||||
_captured_observation = None
|
||||
|
||||
while self.running:
|
||||
control_loop_start = time.perf_counter()
|
||||
"""Control loop: (1) Performing actions, when available"""
|
||||
if self.actions_available():
|
||||
_performed_action = self.control_loop_action(verbose)
|
||||
|
||||
"""Control loop: (2) Streaming observations to the remote policy server"""
|
||||
if self._ready_to_send_observation():
|
||||
_captured_observation = self.control_loop_observation(task, verbose)
|
||||
|
||||
self.logger.debug(f"Control loop (ms): {(time.perf_counter() - control_loop_start) * 1000:.2f}")
|
||||
# Dynamically adjust sleep time to maintain the desired control frequency
|
||||
time.sleep(max(0, self.config.environment_dt - (time.perf_counter() - control_loop_start)))
|
||||
|
||||
return _captured_observation, _performed_action
|
||||
|
||||
|
||||
@draccus.wrap()
|
||||
def async_client(cfg: RobotClientConfig):
|
||||
logging.info(pformat(asdict(cfg)))
|
||||
|
||||
# TODO: Assert if checking robot support is still needed with the plugin system
|
||||
# if cfg.robot.type not in SUPPORTED_ROBOTS:
|
||||
# raise ValueError(f"Robot {cfg.robot.type} not yet supported!")
|
||||
|
||||
client = RobotClient(cfg)
|
||||
|
||||
if client.start():
|
||||
client.logger.info("Starting action receiver thread...")
|
||||
|
||||
# Create and start action receiver thread
|
||||
action_receiver_thread = threading.Thread(target=client.receive_actions, daemon=True)
|
||||
|
||||
# Start action receiver thread
|
||||
action_receiver_thread.start()
|
||||
|
||||
try:
|
||||
# The main thread runs the control loop
|
||||
client.control_loop(task=cfg.task)
|
||||
|
||||
finally:
|
||||
client.stop()
|
||||
action_receiver_thread.join()
|
||||
if cfg.debug_visualize_queue_size:
|
||||
visualize_action_queue_size(client.action_queue_size)
|
||||
client.logger.info("Client stopped")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
register_third_party_plugins()
|
||||
async_client() # run the client
|
||||
@@ -18,6 +18,7 @@ from __future__ import annotations
|
||||
# Utilities
|
||||
########################################################################################
|
||||
import logging
|
||||
import time
|
||||
import traceback
|
||||
from contextlib import nullcontext
|
||||
from copy import copy
|
||||
@@ -243,3 +244,72 @@ def sanity_check_dataset_robot_compatibility(
|
||||
raise ValueError(
|
||||
"Dataset metadata compatibility check failed with mismatches:\n" + "\n".join(mismatches)
|
||||
)
|
||||
|
||||
|
||||
########################################################################################
|
||||
# Teleoperator smooth handover helpers
|
||||
# NOTE(Maxime): These functions use minimal type hints to maintain compatibility with utils
|
||||
# being a root module.
|
||||
########################################################################################
|
||||
|
||||
|
||||
def teleop_supports_feedback(teleop) -> bool:
|
||||
"""Return True when the teleop can receive position feedback (is actuated).
|
||||
|
||||
Actuated teleops (e.g. SO-101, OpenArmMini) have non-empty ``feedback_features``
|
||||
and expose ``enable_torque`` / ``disable_torque`` motor-control methods.
|
||||
|
||||
TODO(Maxime): See if it is possible to unify this interface across teleops instead of duck-typing.
|
||||
"""
|
||||
return (
|
||||
bool(teleop.feedback_features)
|
||||
and hasattr(teleop, "disable_torque")
|
||||
and hasattr(teleop, "enable_torque")
|
||||
)
|
||||
|
||||
|
||||
def teleop_smooth_move_to(teleop, target_pos: dict, duration_s: float = 2.0, fps: int = 30) -> None:
|
||||
"""Smoothly move an actuated teleop to ``target_pos`` via linear interpolation.
|
||||
|
||||
Requires the teleoperator to support feedback (i.e. have non-empty
|
||||
``feedback_features`` and implement ``disable_torque`` / ``enable_torque``).
|
||||
|
||||
``target_pos`` is expected to be in the teleop's action/feedback key space.
|
||||
For homogeneous setups (e.g. SO-101 leader + SO-101 follower) this matches
|
||||
the robot action key space directly.
|
||||
|
||||
TODO(Maxime): This blocks up to ``duration_s`` seconds; during this time the
|
||||
follower robot does not receive new actions, which could be an issue on LeKiwi.
|
||||
"""
|
||||
teleop.enable_torque()
|
||||
current = teleop.get_action()
|
||||
steps = max(int(duration_s * fps), 1)
|
||||
|
||||
for step in range(steps + 1):
|
||||
t = step / steps
|
||||
interp = {
|
||||
k: current[k] * (1 - t) + target_pos[k] * t if k in target_pos else current[k] for k in current
|
||||
}
|
||||
teleop.send_feedback(interp)
|
||||
time.sleep(1 / fps)
|
||||
|
||||
|
||||
def follower_smooth_move_to(
|
||||
robot, current: dict, target: dict, duration_s: float = 1.0, fps: int = 30
|
||||
) -> None:
|
||||
"""Smoothly move the follower robot from ``current`` to ``target`` action.
|
||||
|
||||
Used when the teleop is non-actuated: instead of driving the leader arm to
|
||||
the follower, the follower is brought to the teleop's current pose so the
|
||||
robot meets the operator's hand rather than jumping to it on the first frame.
|
||||
|
||||
Both ``current`` and ``target`` must be in the robot action key space
|
||||
(i.e. the output of ``robot_action_processor``).
|
||||
"""
|
||||
steps = max(int(duration_s * fps), 1)
|
||||
|
||||
for step in range(steps + 1):
|
||||
t = step / steps
|
||||
interp = {k: current[k] * (1 - t) + target[k] * t if k in target else current[k] for k in current}
|
||||
robot.send_action(interp)
|
||||
time.sleep(1 / fps)
|
||||
|
||||
@@ -30,6 +30,7 @@ class EpisodeAwareSampler:
|
||||
drop_n_first_frames: int = 0,
|
||||
drop_n_last_frames: int = 0,
|
||||
shuffle: bool = False,
|
||||
generator: torch.Generator | None = None,
|
||||
):
|
||||
"""Sampler that optionally incorporates episode boundary information.
|
||||
|
||||
@@ -41,6 +42,10 @@ class EpisodeAwareSampler:
|
||||
drop_n_first_frames: Number of frames to drop from the start of each episode.
|
||||
drop_n_last_frames: Number of frames to drop from the end of each episode.
|
||||
shuffle: Whether to shuffle the indices.
|
||||
generator: Generator used for shuffling. Exposing this attribute (even when None) lets
|
||||
`accelerate` register it as the synchronized RNG in distributed training, so
|
||||
every rank draws the same permutation and batch shards stay disjoint. When
|
||||
None, shuffling falls back to the global torch RNG.
|
||||
"""
|
||||
if drop_n_first_frames < 0:
|
||||
raise ValueError(f"drop_n_first_frames must be >= 0, got {drop_n_first_frames}")
|
||||
@@ -73,10 +78,11 @@ class EpisodeAwareSampler:
|
||||
|
||||
self.indices = indices
|
||||
self.shuffle = shuffle
|
||||
self.generator = generator
|
||||
|
||||
def __iter__(self) -> Iterator[int]:
|
||||
if self.shuffle:
|
||||
for i in torch.randperm(len(self.indices)):
|
||||
for i in torch.randperm(len(self.indices), generator=self.generator):
|
||||
yield self.indices[i]
|
||||
else:
|
||||
for i in self.indices:
|
||||
|
||||
@@ -280,22 +280,26 @@ def make_pre_post_processors(
|
||||
policy configuration type.
|
||||
"""
|
||||
if pretrained_path:
|
||||
# TODO(Steven): Temporary patch, implement correctly the processors for Gr00t
|
||||
if isinstance(policy_cfg, GrootConfig):
|
||||
from .groot.processor_groot import make_groot_pre_post_processors_from_pretrained
|
||||
# GROOT handles normalization in groot_pack_inputs_v3 step
|
||||
# Need to override both stats AND normalize_min_max since saved config might be empty
|
||||
preprocessor_overrides = {}
|
||||
postprocessor_overrides = {}
|
||||
preprocessor_overrides["groot_pack_inputs_v3"] = {
|
||||
"stats": kwargs.get("dataset_stats"),
|
||||
"normalize_min_max": True,
|
||||
}
|
||||
|
||||
return make_groot_pre_post_processors_from_pretrained(
|
||||
config=policy_cfg,
|
||||
pretrained_path=pretrained_path,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
preprocessor_overrides=kwargs.get("preprocessor_overrides"),
|
||||
postprocessor_overrides=kwargs.get("postprocessor_overrides"),
|
||||
preprocessor_config_filename=kwargs.get(
|
||||
"preprocessor_config_filename", f"{POLICY_PREPROCESSOR_DEFAULT_NAME}.json"
|
||||
),
|
||||
postprocessor_config_filename=kwargs.get(
|
||||
"postprocessor_config_filename", f"{POLICY_POSTPROCESSOR_DEFAULT_NAME}.json"
|
||||
),
|
||||
)
|
||||
# Also ensure postprocessing slices to env action dim and unnormalizes with dataset stats
|
||||
env_action_dim = policy_cfg.output_features[ACTION].shape[0]
|
||||
postprocessor_overrides["groot_action_unpack_unnormalize_v1"] = {
|
||||
"stats": kwargs.get("dataset_stats"),
|
||||
"normalize_min_max": True,
|
||||
"env_action_dim": env_action_dim,
|
||||
}
|
||||
kwargs["preprocessor_overrides"] = preprocessor_overrides
|
||||
kwargs["postprocessor_overrides"] = postprocessor_overrides
|
||||
|
||||
preprocessor = PolicyProcessorPipeline.from_pretrained(
|
||||
pretrained_model_name_or_path=pretrained_path,
|
||||
|
||||
@@ -18,12 +18,4 @@ from .configuration_groot import GrootConfig
|
||||
from .modeling_groot import GrootPolicy
|
||||
from .processor_groot import make_groot_pre_post_processors
|
||||
|
||||
__all__ = ["GR00TN17", "GR00TN17Config", "GrootConfig", "GrootPolicy", "make_groot_pre_post_processors"]
|
||||
|
||||
|
||||
def __getattr__(name: str):
|
||||
if name in {"GR00TN17", "GR00TN17Config"}:
|
||||
from .groot_n1_7 import GR00TN17, GR00TN17Config
|
||||
|
||||
return {"GR00TN17": GR00TN17, "GR00TN17Config": GR00TN17Config}[name]
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
__all__ = ["GrootConfig", "GrootPolicy", "make_groot_pre_post_processors"]
|
||||
|
||||
@@ -0,0 +1,54 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def swish(x):
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
class SinusoidalPositionalEncoding(nn.Module):
|
||||
"""
|
||||
Produces a sinusoidal encoding of shape (B, T, w)
|
||||
given timesteps of shape (B, T).
|
||||
"""
|
||||
|
||||
def __init__(self, embedding_dim):
|
||||
super().__init__()
|
||||
self.embedding_dim = embedding_dim
|
||||
|
||||
def forward(self, timesteps):
|
||||
# timesteps: shape (B, T)
|
||||
# We'll compute sin/cos frequencies across dim T
|
||||
timesteps = timesteps.float() # ensure float
|
||||
|
||||
b, t = timesteps.shape
|
||||
device = timesteps.device
|
||||
|
||||
half_dim = self.embedding_dim // 2
|
||||
# typical log space frequencies for sinusoidal encoding
|
||||
exponent = -torch.arange(half_dim, dtype=torch.float, device=device) * (
|
||||
torch.log(torch.tensor(10000.0)) / half_dim
|
||||
)
|
||||
# Expand timesteps to (B, T, 1) then multiply
|
||||
freqs = timesteps.unsqueeze(-1) * exponent.exp() # (B, T, half_dim)
|
||||
|
||||
sin = torch.sin(freqs)
|
||||
cos = torch.cos(freqs)
|
||||
enc = torch.cat([sin, cos], dim=-1) # (B, T, w)
|
||||
|
||||
return enc
|
||||
@@ -14,7 +14,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
@@ -43,9 +42,6 @@ else:
|
||||
Timesteps = None
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TimestepEncoder(nn.Module):
|
||||
def __init__(self, embedding_dim, compute_dtype=torch.float32):
|
||||
require_package("diffusers", extra="groot")
|
||||
@@ -185,7 +181,8 @@ class BasicTransformerBlock(nn.Module):
|
||||
attn_output = self.attn1(
|
||||
norm_hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=encoder_attention_mask if encoder_hidden_states is not None else attention_mask,
|
||||
attention_mask=attention_mask,
|
||||
# encoder_attention_mask=encoder_attention_mask,
|
||||
)
|
||||
if self.final_dropout:
|
||||
attn_output = self.final_dropout(attn_output)
|
||||
@@ -269,8 +266,8 @@ class DiT(ModelMixin, ConfigMixin):
|
||||
self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
|
||||
self.proj_out_1 = nn.Linear(self.inner_dim, 2 * self.inner_dim)
|
||||
self.proj_out_2 = nn.Linear(self.inner_dim, self.config.output_dim)
|
||||
logger.debug(
|
||||
"Total number of DiT parameters: %d",
|
||||
print(
|
||||
"Total number of DiT parameters: ",
|
||||
sum(p.numel() for p in self.parameters() if p.requires_grad),
|
||||
)
|
||||
|
||||
@@ -321,71 +318,6 @@ class DiT(ModelMixin, ConfigMixin):
|
||||
return self.proj_out_2(hidden_states)
|
||||
|
||||
|
||||
class AlternateVLDiT(DiT):
|
||||
"""N1.7 DiT variant that alternates cross-attention over image and text tokens."""
|
||||
|
||||
def __init__(self, *args, attend_text_every_n_blocks: int = 2, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.attend_text_every_n_blocks = attend_text_every_n_blocks
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
timestep: torch.LongTensor | None = None,
|
||||
encoder_attention_mask: torch.Tensor | None = None,
|
||||
return_all_hidden_states: bool = False,
|
||||
image_mask: torch.Tensor | None = None,
|
||||
backbone_attention_mask: torch.Tensor | None = None,
|
||||
):
|
||||
if image_mask is None:
|
||||
raise ValueError("image_mask is required for AlternateVLDiT.")
|
||||
if backbone_attention_mask is None:
|
||||
raise ValueError("backbone_attention_mask is required for AlternateVLDiT.")
|
||||
|
||||
temb = self.timestep_encoder(timestep)
|
||||
hidden_states = hidden_states.contiguous()
|
||||
encoder_hidden_states = encoder_hidden_states.contiguous()
|
||||
|
||||
image_attention_mask = image_mask & backbone_attention_mask
|
||||
non_image_attention_mask = (~image_mask) & backbone_attention_mask
|
||||
|
||||
all_hidden_states = [hidden_states]
|
||||
if not self.config.interleave_self_attention:
|
||||
raise ValueError("AlternateVLDiT requires interleave_self_attention=True.")
|
||||
|
||||
for idx, block in enumerate(self.transformer_blocks):
|
||||
if idx % 2 == 1:
|
||||
hidden_states = block(
|
||||
hidden_states,
|
||||
attention_mask=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
temb=temb,
|
||||
)
|
||||
else:
|
||||
curr_encoder_attention_mask = (
|
||||
non_image_attention_mask
|
||||
if idx % (2 * self.attend_text_every_n_blocks) == 0
|
||||
else image_attention_mask
|
||||
)
|
||||
hidden_states = block(
|
||||
hidden_states,
|
||||
attention_mask=None,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=curr_encoder_attention_mask,
|
||||
temb=temb,
|
||||
)
|
||||
all_hidden_states.append(hidden_states)
|
||||
|
||||
conditioning = temb
|
||||
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
|
||||
hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
|
||||
if return_all_hidden_states:
|
||||
return self.proj_out_2(hidden_states), all_hidden_states
|
||||
return self.proj_out_2(hidden_states)
|
||||
|
||||
|
||||
class SelfAttentionTransformer(ModelMixin, ConfigMixin):
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@@ -430,8 +362,8 @@ class SelfAttentionTransformer(ModelMixin, ConfigMixin):
|
||||
for _ in range(self.config.num_layers)
|
||||
]
|
||||
)
|
||||
logger.debug(
|
||||
"Total number of SelfAttentionTransformer parameters: %d",
|
||||
print(
|
||||
"Total number of SelfAttentionTransformer parameters: ",
|
||||
sum(p.numel() for p in self.parameters() if p.requires_grad),
|
||||
)
|
||||
|
||||
|
||||
@@ -0,0 +1,408 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import field
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import nn
|
||||
from torch.distributions import Beta
|
||||
|
||||
from lerobot.utils.import_utils import _transformers_available
|
||||
|
||||
# Conditional import for type checking and lazy loading
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers import PretrainedConfig
|
||||
from transformers.feature_extraction_utils import BatchFeature
|
||||
else:
|
||||
PretrainedConfig = object
|
||||
BatchFeature = None
|
||||
|
||||
from .action_encoder import (
|
||||
SinusoidalPositionalEncoding,
|
||||
swish,
|
||||
)
|
||||
from .cross_attention_dit import DiT, SelfAttentionTransformer
|
||||
|
||||
|
||||
class CategorySpecificLinear(nn.Module):
|
||||
def __init__(self, num_categories, input_dim, hidden_dim):
|
||||
super().__init__()
|
||||
self.num_categories = num_categories
|
||||
# For each category, we have separate weights and biases.
|
||||
self.W = nn.Parameter(0.02 * torch.randn(num_categories, input_dim, hidden_dim))
|
||||
self.b = nn.Parameter(torch.zeros(num_categories, hidden_dim))
|
||||
|
||||
def forward(self, x, cat_ids):
|
||||
selected_w = self.W[cat_ids]
|
||||
selected_b = self.b[cat_ids]
|
||||
return torch.bmm(x, selected_w) + selected_b.unsqueeze(1)
|
||||
|
||||
|
||||
class CategorySpecificMLP(nn.Module):
|
||||
def __init__(self, num_categories, input_dim, hidden_dim, output_dim):
|
||||
super().__init__()
|
||||
self.num_categories = num_categories
|
||||
self.layer1 = CategorySpecificLinear(num_categories, input_dim, hidden_dim)
|
||||
self.layer2 = CategorySpecificLinear(num_categories, hidden_dim, output_dim)
|
||||
|
||||
def forward(self, x, cat_ids):
|
||||
hidden = F.relu(self.layer1(x, cat_ids))
|
||||
return self.layer2(hidden, cat_ids)
|
||||
|
||||
|
||||
class MultiEmbodimentActionEncoder(nn.Module):
|
||||
def __init__(self, action_dim, hidden_size, num_embodiments):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.num_embodiments = num_embodiments
|
||||
|
||||
# W1: R^{w x d}, W2: R^{w x 2w}, W3: R^{w x w}
|
||||
self.W1 = CategorySpecificLinear(num_embodiments, action_dim, hidden_size) # (d -> w)
|
||||
self.W2 = CategorySpecificLinear(num_embodiments, 2 * hidden_size, hidden_size) # (2w -> w)
|
||||
self.W3 = CategorySpecificLinear(num_embodiments, hidden_size, hidden_size) # (w -> w)
|
||||
self.pos_encoding = SinusoidalPositionalEncoding(hidden_size)
|
||||
|
||||
def forward(self, actions, timesteps, cat_ids):
|
||||
"""
|
||||
actions: shape (B, T, action_dim)
|
||||
timesteps: shape (B,) -- a single scalar per batch item
|
||||
cat_ids: shape (B,)
|
||||
returns: shape (B, T, hidden_size)
|
||||
"""
|
||||
b, t, _ = actions.shape
|
||||
|
||||
# 1) Expand each batch's single scalar time 'tau' across all T steps
|
||||
# so that shape => (B, T)
|
||||
# e.g. if timesteps is (B,), replicate across T
|
||||
if timesteps.dim() == 1 and timesteps.shape[0] == b:
|
||||
# shape (B,) => (B,T)
|
||||
timesteps = timesteps.unsqueeze(1).expand(-1, t)
|
||||
else:
|
||||
raise ValueError("Expected `timesteps` to have shape (B,) so we can replicate across T.")
|
||||
|
||||
# 2) Standard action MLP step for shape => (B, T, w)
|
||||
a_emb = self.W1(actions, cat_ids)
|
||||
|
||||
# 3) Get the sinusoidal encoding (B, T, w)
|
||||
tau_emb = self.pos_encoding(timesteps).to(dtype=a_emb.dtype)
|
||||
|
||||
# 4) Concat along last dim => (B, T, 2w), then W2 => (B, T, w), swish
|
||||
x = torch.cat([a_emb, tau_emb], dim=-1)
|
||||
x = swish(self.W2(x, cat_ids))
|
||||
|
||||
# 5) Finally W3 => (B, T, w)
|
||||
x = self.W3(x, cat_ids)
|
||||
return x
|
||||
|
||||
|
||||
class FlowmatchingActionHeadConfig(PretrainedConfig):
|
||||
"""NOTE: N1.5 uses XEmbFlowmatchingPolicyHeadConfig as action head"""
|
||||
|
||||
add_pos_embed: bool = field(default=True, metadata={"help": "Whether to add positional embedding"})
|
||||
model_dtype: str = field(default="float32", metadata={"help": "Model data type."})
|
||||
diffusion_model_cfg: dict = field(default=None, metadata={"help": "Diffusion model configuration."})
|
||||
input_embedding_dim: int = field(default=1536, metadata={"help": "Input embedding channel dimension."})
|
||||
backbone_embedding_dim: int = field(
|
||||
default=1536, metadata={"help": "Backbone embedding channel dimension."}
|
||||
)
|
||||
|
||||
hidden_size: int = field(default=1024, metadata={"help": "Input embedding dimension."})
|
||||
max_seq_len: int = field(default=1024, metadata={"help": "Maximum Sequence Length"})
|
||||
action_dim: int = field(default=None, metadata={"help": "Action dimension."})
|
||||
action_horizon: int = field(default=None, metadata={"help": "Action horizon."})
|
||||
noise_beta_alpha: float = field(default=1.5, metadata={"help": ""})
|
||||
noise_beta_beta: float = field(default=1.0, metadata={"help": ""})
|
||||
noise_s: float = field(default=0.999, metadata={"help": "Flow matching noise Beta distribution s."})
|
||||
num_timestep_buckets: int = field(
|
||||
default=1000, metadata={"help": "Number of timestep discretization buckets."}
|
||||
)
|
||||
num_inference_timesteps: int = field(
|
||||
default=None,
|
||||
metadata={"help": "Number of inference steps for noise diffusion."},
|
||||
)
|
||||
max_num_embodiments: int = field(default=32, metadata={"help": "Number of embodiments."})
|
||||
tune_projector: bool = field(default=True, metadata={"help": "Whether to tune the projector."})
|
||||
tune_diffusion_model: bool = field(
|
||||
default=True, metadata={"help": "Whether to tune the diffusion model."}
|
||||
)
|
||||
load_pretrained_det_decode_layer_path: str = field(
|
||||
default=None, metadata={"help": "Path to pretrained detection model."}
|
||||
)
|
||||
detection_coeff: float = field(default=1.0, metadata={"help": "Detection coefficient."})
|
||||
|
||||
freeze_decode_layer: bool = field(default=False)
|
||||
expand_batch: int = field(default=None)
|
||||
use_vlln: bool = field(default=True)
|
||||
|
||||
vl_self_attention_cfg: dict = field(default=None)
|
||||
num_target_vision_tokens: int = field(default=32, metadata={"help": "Number of target vision tokens."})
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
|
||||
class FlowmatchingActionHead(nn.Module):
|
||||
config_class = FlowmatchingActionHeadConfig
|
||||
supports_gradient_checkpointing = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: FlowmatchingActionHeadConfig,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.input_embedding_dim = config.input_embedding_dim
|
||||
|
||||
self.model = DiT(**config.diffusion_model_cfg)
|
||||
self.action_dim = config.action_dim
|
||||
self.action_horizon = config.action_horizon
|
||||
self.num_inference_timesteps = config.num_inference_timesteps
|
||||
|
||||
self.state_encoder = CategorySpecificMLP(
|
||||
num_categories=config.max_num_embodiments,
|
||||
input_dim=config.max_state_dim,
|
||||
hidden_dim=self.hidden_size,
|
||||
output_dim=self.input_embedding_dim,
|
||||
)
|
||||
self.action_encoder = MultiEmbodimentActionEncoder(
|
||||
action_dim=config.action_dim,
|
||||
hidden_size=self.input_embedding_dim,
|
||||
num_embodiments=config.max_num_embodiments,
|
||||
)
|
||||
self.action_decoder = CategorySpecificMLP(
|
||||
num_categories=config.max_num_embodiments,
|
||||
input_dim=self.hidden_size,
|
||||
hidden_dim=self.hidden_size,
|
||||
output_dim=self.action_dim,
|
||||
)
|
||||
self.future_tokens = nn.Embedding(config.num_target_vision_tokens, self.input_embedding_dim)
|
||||
nn.init.normal_(self.future_tokens.weight, mean=0.0, std=0.02)
|
||||
|
||||
self.vlln = nn.LayerNorm(config.backbone_embedding_dim) if config.use_vlln else nn.Identity()
|
||||
self.vl_self_attention = (
|
||||
SelfAttentionTransformer(**config.vl_self_attention_cfg) if config.use_vlln else nn.Identity()
|
||||
)
|
||||
|
||||
if config.add_pos_embed:
|
||||
self.position_embedding = nn.Embedding(config.max_seq_len, self.input_embedding_dim)
|
||||
nn.init.normal_(self.position_embedding.weight, mean=0.0, std=0.02)
|
||||
|
||||
self._noise_beta_alpha = config.noise_beta_alpha
|
||||
self._noise_beta_beta = config.noise_beta_beta
|
||||
self._beta_dist = None
|
||||
self.num_timestep_buckets = config.num_timestep_buckets
|
||||
self.config = config
|
||||
self.set_trainable_parameters(config.tune_projector, config.tune_diffusion_model)
|
||||
|
||||
def set_trainable_parameters(self, tune_projector: bool, tune_diffusion_model: bool):
|
||||
self.tune_projector = tune_projector
|
||||
self.tune_diffusion_model = tune_diffusion_model
|
||||
for p in self.parameters():
|
||||
p.requires_grad = True
|
||||
if not tune_projector:
|
||||
self.state_encoder.requires_grad_(False)
|
||||
self.action_encoder.requires_grad_(False)
|
||||
self.action_decoder.requires_grad_(False)
|
||||
if self.config.add_pos_embed:
|
||||
self.position_embedding.requires_grad_(False)
|
||||
if not tune_diffusion_model:
|
||||
self.model.requires_grad_(False)
|
||||
print(f"Tune action head projector: {self.tune_projector}")
|
||||
print(f"Tune action head diffusion model: {self.tune_diffusion_model}")
|
||||
# Check if any parameters are still trainable. If not, print a warning.
|
||||
if not tune_projector and not tune_diffusion_model:
|
||||
for name, p in self.named_parameters():
|
||||
if p.requires_grad:
|
||||
print(f"Action head trainable parameter: {name}")
|
||||
if not any(p.requires_grad for p in self.parameters()):
|
||||
print("Warning: No action head trainable parameters found.")
|
||||
|
||||
def set_frozen_modules_to_eval_mode(self):
|
||||
"""
|
||||
Huggingface will call model.train() at each training_step. To ensure
|
||||
the expected behaviors for modules like dropout, batchnorm, etc., we
|
||||
need to call model.eval() for the frozen modules.
|
||||
"""
|
||||
if self.training:
|
||||
if not self.tune_projector:
|
||||
self.state_encoder.eval()
|
||||
self.action_encoder.eval()
|
||||
self.action_decoder.eval()
|
||||
if self.config.add_pos_embed:
|
||||
self.position_embedding.eval()
|
||||
if not self.tune_diffusion_model:
|
||||
self.model.eval()
|
||||
|
||||
def sample_time(self, batch_size, device, dtype):
|
||||
if self._beta_dist is None:
|
||||
self._beta_dist = Beta(self._noise_beta_alpha, self._noise_beta_beta, validate_args=False)
|
||||
sample = self._beta_dist.sample([batch_size]).to(device, dtype=dtype)
|
||||
return (self.config.noise_s - sample) / self.config.noise_s
|
||||
|
||||
def prepare_input(self, batch: dict) -> BatchFeature:
|
||||
return BatchFeature(data=batch)
|
||||
|
||||
def process_backbone_output(self, backbone_output: BatchFeature) -> BatchFeature:
|
||||
backbone_features = backbone_output["backbone_features"]
|
||||
backbone_features = self.vlln(backbone_features)
|
||||
backbone_features = self.vl_self_attention(backbone_features)
|
||||
backbone_output["backbone_features"] = backbone_features
|
||||
return backbone_output
|
||||
|
||||
def forward(self, backbone_output: BatchFeature, action_input: BatchFeature) -> BatchFeature:
|
||||
# Set frozen modules to eval
|
||||
self.set_frozen_modules_to_eval_mode()
|
||||
|
||||
backbone_output = self.process_backbone_output(backbone_output)
|
||||
|
||||
if self.config.expand_batch is not None:
|
||||
for k, v in backbone_output.items():
|
||||
ndim = len(v.shape)
|
||||
factors = [self.config.expand_batch]
|
||||
while len(factors) < ndim:
|
||||
factors.append(1)
|
||||
factors = tuple(factors)
|
||||
expanded = v.repeat(*factors)
|
||||
backbone_output[k] = expanded
|
||||
|
||||
for k, v in action_input.items():
|
||||
ndim = len(v.shape)
|
||||
factors = [self.config.expand_batch]
|
||||
while len(factors) < ndim:
|
||||
factors.append(1)
|
||||
factors = tuple(factors)
|
||||
expanded = v.repeat(*factors)
|
||||
action_input[k] = expanded
|
||||
|
||||
# Get vision and language embeddings.
|
||||
vl_embs = backbone_output.backbone_features
|
||||
device = vl_embs.device
|
||||
|
||||
# Get embodiment ID.
|
||||
embodiment_id = action_input.embodiment_id
|
||||
|
||||
# Embed state.
|
||||
state_features = self.state_encoder(action_input.state, embodiment_id)
|
||||
|
||||
# Embed noised action trajectory.
|
||||
actions = action_input.action
|
||||
noise = torch.randn(actions.shape, device=actions.device, dtype=actions.dtype)
|
||||
t = self.sample_time(actions.shape[0], device=actions.device, dtype=actions.dtype)
|
||||
t = t[:, None, None] # shape (B,1,1) for broadcast
|
||||
|
||||
noisy_trajectory = (1 - t) * noise + t * actions
|
||||
velocity = actions - noise
|
||||
|
||||
# Convert (continuous) t -> discrete if needed
|
||||
t_discretized = (t[:, 0, 0] * self.num_timestep_buckets).long()
|
||||
action_features = self.action_encoder(noisy_trajectory, t_discretized, embodiment_id)
|
||||
|
||||
# Maybe add position embedding.
|
||||
if self.config.add_pos_embed:
|
||||
pos_ids = torch.arange(action_features.shape[1], dtype=torch.long, device=device)
|
||||
pos_embs = self.position_embedding(pos_ids).unsqueeze(0)
|
||||
action_features = action_features + pos_embs
|
||||
|
||||
# Join vision, language, state and action embedding along sequence dimension.
|
||||
future_tokens = self.future_tokens.weight.unsqueeze(0).expand(vl_embs.shape[0], -1, -1)
|
||||
sa_embs = torch.cat((state_features, future_tokens, action_features), dim=1)
|
||||
|
||||
vl_attn_mask = backbone_output.backbone_attention_mask
|
||||
|
||||
model_output = self.model(
|
||||
hidden_states=sa_embs,
|
||||
encoder_hidden_states=vl_embs,
|
||||
encoder_attention_mask=vl_attn_mask,
|
||||
timestep=t_discretized,
|
||||
return_all_hidden_states=False, # NOTE (YL): not using flare now
|
||||
)
|
||||
pred = self.action_decoder(model_output, embodiment_id)
|
||||
pred_actions = pred[:, -actions.shape[1] :]
|
||||
|
||||
# Slice out only the action portion of pred and target.
|
||||
action_mask = action_input.action_mask
|
||||
loss = F.mse_loss(pred_actions, velocity, reduction="none") * action_mask
|
||||
loss = loss.sum() / action_mask.sum()
|
||||
output_dict = {
|
||||
"loss": loss,
|
||||
}
|
||||
return BatchFeature(data=output_dict)
|
||||
|
||||
@torch.no_grad()
|
||||
def get_action(self, backbone_output: BatchFeature, action_input: BatchFeature) -> BatchFeature:
|
||||
backbone_output = self.process_backbone_output(backbone_output)
|
||||
|
||||
# Get vision and language embeddings.
|
||||
vl_embs = backbone_output.backbone_features
|
||||
embodiment_id = action_input.embodiment_id
|
||||
|
||||
# Embed state.
|
||||
state_features = self.state_encoder(action_input.state, embodiment_id)
|
||||
|
||||
# Set initial actions as the sampled noise.
|
||||
batch_size = vl_embs.shape[0]
|
||||
device = vl_embs.device
|
||||
actions = torch.randn(
|
||||
size=(batch_size, self.config.action_horizon, self.config.action_dim),
|
||||
dtype=vl_embs.dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
num_steps = self.num_inference_timesteps
|
||||
dt = 1.0 / num_steps
|
||||
|
||||
# Run denoising steps.
|
||||
for t in range(num_steps):
|
||||
t_cont = t / float(num_steps) # e.g. goes 0, 1/N, 2/N, ...
|
||||
t_discretized = int(t_cont * self.num_timestep_buckets)
|
||||
|
||||
# Embed noised action trajectory.
|
||||
timesteps_tensor = torch.full(size=(batch_size,), fill_value=t_discretized, device=device)
|
||||
action_features = self.action_encoder(actions, timesteps_tensor, embodiment_id)
|
||||
# Maybe add position embedding.
|
||||
if self.config.add_pos_embed:
|
||||
pos_ids = torch.arange(action_features.shape[1], dtype=torch.long, device=device)
|
||||
pos_embs = self.position_embedding(pos_ids).unsqueeze(0)
|
||||
action_features = action_features + pos_embs
|
||||
|
||||
# Join vision, language, state and action embedding along sequence dimension.
|
||||
future_tokens = self.future_tokens.weight.unsqueeze(0).expand(vl_embs.shape[0], -1, -1)
|
||||
sa_embs = torch.cat((state_features, future_tokens, action_features), dim=1)
|
||||
|
||||
# Run model forward.
|
||||
model_output = self.model(
|
||||
hidden_states=sa_embs,
|
||||
encoder_hidden_states=vl_embs,
|
||||
timestep=timesteps_tensor,
|
||||
)
|
||||
pred = self.action_decoder(model_output, embodiment_id)
|
||||
|
||||
pred_velocity = pred[:, -self.action_horizon :]
|
||||
|
||||
# Update actions using euler integration.
|
||||
actions = actions + dt * pred_velocity
|
||||
return BatchFeature(data={"action_pred": actions})
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(iter(self.parameters())).device
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return next(iter(self.parameters())).dtype
|
||||
@@ -14,327 +14,12 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
|
||||
from lerobot.configs import FeatureType, NormalizationMode, PolicyFeature, PreTrainedConfig
|
||||
from lerobot.optim import AdamWConfig, CosineDecayWithWarmupSchedulerConfig
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
GROOT_N1_7 = "n1.7"
|
||||
# Legacy GR00T N1.5 identifier. N1.5 is NOT a supported model_version (it is
|
||||
# intentionally absent from _GROOT_MODEL_VERSION_ALIASES so normalize_groot_model_version
|
||||
# still rejects it). It is retained only so that infer_groot_model_version can recognise
|
||||
# an N1.5 base path/checkpoint and the N1.7 config/loader can reject the mismatch.
|
||||
GROOT_N1_5 = "n1.5"
|
||||
# Canonical guidance appended to every error raised when an N1.5 checkpoint, config,
|
||||
# or processor pipeline is detected. Keep this message in sync with docs/source/groot.mdx.
|
||||
GROOT_N1_5_REMOVAL_GUIDANCE = (
|
||||
"GR00T N1.5 support was removed from LeRobot. "
|
||||
"To keep using an N1.5 checkpoint, pin the last release that supports it: "
|
||||
"`pip install 'lerobot==0.5.1'`. To use the current release, migrate to GR00T N1.7 "
|
||||
"(model_version='n1.7', base model nvidia/GR00T-N1.7-3B)."
|
||||
)
|
||||
GROOT_N1_7_BASE_MODEL = "nvidia/GR00T-N1.7-3B"
|
||||
GROOT_N1_7_BACKBONE_MODEL = "nvidia/Cosmos-Reason2-2B"
|
||||
# Default GR00T N1.7 training resolution. Fallback if processor_config lacks sizing. Prevents mismatched
|
||||
# full-res patchification by forcing a resize. Mirrored by GR00T_N1_7_DEFAULTS in groot_n1_7.py.
|
||||
N1_7_DEFAULT_IMAGE_TARGET_SIZE = (256, 256)
|
||||
N1_7_DEFAULT_IMAGE_CROP_SIZE = (230, 230)
|
||||
GROOT_ACTION_DECODE_TRANSFORM_LIBERO = "libero"
|
||||
# Sentinel meaning "the user did not pick an action decode transform": __post_init__ resolves it
|
||||
# to the embodiment default ('libero' for 'libero_sim', otherwise None). It is distinct from an
|
||||
# explicit 'none' (resolved to None) so an opt-out survives a draccus save/load round-trip.
|
||||
GROOT_ACTION_DECODE_TRANSFORM_AUTO = "auto"
|
||||
|
||||
_GROOT_MODEL_VERSION_ALIASES = {
|
||||
"n1.7": GROOT_N1_7,
|
||||
"n1_7": GROOT_N1_7,
|
||||
"n1d7": GROOT_N1_7,
|
||||
"n17": GROOT_N1_7,
|
||||
"1.7": GROOT_N1_7,
|
||||
}
|
||||
|
||||
# Legacy N1.5 spellings, kept ONLY so they can be detected and rejected with
|
||||
# GROOT_N1_5_REMOVAL_GUIDANCE (see GROOT_N1_5 above). Never map these to a supported version.
|
||||
_GROOT_N1_5_VERSION_ALIASES = {"n1.5", "n1_5", "n1d5", "n15", "1.5"}
|
||||
|
||||
_GROOT_ACTION_DECODE_TRANSFORM_ALIASES = {
|
||||
GROOT_ACTION_DECODE_TRANSFORM_AUTO: GROOT_ACTION_DECODE_TRANSFORM_AUTO,
|
||||
"none": None,
|
||||
"": None,
|
||||
GROOT_ACTION_DECODE_TRANSFORM_LIBERO: GROOT_ACTION_DECODE_TRANSFORM_LIBERO,
|
||||
}
|
||||
|
||||
|
||||
def normalize_groot_model_version(model_version: str) -> str:
|
||||
normalized = _GROOT_MODEL_VERSION_ALIASES.get(model_version.lower())
|
||||
if normalized is None:
|
||||
supported = GROOT_N1_7
|
||||
message = f"Unsupported GR00T model_version '{model_version}'. Supported versions: {supported}."
|
||||
if model_version.lower() in _GROOT_N1_5_VERSION_ALIASES:
|
||||
message = f"{message} {GROOT_N1_5_REMOVAL_GUIDANCE}"
|
||||
raise ValueError(message)
|
||||
return normalized
|
||||
|
||||
|
||||
def normalize_groot_action_decode_transform(transform: str | None) -> str | None:
|
||||
if transform is None:
|
||||
return None
|
||||
normalized = _GROOT_ACTION_DECODE_TRANSFORM_ALIASES.get(transform.lower())
|
||||
if normalized is None and transform.lower() not in _GROOT_ACTION_DECODE_TRANSFORM_ALIASES:
|
||||
supported = ", ".join(
|
||||
sorted(key for key, value in _GROOT_ACTION_DECODE_TRANSFORM_ALIASES.items() if value is not None)
|
||||
)
|
||||
raise ValueError(
|
||||
f"Unsupported GR00T N1.7 action decode transform '{transform}'. "
|
||||
f"Supported transforms: none, {supported}."
|
||||
)
|
||||
return normalized
|
||||
|
||||
|
||||
def infer_groot_model_version(model_path: str | None) -> str | None:
|
||||
if not model_path:
|
||||
return None
|
||||
model_path_lower = model_path.lower()
|
||||
if "gr00t-n1.7" in model_path_lower or "gr00t_n1.7" in model_path_lower:
|
||||
return GROOT_N1_7
|
||||
# Detect legacy N1.5 paths so the N1.7 config/loader can reject the mismatch.
|
||||
# N1.5 is unsupported, but it must still be recognised here to fail loudly
|
||||
# rather than silently treating an N1.5 checkpoint as N1.7.
|
||||
if "gr00t-n1.5" in model_path_lower or "gr00t_n1.5" in model_path_lower:
|
||||
return GROOT_N1_5
|
||||
config_version = _infer_groot_model_version_from_local_config(model_path)
|
||||
if config_version is not None:
|
||||
return config_version
|
||||
return None
|
||||
|
||||
|
||||
def is_raw_groot_n1_7_checkpoint(model_path: str | Path | None) -> bool:
|
||||
if model_path is None:
|
||||
return False
|
||||
|
||||
path = Path(model_path).expanduser()
|
||||
if path.is_dir():
|
||||
config_path = path / "config.json"
|
||||
elif path.name == "config.json":
|
||||
config_path = path
|
||||
else:
|
||||
return False
|
||||
|
||||
try:
|
||||
with config_path.open() as f:
|
||||
config = json.load(f)
|
||||
except (OSError, json.JSONDecodeError):
|
||||
return False
|
||||
|
||||
return "type" not in config and _infer_groot_model_version_from_config(config) == GROOT_N1_7
|
||||
|
||||
|
||||
def infer_groot_n1_7_embodiment_tag(model_path: str | Path | None) -> str | None:
|
||||
if model_path is None:
|
||||
return None
|
||||
|
||||
processor_config_path = Path(model_path).expanduser() / "processor_config.json"
|
||||
try:
|
||||
with processor_config_path.open() as f:
|
||||
processor_config = json.load(f)
|
||||
except (OSError, json.JSONDecodeError):
|
||||
return None
|
||||
|
||||
modality_configs = processor_config.get("processor_kwargs", {}).get("modality_configs", {})
|
||||
if not isinstance(modality_configs, dict):
|
||||
return None
|
||||
if "libero_sim" in modality_configs:
|
||||
return "libero_sim"
|
||||
if len(modality_configs) == 1:
|
||||
return next(iter(modality_configs))
|
||||
return None
|
||||
|
||||
|
||||
def infer_groot_n1_7_action_horizon(
|
||||
model_path: str | Path | None, embodiment_tag: str | None = None
|
||||
) -> int | None:
|
||||
if model_path is None:
|
||||
return None
|
||||
|
||||
processor_config_path = Path(model_path).expanduser() / "processor_config.json"
|
||||
try:
|
||||
with processor_config_path.open() as f:
|
||||
processor_config = json.load(f)
|
||||
except (OSError, json.JSONDecodeError):
|
||||
return None
|
||||
|
||||
processor_kwargs = processor_config.get("processor_kwargs", {})
|
||||
if not isinstance(processor_kwargs, dict):
|
||||
return None
|
||||
modality_configs = processor_kwargs.get("modality_configs", {})
|
||||
if not isinstance(modality_configs, dict):
|
||||
return None
|
||||
|
||||
if embodiment_tag is None:
|
||||
embodiment_tag = infer_groot_n1_7_embodiment_tag(model_path)
|
||||
if embodiment_tag is None:
|
||||
return None
|
||||
|
||||
embodiment_config = modality_configs.get(embodiment_tag, {})
|
||||
if not isinstance(embodiment_config, dict):
|
||||
return None
|
||||
action_config = embodiment_config.get("action", {})
|
||||
if not isinstance(action_config, dict):
|
||||
return None
|
||||
delta_indices = action_config.get("delta_indices", [])
|
||||
if not isinstance(delta_indices, list):
|
||||
return None
|
||||
return len(delta_indices) or None
|
||||
|
||||
|
||||
def infer_groot_n1_7_action_execution_horizon(
|
||||
model_path: str | Path | None, embodiment_tag: str | None = None
|
||||
) -> int | None:
|
||||
action_horizon = infer_groot_n1_7_action_horizon(model_path, embodiment_tag)
|
||||
if action_horizon is None:
|
||||
return None
|
||||
|
||||
if embodiment_tag is None:
|
||||
embodiment_tag = infer_groot_n1_7_embodiment_tag(model_path)
|
||||
if embodiment_tag == "libero_sim":
|
||||
# NVIDIA's N1.7 LIBERO rollout wrapper replans after 8 of the 16 decoded
|
||||
# actions. Keeping that execution cadence avoids stale open-loop chunks.
|
||||
return min(action_horizon, 8)
|
||||
return action_horizon
|
||||
|
||||
|
||||
def resolve_groot_n1_7_backbone_model(model_name: str, cache_dir: str | Path | None = None) -> str:
|
||||
model_path = Path(model_name).expanduser()
|
||||
if model_path.exists():
|
||||
return str(model_path)
|
||||
|
||||
cached_snapshot = _find_cached_hf_snapshot(model_name, cache_dir=cache_dir)
|
||||
return str(cached_snapshot) if cached_snapshot is not None else model_name
|
||||
|
||||
|
||||
def _find_cached_hf_snapshot(repo_id: str, cache_dir: str | Path | None = None) -> Path | None:
|
||||
repo_cache_name = f"models--{repo_id.replace('/', '--')}"
|
||||
required_files = (
|
||||
"config.json",
|
||||
"tokenizer_config.json",
|
||||
"preprocessor_config.json",
|
||||
"video_preprocessor_config.json",
|
||||
)
|
||||
|
||||
for hub_cache in _candidate_hf_hub_caches(cache_dir):
|
||||
repo_cache = hub_cache / repo_cache_name
|
||||
snapshots_dir = repo_cache / "snapshots"
|
||||
if not snapshots_dir.is_dir():
|
||||
continue
|
||||
|
||||
candidates: list[Path] = []
|
||||
ref_path = repo_cache / "refs" / "main"
|
||||
try:
|
||||
ref = ref_path.read_text().strip()
|
||||
except OSError:
|
||||
ref = ""
|
||||
if ref:
|
||||
candidates.append(snapshots_dir / ref)
|
||||
candidates.extend(
|
||||
sorted(
|
||||
(path for path in snapshots_dir.iterdir() if path.is_dir()),
|
||||
key=lambda path: path.stat().st_mtime,
|
||||
reverse=True,
|
||||
)
|
||||
)
|
||||
|
||||
seen: set[Path] = set()
|
||||
for snapshot in candidates:
|
||||
if snapshot in seen:
|
||||
continue
|
||||
seen.add(snapshot)
|
||||
if all((snapshot / filename).exists() for filename in required_files):
|
||||
return snapshot
|
||||
return None
|
||||
|
||||
|
||||
def _candidate_hf_hub_caches(cache_dir: str | Path | None) -> list[Path]:
|
||||
candidates: list[Path] = []
|
||||
if cache_dir is not None:
|
||||
cache_path = Path(cache_dir).expanduser()
|
||||
candidates.append(cache_path)
|
||||
candidates.append(cache_path / "hub")
|
||||
|
||||
hub_cache = os.environ.get("HUGGINGFACE_HUB_CACHE")
|
||||
if hub_cache:
|
||||
candidates.append(Path(hub_cache).expanduser())
|
||||
|
||||
hf_home = os.environ.get("HF_HOME")
|
||||
if hf_home:
|
||||
candidates.append(Path(hf_home).expanduser() / "hub")
|
||||
|
||||
candidates.append(Path.home() / ".cache" / "huggingface" / "hub")
|
||||
|
||||
deduped: list[Path] = []
|
||||
seen: set[Path] = set()
|
||||
for candidate in candidates:
|
||||
resolved = candidate.resolve() if candidate.exists() else candidate
|
||||
if resolved not in seen:
|
||||
seen.add(resolved)
|
||||
deduped.append(candidate)
|
||||
return deduped
|
||||
|
||||
|
||||
def _infer_groot_model_version_from_local_config(model_path: str) -> str | None:
|
||||
path = Path(model_path).expanduser()
|
||||
if path.is_dir():
|
||||
config_path = path / "config.json"
|
||||
elif path.name == "config.json":
|
||||
config_path = path
|
||||
else:
|
||||
return None
|
||||
|
||||
if not config_path.exists():
|
||||
return None
|
||||
|
||||
try:
|
||||
with config_path.open() as f:
|
||||
config = json.load(f)
|
||||
except (OSError, json.JSONDecodeError):
|
||||
return None
|
||||
|
||||
return _infer_groot_model_version_from_config(config)
|
||||
|
||||
|
||||
def _infer_groot_model_version_from_config(config: dict) -> str | None:
|
||||
model_version = config.get("model_version")
|
||||
if isinstance(model_version, str):
|
||||
if model_version.lower() in _GROOT_N1_5_VERSION_ALIASES:
|
||||
return GROOT_N1_5
|
||||
try:
|
||||
return normalize_groot_model_version(model_version)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
candidates = [config.get("model_type"), *(config.get("architectures") or [])]
|
||||
for candidate in candidates:
|
||||
if not isinstance(candidate, str):
|
||||
continue
|
||||
normalized = candidate.lower().replace("-", "_")
|
||||
if normalized in {"gr00tn1d7", "gr00t_n1d7", "gr00t_n1_7"}:
|
||||
return GROOT_N1_7
|
||||
if normalized in {"gr00t_n1_5", "gr00tn1_5", "gr00t_n15", "gr00t_n1d5", "gr00tn1d5"}:
|
||||
return GROOT_N1_5
|
||||
if config.get("model_name") == GROOT_N1_7_BACKBONE_MODEL:
|
||||
return GROOT_N1_7
|
||||
# The Eagle VLM backbone is specific to pre-N1.7 GR00T checkpoints (N1.7 uses Cosmos/Qwen3-VL).
|
||||
backbone_cfg = config.get("backbone_cfg")
|
||||
if isinstance(backbone_cfg, dict) and "eagle_path" in backbone_cfg:
|
||||
return GROOT_N1_5
|
||||
return None
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("groot")
|
||||
@dataclass
|
||||
@@ -343,44 +28,35 @@ class GrootConfig(PreTrainedConfig):
|
||||
|
||||
# Basic policy settings
|
||||
n_obs_steps: int = 1
|
||||
chunk_size: int = 40
|
||||
n_action_steps: int = 40
|
||||
chunk_size: int = 50
|
||||
n_action_steps: int = 50
|
||||
|
||||
# Dimension settings (must match pretrained GR00T model expectations)
|
||||
# Maximum state dimension. Shorter states will be zero-padded.
|
||||
max_state_dim: int = 132
|
||||
max_state_dim: int = 64
|
||||
|
||||
# Maximum action dimension. Shorter actions will be zero-padded.
|
||||
max_action_dim: int = 132
|
||||
max_action_dim: int = 32
|
||||
|
||||
# GR00T normalizes state/action internally in its processor steps (min/max with
|
||||
# q01/q99 percentiles, per embodiment), and the Qwen3-VL backbone's image processor
|
||||
# handles image normalization. The policy therefore does NOT use LeRobot's
|
||||
# NormalizerProcessorStep/UnnormalizerProcessorStep, so this mapping is intentionally
|
||||
# IDENTITY for every feature and is not consulted by make_groot_pre_post_processors.
|
||||
# Normalization (start with identity, adjust as needed)
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.IDENTITY,
|
||||
"ACTION": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.MEAN_STD,
|
||||
"ACTION": NormalizationMode.MEAN_STD,
|
||||
}
|
||||
)
|
||||
|
||||
# Groot-specific model parameters
|
||||
# Image preprocessing (adjust to match Groot's expected input)
|
||||
image_size: tuple[int, int] = (224, 224)
|
||||
|
||||
# Explicit GR00T model family selection. LeRobot supports GR00T N1.7 only.
|
||||
model_version: str = GROOT_N1_7
|
||||
# Groot-specific model parameters (from groot_finetune_script.py)
|
||||
|
||||
# Path or HuggingFace model ID for the base Groot model
|
||||
base_model_path: str | None = None
|
||||
base_model_path: str = "nvidia/GR00T-N1.5-3B"
|
||||
|
||||
# HF repo ID (or local path) for the GR00T N1.7 Cosmos/Qwen3-VL backbone processor.
|
||||
n1_7_backbone_model: str = GROOT_N1_7_BACKBONE_MODEL
|
||||
|
||||
# Optional named action transform applied after raw N1.7 checkpoint decoding and before env.step().
|
||||
# 'auto' (default) resolves to the embodiment default ('libero' for 'libero_sim', otherwise no
|
||||
# transform). Pass 'none' to explicitly disable the transform, including for 'libero_sim'.
|
||||
action_decode_transform: str | None = GROOT_ACTION_DECODE_TRANSFORM_AUTO
|
||||
# HF repo ID (or local path) that hosts vocab.json and merges.txt for Eagle tokenizer.
|
||||
tokenizer_assets_repo: str = "lerobot/eagle2hg-processor-groot-n1p5"
|
||||
|
||||
# Embodiment tag to use for training (e.g. 'new_embodiment', 'gr1')
|
||||
embodiment_tag: str = "new_embodiment"
|
||||
@@ -420,16 +96,17 @@ class GrootConfig(PreTrainedConfig):
|
||||
warmup_ratio: float = 0.05
|
||||
use_bf16: bool = True
|
||||
|
||||
# TODO(Steven): Remove these deprecated fields in a future release.
|
||||
# Deprecated Isaac-GR00T runner/N1.5 fields below — unused by the LeRobot N1.7 implementation
|
||||
# (nothing in src/lerobot reads them). They are kept only so config.json files saved by
|
||||
# earlier lerobot releases still parse: draccus rejects unknown fields, so removing them
|
||||
# would break every previously saved groot checkpoint at config-load time.
|
||||
image_size: tuple[int, int] = (256, 256) # image sizing is handled by the backbone's image processor.
|
||||
tokenizer_assets_repo: str | None = None
|
||||
# Dataset parameters
|
||||
# Video backend to use for training ('decord' or 'torchvision_av')
|
||||
video_backend: str = "decord"
|
||||
|
||||
# Whether to balance dataset weights in mixture datasets
|
||||
balance_dataset_weights: bool = True
|
||||
|
||||
# Whether to sample trajectories weighted by their length
|
||||
balance_trajectory_weights: bool = True
|
||||
|
||||
# Optional dataset paths for delegating training to Isaac-GR00T runner
|
||||
dataset_paths: list[str] | None = None
|
||||
output_dir: str = "./tmp/gr00t"
|
||||
save_steps: int = 1000
|
||||
@@ -440,66 +117,6 @@ class GrootConfig(PreTrainedConfig):
|
||||
resume: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
if self.tokenizer_assets_repo is not None:
|
||||
raise ValueError(
|
||||
"Config sets 'tokenizer_assets_repo', which only existed for GR00T N1.5; this looks "
|
||||
f"like a legacy GR00T N1.5 checkpoint or config. {GROOT_N1_5_REMOVAL_GUIDANCE}"
|
||||
)
|
||||
|
||||
self.model_version = normalize_groot_model_version(self.model_version)
|
||||
self.action_decode_transform = normalize_groot_action_decode_transform(self.action_decode_transform)
|
||||
if self.base_model_path is None:
|
||||
self.base_model_path = GROOT_N1_7_BASE_MODEL
|
||||
|
||||
# The N1.7 LIBERO checkpoints emit a [0, 1] gripper action, but the LIBERO
|
||||
# simulator expects the OpenVLA/[-1, 1] sign convention. NVIDIA's rollout
|
||||
# wrapper applies this conversion; mirror it here so eval on the
|
||||
# 'libero_sim' embodiment grasps correctly instead of scoring 0% success.
|
||||
# This matches the embodiment-specific handling already done for the
|
||||
# action execution horizon (see infer_groot_n1_7_action_execution_horizon).
|
||||
# Only the 'auto' sentinel resolves to the embodiment default; an explicit
|
||||
# 'none' (normalized to None above) keeps the transform disabled.
|
||||
if self.action_decode_transform == GROOT_ACTION_DECODE_TRANSFORM_AUTO:
|
||||
self.action_decode_transform = (
|
||||
GROOT_ACTION_DECODE_TRANSFORM_LIBERO if self.embodiment_tag == "libero_sim" else None
|
||||
)
|
||||
|
||||
# GR00T N1.5-era default values (e.g. --policy.chunk_size=50 from old commands or
|
||||
# stale configs) are migrated to the values the N1.7 checkpoints expect, with a
|
||||
# warning. The dataclass defaults are already the N1.7 values, so a plain
|
||||
# GrootConfig() never triggers this.
|
||||
legacy_default_remaps = (
|
||||
("max_state_dim", 64, 132),
|
||||
("max_action_dim", 32, 132),
|
||||
("chunk_size", 50, 40),
|
||||
("n_action_steps", 50, 40),
|
||||
("image_size", (224, 224), (256, 256)),
|
||||
)
|
||||
for field_name, legacy_value, n1_7_value in legacy_default_remaps:
|
||||
current_value = getattr(self, field_name)
|
||||
if isinstance(legacy_value, tuple):
|
||||
current_value = tuple(current_value)
|
||||
if current_value == legacy_value:
|
||||
logger.warning(
|
||||
"GrootConfig.%s=%s matches a legacy GR00T N1.5-era default; remapping it to %s, "
|
||||
"the value expected by GR00T N1.7 checkpoints. Set a different value explicitly "
|
||||
"if this is not what you want.",
|
||||
field_name,
|
||||
legacy_value,
|
||||
n1_7_value,
|
||||
)
|
||||
setattr(self, field_name, n1_7_value)
|
||||
|
||||
inferred_version = infer_groot_model_version(self.base_model_path)
|
||||
if inferred_version is not None and inferred_version != self.model_version:
|
||||
message = (
|
||||
f"GR00T model_version '{self.model_version}' does not match base_model_path "
|
||||
f"'{self.base_model_path}', which looks like '{inferred_version}'."
|
||||
)
|
||||
if inferred_version == GROOT_N1_5:
|
||||
message = f"{message} {GROOT_N1_5_REMOVAL_GUIDANCE}"
|
||||
raise ValueError(message)
|
||||
|
||||
super().__post_init__()
|
||||
|
||||
if self.n_action_steps > self.chunk_size:
|
||||
@@ -575,10 +192,7 @@ class GrootConfig(PreTrainedConfig):
|
||||
@property
|
||||
def action_delta_indices(self) -> list[int]:
|
||||
"""Return indices for delta actions."""
|
||||
model_action_horizon = (
|
||||
infer_groot_n1_7_action_horizon(self.base_model_path, self.embodiment_tag) or 40
|
||||
)
|
||||
return list(range(min(self.chunk_size, model_action_horizon)))
|
||||
return list(range(min(self.chunk_size, 16)))
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
|
||||
@@ -0,0 +1,135 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import copy
|
||||
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||
from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
|
||||
from transformers.models.qwen3.configuration_qwen3 import Qwen3Config
|
||||
from transformers.models.siglip.configuration_siglip import SiglipVisionConfig
|
||||
from transformers.utils import logging
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class Eagle25VLConfig(PretrainedConfig):
|
||||
model_type = "eagle_2_5_vl"
|
||||
is_composition = True
|
||||
sub_configs = {"vision_config": SiglipVisionConfig, "text_config": Qwen2Config}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vision_config=None,
|
||||
text_config=None,
|
||||
use_backbone_lora=0,
|
||||
use_llm_lora=0,
|
||||
pad2square=False,
|
||||
select_layer=-4,
|
||||
force_image_size=None,
|
||||
downsample_ratio=0.5,
|
||||
template=None,
|
||||
dynamic_image_size=False,
|
||||
use_thumbnail=False,
|
||||
loss_version="v1",
|
||||
min_dynamic_tiles=1,
|
||||
max_dynamic_tiles=6,
|
||||
mlp_checkpoint=False,
|
||||
initializer_range=0.02,
|
||||
_attn_implementation="flash_attention_2",
|
||||
_attn_implementation_autoset=False,
|
||||
llm_config=None,
|
||||
image_token_index=None,
|
||||
use_pixel_shuffle=True,
|
||||
mlp_connector_layers=2,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
if vision_config is None:
|
||||
vision_config = {"model_type": "siglip_vision_model"}
|
||||
logger.info("vision_config is None. Initializing the InternVisionConfig with default values.")
|
||||
|
||||
if text_config is None:
|
||||
text_config = {"architectures": ["Qwen2ForCausalLM"]}
|
||||
logger.info(
|
||||
"text_config is None. Initializing the LlamaConfig config with default values (`LlamaConfig`)."
|
||||
)
|
||||
|
||||
if vision_config["model_type"] == "siglip_vision_model":
|
||||
self.vision_config = SiglipVisionConfig(**vision_config)
|
||||
else:
|
||||
raise ValueError("Unsupported model_type: {}".format(vision_config["model_type"]))
|
||||
|
||||
if text_config["architectures"][0] == "LlamaForCausalLM":
|
||||
self.text_config = LlamaConfig(**text_config)
|
||||
elif text_config["architectures"][0] == "Qwen2ForCausalLM":
|
||||
self.text_config = Qwen2Config(**text_config)
|
||||
elif text_config["architectures"][0] == "Qwen3ForCausalLM":
|
||||
self.text_config = Qwen3Config(**text_config)
|
||||
else:
|
||||
raise ValueError("Unsupported architecture: {}".format(text_config["architectures"][0]))
|
||||
self.use_backbone_lora = use_backbone_lora
|
||||
self.use_llm_lora = use_llm_lora
|
||||
self.mlp_checkpoint = mlp_checkpoint
|
||||
self.pad2square = pad2square
|
||||
self.select_layer = select_layer
|
||||
self.force_image_size = force_image_size
|
||||
self.downsample_ratio = downsample_ratio
|
||||
self.template = template
|
||||
self.dynamic_image_size = dynamic_image_size
|
||||
self.use_thumbnail = use_thumbnail
|
||||
self.loss_version = loss_version
|
||||
self.initializer_range = initializer_range
|
||||
self.min_dynamic_tiles = min_dynamic_tiles
|
||||
self.max_dynamic_tiles = max_dynamic_tiles
|
||||
self.tie_word_embeddings = self.text_config.tie_word_embeddings
|
||||
self._attn_implementation = _attn_implementation
|
||||
self._attn_implementation_autoset = _attn_implementation_autoset
|
||||
self.image_token_index = image_token_index
|
||||
self.use_pixel_shuffle = use_pixel_shuffle
|
||||
self.mlp_connector_layers = mlp_connector_layers
|
||||
logger.info(f"min_dynamic_tiles: {self.min_dynamic_tiles}")
|
||||
logger.info(f"max_dynamic_tiles: {self.max_dynamic_tiles}")
|
||||
|
||||
def to_dict(self):
|
||||
"""
|
||||
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
|
||||
|
||||
Returns:
|
||||
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
|
||||
"""
|
||||
output = copy.deepcopy(self.__dict__)
|
||||
output["vision_config"] = self.vision_config.to_dict()
|
||||
output["text_config"] = self.text_config.to_dict()
|
||||
output["model_type"] = self.__class__.model_type
|
||||
output["use_backbone_lora"] = self.use_backbone_lora
|
||||
output["use_llm_lora"] = self.use_llm_lora
|
||||
output["pad2square"] = self.pad2square
|
||||
output["select_layer"] = self.select_layer
|
||||
output["force_image_size"] = self.force_image_size
|
||||
output["downsample_ratio"] = self.downsample_ratio
|
||||
output["template"] = self.template
|
||||
output["dynamic_image_size"] = self.dynamic_image_size
|
||||
output["use_thumbnail"] = self.use_thumbnail
|
||||
output["min_dynamic_tiles"] = self.min_dynamic_tiles
|
||||
output["max_dynamic_tiles"] = self.max_dynamic_tiles
|
||||
output["tie_word_embeddings"] = self.tie_word_embeddings
|
||||
output["_attn_implementation"] = self._attn_implementation
|
||||
output["_attn_implementation_autoset"] = self._attn_implementation_autoset
|
||||
output["use_pixel_shuffle"] = self.use_pixel_shuffle
|
||||
output["mlp_connector_layers"] = self.mlp_connector_layers
|
||||
return output
|
||||
@@ -0,0 +1,503 @@
|
||||
# --------------------------------------------------------
|
||||
# NVIDIA
|
||||
# Copyright (c) 2025 NVIDIA
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
# --------------------------------------------------------
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
# copy from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava_onevision/image_processing_llava_onevision_fast.py
|
||||
from transformers.image_processing_utils import (
|
||||
BatchFeature,
|
||||
get_patch_output_size,
|
||||
)
|
||||
from transformers.image_processing_utils_fast import (
|
||||
BaseImageProcessorFast,
|
||||
ImagesKwargs,
|
||||
group_images_by_shape,
|
||||
reorder_images,
|
||||
)
|
||||
from transformers.image_utils import (
|
||||
IMAGENET_STANDARD_MEAN, # 0.5, 0.5, 0.5
|
||||
IMAGENET_STANDARD_STD, # 0.5, 0.5, 0.5
|
||||
ChannelDimension,
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
SizeDict,
|
||||
get_image_size,
|
||||
make_flat_list_of_images,
|
||||
validate_kwargs,
|
||||
)
|
||||
from transformers.processing_utils import Unpack
|
||||
from transformers.utils import (
|
||||
TensorType,
|
||||
add_start_docstrings,
|
||||
is_torch_available,
|
||||
is_torchvision_v2_available,
|
||||
)
|
||||
from transformers.video_utils import VideoInput
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
if is_torchvision_v2_available():
|
||||
from torchvision.transforms.v2 import functional as F # noqa: N812
|
||||
from transformers.image_utils import pil_torch_interpolation_mapping
|
||||
else:
|
||||
from torchvision.transforms import functional as F # noqa: N812
|
||||
|
||||
|
||||
def crop(img: torch.Tensor, left: int, top: int, right: int, bottom: int) -> torch.Tensor:
|
||||
"""Crop the given numpy array.
|
||||
|
||||
Args:
|
||||
img (torch.Tensor): Image to be cropped. Format should be (C, H, W).
|
||||
left (int): The left coordinate of the crop box.
|
||||
top (int): The top coordinate of the crop box.
|
||||
right (int): The right coordinate of the crop box.
|
||||
bottom (int): The bottom coordinate of the crop box.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Cropped image.
|
||||
"""
|
||||
if not isinstance(img, torch.Tensor):
|
||||
raise TypeError(f"img should be torch.Tensor. Got {type(img)}")
|
||||
|
||||
if img.ndim not in [2, 3]:
|
||||
raise ValueError(f"Image should have 2 or 3 dimensions. Got {img.ndim}")
|
||||
|
||||
img_height = img.shape[1]
|
||||
img_width = img.shape[2]
|
||||
if top < 0 or left < 0 or bottom > img_height or right > img_width:
|
||||
raise ValueError("Crop coordinates out of bounds")
|
||||
|
||||
if top >= bottom or left >= right:
|
||||
raise ValueError("Invalid crop coordinates")
|
||||
|
||||
return img[:, top:bottom, left:right]
|
||||
|
||||
|
||||
class Eagle25VLFastImageProcessorKwargs(ImagesKwargs):
|
||||
max_dynamic_tiles: int | None
|
||||
min_dynamic_tiles: int | None
|
||||
use_thumbnail: bool | None
|
||||
pad_during_tiling: bool | None
|
||||
do_pad: bool | None
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"Constructs a fast ConvNeXT image processor. Based on [`SiglipImageProcessor`] with incorporation of processing each video frame.",
|
||||
# BASE_IMAGE_PROCESSOR_FAST_DOCSTRING, TODO: this was depreciated from transformers remove!
|
||||
"""
|
||||
image_grid_pinpoints (`List[List[int]]`, *optional*):
|
||||
A list of possible resolutions to use for processing high resolution images. The best resolution is selected
|
||||
based on the original size of the image. Can be overridden by `image_grid_pinpoints` in the `preprocess`
|
||||
method. Not used for processing videos.
|
||||
do_pad (`bool`, *optional*):
|
||||
Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest
|
||||
number of patches in the batch. Padding will be applied to the bottom and right with zeros.
|
||||
""",
|
||||
)
|
||||
class Eagle25VLImageProcessorFast(BaseImageProcessorFast):
|
||||
resample = PILImageResampling.BICUBIC
|
||||
image_mean = IMAGENET_STANDARD_MEAN
|
||||
image_std = IMAGENET_STANDARD_STD
|
||||
size = {"height": 448, "width": 448}
|
||||
default_to_square = False
|
||||
crop_size = None
|
||||
do_resize = True
|
||||
do_center_crop = None
|
||||
do_rescale = True
|
||||
do_normalize = True
|
||||
do_convert_rgb = True
|
||||
do_pad = True
|
||||
max_dynamic_tiles = 12
|
||||
min_dynamic_tiles = 1
|
||||
use_thumbnail = True
|
||||
pad_during_tiling = False
|
||||
valid_kwargs = Eagle25VLFastImageProcessorKwargs
|
||||
model_input_names = ["pixel_values_videos"]
|
||||
|
||||
def __init__(self, **kwargs: Unpack[Eagle25VLFastImageProcessorKwargs]):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@add_start_docstrings(
|
||||
# BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS, TODO: this was depreciated from transformers remove!
|
||||
"""
|
||||
max_dynamic_tiles (`int`, *optional*):
|
||||
The maximum number of dynamic tiles to use for processing high resolution images.
|
||||
min_dynamic_tiles (`int`, *optional*):
|
||||
The minimum number of dynamic tiles to use for processing high resolution images.
|
||||
use_thumbnail (`bool`, *optional*):
|
||||
Whether to use a thumbnail for processing high resolution images.
|
||||
pad_during_tiling (`bool`, *optional*):
|
||||
Whether to pad the image during tiling.
|
||||
do_pad (`bool`, *optional*):
|
||||
Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest
|
||||
number of patches in the batch. Padding will be applied to the bottom and right with zeros.
|
||||
""",
|
||||
)
|
||||
|
||||
# NOTE(YL): we will overload the preprocess method to add the image_flags
|
||||
# def preprocess(
|
||||
# self, images: ImageInput, **kwargs: Unpack[Eagle25VLFastImageProcessorKwargs]
|
||||
# ) -> BatchFeature:
|
||||
# return super().preprocess(images, **kwargs)
|
||||
|
||||
def _prepare_images_structure(
|
||||
self,
|
||||
images: ImageInput,
|
||||
expected_ndims: int = 3,
|
||||
) -> ImageInput:
|
||||
"""
|
||||
Prepare the images structure for processing.
|
||||
|
||||
Args:
|
||||
images (`ImageInput`):
|
||||
The input images to process.
|
||||
expected_ndims (`int`, *optional*, defaults to 3):
|
||||
Expected number of dimensions for the images (added for transformers >=4.53.0 compatibility).
|
||||
|
||||
Returns:
|
||||
`ImageInput`: The images with a valid nesting.
|
||||
"""
|
||||
return make_flat_list_of_images(images)
|
||||
|
||||
def _resize_for_patching(
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
target_resolution: tuple,
|
||||
interpolation: F.InterpolationMode,
|
||||
input_data_format: ChannelDimension,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Resizes an image to a target resolution while maintaining aspect ratio.
|
||||
|
||||
Args:
|
||||
image ("torch.Tensor"):
|
||||
The input image.
|
||||
target_resolution (tuple):
|
||||
The target resolution (height, width) of the image.
|
||||
interpolation (`InterpolationMode`):
|
||||
Resampling filter to use if resizing the image.
|
||||
input_data_format (`ChannelDimension` or `str`):
|
||||
The channel dimension format of the input image.
|
||||
|
||||
Returns:
|
||||
"torch.Tensor": The resized and padded image.
|
||||
"""
|
||||
new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
|
||||
|
||||
# Resize the image
|
||||
resized_image = F.resize(image, (new_height, new_width), interpolation=interpolation)
|
||||
|
||||
return resized_image
|
||||
|
||||
def find_closest_aspect_ratio(self, aspect_ratio, target_ratios, width, height, image_size):
|
||||
"""
|
||||
previous version mainly focus on ratio.
|
||||
We also consider area ratio here.
|
||||
"""
|
||||
best_factor = float("-inf")
|
||||
best_ratio = (1, 1)
|
||||
area = width * height
|
||||
for ratio in target_ratios:
|
||||
target_aspect_ratio = ratio[0] / ratio[1]
|
||||
# ratio_diff = abs(aspect_ratio - target_aspect_ratio)
|
||||
# area_ratio = (ratio[0] * ratio[1] * image_size * image_size) / area
|
||||
"""
|
||||
new area > 60% of original image area is enough.
|
||||
"""
|
||||
factor_based_on_area_n_ratio = min(
|
||||
(ratio[0] * ratio[1] * image_size * image_size) / area, 0.6
|
||||
) * min(target_aspect_ratio / aspect_ratio, aspect_ratio / target_aspect_ratio)
|
||||
|
||||
if factor_based_on_area_n_ratio > best_factor:
|
||||
best_factor = factor_based_on_area_n_ratio
|
||||
best_ratio = ratio
|
||||
|
||||
return best_ratio
|
||||
|
||||
def _pad_for_patching(
|
||||
self, image: torch.Tensor, target_resolution: tuple, input_data_format: ChannelDimension
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Pad an image to a target resolution while maintaining aspect ratio.
|
||||
"""
|
||||
target_height, target_width = target_resolution
|
||||
new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
|
||||
|
||||
paste_x = (target_width - new_width) // 2
|
||||
paste_y = (target_height - new_height) // 2
|
||||
|
||||
padded_image = F.pad(image, padding=[paste_x, paste_y, paste_x, paste_y])
|
||||
|
||||
return padded_image
|
||||
|
||||
def _get_image_patches(
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
min_num: int,
|
||||
max_num: int,
|
||||
size: tuple,
|
||||
tile_size: int,
|
||||
use_thumbnail: bool,
|
||||
interpolation: F.InterpolationMode,
|
||||
pad_during_tiling: bool,
|
||||
) -> list[torch.Tensor]:
|
||||
image_size = get_image_size(image, channel_dim=ChannelDimension.FIRST)
|
||||
orig_height, orig_width = image_size
|
||||
aspect_ratio = orig_width / orig_height
|
||||
|
||||
# calculate the existing image aspect ratio
|
||||
target_ratios = {
|
||||
(i, j)
|
||||
for n in range(min_num, max_num + 1)
|
||||
for i in range(1, n + 1)
|
||||
for j in range(1, n + 1)
|
||||
if i * j <= max_num and i * j >= min_num
|
||||
}
|
||||
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
|
||||
|
||||
# find the closest aspect ratio to the target
|
||||
target_aspect_ratio = self.find_closest_aspect_ratio(
|
||||
aspect_ratio, target_ratios, orig_width, orig_height, tile_size
|
||||
)
|
||||
|
||||
# calculate the target width and height
|
||||
target_width = tile_size * target_aspect_ratio[0]
|
||||
target_height = tile_size * target_aspect_ratio[1]
|
||||
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
|
||||
if pad_during_tiling:
|
||||
resized_image = self._resize_for_patching(
|
||||
image,
|
||||
(target_height, target_width),
|
||||
interpolation=interpolation,
|
||||
input_data_format=ChannelDimension.FIRST,
|
||||
)
|
||||
padded_image = self._pad_for_patching(
|
||||
resized_image,
|
||||
(target_height, target_width),
|
||||
input_data_format=ChannelDimension.FIRST,
|
||||
)
|
||||
image_used_to_split = padded_image
|
||||
else:
|
||||
image_used_to_split = F.resize(image, (target_height, target_width), interpolation=interpolation)
|
||||
|
||||
processed_tiles = []
|
||||
for i in range(blocks):
|
||||
box = (
|
||||
(i % (target_width // tile_size)) * tile_size,
|
||||
(i // (target_width // tile_size)) * tile_size,
|
||||
((i % (target_width // tile_size)) + 1) * tile_size,
|
||||
((i // (target_width // tile_size)) + 1) * tile_size,
|
||||
)
|
||||
# split the image
|
||||
split_img = crop(image_used_to_split, box[0], box[1], box[2], box[3])
|
||||
processed_tiles.append(split_img)
|
||||
assert len(processed_tiles) == blocks
|
||||
|
||||
if use_thumbnail and len(processed_tiles) != 1:
|
||||
thumbnail_img = F.resize(image, (tile_size, tile_size), interpolation=interpolation)
|
||||
processed_tiles.append(thumbnail_img)
|
||||
|
||||
return processed_tiles
|
||||
|
||||
def _pad_for_batching(
|
||||
self,
|
||||
pixel_values: list[torch.Tensor],
|
||||
) -> list[torch.Tensor]:
|
||||
"""
|
||||
Pads images on the `num_of_patches` dimension with zeros to form a batch of same number of patches.
|
||||
|
||||
Args:
|
||||
pixel_values (`List[torch.Tensor]`):
|
||||
An array of pixel values of each images of shape (`batch_size`, `num_patches`, `image_in_3D`)
|
||||
|
||||
Returns:
|
||||
List[`torch.Tensor`]: The padded images.
|
||||
"""
|
||||
max_patch = max(len(x) for x in pixel_values)
|
||||
pixel_values = [
|
||||
torch.nn.functional.pad(image, pad=[0, 0, 0, 0, 0, 0, 0, max_patch - image.shape[0]])
|
||||
for image in pixel_values
|
||||
]
|
||||
|
||||
return pixel_values
|
||||
|
||||
def _preprocess(
|
||||
self,
|
||||
images: list[torch.Tensor],
|
||||
do_resize: bool,
|
||||
size: SizeDict,
|
||||
max_dynamic_tiles: int,
|
||||
min_dynamic_tiles: int,
|
||||
use_thumbnail: bool,
|
||||
pad_during_tiling: bool,
|
||||
interpolation: F.InterpolationMode | None,
|
||||
do_center_crop: bool,
|
||||
crop_size: SizeDict,
|
||||
do_rescale: bool,
|
||||
rescale_factor: float,
|
||||
do_normalize: bool,
|
||||
image_mean: float | list[float] | None,
|
||||
image_std: float | list[float] | None,
|
||||
do_pad: bool,
|
||||
return_tensors: str | TensorType | None,
|
||||
pad_size: SizeDict | None = None, # Added for transformers >=4.53.0 compatibility
|
||||
disable_grouping: bool | None = None, # Added for transformers >=4.53.0 compatibility
|
||||
) -> BatchFeature:
|
||||
processed_images = []
|
||||
image_sizes = []
|
||||
# Determine the size tuple
|
||||
if size and size.height and size.width:
|
||||
size_tuple = (size.height, size.width)
|
||||
else:
|
||||
size_tuple = (size.shortest_edge, size.shortest_edge)
|
||||
|
||||
# Determine the patch size
|
||||
if crop_size and crop_size.height:
|
||||
tile_size = crop_size.height
|
||||
elif size and size.height:
|
||||
tile_size = size.height
|
||||
else:
|
||||
tile_size = size.shortest_edge
|
||||
|
||||
for image in images:
|
||||
image_patches = self._get_image_patches(
|
||||
image,
|
||||
min_num=min_dynamic_tiles,
|
||||
max_num=max_dynamic_tiles,
|
||||
size=size_tuple,
|
||||
tile_size=tile_size,
|
||||
use_thumbnail=use_thumbnail,
|
||||
interpolation=interpolation,
|
||||
pad_during_tiling=pad_during_tiling,
|
||||
)
|
||||
|
||||
# Group images by size for batched processing
|
||||
processed_image_patches_grouped = {}
|
||||
# Added for transformers >=4.53.0 compatibility
|
||||
grouped_image_patches, grouped_image_patches_index = group_images_by_shape(
|
||||
image_patches,
|
||||
disable_grouping=disable_grouping,
|
||||
)
|
||||
|
||||
for shape, stacked_image_patches in grouped_image_patches.items():
|
||||
if do_resize:
|
||||
stacked_image_patches = self.resize(
|
||||
image=stacked_image_patches,
|
||||
size=size,
|
||||
interpolation=interpolation,
|
||||
)
|
||||
if do_center_crop:
|
||||
stacked_image_patches = self.center_crop(stacked_image_patches, crop_size)
|
||||
# Fused rescale and normalize
|
||||
stacked_image_patches = self.rescale_and_normalize(
|
||||
stacked_image_patches,
|
||||
do_rescale,
|
||||
rescale_factor,
|
||||
do_normalize,
|
||||
image_mean,
|
||||
image_std,
|
||||
)
|
||||
processed_image_patches_grouped[shape] = stacked_image_patches
|
||||
processed_image_patches = reorder_images(
|
||||
processed_image_patches_grouped, grouped_image_patches_index
|
||||
)
|
||||
processed_image_patches = (
|
||||
torch.stack(processed_image_patches, dim=0) if return_tensors else processed_image_patches
|
||||
)
|
||||
processed_images.append(processed_image_patches)
|
||||
image_sizes.append(get_image_size(image, ChannelDimension.FIRST))
|
||||
|
||||
if do_pad:
|
||||
processed_images = self._pad_for_batching(processed_images)
|
||||
|
||||
# processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
|
||||
processed_images = torch.cat(processed_images, dim=0) if return_tensors else processed_images
|
||||
return BatchFeature(
|
||||
data={"pixel_values": processed_images, "image_sizes": image_sizes},
|
||||
tensor_type=return_tensors,
|
||||
)
|
||||
|
||||
def preprocess(
|
||||
self,
|
||||
images: ImageInput,
|
||||
videos: VideoInput = None,
|
||||
**kwargs: Unpack[Eagle25VLFastImageProcessorKwargs],
|
||||
) -> BatchFeature:
|
||||
validate_kwargs(
|
||||
captured_kwargs=kwargs.keys(),
|
||||
valid_processor_keys=self.valid_kwargs.__annotations__.keys(),
|
||||
)
|
||||
# Set default kwargs from self. This ensures that if a kwarg is not provided
|
||||
# by the user, it gets its default value from the instance, or is set to None.
|
||||
for kwarg_name in self.valid_kwargs.__annotations__:
|
||||
kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None))
|
||||
|
||||
# Extract parameters that are only used for preparing the input images
|
||||
do_convert_rgb = kwargs.pop("do_convert_rgb")
|
||||
input_data_format = kwargs.pop("input_data_format")
|
||||
device = kwargs.pop("device")
|
||||
# Prepare input images
|
||||
# transformers >= 4.53.0: uses _prepare_image_like_inputs instead of _prepare_input_images
|
||||
if images is not None:
|
||||
images = self._prepare_image_like_inputs(
|
||||
images=images,
|
||||
do_convert_rgb=do_convert_rgb,
|
||||
input_data_format=input_data_format,
|
||||
device=device,
|
||||
)
|
||||
|
||||
if videos is not None:
|
||||
videos = self._prepare_image_like_inputs(
|
||||
images=videos,
|
||||
do_convert_rgb=do_convert_rgb,
|
||||
input_data_format=input_data_format,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Update kwargs that need further processing before being validated
|
||||
kwargs = self._further_process_kwargs(**kwargs)
|
||||
|
||||
# Validate kwargs
|
||||
self._validate_preprocess_kwargs(**kwargs)
|
||||
|
||||
# torch resize uses interpolation instead of resample
|
||||
# Added for transformers >=4.53.0 compatibility
|
||||
resample = kwargs.pop("resample", self.resample)
|
||||
kwargs["interpolation"] = (
|
||||
pil_torch_interpolation_mapping[resample]
|
||||
if isinstance(resample, PILImageResampling | int)
|
||||
else resample
|
||||
)
|
||||
|
||||
# Filter kwargs to only include those accepted by _preprocess
|
||||
valid_preprocess_kwargs = {
|
||||
"do_resize",
|
||||
"size",
|
||||
"max_dynamic_tiles",
|
||||
"min_dynamic_tiles",
|
||||
"use_thumbnail",
|
||||
"pad_during_tiling",
|
||||
"interpolation",
|
||||
"do_center_crop",
|
||||
"crop_size",
|
||||
"do_rescale",
|
||||
"rescale_factor",
|
||||
"do_normalize",
|
||||
"image_mean",
|
||||
"image_std",
|
||||
"do_pad",
|
||||
"return_tensors",
|
||||
"pad_size",
|
||||
"disable_grouping",
|
||||
}
|
||||
filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_preprocess_kwargs}
|
||||
if images is not None:
|
||||
return self._preprocess(images, **filtered_kwargs)
|
||||
elif videos is not None:
|
||||
return self._preprocess(videos, **filtered_kwargs)
|
||||
|
||||
|
||||
__all__ = ["Eagle25VLImageProcessorFast"]
|
||||
@@ -0,0 +1,396 @@
|
||||
# --------------------------------------------------------
|
||||
# NVIDIA
|
||||
# Copyright (c) 2025 NVIDIA
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
# --------------------------------------------------------
|
||||
|
||||
import inspect
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint as cp
|
||||
from peft import LoraConfig, get_peft_model
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from transformers import GenerationConfig
|
||||
from transformers.generation import GenerationMixin
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.models.llama.modeling_llama import LlamaForCausalLM
|
||||
from transformers.models.qwen2.modeling_qwen2 import Qwen2ForCausalLM
|
||||
from transformers.models.qwen3.modeling_qwen3 import Qwen3ForCausalLM
|
||||
from transformers.models.siglip.modeling_siglip import SiglipVisionModel
|
||||
from transformers.utils import add_start_docstrings, logging
|
||||
|
||||
from .configuration_eagle2_5_vl import Eagle25VLConfig
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# copy from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava_onevision/modeling_llava_onevision.py#L241C1-L280C1
|
||||
EAGLE2_5_VL_START_DOCSTRING = r"""
|
||||
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
||||
etc.)
|
||||
|
||||
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
||||
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
||||
and behavior.
|
||||
|
||||
Parameters:
|
||||
config ([`Eagle25VLConfig`]):
|
||||
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
||||
load the weights associated with the model, only the configuration. Check out the
|
||||
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
||||
"""
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare Eagle2_5_VL Model outputting raw hidden-states without any specific head on top.",
|
||||
EAGLE2_5_VL_START_DOCSTRING,
|
||||
)
|
||||
class Eagle25VLPreTrainedModel(PreTrainedModel):
|
||||
config_class = Eagle25VLConfig
|
||||
base_model_prefix = "model"
|
||||
main_input_name = "input_ids"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = [
|
||||
"Qwen2DecoderLayer",
|
||||
"LlamaDecoderLayer",
|
||||
"Siglip2EncoderLayer",
|
||||
"SiglipEncoderLayer",
|
||||
]
|
||||
_skip_keys_device_placement = "past_key_values"
|
||||
_supports_flash_attn = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_cache_class = True
|
||||
_supports_static_cache = True
|
||||
_supports_quantized_cache = True
|
||||
_supports_sdpa = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.initializer_range
|
||||
if isinstance(module, nn.Linear | nn.Conv2d):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
|
||||
|
||||
class Eagle25VLForConditionalGeneration(Eagle25VLPreTrainedModel, GenerationMixin):
|
||||
config_class = Eagle25VLConfig
|
||||
|
||||
def __init__(self, config: Eagle25VLConfig, vision_model=None, language_model=None):
|
||||
super().__init__(config)
|
||||
|
||||
image_size = config.force_image_size or config.vision_config.image_size
|
||||
patch_size = config.vision_config.patch_size
|
||||
self.patch_size = patch_size
|
||||
if config.use_pixel_shuffle:
|
||||
self.num_image_token = int((image_size // patch_size) ** 2 * (config.downsample_ratio**2))
|
||||
else:
|
||||
self.num_image_token = int((image_size // patch_size) ** 2)
|
||||
|
||||
self.select_layer = config.select_layer
|
||||
self.downsample_ratio = config.downsample_ratio
|
||||
self.loss_version = config.loss_version
|
||||
self.mlp_checkpoint = config.mlp_checkpoint
|
||||
self.use_pixel_shuffle = config.use_pixel_shuffle
|
||||
self.mlp_connector_layers = config.mlp_connector_layers
|
||||
logger.info(f"num_image_token: {self.num_image_token}")
|
||||
logger.info(f"mlp_checkpoint: {self.mlp_checkpoint}")
|
||||
if vision_model is not None:
|
||||
self.vision_model = vision_model
|
||||
else:
|
||||
if config.vision_config.model_type == "siglip_vision_model":
|
||||
config.vision_config._attn_implementation = "flash_attention_2"
|
||||
self.vision_model = SiglipVisionModel(config.vision_config)
|
||||
else:
|
||||
raise NotImplementedError(f"{config.vision_config.model_type} is not implemented.")
|
||||
|
||||
if language_model is not None:
|
||||
self.language_model = language_model
|
||||
else:
|
||||
if config.text_config.architectures[0] == "LlamaForCausalLM":
|
||||
self.language_model = LlamaForCausalLM(config.text_config)
|
||||
elif config.text_config.architectures[0] == "Phi3ForCausalLM":
|
||||
raise NotImplementedError("Phi3 is not implemented.")
|
||||
# self.language_model = Phi3ForCausalLM(config.text_config)
|
||||
elif config.text_config.architectures[0] == "Qwen2ForCausalLM":
|
||||
assert config.text_config._attn_implementation == "flash_attention_2", (
|
||||
f"Qwen2 must use flash_attention_2 but got {config.text_config._attn_implementation}"
|
||||
)
|
||||
self.language_model = Qwen2ForCausalLM(config.text_config)
|
||||
elif config.text_config.architectures[0] == "Qwen3ForCausalLM":
|
||||
self.language_model = Qwen3ForCausalLM(config.text_config)
|
||||
else:
|
||||
raise NotImplementedError(f"{config.text_config.architectures[0]} is not implemented.")
|
||||
|
||||
vit_hidden_size = config.vision_config.hidden_size
|
||||
llm_hidden_size = config.text_config.hidden_size
|
||||
|
||||
if config.mlp_connector_layers == 2:
|
||||
self.mlp1 = nn.Sequential(
|
||||
nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2),
|
||||
nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio) ** 2, llm_hidden_size),
|
||||
nn.GELU(),
|
||||
nn.Linear(llm_hidden_size, llm_hidden_size),
|
||||
)
|
||||
elif config.mlp_connector_layers == 1 and config.use_pixel_shuffle:
|
||||
self.mlp1 = nn.Sequential(
|
||||
nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio) ** 2, llm_hidden_size),
|
||||
)
|
||||
elif config.mlp_connector_layers == 1 and not config.use_pixel_shuffle:
|
||||
self.mlp1 = nn.Sequential(
|
||||
nn.Linear(vit_hidden_size, llm_hidden_size),
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"{config.mlp_connector_layers} is not implemented.")
|
||||
|
||||
self.image_token_index = config.image_token_index
|
||||
self.neftune_alpha = None
|
||||
|
||||
if config.use_backbone_lora:
|
||||
self.wrap_backbone_lora(r=config.use_backbone_lora, lora_alpha=2 * config.use_backbone_lora)
|
||||
|
||||
self.use_llm_lora = config.use_llm_lora
|
||||
if config.use_llm_lora:
|
||||
self.wrap_llm_lora(r=config.use_llm_lora, lora_alpha=2 * config.use_llm_lora)
|
||||
|
||||
self.check_forward_kwargs()
|
||||
|
||||
def check_forward_kwargs(self):
|
||||
# We intentionally avoid using **kwargs in forward because Hugging Face Transformers
|
||||
# has special handling for functions with **kwargs parameters that would affect
|
||||
# how our model is processed during training and inference.
|
||||
forward_params = inspect.signature(self.forward).parameters
|
||||
assert not any(k.kind == inspect.Parameter.VAR_KEYWORD for k in forward_params.values())
|
||||
|
||||
def wrap_backbone_lora(self, r=128, lora_alpha=256, lora_dropout=0.05):
|
||||
lora_config = LoraConfig(
|
||||
r=r,
|
||||
target_modules=[
|
||||
"self_attn.q_proj",
|
||||
"self_attn.k_proj",
|
||||
"self_attn.v_proj",
|
||||
"self_attn.out_proj",
|
||||
"mlp.fc1",
|
||||
"mlp.fc2",
|
||||
],
|
||||
lora_alpha=lora_alpha,
|
||||
lora_dropout=lora_dropout,
|
||||
)
|
||||
self.vision_model = get_peft_model(self.vision_model, lora_config)
|
||||
self.vision_model.print_trainable_parameters()
|
||||
|
||||
def wrap_llm_lora(self, r=128, lora_alpha=256, lora_dropout=0.05):
|
||||
lora_config = LoraConfig(
|
||||
r=r,
|
||||
target_modules=[
|
||||
"self_attn.q_proj",
|
||||
"self_attn.k_proj",
|
||||
"self_attn.v_proj",
|
||||
"self_attn.o_proj",
|
||||
"mlp.gate_proj",
|
||||
"mlp.down_proj",
|
||||
"mlp.up_proj",
|
||||
],
|
||||
lora_alpha=lora_alpha,
|
||||
lora_dropout=lora_dropout,
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
self.language_model = get_peft_model(self.language_model, lora_config)
|
||||
self.language_model.enable_input_require_grads()
|
||||
self.language_model.print_trainable_parameters()
|
||||
self.use_llm_lora = True
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
position_ids: torch.LongTensor | None = None,
|
||||
image_flags: torch.LongTensor | None = None,
|
||||
past_key_values: list[torch.FloatTensor] | None = None,
|
||||
labels: torch.LongTensor | None = None,
|
||||
use_cache: bool | None = None,
|
||||
output_attentions: bool | None = None,
|
||||
output_hidden_states: bool | None = None,
|
||||
return_dict: bool | None = None,
|
||||
num_tiles_list: list[torch.Tensor] | None = None,
|
||||
) -> tuple | CausalLMOutputWithPast:
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
input_embeds = self.language_model.get_input_embeddings()(input_ids)
|
||||
|
||||
vit_embeds = self.extract_feature(pixel_values)
|
||||
|
||||
if image_flags is not None:
|
||||
image_flags = image_flags.view(-1)
|
||||
vit_embeds = vit_embeds[image_flags == 1]
|
||||
|
||||
b, n, c = input_embeds.shape
|
||||
input_embeds = input_embeds.reshape(b * n, c)
|
||||
|
||||
input_ids = input_ids.reshape(b * n)
|
||||
selected = input_ids == self.image_token_index
|
||||
try:
|
||||
input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds.reshape(-1, c)
|
||||
except Exception as e:
|
||||
vit_embeds = vit_embeds.reshape(-1, c)
|
||||
print(
|
||||
f"warning: {e}, input_embeds[selected].shape={input_embeds[selected].shape}, "
|
||||
f"vit_embeds.shape={vit_embeds.shape}"
|
||||
)
|
||||
n_token = selected.sum()
|
||||
input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds[:n_token]
|
||||
|
||||
input_embeds = input_embeds.reshape(b, n, c)
|
||||
|
||||
outputs = self.language_model(
|
||||
inputs_embeds=input_embeds,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
logits = outputs.logits
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_logits = shift_logits.view(-1, self.language_model.config.vocab_size)
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
def pixel_shuffle(self, x, scale_factor=0.5):
|
||||
n, w, h, c = x.size()
|
||||
# N, W, H, C --> N, W, H * scale, C // scale
|
||||
x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
|
||||
# N, W, H * scale, C // scale --> N, H * scale, W, C // scale
|
||||
x = x.permute(0, 2, 1, 3).contiguous()
|
||||
# N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2)
|
||||
x = x.view(n, int(h * scale_factor), int(w * scale_factor), int(c / (scale_factor * scale_factor)))
|
||||
|
||||
x = x.permute(0, 2, 1, 3).contiguous()
|
||||
return x
|
||||
|
||||
def extract_feature(self, pixel_values):
|
||||
if self.select_layer == -1:
|
||||
vit_embeds = self.vision_model(
|
||||
pixel_values=pixel_values, output_hidden_states=False, return_dict=True
|
||||
)
|
||||
if hasattr(vit_embeds, "last_hidden_state"):
|
||||
vit_embeds = vit_embeds.last_hidden_state
|
||||
|
||||
else:
|
||||
vit_embeds = self.vision_model(
|
||||
pixel_values=pixel_values, output_hidden_states=True, return_dict=True
|
||||
).hidden_states[self.select_layer]
|
||||
|
||||
if self.use_pixel_shuffle:
|
||||
h = w = int(vit_embeds.shape[1] ** 0.5)
|
||||
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
|
||||
vit_embeds = self.pixel_shuffle(
|
||||
vit_embeds, scale_factor=self.downsample_ratio
|
||||
) # torch.Size([B, 1024, 1024]) -> torch.Size([B, 16, 16, 4096])
|
||||
vit_embeds = vit_embeds.reshape(
|
||||
vit_embeds.shape[0], -1, vit_embeds.shape[-1]
|
||||
) # torch.Size([B, 16, 16, 4096]) -> torch.Size([B, 256, 4096])
|
||||
|
||||
if self.mlp_checkpoint and vit_embeds.requires_grad:
|
||||
vit_embeds = cp.checkpoint(self.mlp1, vit_embeds)
|
||||
else:
|
||||
vit_embeds = self.mlp1(vit_embeds)
|
||||
|
||||
return vit_embeds
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor | None = None,
|
||||
input_ids: torch.FloatTensor | None = None,
|
||||
attention_mask: torch.LongTensor | None = None,
|
||||
visual_features: torch.FloatTensor | None = None,
|
||||
generation_config: GenerationConfig | None = None,
|
||||
output_hidden_states: bool | None = None,
|
||||
image_sizes: list[tuple[int, int]] | None = None,
|
||||
**generate_kwargs,
|
||||
) -> torch.LongTensor:
|
||||
if pixel_values is not None:
|
||||
if visual_features is not None:
|
||||
vit_embeds = visual_features
|
||||
else:
|
||||
vit_embeds = self.extract_feature(pixel_values)
|
||||
|
||||
input_embeds = self.language_model.get_input_embeddings()(input_ids)
|
||||
b, n, c = input_embeds.shape
|
||||
input_embeds = input_embeds.reshape(b * n, c)
|
||||
|
||||
input_ids = input_ids.reshape(b * n)
|
||||
selected = input_ids == self.config.image_token_index
|
||||
assert selected.sum() != 0
|
||||
input_embeds[selected] = vit_embeds.reshape(-1, c).to(input_embeds.device)
|
||||
|
||||
input_embeds = input_embeds.reshape(b, n, c)
|
||||
else:
|
||||
input_embeds = self.language_model.get_input_embeddings()(input_ids)
|
||||
|
||||
if "use_cache" not in generate_kwargs:
|
||||
generate_kwargs["use_cache"] = True
|
||||
|
||||
outputs = self.language_model.generate(
|
||||
inputs_embeds=input_embeds,
|
||||
attention_mask=attention_mask,
|
||||
generation_config=generation_config,
|
||||
output_hidden_states=output_hidden_states,
|
||||
**generate_kwargs,
|
||||
)
|
||||
|
||||
return outputs
|
||||
|
||||
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.get_input_embeddings
|
||||
def get_input_embeddings(self):
|
||||
return self.language_model.get_input_embeddings()
|
||||
|
||||
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.set_input_embeddings
|
||||
def set_input_embeddings(self, value):
|
||||
self.language_model.set_input_embeddings(value)
|
||||
|
||||
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.get_output_embeddings
|
||||
def get_output_embeddings(self):
|
||||
return self.language_model.get_output_embeddings()
|
||||
|
||||
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.set_output_embeddings
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.language_model.set_output_embeddings(new_embeddings)
|
||||
|
||||
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.set_decoder
|
||||
def set_decoder(self, decoder):
|
||||
self.language_model.set_decoder(decoder)
|
||||
|
||||
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.get_decoder
|
||||
def get_decoder(self):
|
||||
return self.language_model.get_decoder()
|
||||
@@ -0,0 +1,541 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Processor class for Eagle25VL.
|
||||
copy from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava_onevision/processing_llava_onevision.py
|
||||
"""
|
||||
|
||||
import base64
|
||||
import os
|
||||
import re
|
||||
from io import BytesIO
|
||||
|
||||
import requests
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers.feature_extraction_utils import BatchFeature
|
||||
from transformers.image_utils import ImageInput
|
||||
from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
|
||||
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
|
||||
from transformers.utils import logging
|
||||
from transformers.video_utils import VideoInput
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
FRAME_FACTOR = 2
|
||||
FPS = 2.0
|
||||
FPS_MIN_FRAMES = 4
|
||||
FPS_MAX_FRAMES = 256
|
||||
|
||||
|
||||
def to_rgb(pil_image: Image.Image) -> Image.Image:
|
||||
if pil_image.mode == "RGBA":
|
||||
white_background = Image.new("RGB", pil_image.size, (255, 255, 255))
|
||||
white_background.paste(pil_image, mask=pil_image.split()[3]) # Use alpha channel as mask
|
||||
return white_background
|
||||
else:
|
||||
return pil_image.convert("RGB")
|
||||
|
||||
|
||||
def fetch_image(ele: dict[str, str | Image.Image]) -> Image.Image:
|
||||
image = ele["image"] if "image" in ele else ele["image_url"]
|
||||
image_obj = None
|
||||
if isinstance(image, Image.Image):
|
||||
image_obj = image
|
||||
elif image.startswith("http://") or image.startswith("https://"):
|
||||
response = requests.get(image, stream=True, timeout=10)
|
||||
image_obj = Image.open(BytesIO(response.content))
|
||||
elif image.startswith("file://"):
|
||||
image_obj = Image.open(image[7:])
|
||||
elif image.startswith("data:image"):
|
||||
if "base64," in image:
|
||||
_, base64_data = image.split("base64,", 1)
|
||||
data = base64.b64decode(base64_data)
|
||||
image_obj = Image.open(BytesIO(data))
|
||||
else:
|
||||
image_obj = Image.open(image)
|
||||
if image_obj is None:
|
||||
raise ValueError(
|
||||
f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}"
|
||||
)
|
||||
image = to_rgb(image_obj)
|
||||
if "scale_factor" in ele:
|
||||
scale_factor = ele["scale_factor"]
|
||||
image = image.resize((image.width * scale_factor, image.height * scale_factor), Image.BILINEAR)
|
||||
return image
|
||||
|
||||
|
||||
class Eagle25VLProcessorKwargs(ProcessingKwargs, total=False):
|
||||
# see processing_utils.ProcessingKwargs documentation for usage.
|
||||
_defaults = {
|
||||
"text_kwargs": {
|
||||
"padding": False,
|
||||
},
|
||||
"images_kwargs": {},
|
||||
"videos_kwargs": {"max_dynamic_tiles": 1},
|
||||
}
|
||||
|
||||
|
||||
class Eagle25VLProcessor(ProcessorMixin):
|
||||
r"""
|
||||
Constructs a Eagle25VL processor which wraps a Eagle25VL video processor, Eagle25VL image processor and a Eagle25VL tokenizer into a single processor.
|
||||
|
||||
[`Eagle25VLProcessor`] offers all the functionalities of [`Eagle25VLVideoProcessor`], [`Eagle25VLImageProcessor`] and [`Eagle25VLTokenizer`]. See the
|
||||
[`~Eagle25VLVideoProcessor.__call__`], [`~Eagle25VLProcessor.__call__`] and [`~Eagle25VLProcessor.decode`] for more information.
|
||||
|
||||
Args:
|
||||
image_processor ([`LlavaOnevisionImageProcessor`], *optional*):
|
||||
The image processor is a required input.
|
||||
tokenizer ([`LlamaTokenizerFast`], *optional*):
|
||||
The tokenizer is a required input.
|
||||
num_image_tokens (`int`, *optional*):
|
||||
Number of image tokens for one imagethat will be returned by vision tower.
|
||||
vision_feature_select_strategy (`str`, *optional*):
|
||||
The feature selection strategy used to select the vision feature from the vision backbone.
|
||||
Should be same as in model's config
|
||||
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
|
||||
in a chat into a tokenizable string.
|
||||
image_token (`str`, *optional*, defaults to `"<image>"`):
|
||||
Special token used to denote image location.
|
||||
video_token (`str`, *optional*, defaults to `"<video>"`):
|
||||
Special token used to denote video location.
|
||||
"""
|
||||
|
||||
attributes = ["image_processor", "tokenizer"]
|
||||
valid_kwargs = [
|
||||
"chat_template",
|
||||
"num_image_tokens",
|
||||
"vision_feature_select_strategy",
|
||||
"image_token",
|
||||
"video_token",
|
||||
"images_kwargs",
|
||||
"videos_kwargs",
|
||||
"text_kwargs",
|
||||
]
|
||||
tokenizer_class = "AutoTokenizer"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_processor=None,
|
||||
tokenizer=None,
|
||||
vision_feature_select_strategy=None,
|
||||
chat_template=None,
|
||||
image_token="<IMG_CONTEXT>", # nosec: B107
|
||||
video_token="<IMG_CONTEXT>", # nosec: B107
|
||||
tokens_per_tile=256,
|
||||
image_placeholder="image",
|
||||
video_placeholder="video",
|
||||
image_start_token="<img>",
|
||||
image_end_token="</img>",
|
||||
**kwargs,
|
||||
):
|
||||
self.vision_feature_select_strategy = vision_feature_select_strategy
|
||||
self.image_token = tokenizer.image_token if hasattr(tokenizer, "image_token") else image_token
|
||||
self.video_token = tokenizer.video_token if hasattr(tokenizer, "video_token") else video_token
|
||||
self.image_token_id = (
|
||||
tokenizer.image_token_id
|
||||
if getattr(tokenizer, "image_token_id", None)
|
||||
else tokenizer.convert_tokens_to_ids(self.image_token)
|
||||
)
|
||||
self.video_token_id = (
|
||||
tokenizer.video_token_id
|
||||
if getattr(tokenizer, "video_token_id", None)
|
||||
else tokenizer.convert_tokens_to_ids(self.video_token)
|
||||
)
|
||||
self.image_placeholder = image_placeholder
|
||||
self.video_placeholder = video_placeholder
|
||||
self.tokens_per_tile = tokens_per_tile
|
||||
self.image_start_token = image_start_token
|
||||
self.image_end_token = image_end_token
|
||||
if "auto_map" in kwargs:
|
||||
self.auto_map = kwargs["auto_map"]
|
||||
super().__init__(image_processor, tokenizer, chat_template=chat_template)
|
||||
|
||||
def replace_media_placeholder(
|
||||
self, text, image_list, video_list, timestamps_list, fps_list, **output_kwargs
|
||||
):
|
||||
num_of_images_in_this_sample = 0
|
||||
num_of_videos_in_this_sample = 0
|
||||
# Regular expression pattern to match formats like <image-1> or <video-2>
|
||||
pattern = re.compile(rf"<({self.image_placeholder}|{self.video_placeholder})-(\d+)>")
|
||||
unified_frame_list = []
|
||||
|
||||
# image_min_dynamic_tiles = output_kwargs["images_kwargs"].get(
|
||||
# "min_dynamic_tiles", self.image_processor.min_dynamic_tiles
|
||||
# )
|
||||
# image_max_dynamic_tiles = output_kwargs["images_kwargs"].get(
|
||||
# "max_dynamic_tiles", self.image_processor.max_dynamic_tiles
|
||||
# )
|
||||
# image_use_thumbnail = output_kwargs["images_kwargs"].get(
|
||||
# "use_thumbnail", self.image_processor.use_thumbnail
|
||||
# )
|
||||
video_min_dynamic_tiles = output_kwargs["videos_kwargs"].get(
|
||||
"min_dynamic_tiles", self.image_processor.min_dynamic_tiles
|
||||
)
|
||||
video_max_dynamic_tiles = output_kwargs["videos_kwargs"].get(
|
||||
"max_dynamic_tiles", self.image_processor.max_dynamic_tiles
|
||||
)
|
||||
video_use_thumbnail = output_kwargs["videos_kwargs"].get(
|
||||
"use_thumbnail", self.image_processor.use_thumbnail
|
||||
)
|
||||
|
||||
tile_size = self.image_processor.size.get("height", 448)
|
||||
|
||||
# Function to replace tags in a single text
|
||||
def replace_in_text(text):
|
||||
# repl callback function for each match replacement operation
|
||||
def repl(match):
|
||||
nonlocal unified_frame_list
|
||||
nonlocal num_of_images_in_this_sample
|
||||
nonlocal num_of_videos_in_this_sample
|
||||
media_type = match.group(1) # 'image' or 'video'
|
||||
idx_in_list = int(match.group(2)) - 1 # Convert to list index (0-based)
|
||||
# Select the corresponding path based on media type
|
||||
idx_mapper = {
|
||||
0: "first",
|
||||
1: "second",
|
||||
2: "third",
|
||||
3: "fourth",
|
||||
4: "fifth",
|
||||
5: "sixth",
|
||||
6: "seventh",
|
||||
7: "eighth",
|
||||
8: "ninth",
|
||||
9: "tenth",
|
||||
}
|
||||
if media_type == "image":
|
||||
image_inputs = self.image_processor(
|
||||
images=[image_list[idx_in_list]],
|
||||
videos=None,
|
||||
**output_kwargs["images_kwargs"],
|
||||
)
|
||||
if isinstance(image_inputs["pixel_values"], list):
|
||||
_pv = image_inputs["pixel_values"]
|
||||
if _pv and isinstance(_pv[0], list):
|
||||
_pv = [t for sub in _pv for t in sub]
|
||||
image_inputs["pixel_values"] = torch.stack(
|
||||
[t if isinstance(t, torch.Tensor) else torch.as_tensor(t) for t in _pv]
|
||||
)
|
||||
num_all_tiles = image_inputs["pixel_values"].shape[0]
|
||||
special_placeholder = f"<image {idx_in_list + 1}>{self.image_start_token}{self.image_token * num_all_tiles * self.tokens_per_tile}{self.image_end_token}"
|
||||
unified_frame_list.append(image_inputs)
|
||||
num_of_images_in_this_sample += 1
|
||||
|
||||
elif media_type == "video":
|
||||
video_inputs = self.image_processor(
|
||||
images=None,
|
||||
videos=[video_list[idx_in_list]],
|
||||
**output_kwargs["videos_kwargs"],
|
||||
)
|
||||
if isinstance(video_inputs["pixel_values"], list):
|
||||
_pv = video_inputs["pixel_values"]
|
||||
if _pv and isinstance(_pv[0], list):
|
||||
_pv = [t for sub in _pv for t in sub]
|
||||
video_inputs["pixel_values"] = torch.stack(
|
||||
[t if isinstance(t, torch.Tensor) else torch.as_tensor(t) for t in _pv]
|
||||
)
|
||||
num_all_tiles = video_inputs["pixel_values"].shape[0]
|
||||
image_sizes = video_inputs["image_sizes"]
|
||||
if timestamps_list is not None and -1 not in timestamps_list:
|
||||
frame_timestamps = timestamps_list[idx_in_list]
|
||||
else:
|
||||
frame_timestamps = None
|
||||
sampled_fps = fps_list[idx_in_list] if fps_list is not None else None
|
||||
|
||||
num_of_tiles_each_frame = [
|
||||
self.get_number_tiles_based_on_image_size(
|
||||
image_size,
|
||||
video_min_dynamic_tiles,
|
||||
video_max_dynamic_tiles,
|
||||
video_use_thumbnail,
|
||||
tile_size,
|
||||
)
|
||||
for image_size in image_sizes
|
||||
]
|
||||
assert sum(num_of_tiles_each_frame) == num_all_tiles, (
|
||||
f"The number of tiles in each frame is not equal to the total number of tiles: {sum(num_of_tiles_each_frame)} != {num_all_tiles}"
|
||||
)
|
||||
|
||||
if frame_timestamps is not None:
|
||||
assert len(frame_timestamps) == len(num_of_tiles_each_frame), (
|
||||
f"The number of timestamps is not equal to the number of frames: {len(frame_timestamps)} != {len(num_of_tiles_each_frame)}"
|
||||
)
|
||||
special_placeholder = [
|
||||
f"Frame {i + 1} sample at {frame_timestamps[i]:.2f}s: {self.image_start_token}{self.image_token * num_of_tiles * self.tokens_per_tile}{self.image_end_token}"
|
||||
for i, num_of_tiles in enumerate(num_of_tiles_each_frame)
|
||||
]
|
||||
else:
|
||||
special_placeholder = [
|
||||
f"Frame {i + 1}: {self.image_start_token}{self.image_token * num_of_tiles * self.tokens_per_tile}{self.image_end_token}"
|
||||
for i, num_of_tiles in enumerate(num_of_tiles_each_frame)
|
||||
]
|
||||
|
||||
if sampled_fps is not None:
|
||||
special_placeholder = (
|
||||
f"The {idx_mapper[idx_in_list]} video sampled with {sampled_fps:.2f} fps: "
|
||||
+ "".join(special_placeholder)
|
||||
)
|
||||
else:
|
||||
special_placeholder = f"The {idx_mapper[idx_in_list]} video: " + "".join(
|
||||
special_placeholder
|
||||
)
|
||||
unified_frame_list.append(video_inputs)
|
||||
num_of_videos_in_this_sample += 1
|
||||
else:
|
||||
raise ValueError(f"Unknown media type: {media_type}")
|
||||
return special_placeholder
|
||||
|
||||
return pattern.sub(repl, text)
|
||||
|
||||
text = replace_in_text(text)
|
||||
if len(unified_frame_list) > 0:
|
||||
|
||||
def _to_tensor(v):
|
||||
if isinstance(v, torch.Tensor):
|
||||
return v
|
||||
if isinstance(v, list):
|
||||
if v and isinstance(v[0], list):
|
||||
v = [t for sub in v for t in sub]
|
||||
return torch.stack([t if isinstance(t, torch.Tensor) else torch.as_tensor(t) for t in v])
|
||||
return torch.as_tensor(v)
|
||||
|
||||
pixel_values = torch.cat([_to_tensor(frame["pixel_values"]) for frame in unified_frame_list])
|
||||
image_sizes = torch.cat([_to_tensor(frame["image_sizes"]) for frame in unified_frame_list])
|
||||
else:
|
||||
pixel_values = None
|
||||
image_sizes = None
|
||||
return (
|
||||
text,
|
||||
pixel_values,
|
||||
image_sizes,
|
||||
num_of_images_in_this_sample,
|
||||
num_of_videos_in_this_sample,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
images: ImageInput = None,
|
||||
text: TextInput | PreTokenizedInput | list[TextInput] | list[PreTokenizedInput] = None,
|
||||
audio=None,
|
||||
videos: VideoInput = None,
|
||||
**kwargs: Unpack[Eagle25VLProcessorKwargs],
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
|
||||
and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`] if `text` is not `None` to encode
|
||||
the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to
|
||||
LlavaNextImageProcessor's [`~LlavaNextImageProcessor.__call__`] if `images` is not `None`. Please refer to the docstring
|
||||
of the above two methods for more information.
|
||||
|
||||
Args:
|
||||
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
||||
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
||||
tensor. Both channels-first and channels-last formats are supported.
|
||||
text (`str`, `List[str]`, `List[List[str]]`):
|
||||
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
||||
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
||||
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
||||
videos (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
||||
The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch
|
||||
|
||||
Returns:
|
||||
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
|
||||
|
||||
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
|
||||
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
|
||||
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
|
||||
`None`).
|
||||
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
|
||||
- **pixel_values_videos** -- Pixel values of a video input to be fed to a model. Returned when `videos` is not `None`.
|
||||
- **image_sizes** -- Size of each image that will be used to unpad an image. Returned when `images` is not `None`.
|
||||
"""
|
||||
|
||||
output_kwargs = self._merge_kwargs(
|
||||
Eagle25VLProcessorKwargs,
|
||||
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if isinstance(text, str):
|
||||
text_list = [text]
|
||||
elif not isinstance(text, list) and not isinstance(text[0], str):
|
||||
raise ValueError("Invalid input text. Please provide a string, or a list of strings")
|
||||
elif isinstance(text, list) and isinstance(text[0], str):
|
||||
text_list = text
|
||||
|
||||
if images is None:
|
||||
images = []
|
||||
if videos is None:
|
||||
videos = []
|
||||
|
||||
pixel_values_list = []
|
||||
image_sizes_list = []
|
||||
new_sample_list = []
|
||||
image_start_idx = 0
|
||||
video_start_idx = 0
|
||||
timestamps_batch = output_kwargs["videos_kwargs"].pop("timestamps", None)
|
||||
fps_batch = output_kwargs["videos_kwargs"].pop("fps", None)
|
||||
for sample in text_list:
|
||||
timestamps_list = timestamps_batch[video_start_idx:] if timestamps_batch is not None else None
|
||||
fps_list = fps_batch[video_start_idx:] if fps_batch is not None else None
|
||||
(
|
||||
sample,
|
||||
pixel_values,
|
||||
image_sizes,
|
||||
num_of_images_in_this_sample,
|
||||
num_of_videos_in_this_sample,
|
||||
) = self.replace_media_placeholder(
|
||||
sample,
|
||||
images[image_start_idx:],
|
||||
videos[video_start_idx:],
|
||||
timestamps_list,
|
||||
fps_list,
|
||||
**output_kwargs,
|
||||
)
|
||||
new_sample_list.append(sample)
|
||||
if pixel_values is not None:
|
||||
pixel_values_list.append(pixel_values)
|
||||
image_sizes_list.append(image_sizes)
|
||||
image_start_idx += num_of_images_in_this_sample
|
||||
video_start_idx += num_of_videos_in_this_sample
|
||||
|
||||
if len(pixel_values_list) > 0:
|
||||
image_inputs = {
|
||||
"pixel_values": torch.cat(pixel_values_list),
|
||||
"image_sizes": torch.cat(image_sizes_list),
|
||||
}
|
||||
else:
|
||||
image_inputs = {}
|
||||
video_inputs = {}
|
||||
text_inputs = self.tokenizer(new_sample_list, **output_kwargs["text_kwargs"])
|
||||
return BatchFeature(data={**text_inputs, **image_inputs, **video_inputs})
|
||||
|
||||
def get_number_tiles_based_on_image_size(
|
||||
self, image_size: tuple, min_num: int, max_num: int, use_thumbnail: bool, tile_size: int
|
||||
) -> int:
|
||||
"""
|
||||
Get the number of tiles based on the image size.
|
||||
"""
|
||||
orig_height, orig_width = image_size
|
||||
aspect_ratio = orig_width / orig_height
|
||||
# calculate the existing image aspect ratio
|
||||
target_ratios = {
|
||||
(i, j)
|
||||
for n in range(min_num, max_num + 1)
|
||||
for i in range(1, n + 1)
|
||||
for j in range(1, n + 1)
|
||||
if i * j <= max_num and i * j >= min_num
|
||||
}
|
||||
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
|
||||
|
||||
# find the closest aspect ratio to the target
|
||||
target_aspect_ratio = self.image_processor.find_closest_aspect_ratio(
|
||||
aspect_ratio, target_ratios, orig_width, orig_height, tile_size
|
||||
)
|
||||
tiles_num = target_aspect_ratio[0] * target_aspect_ratio[1]
|
||||
if use_thumbnail and tiles_num > 1:
|
||||
tiles_num += 1
|
||||
return tiles_num
|
||||
|
||||
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
|
||||
def batch_decode(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
|
||||
refer to the docstring of this method for more information.
|
||||
"""
|
||||
return self.tokenizer.batch_decode(*args, **kwargs)
|
||||
|
||||
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Llama
|
||||
def decode(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
|
||||
the docstring of this method for more information.
|
||||
"""
|
||||
return self.tokenizer.decode(*args, **kwargs)
|
||||
|
||||
@property
|
||||
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names
|
||||
def model_input_names(self):
|
||||
tokenizer_input_names = self.tokenizer.model_input_names
|
||||
image_processor_input_names = self.image_processor.model_input_names
|
||||
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
|
||||
|
||||
# override to save video-config in a separate config file
|
||||
def save_pretrained(self, save_directory, **kwargs):
|
||||
if os.path.isfile(save_directory):
|
||||
raise ValueError(f"Provided path ({save_directory}) should be a directory, not a file")
|
||||
os.makedirs(save_directory, exist_ok=True)
|
||||
|
||||
outputs = super().save_pretrained(save_directory, **kwargs)
|
||||
return outputs
|
||||
|
||||
# override to load video-config from a separate config file
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
||||
processor = super().from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
# if return_unused_kwargs a tuple is returned where the second element is 'unused_kwargs'
|
||||
if isinstance(processor, tuple):
|
||||
processor = processor[0]
|
||||
return processor
|
||||
|
||||
# Copy from https://github.com/QwenLM/Qwen2.5-VL/blob/main/qwen-vl-utils/src/qwen_vl_utils/vision_process.py
|
||||
def process_vision_info(
|
||||
self,
|
||||
conversations: list[dict] | list[list[dict]],
|
||||
return_video_kwargs: bool = False,
|
||||
) -> tuple[list[Image.Image] | None, list[torch.Tensor | list[Image.Image]] | None, dict | None]:
|
||||
vision_infos = self.extract_vision_info(conversations)
|
||||
## Read images or videos
|
||||
image_inputs = []
|
||||
video_inputs = []
|
||||
video_sample_fps_list = []
|
||||
video_timestamps_list = []
|
||||
for vision_info in vision_infos:
|
||||
if "image" in vision_info or "image_url" in vision_info:
|
||||
image_inputs.append(fetch_image(vision_info))
|
||||
else:
|
||||
raise ValueError("image, image_url or video should in content.")
|
||||
if len(image_inputs) == 0:
|
||||
image_inputs = None
|
||||
if len(video_inputs) == 0:
|
||||
video_inputs = None
|
||||
if return_video_kwargs:
|
||||
return (
|
||||
image_inputs,
|
||||
video_inputs,
|
||||
{"fps": video_sample_fps_list, "timestamps": video_timestamps_list},
|
||||
)
|
||||
return image_inputs, video_inputs
|
||||
|
||||
def extract_vision_info(self, conversations: list[dict] | list[list[dict]]) -> list[dict]:
|
||||
vision_infos = []
|
||||
if isinstance(conversations[0], dict):
|
||||
conversations = [conversations]
|
||||
for conversation in conversations:
|
||||
for message in conversation:
|
||||
if isinstance(message["content"], list):
|
||||
for ele in message["content"]:
|
||||
if (
|
||||
"image" in ele
|
||||
or "image_url" in ele
|
||||
or "video" in ele
|
||||
or ele["type"] in ("image", "image_url", "video")
|
||||
):
|
||||
vision_infos.append(ele)
|
||||
return vision_infos
|
||||
|
||||
|
||||
__all__ = ["Eagle25VLProcessor"]
|
||||
@@ -0,0 +1,380 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from huggingface_hub import snapshot_download
|
||||
from huggingface_hub.errors import HFValidationError, RepositoryNotFoundError
|
||||
|
||||
from lerobot.utils.import_utils import _transformers_available
|
||||
|
||||
# Conditional import for type checking and lazy loading
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from huggingface_hub.dataclasses import strict
|
||||
from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel
|
||||
from transformers.feature_extraction_utils import BatchFeature
|
||||
else:
|
||||
|
||||
def strict(cls):
|
||||
return cls
|
||||
|
||||
AutoConfig = None
|
||||
AutoModel = None
|
||||
PretrainedConfig = object
|
||||
PreTrainedModel = object
|
||||
BatchFeature = None
|
||||
|
||||
try:
|
||||
import tree
|
||||
except ImportError:
|
||||
tree = None
|
||||
|
||||
from lerobot.utils.constants import ACTION, HF_LEROBOT_HOME
|
||||
|
||||
from .action_head.flow_matching_action_head import (
|
||||
FlowmatchingActionHead,
|
||||
FlowmatchingActionHeadConfig,
|
||||
)
|
||||
from .utils import ensure_eagle_cache_ready
|
||||
|
||||
DEFAULT_VENDOR_EAGLE_PATH = str((Path(__file__).resolve().parent / "eagle2_hg_model").resolve())
|
||||
DEFAULT_TOKENIZER_ASSETS_REPO = "lerobot/eagle2hg-processor-groot-n1p5"
|
||||
|
||||
|
||||
class EagleBackbone(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
tune_llm: bool = False,
|
||||
tune_visual: bool = False,
|
||||
select_layer: int = -1,
|
||||
reproject_vision: bool = False,
|
||||
use_flash_attention: bool = False,
|
||||
load_bf16: bool = False,
|
||||
eagle_path: str = DEFAULT_VENDOR_EAGLE_PATH,
|
||||
tokenizer_assets_repo: str = DEFAULT_TOKENIZER_ASSETS_REPO,
|
||||
project_to_dim: int = 1536,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
tune_llm: whether to tune the LLM model (default: True)
|
||||
tune_visual: whether to tune the visual model (default: False)
|
||||
"""
|
||||
super().__init__()
|
||||
assert not reproject_vision, "Reproject vision is not implemented here, set to False"
|
||||
|
||||
# Prefer loading Eagle model config from the cache directory where vendor files were copied.
|
||||
vendor_dir = DEFAULT_VENDOR_EAGLE_PATH
|
||||
cache_dir = HF_LEROBOT_HOME / tokenizer_assets_repo
|
||||
try:
|
||||
ensure_eagle_cache_ready(vendor_dir, cache_dir, tokenizer_assets_repo)
|
||||
except Exception as exc: # nosec: B110
|
||||
print(f"[GROOT] Warning: failed to prepare Eagle cache for backbone: {exc}")
|
||||
|
||||
config = AutoConfig.from_pretrained(str(cache_dir), trust_remote_code=True)
|
||||
self.eagle_model = AutoModel.from_config(config, trust_remote_code=True)
|
||||
|
||||
if project_to_dim is not None:
|
||||
self.eagle_linear = torch.nn.Linear(2048, project_to_dim)
|
||||
else:
|
||||
self.eagle_linear = torch.nn.Identity()
|
||||
|
||||
# needed since we don't use these layers. Also saves compute
|
||||
while len(self.eagle_model.language_model.model.layers) > select_layer:
|
||||
self.eagle_model.language_model.model.layers.pop(-1)
|
||||
|
||||
self.select_layer = select_layer
|
||||
self.set_trainable_parameters(tune_llm, tune_visual)
|
||||
|
||||
def set_trainable_parameters(self, tune_llm: bool, tune_visual: bool):
|
||||
self.tune_llm = tune_llm
|
||||
self.tune_visual = tune_visual
|
||||
for p in self.parameters():
|
||||
p.requires_grad = True
|
||||
if not tune_llm:
|
||||
self.eagle_model.language_model.requires_grad_(False)
|
||||
if not tune_visual:
|
||||
self.eagle_model.vision_model.requires_grad_(False)
|
||||
self.eagle_model.mlp1.requires_grad_(False)
|
||||
print(f"Tune backbone llm: {self.tune_llm}")
|
||||
print(f"Tune backbone visual: {self.tune_visual}")
|
||||
# Check if any parameters are still trainable. If not, print a warning.
|
||||
if not tune_llm and not tune_visual:
|
||||
for name, p in self.named_parameters():
|
||||
if p.requires_grad:
|
||||
print(f"Backbone trainable parameter: {name}")
|
||||
if not any(p.requires_grad for p in self.parameters()):
|
||||
print("Warning: No backbone trainable parameters found.")
|
||||
|
||||
def set_frozen_modules_to_eval_mode(self):
|
||||
"""
|
||||
Huggingface will call model.train() at each training_step. To ensure
|
||||
the expected behaviors for modules like dropout, batchnorm, etc., we
|
||||
need to call model.eval() for the frozen modules.
|
||||
"""
|
||||
if self.training:
|
||||
if self.eagle_model.language_model and not self.tune_llm:
|
||||
self.eagle_model.language_model.eval()
|
||||
if self.eagle_model.vision_model and not self.tune_visual:
|
||||
self.eagle_model.vision_model.eval()
|
||||
|
||||
def prepare_input(self, batch: dict) -> BatchFeature:
|
||||
return BatchFeature(data=batch)
|
||||
|
||||
def forward_eagle(self, vl_input: BatchFeature) -> BatchFeature:
|
||||
eagle_prefix = "eagle_"
|
||||
eagle_input = {
|
||||
k.removeprefix(eagle_prefix): v for k, v in vl_input.items() if k.startswith(eagle_prefix)
|
||||
}
|
||||
del eagle_input["image_sizes"]
|
||||
|
||||
eagle_output = self.eagle_model(**eagle_input, output_hidden_states=True, return_dict=True)
|
||||
eagle_features = eagle_output.hidden_states[self.select_layer]
|
||||
|
||||
eagle_features = self.eagle_linear(eagle_features)
|
||||
return eagle_features, eagle_input["attention_mask"]
|
||||
|
||||
def forward(self, vl_input: BatchFeature) -> BatchFeature:
|
||||
self.set_frozen_modules_to_eval_mode()
|
||||
|
||||
eagle_embeds, eagle_mask = self.forward_eagle(vl_input)
|
||||
|
||||
# YL (TODO HACK): to resolve DDP issue when tune_visual=True
|
||||
# Ensure all trainable parameters in vision_model are used in the forward pass for DDP compatibility
|
||||
if self.training and self.tune_visual:
|
||||
dummy_term = torch.tensor(
|
||||
0.0, device=eagle_embeds.device, dtype=eagle_embeds.dtype, requires_grad=True
|
||||
)
|
||||
for param in self.eagle_model.vision_model.parameters():
|
||||
if param.requires_grad:
|
||||
dummy_term = dummy_term + 0.0 * param.sum()
|
||||
eagle_embeds = eagle_embeds + dummy_term
|
||||
|
||||
return BatchFeature(
|
||||
data={"backbone_features": eagle_embeds, "backbone_attention_mask": eagle_mask}
|
||||
) # [B, T2, hidden_size]
|
||||
|
||||
|
||||
BACKBONE_FEATURE_KEY = "backbone_features"
|
||||
ACTION_KEY = "action_pred"
|
||||
LOSS_KEY = "loss"
|
||||
ERROR_MSG = "Error: unexpected input/output"
|
||||
N_COLOR_CHANNELS = 3
|
||||
|
||||
|
||||
# config
|
||||
@strict
|
||||
class GR00TN15Config(PretrainedConfig):
|
||||
model_type = "gr00t_n1_5"
|
||||
|
||||
backbone_cfg: dict[str, Any] | None = None
|
||||
action_head_cfg: dict[str, Any] | None = None
|
||||
action_horizon: int = 0
|
||||
action_dim: int = 0
|
||||
compute_dtype: str = "float32"
|
||||
|
||||
def __post_init__(self, **kwargs):
|
||||
self.backbone_cfg = {} if self.backbone_cfg is None else self.backbone_cfg
|
||||
self.action_head_cfg = {} if self.action_head_cfg is None else self.action_head_cfg
|
||||
super().__post_init__(**kwargs)
|
||||
|
||||
|
||||
# real model
|
||||
class GR00TN15(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
config_class = GR00TN15Config
|
||||
"""
|
||||
we expect the backbone output to have a key 'backbone_features' with shape (batch_size, n, hidden_size)
|
||||
here n is variable and can be e.g. time, 1 or user specified
|
||||
we expect the action head output to have a key 'action_pred' with shape (batch_size, time, action_dim) during inference time
|
||||
we expect these to have type BatchFeature, and they can of course have many other user specified keys too
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: GR00TN15Config,
|
||||
local_model_path: str,
|
||||
):
|
||||
assert isinstance(config.backbone_cfg, dict)
|
||||
assert isinstance(config.action_head_cfg, dict)
|
||||
|
||||
super().__init__(config)
|
||||
self.local_model_path = local_model_path
|
||||
|
||||
self.backbone = EagleBackbone(**config.backbone_cfg)
|
||||
action_head_cfg = FlowmatchingActionHeadConfig(**config.action_head_cfg)
|
||||
self.action_head = FlowmatchingActionHead(action_head_cfg)
|
||||
|
||||
self.action_horizon = config.action_horizon
|
||||
self.action_dim = config.action_dim
|
||||
self.compute_dtype = config.compute_dtype
|
||||
self.post_init()
|
||||
|
||||
def validate_inputs(self, inputs):
|
||||
# NOTE -- this should be handled internally by the model
|
||||
# however, doing that will likely be breaking changes -- so we'll need to do it after the deadline
|
||||
|
||||
detected_error = False
|
||||
error_msg = ERROR_MSG
|
||||
if ACTION in inputs:
|
||||
action = inputs[ACTION]
|
||||
# In inference, action may be omitted or None; validate only when it's a tensor.
|
||||
if action is None:
|
||||
pass # allow None during inference
|
||||
elif isinstance(action, torch.Tensor):
|
||||
shape_ok = (
|
||||
len(action.shape) == 3
|
||||
and action.shape[1] == self.action_horizon
|
||||
and action.shape[2] == self.action_dim
|
||||
)
|
||||
if not shape_ok:
|
||||
error_msg += f"\n{action.shape=}"
|
||||
detected_error = True
|
||||
else:
|
||||
# Unexpected non-tensor type provided for action
|
||||
error_msg += f"\nInvalid type for action: {type(action)}"
|
||||
detected_error = True
|
||||
|
||||
if "video" in inputs:
|
||||
video = inputs["video"]
|
||||
type_ok = isinstance(video, np.ndarray)
|
||||
dtype_ok = video.dtype == np.uint8
|
||||
shape_ok = len(video.shape) == 6 and video.shape[3] == N_COLOR_CHANNELS
|
||||
if not type_ok:
|
||||
error_msg += f"\n{type(video)=}"
|
||||
detected_error = True
|
||||
if not dtype_ok:
|
||||
error_msg += f"\n{video.dtype=}"
|
||||
detected_error = True
|
||||
if not shape_ok:
|
||||
error_msg += f"\n{video.shape=}"
|
||||
detected_error = True
|
||||
|
||||
if detected_error:
|
||||
raise ValueError(error_msg)
|
||||
|
||||
def validate_data(self, action_head_outputs, backbone_outputs, is_training):
|
||||
fail_backbone = (
|
||||
not isinstance(backbone_outputs, BatchFeature) or BACKBONE_FEATURE_KEY not in backbone_outputs
|
||||
)
|
||||
|
||||
if fail_backbone:
|
||||
error_msg = ERROR_MSG
|
||||
error_msg += f"\n{isinstance(backbone_outputs, BatchFeature)=}"
|
||||
error_msg += f"\n{BACKBONE_FEATURE_KEY in backbone_outputs=}"
|
||||
error_msg += f"\n{backbone_outputs[BACKBONE_FEATURE_KEY].shape=}"
|
||||
raise ValueError(error_msg)
|
||||
|
||||
fail_action_head = (not isinstance(action_head_outputs, BatchFeature)) or not (
|
||||
(
|
||||
LOSS_KEY in action_head_outputs and is_training
|
||||
) # there might not be an action prediction during training
|
||||
or (
|
||||
ACTION_KEY in action_head_outputs
|
||||
and action_head_outputs[ACTION_KEY].shape[1] == self.action_horizon
|
||||
and action_head_outputs[ACTION_KEY].shape[2] == self.action_dim
|
||||
)
|
||||
)
|
||||
|
||||
if fail_action_head:
|
||||
error_msg = ERROR_MSG
|
||||
error_msg += f"\n{isinstance(action_head_outputs, BatchFeature)=}"
|
||||
error_msg += f"\n{LOSS_KEY in action_head_outputs=}"
|
||||
error_msg += f"\n{action_head_outputs[ACTION_KEY].shape=}"
|
||||
error_msg += f"\n{self.action_horizon=}"
|
||||
error_msg += f"\n{self.action_dim=}"
|
||||
raise ValueError(error_msg)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs: dict,
|
||||
) -> BatchFeature:
|
||||
backbone_inputs, action_inputs = self.prepare_input(inputs)
|
||||
backbone_outputs = self.backbone(backbone_inputs)
|
||||
action_head_outputs = self.action_head(backbone_outputs, action_inputs)
|
||||
self.validate_data(action_head_outputs, backbone_outputs, is_training=True)
|
||||
return action_head_outputs
|
||||
|
||||
def get_action(
|
||||
self,
|
||||
inputs: dict,
|
||||
) -> BatchFeature:
|
||||
backbone_inputs, action_inputs = self.prepare_input(inputs)
|
||||
# Because the behavior of backbones remains the same for training and inference, we can use `forward` for backbones.
|
||||
backbone_outputs = self.backbone(backbone_inputs)
|
||||
action_head_outputs = self.action_head.get_action(backbone_outputs, action_inputs)
|
||||
self.validate_data(action_head_outputs, backbone_outputs, is_training=False)
|
||||
return action_head_outputs
|
||||
|
||||
def prepare_input(self, inputs) -> tuple[BatchFeature, BatchFeature]:
|
||||
self.validate_inputs(inputs)
|
||||
backbone_inputs = self.backbone.prepare_input(inputs)
|
||||
action_inputs = self.action_head.prepare_input(inputs)
|
||||
|
||||
def to_device_with_maybe_dtype(x):
|
||||
# Cast floating tensors to a memory-efficient compute dtype when requested.
|
||||
# Rationale: Upcasting backbone activations to fp32 significantly increases VRAM.
|
||||
# When compute_dtype is bfloat16, prefer bf16 for activations to match AMP behavior.
|
||||
if not isinstance(x, torch.Tensor):
|
||||
return x
|
||||
if torch.is_floating_point(x):
|
||||
if getattr(self, "compute_dtype", None) == "bfloat16":
|
||||
return x.to(self.device, dtype=torch.bfloat16)
|
||||
# Fallback: preserve previous behavior if not using bf16 compute
|
||||
return x.to(self.device, dtype=self.action_head.dtype)
|
||||
# Non-floating tensors: move device only
|
||||
return x.to(self.device)
|
||||
|
||||
backbone_inputs = tree.map_structure(to_device_with_maybe_dtype, backbone_inputs)
|
||||
action_inputs = tree.map_structure(to_device_with_maybe_dtype, action_inputs)
|
||||
return backbone_inputs, action_inputs
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
|
||||
tune_visual = kwargs.pop("tune_visual", True)
|
||||
tune_llm = kwargs.pop("tune_llm", False)
|
||||
tune_projector = kwargs.pop("tune_projector", True)
|
||||
tune_diffusion_model = kwargs.pop("tune_diffusion_model", True)
|
||||
|
||||
print(f"Loading pretrained dual brain from {pretrained_model_name_or_path}")
|
||||
print(f"Tune backbone vision tower: {tune_visual}")
|
||||
print(f"Tune backbone LLM: {tune_llm}")
|
||||
print(f"Tune action head projector: {tune_projector}")
|
||||
print(f"Tune action head DiT: {tune_diffusion_model}")
|
||||
|
||||
# get the current model path being downloaded
|
||||
try:
|
||||
# NOTE(YL) This downloads the model to the local cache and returns the local path to the model
|
||||
# saved in ~/.cache/huggingface/hub/
|
||||
local_model_path = snapshot_download(pretrained_model_name_or_path, repo_type="model")
|
||||
# HFValidationError, RepositoryNotFoundError
|
||||
except (HFValidationError, RepositoryNotFoundError):
|
||||
print(
|
||||
f"Model not found or avail in the huggingface hub. Loading from local path: {pretrained_model_name_or_path}"
|
||||
)
|
||||
local_model_path = pretrained_model_name_or_path
|
||||
|
||||
pretrained_model = super().from_pretrained(
|
||||
local_model_path, local_model_path=local_model_path, **kwargs
|
||||
)
|
||||
|
||||
pretrained_model.backbone.set_trainable_parameters(tune_visual=tune_visual, tune_llm=tune_llm)
|
||||
pretrained_model.action_head.set_trainable_parameters(
|
||||
tune_projector=tune_projector, tune_diffusion_model=tune_diffusion_model
|
||||
)
|
||||
return pretrained_model
|
||||
@@ -1,966 +0,0 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import json
|
||||
import logging
|
||||
from contextlib import suppress
|
||||
from copy import deepcopy
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from huggingface_hub import snapshot_download
|
||||
from huggingface_hub.errors import HFValidationError, RepositoryNotFoundError
|
||||
from torch import nn
|
||||
from torch.distributions import Beta
|
||||
|
||||
from lerobot.utils.import_utils import _transformers_available, require_package
|
||||
|
||||
from .action_head.cross_attention_dit import AlternateVLDiT, DiT, SelfAttentionTransformer
|
||||
from .configuration_groot import N1_7_DEFAULT_IMAGE_CROP_SIZE, N1_7_DEFAULT_IMAGE_TARGET_SIZE
|
||||
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel
|
||||
from transformers.feature_extraction_utils import BatchFeature
|
||||
else:
|
||||
AutoConfig = None
|
||||
AutoModel = None
|
||||
PretrainedConfig = object
|
||||
PreTrainedModel = object
|
||||
BatchFeature = None
|
||||
|
||||
try:
|
||||
import tree
|
||||
except ImportError:
|
||||
tree = None
|
||||
|
||||
try:
|
||||
from transformers import Qwen3VLConfig, Qwen3VLForConditionalGeneration
|
||||
except ImportError:
|
||||
Qwen3VLConfig = None
|
||||
Qwen3VLForConditionalGeneration = None
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _copy_default(value: Any) -> Any:
|
||||
return deepcopy(value)
|
||||
|
||||
|
||||
GR00T_N1_7_DEFAULTS: dict[str, Any] = {
|
||||
"model_dtype": "bfloat16",
|
||||
"dtype": "bfloat16",
|
||||
"model_name": "nvidia/Cosmos-Reason2-2B",
|
||||
"backbone_model_type": "qwen",
|
||||
"model_revision": None,
|
||||
"tune_top_llm_layers": 0,
|
||||
"backbone_embedding_dim": 2048,
|
||||
"tune_llm": False,
|
||||
"tune_visual": False,
|
||||
"select_layer": 16,
|
||||
"reproject_vision": False,
|
||||
"use_flash_attention": True,
|
||||
"load_bf16": False,
|
||||
"backbone_trainable_params_fp32": True,
|
||||
"image_crop_size": N1_7_DEFAULT_IMAGE_CROP_SIZE,
|
||||
"image_target_size": N1_7_DEFAULT_IMAGE_TARGET_SIZE,
|
||||
"shortest_image_edge": None,
|
||||
"crop_fraction": None,
|
||||
"random_rotation_angle": None,
|
||||
"color_jitter_params": None,
|
||||
"use_albumentations_transforms": True,
|
||||
"extra_augmentation_config": None,
|
||||
"formalize_language": True,
|
||||
"apply_sincos_state_encoding": False,
|
||||
"use_percentiles": True,
|
||||
"use_relative_action": False,
|
||||
"max_state_dim": 132,
|
||||
"max_action_dim": 132,
|
||||
"action_horizon": 40,
|
||||
"hidden_size": 1024,
|
||||
"input_embedding_dim": 1536,
|
||||
"state_history_length": 1,
|
||||
"add_pos_embed": True,
|
||||
"attn_dropout": 0.2,
|
||||
"use_vlln": True,
|
||||
"max_seq_len": 1024,
|
||||
"use_alternate_vl_dit": True,
|
||||
"attend_text_every_n_blocks": 2,
|
||||
"diffusion_model_cfg": {
|
||||
"positional_embeddings": None,
|
||||
"num_layers": 32,
|
||||
"num_attention_heads": 32,
|
||||
"attention_head_dim": 48,
|
||||
"norm_type": "ada_norm",
|
||||
"dropout": 0.2,
|
||||
"final_dropout": True,
|
||||
"output_dim": 1024,
|
||||
"interleave_self_attention": True,
|
||||
},
|
||||
"vl_self_attention_cfg": {
|
||||
"positional_embeddings": None,
|
||||
"num_layers": 4,
|
||||
"num_attention_heads": 32,
|
||||
"attention_head_dim": 64,
|
||||
"dropout": 0.2,
|
||||
"final_dropout": True,
|
||||
},
|
||||
"num_inference_timesteps": 4,
|
||||
"noise_beta_alpha": 1.5,
|
||||
"noise_beta_beta": 1.0,
|
||||
"noise_s": 0.999,
|
||||
"num_timestep_buckets": 1000,
|
||||
"tune_projector": True,
|
||||
"tune_diffusion_model": True,
|
||||
"tune_vlln": True,
|
||||
"state_dropout_prob": 0.2,
|
||||
"exclude_state": False,
|
||||
"use_mean_std": False,
|
||||
"max_num_embodiments": 32,
|
||||
"rtc_ramp_rate": 6.0,
|
||||
}
|
||||
|
||||
|
||||
class GR00TN17Config(PretrainedConfig):
|
||||
"""Configuration for NVIDIA GR00T N1.7.
|
||||
|
||||
N1.7 uses the Cosmos-Reason2-2B / Qwen3-VL backbone and a multi-embodiment
|
||||
flow-matching action head. This mirrors the public N1.7 checkpoint config
|
||||
while keeping it local to LeRobot and independent from the external
|
||||
Isaac-GR00T ``gr00t`` Python package.
|
||||
"""
|
||||
|
||||
model_type = "Gr00tN1d7"
|
||||
|
||||
_defaults = GR00T_N1_7_DEFAULTS
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
for key, value in GR00T_N1_7_DEFAULTS.items():
|
||||
setattr(self, key, _copy_default(kwargs.pop(key, value)))
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
def to_filtered_dict(self, exclude_augment: bool = True) -> dict[str, Any]:
|
||||
cfg = self.to_dict()
|
||||
if not exclude_augment:
|
||||
return cfg
|
||||
exclude_keys = {
|
||||
"random_rotation_angle",
|
||||
"color_jitter_params",
|
||||
"use_albumentations_transforms",
|
||||
"formalize_language",
|
||||
"image_crop_size",
|
||||
"image_target_size",
|
||||
"shortest_image_edge",
|
||||
"crop_fraction",
|
||||
}
|
||||
return {k: v for k, v in cfg.items() if k not in exclude_keys}
|
||||
|
||||
def to_filtered_json(self, exclude_augment: bool = True, **kwargs) -> str:
|
||||
return json.dumps(self.to_filtered_dict(exclude_augment), indent=2, default=str, **kwargs)
|
||||
|
||||
|
||||
class CategorySpecificLinear(nn.Module):
|
||||
"""Linear layer with category-specific weights for multi-embodiment support."""
|
||||
|
||||
def __init__(self, num_categories: int, input_dim: int, hidden_dim: int):
|
||||
super().__init__()
|
||||
self.num_categories = num_categories
|
||||
self.W = nn.Parameter(0.02 * torch.randn(num_categories, input_dim, hidden_dim))
|
||||
self.b = nn.Parameter(torch.zeros(num_categories, hidden_dim))
|
||||
|
||||
def forward(self, x: torch.Tensor, cat_ids: torch.Tensor) -> torch.Tensor:
|
||||
selected_w = self.W[cat_ids]
|
||||
selected_b = self.b[cat_ids]
|
||||
return torch.bmm(x, selected_w) + selected_b.unsqueeze(1)
|
||||
|
||||
|
||||
class CategorySpecificMLP(nn.Module):
|
||||
"""Two-layer MLP with category-specific weights."""
|
||||
|
||||
def __init__(self, num_categories: int, input_dim: int, hidden_dim: int, output_dim: int):
|
||||
super().__init__()
|
||||
self.layer1 = CategorySpecificLinear(num_categories, input_dim, hidden_dim)
|
||||
self.layer2 = CategorySpecificLinear(num_categories, hidden_dim, output_dim)
|
||||
|
||||
def forward(self, x: torch.Tensor, cat_ids: torch.Tensor) -> torch.Tensor:
|
||||
hidden = F.relu(self.layer1(x, cat_ids))
|
||||
return self.layer2(hidden, cat_ids)
|
||||
|
||||
|
||||
class SinusoidalPositionalEncoding(nn.Module):
|
||||
"""Sinusoidal encoding of shape ``(B, T, D)`` for timestep tensors ``(B, T)``.
|
||||
|
||||
The frequency scalar is intentionally created on CPU and then broadcast with
|
||||
the device-local arange result. That mirrors Isaac-GR00T's N1.7 timestep
|
||||
embedding and avoids tiny dtype/device construction differences in parity
|
||||
tests.
|
||||
"""
|
||||
|
||||
def __init__(self, embedding_dim: int):
|
||||
super().__init__()
|
||||
self.embedding_dim = embedding_dim
|
||||
|
||||
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
|
||||
timesteps = timesteps.float()
|
||||
half_dim = self.embedding_dim // 2
|
||||
exponent = -torch.arange(half_dim, dtype=torch.float, device=timesteps.device) * (
|
||||
torch.log(torch.tensor(10000.0)) / half_dim
|
||||
)
|
||||
freqs = timesteps.unsqueeze(-1) * exponent.exp()
|
||||
return torch.cat([torch.sin(freqs), torch.cos(freqs)], dim=-1)
|
||||
|
||||
|
||||
def swish(x: torch.Tensor) -> torch.Tensor:
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
class MultiEmbodimentActionEncoder(nn.Module):
|
||||
"""Action encoder with category-specific projections and sinusoidal time encoding."""
|
||||
|
||||
def __init__(self, action_dim: int, hidden_size: int, num_embodiments: int):
|
||||
super().__init__()
|
||||
self.W1 = CategorySpecificLinear(num_embodiments, action_dim, hidden_size)
|
||||
self.W2 = CategorySpecificLinear(num_embodiments, 2 * hidden_size, hidden_size)
|
||||
self.W3 = CategorySpecificLinear(num_embodiments, hidden_size, hidden_size)
|
||||
self.pos_encoding = SinusoidalPositionalEncoding(hidden_size)
|
||||
|
||||
def forward(self, actions: torch.Tensor, timesteps: torch.Tensor, cat_ids: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, horizon, _ = actions.shape
|
||||
if timesteps.dim() != 1 or timesteps.shape[0] != batch_size:
|
||||
raise ValueError("Expected `timesteps` to have shape (B,).")
|
||||
timesteps = timesteps.unsqueeze(1).expand(-1, horizon)
|
||||
action_emb = self.W1(actions, cat_ids)
|
||||
time_emb = self.pos_encoding(timesteps).to(dtype=action_emb.dtype)
|
||||
x = swish(self.W2(torch.cat([action_emb, time_emb], dim=-1), cat_ids))
|
||||
return self.W3(x, cat_ids)
|
||||
|
||||
|
||||
class Qwen3Backbone(nn.Module):
|
||||
"""Cosmos-Reason2/Qwen3-VL backbone used by GR00T N1.7.
|
||||
|
||||
The public checkpoint stores the action head in the GR00T checkpoint but
|
||||
uses a Hugging Face Qwen3-VL-compatible backbone interface. This wrapper
|
||||
keeps the nested HF module layout compatible across transformer versions
|
||||
and exposes the hidden states consumed by the action head.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "nvidia/Cosmos-Reason2-2B",
|
||||
tune_llm: bool = False,
|
||||
tune_visual: bool = False,
|
||||
select_layer: int = -1,
|
||||
reproject_vision: bool = False,
|
||||
use_flash_attention: bool = False,
|
||||
load_bf16: bool = False,
|
||||
tune_top_llm_layers: int = 0,
|
||||
trainable_params_fp32: bool = False,
|
||||
transformers_loading_kwargs: dict[str, Any] | None = None,
|
||||
load_pretrained_weights: bool = True,
|
||||
):
|
||||
if Qwen3VLForConditionalGeneration is None:
|
||||
raise ImportError(
|
||||
"Qwen3VLForConditionalGeneration is required for GR00T N1.7. "
|
||||
"Install the GR00T optional dependencies with `pip install 'lerobot[groot]'` "
|
||||
"or use a transformers version that provides Qwen3-VL support."
|
||||
)
|
||||
|
||||
super().__init__()
|
||||
transformers_loading_kwargs = transformers_loading_kwargs or {"trust_remote_code": True}
|
||||
|
||||
extra_kwargs: dict[str, Any] = {}
|
||||
if use_flash_attention:
|
||||
try:
|
||||
import flash_attn # noqa: F401
|
||||
|
||||
extra_kwargs["attn_implementation"] = "flash_attention_2"
|
||||
except ImportError:
|
||||
logger.warning("flash_attn is not installed. Falling back to SDPA attention.")
|
||||
extra_kwargs["attn_implementation"] = "sdpa"
|
||||
if load_bf16:
|
||||
extra_kwargs["torch_dtype"] = torch.bfloat16
|
||||
|
||||
if load_pretrained_weights:
|
||||
self.model = Qwen3VLForConditionalGeneration.from_pretrained(
|
||||
model_name,
|
||||
**extra_kwargs,
|
||||
**transformers_loading_kwargs,
|
||||
).eval()
|
||||
else:
|
||||
self.model = self._from_backbone_config(
|
||||
model_name=model_name,
|
||||
model_kwargs=extra_kwargs,
|
||||
config_kwargs=transformers_loading_kwargs,
|
||||
).eval()
|
||||
|
||||
while len(self.language_model.layers) > select_layer:
|
||||
self.language_model.layers.pop(-1)
|
||||
|
||||
self.select_layer = select_layer
|
||||
self.set_trainable_parameters(tune_llm, tune_visual, tune_top_llm_layers)
|
||||
if load_bf16 and trainable_params_fp32:
|
||||
for parameter in self.parameters():
|
||||
if parameter.requires_grad:
|
||||
parameter.data = parameter.data.to(torch.float32)
|
||||
|
||||
def set_trainable_parameters(
|
||||
self, tune_llm: bool, tune_visual: bool, tune_top_llm_layers: int = 0
|
||||
) -> None:
|
||||
self.tune_llm = tune_llm
|
||||
self.tune_visual = tune_visual
|
||||
for parameter in self.parameters():
|
||||
parameter.requires_grad = True
|
||||
if not tune_llm:
|
||||
self.language_model.requires_grad_(False)
|
||||
if not tune_visual:
|
||||
self.visual.requires_grad_(False)
|
||||
if tune_top_llm_layers > 0:
|
||||
for layer in self.language_model.layers[-tune_top_llm_layers:]:
|
||||
for parameter in layer.parameters():
|
||||
parameter.requires_grad = True
|
||||
|
||||
def set_frozen_modules_to_eval_mode(self) -> None:
|
||||
if self.training:
|
||||
if self.language_model and not self.tune_llm:
|
||||
self.language_model.eval()
|
||||
if self.visual and not self.tune_visual:
|
||||
self.visual.eval()
|
||||
|
||||
@property
|
||||
def language_model(self) -> nn.Module:
|
||||
return getattr(self.model, "model", self.model).language_model
|
||||
|
||||
@property
|
||||
def visual(self) -> nn.Module:
|
||||
return getattr(self.model, "model", self.model).visual
|
||||
|
||||
def _from_backbone_config(
|
||||
self,
|
||||
*,
|
||||
model_name: str,
|
||||
model_kwargs: dict[str, Any],
|
||||
config_kwargs: dict[str, Any],
|
||||
) -> nn.Module:
|
||||
if _is_cosmos_reason2_backbone(model_name):
|
||||
backbone_config = _cosmos_reason2_qwen3_vl_config()
|
||||
else:
|
||||
if AutoConfig is None:
|
||||
raise ImportError(
|
||||
"AutoConfig is required to initialize a GR00T N1.7 backbone from config. "
|
||||
"Install the GR00T optional dependencies with `pip install 'lerobot[groot]'`."
|
||||
)
|
||||
backbone_config = AutoConfig.from_pretrained(model_name, **config_kwargs)
|
||||
return Qwen3VLForConditionalGeneration._from_config(backbone_config, **model_kwargs)
|
||||
|
||||
def prepare_input(self, batch: dict[str, Any]) -> BatchFeature:
|
||||
return BatchFeature(data=batch)
|
||||
|
||||
def _ensure_mm_token_type_ids(self, model_input: dict[str, torch.Tensor]) -> None:
|
||||
if "mm_token_type_ids" in model_input:
|
||||
return
|
||||
if "image_grid_thw" not in model_input and "video_grid_thw" not in model_input:
|
||||
return
|
||||
|
||||
input_ids = model_input.get("input_ids")
|
||||
if input_ids is None:
|
||||
return
|
||||
|
||||
mm_token_type_ids = torch.zeros(input_ids.shape, dtype=torch.int32, device=input_ids.device)
|
||||
image_token_id = getattr(self.model.config, "image_token_id", None)
|
||||
video_token_id = getattr(self.model.config, "video_token_id", None)
|
||||
if image_token_id is not None:
|
||||
mm_token_type_ids[input_ids == image_token_id] = 1
|
||||
if video_token_id is not None:
|
||||
mm_token_type_ids[input_ids == video_token_id] = 2
|
||||
|
||||
model_input["mm_token_type_ids"] = mm_token_type_ids
|
||||
|
||||
def _ensure_legacy_qwen3_position_ids(self, model_input: dict[str, torch.Tensor]) -> None:
|
||||
"""Restore the Qwen3-VL text position ids used by older Transformers releases.
|
||||
|
||||
Transformers 5.x computes 3-row multimodal RoPE ids for Qwen3-VL and then
|
||||
drops text position ids before calling text-layer flash attention. GR00T
|
||||
N1.7 was aligned against the older Transformers path, where a fourth text
|
||||
position row is forwarded alongside the temporal/height/width rows. Adding
|
||||
the row here preserves the newer multimodal position computation while
|
||||
keeping flash attention on the legacy code path.
|
||||
"""
|
||||
|
||||
if "position_ids" in model_input:
|
||||
return
|
||||
|
||||
qwen3_model = getattr(self.model, "model", self.model)
|
||||
compute_3d_position_ids = getattr(qwen3_model, "compute_3d_position_ids", None)
|
||||
if compute_3d_position_ids is None:
|
||||
return
|
||||
|
||||
position_ids = compute_3d_position_ids(
|
||||
input_ids=model_input.get("input_ids"),
|
||||
image_grid_thw=model_input.get("image_grid_thw"),
|
||||
video_grid_thw=model_input.get("video_grid_thw"),
|
||||
inputs_embeds=None,
|
||||
attention_mask=model_input.get("attention_mask"),
|
||||
past_key_values=None,
|
||||
mm_token_type_ids=model_input.get("mm_token_type_ids"),
|
||||
)
|
||||
if position_ids.ndim == 3 and position_ids.shape[0] == 3:
|
||||
position_ids = torch.cat([position_ids[:1], position_ids], dim=0)
|
||||
|
||||
model_input["position_ids"] = position_ids
|
||||
|
||||
def _last_decoder_layer_output(self, model_input: dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
"""Return the pre-final-norm decoder output consumed by the N1.7 action head.
|
||||
|
||||
Older Transformers releases exposed this tensor as ``hidden_states[-1]``.
|
||||
Newer releases expose the post-final-norm tensor there instead. Capturing
|
||||
the last decoder layer output directly keeps the N1.7 action head input
|
||||
stable across Transformers versions.
|
||||
"""
|
||||
|
||||
captured: dict[str, torch.Tensor] = {}
|
||||
|
||||
def capture_output(_module: nn.Module, _inputs: tuple[Any, ...], output: Any) -> None:
|
||||
if isinstance(output, torch.Tensor):
|
||||
captured["features"] = output
|
||||
elif isinstance(output, (tuple, list)) and output:
|
||||
captured["features"] = output[0]
|
||||
elif hasattr(output, "last_hidden_state"):
|
||||
captured["features"] = output.last_hidden_state
|
||||
|
||||
hook = self.language_model.layers[-1].register_forward_hook(capture_output)
|
||||
try:
|
||||
outputs = self.model(**model_input, output_hidden_states=True)
|
||||
finally:
|
||||
hook.remove()
|
||||
|
||||
return captured.get("features", outputs.hidden_states[-1])
|
||||
|
||||
def forward(self, vl_input: BatchFeature) -> BatchFeature:
|
||||
self.set_frozen_modules_to_eval_mode()
|
||||
keys_to_use = ["input_ids", "attention_mask", "pixel_values", "image_grid_thw"]
|
||||
optional_keys = ["mm_token_type_ids", "pixel_values_videos", "video_grid_thw"]
|
||||
model_input = {key: vl_input[key] for key in keys_to_use}
|
||||
model_input.update({key: vl_input[key] for key in optional_keys if key in vl_input})
|
||||
self._ensure_mm_token_type_ids(model_input)
|
||||
self._ensure_legacy_qwen3_position_ids(model_input)
|
||||
features = self._last_decoder_layer_output(model_input)
|
||||
image_mask = model_input["input_ids"] == self.model.config.image_token_id
|
||||
attention_mask = model_input["attention_mask"] == 1
|
||||
return BatchFeature(
|
||||
data={
|
||||
"backbone_features": features,
|
||||
"backbone_attention_mask": attention_mask,
|
||||
"image_mask": image_mask,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class GR00TN17ActionHead(nn.Module):
|
||||
supports_gradient_checkpointing = True
|
||||
|
||||
def __init__(self, config: GR00TN17Config):
|
||||
require_package("diffusers", extra="groot")
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.input_embedding_dim = config.input_embedding_dim
|
||||
|
||||
if config.use_alternate_vl_dit:
|
||||
self.model = AlternateVLDiT(
|
||||
**config.diffusion_model_cfg,
|
||||
cross_attention_dim=config.backbone_embedding_dim,
|
||||
attend_text_every_n_blocks=config.attend_text_every_n_blocks,
|
||||
)
|
||||
else:
|
||||
self.model = DiT(
|
||||
**config.diffusion_model_cfg,
|
||||
cross_attention_dim=config.backbone_embedding_dim,
|
||||
)
|
||||
|
||||
self.action_dim = config.max_action_dim
|
||||
self.action_horizon = config.action_horizon
|
||||
self.num_inference_timesteps = config.num_inference_timesteps
|
||||
self.state_encoder = CategorySpecificMLP(
|
||||
num_categories=config.max_num_embodiments,
|
||||
input_dim=config.max_state_dim * config.state_history_length,
|
||||
hidden_dim=self.hidden_size,
|
||||
output_dim=self.input_embedding_dim,
|
||||
)
|
||||
self.action_encoder = MultiEmbodimentActionEncoder(
|
||||
action_dim=self.action_dim,
|
||||
hidden_size=self.input_embedding_dim,
|
||||
num_embodiments=config.max_num_embodiments,
|
||||
)
|
||||
self.action_decoder = CategorySpecificMLP(
|
||||
num_categories=config.max_num_embodiments,
|
||||
input_dim=self.hidden_size,
|
||||
hidden_dim=self.hidden_size,
|
||||
output_dim=self.action_dim,
|
||||
)
|
||||
self.vlln = nn.LayerNorm(config.backbone_embedding_dim) if config.use_vlln else nn.Identity()
|
||||
vl_self_attention_cfg = getattr(config, "vl_self_attention_cfg", None)
|
||||
if vl_self_attention_cfg and vl_self_attention_cfg.get("num_layers", 0) > 0:
|
||||
self.vl_self_attention = SelfAttentionTransformer(**vl_self_attention_cfg)
|
||||
else:
|
||||
self.vl_self_attention = nn.Identity()
|
||||
if config.add_pos_embed:
|
||||
self.position_embedding = nn.Embedding(config.max_seq_len, self.input_embedding_dim)
|
||||
nn.init.normal_(self.position_embedding.weight, mean=0.0, std=0.02)
|
||||
self.state_dropout_prob = config.state_dropout_prob
|
||||
self._noise_beta_alpha = config.noise_beta_alpha
|
||||
self._noise_beta_beta = config.noise_beta_beta
|
||||
self._beta_dist = None
|
||||
self.num_timestep_buckets = config.num_timestep_buckets
|
||||
self.set_trainable_parameters(config.tune_projector, config.tune_diffusion_model, config.tune_vlln)
|
||||
|
||||
def set_trainable_parameters(
|
||||
self, tune_projector: bool, tune_diffusion_model: bool, tune_vlln: bool
|
||||
) -> None:
|
||||
self.tune_projector = tune_projector
|
||||
self.tune_diffusion_model = tune_diffusion_model
|
||||
self.tune_vlln = tune_vlln
|
||||
for parameter in self.parameters():
|
||||
parameter.requires_grad = True
|
||||
if not tune_projector:
|
||||
self.state_encoder.requires_grad_(False)
|
||||
self.action_encoder.requires_grad_(False)
|
||||
self.action_decoder.requires_grad_(False)
|
||||
if self.config.add_pos_embed:
|
||||
self.position_embedding.requires_grad_(False)
|
||||
if not tune_diffusion_model:
|
||||
self.model.requires_grad_(False)
|
||||
if not tune_vlln:
|
||||
self.vlln.requires_grad_(False)
|
||||
self.vl_self_attention.requires_grad_(False)
|
||||
|
||||
def set_frozen_modules_to_eval_mode(self) -> None:
|
||||
if self.training:
|
||||
if not self.tune_projector:
|
||||
self.state_encoder.eval()
|
||||
self.action_encoder.eval()
|
||||
self.action_decoder.eval()
|
||||
if self.config.add_pos_embed:
|
||||
self.position_embedding.eval()
|
||||
if not self.tune_diffusion_model:
|
||||
self.model.eval()
|
||||
if not self.tune_vlln:
|
||||
self.vlln.eval()
|
||||
self.vl_self_attention.eval()
|
||||
|
||||
def sample_time(self, batch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
|
||||
if self._beta_dist is None:
|
||||
beta_alpha = torch.tensor(self._noise_beta_alpha, device="cpu", dtype=torch.float32)
|
||||
beta_beta = torch.tensor(self._noise_beta_beta, device="cpu", dtype=torch.float32)
|
||||
self._beta_dist = Beta(beta_alpha, beta_beta, validate_args=False)
|
||||
sample = self._beta_dist.sample([batch_size]).to(device, dtype=dtype)
|
||||
return (1 - sample) * self.config.noise_s
|
||||
|
||||
def process_backbone_output(self, backbone_output: BatchFeature) -> BatchFeature:
|
||||
backbone_features = self.vlln(backbone_output["backbone_features"])
|
||||
backbone_output["backbone_features"] = self.vl_self_attention(backbone_features)
|
||||
return backbone_output
|
||||
|
||||
def forward(self, backbone_output: BatchFeature, action_input: BatchFeature) -> BatchFeature:
|
||||
self.set_frozen_modules_to_eval_mode()
|
||||
backbone_output = self.process_backbone_output(backbone_output)
|
||||
vl_embeds = backbone_output.backbone_features
|
||||
device = vl_embeds.device
|
||||
embodiment_id = action_input.embodiment_id
|
||||
|
||||
if action_input.state.shape[1] != self.config.state_history_length:
|
||||
raise ValueError("state history length does not match GR00T N1.7 config.")
|
||||
state = action_input.state.view(action_input.state.shape[0], 1, -1)
|
||||
state_features = self.state_encoder(state, embodiment_id)
|
||||
|
||||
if self.training and self.state_dropout_prob > 0:
|
||||
do_dropout = (
|
||||
torch.rand(state_features.shape[0], device=state_features.device) < self.state_dropout_prob
|
||||
)
|
||||
state_features = state_features * (1 - do_dropout[:, None, None].to(dtype=state_features.dtype))
|
||||
|
||||
actions = action_input.action
|
||||
noise = torch.randn(actions.shape, device=actions.device, dtype=actions.dtype)
|
||||
t = self.sample_time(actions.shape[0], device=actions.device, dtype=actions.dtype)
|
||||
t = t[:, None, None]
|
||||
noisy_trajectory = (1 - t) * noise + t * actions
|
||||
velocity = actions - noise
|
||||
t_discretized = (t[:, 0, 0] * self.num_timestep_buckets).long()
|
||||
action_features = self.action_encoder(noisy_trajectory, t_discretized, embodiment_id)
|
||||
|
||||
if self.config.add_pos_embed:
|
||||
pos_ids = torch.arange(action_features.shape[1], dtype=torch.long, device=device)
|
||||
action_features = action_features + self.position_embedding(pos_ids).unsqueeze(0)
|
||||
|
||||
sa_embs = torch.cat((state_features, action_features), dim=1)
|
||||
if self.config.use_alternate_vl_dit:
|
||||
model_output, _ = self.model(
|
||||
hidden_states=sa_embs,
|
||||
encoder_hidden_states=vl_embeds,
|
||||
encoder_attention_mask=backbone_output.backbone_attention_mask,
|
||||
timestep=t_discretized,
|
||||
return_all_hidden_states=True,
|
||||
image_mask=backbone_output.image_mask,
|
||||
backbone_attention_mask=backbone_output.backbone_attention_mask,
|
||||
)
|
||||
else:
|
||||
model_output, _ = self.model(
|
||||
hidden_states=sa_embs,
|
||||
encoder_hidden_states=vl_embeds,
|
||||
encoder_attention_mask=backbone_output.backbone_attention_mask,
|
||||
timestep=t_discretized,
|
||||
return_all_hidden_states=True,
|
||||
)
|
||||
|
||||
pred = self.action_decoder(model_output, embodiment_id)
|
||||
pred_actions = pred[:, -actions.shape[1] :]
|
||||
action_mask = action_input.action_mask.to(dtype=pred_actions.dtype)
|
||||
action_loss = F.mse_loss(pred_actions, velocity, reduction="none") * action_mask
|
||||
loss = action_loss.sum() / (action_mask.sum() + 1e-6)
|
||||
return BatchFeature(
|
||||
data={
|
||||
"loss": loss,
|
||||
"action_loss": action_loss,
|
||||
"action_mask": action_mask,
|
||||
"backbone_features": vl_embeds,
|
||||
"state_features": state_features,
|
||||
}
|
||||
)
|
||||
|
||||
def _encode_features(self, backbone_output: BatchFeature, action_input: BatchFeature) -> BatchFeature:
|
||||
backbone_output = self.process_backbone_output(backbone_output)
|
||||
state = action_input.state
|
||||
if state.shape[1] != self.config.state_history_length:
|
||||
raise ValueError("state history length does not match GR00T N1.7 config.")
|
||||
state = state.view(state.shape[0], 1, -1)
|
||||
state_features = self.state_encoder(state, action_input.embodiment_id)
|
||||
return BatchFeature(
|
||||
data={"backbone_features": backbone_output.backbone_features, "state_features": state_features}
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def get_action_with_features(
|
||||
self,
|
||||
backbone_features: torch.Tensor,
|
||||
state_features: torch.Tensor,
|
||||
embodiment_id: torch.Tensor,
|
||||
backbone_output: BatchFeature,
|
||||
action_input: BatchFeature,
|
||||
options: dict[str, Any] | None = None,
|
||||
) -> BatchFeature:
|
||||
vl_embeds = backbone_features
|
||||
batch_size = vl_embeds.shape[0]
|
||||
device = vl_embeds.device
|
||||
actions = torch.randn(
|
||||
size=(batch_size, self.config.action_horizon, self.action_dim),
|
||||
dtype=vl_embeds.dtype,
|
||||
device=device,
|
||||
)
|
||||
dt = 1.0 / self.num_inference_timesteps
|
||||
vel_strength = torch.ones_like(actions)
|
||||
|
||||
if "action" in action_input:
|
||||
if options is None:
|
||||
raise ValueError("RTC options are required when action is provided to get_action.")
|
||||
action_horizon_before_padding = options["action_horizon"]
|
||||
actions[:, : options["rtc_overlap_steps"], :] = action_input["action"][
|
||||
:,
|
||||
action_horizon_before_padding - options["rtc_overlap_steps"] : action_horizon_before_padding,
|
||||
:,
|
||||
]
|
||||
vel_strength[:, : options["rtc_frozen_steps"], :] = 0.0
|
||||
intermediate_steps = options["rtc_overlap_steps"] - options["rtc_frozen_steps"]
|
||||
t = torch.linspace(0.0, 1.0, intermediate_steps + 2, device=device)
|
||||
ramp = 1 - torch.exp(-options["rtc_ramp_rate"] * t)
|
||||
ramp = ramp / ramp[-1].clamp_min(1e-8)
|
||||
vel_strength[:, options["rtc_frozen_steps"] : options["rtc_overlap_steps"], :] = ramp[1:-1][
|
||||
None, :, None
|
||||
].to(device)
|
||||
|
||||
for t_step in range(self.num_inference_timesteps):
|
||||
t_cont = t_step / float(self.num_inference_timesteps)
|
||||
t_discretized = int(t_cont * self.num_timestep_buckets)
|
||||
timesteps_tensor = torch.full(size=(batch_size,), fill_value=t_discretized, device=device)
|
||||
action_features = self.action_encoder(actions, timesteps_tensor, embodiment_id)
|
||||
if self.config.add_pos_embed:
|
||||
pos_ids = torch.arange(action_features.shape[1], dtype=torch.long, device=device)
|
||||
action_features = action_features + self.position_embedding(pos_ids).unsqueeze(0)
|
||||
sa_embs = torch.cat((state_features, action_features), dim=1)
|
||||
|
||||
if self.config.use_alternate_vl_dit:
|
||||
model_output = self.model(
|
||||
hidden_states=sa_embs,
|
||||
encoder_hidden_states=vl_embeds,
|
||||
timestep=timesteps_tensor,
|
||||
image_mask=backbone_output.image_mask,
|
||||
backbone_attention_mask=backbone_output.backbone_attention_mask,
|
||||
)
|
||||
else:
|
||||
model_output = self.model(
|
||||
hidden_states=sa_embs,
|
||||
encoder_hidden_states=vl_embeds,
|
||||
timestep=timesteps_tensor,
|
||||
)
|
||||
pred = self.action_decoder(model_output, embodiment_id)
|
||||
actions = actions + dt * pred[:, -self.action_horizon :] * vel_strength
|
||||
|
||||
return BatchFeature(
|
||||
data={
|
||||
"action_pred": actions,
|
||||
"backbone_features": vl_embeds,
|
||||
"state_features": state_features,
|
||||
}
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def get_action(
|
||||
self,
|
||||
backbone_output: BatchFeature,
|
||||
action_input: BatchFeature,
|
||||
options: dict[str, Any] | None = None,
|
||||
) -> BatchFeature:
|
||||
features = self._encode_features(backbone_output, action_input)
|
||||
return self.get_action_with_features(
|
||||
backbone_features=features.backbone_features,
|
||||
state_features=features.state_features,
|
||||
embodiment_id=action_input.embodiment_id,
|
||||
backbone_output=backbone_output,
|
||||
action_input=action_input,
|
||||
options=options,
|
||||
)
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return next(iter(self.parameters())).device
|
||||
|
||||
@property
|
||||
def dtype(self) -> torch.dtype:
|
||||
return next(iter(self.parameters())).dtype
|
||||
|
||||
def prepare_input(self, batch: dict[str, Any]) -> BatchFeature:
|
||||
return BatchFeature(data=batch)
|
||||
|
||||
|
||||
def _is_cosmos_reason2_backbone(model_name: str) -> bool:
|
||||
return str(model_name).rstrip("/") == "nvidia/Cosmos-Reason2-2B"
|
||||
|
||||
|
||||
def _cosmos_reason2_qwen3_vl_config() -> PretrainedConfig:
|
||||
if Qwen3VLConfig is None:
|
||||
raise ImportError(
|
||||
"Qwen3VLConfig is required for GR00T N1.7. "
|
||||
"Install the GR00T optional dependencies with `pip install 'lerobot[groot]'`."
|
||||
)
|
||||
return Qwen3VLConfig(
|
||||
image_token_id=151655,
|
||||
video_token_id=151656,
|
||||
vision_start_token_id=151652,
|
||||
vision_end_token_id=151653,
|
||||
tie_word_embeddings=True,
|
||||
text_config={
|
||||
"attention_bias": False,
|
||||
"attention_dropout": 0.0,
|
||||
"bos_token_id": 151643,
|
||||
"dtype": "bfloat16",
|
||||
"eos_token_id": 151645,
|
||||
"head_dim": 128,
|
||||
"hidden_act": "silu",
|
||||
"hidden_size": 2048,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 6144,
|
||||
"max_position_embeddings": 262144,
|
||||
"model_type": "qwen3_vl_text",
|
||||
"num_attention_heads": 16,
|
||||
"num_hidden_layers": 28,
|
||||
"num_key_value_heads": 8,
|
||||
"rms_norm_eps": 1e-6,
|
||||
"rope_scaling": {
|
||||
"mrope_interleaved": True,
|
||||
"mrope_section": [24, 20, 20],
|
||||
"rope_type": "default",
|
||||
},
|
||||
"rope_theta": 5000000,
|
||||
"tie_word_embeddings": True,
|
||||
"use_cache": True,
|
||||
"vocab_size": 151936,
|
||||
},
|
||||
vision_config={
|
||||
"deepstack_visual_indexes": [5, 11, 17],
|
||||
"depth": 24,
|
||||
"hidden_act": "gelu_pytorch_tanh",
|
||||
"hidden_size": 1024,
|
||||
"in_channels": 3,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 4096,
|
||||
"model_type": "qwen3_vl",
|
||||
"num_heads": 16,
|
||||
"num_position_embeddings": 2304,
|
||||
"out_hidden_size": 2048,
|
||||
"patch_size": 16,
|
||||
"spatial_merge_size": 2,
|
||||
"temporal_patch_size": 2,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def get_backbone_cls(config: GR00TN17Config):
|
||||
if "nvidia/Cosmos-Reason2" in config.model_name or "Qwen/Qwen3-VL" in config.model_name:
|
||||
return Qwen3Backbone
|
||||
if config.backbone_model_type == "qwen":
|
||||
logger.warning(
|
||||
"Unrecognized GR00T N1.7 backbone model name '%s'; assuming a Qwen3-VL-compatible "
|
||||
"backbone because backbone_model_type='qwen'.",
|
||||
config.model_name,
|
||||
)
|
||||
return Qwen3Backbone
|
||||
raise ValueError(f"Unsupported GR00T N1.7 backbone model: {config.model_name}")
|
||||
|
||||
|
||||
class GR00TN17(PreTrainedModel):
|
||||
"""GR00T N1.7 model with a Cosmos-Reason2/Qwen3-VL backbone."""
|
||||
|
||||
config_class = GR00TN17Config
|
||||
supports_gradient_checkpointing = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: GR00TN17Config,
|
||||
transformers_loading_kwargs: dict[str, Any] | None = None,
|
||||
load_backbone_weights: bool = True,
|
||||
):
|
||||
super().__init__(config)
|
||||
transformers_loading_kwargs = transformers_loading_kwargs or {"trust_remote_code": True}
|
||||
self.config = config
|
||||
backbone_cls = get_backbone_cls(config)
|
||||
self.backbone = backbone_cls(
|
||||
model_name=config.model_name,
|
||||
tune_llm=config.tune_llm,
|
||||
tune_visual=config.tune_visual,
|
||||
select_layer=config.select_layer,
|
||||
reproject_vision=config.reproject_vision,
|
||||
use_flash_attention=config.use_flash_attention,
|
||||
load_bf16=config.load_bf16,
|
||||
tune_top_llm_layers=config.tune_top_llm_layers,
|
||||
trainable_params_fp32=config.backbone_trainable_params_fp32,
|
||||
transformers_loading_kwargs=transformers_loading_kwargs,
|
||||
load_pretrained_weights=load_backbone_weights,
|
||||
)
|
||||
self.action_head = GR00TN17ActionHead(config)
|
||||
self.post_init()
|
||||
|
||||
def prepare_input(self, inputs: dict[str, Any]) -> tuple[BatchFeature, BatchFeature]:
|
||||
global tree
|
||||
if tree is None:
|
||||
require_package("dm-tree", extra="groot", import_name="tree")
|
||||
tree = importlib.import_module("tree")
|
||||
backbone_inputs = self.backbone.prepare_input(inputs)
|
||||
action_inputs = self.action_head.prepare_input(inputs)
|
||||
|
||||
def to_device_with_dtype(x):
|
||||
if not isinstance(x, torch.Tensor):
|
||||
return x
|
||||
if torch.is_floating_point(x):
|
||||
return x.to(self.device, dtype=self.dtype)
|
||||
return x.to(self.device)
|
||||
|
||||
return (
|
||||
tree.map_structure(to_device_with_dtype, backbone_inputs),
|
||||
tree.map_structure(to_device_with_dtype, action_inputs),
|
||||
)
|
||||
|
||||
def forward(self, inputs: dict[str, Any]) -> BatchFeature:
|
||||
backbone_inputs, action_inputs = self.prepare_input(inputs)
|
||||
backbone_outputs = self.backbone(backbone_inputs)
|
||||
return self.action_head(backbone_outputs, action_inputs)
|
||||
|
||||
def get_action(self, inputs: dict[str, Any], options: dict[str, Any] | None = None) -> BatchFeature:
|
||||
backbone_inputs, action_inputs = self.prepare_input(inputs)
|
||||
backbone_outputs = self.backbone(backbone_inputs)
|
||||
return self.action_head.get_action(backbone_outputs, action_inputs, options)
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return next(iter(self.parameters())).device
|
||||
|
||||
@property
|
||||
def dtype(self) -> torch.dtype:
|
||||
return next(iter(self.parameters())).dtype
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
|
||||
tune_visual = kwargs.pop("tune_visual", True)
|
||||
tune_llm = kwargs.pop("tune_llm", False)
|
||||
tune_projector = kwargs.pop("tune_projector", True)
|
||||
tune_diffusion_model = kwargs.pop("tune_diffusion_model", True)
|
||||
tune_vlln = kwargs.pop("tune_vlln", True)
|
||||
transformers_loading_kwargs = kwargs.pop("transformers_loading_kwargs", None) or {
|
||||
"trust_remote_code": True
|
||||
}
|
||||
load_backbone_weights = kwargs.pop("load_backbone_weights", False)
|
||||
for key in ("cache_dir", "local_files_only", "token"):
|
||||
if key in kwargs:
|
||||
transformers_loading_kwargs.setdefault(key, kwargs[key])
|
||||
|
||||
try:
|
||||
local_model_path = snapshot_download(
|
||||
pretrained_model_name_or_path,
|
||||
repo_type="model",
|
||||
revision=kwargs.get("revision"),
|
||||
cache_dir=kwargs.get("cache_dir"),
|
||||
local_files_only=kwargs.get("local_files_only", False),
|
||||
token=kwargs.get("token"),
|
||||
)
|
||||
except (HFValidationError, RepositoryNotFoundError):
|
||||
local_model_path = pretrained_model_name_or_path
|
||||
|
||||
pretrained_model = super().from_pretrained(
|
||||
local_model_path,
|
||||
transformers_loading_kwargs=transformers_loading_kwargs,
|
||||
load_backbone_weights=load_backbone_weights,
|
||||
**kwargs,
|
||||
)
|
||||
pretrained_model.backbone.set_trainable_parameters(
|
||||
tune_visual=tune_visual,
|
||||
tune_llm=tune_llm,
|
||||
tune_top_llm_layers=pretrained_model.config.tune_top_llm_layers,
|
||||
)
|
||||
pretrained_model.action_head.set_trainable_parameters(
|
||||
tune_projector=tune_projector,
|
||||
tune_diffusion_model=tune_diffusion_model,
|
||||
tune_vlln=tune_vlln,
|
||||
)
|
||||
return pretrained_model
|
||||
|
||||
|
||||
def _register_with_transformers() -> None:
|
||||
if AutoConfig is None or AutoModel is None:
|
||||
return
|
||||
try:
|
||||
AutoConfig.register(GR00TN17Config.model_type, GR00TN17Config, exist_ok=True)
|
||||
except TypeError:
|
||||
with suppress(ValueError):
|
||||
AutoConfig.register(GR00TN17Config.model_type, GR00TN17Config)
|
||||
try:
|
||||
AutoModel.register(GR00TN17Config, GR00TN17, exist_ok=True)
|
||||
except TypeError:
|
||||
with suppress(ValueError):
|
||||
AutoModel.register(GR00TN17Config, GR00TN17)
|
||||
|
||||
|
||||
_register_with_transformers()
|
||||
@@ -17,13 +17,22 @@
|
||||
"""
|
||||
Groot Policy Wrapper for LeRobot Integration
|
||||
|
||||
Minimal integration that delegates to Isaac-GR00T N1.7 components where
|
||||
possible without porting their code. Dataset loading and training
|
||||
orchestration are handled by LeRobot's standard training stack.
|
||||
Minimal integration that delegates to Isaac-GR00T components where possible
|
||||
without porting their code. The intent is to:
|
||||
|
||||
- Download and load the pretrained GR00T model via GR00TN15.from_pretrained
|
||||
- Optionally align action horizon similar to gr00t_finetune.py
|
||||
- Expose predict_action via GR00T model.get_action
|
||||
- Provide a training forward that can call the GR00T model forward if batch
|
||||
structure matches.
|
||||
|
||||
Notes:
|
||||
- Dataset loading and full training orchestration is handled by Isaac-GR00T
|
||||
TrainRunner in their codebase. If you want to invoke that flow end-to-end
|
||||
from LeRobot, see `GrootPolicy.finetune_with_groot_runner` below.
|
||||
"""
|
||||
|
||||
import builtins
|
||||
import logging
|
||||
import os
|
||||
from collections import deque
|
||||
from pathlib import Path
|
||||
@@ -37,19 +46,8 @@ from lerobot.utils.constants import ACTION, OBS_IMAGES
|
||||
from lerobot.utils.import_utils import require_package
|
||||
|
||||
from ..pretrained import PreTrainedPolicy
|
||||
from ..utils import get_device_from_parameters
|
||||
from .configuration_groot import (
|
||||
GROOT_N1_5,
|
||||
GROOT_N1_5_REMOVAL_GUIDANCE,
|
||||
GROOT_N1_7,
|
||||
GrootConfig,
|
||||
infer_groot_model_version,
|
||||
infer_groot_n1_7_action_execution_horizon,
|
||||
infer_groot_n1_7_action_horizon,
|
||||
normalize_groot_model_version,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from .configuration_groot import GrootConfig
|
||||
from .groot_n1 import GR00TN15
|
||||
|
||||
T = TypeVar("T", bound="GrootPolicy")
|
||||
|
||||
@@ -69,35 +67,37 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
|
||||
# Initialize GR00T model using ported components
|
||||
self._groot_model = self._create_groot_model()
|
||||
self._action_queue_steps = self._resolve_action_queue_steps()
|
||||
|
||||
self.reset()
|
||||
|
||||
def _create_groot_model(self):
|
||||
"""Create and initialize the GR00T N1.7 model using Isaac-GR00T APIs."""
|
||||
"""Create and initialize the GR00T model using Isaac-GR00T API.
|
||||
|
||||
This is only called when creating a NEW policy (not when loading from checkpoint).
|
||||
|
||||
Steps (delegating to Isaac-GR00T):
|
||||
1) Download and load pretrained model via GR00TN15.from_pretrained
|
||||
2) Align action horizon with data_config if provided
|
||||
"""
|
||||
# Handle Flash Attention compatibility issues
|
||||
self._handle_flash_attention_compatibility()
|
||||
|
||||
model_kwargs = {
|
||||
"pretrained_model_name_or_path": self.config.base_model_path,
|
||||
"tune_llm": self.config.tune_llm,
|
||||
"tune_visual": self.config.tune_visual,
|
||||
"tune_projector": self.config.tune_projector,
|
||||
"tune_diffusion_model": self.config.tune_diffusion_model,
|
||||
}
|
||||
from .groot_n1_7 import GR00TN17
|
||||
|
||||
model = GR00TN17.from_pretrained(
|
||||
**model_kwargs,
|
||||
tune_vlln=True,
|
||||
transformers_loading_kwargs={"trust_remote_code": True},
|
||||
model = GR00TN15.from_pretrained(
|
||||
pretrained_model_name_or_path=self.config.base_model_path,
|
||||
tune_llm=self.config.tune_llm,
|
||||
tune_visual=self.config.tune_visual,
|
||||
tune_projector=self.config.tune_projector,
|
||||
tune_diffusion_model=self.config.tune_diffusion_model,
|
||||
)
|
||||
|
||||
model.compute_dtype = "bfloat16" if self.config.use_bf16 else model.compute_dtype
|
||||
model.config.compute_dtype = model.compute_dtype
|
||||
|
||||
return model
|
||||
|
||||
def reset(self):
|
||||
"""Reset policy state when environment resets."""
|
||||
self._action_queue = deque([], maxlen=self._action_queue_steps)
|
||||
self._action_queue = deque([], maxlen=self.config.n_action_steps)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
@@ -118,7 +118,7 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
"""Load Groot policy from pretrained model.
|
||||
|
||||
Handles two cases:
|
||||
1. Base GR00T N1.7 models - loads the raw model
|
||||
1. Base GR00T models (e.g., 'nvidia/GR00T-N1.5-3B') - loads the raw model
|
||||
2. Fine-tuned LeRobot checkpoints - loads config and weights from safetensors
|
||||
|
||||
Args:
|
||||
@@ -141,15 +141,9 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
|
||||
from huggingface_hub.errors import HfHubHTTPError
|
||||
|
||||
requested_version = (
|
||||
normalize_groot_model_version(config.model_version)
|
||||
if config is not None
|
||||
else infer_groot_model_version(str(pretrained_name_or_path)) or GROOT_N1_7
|
||||
)
|
||||
logger.info(
|
||||
"The Groot policy wraps NVIDIA's GR00T %s model. Loading pretrained model from: %s",
|
||||
requested_version,
|
||||
pretrained_name_or_path,
|
||||
print(
|
||||
"The Groot policy is a wrapper around Nvidia's GR00T N1.5 model.\n"
|
||||
f"Loading pretrained model from: {pretrained_name_or_path}"
|
||||
)
|
||||
|
||||
model_id = str(pretrained_name_or_path)
|
||||
@@ -180,7 +174,7 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
|
||||
if is_finetuned_checkpoint:
|
||||
# This is a fine-tuned LeRobot checkpoint - use parent class loading
|
||||
logger.info("Detected fine-tuned LeRobot checkpoint, loading with state dict...")
|
||||
print("Detected fine-tuned LeRobot checkpoint, loading with state dict...")
|
||||
return super().from_pretrained(
|
||||
pretrained_name_or_path=pretrained_name_or_path,
|
||||
config=config,
|
||||
@@ -196,15 +190,11 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
)
|
||||
|
||||
# This is a base GR00T model - load it fresh
|
||||
logger.info("Detected base GR00T model, loading from HuggingFace...")
|
||||
print("Detected base GR00T model, loading from HuggingFace...")
|
||||
|
||||
if config is None:
|
||||
model_version = infer_groot_model_version(str(pretrained_name_or_path)) or GROOT_N1_7
|
||||
# Create default config with the pretrained path
|
||||
config = GrootConfig(
|
||||
model_version=model_version,
|
||||
base_model_path=str(pretrained_name_or_path),
|
||||
)
|
||||
config = GrootConfig(base_model_path=str(pretrained_name_or_path))
|
||||
|
||||
# Add minimal visual feature required for validation
|
||||
# validate_features() will automatically add state and action features
|
||||
@@ -225,16 +215,6 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
if hasattr(config, key):
|
||||
setattr(config, key, value)
|
||||
|
||||
config.model_version = normalize_groot_model_version(config.model_version)
|
||||
inferred_version = infer_groot_model_version(config.base_model_path)
|
||||
if inferred_version is not None and inferred_version != config.model_version:
|
||||
message = (
|
||||
f"GR00T model_version '{config.model_version}' does not match base_model_path "
|
||||
f"'{config.base_model_path}', which looks like '{inferred_version}'."
|
||||
)
|
||||
if inferred_version == GROOT_N1_5:
|
||||
message = f"{message} {GROOT_N1_5_REMOVAL_GUIDANCE}"
|
||||
raise ValueError(message)
|
||||
# Create a fresh policy instance - this will automatically load the GR00T model
|
||||
# in __init__ via _create_groot_model()
|
||||
policy = cls(config)
|
||||
@@ -245,160 +225,21 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
def get_optim_params(self) -> dict:
|
||||
return self.parameters()
|
||||
|
||||
def _resolve_action_queue_steps(self) -> int:
|
||||
n_action_steps = int(self.config.n_action_steps)
|
||||
checkpoint_action_horizon = infer_groot_n1_7_action_horizon(
|
||||
self.config.base_model_path,
|
||||
self.config.embodiment_tag,
|
||||
)
|
||||
execution_horizon = infer_groot_n1_7_action_execution_horizon(
|
||||
self.config.base_model_path,
|
||||
self.config.embodiment_tag,
|
||||
)
|
||||
horizons = [n_action_steps]
|
||||
if checkpoint_action_horizon is not None:
|
||||
horizons.append(checkpoint_action_horizon)
|
||||
if execution_horizon is not None:
|
||||
horizons.append(execution_horizon)
|
||||
return min(horizons)
|
||||
|
||||
def _resolve_prediction_horizon(self, actions: Tensor) -> int:
|
||||
"""Return the policy-facing action horizon for a native GR00T prediction."""
|
||||
|
||||
horizons = [actions.shape[1]]
|
||||
checkpoint_action_horizon = infer_groot_n1_7_action_horizon(
|
||||
self.config.base_model_path,
|
||||
self.config.embodiment_tag,
|
||||
)
|
||||
if checkpoint_action_horizon is not None:
|
||||
horizons.append(checkpoint_action_horizon)
|
||||
|
||||
for horizon in (self.config.chunk_size, self.config.n_action_steps):
|
||||
horizon = int(horizon)
|
||||
if horizon > 0:
|
||||
horizons.append(horizon)
|
||||
|
||||
return max(1, min(horizons))
|
||||
|
||||
def _filter_groot_inputs(self, batch: dict[str, Tensor], *, include_action: bool) -> dict[str, Tensor]:
|
||||
allowed_base = {"state", "state_mask", "embodiment_id"}
|
||||
if include_action:
|
||||
allowed_base.update({"action", "action_mask"})
|
||||
|
||||
allowed_base.update(
|
||||
{
|
||||
"input_ids",
|
||||
"attention_mask",
|
||||
"pixel_values",
|
||||
"image_grid_thw",
|
||||
"mm_token_type_ids",
|
||||
"pixel_values_videos",
|
||||
"video_grid_thw",
|
||||
}
|
||||
)
|
||||
allowed_base.add("action_mask")
|
||||
|
||||
return {
|
||||
k: v for k, v in batch.items() if k in allowed_base and not (k.startswith("next.") or k == "info")
|
||||
}
|
||||
|
||||
def _prepare_n1_7_rtc_inputs(
|
||||
self,
|
||||
inputs: dict[str, Tensor],
|
||||
*,
|
||||
inference_delay: object,
|
||||
prev_chunk_left_over: object,
|
||||
) -> tuple[dict[str, Tensor], dict[str, object] | None]:
|
||||
if prev_chunk_left_over is None:
|
||||
return inputs, None
|
||||
if not isinstance(prev_chunk_left_over, torch.Tensor):
|
||||
raise TypeError("prev_chunk_left_over must be a torch.Tensor for GR00T N1.7 RTC.")
|
||||
if prev_chunk_left_over.numel() == 0:
|
||||
return inputs, None
|
||||
|
||||
prev_actions = prev_chunk_left_over
|
||||
if prev_actions.ndim == 2:
|
||||
prev_actions = prev_actions.unsqueeze(0)
|
||||
elif prev_actions.ndim != 3:
|
||||
raise ValueError("prev_chunk_left_over must have shape (T, A) or (B, T, A) for GR00T N1.7 RTC.")
|
||||
|
||||
state = inputs.get("state")
|
||||
if state is None:
|
||||
raise ValueError("GR00T N1.7 RTC requires `state` in the preprocessed batch.")
|
||||
batch_size = state.shape[0]
|
||||
if prev_actions.shape[0] == 1 and batch_size > 1:
|
||||
prev_actions = prev_actions.expand(batch_size, -1, -1).clone()
|
||||
elif prev_actions.shape[0] != batch_size:
|
||||
raise ValueError("prev_chunk_left_over batch size must match the current GR00T N1.7 batch size.")
|
||||
|
||||
# The generic LeRobot RTC engine pads short leftovers with exact zero
|
||||
# rows for fixed-shape policy calls. Native GR00T N1.7 RTC treats every
|
||||
# provided prefix row as a real action constraint, so strip that padding
|
||||
# before constructing the native overlap options.
|
||||
valid_prefix_rows = prev_actions.detach().abs().sum(dim=(0, 2)) > 0
|
||||
if valid_prefix_rows.any():
|
||||
valid_prefix_steps = int(valid_prefix_rows.nonzero()[-1].item()) + 1
|
||||
prev_actions = prev_actions[:, :valid_prefix_steps, :]
|
||||
else:
|
||||
return inputs, None
|
||||
|
||||
model_action_horizon = int(
|
||||
getattr(self._groot_model.config, "action_horizon", self.config.chunk_size)
|
||||
)
|
||||
max_action_dim = int(getattr(self._groot_model.config, "max_action_dim", self.config.max_action_dim))
|
||||
if prev_actions.shape[1] > model_action_horizon:
|
||||
prev_actions = prev_actions[:, -model_action_horizon:, :]
|
||||
|
||||
action_horizon = int(prev_actions.shape[1])
|
||||
if action_horizon <= 0:
|
||||
return inputs, None
|
||||
|
||||
if prev_actions.shape[2] > max_action_dim:
|
||||
prev_actions = prev_actions[:, :, :max_action_dim]
|
||||
elif prev_actions.shape[2] < max_action_dim:
|
||||
pad = torch.zeros(
|
||||
prev_actions.shape[0],
|
||||
prev_actions.shape[1],
|
||||
max_action_dim - prev_actions.shape[2],
|
||||
dtype=prev_actions.dtype,
|
||||
device=prev_actions.device,
|
||||
)
|
||||
prev_actions = torch.cat([prev_actions, pad], dim=2)
|
||||
|
||||
prev_actions = prev_actions.to(device=state.device, dtype=state.dtype)
|
||||
|
||||
rtc_config = getattr(self.config, "rtc_config", None)
|
||||
execution_horizon = int(getattr(rtc_config, "execution_horizon", action_horizon))
|
||||
overlap_steps = max(0, min(action_horizon, execution_horizon))
|
||||
if overlap_steps == 0:
|
||||
return inputs, None
|
||||
|
||||
try:
|
||||
frozen_steps = int(inference_delay or 0)
|
||||
except (TypeError, ValueError):
|
||||
frozen_steps = 0
|
||||
frozen_steps = max(0, min(frozen_steps, overlap_steps))
|
||||
|
||||
options = {
|
||||
"action_horizon": action_horizon,
|
||||
"rtc_overlap_steps": overlap_steps,
|
||||
"rtc_frozen_steps": frozen_steps,
|
||||
"rtc_ramp_rate": float(getattr(self._groot_model.config, "rtc_ramp_rate", 6.0)),
|
||||
}
|
||||
|
||||
inputs = dict(inputs)
|
||||
inputs["action"] = prev_actions
|
||||
return inputs, options
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
||||
"""Training forward pass.
|
||||
|
||||
Delegates to Isaac-GR00T model.forward when inputs are compatible.
|
||||
"""
|
||||
groot_inputs = self._filter_groot_inputs(batch, include_action=True)
|
||||
# Build a clean input dict for GR00T: keep only tensors GR00T consumes
|
||||
allowed_base = {"state", "state_mask", "action", "action_mask", "embodiment_id"}
|
||||
groot_inputs = {
|
||||
k: v
|
||||
for k, v in batch.items()
|
||||
if (k in allowed_base or k.startswith("eagle_")) and not (k.startswith("next.") or k == "info")
|
||||
}
|
||||
|
||||
# Get device from model parameters
|
||||
device = get_device_from_parameters(self)
|
||||
device = next(self.parameters()).device
|
||||
|
||||
# Run GR00T forward under bf16 autocast when enabled to reduce activation memory
|
||||
# Rationale: Matches original GR00T finetuning (bf16 compute, fp32 params) and avoids fp32 upcasts.
|
||||
@@ -407,54 +248,38 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
|
||||
# Isaac-GR00T returns a BatchFeature; loss key is typically 'loss'
|
||||
loss = outputs.get("loss")
|
||||
if loss is None:
|
||||
raise RuntimeError(
|
||||
"GR00T model.forward did not return a 'loss'. Training batches must include "
|
||||
"'action' and 'action_mask'; check the preprocessor output."
|
||||
)
|
||||
|
||||
loss_dict = {"loss": loss.item()}
|
||||
|
||||
return loss, loss_dict
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs: object) -> Tensor:
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Predict a chunk of actions for inference by delegating to Isaac-GR00T.
|
||||
|
||||
Returns a tensor of shape (B, n_action_steps, action_dim).
|
||||
|
||||
For N1.7, LeRobot's RTC leftovers are converted into the native GR00T
|
||||
action-overlap options before calling the underlying model.
|
||||
"""
|
||||
self.eval()
|
||||
|
||||
# Preprocessing is handled by the processor pipeline, so we just filter the batch.
|
||||
# During inference, we do not pass action because it is predicted.
|
||||
# N1.7 still carries a 2-D action horizon mask from its checkpoint processor.
|
||||
groot_inputs = self._filter_groot_inputs(batch, include_action=False)
|
||||
groot_options = None
|
||||
if self.config.model_version == GROOT_N1_7:
|
||||
groot_inputs, groot_options = self._prepare_n1_7_rtc_inputs(
|
||||
groot_inputs,
|
||||
inference_delay=kwargs.get("inference_delay"),
|
||||
prev_chunk_left_over=kwargs.get("prev_chunk_left_over"),
|
||||
)
|
||||
# Build a clean input dict for GR00T: keep only tensors GR00T consumes
|
||||
# Preprocessing is handled by the processor pipeline, so we just filter the batch
|
||||
# NOTE: During inference, we should NOT pass action/action_mask (that's what we're predicting)
|
||||
allowed_base = {"state", "state_mask", "embodiment_id"}
|
||||
groot_inputs = {
|
||||
k: v
|
||||
for k, v in batch.items()
|
||||
if (k in allowed_base or k.startswith("eagle_")) and not (k.startswith("next.") or k == "info")
|
||||
}
|
||||
|
||||
# Get device from model parameters
|
||||
device = get_device_from_parameters(self)
|
||||
device = next(self.parameters()).device
|
||||
|
||||
# Use bf16 autocast for inference to keep memory low and match backbone dtype
|
||||
with torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=self.config.use_bf16):
|
||||
if groot_options is not None:
|
||||
outputs = self._groot_model.get_action(groot_inputs, options=groot_options)
|
||||
else:
|
||||
outputs = self._groot_model.get_action(groot_inputs)
|
||||
outputs = self._groot_model.get_action(groot_inputs)
|
||||
|
||||
actions = outputs.get("action_pred")
|
||||
|
||||
prediction_horizon = self._resolve_prediction_horizon(actions)
|
||||
actions = actions[:, :prediction_horizon]
|
||||
|
||||
original_action_dim = self.config.output_features[ACTION].shape[0]
|
||||
actions = actions[:, :, :original_action_dim]
|
||||
|
||||
@@ -467,28 +292,40 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
|
||||
if len(self._action_queue) == 0:
|
||||
actions = self.predict_action_chunk(batch)
|
||||
self._action_queue.extend(actions[:, : self._action_queue_steps].transpose(0, 1))
|
||||
self._action_queue.extend(actions.transpose(0, 1))
|
||||
return self._action_queue.popleft()
|
||||
|
||||
# -------------------------
|
||||
# Internal helpers
|
||||
# -------------------------
|
||||
def _handle_flash_attention_compatibility(self) -> None:
|
||||
"""Log Flash Attention availability (diagnostic only).
|
||||
"""Handle Flash Attention compatibility issues by setting environment variables.
|
||||
|
||||
The GR00T N1.7 backbone automatically falls back to SDPA when ``flash_attn`` is
|
||||
unavailable (see ``Qwen3Backbone``), so this probe only emits a hint; it does not
|
||||
change behaviour or mutate global state.
|
||||
This addresses the common 'undefined symbol' error that occurs when Flash Attention
|
||||
is compiled against a different PyTorch version than what's currently installed.
|
||||
"""
|
||||
|
||||
# Set environment variables to handle Flash Attention compatibility
|
||||
# These help with symbol resolution issues
|
||||
os.environ.setdefault("FLASH_ATTENTION_FORCE_BUILD", "0")
|
||||
os.environ.setdefault("FLASH_ATTENTION_SKIP_CUDA_BUILD", "0")
|
||||
|
||||
# Try to import flash_attn and handle failures gracefully
|
||||
try:
|
||||
import flash_attn
|
||||
|
||||
logger.debug("Flash Attention %s is available.", flash_attn.__version__)
|
||||
except ImportError:
|
||||
logger.debug("Flash Attention is not installed; the GR00T backbone will use SDPA.")
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.warning(
|
||||
"Flash Attention failed to import (%s); the GR00T backbone will use SDPA. If this is "
|
||||
"an 'undefined symbol' error, reinstall a flash-attn build matching your torch version.",
|
||||
e,
|
||||
)
|
||||
print(f"[GROOT] Flash Attention version: {flash_attn.__version__}")
|
||||
except ImportError as e:
|
||||
print(f"[GROOT] Flash Attention not available: {e}")
|
||||
print("[GROOT] Will use fallback attention mechanism")
|
||||
except Exception as e:
|
||||
if "undefined symbol" in str(e):
|
||||
print(f"[GROOT] Flash Attention compatibility issue detected: {e}")
|
||||
print("[GROOT] This is likely due to PyTorch/Flash Attention version mismatch")
|
||||
print("[GROOT] Consider reinstalling Flash Attention with compatible version:")
|
||||
print(" pip uninstall flash-attn")
|
||||
print(" pip install --no-build-isolation flash-attn==2.6.3")
|
||||
print("[GROOT] Continuing with fallback attention mechanism")
|
||||
else:
|
||||
print(f"[GROOT] Flash Attention error: {e}")
|
||||
print("[GROOT] Continuing with fallback attention mechanism")
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,47 @@
|
||||
from pathlib import Path
|
||||
from shutil import copytree
|
||||
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
|
||||
def ensure_eagle_cache_ready(vendor_dir: Path, cache_dir: Path, assets_repo: str) -> None:
|
||||
"""Populate the Eagle processor directory in cache and ensure tokenizer assets exist.
|
||||
|
||||
- Copies the vendored Eagle files into cache_dir (overwriting when needed).
|
||||
- Downloads vocab.json and merges.txt into the same cache_dir if missing.
|
||||
"""
|
||||
cache_dir = Path(cache_dir)
|
||||
vendor_dir = Path(vendor_dir)
|
||||
|
||||
try:
|
||||
# Populate/refresh cache with vendor files to ensure a complete processor directory
|
||||
print(f"[GROOT] Copying vendor Eagle files to cache: {vendor_dir} -> {cache_dir}")
|
||||
copytree(vendor_dir, cache_dir, dirs_exist_ok=True)
|
||||
except Exception as exc: # nosec: B110
|
||||
print(f"[GROOT] Warning: Failed to copy vendor Eagle files to cache: {exc}")
|
||||
|
||||
required_assets = [
|
||||
"vocab.json",
|
||||
"merges.txt",
|
||||
"added_tokens.json",
|
||||
"chat_template.json",
|
||||
"special_tokens_map.json",
|
||||
"config.json",
|
||||
"generation_config.json",
|
||||
"preprocessor_config.json",
|
||||
"processor_config.json",
|
||||
"tokenizer_config.json",
|
||||
]
|
||||
|
||||
print(f"[GROOT] Assets repo: {assets_repo} \n Cache dir: {cache_dir}")
|
||||
|
||||
for fname in required_assets:
|
||||
dst = cache_dir / fname
|
||||
if not dst.exists():
|
||||
print(f"[GROOT] Fetching {fname}")
|
||||
hf_hub_download(
|
||||
repo_id=assets_repo,
|
||||
filename=fname,
|
||||
repo_type="model",
|
||||
local_dir=str(cache_dir),
|
||||
)
|
||||
@@ -0,0 +1,286 @@
|
||||
# Remote Inference Architecture
|
||||
|
||||
How `lerobot-policy-server` and `lerobot-rollout --inference.type=remote` decouple GPU-bound policy inference from high-frequency robot control over Zenoh.
|
||||
|
||||
This document explains the **internals** — the wire protocol, threading models, state machines, and safety invariants. For the user-facing guide (CLI quickstarts, deployment), see [`docs/source/remote_inference.mdx`](../../../docs/source/remote_inference.mdx).
|
||||
|
||||
## 1. The problem and the shape of the solution
|
||||
|
||||
Running a large policy (Pi0-class, ~150 ms inference) inside a 33 ms control loop doesn't work, and putting a GPU next to every robot doesn't scale. LeRobot already solved the _local_ version of this problem: `RTCInferenceEngine` runs inference in a background **thread** that fills a thread-safe `ActionQueue`, while the control loop pops one action per tick.
|
||||
|
||||
Remote inference is **that same architecture with the thread boundary replaced by a network boundary**:
|
||||
|
||||
```
|
||||
local RTC: control loop ──ActionQueue── inference thread (same process, same GPU)
|
||||
remote: control loop ──ActionQueue── network worker ══zenoh══ policy server (GPU, elsewhere)
|
||||
```
|
||||
|
||||
Three design commitments follow from this:
|
||||
|
||||
- **The client is a backend, not a CLI.** `RemoteInferenceEngine` plugs into the existing `InferenceEngine` seam (`rollout/inference/base.py`), so every rollout strategy (base, sentry, highlight, dagger, episodic) gets network inference — including dataset recording, pause/resume, and safe teardown — without changing a line.
|
||||
- **The client is weightless.** No policy weights, no policy processors on the edge. `--policy.path` resolves to a config-only `PreTrainedConfig` used for pre-flight validation and action ordering.
|
||||
- **The server is stateless per request.** All chunk state (RTC prefixes, latency tracking, delay computation) lives client-side in the existing `ActionQueue`/`LatencyTracker`. The client ships prefixes + a delay hint with every observation, so a server crash loses zero control state.
|
||||
|
||||
## 2. Component map
|
||||
|
||||
```mermaid
|
||||
flowchart LR
|
||||
subgraph EDGE["Edge (per robot, weightless)"]
|
||||
R[Robot HW] --> S["Rollout strategy<br/>(sentry / dagger / ...)"]
|
||||
S -->|"notify_observation()"| E[RemoteInferenceEngine]
|
||||
E -->|"get_action()"| S
|
||||
S -->|actions| R
|
||||
E --- AQ[("ActionQueue<br/>(chunk buffer)")]
|
||||
end
|
||||
|
||||
subgraph NET["Transport"]
|
||||
Z["zenohd router(s)<br/>(robots dial out, mTLS + ACL)"]
|
||||
end
|
||||
|
||||
subgraph GPU["GPU pod (one model · one device · one process)"]
|
||||
PS[PolicyServer]
|
||||
PS --- SR["SessionRegistry<br/>(per-client mailboxes + pipelines)"]
|
||||
PS --- W["Inference worker<br/>(1 thread, owns GPU)"]
|
||||
W --- P["PreTrainedPolicy<br/>(pre-warmed)"]
|
||||
end
|
||||
|
||||
E <-->|"obs ↑ / chunks ↓ (pub/sub)<br/>status · session · reset (queryables)"| Z
|
||||
Z <--> PS
|
||||
```
|
||||
|
||||
One server process = one pre-warmed `(model, revision, dtype, device)` serving up to `max_sessions` robots. Scaling out = more pods; clients rejected with the current load retry another replica.
|
||||
|
||||
## 3. Where the network cut goes
|
||||
|
||||
The local RTC pipeline is split at the cheapest, most hardware-coupled point. Everything policy-coupled (resize, normalize, tokenize) runs server-side with the **canonical training-time processors**, so serve-time preprocessing is byte-identical to train-time:
|
||||
|
||||
```
|
||||
robot obs (processed dict)
|
||||
→ build_dataset_frame(...) CLIENT cheap, hardware-coupled
|
||||
→ rename_map applied to keys CLIENT wire format = canonical policy keys
|
||||
══════════════════════ network (msgpack + JPEG) ══════════════════════
|
||||
→ prepare_observation_for_inference(...) SERVER tensors, batch dim, device
|
||||
→ per-session preprocessor(...) SERVER stateful within the request
|
||||
→ policy.predict_action_chunk(obs, delay, prefix) SERVER pure for allowlisted policies
|
||||
→ per-session postprocessor(...) SERVER reads state cached at preprocess
|
||||
══════════════════════ network ══════════════════════
|
||||
→ ActionQueue.merge(original, processed, delay, idx_before) CLIENT
|
||||
```
|
||||
|
||||
The reply carries **both** the model-space (`chunk_model`) and robot-space (`chunk_robot`) chunks because `ActionQueue.merge` needs both, and the next request's relative-action prefix re-anchoring needs the robot-space tail.
|
||||
|
||||
## 4. Wire protocol
|
||||
|
||||
### 4.1 Key-expression schema (Zenoh)
|
||||
|
||||
```
|
||||
@lerobot/<model_slug>/<revision>/<task_slug>/<client_uuid>/obs client → server pub/sub
|
||||
@lerobot/<model_slug>/<revision>/<task_slug>/<client_uuid>/action server → client pub/sub
|
||||
@lerobot/<model_slug>/<revision>/<task_slug>/status queryable (capabilities)
|
||||
@lerobot/<model_slug>/<revision>/<task_slug>/session queryable (open / close)
|
||||
@lerobot/<model_slug>/<revision>/<task_slug>/<client_uuid>/reset queryable (episode boundary)
|
||||
@lerobot/<model_slug>/<revision>/<task_slug>/<client_uuid>/alive liveliness token (client)
|
||||
@lerobot/<model_slug>/<revision>/<task_slug>/server/alive liveliness token (server)
|
||||
```
|
||||
|
||||
`@lerobot` is a **verbatim chunk**: wildcards never match it, so third-party `**` subscribers on a shared router cannot scrape the tree. User-supplied segments are sanitized (`sanitize_key_segment`), and the server subscribes with single-depth wildcards only (`.../*/obs`, never `**`).
|
||||
|
||||
Data plane = pub/sub (a late chunk is still usable; a timed-out query reply is not). Control plane = queryables with explicit timeouts (the rmw_zenoh pattern). QoS (`zenoh_utils.py`): actions are `RELIABLE + congestion DROP + express + INTERACTIVE_HIGH` — **never BLOCK**, so one dead robot uplink can never stall the server's publish path; a dropped chunk is recoverable because the client buffer keeps the robot moving.
|
||||
|
||||
### 4.2 Messages
|
||||
|
||||
Every data-plane message carries a **packed little-endian attachment header** (27 bytes, parsed without touching the body):
|
||||
|
||||
| field | type | meaning |
|
||||
| ---------------- | ---- | --------------------------------------------------------- |
|
||||
| `schema_version` | u16 | negotiated at session open; additive-only body evolution |
|
||||
| `msg_type` | u8 | OBS / CHUNK / EVENT |
|
||||
| `seq_id` | u64 | per-session monotonic; echoed in the chunk |
|
||||
| `episode_id` | u32 | bumped by `reset()` |
|
||||
| `client_mono_ns` | i64 | client monotonic clock — **opaque to the server, echoed** |
|
||||
| `session_epoch` | u32 | bumped per (re)connect; stale-epoch chunks dropped |
|
||||
|
||||
Bodies are msgpack (`codec.py`): tensors as raw little-endian bytes + dtype + shape, images JPEG (RGB convention enforced inside the codec; `jpeg_quality=0` = raw). No pickle anywhere — nothing on the wire can carry code.
|
||||
|
||||
**Clock iron rule:** wall-clock instants never cross machines. The client computes RTT from its own monotonic clock via the echoed `client_mono_ns`; the server reports only **durations** (`queue_wait_ms`, `inference_ms`).
|
||||
|
||||
### 4.3 Session lifecycle
|
||||
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
participant C as RemoteInferenceEngine
|
||||
participant S as PolicyServer
|
||||
|
||||
C->>S: GET status (timeout 2s)
|
||||
S-->>C: capabilities (model, action_names, cameras, chunk_size, supports_rtc, ...)
|
||||
C->>S: GET session {op: open, action_names, cameras, state_dim, fps, rtc, task}
|
||||
Note over S: validate (hard: action name ORDER,<br/>cameras, state_dim, schema, capacity)
|
||||
S-->>C: SessionAck {session_id, warnings, rtc_execution_horizon, ...}
|
||||
Note over C,S: both declare liveliness tokens
|
||||
|
||||
loop self-clocked by buffer_time_s (one-in-flight)
|
||||
C->>S: PUB obs {state, images, delay_steps, prefix_model, prefix_robot} + header
|
||||
Note over S: latest-only mailbox → worker →<br/>preprocess → predict_action_chunk → postprocess
|
||||
S-->>C: PUB chunk {chunk_model, chunk_robot, durations} + echoed header
|
||||
Note over C: validate (episode, epoch) → ActionQueue.merge(..., idx_before)
|
||||
end
|
||||
|
||||
C->>S: GET reset {episode_id} (episode boundary, acked)
|
||||
C->>S: GET session {op: close} (graceful stop)
|
||||
```
|
||||
|
||||
The **action-name order check is a hard reject**: it is the contract that maps chunk columns to motors. A mismatch means wrong-joint commands, so the session never opens.
|
||||
|
||||
## 5. The client: `RemoteInferenceEngine`
|
||||
|
||||
File: `src/lerobot/rollout/inference/remote.py`, registered as `--inference.type=remote` (`RemoteInferenceConfig` in `factory.py`).
|
||||
|
||||
### 5.1 Threading model
|
||||
|
||||
| thread | role |
|
||||
| ---------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| main (strategy loop) | `notify_observation()` → latest-only slot; `get_action()` → `ActionQueue.get()` + staleness check + fallback. **Never any I/O.** |
|
||||
| network worker (1) | gate on `buffer_time_s` → snapshot `(seq, episode, epoch)` then `idx_before` + RTC prefixes → publish obs → await chunk (timeout) → revalidate → merge. Owns the state machine and reconnects. |
|
||||
| zenoh callback threads | deposit-only: chunk → bounded queue; server liveliness → event. |
|
||||
|
||||
**One-in-flight is a correctness requirement, not a tuning choice.** `merge(..., idx_before)` validates against the consumption index snapshotted at send time; two in-flight requests would carry conflicting snapshots and corrupt both RTC-replace and append modes. The worker therefore publishes one observation, waits for its chunk (or timeout), then sends the next. A late chunk is accepted only if it answers the latest outstanding `seq_id` _and_ the current `(episode, epoch)`.
|
||||
|
||||
### 5.2 The request cycle
|
||||
|
||||
```
|
||||
queue playback ≤ buffer_time_s? (self-clocking: ~1–4 Hz, not the 30 Hz control rate)
|
||||
├─ snapshot (seq, episode, epoch)
|
||||
├─ snapshot idx_before, prefix_model = queue.get_left_over()[:H],
|
||||
│ prefix_robot = queue.get_processed_left_over()[:H]
|
||||
├─ revalidate (episode, epoch) unchanged ← a reset racing the snapshot skips the cycle
|
||||
├─ delay_steps = ceil(LatencyTracker.max() / dt)
|
||||
├─ publish obs + header
|
||||
├─ await chunk (request_timeout_s)
|
||||
├─ revalidate (episode, epoch) under _anchor_lock ← a stale chunk can never survive a reset
|
||||
└─ merge(chunk_model, chunk_robot, ceil(measured_latency/dt), idx_before); update anchor
|
||||
```
|
||||
|
||||
Because the `LatencyTracker` samples are full network-inclusive cycle times, RTT compensation falls out for free — the same `delay`-trimming machinery local RTC uses absorbs network latency as just more delay.
|
||||
|
||||
### 5.3 Fail-safe state machine
|
||||
|
||||
```mermaid
|
||||
stateDiagram-v2
|
||||
[*] --> CONNECTING
|
||||
CONNECTING --> STREAMING: first merge
|
||||
STREAMING --> DEGRADED: request timeouts,<br/>queue still has actions
|
||||
DEGRADED --> STREAMING: merge
|
||||
DEGRADED --> STALLED: queue empty or<br/>max_action_age_s hit
|
||||
STALLED --> RECONNECTING: timeout streak /<br/>server liveliness drop
|
||||
DEGRADED --> RECONNECTING: timeout streak /<br/>server liveliness drop
|
||||
RECONNECTING --> STREAMING: re-handshake OK<br/>(epoch++)
|
||||
RECONNECTING --> DEAD: offline > max_offline_s,<br/>capability/model mismatch
|
||||
DEAD --> [*]: failed=True → shutdown_event<br/>→ strategy teardown
|
||||
```
|
||||
|
||||
- **DEGRADED**: the chunk buffer _is_ the fault tolerance — 1–3 s of buffered actions makes network blips and clean server drains invisible to the robot.
|
||||
- **Staleness bound** (`max_action_age_s`): `get_action` refuses any action whose source observation is too old, bounding open-loop execution after a stall. Then the **fallback ladder** applies: `hold` (return `None`; the robot holds), `repeat_last`, or `zero` (the safe stop for velocity-controlled robots).
|
||||
- **Watchdog layering**: per-request timeout (catches a _hung-but-connected_ server) → server liveliness token (catches a dead server/router) → staleness bound (the robot-side invariant that holds regardless of why data stopped).
|
||||
- **DEAD** is reserved for hard failures: offline beyond `max_offline_s` with no successful merge (a server that handshakes but never delivers chunks still runs out of budget), or a contract violation on reconnect (model/revision changed, RTC capability flipped — never execute wrong-model chunks). It triggers the exact mechanism local RTC uses: `failed=True` + the global `shutdown_event`, so the existing teardown (return-to-initial-pose) runs unchanged.
|
||||
- **Pause/resume** (DAgger): `pause()` stops publishing; the queue stays intact. A pause during an outage freezes the offline budget so a human correction can never be aborted by `max_offline_s`.
|
||||
|
||||
### 5.4 Episode boundaries
|
||||
|
||||
`reset()` (control thread) atomically — under the same lock the merge path takes — clears the `ActionQueue`, nulls the staleness anchor, bumps `episode_id`, and invalidates the observation slot (the previous episode's final frame must not seed the new one). The worker sends an acked `reset` query, and the next observation header carries the new `episode_id` anyway — so a lost ack costs nothing (the server is stateless per request).
|
||||
|
||||
## 6. The server: `PolicyServer`
|
||||
|
||||
Files: `src/lerobot/policy_server/`. Entry point: `lerobot-policy-server --manifest server.yaml` (draccus dataclasses in `manifest.py`).
|
||||
|
||||
### 6.1 Concurrency model
|
||||
|
||||
zenoh-python is thread-based (no asyncio); callbacks must be deposit-only:
|
||||
|
||||
```
|
||||
zenoh subscriber (.../*/obs) inference worker (1 thread, owns GPU)
|
||||
deposit-only callback: loop:
|
||||
session.deposit(header, body) ──► scheduler picks next session with pending obs
|
||||
(per-client latest-only mailbox) decode → episode-boundary check
|
||||
preprocess → predict_action_chunk(delay, prefix)
|
||||
control queryables (status / postprocess → encode
|
||||
session / reset): validate, publisher.put(.../<uuid>/action)
|
||||
mutate registry, reply inline
|
||||
|
||||
liveliness subscriber (.../*/alive): mark sessions for GC on token DELETE
|
||||
```
|
||||
|
||||
- **Latest-only mailboxes**: the newest observation wins; superseded requests are counted and reported in the next reply (`superseded_seqs`), so drops are visible client-side. The client decides _when_ to request; the server never second-guesses observation content.
|
||||
- **Single inference worker** + round-robin over ready sessions: every ready session gets exactly one inference per cycle — starvation is structurally impossible. Overload degrades into longer cycle times → larger (but correct) client `delay_steps` → eventually the client staleness bound trips and the robot holds. Safe by construction.
|
||||
- The `Scheduler` seam (`scheduler.py`) exists so cross-session micro-batching can land later without redesign (blocked today on `predict_action_chunk` taking a _scalar_ `inference_delay`).
|
||||
- `_inference_lock` serializes the worker's predict path against episode resets arriving on queryable threads (in exclusive mode a `policy.reset()` mid-predict would corrupt the in-flight request).
|
||||
|
||||
### 6.2 Multi-tenancy: engineered, not assumed
|
||||
|
||||
Sharing one policy instance across sessions is only safe when `predict_action_chunk` touches no cross-request instance state. That property is **verified per family and encoded as a registry** (`validation.py`) — never inferred:
|
||||
|
||||
| class | policies | mode | why |
|
||||
| --------------- | ------------------------------------------------- | ----------- | ----------------------------------------------------------------------------------- |
|
||||
| chunk-stateless | `act`, `pi0`, `pi05`, `smolvla` (`n_obs_steps=1`) | `shared` | chunk call is pure (smolvla overwrites its 1-deep queue with the request's own obs) |
|
||||
| chunk-stateful | `diffusion` (and `smolvla` with `n_obs_steps>1`) | `exclusive` | chunk call reads `select_action`-fed `_queues` → server populates them per request |
|
||||
| no chunk API | `sac`, `tdmpc`, ... (no `predict_action_chunk`) | refused | nothing to serve |
|
||||
| unverified | any other chunk-API policy | `exclusive` | a manifest can force `exclusive`, but never `shared` for an unverified policy |
|
||||
|
||||
The real multi-tenancy hazard is **processor state**, not just policy purity: `RelativeActionsProcessorStep` caches `_last_state` at preprocess and the postprocessor reads it back. The server therefore builds a **fresh pre/post pipeline pair per session** — two robots at different joint positions can never cross-contaminate each other's action conversions. `policy.reset()` is **never** called in shared mode (it is global to the shared instance).
|
||||
|
||||
### 6.3 Statelessness and the RTC prefix
|
||||
|
||||
The server holds no cross-request control state. Each observation ships everything inference needs:
|
||||
|
||||
- `inference_delay_steps` — computed client-side from network-inclusive latency.
|
||||
- `prefix_model` — the unexecuted tail of the previous chunk in model space (feeds `prev_chunk_left_over`).
|
||||
- `prefix_robot` — the same tail in robot space. For relative-action policies the server **re-anchors** it against the state cached by _this request's_ preprocess (`reanchor_relative_rtc_prefix`, mirroring `rtc.py`), so the prefix is expressed relative to where the robot actually is now.
|
||||
|
||||
Consequences: reconnects are trivial, horizontal scaling is trivial, and a `kill -9` on the server loses nothing the client can't re-send.
|
||||
|
||||
### 6.4 Episode and reconnect hygiene
|
||||
|
||||
- Fresh sessions start at the `episode_id = -1` sentinel: the **first** observation of any session always triggers the boundary branch (pipelines reset; exclusive policies `reset()`), so a mid-episode reconnect can never inherit stale state.
|
||||
- Session replacement is identity-checked (`SessionRegistry.remove(expected=...)`): a GC sweep that snapshotted an old session can never tear down its just-handshaked replacement.
|
||||
- Liveliness GC double-checks with an explicit liveliness `get` before closing: the token key is per-client (not per-epoch), so a _late_ DELETE from a previous incarnation must not kill the live session.
|
||||
- Drain (`SIGTERM`): drop the liveliness token first (clients ride their buffers), finish the in-flight inference, undeclare the control surface, then close. Clients reconnect to another replica invisibly.
|
||||
|
||||
## 7. Latency budget (why the transport is never the bottleneck)
|
||||
|
||||
| stage | LAN | WAN (50 ms RTT) |
|
||||
| ------------------------------ | ------------- | --------------- |
|
||||
| JPEG encode + serialize (edge) | 2–9 ms | 2–9 ms |
|
||||
| uplink | ~2 ms | ~54 ms |
|
||||
| decode + canonical preprocess | 4–10 ms | 4–10 ms |
|
||||
| **inference** | **15–150 ms** | **15–150 ms** |
|
||||
| postprocess + downlink + merge | ~2 ms | ~27 ms |
|
||||
|
||||
Inference dominates (60–85% on LAN). At 30 fps a WAN deployment lands `delay_steps ≈ 4–8`, comfortably inside RTC execution horizons: WAN degrades smoothness parameters, never correctness. Requests are self-clocked by `buffer_time_s` to ~1–4 Hz per robot, so 300 robots cost ~0.3–10 Mbps each.
|
||||
|
||||
Capacity per GPU: `N_max ≈ 0.8 / (request_rate × inference_time)` → ~40 ACT-class or ~5 Pi0-class clients; `max_sessions` enforces it at session open (rejected clients receive the current load and retry another replica).
|
||||
|
||||
## 8. Observability & reproducibility
|
||||
|
||||
The contract is **fully logged + replayable**, not "deterministic" (no seed controls hardware or network jitter):
|
||||
|
||||
- **Client = source of truth**: recording strategies persist observations + executed actions as usual; the engine tracks `(session_id, seq_id, episode_id)` and per-cycle stats.
|
||||
- **Server**: one JSON audit line per request on the `lerobot.policy_server.audit` logger — `{session_id, client_uuid, seq_id, episode_id, queue_wait_ms, inference_ms, superseded, outcome}` — plus `/healthz` and Prometheus-style `/metrics`, and an optional bounded raw request/response capture (`debug.capture_dir`) for byte-exact offline replay.
|
||||
- Every hop shares `(session_id, seq_id)`, so joining a robot-side stutter to a server-side cause is mechanical.
|
||||
|
||||
## 9. File map
|
||||
|
||||
| path | contents |
|
||||
| ---------------------------------- | ---------------------------------------------------------------------------------------------------- |
|
||||
| `policy_server/schema.py` | wire messages, packed header, key-expression schema + sanitizer |
|
||||
| `policy_server/codec.py` | msgpack bodies, tensor codec (LE bytes), JPEG image codec (RGB convention) |
|
||||
| `policy_server/manifest.py` | draccus config: model, zenoh endpoints/TLS, serving mode, capacity, RTC, health |
|
||||
| `policy_server/validation.py` | serving-mode registry + session-open capability matrix |
|
||||
| `policy_server/session.py` | per-client `Session` (pipelines, latest-only mailbox, stats) + identity-safe registry |
|
||||
| `policy_server/scheduler.py` | `Scheduler` seam; `RoundRobinScheduler` |
|
||||
| `policy_server/zenoh_utils.py` | config builder, QoS profiles, lazy import with install hint |
|
||||
| `policy_server/server.py` | `PolicyServer`: zenoh surface, inference worker, GC, warmup, drain, health/metrics |
|
||||
| `rollout/inference/remote.py` | `RemoteInferenceEngine` (the edge client) |
|
||||
| `rollout/inference/factory.py` | `RemoteInferenceConfig`, `FallbackMode`, factory dispatch |
|
||||
| `scripts/lerobot_policy_server.py` | console entry point (`--manifest` → draccus `--config_path`) |
|
||||
| `tests/policy_server/` | codec/schema/validation/scheduler/session units, server logic, zenoh loopback + chaos, golden parity |
|
||||
|
||||
The golden parity test (`tests/policy_server/test_golden_parity.py`) is the standing contract: the remote request path (encode → decode → `run_inference_request` → encode → decode → merge) must produce **byte-identical** action queues to the local RTC compute path on identical inputs.
|
||||
@@ -0,0 +1,53 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Multi-client GPU policy serving over Zenoh (``lerobot-policy-server``).
|
||||
|
||||
The wire schema (:mod:`.schema`) and codecs (:mod:`.codec`) are shared
|
||||
with the edge-side :class:`~lerobot.rollout.inference.remote.RemoteInferenceEngine`.
|
||||
Heavy/optional imports (msgpack, zenoh, torch server) are deferred so the
|
||||
schema stays importable without the ``async`` extra.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from .manifest import (
|
||||
DebugSpec,
|
||||
ModelSpec,
|
||||
PolicyServerManifest,
|
||||
ZenohSpec,
|
||||
)
|
||||
from .schema import SCHEMA_VERSION, MsgHeader, service_prefix
|
||||
|
||||
__all__ = [
|
||||
"SCHEMA_VERSION",
|
||||
"DebugSpec",
|
||||
"ModelSpec",
|
||||
"MsgHeader",
|
||||
"PolicyServer",
|
||||
"PolicyServerManifest",
|
||||
"ZenohSpec",
|
||||
"codec",
|
||||
"service_prefix",
|
||||
]
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
import importlib
|
||||
|
||||
if name == "PolicyServer":
|
||||
return importlib.import_module(".server", __name__).PolicyServer
|
||||
if name == "codec":
|
||||
return importlib.import_module(".codec", __name__)
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
@@ -0,0 +1,262 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""MessagePack codecs for the remote-inference wire schema.
|
||||
|
||||
Encoding rules:
|
||||
- Tensors are raw little-endian bytes + dtype + shape (msgpack's ``bin``
|
||||
type), so decoding is a zero-parse ``np.frombuffer``.
|
||||
- Images are JPEG by default (``jpeg_quality=0`` sends raw bytes). The
|
||||
in-memory convention on both ends is **RGB** uint8 HWC; the OpenCV
|
||||
BGR↔RGB conversion happens inside this module only.
|
||||
- Decoders are tolerant: unknown keys are ignored, missing optional keys
|
||||
take dataclass defaults — schema evolution is additive-only.
|
||||
- No pickle anywhere: nothing in this codec can carry code.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
import msgpack
|
||||
except ImportError as e: # pragma: no cover
|
||||
raise ImportError(
|
||||
"Remote inference requires the 'async' extra: pip install 'lerobot[async]' (eclipse-zenoh + msgpack)"
|
||||
) from e
|
||||
|
||||
from .schema import (
|
||||
IMAGE_CODEC_JPEG,
|
||||
IMAGE_CODEC_RAW,
|
||||
ActionChunkMsg,
|
||||
ObservationMsg,
|
||||
ResetAckMsg,
|
||||
ResetMsg,
|
||||
SessionAckMsg,
|
||||
SessionCloseMsg,
|
||||
SessionOpenMsg,
|
||||
StatusMsg,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tensor codec
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _to_little_endian(arr: np.ndarray) -> np.ndarray:
|
||||
if arr.dtype.byteorder == ">":
|
||||
arr = arr.astype(arr.dtype.newbyteorder("<"))
|
||||
return np.ascontiguousarray(arr)
|
||||
|
||||
|
||||
def encode_tensor(arr: np.ndarray | None) -> dict[str, Any] | None:
|
||||
"""Encode an ndarray as raw little-endian bytes + dtype + shape."""
|
||||
if arr is None:
|
||||
return None
|
||||
arr = np.asarray(arr)
|
||||
# Record the shape before ascontiguousarray, which promotes 0-d to 1-d.
|
||||
shape = list(arr.shape)
|
||||
arr = _to_little_endian(arr)
|
||||
return {"dtype": arr.dtype.str, "shape": shape, "data": arr.tobytes()}
|
||||
|
||||
|
||||
def decode_tensor(obj: dict[str, Any] | None) -> np.ndarray | None:
|
||||
if obj is None:
|
||||
return None
|
||||
dtype = np.dtype(obj["dtype"])
|
||||
if dtype.hasobject:
|
||||
raise ValueError(f"Refusing object dtype {dtype} on the wire")
|
||||
arr = np.frombuffer(obj["data"], dtype=dtype).reshape(obj["shape"])
|
||||
# frombuffer returns a read-only view; copy so downstream torch.from_numpy works.
|
||||
return arr.copy()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Image codec (RGB uint8 HWC on both ends)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def encode_image(img: np.ndarray, jpeg_quality: int = 90) -> dict[str, Any]:
|
||||
"""Encode an RGB uint8 HWC image; ``jpeg_quality=0`` keeps it raw."""
|
||||
img = np.asarray(img)
|
||||
if img.dtype != np.uint8 or img.ndim != 3 or img.shape[2] != 3:
|
||||
raise ValueError(f"Expected uint8 HWC RGB image, got dtype={img.dtype} shape={img.shape}")
|
||||
if jpeg_quality <= 0:
|
||||
return {"codec": IMAGE_CODEC_RAW, "shape": list(img.shape), "data": _to_little_endian(img).tobytes()}
|
||||
ok, buf = cv2.imencode(
|
||||
".jpg", cv2.cvtColor(img, cv2.COLOR_RGB2BGR), [int(cv2.IMWRITE_JPEG_QUALITY), int(jpeg_quality)]
|
||||
)
|
||||
if not ok:
|
||||
raise ValueError("JPEG encoding failed")
|
||||
return {"codec": IMAGE_CODEC_JPEG, "data": buf.tobytes()}
|
||||
|
||||
|
||||
def decode_image(obj: dict[str, Any]) -> np.ndarray:
|
||||
"""Decode to an RGB uint8 HWC image."""
|
||||
codec = obj.get("codec", IMAGE_CODEC_JPEG)
|
||||
if codec == IMAGE_CODEC_RAW:
|
||||
return np.frombuffer(obj["data"], dtype=np.uint8).reshape(obj["shape"]).copy()
|
||||
if codec == IMAGE_CODEC_JPEG:
|
||||
bgr = cv2.imdecode(np.frombuffer(obj["data"], dtype=np.uint8), cv2.IMREAD_COLOR)
|
||||
if bgr is None:
|
||||
raise ValueError("JPEG decoding failed")
|
||||
return cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
|
||||
raise ValueError(f"Unknown image codec: {codec!r}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# msgpack helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _packb(obj: dict[str, Any]) -> bytes:
|
||||
return msgpack.packb(obj, use_bin_type=True)
|
||||
|
||||
|
||||
def _unpackb(data: bytes) -> dict[str, Any]:
|
||||
return msgpack.unpackb(data, raw=False)
|
||||
|
||||
|
||||
def decode_raw(data: bytes) -> dict[str, Any]:
|
||||
"""Decode a body to a plain dict (e.g. to peek a control-plane ``op``)."""
|
||||
return _unpackb(data)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Data-plane messages
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def encode_observation(msg: ObservationMsg) -> bytes:
|
||||
return _packb(
|
||||
{
|
||||
"state": encode_tensor(msg.state),
|
||||
"images": {name: encode_image(img, msg.jpeg_quality) for name, img in msg.images.items()},
|
||||
"task": msg.task,
|
||||
"inference_delay_steps": int(msg.inference_delay_steps),
|
||||
"prefix_model": encode_tensor(msg.prefix_model),
|
||||
"prefix_robot": encode_tensor(msg.prefix_robot),
|
||||
"episode_start": bool(msg.episode_start),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def decode_observation(data: bytes) -> ObservationMsg:
|
||||
obj = _unpackb(data)
|
||||
return ObservationMsg(
|
||||
state=decode_tensor(obj.get("state")),
|
||||
images={name: decode_image(img) for name, img in obj.get("images", {}).items()},
|
||||
task=obj.get("task", ""),
|
||||
inference_delay_steps=obj.get("inference_delay_steps", 0),
|
||||
prefix_model=decode_tensor(obj.get("prefix_model")),
|
||||
prefix_robot=decode_tensor(obj.get("prefix_robot")),
|
||||
episode_start=obj.get("episode_start", False),
|
||||
)
|
||||
|
||||
|
||||
def encode_action_chunk(msg: ActionChunkMsg) -> bytes:
|
||||
return _packb(
|
||||
{
|
||||
"seq_id_echo": int(msg.seq_id_echo),
|
||||
"client_mono_ns_echo": int(msg.client_mono_ns_echo),
|
||||
"episode_id_echo": int(msg.episode_id_echo),
|
||||
"chunk_model": encode_tensor(msg.chunk_model),
|
||||
"chunk_robot": encode_tensor(msg.chunk_robot),
|
||||
"queue_wait_ms": float(msg.queue_wait_ms),
|
||||
"inference_ms": float(msg.inference_ms),
|
||||
"superseded_seqs": int(msg.superseded_seqs),
|
||||
"server_load": float(msg.server_load),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def decode_action_chunk(data: bytes) -> ActionChunkMsg:
|
||||
obj = _unpackb(data)
|
||||
return ActionChunkMsg(
|
||||
seq_id_echo=obj.get("seq_id_echo", 0),
|
||||
client_mono_ns_echo=obj.get("client_mono_ns_echo", 0),
|
||||
episode_id_echo=obj.get("episode_id_echo", 0),
|
||||
chunk_model=decode_tensor(obj.get("chunk_model")),
|
||||
chunk_robot=decode_tensor(obj.get("chunk_robot")),
|
||||
queue_wait_ms=obj.get("queue_wait_ms", 0.0),
|
||||
inference_ms=obj.get("inference_ms", 0.0),
|
||||
superseded_seqs=obj.get("superseded_seqs", 0),
|
||||
server_load=obj.get("server_load", 0.0),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Control-plane messages (flat scalar/list/dict fields → generic codec)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _encode_flat(msg: Any) -> bytes:
|
||||
return _packb(dict(vars(msg).items()))
|
||||
|
||||
|
||||
def _decode_flat(cls: type, data: bytes) -> Any:
|
||||
obj = _unpackb(data)
|
||||
known = set(cls.__dataclass_fields__)
|
||||
return cls(**{k: v for k, v in obj.items() if k in known})
|
||||
|
||||
|
||||
def encode_session_open(msg: SessionOpenMsg) -> bytes:
|
||||
return _encode_flat(msg)
|
||||
|
||||
|
||||
def decode_session_open(data: bytes) -> SessionOpenMsg:
|
||||
return _decode_flat(SessionOpenMsg, data)
|
||||
|
||||
|
||||
def encode_session_ack(msg: SessionAckMsg) -> bytes:
|
||||
return _encode_flat(msg)
|
||||
|
||||
|
||||
def decode_session_ack(data: bytes) -> SessionAckMsg:
|
||||
return _decode_flat(SessionAckMsg, data)
|
||||
|
||||
|
||||
def encode_status(msg: StatusMsg) -> bytes:
|
||||
return _encode_flat(msg)
|
||||
|
||||
|
||||
def decode_status(data: bytes) -> StatusMsg:
|
||||
return _decode_flat(StatusMsg, data)
|
||||
|
||||
|
||||
def encode_reset(msg: ResetMsg) -> bytes:
|
||||
return _encode_flat(msg)
|
||||
|
||||
|
||||
def decode_reset(data: bytes) -> ResetMsg:
|
||||
return _decode_flat(ResetMsg, data)
|
||||
|
||||
|
||||
def encode_reset_ack(msg: ResetAckMsg) -> bytes:
|
||||
return _encode_flat(msg)
|
||||
|
||||
|
||||
def decode_reset_ack(data: bytes) -> ResetAckMsg:
|
||||
return _decode_flat(ResetAckMsg, data)
|
||||
|
||||
|
||||
def encode_session_close(msg: SessionCloseMsg) -> bytes:
|
||||
return _encode_flat(msg)
|
||||
|
||||
|
||||
def decode_session_close(data: bytes) -> SessionCloseMsg:
|
||||
return _decode_flat(SessionCloseMsg, data)
|
||||
@@ -0,0 +1,139 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Policy-server manifest: one process = one (model, revision, dtype, device) on one GPU.
|
||||
|
||||
Loaded from YAML via ``lerobot-policy-server --manifest server.yaml`` (or
|
||||
individual ``--model.repo_or_path=...`` CLI overrides through draccus).
|
||||
Dynamic model loading is deliberately unsupported: pre-warmed processes
|
||||
keep capacity planning honest and keep code-carrying payloads off the wire.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SERVING_MODE_AUTO = "auto"
|
||||
SERVING_MODE_SHARED = "shared"
|
||||
SERVING_MODE_EXCLUSIVE = "exclusive"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelSpec:
|
||||
"""Which policy this process serves, and where it runs."""
|
||||
|
||||
repo_or_path: str = ""
|
||||
revision: str = "main"
|
||||
# Optional torch dtype cast applied after load (e.g. "bfloat16").
|
||||
dtype: str | None = None
|
||||
device: str = "cuda"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ZenohSpec:
|
||||
"""Transport endpoints and security.
|
||||
|
||||
Robots and servers both *dial out* to a ``zenohd`` router in
|
||||
production (``mode=client``). ``mode=peer`` + ``listen_endpoints``
|
||||
supports router-less LAN and loopback test deployments. Multicast
|
||||
scouting is always disabled: fleet discovery is configuration, not
|
||||
protocol magic.
|
||||
"""
|
||||
|
||||
mode: str = "client" # "client" (via router) | "peer" (direct)
|
||||
connect_endpoints: list[str] = field(default_factory=list)
|
||||
listen_endpoints: list[str] = field(default_factory=list)
|
||||
# mTLS material (PEM paths). All three are required for TLS endpoints.
|
||||
tls_root_ca_certificate: str | None = None
|
||||
tls_connect_certificate: str | None = None
|
||||
tls_connect_private_key: str | None = None
|
||||
# Escape hatch: raw JSON5 merged into the zenoh config last.
|
||||
extra_config_json5: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class DebugSpec:
|
||||
"""Optional bounded request/response capture for offline replay."""
|
||||
|
||||
capture_dir: str | None = None
|
||||
capture_max: int = 256
|
||||
|
||||
|
||||
@dataclass
|
||||
class PolicyServerManifest:
|
||||
"""Top-level config for ``lerobot-policy-server``."""
|
||||
|
||||
model: ModelSpec = field(default_factory=ModelSpec)
|
||||
zenoh: ZenohSpec = field(default_factory=ZenohSpec)
|
||||
|
||||
# The task namespace this service is published under. When
|
||||
# ``pin_task`` is true, session opens with a different task string
|
||||
# are rejected; otherwise VLA clients may set the task per session.
|
||||
default_task: str = ""
|
||||
pin_task: bool = False
|
||||
# Optional override for the <task_slug> key segment (defaults to a
|
||||
# slug of ``default_task``).
|
||||
service_name: str = ""
|
||||
|
||||
# "auto" resolves from the policy classification (shared for
|
||||
# chunk-stateless policies, exclusive otherwise). "exclusive" can be
|
||||
# forced; "shared" cannot override a chunk-stateful classification.
|
||||
serving_mode: str = SERVING_MODE_AUTO
|
||||
max_sessions: int = 5
|
||||
warmup_inferences: int = 2
|
||||
|
||||
# FPS contract: warn on mismatch unless strict.
|
||||
trained_fps: float = 30.0
|
||||
strict_fps: bool = False
|
||||
|
||||
# RTC behaviour for this server process (global to the shared policy:
|
||||
# ``init_rtc_processor`` mutates the policy instance, so it is a
|
||||
# per-process decision, not per-session).
|
||||
rtc: RTCConfig = field(default_factory=RTCConfig)
|
||||
|
||||
# Sessions with no liveliness token and no traffic for this long are
|
||||
# garbage-collected (belt-and-braces behind liveliness GC).
|
||||
session_idle_timeout_s: float = 300.0
|
||||
|
||||
# HTTP health + Prometheus metrics port; 0 disables the endpoint.
|
||||
health_port: int = 9100
|
||||
|
||||
debug: DebugSpec = field(default_factory=DebugSpec)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if not self.model.repo_or_path:
|
||||
raise ValueError("--model.repo_or_path is required (the policy this server serves)")
|
||||
if self.serving_mode not in (SERVING_MODE_AUTO, SERVING_MODE_SHARED, SERVING_MODE_EXCLUSIVE):
|
||||
raise ValueError(f"serving_mode must be one of auto|shared|exclusive, got {self.serving_mode!r}")
|
||||
if self.max_sessions < 1:
|
||||
raise ValueError(f"max_sessions must be >= 1, got {self.max_sessions}")
|
||||
if self.zenoh.mode not in ("client", "peer"):
|
||||
raise ValueError(f"zenoh.mode must be 'client' or 'peer', got {self.zenoh.mode!r}")
|
||||
if self.zenoh.mode == "client" and not self.zenoh.connect_endpoints:
|
||||
raise ValueError("zenoh.connect_endpoints is required in client mode (router address)")
|
||||
tls_fields = (
|
||||
self.zenoh.tls_root_ca_certificate,
|
||||
self.zenoh.tls_connect_certificate,
|
||||
self.zenoh.tls_connect_private_key,
|
||||
)
|
||||
if any(tls_fields) and not all(tls_fields):
|
||||
raise ValueError(
|
||||
"TLS requires all of tls_root_ca_certificate, tls_connect_certificate, "
|
||||
"tls_connect_private_key"
|
||||
)
|
||||
@@ -0,0 +1,58 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Scheduling seam between the session registry and the inference worker.
|
||||
|
||||
The v1 scheduler is strict round-robin over sessions with a pending
|
||||
observation: every ready session gets exactly one inference per cycle,
|
||||
so starvation is structurally impossible. The seam exists so that
|
||||
cross-session micro-batching can land later without redesign (blocked
|
||||
today on ``predict_action_chunk`` taking a *scalar* ``inference_delay``).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
|
||||
from .session import Session
|
||||
|
||||
|
||||
class Scheduler(abc.ABC):
|
||||
"""Pick which ready session(s) the worker serves next."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def select(self, ready: list[Session]) -> list[Session]:
|
||||
"""Return the sessions to serve this cycle (subset of ``ready``)."""
|
||||
|
||||
|
||||
class RoundRobinScheduler(Scheduler):
|
||||
"""Serve one session per cycle, fairly, in client_uuid order."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._last_served: str | None = None
|
||||
|
||||
def select(self, ready: list[Session]) -> list[Session]:
|
||||
if not ready:
|
||||
return []
|
||||
ring = sorted(ready, key=lambda s: s.client_uuid)
|
||||
if self._last_served is not None:
|
||||
for i, session in enumerate(ring):
|
||||
if session.client_uuid > self._last_served:
|
||||
ring = ring[i:] + ring[:i]
|
||||
break
|
||||
else:
|
||||
pass # wrap: everyone is <= last served, keep sorted order
|
||||
chosen = ring[0]
|
||||
self._last_served = chosen.client_uuid
|
||||
return [chosen]
|
||||
@@ -0,0 +1,340 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Wire schema for remote policy inference.
|
||||
|
||||
Message dataclasses, the packed attachment header, and the Zenoh
|
||||
key-expression layout shared by the policy server and the remote
|
||||
inference engine. This module is transport-free (no zenoh import) so
|
||||
codecs and validation can be unit-tested without the optional extra.
|
||||
|
||||
Schema discipline: bodies are MessagePack maps decoded tolerantly
|
||||
(unknown keys ignored, missing optional keys defaulted) so evolution is
|
||||
additive-only. Any change to the attachment layout requires a
|
||||
``SCHEMA_VERSION`` bump; versions are negotiated at session open.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import struct
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import numpy as np
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Versioning
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
SCHEMA_VERSION = 1
|
||||
# Oldest schema version this build can still serve.
|
||||
MIN_SUPPORTED_SCHEMA_VERSION = 1
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Attachment header (fixed layout, packed little-endian)
|
||||
#
|
||||
# Parsed without touching the msgpack body so routing, correlation and
|
||||
# supersession decisions never pay deserialization costs. The
|
||||
# ``client_mono_ns`` field is a client-monotonic timestamp that is
|
||||
# OPAQUE to the server: it is echoed back verbatim so the client can
|
||||
# compute round-trip times on its own clock. Wall-clock instants never
|
||||
# cross machines (the clock iron rule).
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_HEADER_STRUCT = struct.Struct("<HBQIqI") # schema_version, msg_type, seq_id, episode_id, mono_ns, epoch
|
||||
|
||||
MSG_TYPE_OBS = 1
|
||||
MSG_TYPE_CHUNK = 2
|
||||
MSG_TYPE_EVENT = 3
|
||||
|
||||
|
||||
@dataclass
|
||||
class MsgHeader:
|
||||
"""Packed per-message header carried in the Zenoh attachment."""
|
||||
|
||||
schema_version: int = SCHEMA_VERSION
|
||||
msg_type: int = MSG_TYPE_OBS
|
||||
seq_id: int = 0
|
||||
episode_id: int = 0
|
||||
client_mono_ns: int = 0
|
||||
session_epoch: int = 0
|
||||
|
||||
def pack(self) -> bytes:
|
||||
return _HEADER_STRUCT.pack(
|
||||
self.schema_version,
|
||||
self.msg_type,
|
||||
self.seq_id,
|
||||
self.episode_id,
|
||||
self.client_mono_ns,
|
||||
self.session_epoch,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def unpack(cls, data: bytes) -> MsgHeader:
|
||||
if len(data) != _HEADER_STRUCT.size:
|
||||
raise ValueError(f"Bad header length: {len(data)} (expected {_HEADER_STRUCT.size})")
|
||||
version, msg_type, seq_id, episode_id, mono_ns, epoch = _HEADER_STRUCT.unpack(data)
|
||||
return cls(
|
||||
schema_version=version,
|
||||
msg_type=msg_type,
|
||||
seq_id=seq_id,
|
||||
episode_id=episode_id,
|
||||
client_mono_ns=mono_ns,
|
||||
session_epoch=epoch,
|
||||
)
|
||||
|
||||
|
||||
HEADER_SIZE = _HEADER_STRUCT.size
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Message bodies
|
||||
#
|
||||
# ``np.ndarray`` fields travel as raw little-endian bytes + dtype + shape
|
||||
# (see codec.py). Images travel JPEG-compressed by default.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
IMAGE_CODEC_JPEG = "jpeg"
|
||||
IMAGE_CODEC_RAW = "raw"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ObservationMsg:
|
||||
"""Client → server: one inference request (data plane)."""
|
||||
|
||||
state: np.ndarray | None = None # float32 [state_dim]
|
||||
images: dict[str, np.ndarray] = field(default_factory=dict) # name -> uint8 HWC RGB
|
||||
task: str = ""
|
||||
inference_delay_steps: int = 0
|
||||
# RTC prefixes: the unexecuted tail of the previous chunk, in model
|
||||
# space (original) and robot space (postprocessed). Both are needed
|
||||
# because the server re-anchors relative-action prefixes against the
|
||||
# current state and the client's ActionQueue.merge needs both chunks.
|
||||
prefix_model: np.ndarray | None = None # float32 [T, action_dim]
|
||||
prefix_robot: np.ndarray | None = None # float32 [T, action_dim]
|
||||
episode_start: bool = False
|
||||
# JPEG quality the images were encoded with; 0 means raw.
|
||||
jpeg_quality: int = 90
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActionChunkMsg:
|
||||
"""Server → client: one action chunk (data plane)."""
|
||||
|
||||
seq_id_echo: int = 0
|
||||
client_mono_ns_echo: int = 0
|
||||
episode_id_echo: int = 0
|
||||
chunk_model: np.ndarray | None = None # float32 [H, action_dim] (pre-postprocessor)
|
||||
chunk_robot: np.ndarray | None = None # float32 [H, action_dim] (postprocessed)
|
||||
# Durations only — measured on the server's monotonic clock, never
|
||||
# compared against client time (the clock iron rule).
|
||||
queue_wait_ms: float = 0.0
|
||||
inference_ms: float = 0.0
|
||||
# Observations from this client that were superseded (overwritten in
|
||||
# the latest-only mailbox) since the previous reply — makes drops visible.
|
||||
superseded_seqs: int = 0
|
||||
server_load: float = 0.0 # active_sessions / max_sessions
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionOpenMsg:
|
||||
"""Client → server (control plane): open and validate a session."""
|
||||
|
||||
op: str = "open"
|
||||
client_uuid: str = ""
|
||||
robot_type: str = ""
|
||||
policy_type: str = ""
|
||||
fps: float = 0.0
|
||||
# Hard sync-safety contract: must equal the server's action feature
|
||||
# names *and order* — this maps chunk columns to motors.
|
||||
action_names: list[str] = field(default_factory=list)
|
||||
camera_names: list[str] = field(default_factory=list) # canonical keys (post-rename)
|
||||
state_dim: int = 0
|
||||
schema_version: int = SCHEMA_VERSION
|
||||
rtc_enabled: bool = False
|
||||
task: str = ""
|
||||
tags: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionAckMsg:
|
||||
"""Server → client (control plane): session accept/reject + capabilities."""
|
||||
|
||||
accepted: bool = False
|
||||
error: str = ""
|
||||
warnings: list[str] = field(default_factory=list)
|
||||
session_id: str = ""
|
||||
model_repo: str = ""
|
||||
model_revision: str = ""
|
||||
policy_type: str = ""
|
||||
action_names: list[str] = field(default_factory=list)
|
||||
expected_cameras: list[str] = field(default_factory=list)
|
||||
state_dim: int = 0
|
||||
chunk_size: int = 0
|
||||
trained_fps: float = 0.0
|
||||
supports_rtc: bool = False
|
||||
rtc_execution_horizon: int = 0
|
||||
serving_mode: str = ""
|
||||
warmed_up: bool = False
|
||||
schema_version: int = SCHEMA_VERSION
|
||||
server_load: float = 0.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class StatusMsg:
|
||||
"""Server → client (control plane): pre-flight capability snapshot."""
|
||||
|
||||
model_repo: str = ""
|
||||
model_revision: str = ""
|
||||
policy_type: str = ""
|
||||
action_names: list[str] = field(default_factory=list)
|
||||
expected_cameras: list[str] = field(default_factory=list)
|
||||
state_dim: int = 0
|
||||
chunk_size: int = 0
|
||||
trained_fps: float = 0.0
|
||||
supports_rtc: bool = False
|
||||
rtc_execution_horizon: int = 0
|
||||
serving_mode: str = ""
|
||||
warmed_up: bool = False
|
||||
min_schema_version: int = MIN_SUPPORTED_SCHEMA_VERSION
|
||||
max_schema_version: int = SCHEMA_VERSION
|
||||
active_sessions: int = 0
|
||||
max_sessions: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResetMsg:
|
||||
"""Client → server (control plane): episode boundary (acknowledged)."""
|
||||
|
||||
client_uuid: str = ""
|
||||
episode_id: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResetAckMsg:
|
||||
"""Server → client: reset acknowledgement."""
|
||||
|
||||
ok: bool = True
|
||||
error: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionCloseMsg:
|
||||
"""Client → server (control plane): graceful session close."""
|
||||
|
||||
op: str = "close"
|
||||
client_uuid: str = ""
|
||||
session_id: str = ""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Key-expression schema
|
||||
#
|
||||
# @lerobot/<service>/<client_uuid>/obs client → server (pub/sub)
|
||||
# @lerobot/<service>/<client_uuid>/action server → client (pub/sub)
|
||||
# @lerobot/<service>/status queryable (capabilities)
|
||||
# @lerobot/<service>/session queryable (open/close)
|
||||
# @lerobot/<service>/<client_uuid>/reset queryable (episode boundary)
|
||||
# @lerobot/<service>/<client_uuid>/alive liveliness token (client)
|
||||
# @lerobot/<service>/server/alive liveliness token (server)
|
||||
#
|
||||
# where <service> = <model_slug>/<revision_slug>/<task_slug>. The task
|
||||
# segment is a *namespace label* derived from the server's default task
|
||||
# (or an explicit service name) — the actual inference task string
|
||||
# travels in the session/observation messages.
|
||||
#
|
||||
# ``@lerobot`` is a verbatim chunk: it is only matched by an identical
|
||||
# chunk, so third-party ``**`` subscribers on a shared router can never
|
||||
# scrape this tree.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
KEY_ROOT = "@lerobot"
|
||||
|
||||
# Conservative allowlist for user-supplied key segments. Everything
|
||||
# else (including '/', '*', '$', '?', '#', whitespace) is folded to '-'.
|
||||
_SEGMENT_SANITIZE_RE = re.compile(r"[^A-Za-z0-9_.\-]+")
|
||||
|
||||
# Reserved final chunks of the key tree; a client UUID must never
|
||||
# collide with them.
|
||||
RESERVED_SEGMENTS = frozenset({"status", "session", "server", "obs", "action", "reset", "alive"})
|
||||
|
||||
|
||||
def sanitize_key_segment(segment: str) -> str:
|
||||
"""Fold an arbitrary string into a single safe Zenoh key chunk."""
|
||||
slug = _SEGMENT_SANITIZE_RE.sub("-", segment.strip()).strip("-.")
|
||||
if not slug:
|
||||
raise ValueError(f"Key segment {segment!r} sanitizes to an empty chunk")
|
||||
if slug in RESERVED_SEGMENTS:
|
||||
raise ValueError(f"Key segment {segment!r} collides with reserved chunk {slug!r}")
|
||||
return slug
|
||||
|
||||
|
||||
def service_prefix(model_id: str, revision: str, task: str) -> str:
|
||||
"""Build the shared key prefix for one served (model, revision, task) triple."""
|
||||
return "/".join(
|
||||
(
|
||||
KEY_ROOT,
|
||||
sanitize_key_segment(model_id),
|
||||
sanitize_key_segment(revision or "main"),
|
||||
sanitize_key_segment(task or "default"),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def obs_key(prefix: str, client_uuid: str) -> str:
|
||||
return f"{prefix}/{sanitize_key_segment(client_uuid)}/obs"
|
||||
|
||||
|
||||
def action_key(prefix: str, client_uuid: str) -> str:
|
||||
return f"{prefix}/{sanitize_key_segment(client_uuid)}/action"
|
||||
|
||||
|
||||
def reset_key(prefix: str, client_uuid: str) -> str:
|
||||
return f"{prefix}/{sanitize_key_segment(client_uuid)}/reset"
|
||||
|
||||
|
||||
def client_alive_key(prefix: str, client_uuid: str) -> str:
|
||||
return f"{prefix}/{sanitize_key_segment(client_uuid)}/alive"
|
||||
|
||||
|
||||
def status_key(prefix: str) -> str:
|
||||
return f"{prefix}/status"
|
||||
|
||||
|
||||
def session_key(prefix: str) -> str:
|
||||
return f"{prefix}/session"
|
||||
|
||||
|
||||
def server_alive_key(prefix: str) -> str:
|
||||
return f"{prefix}/server/alive"
|
||||
|
||||
|
||||
# Single-depth wildcards only — '**' would also match status/session/alive.
|
||||
def obs_wildcard(prefix: str) -> str:
|
||||
return f"{prefix}/*/obs"
|
||||
|
||||
|
||||
def reset_wildcard(prefix: str) -> str:
|
||||
return f"{prefix}/*/reset"
|
||||
|
||||
|
||||
def client_alive_wildcard(prefix: str) -> str:
|
||||
return f"{prefix}/*/alive"
|
||||
|
||||
|
||||
def client_uuid_from_key(key: str) -> str:
|
||||
"""Extract the client UUID chunk from an obs/reset/alive key."""
|
||||
parts = key.split("/")
|
||||
if len(parts) < 2:
|
||||
raise ValueError(f"Key {key!r} has no client chunk")
|
||||
return parts[-2]
|
||||
@@ -0,0 +1,934 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""``lerobot-policy-server``: multi-client GPU inference over Zenoh.
|
||||
|
||||
One process serves one pre-warmed (model, revision, dtype, device) to up
|
||||
to ``max_sessions`` robot clients. The process is **stateless per
|
||||
request**: clients ship RTC prefixes and a delay hint with every
|
||||
observation, so a server crash loses zero control state and reconnects
|
||||
are trivial.
|
||||
|
||||
Concurrency model (pure threads — zenoh-python has no asyncio API):
|
||||
|
||||
zenoh subscriber (.../*/obs) inference worker (1 thread, owns GPU)
|
||||
deposit-only callback: loop:
|
||||
session.deposit(header, body) ──► pick next session with pending obs (RR)
|
||||
(per-client latest-only slot) decode → per-session preprocess
|
||||
predict_action_chunk(delay, prefix)
|
||||
control queryables (status/session/ per-session postprocess → encode
|
||||
reset): validate, mutate session publisher.put(.../<uuid>/action)
|
||||
registry, reply inline
|
||||
|
||||
The single worker thread serializes GPU access; newest-wins mailboxes
|
||||
mean overload degrades into longer cycle times (larger but correct
|
||||
client delays), never into queue buildup.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import http.server
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
import uuid as uuid_module
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.configs import FeatureType
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.policies.factory import get_policy_class, make_pre_post_processors
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.rtc.relative import reanchor_relative_rtc_prefix
|
||||
from lerobot.policies.utils import populate_queues, prepare_observation_for_inference
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
|
||||
|
||||
from . import codec
|
||||
from .manifest import PolicyServerManifest
|
||||
from .scheduler import RoundRobinScheduler, Scheduler
|
||||
from .schema import (
|
||||
SCHEMA_VERSION,
|
||||
ActionChunkMsg,
|
||||
MsgHeader,
|
||||
ObservationMsg,
|
||||
ResetAckMsg,
|
||||
SessionAckMsg,
|
||||
SessionCloseMsg,
|
||||
SessionOpenMsg,
|
||||
StatusMsg,
|
||||
action_key,
|
||||
client_alive_key,
|
||||
client_alive_wildcard,
|
||||
client_uuid_from_key,
|
||||
obs_wildcard,
|
||||
reset_wildcard,
|
||||
server_alive_key,
|
||||
service_prefix,
|
||||
session_key,
|
||||
status_key,
|
||||
)
|
||||
from .session import Session, SessionRegistry
|
||||
from .validation import (
|
||||
PolicyClassification,
|
||||
classify_policy,
|
||||
resolve_serving_mode,
|
||||
validate_session_open,
|
||||
)
|
||||
from .zenoh_utils import action_publisher_qos, build_zenoh_config, import_zenoh
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
audit_logger = logging.getLogger("lerobot.policy_server.audit")
|
||||
|
||||
# Grace period after a client liveliness token drops before its session
|
||||
# is garbage-collected (rides out router blips and reconnects).
|
||||
_LIVELINESS_GC_GRACE_S = 5.0
|
||||
# Worker idle wait between work-event checks (also paces the GC sweep).
|
||||
_WORKER_IDLE_WAIT_S = 0.05
|
||||
|
||||
|
||||
def _normalize_prev_actions_length(prev_actions: torch.Tensor, target_steps: int) -> torch.Tensor:
|
||||
"""Pad or truncate RTC prefix actions to a fixed length (mirrors rtc.py)."""
|
||||
if prev_actions.ndim != 2:
|
||||
raise ValueError(f"Expected 2D [T, A] tensor, got shape={tuple(prev_actions.shape)}")
|
||||
steps, action_dim = prev_actions.shape
|
||||
if steps == target_steps:
|
||||
return prev_actions
|
||||
if steps > target_steps:
|
||||
return prev_actions[:target_steps]
|
||||
padded = torch.zeros((target_steps, action_dim), dtype=prev_actions.dtype, device=prev_actions.device)
|
||||
padded[:steps] = prev_actions
|
||||
return padded
|
||||
|
||||
|
||||
class PolicyServer:
|
||||
"""Zenoh policy server: control-plane queryables + data-plane pub/sub + one GPU worker."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
manifest: PolicyServerManifest,
|
||||
*,
|
||||
policy: PreTrainedPolicy | None = None,
|
||||
policy_cfg: PreTrainedConfig | None = None,
|
||||
processor_factory: Callable[[], tuple[Any, Any]] | None = None,
|
||||
classification: PolicyClassification | None = None,
|
||||
scheduler: Scheduler | None = None,
|
||||
) -> None:
|
||||
"""``policy``/``policy_cfg``/``processor_factory``/``classification``
|
||||
are injection points for tests; production loads everything from
|
||||
the manifest via :meth:`load_policy`.
|
||||
"""
|
||||
self._manifest = manifest
|
||||
self._device = torch.device(manifest.model.device)
|
||||
self._policy = policy
|
||||
self._policy_cfg = policy_cfg
|
||||
self._processor_factory = processor_factory
|
||||
self._classification = classification
|
||||
self._scheduler = scheduler or RoundRobinScheduler()
|
||||
|
||||
self._serving_mode: str = ""
|
||||
self._max_sessions: int = manifest.max_sessions
|
||||
self._rtc_active = False
|
||||
self._warmed_up = False
|
||||
|
||||
self.registry = SessionRegistry()
|
||||
self._registry_lock = threading.Lock() # serializes open/close/GC decisions
|
||||
# Serializes inference against episode resets: in exclusive mode a
|
||||
# reset (policy.reset(), pipeline reset) arriving on a queryable
|
||||
# thread mid-predict would corrupt the in-flight request's state.
|
||||
self._inference_lock = threading.Lock()
|
||||
|
||||
self._zenoh = None
|
||||
self._declarations: list[Any] = []
|
||||
self._alive_token = None
|
||||
|
||||
self._work = threading.Event()
|
||||
self._shutdown = threading.Event()
|
||||
self._worker: threading.Thread | None = None
|
||||
self._health_server: http.server.ThreadingHTTPServer | None = None
|
||||
|
||||
self._unknown_clients_warned: set[str] = set()
|
||||
self._capture_count = 0
|
||||
|
||||
self.metrics: dict[str, float] = {
|
||||
"requests_total": 0,
|
||||
"errors_total": 0,
|
||||
"superseded_total": 0,
|
||||
"dropped_unknown_client_total": 0,
|
||||
"sessions_opened_total": 0,
|
||||
"sessions_closed_total": 0,
|
||||
}
|
||||
self._metrics_lock = threading.Lock()
|
||||
|
||||
task_slug_source = manifest.service_name or manifest.default_task or "default"
|
||||
self.prefix = service_prefix(manifest.model.repo_or_path, manifest.model.revision, task_slug_source)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Loading & warmup
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def load_policy(self) -> None:
|
||||
"""Load config + weights, apply RTC settings, classify, warm up."""
|
||||
manifest = self._manifest
|
||||
if self._policy is None:
|
||||
logger.info(
|
||||
"Loading policy from '%s' (revision=%s)...",
|
||||
manifest.model.repo_or_path,
|
||||
manifest.model.revision,
|
||||
)
|
||||
policy_cfg = PreTrainedConfig.from_pretrained(manifest.model.repo_or_path)
|
||||
policy_cfg.pretrained_path = manifest.model.repo_or_path
|
||||
policy_class = get_policy_class(policy_cfg.type)
|
||||
policy = policy_class.from_pretrained(manifest.model.repo_or_path, config=policy_cfg)
|
||||
self._policy = policy
|
||||
self._policy_cfg = policy_cfg
|
||||
elif self._policy_cfg is None:
|
||||
self._policy_cfg = self._policy.config
|
||||
|
||||
if self._classification is None:
|
||||
self._classification = classify_policy(self._policy)
|
||||
logger.info("Policy classification: %s", self._classification.reason)
|
||||
|
||||
self._serving_mode, self._max_sessions = resolve_serving_mode(self._classification, manifest)
|
||||
logger.info("Serving mode: %s (max_sessions=%d)", self._serving_mode, self._max_sessions)
|
||||
|
||||
# RTC is a per-process decision: init_rtc_processor mutates the
|
||||
# shared policy instance.
|
||||
self._rtc_active = (
|
||||
manifest.rtc.enabled
|
||||
and self._classification.supports_rtc
|
||||
and hasattr(self._policy.config, "rtc_config")
|
||||
)
|
||||
if self._rtc_active:
|
||||
self._policy.config.rtc_config = manifest.rtc
|
||||
if hasattr(self._policy, "init_rtc_processor"):
|
||||
self._policy.init_rtc_processor()
|
||||
logger.info("RTC active (execution_horizon=%d)", manifest.rtc.execution_horizon)
|
||||
|
||||
if manifest.model.dtype:
|
||||
self._policy = self._policy.to(getattr(torch, manifest.model.dtype))
|
||||
self._policy = self._policy.to(self._device)
|
||||
self._policy.eval()
|
||||
|
||||
if not self.action_names:
|
||||
logger.warning(
|
||||
"Policy config has no action_feature_names: the action-order contract "
|
||||
"cannot be enforced at session open. Clients are trusted to match training order."
|
||||
)
|
||||
|
||||
if manifest.warmup_inferences > 0:
|
||||
self._warmup(manifest.warmup_inferences)
|
||||
self._warmed_up = True
|
||||
|
||||
def make_session_processors(self) -> tuple[Any, Any]:
|
||||
"""Build a fresh per-session pre/post pipeline pair.
|
||||
|
||||
The rename step is forced to identity: clients apply their
|
||||
rename map before encoding, so the wire format is canonical
|
||||
policy-feature keys across heterogeneous robots.
|
||||
"""
|
||||
if self._processor_factory is not None:
|
||||
return self._processor_factory()
|
||||
return make_pre_post_processors(
|
||||
policy_cfg=self._policy_cfg,
|
||||
pretrained_path=self._policy_cfg.pretrained_path,
|
||||
preprocessor_overrides={
|
||||
"device_processor": {"device": str(self._device)},
|
||||
"rename_observations_processor": {"rename_map": {}},
|
||||
},
|
||||
)
|
||||
|
||||
def _warmup(self, n: int) -> None:
|
||||
"""Run dummy forwards through the full request path (covers compile/caches)."""
|
||||
logger.info("Warmup: %d inferences...", n)
|
||||
obs = self._synthetic_observation()
|
||||
preprocessor, postprocessor = self.make_session_processors()
|
||||
session = Session(
|
||||
session_id="warmup",
|
||||
client_uuid="warmup",
|
||||
task=self._manifest.default_task,
|
||||
robot_type="",
|
||||
rtc_enabled=self._rtc_active,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
)
|
||||
reply = self.run_inference_request(session, MsgHeader(), obs)
|
||||
if self._rtc_active and reply.chunk_model is not None and n > 1:
|
||||
# Exercise the prefix-conditioned path too so its compile/cache
|
||||
# cost isn't paid by the first real RTC request.
|
||||
action_dim = reply.chunk_model.shape[-1]
|
||||
horizon = self._manifest.rtc.execution_horizon
|
||||
obs.prefix_model = np.zeros((horizon, action_dim), dtype=np.float32)
|
||||
obs.prefix_robot = np.zeros((horizon, action_dim), dtype=np.float32)
|
||||
obs.inference_delay_steps = 1
|
||||
for _ in range(n - 1):
|
||||
self.run_inference_request(session, MsgHeader(), obs)
|
||||
session.close()
|
||||
# Stateful policies must not carry warmup observations into real sessions.
|
||||
if self._serving_mode == "exclusive":
|
||||
self._policy.reset()
|
||||
logger.info("Warmup complete")
|
||||
|
||||
def _synthetic_observation(self) -> ObservationMsg:
|
||||
cfg = self._policy_cfg
|
||||
state_dim = self.state_dim or 1
|
||||
images = {}
|
||||
for key, feature in cfg.input_features.items():
|
||||
if feature.type == FeatureType.VISUAL:
|
||||
channels, height, width = feature.shape
|
||||
images[key] = np.zeros((height, width, channels), dtype=np.uint8)
|
||||
return ObservationMsg(
|
||||
state=np.zeros(state_dim, dtype=np.float32),
|
||||
images=images,
|
||||
task=self._manifest.default_task,
|
||||
jpeg_quality=0,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Capabilities
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@property
|
||||
def action_names(self) -> list[str]:
|
||||
names = getattr(self._policy_cfg, "action_feature_names", None)
|
||||
return list(names) if names else []
|
||||
|
||||
@property
|
||||
def state_dim(self) -> int:
|
||||
cfg = self._policy_cfg
|
||||
for key, feature in getattr(cfg, "input_features", {}).items():
|
||||
if key == OBS_STATE or feature.type == FeatureType.STATE:
|
||||
return int(feature.shape[0])
|
||||
return 0
|
||||
|
||||
@property
|
||||
def chunk_size(self) -> int:
|
||||
cfg = self._policy_cfg
|
||||
for attr in ("chunk_size", "n_action_steps", "horizon"):
|
||||
value = getattr(cfg, attr, None)
|
||||
if value:
|
||||
return int(value)
|
||||
return 0
|
||||
|
||||
def status_snapshot(self) -> StatusMsg:
|
||||
cfg = self._policy_cfg
|
||||
expected_cameras = [
|
||||
key
|
||||
for key, feature in getattr(cfg, "input_features", {}).items()
|
||||
if feature.type == FeatureType.VISUAL
|
||||
]
|
||||
return StatusMsg(
|
||||
model_repo=self._manifest.model.repo_or_path,
|
||||
model_revision=self._manifest.model.revision,
|
||||
policy_type=getattr(cfg, "type", "") or getattr(self._policy, "name", ""),
|
||||
action_names=self.action_names,
|
||||
expected_cameras=expected_cameras,
|
||||
state_dim=self.state_dim,
|
||||
chunk_size=self.chunk_size,
|
||||
trained_fps=self._manifest.trained_fps,
|
||||
supports_rtc=self._rtc_active,
|
||||
rtc_execution_horizon=self._manifest.rtc.execution_horizon if self._rtc_active else 0,
|
||||
serving_mode=self._serving_mode,
|
||||
warmed_up=self._warmed_up,
|
||||
active_sessions=len(self.registry),
|
||||
max_sessions=self._max_sessions,
|
||||
)
|
||||
|
||||
@property
|
||||
def server_load(self) -> float:
|
||||
return len(self.registry) / max(1, self._max_sessions)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# The per-request inference path (pure: no zenoh — parity-testable)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def run_inference_request(
|
||||
self, session: Session, header: MsgHeader, obs: ObservationMsg
|
||||
) -> ActionChunkMsg:
|
||||
"""Mirror of the local RTC loop's compute step (rtc.py), minus the queue merge."""
|
||||
t0 = time.perf_counter()
|
||||
|
||||
obs_np: dict[str, np.ndarray] = {}
|
||||
if obs.state is not None:
|
||||
obs_np[OBS_STATE] = np.asarray(obs.state, dtype=np.float32)
|
||||
for name, img in obs.images.items():
|
||||
obs_np[name] = img
|
||||
|
||||
task = obs.task or session.task or self._manifest.default_task
|
||||
batch = prepare_observation_for_inference(obs_np, self._device, task, session.robot_type)
|
||||
batch["task"] = [task]
|
||||
|
||||
preprocessed = session.preprocessor(batch)
|
||||
|
||||
use_rtc = self._rtc_active and session.rtc_enabled
|
||||
if use_rtc:
|
||||
delay = max(0, int(obs.inference_delay_steps))
|
||||
prev_actions: torch.Tensor | None = None
|
||||
if obs.prefix_model is not None and obs.prefix_model.size:
|
||||
prev_actions = torch.from_numpy(np.ascontiguousarray(obs.prefix_model)).to(self._device)
|
||||
|
||||
if prev_actions is not None and session.relative_step is not None:
|
||||
# Re-anchor the absolute leftover tail against the state
|
||||
# cached by THIS request's preprocess (mirrors rtc.py).
|
||||
raw_state = session.relative_step.get_cached_state()
|
||||
prefix_robot = obs.prefix_robot
|
||||
if raw_state is not None and prefix_robot is not None and prefix_robot.size:
|
||||
prev_actions = reanchor_relative_rtc_prefix(
|
||||
prev_actions_absolute=torch.from_numpy(np.ascontiguousarray(prefix_robot)),
|
||||
current_state=raw_state,
|
||||
relative_step=session.relative_step,
|
||||
normalizer_step=session.normalizer_step,
|
||||
policy_device=self._device,
|
||||
)
|
||||
|
||||
if prev_actions is not None:
|
||||
prev_actions = _normalize_prev_actions_length(
|
||||
prev_actions, target_steps=self._manifest.rtc.execution_horizon
|
||||
)
|
||||
|
||||
actions = self._policy.predict_action_chunk(
|
||||
preprocessed, inference_delay=delay, prev_chunk_left_over=prev_actions
|
||||
)
|
||||
else:
|
||||
if self._classification is not None and self._classification.needs_queue_population:
|
||||
preprocessed = self._populate_select_queues(preprocessed)
|
||||
actions = self._policy.predict_action_chunk(preprocessed)
|
||||
|
||||
original = actions.squeeze(0).clone()
|
||||
processed = session.postprocessor(actions).squeeze(0)
|
||||
inference_ms = (time.perf_counter() - t0) * 1e3
|
||||
|
||||
session.stats.requests += 1
|
||||
session.stats.last_inference_ms = inference_ms
|
||||
superseded = session.take_superseded()
|
||||
|
||||
return ActionChunkMsg(
|
||||
seq_id_echo=header.seq_id,
|
||||
client_mono_ns_echo=header.client_mono_ns,
|
||||
episode_id_echo=header.episode_id,
|
||||
chunk_model=original.detach().to("cpu", torch.float32).numpy(),
|
||||
chunk_robot=processed.detach().to("cpu", torch.float32).numpy(),
|
||||
inference_ms=inference_ms,
|
||||
superseded_seqs=superseded,
|
||||
server_load=self.server_load,
|
||||
)
|
||||
|
||||
def _populate_select_queues(self, batch: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Exclusive-mode shim for select_action-fed policies (diffusion family).
|
||||
|
||||
Mirrors ``DiffusionPolicy.select_action``: stack camera features
|
||||
into OBS_IMAGES, then populate the policy's observation queues so
|
||||
``predict_action_chunk`` sees the same history it would locally.
|
||||
"""
|
||||
policy = self._policy
|
||||
batch = {k: v for k, v in batch.items() if k != ACTION}
|
||||
if getattr(policy.config, "image_features", None):
|
||||
batch[OBS_IMAGES] = torch.stack([batch[key] for key in policy.config.image_features], dim=-4)
|
||||
policy._queues = populate_queues(policy._queues, batch)
|
||||
return batch
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Zenoh wiring
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def start(self) -> None:
|
||||
"""Open zenoh, declare the service surface, start worker + health threads."""
|
||||
if self._policy is None or not self._warmed_up:
|
||||
self.load_policy()
|
||||
|
||||
zenoh = import_zenoh()
|
||||
spec = self._manifest.zenoh
|
||||
self._zenoh = zenoh.open(
|
||||
build_zenoh_config(
|
||||
mode=spec.mode,
|
||||
connect_endpoints=spec.connect_endpoints,
|
||||
listen_endpoints=spec.listen_endpoints,
|
||||
tls_root_ca_certificate=spec.tls_root_ca_certificate,
|
||||
tls_connect_certificate=spec.tls_connect_certificate,
|
||||
tls_connect_private_key=spec.tls_connect_private_key,
|
||||
extra_config_json5=spec.extra_config_json5,
|
||||
)
|
||||
)
|
||||
handlers = zenoh.handlers
|
||||
|
||||
# Data plane: wildcard subscriber, deposit-only callback.
|
||||
self._declarations.append(
|
||||
self._zenoh.declare_subscriber(obs_wildcard(self.prefix), handlers.Callback(self._on_obs))
|
||||
)
|
||||
# Control plane: queryables reply inline (low rate).
|
||||
self._declarations.append(
|
||||
self._zenoh.declare_queryable(status_key(self.prefix), handlers.Callback(self._on_status_query))
|
||||
)
|
||||
self._declarations.append(
|
||||
self._zenoh.declare_queryable(session_key(self.prefix), handlers.Callback(self._on_session_query))
|
||||
)
|
||||
self._declarations.append(
|
||||
self._zenoh.declare_queryable(
|
||||
reset_wildcard(self.prefix), handlers.Callback(self._on_reset_query)
|
||||
)
|
||||
)
|
||||
# Presence: watch client tokens; publish our own.
|
||||
self._declarations.append(
|
||||
self._zenoh.liveliness().declare_subscriber(
|
||||
client_alive_wildcard(self.prefix), handlers.Callback(self._on_liveliness), history=True
|
||||
)
|
||||
)
|
||||
self._alive_token = self._zenoh.liveliness().declare_token(server_alive_key(self.prefix))
|
||||
|
||||
self._shutdown.clear()
|
||||
self._worker = threading.Thread(target=self._worker_loop, daemon=True, name="PolicyServerWorker")
|
||||
self._worker.start()
|
||||
|
||||
if self._manifest.health_port:
|
||||
self._start_health_server(self._manifest.health_port)
|
||||
|
||||
logger.info(
|
||||
"Policy server up: prefix=%s mode=%s max_sessions=%d rtc=%s",
|
||||
self.prefix,
|
||||
self._serving_mode,
|
||||
self._max_sessions,
|
||||
self._rtc_active,
|
||||
)
|
||||
|
||||
def serve_forever(self) -> None:
|
||||
try:
|
||||
while not self._shutdown.is_set():
|
||||
self._shutdown.wait(timeout=0.5)
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Interrupted — draining")
|
||||
finally:
|
||||
self.stop()
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Drain: drop the liveliness token first (clients ride their buffers
|
||||
through the drain), finish the in-flight inference, then close."""
|
||||
if self._alive_token is not None:
|
||||
with contextlib.suppress(Exception):
|
||||
self._alive_token.undeclare()
|
||||
self._alive_token = None
|
||||
|
||||
self._shutdown.set()
|
||||
self._work.set()
|
||||
if self._worker is not None and self._worker.is_alive():
|
||||
self._worker.join(timeout=10.0)
|
||||
if self._worker.is_alive():
|
||||
logger.warning("Inference worker did not join within 10s")
|
||||
self._worker = None
|
||||
|
||||
# Undeclare the control/data surface BEFORE closing sessions so a
|
||||
# late session open cannot be accepted by a server that has
|
||||
# already drained its worker.
|
||||
for declaration in self._declarations:
|
||||
with contextlib.suppress(Exception):
|
||||
declaration.undeclare()
|
||||
self._declarations.clear()
|
||||
|
||||
for session in self.registry.snapshot():
|
||||
self._close_session(session, reason="server shutdown")
|
||||
|
||||
if self._zenoh is not None:
|
||||
with contextlib.suppress(Exception):
|
||||
self._zenoh.close()
|
||||
self._zenoh = None
|
||||
|
||||
if self._health_server is not None:
|
||||
self._health_server.shutdown()
|
||||
self._health_server = None
|
||||
logger.info("Policy server stopped")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Zenoh callbacks (deposit-only on the data plane)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _on_obs(self, sample: Any) -> None:
|
||||
try:
|
||||
attachment = sample.attachment
|
||||
if attachment is None:
|
||||
return
|
||||
header = MsgHeader.unpack(attachment.to_bytes())
|
||||
if header.schema_version != SCHEMA_VERSION:
|
||||
return
|
||||
client_uuid = client_uuid_from_key(str(sample.key_expr))
|
||||
session = self.registry.get(client_uuid)
|
||||
if session is None:
|
||||
self._bump("dropped_unknown_client_total")
|
||||
# Bounded: garbage publishers must not grow this set (or
|
||||
# the log) without limit.
|
||||
if (
|
||||
client_uuid not in self._unknown_clients_warned
|
||||
and len(self._unknown_clients_warned) < 256
|
||||
):
|
||||
self._unknown_clients_warned.add(client_uuid)
|
||||
logger.warning(
|
||||
"Observation from unknown client '%s' (no session) — dropping", client_uuid
|
||||
)
|
||||
return
|
||||
session.deposit(header, sample.payload.to_bytes())
|
||||
self._work.set()
|
||||
except Exception as e: # noqa: BLE001 — a malformed sample must never kill the subscriber
|
||||
logger.error("obs callback error: %s", e)
|
||||
|
||||
def _on_liveliness(self, sample: Any) -> None:
|
||||
try:
|
||||
import zenoh
|
||||
|
||||
client_uuid = client_uuid_from_key(str(sample.key_expr))
|
||||
session = self.registry.get(client_uuid)
|
||||
if session is None:
|
||||
return
|
||||
if sample.kind == zenoh.SampleKind.DELETE:
|
||||
session.alive = False
|
||||
session.token_dropped_mono = time.monotonic()
|
||||
logger.info(
|
||||
"Client '%s' liveliness dropped — GC in %.0fs", client_uuid, _LIVELINESS_GC_GRACE_S
|
||||
)
|
||||
else:
|
||||
session.alive = True
|
||||
session.token_dropped_mono = None
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.error("liveliness callback error: %s", e)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Control-plane queryables
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _on_status_query(self, query: Any) -> None:
|
||||
try:
|
||||
query.reply(status_key(self.prefix), codec.encode_status(self.status_snapshot()))
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.error("status query error: %s", e)
|
||||
|
||||
def _on_session_query(self, query: Any) -> None:
|
||||
try:
|
||||
payload = query.payload.to_bytes() if query.payload is not None else b""
|
||||
op = codec.decode_raw(payload).get("op", "open") if payload else "open"
|
||||
if op == "close":
|
||||
self._handle_session_close(codec.decode_session_close(payload))
|
||||
query.reply(session_key(self.prefix), codec.encode_reset_ack(ResetAckMsg(ok=True)))
|
||||
return
|
||||
ack = self._handle_session_open(codec.decode_session_open(payload))
|
||||
query.reply(session_key(self.prefix), codec.encode_session_ack(ack))
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.error("session query error: %s\n%s", e, traceback.format_exc())
|
||||
with contextlib.suppress(Exception):
|
||||
query.reply(
|
||||
session_key(self.prefix),
|
||||
codec.encode_session_ack(SessionAckMsg(accepted=False, error=f"server error: {e}")),
|
||||
)
|
||||
|
||||
def _on_reset_query(self, query: Any) -> None:
|
||||
try:
|
||||
payload = query.payload.to_bytes() if query.payload is not None else b""
|
||||
msg = codec.decode_reset(payload)
|
||||
session = self.registry.get(msg.client_uuid)
|
||||
if session is None:
|
||||
ack = ResetAckMsg(ok=False, error=f"unknown client '{msg.client_uuid}'")
|
||||
else:
|
||||
# Serialize with the worker: resetting pipelines/policy
|
||||
# mid-predict would corrupt the in-flight request.
|
||||
with self._inference_lock:
|
||||
session.reset_episode(msg.episode_id)
|
||||
if self._serving_mode == "exclusive":
|
||||
# Safe: max_sessions=1, the policy belongs to this client.
|
||||
self._policy.reset()
|
||||
ack = ResetAckMsg(ok=True)
|
||||
query.reply(str(query.key_expr), codec.encode_reset_ack(ack))
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.error("reset query error: %s", e)
|
||||
|
||||
def _handle_session_open(self, msg: SessionOpenMsg) -> SessionAckMsg:
|
||||
capabilities = self.status_snapshot()
|
||||
with self._registry_lock:
|
||||
# A re-handshake from a known client replaces its session and
|
||||
# does not count against capacity.
|
||||
existing = self.registry.get(msg.client_uuid)
|
||||
active = len(self.registry) - (1 if existing else 0)
|
||||
result = validate_session_open(msg, capabilities, self._manifest, active)
|
||||
if not result.ok:
|
||||
logger.warning("Session rejected for '%s': %s", msg.client_uuid, result.error)
|
||||
return SessionAckMsg(accepted=False, error=result.error, server_load=self.server_load)
|
||||
|
||||
preprocessor, postprocessor = self.make_session_processors()
|
||||
session = Session(
|
||||
session_id=uuid_module.uuid4().hex,
|
||||
client_uuid=msg.client_uuid,
|
||||
task=msg.task or self._manifest.default_task,
|
||||
robot_type=msg.robot_type,
|
||||
rtc_enabled=msg.rtc_enabled and not result.rtc_downgraded,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
action_publisher=self._declare_action_publisher(msg.client_uuid),
|
||||
)
|
||||
if session.relative_step is not None and session.relative_step.action_names is None:
|
||||
session.relative_step.action_names = self.action_names or list(msg.action_names)
|
||||
# Sentinel: the FIRST observation of a fresh session always
|
||||
# triggers the episode-boundary branch in _serve_one, so a
|
||||
# mid-episode reconnect can never inherit stale state.
|
||||
session.episode_id = -1
|
||||
displaced = self.registry.add(session)
|
||||
if displaced is not None:
|
||||
displaced.close()
|
||||
self._bump("sessions_closed_total")
|
||||
logger.info("Client '%s' re-handshake: previous session replaced", msg.client_uuid)
|
||||
if self._serving_mode == "exclusive":
|
||||
# A new exclusive session must start from fresh policy state.
|
||||
with self._inference_lock:
|
||||
self._policy.reset()
|
||||
self._bump("sessions_opened_total")
|
||||
self._unknown_clients_warned.discard(msg.client_uuid)
|
||||
logger.info(
|
||||
"Session opened: client=%s session=%s task=%r rtc=%s (%d/%d)",
|
||||
msg.client_uuid,
|
||||
session.session_id,
|
||||
session.task,
|
||||
session.rtc_enabled,
|
||||
len(self.registry),
|
||||
self._max_sessions,
|
||||
)
|
||||
return SessionAckMsg(
|
||||
accepted=True,
|
||||
warnings=result.warnings,
|
||||
session_id=session.session_id,
|
||||
model_repo=capabilities.model_repo,
|
||||
model_revision=capabilities.model_revision,
|
||||
policy_type=capabilities.policy_type,
|
||||
action_names=capabilities.action_names,
|
||||
expected_cameras=capabilities.expected_cameras,
|
||||
state_dim=capabilities.state_dim,
|
||||
chunk_size=capabilities.chunk_size,
|
||||
trained_fps=capabilities.trained_fps,
|
||||
supports_rtc=capabilities.supports_rtc and session.rtc_enabled,
|
||||
rtc_execution_horizon=capabilities.rtc_execution_horizon,
|
||||
serving_mode=capabilities.serving_mode,
|
||||
warmed_up=capabilities.warmed_up,
|
||||
server_load=self.server_load,
|
||||
)
|
||||
|
||||
def _declare_action_publisher(self, client_uuid: str) -> Any:
|
||||
if self._zenoh is None: # pure-logic tests run without transport
|
||||
return None
|
||||
zenoh = import_zenoh()
|
||||
return self._zenoh.declare_publisher(
|
||||
action_key(self.prefix, client_uuid), **action_publisher_qos(zenoh)
|
||||
)
|
||||
|
||||
def _handle_session_close(self, msg: SessionCloseMsg) -> None:
|
||||
session = self.registry.get(msg.client_uuid)
|
||||
if session is not None and (not msg.session_id or msg.session_id == session.session_id):
|
||||
self._close_session(session, reason="client close")
|
||||
|
||||
def _close_session(self, session: Session, reason: str) -> None:
|
||||
# Identity-checked removal: never tear down a same-uuid session
|
||||
# that replaced this one via a re-handshake.
|
||||
removed = self.registry.remove(session.client_uuid, expected=session)
|
||||
if removed is not None:
|
||||
removed.close()
|
||||
self._bump("sessions_closed_total")
|
||||
logger.info("Session closed: client=%s (%s)", session.client_uuid, reason)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Inference worker
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _worker_loop(self) -> None:
|
||||
last_gc = time.monotonic()
|
||||
while not self._shutdown.is_set():
|
||||
ready = [s for s in self.registry.snapshot() if s.has_pending()]
|
||||
if not ready:
|
||||
self._work.wait(timeout=_WORKER_IDLE_WAIT_S)
|
||||
self._work.clear()
|
||||
else:
|
||||
for session in self._scheduler.select(ready):
|
||||
self._serve_one(session)
|
||||
|
||||
now = time.monotonic()
|
||||
if now - last_gc > 1.0:
|
||||
last_gc = now
|
||||
self._gc_sessions(now)
|
||||
|
||||
def _serve_one(self, session: Session) -> None:
|
||||
item = session.take()
|
||||
if item is None:
|
||||
return
|
||||
queue_wait_ms = (time.monotonic() - item.recv_mono) * 1e3
|
||||
outcome = "ok"
|
||||
try:
|
||||
obs = codec.decode_observation(item.payload)
|
||||
self._capture("req", item.payload)
|
||||
|
||||
with self._inference_lock:
|
||||
# Belt-and-braces episode ordering: the first observation of
|
||||
# an episode also announces the boundary (one-in-flight makes
|
||||
# the reset query race-free, but a lost ack must not desync
|
||||
# us; fresh sessions start at the -1 sentinel so their first
|
||||
# request always lands here).
|
||||
if obs.episode_start or item.header.episode_id != session.episode_id:
|
||||
session.preprocessor.reset()
|
||||
session.postprocessor.reset()
|
||||
session.episode_id = item.header.episode_id
|
||||
if self._serving_mode == "exclusive":
|
||||
self._policy.reset()
|
||||
|
||||
reply = self.run_inference_request(session, item.header, obs)
|
||||
reply.queue_wait_ms = queue_wait_ms
|
||||
session.stats.last_queue_wait_ms = queue_wait_ms
|
||||
|
||||
body = codec.encode_action_chunk(reply)
|
||||
self._capture("rep", body)
|
||||
# Local ref: a re-handshake can null session.action_publisher
|
||||
# between the check and the put.
|
||||
publisher = session.action_publisher
|
||||
if publisher is not None:
|
||||
reply_header = MsgHeader(
|
||||
schema_version=SCHEMA_VERSION,
|
||||
msg_type=2, # MSG_TYPE_CHUNK
|
||||
seq_id=item.header.seq_id,
|
||||
episode_id=item.header.episode_id,
|
||||
client_mono_ns=item.header.client_mono_ns,
|
||||
session_epoch=item.header.session_epoch,
|
||||
)
|
||||
publisher.put(body, attachment=reply_header.pack())
|
||||
self._bump("requests_total")
|
||||
self._bump("superseded_total", reply.superseded_seqs)
|
||||
except Exception as e: # noqa: BLE001 — one bad request must not kill the worker
|
||||
outcome = f"error: {e}"
|
||||
session.stats.errors += 1
|
||||
self._bump("errors_total")
|
||||
logger.error(
|
||||
"Inference error for client '%s' seq=%d: %s\n%s",
|
||||
session.client_uuid,
|
||||
item.header.seq_id,
|
||||
e,
|
||||
traceback.format_exc(),
|
||||
)
|
||||
finally:
|
||||
audit_logger.info(
|
||||
json.dumps(
|
||||
{
|
||||
"session_id": session.session_id,
|
||||
"client_uuid": session.client_uuid,
|
||||
"seq_id": item.header.seq_id,
|
||||
"episode_id": item.header.episode_id,
|
||||
"queue_wait_ms": round(queue_wait_ms, 3),
|
||||
"inference_ms": round(session.stats.last_inference_ms, 3),
|
||||
"superseded": session.stats.superseded,
|
||||
"outcome": outcome,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
def _gc_sessions(self, now: float) -> None:
|
||||
for session in self.registry.snapshot():
|
||||
if (
|
||||
session.token_dropped_mono is not None
|
||||
and now - session.token_dropped_mono > _LIVELINESS_GC_GRACE_S
|
||||
):
|
||||
if self._client_token_alive(session.client_uuid):
|
||||
# The DELETE was a late echo of a previous incarnation
|
||||
# (the token key is per client, not per epoch) — the
|
||||
# client re-declared and is alive.
|
||||
session.token_dropped_mono = None
|
||||
session.alive = True
|
||||
continue
|
||||
self._close_session(session, reason="liveliness token dropped")
|
||||
elif now - session.last_seen_mono > self._manifest.session_idle_timeout_s:
|
||||
self._close_session(session, reason="idle timeout")
|
||||
|
||||
def _client_token_alive(self, client_uuid: str) -> bool:
|
||||
"""Confirm a client's liveliness token via an explicit get (GC double-check)."""
|
||||
if self._zenoh is None:
|
||||
return False
|
||||
try:
|
||||
zenoh = import_zenoh()
|
||||
replies = self._zenoh.liveliness().get(
|
||||
client_alive_key(self.prefix, client_uuid),
|
||||
handler=zenoh.handlers.FifoChannel(4),
|
||||
timeout=0.5,
|
||||
)
|
||||
deadline = time.monotonic() + 1.0
|
||||
while time.monotonic() < deadline:
|
||||
try:
|
||||
reply = replies.try_recv()
|
||||
except Exception: # channel closed: no token found # noqa: BLE001
|
||||
return False
|
||||
if reply is None:
|
||||
time.sleep(0.01)
|
||||
continue
|
||||
if reply.ok is not None:
|
||||
return True
|
||||
return False
|
||||
except Exception: # noqa: BLE001 — treat transport trouble as "not alive"
|
||||
return False
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Misc
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _bump(self, key: str, amount: float = 1) -> None:
|
||||
with self._metrics_lock:
|
||||
self.metrics[key] = self.metrics.get(key, 0) + amount
|
||||
|
||||
def _capture(self, kind: str, data: bytes) -> None:
|
||||
capture_dir = self._manifest.debug.capture_dir
|
||||
if not capture_dir:
|
||||
return
|
||||
try:
|
||||
directory = Path(capture_dir)
|
||||
directory.mkdir(parents=True, exist_ok=True)
|
||||
index = self._capture_count % max(1, self._manifest.debug.capture_max)
|
||||
(directory / f"{kind}_{index:05d}.bin").write_bytes(data)
|
||||
if kind == "rep":
|
||||
self._capture_count += 1
|
||||
except OSError as e:
|
||||
logger.warning("debug capture failed: %s", e)
|
||||
|
||||
def _start_health_server(self, port: int) -> None:
|
||||
server_ref = self
|
||||
|
||||
class Handler(http.server.BaseHTTPRequestHandler):
|
||||
def do_GET(self) -> None: # noqa: N802 — http.server API
|
||||
if self.path == "/healthz":
|
||||
worker = server_ref._worker # local ref: stop() may null it mid-read
|
||||
healthy = worker is not None and worker.is_alive()
|
||||
self.send_response(200 if healthy else 503)
|
||||
self.end_headers()
|
||||
self.wfile.write(b"ok" if healthy else b"worker dead")
|
||||
elif self.path == "/metrics":
|
||||
with server_ref._metrics_lock:
|
||||
counters = dict(server_ref.metrics)
|
||||
counters["active_sessions"] = len(server_ref.registry)
|
||||
counters["server_load"] = server_ref.server_load
|
||||
body = "".join(
|
||||
f"lerobot_policy_server_{name} {value}\n" for name, value in sorted(counters.items())
|
||||
).encode()
|
||||
self.send_response(200)
|
||||
self.send_header("Content-Type", "text/plain; version=0.0.4")
|
||||
self.end_headers()
|
||||
self.wfile.write(body)
|
||||
else:
|
||||
self.send_response(404)
|
||||
self.end_headers()
|
||||
|
||||
def log_message(self, *args: Any) -> None: # silence per-request logging
|
||||
pass
|
||||
|
||||
self._health_server = http.server.ThreadingHTTPServer(("0.0.0.0", port), Handler) # nosec B104
|
||||
threading.Thread(target=self._health_server.serve_forever, daemon=True, name="HealthHTTP").start()
|
||||
logger.info("Health/metrics on :%d (/healthz, /metrics)", port)
|
||||
@@ -0,0 +1,203 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Per-client session state and the latest-only observation mailbox.
|
||||
|
||||
The server holds **no cross-request control state**: RTC prefixes and
|
||||
delay hints arrive with every observation. What a session does hold:
|
||||
|
||||
- Per-session processor pipeline instances. Mandatory:
|
||||
``RelativeActionsProcessorStep`` caches ``_last_state`` at preprocess
|
||||
and the postprocessor reads it back — a pipeline shared across clients
|
||||
would be a race.
|
||||
- A one-slot mailbox: the newest observation wins; superseded requests
|
||||
are counted so drops stay visible to the client.
|
||||
- Counters for the audit log and ``/metrics``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from threading import Lock
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from lerobot.processor import (
|
||||
NormalizerProcessorStep,
|
||||
PolicyProcessorPipeline,
|
||||
RelativeActionsProcessorStep,
|
||||
)
|
||||
|
||||
from .schema import MsgHeader
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class MailboxItem:
|
||||
header: MsgHeader
|
||||
payload: bytes
|
||||
recv_mono: float # server-local monotonic deposit time (for queue_wait_ms)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionStats:
|
||||
requests: int = 0
|
||||
errors: int = 0
|
||||
superseded: int = 0 # observations overwritten before inference (lifetime)
|
||||
superseded_since_reply: int = 0
|
||||
last_inference_ms: float = 0.0
|
||||
last_queue_wait_ms: float = 0.0
|
||||
|
||||
|
||||
class Session:
|
||||
"""One connected robot client."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_id: str,
|
||||
client_uuid: str,
|
||||
task: str,
|
||||
robot_type: str,
|
||||
rtc_enabled: bool,
|
||||
preprocessor: PolicyProcessorPipeline,
|
||||
postprocessor: PolicyProcessorPipeline,
|
||||
action_publisher: Any = None, # zenoh.Publisher (Any: zenoh optional at import)
|
||||
) -> None:
|
||||
self.session_id = session_id
|
||||
self.client_uuid = client_uuid
|
||||
self.task = task
|
||||
self.robot_type = robot_type
|
||||
self.rtc_enabled = rtc_enabled
|
||||
self.preprocessor = preprocessor
|
||||
self.postprocessor = postprocessor
|
||||
self.action_publisher = action_publisher
|
||||
|
||||
self.episode_id = 0
|
||||
self.stats = SessionStats()
|
||||
self.alive = True
|
||||
self.last_seen_mono = time.monotonic()
|
||||
# Set when the client's liveliness token drops; GC after grace.
|
||||
self.token_dropped_mono: float | None = None
|
||||
|
||||
# Processor introspection for relative-action prefix re-anchoring
|
||||
# (mirrors RTCInferenceEngine.__init__).
|
||||
self.relative_step = next(
|
||||
(s for s in preprocessor.steps if isinstance(s, RelativeActionsProcessorStep) and s.enabled),
|
||||
None,
|
||||
)
|
||||
self.normalizer_step = next(
|
||||
(s for s in preprocessor.steps if isinstance(s, NormalizerProcessorStep)),
|
||||
None,
|
||||
)
|
||||
|
||||
self._mailbox: MailboxItem | None = None
|
||||
self._mailbox_lock = Lock()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Mailbox (deposit-only callbacks write, the inference worker reads)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def deposit(self, header: MsgHeader, payload: bytes) -> None:
|
||||
"""Latest-only deposit; counts superseded observations."""
|
||||
item = MailboxItem(header=header, payload=payload, recv_mono=time.monotonic())
|
||||
with self._mailbox_lock:
|
||||
if self._mailbox is not None:
|
||||
self.stats.superseded += 1
|
||||
self.stats.superseded_since_reply += 1
|
||||
self._mailbox = item
|
||||
self.alive = True
|
||||
self.token_dropped_mono = None
|
||||
self.last_seen_mono = item.recv_mono
|
||||
|
||||
def take(self) -> MailboxItem | None:
|
||||
with self._mailbox_lock:
|
||||
item, self._mailbox = self._mailbox, None
|
||||
return item
|
||||
|
||||
def take_superseded(self) -> int:
|
||||
"""Atomically read-and-reset the per-reply supersession counter."""
|
||||
with self._mailbox_lock:
|
||||
count = self.stats.superseded_since_reply
|
||||
self.stats.superseded_since_reply = 0
|
||||
return count
|
||||
|
||||
def has_pending(self) -> bool:
|
||||
with self._mailbox_lock:
|
||||
return self._mailbox is not None
|
||||
|
||||
def clear_mailbox(self) -> None:
|
||||
with self._mailbox_lock:
|
||||
self._mailbox = None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Episode boundary
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def reset_episode(self, episode_id: int | None = None) -> None:
|
||||
"""Clear per-episode state. The shared policy is NOT touched here."""
|
||||
self.clear_mailbox()
|
||||
self.preprocessor.reset()
|
||||
self.postprocessor.reset()
|
||||
self.episode_id = episode_id if episode_id is not None else self.episode_id + 1
|
||||
|
||||
def close(self) -> None:
|
||||
self.clear_mailbox()
|
||||
publisher = self.action_publisher
|
||||
self.action_publisher = None
|
||||
if publisher is not None:
|
||||
# Already-closed transport is fine on teardown.
|
||||
with contextlib.suppress(Exception):
|
||||
publisher.undeclare()
|
||||
|
||||
|
||||
class SessionRegistry:
|
||||
"""Thread-safe map of client_uuid → Session."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._sessions: dict[str, Session] = {}
|
||||
self._lock = Lock()
|
||||
|
||||
def add(self, session: Session) -> Session | None:
|
||||
"""Register, returning a displaced same-client session (caller closes it)."""
|
||||
with self._lock:
|
||||
old = self._sessions.get(session.client_uuid)
|
||||
self._sessions[session.client_uuid] = session
|
||||
return old
|
||||
|
||||
def get(self, client_uuid: str) -> Session | None:
|
||||
with self._lock:
|
||||
return self._sessions.get(client_uuid)
|
||||
|
||||
def remove(self, client_uuid: str, expected: Session | None = None) -> Session | None:
|
||||
"""Remove by uuid; with ``expected``, only if it is still that exact session.
|
||||
|
||||
The identity check stops a GC sweep that snapshotted an old
|
||||
session from tearing down its just-handshaked replacement.
|
||||
"""
|
||||
with self._lock:
|
||||
current = self._sessions.get(client_uuid)
|
||||
if current is None or (expected is not None and current is not expected):
|
||||
return None
|
||||
return self._sessions.pop(client_uuid)
|
||||
|
||||
def snapshot(self) -> list[Session]:
|
||||
with self._lock:
|
||||
return list(self._sessions.values())
|
||||
|
||||
def __len__(self) -> int:
|
||||
with self._lock:
|
||||
return len(self._sessions)
|
||||
@@ -0,0 +1,265 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Serving-mode classification and session capability validation.
|
||||
|
||||
Multi-tenancy is engineered, not assumed: sharing one policy instance
|
||||
across sessions is only safe when ``predict_action_chunk`` touches no
|
||||
instance state. That property has been verified per policy family and
|
||||
is encoded here as an explicit registry — never inferred.
|
||||
|
||||
- ``act``/``pi0``/``pi05``: chunk-stateless (verified in-tree).
|
||||
- ``smolvla``: populates its ``_queues`` *inside* ``predict_action_chunk``;
|
||||
with ``n_obs_steps == 1`` the queue is overwritten with the request's
|
||||
own observation before being read, so sharing is safe. With history
|
||||
(``n_obs_steps > 1``) requests would read other sessions' frames →
|
||||
exclusive.
|
||||
- ``diffusion``: ``predict_action_chunk`` reads ``_queues`` that only
|
||||
``select_action`` populates → exclusive, with the server populating
|
||||
the observation queues per request (mirroring ``select_action``).
|
||||
- Policies without a ``predict_action_chunk`` override are refused.
|
||||
- Unverified chunk-API policies default to exclusive; ``shared`` cannot
|
||||
be forced for them (the roadmap upstreams a
|
||||
``supports_stateless_chunking`` attribute to policy classes).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
|
||||
from .manifest import (
|
||||
SERVING_MODE_EXCLUSIVE,
|
||||
SERVING_MODE_SHARED,
|
||||
PolicyServerManifest,
|
||||
)
|
||||
from .schema import SessionOpenMsg, StatusMsg
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ServingClass(Enum):
|
||||
SHARED = "shared"
|
||||
EXCLUSIVE = "exclusive"
|
||||
REFUSED = "refused"
|
||||
|
||||
|
||||
# Verified chunk-stateless families (predict_action_chunk touches no
|
||||
# cross-request instance state).
|
||||
VERIFIED_CHUNK_STATELESS: frozenset[str] = frozenset({"act", "pi0", "pi05"})
|
||||
|
||||
# Families whose predict_action_chunk reads select_action-fed queues:
|
||||
# the server must populate the observation queues per request.
|
||||
QUEUE_POPULATED_IN_SELECT: frozenset[str] = frozenset({"diffusion"})
|
||||
|
||||
# Families whose predict_action_chunk accepts the RTC kwargs
|
||||
# (inference_delay / prev_chunk_left_over) — see each family's
|
||||
# ActionSelectKwargs TypedDict.
|
||||
RTC_CAPABLE: frozenset[str] = frozenset({"pi0", "pi05", "smolvla"})
|
||||
|
||||
|
||||
@dataclass
|
||||
class PolicyClassification:
|
||||
serving_class: ServingClass
|
||||
supports_rtc: bool
|
||||
needs_queue_population: bool
|
||||
reason: str
|
||||
|
||||
|
||||
def _has_chunk_api(policy: PreTrainedPolicy) -> bool:
|
||||
method = getattr(type(policy), "predict_action_chunk", None)
|
||||
return method is not None and method is not PreTrainedPolicy.predict_action_chunk
|
||||
|
||||
|
||||
def classify_policy(policy: PreTrainedPolicy) -> PolicyClassification:
|
||||
"""Classify a loaded policy into a serving class. Registry-driven, never inferred."""
|
||||
name = getattr(policy, "name", type(policy).__name__)
|
||||
|
||||
if not _has_chunk_api(policy):
|
||||
return PolicyClassification(
|
||||
ServingClass.REFUSED,
|
||||
supports_rtc=False,
|
||||
needs_queue_population=False,
|
||||
reason=f"policy '{name}' does not implement predict_action_chunk",
|
||||
)
|
||||
|
||||
supports_rtc = name in RTC_CAPABLE
|
||||
|
||||
if name in VERIFIED_CHUNK_STATELESS:
|
||||
return PolicyClassification(
|
||||
ServingClass.SHARED, supports_rtc, False, f"'{name}' is verified chunk-stateless"
|
||||
)
|
||||
|
||||
if name == "smolvla":
|
||||
n_obs_steps = getattr(policy.config, "n_obs_steps", 1)
|
||||
if n_obs_steps == 1:
|
||||
return PolicyClassification(
|
||||
ServingClass.SHARED,
|
||||
supports_rtc,
|
||||
False,
|
||||
"'smolvla' with n_obs_steps=1 overwrites its queues per request",
|
||||
)
|
||||
return PolicyClassification(
|
||||
ServingClass.EXCLUSIVE,
|
||||
supports_rtc,
|
||||
False,
|
||||
f"'smolvla' with n_obs_steps={n_obs_steps} keeps observation history across requests",
|
||||
)
|
||||
|
||||
if name in QUEUE_POPULATED_IN_SELECT:
|
||||
return PolicyClassification(
|
||||
ServingClass.EXCLUSIVE,
|
||||
supports_rtc,
|
||||
True,
|
||||
f"'{name}' predict_action_chunk reads select_action-fed queues",
|
||||
)
|
||||
|
||||
return PolicyClassification(
|
||||
ServingClass.EXCLUSIVE,
|
||||
supports_rtc,
|
||||
False,
|
||||
f"'{name}' has a chunk API but is not in the verified chunk-stateless registry",
|
||||
)
|
||||
|
||||
|
||||
def resolve_serving_mode(
|
||||
classification: PolicyClassification, manifest: PolicyServerManifest
|
||||
) -> tuple[str, int]:
|
||||
"""Resolve the final (serving_mode, max_sessions) from classification + manifest.
|
||||
|
||||
The manifest may force ``exclusive`` but can never force ``shared``
|
||||
for a policy that is not verified chunk-stateless.
|
||||
"""
|
||||
if classification.serving_class is ServingClass.REFUSED:
|
||||
raise ValueError(f"Refusing to serve this policy: {classification.reason}")
|
||||
|
||||
if manifest.serving_mode == SERVING_MODE_SHARED:
|
||||
if classification.serving_class is not ServingClass.SHARED:
|
||||
raise ValueError(
|
||||
f"serving_mode=shared is unsafe for this policy: {classification.reason}. "
|
||||
"Use serving_mode=exclusive (or auto)."
|
||||
)
|
||||
mode = SERVING_MODE_SHARED
|
||||
elif manifest.serving_mode == SERVING_MODE_EXCLUSIVE:
|
||||
mode = SERVING_MODE_EXCLUSIVE
|
||||
else: # auto
|
||||
mode = (
|
||||
SERVING_MODE_SHARED
|
||||
if classification.serving_class is ServingClass.SHARED
|
||||
else SERVING_MODE_EXCLUSIVE
|
||||
)
|
||||
|
||||
max_sessions = manifest.max_sessions
|
||||
if mode == SERVING_MODE_EXCLUSIVE and max_sessions != 1:
|
||||
logger.warning(
|
||||
"serving_mode=exclusive forces max_sessions=1 (manifest had %d)", manifest.max_sessions
|
||||
)
|
||||
max_sessions = 1
|
||||
return mode, max_sessions
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Session-open validation (fail fast, fail loud)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValidationResult:
|
||||
error: str = "" # non-empty → hard reject
|
||||
warnings: list[str] = field(default_factory=list)
|
||||
# RTC requested but unsupported → downgrade to plain chunk-append.
|
||||
rtc_downgraded: bool = False
|
||||
|
||||
@property
|
||||
def ok(self) -> bool:
|
||||
return not self.error
|
||||
|
||||
|
||||
def validate_session_open(
|
||||
msg: SessionOpenMsg,
|
||||
capabilities: StatusMsg,
|
||||
manifest: PolicyServerManifest,
|
||||
active_sessions: int,
|
||||
) -> ValidationResult:
|
||||
"""Apply the capability matrix from the design doc (§8.4)."""
|
||||
result = ValidationResult()
|
||||
|
||||
# Schema version: client must be within the server's supported range.
|
||||
if not (capabilities.min_schema_version <= msg.schema_version <= capabilities.max_schema_version):
|
||||
result.error = (
|
||||
f"schema_version {msg.schema_version} outside supported range "
|
||||
f"[{capabilities.min_schema_version}, {capabilities.max_schema_version}]"
|
||||
)
|
||||
return result
|
||||
|
||||
# Capacity: reject with current load so the client can retry another replica.
|
||||
if active_sessions >= capabilities.max_sessions:
|
||||
result.error = f"server full: {active_sessions}/{capabilities.max_sessions} sessions active"
|
||||
return result
|
||||
|
||||
# Action names AND order: the hard sync-safety contract mapping
|
||||
# chunk columns to motors.
|
||||
if capabilities.action_names and msg.action_names != capabilities.action_names:
|
||||
result.error = (
|
||||
"action feature names/order mismatch — refusing to map chunk columns to motors.\n"
|
||||
f" server: {capabilities.action_names}\n"
|
||||
f" client: {msg.action_names}"
|
||||
)
|
||||
return result
|
||||
|
||||
# State dim.
|
||||
if capabilities.state_dim and msg.state_dim and msg.state_dim != capabilities.state_dim:
|
||||
result.error = f"state dim mismatch: server={capabilities.state_dim}, client={msg.state_dim}"
|
||||
return result
|
||||
|
||||
# Camera names: the client set must cover the policy's visual features.
|
||||
missing = set(capabilities.expected_cameras) - set(msg.camera_names)
|
||||
if missing:
|
||||
result.error = (
|
||||
f"missing camera features {sorted(missing)} "
|
||||
f"(client provides {sorted(msg.camera_names)}; resolution may differ — names may not)"
|
||||
)
|
||||
return result
|
||||
|
||||
# Task pinning.
|
||||
if manifest.pin_task and msg.task and msg.task != manifest.default_task:
|
||||
result.error = f"task is pinned to {manifest.default_task!r} on this server, got {msg.task!r}"
|
||||
return result
|
||||
|
||||
# fps: warn unless strict.
|
||||
if capabilities.trained_fps and msg.fps and abs(msg.fps - capabilities.trained_fps) > 1e-6:
|
||||
fps_msg = f"client fps={msg.fps:g} != trained fps={capabilities.trained_fps:g}"
|
||||
if manifest.strict_fps:
|
||||
result.error = fps_msg + " (strict_fps=true)"
|
||||
return result
|
||||
result.warnings.append(fps_msg)
|
||||
|
||||
# Policy type sanity (informational mismatch is a warning, not fatal:
|
||||
# the action/state/camera contracts above are the binding ones).
|
||||
if msg.policy_type and capabilities.policy_type and msg.policy_type != capabilities.policy_type:
|
||||
result.warnings.append(
|
||||
f"client expected policy_type={msg.policy_type!r}, server runs {capabilities.policy_type!r}"
|
||||
)
|
||||
|
||||
# RTC: requested but unsupported → serve plain chunks, client appends.
|
||||
if msg.rtc_enabled and not capabilities.supports_rtc:
|
||||
result.rtc_downgraded = True
|
||||
result.warnings.append(
|
||||
"RTC requested but this server/policy does not support it — downgrading to chunk-append"
|
||||
)
|
||||
|
||||
return result
|
||||
@@ -0,0 +1,101 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Zenoh session construction shared by the policy server and the remote engine.
|
||||
|
||||
Verified against eclipse-zenoh 1.9 (thread-based; no asyncio API).
|
||||
Multicast scouting is always disabled — fleet "discovery" is static
|
||||
endpoint configuration plus liveliness tokens, never protocol magic.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_ZENOH_IMPORT_HINT = (
|
||||
"Remote inference requires the 'async' extra: pip install 'lerobot[async]' (eclipse-zenoh + msgpack)"
|
||||
)
|
||||
|
||||
|
||||
def import_zenoh():
|
||||
"""Import zenoh lazily with an actionable error message."""
|
||||
try:
|
||||
import zenoh
|
||||
except ImportError as e:
|
||||
raise ImportError(_ZENOH_IMPORT_HINT) from e
|
||||
return zenoh
|
||||
|
||||
|
||||
def build_zenoh_config(
|
||||
*,
|
||||
mode: str = "client",
|
||||
connect_endpoints: list[str] | None = None,
|
||||
listen_endpoints: list[str] | None = None,
|
||||
tls_root_ca_certificate: str | None = None,
|
||||
tls_connect_certificate: str | None = None,
|
||||
tls_connect_private_key: str | None = None,
|
||||
extra_config_json5: str | None = None,
|
||||
):
|
||||
"""Build a zenoh.Config (values are JSON5 strings — note the inner quoting)."""
|
||||
zenoh = import_zenoh()
|
||||
cfg = zenoh.Config()
|
||||
cfg.insert_json5("mode", json.dumps(mode))
|
||||
cfg.insert_json5("scouting/multicast/enabled", "false")
|
||||
if connect_endpoints:
|
||||
cfg.insert_json5("connect/endpoints", json.dumps(list(connect_endpoints)))
|
||||
if listen_endpoints:
|
||||
cfg.insert_json5("listen/endpoints", json.dumps(list(listen_endpoints)))
|
||||
if tls_root_ca_certificate:
|
||||
cfg.insert_json5("transport/link/tls/root_ca_certificate", json.dumps(tls_root_ca_certificate))
|
||||
if tls_connect_certificate:
|
||||
cfg.insert_json5("transport/link/tls/connect_certificate", json.dumps(tls_connect_certificate))
|
||||
if tls_connect_private_key:
|
||||
cfg.insert_json5("transport/link/tls/connect_private_key", json.dumps(tls_connect_private_key))
|
||||
if extra_config_json5:
|
||||
merged = json.loads(extra_config_json5)
|
||||
for key, value in merged.items():
|
||||
cfg.insert_json5(key, json.dumps(value))
|
||||
return cfg
|
||||
|
||||
|
||||
def action_publisher_qos(zenoh) -> dict:
|
||||
"""QoS for the action topic: RELIABLE + congestion DROP (never BLOCK) + express.
|
||||
|
||||
DROP so one dead robot uplink can never stall the server's publish
|
||||
path; a dropped chunk is recoverable by design — the client's action
|
||||
buffer keeps the robot moving and the next chunk replaces it.
|
||||
"""
|
||||
return {
|
||||
"reliability": zenoh.Reliability.RELIABLE,
|
||||
"congestion_control": zenoh.CongestionControl.DROP,
|
||||
"express": True,
|
||||
"priority": zenoh.Priority.INTERACTIVE_HIGH,
|
||||
}
|
||||
|
||||
|
||||
def obs_publisher_qos(zenoh) -> dict:
|
||||
"""QoS for the observation topic: best-effort drop, default priority.
|
||||
|
||||
Intentional drop already happened at the client's one-slot holder;
|
||||
if the uplink stalls, dropping a frame protects the control loop.
|
||||
"""
|
||||
return {
|
||||
"reliability": zenoh.Reliability.BEST_EFFORT,
|
||||
"congestion_control": zenoh.CongestionControl.DROP,
|
||||
"express": False,
|
||||
"priority": zenoh.Priority.DATA,
|
||||
}
|
||||
@@ -32,7 +32,6 @@ from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable, Iterable, Sequence
|
||||
@@ -281,6 +280,11 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
|
||||
before_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False)
|
||||
after_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False)
|
||||
_serialized_state_filenames: tuple[str | None, ...] | None = field(
|
||||
default=None,
|
||||
init=False,
|
||||
repr=False,
|
||||
)
|
||||
|
||||
def __call__(self, data: TInput) -> TOutput:
|
||||
"""Processes input data through the full pipeline.
|
||||
@@ -338,30 +342,108 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
transition = processor_step(transition)
|
||||
yield transition
|
||||
|
||||
def _save_pretrained(self, save_directory: Path, **kwargs):
|
||||
"""Internal method to comply with `HubMixin`'s saving mechanism.
|
||||
def _get_sanitized_name(self) -> str:
|
||||
"""Return a filename-safe version of the pipeline name.
|
||||
|
||||
This method does the actual saving work and is called by HubMixin.save_pretrained.
|
||||
Returns:
|
||||
The lower-cased pipeline name with non-alphanumeric characters replaced by underscores.
|
||||
"""
|
||||
config_filename = kwargs.pop("config_filename", None)
|
||||
return re.sub(r"[^a-zA-Z0-9_]", "_", self.name.lower())
|
||||
|
||||
# Sanitize the pipeline name to create a valid filename prefix.
|
||||
sanitized_name = re.sub(r"[^a-zA-Z0-9_]", "_", self.name.lower())
|
||||
@staticmethod
|
||||
def _get_state_filename(
|
||||
*,
|
||||
step_index: int,
|
||||
registry_name: str | None,
|
||||
sanitized_name: str,
|
||||
) -> str:
|
||||
"""Return the safetensors filename for one stateful processor step.
|
||||
|
||||
if config_filename is None:
|
||||
config_filename = f"{sanitized_name}.json"
|
||||
Args:
|
||||
step_index: The index of the processor step in this pipeline.
|
||||
registry_name: The registered processor step name, if available.
|
||||
sanitized_name: The filename-safe pipeline name.
|
||||
|
||||
config: dict[str, Any] = {
|
||||
Returns:
|
||||
The state filename used by the existing disk serialization format.
|
||||
"""
|
||||
if registry_name:
|
||||
return f"{sanitized_name}_step_{step_index}_{registry_name}.safetensors"
|
||||
|
||||
return f"{sanitized_name}_step_{step_index}.safetensors"
|
||||
|
||||
@staticmethod
|
||||
def _get_state_key(state_filename: str) -> str:
|
||||
"""Return the in-memory state key for a serialized state filename.
|
||||
|
||||
Args:
|
||||
state_filename: The `.safetensors` filename from the serialized config.
|
||||
|
||||
Returns:
|
||||
The state key used by the in-memory pipeline state dictionary.
|
||||
"""
|
||||
return state_filename.removesuffix(".safetensors")
|
||||
|
||||
@staticmethod
|
||||
def _get_state_filenames_from_config(loaded_config: dict[str, Any]) -> tuple[str | None, ...]:
|
||||
"""Return serialized state filenames in step order.
|
||||
|
||||
Args:
|
||||
loaded_config: A validated processor pipeline config.
|
||||
|
||||
Returns:
|
||||
A tuple containing each step's serialized state filename, or None for stateless steps.
|
||||
"""
|
||||
return tuple(step_entry.get("state_file") for step_entry in loaded_config["steps"])
|
||||
|
||||
def _get_state_filenames_for_loading(self) -> tuple[str | None, ...]:
|
||||
"""Return expected state filenames in step order for `load_state_dict()`.
|
||||
|
||||
Returns:
|
||||
The preserved serialized state filenames when available, otherwise filenames derived from
|
||||
current non-empty step state.
|
||||
"""
|
||||
if self._serialized_state_filenames is not None and len(self._serialized_state_filenames) == len(
|
||||
self.steps
|
||||
):
|
||||
return self._serialized_state_filenames
|
||||
|
||||
sanitized_name = self._get_sanitized_name()
|
||||
state_filenames: list[str | None] = []
|
||||
|
||||
for step_index, processor_step in enumerate(self.steps):
|
||||
step_state_dict = processor_step.state_dict()
|
||||
if not step_state_dict:
|
||||
state_filenames.append(None)
|
||||
continue
|
||||
|
||||
registry_name = getattr(processor_step.__class__, "_registry_name", None)
|
||||
state_filenames.append(
|
||||
self._get_state_filename(
|
||||
step_index=step_index,
|
||||
registry_name=registry_name,
|
||||
sanitized_name=sanitized_name,
|
||||
)
|
||||
)
|
||||
|
||||
return tuple(state_filenames)
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""Return the JSON-serializable pipeline configuration.
|
||||
|
||||
Returns:
|
||||
A dictionary with the same content that `save_pretrained()` writes as JSON.
|
||||
"""
|
||||
sanitized_name = self._get_sanitized_name()
|
||||
pipeline_config: dict[str, Any] = {
|
||||
"name": self.name,
|
||||
"steps": [],
|
||||
}
|
||||
|
||||
# Iterate through each step to build its configuration entry.
|
||||
for step_index, processor_step in enumerate(self.steps):
|
||||
registry_name = getattr(processor_step.__class__, "_registry_name", None)
|
||||
|
||||
step_entry: dict[str, Any] = {}
|
||||
# Prefer registry name for portability, otherwise fall back to full class path.
|
||||
|
||||
if registry_name:
|
||||
step_entry["registry_name"] = registry_name
|
||||
else:
|
||||
@@ -369,31 +451,110 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
f"{processor_step.__class__.__module__}.{processor_step.__class__.__name__}"
|
||||
)
|
||||
|
||||
# Save step configuration if `get_config` is implemented.
|
||||
if hasattr(processor_step, "get_config"):
|
||||
step_entry["config"] = processor_step.get_config()
|
||||
step_entry["config"] = processor_step.get_config()
|
||||
|
||||
# Save step state if `state_dict` is implemented and returns a non-empty dict.
|
||||
if hasattr(processor_step, "state_dict"):
|
||||
state = processor_step.state_dict()
|
||||
if state:
|
||||
# Clone tensors to avoid modifying the original state.
|
||||
cloned_state = {key: tensor.clone() for key, tensor in state.items()}
|
||||
step_state_dict = processor_step.state_dict()
|
||||
if step_state_dict:
|
||||
step_entry["state_file"] = self._get_state_filename(
|
||||
step_index=step_index,
|
||||
registry_name=registry_name,
|
||||
sanitized_name=sanitized_name,
|
||||
)
|
||||
|
||||
# Create a unique filename for the state file.
|
||||
if registry_name:
|
||||
state_filename = f"{sanitized_name}_step_{step_index}_{registry_name}.safetensors"
|
||||
else:
|
||||
state_filename = f"{sanitized_name}_step_{step_index}.safetensors"
|
||||
pipeline_config["steps"].append(step_entry)
|
||||
|
||||
save_file(cloned_state, os.path.join(str(save_directory), state_filename))
|
||||
step_entry["state_file"] = state_filename
|
||||
return pipeline_config
|
||||
|
||||
config["steps"].append(step_entry)
|
||||
def state_dict(self) -> dict[str, dict[str, torch.Tensor]]:
|
||||
"""Return pipeline state tensors grouped by state key.
|
||||
|
||||
# Write the main configuration JSON file.
|
||||
with open(os.path.join(str(save_directory), config_filename), "w") as file_pointer:
|
||||
json.dump(config, file_pointer, indent=2)
|
||||
Returns:
|
||||
A dictionary mapping suffixless state keys to cloned step state dictionaries.
|
||||
"""
|
||||
sanitized_name = self._get_sanitized_name()
|
||||
pipeline_state_dict: dict[str, dict[str, torch.Tensor]] = {}
|
||||
|
||||
for step_index, processor_step in enumerate(self.steps):
|
||||
step_state_dict = processor_step.state_dict()
|
||||
if not step_state_dict:
|
||||
continue
|
||||
|
||||
registry_name = getattr(processor_step.__class__, "_registry_name", None)
|
||||
state_filename = self._get_state_filename(
|
||||
step_index=step_index,
|
||||
registry_name=registry_name,
|
||||
sanitized_name=sanitized_name,
|
||||
)
|
||||
state_key = self._get_state_key(state_filename)
|
||||
pipeline_state_dict[state_key] = {
|
||||
tensor_name: tensor.clone() for tensor_name, tensor in step_state_dict.items()
|
||||
}
|
||||
|
||||
return pipeline_state_dict
|
||||
|
||||
def load_state_dict(
|
||||
self,
|
||||
state_dict: dict[str, dict[str, torch.Tensor]],
|
||||
) -> None:
|
||||
"""Load pipeline state tensors into the existing steps.
|
||||
|
||||
Args:
|
||||
state_dict: A dictionary mapping suffixless state keys to step state dictionaries.
|
||||
|
||||
Raises:
|
||||
KeyError: If loading finds missing expected state or unexpected extra state.
|
||||
"""
|
||||
expected_state_filenames = self._get_state_filenames_for_loading()
|
||||
used_state_keys: set[str] = set()
|
||||
|
||||
for step_index, (processor_step, state_filename) in enumerate(
|
||||
zip(self.steps, expected_state_filenames, strict=True)
|
||||
):
|
||||
if state_filename is None:
|
||||
continue
|
||||
|
||||
state_key = self._get_state_key(state_filename)
|
||||
if state_key not in state_dict:
|
||||
raise KeyError(
|
||||
f"Missing state key '{state_key}' for processor step {step_index}. "
|
||||
f"Available state keys: {sorted(state_dict.keys())}"
|
||||
)
|
||||
|
||||
processor_step.load_state_dict(state_dict[state_key])
|
||||
used_state_keys.add(state_key)
|
||||
|
||||
unexpected_state_keys = set(state_dict) - used_state_keys
|
||||
if unexpected_state_keys:
|
||||
expected_state_key_set = {
|
||||
self._get_state_key(state_filename)
|
||||
for state_filename in expected_state_filenames
|
||||
if state_filename is not None
|
||||
}
|
||||
raise KeyError(
|
||||
f"Unexpected processor state keys: {sorted(unexpected_state_keys)}. "
|
||||
f"Expected state keys: {sorted(expected_state_key_set)}"
|
||||
)
|
||||
|
||||
def _save_pretrained(self, save_directory: Path, **kwargs) -> None:
|
||||
"""Internal method to comply with `HubMixin`'s saving mechanism.
|
||||
|
||||
This method does the actual saving work and is called by HubMixin.save_pretrained.
|
||||
"""
|
||||
config_filename = kwargs.pop("config_filename", None)
|
||||
sanitized_name = self._get_sanitized_name()
|
||||
|
||||
if config_filename is None:
|
||||
config_filename = f"{sanitized_name}.json"
|
||||
|
||||
pipeline_config = self.get_config()
|
||||
pipeline_state_dict = self.state_dict()
|
||||
|
||||
for state_key, step_state_dict in pipeline_state_dict.items():
|
||||
state_filename = f"{state_key}.safetensors"
|
||||
save_file(step_state_dict, save_directory / state_filename)
|
||||
|
||||
with open(save_directory / config_filename, "w") as file_pointer:
|
||||
json.dump(pipeline_config, file_pointer, indent=2)
|
||||
|
||||
def save_pretrained(
|
||||
self,
|
||||
@@ -577,12 +738,54 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
cls._validate_overrides_used(validated_overrides, loaded_config)
|
||||
|
||||
# 5. Construct and return the final pipeline instance
|
||||
return cls(
|
||||
pipeline = cls(
|
||||
steps=steps,
|
||||
name=loaded_config.get("name", "DataProcessorPipeline"),
|
||||
to_transition=to_transition or cast(Callable[[TInput], EnvTransition], batch_to_transition),
|
||||
to_output=to_output or cast(Callable[[EnvTransition], TOutput], transition_to_batch),
|
||||
)
|
||||
pipeline._serialized_state_filenames = cls._get_state_filenames_from_config(loaded_config)
|
||||
return pipeline
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls,
|
||||
config: dict[str, Any],
|
||||
*,
|
||||
state_dict: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
overrides: dict[str, Any] | None = None,
|
||||
to_transition: Callable[[TInput], EnvTransition] | None = None,
|
||||
to_output: Callable[[EnvTransition], TOutput] | None = None,
|
||||
) -> DataProcessorPipeline[TInput, TOutput]:
|
||||
"""Build a pipeline from an in-memory config and optional state tensors.
|
||||
|
||||
Args:
|
||||
config: A config dictionary with the same structure as the saved processor JSON.
|
||||
state_dict: Optional in-memory pipeline state grouped by suffixless state key.
|
||||
overrides: Optional constructor overrides keyed by registry name or class name.
|
||||
to_transition: Optional converter from input data to `EnvTransition`.
|
||||
to_output: Optional converter from `EnvTransition` to output data.
|
||||
|
||||
Returns:
|
||||
A processor pipeline built from the config and optional state.
|
||||
"""
|
||||
cls._validate_loaded_config("<in-memory config>", config, "<in-memory config>")
|
||||
|
||||
steps, remaining_override_keys = cls._build_steps_from_config(config, overrides or {})
|
||||
cls._validate_overrides_used(remaining_override_keys, config)
|
||||
|
||||
pipeline = cls(
|
||||
steps=steps,
|
||||
name=config.get("name", "DataProcessorPipeline"),
|
||||
to_transition=to_transition or cast(Callable[[TInput], EnvTransition], batch_to_transition),
|
||||
to_output=to_output or cast(Callable[[EnvTransition], TOutput], transition_to_batch),
|
||||
)
|
||||
pipeline._serialized_state_filenames = cls._get_state_filenames_from_config(config)
|
||||
|
||||
if state_dict is not None:
|
||||
pipeline.load_state_dict(state_dict)
|
||||
|
||||
return pipeline
|
||||
|
||||
@classmethod
|
||||
def _load_config(
|
||||
@@ -666,9 +869,7 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
) from e
|
||||
|
||||
@classmethod
|
||||
def _validate_loaded_config(
|
||||
cls, model_id: str, loaded_config: dict[str, Any], config_filename: str
|
||||
) -> None:
|
||||
def _validate_loaded_config(cls, model_id: str, loaded_config: Any, config_filename: str) -> None:
|
||||
"""Validate that a config was loaded and is a valid processor config.
|
||||
|
||||
This method validates processor config format with intelligent migration detection:
|
||||
@@ -688,7 +889,7 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
|
||||
Args:
|
||||
model_id: The model identifier (used for migration detection)
|
||||
loaded_config: The loaded config dictionary (guaranteed non-None)
|
||||
loaded_config: The loaded config value to validate (may be non-dict)
|
||||
config_filename: The config filename that was loaded (for error messages)
|
||||
|
||||
Raises:
|
||||
@@ -702,9 +903,14 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
model_id,
|
||||
f"Config file '{config_filename}' is not a valid processor configuration",
|
||||
)
|
||||
loaded_config_description = (
|
||||
list(loaded_config.keys())
|
||||
if isinstance(loaded_config, dict)
|
||||
else type(loaded_config).__name__
|
||||
)
|
||||
raise ValueError(
|
||||
f"Config file '{config_filename}' is not a valid processor configuration. "
|
||||
f"Expected a config with 'steps' field, but got: {list(loaded_config.keys())}"
|
||||
f"Expected a config with 'steps' field, but got: {loaded_config_description}"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -766,26 +972,41 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
ImportError: If a step class cannot be imported or found in registry
|
||||
ValueError: If a step cannot be instantiated with its configuration
|
||||
"""
|
||||
steps: list[ProcessorStep] = []
|
||||
override_keys = set(overrides.keys())
|
||||
steps, remaining_override_keys = cls._build_steps_from_config(loaded_config, overrides)
|
||||
|
||||
for step_entry in loaded_config["steps"]:
|
||||
# 1. Get step class and key
|
||||
step_class, step_key = cls._resolve_step_class(step_entry)
|
||||
|
||||
# 2. Instantiate step with overrides
|
||||
step_instance = cls._instantiate_step(step_entry, step_class, step_key, overrides)
|
||||
|
||||
# 3. Load step state if available
|
||||
for step_instance, step_entry in zip(steps, loaded_config["steps"], strict=True):
|
||||
cls._load_step_state(step_instance, step_entry, model_id, base_path, hub_download_kwargs)
|
||||
|
||||
# 4. Track used overrides
|
||||
if step_key in override_keys:
|
||||
override_keys.discard(step_key)
|
||||
return steps, remaining_override_keys
|
||||
|
||||
steps.append(step_instance)
|
||||
@classmethod
|
||||
def _build_steps_from_config(
|
||||
cls,
|
||||
loaded_config: dict[str, Any],
|
||||
overrides: dict[str, Any],
|
||||
) -> tuple[list[ProcessorStep], set[str]]:
|
||||
"""Build processor steps from config without loading tensor state.
|
||||
|
||||
return steps, override_keys
|
||||
Args:
|
||||
loaded_config: The loaded processor configuration.
|
||||
overrides: User-provided constructor overrides keyed by step key.
|
||||
|
||||
Returns:
|
||||
A tuple containing instantiated steps and override keys that did not match a step.
|
||||
"""
|
||||
processor_steps: list[ProcessorStep] = []
|
||||
remaining_override_keys = set(overrides.keys())
|
||||
|
||||
for step_entry in loaded_config["steps"]:
|
||||
step_class, step_key = cls._resolve_step_class(step_entry)
|
||||
processor_step = cls._instantiate_step(step_entry, step_class, step_key, overrides)
|
||||
|
||||
if step_key in remaining_override_keys:
|
||||
remaining_override_keys.discard(step_key)
|
||||
|
||||
processor_steps.append(processor_step)
|
||||
|
||||
return processor_steps, remaining_override_keys
|
||||
|
||||
@classmethod
|
||||
def _resolve_step_class(cls, step_entry: dict[str, Any]) -> tuple[type[ProcessorStep], str]:
|
||||
@@ -1096,7 +1317,7 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def _is_processor_config(cls, config: dict) -> bool:
|
||||
def _is_processor_config(cls, config: Any) -> bool:
|
||||
"""Check if config follows DataProcessorPipeline format.
|
||||
|
||||
This method validates the processor configuration structure:
|
||||
@@ -1147,6 +1368,9 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
Returns:
|
||||
True if config follows valid DataProcessorPipeline format, False otherwise
|
||||
"""
|
||||
if not isinstance(config, dict):
|
||||
return False
|
||||
|
||||
# Must have a "steps" field with a list of step configurations
|
||||
if not isinstance(config.get("steps"), list):
|
||||
return False
|
||||
|
||||
@@ -23,6 +23,7 @@ from .configs import (
|
||||
DAggerKeyboardConfig,
|
||||
DAggerPedalConfig,
|
||||
DAggerStrategyConfig,
|
||||
EpisodicStrategyConfig,
|
||||
HighlightStrategyConfig,
|
||||
RolloutConfig,
|
||||
RolloutStrategyConfig,
|
||||
@@ -38,8 +39,10 @@ from .context import (
|
||||
build_rollout_context,
|
||||
)
|
||||
from .inference import (
|
||||
FallbackMode,
|
||||
InferenceEngine,
|
||||
InferenceEngineConfig,
|
||||
RemoteInferenceConfig,
|
||||
RTCInferenceConfig,
|
||||
RTCInferenceEngine,
|
||||
SyncInferenceConfig,
|
||||
@@ -49,6 +52,7 @@ from .inference import (
|
||||
from .strategies import (
|
||||
BaseStrategy,
|
||||
DAggerStrategy,
|
||||
EpisodicStrategy,
|
||||
HighlightStrategy,
|
||||
RolloutStrategy,
|
||||
SentryStrategy,
|
||||
@@ -66,12 +70,16 @@ __all__ = [
|
||||
"HardwareContext",
|
||||
"HighlightStrategy",
|
||||
"HighlightStrategyConfig",
|
||||
"EpisodicStrategy",
|
||||
"EpisodicStrategyConfig",
|
||||
"FallbackMode",
|
||||
"InferenceEngine",
|
||||
"InferenceEngineConfig",
|
||||
"PolicyContext",
|
||||
"ProcessorContext",
|
||||
"RTCInferenceConfig",
|
||||
"RTCInferenceEngine",
|
||||
"RemoteInferenceConfig",
|
||||
"RolloutConfig",
|
||||
"RolloutContext",
|
||||
"RolloutStrategy",
|
||||
|
||||
@@ -121,6 +121,35 @@ class DAggerPedalConfig:
|
||||
upload: str = "KEY_C"
|
||||
|
||||
|
||||
@RolloutStrategyConfig.register_subclass("episodic")
|
||||
@dataclass
|
||||
class EpisodicStrategyConfig(RolloutStrategyConfig):
|
||||
"""Episode-oriented recording that mirrors the behavior of ``lerobot-record``.
|
||||
|
||||
Records ``dataset.num_episodes`` episodes of maximum ``dataset.episode_time_s`` each.
|
||||
After each episode, runs ``dataset.reset_time_s`` seconds of reset time.
|
||||
|
||||
Keyboard controls:
|
||||
Right arrow — end current episode or reset phase early
|
||||
Left arrow — discard current episode and re-record
|
||||
Escape — stop recording session
|
||||
|
||||
In between episodes:
|
||||
- if there is no teleop leader, the robot is held at its initial joint positions captured at startup.
|
||||
- else, the robot is moved smoothly to the position of the teleop leader.
|
||||
"""
|
||||
|
||||
# This only applies if there are no teleop leaders specified.
|
||||
# When True (default), moves the robot back to the joint positions captured at startup.
|
||||
# Otherwise, leave the robot in its current position.
|
||||
reset_to_initial_position: bool = True
|
||||
|
||||
# Whether to turn on or off the leader -> follower smooth handover behavior.
|
||||
# When False, fallback to follower -> leader handover.
|
||||
# Note that leader -> follower handover is only supported when the leader has `send_feedback` capability.
|
||||
smooth_leader_to_follower_handover: bool = True
|
||||
|
||||
|
||||
@RolloutStrategyConfig.register_subclass("dagger")
|
||||
@dataclass
|
||||
class DAggerStrategyConfig(RolloutStrategyConfig):
|
||||
@@ -229,7 +258,13 @@ class RolloutConfig:
|
||||
|
||||
# TODO(Steven): DAgger shouldn't require a dataset (user may want to just rollout+intervene without recording), but for now we require it to simplify the implementation.
|
||||
needs_dataset = isinstance(
|
||||
self.strategy, (SentryStrategyConfig, HighlightStrategyConfig, DAggerStrategyConfig)
|
||||
self.strategy,
|
||||
(
|
||||
SentryStrategyConfig,
|
||||
HighlightStrategyConfig,
|
||||
DAggerStrategyConfig,
|
||||
EpisodicStrategyConfig,
|
||||
),
|
||||
)
|
||||
if needs_dataset and (self.dataset is None or not self.dataset.repo_id):
|
||||
raise ValueError(f"{self.strategy.type} strategy requires --dataset.repo_id to be set")
|
||||
|
||||
@@ -51,6 +51,7 @@ from lerobot.utils.feature_utils import combine_feature_dicts, hw_to_dataset_fea
|
||||
from .configs import BaseStrategyConfig, DAggerStrategyConfig, RolloutConfig
|
||||
from .inference import (
|
||||
InferenceEngine,
|
||||
RemoteInferenceConfig,
|
||||
RTCInferenceConfig,
|
||||
SyncInferenceConfig,
|
||||
create_inference_engine,
|
||||
@@ -113,11 +114,17 @@ class HardwareContext:
|
||||
|
||||
@dataclass
|
||||
class PolicyContext:
|
||||
"""Loaded policy and its inference engine."""
|
||||
"""Loaded policy and its inference engine.
|
||||
|
||||
policy: PreTrainedPolicy
|
||||
preprocessor: PolicyProcessorPipeline
|
||||
postprocessor: PolicyProcessorPipeline
|
||||
``policy``/``preprocessor``/``postprocessor`` are ``None`` for the
|
||||
weightless remote backend (``--inference.type=remote``): inference
|
||||
runs on a ``lerobot-policy-server`` and strategies only ever consume
|
||||
``inference``.
|
||||
"""
|
||||
|
||||
policy: PreTrainedPolicy | None
|
||||
preprocessor: PolicyProcessorPipeline | None
|
||||
postprocessor: PolicyProcessorPipeline | None
|
||||
inference: InferenceEngine
|
||||
|
||||
|
||||
@@ -172,54 +179,66 @@ def build_rollout_context(
|
||||
fails fast without touching the robot.
|
||||
"""
|
||||
is_rtc = isinstance(cfg.inference, RTCInferenceConfig)
|
||||
is_remote = isinstance(cfg.inference, RemoteInferenceConfig)
|
||||
|
||||
# --- 1. Policy (heavy I/O, but no hardware yet) -------------------
|
||||
logger.info("Loading policy from '%s'...", cfg.policy.pretrained_path)
|
||||
# Remote inference keeps the edge weightless: the config-only
|
||||
# PreTrainedConfig (already loaded by RolloutConfig.__post_init__,
|
||||
# no weight download) is all the client needs for pre-flight
|
||||
# validation and action ordering.
|
||||
policy_config = cfg.policy
|
||||
policy_class = get_policy_class(policy_config.type)
|
||||
|
||||
if hasattr(policy_config, "compile_model"):
|
||||
policy_config.compile_model = cfg.use_torch_compile
|
||||
|
||||
if policy_config.type == "vqbet" and cfg.device == "mps":
|
||||
raise NotImplementedError(
|
||||
"Current implementation of VQBeT does not support `mps` backend. "
|
||||
"Please use `cpu` or `cuda` backend."
|
||||
policy = None
|
||||
if is_remote:
|
||||
logger.info(
|
||||
"Remote inference: weightless client for '%s' (no weights downloaded)",
|
||||
cfg.policy.pretrained_path,
|
||||
)
|
||||
|
||||
if policy_config.use_peft:
|
||||
from peft import PeftConfig, PeftModel
|
||||
|
||||
peft_path = policy_config.pretrained_path
|
||||
peft_config = PeftConfig.from_pretrained(peft_path)
|
||||
policy = policy_class.from_pretrained(
|
||||
pretrained_name_or_path=peft_config.base_model_name_or_path, config=policy_config
|
||||
)
|
||||
policy = PeftModel.from_pretrained(policy, peft_path, config=peft_config)
|
||||
else:
|
||||
policy = policy_class.from_pretrained(policy_config.pretrained_path, config=policy_config)
|
||||
logger.info("Loading policy from '%s'...", cfg.policy.pretrained_path)
|
||||
policy_class = get_policy_class(policy_config.type)
|
||||
|
||||
if is_rtc:
|
||||
policy.config.rtc_config = cfg.inference.rtc
|
||||
if hasattr(policy, "init_rtc_processor"):
|
||||
policy.init_rtc_processor()
|
||||
if hasattr(policy_config, "compile_model"):
|
||||
policy_config.compile_model = cfg.use_torch_compile
|
||||
|
||||
policy = policy.to(cfg.device)
|
||||
policy.eval()
|
||||
logger.info("Policy loaded: type=%s, device=%s", policy_config.type, cfg.device)
|
||||
if policy_config.type == "vqbet" and cfg.device == "mps":
|
||||
raise NotImplementedError(
|
||||
"Current implementation of VQBeT does not support `mps` backend. "
|
||||
"Please use `cpu` or `cuda` backend."
|
||||
)
|
||||
|
||||
if cfg.use_torch_compile and policy.type not in ("pi0", "pi05"):
|
||||
try:
|
||||
if hasattr(torch, "compile"):
|
||||
compile_kwargs = {
|
||||
"backend": cfg.torch_compile_backend,
|
||||
"mode": cfg.torch_compile_mode,
|
||||
"options": {"triton.cudagraphs": False},
|
||||
}
|
||||
policy.predict_action_chunk = torch.compile(policy.predict_action_chunk, **compile_kwargs)
|
||||
logger.info("torch.compile applied to predict_action_chunk")
|
||||
except Exception as e:
|
||||
logger.warning("Failed to apply torch.compile: %s", e)
|
||||
if policy_config.use_peft:
|
||||
from peft import PeftConfig, PeftModel
|
||||
|
||||
peft_path = policy_config.pretrained_path
|
||||
peft_config = PeftConfig.from_pretrained(peft_path)
|
||||
policy = policy_class.from_pretrained(
|
||||
pretrained_name_or_path=peft_config.base_model_name_or_path, config=policy_config
|
||||
)
|
||||
policy = PeftModel.from_pretrained(policy, peft_path, config=peft_config)
|
||||
else:
|
||||
policy = policy_class.from_pretrained(policy_config.pretrained_path, config=policy_config)
|
||||
|
||||
if is_rtc:
|
||||
policy.config.rtc_config = cfg.inference.rtc
|
||||
if hasattr(policy, "init_rtc_processor"):
|
||||
policy.init_rtc_processor()
|
||||
|
||||
policy = policy.to(cfg.device)
|
||||
policy.eval()
|
||||
logger.info("Policy loaded: type=%s, device=%s", policy_config.type, cfg.device)
|
||||
|
||||
if cfg.use_torch_compile and policy.type not in ("pi0", "pi05"):
|
||||
try:
|
||||
if hasattr(torch, "compile"):
|
||||
compile_kwargs = {
|
||||
"backend": cfg.torch_compile_backend,
|
||||
"mode": cfg.torch_compile_mode,
|
||||
"options": {"triton.cudagraphs": False},
|
||||
}
|
||||
policy.predict_action_chunk = torch.compile(policy.predict_action_chunk, **compile_kwargs)
|
||||
logger.info("torch.compile applied to predict_action_chunk")
|
||||
except Exception as e:
|
||||
logger.warning("Failed to apply torch.compile: %s", e)
|
||||
|
||||
# --- 2. Robot-side processors (user-supplied or defaults) --------
|
||||
if (
|
||||
@@ -378,31 +397,36 @@ def build_rollout_context(
|
||||
logger.info("Dataset ready: %s (%d existing episodes)", dataset.repo_id, dataset.num_episodes)
|
||||
|
||||
# --- 6. Policy pre/post processors (needs dataset stats if any) ---
|
||||
dataset_stats = None
|
||||
if dataset is not None:
|
||||
dataset_stats = rename_stats(
|
||||
dataset.meta.stats,
|
||||
cfg.rename_map,
|
||||
# Remote inference runs the policy processors server-side (per
|
||||
# session); the edge ships canonical dataset-format observations.
|
||||
preprocessor = None
|
||||
postprocessor = None
|
||||
if not is_remote:
|
||||
dataset_stats = None
|
||||
if dataset is not None:
|
||||
dataset_stats = rename_stats(
|
||||
dataset.meta.stats,
|
||||
cfg.rename_map,
|
||||
)
|
||||
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=policy_config,
|
||||
pretrained_path=cfg.policy.pretrained_path,
|
||||
dataset_stats=dataset_stats,
|
||||
preprocessor_overrides={
|
||||
"device_processor": {"device": cfg.device},
|
||||
"rename_observations_processor": {"rename_map": cfg.rename_map},
|
||||
},
|
||||
)
|
||||
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=policy_config,
|
||||
pretrained_path=cfg.policy.pretrained_path,
|
||||
dataset_stats=dataset_stats,
|
||||
preprocessor_overrides={
|
||||
"device_processor": {"device": cfg.device},
|
||||
"rename_observations_processor": {"rename_map": cfg.rename_map},
|
||||
},
|
||||
)
|
||||
|
||||
if isinstance(cfg.inference, SyncInferenceConfig) and any(
|
||||
isinstance(step, RelativeActionsProcessorStep) and step.enabled
|
||||
for step in getattr(preprocessor, "steps", ())
|
||||
):
|
||||
raise NotImplementedError(
|
||||
"SyncInferenceEngine does not support policies with relative actions for now."
|
||||
"Use --inference.type=rtc or remove relative action processor steps from the policy pipeline."
|
||||
)
|
||||
if isinstance(cfg.inference, SyncInferenceConfig) and any(
|
||||
isinstance(step, RelativeActionsProcessorStep) and step.enabled
|
||||
for step in getattr(preprocessor, "steps", ())
|
||||
):
|
||||
raise NotImplementedError(
|
||||
"SyncInferenceEngine does not support policies with relative actions for now."
|
||||
"Use --inference.type=rtc or remove relative action processor steps from the policy pipeline."
|
||||
)
|
||||
|
||||
# --- 7. Inference strategy (needs policy + pre/post + hardware) --
|
||||
logger.info(
|
||||
@@ -425,6 +449,8 @@ def build_rollout_context(
|
||||
use_torch_compile=cfg.use_torch_compile,
|
||||
compile_warmup_inferences=cfg.compile_warmup_inferences,
|
||||
shutdown_event=shutdown_event,
|
||||
policy_config=policy_config,
|
||||
rename_map=cfg.rename_map,
|
||||
)
|
||||
|
||||
# --- 8. Assemble ---------------------------------------------------
|
||||
|
||||
@@ -14,13 +14,18 @@
|
||||
|
||||
"""Inference engine package — backend-agnostic action production.
|
||||
|
||||
Concrete backends (``sync``, ``rtc``, ...) expose the same small interface so
|
||||
rollout strategies never branch on which backend is in use.
|
||||
Concrete backends (``sync``, ``rtc``, ``remote``, ...) expose the same
|
||||
small interface so rollout strategies never branch on which backend is
|
||||
in use.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from .base import InferenceEngine
|
||||
from .factory import (
|
||||
FallbackMode,
|
||||
InferenceEngineConfig,
|
||||
RemoteInferenceConfig,
|
||||
RTCInferenceConfig,
|
||||
SyncInferenceConfig,
|
||||
create_inference_engine,
|
||||
@@ -29,11 +34,23 @@ from .rtc import RTCInferenceEngine
|
||||
from .sync import SyncInferenceEngine
|
||||
|
||||
__all__ = [
|
||||
"FallbackMode",
|
||||
"InferenceEngine",
|
||||
"InferenceEngineConfig",
|
||||
"RTCInferenceConfig",
|
||||
"RTCInferenceEngine",
|
||||
"RemoteInferenceConfig",
|
||||
"RemoteInferenceEngine",
|
||||
"SyncInferenceConfig",
|
||||
"SyncInferenceEngine",
|
||||
"create_inference_engine",
|
||||
]
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
# Lazy: RemoteInferenceEngine pulls in msgpack/zenoh ('async' extra).
|
||||
if name == "RemoteInferenceEngine":
|
||||
from .remote import RemoteInferenceEngine
|
||||
|
||||
return RemoteInferenceEngine
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
|
||||
@@ -14,9 +14,9 @@
|
||||
|
||||
"""Inference engine configs and factory.
|
||||
|
||||
Selection is explicit via ``--inference.type=sync|rtc``. Adding a new
|
||||
backend requires registering its config subclass and dispatching it in
|
||||
:func:`create_inference_engine`.
|
||||
Selection is explicit via ``--inference.type=sync|rtc|remote``. Adding a
|
||||
new backend requires registering its config subclass and dispatching it
|
||||
in :func:`create_inference_engine`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -24,10 +24,12 @@ from __future__ import annotations
|
||||
import abc
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from enum import StrEnum
|
||||
from threading import Event
|
||||
|
||||
import draccus
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.processor import PolicyProcessorPipeline
|
||||
@@ -74,6 +76,73 @@ class RTCInferenceConfig(InferenceEngineConfig):
|
||||
queue_threshold: int = 30
|
||||
|
||||
|
||||
class FallbackMode(StrEnum):
|
||||
"""What ``get_action`` returns when the remote queue runs dry (STALLED)."""
|
||||
|
||||
HOLD = "hold" # return None: the robot holds its last commanded position
|
||||
REPEAT_LAST = "repeat_last" # re-send the last executed action
|
||||
ZERO = "zero" # explicit zero command (required for velocity-controlled robots)
|
||||
|
||||
|
||||
@InferenceEngineConfig.register_subclass("remote")
|
||||
@dataclass
|
||||
class RemoteInferenceConfig(InferenceEngineConfig):
|
||||
"""Network inference against a ``lerobot-policy-server`` over Zenoh.
|
||||
|
||||
The edge stays weightless: ``--policy.path`` resolves to a
|
||||
config-only ``PreTrainedConfig`` (no weight download) used for
|
||||
pre-flight validation and action ordering. Requires the ``async``
|
||||
extra (``pip install 'lerobot[async]'``).
|
||||
"""
|
||||
|
||||
# Transport: robots dial out to a zenoh router (NAT-friendly).
|
||||
connect_endpoint: str = "tcp/localhost:7447"
|
||||
# "client" via a zenohd router (production) | "peer" direct (LAN/tests).
|
||||
zenoh_mode: str = "client"
|
||||
tls_ca: str | None = None
|
||||
tls_cert: str | None = None
|
||||
tls_key: str | None = None
|
||||
|
||||
# Service addressing: which (model, revision, task) key tree to dial.
|
||||
# service_model_id defaults to --policy.path; service_task to the
|
||||
# rollout task. These must match the server manifest's namespace.
|
||||
service_model_id: str = ""
|
||||
service_revision: str = "main"
|
||||
service_task: str = ""
|
||||
|
||||
# Identity: "" → a fresh uuid4 per run. Set a stable ID per robot for
|
||||
# fleet-wide log correlation and per-robot router ACLs.
|
||||
client_uuid: str = ""
|
||||
|
||||
# Observation encoding: JPEG quality (0 = raw, LAN/debug only).
|
||||
jpeg_quality: int = 90
|
||||
|
||||
# Self-clocking: request the next chunk when the local queue holds
|
||||
# less than this many seconds of playback.
|
||||
buffer_time_s: float = 0.5
|
||||
|
||||
# Safety: never execute an action whose source observation is older
|
||||
# than this (bounds open-loop execution after a network stall).
|
||||
max_action_age_s: float = 3.0
|
||||
# Fallback when the queue runs dry (see FallbackMode).
|
||||
fallback: FallbackMode = FallbackMode.HOLD
|
||||
|
||||
# Watchdogs & reconnection.
|
||||
degraded_after_s: float = 1.0
|
||||
request_timeout_s: float = 5.0
|
||||
handshake_timeout_s: float = 2.0
|
||||
reconnect_initial_backoff_s: float = 0.5
|
||||
reconnect_max_backoff_s: float = 10.0
|
||||
max_offline_s: float = 60.0
|
||||
|
||||
# RTC settings (enabled → replace-merge with prefix conditioning when
|
||||
# the server supports it; otherwise downgraded to chunk-append).
|
||||
rtc: RTCConfig = field(default_factory=RTCConfig)
|
||||
|
||||
# Free-form labels forwarded in the session handshake (telemetry only).
|
||||
tags: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Factory
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -82,9 +151,9 @@ class RTCInferenceConfig(InferenceEngineConfig):
|
||||
def create_inference_engine(
|
||||
config: InferenceEngineConfig,
|
||||
*,
|
||||
policy: PreTrainedPolicy,
|
||||
preprocessor: PolicyProcessorPipeline,
|
||||
postprocessor: PolicyProcessorPipeline,
|
||||
policy: PreTrainedPolicy | None,
|
||||
preprocessor: PolicyProcessorPipeline | None,
|
||||
postprocessor: PolicyProcessorPipeline | None,
|
||||
robot_wrapper: ThreadSafeRobot,
|
||||
hw_features: dict,
|
||||
dataset_features: dict,
|
||||
@@ -95,10 +164,19 @@ def create_inference_engine(
|
||||
use_torch_compile: bool = False,
|
||||
compile_warmup_inferences: int = 2,
|
||||
shutdown_event: Event | None = None,
|
||||
policy_config: PreTrainedConfig | None = None,
|
||||
rename_map: dict[str, str] | None = None,
|
||||
) -> InferenceEngine:
|
||||
"""Instantiate the appropriate inference engine from a config object."""
|
||||
"""Instantiate the appropriate inference engine from a config object.
|
||||
|
||||
``policy``/``preprocessor``/``postprocessor`` are required for the
|
||||
local backends (``sync``, ``rtc``) and must be ``None``-free there;
|
||||
the ``remote`` backend is weightless and needs only ``policy_config``.
|
||||
"""
|
||||
logger.info("Creating inference engine: %s", config.type)
|
||||
if isinstance(config, SyncInferenceConfig):
|
||||
if policy is None or preprocessor is None or postprocessor is None:
|
||||
raise ValueError("sync inference requires a loaded policy and processors")
|
||||
return SyncInferenceEngine(
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
@@ -110,6 +188,8 @@ def create_inference_engine(
|
||||
robot_type=robot_wrapper.robot_type,
|
||||
)
|
||||
if isinstance(config, RTCInferenceConfig):
|
||||
if policy is None or preprocessor is None or postprocessor is None:
|
||||
raise ValueError("rtc inference requires a loaded policy and processors")
|
||||
return RTCInferenceEngine(
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
@@ -125,4 +205,25 @@ def create_inference_engine(
|
||||
rtc_queue_threshold=config.queue_threshold,
|
||||
shutdown_event=shutdown_event,
|
||||
)
|
||||
if isinstance(config, RemoteInferenceConfig):
|
||||
if policy_config is None:
|
||||
raise ValueError("remote inference requires policy_config (from config-only --policy.path)")
|
||||
if use_torch_compile:
|
||||
logger.warning("--use_torch_compile is ignored with remote inference (server-side concern)")
|
||||
if device not in (None, "cpu"):
|
||||
logger.warning("--device=%s is ignored with remote inference (server-side concern)", device)
|
||||
# Lazy import: eclipse-zenoh/msgpack live behind the 'async' extra.
|
||||
from .remote import RemoteInferenceEngine
|
||||
|
||||
return RemoteInferenceEngine(
|
||||
config=config,
|
||||
policy_config=policy_config,
|
||||
hw_features=hw_features,
|
||||
ordered_action_keys=ordered_action_keys,
|
||||
task=task,
|
||||
fps=fps,
|
||||
robot_type=robot_wrapper.robot_type,
|
||||
rename_map=rename_map,
|
||||
shutdown_event=shutdown_event,
|
||||
)
|
||||
raise ValueError(f"Unknown inference engine type: {type(config).__name__}")
|
||||
|
||||
@@ -0,0 +1,851 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Remote inference engine: network-decoupled policy inference over Zenoh.
|
||||
|
||||
The same architecture as :class:`RTCInferenceEngine` with the thread
|
||||
boundary replaced by a network boundary. The edge stays **weightless**
|
||||
(no policy weights, no policy processors); a ``lerobot-policy-server``
|
||||
runs the heavy half. All chunk state — leftover prefixes, latency
|
||||
tracking, delay computation — lives client-side in the existing
|
||||
``ActionQueue``/``LatencyTracker`` machinery, so the server is stateless
|
||||
per request and a server crash loses zero control state.
|
||||
|
||||
Threading model:
|
||||
- **Main thread** (strategy loop): ``notify_observation`` writes a
|
||||
latest-only slot; ``get_action`` pops the local queue and applies the
|
||||
staleness bound + fallback ladder. Never any I/O.
|
||||
- **Network worker** (one daemon thread): self-clocked by
|
||||
``buffer_time_s``, publishes one observation, awaits its chunk (or
|
||||
timeout), merges, repeats. One-in-flight is a *correctness*
|
||||
requirement: ``idx_before``/prefix snapshots must serialize with
|
||||
merges.
|
||||
- **Zenoh threads**: deposit-only callbacks (chunk → bounded queue,
|
||||
liveliness → event).
|
||||
|
||||
Clock iron rule: wall-clock instants never cross machines. The header's
|
||||
``client_mono_ns`` is opaque to the server and echoed back; the server
|
||||
reports only durations.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
import math
|
||||
import queue as queue_module
|
||||
import time
|
||||
import traceback
|
||||
import uuid as uuid_module
|
||||
from threading import Event, Lock, Thread
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.policies.rtc import ActionQueue, LatencyTracker
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.policy_server import codec
|
||||
from lerobot.policy_server.schema import (
|
||||
MSG_TYPE_OBS,
|
||||
SCHEMA_VERSION,
|
||||
MsgHeader,
|
||||
ObservationMsg,
|
||||
ResetMsg,
|
||||
SessionAckMsg,
|
||||
SessionCloseMsg,
|
||||
SessionOpenMsg,
|
||||
action_key,
|
||||
client_alive_key,
|
||||
obs_key,
|
||||
reset_key,
|
||||
sanitize_key_segment,
|
||||
server_alive_key,
|
||||
service_prefix,
|
||||
session_key,
|
||||
status_key,
|
||||
)
|
||||
from lerobot.policy_server.zenoh_utils import build_zenoh_config, import_zenoh, obs_publisher_qos
|
||||
from lerobot.utils.constants import OBS_STATE, OBS_STR
|
||||
from lerobot.utils.feature_utils import build_dataset_frame
|
||||
|
||||
from .base import InferenceEngine
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
|
||||
from .factory import RemoteInferenceConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_IDLE_SLEEP_S = 0.01
|
||||
_MAX_CONSECUTIVE_WORKER_ERRORS = 10
|
||||
|
||||
|
||||
class ClientState:
|
||||
"""Fail-safe state machine states (see the design doc §9.2)."""
|
||||
|
||||
CONNECTING = "CONNECTING"
|
||||
STREAMING = "STREAMING"
|
||||
DEGRADED = "DEGRADED"
|
||||
STALLED = "STALLED"
|
||||
RECONNECTING = "RECONNECTING"
|
||||
DEAD = "DEAD"
|
||||
|
||||
|
||||
class RemoteInferenceEngine(InferenceEngine):
|
||||
"""``--inference.type=remote``: weightless edge client of a policy server."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: RemoteInferenceConfig,
|
||||
policy_config: PreTrainedConfig,
|
||||
hw_features: dict,
|
||||
ordered_action_keys: list[str],
|
||||
task: str,
|
||||
fps: float,
|
||||
robot_type: str,
|
||||
rename_map: dict[str, str] | None = None,
|
||||
shutdown_event: Event | None = None,
|
||||
) -> None:
|
||||
self._config = config
|
||||
self._policy_config = policy_config
|
||||
self._hw_features = hw_features
|
||||
self._ordered_action_keys = list(ordered_action_keys)
|
||||
self._task = task
|
||||
self._fps = float(fps)
|
||||
self._dt = 1.0 / self._fps
|
||||
self._robot_type = robot_type
|
||||
self._rename_map = dict(rename_map or {})
|
||||
self._global_shutdown_event = shutdown_event
|
||||
|
||||
self._client_uuid = sanitize_key_segment(config.client_uuid or uuid_module.uuid4().hex)
|
||||
model_id = config.service_model_id or getattr(policy_config, "pretrained_path", "") or "model"
|
||||
self._prefix = service_prefix(model_id, config.service_revision, config.service_task or task)
|
||||
|
||||
# Latest-only observation slot (identical to rtc.py's _obs_holder).
|
||||
self._obs_holder: dict[str, Any] = {"obs": None}
|
||||
self._obs_lock = Lock()
|
||||
|
||||
self._action_queue: ActionQueue | None = None
|
||||
self._latency_tracker = LatencyTracker()
|
||||
self._effective_rtc: RTCConfig = config.rtc
|
||||
|
||||
# Replies deposited by the zenoh callback, consumed by the worker.
|
||||
self._reply_queue: queue_module.Queue[tuple[MsgHeader, bytes]] = queue_module.Queue(maxsize=4)
|
||||
|
||||
self._zenoh = None
|
||||
self._obs_publisher = None
|
||||
self._declarations: list[Any] = []
|
||||
self._alive_token = None
|
||||
self._server_alive = Event()
|
||||
|
||||
self._worker: Thread | None = None
|
||||
self._stop_event = Event()
|
||||
self._active = Event()
|
||||
self._dead = Event()
|
||||
self._session_ack: SessionAckMsg | None = None
|
||||
|
||||
self.state = ClientState.CONNECTING
|
||||
self._state_lock = Lock()
|
||||
self._seq_id = 0
|
||||
self._epoch = 0
|
||||
self._episode_id = 0
|
||||
self._pending_reset = False
|
||||
|
||||
# Staleness bookkeeping: client-monotonic send time of the
|
||||
# observation that produced the current queue contents.
|
||||
# _anchor_lock serializes {merge + anchor update} (worker),
|
||||
# {staleness clear} (control thread), and {reset clear} so a
|
||||
# stale chunk can never merge into a freshly-reset queue and the
|
||||
# safety path can never clear a just-merged one.
|
||||
self._anchor_lock = Lock()
|
||||
self._chunk_anchor_mono: float | None = None
|
||||
self._last_chunk_mono: float | None = None
|
||||
self._offline_since_mono: float | None = None
|
||||
self._last_action: torch.Tensor | None = None
|
||||
|
||||
self.stats: dict[str, float] = {
|
||||
"requests": 0,
|
||||
"timeouts": 0,
|
||||
"merges": 0,
|
||||
"stale_drops": 0,
|
||||
"fallback_ticks": 0,
|
||||
"reconnects": 0,
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Lifecycle
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@property
|
||||
def ready(self) -> bool:
|
||||
"""Session opened, capabilities validated, server warmed up."""
|
||||
ack = self._session_ack
|
||||
return ack is not None and ack.warmed_up and not self._dead.is_set()
|
||||
|
||||
@property
|
||||
def failed(self) -> bool:
|
||||
return self._dead.is_set()
|
||||
|
||||
@property
|
||||
def action_queue(self) -> ActionQueue | None:
|
||||
return self._action_queue
|
||||
|
||||
def start(self) -> None:
|
||||
"""Open transport, handshake, start the network worker.
|
||||
|
||||
Raises on initial connection/validation failure so a bad
|
||||
deployment aborts before the robot moves (reconnect logic only
|
||||
guards established sessions).
|
||||
"""
|
||||
zenoh = import_zenoh()
|
||||
cfg = self._config
|
||||
self._zenoh = zenoh.open(
|
||||
build_zenoh_config(
|
||||
mode=cfg.zenoh_mode,
|
||||
connect_endpoints=[cfg.connect_endpoint] if cfg.connect_endpoint else None,
|
||||
tls_root_ca_certificate=cfg.tls_ca,
|
||||
tls_connect_certificate=cfg.tls_cert,
|
||||
tls_connect_private_key=cfg.tls_key,
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
ack = self._handshake(initial=True)
|
||||
except Exception:
|
||||
# Fail fast without leaking the transport session.
|
||||
with contextlib.suppress(Exception):
|
||||
self._zenoh.close()
|
||||
self._zenoh = None
|
||||
raise
|
||||
self._configure_from_ack(ack)
|
||||
|
||||
handlers = zenoh.handlers
|
||||
self._declarations.append(
|
||||
self._zenoh.declare_subscriber(
|
||||
action_key(self._prefix, self._client_uuid), handlers.Callback(self._on_chunk)
|
||||
)
|
||||
)
|
||||
self._obs_publisher = self._zenoh.declare_publisher(
|
||||
obs_key(self._prefix, self._client_uuid), **obs_publisher_qos(zenoh)
|
||||
)
|
||||
self._declarations.append(self._obs_publisher)
|
||||
self._server_alive.set()
|
||||
self._declarations.append(
|
||||
self._zenoh.liveliness().declare_subscriber(
|
||||
server_alive_key(self._prefix), handlers.Callback(self._on_server_liveliness), history=True
|
||||
)
|
||||
)
|
||||
self._alive_token = self._zenoh.liveliness().declare_token(
|
||||
client_alive_key(self._prefix, self._client_uuid)
|
||||
)
|
||||
|
||||
self._stop_event.clear()
|
||||
self._dead.clear()
|
||||
self._active.set()
|
||||
self._worker = Thread(target=self._worker_loop, daemon=True, name="RemoteInference")
|
||||
self._worker.start()
|
||||
logger.info(
|
||||
"Remote inference started: prefix=%s client=%s rtc=%s",
|
||||
self._prefix,
|
||||
self._client_uuid,
|
||||
self._effective_rtc.enabled,
|
||||
)
|
||||
|
||||
def stop(self) -> None:
|
||||
logger.info("Stopping remote inference engine...")
|
||||
self._stop_event.set()
|
||||
self._active.clear()
|
||||
if self._worker is not None and self._worker.is_alive():
|
||||
# Worst case the worker is mid-handshake inside _enter_reconnect.
|
||||
join_timeout = max(3.0, self._config.handshake_timeout_s + self._config.request_timeout_s + 2.0)
|
||||
self._worker.join(timeout=join_timeout)
|
||||
if self._worker.is_alive():
|
||||
logger.warning("Remote inference worker did not join")
|
||||
self._worker = None
|
||||
|
||||
if self._zenoh is not None:
|
||||
# Best-effort graceful close; the server also GCs on liveliness drop.
|
||||
with contextlib.suppress(Exception):
|
||||
self._control_query(
|
||||
session_key(self._prefix),
|
||||
codec.encode_session_close(
|
||||
SessionCloseMsg(
|
||||
client_uuid=self._client_uuid,
|
||||
session_id=self._session_ack.session_id if self._session_ack else "",
|
||||
)
|
||||
),
|
||||
timeout=1.0,
|
||||
)
|
||||
if self._alive_token is not None:
|
||||
with contextlib.suppress(Exception):
|
||||
self._alive_token.undeclare()
|
||||
self._alive_token = None
|
||||
for declaration in self._declarations:
|
||||
with contextlib.suppress(Exception):
|
||||
declaration.undeclare()
|
||||
self._declarations.clear()
|
||||
self._obs_publisher = None
|
||||
with contextlib.suppress(Exception):
|
||||
self._zenoh.close()
|
||||
self._zenoh = None
|
||||
logger.info("Remote inference engine stopped")
|
||||
|
||||
def pause(self) -> None:
|
||||
"""Stop publishing observations; the local queue stays intact."""
|
||||
logger.info("Pausing remote inference (publishing stops, queue intact)")
|
||||
self._active.clear()
|
||||
|
||||
def resume(self) -> None:
|
||||
logger.info("Resuming remote inference")
|
||||
self._active.set()
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Episode boundary: clear local chunk state, notify the server.
|
||||
|
||||
The acked reset query runs on the worker thread (never I/O on the
|
||||
control thread); thanks to per-request server statelessness a
|
||||
lost ack only costs a warning — the next observation announces
|
||||
the new episode in its header anyway.
|
||||
"""
|
||||
logger.info("Resetting remote inference state (queue + episode)")
|
||||
with self._anchor_lock:
|
||||
if self._action_queue is not None:
|
||||
self._action_queue.clear()
|
||||
self._chunk_anchor_mono = None
|
||||
with self._state_lock:
|
||||
self._episode_id += 1
|
||||
self._pending_reset = True
|
||||
with self._obs_lock:
|
||||
# The previous episode's final frame must not seed the new
|
||||
# episode's first request.
|
||||
self._obs_holder["obs"] = None
|
||||
self._last_action = None
|
||||
# LatencyTracker intentionally survives reset: latency is
|
||||
# episode-invariant (parity with local RTC).
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Action production (main thread — never any I/O)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def notify_observation(self, obs: dict) -> None:
|
||||
with self._obs_lock:
|
||||
self._obs_holder["obs"] = obs
|
||||
|
||||
def get_action(self, obs_frame: dict | None) -> torch.Tensor | None:
|
||||
queue = self._action_queue
|
||||
if queue is None:
|
||||
return None
|
||||
|
||||
# Staleness bound (sync safety): never execute an action whose
|
||||
# source observation is older than max_action_age_s. The lock
|
||||
# makes the check-and-clear atomic with the worker's merge.
|
||||
with self._anchor_lock:
|
||||
anchor = self._chunk_anchor_mono
|
||||
if (
|
||||
anchor is not None
|
||||
and queue.qsize() > 0
|
||||
and time.monotonic() - anchor > self._config.max_action_age_s
|
||||
):
|
||||
logger.warning(
|
||||
"Dropping %d stale actions (older than %.1fs) — applying fallback",
|
||||
queue.qsize(),
|
||||
self._config.max_action_age_s,
|
||||
)
|
||||
self.stats["stale_drops"] += 1
|
||||
queue.clear()
|
||||
self._chunk_anchor_mono = None
|
||||
|
||||
action = queue.get()
|
||||
if action is not None:
|
||||
self._last_action = action
|
||||
return action
|
||||
|
||||
self._set_state(ClientState.STALLED if self.state == ClientState.DEGRADED else self.state)
|
||||
return self._fallback_action()
|
||||
|
||||
def _fallback_action(self) -> torch.Tensor | None:
|
||||
from .factory import FallbackMode
|
||||
|
||||
mode = self._config.fallback
|
||||
if mode == FallbackMode.REPEAT_LAST and self._last_action is not None:
|
||||
self.stats["fallback_ticks"] += 1
|
||||
return self._last_action.clone()
|
||||
if mode == FallbackMode.ZERO:
|
||||
# For velocity-controlled robots "send nothing" means "keep
|
||||
# last velocity" — an explicit zero command is the safe stop.
|
||||
self.stats["fallback_ticks"] += 1
|
||||
return torch.zeros(len(self._ordered_action_keys))
|
||||
return None # HOLD: send_next_action tolerates None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Handshake & control plane
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _handshake(self, initial: bool) -> SessionAckMsg:
|
||||
"""status (pre-flight) + session open; raises on rejection."""
|
||||
cfg = self._config
|
||||
status_data = self._control_query(status_key(self._prefix), b"", timeout=cfg.handshake_timeout_s)
|
||||
if status_data is None:
|
||||
raise ConnectionError(
|
||||
f"No policy server answered status query at {status_key(self._prefix)!r} "
|
||||
f"via {cfg.connect_endpoint!r} (timeout {cfg.handshake_timeout_s}s)"
|
||||
)
|
||||
status = codec.decode_status(status_data)
|
||||
logger.info(
|
||||
"Server status: model=%s@%s policy=%s sessions=%d/%d warmed_up=%s",
|
||||
status.model_repo,
|
||||
status.model_revision,
|
||||
status.policy_type,
|
||||
status.active_sessions,
|
||||
status.max_sessions,
|
||||
status.warmed_up,
|
||||
)
|
||||
|
||||
open_msg = SessionOpenMsg(
|
||||
client_uuid=self._client_uuid,
|
||||
robot_type=self._robot_type,
|
||||
policy_type=getattr(self._policy_config, "type", ""),
|
||||
fps=self._fps,
|
||||
action_names=self._ordered_action_keys,
|
||||
camera_names=self._wire_camera_names(),
|
||||
state_dim=self._state_dim(),
|
||||
schema_version=SCHEMA_VERSION,
|
||||
rtc_enabled=cfg.rtc.enabled,
|
||||
task=self._task,
|
||||
tags=cfg.tags,
|
||||
)
|
||||
ack_data = self._control_query(
|
||||
session_key(self._prefix), codec.encode_session_open(open_msg), timeout=cfg.request_timeout_s
|
||||
)
|
||||
if ack_data is None:
|
||||
raise ConnectionError("Session open query timed out")
|
||||
ack = codec.decode_session_ack(ack_data)
|
||||
if not ack.accepted:
|
||||
raise ConnectionError(f"Policy server rejected the session: {ack.error}")
|
||||
for warning in ack.warnings:
|
||||
logger.warning("Server warning: %s", warning)
|
||||
|
||||
# Hard sync-safety contract: chunk columns map to motors by order.
|
||||
if ack.action_names and ack.action_names != self._ordered_action_keys:
|
||||
raise ValueError(
|
||||
"Action name/order mismatch between server policy and this robot.\n"
|
||||
f" server: {ack.action_names}\n client: {self._ordered_action_keys}"
|
||||
)
|
||||
if not initial and self._session_ack is not None:
|
||||
previous = self._session_ack
|
||||
if (ack.model_repo, ack.model_revision) != (previous.model_repo, previous.model_revision):
|
||||
raise ValueError(
|
||||
f"Server model changed across reconnect "
|
||||
f"({previous.model_repo}@{previous.model_revision} → "
|
||||
f"{ack.model_repo}@{ack.model_revision}) — refusing to execute wrong-model chunks"
|
||||
)
|
||||
return ack
|
||||
|
||||
def _configure_from_ack(self, ack: SessionAckMsg) -> None:
|
||||
rtc_requested = self._config.rtc.enabled
|
||||
rtc_effective = rtc_requested and ack.supports_rtc
|
||||
if rtc_requested and not rtc_effective:
|
||||
logger.warning("RTC downgraded to chunk-append (server does not support RTC)")
|
||||
if self._action_queue is not None and self._action_queue.cfg.enabled != rtc_effective:
|
||||
# The queue's merge semantics (replace vs append) were fixed at
|
||||
# session start; a server whose RTC capability changed across a
|
||||
# reconnect would corrupt them.
|
||||
raise ValueError(
|
||||
"Server RTC capability changed across reconnect "
|
||||
f"(queue merge mode {'replace' if self._action_queue.cfg.enabled else 'append'} "
|
||||
f"vs server RTC={rtc_effective}) — refusing to continue"
|
||||
)
|
||||
self._effective_rtc = RTCConfig(
|
||||
enabled=rtc_effective,
|
||||
prefix_attention_schedule=self._config.rtc.prefix_attention_schedule,
|
||||
max_guidance_weight=self._config.rtc.max_guidance_weight,
|
||||
execution_horizon=ack.rtc_execution_horizon or self._config.rtc.execution_horizon,
|
||||
debug=self._config.rtc.debug,
|
||||
debug_maxlen=self._config.rtc.debug_maxlen,
|
||||
)
|
||||
if self._action_queue is None:
|
||||
self._action_queue = ActionQueue(self._effective_rtc)
|
||||
self._session_ack = ack
|
||||
|
||||
def _control_query(self, key: str, payload: bytes, timeout: float) -> bytes | None:
|
||||
"""One request/reply on the control plane; None on timeout/no-server."""
|
||||
zenoh = import_zenoh()
|
||||
try:
|
||||
replies = self._zenoh.get(
|
||||
key,
|
||||
handler=zenoh.handlers.FifoChannel(4),
|
||||
payload=payload,
|
||||
timeout=timeout,
|
||||
)
|
||||
deadline = time.monotonic() + timeout + 0.5
|
||||
while time.monotonic() < deadline:
|
||||
try:
|
||||
reply = replies.try_recv()
|
||||
except Exception: # zenoh.ZError: channel closed (no queryable / finished)
|
||||
return None
|
||||
if reply is None:
|
||||
time.sleep(0.005)
|
||||
continue
|
||||
if reply.ok is not None:
|
||||
return reply.ok.payload.to_bytes()
|
||||
return None # Reply.err (e.g. b"Timeout")
|
||||
return None
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.warning("Control query %s failed: %s", key, e)
|
||||
return None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Zenoh callbacks (deposit-only)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _on_chunk(self, sample: Any) -> None:
|
||||
try:
|
||||
attachment = sample.attachment
|
||||
if attachment is None:
|
||||
return
|
||||
header = MsgHeader.unpack(attachment.to_bytes())
|
||||
item = (header, sample.payload.to_bytes())
|
||||
try:
|
||||
self._reply_queue.put_nowait(item)
|
||||
except queue_module.Full:
|
||||
# Drop oldest, keep newest.
|
||||
with contextlib.suppress(queue_module.Empty):
|
||||
self._reply_queue.get_nowait()
|
||||
with contextlib.suppress(queue_module.Full):
|
||||
self._reply_queue.put_nowait(item)
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.error("chunk callback error: %s", e)
|
||||
|
||||
def _on_server_liveliness(self, sample: Any) -> None:
|
||||
try:
|
||||
import zenoh
|
||||
|
||||
if sample.kind == zenoh.SampleKind.DELETE:
|
||||
logger.warning("Server liveliness token dropped")
|
||||
self._server_alive.clear()
|
||||
else:
|
||||
self._server_alive.set()
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.error("liveliness callback error: %s", e)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Network worker
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _worker_loop(self) -> None:
|
||||
consecutive_errors = 0
|
||||
try:
|
||||
while not self._stop_event.is_set():
|
||||
if not self._active.is_set():
|
||||
time.sleep(_IDLE_SLEEP_S)
|
||||
continue
|
||||
try:
|
||||
self._maybe_send_reset()
|
||||
|
||||
if not self._server_alive.is_set():
|
||||
self._enter_reconnect("server liveliness dropped")
|
||||
continue
|
||||
|
||||
queue = self._action_queue
|
||||
if queue is not None and queue.qsize() * self._dt > self._config.buffer_time_s:
|
||||
time.sleep(_IDLE_SLEEP_S)
|
||||
continue
|
||||
|
||||
with self._obs_lock:
|
||||
obs = self._obs_holder.get("obs")
|
||||
if obs is None:
|
||||
time.sleep(_IDLE_SLEEP_S)
|
||||
continue
|
||||
|
||||
self._request_cycle(obs)
|
||||
consecutive_errors = 0
|
||||
except ConnectionError as e:
|
||||
# Raised by reconnect on hard contract violations.
|
||||
raise e
|
||||
except Exception as e: # noqa: BLE001 — transient worker errors retry
|
||||
consecutive_errors += 1
|
||||
logger.error(
|
||||
"Remote inference worker error (%d/%d): %s",
|
||||
consecutive_errors,
|
||||
_MAX_CONSECUTIVE_WORKER_ERRORS,
|
||||
e,
|
||||
)
|
||||
logger.debug(traceback.format_exc())
|
||||
if consecutive_errors >= _MAX_CONSECUTIVE_WORKER_ERRORS:
|
||||
raise
|
||||
time.sleep(0.5)
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.error("Fatal error in remote inference worker: %s", e)
|
||||
logger.error(traceback.format_exc())
|
||||
self._go_dead(str(e))
|
||||
|
||||
def _request_cycle(self, obs: dict) -> None:
|
||||
"""Publish one observation and merge its chunk (one-in-flight)."""
|
||||
cfg = self._config
|
||||
queue = self._action_queue
|
||||
|
||||
obs_frame = build_dataset_frame(self._hw_features, obs, prefix=OBS_STR)
|
||||
if self._rename_map:
|
||||
obs_frame = {self._rename_map.get(k, k): v for k, v in obs_frame.items()}
|
||||
|
||||
state = obs_frame.pop(OBS_STATE, None)
|
||||
images = {k: v for k, v in obs_frame.items() if isinstance(v, np.ndarray) and v.ndim == 3}
|
||||
|
||||
with self._state_lock:
|
||||
self._seq_id += 1
|
||||
seq_id = self._seq_id
|
||||
episode_id = self._episode_id
|
||||
epoch = self._epoch
|
||||
|
||||
# Snapshot RTC state (must precede the publish; merge validates
|
||||
# against idx_before).
|
||||
idx_before = queue.get_action_index()
|
||||
prefix_model: np.ndarray | None = None
|
||||
prefix_robot: np.ndarray | None = None
|
||||
delay_steps = 0
|
||||
if self._effective_rtc.enabled:
|
||||
horizon = self._effective_rtc.execution_horizon
|
||||
left_over = queue.get_left_over()
|
||||
if left_over is not None and left_over.numel():
|
||||
prefix_model = left_over[:horizon].to(torch.float32).numpy()
|
||||
processed_left_over = queue.get_processed_left_over()
|
||||
if processed_left_over is not None and processed_left_over.numel():
|
||||
prefix_robot = processed_left_over[:horizon].to(torch.float32).numpy()
|
||||
max_latency = self._latency_tracker.max() if len(self._latency_tracker) else 0.0
|
||||
delay_steps = math.ceil(max_latency / self._dt) if max_latency else 0
|
||||
|
||||
# A reset/reconnect between the counter snapshot and the prefix
|
||||
# snapshot would pair a new episode id with old-episode prefixes
|
||||
# — skip the cycle instead.
|
||||
with self._state_lock:
|
||||
if (self._episode_id, self._epoch) != (episode_id, epoch):
|
||||
return
|
||||
|
||||
header = MsgHeader(
|
||||
schema_version=SCHEMA_VERSION,
|
||||
msg_type=MSG_TYPE_OBS,
|
||||
seq_id=seq_id,
|
||||
episode_id=episode_id,
|
||||
client_mono_ns=time.monotonic_ns(),
|
||||
session_epoch=epoch,
|
||||
)
|
||||
msg = ObservationMsg(
|
||||
state=state,
|
||||
images=images,
|
||||
task=self._task,
|
||||
inference_delay_steps=delay_steps,
|
||||
prefix_model=prefix_model,
|
||||
prefix_robot=prefix_robot,
|
||||
episode_start=(queue.qsize() == 0 and idx_before == 0 and self._chunk_anchor_mono is None),
|
||||
jpeg_quality=cfg.jpeg_quality,
|
||||
)
|
||||
|
||||
t_send = time.perf_counter()
|
||||
self._obs_publisher.put(codec.encode_observation(msg), attachment=header.pack())
|
||||
self.stats["requests"] += 1
|
||||
|
||||
reply = self._await_chunk(seq_id, episode_id, epoch, timeout=cfg.request_timeout_s)
|
||||
if reply is None:
|
||||
self.stats["timeouts"] += 1
|
||||
self._on_request_timeout()
|
||||
return
|
||||
|
||||
chunk = codec.decode_action_chunk(reply)
|
||||
if chunk.chunk_model is None or chunk.chunk_robot is None:
|
||||
# A persistently malformed server must still trip the
|
||||
# degradation ladder, not stall in nominal state.
|
||||
logger.warning("Chunk for seq=%d had empty tensors — dropping", seq_id)
|
||||
self.stats["timeouts"] += 1
|
||||
self._on_request_timeout()
|
||||
return
|
||||
|
||||
latency = time.perf_counter() - t_send
|
||||
real_delay = math.ceil(latency / self._dt)
|
||||
with self._anchor_lock:
|
||||
# reset() takes the same lock before clearing: either the
|
||||
# reset fully precedes this merge (episode changed → drop the
|
||||
# stale chunk) or the merge completes first (and the reset
|
||||
# then clears it) — a stale chunk can never survive a reset.
|
||||
with self._state_lock:
|
||||
if (self._episode_id, self._epoch) != (episode_id, epoch):
|
||||
logger.debug("Dropping chunk seq=%d: episode/epoch changed mid-flight", seq_id)
|
||||
return
|
||||
queue.merge(
|
||||
torch.from_numpy(np.ascontiguousarray(chunk.chunk_model)),
|
||||
torch.from_numpy(np.ascontiguousarray(chunk.chunk_robot)),
|
||||
real_delay,
|
||||
idx_before,
|
||||
)
|
||||
self._chunk_anchor_mono = time.monotonic() - latency # ≈ when the source obs was sent
|
||||
self._latency_tracker.add(latency)
|
||||
self._last_chunk_mono = time.monotonic()
|
||||
self._offline_since_mono = None
|
||||
self.stats["merges"] += 1
|
||||
self._set_state(ClientState.STREAMING)
|
||||
logger.debug(
|
||||
"merge: seq=%d latency=%.0fms delay=%d queue=%d server(inf=%.0fms wait=%.0fms load=%.2f)",
|
||||
seq_id,
|
||||
latency * 1e3,
|
||||
real_delay,
|
||||
queue.qsize(),
|
||||
chunk.inference_ms,
|
||||
chunk.queue_wait_ms,
|
||||
chunk.server_load,
|
||||
)
|
||||
|
||||
def _await_chunk(self, seq_id: int, episode_id: int, epoch: int, timeout: float) -> bytes | None:
|
||||
"""Wait for the chunk answering the latest outstanding request.
|
||||
|
||||
Stale replies (older seq/episode/epoch) are dropped — under
|
||||
one-in-flight a late chunk can only ever answer an older request.
|
||||
"""
|
||||
deadline = time.monotonic() + timeout
|
||||
while not self._stop_event.is_set():
|
||||
remaining = deadline - time.monotonic()
|
||||
if remaining <= 0:
|
||||
return None
|
||||
try:
|
||||
header, payload = self._reply_queue.get(timeout=min(remaining, 0.1))
|
||||
except queue_module.Empty:
|
||||
continue
|
||||
if header.session_epoch != epoch or header.episode_id != episode_id:
|
||||
continue # stale epoch/episode (reset or reconnect happened)
|
||||
if header.seq_id != seq_id:
|
||||
continue # late reply to a superseded request
|
||||
return payload
|
||||
return None
|
||||
|
||||
def _maybe_send_reset(self) -> None:
|
||||
with self._state_lock:
|
||||
pending, episode_id = self._pending_reset, self._episode_id
|
||||
self._pending_reset = False
|
||||
if pending and self._zenoh is not None:
|
||||
ack_data = self._control_query(
|
||||
reset_key(self._prefix, self._client_uuid),
|
||||
codec.encode_reset(ResetMsg(client_uuid=self._client_uuid, episode_id=episode_id)),
|
||||
timeout=1.0,
|
||||
)
|
||||
if ack_data is None:
|
||||
# Harmless: the server is stateless per request and the next
|
||||
# observation header announces the new episode anyway.
|
||||
logger.warning("Reset ack not received (continuing — header carries the episode bump)")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Degradation / reconnect / death
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _on_request_timeout(self) -> None:
|
||||
if self._stop_event.is_set():
|
||||
# _await_chunk aborted by a normal stop(), not by the network.
|
||||
return
|
||||
now = time.monotonic()
|
||||
if self._offline_since_mono is None:
|
||||
self._offline_since_mono = now
|
||||
offline_for = now - self._offline_since_mono
|
||||
|
||||
queue = self._action_queue
|
||||
if queue is not None and queue.qsize() > 0:
|
||||
self._set_state(ClientState.DEGRADED)
|
||||
else:
|
||||
self._set_state(ClientState.STALLED)
|
||||
|
||||
last = self._last_chunk_mono or 0.0
|
||||
if (now - last if last else offline_for) >= self._config.degraded_after_s:
|
||||
logger.warning(
|
||||
"No chunk for %.1fs (queue=%d) — %s",
|
||||
offline_for,
|
||||
queue.qsize() if queue else 0,
|
||||
self.state,
|
||||
)
|
||||
if offline_for > self._config.max_offline_s:
|
||||
self._go_dead(f"offline for {offline_for:.0f}s (> max_offline_s)")
|
||||
return
|
||||
if not self._server_alive.is_set() or offline_for >= 2 * self._config.request_timeout_s:
|
||||
self._enter_reconnect(f"request timeouts for {offline_for:.0f}s")
|
||||
|
||||
def _enter_reconnect(self, reason: str) -> None:
|
||||
"""Backoff + re-handshake loop. Hard contract violations → DEAD."""
|
||||
self._set_state(ClientState.RECONNECTING)
|
||||
logger.warning("Reconnecting: %s", reason)
|
||||
if self._offline_since_mono is None:
|
||||
self._offline_since_mono = time.monotonic()
|
||||
backoff = self._config.reconnect_initial_backoff_s
|
||||
while not self._stop_event.is_set():
|
||||
if not self._active.is_set():
|
||||
# Paused (e.g. DAgger human correction): keep trying to
|
||||
# reconnect, but a pause must never burn the offline budget
|
||||
# into a mid-correction shutdown.
|
||||
self._offline_since_mono = time.monotonic()
|
||||
offline_for = time.monotonic() - self._offline_since_mono
|
||||
if offline_for > self._config.max_offline_s:
|
||||
self._go_dead(f"offline for {offline_for:.0f}s (> max_offline_s)")
|
||||
return
|
||||
self._stop_event.wait(timeout=backoff)
|
||||
if self._stop_event.is_set():
|
||||
return
|
||||
backoff = min(backoff * 2, self._config.reconnect_max_backoff_s)
|
||||
try:
|
||||
with self._state_lock:
|
||||
self._epoch += 1
|
||||
ack = self._handshake(initial=False)
|
||||
self._configure_from_ack(ack)
|
||||
except ValueError as e:
|
||||
# Capability/schema/model mismatch: never execute wrong-model chunks.
|
||||
self._go_dead(str(e))
|
||||
return
|
||||
except Exception as e: # noqa: BLE001 — server still down, keep trying
|
||||
logger.info("Re-handshake failed (%s) — retrying in %.1fs", e, backoff)
|
||||
continue
|
||||
# A successful handshake is proof of life even if the liveliness
|
||||
# PUT was missed or hasn't been delivered yet.
|
||||
self._server_alive.set()
|
||||
# The offline budget is only reset by the next successful merge:
|
||||
# a server that handshakes but never delivers chunks must still
|
||||
# run out of budget and go DEAD.
|
||||
self.stats["reconnects"] += 1
|
||||
self._set_state(ClientState.STREAMING)
|
||||
logger.info("Reconnected (epoch=%d, session=%s)", self._epoch, ack.session_id)
|
||||
return
|
||||
|
||||
def _go_dead(self, reason: str) -> None:
|
||||
if self._dead.is_set():
|
||||
return
|
||||
logger.error("Remote inference DEAD: %s", reason)
|
||||
self._set_state(ClientState.DEAD)
|
||||
self._dead.set()
|
||||
if self._global_shutdown_event is not None:
|
||||
self._global_shutdown_event.set()
|
||||
|
||||
def _set_state(self, new_state: str) -> None:
|
||||
if new_state != self.state:
|
||||
logger.info("Client state: %s → %s", self.state, new_state)
|
||||
self.state = new_state
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Feature helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _wire_camera_names(self) -> list[str]:
|
||||
names = [
|
||||
key for key, feature in self._hw_features.items() if feature.get("dtype") in ("image", "video")
|
||||
]
|
||||
return [self._rename_map.get(name, name) for name in names]
|
||||
|
||||
def _state_dim(self) -> int:
|
||||
state_feature = self._hw_features.get(OBS_STATE)
|
||||
if state_feature and state_feature.get("names"):
|
||||
return len(state_feature["names"])
|
||||
return 0
|
||||
@@ -17,6 +17,7 @@
|
||||
from .base import BaseStrategy
|
||||
from .core import RolloutStrategy, estimate_max_episode_seconds, safe_push_to_hub, send_next_action
|
||||
from .dagger import DAggerEvents, DAggerPhase, DAggerStrategy
|
||||
from .episodic import EpisodicStrategy
|
||||
from .factory import create_strategy
|
||||
from .highlight import HighlightStrategy
|
||||
from .sentry import SentryStrategy
|
||||
@@ -27,6 +28,7 @@ __all__ = [
|
||||
"DAggerPhase",
|
||||
"DAggerStrategy",
|
||||
"HighlightStrategy",
|
||||
"EpisodicStrategy",
|
||||
"RolloutStrategy",
|
||||
"SentryStrategy",
|
||||
"create_strategy",
|
||||
|
||||
@@ -56,10 +56,14 @@ from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lerobot.common.control_utils import is_headless
|
||||
from lerobot.common.control_utils import (
|
||||
follower_smooth_move_to,
|
||||
is_headless,
|
||||
teleop_smooth_move_to,
|
||||
teleop_supports_feedback,
|
||||
)
|
||||
from lerobot.datasets import VideoEncodingManager
|
||||
from lerobot.datasets.utils import DEFAULT_VIDEO_FILE_SIZE_IN_MB
|
||||
from lerobot.teleoperators import Teleoperator
|
||||
from lerobot.utils.constants import ACTION, OBS_STR
|
||||
from lerobot.utils.feature_utils import build_dataset_frame
|
||||
from lerobot.utils.import_utils import _pynput_available
|
||||
@@ -69,7 +73,6 @@ from lerobot.utils.utils import log_say
|
||||
|
||||
from ..configs import DAggerKeyboardConfig, DAggerPedalConfig, DAggerStrategyConfig
|
||||
from ..context import RolloutContext
|
||||
from ..robot_wrapper import ThreadSafeRobot
|
||||
from .core import RolloutStrategy, estimate_max_episode_seconds, safe_push_to_hub, send_next_action
|
||||
|
||||
PYNPUT_AVAILABLE = _pynput_available
|
||||
@@ -171,64 +174,6 @@ class DAggerEvents:
|
||||
self.upload_requested.clear()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Teleoperator helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _teleop_supports_feedback(teleop: Teleoperator) -> bool:
|
||||
"""Return True when the teleop can receive position feedback (is actuated).
|
||||
TODO(Maxime): See if it is possible to unify this interface across teleops instead of duck-typing.
|
||||
"""
|
||||
return (
|
||||
bool(teleop.feedback_features)
|
||||
and hasattr(teleop, "disable_torque")
|
||||
and hasattr(teleop, "enable_torque")
|
||||
)
|
||||
|
||||
|
||||
def _teleop_smooth_move_to(
|
||||
teleop: Teleoperator, target_pos: dict, duration_s: float = 2.0, fps: int = 30
|
||||
) -> None:
|
||||
"""Smoothly move an actuated teleop to ``target_pos`` via linear interpolation.
|
||||
|
||||
Requires the teleoperator to support feedback
|
||||
(i.e. have non-empty ``feedback_features`` and implement ``disable_torque`` / ``enable_torque``).
|
||||
|
||||
TODO(Maxime): This blocks up to ``duration_s`` seconds, during this time
|
||||
the follower robot doesn't receive new actions, this could be an issue on LeKiwi.
|
||||
"""
|
||||
teleop.enable_torque()
|
||||
current = teleop.get_action()
|
||||
steps = max(int(duration_s * fps), 1)
|
||||
|
||||
for step in range(steps + 1):
|
||||
t = step / steps
|
||||
interp = {
|
||||
k: current[k] * (1 - t) + target_pos[k] * t if k in target_pos else current[k] for k in current
|
||||
}
|
||||
teleop.send_feedback(interp)
|
||||
time.sleep(1 / fps)
|
||||
|
||||
|
||||
def _follower_smooth_move_to(
|
||||
robot: ThreadSafeRobot, current: dict, target: dict, duration_s: float = 1.0, fps: int = 30
|
||||
) -> None:
|
||||
"""Smoothly move the follower robot from ``current`` to ``target`` action.
|
||||
|
||||
Used when the teleop is non-actuated: instead of driving the leader arm
|
||||
to the follower, we bring the follower to the teleop's current pose.
|
||||
Both ``current`` and ``target`` must be in robot-action key space.
|
||||
"""
|
||||
steps = max(int(duration_s * fps), 1)
|
||||
|
||||
for step in range(steps + 1):
|
||||
t = step / steps
|
||||
interp = {k: current[k] * (1 - t) + target[k] * t if k in target else current[k] for k in current}
|
||||
robot.send_action(interp)
|
||||
time.sleep(1 / fps)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Input device handlers
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -756,31 +701,31 @@ class DAggerStrategy(RolloutStrategy):
|
||||
logger.info("Pausing engine - robot holds position")
|
||||
engine.pause()
|
||||
|
||||
if _teleop_supports_feedback(teleop) and prev_action is not None:
|
||||
if teleop_supports_feedback(teleop) and prev_action is not None:
|
||||
# TODO(Maxime): prev_action is in robot action key space (output of robot_action_processor).
|
||||
# send_feedback expects teleop feedback key space. For homogeneous setups (e.g. SO-101
|
||||
# leader + SO-101 follower) the keys are identical so this works. If the processor pipeline
|
||||
# does non-trivial key renaming (e.g. a rename_map on action keys), the interpolation in
|
||||
# _teleop_smooth_move_to silently no-ops and the arm doesn't move.
|
||||
# teleop_smooth_move_to silently no-ops and the arm doesn't move.
|
||||
logger.info("Smooth handover: moving leader arm to follower position")
|
||||
_teleop_smooth_move_to(teleop, prev_action)
|
||||
teleop_smooth_move_to(teleop, prev_action)
|
||||
|
||||
elif old_phase == DAggerPhase.PAUSED and new_phase == DAggerPhase.CORRECTING:
|
||||
logger.info("Entering correction mode - human teleop control")
|
||||
if not _teleop_supports_feedback(teleop) and prev_action is not None:
|
||||
if not teleop_supports_feedback(teleop) and prev_action is not None:
|
||||
logger.info("Smooth handover: sliding follower to teleop position")
|
||||
obs = robot.get_observation()
|
||||
teleop_action = teleop.get_action()
|
||||
processed = ctx.processors.teleop_action_processor((teleop_action, obs))
|
||||
target = ctx.processors.robot_action_processor((processed, obs))
|
||||
_follower_smooth_move_to(robot, prev_action, target)
|
||||
follower_smooth_move_to(robot, prev_action, target)
|
||||
|
||||
# unlock the teleop for human control
|
||||
if _teleop_supports_feedback(teleop):
|
||||
if teleop_supports_feedback(teleop):
|
||||
teleop.disable_torque()
|
||||
|
||||
elif old_phase == DAggerPhase.CORRECTING and new_phase == DAggerPhase.PAUSED:
|
||||
if _teleop_supports_feedback(teleop):
|
||||
if teleop_supports_feedback(teleop):
|
||||
teleop.enable_torque()
|
||||
|
||||
elif new_phase == DAggerPhase.AUTONOMOUS:
|
||||
@@ -790,7 +735,7 @@ class DAggerStrategy(RolloutStrategy):
|
||||
engine.resume()
|
||||
|
||||
# release teleop before resuming the policy
|
||||
if _teleop_supports_feedback(teleop):
|
||||
if teleop_supports_feedback(teleop):
|
||||
teleop.disable_torque()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@@ -0,0 +1,335 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Episodic rollout strategy: mirrors the behavior of ``lerobot-record``.
|
||||
|
||||
- Policy drives the robot during each recording episode.
|
||||
- An optional teleoperator can drive the robot during reset phases so the
|
||||
operator can bring the environment back to its starting configuration.
|
||||
If no teleop is connected the robot stays in its current position.
|
||||
- Keyboard controls:
|
||||
|
||||
Right arrow — end the current episode or reset phase early
|
||||
Left arrow — discard the current episode and re-record it
|
||||
Escape — stop the recording session
|
||||
|
||||
Dataset naming follows the rollout convention: repo names must start with ``rollout_``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
import time
|
||||
|
||||
from lerobot.common.control_utils import (
|
||||
follower_smooth_move_to,
|
||||
init_keyboard_listener,
|
||||
is_headless,
|
||||
teleop_smooth_move_to,
|
||||
teleop_supports_feedback,
|
||||
)
|
||||
from lerobot.datasets import VideoEncodingManager
|
||||
from lerobot.utils.constants import ACTION, OBS_STR
|
||||
from lerobot.utils.feature_utils import build_dataset_frame
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import log_rerun_data
|
||||
|
||||
from ..configs import EpisodicStrategyConfig
|
||||
from ..context import RolloutContext
|
||||
from .core import RolloutStrategy, safe_push_to_hub, send_next_action
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EpisodicStrategy(RolloutStrategy):
|
||||
"""Policy-driven multi-episode recording, mirrors the behavior of ``lerobot-record``.
|
||||
|
||||
Each recording episode runs the policy for maximum ``dataset.episode_time_s``
|
||||
seconds, recording every frame. A reset phase of ``dataset.reset_time_s``
|
||||
follows every episode (except the last) so the operator can manually
|
||||
reset the environment. During the reset phase, an optional teleoperator
|
||||
drives the robot; if none is present the robot returns to its initial joint positions captured at startup.
|
||||
|
||||
The policy state (hidden state, RTC queue, interpolator) is reset at
|
||||
the start of each recording episode.
|
||||
|
||||
Keyboard events:
|
||||
right arrow → end current episode or reset phase early
|
||||
left arrow → discard & re-record current episode
|
||||
ESC → stop the session
|
||||
"""
|
||||
|
||||
config: EpisodicStrategyConfig
|
||||
|
||||
def __init__(self, config: EpisodicStrategyConfig) -> None:
|
||||
super().__init__(config)
|
||||
self._listener = None
|
||||
self._events: dict | None = None
|
||||
|
||||
def setup(self, ctx: RolloutContext) -> None:
|
||||
"""Start the inference engine and attach the keyboard listener."""
|
||||
self._init_engine(ctx)
|
||||
self._listener, self._events = init_keyboard_listener()
|
||||
logger.info("Episodic strategy ready")
|
||||
|
||||
def run(self, ctx: RolloutContext) -> None:
|
||||
"""Main multi-episode recording loop."""
|
||||
cfg = ctx.runtime.cfg
|
||||
dataset_cfg = cfg.dataset
|
||||
robot = ctx.hardware.robot_wrapper
|
||||
teleop = ctx.hardware.teleop
|
||||
dataset = ctx.data.dataset
|
||||
events = self._events
|
||||
features = ctx.data.dataset_features
|
||||
|
||||
fps = cfg.fps
|
||||
episode_time_s = dataset_cfg.episode_time_s
|
||||
reset_time_s = dataset_cfg.reset_time_s
|
||||
num_episodes = dataset_cfg.num_episodes
|
||||
single_task = dataset_cfg.single_task or cfg.task
|
||||
play_sounds = cfg.play_sounds
|
||||
|
||||
display_compressed = (
|
||||
True
|
||||
if (cfg.display_data and cfg.display_ip is not None and cfg.display_port is not None)
|
||||
else cfg.display_compressed_images
|
||||
)
|
||||
|
||||
with VideoEncodingManager(dataset):
|
||||
try:
|
||||
recorded_episodes = 0
|
||||
while recorded_episodes < num_episodes and not events["stop_recording"]:
|
||||
if ctx.runtime.shutdown_event.is_set():
|
||||
break
|
||||
|
||||
# Reset policy state at episode start (discard leftover hidden state / queue)
|
||||
self._engine.reset()
|
||||
self._interpolator.reset()
|
||||
self._engine.resume()
|
||||
|
||||
log_say(f"Recording episode {dataset.num_episodes}", play_sounds)
|
||||
self._policy_loop(
|
||||
ctx=ctx,
|
||||
robot=robot,
|
||||
events=events,
|
||||
features=features,
|
||||
fps=fps,
|
||||
control_time_s=episode_time_s,
|
||||
dataset=dataset,
|
||||
single_task=single_task,
|
||||
)
|
||||
|
||||
# Reset phase, skip after the last episode (but run when re-recording)
|
||||
if not events["stop_recording"] and (
|
||||
recorded_episodes < num_episodes - 1 or events["rerecord_episode"]
|
||||
):
|
||||
log_say("Reset the environment", play_sounds)
|
||||
|
||||
if teleop:
|
||||
# Smooth handover so the transition to teleop control is jerk-free.
|
||||
# For actuated teleops: drive the leader arm to the follower's current
|
||||
# position so the operator takes over without fighting the arm.
|
||||
# For non-actuated teleops: slide the follower to the teleop's current
|
||||
# pose instead, since the leader cannot be driven.
|
||||
obs = robot.get_observation()
|
||||
current_pos = {k: v for k, v in obs.items() if k.endswith(".pos")}
|
||||
if (
|
||||
teleop_supports_feedback(teleop)
|
||||
and self.config.smooth_leader_to_follower_handover
|
||||
):
|
||||
logger.info("Smooth handover: moving leader arm to follower position")
|
||||
teleop_smooth_move_to(teleop, current_pos, duration_s=2)
|
||||
teleop.disable_torque()
|
||||
else:
|
||||
logger.info("Smooth handover: sliding follower to teleop position")
|
||||
teleop_action = teleop.get_action()
|
||||
processed = ctx.processors.teleop_action_processor((teleop_action, obs))
|
||||
target = ctx.processors.robot_action_processor((processed, obs))
|
||||
follower_smooth_move_to(robot, current_pos, target, duration_s=1)
|
||||
|
||||
elif self.config.reset_to_initial_position:
|
||||
# No teleop: return the robot to its startup position.
|
||||
self._return_to_initial_position(hw=ctx.hardware, duration_s=1)
|
||||
|
||||
self._reset_loop(
|
||||
ctx=ctx,
|
||||
robot=robot,
|
||||
teleop=teleop,
|
||||
events=events,
|
||||
fps=fps,
|
||||
control_time_s=reset_time_s,
|
||||
display_data=cfg.display_data,
|
||||
display_compressed=display_compressed,
|
||||
)
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-record episode", play_sounds)
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
|
||||
# returns to its initial joint positions captured at startup
|
||||
if not teleop and self.config.reset_to_initial_position:
|
||||
self._return_to_initial_position(hw=ctx.hardware, duration_s=1)
|
||||
|
||||
continue
|
||||
|
||||
dataset.save_episode()
|
||||
recorded_episodes += 1
|
||||
finally:
|
||||
# Save any frames buffered in the current episode so an unexpected
|
||||
# exception or KeyboardInterrupt does not silently drop recorded data.
|
||||
# suppress: save_episode raises if the buffer is empty (nothing to lose).
|
||||
logger.info("Episodic control loop ended — saving any in-progress episode")
|
||||
with contextlib.suppress(Exception):
|
||||
dataset.save_episode()
|
||||
|
||||
def _policy_loop(
|
||||
self,
|
||||
ctx: RolloutContext,
|
||||
robot,
|
||||
events: dict,
|
||||
features: dict,
|
||||
fps: float,
|
||||
control_time_s: float,
|
||||
dataset,
|
||||
single_task: str,
|
||||
) -> None:
|
||||
"""Policy-driven recording loop for a single episode."""
|
||||
interpolator = self._interpolator
|
||||
control_interval = interpolator.get_control_interval(fps)
|
||||
|
||||
timestamp = 0.0
|
||||
start_t = time.perf_counter()
|
||||
|
||||
while timestamp < control_time_s:
|
||||
loop_start = time.perf_counter()
|
||||
|
||||
if events["exit_early"]:
|
||||
events["exit_early"] = False
|
||||
break
|
||||
|
||||
if ctx.runtime.shutdown_event.is_set():
|
||||
break
|
||||
|
||||
obs = robot.get_observation()
|
||||
obs_processed = self._process_observation_and_notify(ctx.processors, obs)
|
||||
|
||||
if self._handle_warmup(ctx.runtime.cfg.use_torch_compile, loop_start, control_interval):
|
||||
continue
|
||||
|
||||
action_dict = send_next_action(obs_processed, obs, ctx, interpolator)
|
||||
|
||||
if action_dict is not None:
|
||||
obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR)
|
||||
action_frame = build_dataset_frame(features, action_dict, prefix=ACTION)
|
||||
dataset.add_frame({**obs_frame, **action_frame, "task": single_task})
|
||||
self._log_telemetry(obs_processed, action_dict, ctx.runtime)
|
||||
|
||||
dt = time.perf_counter() - loop_start
|
||||
sleep_t = control_interval - dt
|
||||
if sleep_t < 0:
|
||||
logger.warning(
|
||||
f"Record loop is running slower ({1 / dt:.1f} Hz) than the target FPS ({fps} Hz). "
|
||||
"Dataset frames might be dropped and robot control might be unstable. "
|
||||
"Common causes are: 1) Camera FPS not keeping up 2) Policy inference taking too long "
|
||||
"3) CPU starvation"
|
||||
)
|
||||
precise_sleep(max(sleep_t, 0.0))
|
||||
timestamp = time.perf_counter() - start_t
|
||||
|
||||
def _reset_loop(
|
||||
self,
|
||||
ctx: RolloutContext,
|
||||
robot,
|
||||
teleop,
|
||||
events: dict,
|
||||
fps: float,
|
||||
control_time_s: float,
|
||||
display_data: bool,
|
||||
display_compressed: bool,
|
||||
) -> None:
|
||||
"""Reset-phase loop: teleop drives the robot if available, no recording."""
|
||||
processors = ctx.processors
|
||||
control_interval = 1.0 / fps
|
||||
|
||||
timestamp = 0.0
|
||||
start_t = time.perf_counter()
|
||||
|
||||
while timestamp < control_time_s:
|
||||
loop_start = time.perf_counter()
|
||||
|
||||
if events["exit_early"]:
|
||||
events["exit_early"] = False
|
||||
break
|
||||
|
||||
if ctx.runtime.shutdown_event.is_set():
|
||||
break
|
||||
|
||||
obs = robot.get_observation()
|
||||
|
||||
if teleop is not None:
|
||||
act = teleop.get_action()
|
||||
act_teleop = processors.teleop_action_processor((act, obs))
|
||||
robot_action = processors.robot_action_processor((act_teleop, obs))
|
||||
robot.send_action(robot_action)
|
||||
|
||||
if display_data:
|
||||
obs_processed = processors.robot_observation_processor(obs)
|
||||
log_rerun_data(
|
||||
observation=obs_processed,
|
||||
action=act_teleop,
|
||||
compress_images=display_compressed,
|
||||
)
|
||||
|
||||
dt = time.perf_counter() - loop_start
|
||||
sleep_t = control_interval - dt
|
||||
precise_sleep(max(sleep_t, 0.0))
|
||||
timestamp = time.perf_counter() - start_t
|
||||
|
||||
def teardown(self, ctx: RolloutContext) -> None:
|
||||
"""Finalise dataset, stop listener, push to hub, and disconnect hardware."""
|
||||
cfg = ctx.runtime.cfg
|
||||
play_sounds = cfg.play_sounds
|
||||
|
||||
log_say("Stop recording", play_sounds, blocking=True)
|
||||
|
||||
if not is_headless() and self._listener is not None:
|
||||
self._listener.stop()
|
||||
|
||||
if ctx.data.dataset is not None:
|
||||
logger.info("Finalizing dataset...")
|
||||
ctx.data.dataset.finalize()
|
||||
|
||||
if (
|
||||
cfg.dataset is not None
|
||||
and cfg.dataset.push_to_hub
|
||||
and ctx.data.dataset is not None
|
||||
and safe_push_to_hub(
|
||||
ctx.data.dataset,
|
||||
tags=cfg.dataset.tags,
|
||||
private=cfg.dataset.private,
|
||||
)
|
||||
):
|
||||
logger.info("Dataset uploaded to hub")
|
||||
log_say("Dataset uploaded to hub", play_sounds)
|
||||
|
||||
self._teardown_hardware(
|
||||
ctx.hardware,
|
||||
return_to_initial_position=cfg.return_to_initial_position,
|
||||
)
|
||||
log_say("Exiting", play_sounds)
|
||||
logger.info("Episodic strategy teardown complete")
|
||||
@@ -21,6 +21,7 @@ from typing import TYPE_CHECKING
|
||||
from .base import BaseStrategy
|
||||
from .core import RolloutStrategy
|
||||
from .dagger import DAggerStrategy
|
||||
from .episodic import EpisodicStrategy
|
||||
from .highlight import HighlightStrategy
|
||||
from .sentry import SentryStrategy
|
||||
|
||||
@@ -42,4 +43,8 @@ def create_strategy(config: RolloutStrategyConfig) -> RolloutStrategy:
|
||||
return HighlightStrategy(config)
|
||||
if config.type == "dagger":
|
||||
return DAggerStrategy(config)
|
||||
raise ValueError(f"Unknown strategy type '{config.type}'. Available: base, sentry, highlight, dagger")
|
||||
if config.type == "episodic":
|
||||
return EpisodicStrategy(config)
|
||||
raise ValueError(
|
||||
f"Unknown strategy type '{config.type}'. Available: base, sentry, highlight, dagger, episodic"
|
||||
)
|
||||
|
||||
@@ -0,0 +1,93 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Serve a pretrained policy to remote ``lerobot-rollout`` clients over Zenoh.
|
||||
|
||||
One process = one pre-warmed (model, revision, dtype, device) on one GPU.
|
||||
Robots connect with ``lerobot-rollout --inference.type=remote``.
|
||||
|
||||
Usage examples
|
||||
--------------
|
||||
|
||||
Serve a model from a YAML manifest::
|
||||
|
||||
lerobot-policy-server --manifest server.yaml
|
||||
|
||||
Minimal manifest::
|
||||
|
||||
model:
|
||||
repo_or_path: lerobot/pi0_towels
|
||||
device: cuda
|
||||
default_task: "fold the towel"
|
||||
max_sessions: 5
|
||||
zenoh:
|
||||
mode: client
|
||||
connect_endpoints: ["tcp/router.gpu-cluster.internal:7447"]
|
||||
|
||||
Everything in the manifest can also be set directly on the CLI::
|
||||
|
||||
lerobot-policy-server \\
|
||||
--model.repo_or_path=lerobot/pi0_towels \\
|
||||
--default_task="fold the towel" \\
|
||||
--zenoh.mode=peer --zenoh.listen_endpoints='["tcp/0.0.0.0:7447"]'
|
||||
|
||||
SIGTERM/SIGINT drains gracefully: the server drops its liveliness token
|
||||
(clients ride their action buffers through the drain), finishes the
|
||||
in-flight inference, and exits.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import signal
|
||||
import sys
|
||||
|
||||
from lerobot.configs import parser
|
||||
from lerobot.policy_server.manifest import PolicyServerManifest
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def policy_server(manifest: PolicyServerManifest):
|
||||
init_logging()
|
||||
from lerobot.policy_server.server import PolicyServer
|
||||
|
||||
server = PolicyServer(manifest)
|
||||
server.load_policy()
|
||||
|
||||
def _drain(signum, frame): # noqa: ARG001
|
||||
logger.info("Signal %s received — draining", signum)
|
||||
server.stop()
|
||||
sys.exit(0)
|
||||
|
||||
signal.signal(signal.SIGTERM, _drain)
|
||||
server.start()
|
||||
server.serve_forever()
|
||||
|
||||
|
||||
def main():
|
||||
# `--manifest path.yaml` is sugar for draccus's `--config_path`.
|
||||
sys.argv = [
|
||||
arg.replace("--manifest=", "--config_path=") if arg.startswith("--manifest=") else arg
|
||||
for arg in sys.argv
|
||||
]
|
||||
if "--manifest" in sys.argv:
|
||||
sys.argv[sys.argv.index("--manifest")] = "--config_path"
|
||||
policy_server()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -25,11 +25,13 @@ Strategies
|
||||
--strategy.type=sentry Continuous recording with auto-upload
|
||||
--strategy.type=highlight Ring buffer + keystroke save
|
||||
--strategy.type=dagger Human-in-the-loop (DAgger / RaC)
|
||||
--strategy.type=episodic Episode-oriented recording with reset phases
|
||||
|
||||
Inference backends
|
||||
------------------
|
||||
--inference.type=sync One policy call per control tick (default)
|
||||
--inference.type=rtc Real-Time Chunking for slow VLA models
|
||||
--inference.type=remote Network inference via lerobot-policy-server (weightless edge)
|
||||
|
||||
Usage examples
|
||||
--------------
|
||||
@@ -111,6 +113,18 @@ Usage examples
|
||||
--display_data=true \\
|
||||
--use_torch_compile=true
|
||||
|
||||
# Episodic mode — episode-oriented recording with reset phases
|
||||
lerobot-rollout \\
|
||||
--strategy.type=episodic \\
|
||||
--policy.path=user/my_policy \\
|
||||
--robot.type=so100_follower \\
|
||||
--robot.port=/dev/ttyACM0 \\
|
||||
--teleop.type=so100_leader \\
|
||||
--teleop.port=/dev/ttyACM1 \\
|
||||
--dataset.repo_id=user/rollout_episodic_data \\
|
||||
--dataset.num_episodes=20 \\
|
||||
--dataset.single_task="Grab the cube"
|
||||
|
||||
# Resume a previous sentry recording session
|
||||
lerobot-rollout \\
|
||||
--strategy.type=sentry \\
|
||||
@@ -132,6 +146,19 @@ Usage examples
|
||||
--dataset.camera_encoder.vcodec=h264 \\
|
||||
--dataset.camera_encoder.preset=fast \\
|
||||
--dataset.camera_encoder.extra_options={"tune": "film", "profile:v": "high", "bf": 2}
|
||||
|
||||
# Sentry mode — remote inference against a lerobot-policy-server (weightless edge)
|
||||
lerobot-rollout \\
|
||||
--strategy.type=sentry \\
|
||||
--strategy.upload_every_n_episodes=5 \\
|
||||
--policy.path=lerobot/pi0_base \\
|
||||
--inference.type=remote \\
|
||||
--inference.connect_endpoint=tcp/router.gpu-cluster.internal:7447 \\
|
||||
--inference.rtc.execution_horizon=10 \\
|
||||
--robot.type=so100_follower \\
|
||||
--robot.port=/dev/ttyACM0 \\
|
||||
--dataset.repo_id=user/rollout_sentry_data \\
|
||||
--dataset.single_task="patrol" --duration=3600
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
@@ -99,6 +99,9 @@ def update_policy(
|
||||
start_time = time.perf_counter()
|
||||
policy.train()
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
# Compute sample weights if a weighter is provided
|
||||
sample_weights = None
|
||||
weight_stats = None
|
||||
@@ -158,6 +161,8 @@ def update_policy(
|
||||
train_metrics.grad_norm = grad_norm.item()
|
||||
train_metrics.lr = optimizer.param_groups[0]["lr"]
|
||||
train_metrics.update_s = time.perf_counter() - start_time
|
||||
if torch.cuda.is_available():
|
||||
train_metrics.gpu_mem_gb = torch.cuda.max_memory_allocated() / (1024**3)
|
||||
return train_metrics, output_dict
|
||||
|
||||
|
||||
@@ -232,15 +237,18 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
# Dataset loading synchronization: main process downloads first to avoid race conditions
|
||||
if is_main_process:
|
||||
logging.info("Creating dataset")
|
||||
# Dataset loading synchronization: each node's local main process downloads first to avoid
|
||||
# race conditions (the global main process only exists on node 0, so gating on it would let
|
||||
# all ranks of the other nodes download and build the Arrow cache concurrently).
|
||||
if accelerator.is_local_main_process:
|
||||
if is_main_process:
|
||||
logging.info("Creating dataset")
|
||||
dataset = make_dataset(cfg)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# Now all other processes can safely load the dataset
|
||||
if not is_main_process:
|
||||
# Now all other processes can safely load the dataset from the local cache
|
||||
if not accelerator.is_local_main_process:
|
||||
dataset = make_dataset(cfg)
|
||||
|
||||
# Create environment used for evaluating checkpoints during training on simulation data.
|
||||
@@ -386,12 +394,19 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
# create dataloader for offline training
|
||||
if hasattr(active_cfg, "drop_n_last_frames"):
|
||||
shuffle = False
|
||||
# A dedicated generator (rather than the global torch RNG) lets accelerator.prepare
|
||||
# synchronize the shuffle permutation across ranks, keeping batch shards disjoint even
|
||||
# when ranks consume the global RNG asymmetrically (e.g. eval on the main process only).
|
||||
sampler_generator = torch.Generator()
|
||||
if cfg.seed is not None:
|
||||
sampler_generator.manual_seed(cfg.seed)
|
||||
sampler = EpisodeAwareSampler(
|
||||
dataset.meta.episodes["dataset_from_index"],
|
||||
dataset.meta.episodes["dataset_to_index"],
|
||||
episode_indices_to_use=dataset.episodes,
|
||||
drop_n_last_frames=active_cfg.drop_n_last_frames,
|
||||
shuffle=True,
|
||||
generator=sampler_generator,
|
||||
)
|
||||
else:
|
||||
shuffle = True
|
||||
@@ -424,12 +439,22 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
policy.train()
|
||||
|
||||
train_metrics = {
|
||||
"loss": AverageMeter("loss", ":.3f"),
|
||||
# Per-rank loss reflects only one shard of the global batch; mean recovers the loss DDP
|
||||
# is actually optimizing. grad_norm and lr are already identical on every rank (post
|
||||
# gradient sync / deterministic scheduler) so reducing them would be a no-op collective.
|
||||
"loss": AverageMeter("loss", ":.3f", reduction="mean"),
|
||||
"grad_norm": AverageMeter("grdn", ":.3f"),
|
||||
"lr": AverageMeter("lr", ":0.1e"),
|
||||
"update_s": AverageMeter("updt_s", ":.3f"),
|
||||
"dataloading_s": AverageMeter("data_s", ":.3f"),
|
||||
# Report the slowest rank for bottleneck-style timings so multi-GPU runs surface the
|
||||
# true straggler instead of rank 0's view.
|
||||
"update_s": AverageMeter("updt_s", ":.3f", reduction="max"),
|
||||
"dataloading_s": AverageMeter("data_s", ":.3f", reduction="max"),
|
||||
# Derived from the post-reduce max step time; set once per log window on the main rank.
|
||||
"samples_per_s": AverageMeter("smp/s", ":.0f"),
|
||||
}
|
||||
if torch.cuda.is_available():
|
||||
# max() because headroom is gated by the worst-case rank.
|
||||
train_metrics["gpu_mem_gb"] = AverageMeter("mem_gb", ":.2f", reduction="max")
|
||||
|
||||
# Keep global batch size for logging; MetricsTracker handles world size internally.
|
||||
effective_batch_size = cfg.batch_size * accelerator.num_processes
|
||||
@@ -481,21 +506,29 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
if is_main_process:
|
||||
progbar.update(1)
|
||||
train_tracker.step()
|
||||
is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0 and is_main_process
|
||||
is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0
|
||||
is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps
|
||||
is_eval_step = cfg.eval_freq > 0 and step % cfg.eval_freq == 0
|
||||
|
||||
if is_log_step:
|
||||
logging.info(train_tracker)
|
||||
if wandb_logger:
|
||||
wandb_log_dict = train_tracker.to_dict()
|
||||
if output_dict:
|
||||
wandb_log_dict.update(output_dict)
|
||||
# Log sample weighting statistics if enabled
|
||||
if sample_weighter is not None:
|
||||
weighter_stats = sample_weighter.get_stats()
|
||||
wandb_log_dict.update({f"sample_weighting/{k}": v for k, v in weighter_stats.items()})
|
||||
wandb_logger.log_dict(wandb_log_dict, step)
|
||||
# Collective reduce must run on every rank, before the main-process gate below.
|
||||
train_tracker.reduce_across_ranks()
|
||||
if is_main_process:
|
||||
# Cluster-wide throughput, derived from the already-reduced (max) step time so it
|
||||
# reflects the slowest rank — which is what actually gates the next iteration.
|
||||
step_time = train_tracker.update_s.avg + train_tracker.dataloading_s.avg
|
||||
if step_time > 0:
|
||||
train_tracker.samples_per_s = effective_batch_size / step_time
|
||||
logging.info(train_tracker)
|
||||
if wandb_logger:
|
||||
wandb_log_dict = train_tracker.to_dict()
|
||||
if output_dict:
|
||||
wandb_log_dict.update(output_dict)
|
||||
# Log sample weighting statistics if enabled
|
||||
if sample_weighter is not None:
|
||||
weighter_stats = sample_weighter.get_stats()
|
||||
wandb_log_dict.update({f"sample_weighting/{k}": v for k, v in weighter_stats.items()})
|
||||
wandb_logger.log_dict(wandb_log_dict, step)
|
||||
train_tracker.reset_averages()
|
||||
|
||||
if cfg.save_checkpoint and is_saving_step:
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
gRPC transport layer for async inference.
|
||||
gRPC transport layer for the HIL-SERL RL stack (actor ↔ learner).
|
||||
|
||||
Requires: ``pip install 'lerobot[grpcio-dep]'``
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.python -m grpc_tools.protoc -I src --python_out=src --grpc_python_out=src src/lerobot/transport/services.proto
|
||||
// limitations under the License.
|
||||
|
||||
// To generate a classes for transport part (services_pb2.py and services_pb2_grpc.py) use the following command:
|
||||
//
|
||||
@@ -33,17 +33,6 @@ service LearnerService {
|
||||
rpc Ready(Empty) returns (Empty);
|
||||
}
|
||||
|
||||
// AsyncInference: from Robot perspective
|
||||
// Robot send observations to & executes action received from a remote Policy server
|
||||
service AsyncInference {
|
||||
// Robot -> Policy to share observations with a remote inference server
|
||||
// Policy -> Robot to share actions predicted for given observations
|
||||
rpc SendObservations(stream Observation) returns (Empty);
|
||||
rpc GetActions(Empty) returns (Actions);
|
||||
rpc SendPolicyInstructions(PolicySetup) returns (Empty);
|
||||
rpc Ready(Empty) returns (Empty);
|
||||
}
|
||||
|
||||
enum TransferState {
|
||||
TRANSFER_UNKNOWN = 0;
|
||||
TRANSFER_BEGIN = 1;
|
||||
@@ -67,21 +56,4 @@ message InteractionMessage {
|
||||
bytes data = 2;
|
||||
}
|
||||
|
||||
// Messages
|
||||
message Observation {
|
||||
// sent by Robot, to remote Policy
|
||||
TransferState transfer_state = 1; // Observations can be streamed exceeding 4MB of size
|
||||
bytes data = 2;
|
||||
}
|
||||
|
||||
message Actions {
|
||||
// sent by remote Policy, to Robot
|
||||
bytes data = 1;
|
||||
}
|
||||
|
||||
message PolicySetup {
|
||||
// sent by Robot to remote server, to init Policy
|
||||
bytes data = 1;
|
||||
}
|
||||
|
||||
message Empty {}
|
||||
|
||||
@@ -23,31 +23,23 @@ _sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n lerobot/transport/services.proto\x12\ttransport\"L\n\nTransition\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"L\n\nParameters\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"T\n\x12InteractionMessage\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"M\n\x0bObservation\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"\x17\n\x07\x41\x63tions\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\"\x1b\n\x0bPolicySetup\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\"\x07\n\x05\x45mpty*`\n\rTransferState\x12\x14\n\x10TRANSFER_UNKNOWN\x10\x00\x12\x12\n\x0eTRANSFER_BEGIN\x10\x01\x12\x13\n\x0fTRANSFER_MIDDLE\x10\x02\x12\x10\n\x0cTRANSFER_END\x10\x03\x32\x81\x02\n\x0eLearnerService\x12=\n\x10StreamParameters\x12\x10.transport.Empty\x1a\x15.transport.Parameters0\x01\x12<\n\x0fSendTransitions\x12\x15.transport.Transition\x1a\x10.transport.Empty(\x01\x12\x45\n\x10SendInteractions\x12\x1d.transport.InteractionMessage\x1a\x10.transport.Empty(\x01\x12+\n\x05Ready\x12\x10.transport.Empty\x1a\x10.transport.Empty2\xf5\x01\n\x0e\x41syncInference\x12>\n\x10SendObservations\x12\x16.transport.Observation\x1a\x10.transport.Empty(\x01\x12\x32\n\nGetActions\x12\x10.transport.Empty\x1a\x12.transport.Actions\x12\x42\n\x16SendPolicyInstructions\x12\x16.transport.PolicySetup\x1a\x10.transport.Empty\x12+\n\x05Ready\x12\x10.transport.Empty\x1a\x10.transport.Emptyb\x06proto3')
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n lerobot/transport/services.proto\x12\ttransport\"L\n\nTransition\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"L\n\nParameters\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"T\n\x12InteractionMessage\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"\x07\n\x05\x45mpty*`\n\rTransferState\x12\x14\n\x10TRANSFER_UNKNOWN\x10\x00\x12\x12\n\x0eTRANSFER_BEGIN\x10\x01\x12\x13\n\x0fTRANSFER_MIDDLE\x10\x02\x12\x10\n\x0cTRANSFER_END\x10\x03\x32\x81\x02\n\x0eLearnerService\x12=\n\x10StreamParameters\x12\x10.transport.Empty\x1a\x15.transport.Parameters0\x01\x12<\n\x0fSendTransitions\x12\x15.transport.Transition\x1a\x10.transport.Empty(\x01\x12\x45\n\x10SendInteractions\x12\x1d.transport.InteractionMessage\x1a\x10.transport.Empty(\x01\x12+\n\x05Ready\x12\x10.transport.Empty\x1a\x10.transport.Emptyb\x06proto3')
|
||||
|
||||
_globals = globals()
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'lerobot.transport.services_pb2', _globals)
|
||||
if not _descriptor._USE_C_DESCRIPTORS:
|
||||
DESCRIPTOR._loaded_options = None
|
||||
_globals['_TRANSFERSTATE']._serialized_start=431
|
||||
_globals['_TRANSFERSTATE']._serialized_end=527
|
||||
_globals['_TRANSFERSTATE']._serialized_start=298
|
||||
_globals['_TRANSFERSTATE']._serialized_end=394
|
||||
_globals['_TRANSITION']._serialized_start=47
|
||||
_globals['_TRANSITION']._serialized_end=123
|
||||
_globals['_PARAMETERS']._serialized_start=125
|
||||
_globals['_PARAMETERS']._serialized_end=201
|
||||
_globals['_INTERACTIONMESSAGE']._serialized_start=203
|
||||
_globals['_INTERACTIONMESSAGE']._serialized_end=287
|
||||
_globals['_OBSERVATION']._serialized_start=289
|
||||
_globals['_OBSERVATION']._serialized_end=366
|
||||
_globals['_ACTIONS']._serialized_start=368
|
||||
_globals['_ACTIONS']._serialized_end=391
|
||||
_globals['_POLICYSETUP']._serialized_start=393
|
||||
_globals['_POLICYSETUP']._serialized_end=420
|
||||
_globals['_EMPTY']._serialized_start=422
|
||||
_globals['_EMPTY']._serialized_end=429
|
||||
_globals['_LEARNERSERVICE']._serialized_start=530
|
||||
_globals['_LEARNERSERVICE']._serialized_end=787
|
||||
_globals['_ASYNCINFERENCE']._serialized_start=790
|
||||
_globals['_ASYNCINFERENCE']._serialized_end=1035
|
||||
_globals['_EMPTY']._serialized_start=289
|
||||
_globals['_EMPTY']._serialized_end=296
|
||||
_globals['_LEARNERSERVICE']._serialized_start=397
|
||||
_globals['_LEARNERSERVICE']._serialized_end=654
|
||||
# @@protoc_insertion_point(module_scope)
|
||||
|
||||
@@ -231,212 +231,3 @@ class LearnerService:
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
|
||||
class AsyncInferenceStub:
|
||||
"""AsyncInference: from Robot perspective
|
||||
Robot send observations to & executes action received from a remote Policy server
|
||||
"""
|
||||
|
||||
def __init__(self, channel):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
channel: A grpc.Channel.
|
||||
"""
|
||||
self.SendObservations = channel.stream_unary(
|
||||
'/transport.AsyncInference/SendObservations',
|
||||
request_serializer=lerobot_dot_transport_dot_services__pb2.Observation.SerializeToString,
|
||||
response_deserializer=lerobot_dot_transport_dot_services__pb2.Empty.FromString,
|
||||
_registered_method=True)
|
||||
self.GetActions = channel.unary_unary(
|
||||
'/transport.AsyncInference/GetActions',
|
||||
request_serializer=lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString,
|
||||
response_deserializer=lerobot_dot_transport_dot_services__pb2.Actions.FromString,
|
||||
_registered_method=True)
|
||||
self.SendPolicyInstructions = channel.unary_unary(
|
||||
'/transport.AsyncInference/SendPolicyInstructions',
|
||||
request_serializer=lerobot_dot_transport_dot_services__pb2.PolicySetup.SerializeToString,
|
||||
response_deserializer=lerobot_dot_transport_dot_services__pb2.Empty.FromString,
|
||||
_registered_method=True)
|
||||
self.Ready = channel.unary_unary(
|
||||
'/transport.AsyncInference/Ready',
|
||||
request_serializer=lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString,
|
||||
response_deserializer=lerobot_dot_transport_dot_services__pb2.Empty.FromString,
|
||||
_registered_method=True)
|
||||
|
||||
|
||||
class AsyncInferenceServicer:
|
||||
"""AsyncInference: from Robot perspective
|
||||
Robot send observations to & executes action received from a remote Policy server
|
||||
"""
|
||||
|
||||
def SendObservations(self, request_iterator, context):
|
||||
"""Robot -> Policy to share observations with a remote inference server
|
||||
Policy -> Robot to share actions predicted for given observations
|
||||
"""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def GetActions(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def SendPolicyInstructions(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def Ready(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
|
||||
def add_AsyncInferenceServicer_to_server(servicer, server):
|
||||
rpc_method_handlers = {
|
||||
'SendObservations': grpc.stream_unary_rpc_method_handler(
|
||||
servicer.SendObservations,
|
||||
request_deserializer=lerobot_dot_transport_dot_services__pb2.Observation.FromString,
|
||||
response_serializer=lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString,
|
||||
),
|
||||
'GetActions': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.GetActions,
|
||||
request_deserializer=lerobot_dot_transport_dot_services__pb2.Empty.FromString,
|
||||
response_serializer=lerobot_dot_transport_dot_services__pb2.Actions.SerializeToString,
|
||||
),
|
||||
'SendPolicyInstructions': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.SendPolicyInstructions,
|
||||
request_deserializer=lerobot_dot_transport_dot_services__pb2.PolicySetup.FromString,
|
||||
response_serializer=lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString,
|
||||
),
|
||||
'Ready': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.Ready,
|
||||
request_deserializer=lerobot_dot_transport_dot_services__pb2.Empty.FromString,
|
||||
response_serializer=lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString,
|
||||
),
|
||||
}
|
||||
generic_handler = grpc.method_handlers_generic_handler(
|
||||
'transport.AsyncInference', rpc_method_handlers)
|
||||
server.add_generic_rpc_handlers((generic_handler,))
|
||||
server.add_registered_method_handlers('transport.AsyncInference', rpc_method_handlers)
|
||||
|
||||
|
||||
# This class is part of an EXPERIMENTAL API.
|
||||
class AsyncInference:
|
||||
"""AsyncInference: from Robot perspective
|
||||
Robot send observations to & executes action received from a remote Policy server
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def SendObservations(request_iterator,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.stream_unary(
|
||||
request_iterator,
|
||||
target,
|
||||
'/transport.AsyncInference/SendObservations',
|
||||
lerobot_dot_transport_dot_services__pb2.Observation.SerializeToString,
|
||||
lerobot_dot_transport_dot_services__pb2.Empty.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def GetActions(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/transport.AsyncInference/GetActions',
|
||||
lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString,
|
||||
lerobot_dot_transport_dot_services__pb2.Actions.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def SendPolicyInstructions(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/transport.AsyncInference/SendPolicyInstructions',
|
||||
lerobot_dot_transport_dot_services__pb2.PolicySetup.SerializeToString,
|
||||
lerobot_dot_transport_dot_services__pb2.Empty.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def Ready(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/transport.AsyncInference/Ready',
|
||||
lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString,
|
||||
lerobot_dot_transport_dot_services__pb2.Empty.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@@ -13,21 +13,39 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from .utils import format_big_number
|
||||
|
||||
_VALID_REDUCTIONS = ("none", "max", "mean", "sum")
|
||||
|
||||
|
||||
class AverageMeter:
|
||||
"""
|
||||
Computes and stores the average and current value
|
||||
Adapted from https://github.com/pytorch/examples/blob/main/imagenet/main.py
|
||||
|
||||
Args:
|
||||
name: Display name of the metric.
|
||||
fmt: Format string used when rendering the metric.
|
||||
reduction: Cross-process reduction applied by :meth:`MetricsTracker.reduce_across_ranks`
|
||||
before logging. One of ``"none"`` (per-rank value, default), ``"max"``, ``"mean"``,
|
||||
or ``"sum"``. Use ``"max"`` for bottleneck-style metrics (e.g. dataloading or
|
||||
update wall time) so multi-GPU runs report the slowest rank rather than rank 0.
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, fmt: str = ":f"):
|
||||
def __init__(self, name: str, fmt: str = ":f", reduction: str = "none"):
|
||||
if reduction not in _VALID_REDUCTIONS:
|
||||
raise ValueError(
|
||||
f"Invalid reduction {reduction!r} for AverageMeter; expected one of {_VALID_REDUCTIONS}."
|
||||
)
|
||||
self.name = name
|
||||
self.fmt = fmt
|
||||
self.reduction = reduction
|
||||
self.reset()
|
||||
|
||||
def reset(self) -> None:
|
||||
@@ -138,6 +156,37 @@ class MetricsTracker:
|
||||
self.episodes = self.samples / self._avg_samples_per_ep
|
||||
self.epochs = self.samples / self._num_frames
|
||||
|
||||
def reduce_across_ranks(self) -> None:
|
||||
"""
|
||||
Synchronises the running averages of every metric whose ``reduction`` is not ``"none"``
|
||||
across all distributed processes (in-place).
|
||||
|
||||
This is a collective operation and MUST be invoked on every rank — typically just before
|
||||
logging. With no accelerator or in single-process runs it is a no-op. Without it, metrics
|
||||
reported by the main process only reflect rank 0; for bottleneck-style timings
|
||||
(``dataloading_s``, ``update_s``, ...) that means the slowest worker's stall is invisible.
|
||||
"""
|
||||
if self.accelerator is None or self.accelerator.num_processes <= 1:
|
||||
return
|
||||
|
||||
buckets: dict[str, list[str]] = defaultdict(list)
|
||||
for name, meter in self.metrics.items():
|
||||
if meter.reduction != "none":
|
||||
buckets[meter.reduction].append(name)
|
||||
if not buckets:
|
||||
return
|
||||
|
||||
device = self.accelerator.device
|
||||
for reduction, names in buckets.items():
|
||||
tensor = torch.tensor([self.metrics[n].avg for n in names], dtype=torch.float32, device=device)
|
||||
reduced = self.accelerator.reduce(tensor, reduction=reduction)
|
||||
for name, value in zip(names, reduced.tolist(), strict=True):
|
||||
meter = self.metrics[name]
|
||||
# Preserve avg == sum / count so a later .update() on this meter accumulates
|
||||
# against the cluster view, not the stale per-rank history.
|
||||
meter.avg = value
|
||||
meter.sum = value * meter.count
|
||||
|
||||
def __str__(self) -> str:
|
||||
display_list = [
|
||||
f"step:{format_big_number(self.steps)}",
|
||||
|
||||
@@ -1,187 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""End-to-end test of the asynchronous inference stack (client ↔ server).
|
||||
|
||||
This test spins up a lightweight gRPC `PolicyServer` instance with a stubbed
|
||||
policy network and launches a `RobotClient` that uses a `MockRobot`. The goal
|
||||
is to exercise the full communication loop:
|
||||
|
||||
1. Client sends policy specification → Server
|
||||
2. Client streams observations → Server
|
||||
3. Server streams action chunks → Client
|
||||
4. Client executes received actions
|
||||
|
||||
The test succeeds if at least one action is executed and the server records at
|
||||
least one predicted timestep - demonstrating that the gRPC round-trip works
|
||||
end-to-end using real (but lightweight) protocol messages.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
from concurrent import futures
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
# Skip entire module if required deps are not available
|
||||
pytest.importorskip("grpc")
|
||||
pytest.importorskip("serial", reason="pyserial is required (install lerobot[hardware])")
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# End-to-end test
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_async_inference_e2e(monkeypatch):
|
||||
"""Tests the full asynchronous inference pipeline."""
|
||||
# Import grpc-dependent modules inside the test function
|
||||
import grpc
|
||||
|
||||
from lerobot.async_inference.configs import PolicyServerConfig, RobotClientConfig
|
||||
from lerobot.async_inference.helpers import map_robot_keys_to_lerobot_features
|
||||
from lerobot.async_inference.policy_server import PolicyServer
|
||||
from lerobot.async_inference.robot_client import RobotClient
|
||||
from lerobot.robots.utils import make_robot_from_config
|
||||
from lerobot.transport import (
|
||||
services_pb2, # type: ignore
|
||||
services_pb2_grpc, # type: ignore
|
||||
)
|
||||
from tests.mocks.mock_robot import MockRobotConfig
|
||||
|
||||
# Create a stub policy similar to test_policy_server.py
|
||||
class MockPolicy:
|
||||
"""A minimal mock for an actual policy, returning zeros."""
|
||||
|
||||
class _Config:
|
||||
robot_type = "dummy_robot"
|
||||
|
||||
@property
|
||||
def image_features(self):
|
||||
"""Empty image features since this test doesn't use images."""
|
||||
return {}
|
||||
|
||||
def __init__(self):
|
||||
self.config = self._Config()
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
def model(self, batch):
|
||||
# Return a chunk of 20 dummy actions.
|
||||
batch_size = len(batch["robot_type"])
|
||||
return torch.zeros(batch_size, 20, 6)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 1. Create PolicyServer instance with mock policy
|
||||
# ------------------------------------------------------------------
|
||||
policy_server_config = PolicyServerConfig(host="localhost", port=9999)
|
||||
policy_server = PolicyServer(policy_server_config)
|
||||
# Replace the real policy with our fast, deterministic stub.
|
||||
policy_server.policy = MockPolicy()
|
||||
policy_server.actions_per_chunk = 20
|
||||
policy_server.device = "cpu"
|
||||
# NOTE(Steven): Smelly tests as the Server is a state machine being partially mocked. Adding these processors as a quick fix.
|
||||
policy_server.preprocessor = lambda obs: obs
|
||||
policy_server.postprocessor = lambda tensor: tensor
|
||||
|
||||
# Set up robot config and features
|
||||
robot_config = MockRobotConfig()
|
||||
mock_robot = make_robot_from_config(robot_config)
|
||||
|
||||
lerobot_features = map_robot_keys_to_lerobot_features(mock_robot)
|
||||
policy_server.lerobot_features = lerobot_features
|
||||
|
||||
# Force server to produce deterministic action chunks in test mode
|
||||
policy_server.policy_type = "act"
|
||||
|
||||
def _fake_get_action_chunk(_self, _obs, _type="test"):
|
||||
action_dim = 6
|
||||
batch_size = 1
|
||||
actions_per_chunk = policy_server.actions_per_chunk
|
||||
|
||||
return torch.zeros(batch_size, actions_per_chunk, action_dim)
|
||||
|
||||
monkeypatch.setattr(PolicyServer, "_get_action_chunk", _fake_get_action_chunk, raising=True)
|
||||
|
||||
# Bypass potentially heavy model loading inside SendPolicyInstructions
|
||||
def _fake_send_policy_instructions(self, request, context): # noqa: N802
|
||||
return services_pb2.Empty()
|
||||
|
||||
monkeypatch.setattr(PolicyServer, "SendPolicyInstructions", _fake_send_policy_instructions, raising=True)
|
||||
|
||||
# Build gRPC server running a PolicyServer
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=1, thread_name_prefix="policy_server"))
|
||||
services_pb2_grpc.add_AsyncInferenceServicer_to_server(policy_server, server)
|
||||
|
||||
# Use the host/port specified in the fixture's config
|
||||
server_address = f"{policy_server.config.host}:{policy_server.config.port}"
|
||||
server.add_insecure_port(server_address)
|
||||
server.start()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 2. Create a RobotClient around the MockRobot
|
||||
# ------------------------------------------------------------------
|
||||
client_config = RobotClientConfig(
|
||||
server_address=server_address,
|
||||
robot=robot_config,
|
||||
chunk_size_threshold=0.0,
|
||||
policy_type="test",
|
||||
pretrained_name_or_path="test",
|
||||
actions_per_chunk=20,
|
||||
)
|
||||
|
||||
client = RobotClient(client_config)
|
||||
assert client.start(), "Client failed initial handshake with the server"
|
||||
|
||||
# Track action chunks received and verify device type
|
||||
action_chunks_received = {"count": 0, "actions_on_cpu": True}
|
||||
original_aggregate = client._aggregate_action_queues
|
||||
|
||||
def counting_aggregate(*args, **kwargs):
|
||||
action_chunks_received["count"] += 1
|
||||
# Check that all received actions are on CPU
|
||||
if args:
|
||||
for timed_action in args[0]: # args[0] is the list of TimedAction
|
||||
action_tensor = timed_action.get_action()
|
||||
if action_tensor.device.type != "cpu":
|
||||
action_chunks_received["actions_on_cpu"] = False
|
||||
return original_aggregate(*args, **kwargs)
|
||||
|
||||
monkeypatch.setattr(client, "_aggregate_action_queues", counting_aggregate)
|
||||
|
||||
# Start client threads
|
||||
action_thread = threading.Thread(target=client.receive_actions, daemon=True)
|
||||
control_thread = threading.Thread(target=client.control_loop, args=({"task": ""}), daemon=True)
|
||||
action_thread.start()
|
||||
control_thread.start()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 3. System exchanges a few messages
|
||||
# ------------------------------------------------------------------
|
||||
# Wait for 5 seconds
|
||||
server.wait_for_termination(timeout=5)
|
||||
|
||||
assert action_chunks_received["count"] > 0, "Client did not receive any action chunks"
|
||||
assert len(policy_server._predicted_timesteps) > 0, "Server did not record any predicted timesteps"
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 4. Stop the system
|
||||
# ------------------------------------------------------------------
|
||||
client.stop()
|
||||
action_thread.join()
|
||||
control_thread.join()
|
||||
policy_server.stop()
|
||||
server.stop(grace=None)
|
||||
@@ -1,454 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
import pickle
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
pytest.importorskip("grpc")
|
||||
|
||||
import numpy as np # noqa: E402
|
||||
import torch # noqa: E402
|
||||
|
||||
from lerobot.async_inference.helpers import ( # noqa: E402
|
||||
FPSTracker,
|
||||
TimedAction,
|
||||
TimedObservation,
|
||||
observations_similar,
|
||||
prepare_image,
|
||||
prepare_raw_observation,
|
||||
raw_observation_to_observation,
|
||||
resize_robot_observation_image,
|
||||
)
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
# FPSTracker
|
||||
# ---------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_fps_tracker_first_observation():
|
||||
"""First observation should initialize timestamp and return 0 FPS."""
|
||||
tracker = FPSTracker(target_fps=30.0)
|
||||
timestamp = 1000.0
|
||||
|
||||
metrics = tracker.calculate_fps_metrics(timestamp)
|
||||
|
||||
assert tracker.first_timestamp == timestamp
|
||||
assert tracker.total_obs_count == 1
|
||||
assert metrics["avg_fps"] == 0.0
|
||||
assert metrics["target_fps"] == 30.0
|
||||
|
||||
|
||||
def test_fps_tracker_single_interval():
|
||||
"""Two observations 1 second apart should give 1 FPS."""
|
||||
tracker = FPSTracker(target_fps=30.0)
|
||||
|
||||
# First observation at t=0
|
||||
metrics1 = tracker.calculate_fps_metrics(0.0)
|
||||
assert metrics1["avg_fps"] == 0.0
|
||||
|
||||
# Second observation at t=1 (1 second later)
|
||||
metrics2 = tracker.calculate_fps_metrics(1.0)
|
||||
expected_fps = 1.0 # (2-1) observations / 1.0 seconds = 1 FPS
|
||||
assert math.isclose(metrics2["avg_fps"], expected_fps, rel_tol=1e-6)
|
||||
|
||||
|
||||
def test_fps_tracker_multiple_intervals():
|
||||
"""Multiple observations should calculate correct average FPS."""
|
||||
tracker = FPSTracker(target_fps=30.0)
|
||||
|
||||
# Simulate 5 observations over 2 seconds (should be 2 FPS average)
|
||||
timestamps = [0.0, 0.5, 1.0, 1.5, 2.0]
|
||||
|
||||
for i, ts in enumerate(timestamps):
|
||||
metrics = tracker.calculate_fps_metrics(ts)
|
||||
|
||||
if i == 0:
|
||||
assert metrics["avg_fps"] == 0.0
|
||||
elif i == len(timestamps) - 1:
|
||||
# After 5 observations over 2 seconds: (5-1)/2 = 2 FPS
|
||||
expected_fps = 2.0
|
||||
assert math.isclose(metrics["avg_fps"], expected_fps, rel_tol=1e-6)
|
||||
|
||||
|
||||
def test_fps_tracker_irregular_intervals():
|
||||
"""FPS calculation should work with irregular time intervals."""
|
||||
tracker = FPSTracker(target_fps=30.0)
|
||||
|
||||
# Irregular timestamps: 0, 0.1, 0.5, 2.0, 3.0 seconds
|
||||
timestamps = [0.0, 0.1, 0.5, 2.0, 3.0]
|
||||
|
||||
for ts in timestamps:
|
||||
metrics = tracker.calculate_fps_metrics(ts)
|
||||
|
||||
# 5 observations over 3 seconds: (5-1)/3 = 1.333... FPS
|
||||
expected_fps = 4.0 / 3.0
|
||||
assert math.isclose(metrics["avg_fps"], expected_fps, rel_tol=1e-6)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
# TimedData helpers
|
||||
# ---------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_timed_action_getters():
|
||||
"""TimedAction stores & returns timestamp, action tensor and timestep."""
|
||||
ts = time.time()
|
||||
action = torch.arange(10)
|
||||
ta = TimedAction(timestamp=ts, action=action, timestep=0)
|
||||
|
||||
assert math.isclose(ta.get_timestamp(), ts, rel_tol=0, abs_tol=1e-6)
|
||||
torch.testing.assert_close(ta.get_action(), action)
|
||||
assert ta.get_timestep() == 0
|
||||
|
||||
|
||||
def test_timed_observation_getters():
|
||||
"""TimedObservation stores & returns timestamp, dict and timestep."""
|
||||
ts = time.time()
|
||||
obs_dict = {OBS_STATE: torch.ones(6)}
|
||||
to = TimedObservation(timestamp=ts, observation=obs_dict, timestep=0)
|
||||
|
||||
assert math.isclose(to.get_timestamp(), ts, rel_tol=0, abs_tol=1e-6)
|
||||
assert to.get_observation() is obs_dict
|
||||
assert to.get_timestep() == 0
|
||||
|
||||
|
||||
def test_timed_data_deserialization_data_getters():
|
||||
"""TimedAction / TimedObservation survive a round-trip through ``pickle``.
|
||||
|
||||
The async-inference stack uses ``pickle.dumps`` to move these objects across
|
||||
the gRPC boundary (see RobotClient.send_observation and PolicyServer.StreamActions).
|
||||
This test ensures that the payload keeps its content intact after
|
||||
the (de)serialization round-trip.
|
||||
"""
|
||||
ts = time.time()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# TimedAction
|
||||
# ------------------------------------------------------------------
|
||||
original_action = torch.randn(6)
|
||||
ta_in = TimedAction(timestamp=ts, action=original_action, timestep=13)
|
||||
|
||||
# Serialize → bytes → deserialize
|
||||
ta_bytes = pickle.dumps(ta_in) # nosec
|
||||
ta_out: TimedAction = pickle.loads(ta_bytes) # nosec B301
|
||||
|
||||
# Identity & content checks
|
||||
assert math.isclose(ta_out.get_timestamp(), ts, rel_tol=0, abs_tol=1e-6)
|
||||
assert ta_out.get_timestep() == 13
|
||||
torch.testing.assert_close(ta_out.get_action(), original_action)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# TimedObservation
|
||||
# ------------------------------------------------------------------
|
||||
obs_dict = {OBS_STATE: torch.arange(4).float()}
|
||||
to_in = TimedObservation(timestamp=ts, observation=obs_dict, timestep=7, must_go=True)
|
||||
|
||||
to_bytes = pickle.dumps(to_in) # nosec
|
||||
to_out: TimedObservation = pickle.loads(to_bytes) # nosec B301
|
||||
|
||||
assert math.isclose(to_out.get_timestamp(), ts, rel_tol=0, abs_tol=1e-6)
|
||||
assert to_out.get_timestep() == 7
|
||||
assert to_out.must_go is True
|
||||
assert to_out.get_observation().keys() == obs_dict.keys()
|
||||
torch.testing.assert_close(to_out.get_observation()[OBS_STATE], obs_dict[OBS_STATE])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
# observations_similar()
|
||||
# ---------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_obs(state: torch.Tensor) -> TimedObservation:
|
||||
"""Create a TimedObservation with raw robot observation format."""
|
||||
return TimedObservation(
|
||||
timestamp=time.time(),
|
||||
observation={
|
||||
"shoulder": state[0].item() if len(state) > 0 else 0.0,
|
||||
"elbow": state[1].item() if len(state) > 1 else 0.0,
|
||||
"wrist": state[2].item() if len(state) > 2 else 0.0,
|
||||
"gripper": state[3].item() if len(state) > 3 else 0.0,
|
||||
},
|
||||
timestep=0,
|
||||
)
|
||||
|
||||
|
||||
def test_observations_similar_true():
|
||||
"""Distance below atol → observations considered similar."""
|
||||
# Create mock lerobot features for the similarity check
|
||||
lerobot_features = {
|
||||
OBS_STATE: {
|
||||
"dtype": "float32",
|
||||
"shape": [4],
|
||||
"names": ["shoulder", "elbow", "wrist", "gripper"],
|
||||
}
|
||||
}
|
||||
|
||||
obs1 = _make_obs(torch.zeros(4))
|
||||
obs2 = _make_obs(0.5 * torch.ones(4))
|
||||
assert observations_similar(obs1, obs2, lerobot_features, atol=2.0)
|
||||
|
||||
obs3 = _make_obs(2.0 * torch.ones(4))
|
||||
assert not observations_similar(obs1, obs3, lerobot_features, atol=2.0)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
# raw_observation_to_observation and helpers
|
||||
# ---------------------------------------------------------------------
|
||||
|
||||
|
||||
def _create_mock_robot_observation():
|
||||
"""Create a mock robot observation with motor positions and camera images."""
|
||||
return {
|
||||
"shoulder": 1.0,
|
||||
"elbow": 2.0,
|
||||
"wrist": 3.0,
|
||||
"gripper": 0.5,
|
||||
"laptop": np.random.randint(0, 256, size=(480, 640, 3), dtype=np.uint8),
|
||||
"phone": np.random.randint(0, 256, size=(480, 640, 3), dtype=np.uint8),
|
||||
}
|
||||
|
||||
|
||||
def _create_mock_lerobot_features():
|
||||
"""Create mock lerobot features mapping similar to what hw_to_dataset_features returns."""
|
||||
return {
|
||||
OBS_STATE: {
|
||||
"dtype": "float32",
|
||||
"shape": [4],
|
||||
"names": ["shoulder", "elbow", "wrist", "gripper"],
|
||||
},
|
||||
f"{OBS_IMAGES}.laptop": {
|
||||
"dtype": "image",
|
||||
"shape": [480, 640, 3],
|
||||
"names": ["height", "width", "channels"],
|
||||
},
|
||||
f"{OBS_IMAGES}.phone": {
|
||||
"dtype": "image",
|
||||
"shape": [480, 640, 3],
|
||||
"names": ["height", "width", "channels"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _create_mock_policy_image_features():
|
||||
"""Create mock policy image features with different resolutions."""
|
||||
return {
|
||||
f"{OBS_IMAGES}.laptop": PolicyFeature(
|
||||
type=FeatureType.VISUAL,
|
||||
shape=(3, 224, 224), # Policy expects smaller resolution
|
||||
),
|
||||
f"{OBS_IMAGES}.phone": PolicyFeature(
|
||||
type=FeatureType.VISUAL,
|
||||
shape=(3, 160, 160), # Different resolution for second camera
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def test_prepare_image():
|
||||
"""Test image preprocessing: int8 → float32, normalization to [0,1]."""
|
||||
# Create mock int8 image data
|
||||
image_int8 = torch.randint(0, 256, size=(3, 224, 224), dtype=torch.uint8)
|
||||
|
||||
processed = prepare_image(image_int8)
|
||||
|
||||
# Check dtype conversion
|
||||
assert processed.dtype == torch.float32
|
||||
|
||||
# Check normalization range
|
||||
assert processed.min() >= 0.0
|
||||
assert processed.max() <= 1.0
|
||||
|
||||
# Check that values are scaled correctly (255 → 1.0, 0 → 0.0)
|
||||
if image_int8.max() == 255:
|
||||
assert torch.isclose(processed.max(), torch.tensor(1.0), atol=1e-6)
|
||||
if image_int8.min() == 0:
|
||||
assert torch.isclose(processed.min(), torch.tensor(0.0), atol=1e-6)
|
||||
|
||||
# Check memory contiguity
|
||||
assert processed.is_contiguous()
|
||||
|
||||
|
||||
def test_resize_robot_observation_image():
|
||||
"""Test image resizing from robot resolution to policy resolution."""
|
||||
# Create mock image: (H=480, W=640, C=3)
|
||||
original_image = torch.randint(0, 256, size=(480, 640, 3), dtype=torch.uint8)
|
||||
target_shape = (3, 224, 224) # (C, H, W)
|
||||
|
||||
resized = resize_robot_observation_image(original_image, target_shape)
|
||||
|
||||
# Check output shape matches target
|
||||
assert resized.shape == target_shape
|
||||
|
||||
# Check that original image had different dimensions
|
||||
assert original_image.shape != resized.shape
|
||||
|
||||
# Check that resizing preserves value range
|
||||
assert resized.min() >= 0
|
||||
assert resized.max() <= 255
|
||||
|
||||
|
||||
def test_prepare_raw_observation():
|
||||
"""Test the preparation of raw robot observation to lerobot format."""
|
||||
robot_obs = _create_mock_robot_observation()
|
||||
lerobot_features = _create_mock_lerobot_features()
|
||||
policy_image_features = _create_mock_policy_image_features()
|
||||
|
||||
prepared = prepare_raw_observation(robot_obs, lerobot_features, policy_image_features)
|
||||
|
||||
# Check that state is properly extracted and batched
|
||||
assert OBS_STATE in prepared
|
||||
state = prepared[OBS_STATE]
|
||||
assert isinstance(state, torch.Tensor)
|
||||
assert state.shape == (1, 4) # Batched state
|
||||
|
||||
# Check that images are processed and resized
|
||||
assert f"{OBS_IMAGES}.laptop" in prepared
|
||||
assert f"{OBS_IMAGES}.phone" in prepared
|
||||
|
||||
laptop_img = prepared[f"{OBS_IMAGES}.laptop"]
|
||||
phone_img = prepared[f"{OBS_IMAGES}.phone"]
|
||||
|
||||
# Check image shapes match policy requirements
|
||||
assert laptop_img.shape == policy_image_features[f"{OBS_IMAGES}.laptop"].shape
|
||||
assert phone_img.shape == policy_image_features[f"{OBS_IMAGES}.phone"].shape
|
||||
|
||||
# Check that images are tensors
|
||||
assert isinstance(laptop_img, torch.Tensor)
|
||||
assert isinstance(phone_img, torch.Tensor)
|
||||
|
||||
|
||||
def test_raw_observation_to_observation_basic():
|
||||
"""Test the main raw_observation_to_observation function."""
|
||||
robot_obs = _create_mock_robot_observation()
|
||||
lerobot_features = _create_mock_lerobot_features()
|
||||
policy_image_features = _create_mock_policy_image_features()
|
||||
|
||||
observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features)
|
||||
|
||||
# Check that all expected keys are present
|
||||
assert OBS_STATE in observation
|
||||
assert f"{OBS_IMAGES}.laptop" in observation
|
||||
assert f"{OBS_IMAGES}.phone" in observation
|
||||
|
||||
# Check state processing
|
||||
state = observation[OBS_STATE]
|
||||
assert isinstance(state, torch.Tensor)
|
||||
assert state.shape == (1, 4) # Batched
|
||||
|
||||
# Check image processing
|
||||
laptop_img = observation[f"{OBS_IMAGES}.laptop"]
|
||||
phone_img = observation[f"{OBS_IMAGES}.phone"]
|
||||
|
||||
# Images should have batch dimension: (B, C, H, W)
|
||||
assert laptop_img.shape == (1, 3, 224, 224)
|
||||
assert phone_img.shape == (1, 3, 160, 160)
|
||||
|
||||
# Check image dtype and range (should be float32 in [0, 1])
|
||||
assert laptop_img.dtype == torch.float32
|
||||
assert phone_img.dtype == torch.float32
|
||||
assert laptop_img.min() >= 0.0 and laptop_img.max() <= 1.0
|
||||
assert phone_img.min() >= 0.0 and phone_img.max() <= 1.0
|
||||
|
||||
|
||||
def test_raw_observation_to_observation_with_non_tensor_data():
|
||||
"""Test that non-tensor data (like task strings) is preserved."""
|
||||
robot_obs = _create_mock_robot_observation()
|
||||
robot_obs["task"] = "pick up the red cube" # Add string instruction
|
||||
|
||||
lerobot_features = _create_mock_lerobot_features()
|
||||
policy_image_features = _create_mock_policy_image_features()
|
||||
|
||||
observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features)
|
||||
|
||||
# Check that task string is preserved
|
||||
assert "task" in observation
|
||||
assert observation["task"] == "pick up the red cube"
|
||||
assert isinstance(observation["task"], str)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_raw_observation_to_observation_device_handling():
|
||||
"""Test that tensors are created (device placement is handled by preprocessor)."""
|
||||
robot_obs = _create_mock_robot_observation()
|
||||
lerobot_features = _create_mock_lerobot_features()
|
||||
policy_image_features = _create_mock_policy_image_features()
|
||||
|
||||
observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features)
|
||||
|
||||
# Check that all expected keys produce tensors (device placement handled by preprocessor later)
|
||||
for key, value in observation.items():
|
||||
if isinstance(value, torch.Tensor):
|
||||
assert value.device.type in ["cpu", "cuda", "mps", "xpu"], f"Tensor {key} on unexpected device"
|
||||
|
||||
|
||||
def test_raw_observation_to_observation_deterministic():
|
||||
"""Test that the function produces consistent results for the same input."""
|
||||
robot_obs = _create_mock_robot_observation()
|
||||
lerobot_features = _create_mock_lerobot_features()
|
||||
policy_image_features = _create_mock_policy_image_features()
|
||||
|
||||
# Run twice with same input
|
||||
obs1 = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features)
|
||||
obs2 = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features)
|
||||
|
||||
# Results should be identical
|
||||
assert set(obs1.keys()) == set(obs2.keys())
|
||||
|
||||
for key in obs1:
|
||||
if isinstance(obs1[key], torch.Tensor):
|
||||
torch.testing.assert_close(obs1[key], obs2[key])
|
||||
else:
|
||||
assert obs1[key] == obs2[key]
|
||||
|
||||
|
||||
def test_image_processing_pipeline_preserves_content():
|
||||
"""Test that the image processing pipeline preserves recognizable patterns."""
|
||||
# Create an image with a specific pattern
|
||||
original_img = np.zeros((100, 100, 3), dtype=np.uint8)
|
||||
original_img[25:75, 25:75, :] = 255 # White square in center
|
||||
|
||||
robot_obs = {"shoulder": 1.0, "elbow": 1.0, "wrist": 1.0, "gripper": 1.0, "laptop": original_img}
|
||||
lerobot_features = {
|
||||
OBS_STATE: {
|
||||
"dtype": "float32",
|
||||
"shape": [4],
|
||||
"names": ["shoulder", "elbow", "wrist", "gripper"],
|
||||
},
|
||||
f"{OBS_IMAGES}.laptop": {
|
||||
"dtype": "image",
|
||||
"shape": [100, 100, 3],
|
||||
"names": ["height", "width", "channels"],
|
||||
},
|
||||
}
|
||||
policy_image_features = {
|
||||
f"{OBS_IMAGES}.laptop": PolicyFeature(
|
||||
type=FeatureType.VISUAL,
|
||||
shape=(3, 50, 50), # Downsamples from 100x100
|
||||
)
|
||||
}
|
||||
|
||||
observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features)
|
||||
|
||||
processed_img = observation[f"{OBS_IMAGES}.laptop"].squeeze(0) # Remove batch dim
|
||||
|
||||
# Check that the center region has higher values than corners
|
||||
# Due to bilinear interpolation, exact values will change but pattern should remain
|
||||
center_val = processed_img[:, 25, 25].mean() # Center of 50x50 image
|
||||
corner_val = processed_img[:, 5, 5].mean() # Corner
|
||||
|
||||
assert center_val > corner_val, "Image processing should preserve recognizable patterns"
|
||||
@@ -1,219 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Unit-tests for the `PolicyServer` core logic.
|
||||
Monkey-patch the `policy` attribute with a stub so that no real model inference is performed.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.utils.constants import OBS_STATE
|
||||
from tests.utils import skip_if_package_missing
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Test fixtures
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class MockPolicy:
|
||||
"""A minimal mock for an actual policy, returning zeros.
|
||||
Refer to tests/policies for tests of the individual policies supported."""
|
||||
|
||||
class _Config:
|
||||
robot_type = "dummy_robot"
|
||||
|
||||
@property
|
||||
def image_features(self) -> dict[str, PolicyFeature]:
|
||||
"""Empty image features since this test doesn't use images."""
|
||||
return {}
|
||||
|
||||
def predict_action_chunk(self, observation: dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
"""Return a chunk of 20 dummy actions."""
|
||||
batch_size = len(observation[OBS_STATE])
|
||||
return torch.zeros(batch_size, 20, 6)
|
||||
|
||||
def __init__(self):
|
||||
self.config = self._Config()
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
# The server calls `policy.to(device)`. This stub ignores it.
|
||||
return self
|
||||
|
||||
def model(self, batch: dict) -> torch.Tensor:
|
||||
# Return a chunk of 20 dummy actions.
|
||||
batch_size = len(batch["robot_type"])
|
||||
return torch.zeros(batch_size, 20, 6)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@skip_if_package_missing("grpcio", "grpc")
|
||||
def policy_server():
|
||||
"""Fresh `PolicyServer` instance with a stubbed-out policy model."""
|
||||
# Import only when the test actually runs (after decorator check)
|
||||
from lerobot.async_inference.configs import PolicyServerConfig
|
||||
from lerobot.async_inference.policy_server import PolicyServer
|
||||
|
||||
test_config = PolicyServerConfig(host="localhost", port=9999)
|
||||
server = PolicyServer(test_config)
|
||||
# Replace the real policy with our fast, deterministic stub.
|
||||
server.policy = MockPolicy()
|
||||
server.actions_per_chunk = 20
|
||||
server.device = "cpu"
|
||||
|
||||
# Add mock lerobot_features that the observation similarity functions need
|
||||
server.lerobot_features = {
|
||||
OBS_STATE: {
|
||||
"dtype": "float32",
|
||||
"shape": [6],
|
||||
"names": ["joint1", "joint2", "joint3", "joint4", "joint5", "joint6"],
|
||||
}
|
||||
}
|
||||
|
||||
return server
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Helper utilities for tests
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_obs(state: torch.Tensor, timestep: int = 0, must_go: bool = False):
|
||||
"""Create a TimedObservation with a given state vector."""
|
||||
# Import only when needed
|
||||
from lerobot.async_inference.helpers import TimedObservation
|
||||
|
||||
return TimedObservation(
|
||||
observation={
|
||||
"joint1": state[0].item() if len(state) > 0 else 0.0,
|
||||
"joint2": state[1].item() if len(state) > 1 else 0.0,
|
||||
"joint3": state[2].item() if len(state) > 2 else 0.0,
|
||||
"joint4": state[3].item() if len(state) > 3 else 0.0,
|
||||
"joint5": state[4].item() if len(state) > 4 else 0.0,
|
||||
"joint6": state[5].item() if len(state) > 5 else 0.0,
|
||||
},
|
||||
timestamp=time.time(),
|
||||
timestep=timestep,
|
||||
must_go=must_go,
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Tests
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_time_action_chunk(policy_server):
|
||||
"""Verify that `_time_action_chunk` assigns correct timestamps and timesteps."""
|
||||
start_ts = time.time()
|
||||
start_t = 10
|
||||
# A chunk of 3 action tensors.
|
||||
action_tensors = [torch.randn(6) for _ in range(3)]
|
||||
|
||||
timed_actions = policy_server._time_action_chunk(start_ts, action_tensors, start_t)
|
||||
|
||||
assert len(timed_actions) == 3
|
||||
# Check timesteps
|
||||
assert [ta.get_timestep() for ta in timed_actions] == [10, 11, 12]
|
||||
# Check timestamps
|
||||
expected_timestamps = [
|
||||
start_ts,
|
||||
start_ts + policy_server.config.environment_dt,
|
||||
start_ts + 2 * policy_server.config.environment_dt,
|
||||
]
|
||||
for ta, expected_ts in zip(timed_actions, expected_timestamps, strict=True):
|
||||
assert abs(ta.get_timestamp() - expected_ts) < 1e-6
|
||||
|
||||
|
||||
def test_maybe_enqueue_observation_must_go(policy_server):
|
||||
"""An observation with `must_go=True` is always enqueued."""
|
||||
obs = _make_obs(torch.zeros(6), must_go=True)
|
||||
assert policy_server._enqueue_observation(obs) is True
|
||||
assert policy_server.observation_queue.qsize() == 1
|
||||
assert policy_server.observation_queue.get_nowait() is obs
|
||||
|
||||
|
||||
def test_maybe_enqueue_observation_dissimilar(policy_server):
|
||||
"""A dissimilar observation (not `must_go`) is enqueued."""
|
||||
# Set a last predicted observation.
|
||||
policy_server.last_processed_obs = _make_obs(torch.zeros(6))
|
||||
# Create a new, dissimilar observation.
|
||||
new_obs = _make_obs(torch.ones(6) * 5) # High norm difference
|
||||
|
||||
assert policy_server._enqueue_observation(new_obs) is True
|
||||
assert policy_server.observation_queue.qsize() == 1
|
||||
|
||||
|
||||
def test_maybe_enqueue_observation_is_skipped(policy_server):
|
||||
"""A similar observation (not `must_go`) is skipped."""
|
||||
# Set a last predicted observation.
|
||||
policy_server.last_processed_obs = _make_obs(torch.zeros(6))
|
||||
# Create a new, very similar observation.
|
||||
new_obs = _make_obs(torch.zeros(6) + 1e-4)
|
||||
|
||||
assert policy_server._enqueue_observation(new_obs) is False
|
||||
assert policy_server.observation_queue.empty() is True
|
||||
|
||||
|
||||
def test_obs_sanity_checks(policy_server):
|
||||
"""Unit-test the private `_obs_sanity_checks` helper."""
|
||||
prev = _make_obs(torch.zeros(6), timestep=0)
|
||||
|
||||
# Case 1 – timestep already predicted
|
||||
policy_server._predicted_timesteps.add(1)
|
||||
obs_same_ts = _make_obs(torch.ones(6), timestep=1)
|
||||
assert policy_server._obs_sanity_checks(obs_same_ts, prev) is False
|
||||
|
||||
# Case 2 – observation too similar
|
||||
policy_server._predicted_timesteps.clear()
|
||||
obs_similar = _make_obs(torch.zeros(6) + 1e-4, timestep=2)
|
||||
assert policy_server._obs_sanity_checks(obs_similar, prev) is False
|
||||
|
||||
# Case 3 – genuinely new & dissimilar observation passes
|
||||
obs_ok = _make_obs(torch.ones(6) * 5, timestep=3)
|
||||
assert policy_server._obs_sanity_checks(obs_ok, prev) is True
|
||||
|
||||
|
||||
def test_predict_action_chunk(monkeypatch, policy_server):
|
||||
"""End-to-end test of `_predict_action_chunk` with a stubbed _get_action_chunk."""
|
||||
# Import only when needed
|
||||
from lerobot.async_inference.policy_server import PolicyServer
|
||||
|
||||
# Force server to act-style policy; patch method to return deterministic tensor
|
||||
policy_server.policy_type = "act"
|
||||
# NOTE(Steven): Smelly tests as the Server is a state machine being partially mocked. Adding these processors as a quick fix.
|
||||
policy_server.preprocessor = lambda obs: obs
|
||||
policy_server.postprocessor = lambda tensor: tensor
|
||||
action_dim = 6
|
||||
batch_size = 1
|
||||
actions_per_chunk = policy_server.actions_per_chunk
|
||||
|
||||
def _fake_get_action_chunk(_self, _obs, _type="act"):
|
||||
return torch.zeros(batch_size, actions_per_chunk, action_dim)
|
||||
|
||||
monkeypatch.setattr(PolicyServer, "_get_action_chunk", _fake_get_action_chunk, raising=True)
|
||||
|
||||
obs = _make_obs(torch.zeros(6), timestep=5)
|
||||
timed_actions = policy_server._predict_action_chunk(obs)
|
||||
|
||||
assert len(timed_actions) == actions_per_chunk
|
||||
assert [ta.get_timestep() for ta in timed_actions] == list(range(5, 5 + actions_per_chunk))
|
||||
|
||||
for i, ta in enumerate(timed_actions):
|
||||
expected_ts = obs.get_timestamp() + i * policy_server.config.environment_dt
|
||||
assert abs(ta.get_timestamp() - expected_ts) < 1e-6
|
||||
@@ -1,271 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Unit-tests for the `RobotClient` action-queue logic (pure Python, no gRPC).
|
||||
|
||||
We monkey-patch `lerobot.robots.utils.make_robot_from_config` so that
|
||||
no real hardware is accessed. Only the queue-update mechanism is verified.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from queue import Queue
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
# Skip entire module if required deps are not available
|
||||
pytest.importorskip("grpc")
|
||||
pytest.importorskip("serial", reason="pyserial is required (install lerobot[hardware])")
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Test fixtures
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def robot_client():
|
||||
"""Fresh `RobotClient` instance for each test case (no threads started).
|
||||
Uses DummyRobot."""
|
||||
# Import only when the test actually runs (after decorator check)
|
||||
from lerobot.async_inference.configs import RobotClientConfig
|
||||
from lerobot.async_inference.robot_client import RobotClient
|
||||
from tests.mocks.mock_robot import MockRobotConfig
|
||||
|
||||
test_config = MockRobotConfig()
|
||||
|
||||
# gRPC channel is not actually used in tests, so using a dummy address
|
||||
test_config = RobotClientConfig(
|
||||
robot=test_config,
|
||||
server_address="localhost:9999",
|
||||
policy_type="test",
|
||||
pretrained_name_or_path="test",
|
||||
actions_per_chunk=20,
|
||||
)
|
||||
|
||||
client = RobotClient(test_config)
|
||||
|
||||
# Initialize attributes that are normally set in start() method
|
||||
client.chunks_received = 0
|
||||
client.available_actions_size = []
|
||||
|
||||
yield client
|
||||
|
||||
if client.robot.is_connected:
|
||||
client.stop()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Helper utilities for tests
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_actions(start_ts: float, start_t: int, count: int):
|
||||
"""Generate `count` consecutive TimedAction objects starting at timestep `start_t`."""
|
||||
from lerobot.async_inference.helpers import TimedAction
|
||||
|
||||
fps = 30 # emulates most common frame-rate
|
||||
actions = []
|
||||
for i in range(count):
|
||||
timestep = start_t + i
|
||||
timestamp = start_ts + i * (1 / fps)
|
||||
action_tensor = torch.full((6,), timestep, dtype=torch.float32)
|
||||
actions.append(TimedAction(action=action_tensor, timestep=timestep, timestamp=timestamp))
|
||||
return actions
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Tests
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_update_action_queue_discards_stale(robot_client):
|
||||
"""`_update_action_queue` must drop actions with `timestep` <= `latest_action`."""
|
||||
|
||||
# Pretend we already executed up to action #4
|
||||
robot_client.latest_action = 4
|
||||
|
||||
# Incoming chunk contains timesteps 3..7 -> expect 5,6,7 kept.
|
||||
incoming = _make_actions(start_ts=time.time(), start_t=3, count=5) # 3,4,5,6,7
|
||||
|
||||
robot_client._aggregate_action_queues(incoming)
|
||||
|
||||
# Extract timesteps from queue
|
||||
resulting_timesteps = [a.get_timestep() for a in robot_client.action_queue.queue]
|
||||
|
||||
assert resulting_timesteps == [5, 6, 7]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"weight_old, weight_new",
|
||||
[
|
||||
(1.0, 0.0),
|
||||
(0.0, 1.0),
|
||||
(0.5, 0.5),
|
||||
(0.2, 0.8),
|
||||
(0.8, 0.2),
|
||||
(0.1, 0.9),
|
||||
(0.9, 0.1),
|
||||
],
|
||||
)
|
||||
def test_aggregate_action_queues_combines_actions_in_overlap(
|
||||
robot_client, weight_old: float, weight_new: float
|
||||
):
|
||||
"""`_aggregate_action_queues` must combine actions on overlapping timesteps according
|
||||
to the provided aggregate_fn, here tested with multiple coefficients."""
|
||||
from lerobot.async_inference.helpers import TimedAction
|
||||
|
||||
robot_client.chunks_received = 0
|
||||
|
||||
# Pretend we already executed up to action #4, and queue contains actions for timesteps 5..6
|
||||
robot_client.latest_action = 4
|
||||
current_actions = _make_actions(
|
||||
start_ts=time.time(), start_t=5, count=2
|
||||
) # actions are [torch.ones(6), torch.ones(6), ...]
|
||||
current_actions = [
|
||||
TimedAction(action=10 * a.get_action(), timestep=a.get_timestep(), timestamp=a.get_timestamp())
|
||||
for a in current_actions
|
||||
]
|
||||
|
||||
for a in current_actions:
|
||||
robot_client.action_queue.put(a)
|
||||
|
||||
# Incoming chunk contains timesteps 3..7 -> expect 5,6,7 kept.
|
||||
incoming = _make_actions(start_ts=time.time(), start_t=3, count=5) # 3,4,5,6,7
|
||||
|
||||
overlap_timesteps = [5, 6] # properly tested in test_aggregate_action_queues_discards_stale
|
||||
nonoverlap_timesteps = [7]
|
||||
|
||||
robot_client._aggregate_action_queues(
|
||||
incoming, aggregate_fn=lambda x1, x2: weight_old * x1 + weight_new * x2
|
||||
)
|
||||
|
||||
queue_overlap_actions = []
|
||||
queue_non_overlap_actions = []
|
||||
for a in robot_client.action_queue.queue:
|
||||
if a.get_timestep() in overlap_timesteps:
|
||||
queue_overlap_actions.append(a)
|
||||
elif a.get_timestep() in nonoverlap_timesteps:
|
||||
queue_non_overlap_actions.append(a)
|
||||
|
||||
queue_overlap_actions = sorted(queue_overlap_actions, key=lambda x: x.get_timestep())
|
||||
queue_non_overlap_actions = sorted(queue_non_overlap_actions, key=lambda x: x.get_timestep())
|
||||
|
||||
assert torch.allclose(
|
||||
queue_overlap_actions[0].get_action(),
|
||||
weight_old * current_actions[0].get_action() + weight_new * incoming[-3].get_action(),
|
||||
)
|
||||
assert torch.allclose(
|
||||
queue_overlap_actions[1].get_action(),
|
||||
weight_old * current_actions[1].get_action() + weight_new * incoming[-2].get_action(),
|
||||
)
|
||||
assert torch.allclose(queue_non_overlap_actions[0].get_action(), incoming[-1].get_action())
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"chunk_size, queue_len, expected",
|
||||
[
|
||||
(20, 12, False), # 12 / 20 = 0.6 > g=0.5 threshold, not ready to send
|
||||
(20, 8, True), # 8 / 20 = 0.4 <= g=0.5, ready to send
|
||||
(10, 5, True),
|
||||
(10, 6, False),
|
||||
],
|
||||
)
|
||||
def test_ready_to_send_observation(robot_client, chunk_size: int, queue_len: int, expected: bool):
|
||||
"""Validate `_ready_to_send_observation` ratio logic for various sizes."""
|
||||
|
||||
robot_client.action_chunk_size = chunk_size
|
||||
|
||||
# Clear any existing actions then fill with `queue_len` dummy entries ----
|
||||
robot_client.action_queue = Queue()
|
||||
|
||||
dummy_actions = _make_actions(start_ts=time.time(), start_t=0, count=queue_len)
|
||||
for act in dummy_actions:
|
||||
robot_client.action_queue.put(act)
|
||||
|
||||
assert robot_client._ready_to_send_observation() is expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"g_threshold, expected",
|
||||
[
|
||||
# The condition is `queue_size / chunk_size <= g`.
|
||||
# Here, ratio = 6 / 10 = 0.6.
|
||||
(0.0, False), # 0.6 <= 0.0 is False
|
||||
(0.1, False),
|
||||
(0.2, False),
|
||||
(0.3, False),
|
||||
(0.4, False),
|
||||
(0.5, False),
|
||||
(0.6, True), # 0.6 <= 0.6 is True
|
||||
(0.7, True),
|
||||
(0.8, True),
|
||||
(0.9, True),
|
||||
(1.0, True),
|
||||
],
|
||||
)
|
||||
def test_ready_to_send_observation_with_varying_threshold(robot_client, g_threshold: float, expected: bool):
|
||||
"""Validate `_ready_to_send_observation` with fixed sizes and varying `g`."""
|
||||
# Fixed sizes for this test: ratio = 6 / 10 = 0.6
|
||||
chunk_size = 10
|
||||
queue_len = 6
|
||||
|
||||
robot_client.action_chunk_size = chunk_size
|
||||
# This is the parameter we are testing
|
||||
robot_client._chunk_size_threshold = g_threshold
|
||||
|
||||
# Fill queue with dummy actions
|
||||
robot_client.action_queue = Queue()
|
||||
dummy_actions = _make_actions(start_ts=time.time(), start_t=0, count=queue_len)
|
||||
for act in dummy_actions:
|
||||
robot_client.action_queue.put(act)
|
||||
|
||||
assert robot_client._ready_to_send_observation() is expected
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Regression test: robot type registry populated by robot_client imports
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_robot_client_registers_builtin_robot_types():
|
||||
"""Importing robot_client must populate RobotConfig's ChoiceRegistry.
|
||||
|
||||
This is a regression test for a bug introduced in #2425, where removing
|
||||
robot module imports from robot_client.py caused RobotConfig's registry to
|
||||
be empty, breaking CLI argument parsing with:
|
||||
error: argument --robot.type: invalid choice: 'so101_follower' (choose from )
|
||||
|
||||
Robot types are registered via @RobotConfig.register_subclass() decorators
|
||||
at import time, so all supported modules must be explicitly imported.
|
||||
"""
|
||||
import lerobot.async_inference.robot_client # noqa: F401
|
||||
from lerobot.robots.config import RobotConfig
|
||||
|
||||
known_choices = RobotConfig.get_known_choices()
|
||||
|
||||
expected_robot_types = [
|
||||
"so100_follower",
|
||||
"so101_follower",
|
||||
"koch_follower",
|
||||
"omx_follower",
|
||||
"bi_so_follower",
|
||||
]
|
||||
for robot_type in expected_robot_types:
|
||||
assert robot_type in known_choices, (
|
||||
f"Robot type '{robot_type}' is not registered in RobotConfig's ChoiceRegistry. "
|
||||
f"Ensure the corresponding module is imported in robot_client.py. "
|
||||
f"Known choices: {sorted(known_choices)}"
|
||||
)
|
||||
@@ -114,6 +114,30 @@ def test_shuffle():
|
||||
assert set(sampler) == {0, 1, 2, 3, 4, 5}
|
||||
|
||||
|
||||
def test_shuffle_with_generator_is_deterministic():
|
||||
# Two samplers shuffling with same-seed generators must yield identical permutations.
|
||||
# This is what keeps batch shards disjoint across ranks in distributed training, where
|
||||
# accelerate synchronizes the sampler's generator state instead of the global torch RNG.
|
||||
sampler_a = EpisodeAwareSampler([0], [6], shuffle=True, generator=torch.Generator().manual_seed(42))
|
||||
sampler_b = EpisodeAwareSampler([0], [6], shuffle=True, generator=torch.Generator().manual_seed(42))
|
||||
assert list(sampler_a) == list(sampler_b)
|
||||
|
||||
# Desyncing the global RNG must not affect the permutation.
|
||||
sampler_c = EpisodeAwareSampler([0], [6], shuffle=True, generator=torch.Generator().manual_seed(42))
|
||||
order_before = list(sampler_c)
|
||||
sampler_c.generator.manual_seed(42)
|
||||
torch.randperm(1000) # consume global RNG, as rank-asymmetric code (e.g. eval) would
|
||||
assert list(sampler_c) == order_before
|
||||
|
||||
|
||||
def test_generator_attribute_defaults_to_none():
|
||||
# accelerate detects synchronizable samplers via `hasattr(sampler, "generator")`,
|
||||
# so the attribute must exist even when no generator is passed.
|
||||
sampler = EpisodeAwareSampler([0], [6], shuffle=True)
|
||||
assert sampler.generator is None
|
||||
assert set(sampler) == {0, 1, 2, 3, 4, 5}
|
||||
|
||||
|
||||
def test_negative_drop_first_frames_raises():
|
||||
with pytest.raises(ValueError, match="drop_n_first_frames must be >= 0"):
|
||||
EpisodeAwareSampler([0], [10], drop_n_first_frames=-1)
|
||||
|
||||
@@ -1,2 +0,0 @@
|
||||
# Local-only parity artifacts (regenerated via dump_original_n1_7.py); never committed.
|
||||
*.npz
|
||||
@@ -14,7 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Test script for LeRobot's GR00T N1.7 policy forward and inference passes."""
|
||||
"""Test script for LeRobot's Groot policy forward and inference passes."""
|
||||
|
||||
import gc
|
||||
import os
|
||||
@@ -41,20 +41,13 @@ pytestmark = pytest.mark.skipif(
|
||||
)
|
||||
|
||||
|
||||
# Define constants for dummy data (GR00T N1.7 native conventions).
|
||||
# N1.7 internally uses a 40-step action chunk, 132-dim state/action, and 256px images
|
||||
# (see GrootConfig.__post_init__). Use a chunk-sized action horizon so the dummy batch
|
||||
# matches the model's native action space.
|
||||
# Define constants for dummy data
|
||||
DUMMY_STATE_DIM = 44
|
||||
DUMMY_ACTION_DIM = 44
|
||||
DUMMY_ACTION_HORIZON = 40
|
||||
DUMMY_ACTION_HORIZON = 16
|
||||
IMAGE_SIZE = 256
|
||||
DEVICE = auto_select_torch_device()
|
||||
# GR00T N1.7 checkpoint (N1.5 is no longer supported). The N1.7-3B base model loads
|
||||
# via GrootPolicy.from_pretrained with root-level sharded safetensors.
|
||||
MODEL_PATH = "nvidia/GR00T-N1.7-3B"
|
||||
# Valid N1.7 embodiment tag carried by the checkpoint metadata.
|
||||
EMBODIMENT_TAG = "gr1_unified"
|
||||
MODEL_PATH = "aractingi/bimanual-handover-groot-10k"
|
||||
|
||||
|
||||
def cleanup_memory():
|
||||
@@ -95,13 +88,13 @@ def instantiate_lerobot_groot(
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
"""Instantiate LeRobot GR00T N1.7 policy with preprocessor and postprocessor."""
|
||||
"""Instantiate LeRobot Groot policy with preprocessor and postprocessor."""
|
||||
if from_pretrained:
|
||||
policy = GrootPolicy.from_pretrained(
|
||||
pretrained_name_or_path=model_path,
|
||||
strict=False,
|
||||
)
|
||||
policy.config.embodiment_tag = EMBODIMENT_TAG
|
||||
policy.config.embodiment_tag = "gr1"
|
||||
else:
|
||||
config = GrootConfig(
|
||||
base_model_path=model_path,
|
||||
@@ -109,7 +102,7 @@ def instantiate_lerobot_groot(
|
||||
chunk_size=DUMMY_ACTION_HORIZON,
|
||||
image_size=[IMAGE_SIZE, IMAGE_SIZE],
|
||||
device=DEVICE,
|
||||
embodiment_tag=EMBODIMENT_TAG,
|
||||
embodiment_tag="gr1",
|
||||
)
|
||||
policy = GrootPolicy(config)
|
||||
|
||||
@@ -155,8 +148,8 @@ def create_dummy_data(device=DEVICE):
|
||||
|
||||
@require_cuda
|
||||
def test_lerobot_groot_inference():
|
||||
"""Test the inference pass (select_action) of LeRobot's GR00T N1.7 policy."""
|
||||
print("Test: LeRobot GR00T N1.7 Inference Pass")
|
||||
"""Test the inference pass (select_action) of LeRobot's Groot policy."""
|
||||
print("Test: LeRobot Groot Inference Pass")
|
||||
|
||||
set_seed_all(42)
|
||||
|
||||
@@ -188,9 +181,9 @@ def test_lerobot_groot_inference():
|
||||
|
||||
@require_cuda
|
||||
def test_lerobot_groot_forward_pass():
|
||||
"""Test the forward pass of LeRobot's GR00T N1.7 policy."""
|
||||
"""Test the forward pass of LeRobot's Groot policy."""
|
||||
print("\n" + "=" * 50)
|
||||
print("Test: LeRobot GR00T N1.7 Forward Pass (Training Mode)")
|
||||
print("Test: LeRobot Groot Forward Pass (Training Mode)")
|
||||
|
||||
set_seed_all(42)
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -14,194 +14,431 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Parity test: original NVIDIA GR00T N1.7 vs the GR00T N1.7 integration in LeRobot.
|
||||
|
||||
Verifies that the self-contained LeRobot reimplementation of the GR00T N1.7 action
|
||||
head + Qwen3-VL backbone produces the SAME raw model output (``action_pred``, the
|
||||
normalized flow-matching prediction before any action decoding) as NVIDIA's original
|
||||
``gr00t`` package, given byte-identical pre-processed inputs and the same
|
||||
flow-matching seed. The comparison is parametrized over every embodiment tag present
|
||||
in the checkpoint.
|
||||
|
||||
To keep the comparison fair, the original outputs + the exact collated inputs are
|
||||
produced once per embodiment in the original ``gr00t`` env via the companion script
|
||||
``utils/dump_original_n1_7.py`` (in the ``utils`` package next to this file) and saved
|
||||
to per-tag ``.npz`` files.
|
||||
This test discovers those artifacts, replays the identical inputs through the LeRobot
|
||||
model, and compares.
|
||||
|
||||
This test is LOCAL-only and skips on CI, when ``gr00t``-side prerequisites are not
|
||||
present, or when no artifact has been generated. By default it looks for artifacts in
|
||||
``<this dir>/artifacts/``; override with ``GROOT_N1_7_PARITY_DIR``. See the
|
||||
"Original-vs-LeRobot parity test" section of ``src/lerobot/policies/groot/README.md``
|
||||
for the full run procedure.
|
||||
"""
|
||||
"""Test script to verify Groot policy integration with LeRobot vs the original implementation, only meant to be run locally!"""
|
||||
|
||||
import gc
|
||||
import os
|
||||
from pathlib import Path
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.policies.groot.configuration_groot import GrootConfig
|
||||
from lerobot.policies.groot.modeling_groot import GrootPolicy
|
||||
from lerobot.policies.groot.processor_groot import make_groot_pre_post_processors
|
||||
from lerobot.processor import PolicyProcessorPipeline
|
||||
from lerobot.types import PolicyAction
|
||||
|
||||
pytest.importorskip("gr00t")
|
||||
pytest.importorskip("transformers")
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
|
||||
reason="Requires a local GR00T N1.7 checkpoint + pre-generated artifacts; not for CI.",
|
||||
reason="This test requires local Groot installation and is not meant for CI",
|
||||
)
|
||||
|
||||
from lerobot.policies.groot.configuration_groot import GROOT_N1_7 # noqa: E402,F401
|
||||
|
||||
SEED = 42
|
||||
DEVICE = os.environ.get("GROOT_PARITY_DEVICE", "cuda" if torch.cuda.is_available() else "cpu")
|
||||
ATOL = float(os.environ.get("GROOT_PARITY_ATOL", "1e-3"))
|
||||
RTOL = float(os.environ.get("GROOT_PARITY_RTOL", "1e-3"))
|
||||
from gr00t.data.dataset import ModalityConfig # noqa: E402
|
||||
from gr00t.data.embodiment_tags import EmbodimentTag # noqa: E402
|
||||
from gr00t.data.transform.base import ComposedModalityTransform # noqa: E402
|
||||
from gr00t.model.policy import Gr00tPolicy # noqa: E402
|
||||
|
||||
# Artifact filenames are original_n1_7_<embodiment_tag>.npz
|
||||
_ARTIFACT_PREFIX = "original_n1_7_"
|
||||
_ARTIFACT_SUFFIX = ".npz"
|
||||
# GR1 humanoid dimensions (from pretrained model metadata)
|
||||
# The actual GR1 robot has 44 dimensions for both state and action
|
||||
# GR00TTransform will pad state to 64 and truncate action to 32
|
||||
DUMMY_STATE_DIM = 44
|
||||
DUMMY_ACTION_DIM = 44
|
||||
DUMMY_ACTION_HORIZON = 16
|
||||
IMAGE_SIZE = 256
|
||||
DEVICE = "cpu"
|
||||
MODEL_PATH = "nvidia/GR00T-N1.5-3B"
|
||||
|
||||
GR1_BODY_PARTS = {
|
||||
"left_arm": 7,
|
||||
"left_hand": 6,
|
||||
"left_leg": 6,
|
||||
"neck": 3,
|
||||
"right_arm": 7,
|
||||
"right_hand": 6,
|
||||
"right_leg": 6,
|
||||
"waist": 3,
|
||||
}
|
||||
|
||||
|
||||
def _artifact_dir() -> Path:
|
||||
"""Directory holding the per-embodiment .npz artifacts.
|
||||
|
||||
Self-contained by default: a sibling ``artifacts/`` directory next to this test.
|
||||
Override with ``GROOT_N1_7_PARITY_DIR`` (e.g. to point at a scratch location).
|
||||
The directory is read-only here -- it is populated by ``utils/dump_original_n1_7.py``
|
||||
run in the original gr00t environment; the test never creates it.
|
||||
"""
|
||||
env = os.environ.get("GROOT_N1_7_PARITY_DIR")
|
||||
if env:
|
||||
return Path(env)
|
||||
return Path(__file__).resolve().parent / "artifacts"
|
||||
|
||||
|
||||
def _discover_artifacts() -> list[tuple[str, Path]]:
|
||||
"""Return [(embodiment_tag, npz_path), ...] for every dumped artifact."""
|
||||
d = _artifact_dir()
|
||||
if not d.is_dir():
|
||||
return []
|
||||
out = []
|
||||
for p in sorted(d.glob(f"{_ARTIFACT_PREFIX}*{_ARTIFACT_SUFFIX}")):
|
||||
tag = p.name[len(_ARTIFACT_PREFIX) : -len(_ARTIFACT_SUFFIX)]
|
||||
out.append((tag, p))
|
||||
return out
|
||||
|
||||
|
||||
def _resolve_checkpoint() -> str:
|
||||
env = os.environ.get("GROOT_N1_7_LIBERO_CKPT")
|
||||
if env:
|
||||
if not Path(env).exists():
|
||||
pytest.skip(f"GROOT_N1_7_LIBERO_CKPT={env} does not exist")
|
||||
return env
|
||||
try:
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
root = snapshot_download(
|
||||
"nvidia/GR00T-N1.7-LIBERO",
|
||||
local_files_only=True,
|
||||
allow_patterns=["libero_10/*"],
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
pytest.skip(f"GR00T N1.7 LIBERO checkpoint not available locally: {exc}")
|
||||
ckpt = Path(root) / "libero_10"
|
||||
if not (ckpt / "config.json").exists():
|
||||
pytest.skip(f"GR00T N1.7 LIBERO checkpoint incomplete at {ckpt}")
|
||||
return str(ckpt)
|
||||
|
||||
|
||||
def _load_artifact(path: Path):
|
||||
data = np.load(path, allow_pickle=True)
|
||||
original_action = torch.from_numpy(data["action_pred"]).float()
|
||||
dtypes = dict(zip(data["meta_keys"].tolist(), data["meta_dtypes"].tolist(), strict=False))
|
||||
inputs = {}
|
||||
for key in data.files:
|
||||
if not key.startswith("in::"):
|
||||
continue
|
||||
name = key[4:]
|
||||
arr = data[key]
|
||||
t = torch.from_numpy(np.asarray(arr))
|
||||
declared = dtypes.get(key, "")
|
||||
if "int" in declared or "long" in declared:
|
||||
t = t.long()
|
||||
inputs[name] = t
|
||||
return original_action, inputs
|
||||
|
||||
|
||||
def _unflatten(inputs: dict[str, torch.Tensor]) -> dict:
|
||||
"""Rebuild the nested model-input dict from dot-prefixed flat keys."""
|
||||
nested: dict = {}
|
||||
for dotted, value in inputs.items():
|
||||
parts = dotted.split(".")
|
||||
cur = nested
|
||||
for p in parts[:-1]:
|
||||
cur = cur.setdefault(p, {})
|
||||
cur[parts[-1]] = value
|
||||
return nested.get("inputs", nested)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def lerobot_model():
|
||||
"""Load the LeRobot GR00T N1.7 model once (fp32 + SDPA) and reuse across tags."""
|
||||
ckpt = _resolve_checkpoint()
|
||||
from lerobot.policies.groot.groot_n1_7 import GR00TN17
|
||||
|
||||
model = GR00TN17.from_pretrained(
|
||||
ckpt,
|
||||
tune_llm=False,
|
||||
tune_visual=False,
|
||||
tune_projector=False,
|
||||
tune_diffusion_model=False,
|
||||
tune_vlln=False,
|
||||
transformers_loading_kwargs={"trust_remote_code": True},
|
||||
)
|
||||
# fp32 + SDPA on both sides: bf16 + differing attention kernels otherwise introduce
|
||||
# ~1e-2 numerical noise unrelated to the implementations.
|
||||
model.compute_dtype = "float32"
|
||||
model.config.compute_dtype = model.compute_dtype
|
||||
model.to(device=DEVICE, dtype=torch.float32)
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
_ARTIFACTS = _discover_artifacts()
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not _ARTIFACTS,
|
||||
reason=(
|
||||
"No GR00T N1.7 parity artifacts found. Generate them first in the original gr00t "
|
||||
"env:\n .venv-original/bin/python tests/policies/groot/utils/dump_original_n1_7.py "
|
||||
"--ckpt <ckpt> --out-dir tests/policies/groot/artifacts --device cuda"
|
||||
),
|
||||
)
|
||||
@pytest.mark.parametrize("embodiment_tag,artifact", _ARTIFACTS, ids=[t for t, _ in _ARTIFACTS])
|
||||
def test_groot_get_action_parity(embodiment_tag, artifact, lerobot_model):
|
||||
"""Raw model.get_action(action_pred) parity per embodiment: original vs LeRobot."""
|
||||
original_action, flat_inputs = _load_artifact(artifact)
|
||||
model_inputs = _unflatten(flat_inputs)
|
||||
|
||||
# Align the flow-matching RNG exactly as the producer did (seed right before sampling).
|
||||
torch.manual_seed(SEED)
|
||||
def cleanup_memory():
|
||||
"""Clean up GPU/MPS memory to prevent OOM errors between tests."""
|
||||
print("\nCleaning up memory...")
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(SEED)
|
||||
with torch.inference_mode():
|
||||
out = lerobot_model.get_action(model_inputs)
|
||||
lerobot_action = out["action_pred"].float().cpu()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
if torch.backends.mps.is_available():
|
||||
torch.mps.empty_cache()
|
||||
print("Memory cleanup complete.")
|
||||
|
||||
t = min(original_action.shape[1], lerobot_action.shape[1])
|
||||
d = min(original_action.shape[2], lerobot_action.shape[2])
|
||||
original_action = original_action[:, :t, :d]
|
||||
lerobot_action = lerobot_action[:, :t, :d]
|
||||
|
||||
diff = torch.abs(lerobot_action - original_action)
|
||||
max_diff = diff.max().item()
|
||||
print(
|
||||
f"\n[{embodiment_tag}] shapes lerobot={tuple(lerobot_action.shape)} "
|
||||
f"original={tuple(original_action.shape)} "
|
||||
f"max|diff|={max_diff:.6e} mean|diff|={diff.mean().item():.6e}"
|
||||
def set_seed_all(seed: int):
|
||||
"""Set random seed for all RNG sources to ensure reproducibility."""
|
||||
import random
|
||||
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
# Set deterministic behavior
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
torch.use_deterministic_algorithms(True, warn_only=True)
|
||||
|
||||
|
||||
def instantiate_lerobot_groot(
|
||||
from_pretrained: bool = False,
|
||||
model_path: str = MODEL_PATH,
|
||||
) -> tuple[
|
||||
GrootPolicy,
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
"""Instantiate LeRobot Groot policy with preprocessor and postprocessor."""
|
||||
if from_pretrained:
|
||||
policy = GrootPolicy.from_pretrained(
|
||||
pretrained_name_or_path=model_path,
|
||||
strict=False,
|
||||
)
|
||||
policy.config.embodiment_tag = "gr1"
|
||||
else:
|
||||
config = GrootConfig(
|
||||
base_model_path=model_path,
|
||||
n_action_steps=DUMMY_ACTION_HORIZON,
|
||||
chunk_size=DUMMY_ACTION_HORIZON,
|
||||
image_size=[IMAGE_SIZE, IMAGE_SIZE],
|
||||
device=DEVICE,
|
||||
embodiment_tag="gr1",
|
||||
)
|
||||
policy = GrootPolicy(config)
|
||||
|
||||
policy.to(DEVICE)
|
||||
policy.config.device = DEVICE
|
||||
|
||||
preprocessor, postprocessor = make_groot_pre_post_processors(
|
||||
config=policy.config,
|
||||
dataset_stats=None, # Pass None for dataset_stats to disable normalization (original GR00T doesn't normalize)
|
||||
)
|
||||
|
||||
assert torch.allclose(lerobot_action, original_action, atol=ATOL, rtol=RTOL), (
|
||||
f"GR00T N1.7 raw action_pred differs for embodiment '{embodiment_tag}' beyond "
|
||||
f"atol={ATOL}, rtol={RTOL}: max|diff|={max_diff:.6e}"
|
||||
return (policy, preprocessor, postprocessor)
|
||||
|
||||
|
||||
def instantiate_original_groot(
|
||||
from_pretrained: bool = False,
|
||||
model_path: str = MODEL_PATH,
|
||||
):
|
||||
"""Instantiate original Groot policy from NVIDIA's implementation."""
|
||||
from gr00t.data.transform.concat import ConcatTransform
|
||||
from gr00t.data.transform.state_action import StateActionToTensor
|
||||
from gr00t.data.transform.video import VideoToNumpy, VideoToTensor
|
||||
from gr00t.model.transforms import GR00TTransform
|
||||
|
||||
video_keys = ["video.ego_view"]
|
||||
state_keys = [
|
||||
"state"
|
||||
] # Important: Use single concatenated "state" key (not split body parts) to match preprocessing
|
||||
action_keys = [
|
||||
"action.left_arm",
|
||||
"action.right_arm",
|
||||
"action.left_hand",
|
||||
"action.right_hand",
|
||||
"action.left_leg",
|
||||
"action.right_leg",
|
||||
"action.neck",
|
||||
"action.waist",
|
||||
]
|
||||
language_keys = ["annotation.human.action.task_description"]
|
||||
|
||||
modality_config = {
|
||||
"video": ModalityConfig(
|
||||
delta_indices=[0], # Current frame only
|
||||
modality_keys=video_keys,
|
||||
),
|
||||
"state": ModalityConfig(
|
||||
delta_indices=[0],
|
||||
modality_keys=state_keys,
|
||||
),
|
||||
"action": ModalityConfig(
|
||||
delta_indices=list(range(DUMMY_ACTION_HORIZON)),
|
||||
modality_keys=action_keys,
|
||||
),
|
||||
"language": ModalityConfig(
|
||||
delta_indices=[0],
|
||||
modality_keys=language_keys,
|
||||
),
|
||||
}
|
||||
|
||||
modality_transform = ComposedModalityTransform(
|
||||
transforms=[
|
||||
VideoToTensor(apply_to=video_keys),
|
||||
VideoToNumpy(apply_to=video_keys), # Convert to numpy (GR00TTransform expects numpy arrays)
|
||||
# State is already a single concatenated key, so no StateActionToTensor needed
|
||||
# Convert action from numpy to tensor
|
||||
StateActionToTensor(apply_to=action_keys),
|
||||
# Concatenate only video and actions (state is already single key)
|
||||
ConcatTransform(
|
||||
video_concat_order=video_keys,
|
||||
state_concat_order=[], # Empty:state is already single key
|
||||
action_concat_order=action_keys,
|
||||
),
|
||||
GR00TTransform(
|
||||
max_state_dim=64,
|
||||
max_action_dim=32,
|
||||
state_horizon=1,
|
||||
action_horizon=DUMMY_ACTION_HORIZON,
|
||||
training=False,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
policy = Gr00tPolicy(
|
||||
model_path=model_path,
|
||||
embodiment_tag=EmbodimentTag.GR1,
|
||||
modality_config=modality_config,
|
||||
modality_transform=modality_transform,
|
||||
device=DEVICE,
|
||||
)
|
||||
|
||||
return policy, modality_config, modality_transform
|
||||
|
||||
|
||||
def create_dummy_data(device=DEVICE):
|
||||
"""Create dummy data for testing both implementations."""
|
||||
batch_size = 2
|
||||
prompt = "Pick up the red cube and place it in the bin"
|
||||
state = torch.randn(batch_size, DUMMY_STATE_DIM, dtype=torch.float32, device=device)
|
||||
|
||||
batch = {
|
||||
"observation.state": state,
|
||||
"action": torch.randn(
|
||||
batch_size,
|
||||
DUMMY_ACTION_HORIZON,
|
||||
DUMMY_ACTION_DIM,
|
||||
dtype=torch.float32,
|
||||
device=device, # Action ground truth (for training)
|
||||
),
|
||||
"observation.images.ego_view": torch.rand(
|
||||
batch_size,
|
||||
3,
|
||||
IMAGE_SIZE,
|
||||
IMAGE_SIZE,
|
||||
dtype=torch.float32,
|
||||
device=device, # Images in [0, 1] range as expected by LeRobot
|
||||
),
|
||||
"task": [prompt for _ in range(batch_size)],
|
||||
}
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
def convert_lerobot_to_original_format(batch, modality_config):
|
||||
"""Convert LeRobot batch format to original Groot format.
|
||||
|
||||
The original Groot expects observations in this format:
|
||||
{
|
||||
"video.<camera_name>": np.ndarray (T, H, W, C) or (B, T, H, W, C)
|
||||
"state.<state_component>": np.ndarray (T, D) or (B, T, D)
|
||||
"action.<action_component>": np.ndarray (T, D) or (B, T, D)
|
||||
"annotation.<annotation_type>": str or list[str]
|
||||
}
|
||||
"""
|
||||
# Original Groot expects (T, H, W, C) format for images
|
||||
# LeRobot has (B, C, H, W) format, so we need to convert
|
||||
observation = {}
|
||||
|
||||
for img_key in ["ego_view"]:
|
||||
lerobot_key = f"observation.images.{img_key}"
|
||||
if lerobot_key in batch:
|
||||
img = batch[lerobot_key]
|
||||
# Convert from (B, C, H, W) to (B, T=1, H, W, C)
|
||||
img_np = img.permute(0, 2, 3, 1).unsqueeze(1).cpu().numpy()
|
||||
# Convert [0, 1] to [0, 255] uint8 as expected by original
|
||||
img_np = (img_np * 255).astype(np.uint8)
|
||||
observation[f"video.{img_key}"] = img_np
|
||||
|
||||
# Important: The Original's GR00TTransform expects "state" as (B, T, D), not split body parts
|
||||
if "observation.state" in batch:
|
||||
state = batch["observation.state"]
|
||||
state_np = state.unsqueeze(1).cpu().numpy() # (B, 1, D)
|
||||
observation["state"] = state_np
|
||||
|
||||
if "action" in batch:
|
||||
action = batch["action"]
|
||||
action_np = action.cpu().numpy()
|
||||
|
||||
start_idx = 0
|
||||
for part_name, part_dim in GR1_BODY_PARTS.items():
|
||||
end_idx = start_idx + part_dim
|
||||
observation[f"action.{part_name}"] = action_np[:, :, start_idx:end_idx]
|
||||
start_idx = end_idx
|
||||
|
||||
if "task" in batch:
|
||||
task_list = batch["task"]
|
||||
# GR00TTransform expects language with (B, T) shape for batched data
|
||||
# Create a (B, T=1) array where each element is the string directly
|
||||
bsz = len(task_list)
|
||||
task_array = np.empty((bsz, 1), dtype=object)
|
||||
for i in range(bsz):
|
||||
task_array[i, 0] = task_list[i] # Assign string directly to each (i, 0) position
|
||||
observation["annotation.human.action.task_description"] = task_array
|
||||
|
||||
return observation
|
||||
|
||||
|
||||
def test_groot_original_vs_lerobot_pretrained():
|
||||
"""Test Groot original implementation vs LeRobot implementation with pretrained weights."""
|
||||
print("Test: Groot Original vs LeRobot with Pretrained Weights (Inference)")
|
||||
|
||||
set_seed_all(42)
|
||||
|
||||
lerobot_policy, lerobot_preprocessor, lerobot_postprocessor = instantiate_lerobot_groot(
|
||||
from_pretrained=True
|
||||
)
|
||||
original_policy, modality_config, modality_transform = instantiate_original_groot(from_pretrained=True)
|
||||
|
||||
batch = create_dummy_data()
|
||||
batch_lerobot = deepcopy(batch)
|
||||
|
||||
print("\n[LeRobot] Running inference...")
|
||||
lerobot_policy.eval()
|
||||
batch_lerobot_processed = lerobot_preprocessor(batch_lerobot)
|
||||
|
||||
# Important: Reset seed immediately before inference to ensure identical RNG state
|
||||
torch.manual_seed(42)
|
||||
|
||||
with torch.no_grad():
|
||||
lerobot_actions = lerobot_policy.select_action(batch_lerobot_processed)
|
||||
|
||||
print("\n[Original] Running inference...")
|
||||
original_policy.model.eval()
|
||||
observation = convert_lerobot_to_original_format(batch, modality_config)
|
||||
original_obs_transformed = modality_transform(deepcopy(observation))
|
||||
|
||||
# Important: Reset seed immediately before inference to ensure identical RNG state
|
||||
torch.manual_seed(42)
|
||||
|
||||
with torch.no_grad():
|
||||
original_model_output = original_policy.model.get_action(original_obs_transformed)
|
||||
original_actions_raw = original_model_output["action_pred"] # [2, 16, 32]
|
||||
# Take first timestep
|
||||
original_actions = original_actions_raw[:, 0, :].to(lerobot_actions.device).to(lerobot_actions.dtype)
|
||||
|
||||
print("Action Comparison:")
|
||||
diff = lerobot_actions - original_actions
|
||||
abs_diff = torch.abs(diff)
|
||||
|
||||
for batch_idx in range(lerobot_actions.shape[0]):
|
||||
print(f"\n{'=' * 60}")
|
||||
print(f"Batch {batch_idx}")
|
||||
print(f"{'=' * 60}")
|
||||
print(f"{'Idx':<5} {'LeRobot':<14} {'Original':<14} {'Difference':<14}")
|
||||
print("-" * 60)
|
||||
for action_idx in range(lerobot_actions.shape[1]):
|
||||
lr_val = lerobot_actions[batch_idx, action_idx].item()
|
||||
orig_val = original_actions[batch_idx, action_idx].item()
|
||||
diff_val = abs(lr_val - orig_val)
|
||||
sign = "+" if (lr_val - orig_val) > 0 else "-"
|
||||
print(f"{action_idx:<5} {lr_val:>13.6f} {orig_val:>13.6f} {sign}{diff_val:>12.6f}")
|
||||
|
||||
max_diff = abs_diff.max().item()
|
||||
tolerance = 0.001
|
||||
assert torch.allclose(lerobot_actions, original_actions, atol=tolerance), (
|
||||
f"Actions differ by more than tolerance ({tolerance}): max diff = {max_diff:.6f}"
|
||||
)
|
||||
print(f"\nSuccess: Actions match within tolerance ({tolerance})!")
|
||||
|
||||
del lerobot_policy, lerobot_preprocessor, lerobot_postprocessor
|
||||
del original_policy, modality_config, modality_transform
|
||||
del batch, batch_lerobot, observation
|
||||
cleanup_memory()
|
||||
|
||||
|
||||
def test_groot_forward_pass_comparison():
|
||||
"""Test forward pass comparison between LeRobot and Original Groot implementations."""
|
||||
print("Test: Forward Pass Comparison (Training Mode)")
|
||||
|
||||
set_seed_all(42)
|
||||
|
||||
lerobot_policy, lerobot_preprocessor, lerobot_postprocessor = instantiate_lerobot_groot(
|
||||
from_pretrained=True
|
||||
)
|
||||
original_policy, modality_config, modality_transform = instantiate_original_groot(from_pretrained=True)
|
||||
|
||||
batch = create_dummy_data()
|
||||
lerobot_policy.eval()
|
||||
original_policy.model.eval()
|
||||
|
||||
print("\n[LeRobot] Running forward pass...")
|
||||
batch_lerobot = deepcopy(batch)
|
||||
batch_lerobot_processed = lerobot_preprocessor(batch_lerobot)
|
||||
|
||||
set_seed_all(42)
|
||||
with torch.no_grad():
|
||||
lerobot_loss, lerobot_metrics = lerobot_policy.forward(batch_lerobot_processed)
|
||||
|
||||
print(f" Loss: {lerobot_loss.item():.6f}")
|
||||
|
||||
print("\n[Original] Running forward pass...")
|
||||
observation = convert_lerobot_to_original_format(batch, modality_config)
|
||||
transformed_obs = modality_transform(observation)
|
||||
|
||||
if "action" not in transformed_obs:
|
||||
action_for_forward = batch_lerobot_processed["action"]
|
||||
action_mask_for_forward = batch_lerobot_processed["action_mask"]
|
||||
|
||||
# Match action horizon if needed
|
||||
if action_for_forward.shape[1] != original_policy.model.action_horizon:
|
||||
if action_for_forward.shape[1] < original_policy.model.action_horizon:
|
||||
pad_size = original_policy.model.action_horizon - action_for_forward.shape[1]
|
||||
last_action = action_for_forward[:, -1:, :]
|
||||
padding = last_action.repeat(1, pad_size, 1)
|
||||
action_for_forward = torch.cat([action_for_forward, padding], dim=1)
|
||||
|
||||
mask_padding = torch.zeros(
|
||||
action_mask_for_forward.shape[0],
|
||||
pad_size,
|
||||
action_mask_for_forward.shape[2],
|
||||
dtype=action_mask_for_forward.dtype,
|
||||
device=action_mask_for_forward.device,
|
||||
)
|
||||
action_mask_for_forward = torch.cat([action_mask_for_forward, mask_padding], dim=1)
|
||||
else:
|
||||
action_for_forward = action_for_forward[:, : original_policy.model.action_horizon, :]
|
||||
action_mask_for_forward = action_mask_for_forward[
|
||||
:, : original_policy.model.action_horizon, :
|
||||
]
|
||||
|
||||
transformed_obs["action"] = action_for_forward
|
||||
transformed_obs["action_mask"] = action_mask_for_forward
|
||||
|
||||
set_seed_all(42)
|
||||
with torch.no_grad():
|
||||
original_outputs = original_policy.model.forward(transformed_obs)
|
||||
|
||||
original_loss = original_outputs["loss"]
|
||||
print(f" Loss: {original_loss.item():.6f}")
|
||||
|
||||
loss_diff = abs(lerobot_loss.item() - original_loss.item())
|
||||
loss_rel_diff = loss_diff / (abs(original_loss.item()) + 1e-8) * 100
|
||||
|
||||
print("\nLoss Values:")
|
||||
print(f" LeRobot: {lerobot_loss.item():.6f}")
|
||||
print(f" Original: {original_loss.item():.6f}")
|
||||
print(f" Absolute difference: {loss_diff:.6f}")
|
||||
print(f" Relative difference: {loss_rel_diff:.2f}%")
|
||||
|
||||
del lerobot_policy, lerobot_preprocessor, lerobot_postprocessor
|
||||
del original_policy, modality_config, modality_transform
|
||||
del batch, batch_lerobot, observation, transformed_obs
|
||||
cleanup_memory()
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
"""Utilities shared by GR00T policy tests."""
|
||||
@@ -1,198 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License").
|
||||
"""Producer (run in the ORIGINAL gr00t env): dump original GR00T N1.7 outputs + inputs.
|
||||
|
||||
The original NVIDIA ``gr00t`` package pins ``transformers==4.57.3`` (py3.10) and its
|
||||
model-config dataclasses are incompatible with the ``transformers==5.x`` that the
|
||||
LeRobot GR00T N1.7 integration requires. The two implementations therefore cannot be
|
||||
imported in the same Python process. To keep the parity comparison FAIR, we run the
|
||||
original model in its native env here and serialize, PER EMBODIMENT TAG:
|
||||
|
||||
* the exact pre-processed/collated model inputs (so the LeRobot side consumes the
|
||||
byte-identical tensors -- same image preprocessing, tokenization, normalization),
|
||||
* the random seed used right before the flow-matching sampler,
|
||||
* the raw ``action_pred`` tensor returned by ``model.get_action`` (normalized space,
|
||||
before any per-implementation action decoding).
|
||||
|
||||
Inputs are built GENERICALLY from the checkpoint metadata (no per-tag hardcoding):
|
||||
state keys + dims come from ``statistics.json``; video + language keys come from the
|
||||
processor's per-embodiment modality configs. This lets us test many embodiment tags
|
||||
from the SAME checkpoint and confirm the LeRobot integration is not overfit to
|
||||
``libero_sim``.
|
||||
|
||||
The companion pytest (run in the LeRobot env) loads each .npz, replays the identical
|
||||
inputs + seed through the LeRobot GR00T N1.7 model, and asserts the outputs match.
|
||||
|
||||
Usage:
|
||||
.venv-original/bin/python tests/policies/groot/utils/dump_original_n1_7.py \
|
||||
--ckpt <path-to-GR00T-N1.7-LIBERO/libero_10> \
|
||||
--out-dir tests/policies/groot/artifacts \
|
||||
[--tags libero_sim,oxe_droid_relative_eef_relative_joint,...] \
|
||||
[--device cuda] [--seed 42]
|
||||
|
||||
If --tags is omitted, every embodiment present in the checkpoint statistics is dumped.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
IMAGE_SIZE = 256
|
||||
BATCH_SIZE = 2
|
||||
PROMPT = "pick up the black bowl and place it on the plate"
|
||||
|
||||
|
||||
def load_statistics(ckpt: str) -> dict:
|
||||
with open(os.path.join(ckpt, "statistics.json")) as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def make_observation(seed: int, video_keys, lang_key, state_spec):
|
||||
"""Build a dummy observation dict generically from the embodiment metadata."""
|
||||
rng = np.random.default_rng(seed)
|
||||
video = {
|
||||
k: rng.integers(0, 256, (BATCH_SIZE, 1, IMAGE_SIZE, IMAGE_SIZE, 3), dtype=np.uint8)
|
||||
for k in video_keys
|
||||
}
|
||||
# One ndarray per state key, shape (B, T=1, key_dim); dim taken from statistics.
|
||||
# Keys with dim 0 (e.g. disabled eef on some embodiments) are still emitted as
|
||||
# present-but-empty so the processor's state transform finds every expected key.
|
||||
state = {
|
||||
k: rng.standard_normal((BATCH_SIZE, 1, dim)).astype(np.float32)
|
||||
for k, dim in state_spec
|
||||
}
|
||||
language = {lang_key: [[PROMPT] for _ in range(BATCH_SIZE)]}
|
||||
return {"video": video, "state": state, "language": language}
|
||||
|
||||
|
||||
def dump_one_tag(policy, fair_model, tag, modality_cfg, state_spec, args, out_path):
|
||||
from gr00t.data.types import MessageType
|
||||
|
||||
video_keys = modality_cfg["video"].modality_keys
|
||||
lang_key = modality_cfg["language"].modality_keys[0]
|
||||
observation = make_observation(args.seed, video_keys, lang_key, state_spec)
|
||||
|
||||
# Point the policy preprocessing at this embodiment (mirrors Gr00tPolicy.__init__).
|
||||
policy.embodiment_tag = type(policy.embodiment_tag)(tag)
|
||||
policy.modality_configs = {
|
||||
k: v for k, v in policy.processor.get_modality_configs()[tag].items() if k != "rl_info"
|
||||
}
|
||||
policy.language_key = policy.modality_configs["language"].modality_keys[0]
|
||||
|
||||
torch.manual_seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
|
||||
unbatched = policy._unbatch_observation(observation)
|
||||
processed = []
|
||||
for obs in unbatched:
|
||||
vla = policy._to_vla_step_data(obs)
|
||||
processed.append(policy.processor([{"type": MessageType.EPISODE_STEP.value, "content": vla}]))
|
||||
collated = policy.collate_fn(processed)
|
||||
|
||||
def to_dev(x):
|
||||
if isinstance(x, torch.Tensor) and torch.is_floating_point(x):
|
||||
return x.to(args.device, torch.float32)
|
||||
if isinstance(x, torch.Tensor):
|
||||
return x.to(args.device)
|
||||
if isinstance(x, dict):
|
||||
return {k: to_dev(v) for k, v in x.items()}
|
||||
return x
|
||||
|
||||
collated = {k: to_dev(v) for k, v in collated.items()}
|
||||
|
||||
torch.manual_seed(args.seed)
|
||||
with torch.inference_mode():
|
||||
out = fair_model.get_action(**collated)
|
||||
action_pred = out["action_pred"].float().cpu().numpy()
|
||||
|
||||
flat, meta = {}, {}
|
||||
|
||||
def flatten(prefix, obj):
|
||||
if isinstance(obj, torch.Tensor):
|
||||
arr = obj.float().cpu().numpy() if torch.is_floating_point(obj) else obj.cpu().numpy()
|
||||
flat[f"in::{prefix}"] = arr
|
||||
meta[f"in::{prefix}"] = str(obj.dtype)
|
||||
elif isinstance(obj, dict):
|
||||
for k, v in obj.items():
|
||||
flatten(f"{prefix}.{k}" if prefix else k, v)
|
||||
elif isinstance(obj, (list, tuple)):
|
||||
flat[f"in::{prefix}"] = np.array(obj, dtype=object)
|
||||
else:
|
||||
flat[f"in::{prefix}"] = np.array(obj)
|
||||
|
||||
flatten("", collated)
|
||||
|
||||
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
np.savez(
|
||||
out_path,
|
||||
action_pred=action_pred,
|
||||
seed=np.array(args.seed),
|
||||
device=np.array(args.device),
|
||||
embodiment_tag=np.array(tag),
|
||||
meta_keys=np.array(list(meta.keys()), dtype=object),
|
||||
meta_dtypes=np.array(list(meta.values()), dtype=object),
|
||||
**flat,
|
||||
)
|
||||
print(f"[{tag}] action_pred {action_pred.shape} -> {out_path.name} ({os.path.getsize(out_path)} B)")
|
||||
|
||||
|
||||
def main():
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--ckpt", required=True)
|
||||
ap.add_argument("--out-dir", required=True, help="directory for per-tag .npz files")
|
||||
ap.add_argument("--tags", default="", help="comma-separated embodiment tags (default: all in stats)")
|
||||
ap.add_argument("--device", default="cuda")
|
||||
ap.add_argument("--seed", type=int, default=42)
|
||||
args = ap.parse_args()
|
||||
|
||||
from gr00t.policy.gr00t_policy import Gr00tPolicy
|
||||
from transformers import AutoConfig, AutoModel
|
||||
|
||||
stats = load_statistics(args.ckpt)
|
||||
requested = [t.strip() for t in args.tags.split(",") if t.strip()] or list(stats.keys())
|
||||
|
||||
# Load the policy once (for its processor/preprocessing) on any valid tag.
|
||||
bootstrap_tag = "libero_sim" if "libero_sim" in stats else requested[0]
|
||||
policy = Gr00tPolicy(embodiment_tag=bootstrap_tag, model_path=args.ckpt, device=args.device)
|
||||
all_modality = policy.processor.get_modality_configs()
|
||||
|
||||
# Load a FAIR model (SDPA + fp32) once and reuse across tags. Otherwise the
|
||||
# original checkpoint default (flash_attention_2 + bf16) introduces kernel/rounding
|
||||
# noise vs the LeRobot env (which has no flash_attn and runs SDPA).
|
||||
cfg = AutoConfig.from_pretrained(args.ckpt, trust_remote_code=True)
|
||||
cfg.use_flash_attention = False
|
||||
cfg.load_bf16 = False
|
||||
fair_model = AutoModel.from_pretrained(args.ckpt, config=cfg, trust_remote_code=True)
|
||||
fair_model.to(device=args.device, dtype=torch.float32)
|
||||
fair_model.eval()
|
||||
|
||||
out_dir = Path(args.out_dir)
|
||||
done, skipped = [], []
|
||||
for tag in requested:
|
||||
if tag not in stats or tag not in all_modality:
|
||||
print(f"[skip] {tag}: not present in checkpoint statistics/modality configs")
|
||||
skipped.append(tag)
|
||||
continue
|
||||
state_spec = [(k, len(v["min"])) for k, v in stats[tag]["state"].items()]
|
||||
try:
|
||||
dump_one_tag(
|
||||
policy, fair_model, tag, all_modality[tag], state_spec, args,
|
||||
out_dir / f"original_n1_7_{tag}.npz",
|
||||
)
|
||||
done.append(tag)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
print(f"[fail] {tag}: {type(exc).__name__}: {exc}")
|
||||
skipped.append(tag)
|
||||
|
||||
print(f"\nDumped {len(done)} tags: {done}")
|
||||
if skipped:
|
||||
print(f"Skipped/failed {len(skipped)} tags: {skipped}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,275 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Shared fixtures for the remote-inference test suite.
|
||||
|
||||
The mock policy is deterministic: chunk[t, j] = state[j] + 0.01 * t (so
|
||||
tests can predict exact values), accepts the RTC kwargs, and records
|
||||
every call for assertions. Pipelines mimic the
|
||||
``PolicyProcessorPipeline`` surface the server uses (``__call__``,
|
||||
``reset``, ``steps``); the mock postprocessor doubles actions so tests
|
||||
can tell model-space from robot-space chunks.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import socket
|
||||
from dataclasses import dataclass, field
|
||||
from threading import Event
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.policy_server.manifest import ModelSpec, PolicyServerManifest, ZenohSpec
|
||||
from lerobot.policy_server.validation import PolicyClassification, ServingClass
|
||||
|
||||
ACTION_DIM = 6
|
||||
CHUNK_SIZE = 20
|
||||
STATE_DIM = 6
|
||||
IMG_H, IMG_W = 48, 64
|
||||
ACTION_NAMES = [f"joint_{i}.pos" for i in range(ACTION_DIM)]
|
||||
TASK = "test task"
|
||||
MODEL_ID = "mock/model"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Mock policy & config
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockPolicyConfig:
|
||||
type: str = "mockchunk"
|
||||
pretrained_path: str = MODEL_ID
|
||||
chunk_size: int = CHUNK_SIZE
|
||||
action_feature_names: list[str] = field(default_factory=lambda: list(ACTION_NAMES))
|
||||
input_features: dict = field(
|
||||
default_factory=lambda: {
|
||||
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(STATE_DIM,)),
|
||||
"observation.images.front": PolicyFeature(type=FeatureType.VISUAL, shape=(3, IMG_H, IMG_W)),
|
||||
}
|
||||
)
|
||||
rtc_config: object | None = None
|
||||
|
||||
|
||||
class MockChunkPolicy:
|
||||
"""Deterministic chunk policy with the RTC kwargs surface."""
|
||||
|
||||
name = "mockchunk"
|
||||
|
||||
def __init__(self, config: MockPolicyConfig | None = None):
|
||||
self.config = config or MockPolicyConfig()
|
||||
self.calls: list[dict] = []
|
||||
self.reset_count = 0
|
||||
self.rtc_initialized = False
|
||||
|
||||
# nn.Module surface the server touches
|
||||
def to(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
def eval(self):
|
||||
return self
|
||||
|
||||
def reset(self):
|
||||
self.reset_count += 1
|
||||
|
||||
def init_rtc_processor(self):
|
||||
self.rtc_initialized = True
|
||||
|
||||
def predict_action_chunk(self, batch, inference_delay=None, prev_chunk_left_over=None):
|
||||
state = batch["observation.state"]
|
||||
if state.ndim == 1:
|
||||
state = state.unsqueeze(0)
|
||||
self.calls.append(
|
||||
{
|
||||
"state": state.detach().clone(),
|
||||
"inference_delay": inference_delay,
|
||||
"prev_chunk_left_over": None
|
||||
if prev_chunk_left_over is None
|
||||
else prev_chunk_left_over.detach().clone(),
|
||||
"task": batch.get("task"),
|
||||
}
|
||||
)
|
||||
steps = torch.arange(CHUNK_SIZE, dtype=torch.float32).unsqueeze(1) * 0.01
|
||||
return (state[:, :ACTION_DIM].unsqueeze(1) + steps.unsqueeze(0)).clone()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Mock processor pipelines
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class MockPipeline:
|
||||
"""Mimics the PolicyProcessorPipeline surface used by the server."""
|
||||
|
||||
def __init__(self, transform=None, steps=()):
|
||||
self._transform = transform
|
||||
self.steps = list(steps)
|
||||
self.reset_count = 0
|
||||
self.call_count = 0
|
||||
|
||||
def __call__(self, x):
|
||||
self.call_count += 1
|
||||
return self._transform(x) if self._transform is not None else x
|
||||
|
||||
def reset(self):
|
||||
self.reset_count += 1
|
||||
|
||||
|
||||
def make_mock_processors():
|
||||
"""Identity preprocessor + doubling postprocessor (model vs robot space)."""
|
||||
return MockPipeline(), MockPipeline(transform=lambda actions: actions * 2.0)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Server fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_policy():
|
||||
return MockChunkPolicy()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def shared_rtc_classification():
|
||||
return PolicyClassification(
|
||||
ServingClass.SHARED, supports_rtc=True, needs_queue_population=False, reason="mock"
|
||||
)
|
||||
|
||||
|
||||
def make_manifest(**overrides) -> PolicyServerManifest:
|
||||
kwargs = {
|
||||
"model": ModelSpec(repo_or_path=MODEL_ID, device="cpu"),
|
||||
"zenoh": ZenohSpec(mode="peer"),
|
||||
"default_task": TASK,
|
||||
"max_sessions": 4,
|
||||
"warmup_inferences": 0,
|
||||
"trained_fps": 30.0,
|
||||
"health_port": 0,
|
||||
}
|
||||
kwargs.update(overrides)
|
||||
return PolicyServerManifest(**kwargs)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def manifest():
|
||||
return make_manifest()
|
||||
|
||||
|
||||
def make_logic_server(
|
||||
manifest: PolicyServerManifest | None = None,
|
||||
policy: MockChunkPolicy | None = None,
|
||||
classification: PolicyClassification | None = None,
|
||||
processor_factory=None,
|
||||
):
|
||||
"""A PolicyServer with everything injected and no zenoh transport."""
|
||||
from lerobot.policy_server.server import PolicyServer
|
||||
|
||||
policy = policy or MockChunkPolicy()
|
||||
factory_calls = []
|
||||
|
||||
def default_factory():
|
||||
pair = make_mock_processors()
|
||||
factory_calls.append(pair)
|
||||
return pair
|
||||
|
||||
server = PolicyServer(
|
||||
manifest or make_manifest(),
|
||||
policy=policy,
|
||||
policy_cfg=policy.config,
|
||||
processor_factory=processor_factory or default_factory,
|
||||
classification=classification
|
||||
or PolicyClassification(
|
||||
ServingClass.SHARED, supports_rtc=True, needs_queue_population=False, reason="mock"
|
||||
),
|
||||
)
|
||||
server.load_policy()
|
||||
server.factory_calls = factory_calls
|
||||
return server
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Client-side fixtures (hw features, observations)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def hw_features():
|
||||
return {
|
||||
"observation.state": {
|
||||
"dtype": "float32",
|
||||
"shape": (STATE_DIM,),
|
||||
"names": list(ACTION_NAMES),
|
||||
},
|
||||
"observation.images.front": {
|
||||
"dtype": "video",
|
||||
"shape": (IMG_H, IMG_W, 3),
|
||||
"names": ["height", "width", "channels"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def make_robot_obs(seed: float = 1.0) -> dict:
|
||||
obs = {name: seed + 0.1 * i for i, name in enumerate(ACTION_NAMES)}
|
||||
rng = np.random.default_rng(int(seed * 10))
|
||||
obs["front"] = rng.integers(0, 255, size=(IMG_H, IMG_W, 3), dtype=np.uint8)
|
||||
return obs
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def shutdown_event():
|
||||
return Event()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Loopback helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def free_tcp_port() -> int:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
|
||||
sock.bind(("127.0.0.1", 0))
|
||||
return sock.getsockname()[1]
|
||||
|
||||
|
||||
def make_loopback_manifest(port: int, **overrides) -> PolicyServerManifest:
|
||||
return make_manifest(
|
||||
zenoh=ZenohSpec(mode="peer", listen_endpoints=[f"tcp/127.0.0.1:{port}"]),
|
||||
**overrides,
|
||||
)
|
||||
|
||||
|
||||
def make_remote_config(port: int, **overrides):
|
||||
"""RemoteInferenceConfig dialing a loopback server (fast watchdogs)."""
|
||||
from lerobot.rollout.inference.factory import RemoteInferenceConfig
|
||||
|
||||
kwargs = {
|
||||
"connect_endpoint": f"tcp/127.0.0.1:{port}",
|
||||
"zenoh_mode": "peer",
|
||||
"service_model_id": MODEL_ID,
|
||||
"service_task": TASK,
|
||||
"jpeg_quality": 0, # raw images: byte-exact loopback
|
||||
"buffer_time_s": 0.2,
|
||||
"handshake_timeout_s": 2.0,
|
||||
"request_timeout_s": 1.0,
|
||||
"degraded_after_s": 0.3,
|
||||
"reconnect_initial_backoff_s": 0.1,
|
||||
"reconnect_max_backoff_s": 0.5,
|
||||
"max_offline_s": 8.0,
|
||||
}
|
||||
kwargs.update(overrides)
|
||||
return RemoteInferenceConfig(**kwargs)
|
||||
@@ -0,0 +1,412 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Unit tests for the MessagePack wire codecs (tensors, images, messages)."""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
msgpack = pytest.importorskip("msgpack")
|
||||
|
||||
from lerobot.policy_server.codec import ( # noqa: E402
|
||||
decode_action_chunk,
|
||||
decode_image,
|
||||
decode_observation,
|
||||
decode_raw,
|
||||
decode_reset,
|
||||
decode_reset_ack,
|
||||
decode_session_ack,
|
||||
decode_session_close,
|
||||
decode_session_open,
|
||||
decode_status,
|
||||
decode_tensor,
|
||||
encode_action_chunk,
|
||||
encode_image,
|
||||
encode_observation,
|
||||
encode_reset,
|
||||
encode_reset_ack,
|
||||
encode_session_ack,
|
||||
encode_session_close,
|
||||
encode_session_open,
|
||||
encode_status,
|
||||
encode_tensor,
|
||||
)
|
||||
from lerobot.policy_server.schema import ( # noqa: E402
|
||||
IMAGE_CODEC_JPEG,
|
||||
IMAGE_CODEC_RAW,
|
||||
ActionChunkMsg,
|
||||
ObservationMsg,
|
||||
ResetAckMsg,
|
||||
ResetMsg,
|
||||
SessionAckMsg,
|
||||
SessionCloseMsg,
|
||||
SessionOpenMsg,
|
||||
StatusMsg,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tensor codec
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"arr",
|
||||
[
|
||||
np.array([1.5, -2.25, 3.0], dtype=np.float32),
|
||||
np.arange(12, dtype=np.float64).reshape(3, 4),
|
||||
np.array([[1, -2], [3, 4]], dtype=np.int64),
|
||||
np.array([[True, False], [False, True]], dtype=np.bool_),
|
||||
np.zeros((0,), dtype=np.float32), # empty 1-d
|
||||
np.zeros((0, 6), dtype=np.float64), # empty 2-d
|
||||
np.arange(24, dtype=np.int64).reshape(2, 3, 4),
|
||||
],
|
||||
ids=["f32_1d", "f64_2d", "i64_2d", "bool_2d", "f32_empty", "f64_empty_2d", "i64_3d"],
|
||||
)
|
||||
def test_tensor_roundtrip(arr):
|
||||
out = decode_tensor(encode_tensor(arr))
|
||||
assert out.dtype == arr.dtype
|
||||
assert out.shape == arr.shape
|
||||
np.testing.assert_array_equal(out, arr)
|
||||
|
||||
|
||||
def test_tensor_roundtrip_0d_preserves_value():
|
||||
# KNOWN QUIRK: np.ascontiguousarray inside encode_tensor promotes
|
||||
# 0-d arrays to shape (1,), so the round-trip is value-preserving
|
||||
# but not shape-preserving for scalars.
|
||||
arr = np.array(3.5, dtype=np.float32)
|
||||
out = decode_tensor(encode_tensor(arr))
|
||||
assert out.dtype == arr.dtype
|
||||
assert out.shape in ((), (1,))
|
||||
assert float(np.squeeze(out)) == 3.5
|
||||
|
||||
|
||||
def test_tensor_none_passthrough():
|
||||
assert encode_tensor(None) is None
|
||||
assert decode_tensor(None) is None
|
||||
|
||||
|
||||
def test_tensor_big_endian_input_values_identical():
|
||||
be = np.array([1.0, 2.5, -3.75], dtype=">f4")
|
||||
enc = encode_tensor(be)
|
||||
assert np.dtype(enc["dtype"]).byteorder != ">"
|
||||
out = decode_tensor(enc)
|
||||
np.testing.assert_array_equal(out, be.astype("<f4"))
|
||||
np.testing.assert_array_equal(out, np.array([1.0, 2.5, -3.75], dtype=np.float32))
|
||||
|
||||
|
||||
def test_tensor_decoded_writable_and_contiguous():
|
||||
arr = np.arange(6, dtype=np.float32).reshape(2, 3)
|
||||
out = decode_tensor(encode_tensor(arr))
|
||||
assert out.flags.writeable
|
||||
assert out.flags.c_contiguous
|
||||
out[0, 0] = 99.0 # must not raise
|
||||
assert out[0, 0] == 99.0
|
||||
|
||||
|
||||
def test_tensor_decode_refuses_object_dtype():
|
||||
with pytest.raises(ValueError, match="object dtype"):
|
||||
decode_tensor({"dtype": "|O", "shape": [1], "data": b"\x00" * 8})
|
||||
|
||||
|
||||
def test_tensor_roundtrip_through_msgpack():
|
||||
arr = np.arange(10, dtype=np.float32)
|
||||
packed = msgpack.packb(encode_tensor(arr), use_bin_type=True)
|
||||
out = decode_tensor(msgpack.unpackb(packed, raw=False))
|
||||
np.testing.assert_array_equal(out, arr)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Image codec
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _gradient_image(h: int = 32, w: int = 48) -> np.ndarray:
|
||||
"""Smooth RGB gradient: JPEG-friendly, deterministic."""
|
||||
img = np.zeros((h, w, 3), dtype=np.uint8)
|
||||
img[..., 0] = np.linspace(0, 255, w, dtype=np.uint8)[None, :]
|
||||
img[..., 1] = np.linspace(0, 255, h, dtype=np.uint8)[:, None]
|
||||
img[..., 2] = 128
|
||||
return img
|
||||
|
||||
|
||||
def test_image_raw_roundtrip_byte_exact():
|
||||
img = _gradient_image()
|
||||
enc = encode_image(img, jpeg_quality=0)
|
||||
assert enc["codec"] == IMAGE_CODEC_RAW
|
||||
out = decode_image(enc)
|
||||
assert out.dtype == np.uint8
|
||||
assert out.shape == img.shape
|
||||
np.testing.assert_array_equal(out, img)
|
||||
|
||||
|
||||
def test_image_jpeg_roundtrip_approximately_equal():
|
||||
img = _gradient_image()
|
||||
enc = encode_image(img, jpeg_quality=95)
|
||||
assert enc["codec"] == IMAGE_CODEC_JPEG
|
||||
out = decode_image(enc)
|
||||
assert out.dtype == np.uint8
|
||||
assert out.shape == img.shape
|
||||
err = np.abs(out.astype(np.int32) - img.astype(np.int32)).mean()
|
||||
assert err < 5.0, f"JPEG round-trip too lossy: mean abs error {err}"
|
||||
|
||||
|
||||
def test_image_jpeg_rgb_order_regression_pure_red_stays_red():
|
||||
# A silent BGR swap would poison every VLA in a fleet: pure red must
|
||||
# come back red-dominant, not blue-dominant.
|
||||
img = np.zeros((32, 32, 3), dtype=np.uint8)
|
||||
img[..., 0] = 255 # RGB: red channel
|
||||
out = decode_image(encode_image(img, jpeg_quality=90))
|
||||
red_mean = out[..., 0].astype(np.float64).mean()
|
||||
blue_mean = out[..., 2].astype(np.float64).mean()
|
||||
assert red_mean > 200, f"red channel lost: mean {red_mean}"
|
||||
assert blue_mean < 50, f"blue channel gained: mean {blue_mean}"
|
||||
assert red_mean > blue_mean
|
||||
|
||||
|
||||
def test_encode_image_rejects_float_arrays():
|
||||
with pytest.raises(ValueError, match="uint8 HWC RGB"):
|
||||
encode_image(np.zeros((8, 8, 3), dtype=np.float32))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"shape", [(3, 16, 24), (16, 24), (16, 24, 4), (16, 24, 1)], ids=["chw", "hw", "hwc4", "hwc1"]
|
||||
)
|
||||
def test_encode_image_rejects_non_hwc(shape):
|
||||
with pytest.raises(ValueError, match="uint8 HWC RGB"):
|
||||
encode_image(np.zeros(shape, dtype=np.uint8))
|
||||
|
||||
|
||||
def test_decode_image_rejects_unknown_codec():
|
||||
with pytest.raises(ValueError, match="Unknown image codec"):
|
||||
decode_image({"codec": "webp", "data": b""})
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Data-plane messages
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_observation_roundtrip_full():
|
||||
rng = np.random.default_rng(0)
|
||||
state = np.array([0.1, -0.2, 0.3, 0.4], dtype=np.float32)
|
||||
front = rng.integers(0, 255, size=(16, 24, 3), dtype=np.uint8)
|
||||
wrist = rng.integers(0, 255, size=(8, 12, 3), dtype=np.uint8)
|
||||
prefix_model = rng.standard_normal((5, 4)).astype(np.float32)
|
||||
prefix_robot = (prefix_model * 2.0).astype(np.float32)
|
||||
msg = ObservationMsg(
|
||||
state=state,
|
||||
images={"front": front, "wrist": wrist},
|
||||
task="fold the towel",
|
||||
inference_delay_steps=3,
|
||||
prefix_model=prefix_model,
|
||||
prefix_robot=prefix_robot,
|
||||
episode_start=True,
|
||||
jpeg_quality=0, # raw: byte-exact images
|
||||
)
|
||||
out = decode_observation(encode_observation(msg))
|
||||
np.testing.assert_array_equal(out.state, state)
|
||||
assert set(out.images) == {"front", "wrist"}
|
||||
np.testing.assert_array_equal(out.images["front"], front)
|
||||
np.testing.assert_array_equal(out.images["wrist"], wrist)
|
||||
assert out.task == "fold the towel"
|
||||
assert out.inference_delay_steps == 3
|
||||
np.testing.assert_array_equal(out.prefix_model, prefix_model)
|
||||
np.testing.assert_array_equal(out.prefix_robot, prefix_robot)
|
||||
assert out.episode_start is True
|
||||
|
||||
|
||||
def test_observation_roundtrip_minimal_defaults():
|
||||
out = decode_observation(encode_observation(ObservationMsg()))
|
||||
assert out.state is None
|
||||
assert out.images == {}
|
||||
assert out.task == ""
|
||||
assert out.inference_delay_steps == 0
|
||||
assert out.prefix_model is None
|
||||
assert out.prefix_robot is None
|
||||
assert out.episode_start is False
|
||||
|
||||
|
||||
def test_action_chunk_roundtrip():
|
||||
chunk_model = np.arange(12, dtype=np.float32).reshape(3, 4)
|
||||
chunk_robot = chunk_model * 2.0
|
||||
msg = ActionChunkMsg(
|
||||
seq_id_echo=17,
|
||||
client_mono_ns_echo=123456789,
|
||||
episode_id_echo=2,
|
||||
chunk_model=chunk_model,
|
||||
chunk_robot=chunk_robot,
|
||||
queue_wait_ms=1.5,
|
||||
inference_ms=12.25,
|
||||
superseded_seqs=4,
|
||||
server_load=0.75,
|
||||
)
|
||||
out = decode_action_chunk(encode_action_chunk(msg))
|
||||
assert out.seq_id_echo == 17
|
||||
assert out.client_mono_ns_echo == 123456789
|
||||
assert out.episode_id_echo == 2
|
||||
np.testing.assert_array_equal(out.chunk_model, chunk_model)
|
||||
np.testing.assert_array_equal(out.chunk_robot, chunk_robot)
|
||||
assert out.queue_wait_ms == 1.5
|
||||
assert out.inference_ms == 12.25
|
||||
assert out.superseded_seqs == 4
|
||||
assert out.server_load == 0.75
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Control-plane messages
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_session_open_roundtrip():
|
||||
msg = SessionOpenMsg(
|
||||
client_uuid="uuid-1",
|
||||
robot_type="so101_follower",
|
||||
policy_type="pi0",
|
||||
fps=30.0,
|
||||
action_names=["j0.pos", "j1.pos"],
|
||||
camera_names=["front", "wrist"],
|
||||
state_dim=6,
|
||||
rtc_enabled=True,
|
||||
task="fold",
|
||||
tags={"site": "lab-3"},
|
||||
)
|
||||
out = decode_session_open(encode_session_open(msg))
|
||||
assert out == msg
|
||||
|
||||
|
||||
def test_session_ack_roundtrip():
|
||||
msg = SessionAckMsg(
|
||||
accepted=True,
|
||||
warnings=["fps mismatch"],
|
||||
session_id="sess-1",
|
||||
model_repo="org/model",
|
||||
model_revision="main",
|
||||
policy_type="pi0",
|
||||
action_names=["j0.pos"],
|
||||
expected_cameras=["front"],
|
||||
state_dim=6,
|
||||
chunk_size=50,
|
||||
trained_fps=30.0,
|
||||
supports_rtc=True,
|
||||
rtc_execution_horizon=25,
|
||||
serving_mode="shared",
|
||||
warmed_up=True,
|
||||
server_load=0.5,
|
||||
)
|
||||
out = decode_session_ack(encode_session_ack(msg))
|
||||
assert out == msg
|
||||
|
||||
|
||||
def test_status_roundtrip():
|
||||
msg = StatusMsg(
|
||||
model_repo="org/model",
|
||||
model_revision="abc123",
|
||||
policy_type="act",
|
||||
action_names=["j0.pos", "j1.pos"],
|
||||
expected_cameras=["front"],
|
||||
state_dim=6,
|
||||
chunk_size=100,
|
||||
trained_fps=30.0,
|
||||
supports_rtc=False,
|
||||
rtc_execution_horizon=0,
|
||||
serving_mode="exclusive",
|
||||
warmed_up=False,
|
||||
active_sessions=2,
|
||||
max_sessions=4,
|
||||
)
|
||||
out = decode_status(encode_status(msg))
|
||||
assert out == msg
|
||||
|
||||
|
||||
def test_reset_and_reset_ack_roundtrip():
|
||||
out = decode_reset(encode_reset(ResetMsg(client_uuid="uuid-1", episode_id=5)))
|
||||
assert out == ResetMsg(client_uuid="uuid-1", episode_id=5)
|
||||
out_ack = decode_reset_ack(encode_reset_ack(ResetAckMsg(ok=False, error="busy")))
|
||||
assert out_ack == ResetAckMsg(ok=False, error="busy")
|
||||
|
||||
|
||||
def test_session_close_roundtrip():
|
||||
msg = SessionCloseMsg(client_uuid="uuid-1", session_id="sess-1")
|
||||
out = decode_session_close(encode_session_close(msg))
|
||||
assert out == msg
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Schema evolution (additive-only contract)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("encoded", "decoder", "expected"),
|
||||
[
|
||||
(
|
||||
encode_session_ack(SessionAckMsg(accepted=True, session_id="s")),
|
||||
decode_session_ack,
|
||||
SessionAckMsg(accepted=True, session_id="s"),
|
||||
),
|
||||
(
|
||||
encode_reset(ResetMsg(client_uuid="u", episode_id=1)),
|
||||
decode_reset,
|
||||
ResetMsg(client_uuid="u", episode_id=1),
|
||||
),
|
||||
(
|
||||
encode_session_open(SessionOpenMsg(client_uuid="u")),
|
||||
decode_session_open,
|
||||
SessionOpenMsg(client_uuid="u"),
|
||||
),
|
||||
],
|
||||
ids=["session_ack", "reset", "session_open"],
|
||||
)
|
||||
def test_unknown_keys_ignored(encoded, decoder, expected):
|
||||
obj = msgpack.unpackb(encoded, raw=False)
|
||||
obj["a_future_field"] = {"nested": [1, 2, 3]}
|
||||
out = decoder(msgpack.packb(obj, use_bin_type=True))
|
||||
assert out == expected
|
||||
|
||||
|
||||
def test_missing_optional_keys_take_defaults():
|
||||
minimal = msgpack.packb({"accepted": True}, use_bin_type=True)
|
||||
out = decode_session_ack(minimal)
|
||||
assert out.accepted is True
|
||||
assert out.error == ""
|
||||
assert out.warnings == []
|
||||
assert out.chunk_size == 0
|
||||
assert out.server_load == 0.0
|
||||
|
||||
out_chunk = decode_action_chunk(msgpack.packb({"seq_id_echo": 9}, use_bin_type=True))
|
||||
assert out_chunk.seq_id_echo == 9
|
||||
assert out_chunk.chunk_model is None
|
||||
assert out_chunk.chunk_robot is None
|
||||
assert out_chunk.queue_wait_ms == 0.0
|
||||
|
||||
out_obs = decode_observation(msgpack.packb({"task": "t"}, use_bin_type=True))
|
||||
assert out_obs.task == "t"
|
||||
assert out_obs.state is None
|
||||
assert out_obs.images == {}
|
||||
assert out_obs.episode_start is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# decode_raw
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_decode_raw_returns_plain_dict_with_op():
|
||||
open_obj = decode_raw(encode_session_open(SessionOpenMsg(client_uuid="u")))
|
||||
assert isinstance(open_obj, dict)
|
||||
assert open_obj["op"] == "open"
|
||||
|
||||
close_obj = decode_raw(encode_session_close(SessionCloseMsg(client_uuid="u")))
|
||||
assert isinstance(close_obj, dict)
|
||||
assert close_obj["op"] == "close"
|
||||
@@ -0,0 +1,235 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Golden parity contract test: remote request path == local RTC compute path.
|
||||
|
||||
The local side replicates exactly what ``RTCInferenceEngine._rtc_loop``
|
||||
(rtc.py) does per iteration; the remote side runs the same observation
|
||||
through the wire codec (encode -> decode), ``PolicyServer.run_inference_request``,
|
||||
and the action-chunk codec — no network, no threads. With the same
|
||||
deterministic policy and identical inputs, both ActionQueues must stay
|
||||
byte-identical merge after merge.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
pytest.importorskip("msgpack")
|
||||
|
||||
from lerobot.policies.rtc import ActionQueue # noqa: E402
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig # noqa: E402
|
||||
from lerobot.policies.utils import prepare_observation_for_inference # noqa: E402
|
||||
from lerobot.policy_server import codec # noqa: E402
|
||||
from lerobot.policy_server.schema import MsgHeader, ObservationMsg # noqa: E402
|
||||
from lerobot.policy_server.server import _normalize_prev_actions_length # noqa: E402
|
||||
from lerobot.policy_server.session import Session # noqa: E402
|
||||
from lerobot.utils.constants import OBS_STATE, OBS_STR # noqa: E402
|
||||
from lerobot.utils.feature_utils import build_dataset_frame # noqa: E402
|
||||
from tests.policy_server.conftest import ( # noqa: E402
|
||||
ACTION_NAMES,
|
||||
CHUNK_SIZE,
|
||||
STATE_DIM,
|
||||
TASK,
|
||||
MockChunkPolicy,
|
||||
make_logic_server,
|
||||
make_mock_processors,
|
||||
make_robot_obs,
|
||||
)
|
||||
|
||||
# Must match make_manifest()'s default RTCConfig (enabled=True, horizon=10).
|
||||
EXECUTION_HORIZON = 10
|
||||
ROBOT_TYPE = "mock_robot"
|
||||
# Fixed per-cycle inference-delay hints; cycle 2 exercises a non-zero delay.
|
||||
DELAYS = [0, 2, 1]
|
||||
# Actions consumed from both queues between cycles (makes prefixes non-trivial).
|
||||
CONSUME_K = 4
|
||||
|
||||
STATE_ONLY_FEATURES = {
|
||||
OBS_STATE: {
|
||||
"dtype": "float32",
|
||||
"shape": (STATE_DIM,),
|
||||
"names": list(ACTION_NAMES),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _make_queue() -> ActionQueue:
|
||||
return ActionQueue(RTCConfig(enabled=True, execution_horizon=EXECUTION_HORIZON))
|
||||
|
||||
|
||||
def _local_cycle(policy, pre, post, queue, features, obs, delay) -> None:
|
||||
"""Replicates the loop body of RTCInferenceEngine._rtc_loop (rtc.py)."""
|
||||
idx_before = queue.get_action_index()
|
||||
prev_actions = queue.get_left_over()
|
||||
|
||||
obs_batch = build_dataset_frame(features, obs, prefix=OBS_STR)
|
||||
obs_batch = prepare_observation_for_inference(obs_batch, torch.device("cpu"), TASK, ROBOT_TYPE)
|
||||
obs_batch["task"] = [TASK]
|
||||
preprocessed = pre(obs_batch)
|
||||
|
||||
if prev_actions is not None:
|
||||
prev_actions = _normalize_prev_actions_length(prev_actions, target_steps=EXECUTION_HORIZON)
|
||||
|
||||
actions = policy.predict_action_chunk(
|
||||
preprocessed, inference_delay=delay, prev_chunk_left_over=prev_actions
|
||||
)
|
||||
original = actions.squeeze(0).clone()
|
||||
processed = post(actions).squeeze(0)
|
||||
queue.merge(original, processed, delay, idx_before)
|
||||
|
||||
|
||||
def _remote_cycle(server, session, queue, features, obs, delay, seq_id) -> None:
|
||||
"""Replicates RemoteInferenceEngine._request_cycle with the wire codec in
|
||||
the loop (encode -> decode on both legs) but no network or threads."""
|
||||
idx_before = queue.get_action_index()
|
||||
|
||||
obs_frame = build_dataset_frame(features, obs, prefix=OBS_STR)
|
||||
state = obs_frame.pop(OBS_STATE, None)
|
||||
images = {k: v for k, v in obs_frame.items() if isinstance(v, np.ndarray) and v.ndim == 3}
|
||||
|
||||
prefix_model: np.ndarray | None = None
|
||||
prefix_robot: np.ndarray | None = None
|
||||
left_over = queue.get_left_over()
|
||||
if left_over is not None and left_over.numel():
|
||||
prefix_model = left_over[:EXECUTION_HORIZON].to(torch.float32).numpy()
|
||||
processed_left_over = queue.get_processed_left_over()
|
||||
if processed_left_over is not None and processed_left_over.numel():
|
||||
prefix_robot = processed_left_over[:EXECUTION_HORIZON].to(torch.float32).numpy()
|
||||
|
||||
msg = ObservationMsg(
|
||||
state=state,
|
||||
images=images,
|
||||
task=TASK,
|
||||
inference_delay_steps=delay,
|
||||
prefix_model=prefix_model,
|
||||
prefix_robot=prefix_robot,
|
||||
jpeg_quality=0, # raw image codec: byte-exact transport
|
||||
)
|
||||
decoded = codec.decode_observation(codec.encode_observation(msg))
|
||||
|
||||
# The float32 wire dtype must introduce zero drift: byte-exact roundtrip.
|
||||
assert decoded.state.dtype == np.float32
|
||||
assert decoded.state.tobytes() == np.ascontiguousarray(state).tobytes()
|
||||
if prefix_model is not None:
|
||||
assert decoded.prefix_model.tobytes() == np.ascontiguousarray(prefix_model).tobytes()
|
||||
if prefix_robot is not None:
|
||||
assert decoded.prefix_robot.tobytes() == np.ascontiguousarray(prefix_robot).tobytes()
|
||||
for name, img in images.items():
|
||||
assert np.array_equal(decoded.images[name], img)
|
||||
|
||||
reply = server.run_inference_request(session, MsgHeader(seq_id=seq_id), decoded)
|
||||
chunk = codec.decode_action_chunk(codec.encode_action_chunk(reply))
|
||||
|
||||
# Reply leg is byte-exact too (float32 in, float32 on the wire).
|
||||
assert chunk.chunk_model.tobytes() == np.ascontiguousarray(reply.chunk_model).tobytes()
|
||||
assert chunk.chunk_robot.tobytes() == np.ascontiguousarray(reply.chunk_robot).tobytes()
|
||||
|
||||
queue.merge(
|
||||
torch.from_numpy(np.ascontiguousarray(chunk.chunk_model)),
|
||||
torch.from_numpy(np.ascontiguousarray(chunk.chunk_robot)),
|
||||
delay,
|
||||
idx_before,
|
||||
)
|
||||
|
||||
|
||||
def _drive_parity(features) -> tuple[MockChunkPolicy, MockChunkPolicy]:
|
||||
"""Run DELAYS cycles through both paths, asserting queue parity after
|
||||
each merge and consuming CONSUME_K actions from both queues between
|
||||
cycles. Returns (local_policy, remote_policy) for call-level checks."""
|
||||
policy_local = MockChunkPolicy()
|
||||
pre_local, post_local = make_mock_processors()
|
||||
queue_local = _make_queue()
|
||||
|
||||
policy_remote = MockChunkPolicy()
|
||||
server = make_logic_server(policy=policy_remote)
|
||||
pre_remote, post_remote = make_mock_processors()
|
||||
session = Session(
|
||||
session_id="parity",
|
||||
client_uuid="parity-client",
|
||||
task=TASK,
|
||||
robot_type=ROBOT_TYPE,
|
||||
rtc_enabled=True,
|
||||
preprocessor=pre_remote,
|
||||
postprocessor=post_remote,
|
||||
)
|
||||
queue_remote = _make_queue()
|
||||
|
||||
for cycle, delay in enumerate(DELAYS):
|
||||
obs = make_robot_obs(seed=float(cycle + 1))
|
||||
_local_cycle(policy_local, pre_local, post_local, queue_local, features, obs, delay)
|
||||
_remote_cycle(server, session, queue_remote, features, obs, delay, seq_id=cycle + 1)
|
||||
|
||||
assert queue_local.queue is not None and queue_remote.queue is not None
|
||||
assert torch.equal(queue_local.queue, queue_remote.queue), (
|
||||
f"robot-space queues diverged (cycle {cycle})"
|
||||
)
|
||||
assert torch.equal(queue_local.original_queue, queue_remote.original_queue), (
|
||||
f"model-space queues diverged (cycle {cycle})"
|
||||
)
|
||||
assert queue_local.queue.shape == (CHUNK_SIZE - min(delay, CHUNK_SIZE), len(ACTION_NAMES))
|
||||
|
||||
# Consume the same k actions on both sides so the next cycle's RTC
|
||||
# prefixes are non-trivial (and identical).
|
||||
for _ in range(CONSUME_K):
|
||||
action_local = queue_local.get()
|
||||
action_remote = queue_remote.get()
|
||||
assert action_local is not None and action_remote is not None
|
||||
assert torch.equal(action_local, action_remote)
|
||||
|
||||
return policy_local, policy_remote
|
||||
|
||||
|
||||
def test_remote_path_matches_local_rtc_path_state_only():
|
||||
"""3 cycles, state-only features: queues stay byte-identical."""
|
||||
_drive_parity(STATE_ONLY_FEATURES)
|
||||
|
||||
|
||||
def test_remote_path_matches_local_rtc_path_with_images(hw_features):
|
||||
"""Images in the loop (raw codec) must not perturb state-driven outputs."""
|
||||
_drive_parity(hw_features)
|
||||
|
||||
|
||||
def test_policy_inputs_identical_across_paths(hw_features):
|
||||
"""The strongest contract: both policies saw byte-identical inputs."""
|
||||
policy_local, policy_remote = _drive_parity(hw_features)
|
||||
|
||||
assert len(policy_local.calls) == len(policy_remote.calls) == len(DELAYS)
|
||||
for i, (local_call, remote_call) in enumerate(zip(policy_local.calls, policy_remote.calls, strict=True)):
|
||||
assert torch.equal(local_call["state"], remote_call["state"]), f"state diverged (call {i})"
|
||||
assert local_call["state"].dtype == remote_call["state"].dtype == torch.float32
|
||||
assert local_call["inference_delay"] == remote_call["inference_delay"] == DELAYS[i]
|
||||
if local_call["prev_chunk_left_over"] is None:
|
||||
assert remote_call["prev_chunk_left_over"] is None
|
||||
else:
|
||||
assert torch.equal(local_call["prev_chunk_left_over"], remote_call["prev_chunk_left_over"])
|
||||
assert local_call["prev_chunk_left_over"].shape == (EXECUTION_HORIZON, len(ACTION_NAMES))
|
||||
|
||||
# First cycle has no leftover; later cycles must carry a real prefix.
|
||||
assert policy_local.calls[0]["prev_chunk_left_over"] is None
|
||||
assert all(call["prev_chunk_left_over"] is not None for call in policy_local.calls[1:])
|
||||
|
||||
|
||||
def test_float32_wire_dtype_is_byte_exact():
|
||||
"""Round-tripping non-dyadic float32 values through the tensor codec
|
||||
must reproduce the exact bytes (no dtype casts, no re-quantization)."""
|
||||
rng = np.random.default_rng(7)
|
||||
arr = (rng.standard_normal((CHUNK_SIZE, STATE_DIM)) * 0.1).astype(np.float32)
|
||||
decoded = codec.decode_tensor(codec.encode_tensor(arr))
|
||||
assert decoded.dtype == np.float32
|
||||
assert decoded.shape == arr.shape
|
||||
assert decoded.tobytes() == arr.tobytes()
|
||||
assert torch.equal(torch.from_numpy(decoded), torch.from_numpy(arr))
|
||||
@@ -0,0 +1,366 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Real zenoh peer-to-peer loopback tests: PolicyServer ↔ RemoteInferenceEngine.
|
||||
|
||||
The server listens on a fresh loopback TCP port per test; the engine
|
||||
dials it directly (``mode=peer``, no router). Mock policy values are
|
||||
deterministic — chunk_robot[t, j] = 2 * (state[j] + 0.01 t) — so first
|
||||
actions identify which client's observation produced them (the
|
||||
per-session isolation regression). Chaos tests kill/restart the server
|
||||
mid-episode and assert the engine degrades and recovers without ever
|
||||
raising on the control thread.
|
||||
"""
|
||||
|
||||
import time
|
||||
from threading import Event
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
pytest.importorskip("msgpack")
|
||||
zenoh = pytest.importorskip("zenoh")
|
||||
|
||||
from lerobot.policy_server.schema import MsgHeader, obs_key # noqa: E402
|
||||
from lerobot.policy_server.zenoh_utils import build_zenoh_config # noqa: E402
|
||||
from lerobot.rollout.inference.factory import FallbackMode # noqa: E402
|
||||
from lerobot.rollout.inference.remote import ClientState, RemoteInferenceEngine # noqa: E402
|
||||
from tests.policy_server.conftest import ( # noqa: E402
|
||||
ACTION_DIM,
|
||||
ACTION_NAMES,
|
||||
TASK,
|
||||
free_tcp_port,
|
||||
make_logic_server,
|
||||
make_loopback_manifest,
|
||||
make_remote_config,
|
||||
make_robot_obs,
|
||||
)
|
||||
|
||||
_FPS = 30.0
|
||||
_TICK_S = 1.0 / _FPS
|
||||
# Settle time after server.start() for zenoh declarations to propagate.
|
||||
_DECLARATION_SETTLE_S = 0.5
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _wait_until(predicate, timeout_s: float, interval_s: float = 0.05) -> bool:
|
||||
"""Poll ``predicate`` until true or the deadline passes (never a fixed sleep)."""
|
||||
deadline = time.monotonic() + timeout_s
|
||||
while time.monotonic() < deadline:
|
||||
if predicate():
|
||||
return True
|
||||
time.sleep(interval_s)
|
||||
return bool(predicate())
|
||||
|
||||
|
||||
def _start_loopback_server(port: int, attempts: int = 3):
|
||||
"""Start a fully-injected PolicyServer listening on tcp/127.0.0.1:<port>."""
|
||||
last_error: Exception | None = None
|
||||
for _ in range(attempts):
|
||||
server = make_logic_server(make_loopback_manifest(port))
|
||||
try:
|
||||
server.start()
|
||||
except Exception as e: # noqa: BLE001 — e.g. lingering socket on restart
|
||||
last_error = e
|
||||
server.stop()
|
||||
time.sleep(0.5)
|
||||
continue
|
||||
time.sleep(_DECLARATION_SETTLE_S)
|
||||
return server
|
||||
raise last_error
|
||||
|
||||
|
||||
def _make_engine(port: int, server, hw_features: dict, **config_overrides) -> RemoteInferenceEngine:
|
||||
return RemoteInferenceEngine(
|
||||
config=make_remote_config(port, **config_overrides),
|
||||
policy_config=server._policy_cfg,
|
||||
hw_features=hw_features,
|
||||
ordered_action_keys=list(ACTION_NAMES),
|
||||
task=TASK,
|
||||
fps=_FPS,
|
||||
robot_type="mock",
|
||||
shutdown_event=Event(),
|
||||
)
|
||||
|
||||
|
||||
def _start_engine(engine: RemoteInferenceEngine, attempts: int = 4) -> None:
|
||||
"""Engine start with handshake retries (declarations may still be settling)."""
|
||||
last_error: Exception | None = None
|
||||
for _ in range(attempts):
|
||||
try:
|
||||
engine.start()
|
||||
return
|
||||
except ConnectionError as e:
|
||||
last_error = e
|
||||
engine.stop()
|
||||
time.sleep(0.3)
|
||||
raise last_error
|
||||
|
||||
|
||||
def _collect_actions(engine: RemoteInferenceEngine, n: int, timeout_s: float) -> list[torch.Tensor]:
|
||||
"""Poll ``get_action`` at ~30 Hz until ``n`` actions arrive or the deadline passes."""
|
||||
actions: list[torch.Tensor] = []
|
||||
deadline = time.monotonic() + timeout_s
|
||||
while len(actions) < n and time.monotonic() < deadline:
|
||||
action = engine.get_action(None)
|
||||
if action is not None:
|
||||
actions.append(action)
|
||||
time.sleep(_TICK_S)
|
||||
return actions
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Happy path
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.timeout(60)
|
||||
def test_end_to_end_chunks(hw_features):
|
||||
port = free_tcp_port()
|
||||
server = _start_loopback_server(port)
|
||||
engine = _make_engine(port, server, hw_features)
|
||||
try:
|
||||
_start_engine(engine)
|
||||
engine.notify_observation(make_robot_obs(2.0))
|
||||
|
||||
actions = _collect_actions(engine, n=20, timeout_s=15.0)
|
||||
assert len(actions) >= 20
|
||||
|
||||
# chunk_robot[t, j] = 2 * (2.0 + 0.1 j + 0.01 t); the queue head is
|
||||
# trimmed by the (small, loopback) delay → within 0.1 of the t=0 value.
|
||||
first = actions[0]
|
||||
assert first.shape == (ACTION_DIM,)
|
||||
for j in range(ACTION_DIM):
|
||||
expected = 2.0 * (2.0 + 0.1 * j)
|
||||
assert abs(first[j].item() - expected) < 0.1, f"joint {j}: {first[j].item()} vs {expected}"
|
||||
|
||||
assert engine.state == ClientState.STREAMING
|
||||
assert engine.ready
|
||||
assert engine.failed is False
|
||||
assert engine.stats["merges"] >= 1
|
||||
finally:
|
||||
engine.stop()
|
||||
server.stop()
|
||||
|
||||
|
||||
@pytest.mark.timeout(60)
|
||||
def test_multi_client_no_cross_contamination(hw_features):
|
||||
port = free_tcp_port()
|
||||
server = _start_loopback_server(port)
|
||||
engine_a = _make_engine(port, server, hw_features, client_uuid="client-a")
|
||||
engine_b = _make_engine(port, server, hw_features, client_uuid="client-b")
|
||||
try:
|
||||
_start_engine(engine_a)
|
||||
_start_engine(engine_b)
|
||||
engine_a.notify_observation(make_robot_obs(2.0))
|
||||
engine_b.notify_observation(make_robot_obs(7.0))
|
||||
|
||||
actions_a = _collect_actions(engine_a, n=1, timeout_s=10.0)
|
||||
actions_b = _collect_actions(engine_b, n=1, timeout_s=10.0)
|
||||
assert actions_a, "engine A produced no actions"
|
||||
assert actions_b, "engine B produced no actions"
|
||||
|
||||
# Each engine's first action must reflect ITS OWN observation seed:
|
||||
# 2*(2.0) = 4.0 for A, 2*(7.0) = 14.0 for B (gap 10.0 ≫ tolerance).
|
||||
first_a = actions_a[0][0].item()
|
||||
first_b = actions_b[0][0].item()
|
||||
assert abs(first_a - 4.0) < 0.3, f"engine A got {first_a} (cross-contamination?)"
|
||||
assert abs(first_b - 14.0) < 0.3, f"engine B got {first_b} (cross-contamination?)"
|
||||
finally:
|
||||
engine_a.stop()
|
||||
engine_b.stop()
|
||||
server.stop()
|
||||
|
||||
|
||||
@pytest.mark.timeout(60)
|
||||
def test_reset_roundtrip(hw_features):
|
||||
port = free_tcp_port()
|
||||
server = _start_loopback_server(port)
|
||||
engine = _make_engine(port, server, hw_features)
|
||||
try:
|
||||
_start_engine(engine)
|
||||
engine.notify_observation(make_robot_obs(2.0))
|
||||
assert _collect_actions(engine, n=3, timeout_s=10.0), "no actions before reset"
|
||||
|
||||
merges_before = engine.stats["merges"]
|
||||
engine.reset()
|
||||
engine.notify_observation(make_robot_obs(5.0))
|
||||
|
||||
# New merges land after the reset (worker keeps cycling).
|
||||
assert _wait_until(lambda: engine.stats["merges"] > merges_before, timeout_s=10.0)
|
||||
# The queue refills with post-reset chunks.
|
||||
assert _collect_actions(engine, n=1, timeout_s=5.0), "queue did not refill after reset"
|
||||
|
||||
# Server-side session advanced to the new episode (via the acked
|
||||
# reset query, or via the episode bump in the next obs header).
|
||||
def _episode_advanced() -> bool:
|
||||
sessions = server.registry.snapshot()
|
||||
return bool(sessions) and sessions[0].episode_id >= 1
|
||||
|
||||
assert _wait_until(_episode_advanced, timeout_s=8.0), "server session episode_id never advanced"
|
||||
assert engine.failed is False
|
||||
finally:
|
||||
engine.stop()
|
||||
server.stop()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Chaos: server death / restart
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.timeout(60)
|
||||
def test_server_death_is_safe(hw_features):
|
||||
port = free_tcp_port()
|
||||
server = _start_loopback_server(port)
|
||||
engine = _make_engine(port, server, hw_features)
|
||||
try:
|
||||
_start_engine(engine)
|
||||
engine.notify_observation(make_robot_obs(2.0))
|
||||
assert _collect_actions(engine, n=5, timeout_s=10.0), "no actions before server death"
|
||||
|
||||
server.stop()
|
||||
|
||||
# Keep ticking at 30 Hz for ~2 s: get_action must never raise.
|
||||
results = []
|
||||
deadline = time.monotonic() + 2.0
|
||||
while time.monotonic() < deadline:
|
||||
results.append(engine.get_action(None))
|
||||
time.sleep(_TICK_S)
|
||||
|
||||
# The local queue drains and HOLD fallback yields None.
|
||||
assert all(result is None for result in results[-5:]), "queue never drained to HOLD fallback"
|
||||
# max_offline_s=8 not reached → not failed, in a degraded-but-alive state.
|
||||
assert engine.failed is False
|
||||
assert engine.state in {ClientState.DEGRADED, ClientState.STALLED, ClientState.RECONNECTING}
|
||||
finally:
|
||||
engine.stop()
|
||||
server.stop()
|
||||
|
||||
|
||||
@pytest.mark.timeout(60)
|
||||
def test_server_restart_recovery(hw_features):
|
||||
port = free_tcp_port()
|
||||
server = _start_loopback_server(port)
|
||||
engine = _make_engine(port, server, hw_features, max_offline_s=45.0)
|
||||
new_server = None
|
||||
try:
|
||||
_start_engine(engine)
|
||||
engine.notify_observation(make_robot_obs(2.0))
|
||||
assert _collect_actions(engine, n=3, timeout_s=10.0), "no actions before server death"
|
||||
|
||||
server.stop()
|
||||
# Let the engine notice the death (liveliness drop / request timeout).
|
||||
_wait_until(lambda: engine.state != ClientState.STREAMING, timeout_s=5.0)
|
||||
|
||||
new_server = _start_loopback_server(port)
|
||||
|
||||
# Re-handshake: bounded by the engine backoff and zenoh's TCP
|
||||
# reconnect period — poll generously rather than sleeping.
|
||||
reconnected = _wait_until(lambda: engine.stats["reconnects"] >= 1, timeout_s=25.0, interval_s=0.1)
|
||||
assert reconnected, f"engine never re-handshook (state={engine.state})"
|
||||
|
||||
engine.notify_observation(make_robot_obs(2.0))
|
||||
actions = _collect_actions(engine, n=3, timeout_s=10.0)
|
||||
assert len(actions) >= 3, "no actions after server restart"
|
||||
assert abs(actions[-1][0].item() - 4.0) < 0.3
|
||||
assert engine.failed is False
|
||||
finally:
|
||||
engine.stop()
|
||||
server.stop()
|
||||
if new_server is not None:
|
||||
new_server.stop()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Robustness: unknown clients, fallback modes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.timeout(60)
|
||||
def test_unknown_client_dropped(hw_features):
|
||||
port = free_tcp_port()
|
||||
server = _start_loopback_server(port)
|
||||
intruder = None
|
||||
engine = None
|
||||
try:
|
||||
intruder = zenoh.open(build_zenoh_config(mode="peer", connect_endpoints=[f"tcp/127.0.0.1:{port}"]))
|
||||
key = obs_key(server.prefix, "intruder-uuid")
|
||||
header_bytes = MsgHeader().pack() # valid header, garbage body, no session
|
||||
|
||||
deadline = time.monotonic() + 8.0
|
||||
while server.metrics["dropped_unknown_client_total"] < 1 and time.monotonic() < deadline:
|
||||
intruder.put(key, b"\xde\xad\xbe\xef", attachment=header_bytes)
|
||||
time.sleep(0.1)
|
||||
assert server.metrics["dropped_unknown_client_total"] >= 1
|
||||
assert len(server.registry) == 0
|
||||
|
||||
# The server stays healthy: a legitimate engine still works.
|
||||
engine = _make_engine(port, server, hw_features)
|
||||
_start_engine(engine)
|
||||
engine.notify_observation(make_robot_obs(2.0))
|
||||
actions = _collect_actions(engine, n=1, timeout_s=10.0)
|
||||
assert actions, "legitimate engine got no actions after garbage traffic"
|
||||
assert abs(actions[0][0].item() - 4.0) < 0.3
|
||||
finally:
|
||||
if engine is not None:
|
||||
engine.stop()
|
||||
if intruder is not None:
|
||||
intruder.close()
|
||||
server.stop()
|
||||
|
||||
|
||||
@pytest.mark.timeout(60)
|
||||
def test_fallback_zero(hw_features):
|
||||
port = free_tcp_port()
|
||||
server = _start_loopback_server(port)
|
||||
engine = _make_engine(port, server, hw_features, fallback=FallbackMode.ZERO)
|
||||
try:
|
||||
_start_engine(engine)
|
||||
engine.notify_observation(make_robot_obs(2.0))
|
||||
# With ZERO fallback an empty queue already yields zeros, so wait
|
||||
# for a *streamed* (nonzero ~4.0) action to prove chunks flowed.
|
||||
streamed = False
|
||||
deadline = time.monotonic() + 10.0
|
||||
while time.monotonic() < deadline:
|
||||
action = engine.get_action(None)
|
||||
if action is not None and torch.count_nonzero(action) > 0:
|
||||
streamed = True
|
||||
break
|
||||
time.sleep(_TICK_S)
|
||||
assert streamed, "no streamed (nonzero) actions before server death"
|
||||
|
||||
server.stop()
|
||||
|
||||
# Drain the local queue; once dry, ZERO fallback must return an
|
||||
# explicit zero command (never None) of the action dimension.
|
||||
saw_zero = False
|
||||
deadline = time.monotonic() + 6.0
|
||||
while time.monotonic() < deadline:
|
||||
action = engine.get_action(None)
|
||||
assert action is not None, "FallbackMode.ZERO returned None"
|
||||
if torch.count_nonzero(action) == 0:
|
||||
assert action.shape == (len(ACTION_NAMES),)
|
||||
saw_zero = True
|
||||
break
|
||||
time.sleep(_TICK_S)
|
||||
assert saw_zero, "queue never drained to the zero fallback"
|
||||
assert engine.failed is False
|
||||
finally:
|
||||
engine.stop()
|
||||
server.stop()
|
||||
@@ -0,0 +1,190 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Unit tests for the policy-server manifest (defaults + __post_init__ validation)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import textwrap
|
||||
|
||||
import draccus
|
||||
import pytest
|
||||
|
||||
from lerobot.policy_server.manifest import (
|
||||
SERVING_MODE_AUTO,
|
||||
ModelSpec,
|
||||
PolicyServerManifest,
|
||||
ZenohSpec,
|
||||
)
|
||||
|
||||
|
||||
def _manifest(**overrides) -> PolicyServerManifest:
|
||||
kwargs: dict = {"model": ModelSpec(repo_or_path="mock/model")}
|
||||
kwargs.update(overrides)
|
||||
return PolicyServerManifest(**kwargs)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Defaults
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_defaults_parse():
|
||||
# The default zenoh mode is "client", which requires a router address;
|
||||
# peer mode is the minimal valid transport for a defaults-only manifest.
|
||||
manifest = _manifest(zenoh=ZenohSpec(mode="peer"))
|
||||
assert manifest.model.repo_or_path == "mock/model"
|
||||
assert manifest.model.revision == "main"
|
||||
assert manifest.model.device == "cuda"
|
||||
assert manifest.serving_mode == SERVING_MODE_AUTO
|
||||
assert manifest.max_sessions == 5
|
||||
assert manifest.warmup_inferences == 2
|
||||
assert manifest.trained_fps == 30.0
|
||||
assert manifest.strict_fps is False
|
||||
assert manifest.pin_task is False
|
||||
assert manifest.session_idle_timeout_s == 300.0
|
||||
assert manifest.health_port == 9100
|
||||
assert manifest.debug.capture_dir is None
|
||||
assert manifest.rtc.enabled is True
|
||||
# Bare ZenohSpec defaults (validated only when embedded in a manifest).
|
||||
assert ZenohSpec().mode == "client"
|
||||
assert ZenohSpec().connect_endpoints == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# __post_init__ validation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_missing_repo_or_path_raises():
|
||||
with pytest.raises(ValueError, match="repo_or_path"):
|
||||
PolicyServerManifest()
|
||||
|
||||
|
||||
def test_bad_serving_mode_raises():
|
||||
with pytest.raises(ValueError, match="serving_mode"):
|
||||
_manifest(serving_mode="multiplexed")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("max_sessions", [0, -3])
|
||||
def test_max_sessions_below_one_raises(max_sessions):
|
||||
with pytest.raises(ValueError, match="max_sessions"):
|
||||
_manifest(max_sessions=max_sessions)
|
||||
|
||||
|
||||
def test_zenoh_client_mode_requires_connect_endpoints():
|
||||
with pytest.raises(ValueError, match="connect_endpoints"):
|
||||
_manifest(zenoh=ZenohSpec(mode="client", connect_endpoints=[]))
|
||||
|
||||
|
||||
def test_zenoh_client_mode_with_endpoints_ok():
|
||||
manifest = _manifest(zenoh=ZenohSpec(mode="client", connect_endpoints=["tcp/router:7447"]))
|
||||
assert manifest.zenoh.connect_endpoints == ["tcp/router:7447"]
|
||||
|
||||
|
||||
def test_zenoh_peer_mode_without_endpoints_ok():
|
||||
manifest = _manifest(zenoh=ZenohSpec(mode="peer"))
|
||||
assert manifest.zenoh.mode == "peer"
|
||||
assert manifest.zenoh.connect_endpoints == []
|
||||
|
||||
|
||||
def test_partial_tls_triple_raises():
|
||||
with pytest.raises(ValueError, match="TLS"):
|
||||
_manifest(zenoh=ZenohSpec(mode="peer", tls_root_ca_certificate="/certs/ca.pem"))
|
||||
|
||||
|
||||
def test_full_tls_triple_ok():
|
||||
manifest = _manifest(
|
||||
zenoh=ZenohSpec(
|
||||
mode="peer",
|
||||
tls_root_ca_certificate="/certs/ca.pem",
|
||||
tls_connect_certificate="/certs/cert.pem",
|
||||
tls_connect_private_key="/certs/key.pem",
|
||||
)
|
||||
)
|
||||
assert manifest.zenoh.tls_connect_private_key == "/certs/key.pem"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Draccus round-trip (YAML manifest → dataclass)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_draccus_yaml_round_trip(tmp_path):
|
||||
yaml_path = tmp_path / "server.yaml"
|
||||
yaml_path.write_text(
|
||||
textwrap.dedent(
|
||||
"""\
|
||||
model:
|
||||
repo_or_path: mock/model
|
||||
revision: v2.0
|
||||
device: cpu
|
||||
zenoh:
|
||||
mode: peer
|
||||
listen_endpoints:
|
||||
- tcp/127.0.0.1:7447
|
||||
default_task: pick the cube
|
||||
pin_task: true
|
||||
serving_mode: exclusive
|
||||
max_sessions: 1
|
||||
trained_fps: 25.0
|
||||
strict_fps: true
|
||||
health_port: 0
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
manifest = draccus.parse(PolicyServerManifest, config_path=str(yaml_path), args=[])
|
||||
|
||||
assert isinstance(manifest, PolicyServerManifest)
|
||||
assert manifest.model.repo_or_path == "mock/model"
|
||||
assert manifest.model.revision == "v2.0"
|
||||
assert manifest.model.device == "cpu"
|
||||
assert manifest.zenoh.mode == "peer"
|
||||
assert manifest.zenoh.listen_endpoints == ["tcp/127.0.0.1:7447"]
|
||||
assert manifest.default_task == "pick the cube"
|
||||
assert manifest.pin_task is True
|
||||
assert manifest.serving_mode == "exclusive"
|
||||
assert manifest.max_sessions == 1
|
||||
assert manifest.trained_fps == 25.0
|
||||
assert manifest.strict_fps is True
|
||||
assert manifest.health_port == 0
|
||||
# Untouched fields keep their defaults.
|
||||
assert manifest.warmup_inferences == 2
|
||||
assert manifest.session_idle_timeout_s == 300.0
|
||||
|
||||
|
||||
def test_draccus_cli_override_on_top_of_yaml(tmp_path):
|
||||
yaml_path = tmp_path / "server.yaml"
|
||||
yaml_path.write_text(
|
||||
textwrap.dedent(
|
||||
"""\
|
||||
model:
|
||||
repo_or_path: mock/model
|
||||
device: cpu
|
||||
zenoh:
|
||||
mode: peer
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
manifest = draccus.parse(
|
||||
PolicyServerManifest,
|
||||
config_path=str(yaml_path),
|
||||
args=["--max_sessions", "3", "--model.revision", "exp-1"],
|
||||
)
|
||||
|
||||
assert manifest.max_sessions == 3
|
||||
assert manifest.model.revision == "exp-1"
|
||||
assert manifest.model.repo_or_path == "mock/model"
|
||||
@@ -0,0 +1,117 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Unit tests for the RoundRobinScheduler fairness guarantees.
|
||||
|
||||
The scheduler only reads ``client_uuid`` from sessions, so minimal fakes
|
||||
stand in for real Session objects.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import Counter
|
||||
|
||||
from lerobot.policy_server.scheduler import RoundRobinScheduler
|
||||
|
||||
|
||||
class FakeSession:
|
||||
"""The scheduler touches nothing but client_uuid."""
|
||||
|
||||
def __init__(self, client_uuid: str):
|
||||
self.client_uuid = client_uuid
|
||||
|
||||
|
||||
def picks(scheduler: RoundRobinScheduler, ready: list[FakeSession], n: int) -> list[str]:
|
||||
served = []
|
||||
for _ in range(n):
|
||||
chosen = scheduler.select(list(ready))
|
||||
assert len(chosen) == 1
|
||||
served.append(chosen[0].client_uuid)
|
||||
return served
|
||||
|
||||
|
||||
def test_empty_ready_returns_empty():
|
||||
scheduler = RoundRobinScheduler()
|
||||
assert scheduler.select([]) == []
|
||||
|
||||
|
||||
def test_empty_ready_after_serving_returns_empty():
|
||||
scheduler = RoundRobinScheduler()
|
||||
scheduler.select([FakeSession("a")])
|
||||
assert scheduler.select([]) == []
|
||||
|
||||
|
||||
def test_single_session_picked_repeatedly():
|
||||
scheduler = RoundRobinScheduler()
|
||||
only = FakeSession("solo")
|
||||
for _ in range(5):
|
||||
assert scheduler.select([only]) == [only]
|
||||
|
||||
|
||||
def test_three_sessions_served_fairly_in_sorted_uuid_order():
|
||||
scheduler = RoundRobinScheduler()
|
||||
a, b, c = FakeSession("a"), FakeSession("b"), FakeSession("c")
|
||||
# Pass ready in non-sorted order: the ring is sorted by uuid internally.
|
||||
served = picks(scheduler, [c, a, b], 9)
|
||||
|
||||
assert served == ["a", "b", "c", "a", "b", "c", "a", "b", "c"]
|
||||
assert Counter(served) == {"a": 3, "b": 3, "c": 3}
|
||||
|
||||
|
||||
def test_session_leaving_between_calls_keeps_fairness():
|
||||
scheduler = RoundRobinScheduler()
|
||||
a, b, c = FakeSession("a"), FakeSession("b"), FakeSession("c")
|
||||
assert scheduler.select([a, b, c]) == [a]
|
||||
|
||||
# 'a' leaves; remaining sessions alternate with no crash or starvation.
|
||||
served = picks(scheduler, [b, c], 4)
|
||||
assert served == ["b", "c", "b", "c"]
|
||||
|
||||
|
||||
def test_departed_last_served_uuid_resumes_after_it():
|
||||
scheduler = RoundRobinScheduler()
|
||||
a, b, c = FakeSession("a"), FakeSession("b"), FakeSession("c")
|
||||
picks(scheduler, [a, b, c], 2) # last served is 'b'
|
||||
|
||||
# 'b' leaves; the next pick is the first uuid greater than 'b'.
|
||||
assert scheduler.select([a, c]) == [c]
|
||||
assert scheduler.select([a, c]) == [a]
|
||||
|
||||
|
||||
def test_wraparound_from_last_uuid_back_to_first():
|
||||
scheduler = RoundRobinScheduler()
|
||||
a, b, c = FakeSession("a"), FakeSession("b"), FakeSession("c")
|
||||
assert scheduler.select([c]) == [c] # last served is the highest uuid
|
||||
|
||||
# Everyone is <= last served: wrap back to the first sorted uuid.
|
||||
assert scheduler.select([a, b, c]) == [a]
|
||||
|
||||
|
||||
def test_newly_ready_session_joins_ring_fairly():
|
||||
scheduler = RoundRobinScheduler()
|
||||
a, c = FakeSession("a"), FakeSession("c")
|
||||
served = picks(scheduler, [a, c], 2)
|
||||
assert served == ["a", "c"]
|
||||
|
||||
# 'b' becomes ready; wrap-around lands on 'a', then 'b' gets its turn.
|
||||
b = FakeSession("b")
|
||||
served = picks(scheduler, [a, b, c], 3)
|
||||
assert served == ["a", "b", "c"]
|
||||
|
||||
|
||||
def test_no_starvation_over_many_cycles():
|
||||
scheduler = RoundRobinScheduler()
|
||||
ready = [FakeSession(f"u{i:02d}") for i in range(5)]
|
||||
served = picks(scheduler, ready, 50)
|
||||
assert Counter(served) == {f"u{i:02d}": 10 for i in range(5)}
|
||||
@@ -0,0 +1,197 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Unit tests for the wire schema: packed header and key-expression layout."""
|
||||
|
||||
import pytest
|
||||
|
||||
from lerobot.policy_server.schema import (
|
||||
HEADER_SIZE,
|
||||
MSG_TYPE_CHUNK,
|
||||
MSG_TYPE_EVENT,
|
||||
MSG_TYPE_OBS,
|
||||
RESERVED_SEGMENTS,
|
||||
SCHEMA_VERSION,
|
||||
MsgHeader,
|
||||
action_key,
|
||||
client_alive_key,
|
||||
client_alive_wildcard,
|
||||
client_uuid_from_key,
|
||||
obs_key,
|
||||
obs_wildcard,
|
||||
reset_key,
|
||||
reset_wildcard,
|
||||
sanitize_key_segment,
|
||||
server_alive_key,
|
||||
service_prefix,
|
||||
session_key,
|
||||
status_key,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MsgHeader
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_header_roundtrip_all_fields():
|
||||
hdr = MsgHeader(
|
||||
schema_version=3,
|
||||
msg_type=MSG_TYPE_CHUNK,
|
||||
seq_id=123456789,
|
||||
episode_id=42,
|
||||
client_mono_ns=987654321012345,
|
||||
session_epoch=7,
|
||||
)
|
||||
out = MsgHeader.unpack(hdr.pack())
|
||||
assert out == hdr
|
||||
|
||||
|
||||
def test_header_defaults_roundtrip():
|
||||
out = MsgHeader.unpack(MsgHeader().pack())
|
||||
assert out.schema_version == SCHEMA_VERSION
|
||||
assert out.msg_type == MSG_TYPE_OBS
|
||||
assert out.seq_id == 0
|
||||
assert out.episode_id == 0
|
||||
assert out.client_mono_ns == 0
|
||||
assert out.session_epoch == 0
|
||||
|
||||
|
||||
def test_header_negative_client_mono_ns():
|
||||
hdr = MsgHeader(msg_type=MSG_TYPE_EVENT, client_mono_ns=-123456789)
|
||||
out = MsgHeader.unpack(hdr.pack())
|
||||
assert out.client_mono_ns == -123456789
|
||||
|
||||
|
||||
def test_header_max_u64_seq_id():
|
||||
max_u64 = 2**64 - 1
|
||||
hdr = MsgHeader(seq_id=max_u64)
|
||||
out = MsgHeader.unpack(hdr.pack())
|
||||
assert out.seq_id == max_u64
|
||||
|
||||
|
||||
def test_header_size_constant_matches_pack():
|
||||
assert len(MsgHeader().pack()) == HEADER_SIZE
|
||||
|
||||
|
||||
def test_header_unpack_rejects_wrong_length():
|
||||
packed = MsgHeader().pack()
|
||||
with pytest.raises(ValueError, match="Bad header length"):
|
||||
MsgHeader.unpack(packed[:-1])
|
||||
with pytest.raises(ValueError, match="Bad header length"):
|
||||
MsgHeader.unpack(packed + b"\x00")
|
||||
with pytest.raises(ValueError, match="Bad header length"):
|
||||
MsgHeader.unpack(b"")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# sanitize_key_segment
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize("bad_char", ["/", "*", "$", "?", "#", " "])
|
||||
def test_sanitize_folds_unsafe_chars_to_dash(bad_char):
|
||||
assert sanitize_key_segment(f"a{bad_char}b") == "a-b"
|
||||
|
||||
|
||||
def test_sanitize_folds_runs_to_single_dash():
|
||||
assert sanitize_key_segment("a/*$?# b") == "a-b"
|
||||
|
||||
|
||||
def test_sanitize_strips_whitespace_and_edge_dashes():
|
||||
assert sanitize_key_segment(" hello ") == "hello"
|
||||
assert sanitize_key_segment("/leading/trailing/") == "leading-trailing"
|
||||
|
||||
|
||||
def test_sanitize_preserves_allowed_chars():
|
||||
assert sanitize_key_segment("Az09_.-x") == "Az09_.-x"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("empty_input", ["", " ", "***", "/?#$*"])
|
||||
def test_sanitize_raises_on_empty_after_sanitize(empty_input):
|
||||
with pytest.raises(ValueError, match="empty"):
|
||||
sanitize_key_segment(empty_input)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"reserved", sorted(["status", "session", "server", "obs", "action", "reset", "alive"])
|
||||
)
|
||||
def test_sanitize_raises_on_reserved_segments(reserved):
|
||||
assert reserved in RESERVED_SEGMENTS
|
||||
with pytest.raises(ValueError, match="reserved"):
|
||||
sanitize_key_segment(reserved)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# service_prefix
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_service_prefix_example():
|
||||
prefix = service_prefix("lerobot/pi0_towels", "main", "fold the towel")
|
||||
assert prefix == "@lerobot/lerobot-pi0_towels/main/fold-the-towel"
|
||||
|
||||
|
||||
def test_service_prefix_defaults_for_empty_revision_and_task():
|
||||
prefix = service_prefix("org/model", "", "")
|
||||
assert prefix == "@lerobot/org-model/main/default"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Key builders and wildcards
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
PREFIX = "@lerobot/org-model/main/default"
|
||||
UUID = "client-uuid-1234"
|
||||
|
||||
|
||||
def test_per_client_keys():
|
||||
assert obs_key(PREFIX, UUID) == f"{PREFIX}/{UUID}/obs"
|
||||
assert action_key(PREFIX, UUID) == f"{PREFIX}/{UUID}/action"
|
||||
assert reset_key(PREFIX, UUID) == f"{PREFIX}/{UUID}/reset"
|
||||
assert client_alive_key(PREFIX, UUID) == f"{PREFIX}/{UUID}/alive"
|
||||
|
||||
|
||||
def test_service_level_keys():
|
||||
assert status_key(PREFIX) == f"{PREFIX}/status"
|
||||
assert session_key(PREFIX) == f"{PREFIX}/session"
|
||||
assert server_alive_key(PREFIX) == f"{PREFIX}/server/alive"
|
||||
|
||||
|
||||
def test_wildcards_are_single_depth():
|
||||
assert obs_wildcard(PREFIX) == f"{PREFIX}/*/obs"
|
||||
assert reset_wildcard(PREFIX) == f"{PREFIX}/*/reset"
|
||||
assert client_alive_wildcard(PREFIX) == f"{PREFIX}/*/alive"
|
||||
assert "**" not in obs_wildcard(PREFIX)
|
||||
|
||||
|
||||
def test_key_builders_sanitize_client_uuid():
|
||||
assert obs_key(PREFIX, "bad uuid") == f"{PREFIX}/bad-uuid/obs"
|
||||
with pytest.raises(ValueError):
|
||||
obs_key(PREFIX, "status")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# client_uuid_from_key
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_client_uuid_from_obs_reset_alive_keys():
|
||||
assert client_uuid_from_key(obs_key(PREFIX, UUID)) == UUID
|
||||
assert client_uuid_from_key(reset_key(PREFIX, UUID)) == UUID
|
||||
assert client_uuid_from_key(client_alive_key(PREFIX, UUID)) == UUID
|
||||
|
||||
|
||||
def test_client_uuid_from_key_rejects_keys_without_client_chunk():
|
||||
with pytest.raises(ValueError, match="no client chunk"):
|
||||
client_uuid_from_key("obs")
|
||||
@@ -0,0 +1,405 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Pure-logic PolicyServer tests (no zenoh transport).
|
||||
|
||||
Covers status snapshots, session open/reject/re-handshake, the
|
||||
per-request inference path (determinism, RTC forwarding, echo fields,
|
||||
supersession), episode boundaries in ``_serve_one``, warmup, and the
|
||||
error/metrics accounting.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
pytest.importorskip("msgpack")
|
||||
|
||||
from lerobot.policy_server import codec # noqa: E402
|
||||
from lerobot.policy_server.schema import MsgHeader, ObservationMsg, SessionOpenMsg # noqa: E402
|
||||
from lerobot.policy_server.validation import PolicyClassification, ServingClass # noqa: E402
|
||||
from tests.policy_server.conftest import ( # noqa: E402
|
||||
ACTION_DIM,
|
||||
ACTION_NAMES,
|
||||
CHUNK_SIZE,
|
||||
IMG_H,
|
||||
IMG_W,
|
||||
MODEL_ID,
|
||||
STATE_DIM,
|
||||
TASK,
|
||||
MockChunkPolicy,
|
||||
make_logic_server,
|
||||
make_manifest,
|
||||
)
|
||||
|
||||
CAMERA_KEY = "observation.images.front"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def make_open_msg(client_uuid: str = "client-a", **overrides) -> SessionOpenMsg:
|
||||
kwargs = {
|
||||
"client_uuid": client_uuid,
|
||||
"robot_type": "so101",
|
||||
"policy_type": "mockchunk",
|
||||
"fps": 30.0,
|
||||
"action_names": list(ACTION_NAMES),
|
||||
"camera_names": [CAMERA_KEY],
|
||||
"state_dim": STATE_DIM,
|
||||
"rtc_enabled": True,
|
||||
"task": TASK,
|
||||
}
|
||||
kwargs.update(overrides)
|
||||
return SessionOpenMsg(**kwargs)
|
||||
|
||||
|
||||
def make_obs(**overrides) -> ObservationMsg:
|
||||
kwargs = {
|
||||
"state": np.arange(STATE_DIM, dtype=np.float32),
|
||||
"images": {CAMERA_KEY: np.zeros((IMG_H, IMG_W, 3), dtype=np.uint8)},
|
||||
"task": TASK,
|
||||
"jpeg_quality": 0,
|
||||
}
|
||||
kwargs.update(overrides)
|
||||
return ObservationMsg(**kwargs)
|
||||
|
||||
|
||||
def open_session(server, client_uuid: str = "client-a", **overrides):
|
||||
ack = server._handle_session_open(make_open_msg(client_uuid=client_uuid, **overrides))
|
||||
assert ack.accepted, ack.error
|
||||
return server.registry.get(client_uuid), ack
|
||||
|
||||
|
||||
def deposit(session, obs: ObservationMsg, header: MsgHeader | None = None) -> None:
|
||||
session.deposit(header or MsgHeader(episode_id=session.episode_id), codec.encode_observation(obs))
|
||||
|
||||
|
||||
def make_exclusive_server(policy=None):
|
||||
return make_logic_server(
|
||||
manifest=make_manifest(serving_mode="exclusive"),
|
||||
policy=policy,
|
||||
classification=PolicyClassification(
|
||||
ServingClass.EXCLUSIVE, supports_rtc=False, needs_queue_population=False, reason="x"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def expected_chunk(state: np.ndarray) -> np.ndarray:
|
||||
steps = np.arange(CHUNK_SIZE, dtype=np.float32)[:, None] * np.float32(0.01)
|
||||
return state[None, :ACTION_DIM] + steps
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# status_snapshot
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_status_snapshot_capabilities():
|
||||
server = make_logic_server()
|
||||
snap = server.status_snapshot()
|
||||
assert snap.model_repo == MODEL_ID
|
||||
assert snap.policy_type == "mockchunk"
|
||||
assert snap.action_names == ACTION_NAMES
|
||||
assert snap.state_dim == STATE_DIM
|
||||
assert snap.chunk_size == CHUNK_SIZE
|
||||
assert snap.expected_cameras == [CAMERA_KEY]
|
||||
assert snap.supports_rtc is True
|
||||
assert snap.warmed_up is True
|
||||
assert snap.serving_mode == "shared"
|
||||
assert snap.active_sessions == 0
|
||||
assert snap.max_sessions == 4 # make_manifest default
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _handle_session_open
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_session_open_happy_path():
|
||||
server = make_logic_server()
|
||||
session, ack = open_session(server, "client-a")
|
||||
assert ack.accepted is True
|
||||
assert ack.error == ""
|
||||
assert ack.session_id == session.session_id
|
||||
assert ack.session_id != ""
|
||||
assert ack.model_repo == MODEL_ID
|
||||
assert ack.policy_type == "mockchunk"
|
||||
assert ack.action_names == ACTION_NAMES
|
||||
assert ack.expected_cameras == [CAMERA_KEY]
|
||||
assert ack.state_dim == STATE_DIM
|
||||
assert ack.chunk_size == CHUNK_SIZE
|
||||
assert ack.supports_rtc is True
|
||||
assert ack.serving_mode == "shared"
|
||||
assert ack.warmed_up is True
|
||||
assert len(server.registry) == 1
|
||||
assert session.rtc_enabled is True
|
||||
assert session.task == TASK
|
||||
assert server.metrics["sessions_opened_total"] == 1
|
||||
|
||||
|
||||
def test_session_open_fresh_processor_pair_per_session():
|
||||
server = make_logic_server()
|
||||
assert len(server.factory_calls) == 0 # warmup_inferences=0: no warmup pair
|
||||
session_a, _ = open_session(server, "client-a")
|
||||
assert len(server.factory_calls) == 1
|
||||
session_b, _ = open_session(server, "client-b")
|
||||
assert len(server.factory_calls) == 2
|
||||
assert session_a.preprocessor is not session_b.preprocessor
|
||||
assert session_a.postprocessor is not session_b.postprocessor
|
||||
|
||||
|
||||
def test_session_open_rejects_action_order_mismatch():
|
||||
server = make_logic_server()
|
||||
ack = server._handle_session_open(make_open_msg(action_names=list(reversed(ACTION_NAMES))))
|
||||
assert ack.accepted is False
|
||||
assert "action" in ack.error
|
||||
assert len(server.registry) == 0
|
||||
|
||||
|
||||
def test_session_open_rejects_at_capacity():
|
||||
server = make_logic_server()
|
||||
for i in range(4): # make_manifest max_sessions=4
|
||||
open_session(server, f"client-{i}")
|
||||
ack = server._handle_session_open(make_open_msg(client_uuid="client-overflow"))
|
||||
assert ack.accepted is False
|
||||
assert "sessions" in ack.error
|
||||
assert len(server.registry) == 4
|
||||
|
||||
|
||||
def test_rehandshake_replaces_session_without_counting_against_capacity():
|
||||
server = make_logic_server()
|
||||
first_ack = None
|
||||
for i in range(4):
|
||||
_, ack = open_session(server, f"client-{i}")
|
||||
if i == 0:
|
||||
first_ack = ack
|
||||
# Server is full; the same client re-handshakes and must be accepted.
|
||||
session, ack = open_session(server, "client-0")
|
||||
assert ack.accepted is True
|
||||
assert len(server.registry) == 4
|
||||
assert ack.session_id != first_ack.session_id
|
||||
assert server.registry.get("client-0").session_id == session.session_id
|
||||
|
||||
|
||||
def test_session_open_rtc_downgrade():
|
||||
server = make_logic_server(
|
||||
classification=PolicyClassification(
|
||||
ServingClass.SHARED, supports_rtc=False, needs_queue_population=False, reason="x"
|
||||
)
|
||||
)
|
||||
session, ack = open_session(server, "client-a", rtc_enabled=True)
|
||||
assert ack.accepted is True
|
||||
assert ack.supports_rtc is False
|
||||
assert session.rtc_enabled is False
|
||||
assert any("RTC" in w for w in ack.warnings)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# run_inference_request
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_inference_deterministic_chunks():
|
||||
server = make_logic_server()
|
||||
session, _ = open_session(server)
|
||||
state = np.arange(STATE_DIM, dtype=np.float32)
|
||||
reply = server.run_inference_request(session, MsgHeader(), make_obs(state=state))
|
||||
assert reply.chunk_model.shape == (CHUNK_SIZE, ACTION_DIM)
|
||||
assert reply.chunk_robot.shape == (CHUNK_SIZE, ACTION_DIM)
|
||||
np.testing.assert_allclose(reply.chunk_model[0], state, rtol=0, atol=0)
|
||||
np.testing.assert_allclose(reply.chunk_model, expected_chunk(state), rtol=1e-6)
|
||||
np.testing.assert_allclose(reply.chunk_robot, 2.0 * reply.chunk_model, rtol=0, atol=0)
|
||||
|
||||
|
||||
def test_inference_delay_forwarded_to_policy():
|
||||
server = make_logic_server()
|
||||
session, _ = open_session(server)
|
||||
policy = server._policy
|
||||
server.run_inference_request(session, MsgHeader(), make_obs(inference_delay_steps=3))
|
||||
assert policy.calls[-1]["inference_delay"] == 3
|
||||
assert policy.calls[-1]["prev_chunk_left_over"] is None
|
||||
|
||||
|
||||
def test_prefix_model_forwarded_padded_to_execution_horizon():
|
||||
server = make_logic_server()
|
||||
session, _ = open_session(server)
|
||||
policy = server._policy
|
||||
prefix = (np.arange(3 * ACTION_DIM, dtype=np.float32).reshape(3, ACTION_DIM)) + 1.0
|
||||
server.run_inference_request(session, MsgHeader(), make_obs(prefix_model=prefix))
|
||||
received = policy.calls[-1]["prev_chunk_left_over"]
|
||||
assert received is not None
|
||||
horizon = server._manifest.rtc.execution_horizon # 10 by default
|
||||
assert received.shape == (horizon, ACTION_DIM)
|
||||
np.testing.assert_allclose(received[:3].numpy(), prefix, rtol=0, atol=0)
|
||||
np.testing.assert_allclose(received[3:].numpy(), np.zeros((horizon - 3, ACTION_DIM)), rtol=0, atol=0)
|
||||
|
||||
|
||||
def test_prefix_model_truncated_to_execution_horizon():
|
||||
server = make_logic_server()
|
||||
session, _ = open_session(server)
|
||||
policy = server._policy
|
||||
horizon = server._manifest.rtc.execution_horizon
|
||||
prefix = np.ones((horizon + 5, ACTION_DIM), dtype=np.float32)
|
||||
server.run_inference_request(session, MsgHeader(), make_obs(prefix_model=prefix))
|
||||
assert policy.calls[-1]["prev_chunk_left_over"].shape == (horizon, ACTION_DIM)
|
||||
|
||||
|
||||
def test_reply_echo_fields():
|
||||
server = make_logic_server()
|
||||
session, _ = open_session(server)
|
||||
header = MsgHeader(seq_id=7, episode_id=2, client_mono_ns=123_456_789)
|
||||
reply = server.run_inference_request(session, header, make_obs())
|
||||
assert reply.seq_id_echo == 7
|
||||
assert reply.episode_id_echo == 2
|
||||
assert reply.client_mono_ns_echo == 123_456_789
|
||||
|
||||
|
||||
def test_superseded_seqs_reported_then_reset():
|
||||
server = make_logic_server()
|
||||
session, _ = open_session(server)
|
||||
deposit(session, make_obs(), MsgHeader(seq_id=1))
|
||||
deposit(session, make_obs(), MsgHeader(seq_id=2)) # supersedes seq 1
|
||||
item = session.take()
|
||||
assert item.header.seq_id == 2 # latest-only mailbox
|
||||
reply = server.run_inference_request(session, item.header, codec.decode_observation(item.payload))
|
||||
assert reply.superseded_seqs == 1
|
||||
deposit(session, make_obs(), MsgHeader(seq_id=3))
|
||||
item = session.take()
|
||||
reply = server.run_inference_request(session, item.header, codec.decode_observation(item.payload))
|
||||
assert reply.superseded_seqs == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _serve_one
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_serve_one_episode_boundary_resets_session_pipelines():
|
||||
server = make_logic_server()
|
||||
session, _ = open_session(server)
|
||||
# Fresh sessions start at the -1 sentinel so their first request
|
||||
# always lands on the episode-boundary branch (mid-episode reconnects
|
||||
# can never inherit stale state).
|
||||
assert session.episode_id == -1
|
||||
deposit(session, make_obs(episode_start=True), MsgHeader(episode_id=1))
|
||||
server._serve_one(session)
|
||||
assert session.preprocessor.reset_count == 1
|
||||
assert session.postprocessor.reset_count == 1
|
||||
assert session.episode_id == 1
|
||||
# Shared mode never resets the policy itself.
|
||||
assert server._policy.reset_count == 0
|
||||
|
||||
|
||||
def test_serve_one_no_boundary_no_reset():
|
||||
server = make_logic_server()
|
||||
session, _ = open_session(server)
|
||||
# First request always resets (the -1 sentinel) and syncs the episode.
|
||||
deposit(session, make_obs(), MsgHeader(episode_id=0))
|
||||
server._serve_one(session)
|
||||
assert session.preprocessor.reset_count == 1
|
||||
assert session.episode_id == 0
|
||||
# Same-episode follow-up: no further reset.
|
||||
deposit(session, make_obs(), MsgHeader(episode_id=0, seq_id=2))
|
||||
server._serve_one(session)
|
||||
assert session.preprocessor.reset_count == 1
|
||||
assert session.postprocessor.reset_count == 1
|
||||
|
||||
|
||||
def test_serve_one_exclusive_mode_resets_policy_on_boundary():
|
||||
policy = MockChunkPolicy()
|
||||
server = make_exclusive_server(policy=policy)
|
||||
assert server.status_snapshot().serving_mode == "exclusive"
|
||||
assert server.status_snapshot().max_sessions == 1 # exclusive forces 1
|
||||
session, _ = open_session(server, rtc_enabled=False)
|
||||
# Exclusive session open already resets the policy to fresh state.
|
||||
base_resets = policy.reset_count
|
||||
assert base_resets >= 1
|
||||
deposit(session, make_obs(episode_start=True), MsgHeader(episode_id=1))
|
||||
server._serve_one(session)
|
||||
assert policy.reset_count == base_resets + 1
|
||||
assert session.episode_id == 1
|
||||
|
||||
|
||||
def test_serve_one_inference_error_counted_not_propagated(monkeypatch):
|
||||
server = make_logic_server()
|
||||
session, _ = open_session(server)
|
||||
|
||||
def boom(*args, **kwargs):
|
||||
raise RuntimeError("boom")
|
||||
|
||||
monkeypatch.setattr(server._policy, "predict_action_chunk", boom)
|
||||
deposit(session, make_obs())
|
||||
server._serve_one(session) # must not raise
|
||||
assert server.metrics["errors_total"] == 1
|
||||
assert server.metrics["requests_total"] == 0
|
||||
assert session.stats.errors == 1
|
||||
|
||||
|
||||
def test_serve_one_increments_requests_total():
|
||||
server = make_logic_server()
|
||||
session, _ = open_session(server)
|
||||
for seq in (1, 2):
|
||||
deposit(session, make_obs(), MsgHeader(seq_id=seq))
|
||||
server._serve_one(session)
|
||||
assert server.metrics["requests_total"] == 2
|
||||
assert server.metrics["errors_total"] == 0
|
||||
assert session.stats.requests == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Episode reset semantics (session-level, as used by _on_reset_query)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_session_reset_episode_clears_state():
|
||||
server = make_logic_server()
|
||||
session, _ = open_session(server)
|
||||
deposit(session, make_obs())
|
||||
assert session.has_pending()
|
||||
session.reset_episode(5)
|
||||
assert not session.has_pending() # mailbox cleared
|
||||
assert session.episode_id == 5
|
||||
assert session.preprocessor.reset_count == 1
|
||||
assert session.postprocessor.reset_count == 1
|
||||
session.reset_episode() # no explicit id: increments
|
||||
assert session.episode_id == 6
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Warmup
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_warmup_runs_before_any_session():
|
||||
policy = MockChunkPolicy()
|
||||
server = make_logic_server(manifest=make_manifest(warmup_inferences=2), policy=policy)
|
||||
assert len(policy.calls) == 2
|
||||
assert len(server.registry) == 0 # warmup session is not registered
|
||||
for call in policy.calls:
|
||||
assert tuple(call["state"].shape) == (1, STATE_DIM)
|
||||
assert float(call["state"].abs().sum()) == 0.0 # synthetic zeros
|
||||
assert server.status_snapshot().warmed_up is True
|
||||
|
||||
|
||||
def test_synthetic_observation_matches_input_features():
|
||||
server = make_logic_server()
|
||||
obs = server._synthetic_observation()
|
||||
assert obs.state.shape == (STATE_DIM,)
|
||||
assert obs.state.dtype == np.float32
|
||||
assert set(obs.images) == {CAMERA_KEY}
|
||||
assert obs.images[CAMERA_KEY].shape == (IMG_H, IMG_W, 3)
|
||||
assert obs.images[CAMERA_KEY].dtype == np.uint8
|
||||
assert obs.jpeg_quality == 0
|
||||
@@ -0,0 +1,298 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Unit tests for Session (latest-only mailbox, episode reset, close,
|
||||
processor-step introspection) and SessionRegistry (thread-safe map)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
|
||||
from lerobot.policy_server.schema import MsgHeader
|
||||
from lerobot.policy_server.session import Session, SessionRegistry
|
||||
from lerobot.processor import NormalizerProcessorStep, RelativeActionsProcessorStep
|
||||
from tests.policy_server.conftest import TASK, MockPipeline, make_mock_processors
|
||||
|
||||
|
||||
def make_session(
|
||||
client_uuid: str = "client-a",
|
||||
preprocessor: MockPipeline | None = None,
|
||||
postprocessor: MockPipeline | None = None,
|
||||
publisher=None,
|
||||
) -> Session:
|
||||
default_pre, default_post = make_mock_processors()
|
||||
return Session(
|
||||
session_id=f"sess-{client_uuid}",
|
||||
client_uuid=client_uuid,
|
||||
task=TASK,
|
||||
robot_type="mock_robot",
|
||||
rtc_enabled=False,
|
||||
preprocessor=preprocessor if preprocessor is not None else default_pre,
|
||||
postprocessor=postprocessor if postprocessor is not None else default_post,
|
||||
action_publisher=publisher,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Mailbox: latest-only deposit / take
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_deposit_then_take_returns_item():
|
||||
session = make_session()
|
||||
header = MsgHeader(seq_id=7)
|
||||
session.deposit(header, b"payload-7")
|
||||
|
||||
item = session.take()
|
||||
assert item is not None
|
||||
assert item.header.seq_id == 7
|
||||
assert item.payload == b"payload-7"
|
||||
assert item.recv_mono > 0
|
||||
|
||||
|
||||
def test_second_deposit_supersedes_and_take_returns_newer():
|
||||
session = make_session()
|
||||
session.deposit(MsgHeader(seq_id=1), b"old")
|
||||
session.deposit(MsgHeader(seq_id=2), b"new")
|
||||
|
||||
assert session.stats.superseded == 1
|
||||
assert session.stats.superseded_since_reply == 1
|
||||
|
||||
item = session.take()
|
||||
assert item is not None
|
||||
assert item.header.seq_id == 2
|
||||
assert item.payload == b"new"
|
||||
|
||||
|
||||
def test_deposit_after_take_is_not_superseded():
|
||||
session = make_session()
|
||||
session.deposit(MsgHeader(seq_id=1), b"one")
|
||||
session.take()
|
||||
session.deposit(MsgHeader(seq_id=2), b"two")
|
||||
|
||||
assert session.stats.superseded == 0
|
||||
assert session.stats.superseded_since_reply == 0
|
||||
|
||||
|
||||
def test_take_clears_mailbox_second_take_is_none():
|
||||
session = make_session()
|
||||
session.deposit(MsgHeader(seq_id=1), b"one")
|
||||
|
||||
assert session.take() is not None
|
||||
assert session.take() is None
|
||||
|
||||
|
||||
def test_has_pending_transitions():
|
||||
session = make_session()
|
||||
assert not session.has_pending()
|
||||
|
||||
session.deposit(MsgHeader(seq_id=1), b"one")
|
||||
assert session.has_pending()
|
||||
|
||||
session.take()
|
||||
assert not session.has_pending()
|
||||
|
||||
|
||||
def test_deposit_marks_alive_and_clears_token_drop():
|
||||
session = make_session()
|
||||
session.alive = False
|
||||
session.token_dropped_mono = 123.4
|
||||
before = session.last_seen_mono
|
||||
|
||||
session.deposit(MsgHeader(seq_id=1), b"one")
|
||||
|
||||
assert session.alive is True
|
||||
assert session.token_dropped_mono is None
|
||||
assert session.last_seen_mono >= before
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Episode boundary
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_reset_episode_resets_pipelines_clears_mailbox_and_increments():
|
||||
pre, post = make_mock_processors()
|
||||
session = make_session(preprocessor=pre, postprocessor=post)
|
||||
session.deposit(MsgHeader(seq_id=1), b"stale")
|
||||
assert session.episode_id == 0
|
||||
|
||||
session.reset_episode()
|
||||
|
||||
assert not session.has_pending()
|
||||
assert pre.reset_count == 1
|
||||
assert post.reset_count == 1
|
||||
assert session.episode_id == 1
|
||||
|
||||
session.reset_episode()
|
||||
assert session.episode_id == 2
|
||||
assert pre.reset_count == 2
|
||||
assert post.reset_count == 2
|
||||
|
||||
|
||||
def test_reset_episode_with_explicit_id():
|
||||
session = make_session()
|
||||
session.reset_episode(episode_id=7)
|
||||
assert session.episode_id == 7
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# close()
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class FakePublisher:
|
||||
def __init__(self, raise_on_undeclare: bool = False):
|
||||
self.undeclare_calls = 0
|
||||
self.raise_on_undeclare = raise_on_undeclare
|
||||
|
||||
def undeclare(self):
|
||||
self.undeclare_calls += 1
|
||||
if self.raise_on_undeclare:
|
||||
raise RuntimeError("transport already closed")
|
||||
|
||||
|
||||
def test_close_clears_mailbox_and_undeclares_publisher_exactly_once():
|
||||
publisher = FakePublisher()
|
||||
session = make_session(publisher=publisher)
|
||||
session.deposit(MsgHeader(seq_id=1), b"stale")
|
||||
|
||||
session.close()
|
||||
|
||||
assert not session.has_pending()
|
||||
assert publisher.undeclare_calls == 1
|
||||
assert session.action_publisher is None
|
||||
|
||||
# Idempotent: a second close must not undeclare again.
|
||||
session.close()
|
||||
assert publisher.undeclare_calls == 1
|
||||
|
||||
|
||||
def test_close_tolerates_undeclare_raising():
|
||||
publisher = FakePublisher(raise_on_undeclare=True)
|
||||
session = make_session(publisher=publisher)
|
||||
session.deposit(MsgHeader(seq_id=1), b"stale")
|
||||
|
||||
session.close() # must not raise
|
||||
|
||||
assert publisher.undeclare_calls == 1
|
||||
assert not session.has_pending()
|
||||
assert session.action_publisher is None
|
||||
|
||||
|
||||
def test_close_without_publisher_is_noop():
|
||||
session = make_session(publisher=None)
|
||||
session.close() # must not raise
|
||||
assert session.action_publisher is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Processor-step introspection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_relative_and_normalizer_steps_detected():
|
||||
relative = RelativeActionsProcessorStep(enabled=True)
|
||||
normalizer = NormalizerProcessorStep(features={}, norm_map={})
|
||||
pre = MockPipeline(steps=[relative, normalizer])
|
||||
session = make_session(preprocessor=pre)
|
||||
|
||||
assert session.relative_step is relative
|
||||
assert session.normalizer_step is normalizer
|
||||
|
||||
|
||||
def test_disabled_relative_step_is_not_detected():
|
||||
relative = RelativeActionsProcessorStep(enabled=False)
|
||||
pre = MockPipeline(steps=[relative])
|
||||
session = make_session(preprocessor=pre)
|
||||
|
||||
assert session.relative_step is None
|
||||
assert session.normalizer_step is None
|
||||
|
||||
|
||||
def test_empty_pipeline_yields_no_introspected_steps():
|
||||
session = make_session()
|
||||
assert session.relative_step is None
|
||||
assert session.normalizer_step is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SessionRegistry
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_registry_add_get_remove_len_snapshot():
|
||||
registry = SessionRegistry()
|
||||
assert len(registry) == 0
|
||||
assert registry.get("missing") is None
|
||||
assert registry.snapshot() == []
|
||||
|
||||
session_a = make_session("uuid-a")
|
||||
session_b = make_session("uuid-b")
|
||||
assert registry.add(session_a) is None
|
||||
assert registry.add(session_b) is None
|
||||
|
||||
assert len(registry) == 2
|
||||
assert registry.get("uuid-a") is session_a
|
||||
assert registry.get("uuid-b") is session_b
|
||||
assert set(registry.snapshot()) == {session_a, session_b}
|
||||
|
||||
removed = registry.remove("uuid-a")
|
||||
assert removed is session_a
|
||||
assert len(registry) == 1
|
||||
assert registry.get("uuid-a") is None
|
||||
|
||||
|
||||
def test_registry_remove_missing_returns_none():
|
||||
registry = SessionRegistry()
|
||||
assert registry.remove("never-added") is None
|
||||
|
||||
|
||||
def test_registry_add_returns_displaced_same_uuid_session():
|
||||
registry = SessionRegistry()
|
||||
first = make_session("uuid-x")
|
||||
second = make_session("uuid-x")
|
||||
|
||||
assert registry.add(first) is None
|
||||
displaced = registry.add(second)
|
||||
|
||||
assert displaced is first
|
||||
assert registry.get("uuid-x") is second
|
||||
assert len(registry) == 1
|
||||
|
||||
|
||||
def test_registry_thread_safety_smoke():
|
||||
registry = SessionRegistry()
|
||||
errors: list[Exception] = []
|
||||
|
||||
def worker(prefix: str) -> None:
|
||||
try:
|
||||
for i in range(200):
|
||||
session = make_session(f"{prefix}-{i}")
|
||||
registry.add(session)
|
||||
assert registry.get(session.client_uuid) is session
|
||||
assert registry.remove(session.client_uuid) is session
|
||||
except Exception as exc: # noqa: BLE001 — surfaced to the main thread
|
||||
errors.append(exc)
|
||||
|
||||
threads = [threading.Thread(target=worker, args=(p,)) for p in ("alpha", "beta")]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join(timeout=10)
|
||||
assert not t.is_alive()
|
||||
|
||||
assert errors == []
|
||||
assert len(registry) == 0
|
||||
assert registry.snapshot() == []
|
||||
@@ -0,0 +1,323 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Unit tests for serving-mode classification and session-open validation.
|
||||
|
||||
Uses tiny fake policy classes (deliberately NOT subclassing
|
||||
``PreTrainedPolicy``): classification keys off the ``name`` attribute and
|
||||
the presence of a ``predict_action_chunk`` override, never off the class
|
||||
hierarchy.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from lerobot.policy_server.schema import (
|
||||
MIN_SUPPORTED_SCHEMA_VERSION,
|
||||
SCHEMA_VERSION,
|
||||
SessionOpenMsg,
|
||||
StatusMsg,
|
||||
)
|
||||
from lerobot.policy_server.validation import (
|
||||
PolicyClassification,
|
||||
ServingClass,
|
||||
classify_policy,
|
||||
resolve_serving_mode,
|
||||
validate_session_open,
|
||||
)
|
||||
from tests.policy_server.conftest import ACTION_NAMES, STATE_DIM, TASK, make_manifest
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fake policy classes (classification only needs `name`, an optional
|
||||
# `predict_action_chunk` method, and `.config` for smolvla)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _fake_policy(name: str, *, chunk_api: bool = True, n_obs_steps: int | None = None):
|
||||
"""Build a minimal fake policy instance with a class-level chunk method."""
|
||||
|
||||
namespace = {"name": name}
|
||||
if chunk_api:
|
||||
namespace["predict_action_chunk"] = lambda self, batch, **kwargs: None
|
||||
cls = type(f"Fake_{name}_Policy", (), namespace)
|
||||
policy = cls()
|
||||
if n_obs_steps is not None:
|
||||
policy.config = SimpleNamespace(n_obs_steps=n_obs_steps)
|
||||
return policy
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# classify_policy
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_classify_act_is_shared_without_rtc():
|
||||
classification = classify_policy(_fake_policy("act"))
|
||||
assert classification.serving_class is ServingClass.SHARED
|
||||
assert classification.supports_rtc is False
|
||||
assert classification.needs_queue_population is False
|
||||
|
||||
|
||||
@pytest.mark.parametrize("name", ["pi0", "pi05"])
|
||||
def test_classify_pi_families_are_shared_with_rtc(name):
|
||||
classification = classify_policy(_fake_policy(name))
|
||||
assert classification.serving_class is ServingClass.SHARED
|
||||
assert classification.supports_rtc is True
|
||||
assert classification.needs_queue_population is False
|
||||
|
||||
|
||||
def test_classify_smolvla_single_obs_step_is_shared():
|
||||
classification = classify_policy(_fake_policy("smolvla", n_obs_steps=1))
|
||||
assert classification.serving_class is ServingClass.SHARED
|
||||
assert classification.supports_rtc is True
|
||||
|
||||
|
||||
def test_classify_smolvla_with_history_is_exclusive():
|
||||
classification = classify_policy(_fake_policy("smolvla", n_obs_steps=2))
|
||||
assert classification.serving_class is ServingClass.EXCLUSIVE
|
||||
assert classification.supports_rtc is True
|
||||
assert classification.needs_queue_population is False
|
||||
|
||||
|
||||
def test_classify_diffusion_is_exclusive_with_queue_population():
|
||||
classification = classify_policy(_fake_policy("diffusion"))
|
||||
assert classification.serving_class is ServingClass.EXCLUSIVE
|
||||
assert classification.supports_rtc is False
|
||||
assert classification.needs_queue_population is True
|
||||
|
||||
|
||||
def test_classify_without_chunk_api_is_refused():
|
||||
classification = classify_policy(_fake_policy("act", chunk_api=False))
|
||||
assert classification.serving_class is ServingClass.REFUSED
|
||||
assert classification.supports_rtc is False
|
||||
assert "predict_action_chunk" in classification.reason
|
||||
|
||||
|
||||
def test_classify_unknown_name_with_chunk_api_is_exclusive():
|
||||
classification = classify_policy(_fake_policy("totally_new_policy"))
|
||||
assert classification.serving_class is ServingClass.EXCLUSIVE
|
||||
assert classification.supports_rtc is False
|
||||
assert classification.needs_queue_population is False
|
||||
assert "verified" in classification.reason
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# resolve_serving_mode
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _classification(serving_class: ServingClass, reason: str = "test") -> PolicyClassification:
|
||||
return PolicyClassification(
|
||||
serving_class, supports_rtc=False, needs_queue_population=False, reason=reason
|
||||
)
|
||||
|
||||
|
||||
def test_resolve_auto_maps_shared_to_shared():
|
||||
mode, max_sessions = resolve_serving_mode(
|
||||
_classification(ServingClass.SHARED), make_manifest(serving_mode="auto", max_sessions=4)
|
||||
)
|
||||
assert mode == "shared"
|
||||
assert max_sessions == 4
|
||||
|
||||
|
||||
def test_resolve_auto_maps_exclusive_to_exclusive():
|
||||
mode, max_sessions = resolve_serving_mode(
|
||||
_classification(ServingClass.EXCLUSIVE), make_manifest(serving_mode="auto", max_sessions=4)
|
||||
)
|
||||
assert mode == "exclusive"
|
||||
assert max_sessions == 1
|
||||
|
||||
|
||||
def test_resolve_forced_shared_rejected_for_non_verified_policy():
|
||||
with pytest.raises(ValueError, match="unsafe"):
|
||||
resolve_serving_mode(_classification(ServingClass.EXCLUSIVE), make_manifest(serving_mode="shared"))
|
||||
|
||||
|
||||
def test_resolve_forced_shared_allowed_for_verified_policy():
|
||||
mode, max_sessions = resolve_serving_mode(
|
||||
_classification(ServingClass.SHARED), make_manifest(serving_mode="shared", max_sessions=4)
|
||||
)
|
||||
assert mode == "shared"
|
||||
assert max_sessions == 4
|
||||
|
||||
|
||||
def test_resolve_forced_exclusive_allowed_for_shared_policy():
|
||||
mode, _ = resolve_serving_mode(
|
||||
_classification(ServingClass.SHARED), make_manifest(serving_mode="exclusive")
|
||||
)
|
||||
assert mode == "exclusive"
|
||||
|
||||
|
||||
def test_resolve_exclusive_forces_single_session():
|
||||
mode, max_sessions = resolve_serving_mode(
|
||||
_classification(ServingClass.SHARED), make_manifest(serving_mode="exclusive", max_sessions=4)
|
||||
)
|
||||
assert mode == "exclusive"
|
||||
assert max_sessions == 1
|
||||
|
||||
|
||||
def test_resolve_refused_raises_with_reason():
|
||||
with pytest.raises(ValueError, match="no chunk API here"):
|
||||
resolve_serving_mode(
|
||||
_classification(ServingClass.REFUSED, reason="no chunk API here"), make_manifest()
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# validate_session_open
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
EXPECTED_CAMERAS = ["observation.images.front"]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def capabilities() -> StatusMsg:
|
||||
return StatusMsg(
|
||||
model_repo="mock/model",
|
||||
policy_type="mockchunk",
|
||||
action_names=list(ACTION_NAMES),
|
||||
expected_cameras=list(EXPECTED_CAMERAS),
|
||||
state_dim=STATE_DIM,
|
||||
chunk_size=20,
|
||||
trained_fps=30.0,
|
||||
supports_rtc=True,
|
||||
min_schema_version=MIN_SUPPORTED_SCHEMA_VERSION,
|
||||
max_schema_version=SCHEMA_VERSION,
|
||||
max_sessions=4,
|
||||
)
|
||||
|
||||
|
||||
def _open_msg(**overrides) -> SessionOpenMsg:
|
||||
kwargs: dict = {
|
||||
"client_uuid": "client-1",
|
||||
"policy_type": "mockchunk",
|
||||
"fps": 30.0,
|
||||
"action_names": list(ACTION_NAMES),
|
||||
"camera_names": list(EXPECTED_CAMERAS),
|
||||
"state_dim": STATE_DIM,
|
||||
"schema_version": SCHEMA_VERSION,
|
||||
"task": TASK,
|
||||
}
|
||||
kwargs.update(overrides)
|
||||
return SessionOpenMsg(**kwargs)
|
||||
|
||||
|
||||
def test_validate_happy_path(capabilities):
|
||||
result = validate_session_open(_open_msg(), capabilities, make_manifest(), active_sessions=0)
|
||||
assert result.ok
|
||||
assert result.error == ""
|
||||
assert result.warnings == []
|
||||
assert result.rtc_downgraded is False
|
||||
|
||||
|
||||
def test_validate_action_name_order_is_a_hard_reject(capabilities):
|
||||
# Same set of names, different order: chunk columns would map to the
|
||||
# wrong motors, so this must be a hard reject.
|
||||
result = validate_session_open(
|
||||
_open_msg(action_names=list(reversed(ACTION_NAMES))),
|
||||
capabilities,
|
||||
make_manifest(),
|
||||
active_sessions=0,
|
||||
)
|
||||
assert not result.ok
|
||||
assert "action" in result.error
|
||||
assert "mismatch" in result.error
|
||||
|
||||
|
||||
def test_validate_missing_camera_rejected(capabilities):
|
||||
result = validate_session_open(
|
||||
_open_msg(camera_names=[]), capabilities, make_manifest(), active_sessions=0
|
||||
)
|
||||
assert not result.ok
|
||||
assert "observation.images.front" in result.error
|
||||
|
||||
|
||||
def test_validate_wrong_state_dim_rejected(capabilities):
|
||||
result = validate_session_open(
|
||||
_open_msg(state_dim=STATE_DIM + 1), capabilities, make_manifest(), active_sessions=0
|
||||
)
|
||||
assert not result.ok
|
||||
assert "state dim" in result.error
|
||||
|
||||
|
||||
def test_validate_schema_version_out_of_range_rejected(capabilities):
|
||||
result = validate_session_open(
|
||||
_open_msg(schema_version=SCHEMA_VERSION + 99),
|
||||
capabilities,
|
||||
make_manifest(),
|
||||
active_sessions=0,
|
||||
)
|
||||
assert not result.ok
|
||||
assert "schema_version" in result.error
|
||||
|
||||
|
||||
def test_validate_at_capacity_rejected_with_load(capabilities):
|
||||
result = validate_session_open(
|
||||
_open_msg(), capabilities, make_manifest(), active_sessions=capabilities.max_sessions
|
||||
)
|
||||
assert not result.ok
|
||||
assert "full" in result.error
|
||||
assert f"{capabilities.max_sessions}/{capabilities.max_sessions}" in result.error
|
||||
|
||||
|
||||
def test_validate_pinned_task_rejects_other_task(capabilities):
|
||||
result = validate_session_open(
|
||||
_open_msg(task="another task"),
|
||||
capabilities,
|
||||
make_manifest(pin_task=True),
|
||||
active_sessions=0,
|
||||
)
|
||||
assert not result.ok
|
||||
assert "pinned" in result.error
|
||||
|
||||
|
||||
def test_validate_fps_mismatch_strict_rejects(capabilities):
|
||||
result = validate_session_open(
|
||||
_open_msg(fps=15.0), capabilities, make_manifest(strict_fps=True), active_sessions=0
|
||||
)
|
||||
assert not result.ok
|
||||
assert "fps" in result.error
|
||||
|
||||
|
||||
def test_validate_fps_mismatch_lenient_warns_only(capabilities):
|
||||
result = validate_session_open(
|
||||
_open_msg(fps=15.0), capabilities, make_manifest(strict_fps=False), active_sessions=0
|
||||
)
|
||||
assert result.ok
|
||||
assert len(result.warnings) == 1
|
||||
assert "fps" in result.warnings[0]
|
||||
|
||||
|
||||
def test_validate_rtc_downgraded_when_unsupported(capabilities):
|
||||
capabilities.supports_rtc = False
|
||||
result = validate_session_open(
|
||||
_open_msg(rtc_enabled=True), capabilities, make_manifest(), active_sessions=0
|
||||
)
|
||||
assert result.ok
|
||||
assert result.rtc_downgraded is True
|
||||
assert any("RTC" in warning for warning in result.warnings)
|
||||
|
||||
|
||||
def test_validate_empty_capability_action_names_skips_action_check(capabilities):
|
||||
capabilities.action_names = []
|
||||
result = validate_session_open(
|
||||
_open_msg(action_names=["whatever.pos"]),
|
||||
capabilities,
|
||||
make_manifest(),
|
||||
active_sessions=0,
|
||||
)
|
||||
assert result.ok
|
||||
@@ -24,6 +24,7 @@ from typing import Any
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from safetensors.torch import load_file
|
||||
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
|
||||
@@ -174,6 +175,53 @@ class MockStepWithTensorState(ProcessorStep):
|
||||
return features
|
||||
|
||||
|
||||
class MockLazyTensorStateStep(ProcessorStep):
|
||||
"""Mock step whose tensor state is not present in constructor config."""
|
||||
|
||||
def __init__(
|
||||
self, name: str = "lazy_tensor_step", scale: float = 1.0, initial_value: float | None = None
|
||||
):
|
||||
self.name = name
|
||||
self.scale = scale
|
||||
self.tensor_state: torch.Tensor | None = None
|
||||
|
||||
if initial_value is not None:
|
||||
self.tensor_state = torch.tensor([initial_value], dtype=torch.float32)
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
"""Return the transition unchanged."""
|
||||
return transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""Return constructor config while intentionally omitting tensor state."""
|
||||
return {
|
||||
"name": self.name,
|
||||
"scale": self.scale,
|
||||
}
|
||||
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
"""Return tensor state only after it has been initialized or loaded."""
|
||||
if self.tensor_state is None:
|
||||
return {}
|
||||
|
||||
return {"tensor_state": self.tensor_state}
|
||||
|
||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||
"""Load tensor state."""
|
||||
self.tensor_state = state["tensor_state"].clone()
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
"""Return features unchanged."""
|
||||
return features
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("registered_lazy_tensor_state_step")
|
||||
class RegisteredLazyTensorStateStep(MockLazyTensorStateStep):
|
||||
"""Registered lazy tensor state step for registry-based serialization tests."""
|
||||
|
||||
|
||||
def test_empty_pipeline():
|
||||
"""Test pipeline with no steps."""
|
||||
pipeline = DataProcessorPipeline([], to_transition=identity_transition, to_output=identity_transition)
|
||||
@@ -620,6 +668,178 @@ def test_mixed_json_and_tensor_state():
|
||||
assert torch.allclose(loaded_step.running_mean, step.running_mean)
|
||||
|
||||
|
||||
def test_get_config_matches_saved_json():
|
||||
"""Test that in-memory config matches the config written by save_pretrained."""
|
||||
stateless_step = MockStep(name="stateless")
|
||||
stateful_step = MockLazyTensorStateStep(name="stateful", initial_value=4.0)
|
||||
pipeline = DataProcessorPipeline([stateless_step, stateful_step], name="Memory Pipeline")
|
||||
|
||||
in_memory_config = pipeline.get_config()
|
||||
|
||||
assert pipeline.get_config() == in_memory_config
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
pipeline.save_pretrained(tmp_dir)
|
||||
|
||||
config_path = Path(tmp_dir) / "memory_pipeline.json"
|
||||
with open(config_path) as file_pointer:
|
||||
saved_config = json.load(file_pointer)
|
||||
|
||||
assert in_memory_config == saved_config
|
||||
assert "state_file" not in in_memory_config["steps"][0]
|
||||
assert in_memory_config["steps"][1]["state_file"] == "memory_pipeline_step_1.safetensors"
|
||||
|
||||
|
||||
def test_state_dict_matches_saved_safetensors():
|
||||
"""Test that in-memory state matches the safetensors written by save_pretrained."""
|
||||
stateful_step = MockLazyTensorStateStep(initial_value=7.0)
|
||||
pipeline = DataProcessorPipeline([stateful_step], name="Stateful Pipeline")
|
||||
|
||||
in_memory_state_dict = pipeline.state_dict()
|
||||
state_filename = "stateful_pipeline_step_0.safetensors"
|
||||
state_key = "stateful_pipeline_step_0"
|
||||
|
||||
assert set(in_memory_state_dict) == {state_key}
|
||||
assert set(in_memory_state_dict[state_key]) == {"tensor_state"}
|
||||
|
||||
in_memory_state_dict[state_key]["tensor_state"].add_(1)
|
||||
assert stateful_step.tensor_state is not None
|
||||
assert torch.equal(stateful_step.tensor_state, torch.tensor([7.0]))
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
pipeline.save_pretrained(tmp_dir)
|
||||
saved_state_dict = load_file(Path(tmp_dir) / state_filename)
|
||||
|
||||
torch.testing.assert_close(saved_state_dict["tensor_state"], torch.tensor([7.0]))
|
||||
|
||||
|
||||
def test_save_pretrained_still_writes_expected_serialization_files():
|
||||
"""Test that save_pretrained keeps the existing config and state filenames."""
|
||||
stateful_step = MockLazyTensorStateStep(initial_value=3.0)
|
||||
pipeline = DataProcessorPipeline([stateful_step], name="Policy Preprocessor")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
pipeline.save_pretrained(tmp_dir)
|
||||
|
||||
save_path = Path(tmp_dir)
|
||||
assert (save_path / "policy_preprocessor.json").exists()
|
||||
assert (save_path / "policy_preprocessor_step_0.safetensors").exists()
|
||||
|
||||
|
||||
def test_from_config_round_trips_stateful_pipeline():
|
||||
"""Test that from_config rebuilds a stateful pipeline from in-memory artifacts."""
|
||||
stateful_step = MockLazyTensorStateStep(name="roundtrip", initial_value=11.0)
|
||||
pipeline = DataProcessorPipeline([stateful_step], name="Roundtrip Pipeline")
|
||||
config = pipeline.get_config()
|
||||
pipeline_state_dict = pipeline.state_dict()
|
||||
|
||||
loaded_pipeline = DataProcessorPipeline.from_config(config, state_dict=pipeline_state_dict)
|
||||
loaded_step = loaded_pipeline.steps[0]
|
||||
|
||||
assert len(loaded_pipeline) == 1
|
||||
assert isinstance(loaded_step, MockLazyTensorStateStep)
|
||||
torch.testing.assert_close(loaded_step.tensor_state, torch.tensor([11.0]))
|
||||
|
||||
|
||||
def test_from_config_round_trips_registered_stateful_pipeline():
|
||||
"""Test that from_config resolves registry steps and loads their named tensor state."""
|
||||
stateful_step = RegisteredLazyTensorStateStep(name="registered", initial_value=29.0)
|
||||
pipeline = DataProcessorPipeline([stateful_step], name="Registry Pipeline")
|
||||
config = pipeline.get_config()
|
||||
pipeline_state_dict = pipeline.state_dict()
|
||||
state_filename = "registry_pipeline_step_0_registered_lazy_tensor_state_step.safetensors"
|
||||
state_key = "registry_pipeline_step_0_registered_lazy_tensor_state_step"
|
||||
|
||||
assert config["steps"][0]["registry_name"] == "registered_lazy_tensor_state_step"
|
||||
assert config["steps"][0]["state_file"] == state_filename
|
||||
assert set(pipeline_state_dict) == {state_key}
|
||||
|
||||
loaded_pipeline = DataProcessorPipeline.from_config(config, state_dict=pipeline_state_dict)
|
||||
loaded_step = loaded_pipeline.steps[0]
|
||||
|
||||
assert isinstance(loaded_step, RegisteredLazyTensorStateStep)
|
||||
assert loaded_step.tensor_state is not None
|
||||
torch.testing.assert_close(loaded_step.tensor_state, torch.tensor([29.0]))
|
||||
|
||||
|
||||
def test_from_config_preserves_state_metadata_for_empty_initial_state():
|
||||
"""Test in-memory loading when rebuilt steps start without tensor state."""
|
||||
stateful_step = MockLazyTensorStateStep(name="lazy", initial_value=13.0)
|
||||
pipeline = DataProcessorPipeline([stateful_step], name="Lazy Pipeline")
|
||||
config = pipeline.get_config()
|
||||
pipeline_state_dict = pipeline.state_dict()
|
||||
|
||||
loaded_pipeline = DataProcessorPipeline.from_config(config)
|
||||
loaded_step = loaded_pipeline.steps[0]
|
||||
|
||||
assert isinstance(loaded_step, MockLazyTensorStateStep)
|
||||
assert loaded_step.state_dict() == {}
|
||||
assert "state_file" not in loaded_pipeline.get_config()["steps"][0]
|
||||
|
||||
loaded_pipeline.load_state_dict(pipeline_state_dict)
|
||||
|
||||
torch.testing.assert_close(loaded_step.tensor_state, torch.tensor([13.0]))
|
||||
|
||||
|
||||
def test_from_config_applies_overrides_before_state_loading():
|
||||
"""Test that constructor overrides and tensor state loading are separate operations."""
|
||||
stateful_step = MockLazyTensorStateStep(name="override", scale=1.0, initial_value=17.0)
|
||||
pipeline = DataProcessorPipeline([stateful_step], name="Override Pipeline")
|
||||
config = pipeline.get_config()
|
||||
pipeline_state_dict = pipeline.state_dict()
|
||||
|
||||
loaded_pipeline = DataProcessorPipeline.from_config(
|
||||
config,
|
||||
state_dict=pipeline_state_dict,
|
||||
overrides={"MockLazyTensorStateStep": {"scale": 5.0}},
|
||||
)
|
||||
loaded_step = loaded_pipeline.steps[0]
|
||||
|
||||
assert isinstance(loaded_step, MockLazyTensorStateStep)
|
||||
assert loaded_step.scale == 5.0
|
||||
torch.testing.assert_close(loaded_step.tensor_state, torch.tensor([17.0]))
|
||||
|
||||
|
||||
def test_load_state_dict_raises_on_missing_expected_state():
|
||||
"""Test loading raises when serialized config expects missing state."""
|
||||
stateful_step = MockLazyTensorStateStep(initial_value=19.0)
|
||||
pipeline = DataProcessorPipeline([stateful_step], name="Missing Pipeline")
|
||||
loaded_pipeline = DataProcessorPipeline.from_config(pipeline.get_config())
|
||||
|
||||
with pytest.raises(KeyError, match="missing_pipeline_step_0"):
|
||||
loaded_pipeline.load_state_dict({})
|
||||
|
||||
|
||||
def test_load_state_dict_raises_on_unexpected_extra_state():
|
||||
"""Test loading raises on unexpected top-level state keys."""
|
||||
pipeline = DataProcessorPipeline([MockStep(name="stateless")], name="Unexpected Pipeline")
|
||||
|
||||
with pytest.raises(KeyError, match="extra"):
|
||||
pipeline.load_state_dict({"extra": {"tensor_state": torch.tensor([1.0])}})
|
||||
|
||||
|
||||
def test_stateless_pipeline_in_memory_serialization_returns_empty_state():
|
||||
"""Test stateless in-memory serialization and loading."""
|
||||
pipeline = DataProcessorPipeline([MockStep(name="stateless")], name="Stateless Pipeline")
|
||||
config = pipeline.get_config()
|
||||
config_without_name = {"steps": config["steps"]}
|
||||
|
||||
assert pipeline.state_dict() == {}
|
||||
assert all("state_file" not in step_entry for step_entry in config["steps"])
|
||||
|
||||
loaded_pipeline = DataProcessorPipeline.from_config(config_without_name, state_dict={})
|
||||
|
||||
assert loaded_pipeline.name == "DataProcessorPipeline"
|
||||
assert loaded_pipeline.state_dict() == {}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("invalid_config", [None, [], "not config"])
|
||||
def test_from_config_rejects_non_dict_config(invalid_config):
|
||||
"""Test from_config reports invalid top-level config values cleanly."""
|
||||
with pytest.raises(ValueError, match="not a valid processor configuration"):
|
||||
DataProcessorPipeline.from_config(invalid_config) # type: ignore[arg-type]
|
||||
|
||||
|
||||
class MockModuleStep(ProcessorStep, nn.Module):
|
||||
"""Mock step that inherits from nn.Module to test state_dict handling of module parameters."""
|
||||
|
||||
|
||||
@@ -59,6 +59,7 @@ def test_strategy_config_types():
|
||||
from lerobot.rollout import (
|
||||
BaseStrategyConfig,
|
||||
DAggerStrategyConfig,
|
||||
EpisodicStrategyConfig,
|
||||
HighlightStrategyConfig,
|
||||
SentryStrategyConfig,
|
||||
)
|
||||
@@ -67,6 +68,7 @@ def test_strategy_config_types():
|
||||
assert SentryStrategyConfig().type == "sentry"
|
||||
assert HighlightStrategyConfig().type == "highlight"
|
||||
assert DAggerStrategyConfig().type == "dagger"
|
||||
assert EpisodicStrategyConfig().type == "episodic"
|
||||
|
||||
|
||||
def test_dagger_config_invalid_input_device():
|
||||
@@ -203,6 +205,8 @@ def test_create_strategy_dispatches():
|
||||
BaseStrategyConfig,
|
||||
DAggerStrategy,
|
||||
DAggerStrategyConfig,
|
||||
EpisodicStrategy,
|
||||
EpisodicStrategyConfig,
|
||||
SentryStrategy,
|
||||
SentryStrategyConfig,
|
||||
create_strategy,
|
||||
@@ -211,6 +215,7 @@ def test_create_strategy_dispatches():
|
||||
assert isinstance(create_strategy(BaseStrategyConfig()), BaseStrategy)
|
||||
assert isinstance(create_strategy(SentryStrategyConfig()), SentryStrategy)
|
||||
assert isinstance(create_strategy(DAggerStrategyConfig()), DAggerStrategy)
|
||||
assert isinstance(create_strategy(EpisodicStrategyConfig()), EpisodicStrategy)
|
||||
|
||||
|
||||
def test_create_strategy_unknown_raises():
|
||||
@@ -343,3 +348,70 @@ def test_rollout_context_fields():
|
||||
|
||||
field_names = {f.name for f in dataclasses.fields(RolloutContext)}
|
||||
assert field_names == {"runtime", "hardware", "policy", "processors", "data"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Remote inference config & factory dispatch
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_create_inference_engine_remote_requires_policy_config():
|
||||
from lerobot.rollout.inference.factory import RemoteInferenceConfig, create_inference_engine
|
||||
|
||||
with pytest.raises(ValueError, match="policy_config"):
|
||||
create_inference_engine(
|
||||
RemoteInferenceConfig(),
|
||||
policy=None,
|
||||
preprocessor=None,
|
||||
postprocessor=None,
|
||||
robot_wrapper=MagicMock(robot_type="mock"),
|
||||
hw_features={},
|
||||
dataset_features={},
|
||||
ordered_action_keys=["k"],
|
||||
task="t",
|
||||
fps=30.0,
|
||||
device=None,
|
||||
policy_config=None,
|
||||
)
|
||||
|
||||
|
||||
def test_remote_config_draccus_registration():
|
||||
from lerobot.rollout.inference.factory import InferenceEngineConfig, RemoteInferenceConfig
|
||||
|
||||
assert RemoteInferenceConfig().type == "remote"
|
||||
assert InferenceEngineConfig.get_choice_class("remote") is RemoteInferenceConfig
|
||||
assert "remote" in dict(InferenceEngineConfig.get_known_choices())
|
||||
|
||||
|
||||
def test_fallback_mode_values():
|
||||
from lerobot.rollout.inference.factory import FallbackMode
|
||||
|
||||
assert FallbackMode.HOLD.value == "hold"
|
||||
assert FallbackMode.REPEAT_LAST.value == "repeat_last"
|
||||
assert FallbackMode.ZERO.value == "zero"
|
||||
assert {mode.value for mode in FallbackMode} == {"hold", "repeat_last", "zero"}
|
||||
|
||||
|
||||
def test_local_backends_require_loaded_policy():
|
||||
from lerobot.rollout.inference.factory import (
|
||||
RTCInferenceConfig,
|
||||
SyncInferenceConfig,
|
||||
create_inference_engine,
|
||||
)
|
||||
|
||||
common = {
|
||||
"policy": None,
|
||||
"preprocessor": None,
|
||||
"postprocessor": None,
|
||||
"robot_wrapper": MagicMock(robot_type="mock"),
|
||||
"hw_features": {},
|
||||
"dataset_features": {},
|
||||
"ordered_action_keys": ["k"],
|
||||
"task": "t",
|
||||
"fps": 30.0,
|
||||
"device": "cpu",
|
||||
}
|
||||
with pytest.raises(ValueError, match="requires a loaded policy"):
|
||||
create_inference_engine(SyncInferenceConfig(), **common)
|
||||
with pytest.raises(ValueError, match="requires a loaded policy"):
|
||||
create_inference_engine(RTCInferenceConfig(), **common)
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.utils.logging_utils import AverageMeter, MetricsTracker
|
||||
|
||||
@@ -25,8 +26,16 @@ def mock_metrics():
|
||||
|
||||
|
||||
class MockAccelerator:
|
||||
def __init__(self, num_processes: int):
|
||||
def __init__(self, num_processes: int, reduce_fn=None):
|
||||
self.num_processes = num_processes
|
||||
self.device = torch.device("cpu")
|
||||
self._reduce_fn = reduce_fn
|
||||
|
||||
def reduce(self, tensor, reduction="mean"):
|
||||
# In single-process tests we just want a deterministic stand-in for accelerate's reduce.
|
||||
if self._reduce_fn is not None:
|
||||
return self._reduce_fn(tensor, reduction)
|
||||
return tensor
|
||||
|
||||
|
||||
def test_average_meter_initialization():
|
||||
@@ -157,3 +166,70 @@ def test_metrics_tracker_reset_averages(mock_metrics):
|
||||
tracker.reset_averages()
|
||||
assert tracker.loss.avg == 0.0
|
||||
assert tracker.accuracy.avg == 0.0
|
||||
|
||||
|
||||
def test_average_meter_invalid_reduction():
|
||||
with pytest.raises(ValueError):
|
||||
AverageMeter("loss", reduction="median")
|
||||
|
||||
|
||||
def test_average_meter_reduction_stored():
|
||||
meter = AverageMeter("updt_s", reduction="max")
|
||||
assert meter.reduction == "max"
|
||||
|
||||
|
||||
def test_metrics_tracker_reduce_across_ranks_no_accelerator():
|
||||
metrics = {"update_s": AverageMeter("update_s", reduction="max")}
|
||||
tracker = MetricsTracker(batch_size=32, num_frames=1000, num_episodes=50, metrics=metrics)
|
||||
tracker.update_s = 0.5
|
||||
tracker.reduce_across_ranks() # no-op without accelerator
|
||||
assert tracker.update_s.avg == 0.5
|
||||
|
||||
|
||||
def test_metrics_tracker_reduce_across_ranks_single_process():
|
||||
metrics = {"update_s": AverageMeter("update_s", reduction="max")}
|
||||
tracker = MetricsTracker(
|
||||
batch_size=32,
|
||||
num_frames=1000,
|
||||
num_episodes=50,
|
||||
metrics=metrics,
|
||||
accelerator=MockAccelerator(num_processes=1),
|
||||
)
|
||||
tracker.update_s = 0.5
|
||||
tracker.reduce_across_ranks() # no-op when world size is 1
|
||||
assert tracker.update_s.avg == 0.5
|
||||
|
||||
|
||||
def test_metrics_tracker_reduce_across_ranks_invokes_reduce():
|
||||
captured = {}
|
||||
|
||||
def fake_reduce(tensor, reduction):
|
||||
captured["reduction"] = reduction
|
||||
captured["values"] = tensor.clone()
|
||||
# Pretend the slowest rank reported 0.9 instead of this rank's 0.4.
|
||||
return torch.tensor([0.9], dtype=tensor.dtype, device=tensor.device)
|
||||
|
||||
metrics = {
|
||||
"loss": AverageMeter("loss"), # reduction="none" -> not touched
|
||||
"update_s": AverageMeter("update_s", reduction="max"),
|
||||
}
|
||||
tracker = MetricsTracker(
|
||||
batch_size=32,
|
||||
num_frames=1000,
|
||||
num_episodes=50,
|
||||
metrics=metrics,
|
||||
accelerator=MockAccelerator(num_processes=4, reduce_fn=fake_reduce),
|
||||
)
|
||||
tracker.loss = 1.0
|
||||
tracker.update_s = 0.4
|
||||
tracker.reduce_across_ranks()
|
||||
|
||||
assert captured["reduction"] == "max"
|
||||
assert torch.allclose(captured["values"], torch.tensor([0.4]))
|
||||
assert tracker.update_s.avg == pytest.approx(0.9)
|
||||
# Metrics without a reduction stay untouched.
|
||||
assert tracker.loss.avg == 1.0
|
||||
# Invariant: avg == sum / count must hold after reduce, so subsequent .update() calls
|
||||
# accumulate against the cluster view rather than the stale per-rank sum.
|
||||
meter = tracker.update_s
|
||||
assert meter.sum / meter.count == pytest.approx(meter.avg)
|
||||
|
||||
@@ -59,7 +59,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "accelerate"
|
||||
version = "1.13.0"
|
||||
version = "1.14.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "huggingface-hub" },
|
||||
@@ -71,9 +71,9 @@ dependencies = [
|
||||
{ name = "torch", version = "2.11.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux'" },
|
||||
{ name = "torch", version = "2.11.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "sys_platform == 'linux'" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/ca/14/787e5498cd062640f0f3d92ef4ae4063174f76f9afd29d13fc52a319daae/accelerate-1.13.0.tar.gz", hash = "sha256:d631b4e0f5b3de4aff2d7e9e6857d164810dfc3237d54d017f075122d057b236", size = 402835, upload-time = "2026-03-04T19:34:12.359Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/8d/75/94cd5d389649578aca399e5aa822637eec18319a1dadc400ffe2f9a7493f/accelerate-1.14.0.tar.gz", hash = "sha256:41b9c4377a54e0b460a959b0defa1b736e4ca0a2373252d9a539964c2afe3c8d", size = 412167, upload-time = "2026-06-11T13:45:52.326Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/7e/46/02ac5e262d4af18054b3e922b2baedbb2a03289ee792162de60a865defc5/accelerate-1.13.0-py3-none-any.whl", hash = "sha256:cf1a3efb96c18f7b152eb0fa7490f3710b19c3f395699358f08decca2b8b62e0", size = 383744, upload-time = "2026-03-04T19:34:10.313Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/a8/db/253133d7e7cb40d3af384bb2f5c0b4a2b7fdcffbc95c688cc67a20a3c103/accelerate-1.14.0-py3-none-any.whl", hash = "sha256:e94390c2863b873be18f623f9df48a0d8fe5eff13ea7f1a00092b0a7904888c6", size = 389246, upload-time = "2026-06-11T13:45:50.477Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1315,6 +1315,22 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/05/ec/fa6963f1198172c2b75c9ab6ecefb3045991f92f75f5eb41b6621b198123/easydict-1.13-py3-none-any.whl", hash = "sha256:6b787daf4dcaf6377b4ad9403a5cee5a86adbc0ca9a5bcf5410e9902002aeac2", size = 6804, upload-time = "2024-03-04T12:04:39.508Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "eclipse-zenoh"
|
||||
version = "1.9.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/d9/42/c8502d0e77f74b9cf4c192a01e620b3d15273d371464485796807d202d9d/eclipse_zenoh-1.9.0.tar.gz", hash = "sha256:b0477ab431132ebfe1096eccac13ea0066d50d1528d726c8872c00e0345070d1", size = 164557, upload-time = "2026-04-10T13:23:35.883Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/e7/3b/22b9104b0a022bd2b1627b4866876831585eda2eacb9ca1f3b4b8e847945/eclipse_zenoh-1.9.0-cp39-abi3-linux_armv6l.whl", hash = "sha256:15b6f37c407617ea4de32d32835cbcab4d1a116b892477490fc6c10a7d27c73b", size = 10664168, upload-time = "2026-04-10T13:23:15.008Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/05/c5/ee0815c7ec49c5a29307cd935478305159bb3f0b2489f8c54fc6db3fdf36/eclipse_zenoh-1.9.0-cp39-abi3-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:6f66059b12e1ec53c70bc25192b0e74502751759064726dbb153ed6dd8f4dc8b", size = 19942168, upload-time = "2026-04-10T13:23:17.785Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/7b/6a/42b83b4e8c262ebbb3bcae702394478326c807f54b3162130b0a603e1a01/eclipse_zenoh-1.9.0-cp39-abi3-macosx_10_12_x86_64.whl", hash = "sha256:180dd2a6da3b86b52e87f5e470a1f8a86db03c519978b22ffb1dc7c11f98ef3b", size = 10225694, upload-time = "2026-04-10T13:23:20.244Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/27/57/28e66893801b63df36fea355a64b6fc22637e1148a952ee11e3039ae955e/eclipse_zenoh-1.9.0-cp39-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:949d82851bc9e3ad646fd1307ee544ed23359dcfd18d4065075fc592f6ab6fa7", size = 10517069, upload-time = "2026-04-10T13:23:23.053Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f0/2f/be614f1f7f4e046da2764cd36227d19db3655839219744ce7a12e6e2dae6/eclipse_zenoh-1.9.0-cp39-abi3-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3a1fe847225cda21e3e74677cfd4ddfd2e72600d5a56968d4229d981c67f78d4", size = 11580068, upload-time = "2026-04-10T13:23:25.594Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/58/1b/2a074d4f4595bd37c3d12f1b2ad49bceef5c8cd0962cbfd97d1d39f32e1f/eclipse_zenoh-1.9.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:43299593891cfd648bca4b2aa00f3dca916508a49a0c9e6960902e6e867b247e", size = 10537556, upload-time = "2026-04-10T13:23:28.414Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ab/33/c3116f1bf7647ee0ea8972efbe0fe5710ae75ea7226440a8fda7f04a4cbc/eclipse_zenoh-1.9.0-cp39-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:8c139a43706c8ff3c94fa625008af8667687c161a8395ad1fa3faff29c16fae4", size = 10721249, upload-time = "2026-04-10T13:23:30.843Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/26/16/a94c4f37e3a088faadf4b5fbc64e5f69dea1023dc7efc49b3be0e0ecc953/eclipse_zenoh-1.9.0-cp39-abi3-win_amd64.whl", hash = "sha256:5dfb352eca4585b85edbbc84c6db58906008e202823ca280496c0b867f9719f0", size = 9124510, upload-time = "2026-04-10T13:23:34.119Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "egl-probe"
|
||||
version = "1.0.2"
|
||||
@@ -1764,7 +1780,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "gym-aloha"
|
||||
version = "0.1.3"
|
||||
version = "0.1.4"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "dm-control" },
|
||||
@@ -1772,14 +1788,14 @@ dependencies = [
|
||||
{ name = "imageio", extra = ["ffmpeg"] },
|
||||
{ name = "mujoco" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/b5/5e/4bb7204730501c2f645e0532a2df4339206948b2882f77cbf0eaf75bc5fe/gym_aloha-0.1.3.tar.gz", hash = "sha256:b794b246a2e6da6ce5f75e152f553fbd4412704bc217fe6311d0ede3bb72a75e", size = 443468, upload-time = "2025-10-09T14:02:35.024Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/4a/c5/a5b8bdbddfcadec0b52b50e6d1a70325e09e6b594e5f55929d67d9122e2c/gym_aloha-0.1.4.tar.gz", hash = "sha256:0dc4e645045aeb3e74e3c320872d28df6dc93a8751d6ab2f266a2ca11323131f", size = 443466, upload-time = "2026-06-10T09:13:25.525Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/57/6c/10da397177c48ce360efa66ec21b10b10ef5fa2766256fcd8d7d9b5fa6fc/gym_aloha-0.1.3-py3-none-any.whl", hash = "sha256:a94e5747e71307897ded7ae17ed97fab05e814dcb714a16d320f110444f9d0c3", size = 447908, upload-time = "2025-10-09T14:02:33.253Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/35/e3/3afd0e517a503aabe255bf65f5136490acb79c43189e8d56a3aa63081a10/gym_aloha-0.1.4-py3-none-any.whl", hash = "sha256:d9044290fbccddf0be4246b5287cf0eb6b9ddee545a3d222ce8d78c93ce7125e", size = 447908, upload-time = "2026-06-10T09:13:23.868Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "gym-hil"
|
||||
version = "0.1.13"
|
||||
version = "0.1.14"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "gymnasium" },
|
||||
@@ -1789,9 +1805,9 @@ dependencies = [
|
||||
{ name = "pygame" },
|
||||
{ name = "pynput" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/f3/41/e89c87b3c66fb2f8ab5818bff4aa552977911eabaee7c12a8a336dcc406f/gym_hil-0.1.13.tar.gz", hash = "sha256:b9eab7a0acc811f181254e3ad72865830fdbb292c236895f374135d3d62f1b27", size = 5668001, upload-time = "2025-10-21T09:57:24.01Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/0c/64/b5cfe59d6a69d20497218f01ad2bdaa2a5a72b850bdb1a445d804ecc9948/gym_hil-0.1.14.tar.gz", hash = "sha256:aeee688dcb3ec72e7bcbe604df4a3f990cce49c8a2da469dd67c3a4eeb4c6bbb", size = 5667991, upload-time = "2026-06-10T09:16:38.98Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/c2/8d/9e3ab53f9aac7bd542f339efd0a9283fa76e034474987e0705379274dfcf/gym_hil-0.1.13-py3-none-any.whl", hash = "sha256:b6444fc43ce1a68ce403df14f99100d9c903ae05d822959e9cd0b76a50b93320", size = 5750805, upload-time = "2025-10-21T09:57:22.068Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/72/97/a7a9c3886306a89046ba5c989bc8b79008e7ec973228bad1fa20d7a94bba/gym_hil-0.1.14-py3-none-any.whl", hash = "sha256:9a2799d47a4561e0b0bb8d37fb3d84934657240be328d13991ea06758726533d", size = 5750805, upload-time = "2026-06-10T09:16:36.827Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1881,7 +1897,7 @@ sdist = { url = "https://files.pythonhosted.org/packages/e6/3e/ffad88145b342d5a9
|
||||
|
||||
[[package]]
|
||||
name = "hf-libero"
|
||||
version = "0.1.3"
|
||||
version = "0.1.4"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "bddl", marker = "sys_platform == 'linux'" },
|
||||
@@ -1902,7 +1918,10 @@ dependencies = [
|
||||
{ name = "transformers", marker = "sys_platform == 'linux'" },
|
||||
{ name = "wandb", marker = "sys_platform == 'linux'" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/7e/ca/7f1c90aedcd067d608681cf03469ae548990ba0806f68a67927dcc801f04/hf_libero-0.1.3.tar.gz", hash = "sha256:0d6b9a215a658db86f66c03d063d6d877d2e9f96d2d326cfa9f43ba4da4a6d5a", size = 2960521, upload-time = "2025-11-03T17:58:00.003Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/af/aa/4e9eb8715e0bff9cb6553db563a35d253393097d446f82bd53575e8b253d/hf_libero-0.1.4.tar.gz", hash = "sha256:c058d67ad5a2b589529c14d614282ef4cca3a7763dafa134f58a6c9039657e34", size = 2961319, upload-time = "2026-06-10T09:56:13.994Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/2a/79/c286b894c051988d062241682834df915c945bcf51009ffdffbe5ecf69bf/hf_libero-0.1.4-py3-none-any.whl", hash = "sha256:207f76e2f28bff30f78132223d8592fe8f64b1f8fd90ce7024948ada0d7e2c27", size = 3169084, upload-time = "2026-06-10T09:56:12.441Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hf-xet"
|
||||
@@ -2684,6 +2703,9 @@ dependencies = [
|
||||
]
|
||||
|
||||
[package.optional-dependencies]
|
||||
accelerate-dep = [
|
||||
{ name = "accelerate" },
|
||||
]
|
||||
all = [
|
||||
{ name = "accelerate" },
|
||||
{ name = "av" },
|
||||
@@ -2693,6 +2715,7 @@ all = [
|
||||
{ name = "deepdiff" },
|
||||
{ name = "diffusers" },
|
||||
{ name = "dynamixel-sdk" },
|
||||
{ name = "eclipse-zenoh" },
|
||||
{ name = "faker" },
|
||||
{ name = "fastapi" },
|
||||
{ name = "feetech-servo-sdk" },
|
||||
@@ -2712,6 +2735,7 @@ all = [
|
||||
{ name = "mock-serial", marker = "sys_platform != 'win32'" },
|
||||
{ name = "motorbridge" },
|
||||
{ name = "motorbridge-smart-servo" },
|
||||
{ name = "msgpack" },
|
||||
{ name = "mypy" },
|
||||
{ name = "num2words" },
|
||||
{ name = "pandas" },
|
||||
@@ -2755,10 +2779,8 @@ aloha = [
|
||||
{ name = "torchcodec", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'AMD64' and sys_platform == 'linux') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'win32'" },
|
||||
]
|
||||
async = [
|
||||
{ name = "contourpy" },
|
||||
{ name = "grpcio" },
|
||||
{ name = "matplotlib" },
|
||||
{ name = "protobuf" },
|
||||
{ name = "eclipse-zenoh" },
|
||||
{ name = "msgpack" },
|
||||
]
|
||||
av-dep = [
|
||||
{ name = "av" },
|
||||
@@ -2983,6 +3005,8 @@ qwen-vl-utils-dep = [
|
||||
{ name = "qwen-vl-utils" },
|
||||
]
|
||||
reachy2 = [
|
||||
{ name = "grpcio" },
|
||||
{ name = "protobuf" },
|
||||
{ name = "reachy2-sdk" },
|
||||
]
|
||||
rebot = [
|
||||
@@ -3070,8 +3094,7 @@ xvla = [
|
||||
|
||||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "accelerate", marker = "extra == 'smolvla'", specifier = ">=1.7.0,<2.0.0" },
|
||||
{ name = "accelerate", marker = "extra == 'training'", specifier = ">=1.10.0,<2.0.0" },
|
||||
{ name = "accelerate", marker = "extra == 'accelerate-dep'", specifier = ">=1.14.0,<2.0.0" },
|
||||
{ name = "av", marker = "extra == 'av-dep'", specifier = ">=15.0.0,<16.0.0" },
|
||||
{ name = "cmake", specifier = ">=3.29.0.1,<4.2.0" },
|
||||
{ name = "contourpy", marker = "extra == 'matplotlib-dep'", specifier = ">=1.3.0,<2.0.0" },
|
||||
@@ -3083,24 +3106,28 @@ requires-dist = [
|
||||
{ name = "dm-tree", marker = "extra == 'groot'", specifier = ">=0.1.8,<1.0.0" },
|
||||
{ name = "draccus", specifier = "==0.10.0" },
|
||||
{ name = "dynamixel-sdk", marker = "extra == 'dynamixel'", specifier = ">=3.7.31,<3.9.0" },
|
||||
{ name = "eclipse-zenoh", marker = "extra == 'async'", specifier = ">=1.9,<2.0" },
|
||||
{ name = "einops", specifier = ">=0.8.0,<0.9.0" },
|
||||
{ name = "faker", marker = "extra == 'sarm'", specifier = ">=33.0.0,<35.0.0" },
|
||||
{ name = "fastapi", marker = "extra == 'phone'", specifier = "<1.0" },
|
||||
{ name = "feetech-servo-sdk", marker = "extra == 'feetech'", specifier = ">=1.0.0,<2.0.0" },
|
||||
{ name = "flash-attn", marker = "sys_platform != 'darwin' and extra == 'groot'", specifier = ">=2.5.9,<3.0.0" },
|
||||
{ name = "grpcio", marker = "extra == 'grpcio-dep'", specifier = "==1.73.1" },
|
||||
{ name = "grpcio-tools", marker = "extra == 'dev'", specifier = "==1.73.1" },
|
||||
{ name = "gym-aloha", marker = "extra == 'aloha'", specifier = ">=0.1.2,<0.2.0" },
|
||||
{ name = "gym-hil", marker = "extra == 'hilserl'", specifier = ">=0.1.13,<0.2.0" },
|
||||
{ name = "grpcio", marker = "extra == 'grpcio-dep'", specifier = ">=1.73.1,<2.0.0" },
|
||||
{ name = "grpcio", marker = "extra == 'reachy2'", specifier = "<=1.73.1" },
|
||||
{ name = "grpcio-tools", marker = "extra == 'dev'", specifier = ">=1.73.1,<2.0.0" },
|
||||
{ name = "gym-aloha", marker = "extra == 'aloha'", specifier = ">=0.1.4,<0.2.0" },
|
||||
{ name = "gym-hil", marker = "extra == 'hilserl'", specifier = ">=0.1.14,<0.2.0" },
|
||||
{ name = "gym-pusht", marker = "extra == 'pusht'", specifier = ">=0.1.5,<0.2.0" },
|
||||
{ name = "gymnasium", specifier = ">=1.1.1,<2.0.0" },
|
||||
{ name = "hebi-py", marker = "extra == 'phone'", specifier = ">=2.8.0,<2.12.0" },
|
||||
{ name = "hf-libero", marker = "sys_platform == 'linux' and extra == 'libero'", specifier = ">=0.1.3,<0.2.0" },
|
||||
{ name = "hf-libero", marker = "sys_platform == 'linux' and extra == 'libero'", specifier = ">=0.1.4,<0.2.0" },
|
||||
{ name = "hidapi", marker = "extra == 'gamepad'", specifier = ">=0.14.0,<0.15.0" },
|
||||
{ name = "huggingface-hub", specifier = ">=1.0.0,<2.0.0" },
|
||||
{ name = "ipykernel", marker = "extra == 'notebook'", specifier = ">=6.0.0,<7.0.0" },
|
||||
{ name = "jsonlines", marker = "extra == 'dataset'", specifier = ">=4.0.0,<5.0.0" },
|
||||
{ name = "jupyter", marker = "extra == 'notebook'", specifier = ">=1.0.0,<2.0.0" },
|
||||
{ name = "lerobot", extras = ["accelerate-dep"], marker = "extra == 'smolvla'" },
|
||||
{ name = "lerobot", extras = ["accelerate-dep"], marker = "extra == 'training'" },
|
||||
{ name = "lerobot", extras = ["aloha"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["async"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["av-dep"], marker = "extra == 'dataset'" },
|
||||
@@ -3132,7 +3159,6 @@ requires-dist = [
|
||||
{ name = "lerobot", extras = ["feetech"], marker = "extra == 'hopejr'" },
|
||||
{ name = "lerobot", extras = ["feetech"], marker = "extra == 'lekiwi'" },
|
||||
{ name = "lerobot", extras = ["gamepad"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["grpcio-dep"], marker = "extra == 'async'" },
|
||||
{ name = "lerobot", extras = ["grpcio-dep"], marker = "extra == 'dev'" },
|
||||
{ name = "lerobot", extras = ["grpcio-dep"], marker = "extra == 'hilserl'" },
|
||||
{ name = "lerobot", extras = ["hardware"], marker = "extra == 'all'" },
|
||||
@@ -3143,7 +3169,6 @@ requires-dist = [
|
||||
{ name = "lerobot", extras = ["kinematics"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["lekiwi"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["libero"], marker = "sys_platform == 'linux' and extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["matplotlib-dep"], marker = "extra == 'async'" },
|
||||
{ name = "lerobot", extras = ["matplotlib-dep"], marker = "extra == 'sarm'" },
|
||||
{ name = "lerobot", extras = ["matplotlib-dep"], marker = "extra == 'unitree-g1'" },
|
||||
{ name = "lerobot", extras = ["metaworld"], marker = "extra == 'all'" },
|
||||
@@ -3223,6 +3248,7 @@ requires-dist = [
|
||||
{ name = "mock-serial", marker = "sys_platform != 'win32' and extra == 'test'", specifier = ">=0.0.1,<0.1.0" },
|
||||
{ name = "motorbridge", marker = "extra == 'motorbridge-dep'", specifier = ">=0.3.2,<0.4.0" },
|
||||
{ name = "motorbridge-smart-servo", marker = "extra == 'motorbridge-smart-servo-dep'", specifier = ">=0.0.4,<0.1.0" },
|
||||
{ name = "msgpack", marker = "extra == 'async'", specifier = ">=1.0.0,<2.0.0" },
|
||||
{ name = "mypy", marker = "extra == 'dev'", specifier = ">=1.19.1" },
|
||||
{ name = "ninja", marker = "extra == 'groot'", specifier = ">=1.11.1,<2.0.0" },
|
||||
{ name = "num2words", marker = "extra == 'smolvla'", specifier = ">=0.5.14,<0.6.0" },
|
||||
@@ -3237,7 +3263,8 @@ requires-dist = [
|
||||
{ name = "pillow", specifier = ">=10.0.0,<13.0.0" },
|
||||
{ name = "placo", marker = "extra == 'placo-dep'", specifier = ">=0.9.6,<0.9.16" },
|
||||
{ name = "pre-commit", marker = "extra == 'dev'", specifier = ">=3.7.0,<5.0.0" },
|
||||
{ name = "protobuf", marker = "extra == 'grpcio-dep'", specifier = ">=6.31.1,<6.32.0" },
|
||||
{ name = "protobuf", marker = "extra == 'grpcio-dep'", specifier = ">=6.31.1,<8.0.0" },
|
||||
{ name = "protobuf", marker = "extra == 'reachy2'", specifier = "<=6.32.0" },
|
||||
{ name = "pyarrow", marker = "extra == 'dataset'", specifier = ">=21.0.0,<30.0.0" },
|
||||
{ name = "pydantic", marker = "extra == 'sarm'", specifier = ">=2.0.0,<3.0.0" },
|
||||
{ name = "pygame", marker = "extra == 'pygame-dep'", specifier = ">=2.5.1,<2.7.0" },
|
||||
@@ -3274,9 +3301,9 @@ requires-dist = [
|
||||
{ name = "torchvision", marker = "sys_platform == 'linux'", specifier = ">=0.22.0,<0.27.0", index = "https://download.pytorch.org/whl/cu128" },
|
||||
{ name = "tqdm", specifier = ">=4.66.0,<5.0.0" },
|
||||
{ name = "transformers", marker = "extra == 'transformers-dep'", specifier = ">=5.4.0,<5.6.0" },
|
||||
{ name = "wandb", marker = "extra == 'training'", specifier = ">=0.24.0,<0.25.0" },
|
||||
{ name = "wandb", marker = "extra == 'training'", specifier = ">=0.24.0,<0.28.0" },
|
||||
]
|
||||
provides-extras = ["dataset", "training", "hardware", "viz", "core-scripts", "evaluation", "dataset-viz", "av-dep", "pygame-dep", "placo-dep", "transformers-dep", "grpcio-dep", "can-dep", "peft-dep", "scipy-dep", "diffusers-dep", "qwen-vl-utils-dep", "matplotlib-dep", "pyserial-dep", "deepdiff-dep", "pynput-dep", "pyzmq-dep", "motorbridge-dep", "motorbridge-smart-servo-dep", "feetech", "dynamixel", "damiao", "robstride", "openarms", "gamepad", "hopejr", "lekiwi", "unitree-g1", "reachy2", "rebot", "kinematics", "intelrealsense", "phone", "diffusion", "wallx", "pi", "molmoact2", "smolvla", "multi-task-dit", "groot", "sarm", "robometer", "topreward", "xvla", "eo1", "hilserl", "vla-jepa", "async", "peft", "dev", "notebook", "test", "video-benchmark", "aloha", "pusht", "libero", "metaworld", "all"]
|
||||
provides-extras = ["dataset", "training", "hardware", "viz", "core-scripts", "evaluation", "dataset-viz", "av-dep", "pygame-dep", "placo-dep", "transformers-dep", "grpcio-dep", "accelerate-dep", "can-dep", "peft-dep", "scipy-dep", "diffusers-dep", "qwen-vl-utils-dep", "matplotlib-dep", "pyserial-dep", "deepdiff-dep", "pynput-dep", "pyzmq-dep", "motorbridge-dep", "motorbridge-smart-servo-dep", "feetech", "dynamixel", "damiao", "robstride", "openarms", "gamepad", "hopejr", "lekiwi", "unitree-g1", "reachy2", "rebot", "kinematics", "intelrealsense", "phone", "diffusion", "wallx", "pi", "molmoact2", "smolvla", "multi-task-dit", "groot", "sarm", "robometer", "topreward", "xvla", "eo1", "hilserl", "vla-jepa", "async", "peft", "dev", "notebook", "test", "video-benchmark", "aloha", "pusht", "libero", "metaworld", "all"]
|
||||
|
||||
[[package]]
|
||||
name = "librt"
|
||||
@@ -3740,6 +3767,58 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/43/e3/7d92a15f894aa0c9c4b49b8ee9ac9850d6e63b03c9c32c0367a13ae62209/mpmath-1.3.0-py3-none-any.whl", hash = "sha256:a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c", size = 536198, upload-time = "2023-03-07T16:47:09.197Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "msgpack"
|
||||
version = "1.2.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/92/23/6139781ca7aadf656fa8e384fa84693ffb13f299e6931b6526427fe5e297/msgpack-1.2.0.tar.gz", hash = "sha256:8e17af38197bf58e7e819041678f6178f4491493f5b8c8580414f40f7c2c3c41", size = 183017, upload-time = "2026-06-11T04:16:10.775Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/44/07/dcb13f37e670257c8d0e944f116c799c34ac6968ecb48c83619f7e91d8b5/msgpack-1.2.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:e2d6047ccd11a12c96a69f2bfe026471abef67334c3d0494a93e5310e45140a2", size = 82888, upload-time = "2026-06-11T04:15:08.992Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/84/5f/6643b2a6a36ca4bc73c7674831be1d4d581cceecc7eb019dba1915951739/msgpack-1.2.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:0347e3ac0dfee99086d3b68fe959da3f5f657c0019ddbaeaaa259a85f8603422", size = 82223, upload-time = "2026-06-11T04:15:10.182Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/2c/c8/9e1668b9897358e5ab39a18142e38be3cf15807e643757782da9f4a53cb3/msgpack-1.2.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:25552ff1f2ff3dc8333e27eabb94f702da5929ed0e07969688194a3e9f12e151", size = 409700, upload-time = "2026-06-11T04:15:11.441Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/38/ed/b7728573156d70b6b094233b0f38d876fc37340826cf852347ec2c7ca8ca/msgpack-1.2.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a0d94420d9d52c56568159a69200af7e45eadb29615fa9d09fada140de1c38c7", size = 420090, upload-time = "2026-06-11T04:15:12.868Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/3f/f7/5ea755a89868c04f9cdf6d96d2d99da4b3d198af10e76a6082dd0fceccc0/msgpack-1.2.0-cp312-cp312-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:d16e1f2db4a9eebc07b7cc91898d71e710f2eed8358711a605fee802caff8923", size = 378538, upload-time = "2026-06-11T04:15:14.511Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/80/2d/126e59332a439c94ffd682c38ca0102b23480e2784b3dac48d8959b0bbac/msgpack-1.2.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:e9cb2e700e85f1e27bbb5c9de6cc1c9a4bc5ac64d5404bdcbcb37a0dc7a947a3", size = 399468, upload-time = "2026-06-11T04:15:16.133Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/da/f9/7abcef683a0ad2e5ab3a4940344aad9f20cdf1f42057ecb0982cf55085d6/msgpack-1.2.0-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:717d0b166dd176a5f786aeafff081f6439680acf5af193eb63e6266c12b04d3d", size = 374212, upload-time = "2026-06-11T04:15:17.536Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/27/23/2d62cf0e971678e96f8a3cfa9bd77fb719ddb98da73790f63c53fd847ad8/msgpack-1.2.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:e87c7a21654d18111eb1a89bd5c42baba42e61887365d9e89585e112b4203f9e", size = 414361, upload-time = "2026-06-11T04:15:18.99Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/32/fb/f5c153f614037aaf802d291a4653ba1bb731f56feacba886f7c21c109e56/msgpack-1.2.0-cp312-cp312-win32.whl", hash = "sha256:967e0c891f5f23ab65762f2e5dc95922759c79f1ef99ef4c7e1fdd863e0d0af9", size = 64389, upload-time = "2026-06-11T04:15:20.237Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/90/af/8aafce6e5544b43b84cb670aca40c8bea7eb5ae8f42bfcbdc7098739987a/msgpack-1.2.0-cp312-cp312-win_amd64.whl", hash = "sha256:6c23e33cee28dcffa112ae205661da4636fd7b06bd9ad1559a890623b92d060b", size = 71185, upload-time = "2026-06-11T04:15:21.51Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ba/08/9cc94be1fc1fe3d1379d439326259aef0344274f64623a8138feb54dff68/msgpack-1.2.0-cp312-cp312-win_arm64.whl", hash = "sha256:6eeb771571f63f68045433b1a35c0256b946f31ed62f006997e40b8ad8b735af", size = 64481, upload-time = "2026-06-11T04:15:22.639Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/7d/26/2902c6946ab5c8fe1e46e40842dfc32b8824464ad5cd4725364fd83f7a58/msgpack-1.2.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:3a1d30df1f302f2b7a7404afbac2ab76d510036c34cf34dffb01f704a7288e45", size = 82621, upload-time = "2026-06-11T04:15:23.844Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/c9/59/7e6b812629d2f919e586041bffc130e1af32079f71bb20699eed54ed6d92/msgpack-1.2.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:581e317112260d8ca488d490cad9290a5682276f309c41c7de237a85ed8799c8", size = 81866, upload-time = "2026-06-11T04:15:25.032Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/31/13/8c291196e60aafdbae38f482205d79432297749ac5d412fe638154fb6f1d/msgpack-1.2.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c6827d12eacc16873eba62408a1b7bbe8ecfb4a8f7ed78a631ae9bae6ad43cf2", size = 405618, upload-time = "2026-06-11T04:15:26.235Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/fb/63/68f5d0ea81e167db5f59ddb94dc6f837667062113feff1c73fabf8907061/msgpack-1.2.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a186027e4279efa4c8bf06ce30605498d7d0d3af0fba0b9799dce85a3fd4a93c", size = 416468, upload-time = "2026-06-11T04:15:27.732Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/73/58/567dddf5c5a2790f673bcd7d80c83466d68e5ee9a9674ebca3db8101c0c8/msgpack-1.2.0-cp313-cp313-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:a96142c14a11cf1a509e8b9aaf72858a3b742b7613e095ce646913e88ce7bd99", size = 374464, upload-time = "2026-06-11T04:15:29.286Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/0d/30/0c2342fc9092e4498045f5f60bca6ccbe4f4d87789778c2300e6fd6efe82/msgpack-1.2.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:50c220579b68a6085b95408b2eaa486b259520f55d8e363ddc9b5d7ba5a6ac6d", size = 395879, upload-time = "2026-06-11T04:15:30.973Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b9/11/9565b29b58ce3c33e177b490478b7aaeb8f726ecaaeda26d815893c1db5a/msgpack-1.2.0-cp313-cp313-musllinux_1_2_riscv64.whl", hash = "sha256:4dcb9d12ab100ecacdfaaf37a3d72fe8392eacc7054afc1916b12d1b747c8446", size = 371749, upload-time = "2026-06-11T04:15:32.418Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f2/da/7bade19d60b73e2ef73fb76aaf4504c112a70cb760951b7202a0c64b5111/msgpack-1.2.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:a804727188ab0ebb237fadb303b743f04925a69d8c3247292d1e33e679767c15", size = 410416, upload-time = "2026-06-11T04:15:34.053Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/6d/14/c0c619571c02432208a5977a8dbdd3fc65fe1369f8226ca4b6d08cca87d8/msgpack-1.2.0-cp313-cp313-win32.whl", hash = "sha256:1a1ac6ae1fe23298f79380e7b144c8a454e5d05616b0096584f353ba2d750114", size = 64357, upload-time = "2026-06-11T04:15:35.535Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/50/a5/de06718460909aa965737fec4cfe8a15dedc6544a8c55feeb6956fa0d6e3/msgpack-1.2.0-cp313-cp313-win_amd64.whl", hash = "sha256:1c3c80949d79578f9dc85fd9fb91edfe6694e8a729cd5744634d59d8455fdde3", size = 71057, upload-time = "2026-06-11T04:15:36.83Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/c7/52/73446b0141c94a856e22b787c56709c0815fc34f185326577e15b26d8cfe/msgpack-1.2.0-cp313-cp313-win_arm64.whl", hash = "sha256:fcf8f76fa587c2395fd0057c7232dbf071241f9ad280b235adb7ab585289989e", size = 64490, upload-time = "2026-06-11T04:15:38.001Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/35/3d/a7e3cdafa8c0cf36c81e2fa848ec4d30cf089459af45b390ad03f9ce6f49/msgpack-1.2.0-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:f854fa1a8b55d75d82ef9a905d9cdbeffdf7897c088f6020bd221867da5e56a5", size = 83032, upload-time = "2026-06-11T04:15:39.38Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ca/aa/53ddfba0e347cc4b484e95f629c5850b9e800ca8390c91ffc604407acf87/msgpack-1.2.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:e90df581f80f53b372d5d9d9349078d729851a3a0d0bd74f53ccb598d01e45b8", size = 82600, upload-time = "2026-06-11T04:15:40.609Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/59/fd/e64c2c776e6dbad0af3c963fe0c0dd1ee1ba09efac478b233ab1db41868f/msgpack-1.2.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b276ed50d8ac75d1f134a433ae79af8557d0fa25ee5b4737da533dfc2ce382e8", size = 404342, upload-time = "2026-06-11T04:15:41.87Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/1b/60/fb9a08e6ccba882dfd370a5837fe3a07572938fdfe954f0f17fdf3e574b9/msgpack-1.2.0-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:544d972459c92aa32e63b800d07c2d9cf2734a3be29cee3a0b478a622850e9f5", size = 412351, upload-time = "2026-06-11T04:15:43.253Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/37/4d/df5c575c274fedc68ac9c6c61d045161899efad2afcdc25138efa7edde69/msgpack-1.2.0-cp314-cp314-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:a070147cc2cf6b8a891734e0f5c8fe8f70ed8739ab30ba140b058005a6e86af4", size = 373331, upload-time = "2026-06-11T04:15:44.754Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/7d/a4/c8b98f8191e985ed2003d87664ce3c95cca41db5d0cf6bf4f54327d32ec8/msgpack-1.2.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:7685e23b0f51745a751629c31713fbefdef8896b31b2bb38299dfa4ae6c0740c", size = 394654, upload-time = "2026-06-11T04:15:46.423Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d4/49/76f036720a602ea24428cfec5ec806f2487c0380b1bff0a2aa3094e15f87/msgpack-1.2.0-cp314-cp314-musllinux_1_2_riscv64.whl", hash = "sha256:b9204daeee8d91a7ae5acf2d2a8e3983be9a3025f38aa21bfaefbd7eea84a7dc", size = 370624, upload-time = "2026-06-11T04:15:48.062Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/9f/38/40af3d29232833705a43b0fce0d07425cc280a7b92ab2b29932425b40df4/msgpack-1.2.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:bfc057248609742ebbabf6bcd27fea4fd99c4980584e613c168c9b002318298f", size = 408038, upload-time = "2026-06-11T04:15:49.669Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/30/b2/f140ca450524dff4d8d0eb81eb9ed75f8f3e0b1f12e49c5b01617cfa0b1c/msgpack-1.2.0-cp314-cp314-win32.whl", hash = "sha256:a3faa7edf2388337ae849239878e92f0298b4dab4488e4f1834062f9d0c410c9", size = 65823, upload-time = "2026-06-11T04:15:51.062Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/4d/13/6517bf966b841c7675ded30701a068ce141f3e698a27aaa35c702d8e078b/msgpack-1.2.0-cp314-cp314-win_amd64.whl", hash = "sha256:1a3effc392a57744e4681e55d05f97d5ee7b598747d718340a9b4b8a970c40e1", size = 72484, upload-time = "2026-06-11T04:15:52.289Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/45/8c/1d948420fdaa24de4efdb8012a6a5bebe09c82ee002b8c2ca745e9917f1f/msgpack-1.2.0-cp314-cp314-win_arm64.whl", hash = "sha256:56a318f7df6bec7b40928d6b0519961f20a510d8baabf6baa393a70444588f0a", size = 66657, upload-time = "2026-06-11T04:15:53.583Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/39/16/1674faa1b7bddc19e79b465fd8e88e2cf4e3f7cae90723740701e8541068/msgpack-1.2.0-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:afa4a65ab2097795e771a74a3a81ea49534aaeba874eaf426a3332268e045ae6", size = 86093, upload-time = "2026-06-11T04:15:54.98Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/dd/24/f241bcfdd9e96b2246289357c5a5e5a496189fd41c5844bee802c116aac7/msgpack-1.2.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:409550770632bb28daa70a11d0ed5763f7db38f40b06f7db9f11dd2794d01102", size = 86372, upload-time = "2026-06-11T04:15:56.381Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/94/c9/57f8ab98a1b21808c27b6dd6029053e0a796ffbb9b371e460dbe997011a9/msgpack-1.2.0-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:bf47e3cd11ce044965a9736a322afdd390b31ed602d1c1b10211d1a841f1d587", size = 428207, upload-time = "2026-06-11T04:15:57.739Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/17/6b/4fd4aa739f131ded751ca7167c8ee87d2aab32506ebbeea893b60b51d343/msgpack-1.2.0-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:204bc9f5d6e59c1718c0a4a84fc8ff71b5b4562faac257c1a68bca611ecf9b72", size = 426082, upload-time = "2026-06-11T04:15:59.356Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f9/00/db88e9a08fcd6513decaad06cbd5c168142bc3e662fb2f1aca3a563b7aa1/msgpack-1.2.0-cp314-cp314t-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:610154307b27267266368bc1d1c7bb8aeb71da7be9356d403cb2442d9e6399f5", size = 378355, upload-time = "2026-06-11T04:16:00.916Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/54/84/eee4dd703d7a600cf46159d621c070b0b9468cf3dbade4ea8272bf5232a4/msgpack-1.2.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:6799f157bb63e79f11e2e590cfdb28423fc18dd60c270c3914b5b4586ae36f7e", size = 410848, upload-time = "2026-06-11T04:16:02.745Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/12/0a/195e2c549fd4631eb7f157d016ff15a10c4c1cf82b6d0a9b1edaef5174b1/msgpack-1.2.0-cp314-cp314t-musllinux_1_2_riscv64.whl", hash = "sha256:72bd844902cf0a5ac3af2ef742f253cd0b1e5bcd184f49b4fb9a6a1f7bf305e8", size = 376152, upload-time = "2026-06-11T04:16:04.041Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/45/9b/bdd143fa79baec411dc658f5686fed680a18b36fcea5fccb6af1b8c7d832/msgpack-1.2.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:3c0bd450f78d0d81722c80da6cdbf674a856967870a9db2f6c4debc4d8b3c67c", size = 417061, upload-time = "2026-06-11T04:16:05.63Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/2d/ce/011ffcd8b919f55196ec53f12ae162e21c879d95afba226894314ff62c07/msgpack-1.2.0-cp314-cp314t-win32.whl", hash = "sha256:378caf74c4c718dfc17590ce68a6d710ed398ff6fcf08237de23b77755730b55", size = 70782, upload-time = "2026-06-11T04:16:07.105Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/57/a8/9b8791ca96b1be6b9f659c718271e2cb7f99f73f58aad2dd0b30f750f6c0/msgpack-1.2.0-cp314-cp314t-win_amd64.whl", hash = "sha256:553b42598165c4dd3235994fd6e4b0dfb1ce5f3fd33d94ba9609442643015f38", size = 77899, upload-time = "2026-06-11T04:16:08.353Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/5b/04/3fa2dffb87bf598696b86bde7cd642d0a7590520c3fa24cd19611dfebeb7/msgpack-1.2.0-cp314-cp314t-win_arm64.whl", hash = "sha256:2825bb1da548d214ab8a810906b7dd69a10f3838b615a2cc46e5172d3cb44f6e", size = 71004, upload-time = "2026-06-11T04:16:09.556Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mujoco"
|
||||
version = "3.8.1"
|
||||
|
||||
Reference in New Issue
Block a user