mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-17 16:27:04 +00:00
Compare commits
10 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| b968020ec4 | |||
| fc019d3902 | |||
| 87242cfced | |||
| 1edc83a0ef | |||
| 6fbcf67249 | |||
| 41166b39fb | |||
| 79c6821407 | |||
| 507083249f | |||
| bd22407d93 | |||
| 49755a3d9e |
Binary file not shown.
|
After Width: | Height: | Size: 445 KiB |
@@ -178,9 +178,3 @@ test-smolvla-ete-eval:
|
||||
--env.episode_length=5 \
|
||||
--eval.n_episodes=1 \
|
||||
--eval.batch_size=1
|
||||
|
||||
# E2E annotation pipeline smoke test against a tiny in-memory fixture
|
||||
# dataset. Opt-in (not part of `make test-end-to-end`) and uses a stub VLM
|
||||
# backend, so it does not require a real model checkpoint or GPU.
|
||||
annotation-e2e:
|
||||
uv run python -m tests.annotations.run_e2e_smoke
|
||||
|
||||
@@ -58,7 +58,7 @@ action = model.select_action(obs)
|
||||
robot.send_action(action)
|
||||
```
|
||||
|
||||
**Supported Hardware:** SO100, LeKiwi, Koch, HopeJR, OMX, EarthRover, Reachy2, Gamepads, Keyboards, Phones, OpenARM, Unitree G1.
|
||||
**Supported Hardware:** SO100, LeKiwi, Koch, HopeJR, OMX, EarthRover, Reachy2, Gamepads, Keyboards, Phones, OpenARM, Unitree G1, reBot B601.
|
||||
|
||||
While these devices are natively integrated into the LeRobot codebase, the library is designed to be extensible. You can easily implement the Robot interface to utilize LeRobot's data collection, training, and visualization tools for your own custom robot.
|
||||
|
||||
@@ -101,11 +101,13 @@ lerobot-train \
|
||||
--dataset.repo_id=lerobot/aloha_mobile_cabinet
|
||||
```
|
||||
|
||||
| Category | Models |
|
||||
| -------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| **Imitation Learning** | [ACT](./docs/source/policy_act_README.md), [Diffusion](./docs/source/policy_diffusion_README.md), [VQ-BeT](./docs/source/policy_vqbet_README.md), [Multitask DiT Policy](./docs/source/policy_multi_task_dit_README.md) |
|
||||
| **Reinforcement Learning** | [HIL-SERL](./docs/source/hilserl.mdx), [TDMPC](./docs/source/policy_tdmpc_README.md) & QC-FQL (coming soon) |
|
||||
| **VLAs Models** | [Pi0Fast](./docs/source/pi0fast.mdx), [Pi0.5](./docs/source/pi05.mdx), [GR00T N1.5](./docs/source/policy_groot_README.md), [SmolVLA](./docs/source/policy_smolvla_README.md), [XVLA](./docs/source/xvla.mdx) |
|
||||
| Category | Models |
|
||||
| -------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| **Imitation Learning** | [ACT](./docs/source/policy_act_README.md), [Diffusion](./docs/source/policy_diffusion_README.md), [VQ-BeT](./docs/source/policy_vqbet_README.md), [Multitask DiT Policy](./docs/source/policy_multi_task_dit_README.md) |
|
||||
| **Reinforcement Learning** | [HIL-SERL](./docs/source/hilserl.mdx), [TDMPC](./docs/source/policy_tdmpc_README.md) & QC-FQL (coming soon) |
|
||||
| **VLAs Models** | [Pi0](./docs/source/pi0.mdx), [Pi0Fast](./docs/source/pi0fast.mdx), [Pi0.5](./docs/source/pi05.mdx), [GR00T N1.5](./docs/source/policy_groot_README.md), [SmolVLA](./docs/source/policy_smolvla_README.md), [XVLA](./docs/source/xvla.mdx), [EO-1](./docs/source/eo1.mdx), [MolmoAct2](./docs/source/molmoact2.mdx), [WALL-OSS](./docs/source/walloss.mdx) |
|
||||
| **World Models** | [VLA-JEPA](./docs/source/vla_jepa.mdx) (more coming soon) |
|
||||
| **Reward Models** | [SARM](./docs/source/sarm.mdx), [TOPReward](./docs/source/topreward.mdx), [Robometer](./docs/source/robometer.mdx) |
|
||||
|
||||
Similarly to the hardware, you can easily implement your own policy & leverage LeRobot's data collection, training, and visualization tools, and share your model to the HF Hub
|
||||
|
||||
@@ -133,6 +135,7 @@ Learn how to implement your own simulation environment or benchmark and distribu
|
||||
- **[Discord](https://discord.gg/q8Dzzpym3f):** Join the `LeRobot` server to discuss with the community.
|
||||
- **[X](https://x.com/LeRobotHF):** Follow us on X to stay up-to-date with the latest developments.
|
||||
- **[Robot Learning Tutorial](https://huggingface.co/spaces/lerobot/robot-learning-tutorial):** A free, hands-on course to learn robot learning using LeRobot.
|
||||
- **[T-Shirt Folding Experiment](https://huggingface.co/spaces/lerobot/robot-folding):** An end-to-end demonstration of folding t-shirts with LeRobot.
|
||||
|
||||
## Citation
|
||||
|
||||
@@ -140,7 +143,7 @@ If you use LeRobot in your project, please cite the GitHub repository to acknowl
|
||||
|
||||
```bibtex
|
||||
@misc{cadene2024lerobot,
|
||||
author = {Cadene, Remi and Alibert, Simon and Soare, Alexander and Gallouedec, Quentin and Zouitine, Adil and Palma, Steven and Kooijmans, Pepijn and Aractingi, Michel and Shukor, Mustafa and Aubakirova, Dana and Russi, Martino and Capuano, Francesco and Pascal, Caroline and Choghari, Jade and Moss, Jess and Wolf, Thomas},
|
||||
author = {Cadene, Remi and Alibert, Simon and Soare, Alexander and Gallouedec, Quentin and Zouitine, Adil and Palma, Steven and Kooijmans, Pepijn and Aractingi, Michel and Shukor, Mustafa and Aubakirova, Dana and Russi, Martino and Capuano, Francesco and Pascal, Caroline and Choghari, Jade and Meftah, Khalil and Ellerbach, Maxime and Moss, Jess and Wolf, Thomas},
|
||||
title = {LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch},
|
||||
howpublished = "\url{https://github.com/huggingface/lerobot}",
|
||||
year = {2024}
|
||||
|
||||
@@ -0,0 +1,417 @@
|
||||
# Decoupled VLA Inference & Edge Control: System Design Proposal
|
||||
|
||||
## 1. Executive Summary
|
||||
|
||||
This document proposes a production-grade system for decoupling GPU-bound VLA (Vision-Language-Action) policy inference from high-frequency, CPU-bound robot control in LeRobot. The system adopts a **Model-as-a-Service (MaaS)** paradigm using **Zenoh** as the sole transport protocol, enabling multiple edge devices to be served by centralized GPU servers with minimal latency and high reliability.
|
||||
|
||||
An initial prototype exists in `src/lerobot/async_inference/` (gRPC-based, single-client). This proposal defines the target architecture, identifies gaps between the prototype and production requirements, documents known bugs, and establishes the design for the new system.
|
||||
|
||||
---
|
||||
|
||||
## 2. Motivation
|
||||
|
||||
LeRobot's standard control loop runs policy inference and robot I/O in the same process. This works for lightweight policies on local GPUs, but breaks down when:
|
||||
|
||||
- **The policy is too large for edge hardware** (e.g., Pi0 at ~3B parameters requires a dedicated GPU).
|
||||
- **Multiple robots need the same policy** (redundant GPU allocation per robot).
|
||||
- **Inference latency exceeds the control deadline** (e.g., 200ms inference on a 33ms control loop at 30 FPS).
|
||||
|
||||
Decoupling inference from control solves all three: the edge device runs a tight I/O loop on a CPU, while a GPU server handles inference for one or more clients.
|
||||
|
||||
---
|
||||
|
||||
## 3. Core Architectural Principles
|
||||
|
||||
### 3.1 Model-as-a-Service (MaaS)
|
||||
|
||||
Servers initialize models **once at startup** from a configuration manifest. Edge devices do **not** trigger dynamic model loading — they route to pre-warmed servers and validate compatibility via a status endpoint.
|
||||
|
||||
### 3.2 Multi-Tenant & Stateless Inference
|
||||
|
||||
A single GPU server handles multiple edge devices executing the same task. The server is stateless per inference call — `predict_action_chunk()` is a pure function with no side effects on the model. Client isolation is achieved through per-client observation slots and Zenoh key-expression routing.
|
||||
|
||||
> **Invariant**: `predict_action_chunk()` must remain a pure function (no mutation of `self`) for all supported policies. This is what enables safe multi-tenant sharing of a single model instance. This invariant must be documented and tested.
|
||||
|
||||
### 3.3 Zenoh as primary Transport
|
||||
|
||||
The system uses Zenoh's pub/sub model, replacing the current gRPC implementation. Zenoh provides:
|
||||
|
||||
- **Hierarchical key expressions** for routing (natural fit for the cluster/experiment/model/task topology).
|
||||
- **Built-in discovery** (no external service discovery needed).
|
||||
- **Non-blocking publish** for observations (fire-and-forget with best-effort QoS).
|
||||
- **Reliable delivery** configurable per-topic (required for action chunks).
|
||||
- **Shared-memory transport** for same-machine deployments (zero-copy) (if available).
|
||||
|
||||
### 3.4 Local Edge CPU
|
||||
|
||||
Edge devices rely on standard CPUs for sensor polling, image compression, payload serialization, motor control, and data logging. No edge-GPU dependency.
|
||||
|
||||
---
|
||||
|
||||
## 4. System Topology
|
||||
|
||||

|
||||
|
||||
- **Cluster**: A set of GPU machines. Identified by `cluster_uuid`.
|
||||
- **Experiment**: A logical grouping of servers and clients. Identified by `experiment_tag`.
|
||||
- **Server**: One model + one task, pre-warmed. Serves N clients for that model/task combination.
|
||||
- **Client**: One robot, one task. Publishes observations, subscribes to actions.
|
||||
|
||||
The number of clients a single server can handle is a **user decision** based on model inference time and acceptable latency.
|
||||
|
||||
---
|
||||
|
||||
## 5. Component Specifications
|
||||
|
||||
### 5.1 The Edge Device (Client)
|
||||
|
||||
**Responsibilities:**
|
||||
|
||||
1. **Observation capture**: Read sensors (cameras, motors) at the control loop frequency.
|
||||
2. **Image compression**: JPEG-encode RGB images before transmission.
|
||||
3. **Observation publishing**: Non-blocking Zenoh put to the observation topic.
|
||||
4. **Action subscription**: Zenoh callback receives action chunks, deposits into local buffer.
|
||||
5. **Action execution**: Pop actions from buffer, send to robot at control frequency.
|
||||
6. **Action blending**: When a new action chunk overlaps with the current buffer, blend via configurable aggregation function (weighted average, latest-only, etc.).
|
||||
7. **Latency compensation**: Calculate one-way latency from RTT, discard expired initial steps of incoming action chunks.
|
||||
8. **Fail-safe**: If action buffer empties, logs a warning.
|
||||
9. **Data logging**: Record raw observations and executed actions to local `LeRobotDataset` storage for deferred upload.
|
||||
|
||||
**Threading model:**
|
||||
|
||||
- **Control loop thread** (main): Capture observation → deposit in outbox → pop action from buffer → send to robot → sleep to maintain frequency.
|
||||
- **Zenoh action callback** (Zenoh-managed): Receives action chunks, processes RTT, trims stale steps, deposits into action buffer.
|
||||
- **Observation publisher thread**: Drains the outbox, compresses images, serializes, publishes via Zenoh.
|
||||
|
||||
> **Design note**: The current prototype blocks on `send_observation` inside the control loop (BUG-1, see Section 9). The new design decouples observation publishing from the control loop entirely, using a separate thread and Zenoh's non-blocking put.
|
||||
|
||||
### 5.2 The Inference Server (GPU Pod)
|
||||
|
||||
**Responsibilities:**
|
||||
|
||||
1. **Model pre-warming**: Load model and processor pipelines at startup from config manifest (including expected clients & policy parameters).
|
||||
2. **Status publishing**: Expose model capabilities (policy type, expected camera names, resolutions, action dimensions) via Zenoh queryable.
|
||||
3. **Observation subscription**: Subscribe to observation topics for all clients of this model/task. Maintain per-client observation slots (newest-only semantics).
|
||||
4. **Inference**: Single inference thread processes observations sequentially (round-robin across clients). Calls `policy.predict_action_chunk()`.
|
||||
5. **Action publishing**: Publish action chunks to per-client action topics with reliable QoS.
|
||||
|
||||
> **Thread safety**: PyTorch's `model.forward()` is not guaranteed thread-safe. Inference will be sequential, latency is mostly about the capabilities of the server to serve multiple requests.
|
||||
|
||||
---
|
||||
|
||||
## 6. Zenoh Routing & Key Expressions
|
||||
|
||||
### 6.1 Key Expression Schema
|
||||
|
||||
```
|
||||
[cluster_uuid] / [experiment_tag] / [model_id] / [model_version] / [application_tag] / [client_uuid] / [topic]
|
||||
```
|
||||
|
||||
**Example key expressions:**
|
||||
|
||||
| Key Expression | Direction | Purpose |
|
||||
| ------------------------------------------------ | ----------------- | ---------------------------------- |
|
||||
| `jupiter/fabio2/pi0/v1/cookie/robot_a4b9/obs` | Client → Server | Observation payload |
|
||||
| `jupiter/fabio2/pi0/v1/cookie/robot_a4b9/action` | Server → Client | Action chunk |
|
||||
| `jupiter/fabio2/pi0/v1/cookie/*/obs` | Server subscribes | All observations for pi0/v1/cookie |
|
||||
| `jupiter/fabio2/pi0/v1/cookie/status` | Server publishes | Model capabilities (queryable) |
|
||||
|
||||
### 6.2 QoS Configuration
|
||||
|
||||
| Topic | Reliability | Rationale |
|
||||
| -------- | ----------- | -------------------------------------------------------------------- |
|
||||
| `obs` | Best-effort | Dropping stale observations is expected behavior. |
|
||||
| `action` | Reliable | Every action chunk must be delivered; loss causes action starvation. |
|
||||
| `status` | Reliable | Client needs accurate capability info before starting. |
|
||||
|
||||
### 6.3 Discovery Flow
|
||||
|
||||
0. Server goes up with the static configuration.
|
||||
1. Client constructs its target key prefix: `cluster/experiment/model/version/task/`.
|
||||
2. Client queries `cluster/experiment/model/version/task/status` (Zenoh queryable).
|
||||
3. Server responds with its capabilities (expected camera names, image resolutions, action dimensions, model metadata).
|
||||
4. Client validates its own configuration against server capabilities.
|
||||
5. On match: client starts publishing observations and subscribing to actions.
|
||||
6. On mismatch: client logs an error and refuses to start.
|
||||
|
||||
No dynamic client discovery for now.
|
||||
|
||||
---
|
||||
|
||||
## 7. Message Schema
|
||||
|
||||
### 7.1 Observation Payload (Client → Server)
|
||||
|
||||
| Field | Type | Purpose |
|
||||
| ------------- | ------------------ | ----------------------------------------------------------- |
|
||||
| `seq_id` | `uint64` | Incrementing ID for causality tracking and RTT computation. |
|
||||
| `client_uuid` | `string` | Identifies the sending client. |
|
||||
| `state` | `bytes` | Proprioceptive state vector (`numpy.tobytes()`). |
|
||||
| `images` | `dict[str, bytes]` | JPEG-compressed camera images, keyed by camera name. |
|
||||
| `task` | `string` | Natural-language task instruction (for VLA conditioning). |
|
||||
|
||||
### 7.2 Action Payload (Server → Client)
|
||||
|
||||
| Field | Type | Purpose |
|
||||
| -------------------- | --------- | --------------------------------------------------------------- |
|
||||
| `response_to_seq_id` | `uint64` | Echoes the observation `seq_id` this action corresponds to. |
|
||||
| `inference_time_ms` | `float32` | Server-side compute duration (for edge RTT math). |
|
||||
| `actions` | `bytes` | Action chunk as numpy array bytes (`(chunk_size, action_dim)`). |
|
||||
|
||||
### 7.3 Status Payload (Server, Queryable)
|
||||
|
||||
| Field | Type | Purpose |
|
||||
| ----------------------- | ------------------- | ------------------------------------------ |
|
||||
| `model_id` | `string` | Policy identifier (e.g., `pi0`). |
|
||||
| `model_version` | `string` | Model version or checkpoint path. |
|
||||
| `expected_cameras` | `dict[str, (H, W)]` | Expected camera names and shapes. |
|
||||
| `action_dim` | `int` | Dimensionality of the action space. |
|
||||
| `max_actions_per_chunk` | `int` | Maximum chunk size the model supports. |
|
||||
| `observation_features` | `dict` | Full feature specification for validation. |
|
||||
|
||||
### 7.4 Serialization Format
|
||||
|
||||
**MessagePack** for all structured metadata (compact, fast, cross-language). Image payloads are raw JPEG bytes embedded in the MessagePack structure. State vectors use `numpy.tobytes()` with shape/dtype metadata for zero-copy reconstruction.
|
||||
|
||||
**No pickle.** The current prototype uses `pickle.dumps`/`pickle.loads` throughout, which allows arbitrary code execution. This is replaced entirely.
|
||||
|
||||
---
|
||||
|
||||
## 8. Latency Compensation
|
||||
|
||||
### 8.1 RTT Calculation
|
||||
|
||||
The edge device tracks in-flight observations:
|
||||
|
||||
```python
|
||||
in_flight: dict[int, float] = {} # seq_id -> time.perf_counter() at send
|
||||
|
||||
# On send:
|
||||
in_flight[seq_id] = time.perf_counter()
|
||||
|
||||
# On receive action chunk:
|
||||
rtt = time.perf_counter() - in_flight[response_to_seq_id]
|
||||
# delete older keys than the one received
|
||||
```
|
||||
|
||||
> **Important**: Delete only the exact `response_to_seq_id` key from `in_flight`, not all keys `<= response_to_seq_id`. With Zenoh's best-effort transport, messages can arrive out of order. Clearing earlier keys would make their RTT unmeasurable.
|
||||
|
||||
### 8.2 Stale Action Trimming
|
||||
|
||||
When an action chunk arrives, the edge calculates how many initial steps have already expired:
|
||||
|
||||
```python
|
||||
expired_steps = int(rtt / environment_dt)
|
||||
valid_actions = action_chunk[expired_steps:]
|
||||
```
|
||||
|
||||
The valid actions are then blended into the action buffer using the configured aggregation function.
|
||||
|
||||
### 8.3 Edge Cases
|
||||
|
||||
| Scenario | Behavior |
|
||||
| -------------------------------------- | -------------------------------------------------------------------------------------- |
|
||||
| **First observation** (no RTT history) | Apply all action steps without trimming. |
|
||||
| **Dropped observations** | Server infers on next received observation. No special handling needed. |
|
||||
| **Dropped action chunks** | Edge continues executing current buffer. If buffer empties, warn & hold last position. |
|
||||
| **Server crash** | Edge exhausts buffer, holds position, warns & re-validates via status query. |
|
||||
|
||||
> **Assumption**: All currently supported robots are position-controlled (SO100, SO101, OMX). For velocity-controlled robots, the fail-safe must send zero-velocity instead of holding position. This should be configurable per-robot.
|
||||
|
||||
---
|
||||
|
||||
## 9. Known Bugs in Current Prototype
|
||||
|
||||
These issues exist in `src/lerobot/async_inference/` and must be addressed in the new implementation.
|
||||
|
||||
### BUG-1: `send_observation` Blocks the Control Loop (Critical)
|
||||
|
||||
**Location**: `robot_client.py:207`
|
||||
|
||||
`self.stub.SendObservations(observation_iterator)` is a synchronous gRPC call inside the 33ms control loop. For multi-camera observations (several MB after pickle), this consumes 10-20ms on the network, leaving no headroom for sensor capture and motor commands. The robot stutters.
|
||||
|
||||
**Resolution in new design**: Observation publishing is moved to a dedicated thread. Zenoh's `session.put()` is non-blocking by default. The control loop only deposits observations into a local outbox.
|
||||
|
||||
### BUG-2: Race Condition in Action Queue Aggregation (Correctness)
|
||||
|
||||
**Location**: `robot_client.py:236-267`
|
||||
|
||||
The lock on `self.action_queue` is acquired to read `internal_queue = self.action_queue.queue` (a reference to the internal deque), then **released** at line 238. The aggregation logic iterates over this reference outside the lock. Meanwhile, the control loop thread can `get_nowait()` from the same queue, mutating the deque during iteration. At line 267, the entire queue is replaced, but actions popped between 238-267 are silently lost.
|
||||
|
||||
**Fix**: Either hold the lock for the entire aggregation, or `list(self.action_queue.queue)` to copy contents before releasing.
|
||||
|
||||
### BUG-3: No RPC Deadlines (Reliability)
|
||||
|
||||
**Location**: `robot_client.py:278`
|
||||
|
||||
`GetActions` blocks indefinitely if the server hangs (GPU OOM, deadlock). The retry policy handles `UNAVAILABLE` but not a hung connection.
|
||||
|
||||
**Resolution in new design**: The polling `GetActions` pattern is replaced by Zenoh subscription callbacks. The client needs a watchdog timer or check when action queue is empty: if no actions are received for `T` seconds, trigger re-validation via the status service.
|
||||
|
||||
### BUG-4: Similarity Check Ignores Images (Correctness for VLAs)
|
||||
|
||||
**Location**: `helpers.py:280-297`
|
||||
|
||||
`observations_similar()` + `must_go` is a workaround for current architecure limitations to avoid filling up the server queue the first seconds of the task & the robot remaining idle.
|
||||
|
||||
**Resolution in new design**: the server always processes the latest observation per client in its inference loop, and doesn't need similarity gating at all. The client can always push.
|
||||
|
||||
---
|
||||
|
||||
## 10. Gaps Between Prototype and Target Architecture
|
||||
|
||||
### 10.1 Critical (Must Address)
|
||||
|
||||
| # | Gap | Current State | Target State |
|
||||
| --- | ------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| G1 | **Single-client server** | One `observation_queue(maxsize=1)`, one `last_processed_obs`, one `_predicted_timesteps`. `_reset_server()` flushes all state on any new connection. | Per-client state (`ClientState` dataclass) keyed by `client_uuid`. Zenoh key-expression routing provides client isolation. |
|
||||
| G2 | **Dynamic model loading** | Client sends `RemotePolicyConfig` → server calls `from_pretrained()` on demand. | Server loads models at startup from config manifest. `SendPolicyInstructions` RPC eliminated. Client validates via status query. |
|
||||
| G3 | **gRPC transport** | Entire `transport/` directory: proto definitions, generated stubs, chunking utils. 4 RPCs: `Ready`, `SendPolicyInstructions`, `SendObservations`, `GetActions`. | Zenoh pub/sub. Client publishes obs, subscribes to actions. Server subscribes to obs, publishes actions. Dispatching via key expressions. |
|
||||
| G4 | **Pickle serialization** | `pickle.dumps`/`pickle.loads` throughout (arbitrary code execution risk, `# nosec` suppression). | MessagePack for structured metadata + raw JPEG bytes for images + `numpy.tobytes()` for state vectors. |
|
||||
|
||||
### 10.2 Important
|
||||
|
||||
| # | Gap | Current State | Target State |
|
||||
| --- | -------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | -------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| G5 | **No RTT/latency compensation** | No `seq_id`, no `response_to_seq_id`, no `inference_time_ms`. Timestamps use `time.time()` (unreliable across machines). | Edge-local `perf_counter` + echoed `seq_id` + server inference duration. Stale action step trimming. |
|
||||
| G6 | **No hierarchical routing** | Direct gRPC channel to `host:port`. | Zenoh key expressions: `cluster/experiment/model/version/task/client/topic`. |
|
||||
| G7 | **No data logging** | `control_loop` has access to obs and actions but doesn't persist them. | Edge records via `LeRobotDataset` (`build_dataset_frame` + `dataset.add_frame`). |
|
||||
| G8 | **No authentication** | `grpc.insecure_channel`. | Zenoh TLS + access control lists on key expressions. |
|
||||
| G9 | **ProcessorPipeline divergence** | Server reimplements observation prep in `helpers.py` (custom `resize_robot_observation_image` with `F.interpolate` bilinear). Diverges from standard `RobotProcessorPipeline`. | Use the standard `RobotProcessorPipeline` + `build_dataset_frame` to ensure behavioral equivalence between record and async inference. |
|
||||
|
||||
### 10.3 Nice-to-Have
|
||||
|
||||
| # | Gap | Current State | Target State |
|
||||
| --- | ------------------------------------- | --------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| G11 | **No status/discovery service** | Bare `Ready()` ping. | Zenoh queryable at `cluster/exp/model/version/task/status`. |
|
||||
| G12 | **No monitoring** | `FPSTracker` + `logging.debug`. | Structured metrics via Zenoh telemetry topics. Wildcard subscriptions for centralized monitoring. |
|
||||
| G13 | **No entry points** | Module-level `__main__`. | `lerobot-policy-server` and `lerobot-robot-client` console scripts in `pyproject.toml`. |
|
||||
| G14 | **Ratio-based observation threshold** | `chunk_size_threshold` (0-1 ratio of queue fill). Scales oddly with different `actions_per_chunk` values. | Absolute time threshold: `buffer_time_s` calibrated to observed RTT. Send observation when `queue_size * environment_dt < buffer_time_s`. |
|
||||
|
||||
---
|
||||
|
||||
## 11. Design Decisions & Rationale
|
||||
|
||||
### 11.1 Why Zenoh Over gRPC
|
||||
|
||||
| Aspect | Zenoh | gRPC |
|
||||
| ------------------------- | -------------------------------------------------------------------------- | ---------------------------------------------------------------------------------- |
|
||||
| Communication model | Pub/sub — natural fit for "client publishes obs, server publishes actions" | Request/response — requires polling (`GetActions` loop) or bidirectional streaming |
|
||||
| Multi-tenant routing | Hierarchical key expressions provide built-in per-client topic isolation | Requires manual per-client channel/stream management |
|
||||
| Discovery | Built-in discovery | Requires external service (mDNS, Consul, etc.) |
|
||||
| Observation publishing | Non-blocking put (fire-and-forget) — resolves BUG-1 automatically | Synchronous stream-unary call — blocks the control loop |
|
||||
| Same-machine optimization | Shared-memory transport (zero-copy) | Loopback TCP |
|
||||
| Telemetry | Wildcard subscriptions (`+/+/+/+/+/metrics`) | Requires separate monitoring infrastructure |
|
||||
|
||||
**Tradeoffs of going Zenoh-only:**
|
||||
|
||||
- Smaller community, less tooling for monitoring/tracing vs. gRPC's mature ecosystem.
|
||||
- No built-in schema enforcement (Zenoh sends raw bytes) — serialization correctness is entirely on us.
|
||||
- Default QoS is best-effort (like UDP). Must explicitly configure reliable delivery for action chunks.
|
||||
- `zenoh-python` bindings are less battle-tested than `grpcio`. Needs integration testing under network stress.
|
||||
|
||||
### 11.2 Why Single Inference Thread (Not Batching)
|
||||
|
||||
True GPU batching across clients requires collecting observations from multiple clients and running a single forward pass. This is difficult because:
|
||||
|
||||
- Clients send observations at different times — waiting to batch adds latency.
|
||||
- Different clients may have slightly different image resolutions.
|
||||
- Error in one client's observation shouldn't affect others.
|
||||
|
||||
**Decision**: Start with sequential processing (single inference thread, round-robin across clients). Profile GPU utilization.
|
||||
|
||||
### 11.4 Why MessagePack (Not Protobuf, Not FlatBuffers)
|
||||
|
||||
- **Protobuf**: Strong schema enforcement but heavier toolchain (proto compilation, generated code). Since we're dropping gRPC, the protobuf dependency becomes unnecessary overhead.
|
||||
- **MessagePack**: Fast, compact, schema-less (enforced by application), excellent Python support (`msgpack` package), good for nested dicts with mixed types. Natural fit for observation/action payloads.
|
||||
|
||||
Images are embedded as raw JPEG bytes within the MessagePack structure. State vectors use `numpy.tobytes()` with shape/dtype metadata for zero-copy reconstruction.
|
||||
|
||||
### 11.5 Action Aggregation Strategy
|
||||
|
||||
When a new action chunk overlaps with the existing buffer, the overlapping timesteps must be blended. The current prototype supports configurable aggregation functions:
|
||||
|
||||
| Function | Formula | Character |
|
||||
| ------------------ | ----------------------- | ------------------------------------------ |
|
||||
| `weighted_average` | `0.3 * old + 0.7 * new` | Smooth transitions, favors new predictions |
|
||||
| `latest_only` | `new` | Most responsive, can cause discontinuities |
|
||||
| `average` | `0.5 * old + 0.5 * new` | Equal weight |
|
||||
| `conservative` | `0.7 * old + 0.3 * new` | Smooth, slow to adapt |
|
||||
|
||||
Ultimately, this should be the user's decision. Default to `weighted_average`. The goal of async is not to do temporal ensembling, but to provide a solution when we want to decouple inference and execution.
|
||||
|
||||
---
|
||||
|
||||
## 12. Configuration
|
||||
|
||||
### 12.1 Server Configuration (Manifest)
|
||||
|
||||
Servers are configured via a YAML manifest that declares which models to pre-warm & clients to serve:
|
||||
|
||||
```yaml
|
||||
cluster_uuid: jupiter
|
||||
experiment_tag: fabio2
|
||||
server:
|
||||
- model_id: pi0
|
||||
model_version: v1
|
||||
pretrained_path: lerobot/pi0-cookie-v1
|
||||
application_tag: cookie
|
||||
device: cuda:0
|
||||
fps: 30
|
||||
endpoint: tcp/192.168.1.50:7447
|
||||
clients:
|
||||
- client_uuid: cookie-worker-4269
|
||||
```
|
||||
|
||||
### 12.2 Client Configuration
|
||||
|
||||
Clients are configured via draccus dataclass (CLI-compatible):
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class AsyncClientConfig:
|
||||
# Zenoh routing
|
||||
cluster_uuid: str
|
||||
experiment_tag: str
|
||||
model_id: str
|
||||
model_version: str
|
||||
application_tag: str
|
||||
client_uuid: str
|
||||
endpoint: str
|
||||
|
||||
# Robot
|
||||
robot: RobotConfig
|
||||
|
||||
# Control
|
||||
fps: int = 30
|
||||
actions_per_chunk: int = 50
|
||||
aggregate_fn_name: str = "weighted_average"
|
||||
jpeg_quality: int = 90
|
||||
|
||||
# Fail-safe
|
||||
max_empty_cycles_before_warning: int = 10
|
||||
|
||||
# Datset recording
|
||||
dataset_repo_id: str | None = None # None = no logging
|
||||
|
||||
# Task
|
||||
task: str = ""
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 14. Data Logging Integration
|
||||
|
||||
The client records observations and executed actions into a local `LeRobotDataset` for deferred upload to the training dataset:
|
||||
|
||||
```python
|
||||
# In control_loop, after executing an action:
|
||||
if self.dataset is not None:
|
||||
frame = build_dataset_frame(
|
||||
self.dataset.features,
|
||||
processed_observation,
|
||||
prefix=OBS_STR,
|
||||
)
|
||||
frame["action"] = executed_action_tensor
|
||||
self.dataset.add_frame(frame)
|
||||
```
|
||||
@@ -0,0 +1,498 @@
|
||||
# Decoupled VLA Inference & Edge Control v2: Async Network Inference for `lerobot-rollout`
|
||||
|
||||
> **Status**: supersedes the v1 proposal in full. v1 was written against the standalone `src/lerobot/async_inference/` prototype, before `lerobot-rollout` existed. This revision re-grounds the design in the current codebase, keeps v1's decisions that survived contact with it (marked **KEPT** throughout), reverses the ones that didn't, and adds the safety, multi-tenancy, and operations specifications v1 lacked.
|
||||
|
||||
## 1. Executive Summary
|
||||
|
||||
This document specifies a production-grade system for decoupling GPU-bound policy inference from high-frequency robot control, targeting power users running **hundreds of robots** against centralized GPU clusters. The system keeps v1's **Model-as-a-Service (MaaS)** paradigm and **Zenoh** transport, but changes the integration architecture fundamentally:
|
||||
|
||||
- **The client is not a standalone CLI.** It is `--inference.type=remote`, a new `InferenceEngine` backend inside `lerobot-rollout` (`src/lerobot/rollout/inference/`). Every rollout strategy (base, sentry, highlight, dagger, episodic) gets network inference for free — including dataset recording, DAgger pause/resume, Rerun visualization, and safe teardown.
|
||||
- **The client is weightless.** No policy weights, no policy processors on the edge. `--policy.path` resolves to a config-only `PreTrainedConfig` (no weight download) used for pre-flight validation and action ordering.
|
||||
- **The server is stateless per request.** All RTC chunk state (leftover prefixes, latency tracking, delay computation) lives client-side in the existing `ActionQueue`/`LatencyTracker` machinery — the client ships prefixes + a delay hint with each observation. A server crash loses zero control state; reconnects and horizontal scaling are trivial.
|
||||
- **Multi-tenancy is engineered, not assumed.** The real hazards are stateful processor pipelines and episode-scoped policy state — not `predict_action_chunk` purity (which holds for ACT/Pi0/Pi0.5/SmolVLA but _not_ diffusion). The server uses per-session processor instances, a chunk-stateless allowlist, and an exclusive serving mode for policies that need it.
|
||||
- **The legacy module dies.** `src/lerobot/async_inference/` (~1,900 lines, pickle-over-gRPC, single-client, four confirmed bugs) is deleted in the same PR that lands the new backend. No deprecation cycle: the module is experimental, its CLI undocumented in the main flow, and every config field has a mapped successor (§13.4).
|
||||
|
||||
---
|
||||
|
||||
## 2. Motivation (unchanged from v1) — **KEPT**
|
||||
|
||||
LeRobot's standard control loop runs policy inference and robot I/O in the same process. This breaks down when:
|
||||
|
||||
- **The policy is too large for edge hardware** (Pi0-class models need a dedicated GPU).
|
||||
- **Multiple robots need the same policy** (redundant GPU allocation per robot).
|
||||
- **Inference latency exceeds the control deadline** (e.g. 150 ms inference on a 33 ms control tick).
|
||||
|
||||
Decoupling solves all three: the edge runs a tight CPU loop; a GPU server performs inference for N clients.
|
||||
|
||||
What changed since v1: the _local_ version of this decoupling already shipped. `RTCInferenceEngine` (`src/lerobot/rollout/inference/rtc.py`) runs inference in a background thread against a thread-safe `ActionQueue` with latency-aware chunk merging. **The network system is that same architecture with the thread boundary replaced by a network boundary.** This is the design's central simplification: reuse, don't reinvent.
|
||||
|
||||
---
|
||||
|
||||
## 3. Gap Analysis: v1 Proposal vs. Modern Codebase
|
||||
|
||||
| Topic | v1 assumed | Modern reality | Verdict |
|
||||
| ----------------------------------------- | --------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------- | --------------------------------------- |
|
||||
| Client architecture | Standalone robot-client CLI (§5.1 of v1) | `InferenceEngine` ABC seam in `lerobot-rollout` (`rollout/inference/base.py`); strategies are backend-agnostic | **Superseded** — backend, not CLI |
|
||||
| Chunk blending | Configurable aggregation zoo (`weighted_average`, …) | `ActionQueue` replace-with-delay-trim (RTC) / append (non-RTC) (`policies/rtc/action_queue.py:147-217`) | **Superseded** — drop blending entirely |
|
||||
| Latency compensation | Hand-rolled RTT trim (`expired_steps = int(rtt/dt)`, v1 §8.2) | `ActionQueue.merge(..., real_delay, idx_before)` + `LatencyTracker` already do this, validated | **Superseded** |
|
||||
| Multi-tenancy invariant | "`predict_action_chunk()` pure ⇒ safe to share" | Processor state + episode-scoped policy state are the real hazards (§7) | **Incomplete** — fixed in §8.3 |
|
||||
| Data logging | Client-side `build_dataset_frame` + `add_frame` sketch (v1 §14) | Recording strategies (sentry/episodic/dagger) already log obs + executed actions | **Superseded** — free via rollout |
|
||||
| MaaS pre-warm, no dynamic loading | ✓ | Still right; legacy `SendPolicyInstructions` is a pickle/RCE + capacity-planning disaster | **KEPT** |
|
||||
| JPEG observation compression | ✓ | Still right (§10.1) | **KEPT** |
|
||||
| Status/capability validation before start | ✓ (Zenoh queryable) | Still right; extended into a hard sync-safety contract (§8.4) | **KEPT, extended** |
|
||||
| Time-based send threshold (v1 G14) | ✓ | Adopted as `buffer_time_s` | **KEPT** |
|
||||
| Zenoh pub/sub data plane | ✓ | Confirmed; QoS corrected (§6.3), control plane moved to queryables, liveliness added | **KEPT, hardened** |
|
||||
| MessagePack serialization | ✓ | Endorsed (zenoh's `ext` serializer cannot encode numpy); must be version-gated (§10.4) | **KEPT, with schema discipline** |
|
||||
| QoS table (v1 §6.2) | "obs best-effort, actions reliable" | Conflates transport reliability with congestion control; BLOCK on actions is dangerous | **Revised** (§6.3) |
|
||||
| Bugs BUG-1…BUG-4, gaps G1…G14 | Listed as work items | Every one resolved _structurally_ by this design (§13.5 mapping) | **Resolved by design** |
|
||||
|
||||
---
|
||||
|
||||
## 4. Critical Pushbacks on v1
|
||||
|
||||
Each pushback: claim → evidence → consequence for this design.
|
||||
|
||||
**P1 — A standalone client duplicates `lerobot-rollout`.**
|
||||
v1 §5.1 assigns the client: observation capture, action execution at frequency, fail-safe, data logging. Every one of those is already owned by rollout strategies and `send_next_action` (`rollout/strategies/core.py:269-304`), which tolerates `None` actions, runs the interpolator, and routes through the canonical robot processors. A standalone client re-implements loop timing, recording, DAgger UX, Rerun, and teardown safety — and then drifts. _Consequence_: the client is `RemoteInferenceEngine`, registered as `--inference.type=remote` next to `sync` and `rtc`.
|
||||
|
||||
**P2 — The aggregation-function zoo fabricates actions no policy predicted.**
|
||||
`0.3*old + 0.7*new` produces hybrid actions that exist in no policy's output distribution; the logged action becomes unexplainable (bad for the reproducibility story) and the implementation hosted a real lock-release race (BUG-2, `async_inference/robot_client.py:236-267`). RTC's prefix-conditioned chunk generation is the principled mechanism for smooth chunk transitions; plain append covers non-RTC chunking. _Consequence_: `ActionQueue` replace/append are the only two merge semantics. The zoo is deleted.
|
||||
|
||||
**P3 — "predict_action_chunk pure ⇒ multi-tenant safe" is incomplete.**
|
||||
Verified in-tree: (a) `RelativeActionsProcessorStep` caches `_last_state` at preprocess (`processor/relative_action_processor.py:131`) and the postprocessor reads it back (`:189`) — a shared pipeline across clients is a race; (b) `DiffusionPolicy.predict_action_chunk` reads `self._queues`, which only `select_action` populates (`policies/diffusion/modeling_diffusion.py:90-108`) — it is **not** chunk-stateless; (c) SAC/SARM have no `predict_action_chunk` at all. _Consequence_: per-session processor instances (mandatory), a chunk-stateless allowlist, `serving_mode: exclusive` for diffusion-family, refusal at startup for SAC/SARM, and `policy.reset()` is **never** called in shared mode (§8.3).
|
||||
|
||||
**P4 — v1 re-derives latency compensation that already exists, on top of broken clocks.**
|
||||
v1 §8 specifies an in-flight RTT dict and manual stale-step trimming. `ActionQueue.merge(original, processed, real_delay, idx_before)` already trims `real_delay` stale steps and cross-validates against actions consumed in flight (`action_queue.py:219-246`). Worse, the legacy code compares wall clocks across machines (`robot_client.py:420` stamps `time.time()` "to compare timestamps across client and server"; `policy_server.py:178` compares it) — NTP skew is the same order as the latencies being measured. _Consequence_: the **monotonic iron rule** (§11): instants never cross machines; client timestamps are opaque echoed tokens; servers report only durations. `delay_steps = ceil((rtt + inference)/dt)` is computed client-side from client-local `perf_counter` samples and shipped per request.
|
||||
|
||||
**P5 — One-in-flight per client is a correctness requirement, not a tuning choice.**
|
||||
At send time the client snapshots `idx_before = queue.get_action_index()` and the leftover prefixes; `merge` validates against them. Two in-flight requests carry conflicting snapshots — the second merge corrupts both RTC replace mode and append mode. The local RTC thread is also strictly one-inference-at-a-time; one-in-flight preserves exact parity. _Consequence_: the worker publishes one observation, waits for its chunk (or timeout), then sends the next. v1 §8.1's out-of-order in-flight dict is dead weight; a late chunk is accepted only if it answers the _latest_ outstanding `seq_id`, otherwise dropped.
|
||||
|
||||
**P6 — v1's QoS table conflates transport reliability with congestion behavior.**
|
||||
"Reliable delivery for actions" sounds right but the dangerous knob is congestion control: a publisher configured `BLOCK` on the action topic can stall the **server's** publish path on one robot's dead uplink (Zenoh blocks up to `wait_before_close`, then may close the transport). A dropped action chunk is _recoverable by design_ — the client's queue keeps the robot moving and the next chunk replaces it. _Consequence_ (§6.3): actions = `reliability=RELIABLE` (hop-level) + `congestion_control=DROP` + `express=True` + `priority=INTERACTIVE_HIGH`; observations = `DROP` + `DATA`. If WAN loss proves material, upgrade the action topic to Zenoh Advanced Pub/Sub (cache + recovery, zenoh ≥ 1.5) rather than BLOCK.
|
||||
|
||||
**P7 — Schema-less MessagePack invites silent version drift across a 300-robot fleet.**
|
||||
msgpack stays (zenoh's `ext` serializer cannot encode numpy/dataclasses, and the team's choice stands), but naked msgpack dicts across heterogeneous fleet versions fail at runtime, on the robot. _Consequence_ (§10.4): a packed little-endian **attachment header** (`schema_version`, `seq_id`, `episode_id`, `client_mono_ns` — the rmw_zenoh pattern) so routing/correlation never deserializes the body; `schema_version` negotiated at the session handshake; additive-only evolution; golden codec tests. Protobuf-over-ZBytes is the documented fallback if drift bites in practice.
|
||||
|
||||
**P8 — "Deterministic rollout reproducibility" is unattainable on real robots.**
|
||||
No seed controls hardware, sensor noise, or network jitter; RTC's latency-driven trimming is inherently timing-dependent. _Consequence_: the contract is **fully logged + replayable** (§12): recording strategies already persist observations and executed actions; the remote engine adds `(session_id, seq_id, episode_id)` provenance so client datasets join server audit logs mechanically.
|
||||
|
||||
**P9 — v1 has no safety specification.**
|
||||
"Log a warning when the buffer empties" is not a fail-safe for a 300-robot fleet. _Consequence_ (§9): a staleness bound (`max_action_age_s` — never execute an action older than X relative to its source observation), an explicit fallback ladder (`hold` / `repeat_last` / `zero` — zero-command required for future velocity-controlled robots), and a DEAD state that triggers the existing strategy shutdown path (return-to-initial-pose, disconnect) via the same `shutdown_event` mechanism RTC uses (`rtc.py:359-360`).
|
||||
|
||||
**P10 — Capacity must be formula-driven, not "a user decision".**
|
||||
v1 §4 says clients-per-server "is a user decision". With `t` = server time per request, `r` = per-client request rate, `H` = RTC execution horizon, `dt` = control period:
|
||||
`N_max = min( 0.8 / (r·t), (H·dt/2 − RTT_net) / t )`
|
||||
→ ACT @ 20 ms, 1 Hz: ~40 clients/GPU. Pi0 @ 150 ms, 1 Hz: ~5 clients/GPU. 300 robots on Pi0 ≈ 60 GPU pods. _Consequence_: the manifest carries `max_sessions`; the server rejects session opens beyond it (with current load in the reply) so clients retry another replica. Micro-batching is deferred — blocked on a real API issue (`predict_action_chunk` takes a _scalar_ `inference_delay`; batched clients have different delays) — behind a `Scheduler` seam so it can land later without redesign (§8.5).
|
||||
|
||||
**P11 — Discovery ≠ multicast.**
|
||||
Zenoh's multicast scouting does not cross WAN, NAT, or most k8s CNIs. _Consequence_: multicast scouting disabled; clients use static `connect.endpoints` (DNS name of the router) + gossip; presence and liveness come from Zenoh **liveliness tokens** (§6.4), not discovery. "Discovery" for a robot fleet is configuration.
|
||||
|
||||
---
|
||||
|
||||
## 5. System Topology
|
||||
|
||||

|
||||
_(Diagram unchanged from v1 — the topology survives; transport/QoS/session details in it are superseded by §6.)_
|
||||
|
||||
- **Router tier**: one or more `zenohd` routers (k8s Deployment + Service, TLS on 7447). Robots **dial out** to the router (NAT-friendly: labs only need outbound 7447/443). GPU servers join as peers via cluster DNS.
|
||||
- **Server**: one process = one `(model_repo, revision, dtype, device)` on one GPU, pre-warmed from a YAML manifest (**KEPT** from v1, amended: `pin_task: bool` — VLA prompts may vary per session unless pinned).
|
||||
- **Client**: one robot running `lerobot-rollout --inference.type=remote`. Weightless: config-only policy metadata.
|
||||
- **Identity**: `client_uuid` per robot; `session_id` per connection epoch; both in every log line on both sides.
|
||||
|
||||
---
|
||||
|
||||
## 6. Zenoh Design
|
||||
|
||||
All Zenoh claims below were verified against zenoh / zenoh-python 1.x (eclipse-zenoh 1.9.0). Pin: `eclipse-zenoh>=1.9,<2.0`; keep `zenohd` on the same minor as the Python binding. Wheels cover manylinux x86_64/aarch64/armv7l/armv6l + macOS — Raspberry Pi edge clients are covered.
|
||||
|
||||
### 6.1 Key-expression schema
|
||||
|
||||
```
|
||||
@lerobot/<model_id>/<revision>/<task_slug>/<client_uuid>/obs client → server
|
||||
@lerobot/<model_id>/<revision>/<task_slug>/<client_uuid>/action server → client
|
||||
@lerobot/<model_id>/<revision>/<task_slug>/status queryable (capabilities)
|
||||
@lerobot/<model_id>/<revision>/<task_slug>/session queryable (open/validate)
|
||||
@lerobot/<model_id>/<revision>/<task_slug>/<client_uuid>/reset queryable (episode boundary)
|
||||
@lerobot/<model_id>/<revision>/<task_slug>/<client_uuid>/alive liveliness token (client)
|
||||
@lerobot/<model_id>/<revision>/<task_slug>/server/alive liveliness token (server)
|
||||
```
|
||||
|
||||
Rules (hard, enforced by a `sanitize_keyexpr()` helper):
|
||||
|
||||
- Root at the **verbatim chunk** `@lerobot` — verbatim chunks are only matched by identical chunks, so third-party `**` subscribers on a shared router can never scrape the tree.
|
||||
- Sanitize every user-supplied segment (model ids, task strings, uuids): non-empty, no `* $ ? # /`, no leading/trailing/double `/`. A task string containing `/` must be slugified before it becomes a key chunk.
|
||||
- Server subscribes with a **single-depth** wildcard (`.../*/obs`) — never `**` (it would also match `status`, `alive`, …).
|
||||
- v1's `cluster/experiment` prefix segments are dropped from the key schema; they return as free-form `tags` metadata in the session handshake (telemetry/labeling, not routing). Routing topology belongs to deployment (which router you dial), not to key depth.
|
||||
|
||||
### 6.2 Data plane vs. control plane (the rmw_zenoh split)
|
||||
|
||||
- **Data plane = pub/sub** (KEPT from v1): observations up, action chunks down, correlated by `seq_id` in **attachments** (§10.4). Pub/sub rather than query-per-inference because: a timed-out query's late reply is _dropped by the transport_ (wasted inference), whereas a late pub/sub chunk is still mergeable if it answers the latest outstanding seq; and pub/sub leaves room for server-initiated messages (drain notices). The one-in-flight discipline (P5) is enforced in the client worker, not by the transport.
|
||||
- **Control plane = queryables** (request/reply with explicit timeouts; the pattern rmw*zenoh uses for ROS 2 services): `status` (pre-flight capability fetch, 2 s timeout), `session` (open/validate → ack with capabilities + `session_id`), `reset` (episode boundary — \_acknowledged*, so episodic strategies know the server-side episode state is clean). Always pass an explicit `timeout` to `session.get()` — the config default is 10 s, far too long for our watchdogs.
|
||||
- **Episode ordering**: under one-in-flight there is no obs/reset race window in the data plane, but as belt-and-braces the first observation of each episode also carries `episode_start=True` + the new `episode_id` in its header.
|
||||
|
||||
### 6.3 QoS (revised from v1 §6.2 — see P6)
|
||||
|
||||
| Topic | reliability | congestion_control | express | priority | Why |
|
||||
| ------------------ | ----------- | ---------------------- | -------- | ---------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
|
||||
| `obs` | default | **DROP** | false | DATA | Intentional drop already happened at the client's one-slot holder; if the uplink stalls, dropping a frame protects the control loop. |
|
||||
| `action` | RELIABLE | **DROP** (never BLOCK) | **true** | INTERACTIVE_HIGH | Hop-level reliability over TCP; express skips batching for the small (4–50 KB) latency-critical payload; DROP so one dead robot uplink can never stall the server's publish path. Chunk loss is recoverable: the client buffer rides through it. |
|
||||
| control queryables | RELIABLE | default | — | — | Correctness over latency; explicit timeouts bound them. |
|
||||
|
||||
Upgrade path if WAN chunk loss proves material: `AdvancedPublisher`/`AdvancedSubscriber` (zenoh ≥ 1.5) with a small cache + heartbeat-based recovery **on the action topic only**. Hop-by-hop RELIABLE is not end-to-end reliability — Zenoh has no broker persistence; a disconnected subscriber's data is gone. The design assumes this (client state machine, §9).
|
||||
|
||||
### 6.4 Liveliness (presence + watchdogs)
|
||||
|
||||
- Client declares a liveliness token on `.../<client_uuid>/alive`. The server liveliness-subscribes with `history=True`: token appear → ensure session state; token drop → GC the session (mailbox, processor instances) after a grace period.
|
||||
- Server declares `.../server/alive`. The client liveliness-subscribes: on drop → treat as RECONNECTING (§9), hold/fallback per config, re-run the `status`/`session` handshake when the token reappears.
|
||||
- Tune the transport lease down from its default so ungraceful-death detection is seconds, not tens of seconds (verify the default in the pinned version; it is config `transport/link/tx/lease`).
|
||||
- Liveliness cannot detect a _hung-but-connected_ server. The client's per-request timeout (`request_timeout_s`) is the authoritative watchdog — this is the structural fix for legacy BUG-3 (no deadlines on `GetActions`).
|
||||
|
||||
### 6.5 Threading constraints (zenoh-python facts that shape both processes)
|
||||
|
||||
- **No asyncio API** in zenoh-python — both client and server are thread-based. This matches the existing RTC engine pattern exactly.
|
||||
- Each callback-based subscriber spawns a dedicated Python thread; **blocking Zenoh calls inside callbacks are disallowed**. Callbacks must be deposit-only (write a slot, set an event, return).
|
||||
- Channel handlers (`FifoChannel`, `RingChannel`) are Rust-side; `try_recv()` polls without spawning Python threads. `RingChannel(1)` is native latest-only semantics.
|
||||
- No zero-copy path for our payloads (SHM API is `@_unstable` and same-host-only; `ZBytes` copy behavior undocumented). At ~200 KB × a few Hz per robot, one memcpy is irrelevant.
|
||||
|
||||
### 6.6 Router deployment
|
||||
|
||||
- `zenohd` official image as a k8s Deployment (1–N replicas; routers mesh and reroute around failures) behind a `LoadBalancer`/`NodePort` Service exposing TLS 7447. No official Helm chart exists — roll-your-own manifests.
|
||||
- `scouting.multicast.enabled: false`; `scouting.gossip.enabled: true`; clients/servers use static `connect.endpoints`.
|
||||
- **Auth**: mTLS per robot (`transport.link.tls` with `enable_mtls`) + router **ACL** keyed on `cert_common_names`: a robot's cert may only `put` to `@lerobot/**/<its-uuid>/obs` and receive on `.../<its-uuid>/action`. Caveat (flagged): ACL config reloads require a router restart — plan cert/ACL changes as rolling router restarts.
|
||||
- Security review input: the third-party Zenoh protocol security analysis (Census Labs, 2025) should be read before exposing 7447 publicly.
|
||||
|
||||
---
|
||||
|
||||
## 7. The Statelessness Boundary (the load-bearing section)
|
||||
|
||||
**Where the network cut goes.** The local RTC pipeline is:
|
||||
|
||||
```
|
||||
obs (robot-processed dict)
|
||||
→ build_dataset_frame(hw_features, obs, "observation") CLIENT (cheap, hardware-coupled)
|
||||
─────────────────────────── network ───────────────────────────
|
||||
→ prepare_observation_for_inference(...) SERVER (policy-coupled, heavy)
|
||||
→ per-session preprocessor(...) SERVER (stateful within the request)
|
||||
→ policy.predict_action_chunk(obs, inference_delay, prefix) SERVER (pure for allowlisted policies)
|
||||
→ per-session postprocessor(...) SERVER (reads state cached at preprocess)
|
||||
─────────────────────────── network ───────────────────────────
|
||||
→ ActionQueue.merge(original, processed, real_delay, idx_before) CLIENT
|
||||
```
|
||||
|
||||
Three consequences:
|
||||
|
||||
1. **The server needs no cross-request state.** `RelativeActionsProcessorStep` writes `_last_state` at preprocess and the postprocessor reads it back _within the same request_. Per-session pipeline instances + one-request-at-a-time-per-session give correctness with zero persistent state.
|
||||
2. **RTC state stays client-side**, exactly where `RTCInferenceEngine` already keeps it. Each request ships: `inference_delay_steps = ceil(L_max/dt)` (from the client `LatencyTracker`, whose samples are full network-inclusive cycle times — RTT compensation falls out for free), `prefix_model = queue.get_left_over()[:H]`, and `prefix_robot = queue.get_processed_left_over()[:H]` (needed for server-side relative-prefix re-anchoring, mirroring `rtc.py:287-305`). The response returns **both** the model-space and robot-space chunks because `merge` needs both. ≤ `execution_horizon × action_dim` float32 each — a few hundred bytes.
|
||||
3. **G9 dies structurally.** No bespoke client resize (`F.interpolate` in legacy `helpers.py`), no client-side normalization. Clients ship native camera resolution; the server's canonical processor path does everything — serve-time preprocessing is byte-identical to train-time.
|
||||
|
||||
**What the server _does_ hold** (and what it means):
|
||||
|
||||
- Per-session processor instances (cheap; normalization stat tensors shared read-only).
|
||||
- Per-session episode counter + stats. Episode reset = reset the session's pipelines, clear its mailbox. **`policy.reset()` is never called in shared mode** — it is global to the shared policy instance and unnecessary for chunk-pure policies (ACT's ensembler and Pi0/SmolVLA's queues live in `select_action`, not `predict_action_chunk` — verified).
|
||||
- Policies that are _not_ chunk-pure get `serving_mode: exclusive` (§8.3).
|
||||
|
||||
---
|
||||
|
||||
## 8. The Inference Server: `lerobot-policy-server`
|
||||
|
||||
New package `src/lerobot/policy_server/`; console script `lerobot-policy-server --manifest manifest.yaml`.
|
||||
|
||||
### 8.1 Process model — **KEPT** from v1, amended
|
||||
|
||||
One process = one model+task on one GPU, loaded and warmed at startup (`warmup_inferences` dummy forwards; covers torch.compile). Multi-GPU nodes run N processes (`CUDA_VISIBLE_DEVICES` pinning). Dynamic model loading (`SendPolicyInstructions`) is **rejected**: pickle/RCE surface, arbitrary-download surface, and it destroys capacity planning. Amendment: `pin_task: false` (default) lets VLA clients set the task per session; `pin_task: true` rejects mismatched tasks at session open.
|
||||
|
||||
### 8.2 Concurrency (pure threads — no asyncio in zenoh-python)
|
||||
|
||||
```
|
||||
zenoh subscriber (.../*/obs) inference worker (1 thread, owns GPU)
|
||||
deposit-only callback: loop:
|
||||
slots[client_uuid] = sample ──► pick next session with pending obs (RR ring)
|
||||
(per-client latest-only) decode JPEG → per-session preprocess
|
||||
predict_action_chunk(delay, prefix)
|
||||
control queryables (status/session/ per-session postprocess → encode
|
||||
reset): validate, mutate session publisher.put(.../<uuid>/action)
|
||||
registry, reply (publishing from the worker thread is fine)
|
||||
```
|
||||
|
||||
- **Per-client latest-only mailbox**: a wildcard subscriber with a deposit-only callback writing per-client slots (scales to dynamic fleets), or — when the manifest enumerates clients — one `RingChannel(1)` subscriber per client polled via `try_recv()`. Either way: newest observation wins; a superseded request is counted (`superseded_seqs` in the next response) so drops are visible. This deletes legacy BUG-4 (`observations_similar` + `must_go`) by construction — the **client** decides when to request; the server never second-guesses observation content.
|
||||
- **Single inference worker**: torch releases the GIL inside `forward`, callbacks stay responsive. Strict round-robin over sessions with pending observations: each gets exactly one inference per cycle; starvation is structurally impossible. Overload degrades into longer cycle times → larger (but correct) client `delay_steps` → eventually the client staleness bound trips and the robot holds — safe by construction.
|
||||
|
||||
### 8.3 Chunk-stateless allowlist and serving modes
|
||||
|
||||
At startup the server classifies the loaded policy:
|
||||
|
||||
| Class | Policies (verified) | Mode |
|
||||
| --------------- | ------------------------------------------------------------------------------------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| chunk-stateless | ACT, Pi0, Pi0.5, SmolVLA (and any policy whose `predict_action_chunk` touches no instance state) | `shared`: N sessions, per-session pipelines, `policy.reset()` never called |
|
||||
| chunk-stateful | Diffusion family (`predict_action_chunk` reads `select_action`-fed `self._queues`) | `exclusive`: `max_sessions=1` enforced; episode reset additionally calls `policy.reset()`; second session open → rejected with a self-explanatory error |
|
||||
| no chunk API | SAC, SARM | refused at startup |
|
||||
|
||||
Implemented as a registry in `policy_server/validation.py`; the cleaner follow-up is a `supports_stateless_chunking` class attribute on `PreTrainedPolicy` (needs a pass over policy families — roadmap §14).
|
||||
|
||||
### 8.4 Session open & capability validation (fail fast, fail loud)
|
||||
|
||||
`session` queryable payload: `client_uuid`, `policy_type`, `fps`, feature summary (post-rename observation feature names + shapes, ordered action keys), `schema_version`, RTC intent, `tags`. Checks:
|
||||
|
||||
| Check | Rule | On mismatch |
|
||||
| -------------------------- | --------------------------------------------------------------- | ---------------------------------------------------------------------------------- |
|
||||
| Action names **and order** | must equal server's `action_feature_names` exactly | **hard reject** — this is the sync-safety contract mapping chunk columns to motors |
|
||||
| Camera names | client set must cover `policy.config.input_features` image keys | hard reject |
|
||||
| Resolution | any H×W accepted (server resizes canonically) | warn if aspect ratio differs from training |
|
||||
| State dim | flattened dim must match | hard reject |
|
||||
| `schema_version` | client within server's supported range | hard reject |
|
||||
| fps | vs. manifest `trained_fps` | warn (reject only when `strict_fps: true`) |
|
||||
| Task | when `pin_task: true`, must equal `default_task` | reject |
|
||||
| RTC | client RTC requires policy RTC kwargs support | downgrade to append mode + warning |
|
||||
| Capacity | `active_sessions < max_sessions` | reject with current load → client retries another replica |
|
||||
|
||||
Reply: `session_id`, model info (repo, revision — consider a checkpoint hash, §15), `action_feature_names`, `chunk_size`, `trained_fps`, `supports_rtc`, `serving_mode`, `warmed_up`, `schema_version`, warnings. **rename_map is applied client-side** so the wire format is canonical policy-feature keys across heterogeneous robots (also a prerequisite for future batching).
|
||||
|
||||
### 8.5 Scheduler seam (micro-batching later, not in v1)
|
||||
|
||||
The worker calls a `Scheduler.select(ready: list[Session]) -> list[Session]`; v1 ships `RoundRobin` (`return ready[:1]`). Cross-session batching is blocked on the policy API (`inference_delay` is scalar; batched clients have different delays/prefixes) — when that lands, a `MicroBatch` scheduler groups same-shape sessions. The seam costs nothing now and prevents a redesign later.
|
||||
|
||||
### 8.6 Manifest
|
||||
|
||||
```yaml
|
||||
model:
|
||||
{
|
||||
repo_or_path: lerobot/pi0_towels,
|
||||
revision: main,
|
||||
dtype: bfloat16,
|
||||
device: cuda,
|
||||
}
|
||||
default_task: "fold the towel"
|
||||
pin_task: false
|
||||
serving_mode: shared # forced to exclusive for chunk-stateful policies
|
||||
max_sessions: 5 # from the §P10 formula: Pi0 @150ms, 1 Hz refresh
|
||||
warmup_inferences: 2
|
||||
strict_fps: false
|
||||
zenoh:
|
||||
connect_endpoints: ["tls/router.gpu-cluster.internal:7447"]
|
||||
tls:
|
||||
{
|
||||
connect_certificate: ...,
|
||||
connect_private_key: ...,
|
||||
root_ca_certificate: ...,
|
||||
}
|
||||
health_port: 9100 # HTTP health + Prometheus metrics
|
||||
debug: { capture_dir: null, capture_max: 256 }
|
||||
```
|
||||
|
||||
Draccus dataclass in `policy_server/manifest.py`; YAML via `--manifest`, individual overrides via CLI.
|
||||
|
||||
---
|
||||
|
||||
## 9. The Edge Client: `RemoteInferenceEngine`
|
||||
|
||||
New file `src/lerobot/rollout/inference/remote.py`, registered `@InferenceEngineConfig.register_subclass("remote")`.
|
||||
|
||||
### 9.1 Threading model
|
||||
|
||||
| Thread | Role |
|
||||
| -------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| Main (strategy loop) | `notify_observation(obs)` → lock-protected latest-only slot (identical to `rtc.py` `_obs_holder`). `get_action()` → `ActionQueue.get()` + staleness check. **Never any I/O.** Structurally fixes legacy BUG-1 (blocking send inside the 33 ms loop). |
|
||||
| Network worker (1 daemon thread) | Cycle: wait until `queue_remaining·dt ≤ buffer_time_s` and active → snapshot `idx_before`, prefixes, `delay_steps = ceil(L_max/dt)` → encode (JPEG q=`jpeg_quality`) → `publisher.put(obs, attachment=header)` → await chunk on the action subscriber channel (timeout `request_timeout_s`) → `merge(original, processed, ceil(L/dt), idx_before)` → `latency_tracker.add(L)`. Owns the state machine, reconnects, and control queries. One-in-flight (P5). |
|
||||
| Zenoh action subscriber | `FifoChannel(2)` handler drained by the worker (no Python callback thread on the hot path); liveliness subscriber callback is deposit-only (sets an event). |
|
||||
|
||||
Reused unchanged: `ActionQueue` (`policies/rtc/action_queue.py`), `LatencyTracker`, `ActionInterpolator` (lives in strategies — `interpolation_multiplier` works with remote for free). Deleted concepts: aggregation zoo, `observations_similar`, `must_go`, `TimedObservation`/`TimedAction` pickles.
|
||||
|
||||
### 9.2 Fail-safe state machine
|
||||
|
||||
```
|
||||
ok no chunk for degraded_after_s
|
||||
CONNECTING ─────► STREAMING ───────────────────────────────► DEGRADED
|
||||
│ ▲ ▲ │ queue empty OR max_action_age_s hit │
|
||||
│ │ backoff, │ └───────────────────────────────────► STALLED ◄──┘
|
||||
│ │ re-handshake │ first successful merge │
|
||||
│ └─ RECONNECTING ◄── timeout streak / server liveliness drop ◄─┘
|
||||
│ │ offline > max_offline_s, capability/schema mismatch, auth failure
|
||||
└──────► DEAD (failed=True → shutdown_event → strategy teardown: return-to-initial-pose)
|
||||
```
|
||||
|
||||
- **DEGRADED**: requests failing but the queue still holds actions — the robot keeps executing; chunks _are_ the fault-tolerance buffer (1–3 s of coverage makes blips and clean server drains invisible).
|
||||
- **STALLED**: queue empty or staleness bound hit → apply `fallback`: `hold` (`get_action` → `None`; `send_next_action` already tolerates it), `repeat_last`, or `zero` (required for velocity-controlled robots, where "send nothing" means "keep last velocity").
|
||||
- **Staleness bound** (sync safety): every merge records `(chunk_start_index, t_send)`; `get_action` refuses any action whose source observation is older than `max_action_age_s` (default 3.0 s ≈ 90 steps @ 30 fps). Bounds open-loop execution after a network stall.
|
||||
- **DEAD**: only after `max_offline_s` (default 60 s) or a hard contract violation (capability/schema mismatch on reconnect — e.g. the server restarted with a different model; never execute wrong-model chunks). Uses the exact mechanism RTC uses (`failed=True` + global `shutdown_event`) so existing teardown runs unchanged.
|
||||
- **Watchdog layering**: per-request timeout (hung server — the BUG-3 fix) → server liveliness token (dead server/router) → staleness bound (the robot-side invariant that holds regardless of why data stopped).
|
||||
- **Pause/resume (DAgger)**: `pause()` stops the worker publishing (slot keeps refreshing, ignored); queue intact — parity with `RTCInferenceEngine.pause`. DAgger's existing `interpolator.reset(); engine.reset(); engine.resume()` sequence works unchanged.
|
||||
- **`reset()` (episode boundary)**: clear `ActionQueue` + staleness bookkeeping, bump `episode_id`, fire the acked `reset` query (1 s timeout, failure logged — the server has nothing it _must_ do thanks to per-request statelessness), flag `episode_start` on the next observation. `LatencyTracker` intentionally survives reset (latency is episode-invariant; parity with local RTC).
|
||||
- **`ready`** = session opened ∧ capabilities validated ∧ server `warmed_up`. First-chunk gating is implicit (`get_action` → `None` until the first merge).
|
||||
|
||||
### 9.3 Weightless client — exact integration changes
|
||||
|
||||
- `rollout/context.py`: `PolicyContext.{policy, preprocessor, postprocessor}` become `| None`. For remote configs, skip step 1 (weight load / PEFT / `.to(device)` / torch.compile / `init_rtc_processor`) and step 6 (`make_pre_post_processors`). Verified safe: strategies only consume `ctx.policy.inference`. Keep steps 2–5 (robot processors, hardware, features, dataset) — they are robot-derived. Keep the visual pre-flight check (`context.py:309-324`): `--policy.path` already loads config-only (`rollout/configs.py:324-328`, no weight download) and failing before dialing the server is free. `use_torch_compile` / explicit `--device` → warn-and-ignore for remote.
|
||||
- `rollout/inference/factory.py`: signature loosens to `policy: PreTrainedPolicy | None` (+ `policy_config: PreTrainedConfig`); `sync`/`rtc` branches guard `policy is None`; the `remote` branch lazy-imports (`eclipse-zenoh` stays an optional extra).
|
||||
- The authoritative validation moves to session open (§8.4); the local check becomes a fast-fail convenience.
|
||||
|
||||
### 9.4 Config
|
||||
|
||||
```python
|
||||
@InferenceEngineConfig.register_subclass("remote")
|
||||
@dataclass
|
||||
class RemoteInferenceConfig(InferenceEngineConfig):
|
||||
connect_endpoint: str = "tls/localhost:7447" # zenoh router endpoint
|
||||
tls_cert: str | None = None; tls_key: str | None = None; tls_ca: str | None = None
|
||||
client_uuid: str = "" # "" → uuid4 at start()
|
||||
jpeg_quality: int = 90 # 0 = raw (LAN/debug)
|
||||
buffer_time_s: float = 0.5 # send next obs when queue playback ≤ this (v1 G14) — KEPT
|
||||
max_action_age_s: float = 3.0 # staleness bound (safety)
|
||||
degraded_after_s: float = 1.0
|
||||
request_timeout_s: float = 5.0
|
||||
reconnect_initial_backoff_s: float = 0.5
|
||||
reconnect_max_backoff_s: float = 10.0
|
||||
max_offline_s: float = 60.0
|
||||
fallback: FallbackBehavior = FallbackBehavior.HOLD # hold | repeat_last | zero
|
||||
rtc: RTCConfig = field(default_factory=RTCConfig) # enabled → replace mode; horizon caps prefix
|
||||
tags: dict[str, str] = field(default_factory=dict) # ex-cluster/experiment labels
|
||||
```
|
||||
|
||||
```bash
|
||||
# Remote RTC + sentry recording (the reproducibility path)
|
||||
lerobot-rollout \
|
||||
--strategy.type=sentry \
|
||||
--policy.path=lerobot/pi0_towels \ # config-only: no weights downloaded
|
||||
--inference.type=remote \
|
||||
--inference.connect_endpoint=tls/router.gpu-cluster.internal:7447 \
|
||||
--inference.rtc.execution_horizon=10 \
|
||||
--robot.type=so100_follower --robot.port=/dev/ttyACM0 \
|
||||
--robot.cameras="{front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||
--dataset.repo_id=user/rollout_fleet_a --dataset.single_task="fold the towel"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 10. Wire Schema
|
||||
|
||||
### 10.1 Payload anatomy & rates — **KEPT** (JPEG) with numbers
|
||||
|
||||
Upstream per request: joints (24–128 B) + JPEG frames (480p q90 ≈ 40–90 KB each; 720p ≈ 110–230 KB) + RTC prefixes (≤ a few KB) → 60–450 KB depending on cameras. Downstream: `2 × chunk_size × action_dim × 4 B` + metadata → 3–50 KB. Effective request rate is self-clocked by `buffer_time_s` to ~1–4 Hz per robot (not the 30 Hz control rate). 300 robots ≈ 0.3–10 Mbps each — the wire is never the bottleneck; bandwidth budgeting is about camera count/resolution, and each GPU pod only ever sees its own ≤ `max_sessions` clients. Zenoh fragments >64 KiB payloads transparently; multi-MB messages are fine.
|
||||
|
||||
### 10.2 Attachment header (fixed-layout, packed little-endian — parsed without touching the body)
|
||||
|
||||
| Field | Type | Notes |
|
||||
| ---------------- | ---- | -------------------------------------------------------------- |
|
||||
| `schema_version` | u16 | negotiated at session open |
|
||||
| `msg_type` | u8 | OBS / CHUNK / EVENT |
|
||||
| `seq_id` | u64 | per-session monotonic; echoed in the chunk |
|
||||
| `episode_id` | u32 | bumped by `reset()` |
|
||||
| `client_mono_ns` | i64 | client `monotonic_ns()`; **opaque to the server, echoed back** |
|
||||
| `session_epoch` | u32 | bumped per (re)connect; stale-epoch chunks dropped |
|
||||
|
||||
### 10.3 msgpack bodies
|
||||
|
||||
**ObservationMsg** (client → server): `state: {names_ref, data: f32 LE bytes}`, `images: {name: {codec: jpeg|raw, bytes, (h,w,c) if raw}}`, `task: str`, `inference_delay_steps: int`, `prefix_model: tensor?`, `prefix_robot: tensor?` (tensors = raw LE bytes + dtype + shape), `episode_start: bool`.
|
||||
**ActionChunkMsg** (server → client): `seq_id_echo`, `client_mono_ns_echo`, `chunk_model: tensor`, `chunk_robot: tensor`, `queue_wait_ms: f32`, `inference_ms: f32`, `superseded_seqs: u32`, `server_load: f32`.
|
||||
**Status / SessionOpen / SessionAck / ResetMsg**: as specified in §8.4.
|
||||
|
||||
### 10.4 Schema discipline (P7)
|
||||
|
||||
`schema_version` gates at handshake; evolution is additive-only (new optional msgpack keys; unknown keys ignored); attachment layout changes require a version bump; golden codec round-trip tests (tensor exactness, JPEG RGB-channel-order regression — a silent BGR swap poisons every VLA in the fleet) are part of the test suite. **No pickle anywhere** — KEPT from v1 and now structural: nothing in the schema can carry code.
|
||||
|
||||
---
|
||||
|
||||
## 11. Latency Budget & the Clock Iron Rule
|
||||
|
||||
| Stage | LAN | WAN (50 ms RTT) |
|
||||
| ------------------------------ | --------------- | --------------- |
|
||||
| JPEG encode ×3 (edge CPU) | 2–9 ms | 2–9 ms |
|
||||
| Serialize | <1 ms | <1 ms |
|
||||
| Uplink (tx + ½RTT) | ~2 ms | ~54 ms |
|
||||
| Server queue wait | 0 → 1×inference | 0 → 1×inference |
|
||||
| Decode + canonical preprocess | 4–10 ms | 4–10 ms |
|
||||
| **Inference** | **15–150 ms** | **15–150 ms** |
|
||||
| Postprocess + downlink + merge | ~2 ms | ~27 ms |
|
||||
| **Total (Pi0-class)** | **~110–175 ms** | **~190–250 ms** |
|
||||
|
||||
Inference is 60–85 % of end-to-end on LAN; the entire transport+serialization stack is <10 ms. WAN adds propagation + uplink bandwidth — identical under any transport. At 30 fps this lands `delay_steps` ≈ 4–8, comfortably inside RTC execution horizons: WAN degrades smoothness parameters, never correctness. _This table is the standing answer to transport-performance bikeshedding._
|
||||
|
||||
**Clock iron rule** (P4): wall-clock instants never cross machines. Client stamps `monotonic_ns`, the server echoes it opaquely; `RTT = now − echo`. The server reports only **durations** (`queue_wait_ms`, `inference_ms`) measured on its own monotonic clock; `network_time = RTT − queue_wait − inference` for diagnostics. The schema has no field in which a foreign wall-clock instant can be compared — the legacy `time.time()` bug is unrepresentable.
|
||||
|
||||
---
|
||||
|
||||
## 12. Reproducibility & Audit (P8)
|
||||
|
||||
The contract is **fully logged + replayable**, not "deterministic":
|
||||
|
||||
- **Client = source of truth.** Recording strategies already persist observations + executed actions to `LeRobotDataset`. The remote engine logs, per executed action, the `(session_id, seq_id, episode_id)` of its source chunk plus the echoed `queue_wait_ms`/`inference_ms` (dataset-extras columns are a follow-up; client logs in v1).
|
||||
- **Server audit line per request** (structured JSON): `{ts, session_id, client_uuid, seq_id, episode_id, queue_wait_ms, inference_ms, chunk_range, superseded_seqs, outcome}`.
|
||||
- **Optional bounded capture**: `debug.capture_dir` writes a ring of request/response pairs (safetensors) for byte-exact offline replay through the same server pipeline.
|
||||
- **Runbook — "robot #217 stuttered at 14:03"**: (1) Grafana `session_staleness{client="217"}` — spike ⇒ server side, flat ⇒ client/network. (2) Server side: audit lines — `queue_wait_ms` rising across _all_ sessions ⇒ overloaded replica (check `active_sessions` vs `max_sessions`); `superseded_seqs` streak on 217 only ⇒ that client over-requesting; `outcome=error` ⇒ adjacent stack trace. (3) Client side: state-machine transitions + reconnects in the client log; dataset rows show which seq's chunk was executing and where `None` ticks occurred. Every hop shares `(session_id, seq_id)` — the join is mechanical.
|
||||
|
||||
---
|
||||
|
||||
## 13. Integration & Migration Plan
|
||||
|
||||
### 13.1 New
|
||||
|
||||
| Path | Content |
|
||||
| --------------------------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `src/lerobot/policy_server/{__init__,schema,codec,manifest,session,scheduler,validation,server}.py` | wire schema constants, msgpack/attachment codecs, manifest dataclasses, `Session` + mailbox, `Scheduler` seam, capability rules + chunk-stateless registry, zenoh servicer + inference worker + drain + HTTP health/metrics |
|
||||
| `src/lerobot/rollout/inference/remote.py` | `RemoteInferenceEngine` (~600 lines; mirrors `rtc.py` structure) |
|
||||
| `src/lerobot/scripts/lerobot_policy_server.py` + `[project.scripts]` entry | thin `main()` |
|
||||
| `docker/Dockerfile.policy-server` | CUDA runtime base + uv; manifest via ConfigMap |
|
||||
| `docs/source/remote_inference.mdx` (+ `_toctree.yml`) | replaces `async.mdx` |
|
||||
|
||||
### 13.2 Modified
|
||||
|
||||
`rollout/inference/factory.py` (config + Optional-typed signature + lazy import) · `rollout/context.py` (weightless branch) · `rollout/inference/__init__.py` · `scripts/lerobot_rollout.py` docstring · `pyproject.toml`: `[async]` extra becomes `eclipse-zenoh>=1.9,<2.0` + `msgpack` (grpcio/matplotlib leave it; grpcio remains under `[hilserl]`/`dev` for the RL stack).
|
||||
|
||||
### 13.3 Removed — same landing PR
|
||||
|
||||
`src/lerobot/async_inference/` · `tests/async_inference/` · `docs/source/async.mdx` + its `_toctree.yml` entry · the `AsyncInference` service + `Observation`/`Actions`/`PolicySetup` messages from `src/lerobot/transport/services.proto` (regenerate pb2; **`LearnerService` untouched** — `transport/` is shared with HIL-SERL (`src/lerobot/rl/`); the RL test suite gates this change).
|
||||
|
||||
### 13.4 Legacy config → successor mapping
|
||||
|
||||
| Legacy (`RobotClientConfig`/`PolicyServerConfig`) | Successor |
|
||||
| ------------------------------------------------- | ---------------------------------------------------------- |
|
||||
| `server_address` | `--inference.connect_endpoint` (zenoh router) |
|
||||
| `policy_type`, `pretrained_name_or_path` | `--policy.path` (config-only) + server manifest |
|
||||
| `chunk_size_threshold` (0–1 ratio) | `--inference.buffer_time_s` (seconds) |
|
||||
| `actions_per_chunk` | server manifest (validated at session open) |
|
||||
| `aggregate_fn_name` + `AGGREGATE_FUNCTIONS` | **dropped** — `ActionQueue` replace/append |
|
||||
| `policy_device`, `client_device` | **dropped** — server concern / chunks arrive CPU f32 |
|
||||
| `debug_visualize_queue_size` | **dropped** — Rerun (`--display_data`) + engine stats |
|
||||
| `PolicyServerConfig.{host,port}` | manifest `zenoh.connect_endpoints` |
|
||||
| `inference_latency`, `obs_queue_timeout` | **dropped** — latency client-measured; no server obs queue |
|
||||
| `SendPolicyInstructions` | **dropped** — MaaS manifest + session validation |
|
||||
| `observations_similar` / `must_go` | **dropped** — latest-only slots + client send gate |
|
||||
| pickle envelopes | **dropped** — msgpack + attachment headers |
|
||||
|
||||
### 13.5 Legacy bugs/gaps → structural resolution
|
||||
|
||||
BUG-1 → worker thread owns all I/O. BUG-2 → aggregation deleted; `ActionQueue` is internally locked. BUG-3 → per-request timeout + liveliness. BUG-4 → client-side send gating; server newest-wins. G1 → per-session registry. G2 → manifest. G4 → msgpack+attachments. G5 → monotonic echo + `delay_steps`. G7 → recording strategies. G8 → mTLS + ACL. G9 → server-side canonical processors. G11 → `status` queryable. G12 → Prometheus + audit logs. G13 → `lerobot-policy-server` console script. G14 → `buffer_time_s`.
|
||||
|
||||
### 13.6 Tests
|
||||
|
||||
- **Unit**: codec round-trips (tensor exact; JPEG RGB-order regression), capability-validation matrix (§8.4 as parametrized cases), scheduler fairness + newest-wins supersession (mock policy with configurable sleep), manifest parsing, key-expr sanitization.
|
||||
- **Loopback integration** (CPU, fast CI): client+server in one process over zenoh peer-to-peer (or a localhost `zenohd` started by the fixture), tiny-ACT, fake 2-camera robot, N=8 concurrent sessions. The headline regression: two sessions with different joint states must not cross-contaminate `RelativeActionsProcessorStep` postprocessing — the test that proves the multi-tenancy claim.
|
||||
- **Chaos**: kill the server mid-episode → client returns `None`, never raises into the control loop, `failed` stays False within `max_offline_s`, resumes on restart; `docker kill zenohd` → liveliness flap → safe state → re-handshake (explicitly tests re-declaration behavior, flagged unverified upstream); SIGTERM drain → in-flight chunk completes, clients reconnect invisibly.
|
||||
- **Golden parity**: remote RTC vs local `RTCInferenceEngine` on identical observation sequences → byte-identical merged queues (the re-anchoring contract test). Gate for any real-robot remote-RTC use.
|
||||
|
||||
---
|
||||
|
||||
## 14. Roadmap
|
||||
|
||||
1. **PR1 — schema & codecs** (no torch deps): `policy_server/{schema,codec,manifest}.py`, key-expr sanitizer, golden codec tests.
|
||||
2. **PR2 — server core**: session registry, scheduler, validation/allowlist, inference worker with mock policy, loopback harness.
|
||||
3. **PR3 — client engine**: `RemoteInferenceEngine`, factory/context weightless integration, loopback integration + chaos + golden-parity tests.
|
||||
4. **PR4 — ops & docs**: Dockerfile, health/metrics, drain, ACL examples, `remote_inference.mdx`, rollout docstring.
|
||||
5. **Landing PR — legacy deletion**: remove `async_inference/` + tests + docs + proto service (RL suite gates), `[async]` extra swap.
|
||||
6. **Pre-release field validation**: one real robot on a lossy network (watchdog default tuning); JPEG q90 vs raw A/B on one policy (train/serve shift).
|
||||
7. **Future**: micro-batching (needs per-sample `inference_delay` across policy families), client-side downscale-to-policy-resolution (config-only shapes make it possible), Advanced Pub/Sub on the action topic, per-robot quotas, dataset provenance columns, `supports_stateless_chunking` attribute upstreamed to policy classes.
|
||||
|
||||
---
|
||||
|
||||
## 15. Open Risks
|
||||
|
||||
| Risk | Mitigation / decision needed |
|
||||
| ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| Re-anchoring parity (server-side relative-prefix re-anchor vs `rtc.py`) | Golden parity test (§13.6) is a hard gate before robot use; likely failure mode is normalizer dtype/device drift |
|
||||
| First-chunk over-trim when idle: `merge` trims `ceil(L/dt)` even when nothing was consumed (queue empty at episode start) — wasteful at network latencies (600 ms ⇒ 18 steps) | Proposed clamp `real_delay = min(real_delay, last_index - idx_before)` touches the shared `ActionQueue` used by local RTC — needs sign-off + regression tests |
|
||||
| JPEG train/serve distribution shift | Unmeasured; A/B before locking q90 default (roadmap §14.6) |
|
||||
| Watchdog defaults untuned (`request_timeout_s=5`, `degraded_after_s=1`, `max_action_age_s=3`) | Field validation on wired and Wi-Fi; consider named profiles |
|
||||
| Capability check can pass while semantics differ (different finetune, different normalization stats, identical feature names) | Add checkpoint hash/revision pinning to SessionAck — decide in PR2 |
|
||||
| zenoh-python long-session maturity: re-declaration after router restart partially verified; SHM unstable; no asyncio | Chaos tests own this; thread-based design avoids the asyncio gap entirely |
|
||||
| Router ACL reload requires restart | Operational runbook: cert/ACL changes = rolling router restart |
|
||||
| `fallback=zero` has no consumer until velocity actions land in rollout (only `.pos` features routed today) | Validate the enum against robot capabilities when velocity support lands |
|
||||
| Per-client mailbox memory under fleet-scale wildcard subscription | One decoded-obs slot per client is small; add an LRU GC tied to liveliness drops |
|
||||
@@ -0,0 +1,82 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# This Dockerfile builds a GPU inference pod for `lerobot-policy-server`
|
||||
# (remote inference over Zenoh). It starts from an NVIDIA CUDA base image;
|
||||
# the cu128 PyTorch wheels bundle their own CUDA runtime (driver floor 570.86,
|
||||
# see pyproject.toml [tool.uv]).
|
||||
|
||||
# docker build -f docker/Dockerfile.policy-server -t lerobot-policy-server .
|
||||
# docker run --gpus all -v ./server.yaml:/etc/lerobot/server.yaml lerobot-policy-server
|
||||
#
|
||||
# Extra policy-family dependencies (e.g. pi0/smolvla need transformers) can be
|
||||
# added at build time:
|
||||
# docker build -f docker/Dockerfile.policy-server \
|
||||
# --build-arg LEROBOT_EXTRAS="async pi0" -t lerobot-policy-server .
|
||||
|
||||
# Configure the base image (same CUDA family as Dockerfile.internal)
|
||||
ARG CUDA_VERSION=12.8.1
|
||||
ARG OS_VERSION=24.04
|
||||
FROM nvidia/cuda:${CUDA_VERSION}-base-ubuntu${OS_VERSION}
|
||||
|
||||
# Define Python version and lerobot extras arguments
|
||||
ARG PYTHON_VERSION=3.12
|
||||
ARG LEROBOT_EXTRAS="async"
|
||||
|
||||
# Configure environment variables
|
||||
ENV DEBIAN_FRONTEND=noninteractive \
|
||||
PATH=/lerobot/.venv/bin:$PATH
|
||||
|
||||
# Install system dependencies and uv (as root).
|
||||
# Kept lean: no hardware/teleop libraries — this image only serves policies.
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
git curl ca-certificates libglib2.0-0 ffmpeg \
|
||||
&& curl -LsSf https://astral.sh/uv/install.sh | sh \
|
||||
&& mv /root/.local/bin/uv /usr/local/bin/uv \
|
||||
&& useradd --create-home --shell /bin/bash user_lerobot \
|
||||
&& apt-get clean && rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Create application directory and set permissions
|
||||
WORKDIR /lerobot
|
||||
RUN chown -R user_lerobot:user_lerobot /lerobot
|
||||
|
||||
# Switch to the non-root user
|
||||
USER user_lerobot
|
||||
|
||||
# Model checkpoints are cached under HF_HOME — mount it as a volume
|
||||
# (or a PVC in Kubernetes) so warm restarts skip the Hub download.
|
||||
ENV HOME=/home/user_lerobot \
|
||||
HF_HOME=/home/user_lerobot/.cache/huggingface \
|
||||
HF_LEROBOT_HOME=/home/user_lerobot/.cache/huggingface/lerobot \
|
||||
TORCH_HOME=/home/user_lerobot/.cache/torch \
|
||||
TRITON_CACHE_DIR=/home/user_lerobot/.cache/triton
|
||||
|
||||
# Create the virtual environment (Python provisioned by uv)
|
||||
RUN uv venv --python ${PYTHON_VERSION}
|
||||
|
||||
# Install lerobot from the build context with the async extra
|
||||
# (eclipse-zenoh + msgpack — see pyproject.toml [project.optional-dependencies])
|
||||
COPY --chown=user_lerobot:user_lerobot setup.py pyproject.toml uv.lock README.md MANIFEST.in ./
|
||||
COPY --chown=user_lerobot:user_lerobot src/ src/
|
||||
|
||||
RUN uv sync --locked --no-cache $(printf -- '--extra %s ' ${LEROBOT_EXTRAS})
|
||||
|
||||
# HTTP health + Prometheus metrics (manifest `health_port`, 0 disables)
|
||||
EXPOSE 9100
|
||||
|
||||
# The manifest is typically mounted as a ConfigMap (Kubernetes) or a bind
|
||||
# mount (docker run -v) at /etc/lerobot/server.yaml; any field can also be
|
||||
# overridden on the command line, e.g. --model.repo_or_path=lerobot/pi0_towels
|
||||
ENTRYPOINT ["lerobot-policy-server"]
|
||||
CMD ["--manifest", "/etc/lerobot/server.yaml"]
|
||||
@@ -45,8 +45,6 @@
|
||||
title: Language Columns and Recipes
|
||||
- local: tools
|
||||
title: Tools
|
||||
- local: annotation_pipeline
|
||||
title: Annotation Pipeline
|
||||
- local: video_encoding_parameters
|
||||
title: Video encoding parameters
|
||||
- local: streaming_video_encoding
|
||||
@@ -89,8 +87,8 @@
|
||||
- sections:
|
||||
- local: inference
|
||||
title: Policy Deployment (lerobot-rollout)
|
||||
- local: async
|
||||
title: Use Async Inference
|
||||
- local: remote_inference
|
||||
title: Remote Inference (lerobot-policy-server)
|
||||
- local: rtc
|
||||
title: Real-Time Chunking (RTC)
|
||||
title: "Inference"
|
||||
|
||||
@@ -1,281 +0,0 @@
|
||||
# Annotation Pipeline
|
||||
|
||||
`lerobot-annotate` watches each episode's video with a vision-language
|
||||
model (VLM) and writes natural-language annotations back into your
|
||||
dataset. It fills the two language columns from the
|
||||
[Language Columns and Recipes](./language_and_recipes) page —
|
||||
`language_persistent` and `language_events` — straight into
|
||||
`data/chunk-*/file-*.parquet`.
|
||||
|
||||
In short: point it at a LeRobot dataset, and it adds subtasks, plans,
|
||||
memory, interjections, speech, and visual Q&A that a policy can be
|
||||
trained on.
|
||||
|
||||
## How it fits together
|
||||
|
||||
```text
|
||||
your dataset lerobot-annotate
|
||||
(LeRobot v3.1)
|
||||
│
|
||||
▼
|
||||
┌─────────────────────────────────────────────────────┐
|
||||
│ read episodes │
|
||||
└──────────────────────────┬──────────────────────────┘
|
||||
│
|
||||
┌────────────────────┼────────────────────┐
|
||||
▼ ▼ ▼
|
||||
┌──────────┐ ┌───────────────┐ ┌──────────┐ one shared Qwen-VL
|
||||
│ plan │ │ interjections │ │ vqa │ ◀── server (vLLM, OpenAI
|
||||
└────┬─────┘ └───────┬───────┘ └────┬─────┘ API) drives all three
|
||||
└────────────────────┼─────────────────────┘
|
||||
│ each module stages raw JSONL
|
||||
▼ into .annotate_staging/
|
||||
┌─────────────────┐
|
||||
│ validator │ ◀── checks everything
|
||||
└────────┬────────┘
|
||||
▼
|
||||
┌─────────────────┐
|
||||
│ writer │
|
||||
└────────┬────────┘
|
||||
▼
|
||||
data/chunk-*/file-*.parquet
|
||||
(+ meta/info.json tools)
|
||||
```
|
||||
|
||||
Three modules (`plan`, `interjections`, `vqa`) all talk to **one** shared
|
||||
VLM. Each module stages its output to disk, a validator checks it, and a
|
||||
single writer rewrites the dataset shards in place.
|
||||
|
||||
## What the pipeline produces
|
||||
|
||||
Each module emits a few kinds of annotation ("styles"), routed to one of
|
||||
the two language columns:
|
||||
|
||||
| Style / atom | Column | Module |
|
||||
| ------------------------------------------- | --------------------- | --------------- |
|
||||
| `subtask` (Pi0.7-style "how, not what") | `language_persistent` | `plan` |
|
||||
| `plan` (initial + refresh on interjection) | `language_persistent` | `plan` |
|
||||
| `memory` (MEM-style compression) | `language_persistent` | `plan` |
|
||||
| `task_aug` (rephrasings of the task) | `language_persistent` | `plan` |
|
||||
| `interjection` | `language_events` | `interjections` |
|
||||
| speech tool-call atom (`style=null`, `say`) | `language_events` | `interjections` |
|
||||
| `vqa` (user / assistant pair) | `language_events` | `vqa` |
|
||||
|
||||
### How subtasks are generated
|
||||
|
||||
The `plan` module doesn't ask the VLM for subtasks in one shot. Instead
|
||||
it uses a two-step **describe → segment** flow:
|
||||
|
||||
1. **Describe** — the VLM narrates only what it actually sees in the
|
||||
chosen camera (no guessing about the task).
|
||||
2. **Segment** — that description is fed back in, and the VLM splits the
|
||||
episode into consecutive atomic subtasks.
|
||||
|
||||
The resulting spans are then stitched into a gap-free, full-episode
|
||||
cover, so **every frame has exactly one active subtask**. See
|
||||
[`run_hf_job.py`](https://github.com/huggingface/lerobot/blob/main/examples/annotations/run_hf_job.py)
|
||||
for the production settings (single camera, embedded frames, windowed
|
||||
subtask generation).
|
||||
|
||||
### Tools
|
||||
|
||||
The writer does **not** add a `tools` column to the parquet. The tool
|
||||
catalog lives in `meta/info.json["tools"]` instead (see [Tools](./tools)).
|
||||
After every run, the pipeline makes sure the canonical `say` schema is in
|
||||
that list, keeping any tools you declared beforehand.
|
||||
|
||||
Want to add your own tool? Edit `meta/info.json["tools"]` directly — the
|
||||
pipeline preserves whatever is already there. That makes the tool visible
|
||||
to the chat template, so the model can learn to _generate_ the call. The
|
||||
runtime layer that actually _executes_ a generated call (the `Tool`
|
||||
protocol / `TOOL_REGISTRY` under `src/lerobot/tools/`) is not part of
|
||||
this PR — the [Tools](./tools) doc marks those pieces as
|
||||
not-yet-implemented.
|
||||
|
||||
## Running on Hugging Face Jobs
|
||||
|
||||
Annotation runs on [Hugging Face Jobs](https://huggingface.co/docs/hub/en/jobs).
|
||||
The repo ships a launcher script you copy and tweak for your dataset:
|
||||
|
||||
```bash
|
||||
HF_TOKEN=hf_... uv run python examples/annotations/run_hf_job.py
|
||||
```
|
||||
|
||||
[`run_hf_job.py`](https://github.com/huggingface/lerobot/blob/main/examples/annotations/run_hf_job.py)
|
||||
starts a single-GPU `h200` job (bump it to `h200x4` for big datasets)
|
||||
that:
|
||||
|
||||
1. installs `lerobot` (from `main`) plus the annotation extras,
|
||||
2. boots one vLLM server per GPU (using the `vllm/vllm-openai` image) and
|
||||
drives it over the OpenAI-compatible API,
|
||||
3. runs the `plan` / `interjections` / `vqa` modules across the dataset
|
||||
with `lerobot-annotate`,
|
||||
4. with `--push_to_hub=true`, uploads the result to `--new_repo_id` (or
|
||||
back to `--repo_id` in place if you leave that unset).
|
||||
|
||||
To use a different dataset, model, or hub repo, edit the `CMD` block in
|
||||
the script. Every flag there maps directly to a `lerobot-annotate` flag
|
||||
(run `lerobot-annotate --help` for the full list).
|
||||
|
||||
## Key options
|
||||
|
||||
These are the flags you'll reach for most often. Run
|
||||
`lerobot-annotate --help` for everything else; the defaults are tuned for
|
||||
short manipulation episodes.
|
||||
|
||||
### Dataset in / out
|
||||
|
||||
| Flag | Default | What it does |
|
||||
| ----------------- | ------- | ----------------------------------------------------------------------- |
|
||||
| `--repo_id` | — | Hub dataset to annotate (downloaded if `--root` unset). |
|
||||
| `--root` | — | Annotate a local dataset directory instead. |
|
||||
| `--new_repo_id` | — | Push the result to a new repo (leaves the source repo untouched). |
|
||||
| `--push_to_hub` | `false` | Upload after annotating (to `--new_repo_id`, else back to `--repo_id`). |
|
||||
| `--only_episodes` | all | Annotate just these episode indices (handy for a test run). |
|
||||
| `--seed` | `1729` | Seeds the RNGs that pick interjection timestamps + VQA question types. |
|
||||
|
||||
### Which modules run
|
||||
|
||||
Every module is on by default and can be toggled independently (set to
|
||||
`false` to skip it, e.g. to iterate on one module at a time):
|
||||
|
||||
| Flag | Default | Turns off |
|
||||
| ------------------------- | ------- | ----------------------------------- |
|
||||
| `--plan.enabled` | `true` | subtasks + plan + memory + task_aug |
|
||||
| `--interjections.enabled` | `true` | interjections + speech atoms |
|
||||
| `--vqa.enabled` | `true` | the VQA pairs |
|
||||
|
||||
### The VLM (`--vlm.*`)
|
||||
|
||||
| Flag | Default | What it does |
|
||||
| -------------------------- | ------------------ | ----------------------------------------------------------------------------------- |
|
||||
| `--vlm.model_id` | `Qwen/Qwen3.6-27B` | The model to serve and prompt. |
|
||||
| `--vlm.camera_key` | first `images.*` | Which camera every prompt is grounded on. |
|
||||
| `--vlm.serve_command` | auto | The exact `vllm serve …` command (set TP size, GPU memory, `--max-model-len` here). |
|
||||
| `--vlm.parallel_servers` | `1` | Independent servers for round-robin routing (one per GPU). |
|
||||
| `--vlm.num_gpus` | `0` | GPUs per server (`0` = one each). |
|
||||
| `--vlm.client_concurrency` | `16` | In-flight requests across all servers. |
|
||||
| `--vlm.max_new_tokens` | `512` | Generation cap per call. |
|
||||
| `--vlm.temperature` | `0.2` | Sampling temperature. |
|
||||
|
||||
### Subtasks / plan / memory (`--plan.*`)
|
||||
|
||||
| Flag | Default | What it does |
|
||||
| ------------------------------- | ---------- | ------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `--plan.frames_per_second` | `1.0` | How densely the episode video is sampled. |
|
||||
| `--plan.max_video_frames` | `32` | Hard cap on frames per call (context-budget guard — don't exceed ~32 for a 32k context). |
|
||||
| `--plan.subtask_window_seconds` | `0` | Split long episodes into fixed windows for constant frame density (`0` = whole episode). |
|
||||
| `--plan.plan_max_steps` | `8` | Upper bound on subtasks per episode. |
|
||||
| `--plan.subtask_describe_first` | `true` | Run the describe→segment grounding pass (best subtask quality; +1 call/episode). |
|
||||
| `--plan.emit_plan` | `true` | Emit the numbered `plan` rows (`false` = subtasks + memory only). |
|
||||
| `--plan.n_task_rephrasings` | `10` | How many `task_aug` rephrasings to emit (`0` disables). |
|
||||
| `--plan.derive_task_from_video` | `if_short` | Use the dataset task as-is (`off`), only when it's missing/short (`if_short`), or always re-derive from video (`always`). |
|
||||
| `--plan.use_video_url` | `false` | Send a server-side video clip instead of embedded frames. |
|
||||
|
||||
### Interjections + VQA
|
||||
|
||||
| Flag | Default | What it does |
|
||||
| ----------------------------------------------- | ------- | ---------------------------------------------------------- |
|
||||
| `--interjections.max_interjections_per_episode` | `3` | Cap on interjection/speech pairs per episode. |
|
||||
| `--vqa.vqa_emission_hz` | `1.0` | How often VQA pairs are emitted. |
|
||||
| `--vqa.restrict_to_default_camera` | `false` | Ground VQA only on `--vlm.camera_key` (else every camera). |
|
||||
| `--executor.episode_parallelism` | `16` | Episodes processed concurrently within each phase. |
|
||||
|
||||
## Contributing new modules
|
||||
|
||||
The pipeline is built to grow, and **contributions are very welcome** —
|
||||
a brand-new module (say, trajectory traces or affordances), a new prompt
|
||||
template, a smarter grounding flow, or quality fixes to the existing
|
||||
`plan` / `interjections` / `vqa` modules.
|
||||
|
||||
Every module lives under
|
||||
`src/lerobot/annotations/steerable_pipeline/modules/`, shares the VLM
|
||||
client and the keyframe cache, writes its raw output to the staging
|
||||
tree, and plugs into the executor as its own phase. Got an idea? Open an
|
||||
issue or PR on [the repo](https://github.com/huggingface/lerobot).
|
||||
|
||||
## How recipes consume the output
|
||||
|
||||
The annotations are meant to be read by recipes (see
|
||||
[Language Columns and Recipes](./language_and_recipes)). Typically:
|
||||
|
||||
- low-level / high-level / memory-update branches read
|
||||
`subtask` / `plan` / `memory` from `language_persistent`.
|
||||
- an interjection-response branch reads `interjection` events plus the
|
||||
paired speech atom (merged into one assistant turn via `tool_calls_from`)
|
||||
and the matching `plan` refresh at the same timestamp.
|
||||
- a VQA branch reads the `(vqa, user)` and `(vqa, assistant)` pairs from
|
||||
`language_events`.
|
||||
|
||||
## Why state and events are split
|
||||
|
||||
Two ideas shape the design:
|
||||
|
||||
1. **Persistent state vs. exact events.** Persistent rows (`subtask`,
|
||||
`plan`, `memory`) apply to the whole episode and answer "what's true
|
||||
right now?". Event rows (`interjection`, `vqa`, speech) appear only on
|
||||
the one frame whose timestamp matches. Timestamps are copied straight
|
||||
from the source parquet — never recomputed in floating point.
|
||||
2. **One VLM pass.** All three modules share a single VLM client (the
|
||||
OpenAI-compatible client talking to the job's vLLM server), so you pay
|
||||
for one model load per dataset, not three.
|
||||
|
||||
## Re-running a single module
|
||||
|
||||
Each module stages its raw output to
|
||||
`<root>/.annotate_staging/episode_{N:06d}/<module>.jsonl`. This makes
|
||||
prompt iteration cheap: re-running one module overwrites only its own
|
||||
JSONL, then the writer recomposes the final parquet. Disable modules you
|
||||
don't want with `--plan.enabled=false` (and likewise
|
||||
`--interjections.enabled` / `--vqa.enabled`) to test one at a time.
|
||||
|
||||
## What the validator checks
|
||||
|
||||
Before the writer runs, `StagingValidator` confirms:
|
||||
|
||||
- every event row lands exactly on a real frame timestamp;
|
||||
- no speech / interjection pairs are left orphaned;
|
||||
- `plan` is refreshed at every interjection timestamp;
|
||||
- `memory` rows fall on subtask boundaries (a warning, not an error);
|
||||
- each VQA assistant `content` is valid JSON in one of the
|
||||
bbox / keypoint / count / attribute / spatial shapes;
|
||||
- every row goes to the column chosen by `column_for_style(style)`.
|
||||
|
||||
Any error aborts the writer. Pass `--skip_validation=true` to override
|
||||
while debugging.
|
||||
|
||||
## Where each module's ideas come from
|
||||
|
||||
- **`plan` — subtasks.** Hi Robot ([Shi 2025](https://arxiv.org/abs/2502.19417))
|
||||
for atom granularity ("pick up one piece of lettuce", "place bowl to
|
||||
box"); Pi0.7 ([Physical Intelligence 2025](https://pi.website/pi07))
|
||||
for "how, not what" detail.
|
||||
- **`plan` — memory.** MEM ([Torne 2026](https://arxiv.org/abs/2603.03596)):
|
||||
keep only the minimal relevant information — preserve outcomes, drop
|
||||
specific attributes.
|
||||
- **`interjections`.** Hi Robot's scenario taxonomy: negative task,
|
||||
situated correction, specific constraint, preference. Speech is a
|
||||
tool-call-only atom
|
||||
(`tool_calls=[{type:function, function:{name:"say", arguments:{text:...}}}]`).
|
||||
- **`vqa`.** ECoT ([Zawalski 2024](https://arxiv.org/abs/2407.08693)) for
|
||||
grounded features (pixel bounding boxes `[x_min, y_min, x_max, y_max]`,
|
||||
keypoints) and Steerable VLA Policies
|
||||
([Zhao 2025](https://arxiv.org/abs/2509.07626)) for multi-abstraction
|
||||
grounding. Pi0.7 also grounds answers across abstraction levels.
|
||||
|
||||
When improving a module, tweak its prompt template in
|
||||
`src/lerobot/annotations/steerable_pipeline/prompts/` rather than
|
||||
rewriting from scratch.
|
||||
|
||||
## Roughly how much it costs
|
||||
|
||||
Per episode, the pipeline makes about `max_steps` plan calls,
|
||||
`max_interjections_per_episode` interjection calls, and
|
||||
`vqa_emission_hz × episode_seconds` VQA calls. With the defaults (8
|
||||
subtasks, 1 interjection, 1 Hz × 3 pairs) on a 30-second episode, that's
|
||||
~50 VLM calls.
|
||||
|
||||
Storage stays small: `language_persistent` is at most tens of KB per
|
||||
episode (parquet dictionary-encodes the one entry that repeats across
|
||||
frames), and `language_events` is empty on most frames — its size scales
|
||||
with the number of emissions, not `num_frames × num_emissions`.
|
||||
@@ -1,313 +0,0 @@
|
||||
# Asynchronous Inference
|
||||
|
||||
With our [SmolVLA](https://huggingface.co/papers/2506.01844) we introduced a new way to run inference on real-world robots, **decoupling action prediction from action execution**.
|
||||
In this tutorial, we'll show how to use asynchronous inference (_async inference_) using a finetuned version of SmolVLA, and all the policies supported by LeRobot.
|
||||
**Try async inference with all the policies** supported by LeRobot!
|
||||
|
||||
**What you'll learn:**
|
||||
|
||||
1. Why asynchronous inference matters and how it compares to, more traditional, sequential inference.
|
||||
2. How to spin-up a `PolicyServer` and connect a `RobotClient` from the same machine, and even over the network.
|
||||
3. How to tune key parameters (`actions_per_chunk`, `chunk_size_threshold`) for your robot and policy.
|
||||
|
||||
If you get stuck, hop into our [Discord community](https://discord.gg/s3KuuzsPFb)!
|
||||
|
||||
In a nutshell: with _async inference_, your robot keeps acting while the policy server is already busy computing the next chunk of actions---eliminating "wait-for-inference" lags and unlocking smoother, more reactive behaviours.
|
||||
This is fundamentally different from synchronous inference (sync), where the robot stays idle while the policy computes the next chunk of actions.
|
||||
|
||||
---
|
||||
|
||||
## Getting started with async inference
|
||||
|
||||
You can read more information on asynchronous inference in our [blogpost](https://huggingface.co/blog/async-robot-inference). This guide is designed to help you quickly set up and run asynchronous inference in your environment.
|
||||
|
||||
First, install `lerobot` with the `async` tag, to install the extra dependencies required to run async inference.
|
||||
|
||||
```shell
|
||||
pip install -e ".[async]"
|
||||
```
|
||||
|
||||
Then, spin up a policy server (in one terminal, or in a separate machine) specifying the host address and port for the client to connect to.
|
||||
You can spin up a policy server running:
|
||||
|
||||
```shell
|
||||
python -m lerobot.async_inference.policy_server \
|
||||
--host=127.0.0.1 \
|
||||
--port=8080
|
||||
```
|
||||
|
||||
This will start a policy server listening on `127.0.0.1:8080` (`localhost`, port 8080). At this stage, the policy server is empty, as all information related to which policy to run and with which parameters are specified during the first handshake with the client. Spin up a client with:
|
||||
|
||||
```shell
|
||||
python -m lerobot.async_inference.robot_client \
|
||||
--server_address=127.0.0.1:8080 \ # SERVER: the host address and port of the policy server
|
||||
--robot.type=so100_follower \ # ROBOT: your robot type
|
||||
--robot.port=/dev/tty.usbmodem585A0076841 \ # ROBOT: your robot port
|
||||
--robot.id=follower_so100 \ # ROBOT: your robot id, to load calibration file
|
||||
--robot.cameras="{ laptop: {type: opencv, index_or_path: 0, width: 1920, height: 1080, fps: 30}, phone: {type: opencv, index_or_path: 0, width: 1920, height: 1080, fps: 30}}" \ # POLICY: the cameras used to acquire frames, with keys matching the keys expected by the policy
|
||||
--task="dummy" \ # POLICY: The task to run the policy on (`Fold my t-shirt`). Not necessarily defined for all policies, such as `act`
|
||||
--policy_type=your_policy_type \ # POLICY: the type of policy to run (smolvla, act, etc)
|
||||
--pretrained_name_or_path=user/model \ # POLICY: the model name/path on server to the checkpoint to run (e.g., lerobot/smolvla_base)
|
||||
--policy_device=mps \ # POLICY: the device to run the policy on, on the server (cuda, mps, xpu, cpu)
|
||||
--actions_per_chunk=50 \ # POLICY: the number of actions to output at once
|
||||
--chunk_size_threshold=0.5 \ # CLIENT: the threshold for the chunk size before sending a new observation to the server
|
||||
--aggregate_fn_name=weighted_average \ # CLIENT: the function to aggregate actions on overlapping portions
|
||||
--debug_visualize_queue_size=True # CLIENT: whether to visualize the queue size at runtime
|
||||
```
|
||||
|
||||
In summary, you need to specify instructions for:
|
||||
|
||||
- `SERVER`: the address and port of the policy server
|
||||
- `ROBOT`: the type of robot to connect to, the port to connect to, and the local `id` of the robot
|
||||
- `POLICY`: the type of policy to run, and the model name/path on server to the checkpoint to run. You also need to specify which device should the sever be using, and how many actions to output at once (capped at the policy max actions value).
|
||||
- `CLIENT`: the threshold for the chunk size before sending a new observation to the server, and the function to aggregate actions on overlapping portions. Optionally, you can also visualize the queue size at runtime, to help you tune the `CLIENT` parameters.
|
||||
|
||||
Importantly,
|
||||
|
||||
- `actions_per_chunk` and `chunk_size_threshold` are key parameters to tune for your setup.
|
||||
- `aggregate_fn_name` is the function to aggregate actions on overlapping portions. You can either add a new one to a registry of functions, or add your own in `robot_client.py` (see [here](NOTE:addlinktoLOC))
|
||||
- `debug_visualize_queue_size` is a useful tool to tune the `CLIENT` parameters.
|
||||
|
||||
## Done! You should see your robot moving around by now 😉
|
||||
|
||||
## Async vs. synchronous inference
|
||||
|
||||
Synchronous inference relies on interleaving action chunk prediction and action execution. This inherently results in _idle frames_, frames where the robot awaits idle the policy's output: a new action chunk.
|
||||
In turn, inference is plagued by evident real-time lags, where the robot simply stops acting due to the lack of available actions.
|
||||
With robotics models increasing in size, this problem risks becoming only more severe.
|
||||
|
||||
<p align="center">
|
||||
<img
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/async-inference/sync.png"
|
||||
width="80%"
|
||||
></img>
|
||||
</p>
|
||||
<p align="center">
|
||||
<i>Synchronous inference</i> makes the robot idle while the policy is
|
||||
computing the next chunk of actions.
|
||||
</p>
|
||||
|
||||
To overcome this, we design async inference, a paradigm where action planning and execution are decoupled, resulting in (1) higher adaptability and, most importantly, (2) no idle frames.
|
||||
Crucially, with async inference, the next action chunk is computed _before_ the current one is exhausted, resulting in no idleness.
|
||||
Higher adaptability is ensured by aggregating the different action chunks on overlapping portions, obtaining an up-to-date plan and a tighter control loop.
|
||||
|
||||
<p align="center">
|
||||
<img
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/async-inference/async.png"
|
||||
width="80%"
|
||||
></img>
|
||||
</p>
|
||||
<p align="center">
|
||||
<i>Asynchronous inference</i> results in no idleness because the next chunk is
|
||||
computed before the current chunk is exhausted.
|
||||
</p>
|
||||
|
||||
---
|
||||
|
||||
## Start the Policy Server
|
||||
|
||||
Policy servers are wrappers around a `PreTrainedPolicy` interfacing them with observations coming from a robot client.
|
||||
Policy servers are initialized as empty containers which are populated with the requested policy specified in the initial handshake between the robot client and the policy server.
|
||||
As such, spinning up a policy server is as easy as specifying the host address and port. If you're running the policy server on the same machine as the robot client, you can use `localhost` as the host address.
|
||||
|
||||
<hfoptions id="start_policy_server">
|
||||
<hfoption id="Command">
|
||||
```bash
|
||||
python -m lerobot.async_inference.policy_server \
|
||||
--host=127.0.0.1 \
|
||||
--port=8080
|
||||
```
|
||||
</hfoption>
|
||||
<hfoption id="API example">
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
from lerobot.async_inference.configs import PolicyServerConfig
|
||||
from lerobot.async_inference.policy_server import serve
|
||||
|
||||
config = PolicyServerConfig(
|
||||
host="localhost",
|
||||
port=8080,
|
||||
)
|
||||
serve(config)
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
This listens on `localhost:8080` for an incoming connection from the associated`RobotClient`, which will communicate which policy to run during the first client-server handshake.
|
||||
|
||||
---
|
||||
|
||||
## Launch the Robot Client
|
||||
|
||||
`RobotClient` is a wrapper around a `Robot` instance, which `RobotClient` connects to the (possibly remote) `PolicyServer`.
|
||||
The `RobotClient` streams observations to the `PolicyServer`, and receives action chunks obtained running inference on the server (which we assume to have better computational resources than the robot controller).
|
||||
|
||||
<hfoptions id="start_robot_client">
|
||||
<hfoption id="Command">
|
||||
```bash
|
||||
python -m lerobot.async_inference.robot_client \
|
||||
--server_address=127.0.0.1:8080 \ # SERVER: the host address and port of the policy server
|
||||
--robot.type=so100_follower \ # ROBOT: your robot type
|
||||
--robot.port=/dev/tty.usbmodem585A0076841 \ # ROBOT: your robot port
|
||||
--robot.id=follower_so100 \ # ROBOT: your robot id, to load calibration file
|
||||
--robot.cameras="{ laptop: {type: opencv, index_or_path: 0, width: 1920, height: 1080, fps: 30}, phone: {type: opencv, index_or_path: 0, width: 1920, height: 1080, fps: 30}}" \ # POLICY: the cameras used to acquire frames, with keys matching the keys expected by the policy
|
||||
--task="dummy" \ # POLICY: The task to run the policy on (`Fold my t-shirt`). Not necessarily defined for all policies, such as `act`
|
||||
--policy_type=your_policy_type \ # POLICY: the type of policy to run (smolvla, act, etc)
|
||||
--pretrained_name_or_path=user/model \ # POLICY: the model name/path on server to the checkpoint to run (e.g., lerobot/smolvla_base)
|
||||
--policy_device=mps \ # POLICY: the device to run the policy on, on the server
|
||||
--actions_per_chunk=50 \ # POLICY: the number of actions to output at once
|
||||
--chunk_size_threshold=0.5 \ # CLIENT: the threshold for the chunk size before sending a new observation to the server
|
||||
--aggregate_fn_name=weighted_average \ # CLIENT: the function to aggregate actions on overlapping portions
|
||||
--debug_visualize_queue_size=True # CLIENT: whether to visualize the queue size at runtime
|
||||
```
|
||||
</hfoption>
|
||||
<hfoption id="API example">
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
import threading
|
||||
from lerobot.robots.so_follower import SO100FollowerConfig
|
||||
from lerobot.cameras.opencv import OpenCVCameraConfig
|
||||
from lerobot.async_inference.configs import RobotClientConfig
|
||||
from lerobot.async_inference.robot_client import RobotClient
|
||||
from lerobot.async_inference.helpers import visualize_action_queue_size
|
||||
|
||||
# 1. Create the robot instance
|
||||
"""Check out the cameras available in your setup by running `python lerobot/find_cameras.py`"""
|
||||
# these cameras must match the ones expected by the policy
|
||||
# check the config.json on the Hub for the policy you are using
|
||||
camera_cfg = {
|
||||
"top": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
|
||||
"side": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30)
|
||||
}
|
||||
|
||||
robot_cfg = SO100FollowerConfig(
|
||||
port="/dev/tty.usbmodem585A0076841",
|
||||
id="follower_so100",
|
||||
cameras=camera_cfg
|
||||
)
|
||||
|
||||
# 3. Create client configuration
|
||||
client_cfg = RobotClientConfig(
|
||||
robot=robot_cfg,
|
||||
server_address="localhost:8080",
|
||||
policy_device="mps",
|
||||
client_device="cpu",
|
||||
policy_type="smolvla",
|
||||
pretrained_name_or_path="<user>/smolvla_async",
|
||||
chunk_size_threshold=0.5,
|
||||
actions_per_chunk=50, # make sure this is less than the max actions of the policy
|
||||
)
|
||||
|
||||
# 4. Create and start client
|
||||
client = RobotClient(client_cfg)
|
||||
|
||||
# 5. Specify the task
|
||||
task = "Don't do anything, stay still"
|
||||
|
||||
if client.start():
|
||||
# Start action receiver thread
|
||||
action_receiver_thread = threading.Thread(target=client.receive_actions, daemon=True)
|
||||
action_receiver_thread.start()
|
||||
|
||||
try:
|
||||
# Run the control loop
|
||||
client.control_loop(task)
|
||||
except KeyboardInterrupt:
|
||||
client.stop()
|
||||
action_receiver_thread.join()
|
||||
# (Optionally) plot the action queue size
|
||||
visualize_action_queue_size(client.action_queue_size)
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
The following two parameters are key in every setup:
|
||||
|
||||
<table>
|
||||
<thead>
|
||||
<tr>
|
||||
<th>Hyperparameter</th>
|
||||
<th>Default</th>
|
||||
<th>What it does</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
<tr>
|
||||
<td>
|
||||
<code>actions_per_chunk</code>
|
||||
</td>
|
||||
<td>50</td>
|
||||
<td>
|
||||
How many actions the policy outputs at once. Typical values: 10-50.
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>
|
||||
<code>chunk_size_threshold</code>
|
||||
</td>
|
||||
<td>0.7</td>
|
||||
<td>
|
||||
When the queue is ≤ 50% full, the client sends a fresh observation.
|
||||
Value in [0, 1].
|
||||
</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
|
||||
<Tip>
|
||||
Different values of `actions_per_chunk` and `chunk_size_threshold` do result
|
||||
in different behaviours.
|
||||
</Tip>
|
||||
|
||||
On the one hand, increasing the value of `actions_per_chunk` will result in reducing the likelihood of ending up with no actions to execute, as more actions will be available when the new chunk is computed.
|
||||
However, larger values of `actions_per_chunk` might also result in less precise actions, due to the compounding errors consequent to predicting actions over longer timespans.
|
||||
|
||||
On the other hand, increasing the value of `chunk_size_threshold` will result in sending out to the `PolicyServer` observations for inference more often, resulting in a larger number of updates action chunks, overlapping on significant portions. This results in high adaptability, in the limit predicting one action chunk for each observation, which is in turn only marginally consumed while a new one is produced.
|
||||
This option does also put more pressure on the inference pipeline, as a consequence of the many requests. Conversely, values of `chunk_size_threshold` close to 0.0 collapse to the synchronous edge case, whereby new observations are only sent out whenever the current chunk is exhausted.
|
||||
|
||||
We found the default values of `actions_per_chunk` and `chunk_size_threshold` to work well in the experiments we developed for the [SmolVLA paper](https://huggingface.co/papers/2506.01844), but recommend experimenting with different values to find the best fit for your setup.
|
||||
|
||||
### Tuning async inference for your setup
|
||||
|
||||
1. **Choose your computational resources carefully.** [PI0](https://huggingface.co/lerobot/pi0) occupies 14GB of memory at inference time, while [SmolVLA](https://huggingface.co/lerobot/smolvla_base) requires only ~2GB. You should identify the best computational resource for your use case keeping in mind smaller policies require less computational resources. The combination of policy and device used (CPU-intensive, using MPS, or the number of CUDA cores on a given NVIDIA GPU) directly impacts the average inference latency you should expect.
|
||||
2. **Adjust your `fps` based on inference latency.** While the server generates a new action chunk, the client is not idle and is stepping through its current action queue. If the two processes happen at fundamentally different speeds, the client might end up with an empty queue. As such, you should reduce your fps if you consistently run out of actions in queue.
|
||||
3. **Adjust `chunk_size_threshold`**.
|
||||
- Values closer to `0.0` result in almost sequential behavior. Values closer to `1.0` → send observation every step (more bandwidth, relies on good world-model).
|
||||
- We found values around 0.5-0.6 to work well. If you want to tweak this, spin up a `RobotClient` setting the `--debug_visualize_queue_size` to `True`. This will plot the action queue size evolution at runtime, and you can use it to find the value of `chunk_size_threshold` that works best for your setup.
|
||||
|
||||
<p align="center">
|
||||
<img
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/async-inference/queues.png"
|
||||
width="80%"
|
||||
></img>
|
||||
</p>
|
||||
<p align="center">
|
||||
<i>
|
||||
The action queue size is plotted at runtime when the
|
||||
`--debug_visualize_queue_size` flag is passed, for various levels of
|
||||
`chunk_size_threshold` (`g` in the SmolVLA paper).
|
||||
</i>
|
||||
</p>
|
||||
|
||||
---
|
||||
|
||||
## Conclusion
|
||||
|
||||
Asynchronous inference represents a significant advancement in real-time robotics control, addressing the fundamental challenge of inference latency that has long plagued robotics applications. Through this tutorial, you've learned how to implement a complete async inference pipeline that eliminates idle frames and enables smoother, more reactive robot behaviors.
|
||||
|
||||
**Key Takeaways:**
|
||||
|
||||
- **Paradigm Shift**: Async inference decouples action prediction from execution, allowing robots to continue acting while new action chunks are computed in parallel
|
||||
- **Performance Benefits**: Eliminates "wait-for-inference" lags that are inherent in synchronous approaches, becoming increasingly important as policy models grow larger
|
||||
- **Flexible Architecture**: The server-client design enables distributed computing, where inference can run on powerful remote hardware while maintaining real-time robot control
|
||||
- **Tunable Parameters**: Success depends on properly configuring `actions_per_chunk` and `chunk_size_threshold` for your specific hardware, policy, and task requirements
|
||||
- **Universal Compatibility**: Works with all LeRobot-supported policies, from lightweight ACT models to vision-language models like SmolVLA
|
||||
|
||||
Start experimenting with the default parameters, monitor your action queue sizes, and iteratively refine your setup to achieve optimal performance for your specific use case.
|
||||
If you want to discuss this further, hop into our [Discord community](https://discord.gg/s3KuuzsPFb), or open an issue on our [GitHub repository](https://github.com/huggingface/lerobot/issues).
|
||||
@@ -141,11 +141,6 @@ sample["target_message_indices"]
|
||||
|
||||
The renderer does not apply a tokenizer chat template. Policy processors decide how to serialize the messages for their backbone, which keeps the same dataset usable across SmolVLA, Pi0.5, and any future VLM that expects OpenAI-style chat messages.
|
||||
|
||||
## Blends
|
||||
|
||||
Blend recipes select one weighted sub-recipe deterministically from the sample index.
|
||||
`recipes/subtasks_vqa.yaml` trains the core blend — high-level subtask prediction, low-level execution, and VQA. `recipes/subtask_mem_vqa_speech.yaml` is the fuller variant that also adds memory updates and spoken interjection responses.
|
||||
|
||||
## Graceful absence
|
||||
|
||||
If both language columns are missing, `None`, or empty, `RenderMessagesStep` is a no-op.
|
||||
|
||||
@@ -0,0 +1,250 @@
|
||||
# Remote Inference (lerobot-policy-server)
|
||||
|
||||
Remote inference decouples GPU policy inference from robot control. A `lerobot-policy-server` process runs the policy on a GPU machine; the robot runs `lerobot-rollout --inference.type=remote` as a **weightless edge client** — no policy weights, no GPU, no policy processors on the robot. One GPU server can serve several robots at once, and the remote backend works with every rollout strategy (`base`, `sentry`, `highlight`, `dagger`, `episodic`).
|
||||
|
||||
Use remote inference when:
|
||||
|
||||
- The policy is too large or too slow for the machine attached to the robot (e.g. Pi0/Pi0.5 on a Raspberry Pi or laptop edge).
|
||||
- You want one GPU to serve a fleet of robots running the same policy.
|
||||
- You want to update or restart the inference side without touching the robots.
|
||||
|
||||
<Tip>
|
||||
|
||||
Remote inference requires the `async` extra on **both** sides: `pip install 'lerobot[async]'` (installs `eclipse-zenoh` and `msgpack`). The server additionally needs the extras of the policy it serves (e.g. `lerobot[pi]`, `lerobot[smolvla]`).
|
||||
|
||||
</Tip>
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
robot (edge, weightless) GPU machine
|
||||
┌───────────────────────────┐ ┌────────────────────────────┐
|
||||
│ lerobot-rollout │ │ lerobot-policy-server │
|
||||
│ --inference.type=remote │ zenoh │ one process = one │
|
||||
│ │ router │ (model, revision, GPU) │
|
||||
│ control loop @ fps │ ┌────────┐ │ │
|
||||
│ └─ pops local action ◄──┼───┤ zenohd ├─────┼─► inference worker thread │
|
||||
│ buffer (chunks) │ └────────┘ │ (round-robin over │
|
||||
│ │ observations ► │ client sessions) │
|
||||
│ network worker thread ───┼──► ◄ action │ │
|
||||
│ (publishes obs, merges │ chunks │ stateless per request │
|
||||
│ chunks into buffer) │ │ │
|
||||
└───────────────────────────┘ └────────────────────────────┘
|
||||
```
|
||||
|
||||
The client keeps a local **action buffer** filled with chunks of future actions, so the control loop never blocks on the network: short network blips are absorbed by the buffer and the robot keeps moving. The client self-clocks — it requests a new chunk whenever the buffer holds less than `--inference.buffer_time_s` seconds of playback.
|
||||
|
||||
The server is **stateless per request**: clients ship their RTC prefixes and a delay hint with every observation, so a server crash or restart loses zero control state and reconnects are trivial. In production both robots and servers _dial out_ to a `zenohd` router (NAT-friendly: nothing on the robot network needs an open inbound port).
|
||||
|
||||
## Quickstart on a LAN (peer mode, no router)
|
||||
|
||||
For a quick test on one network you can skip the router: the server listens directly and the robot connects to it.
|
||||
|
||||
On the GPU machine:
|
||||
|
||||
```bash
|
||||
lerobot-policy-server \
|
||||
--model.repo_or_path=${HF_USER}/my_pi0_policy \
|
||||
--default_task="pick up the cube" \
|
||||
--zenoh.mode=peer \
|
||||
--zenoh.listen_endpoints='["tcp/0.0.0.0:7447"]'
|
||||
```
|
||||
|
||||
Wait for `Policy server up: ...` (the model is downloaded, loaded, and warmed up first).
|
||||
|
||||
On the robot machine (replace `192.168.1.42` with the GPU machine's IP):
|
||||
|
||||
```bash
|
||||
lerobot-rollout \
|
||||
--strategy.type=base \
|
||||
--policy.path=${HF_USER}/my_pi0_policy \
|
||||
--inference.type=remote \
|
||||
--inference.zenoh_mode=peer \
|
||||
--inference.connect_endpoint=tcp/192.168.1.42:7447 \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/ttyACM0 \
|
||||
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||
--task="pick up the cube" \
|
||||
--duration=60
|
||||
```
|
||||
|
||||
`--policy.path` on the client resolves to a config-only download (no weights): it is used for pre-flight validation and action ordering, and doubles as the default service address. The client's `--policy.path` and `--task` must match the server's `--model.repo_or_path` and `--default_task` — that pair is the namespace the service is published under (see [Troubleshooting](#troubleshooting)).
|
||||
|
||||
## Production deployment (router)
|
||||
|
||||
In production, run a [zenoh router](https://zenoh.io/docs/getting-started/installation/) (`zenohd`) somewhere both sides can reach, and have robots and servers dial out to it:
|
||||
|
||||
```bash
|
||||
zenohd # listens on tcp/0.0.0.0:7447 by default
|
||||
```
|
||||
|
||||
Configure the server with a YAML manifest:
|
||||
|
||||
```yaml
|
||||
# server.yaml
|
||||
model:
|
||||
repo_or_path: lerobot/pi0_towels
|
||||
revision: main
|
||||
dtype: bfloat16 # optional cast after load
|
||||
device: cuda
|
||||
default_task: "fold the towel"
|
||||
serving_mode: auto # shared for verified chunk-stateless policies, exclusive otherwise
|
||||
max_sessions: 5
|
||||
warmup_inferences: 2
|
||||
trained_fps: 30.0
|
||||
rtc:
|
||||
enabled: true
|
||||
execution_horizon: 10
|
||||
max_guidance_weight: 10.0
|
||||
health_port: 9100 # /healthz + /metrics; 0 disables
|
||||
zenoh:
|
||||
mode: client
|
||||
connect_endpoints: ["tcp/router.gpu-cluster.internal:7447"]
|
||||
```
|
||||
|
||||
```bash
|
||||
lerobot-policy-server --manifest server.yaml
|
||||
```
|
||||
|
||||
Everything in the manifest can also be set directly on the CLI (`--model.repo_or_path=...`, `--max_sessions=...`, etc.). One process serves exactly one `(model, revision, dtype, device)` — to serve two models, or one model on two GPUs, run two processes. Dynamic model loading is deliberately unsupported: pre-warmed processes keep capacity planning honest.
|
||||
|
||||
On the robot, only the endpoint changes (the default `--inference.zenoh_mode=client` is already router mode):
|
||||
|
||||
```bash
|
||||
lerobot-rollout \
|
||||
--strategy.type=base \
|
||||
--policy.path=lerobot/pi0_towels \
|
||||
--inference.type=remote \
|
||||
--inference.connect_endpoint=tcp/router.gpu-cluster.internal:7447 \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/ttyACM0 \
|
||||
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||
--task="fold the towel" \
|
||||
--duration=600
|
||||
```
|
||||
|
||||
### TLS / mTLS
|
||||
|
||||
For traffic that leaves a trusted network, terminate TLS at the router and give both sides client certificates (all three PEM paths are required together):
|
||||
|
||||
```yaml
|
||||
# server.yaml (zenoh section)
|
||||
zenoh:
|
||||
mode: client
|
||||
connect_endpoints: ["tls/router.gpu-cluster.internal:7447"]
|
||||
tls_root_ca_certificate: /etc/lerobot/ca.pem
|
||||
tls_connect_certificate: /etc/lerobot/server.pem
|
||||
tls_connect_private_key: /etc/lerobot/server.key
|
||||
```
|
||||
|
||||
On the robot the equivalent flags are `--inference.tls_ca`, `--inference.tls_cert`, and `--inference.tls_key`, with `--inference.connect_endpoint=tls/...`.
|
||||
|
||||
<Tip>
|
||||
|
||||
Multicast scouting is always disabled: discovery is configuration, not protocol magic. If nothing connects, check the endpoints — there is no fallback discovery mechanism.
|
||||
|
||||
</Tip>
|
||||
|
||||
## RTC over the network
|
||||
|
||||
The remote engine reuses the [Real-Time Chunking](./rtc) machinery: the client keeps the chunk leftover and latency tracking locally and ships an action prefix plus a delay hint with every observation; the server runs prefix-conditioned chunk generation. This gives the same smooth chunk-to-chunk transitions as local RTC, with network latency folded into the delay computation.
|
||||
|
||||
RTC is enabled by default on both sides (`rtc.enabled: true`). Tune it from the client:
|
||||
|
||||
```bash
|
||||
lerobot-rollout \
|
||||
... \
|
||||
--inference.type=remote \
|
||||
--inference.rtc.execution_horizon=10 \
|
||||
--inference.rtc.max_guidance_weight=10.0
|
||||
```
|
||||
|
||||
If the server or its policy does not support RTC (only `pi0`, `pi05`, and `smolvla` are RTC-capable, and the server manifest must have `rtc.enabled: true`), the session is **downgraded to plain chunk-append** and the client logs:
|
||||
|
||||
```
|
||||
RTC downgraded to chunk-append (server does not support RTC)
|
||||
```
|
||||
|
||||
The robot still runs — chunks are simply appended to the buffer without prefix blending, which can produce visible seams between chunks on slow policies.
|
||||
|
||||
## Fail-safe behavior
|
||||
|
||||
The client runs a fail-safe state machine (`CONNECTING → STREAMING → DEGRADED → STALLED → RECONNECTING → DEAD`). A bad initial deployment fails fast: `lerobot-rollout` aborts before the robot moves if the handshake or validation fails. Once streaming, faults degrade in stages:
|
||||
|
||||
| Condition | Behavior |
|
||||
| -------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| Short network blip / late chunk | The robot rides its action buffer; state goes `DEGRADED` after `--inference.degraded_after_s` (default 1.0 s) without a fresh chunk |
|
||||
| Buffered actions older than `max_action_age_s` | Stale actions are dropped (never executed); default `--inference.max_action_age_s=3.0` |
|
||||
| Buffer runs dry (`STALLED`) | Fallback per `--inference.fallback`: `hold` (default — robot holds its last commanded position), `repeat_last`, or `zero` |
|
||||
| Server liveliness lost / repeated request timeouts | `RECONNECTING`: re-handshake with exponential backoff (`reconnect_initial_backoff_s=0.5` doubling up to `reconnect_max_backoff_s=10.0`) |
|
||||
| Reconnected server runs a different model/revision | Hard refusal (`DEAD`) — the client never executes wrong-model chunks |
|
||||
| Offline longer than `max_offline_s` (default 60 s) | `DEAD`: the engine signals the rollout's shutdown event for a clean stop |
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
`--inference.fallback=zero` is required for velocity-controlled robots: for them "send nothing" means "keep the last velocity", so an explicit zero command is the only safe stop. For position-controlled arms the default `hold` is safe.
|
||||
|
||||
</Tip>
|
||||
|
||||
Server restarts are equally graceful: on SIGTERM the server drops its liveliness token first (clients ride their buffers through the drain), finishes the in-flight inference, and exits. Clients reconnect when the replacement comes up.
|
||||
|
||||
## Serving multiple robots
|
||||
|
||||
`max_sessions` caps concurrent clients per server process. A single inference worker thread serializes GPU access and round-robins over sessions with a pending observation; per-client newest-wins mailboxes mean overload degrades into longer cycle times (larger but correct client-side delays), never into queue buildup.
|
||||
|
||||
A rough capacity estimate, keeping ~20% headroom:
|
||||
|
||||
```
|
||||
N_robots ≈ 0.8 / (rate × inference_time)
|
||||
```
|
||||
|
||||
where `rate` is each robot's chunk-request rate in Hz (how often the client's buffer dips below `buffer_time_s`) and `inference_time` is the server's seconds per chunk. For example, at 100 ms per chunk and ~2 chunk requests per second per robot: `N ≈ 0.8 / (2 × 0.1) = 4` robots.
|
||||
|
||||
The actual serving mode is classified per policy family, never inferred:
|
||||
|
||||
- **shared** — verified chunk-stateless policies (`act`, `pi0`, `pi05`, and `smolvla` with `n_obs_steps=1`) serve up to `max_sessions` clients from one policy instance.
|
||||
- **exclusive** — stateful families (diffusion-family policies, `smolvla` with observation history, and any unverified policy) are forced to `max_sessions=1`. Run one server process per robot for these.
|
||||
|
||||
`serving_mode: auto` (the default) resolves this automatically; you may force `exclusive`, but `shared` can never override a stateful classification.
|
||||
|
||||
## Observability
|
||||
|
||||
With `health_port` set (default 9100), the server exposes:
|
||||
|
||||
- `GET /healthz` — `200 ok` while the inference worker is alive, `503` otherwise. Wire this to your orchestrator's liveness probe.
|
||||
- `GET /metrics` — Prometheus text format: `lerobot_policy_server_requests_total`, `errors_total`, `superseded_total`, `dropped_unknown_client_total`, `sessions_opened_total`, `sessions_closed_total`, `active_sessions`, `server_load`.
|
||||
|
||||
Every inference request also emits one structured audit line on the `lerobot.policy_server.audit` logger:
|
||||
|
||||
```json
|
||||
{
|
||||
"session_id": "9f2c...",
|
||||
"client_uuid": "robot-07",
|
||||
"seq_id": 412,
|
||||
"episode_id": 3,
|
||||
"queue_wait_ms": 1.8,
|
||||
"inference_ms": 93.2,
|
||||
"superseded": 0,
|
||||
"outcome": "ok"
|
||||
}
|
||||
```
|
||||
|
||||
`(session_id, seq_id)` correlates a server-side audit line with the client's request. Set a stable `--inference.client_uuid` per robot (instead of the default fresh UUID per run) for fleet-wide log correlation, and use `--inference.tags` to forward free-form labels in the handshake.
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
**`No policy server answered status query at '@lerobot/...'`**
|
||||
|
||||
The client found no server under the key it dialed. Either the endpoint is wrong (check `--inference.connect_endpoint`, the router, and firewalls), or the **service namespace** does not match. The namespace is the `(model_id, revision, task)` triple: on the client it comes from `--inference.service_model_id` (default: `--policy.path`), `--inference.service_revision` (default: `main`), and `--inference.service_task` (default: the rollout `--task`); on the server from `model.repo_or_path`, `model.revision`, and `service_name` (default: a slug of `default_task`). A robot task string that differs from the server's `default_task` is the most common cause — fix the task, or pin the namespace explicitly with `--inference.service_task` on the client / `service_name` in the manifest.
|
||||
|
||||
**`Action name/order mismatch between server policy and this robot`**
|
||||
|
||||
The hard sync-safety contract: chunk columns map to motors **by order**, so the robot's ordered action keys must exactly equal the policy's `action_feature_names`. This fires when the robot type, motor naming, or rename map differs from the training setup. Use the same robot type (and rename map) the policy was trained with.
|
||||
|
||||
**`RTC requested but this server/policy does not support it — downgrading to chunk-append`**
|
||||
|
||||
Informational, not fatal. Enable RTC in the server manifest (`rtc.enabled: true`) and make sure the policy family is RTC-capable (`pi0`, `pi05`, `smolvla`). Otherwise, expect chunk-append behavior (see [RTC over the network](#rtc-over-the-network)).
|
||||
|
||||
**`server full: N/N sessions active`**
|
||||
|
||||
The session-open was rejected at capacity. Raise `max_sessions` (shared mode only), or point the robot at another server replica — the rejection includes the current load so orchestration can retry elsewhere.
|
||||
+9
-9
@@ -151,18 +151,18 @@ lerobot-rollout \
|
||||
--device=cuda
|
||||
```
|
||||
|
||||
## How It Differs from the Async Inference in LeRobot
|
||||
## How It Relates to Remote Inference
|
||||
|
||||
Both RTC and [async inference](./async) improve real-time robot control, but they solve different problems.
|
||||
Both RTC and [remote inference](./remote_inference) improve real-time robot control, but they solve different problems.
|
||||
|
||||
| Aspect | Async Inference | RTC |
|
||||
| ------------- | -------------------------------------------------------------------------- | --------------------------------------------------- |
|
||||
| **Problem** | Idle frames while waiting for inference | Discontinuities between action chunks |
|
||||
| **Solution** | Decouple prediction from execution | Guide new chunks to continue smoothly from previous |
|
||||
| **Benefit** | No waiting, continuous action | Smooth transitions, natural motion |
|
||||
| **Best Used** | Async inference is best used with large models with high inference latency | Flow-matching based policies |
|
||||
| Aspect | Remote Inference | RTC |
|
||||
| ------------- | ------------------------------------------------------------------------ | --------------------------------------------------- |
|
||||
| **Problem** | The policy is too large (or too slow) for the edge machine | Discontinuities between action chunks |
|
||||
| **Solution** | Run inference on a GPU server; the robot executes buffered action chunks | Guide new chunks to continue smoothly from previous |
|
||||
| **Benefit** | Weightless edge clients, one GPU serves many robots | Smooth transitions, natural motion |
|
||||
| **Best Used** | Large models with high inference latency, robot fleets | Flow-matching based policies |
|
||||
|
||||
**Use both together** for maximum smoothness and reactivity!
|
||||
**Use both together** (`--inference.type=remote` with `--inference.rtc.execution_horizon=...`) for maximum smoothness and reactivity: the remote engine reuses RTC's chunk-merging machinery client-side while the server runs prefix-conditioned chunk generation.
|
||||
|
||||
## Advanced: Debug Tracking
|
||||
|
||||
|
||||
@@ -1,109 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Launch ``lerobot-annotate`` on a Hugging Face job (vllm + Qwen3.6-27B VLM).
|
||||
|
||||
Spawns one single-GPU ``h200`` job that:
|
||||
|
||||
1. installs ``lerobot`` from ``main`` plus the annotation extras,
|
||||
2. boots one vllm server with Qwen3.6-27B (dense VLM),
|
||||
3. runs the plan / interjections / vqa modules across the dataset
|
||||
in free-form mode (each episode generates its own subtasks +
|
||||
memory),
|
||||
4. uploads the annotated dataset to ``--new_repo_id`` (when set)
|
||||
or back to ``--repo_id``.
|
||||
|
||||
Usage:
|
||||
|
||||
HF_TOKEN=hf_... uv run python examples/annotations/run_hf_job.py
|
||||
|
||||
Adjust ``CMD`` (dataset, model, hub repo) and ``flavor`` below for your
|
||||
run. For larger datasets, scale to ``h200x4`` and raise
|
||||
``--vlm.parallel_servers`` / ``--vlm.num_gpus`` to match.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from huggingface_hub import get_token, run_job
|
||||
|
||||
token = os.environ.get("HF_TOKEN") or get_token()
|
||||
if not token:
|
||||
raise RuntimeError("No HF token. Run `huggingface-cli login` or `export HF_TOKEN=hf_...`")
|
||||
|
||||
CMD = (
|
||||
"apt-get update -qq && apt-get install -y -qq git ffmpeg && "
|
||||
"pip install --no-deps "
|
||||
"'lerobot @ git+https://github.com/huggingface/lerobot.git@main' && "
|
||||
"pip install --upgrade-strategy only-if-needed "
|
||||
"datasets pyarrow av jsonlines draccus gymnasium torchcodec mergedeep pyyaml-include toml typing-inspect "
|
||||
"openai && "
|
||||
"export VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS=0 && "
|
||||
"export VLLM_VIDEO_BACKEND=pyav && "
|
||||
"lerobot-annotate "
|
||||
"--repo_id=pepijn223/robocasa_pretrain_human300_v4 "
|
||||
"--new_repo_id=pepijn223/robocasa_pretrain_human300_v4_annotated5 "
|
||||
"--push_to_hub=true "
|
||||
"--vlm.backend=openai "
|
||||
"--vlm.model_id=Qwen/Qwen3.6-27B "
|
||||
"--vlm.parallel_servers=1 "
|
||||
"--vlm.num_gpus=1 "
|
||||
'--vlm.serve_command="vllm serve Qwen/Qwen3.6-27B '
|
||||
"--tensor-parallel-size 1 --max-model-len 32768 "
|
||||
'--gpu-memory-utilization 0.8 --uvicorn-log-level warning --port {port}" '
|
||||
"--vlm.serve_ready_timeout_s=1800 "
|
||||
"--vlm.client_concurrency=128 "
|
||||
"--vlm.max_new_tokens=512 "
|
||||
"--vlm.temperature=0.7 "
|
||||
"--executor.episode_parallelism=16 "
|
||||
"--vlm.chat_template_kwargs='{\"enable_thinking\": false}' "
|
||||
"--vlm.camera_key=observation.images.robot0_agentview_right "
|
||||
# Phase 1 — plan module (subtasks + memory).
|
||||
# Embed decoded frames (not a file:// clip): if clip extraction fails,
|
||||
# the video_url path silently sends no video and the VLM hallucinates.
|
||||
"--plan.use_video_url=false "
|
||||
"--plan.frames_per_second=1.0 "
|
||||
# 32 frames ≈ 8-10k vision tokens, fits the 32768 context. Don't push
|
||||
# toward 128 — that overflows the context (BadRequestError 400).
|
||||
"--plan.max_video_frames=32 "
|
||||
# Window long episodes into 32s chunks (constant 1 fps density) so they
|
||||
# get more subtasks; per-window spans are merged + stitched. 0 disables.
|
||||
"--plan.subtask_window_seconds=32 "
|
||||
# RoboCasa: the dataset task string is authoritative (eval uses it), so
|
||||
# keep it driving subtasks. ``always`` would throw it away and hallucinate.
|
||||
"--plan.derive_task_from_video=off "
|
||||
# No task augmentation: eval conditions on the exact task strings, so
|
||||
# rephrasings are unused at best and harmful when they drift.
|
||||
"--plan.n_task_rephrasings=0 "
|
||||
# Keep subtask decomposition tight for atomic tasks.
|
||||
"--plan.plan_max_steps=10 "
|
||||
# Only subtasks + memory — skip the numbered "plan" rows. true re-enables.
|
||||
"--plan.emit_plan=false "
|
||||
# The describe->segment grounding pass (+1 VLM call/episode) is ON by
|
||||
# default; pass --plan.subtask_describe_first=false to skip it.
|
||||
# Phase 2 — interjections + speech.
|
||||
"--interjections.max_interjections_per_episode=6 "
|
||||
# Phase 4 — general VQA: disabled for this run.
|
||||
"--vqa.enabled=false"
|
||||
)
|
||||
|
||||
job = run_job(
|
||||
image="vllm/vllm-openai:latest",
|
||||
command=["bash", "-c", CMD],
|
||||
flavor="h200",
|
||||
secrets={"HF_TOKEN": token},
|
||||
timeout="2h",
|
||||
)
|
||||
print(f"Job URL: {job.url}")
|
||||
print(f"Job ID: {job.id}")
|
||||
@@ -0,0 +1,115 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Example manifest for `lerobot-policy-server --manifest server.yaml`.
|
||||
#
|
||||
# One process = one (model, revision, dtype, device) on one GPU. Dynamic
|
||||
# model loading is deliberately unsupported: pre-warmed processes keep
|
||||
# capacity planning honest. Every field below can also be overridden on
|
||||
# the command line via draccus, e.g. --model.repo_or_path=... or
|
||||
# --zenoh.connect_endpoints='["tcp/other-router:7447"]'.
|
||||
#
|
||||
# Field names mirror the dataclasses in src/lerobot/policy_server/manifest.py.
|
||||
|
||||
# --- Which policy this process serves, and where it runs ------------------
|
||||
model:
|
||||
# Hub repo id (org/name) or a local checkpoint directory. Required.
|
||||
repo_or_path: lerobot/pi0_towels
|
||||
# Hub revision: branch, tag, or commit sha.
|
||||
revision: main
|
||||
# Optional torch dtype cast applied after load (e.g. "bfloat16",
|
||||
# "float16"). null keeps the checkpoint's native dtype.
|
||||
dtype: bfloat16
|
||||
# Inference device, e.g. "cuda", "cuda:1", "cpu".
|
||||
device: cuda
|
||||
|
||||
# --- Task namespace --------------------------------------------------------
|
||||
# The task this service is published under. VLA clients may override the
|
||||
# task per session unless `pin_task` is true, in which case session opens
|
||||
# with a different task string are rejected.
|
||||
default_task: "fold the towel"
|
||||
pin_task: false
|
||||
# Optional override for the <task_slug> key segment of the Zenoh prefix
|
||||
# (defaults to a slug of `default_task`).
|
||||
service_name: ""
|
||||
|
||||
# --- Serving mode & capacity ------------------------------------------------
|
||||
# "auto" resolves from the policy classification: shared for verified
|
||||
# chunk-stateless policies (act/pi0/pi05, smolvla with n_obs_steps=1),
|
||||
# exclusive otherwise. Chunk-stateful policies — e.g. diffusion, whose
|
||||
# predict_action_chunk reads select_action-fed queues — are always forced
|
||||
# to "exclusive" (max_sessions=1); "shared" cannot override that.
|
||||
serving_mode: auto
|
||||
|
||||
# Capacity rule-of-thumb: with t = server seconds per inference, r = each
|
||||
# client's request rate (self-clocked to ~1-4 Hz, not the control rate),
|
||||
# H = RTC execution horizon, and dt = control period:
|
||||
# max_sessions ~= min( 0.8 / (r*t), (H*dt/2 - network RTT) / t )
|
||||
# e.g. ACT @ 20 ms, 1 Hz refresh -> ~40 clients/GPU; Pi0 @ 150 ms -> ~5.
|
||||
# Session opens beyond this are rejected with the current load in the
|
||||
# reply, so clients retry another replica.
|
||||
max_sessions: 5
|
||||
|
||||
# Dummy inferences run at startup so the first real request does not pay
|
||||
# for CUDA graph/kernel warmup.
|
||||
warmup_inferences: 2
|
||||
|
||||
# --- FPS contract -----------------------------------------------------------
|
||||
# Control rate the policy was trained at. Clients reporting a different
|
||||
# fps get a warning — or a hard reject when `strict_fps` is true.
|
||||
trained_fps: 30.0
|
||||
strict_fps: false
|
||||
|
||||
# --- Real Time Chunking (RTC) -----------------------------------------------
|
||||
# Global to this process: init_rtc_processor mutates the policy instance,
|
||||
# so RTC is a per-process decision, not per-session. Only rtc-capable
|
||||
# families (pi0/pi05/smolvla) honor it; others are downgraded to plain
|
||||
# chunk-append at session open.
|
||||
rtc:
|
||||
enabled: true
|
||||
# Number of actions executed from each chunk before the next chunk is
|
||||
# blended in (the H in the capacity formula above).
|
||||
execution_horizon: 10
|
||||
|
||||
# --- Housekeeping ------------------------------------------------------------
|
||||
# Sessions with no liveliness token and no traffic for this long are
|
||||
# garbage-collected (belt-and-braces behind liveliness GC).
|
||||
session_idle_timeout_s: 300.0
|
||||
|
||||
# --- Transport ----------------------------------------------------------------
|
||||
# Robots and servers both *dial out* to a zenohd router in production
|
||||
# (mode: client). mode: peer + listen_endpoints supports router-less LAN
|
||||
# and loopback test deployments. Multicast scouting is always disabled:
|
||||
# fleet discovery is configuration, not protocol magic.
|
||||
zenoh:
|
||||
mode: client
|
||||
connect_endpoints:
|
||||
- tcp/router.gpu-cluster.internal:7447
|
||||
listen_endpoints: []
|
||||
# mTLS material (PEM paths). All three are required for tls/ endpoints;
|
||||
# leave them null for plain tcp/ inside a trusted network.
|
||||
# tls_root_ca_certificate: /etc/lerobot/tls/ca.pem
|
||||
# tls_connect_certificate: /etc/lerobot/tls/server.pem
|
||||
# tls_connect_private_key: /etc/lerobot/tls/server.key
|
||||
# Escape hatch: raw JSON5 merged into the zenoh config last.
|
||||
# extra_config_json5: '{transport: {link: {tx: {queue: {size: {data: 4}}}}}}'
|
||||
|
||||
# --- Observability -------------------------------------------------------------
|
||||
# HTTP health + Prometheus metrics port; 0 disables the endpoint.
|
||||
health_port: 9100
|
||||
|
||||
# Optional bounded request/response capture for offline replay.
|
||||
debug:
|
||||
capture_dir: null
|
||||
capture_max: 256
|
||||
@@ -1,17 +0,0 @@
|
||||
from lerobot.async_inference.configs import PolicyServerConfig
|
||||
from lerobot.async_inference.policy_server import serve
|
||||
|
||||
|
||||
def main():
|
||||
host = ... # something like "127.0.0.1" if you're exposing to localhost
|
||||
port = ... # something like 8080
|
||||
|
||||
config = PolicyServerConfig(
|
||||
host=host,
|
||||
port=port,
|
||||
)
|
||||
serve(config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,62 +0,0 @@
|
||||
import threading
|
||||
|
||||
from lerobot.async_inference.configs import RobotClientConfig
|
||||
from lerobot.async_inference.helpers import visualize_action_queue_size
|
||||
from lerobot.async_inference.robot_client import RobotClient
|
||||
from lerobot.cameras.opencv import OpenCVCameraConfig
|
||||
from lerobot.robots.so_follower import SO100FollowerConfig
|
||||
|
||||
|
||||
def main():
|
||||
# these cameras must match the ones expected by the policy - find your cameras with lerobot-find-cameras
|
||||
# check the config.json on the Hub for the policy you are using to see the expected camera specs
|
||||
camera_cfg = {
|
||||
"up": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
|
||||
"side": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30),
|
||||
}
|
||||
|
||||
# # find ports using lerobot-find-port
|
||||
follower_port = ... # something like "/dev/tty.usbmodem58760431631"
|
||||
|
||||
# # the robot ids are used the load the right calibration files
|
||||
follower_id = ... # something like "follower_so100"
|
||||
|
||||
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id, cameras=camera_cfg)
|
||||
|
||||
server_address = ... # something like "127.0.0.1:8080" if using localhost
|
||||
|
||||
# 3. Create client configuration
|
||||
client_cfg = RobotClientConfig(
|
||||
robot=robot_cfg,
|
||||
server_address=server_address,
|
||||
policy_device="mps",
|
||||
client_device="cpu",
|
||||
policy_type="act",
|
||||
pretrained_name_or_path="<user>/robot_learning_tutorial_act",
|
||||
chunk_size_threshold=0.5, # g
|
||||
actions_per_chunk=50, # make sure this is less than the max actions of the policy
|
||||
)
|
||||
|
||||
# 4. Create and start client
|
||||
client = RobotClient(client_cfg)
|
||||
|
||||
# 5. Provide a textual description of the task
|
||||
task = ...
|
||||
|
||||
if client.start():
|
||||
# Start action receiver thread
|
||||
action_receiver_thread = threading.Thread(target=client.receive_actions, daemon=True)
|
||||
action_receiver_thread.start()
|
||||
|
||||
try:
|
||||
# Run the control loop
|
||||
client.control_loop(task)
|
||||
except KeyboardInterrupt:
|
||||
client.stop()
|
||||
action_receiver_thread.join()
|
||||
# (Optionally) plot the action queue size
|
||||
visualize_action_queue_size(client.action_queue_size)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
+21
-49
@@ -85,11 +85,6 @@ dependencies = [
|
||||
"termcolor>=2.4.0,<4.0.0",
|
||||
"tqdm>=4.66.0,<5.0.0",
|
||||
|
||||
# Training utilities
|
||||
# EMA of policy parameters (Diffusion Policy / pi05 style). Tiny
|
||||
# pure-python dependency — preferred over a hand-rolled implementation.
|
||||
"ema-pytorch>=0.7.7,<1.0.0",
|
||||
|
||||
# Build tools (required by opencv-python-headless on some platforms)
|
||||
"cmake>=3.29.0.1,<4.2.0",
|
||||
"setuptools>=71.0.0,<81.0.0",
|
||||
@@ -120,8 +115,8 @@ dataset = [
|
||||
]
|
||||
training = [
|
||||
"lerobot[dataset]",
|
||||
"accelerate>=1.10.0,<2.0.0",
|
||||
"wandb>=0.24.0,<0.25.0",
|
||||
"wandb>=0.24.0,<0.28.0",
|
||||
"lerobot[accelerate-dep]",
|
||||
]
|
||||
hardware = [
|
||||
"lerobot[pynput-dep]",
|
||||
@@ -147,8 +142,8 @@ pygame-dep = ["pygame>=2.5.1,<2.7.0"]
|
||||
# (noble ships urdfdom 3.x). Cap below 0.9.16 until system urdfdom 4.x is broadly available.
|
||||
placo-dep = ["placo>=0.9.6,<0.9.16"]
|
||||
transformers-dep = ["transformers>=5.4.0,<5.6.0"]
|
||||
sentencepiece-dep = ["sentencepiece>=0.2.0,<0.3.0"] # FAST action tokenizer backend (pi052, pi0_fast)
|
||||
grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"]
|
||||
grpcio-dep = ["grpcio>=1.73.1,<2.0.0", "protobuf>=6.31.1,<8.0.0"]
|
||||
accelerate-dep = ["accelerate>=1.14.0,<2.0.0"]
|
||||
can-dep = ["python-can>=4.2.0,<5.0.0"]
|
||||
peft-dep = ["peft>=0.18.0,<1.0.0"]
|
||||
scipy-dep = ["scipy>=1.14.0,<2.0.0"]
|
||||
@@ -183,7 +178,12 @@ unitree_g1 = [
|
||||
"lerobot[matplotlib-dep]",
|
||||
"lerobot[pygame-dep]",
|
||||
]
|
||||
reachy2 = ["reachy2_sdk>=1.0.15,<1.1.0"]
|
||||
# reachy2-sdk caps grpcio<=1.73.1 and protobuf<=6.32.0; quarantined here so downstream users aren't held back. reachy2-sdk is unlikely to release new versions.
|
||||
reachy2 = [
|
||||
"reachy2_sdk>=1.0.15,<1.1.0",
|
||||
"grpcio<=1.73.1",
|
||||
"protobuf<=6.32.0",
|
||||
]
|
||||
# Seeed Studio reBot B601-DM follower (motorbridge / CAN) + StarArm102 / reBot Arm 102
|
||||
# leader (motorbridge-smart-servo / FashionStar UART servos).
|
||||
rebot = ["lerobot[motorbridge-dep]", "lerobot[motorbridge-smart-servo-dep]"]
|
||||
@@ -203,9 +203,9 @@ wallx = [
|
||||
"torchdiffeq>=0.2.4,<0.3.0",
|
||||
"lerobot[qwen-vl-utils-dep]",
|
||||
]
|
||||
pi = ["lerobot[transformers-dep]", "lerobot[scipy-dep]", "lerobot[sentencepiece-dep]"]
|
||||
pi = ["lerobot[transformers-dep]", "lerobot[scipy-dep]"]
|
||||
molmoact2 = ["lerobot[transformers-dep]", "lerobot[peft-dep]", "lerobot[scipy-dep]"]
|
||||
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0"]
|
||||
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "lerobot[accelerate-dep]"]
|
||||
multi_task_dit = ["lerobot[transformers-dep]", "lerobot[diffusers-dep]"]
|
||||
groot = [
|
||||
"lerobot[transformers-dep]",
|
||||
@@ -222,47 +222,26 @@ robometer = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]", "lerobot
|
||||
topreward = ["lerobot[transformers-dep]"]
|
||||
xvla = ["lerobot[transformers-dep]"]
|
||||
eo1 = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||
hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
||||
hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.14,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
||||
vla_jepa = ["lerobot[transformers-dep]", "lerobot[diffusers-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||
|
||||
# Features
|
||||
async = ["lerobot[grpcio-dep]", "lerobot[matplotlib-dep]"]
|
||||
# Remote inference over Zenoh: lerobot-policy-server + lerobot-rollout --inference.type=remote.
|
||||
# Keep zenohd routers on the same minor version as the Python binding.
|
||||
async = ["eclipse-zenoh>=1.9,<2.0", "msgpack>=1.0.0,<2.0.0"]
|
||||
peft = ["lerobot[transformers-dep]", "lerobot[peft-dep]"]
|
||||
|
||||
# Annotation pipeline (lerobot-annotate). The only backend is ``openai``,
|
||||
# which talks to any OpenAI-compatible server (``vllm serve`` /
|
||||
# ``transformers serve`` / hosted). Distributed runs use Hugging Face Jobs
|
||||
# (see examples/annotations/run_hf_job.py).
|
||||
annotations = [
|
||||
"lerobot[dataset]",
|
||||
"lerobot[transformers-dep]",
|
||||
"openai>=1.40,<2.0",
|
||||
# ``vllm`` is intentionally NOT a hard dep: it pins an older torch, and
|
||||
# uv's single unified lock would then cap ``torch`` for every extra
|
||||
# (e.g. forcing 2.8 while ``torchcodec`` in [dataset] needs 2.11 -> ABI
|
||||
# break in CI). The HF Jobs image (``vllm/vllm-openai``) provides vLLM;
|
||||
# install it locally only if you run your own ``vllm serve``.
|
||||
]
|
||||
|
||||
# Tool implementations under src/lerobot/tools/. Each tool's dependencies
|
||||
# are isolated so adding a new tool doesn't bloat the base install.
|
||||
# Currently only `say` (Kyutai pocket-tts; CPU-only, ~100M params).
|
||||
tools = [
|
||||
"pocket-tts>=1.0.0,<3.0.0",
|
||||
"scipy>=1.11.0,<2.0.0", # SayTool.output_dir uses scipy.io.wavfile
|
||||
]
|
||||
|
||||
# Development
|
||||
dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1", "mypy>=1.19.1", "ruff>=0.14.1", "lerobot[notebook]"]
|
||||
dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools>=1.73.1,<2.0.0", "mypy>=1.19.1", "ruff>=0.14.1", "lerobot[notebook]"]
|
||||
notebook = ["jupyter>=1.0.0,<2.0.0", "ipykernel>=6.0.0,<7.0.0"]
|
||||
test = ["pytest>=8.1.0,<9.0.0", "pytest-timeout>=2.4.0,<3.0.0", "pytest-cov>=5.0.0,<8.0.0", "mock-serial>=0.0.1,<0.1.0 ; sys_platform != 'win32'"]
|
||||
video_benchmark = ["scikit-image>=0.23.2,<0.26.0", "pandas>=2.2.2,<2.4.0"]
|
||||
|
||||
# Simulation
|
||||
# NOTE: Explicitly listing scipy helps flatten the dependecy tree.
|
||||
aloha = ["lerobot[dataset]", "gym-aloha>=0.1.2,<0.2.0", "lerobot[scipy-dep]"]
|
||||
aloha = ["lerobot[dataset]", "gym-aloha>=0.1.4,<0.2.0", "lerobot[scipy-dep]"]
|
||||
pusht = ["lerobot[dataset]", "gym-pusht>=0.1.5,<0.2.0", "pymunk>=6.6.0,<7.0.0"] # TODO: Fix pymunk version in gym-pusht instead
|
||||
libero = ["lerobot[dataset]", "lerobot[transformers-dep]", "hf-libero>=0.1.3,<0.2.0; sys_platform == 'linux'", "lerobot[scipy-dep]"]
|
||||
libero = ["lerobot[dataset]", "lerobot[transformers-dep]", "hf-libero>=0.1.4,<0.2.0; sys_platform == 'linux'", "lerobot[scipy-dep]"]
|
||||
metaworld = ["lerobot[dataset]", "metaworld==3.0.0", "lerobot[scipy-dep]"]
|
||||
# NOTE: vlabench is NOT exposed as a `lerobot` extra. Its only distribution
|
||||
# is the OpenMOSS/VLABench GitHub repo (package name `VLABench`, no PyPI
|
||||
@@ -346,10 +325,8 @@ lerobot-find-joint-limits="lerobot.scripts.lerobot_find_joint_limits:main"
|
||||
lerobot-imgtransform-viz="lerobot.scripts.lerobot_imgtransform_viz:main"
|
||||
lerobot-edit-dataset="lerobot.scripts.lerobot_edit_dataset:main"
|
||||
lerobot-setup-can="lerobot.scripts.lerobot_setup_can:main"
|
||||
lerobot-annotate="lerobot.scripts.lerobot_annotate:main"
|
||||
lerobot-rollout="lerobot.scripts.lerobot_rollout:main"
|
||||
# Interactive hierarchical-VLA runtime for PI052 (PaliGemma backbone).
|
||||
lerobot-pi052-runtime="lerobot.scripts.lerobot_pi052_runtime:main"
|
||||
lerobot-policy-server="lerobot.scripts.lerobot_policy_server:main"
|
||||
|
||||
# ---------------- Tool Configurations ----------------
|
||||
|
||||
@@ -367,7 +344,7 @@ torch = [{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" }]
|
||||
torchvision = [{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" }]
|
||||
|
||||
[tool.setuptools.package-data]
|
||||
lerobot = ["envs/*.json", "annotations/steerable_pipeline/prompts/*.txt"]
|
||||
lerobot = ["envs/*.json"]
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["src"]
|
||||
@@ -542,11 +519,6 @@ ignore_errors = false
|
||||
# module = "lerobot.rl.*"
|
||||
# ignore_errors = false
|
||||
|
||||
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.async_inference.*"
|
||||
# ignore_errors = false
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "lerobot.transport.*"
|
||||
ignore_errors = false
|
||||
|
||||
@@ -1,15 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
@@ -1,36 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Steerable annotation pipeline producing ``language_persistent`` and
|
||||
``language_events`` columns for LeRobot datasets.
|
||||
|
||||
The pipeline is decomposed into three independently runnable modules whose
|
||||
outputs are staged per-episode before a final parquet rewrite:
|
||||
|
||||
- :mod:`.modules.plan_subtasks_memory` (the ``plan`` module) — persistent styles
|
||||
- :mod:`.modules.interjections_and_speech` (the ``interjections`` module) — event styles + speech
|
||||
- :mod:`.modules.general_vqa` (the ``vqa`` module) — event-style VQA pairs
|
||||
"""
|
||||
|
||||
from .config import AnnotationPipelineConfig
|
||||
from .validator import StagingValidator, ValidationReport
|
||||
from .writer import LanguageColumnsWriter
|
||||
|
||||
__all__ = [
|
||||
"AnnotationPipelineConfig",
|
||||
"LanguageColumnsWriter",
|
||||
"StagingValidator",
|
||||
"ValidationReport",
|
||||
]
|
||||
@@ -1,196 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlanConfig:
|
||||
"""``plan`` module: subtasks + plan + memory + task augmentation."""
|
||||
|
||||
enabled: bool = True
|
||||
|
||||
# ``task_aug`` rephrasings at t=0 (renderer rotates ${task} among them); 0 disables.
|
||||
n_task_rephrasings: int = 10
|
||||
|
||||
# Derive the task from video instead of episode_task: off / if_short / always.
|
||||
# Affects prompts only; ``meta/tasks.parquet`` is untouched.
|
||||
derive_task_from_video: str = "if_short"
|
||||
derive_task_min_words: int = 3
|
||||
|
||||
# Frames sampled uniformly, capped at max_video_frames — a hard context cap
|
||||
# (~300 tokens/frame, so 32 fit a 32k VLM; 128 overflow).
|
||||
frames_per_second: float = 1.0
|
||||
max_video_frames: int = 32
|
||||
|
||||
# >0: split long episodes into windows of this length (constant fps density)
|
||||
# instead of subsampling the whole episode; spans merged + stitched. 0 disables.
|
||||
subtask_window_seconds: float = 0.0
|
||||
|
||||
min_subtask_seconds: float = 1.5
|
||||
plan_max_steps: int = 8
|
||||
|
||||
# Narrate-only grounding pass before segmenting — best defense against subtasks
|
||||
# invented from the task text (+1 VLM call/episode).
|
||||
subtask_describe_first: bool = True
|
||||
|
||||
# Emit ``style="plan"`` rows at each boundary; False = subtasks + memory only.
|
||||
emit_plan: bool = True
|
||||
|
||||
# (subtask spans are always stitched to a contiguous full-episode cover; not configurable.)
|
||||
|
||||
# Send a server-side ``video_url`` clip (at use_video_url_fps) instead of embedded frames.
|
||||
use_video_url: bool = False
|
||||
use_video_url_fps: float = 1.0
|
||||
|
||||
# Optional EgoMimic-style 5-axis task augmentation; replaces n_task_rephrasings.
|
||||
task_aug_axes: TaskAugAxesConfig = field(default_factory=lambda: TaskAugAxesConfig())
|
||||
|
||||
|
||||
@dataclass
|
||||
class TaskAugAxesConfig:
|
||||
"""5-axis t=0 task augmentation (EgoMimic-style): synonym / omit_arm /
|
||||
omit_orientation / omit_grasp_method / combined. Replaces n_task_rephrasings
|
||||
when enabled; each variant becomes a ``task_aug`` row. Axes with nothing to
|
||||
omit emit fewer entries. Defaults (3+3+2+2+2) match EgoMimic."""
|
||||
|
||||
enabled: bool = False
|
||||
|
||||
synonym_paraphrase: int = 3
|
||||
omit_arm: int = 3
|
||||
omit_orientation: int = 2
|
||||
omit_grasp_method: int = 2
|
||||
combined_omissions: int = 2
|
||||
|
||||
|
||||
@dataclass
|
||||
class InterjectionsConfig:
|
||||
"""``interjections`` module: interjections + paired speech."""
|
||||
|
||||
enabled: bool = True
|
||||
|
||||
# Each emits a paired (interjection, speech) row + a plan refresh at that ts.
|
||||
max_interjections_per_episode: int = 3
|
||||
interjection_min_t: float = 2.0
|
||||
|
||||
# Frame window centered on the timestamp so the VLM sees motion, not one frame.
|
||||
interjection_window_seconds: float = 2.0
|
||||
interjection_window_frames: int = 4
|
||||
|
||||
|
||||
@dataclass
|
||||
class VqaConfig:
|
||||
"""``vqa`` module: general VQA."""
|
||||
|
||||
enabled: bool = True
|
||||
vqa_emission_hz: float = 1.0
|
||||
K: int = 1
|
||||
"""Consecutive frames per emission tick. The VLM grounds on the FIRST frame,
|
||||
so K>1 smears stale labels onto moved frames. Default 1 (no smear)."""
|
||||
question_types: tuple[str, ...] = ("bbox", "keypoint", "count", "attribute", "spatial")
|
||||
|
||||
# True: ground VQA only on --vlm.camera_key (default: every camera).
|
||||
restrict_to_default_camera: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class VlmConfig:
|
||||
"""Shared Qwen-VL client configuration."""
|
||||
|
||||
# Only ``openai`` (OpenAI-compatible vLLM server, auto-spawned when
|
||||
# auto_serve=True); ``stub`` is for tests.
|
||||
backend: str = "openai"
|
||||
model_id: str = "Qwen/Qwen3.6-27B"
|
||||
|
||||
# OpenAI-compatible endpoint; ``EMPTY`` key works for local servers.
|
||||
api_base: str = "http://localhost:8000/v1"
|
||||
api_key: str = "EMPTY"
|
||||
|
||||
# Spawn a server if none answers api_base; False = fail fast on a remote.
|
||||
auto_serve: bool = True
|
||||
serve_port: int = 8000
|
||||
# Override the auto-serve command; ``{port}`` substituted per replica.
|
||||
serve_command: str | None = None
|
||||
|
||||
# Independent servers for round-robin routing (one per GPU). num_gpus=0 = one each.
|
||||
parallel_servers: int = 1
|
||||
num_gpus: int = 0
|
||||
client_concurrency: int = 16
|
||||
serve_ready_timeout_s: float = 600.0
|
||||
|
||||
max_new_tokens: int = 512
|
||||
temperature: float = 0.2
|
||||
|
||||
# Auto-serve context length (None → 32768); other vLLM flags go in serve_command.
|
||||
max_model_len: int | None = None
|
||||
|
||||
# Camera for keyframes; None → first ``observation.images.*`` key.
|
||||
camera_key: str | None = None
|
||||
# Forwarded as extra_body.chat_template_kwargs (e.g. {"enable_thinking": false}).
|
||||
chat_template_kwargs: dict[str, Any] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExecutorConfig:
|
||||
"""Executor settings (intra-process episode concurrency; distribution via HF Jobs)."""
|
||||
|
||||
# Episodes processed concurrently per phase; main knob for saturating the servers.
|
||||
episode_parallelism: int = 16
|
||||
|
||||
|
||||
@dataclass
|
||||
class AnnotationPipelineConfig:
|
||||
"""Top-level config for ``lerobot-annotate`` (rewrites data shards in place)."""
|
||||
|
||||
# Hub dataset: download source when ``root`` unset; push target when push_to_hub
|
||||
# is on and ``new_repo_id`` unset.
|
||||
repo_id: str | None = None
|
||||
|
||||
# Separate push target (matches the LeRobot edit tools). Unset → push in place.
|
||||
new_repo_id: str | None = None
|
||||
|
||||
root: Path | None = None
|
||||
|
||||
# Defaults to ``<root>/.annotate_staging/``.
|
||||
staging_dir: Path | None = None
|
||||
|
||||
seed: int = 1729
|
||||
|
||||
plan: PlanConfig = field(default_factory=PlanConfig)
|
||||
interjections: InterjectionsConfig = field(default_factory=InterjectionsConfig)
|
||||
vqa: VqaConfig = field(default_factory=VqaConfig)
|
||||
|
||||
vlm: VlmConfig = field(default_factory=VlmConfig)
|
||||
executor: ExecutorConfig = field(default_factory=ExecutorConfig)
|
||||
|
||||
skip_validation: bool = False
|
||||
only_episodes: tuple[int, ...] | None = None
|
||||
|
||||
# Keyframe decode backend. None → ffmpeg CLI (crash-/thread-safe; torchcodec
|
||||
# SIGSEGVs under concurrent decode). Or ``"torchcodec"`` / ``"pyav"``.
|
||||
video_backend: str | None = None
|
||||
|
||||
# Upload to the Hub (new_repo_id if set, else repo_id; one must be set).
|
||||
push_to_hub: bool = False
|
||||
push_private: bool = False
|
||||
push_commit_message: str | None = None
|
||||
|
||||
def resolved_staging_dir(self, root: Path) -> Path:
|
||||
return self.staging_dir if self.staging_dir is not None else root / ".annotate_staging"
|
||||
@@ -1,253 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""In-process executor that runs the annotation phases.
|
||||
|
||||
The executor runs **six phases** in dependency order:
|
||||
|
||||
phase 1: ``plan`` module (plan + subtasks + memory)
|
||||
phase 2: ``interjections`` module (interjections + speech)
|
||||
phase 3: ``plan`` plan-update pass — re-runs plan emission at every
|
||||
interjection timestamp produced by phase 2
|
||||
phase 4: ``vqa`` module (VQA)
|
||||
phase 5: validator
|
||||
phase 6: writer
|
||||
|
||||
Phase 3 is why the ``plan`` module must be re-entered after the
|
||||
``interjections`` module — to refresh ``plan`` rows at interjection
|
||||
timestamps.
|
||||
|
||||
Distributed execution is provided by Hugging Face Jobs (see
|
||||
``examples/annotations/run_hf_job.py``); the runner inside the job
|
||||
invokes ``lerobot-annotate`` which uses this in-process executor.
|
||||
Episode-level concurrency is controlled by
|
||||
``ExecutorConfig.episode_parallelism``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from .config import AnnotationPipelineConfig
|
||||
from .reader import EpisodeRecord, iter_episodes
|
||||
from .staging import EpisodeStaging
|
||||
from .validator import StagingValidator
|
||||
from .writer import LanguageColumnsWriter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PhaseResult:
|
||||
"""Summary of one pipeline phase across all episodes."""
|
||||
|
||||
name: str
|
||||
episodes_processed: int
|
||||
episodes_skipped: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelineRunSummary:
|
||||
"""Aggregated result returned by :meth:`Executor.run`."""
|
||||
|
||||
phases: list[PhaseResult]
|
||||
written_paths: list[Path]
|
||||
validation_report: Any # ValidationReport, kept Any to avoid import cycle
|
||||
|
||||
|
||||
@dataclass
|
||||
class Executor:
|
||||
"""Run all six phases over a dataset root in-process.
|
||||
|
||||
Episode-level concurrency comes from ``ExecutorConfig.episode_parallelism``
|
||||
(a thread pool); cluster-level concurrency comes from running this
|
||||
executor inside a Hugging Face Job. Tests construct the executor
|
||||
directly with stub modules.
|
||||
"""
|
||||
|
||||
config: AnnotationPipelineConfig
|
||||
plan: Any # PlanSubtasksMemoryModule
|
||||
interjections: Any # InterjectionsAndSpeechModule
|
||||
vqa: Any # GeneralVqaModule
|
||||
writer: LanguageColumnsWriter
|
||||
validator: StagingValidator
|
||||
|
||||
def run(self, root: Path) -> PipelineRunSummary:
|
||||
records = list(iter_episodes(root, only_episodes=self.config.only_episodes))
|
||||
n = len(records)
|
||||
if n == 0:
|
||||
raise ValueError(f"No episodes found under {root}/data/")
|
||||
|
||||
print(f"[annotate] {n} episodes total", flush=True)
|
||||
|
||||
staging_dir = self.config.resolved_staging_dir(root)
|
||||
staging_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
phases: list[PhaseResult] = []
|
||||
|
||||
# Phase 1: ``plan`` module (plan + subtasks + memory)
|
||||
phases.append(self._run_module_phase("plan", records, staging_dir, self.plan))
|
||||
# Phase 2: ``interjections`` module (interjections + speech). It
|
||||
# reads the ``plan`` module's subtask rows from the same staging
|
||||
# tree to ground the interjection prompt in the correct local subtask.
|
||||
phases.append(self._run_module_phase("interjections", records, staging_dir, self.interjections))
|
||||
# Phase 3: ``plan`` plan-update pass at interjection timestamps.
|
||||
phases.append(self._run_plan_update_phase(records, staging_dir))
|
||||
# Phase 4: ``vqa`` module (VQA)
|
||||
phases.append(self._run_module_phase("vqa", records, staging_dir, self.vqa))
|
||||
|
||||
print("[annotate] running validator...", flush=True)
|
||||
report = self.validator.validate(records, staging_dir)
|
||||
if not report.ok and not self.config.skip_validation:
|
||||
raise RuntimeError(f"Staging validation failed: {report.summary()}")
|
||||
print(f"[annotate] validator: {report.summary()}", flush=True)
|
||||
|
||||
print(f"[annotate] writing parquet shards into {root}/data/...", flush=True)
|
||||
written = self.writer.write_all(records, staging_dir, root)
|
||||
print(f"[annotate] wrote {len(written)} shard(s); pipeline complete", flush=True)
|
||||
|
||||
# Keep meta/info.json aligned with the parquet schema we just wrote.
|
||||
# Idempotent and additive: existing user metadata is preserved.
|
||||
self._ensure_annotation_metadata_in_info(root)
|
||||
|
||||
return PipelineRunSummary(phases=phases, written_paths=written, validation_report=report)
|
||||
|
||||
@staticmethod
|
||||
def _ensure_annotation_metadata_in_info(root: Path) -> None:
|
||||
"""Write language features and canonical tools to ``meta/info.json``.
|
||||
|
||||
``LanguageColumnsWriter`` adds ``language_persistent`` and
|
||||
``language_events`` to parquet shards. The metadata must advertise
|
||||
those columns too, otherwise non-streaming ``LeRobotDataset`` loads
|
||||
cast against the old schema and fail on the extra parquet columns.
|
||||
"""
|
||||
from lerobot.datasets.io_utils import load_info, write_info # noqa: PLC0415
|
||||
from lerobot.datasets.language import SAY_TOOL_SCHEMA, language_feature_info # noqa: PLC0415
|
||||
|
||||
info_path = root / "meta" / "info.json"
|
||||
if not info_path.exists():
|
||||
return
|
||||
try:
|
||||
info = load_info(root)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
print(f"[annotate] could not read {info_path}: {exc}", flush=True)
|
||||
return
|
||||
|
||||
changed = False
|
||||
|
||||
merged_features = {**info.features, **language_feature_info()}
|
||||
if merged_features != info.features:
|
||||
info.features = merged_features
|
||||
changed = True
|
||||
|
||||
existing = info.tools or []
|
||||
names = {(t.get("function") or {}).get("name") for t in existing if isinstance(t, dict)}
|
||||
if SAY_TOOL_SCHEMA["function"]["name"] not in names:
|
||||
info.tools = [*existing, SAY_TOOL_SCHEMA]
|
||||
changed = True
|
||||
|
||||
if changed:
|
||||
write_info(info, root)
|
||||
print(
|
||||
"[annotate] meta/info.json: "
|
||||
f"language_features={list(language_feature_info())}, "
|
||||
f"tools={[t['function']['name'] for t in (info.tools or [])]}",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
def _run_module_phase(
|
||||
self,
|
||||
name: str,
|
||||
records: list[EpisodeRecord],
|
||||
staging_dir: Path,
|
||||
module: Any,
|
||||
) -> PhaseResult:
|
||||
if not module.enabled:
|
||||
print(f"[annotate] phase={name} skipped (module disabled)", flush=True)
|
||||
return PhaseResult(name=name, episodes_processed=0, episodes_skipped=len(records))
|
||||
n = len(records)
|
||||
parallelism = max(1, min(self.config.executor.episode_parallelism, n))
|
||||
print(
|
||||
f"[annotate] phase={name} starting on {n} episode(s) (parallelism={parallelism})",
|
||||
flush=True,
|
||||
)
|
||||
t0 = time.time()
|
||||
|
||||
def _do(idx_record: tuple[int, EpisodeRecord]) -> tuple[int, int, float]:
|
||||
i, record = idx_record
|
||||
ep_start = time.time()
|
||||
staging = EpisodeStaging(staging_dir, record.episode_index)
|
||||
module.run_episode(record, staging)
|
||||
return i, record.episode_index, time.time() - ep_start
|
||||
|
||||
processed = 0
|
||||
if parallelism == 1:
|
||||
for i, record in enumerate(records, 1):
|
||||
_, ep_idx, elapsed = _do((i, record))
|
||||
processed += 1
|
||||
print(
|
||||
f"[annotate] {name} episode {i}/{n} (idx={ep_idx}) done in {elapsed:.1f}s",
|
||||
flush=True,
|
||||
)
|
||||
else:
|
||||
with ThreadPoolExecutor(max_workers=parallelism) as pool:
|
||||
futures = [pool.submit(_do, (i, r)) for i, r in enumerate(records, 1)]
|
||||
for fut in as_completed(futures):
|
||||
i, ep_idx, elapsed = fut.result()
|
||||
processed += 1
|
||||
print(
|
||||
f"[annotate] {name} episode {processed}/{n} "
|
||||
f"(idx={ep_idx}, submit_order={i}) done in {elapsed:.1f}s",
|
||||
flush=True,
|
||||
)
|
||||
total = time.time() - t0
|
||||
print(f"[annotate] phase={name} complete: {processed}/{n} in {total:.1f}s", flush=True)
|
||||
return PhaseResult(name=name, episodes_processed=processed, episodes_skipped=0)
|
||||
|
||||
def _run_plan_update_phase( # noqa: PLR0915
|
||||
self, records: list[EpisodeRecord], staging_dir: Path
|
||||
) -> PhaseResult:
|
||||
"""Re-emit ``plan`` rows at each timestamp the ``interjections`` module produced.
|
||||
|
||||
The ``plan`` module owns the prompt; the ``interjections`` module
|
||||
produced the timestamps. This phase therefore calls back into the
|
||||
``plan`` module with the interjection timestamps so its existing
|
||||
prompt path is reused.
|
||||
"""
|
||||
if not self.plan.enabled or not self.interjections.enabled:
|
||||
return PhaseResult(name="plan_update", episodes_processed=0, episodes_skipped=len(records))
|
||||
processed = 0
|
||||
for record in records:
|
||||
staging = EpisodeStaging(staging_dir, record.episode_index)
|
||||
interjection_rows = [
|
||||
row for row in staging.read("interjections") if row.get("style") == "interjection"
|
||||
]
|
||||
interjection_times = [float(row["timestamp"]) for row in interjection_rows]
|
||||
interjection_texts = [str(row.get("content") or "") for row in interjection_rows]
|
||||
if interjection_times:
|
||||
self.plan.run_plan_updates(record, staging, interjection_times, interjection_texts)
|
||||
processed += 1
|
||||
# Episodes without any interjections are skipped (no plan refresh
|
||||
# needed); count them so the summary's processed+skipped == total.
|
||||
return PhaseResult(
|
||||
name="plan_update",
|
||||
episodes_processed=processed,
|
||||
episodes_skipped=len(records) - processed,
|
||||
)
|
||||
@@ -1,498 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Keyframe extraction for the annotation pipeline.
|
||||
|
||||
Modules attach decoded camera frames to their VLM prompts so the model can
|
||||
ground subtask decomposition, interjection scenarios, and VQA in actual
|
||||
visual content. The pipeline shares one provider across modules and one
|
||||
episode at a time, with a small per-episode cache so multiple modules
|
||||
querying the same timestamp pay decode cost once.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Protocol
|
||||
|
||||
import PIL.Image
|
||||
import torch
|
||||
|
||||
from lerobot.datasets.video_utils import decode_video_frames
|
||||
|
||||
from .reader import EpisodeRecord
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FrameProvider(Protocol):
|
||||
"""Decodes camera frames at episode-relative timestamps."""
|
||||
|
||||
@property
|
||||
def camera_keys(self) -> list[str]:
|
||||
"""All ``observation.images.*`` feature keys this provider can decode."""
|
||||
|
||||
def frames_at(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
timestamps: list[float],
|
||||
camera_key: str | None = None,
|
||||
) -> list[Any]:
|
||||
"""Return one decoded frame per timestamp from ``camera_key`` (or default).
|
||||
|
||||
Frames are ``torch.Tensor`` (``C, H, W`` uint8) — the shape
|
||||
:func:`lerobot.datasets.video_utils.decode_video_frames` returns.
|
||||
:func:`to_image_blocks` converts them to PIL only at the VLM-message
|
||||
boundary.
|
||||
|
||||
Empty list if the camera is unavailable. ``camera_key=None`` falls back
|
||||
to the provider's default camera so existing single-camera callers
|
||||
(the ``plan`` and ``interjections`` modules) keep working unchanged.
|
||||
"""
|
||||
|
||||
def video_for_episode(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
max_frames: int,
|
||||
camera_key: str | None = None,
|
||||
) -> list[Any]:
|
||||
"""Return up to ``max_frames`` decoded frames covering the whole episode.
|
||||
|
||||
Sampling is uniform across the episode duration. Frames are
|
||||
``torch.Tensor`` (``C, H, W`` uint8); :func:`to_video_block` wraps
|
||||
them into one ``{"type":"video", "video":<list>}`` block for a
|
||||
Qwen-VL-compatible model that pools temporally itself. Empty list if
|
||||
no camera available.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class _NullProvider:
|
||||
"""No-op provider used when the dataset has no video keys or in tests."""
|
||||
|
||||
@property
|
||||
def camera_keys(self) -> list[str]:
|
||||
return []
|
||||
|
||||
def frames_at(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
timestamps: list[float],
|
||||
camera_key: str | None = None,
|
||||
) -> list[Any]:
|
||||
return []
|
||||
|
||||
def video_for_episode(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
max_frames: int,
|
||||
camera_key: str | None = None,
|
||||
) -> list[Any]:
|
||||
return []
|
||||
|
||||
|
||||
def null_provider() -> FrameProvider:
|
||||
return _NullProvider()
|
||||
|
||||
|
||||
@dataclass
|
||||
class VideoFrameProvider:
|
||||
"""Decodes frames from the dataset's ``observation.images.*`` streams.
|
||||
|
||||
By default the *first* camera key is used for the ``plan`` module
|
||||
(subtask decomposition) and the ``interjections`` module (interjection
|
||||
scenarios) — those prompts care about *what is happening*, not which
|
||||
angle. The ``vqa`` module instead iterates over every camera in
|
||||
:attr:`camera_keys` so each frame's
|
||||
grounded answer (bbox/keypoint/...) is tagged with the camera it was
|
||||
grounded against.
|
||||
|
||||
``camera_key`` overrides the default-camera choice but does not restrict
|
||||
:attr:`camera_keys`. Pass ``camera_key`` explicitly to ``frames_at`` /
|
||||
``video_for_episode`` to read a non-default stream.
|
||||
|
||||
Caches up to ``cache_size`` decoded frames per process to keep
|
||||
co-timestamped ``interjections`` + ``plan`` plan-update calls cheap.
|
||||
"""
|
||||
|
||||
root: Path
|
||||
camera_key: str | None = None
|
||||
tolerance_s: float = 1e-2
|
||||
cache_size: int = 256
|
||||
# Keyframe decode backend. ``None`` uses the ffmpeg CLI — the
|
||||
# concurrency- and crash-safe default for the pipeline's threaded
|
||||
# decode. Set to ``"torchcodec"`` or ``"pyav"`` to pin an in-process
|
||||
# decoder when the build is known thread-safe.
|
||||
video_backend: str | None = None
|
||||
_meta: Any = field(default=None, init=False, repr=False)
|
||||
_cache: dict = field(default_factory=dict, init=False, repr=False)
|
||||
_camera_keys: list[str] = field(default_factory=list, init=False, repr=False)
|
||||
# Pipeline runs the three module phases under a ThreadPoolExecutor (see
|
||||
# ``ExecutorConfig.episode_parallelism``); guard the dict cache and the
|
||||
# one-shot warn flag against concurrent updates from worker threads.
|
||||
_lock: threading.Lock = field(default_factory=threading.Lock, init=False, repr=False)
|
||||
_warned_decode_fail: bool = field(default=False, init=False, repr=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata # noqa: PLC0415
|
||||
|
||||
self._meta = LeRobotDatasetMetadata(repo_id="local", root=self.root)
|
||||
# Only ``video_keys`` are decodable here: the clip/decode paths read
|
||||
# ``videos/<key>/from_timestamp`` from episode metadata, which exists
|
||||
# only for video-stored cameras. Image-stored cameras (also in
|
||||
# ``camera_keys``) would KeyError, so restrict the list — and the
|
||||
# default — to video keys.
|
||||
keys = list(self._meta.video_keys)
|
||||
# Last-resort fallback: if metadata didn't surface any video keys but
|
||||
# the caller explicitly named a camera (``--vlm.camera_key=...``),
|
||||
# trust them — the key is by definition known to exist on the dataset.
|
||||
if not keys and self.camera_key:
|
||||
keys = [self.camera_key]
|
||||
self._camera_keys = keys
|
||||
if self.camera_key is None:
|
||||
self.camera_key = keys[0] if keys else None
|
||||
|
||||
@property
|
||||
def camera_keys(self) -> list[str]:
|
||||
"""All ``observation.images.*`` keys available on this dataset."""
|
||||
return list(self._camera_keys)
|
||||
|
||||
def frames_at(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
timestamps: list[float],
|
||||
camera_key: str | None = None,
|
||||
) -> list[Any]:
|
||||
target = camera_key if camera_key is not None else self.camera_key
|
||||
if not timestamps or target is None:
|
||||
return []
|
||||
|
||||
out: list[Any] = []
|
||||
misses: list[float] = []
|
||||
miss_indices: list[int] = []
|
||||
with self._lock:
|
||||
for i, ts in enumerate(timestamps):
|
||||
key = (record.episode_index, target, round(float(ts), 6))
|
||||
cached = self._cache.get(key)
|
||||
if cached is not None:
|
||||
out.append(cached)
|
||||
else:
|
||||
out.append(None)
|
||||
misses.append(float(ts))
|
||||
miss_indices.append(i)
|
||||
|
||||
if misses:
|
||||
decoded = self._decode(record.episode_index, misses, target)
|
||||
# ``_decode`` returns exactly one frame per requested timestamp,
|
||||
# or an empty list if decoding failed wholesale. A partial list
|
||||
# would mean a frame/timestamp misalignment, so only pair them up
|
||||
# when the counts match (``strict=True`` then guards regressions).
|
||||
if len(decoded) == len(miss_indices):
|
||||
with self._lock:
|
||||
for i, frame in zip(miss_indices, decoded, strict=True):
|
||||
out[i] = frame
|
||||
key = (record.episode_index, target, round(float(timestamps[i]), 6))
|
||||
if len(self._cache) >= self.cache_size:
|
||||
self._cache.pop(next(iter(self._cache)))
|
||||
self._cache[key] = frame
|
||||
# filter out any None left over from decode failures
|
||||
return [frame for frame in out if frame is not None]
|
||||
|
||||
def video_for_episode(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
max_frames: int,
|
||||
camera_key: str | None = None,
|
||||
) -> list[Any]:
|
||||
"""Return up to ``max_frames`` frames uniformly sampled across the episode.
|
||||
|
||||
The whole episode duration is covered; the model picks subtask
|
||||
boundaries from the temporal pooling it does internally. Frames are
|
||||
``torch.Tensor`` (see :meth:`frames_at`).
|
||||
"""
|
||||
target = camera_key if camera_key is not None else self.camera_key
|
||||
if max_frames <= 0 or target is None or not record.frame_timestamps:
|
||||
return []
|
||||
n_frames = min(max_frames, len(record.frame_timestamps))
|
||||
if n_frames == len(record.frame_timestamps):
|
||||
timestamps = list(record.frame_timestamps)
|
||||
else:
|
||||
t0 = record.frame_timestamps[0]
|
||||
t_last = record.frame_timestamps[-1]
|
||||
if t_last <= t0:
|
||||
timestamps = [float(t0)] * n_frames
|
||||
else:
|
||||
step = (t_last - t0) / (n_frames - 1) if n_frames > 1 else 0.0
|
||||
timestamps = [float(t0 + i * step) for i in range(n_frames)]
|
||||
return self.frames_at(record, timestamps, camera_key=target)
|
||||
|
||||
def episode_clip_path(self, record: EpisodeRecord, cache_dir: Path) -> Path | None:
|
||||
"""Extract the episode's subclip to ``cache_dir/ep_{idx:06d}.mp4``.
|
||||
|
||||
Returns ``None`` if the dataset has no video tracks. Skips
|
||||
re-extract when the cached clip already exists. Re-encodes to
|
||||
H.264 (libx264) so the resulting mp4 is decodable by every
|
||||
downstream video processor — stream-copy would inherit the
|
||||
source codec (often AV1 in modern LeRobot datasets), which
|
||||
vllm's libav build cannot decode.
|
||||
"""
|
||||
import subprocess # noqa: PLC0415
|
||||
|
||||
if self.camera_key is None:
|
||||
return None
|
||||
cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
out_path = cache_dir / f"ep_{record.episode_index:06d}.mp4"
|
||||
if out_path.exists() and out_path.stat().st_size > 0:
|
||||
return out_path
|
||||
ep = self._meta.episodes[record.episode_index]
|
||||
from_timestamp = float(ep[f"videos/{self.camera_key}/from_timestamp"])
|
||||
to_timestamp = float(ep[f"videos/{self.camera_key}/to_timestamp"])
|
||||
src = self.root / self._meta.get_video_file_path(record.episode_index, self.camera_key)
|
||||
cmd = [
|
||||
"ffmpeg",
|
||||
"-y",
|
||||
"-loglevel",
|
||||
"error",
|
||||
"-ss",
|
||||
f"{from_timestamp:.3f}",
|
||||
"-to",
|
||||
f"{to_timestamp:.3f}",
|
||||
"-i",
|
||||
str(src),
|
||||
"-c:v",
|
||||
"libx264",
|
||||
"-preset",
|
||||
"ultrafast",
|
||||
"-crf",
|
||||
"23",
|
||||
"-pix_fmt",
|
||||
"yuv420p",
|
||||
"-an",
|
||||
str(out_path),
|
||||
]
|
||||
try:
|
||||
# ffmpeg is invoked by name via PATH lookup (the standard way to
|
||||
# call the CLI); the arg list is fully controlled here, not shell.
|
||||
subprocess.run(cmd, check=True, timeout=300) # nosec B607
|
||||
except (subprocess.CalledProcessError, subprocess.TimeoutExpired, FileNotFoundError):
|
||||
return None
|
||||
return out_path if out_path.exists() and out_path.stat().st_size > 0 else None
|
||||
|
||||
def _decode(self, episode_index: int, timestamps: list[float], camera_key: str) -> list[Any]:
|
||||
"""Decode ``timestamps`` from the episode's video as ``(C, H, W)`` tensors.
|
||||
|
||||
Delegates to :func:`lerobot.datasets.video_utils.decode_video_frames`
|
||||
(torchcodec by default, PyAV fallback) rather than a bespoke decoder.
|
||||
Returns one frame per requested timestamp, or ``[]`` if decoding
|
||||
failed wholesale — callers treat ``[]`` as "no frames available".
|
||||
"""
|
||||
ep = self._meta.episodes[episode_index]
|
||||
from_timestamp = ep[f"videos/{camera_key}/from_timestamp"]
|
||||
shifted = [from_timestamp + ts for ts in timestamps]
|
||||
video_path = self.root / self._meta.get_video_file_path(episode_index, camera_key)
|
||||
|
||||
# Default to the ffmpeg CLI. The pipeline decodes under a 16-wide
|
||||
# ThreadPoolExecutor and the in-process decoders are unsafe there:
|
||||
# torchcodec is not thread-safe and SIGSEGVs under concurrent decode
|
||||
# (a crash no try/except can catch), PyAV can likewise segfault on
|
||||
# AV1, and lerobot's ``pyav`` backend routes through the removed
|
||||
# ``torchvision.io.VideoReader``. ``_decode_frames_ffmpeg`` shells
|
||||
# out per frame: each decode is an isolated child process, so it is
|
||||
# both crash-safe and concurrency-safe. ``video_backend`` can pin
|
||||
# ``torchcodec`` / ``pyav`` explicitly for callers that know their
|
||||
# build is safe.
|
||||
chain = [self.video_backend] if self.video_backend else ["ffmpeg"]
|
||||
|
||||
exc: Exception | None = None
|
||||
for backend in chain:
|
||||
try:
|
||||
if backend == "ffmpeg":
|
||||
return _decode_frames_ffmpeg(video_path, shifted)
|
||||
if backend in ("pyav", "av"):
|
||||
return _decode_frames_av(video_path, shifted)
|
||||
# Stacked ``(N, C, H, W)`` uint8 tensor; one row per timestamp.
|
||||
decoded = decode_video_frames(
|
||||
video_path, shifted, self.tolerance_s, backend=backend, return_uint8=True
|
||||
)
|
||||
return list(decoded)
|
||||
except Exception as e: # noqa: PERF203
|
||||
exc = e
|
||||
|
||||
# Every backend raised. Log loudly the first time so a silent
|
||||
# vqa-module no-op (every prompt skipped because frames_at returned
|
||||
# []) is debuggable from the job log instead of post-hoc parquet
|
||||
# inspection. Subsequent failures stay quiet.
|
||||
with self._lock:
|
||||
already_warned = self._warned_decode_fail
|
||||
if not already_warned:
|
||||
self._warned_decode_fail = True
|
||||
if not already_warned:
|
||||
logger.warning(
|
||||
"VideoFrameProvider._decode failed for episode=%s camera=%s video_path=%s backends=%s: %s",
|
||||
episode_index,
|
||||
camera_key,
|
||||
video_path,
|
||||
chain,
|
||||
exc,
|
||||
exc_info=exc,
|
||||
)
|
||||
return []
|
||||
|
||||
|
||||
def make_frame_provider(
|
||||
root: Path, camera_key: str | None = None, video_backend: str | None = None
|
||||
) -> FrameProvider:
|
||||
"""Build a :class:`VideoFrameProvider` if videos are present, else null."""
|
||||
try:
|
||||
provider = VideoFrameProvider(root=root, camera_key=camera_key, video_backend=video_backend)
|
||||
except Exception:
|
||||
return null_provider()
|
||||
if provider.camera_key is None:
|
||||
return null_provider()
|
||||
return provider
|
||||
|
||||
|
||||
def _decode_frames_ffmpeg(video_path: Path, timestamps: list[float]) -> list[Any]:
|
||||
"""Decode the frames nearest to ``timestamps`` via the ffmpeg CLI.
|
||||
|
||||
Runs one ``ffmpeg`` process per timestamp, seeking with ``-ss`` and
|
||||
piping a single PNG to stdout. Unlike the in-process decoders this
|
||||
survives a hostile container: a full ffmpeg build decodes AV1 (the codec
|
||||
modern LeRobot datasets use) where torchcodec raises and PyAV can
|
||||
SIGSEGV, and a crash stays isolated to the child process — a non-zero
|
||||
exit is a catchable error, not a segfault of the whole job. Returns one
|
||||
``(C, H, W)`` uint8 tensor per timestamp.
|
||||
"""
|
||||
import io # noqa: PLC0415
|
||||
import subprocess # noqa: PLC0415
|
||||
|
||||
import numpy as np # noqa: PLC0415
|
||||
|
||||
frames: list[Any] = []
|
||||
for ts in timestamps:
|
||||
# ffmpeg invoked by name via PATH lookup; fully-controlled arg list, no shell.
|
||||
proc = subprocess.run( # nosec B607
|
||||
[
|
||||
"ffmpeg",
|
||||
"-nostdin",
|
||||
"-loglevel",
|
||||
"error",
|
||||
"-ss",
|
||||
f"{max(ts, 0.0):.3f}",
|
||||
"-i",
|
||||
str(video_path),
|
||||
"-frames:v",
|
||||
"1",
|
||||
"-f",
|
||||
"image2pipe",
|
||||
"-vcodec",
|
||||
"png",
|
||||
"pipe:1",
|
||||
],
|
||||
capture_output=True,
|
||||
check=True,
|
||||
timeout=120,
|
||||
)
|
||||
if not proc.stdout:
|
||||
raise RuntimeError(f"ffmpeg returned no frame for t={ts:.3f}s of {video_path}")
|
||||
img = PIL.Image.open(io.BytesIO(proc.stdout)).convert("RGB")
|
||||
frames.append(torch.from_numpy(np.asarray(img).copy()).permute(2, 0, 1).contiguous())
|
||||
return frames
|
||||
|
||||
|
||||
def _decode_frames_av(video_path: Path, timestamps: list[float]) -> list[Any]:
|
||||
"""Decode the frames nearest to ``timestamps`` using PyAV directly.
|
||||
|
||||
lerobot's ``decode_video_frames(backend="pyav")`` routes through
|
||||
``torchvision.io.VideoReader``, removed in torchvision 0.23+. This helper
|
||||
talks to the ``av`` package directly. Note PyAV can SIGSEGV on AV1
|
||||
streams in some builds — prefer ``_decode_frames_ffmpeg`` as the default
|
||||
fallback; this stays available behind ``video_backend="pyav"``. Returns
|
||||
one ``(C, H, W)`` uint8 tensor per timestamp.
|
||||
"""
|
||||
import av # noqa: PLC0415
|
||||
|
||||
first_ts = min(timestamps)
|
||||
last_ts = max(timestamps)
|
||||
loaded_frames: list[torch.Tensor] = []
|
||||
loaded_ts: list[float] = []
|
||||
with av.open(str(video_path)) as container:
|
||||
stream = container.streams.video[0]
|
||||
# Seek to the keyframe at or before the first requested timestamp.
|
||||
offset = max(int(first_ts / stream.time_base), 0) if stream.time_base else 0
|
||||
container.seek(offset, stream=stream, backward=True, any_frame=False)
|
||||
for idx, frame in enumerate(container.decode(stream)):
|
||||
ts = frame.time
|
||||
if ts is None:
|
||||
ts = float(frame.pts * stream.time_base) if frame.pts is not None else float(idx)
|
||||
loaded_ts.append(ts)
|
||||
loaded_frames.append(
|
||||
torch.from_numpy(frame.to_ndarray(format="rgb24")).permute(2, 0, 1).contiguous()
|
||||
)
|
||||
if ts >= last_ts:
|
||||
break
|
||||
if not loaded_frames:
|
||||
raise RuntimeError(f"PyAV decoded no frames from {video_path}")
|
||||
ts_tensor = torch.tensor(loaded_ts)
|
||||
return [loaded_frames[int(torch.argmin((ts_tensor - q).abs()))] for q in timestamps]
|
||||
|
||||
|
||||
def _frame_to_pil(frame: Any) -> Any:
|
||||
"""Materialise a decoded frame as a ``PIL.Image`` for the VLM message.
|
||||
|
||||
Frames flow through the provider as ``torch.Tensor`` (``C, H, W`` uint8,
|
||||
straight from :func:`decode_video_frames`); PIL is only created here, at
|
||||
the VLM-message boundary, because the chat backends expect PIL images /
|
||||
data URLs. Non-tensor inputs (e.g. test stubs) pass through untouched.
|
||||
"""
|
||||
if not isinstance(frame, torch.Tensor):
|
||||
return frame
|
||||
array = frame.detach().cpu()
|
||||
if array.ndim == 3 and array.shape[0] in (1, 3):
|
||||
array = array.permute(1, 2, 0) # (C, H, W) -> (H, W, C)
|
||||
if array.shape[-1] == 1:
|
||||
array = array.squeeze(-1)
|
||||
return PIL.Image.fromarray(array.to(torch.uint8).numpy())
|
||||
|
||||
|
||||
def to_image_blocks(frames: list[Any]) -> list[dict[str, Any]]:
|
||||
"""Convert decoded frames to Qwen-VL-compatible image content blocks."""
|
||||
return [{"type": "image", "image": _frame_to_pil(frame)} for frame in frames]
|
||||
|
||||
|
||||
def to_video_block(frames: list[Any]) -> list[dict[str, Any]]:
|
||||
"""Wrap a list of decoded frames as one Qwen-VL video block.
|
||||
|
||||
Returns ``[]`` when the list is empty, so the caller can splat the result
|
||||
into a content array without a separate emptiness check.
|
||||
"""
|
||||
if not frames:
|
||||
return []
|
||||
return [{"type": "video", "video": [_frame_to_pil(frame) for frame in frames]}]
|
||||
|
||||
|
||||
def to_video_url_block(url: str | None, fps: float = 2.0) -> list[dict[str, Any]]:
|
||||
"""Wrap a video file URL as one ``video_url`` block.
|
||||
|
||||
Used by the ``openai`` backend (transformers serve / vllm serve /
|
||||
ktransformers serve), where the server handles frame sampling.
|
||||
Returns ``[]`` when ``url`` is ``None`` so the caller can splat.
|
||||
"""
|
||||
if not url:
|
||||
return []
|
||||
return [{"type": "video_url", "video_url": {"url": url}, "fps": fps}]
|
||||
@@ -1,25 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .general_vqa import GeneralVqaModule
|
||||
from .interjections_and_speech import InterjectionsAndSpeechModule
|
||||
from .plan_subtasks_memory import PlanSubtasksMemoryModule
|
||||
|
||||
__all__ = [
|
||||
"GeneralVqaModule",
|
||||
"InterjectionsAndSpeechModule",
|
||||
"PlanSubtasksMemoryModule",
|
||||
]
|
||||
@@ -1,248 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""``vqa`` module: general VQA at a timed cadence.
|
||||
|
||||
Every ``1/hz`` seconds an emission tick fires; each tick anchors ``K``
|
||||
consecutive frames, and every anchored frame gets its own VQA pair. Each
|
||||
pair is grounded on that single anchor frame — there is no per-pair frame
|
||||
window. For datasets with multiple cameras, every anchored frame produces
|
||||
one ``(vqa, user)`` + ``(vqa, assistant)`` pair *per camera*: each pair is
|
||||
generated against that camera's frame and stamped with the matching
|
||||
``camera`` field on the emitted rows. The resolver disambiguates via
|
||||
``camera=...``; recipes that consume VQA do so through one sub-recipe
|
||||
per camera (see ``recipes/pi05_hirobot.yaml``).
|
||||
|
||||
Within a single (frame, camera) we still emit at most one ``(vqa, user)``
|
||||
and one ``(vqa, assistant)`` row, so the resolver contract stays scalar.
|
||||
|
||||
Question types covered (per the plan's ``vqa`` table): bbox, keypoint,
|
||||
count, attribute, spatial. The assistant's ``content`` is a JSON string
|
||||
whose schema depends on the question type. Malformed JSON triggers one
|
||||
retry inside :meth:`VlmClient.generate_json`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from ..config import VqaConfig
|
||||
from ..frames import FrameProvider, null_provider, to_image_blocks
|
||||
from ..prompts import load as load_prompt
|
||||
from ..reader import EpisodeRecord
|
||||
from ..staging import EpisodeStaging
|
||||
from ..validator import classify_vqa_answer
|
||||
from ..vlm_client import VlmClient
|
||||
|
||||
|
||||
def _emission_anchor_indices(frame_timestamps: Sequence[float], hz: float, k: int) -> list[int]:
|
||||
"""Return the relative frame indices to anchor VQA emissions to.
|
||||
|
||||
For each emission tick (every ``1/hz`` seconds), we anchor ``k``
|
||||
consecutive frames starting at the tick. Ticks fall on the nearest
|
||||
available source frame timestamp.
|
||||
"""
|
||||
if hz <= 0 or k <= 0 or not frame_timestamps:
|
||||
return []
|
||||
t0 = frame_timestamps[0]
|
||||
t_last = frame_timestamps[-1]
|
||||
period = 1.0 / hz
|
||||
indices: list[int] = []
|
||||
t = t0
|
||||
while t <= t_last + 1e-9:
|
||||
# find the index of the nearest frame to t
|
||||
nearest_i = min(range(len(frame_timestamps)), key=lambda i: abs(frame_timestamps[i] - t))
|
||||
for offset in range(k):
|
||||
j = nearest_i + offset
|
||||
if j >= len(frame_timestamps):
|
||||
break
|
||||
if not indices or indices[-1] != j:
|
||||
indices.append(j)
|
||||
t += period
|
||||
# dedupe while preserving order
|
||||
seen: set[int] = set()
|
||||
deduped: list[int] = []
|
||||
for i in indices:
|
||||
if i in seen:
|
||||
continue
|
||||
seen.add(i)
|
||||
deduped.append(i)
|
||||
return deduped
|
||||
|
||||
|
||||
@dataclass
|
||||
class GeneralVqaModule:
|
||||
"""Emit grounded VQA pairs at a timed cadence."""
|
||||
|
||||
vlm: VlmClient
|
||||
config: VqaConfig
|
||||
seed: int = 1729
|
||||
frame_provider: FrameProvider = field(default_factory=null_provider)
|
||||
_warned_no_camera: bool = field(default=False, init=False, repr=False)
|
||||
|
||||
@property
|
||||
def enabled(self) -> bool:
|
||||
return self.config.enabled
|
||||
|
||||
def run_episode(self, record: EpisodeRecord, staging: EpisodeStaging) -> None:
|
||||
if not record.frame_timestamps:
|
||||
staging.write("vqa", [])
|
||||
return
|
||||
rng = random.Random(f"{self.seed}:{record.episode_index}:vqa")
|
||||
anchor_idx = _emission_anchor_indices(
|
||||
record.frame_timestamps, self.config.vqa_emission_hz, self.config.K
|
||||
)
|
||||
cameras = self._target_cameras()
|
||||
if not cameras:
|
||||
# No camera available — emit nothing rather than producing
|
||||
# untagged rows that would fail validation. Surface a loud one-
|
||||
# time warning so this is never silently a no-op.
|
||||
if not self._warned_no_camera:
|
||||
logging.getLogger(__name__).warning(
|
||||
"vqa module found no cameras on the frame provider — "
|
||||
"every episode will emit zero VQA rows. Check that the "
|
||||
"dataset declares observation.images.* features in "
|
||||
"meta/info.json; passing --vlm.camera_key=<key> at the "
|
||||
"CLI now also seeds the cameras list as a fallback."
|
||||
)
|
||||
self._warned_no_camera = True
|
||||
staging.write("vqa", [])
|
||||
return
|
||||
|
||||
# Build all messages first (one per (frame, camera)), then issue them
|
||||
# as a single batched generate_json call so the client can fan them
|
||||
# out concurrently.
|
||||
per_call: list[tuple[float, str, str, list[dict[str, Any]]]] = []
|
||||
for idx in anchor_idx:
|
||||
ts = float(record.frame_timestamps[idx])
|
||||
qtype = rng.choice(self.config.question_types)
|
||||
for camera in cameras:
|
||||
messages = self._build_messages(record, qtype, ts, camera)
|
||||
# Skip cameras that decoded to zero frames at this ts: no point
|
||||
# asking the VLM to ground a bbox without an image.
|
||||
if not _has_image_block(messages):
|
||||
continue
|
||||
per_call.append((ts, camera, qtype, messages))
|
||||
|
||||
if not per_call:
|
||||
staging.write("vqa", [])
|
||||
return
|
||||
|
||||
results = self.vlm.generate_json([m for _, _, _, m in per_call])
|
||||
|
||||
rows: list[dict[str, Any]] = []
|
||||
for (ts, camera, _qtype, _messages), result in zip(per_call, results, strict=True):
|
||||
qa = self._postprocess(result)
|
||||
if qa is None:
|
||||
continue
|
||||
question, answer = qa
|
||||
rows.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": question,
|
||||
"style": "vqa",
|
||||
"timestamp": ts,
|
||||
"camera": camera,
|
||||
"tool_calls": None,
|
||||
}
|
||||
)
|
||||
rows.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": json.dumps(answer, sort_keys=True),
|
||||
"style": "vqa",
|
||||
"timestamp": ts,
|
||||
"camera": camera,
|
||||
"tool_calls": None,
|
||||
}
|
||||
)
|
||||
staging.write("vqa", rows)
|
||||
|
||||
def _target_cameras(self) -> list[str]:
|
||||
"""Return the cameras the ``vqa`` module should iterate per anchored frame.
|
||||
|
||||
Defaults to every camera the provider exposes. Datasets with no
|
||||
cameras (or test/null providers) yield an empty list, which makes
|
||||
``run_episode`` a no-op.
|
||||
|
||||
When ``config.restrict_to_default_camera`` is set, VQA grounds on
|
||||
only the provider's default camera (the single ``--vlm.camera_key``
|
||||
stream), matching the plan / interjection modules so the whole
|
||||
pipeline focuses on one view.
|
||||
"""
|
||||
all_cameras = list(getattr(self.frame_provider, "camera_keys", []) or [])
|
||||
if getattr(self.config, "restrict_to_default_camera", False):
|
||||
default = getattr(self.frame_provider, "camera_key", None)
|
||||
if default and default in all_cameras:
|
||||
return [default]
|
||||
# ``restrict_to_default_camera`` is set but the configured default
|
||||
# isn't one the provider exposes. Returning it anyway would make
|
||||
# ``_decode`` raise a KeyError deep in frame extraction, so warn and
|
||||
# fall through to every available camera instead.
|
||||
if default:
|
||||
logging.getLogger(__name__).warning(
|
||||
"restrict_to_default_camera is set but camera_key=%r is not in the "
|
||||
"provider's cameras %s; grounding VQA on all available cameras instead.",
|
||||
default,
|
||||
all_cameras,
|
||||
)
|
||||
return all_cameras
|
||||
|
||||
def _build_messages(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
question_type: str,
|
||||
frame_timestamp: float,
|
||||
camera_key: str,
|
||||
) -> list[dict[str, Any]]:
|
||||
prompt = load_prompt("vqa").format(
|
||||
episode_task=record.episode_task,
|
||||
question_type=question_type,
|
||||
)
|
||||
images = self.frame_provider.frames_at(record, [frame_timestamp], camera_key=camera_key)
|
||||
content = [*to_image_blocks(images), {"type": "text", "text": prompt}]
|
||||
return [{"role": "user", "content": content}]
|
||||
|
||||
def _postprocess(self, result: Any) -> tuple[str, dict[str, Any]] | None:
|
||||
if not isinstance(result, dict):
|
||||
return None
|
||||
question = result.get("question")
|
||||
answer = result.get("answer")
|
||||
if not isinstance(question, str) or not question.strip():
|
||||
return None
|
||||
if not isinstance(answer, dict):
|
||||
return None
|
||||
# The validator will enforce shape; here we just sanity-check that the
|
||||
# answer matches *some* known shape so we can drop garbage early.
|
||||
if classify_vqa_answer(answer) is None:
|
||||
return None
|
||||
return question.strip(), answer
|
||||
|
||||
|
||||
def _has_image_block(messages: list[dict[str, Any]]) -> bool:
|
||||
"""Return True if any user content block is a populated image block."""
|
||||
for msg in messages:
|
||||
content = msg.get("content")
|
||||
if not isinstance(content, list):
|
||||
continue
|
||||
for block in content:
|
||||
if isinstance(block, dict) and block.get("type") == "image":
|
||||
return True
|
||||
return False
|
||||
@@ -1,211 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""``interjections`` module: interjections + paired speech (EVENT styles + speech atoms).
|
||||
|
||||
Two sub-passes:
|
||||
|
||||
1. At ``t=0``, emit ONLY a speech tool-call atom (acknowledgement of the
|
||||
canonical task). No interjection row — the canonical task is already the
|
||||
user utterance from ``meta/tasks.parquet``.
|
||||
|
||||
2. For mid-episode interruptions, emit a co-timestamped pair:
|
||||
{role:user, style:interjection, content:<text>}
|
||||
speech atom (role:assistant, style:None, tool_calls=[say(...)])
|
||||
Both rows go in ``language_events`` at the same timestamp.
|
||||
|
||||
The ``plan`` module's :meth:`run_plan_updates` reuses this module's
|
||||
interjection timestamps to refresh the ``plan`` row at the same instant.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import random
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from ..config import InterjectionsConfig
|
||||
from ..frames import FrameProvider, null_provider, to_image_blocks
|
||||
from ..prompts import load as load_prompt
|
||||
from ..reader import EpisodeRecord, reconstruct_subtask_spans, snap_to_frame
|
||||
from ..staging import EpisodeStaging
|
||||
from ..vlm_client import VlmClient
|
||||
from ..writer import speech_atom
|
||||
|
||||
|
||||
@dataclass
|
||||
class InterjectionsAndSpeechModule:
|
||||
"""Generate task-start speech and mid-episode interjection/speech pairs."""
|
||||
|
||||
vlm: VlmClient
|
||||
config: InterjectionsConfig
|
||||
seed: int = 1729
|
||||
frame_provider: FrameProvider = field(default_factory=null_provider)
|
||||
|
||||
@property
|
||||
def enabled(self) -> bool:
|
||||
return self.config.enabled
|
||||
|
||||
def run_episode(self, record: EpisodeRecord, staging: EpisodeStaging) -> None:
|
||||
rows: list[dict[str, Any]] = []
|
||||
if record.frame_timestamps:
|
||||
t0 = float(record.frame_timestamps[0])
|
||||
initial = self._initial_speech(record)
|
||||
if initial:
|
||||
rows.append(speech_atom(t0, initial))
|
||||
# Pull the ``plan`` module's subtask spans for this episode so the
|
||||
# interjection prompt can ground itself in the actual current
|
||||
# subtask at each chosen timestamp. The ``plan`` module ran first.
|
||||
episode_end_t = float(record.frame_timestamps[-1]) if record.frame_timestamps else None
|
||||
subtask_spans = reconstruct_subtask_spans(staging.read("plan"), episode_end_t=episode_end_t)
|
||||
rows.extend(self._mid_episode_interjections(record, subtask_spans))
|
||||
staging.write("interjections", rows)
|
||||
|
||||
@staticmethod
|
||||
def _subtask_at(spans: Sequence[dict[str, Any]], t: float) -> str | None:
|
||||
current: str | None = None
|
||||
for span in spans:
|
||||
if float(span["start"]) <= t:
|
||||
current = span.get("text")
|
||||
else:
|
||||
break
|
||||
return current
|
||||
|
||||
def _initial_speech(self, record: EpisodeRecord) -> str | None:
|
||||
prompt = load_prompt("interjections_initial_speech").format(
|
||||
episode_task=record.episode_task,
|
||||
)
|
||||
messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}]
|
||||
result = self.vlm.generate_json([messages])[0]
|
||||
if isinstance(result, dict) and isinstance(result.get("text"), str):
|
||||
text = result["text"].strip()
|
||||
if text:
|
||||
return text
|
||||
return None
|
||||
|
||||
def _mid_episode_interjections(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
subtask_spans: Sequence[dict[str, Any]],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Generate interjections aligned with the actual demo trajectory.
|
||||
|
||||
Teleop data is frozen — the robot already executed every step in
|
||||
the video. A *counterfactual* interjection like "actually skip
|
||||
the wipe" contradicts what then happens in the video, which is
|
||||
what qwen36moe-10/11 surfaced as low-quality interjections.
|
||||
|
||||
Instead, anchor every interjection at a subtask boundary and
|
||||
write it as a natural user request for the *upcoming* subtask.
|
||||
The robot's visible next behavior IS the interjection's effect,
|
||||
so the training signal stays consistent: interjection text →
|
||||
plan refresh → action stream all line up.
|
||||
"""
|
||||
if self.config.max_interjections_per_episode <= 0:
|
||||
return []
|
||||
if len(subtask_spans) < 2:
|
||||
# Need at least one transition (subtask 0 → subtask 1).
|
||||
return []
|
||||
# Deterministic per-episode RNG so reruns are stable across SLURM jobs.
|
||||
rng = random.Random(f"{self.seed}:{record.episode_index}:interjection")
|
||||
|
||||
# Boundaries: the start time of every subtask except the first
|
||||
# (which is just t0 and is covered by the initial-task speech atom).
|
||||
boundaries: list[tuple[float, str, str]] = []
|
||||
for i in range(1, len(subtask_spans)):
|
||||
ts = float(subtask_spans[i]["start"])
|
||||
if ts < self.config.interjection_min_t:
|
||||
continue
|
||||
prev_text = (subtask_spans[i - 1].get("text") or "").strip()
|
||||
next_text = (subtask_spans[i].get("text") or "").strip()
|
||||
if not next_text:
|
||||
continue
|
||||
boundaries.append((ts, prev_text, next_text))
|
||||
if not boundaries:
|
||||
return []
|
||||
|
||||
n = min(self.config.max_interjections_per_episode, len(boundaries))
|
||||
chosen = sorted(rng.sample(boundaries, n), key=lambda b: b[0])
|
||||
|
||||
out: list[dict[str, Any]] = []
|
||||
for t, prev_subtask, next_subtask in chosen:
|
||||
t_snap = snap_to_frame(t, record.frame_timestamps)
|
||||
# Window straddles the boundary so the VLM sees the end of the
|
||||
# previous subtask and the start of the next one — same
|
||||
# conditioning the policy will see at training time.
|
||||
window_ts = self._window_timestamps(t_snap, record.frame_timestamps)
|
||||
prompt = load_prompt("interjections_interjection").format(
|
||||
episode_task=record.episode_task,
|
||||
prev_subtask=prev_subtask or "(starting from initial state)",
|
||||
next_subtask=next_subtask,
|
||||
timestamp=t_snap,
|
||||
window_seconds=self.config.interjection_window_seconds,
|
||||
)
|
||||
images = self.frame_provider.frames_at(record, window_ts)
|
||||
content = [*to_image_blocks(images), {"type": "text", "text": prompt}]
|
||||
messages = [{"role": "user", "content": content}]
|
||||
result = self.vlm.generate_json([messages])[0]
|
||||
if not isinstance(result, dict):
|
||||
continue
|
||||
interjection_text = result.get("interjection")
|
||||
speech_text = result.get("speech")
|
||||
if not isinstance(interjection_text, str) or not interjection_text.strip():
|
||||
continue
|
||||
if not isinstance(speech_text, str) or not speech_text.strip():
|
||||
continue
|
||||
out.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": interjection_text.strip(),
|
||||
"style": "interjection",
|
||||
"timestamp": t_snap,
|
||||
"tool_calls": None,
|
||||
}
|
||||
)
|
||||
out.append(speech_atom(t_snap, speech_text.strip()))
|
||||
return out
|
||||
|
||||
def _window_timestamps(self, t_anchor: float, frame_timestamps: Sequence[float]) -> list[float]:
|
||||
"""Return a small set of frame timestamps centered on ``t_anchor``.
|
||||
|
||||
The window straddles the subtask boundary the interjection sits
|
||||
on: roughly half the frames cover the end of the previous
|
||||
subtask, half cover the start of the next one. The VLM therefore
|
||||
sees BOTH what just finished AND what's about to start, which is
|
||||
the conditioning we need to write a natural "now please do X"
|
||||
request that matches the visible upcoming behavior.
|
||||
"""
|
||||
if not frame_timestamps:
|
||||
return [t_anchor]
|
||||
n = max(1, int(self.config.interjection_window_frames))
|
||||
if n == 1:
|
||||
return [t_anchor]
|
||||
window = float(self.config.interjection_window_seconds)
|
||||
step = window / max(1, n - 1)
|
||||
# Center the window on the anchor so half lands before, half after.
|
||||
start_offset = -window / 2.0
|
||||
targets = [t_anchor + start_offset + step * i for i in range(n)]
|
||||
first_ts = float(frame_timestamps[0])
|
||||
last_ts = float(frame_timestamps[-1])
|
||||
snapped: list[float] = []
|
||||
seen: set[float] = set()
|
||||
for tgt in targets:
|
||||
clamped = min(last_ts, max(first_ts, tgt))
|
||||
t = snap_to_frame(clamped, frame_timestamps)
|
||||
if t not in seen:
|
||||
seen.add(t)
|
||||
snapped.append(t)
|
||||
return snapped or [t_anchor]
|
||||
@@ -1,712 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""``plan`` module: subtask decomposition + plan + memory (PERSISTENT styles)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from ..config import PlanConfig
|
||||
from ..frames import (
|
||||
FrameProvider,
|
||||
VideoFrameProvider,
|
||||
null_provider,
|
||||
to_video_block,
|
||||
to_video_url_block,
|
||||
)
|
||||
from ..prompts import load as load_prompt
|
||||
from ..reader import EpisodeRecord, reconstruct_subtask_spans, snap_to_frame
|
||||
from ..staging import EpisodeStaging
|
||||
from ..vlm_client import VlmClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlanSubtasksMemoryModule:
|
||||
"""Generate subtask spans, plan, and memory rows.
|
||||
|
||||
All output is persistent (lives in ``language_persistent``):
|
||||
|
||||
- ``subtask`` rows: one per span, stamped at the span's *start* timestamp
|
||||
(snapped to an exact frame).
|
||||
- ``plan`` rows: emitted at ``t=0``; refreshed at every interjection
|
||||
timestamp via :meth:`run_plan_updates` (called by the executor after
|
||||
the ``interjections`` module completes).
|
||||
- ``memory`` rows: emitted at each subtask boundary (= subtask start
|
||||
timestamp from the second subtask onward).
|
||||
"""
|
||||
|
||||
vlm: VlmClient
|
||||
config: PlanConfig
|
||||
frame_provider: FrameProvider = field(default_factory=null_provider)
|
||||
|
||||
@property
|
||||
def enabled(self) -> bool:
|
||||
return self.config.enabled
|
||||
|
||||
def run_episode(self, record: EpisodeRecord, staging: EpisodeStaging) -> None:
|
||||
rows: list[dict[str, Any]] = []
|
||||
# Task driving every plan-module prompt: canonical episode_task, or a
|
||||
# video-derived one when it's empty/placeholder (see derive_task_*).
|
||||
effective_task = self._resolve_effective_task(record)
|
||||
# task_aug rows at t=0: phrasings the renderer rotates ${task} through.
|
||||
# Either the structured 5-axis taxonomy (task_aug_axes.enabled) or
|
||||
# free-form n_task_rephrasings; the effective task is always emitted
|
||||
# first so the rotation covers the source-of-truth phrasing.
|
||||
t0 = float(record.frame_timestamps[0]) if record.frame_timestamps else 0.0
|
||||
variants: list[str] | None = None
|
||||
if self.config.task_aug_axes.enabled and effective_task:
|
||||
variants = self._generate_task_aug_by_axes(effective_task, self.config.task_aug_axes)
|
||||
elif self.config.n_task_rephrasings > 0 and effective_task:
|
||||
variants = self._generate_task_rephrasings(effective_task, n=self.config.n_task_rephrasings)
|
||||
if variants is not None:
|
||||
rows.extend(self._task_aug_rows([effective_task, *variants], t0))
|
||||
|
||||
subtask_spans = self._generate_subtasks(record, task=effective_task)
|
||||
|
||||
# subtask rows
|
||||
for span in subtask_spans:
|
||||
rows.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": span["text"],
|
||||
"style": "subtask",
|
||||
"timestamp": snap_to_frame(span["start"], record.frame_timestamps),
|
||||
"tool_calls": None,
|
||||
}
|
||||
)
|
||||
# Plan rows at every subtask boundary (incl. t=0). The plan is a
|
||||
# numbered list of still-todo subtasks, so re-emitting at each
|
||||
# boundary makes it shrink as work progresses — ${plan} at frame t is
|
||||
# exactly what's left to do.
|
||||
if self.config.emit_plan:
|
||||
for span in subtask_spans:
|
||||
boundary_t = snap_to_frame(span["start"], record.frame_timestamps)
|
||||
plan_text = self._generate_plan(
|
||||
record, subtask_spans, refresh_t=boundary_t, task=effective_task
|
||||
)
|
||||
if plan_text is not None:
|
||||
rows.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": plan_text,
|
||||
"style": "plan",
|
||||
"timestamp": float(boundary_t),
|
||||
"tool_calls": None,
|
||||
}
|
||||
)
|
||||
# memory rows at every subtask boundary except the very first start
|
||||
prior_memory = ""
|
||||
for i, span in enumerate(subtask_spans[1:], start=1):
|
||||
completed = subtask_spans[i - 1]["text"]
|
||||
remaining = [s["text"] for s in subtask_spans[i:]]
|
||||
mem_text = self._generate_memory(record, prior_memory, completed, remaining, task=effective_task)
|
||||
if mem_text:
|
||||
ts = snap_to_frame(span["start"], record.frame_timestamps)
|
||||
rows.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": mem_text,
|
||||
"style": "memory",
|
||||
"timestamp": ts,
|
||||
"tool_calls": None,
|
||||
}
|
||||
)
|
||||
prior_memory = mem_text
|
||||
staging.write("plan", rows)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Task derivation + rephrasings
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
_PLACEHOLDER_TASKS: frozenset[str] = frozenset(
|
||||
{
|
||||
"debug",
|
||||
"test",
|
||||
"tbd",
|
||||
"todo",
|
||||
"n/a",
|
||||
"na",
|
||||
"untitled",
|
||||
"unnamed",
|
||||
"default",
|
||||
"placeholder",
|
||||
}
|
||||
)
|
||||
|
||||
def _resolve_effective_task(self, record: EpisodeRecord) -> str:
|
||||
"""Decide which task string drives the ``plan`` module for this episode.
|
||||
|
||||
Returns the user-supplied ``record.episode_task`` unless
|
||||
``derive_task_from_video`` says otherwise (see config docstring).
|
||||
Falls back gracefully to the canonical task if video derivation
|
||||
fails.
|
||||
"""
|
||||
canonical = (record.episode_task or "").strip()
|
||||
mode = (self.config.derive_task_from_video or "off").strip().lower()
|
||||
if mode == "always":
|
||||
derived = self._derive_task_from_video(record)
|
||||
return derived or canonical
|
||||
if mode == "if_short" and self._task_seems_bad(canonical):
|
||||
derived = self._derive_task_from_video(record)
|
||||
if derived:
|
||||
return derived
|
||||
return canonical
|
||||
|
||||
def _task_seems_bad(self, task: str) -> bool:
|
||||
if not task:
|
||||
return True
|
||||
if len(task.split()) < int(self.config.derive_task_min_words):
|
||||
return True
|
||||
return task.lower() in self._PLACEHOLDER_TASKS
|
||||
|
||||
@staticmethod
|
||||
def _task_aug_rows(phrasings: Sequence[str], t0: float) -> list[dict[str, Any]]:
|
||||
"""Build deduplicated ``task_aug`` rows (role=user) at ``t0``."""
|
||||
seen: set[str] = set()
|
||||
rows: list[dict[str, Any]] = []
|
||||
for phrasing in phrasings:
|
||||
key = phrasing.strip()
|
||||
if not key or key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
rows.append(
|
||||
{"role": "user", "content": key, "style": "task_aug", "timestamp": t0, "tool_calls": None}
|
||||
)
|
||||
return rows
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# VLM call helpers — every plan-module prompt follows the same shape:
|
||||
# build messages → single VLM call → pull a named field.
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _vlm_field(self, messages: list[dict[str, Any]], field: str) -> Any:
|
||||
"""Run a single VLM call and return ``result[field]`` or ``None``.
|
||||
|
||||
Centralizes the ``vlm.generate_json([m])[0]`` + ``isinstance(dict)``
|
||||
dance every prompt-call site needs.
|
||||
"""
|
||||
result = self.vlm.generate_json([messages])[0]
|
||||
if isinstance(result, dict):
|
||||
return result.get(field)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _text_message(text: str) -> list[dict[str, Any]]:
|
||||
"""One-shot text-only user message wrapped for ``generate_json``."""
|
||||
return [{"role": "user", "content": [{"type": "text", "text": text}]}]
|
||||
|
||||
def _video_message(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
prompt: str,
|
||||
window: tuple[float, float] | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""User message combining the (optionally windowed) video block with ``prompt``."""
|
||||
content = [*self._episode_video_block(record, window=window), {"type": "text", "text": prompt}]
|
||||
return [{"role": "user", "content": content}]
|
||||
|
||||
def _derive_task_from_video(self, record: EpisodeRecord) -> str | None:
|
||||
"""Ask the VLM "what is this video about" with no task hint at all."""
|
||||
text = self._vlm_field(self._video_message(record, load_prompt("plan_video_task")), "task")
|
||||
return text.strip() if isinstance(text, str) and text.strip() else None
|
||||
|
||||
def _generate_task_rephrasings(self, base_task: str, *, n: int) -> list[str]:
|
||||
"""Generate ``n`` text-only paraphrases of ``base_task``."""
|
||||
if n <= 0 or not base_task:
|
||||
return []
|
||||
prompt = load_prompt("plan_task_rephrasings").format(base_task=base_task, n=n)
|
||||
raw = self._vlm_field(self._text_message(prompt), "rephrasings")
|
||||
if not isinstance(raw, list):
|
||||
return []
|
||||
out = [item.strip().strip('"').strip("'") for item in raw if isinstance(item, str)]
|
||||
return [s for s in out if s][:n]
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Structured 5-axis task augmentation (EgoMimic-style taxonomy)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _generate_task_aug_by_axes(self, base_task: str, axes_cfg: Any) -> list[str]:
|
||||
"""One VLM call → variants along the 5-axis taxonomy.
|
||||
|
||||
Variants from all axes are flattened into a single list (the
|
||||
downstream pipeline doesn't need to know about the per-axis
|
||||
bucketing — every variant becomes a ``task_aug`` row). Order
|
||||
is preserved for reproducibility: synonym_paraphrase first,
|
||||
then omit_arm, then omit_orientation, then omit_grasp_method,
|
||||
then combined_omissions.
|
||||
"""
|
||||
if not base_task:
|
||||
return []
|
||||
prompt = load_prompt("plan_task_aug_axes").format(
|
||||
base_task=base_task,
|
||||
n_synonym=axes_cfg.synonym_paraphrase,
|
||||
n_omit_arm=axes_cfg.omit_arm,
|
||||
n_omit_orientation=axes_cfg.omit_orientation,
|
||||
n_omit_grasp_method=axes_cfg.omit_grasp_method,
|
||||
n_combined=axes_cfg.combined_omissions,
|
||||
)
|
||||
result = self.vlm.generate_json([self._text_message(prompt)])[0]
|
||||
if not isinstance(result, dict):
|
||||
return []
|
||||
ordered_axes = (
|
||||
"synonym_paraphrase",
|
||||
"omit_arm",
|
||||
"omit_orientation",
|
||||
"omit_grasp_method",
|
||||
"combined_omissions",
|
||||
)
|
||||
flat: list[str] = []
|
||||
seen: set[str] = set()
|
||||
for axis in ordered_axes:
|
||||
entries = result.get(axis)
|
||||
if not isinstance(entries, list):
|
||||
continue
|
||||
for item in entries:
|
||||
if not isinstance(item, str):
|
||||
continue
|
||||
key = item.strip().strip('"').strip("'")
|
||||
if not key or key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
flat.append(key)
|
||||
return flat
|
||||
|
||||
def _episode_video_block(
|
||||
self, record: EpisodeRecord, window: tuple[float, float] | None = None
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Video block for the segmentation / describe prompts.
|
||||
|
||||
Always returns a block that actually carries the video. When
|
||||
``use_video_url`` is set we try the server-side ``video_url``
|
||||
path first, but if clip extraction fails we FALL BACK to
|
||||
decoding + embedding frames rather than returning an empty
|
||||
block — an empty block would leave the VLM with no visual
|
||||
grounding at all and it would hallucinate subtasks purely from
|
||||
the task text.
|
||||
|
||||
When ``window=(w0, w1)`` is given (windowed subtask generation,
|
||||
``subtask_window_seconds > 0``), embed frames sampled at the FIXED
|
||||
``frames_per_second`` rate within ``[w0, w1]`` — constant temporal
|
||||
density regardless of episode length, so long episodes are split
|
||||
into windows rather than subsampled to a sparse 32-frame whole-
|
||||
episode view. The ``video_url`` path is skipped for windows (it is
|
||||
a whole-episode clip). ``max_video_frames`` still caps each window
|
||||
as a context-budget safety net.
|
||||
"""
|
||||
if not record.frame_timestamps:
|
||||
return []
|
||||
if window is not None:
|
||||
w0, w1 = float(window[0]), float(window[1])
|
||||
dur = max(0.0, w1 - w0)
|
||||
n = max(1, int(round(dur * self.config.frames_per_second)) + 1)
|
||||
n = min(n, self.config.max_video_frames)
|
||||
if n <= 1 or dur <= 0.0:
|
||||
timestamps = [0.5 * (w0 + w1)]
|
||||
else:
|
||||
step = dur / (n - 1)
|
||||
timestamps = [w0 + i * step for i in range(n)]
|
||||
return to_video_block(self.frame_provider.frames_at(record, timestamps))
|
||||
if self.config.use_video_url and isinstance(self.frame_provider, VideoFrameProvider):
|
||||
cache_dir = Path(self.frame_provider.root) / ".annotate_staging" / ".video_clips"
|
||||
clip = self.frame_provider.episode_clip_path(record, cache_dir)
|
||||
if clip is not None:
|
||||
return to_video_url_block(f"file://{clip}", fps=self.config.use_video_url_fps)
|
||||
logger.warning(
|
||||
"episode %d: video_url clip extraction failed — falling back to "
|
||||
"embedded frames so the VLM still sees the demonstration",
|
||||
record.episode_index,
|
||||
)
|
||||
episode_duration = record.frame_timestamps[-1] - record.frame_timestamps[0]
|
||||
target_count = max(1, int(round(episode_duration * self.config.frames_per_second)))
|
||||
target_count = min(target_count, self.config.max_video_frames)
|
||||
video_frames = self.frame_provider.video_for_episode(record, target_count)
|
||||
return to_video_block(video_frames)
|
||||
|
||||
def run_plan_updates(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
staging: EpisodeStaging,
|
||||
interjection_times: Sequence[float],
|
||||
interjection_texts: Sequence[str] | None = None,
|
||||
) -> None:
|
||||
"""Append additional ``plan`` rows at every interjection timestamp.
|
||||
|
||||
Plans refresh ONLY on user interjections (event-driven). The
|
||||
interjection text is forwarded into the prompt so the refreshed plan
|
||||
reflects the user's correction.
|
||||
"""
|
||||
if not self.config.emit_plan:
|
||||
return
|
||||
existing = staging.read("plan")
|
||||
# Pass the last frame timestamp so the final span is closed (else its
|
||||
# end == start, zero duration, and a refresh inside it is missed).
|
||||
episode_end_t = float(record.frame_timestamps[-1]) if record.frame_timestamps else None
|
||||
spans = reconstruct_subtask_spans(existing, episode_end_t=episode_end_t)
|
||||
already_planned: set[float] = {float(r["timestamp"]) for r in existing if r.get("style") == "plan"}
|
||||
new_rows = list(existing)
|
||||
|
||||
texts: list[str | None] = (
|
||||
[None] * len(interjection_times)
|
||||
if interjection_texts is None
|
||||
else [str(t) if t else None for t in interjection_texts]
|
||||
)
|
||||
for raw_t, inter_text in zip(interjection_times, texts, strict=True):
|
||||
t = snap_to_frame(raw_t, record.frame_timestamps)
|
||||
if t in already_planned:
|
||||
continue
|
||||
already_planned.add(t)
|
||||
plan_text = self._generate_plan(record, spans, refresh_t=t, interjection=inter_text)
|
||||
if plan_text is not None:
|
||||
new_rows.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": plan_text,
|
||||
"style": "plan",
|
||||
"timestamp": t,
|
||||
"tool_calls": None,
|
||||
}
|
||||
)
|
||||
staging.write("plan", new_rows)
|
||||
|
||||
def _generate_subtasks(self, record: EpisodeRecord, *, task: str | None = None) -> list[dict[str, Any]]:
|
||||
"""Generate subtask spans, optionally via a multi-call quality chain.
|
||||
|
||||
Single call (default): watch video → emit subtask JSON.
|
||||
|
||||
Multi-call (opt-in, higher quality, more VLM calls):
|
||||
1. ``subtask_describe_first`` — a grounding pass that narrates
|
||||
ONLY what is visible (no JSON commitment to subtasks yet);
|
||||
its description is injected into the segmentation prompt so
|
||||
the model segments its own grounded observations instead of
|
||||
pattern-matching the task text.
|
||||
2. segmentation — emit subtask JSON (as before).
|
||||
"""
|
||||
if record.row_count == 0 or not record.frame_timestamps:
|
||||
return []
|
||||
episode_duration = record.frame_timestamps[-1] - record.frame_timestamps[0]
|
||||
effective_task = task if task is not None else record.episode_task
|
||||
|
||||
# ---- Windowed path (constant temporal density) ---------------
|
||||
# If subtask_window_seconds > 0 and the episode exceeds one window,
|
||||
# process fixed-length windows so the VLM always sees
|
||||
# frames_per_second density; results are merged + stitched.
|
||||
window_s = float(getattr(self.config, "subtask_window_seconds", 0.0) or 0.0)
|
||||
if window_s > 0.0 and episode_duration > window_s:
|
||||
return self._generate_subtasks_windowed(record, effective_task, window_s)
|
||||
|
||||
# ---- Pass 1 (optional): grounding description ----------------
|
||||
observation_block = ""
|
||||
if getattr(self.config, "subtask_describe_first", False):
|
||||
description = self._describe_episode(record, effective_task)
|
||||
if description:
|
||||
observation_block = (
|
||||
"You watched this video and described, chronologically, "
|
||||
"ONLY what the robot actually does:\n"
|
||||
f'"""{description}"""\n\n'
|
||||
"Segment THAT grounded description (cross-checked against "
|
||||
"the video) into atomic subtasks. Do not introduce any "
|
||||
"action that is not in your description above.\n\n"
|
||||
)
|
||||
|
||||
# ---- Pass 2: segmentation ------------------------------------
|
||||
prompt = load_prompt("plan_subtasks").format(
|
||||
episode_task=effective_task,
|
||||
min_subtask_seconds=self.config.min_subtask_seconds,
|
||||
max_steps=self.config.plan_max_steps,
|
||||
episode_duration=f"{episode_duration:.3f}",
|
||||
observation_block=observation_block,
|
||||
)
|
||||
spans = self._vlm_field(self._video_message(record, prompt), "subtasks")
|
||||
cleaned = self._clean_spans(spans, record)
|
||||
if not cleaned:
|
||||
return []
|
||||
|
||||
# ---- Full-episode coverage stitch ----------------------------
|
||||
# The VLM can start after t0 or leave gaps, so frames fall through
|
||||
# with no active subtask. Always stitch into a contiguous
|
||||
# [t0, t_last] cover.
|
||||
cleaned = self._stitch_full_coverage(cleaned, record)
|
||||
|
||||
return cleaned
|
||||
|
||||
def _generate_subtasks_windowed(
|
||||
self, record: EpisodeRecord, task: str, window_s: float
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Subtask generation in fixed-length windows at constant fps.
|
||||
|
||||
Splits ``[t0, t_last]`` into consecutive windows of ``window_s``
|
||||
seconds, runs the describe -> segment chain on each window's own
|
||||
frames (sampled at ``frames_per_second``), offsets
|
||||
each window's spans back to absolute episode time, then merges +
|
||||
stitches into a contiguous whole-episode cover.
|
||||
"""
|
||||
t0 = float(record.frame_timestamps[0])
|
||||
t_last = float(record.frame_timestamps[-1])
|
||||
all_spans: list[dict[str, Any]] = []
|
||||
w0 = t0
|
||||
n_windows = 0
|
||||
while w0 < t_last - 1e-6:
|
||||
w1 = min(w0 + window_s, t_last)
|
||||
all_spans.extend(self._subtasks_for_window(record, task, w0, w1))
|
||||
n_windows += 1
|
||||
w0 = w1
|
||||
logger.info(
|
||||
"episode %d: windowed subtask gen over %d window(s) of %.1fs -> %d raw spans",
|
||||
record.episode_index,
|
||||
n_windows,
|
||||
window_s,
|
||||
len(all_spans),
|
||||
)
|
||||
# Merge across windows: clamp to the absolute episode, sort, and
|
||||
# frame-snap to distinct starts (handles any boundary collisions).
|
||||
cleaned = self._clean_spans(all_spans, record)
|
||||
if not cleaned:
|
||||
return []
|
||||
return self._stitch_full_coverage(cleaned, record)
|
||||
|
||||
def _subtasks_for_window(
|
||||
self, record: EpisodeRecord, task: str, w0: float, w1: float
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Run describe -> segment on one ``[w0, w1]`` window.
|
||||
|
||||
The model works in window-RELATIVE time ``[0, L]`` (it perceives
|
||||
the window as a clip starting at 0); spans are offset back to
|
||||
absolute ``[w0, w1]`` before returning.
|
||||
"""
|
||||
window = (w0, w1)
|
||||
win_len = max(0.0, w1 - w0)
|
||||
|
||||
observation_block = ""
|
||||
if getattr(self.config, "subtask_describe_first", False):
|
||||
description = self._describe_episode(record, task, window=window)
|
||||
if description:
|
||||
observation_block = (
|
||||
"You watched this video clip and described, chronologically, "
|
||||
"ONLY what the robot actually does:\n"
|
||||
f'"""{description}"""\n\n'
|
||||
"Segment THAT grounded description (cross-checked against "
|
||||
"the clip) into atomic subtasks. Do not introduce any "
|
||||
"action that is not in your description above.\n\n"
|
||||
)
|
||||
|
||||
prompt = load_prompt("plan_subtasks").format(
|
||||
episode_task=task,
|
||||
min_subtask_seconds=self.config.min_subtask_seconds,
|
||||
max_steps=self.config.plan_max_steps,
|
||||
episode_duration=f"{win_len:.3f}",
|
||||
observation_block=observation_block,
|
||||
)
|
||||
spans = self._vlm_field(self._video_message(record, prompt, window=window), "subtasks")
|
||||
# Window-relative clamp; no frame-snap dedupe yet (done on the
|
||||
# merged absolute set).
|
||||
cleaned = self._clean_spans(spans, record, bounds=(0.0, win_len), dedupe=False)
|
||||
if not cleaned:
|
||||
return []
|
||||
|
||||
# Offset window-relative spans back to absolute episode time.
|
||||
for s in cleaned:
|
||||
s["start"] = w0 + float(s["start"])
|
||||
s["end"] = w0 + float(s["end"])
|
||||
return cleaned
|
||||
|
||||
def _stitch_full_coverage(
|
||||
self, spans: list[dict[str, Any]], record: EpisodeRecord
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Make subtask spans tile the full episode with no gaps.
|
||||
|
||||
* The first subtask starts at the episode's first frame ``t0``
|
||||
(any idle / approach before the first labelled action is folded
|
||||
into it), so every early frame has an active subtask.
|
||||
* Each subtask's ``end`` is snapped to the next subtask's
|
||||
``start`` (gaps between spans are closed), and the final
|
||||
subtask's ``end`` extends to the last frame ``t_last``.
|
||||
|
||||
Starts are otherwise left as the (already frame-snapped, distinct)
|
||||
values the VLM produced — only the FIRST start is pulled
|
||||
back to ``t0``, which can't collide with a later span because it
|
||||
was already the earliest. Purely deterministic; runs after the
|
||||
VLM passes.
|
||||
"""
|
||||
if not spans or not record.frame_timestamps:
|
||||
return spans
|
||||
t0 = float(record.frame_timestamps[0])
|
||||
t_last = float(record.frame_timestamps[-1])
|
||||
spans = sorted(spans, key=lambda s: float(s["start"]))
|
||||
spans[0]["start"] = t0
|
||||
for i in range(len(spans) - 1):
|
||||
spans[i]["end"] = float(spans[i + 1]["start"])
|
||||
spans[-1]["end"] = t_last
|
||||
for s in spans:
|
||||
if float(s["end"]) < float(s["start"]):
|
||||
s["end"] = float(s["start"])
|
||||
return spans
|
||||
|
||||
def _clean_spans(
|
||||
self,
|
||||
spans: Any,
|
||||
record: EpisodeRecord,
|
||||
bounds: tuple[float, float] | None = None,
|
||||
dedupe: bool = True,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Clamp / sort / (optionally) dedupe raw VLM subtask spans into valid rows.
|
||||
|
||||
``bounds`` overrides the clamp range — pass the window's
|
||||
``(w_lo, w_hi)`` when cleaning window-relative spans, or leave
|
||||
``None`` to clamp to the whole episode ``[t0, t_last]``.
|
||||
``dedupe`` runs the frame-snap distinct-start step; skip it for
|
||||
window-relative spans (frame snapping is done once on the merged,
|
||||
absolute-time set).
|
||||
"""
|
||||
if not spans:
|
||||
return []
|
||||
if bounds is not None:
|
||||
lo, hi = float(bounds[0]), float(bounds[1])
|
||||
else:
|
||||
lo = record.frame_timestamps[0]
|
||||
hi = record.frame_timestamps[-1]
|
||||
cleaned: list[dict[str, Any]] = []
|
||||
for span in spans:
|
||||
try:
|
||||
start = float(span["start"])
|
||||
end = float(span["end"])
|
||||
text = str(span["text"]).strip()
|
||||
except (KeyError, ValueError, TypeError):
|
||||
continue
|
||||
start = max(lo, min(start, hi))
|
||||
end = max(lo, min(end, hi))
|
||||
if end < start:
|
||||
start, end = end, start
|
||||
if not text:
|
||||
continue
|
||||
cleaned.append({"text": text, "start": start, "end": end})
|
||||
cleaned.sort(key=lambda s: s["start"])
|
||||
if dedupe:
|
||||
return self._dedupe_starts_to_distinct_frames(cleaned, record)
|
||||
return cleaned
|
||||
|
||||
def _describe_episode(
|
||||
self, record: EpisodeRecord, task: str, window: tuple[float, float] | None = None
|
||||
) -> str:
|
||||
"""Grounding pass: free-form chronological description of the (windowed) video."""
|
||||
prompt = load_prompt("plan_subtask_describe").format(episode_task=task)
|
||||
text = self._vlm_field(self._video_message(record, prompt, window=window), "description")
|
||||
return text.strip() if isinstance(text, str) and text.strip() else ""
|
||||
|
||||
@staticmethod
|
||||
def _dedupe_starts_to_distinct_frames(
|
||||
spans: list[dict[str, Any]], record: EpisodeRecord
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Bump same-frame subtask starts onto distinct frames.
|
||||
|
||||
Two consecutive VLM spans whose ``start`` rounds to the same
|
||||
source frame (after :func:`snap_to_frame`) would otherwise emit
|
||||
two ``style=subtask`` rows at the identical persistent
|
||||
timestamp. The training-time renderer's ``active_at(t,
|
||||
style=subtask)`` resolver can't disambiguate that and raises
|
||||
``Ambiguous resolver for style='subtask'``.
|
||||
|
||||
Walk the (sorted-by-start) spans, snap each to its frame, and
|
||||
if the snapped frame is already taken push the span onto the
|
||||
next unused frame so both subtasks survive on distinct
|
||||
timestamps. If the episode ends before a free frame is found,
|
||||
the trailing span is dropped with a warning — better than
|
||||
poisoning the render.
|
||||
"""
|
||||
if not spans:
|
||||
return spans
|
||||
frames = record.frame_timestamps
|
||||
if not frames:
|
||||
return spans
|
||||
used: set[float] = set()
|
||||
out: list[dict[str, Any]] = []
|
||||
for span in spans:
|
||||
ts = snap_to_frame(span["start"], frames)
|
||||
if ts in used:
|
||||
next_ts = next((f for f in frames if f > ts and f not in used), None)
|
||||
if next_ts is None:
|
||||
logger.warning(
|
||||
"episode %d: subtask %r snapped to occupied frame "
|
||||
"%.3f and no free later frame exists — dropping",
|
||||
record.episode_index,
|
||||
span.get("text"),
|
||||
ts,
|
||||
)
|
||||
continue
|
||||
ts = next_ts
|
||||
used.add(ts)
|
||||
new_span = {**span, "start": ts}
|
||||
if float(new_span.get("end", ts)) < ts:
|
||||
new_span["end"] = ts
|
||||
out.append(new_span)
|
||||
return out
|
||||
|
||||
def _generate_plan(
|
||||
self,
|
||||
record: EpisodeRecord, # noqa: ARG002 (kept for signature stability)
|
||||
subtask_spans: Sequence[dict[str, Any]],
|
||||
*,
|
||||
refresh_t: float | None = None,
|
||||
interjection: str | None = None, # noqa: ARG002
|
||||
task: str | None = None, # noqa: ARG002
|
||||
) -> str | None:
|
||||
"""Deterministic plan = numbered list of *still-todo* subtasks.
|
||||
|
||||
No VLM call: a plain numbered list keeps the plan aligned with the
|
||||
upcoming subtasks (the old VLM "compact hierarchical plan" prompt
|
||||
cost a round-trip per episode/refresh and could diverge).
|
||||
|
||||
1. <subtask 1>
|
||||
2. <subtask 2>
|
||||
|
||||
On a refresh at ``refresh_t`` (from ``run_plan_updates`` on
|
||||
interjections, and ``run_episode`` at each boundary), only subtasks
|
||||
starting at or after ``refresh_t`` are included — so it always
|
||||
describes what's left.
|
||||
"""
|
||||
if not subtask_spans:
|
||||
return None
|
||||
remaining = [
|
||||
s for s in subtask_spans if refresh_t is None or float(s.get("start", 0.0)) >= float(refresh_t)
|
||||
]
|
||||
if not remaining:
|
||||
# Past the last subtask boundary on a late refresh — nothing
|
||||
# left to plan; emit None so the caller skips the row.
|
||||
return None
|
||||
return "\n".join(f"{i}. {span.get('text', '').strip()}" for i, span in enumerate(remaining, start=1))
|
||||
|
||||
def _generate_memory(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
prior_memory: str,
|
||||
completed: str,
|
||||
remaining: Sequence[str],
|
||||
*,
|
||||
task: str | None = None,
|
||||
) -> str:
|
||||
prompt = load_prompt("plan_memory").format(
|
||||
episode_task=(task if task is not None else record.episode_task),
|
||||
prior_memory=prior_memory or "(none)",
|
||||
completed_subtask=completed,
|
||||
remaining_subtasks=", ".join(remaining) if remaining else "(none)",
|
||||
)
|
||||
memory = self._vlm_field(self._text_message(prompt), "memory")
|
||||
return memory.strip() if isinstance(memory, str) else ""
|
||||
@@ -1,33 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Prompt templates loaded as plain text.
|
||||
|
||||
One file per use site. Templates use ``str.format(**vars)`` substitution; we
|
||||
intentionally avoid jinja2 here so the templates remain inspectable in
|
||||
plain editors and roundtrip cleanly through ``ruff format``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
_DIR = Path(__file__).parent
|
||||
|
||||
|
||||
def load(name: str) -> str:
|
||||
"""Read prompt template ``name.txt`` from the ``prompts/`` directory."""
|
||||
path = _DIR / f"{name}.txt"
|
||||
return path.read_text(encoding="utf-8")
|
||||
@@ -1,12 +0,0 @@
|
||||
The user just asked the robot: "{episode_task}".
|
||||
|
||||
Generate a short verbal acknowledgement the robot would speak back before
|
||||
beginning the task. Style: compact, confident, friendly.
|
||||
|
||||
Examples (Hi Robot, Shi 2025): "Sure, I won't put cheese on it.",
|
||||
"OK, starting with the sponge.", "Got it.".
|
||||
|
||||
Prefer very short replies: "Got it.", "On it.", "OK."
|
||||
|
||||
Output strictly valid JSON:
|
||||
{{ "text": "<the spoken acknowledgement>" }}
|
||||
@@ -1,46 +0,0 @@
|
||||
You are generating training data for a Hi Robot-style hierarchical
|
||||
robot policy. The robot in this demonstration has ALREADY executed
|
||||
every step shown in the video — we cannot retroactively change the
|
||||
action stream. To keep training data consistent with the video, the
|
||||
"interjection" must align with what the robot is *about to do next* in
|
||||
the demonstration, framed as a natural mid-task user request.
|
||||
|
||||
The episode's overall task: "{episode_task}".
|
||||
|
||||
The images above show roughly {window_seconds:.1f} seconds straddling a
|
||||
subtask boundary in the demonstration:
|
||||
|
||||
- Subtask the robot just finished: "{prev_subtask}"
|
||||
- Subtask the robot is about to start: "{next_subtask}"
|
||||
- Time into episode: {timestamp:.2f}s
|
||||
|
||||
Write ONE compact interjection the user would naturally say at this
|
||||
moment to prompt / confirm / encourage the robot to do "{next_subtask}".
|
||||
Keep it like a mid-task coaching cue, not a full instruction paragraph.
|
||||
Also write the robot's compact verbal acknowledgement.
|
||||
|
||||
Hard rules:
|
||||
|
||||
- The interjection MUST be consistent with the next subtask. The user
|
||||
cannot ask for something different from what the robot then does in
|
||||
the video. If you're tempted to say "actually skip X" or "do Y
|
||||
instead", DO NOT — those would contradict the demonstration.
|
||||
- The interjection must reference an object, location, or action that
|
||||
is plausible given the visible scene and the next subtask text.
|
||||
- One short phrase or sentence each. Conversational, not robotic.
|
||||
- Prefer direct cues: "{next_subtask}, please."; "Now {next_subtask}."
|
||||
- Keep robot speech very short: "OK.", "On it.", "Doing that."
|
||||
|
||||
Style examples (vary the phrasing — don't reuse these verbatim):
|
||||
- "Now go ahead and {next_subtask}."
|
||||
- "Great, can you {next_subtask} next?"
|
||||
- "{next_subtask}, please."
|
||||
- "Before you continue, please {next_subtask}."
|
||||
- "Looking good — {next_subtask} now."
|
||||
- "Okay, {next_subtask}."
|
||||
|
||||
Output strictly valid JSON:
|
||||
{{
|
||||
"interjection": "<short cue from the user, asking for the next subtask>",
|
||||
"speech": "<short robot acknowledgement>"
|
||||
}}
|
||||
@@ -1,36 +0,0 @@
|
||||
You are updating the robot's compressed semantic memory at the boundary of
|
||||
a completed subtask.
|
||||
|
||||
Reference (verbatim from MEM, Torne 2026):
|
||||
"Remove or compress information in the language memory whenever
|
||||
appropriate. Keep ONLY the minimal set of relevant information for future
|
||||
task execution. Specific object attributes (colors, precise quantities of
|
||||
each item) get discarded when their details won't affect subsequent
|
||||
actions. Functional outcomes (where items went, how many) are preserved."
|
||||
|
||||
Episode task: "{episode_task}"
|
||||
Previous memory: {prior_memory}
|
||||
Just-completed subtask: "{completed_subtask}"
|
||||
Remaining subtasks (for relevance judgement only): {remaining_subtasks}
|
||||
|
||||
Write the memory as a short FIRST-PERSON, PAST-TENSE narrative of what the
|
||||
robot has accomplished so far — the running story it would tell itself.
|
||||
|
||||
Authoring rules:
|
||||
- First person, past tense. Every sentence starts with "I": "I picked
|
||||
up...", "I opened...", "I moved to...".
|
||||
- One or two short sentences. Extend the previous memory with the
|
||||
just-completed subtask; do not rewrite it from scratch.
|
||||
- Keep WHAT happened (functional outcomes — where items went, how many),
|
||||
drop HOW (grasp details, motions).
|
||||
- Compress completed steps and drop object attributes (colors, exact
|
||||
counts) once they no longer affect the remaining subtasks.
|
||||
|
||||
Example (MEM, Torne 2026):
|
||||
Before: "I prepared the pot and got the potatoes, milk, and butter. I
|
||||
moved to the drawer."
|
||||
After: "I prepared the pot and got the ingredients. I opened the
|
||||
drawer with the masher."
|
||||
|
||||
Output strictly valid JSON:
|
||||
{{ "memory": "<one or two short first-person past-tense sentences>" }}
|
||||
@@ -1,27 +0,0 @@
|
||||
You are watching a teleoperated robot demonstration from a single
|
||||
camera. The user asked the robot to: "{episode_task}"
|
||||
|
||||
This is an OBSERVATION pass. Watch the entire clip and describe, in
|
||||
chronological order, ONLY what the robot physically does — the concrete
|
||||
motions, approaches, contacts, grasps, releases, and relocations you can
|
||||
actually SEE in the frames.
|
||||
|
||||
Hard rules:
|
||||
- Describe only motion visible in the video. Do NOT use the task
|
||||
instruction to guess steps that aren't shown. The instruction is the
|
||||
goal; the video is ground truth.
|
||||
- Do NOT segment into named subtasks yet and do NOT output JSON beyond
|
||||
the single field below. Just narrate what happens.
|
||||
- Give an approximate timestamp (in seconds) for each distinct event,
|
||||
e.g. "0.0-1.4s: the base drives forward toward the stove".
|
||||
- Do NOT invent objects, grasps, destinations, or steps. If the robot
|
||||
only does one thing (e.g. it just navigates and the clip ends), say
|
||||
exactly that and nothing more.
|
||||
- Be concrete and literal. "the gripper closes on the mug" — not "the
|
||||
robot prepares to make coffee".
|
||||
|
||||
Output strictly valid JSON:
|
||||
|
||||
{{
|
||||
"description": "<chronological, timestamped description of ONLY what is visible>"
|
||||
}}
|
||||
@@ -1,112 +0,0 @@
|
||||
You are labeling a teleoperated robot demonstration.
|
||||
|
||||
The user originally asked: "{episode_task}"
|
||||
|
||||
You are shown the entire demonstration as a single video. Watch the
|
||||
whole clip, then segment it into a list of consecutive atomic subtasks
|
||||
the robot performs.
|
||||
|
||||
{observation_block}GROUNDING — read this first, it overrides everything below:
|
||||
- Label ONLY what the robot actually does in the video. Every subtask
|
||||
you emit must correspond to motion you can SEE in specific frames.
|
||||
- Do NOT invent, anticipate, or pad. If the robot only does one thing
|
||||
(e.g. it just navigates to a location and the clip ends), emit
|
||||
EXACTLY ONE subtask. Many demonstrations are a single atomic skill.
|
||||
- ``max_steps`` below is a hard CEILING, not a target. Emitting fewer
|
||||
subtasks than the ceiling is not just allowed, it is expected for
|
||||
short / atomic demonstrations. One correct subtask is far better
|
||||
than several invented ones.
|
||||
- If the video does not clearly show the action implied by the task,
|
||||
describe what you actually see — do NOT fabricate the task's steps
|
||||
from the instruction text. The instruction tells you the goal; the
|
||||
VIDEO is the ground truth for what happened.
|
||||
|
||||
Authoring rules — Hi Robot atom granularity, pi0.7-style short prompts:
|
||||
|
||||
- Each subtask = one COMPOSITE atomic skill the low-level policy can
|
||||
execute end-to-end. A "skill" bundles its own approach motion with
|
||||
its terminal action — do NOT split the approach off as its own
|
||||
subtask. The whole-arm policy already learns to reach as part of
|
||||
every manipulation primitive.
|
||||
- Write each subtask as an IMPERATIVE COMMAND, starting with one of
|
||||
these verbs (extend only when none fits):
|
||||
pick up <obj> — approach + grasp + lift in one subtask
|
||||
put <obj> on/in <loc> — transport + release in one subtask
|
||||
place <obj> on/in <loc> — synonym of "put"; pick one and stay consistent
|
||||
push <obj> — contact + linear shove
|
||||
pull <obj> — contact + linear retract
|
||||
turn <knob/dial/handle> — rotary actuation
|
||||
press <button> — single-press contact
|
||||
open <drawer/door/lid> — full open motion
|
||||
close <drawer/door/lid> — full close motion
|
||||
pour <src> into <dst> — tilt + flow
|
||||
insert <obj> into <slot>— alignment + push-fit
|
||||
go to <loc> — ONLY when no grasp / actuation follows
|
||||
(e.g. a pure relocation between phases).
|
||||
If the next subtask grasps something at
|
||||
that location, drop "go to ..." and just
|
||||
write "pick up ..." instead.
|
||||
- Forbidden ultra-fine splits — the VLM is NOT allowed to emit these
|
||||
as standalone subtasks; fold them into the parent composite:
|
||||
"move to X" → fold into "pick up X" (or whatever follows)
|
||||
"reach for X" → fold into "pick up X"
|
||||
"grasp X" → fold into "pick up X"
|
||||
"lift X" → fold into "pick up X" (or "put X on Y" if it's
|
||||
the transport phase of a place)
|
||||
"release X" → fold into "put X on Y" (or "place X in Y")
|
||||
- Keep it SHORT — a verb phrase, not a sentence. Drop articles
|
||||
("the", "a") and adverbs ("carefully", "slowly"). Add a "how"
|
||||
detail (which hand, which grasp point) ONLY when it is needed to
|
||||
disambiguate. Every subtask must begin with one of the verbs
|
||||
above (no leading nouns, no "then", no "first").
|
||||
- NEVER use third person. Never write "the robot", "the arm", "the
|
||||
gripper moves", "it picks up" — the robot is implied. Command it,
|
||||
do not describe it.
|
||||
- Use the exact object nouns from the task above. If the task says
|
||||
"cube", every subtask says "cube" — never switch to "block". If it
|
||||
says "box", never switch to "bin"/"container". Keep vocabulary
|
||||
consistent across the whole episode.
|
||||
- Good: "pick up blue cube", "put blue cube in box", "open drawer",
|
||||
"turn red knob", "press start button", "go to sink".
|
||||
- Bad: "move to blue cube" (approach as its own subtask — forbidden,
|
||||
must be folded into "pick up blue cube"); "the robot arm moves
|
||||
towards the blue cube" (third person, too long); "carefully pick
|
||||
up the cube" (adverb, article); "release the yellow block"
|
||||
("block" when the task said "cube", and "release" must be folded
|
||||
into a "put"/"place" subtask).
|
||||
- Subtasks are non-overlapping and cover the full episode in order.
|
||||
Choose the cut points yourself based on what you see in the video
|
||||
(gripper open/close events, contact, regrasps, transitions).
|
||||
- Each subtask spans at least {min_subtask_seconds} seconds. If a
|
||||
candidate span would be shorter, merge it into its neighbour
|
||||
rather than emitting it.
|
||||
- Do not exceed {max_steps} subtasks total. Fewer, larger composites
|
||||
are preferred over many micro-steps.
|
||||
- Every subtask's [start_time, end_time] must lie within
|
||||
[0.0, {episode_duration}] seconds.
|
||||
|
||||
SPECIAL CASES — verb disambiguation (each rule is narrowly visual and
|
||||
fires ONLY on the spatial situation it names; it must not change how you
|
||||
label any other situation):
|
||||
- STACK vs PUT: if an object is placed ON TOP OF another specific object
|
||||
(not on a flat table / shelf / counter), use "stack ... on ...", not
|
||||
"put". "stack blue book on green book", NOT "put blue book on table".
|
||||
- INSERT vs PUT: if an object goes INTO a fitted slot / hole / socket /
|
||||
receptacle (push-fit), use "insert ... into ...", not "put".
|
||||
- RETRIEVE/PICK-UP vs PUT (direction): watch the gripper. If it CLOSES
|
||||
on the object and the object moves WITH the hand, it is "pick up" /
|
||||
"retrieve" (object leaves its location). If the gripper OPENS and the
|
||||
object stays where the hand left it, it is "put" / "place" (object
|
||||
arrives at a location). Decide by which way the object moves, not by
|
||||
where the hand ends up.
|
||||
- POUR vs PUT: only use "pour" when the source is tilted and contents
|
||||
flow out; moving a full container without tilting is "put"/"place".
|
||||
|
||||
Output strictly valid JSON of shape:
|
||||
|
||||
{{
|
||||
"subtasks": [
|
||||
{{"text": "<short imperative verb phrase>", "start": <float>, "end": <float>}},
|
||||
...
|
||||
]
|
||||
}}
|
||||
@@ -1,67 +0,0 @@
|
||||
You are generating structured augmentations of a robot task instruction
|
||||
for training a language-conditioned policy. Unlike free-form rephrasing,
|
||||
your variants follow a NAMED 5-axis taxonomy — each axis omits or varies
|
||||
a specific element of the task while preserving its meaning.
|
||||
|
||||
Original task: "{base_task}"
|
||||
|
||||
Produce variants along five named axes. Each axis has a target count.
|
||||
The whole batch should expose the policy to maximum linguistic diversity
|
||||
WITHOUT changing what the robot is supposed to do.
|
||||
|
||||
Axes and target counts:
|
||||
|
||||
synonym_paraphrase ({n_synonym}):
|
||||
Different wording / verbs / sentence structure. ALL information
|
||||
from the original task is preserved — same object, same arm
|
||||
specification if present, same orientation if present, same grasp
|
||||
if present.
|
||||
|
||||
omit_arm ({n_omit_arm}):
|
||||
Drop the left/right/both arm specification from the task. Skip
|
||||
entirely (emit 0 entries) if the original task does NOT mention an
|
||||
arm. Do not invent an arm specification just to omit it.
|
||||
|
||||
omit_orientation ({n_omit_orientation}):
|
||||
Drop orientation cues (upright, sideways, facing the user,
|
||||
long-edge-first, etc.). Skip entirely if no orientation cue is
|
||||
present in the original task.
|
||||
|
||||
omit_grasp_method ({n_omit_grasp_method}):
|
||||
Drop the grip / grasp method specification (pinch, wrap, hold by
|
||||
the rim, etc.). Skip entirely if no grasp method is mentioned.
|
||||
|
||||
combined_omissions ({n_combined}):
|
||||
Combine TWO of the above omissions simultaneously (e.g. drop both
|
||||
arm and orientation). Skip entirely if fewer than two of (arm,
|
||||
orientation, grasp_method) appear in the original task.
|
||||
|
||||
Hard rules:
|
||||
- Each variant MUST preserve the core action, the target object, AND
|
||||
the goal / destination. Do not change which object is involved, where
|
||||
it goes, or the high-level action. "Navigate to the stove" may become
|
||||
"go to the stove" or "head over to the stove" — it must NEVER become
|
||||
"wander around the kitchen", "explore the room", or anything that
|
||||
drops or generalises the stove destination. If you cannot vary the
|
||||
wording without changing the goal, emit fewer variants.
|
||||
- Only the FIVE listed elements (wording, arm, orientation, grasp
|
||||
method, or a combination) may be varied or omitted. The verb's
|
||||
meaning, the object, and the destination are fixed.
|
||||
- Each variant is plain prose, no markdown, no quotes, no list numbers.
|
||||
- Each variant must be DISTINCT from every other variant in the entire
|
||||
output, both within and across axes. Near-duplicates are not allowed.
|
||||
- If an axis cannot reach its target count because the original task
|
||||
lacks the omittable element, emit fewer entries — do NOT pad the
|
||||
axis with paraphrases that belong to a different axis.
|
||||
- Variants should not all start with verbs — vary sentence structure
|
||||
(some imperative, some polite request, some question).
|
||||
|
||||
Output strictly valid JSON of shape:
|
||||
|
||||
{{
|
||||
"synonym_paraphrase": ["<v1>", "<v2>", ...],
|
||||
"omit_arm": ["<v1>", "<v2>", ...],
|
||||
"omit_orientation": ["<v1>", ...],
|
||||
"omit_grasp_method": ["<v1>", ...],
|
||||
"combined_omissions": ["<v1>", ...]
|
||||
}}
|
||||
@@ -1,32 +0,0 @@
|
||||
You are generating training data for a Hi Robot-style policy. We need
|
||||
{n} alternative phrasings of the same robot task so the policy sees
|
||||
diverse user prompts during training instead of the same canonical
|
||||
string repeated every frame.
|
||||
|
||||
Original task:
|
||||
"{base_task}"
|
||||
|
||||
Generate exactly {n} alternative phrasings of the same task. Vary:
|
||||
|
||||
- formality (casual / polite / curt)
|
||||
- verbosity (mostly short imperative; occasional polite request)
|
||||
- word choice (synonyms, different verbs)
|
||||
- sentence structure (imperative / question / suggestion)
|
||||
|
||||
Hard rules:
|
||||
- Each phrasing MUST preserve the exact meaning of the original task.
|
||||
Do not change which object is involved, the destination, or the
|
||||
action. Do not add extra steps. Do not invent new objects.
|
||||
- Each phrasing must be a short phrase or sentence, plain prose, no
|
||||
markdown, no quotes, no list numbers.
|
||||
- Phrasings must be distinct — no near-duplicates.
|
||||
- Output exactly {n} entries.
|
||||
|
||||
Output strictly valid JSON:
|
||||
{{
|
||||
"rephrasings": [
|
||||
"<phrasing 1>",
|
||||
"<phrasing 2>",
|
||||
...
|
||||
]
|
||||
}}
|
||||
@@ -1,17 +0,0 @@
|
||||
The video above shows a robot manipulation episode in full. Look at
|
||||
the entire video and describe in ONE concise sentence what the robot
|
||||
is doing.
|
||||
|
||||
Rules:
|
||||
- One sentence, in natural English, like a user instruction.
|
||||
- Capture the goal of the demonstration, not low-level motions.
|
||||
Example: "place the yellow cube into the red bin" — not "move the
|
||||
end-effector down 5cm and close the gripper".
|
||||
- 4 to 15 words. Plain prose, no markdown, no bullets, no quotes.
|
||||
- Do not invent objects or actions that aren't visible.
|
||||
- Do not output anything other than the JSON object below.
|
||||
|
||||
Output strictly valid JSON:
|
||||
{{
|
||||
"task": "<single concise sentence describing what the robot does in this video>"
|
||||
}}
|
||||
@@ -1,32 +0,0 @@
|
||||
You are generating a frame-grounded visual question/answer pair for
|
||||
chain-of-thought training. Reference: ECoT (Zawalski 2024) and Steerable
|
||||
Policies — both train policies on grounded features such as bounding box
|
||||
pixel coordinates, keypoints, counts, attributes, and spatial relations.
|
||||
|
||||
The frame shows a robot working on: "{episode_task}".
|
||||
|
||||
Question types and the EXACT answer JSON shape required for each:
|
||||
|
||||
bbox => {{"detections": [{{"label": "<obj>", "bbox_format": "xyxy",
|
||||
"bbox": [x1, y1, x2, y2]}}, ...]}}
|
||||
bbox is in pixel coordinates (x_min, y_min, x_max, y_max).
|
||||
ECoT example: "a white cup [124, 25, 176, 113]".
|
||||
|
||||
keypoint => {{"label": "<point>", "point_format": "xy",
|
||||
"point": [x, y]}}
|
||||
|
||||
count => {{"label": "<obj>", "count": <int>,
|
||||
"note": "<optional short note>"}}
|
||||
|
||||
attribute => {{"label": "<obj>", "attribute": "<color|shape|state|...>",
|
||||
"value": "<observed value>"}}
|
||||
|
||||
spatial => {{"subject": "<obj>", "relation": "<left_of|right_of|on|in|"
|
||||
"above|below|near>", "object": "<obj>"}}
|
||||
|
||||
Generate a question of type "{question_type}". Output strictly valid JSON:
|
||||
|
||||
{{
|
||||
"question": "<short, frame-grounded question>",
|
||||
"answer": <object whose shape matches the schema above>
|
||||
}}
|
||||
@@ -1,216 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Datatrove-shaped reader.
|
||||
|
||||
The reader walks ``data/chunk-*/file-*.parquet`` and yields one record per
|
||||
episode containing:
|
||||
|
||||
- ``episode_index``: int
|
||||
- ``frame_timestamps``: tuple[float, ...]
|
||||
- ``frame_indices``: tuple[int, ...]
|
||||
- ``episode_task``: str (canonical task from ``meta/tasks.parquet``)
|
||||
- ``data_path``: pathlib.Path of the source parquet shard
|
||||
- ``frames_df``: pandas.DataFrame slice for the episode (only loaded on demand)
|
||||
|
||||
This shape lets each module operate per-episode without loading all parquet
|
||||
rows into memory at once.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator, Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
from lerobot.datasets.io_utils import load_tasks
|
||||
from lerobot.datasets.utils import DEFAULT_TASKS_PATH
|
||||
|
||||
|
||||
@dataclass
|
||||
class EpisodeRecord:
|
||||
"""Per-episode record yielded by the reader."""
|
||||
|
||||
episode_index: int
|
||||
episode_task: str
|
||||
frame_timestamps: tuple[float, ...]
|
||||
frame_indices: tuple[int, ...]
|
||||
data_path: Path
|
||||
row_offset: int # row offset within the parquet file where this episode starts
|
||||
row_count: int # number of rows for this episode
|
||||
|
||||
# Memoized parquet slice — populated on first ``frames_df()`` call so
|
||||
# repeat queries from different modules don't re-read the whole shard.
|
||||
_frames_df_cache: Any = field(default=None, init=False, repr=False, compare=False)
|
||||
|
||||
def frames_df(self): # type: ignore[no-untyped-def]
|
||||
"""Lazy-load the pandas slice for this episode (memoized)."""
|
||||
if self._frames_df_cache is None:
|
||||
import pandas as pd # noqa: PLC0415 - deferred for optional dataset extra
|
||||
|
||||
table = pq.read_table(self.data_path)
|
||||
df: pd.DataFrame = table.to_pandas()
|
||||
self._frames_df_cache = df.iloc[self.row_offset : self.row_offset + self.row_count].reset_index(
|
||||
drop=True
|
||||
)
|
||||
return self._frames_df_cache
|
||||
|
||||
|
||||
def reconstruct_subtask_spans(
|
||||
rows: Sequence[dict[str, Any]],
|
||||
*,
|
||||
episode_end_t: float | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Turn ``style="subtask"`` rows into ``{text, start, end}`` spans.
|
||||
|
||||
Each span's ``end`` is the next span's ``start``. The final span's
|
||||
``end`` defaults to its own ``start`` (zero-duration) — pass
|
||||
``episode_end_t`` to extend it to the episode's last frame instead,
|
||||
which is what downstream consumers (memory, interjection boundary
|
||||
selection) expect.
|
||||
|
||||
Used by the ``plan`` module (plan-update pass) and the
|
||||
``interjections`` module (interjection anchoring), which both need the
|
||||
same span shape.
|
||||
"""
|
||||
sorted_rows = sorted(
|
||||
(r for r in rows if r.get("style") == "subtask"),
|
||||
key=lambda r: float(r["timestamp"]),
|
||||
)
|
||||
spans: list[dict[str, Any]] = []
|
||||
for r in sorted_rows:
|
||||
t = float(r["timestamp"])
|
||||
if spans:
|
||||
spans[-1]["end"] = t
|
||||
spans.append({"text": r.get("content") or "", "start": t, "end": t})
|
||||
if spans and episode_end_t is not None and float(episode_end_t) > spans[-1]["start"]:
|
||||
spans[-1]["end"] = float(episode_end_t)
|
||||
return spans
|
||||
|
||||
|
||||
def snap_to_frame(t: float, frame_timestamps: Sequence[float]) -> float:
|
||||
"""Snap an arbitrary float to the nearest exact source frame timestamp.
|
||||
|
||||
Modules use this when emitting event-style rows so the row's
|
||||
timestamp matches a real parquet frame: event rows must land on an
|
||||
exact frame, otherwise the per-frame event lookup the writer does
|
||||
would never match them.
|
||||
"""
|
||||
if not frame_timestamps:
|
||||
return float(t)
|
||||
nearest = min(frame_timestamps, key=lambda f: abs(f - t))
|
||||
return float(nearest)
|
||||
|
||||
|
||||
def _load_tasks_lookup(root: Path) -> dict[int, str]:
|
||||
"""Map ``task_index -> task`` from ``meta/tasks.parquet``.
|
||||
|
||||
Returns an empty dict when the file is absent — the task description is
|
||||
derived later from the video if needed. Reuses the library-level
|
||||
:func:`lerobot.datasets.io_utils.load_tasks`, which returns the tasks
|
||||
frame indexed by task string with a ``task_index`` column.
|
||||
"""
|
||||
if not (root / DEFAULT_TASKS_PATH).exists():
|
||||
return {}
|
||||
tasks = load_tasks(root)
|
||||
return {int(idx): str(task) for task, idx in zip(tasks.index, tasks["task_index"], strict=True)}
|
||||
|
||||
|
||||
def iter_episodes(root: Path, *, only_episodes: tuple[int, ...] | None = None) -> Iterator[EpisodeRecord]:
|
||||
"""Yield :class:`EpisodeRecord` for every episode under ``root/data/``.
|
||||
|
||||
Episodes are yielded in ascending ``episode_index`` order. The reader does
|
||||
not assume a specific chunk/file layout: it scans every ``*.parquet``
|
||||
under ``data/`` and groups by ``episode_index``.
|
||||
"""
|
||||
tasks = _load_tasks_lookup(root)
|
||||
data_dir = root / "data"
|
||||
parquet_files = sorted(data_dir.rglob("*.parquet"))
|
||||
|
||||
only_set = set(only_episodes) if only_episodes is not None else None
|
||||
|
||||
for path in parquet_files:
|
||||
yield from _iter_one_path(path, tasks, only_set)
|
||||
|
||||
|
||||
def _iter_one_path(path: Path, tasks: dict[int, str], only_set: set[int] | None) -> Iterator[EpisodeRecord]:
|
||||
table = pq.read_table(path)
|
||||
names = table.column_names
|
||||
if "episode_index" not in names:
|
||||
return
|
||||
episode_col = table.column("episode_index").to_pylist()
|
||||
timestamp_col = (
|
||||
table.column("timestamp").to_pylist() if "timestamp" in names else [0.0] * len(episode_col)
|
||||
)
|
||||
frame_col = (
|
||||
table.column("frame_index").to_pylist() if "frame_index" in names else list(range(len(episode_col)))
|
||||
)
|
||||
task_col = table.column("task_index").to_pylist() if "task_index" in names else None
|
||||
|
||||
def _build(
|
||||
ep: int,
|
||||
start: int,
|
||||
end: int,
|
||||
task_idx: int | None,
|
||||
ts_buf: list[float],
|
||||
fi_buf: list[int],
|
||||
) -> EpisodeRecord | None:
|
||||
if only_set is not None and ep not in only_set:
|
||||
return None
|
||||
task = tasks.get(task_idx, "") if task_idx is not None else ""
|
||||
return EpisodeRecord(
|
||||
episode_index=ep,
|
||||
episode_task=task,
|
||||
frame_timestamps=tuple(ts_buf),
|
||||
frame_indices=tuple(fi_buf),
|
||||
data_path=path,
|
||||
row_offset=start,
|
||||
row_count=end - start,
|
||||
)
|
||||
|
||||
cur_ep: int | None = None
|
||||
start_offset = 0
|
||||
ts_buf: list[float] = []
|
||||
fi_buf: list[int] = []
|
||||
cur_task_idx: int | None = None
|
||||
|
||||
for i, ep in enumerate(episode_col):
|
||||
if cur_ep is None:
|
||||
cur_ep = ep
|
||||
start_offset = i
|
||||
ts_buf = [timestamp_col[i]]
|
||||
fi_buf = [frame_col[i]]
|
||||
cur_task_idx = task_col[i] if task_col is not None else None
|
||||
continue
|
||||
if ep != cur_ep:
|
||||
rec = _build(cur_ep, start_offset, i, cur_task_idx, ts_buf, fi_buf)
|
||||
if rec is not None:
|
||||
yield rec
|
||||
cur_ep = ep
|
||||
start_offset = i
|
||||
ts_buf = [timestamp_col[i]]
|
||||
fi_buf = [frame_col[i]]
|
||||
cur_task_idx = task_col[i] if task_col is not None else None
|
||||
else:
|
||||
ts_buf.append(timestamp_col[i])
|
||||
fi_buf.append(frame_col[i])
|
||||
|
||||
if cur_ep is not None:
|
||||
rec = _build(cur_ep, start_offset, len(episode_col), cur_task_idx, ts_buf, fi_buf)
|
||||
if rec is not None:
|
||||
yield rec
|
||||
@@ -1,92 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Per-episode staging.
|
||||
|
||||
Each module writes its raw output as a JSONL file under
|
||||
``<staging_dir>/episode_{ep:06d}/<module>.jsonl``. The writer reads back this
|
||||
staging tree and partitions rows into the two language columns.
|
||||
|
||||
JSONL is preferred over parquet here because the staging artifact is meant to
|
||||
be human-inspectable, easy to diff between prompt iterations, and trivially
|
||||
appended to. The final dataset format is parquet; staging is just an
|
||||
intermediate.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
ModuleName = str
|
||||
|
||||
_MODULES: tuple[ModuleName, ...] = (
|
||||
"plan",
|
||||
"interjections",
|
||||
"vqa",
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EpisodeStaging:
|
||||
"""Filesystem layout for a single episode's staged module outputs."""
|
||||
|
||||
root: Path
|
||||
episode_index: int
|
||||
|
||||
@property
|
||||
def episode_dir(self) -> Path:
|
||||
return self.root / f"episode_{self.episode_index:06d}"
|
||||
|
||||
def path_for(self, module: ModuleName) -> Path:
|
||||
if module not in _MODULES:
|
||||
raise ValueError(f"Unknown module {module!r}; expected one of {_MODULES}")
|
||||
return self.episode_dir / f"{module}.jsonl"
|
||||
|
||||
def write(self, module: ModuleName, rows: Iterable[dict[str, Any]]) -> Path:
|
||||
path = self.path_for(module)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
# Atomic replace: a crash mid-write would otherwise leave a
|
||||
# half-written JSONL file that ``read()`` would then fail to
|
||||
# parse. Write to a sibling .tmp and rename so the target path
|
||||
# only ever points at a complete file.
|
||||
tmp_path = path.with_suffix(path.suffix + ".tmp")
|
||||
with tmp_path.open("w", encoding="utf-8") as f:
|
||||
for row in rows:
|
||||
f.write(json.dumps(row, ensure_ascii=False, sort_keys=True))
|
||||
f.write("\n")
|
||||
tmp_path.replace(path)
|
||||
return path
|
||||
|
||||
def read(self, module: ModuleName) -> list[dict[str, Any]]:
|
||||
path = self.path_for(module)
|
||||
if not path.exists():
|
||||
return []
|
||||
out: list[dict[str, Any]] = []
|
||||
with path.open(encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line:
|
||||
out.append(json.loads(line))
|
||||
return out
|
||||
|
||||
def read_all(self) -> dict[ModuleName, list[dict[str, Any]]]:
|
||||
return {m: self.read(m) for m in _MODULES}
|
||||
|
||||
def has(self, module: ModuleName) -> bool:
|
||||
return self.path_for(module).exists()
|
||||
@@ -1,332 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Pre-write validation against staged outputs.
|
||||
|
||||
Runs after all three modules have written their per-episode artifacts but
|
||||
*before* the writer rewrites parquet shards. The validator never touches
|
||||
parquet; it only inspects the staging tree and the source frame timestamps
|
||||
exposed by :class:`EpisodeRecord`.
|
||||
|
||||
Checks (per the plan's "Intermediate staging and validation" section):
|
||||
|
||||
- exact timestamp alignment against source frame timestamps
|
||||
- no orphan speech / interjection pairs
|
||||
- plan / memory emission consistency (events have a paired persistent row)
|
||||
- VQA assistant ``content`` is valid JSON (one of bbox / keypoint / count /
|
||||
attribute / spatial)
|
||||
- every row maps to its correct column under :func:`column_for_style`
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Iterable, Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from lerobot.datasets.language import (
|
||||
LANGUAGE_EVENTS,
|
||||
LANGUAGE_PERSISTENT,
|
||||
column_for_style,
|
||||
is_view_dependent_style,
|
||||
validate_camera_field,
|
||||
)
|
||||
|
||||
from .reader import EpisodeRecord
|
||||
from .staging import EpisodeStaging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValidationReport:
|
||||
"""Outcome of one validation pass across all episodes."""
|
||||
|
||||
errors: list[str] = field(default_factory=list)
|
||||
warnings: list[str] = field(default_factory=list)
|
||||
episodes_checked: int = 0
|
||||
|
||||
@property
|
||||
def ok(self) -> bool:
|
||||
return not self.errors
|
||||
|
||||
def add_error(self, message: str) -> None:
|
||||
self.errors.append(message)
|
||||
|
||||
def add_warning(self, message: str) -> None:
|
||||
self.warnings.append(message)
|
||||
|
||||
def summary(self) -> str:
|
||||
return f"checked={self.episodes_checked} errors={len(self.errors)} warnings={len(self.warnings)}"
|
||||
|
||||
|
||||
VQA_ANSWER_SHAPES: dict[str, set[str]] = {
|
||||
"bbox": {"detections"},
|
||||
"keypoint": {"label", "point_format", "point"},
|
||||
"count": {"label", "count"},
|
||||
"attribute": {"label", "attribute", "value"},
|
||||
"spatial": {"subject", "relation", "object"},
|
||||
}
|
||||
|
||||
|
||||
def classify_vqa_answer(payload: Any) -> str | None:
|
||||
"""Best-effort classification of a VQA answer payload to a question type."""
|
||||
if not isinstance(payload, dict):
|
||||
return None
|
||||
keys = set(payload.keys())
|
||||
for kind, required in VQA_ANSWER_SHAPES.items():
|
||||
if required.issubset(keys):
|
||||
return kind
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class StagingValidator:
|
||||
"""Walks the staging tree and produces a :class:`ValidationReport`."""
|
||||
|
||||
timestamp_atol: float = 0.0 # exact-match by default
|
||||
dataset_camera_keys: tuple[str, ...] | None = None
|
||||
"""Known ``observation.images.*`` keys on the dataset. When set, the
|
||||
validator additionally enforces that every view-dependent row's
|
||||
``camera`` field references one of these keys. Pass ``None`` (default)
|
||||
to skip that cross-check (e.g. in unit tests with no real dataset)."""
|
||||
|
||||
def validate(
|
||||
self,
|
||||
records: Sequence[EpisodeRecord],
|
||||
staging_dir: Path,
|
||||
) -> ValidationReport:
|
||||
report = ValidationReport()
|
||||
for record in records:
|
||||
self._validate_episode(record, staging_dir, report)
|
||||
report.episodes_checked += 1
|
||||
return report
|
||||
|
||||
def _validate_episode(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
staging_dir: Path,
|
||||
report: ValidationReport,
|
||||
) -> None:
|
||||
staging = EpisodeStaging(staging_dir, record.episode_index)
|
||||
staged = staging.read_all()
|
||||
all_rows: list[dict[str, Any]] = []
|
||||
for module_name, rows in staged.items():
|
||||
for row in rows:
|
||||
row = {**row, "_module": module_name}
|
||||
all_rows.append(row)
|
||||
|
||||
frame_ts = set(record.frame_timestamps)
|
||||
|
||||
events: list[dict[str, Any]] = []
|
||||
persistent: list[dict[str, Any]] = []
|
||||
for row in all_rows:
|
||||
self._check_column_routing(row, report, record.episode_index)
|
||||
self._check_camera_field(row, report, record.episode_index, self.dataset_camera_keys)
|
||||
# ``_check_column_routing`` already recorded any unknown-style error;
|
||||
# don't let the same ``column_for_style`` lookup raise here uncaught.
|
||||
try:
|
||||
column = column_for_style(row.get("style"))
|
||||
except ValueError:
|
||||
continue
|
||||
if column == LANGUAGE_PERSISTENT:
|
||||
persistent.append(row)
|
||||
else:
|
||||
events.append(row)
|
||||
|
||||
for row in events:
|
||||
self._check_event_timestamp_alignment(row, frame_ts, report, record.episode_index)
|
||||
|
||||
self._check_speech_interjection_pairs(events, report, record.episode_index)
|
||||
self._check_plan_memory_consistency(persistent, events, report, record.episode_index)
|
||||
self._check_vqa_json(events, report, record.episode_index)
|
||||
self._check_vqa_uniqueness_per_frame_camera(events, report, record.episode_index)
|
||||
|
||||
def _check_camera_field(
|
||||
self,
|
||||
row: dict[str, Any],
|
||||
report: ValidationReport,
|
||||
episode_index: int,
|
||||
dataset_camera_keys: Sequence[str] | None,
|
||||
) -> None:
|
||||
"""Enforce the camera invariant + that the key matches the dataset's cameras."""
|
||||
style = row.get("style")
|
||||
camera = row.get("camera")
|
||||
try:
|
||||
validate_camera_field(style, camera)
|
||||
except ValueError as exc:
|
||||
report.add_error(f"ep={episode_index} module={row.get('_module')}: {exc}")
|
||||
return
|
||||
if is_view_dependent_style(style) and dataset_camera_keys and camera not in dataset_camera_keys:
|
||||
report.add_error(
|
||||
f"ep={episode_index} module={row.get('_module')}: camera {camera!r} on style "
|
||||
f"{style!r} is not one of the dataset's video keys {sorted(dataset_camera_keys)!r}"
|
||||
)
|
||||
|
||||
def _check_vqa_uniqueness_per_frame_camera(
|
||||
self,
|
||||
events: Iterable[dict[str, Any]],
|
||||
report: ValidationReport,
|
||||
episode_index: int,
|
||||
) -> None:
|
||||
"""Ensure at most one (vqa, user) and one (vqa, assistant) per (t, camera)."""
|
||||
counts: dict[tuple[float, str, str], int] = {}
|
||||
for row in events:
|
||||
if row.get("style") != "vqa":
|
||||
continue
|
||||
ts = row.get("timestamp")
|
||||
camera = row.get("camera")
|
||||
role = row.get("role")
|
||||
if ts is None or camera is None or role is None:
|
||||
continue # other validators flag these
|
||||
key = (float(ts), str(camera), str(role))
|
||||
counts[key] = counts.get(key, 0) + 1
|
||||
for (ts, camera, role), n in counts.items():
|
||||
if n > 1:
|
||||
report.add_error(
|
||||
f"ep={episode_index}: {n} duplicate vqa rows at t={ts} "
|
||||
f"camera={camera!r} role={role!r}; expected at most one per (t, camera, role)"
|
||||
)
|
||||
|
||||
def _check_column_routing(
|
||||
self,
|
||||
row: dict[str, Any],
|
||||
report: ValidationReport,
|
||||
episode_index: int,
|
||||
) -> None:
|
||||
style = row.get("style")
|
||||
module = row.get("_module")
|
||||
try:
|
||||
target_col = column_for_style(style)
|
||||
except ValueError:
|
||||
report.add_error(f"ep={episode_index} module={module}: unknown style {style!r}")
|
||||
return
|
||||
if module == "plan" and target_col != LANGUAGE_PERSISTENT:
|
||||
report.add_error(
|
||||
f"ep={episode_index} module=plan emitted style {style!r} that routes to {target_col} (must be persistent)"
|
||||
)
|
||||
if module in {"interjections", "vqa"} and target_col != LANGUAGE_EVENTS:
|
||||
report.add_error(
|
||||
f"ep={episode_index} module={module} emitted style {style!r} that routes to {target_col} (must be events)"
|
||||
)
|
||||
|
||||
def _check_event_timestamp_alignment(
|
||||
self,
|
||||
row: dict[str, Any],
|
||||
frame_ts: set[float],
|
||||
report: ValidationReport,
|
||||
episode_index: int,
|
||||
) -> None:
|
||||
ts = row.get("timestamp")
|
||||
if ts is None:
|
||||
report.add_error(f"ep={episode_index}: event row missing timestamp: {row!r}")
|
||||
return
|
||||
if self.timestamp_atol == 0.0:
|
||||
if float(ts) not in frame_ts:
|
||||
report.add_error(
|
||||
f"ep={episode_index}: event row timestamp {ts!r} does not match any source frame timestamp"
|
||||
)
|
||||
else:
|
||||
if not any(abs(float(ts) - f) <= self.timestamp_atol for f in frame_ts):
|
||||
report.add_error(
|
||||
f"ep={episode_index}: event row timestamp {ts!r} not within {self.timestamp_atol}s of any frame"
|
||||
)
|
||||
|
||||
def _check_speech_interjection_pairs(
|
||||
self,
|
||||
events: Iterable[dict[str, Any]],
|
||||
report: ValidationReport,
|
||||
episode_index: int,
|
||||
) -> None:
|
||||
speech_ts: dict[float, int] = {}
|
||||
interjection_ts: dict[float, int] = {}
|
||||
for row in events:
|
||||
ts = row.get("timestamp")
|
||||
if ts is None:
|
||||
continue
|
||||
ts_f = float(ts)
|
||||
if row.get("style") is None and row.get("role") == "assistant":
|
||||
speech_ts[ts_f] = speech_ts.get(ts_f, 0) + 1
|
||||
if row.get("style") == "interjection":
|
||||
interjection_ts[ts_f] = interjection_ts.get(ts_f, 0) + 1
|
||||
|
||||
for ts in interjection_ts:
|
||||
if ts not in speech_ts:
|
||||
report.add_error(f"ep={episode_index}: interjection at t={ts} has no paired speech atom")
|
||||
|
||||
def _check_plan_memory_consistency(
|
||||
self,
|
||||
persistent: Sequence[dict[str, Any]],
|
||||
events: Sequence[dict[str, Any]],
|
||||
report: ValidationReport,
|
||||
episode_index: int,
|
||||
) -> None:
|
||||
plan_ts = sorted({float(r["timestamp"]) for r in persistent if r.get("style") == "plan"})
|
||||
memory_ts = sorted({float(r["timestamp"]) for r in persistent if r.get("style") == "memory"})
|
||||
subtask_ts = sorted({float(r["timestamp"]) for r in persistent if r.get("style") == "subtask"})
|
||||
interjection_ts = sorted(
|
||||
{
|
||||
float(r["timestamp"])
|
||||
for r in events
|
||||
if r.get("style") == "interjection" and r.get("timestamp") is not None
|
||||
}
|
||||
)
|
||||
|
||||
if persistent and not plan_ts:
|
||||
report.add_warning(f"ep={episode_index}: persistent rows present but no plan emitted")
|
||||
# every interjection should have a same-timestamp plan refresh
|
||||
for ts in interjection_ts:
|
||||
if ts not in set(plan_ts):
|
||||
report.add_error(
|
||||
f"ep={episode_index}: interjection at t={ts} has no co-timestamped plan update"
|
||||
)
|
||||
# memory should be emitted at subtask boundaries (subset relation)
|
||||
if memory_ts and subtask_ts:
|
||||
mem_set = set(memory_ts)
|
||||
sub_set = set(subtask_ts)
|
||||
stray = sorted(mem_set - sub_set)
|
||||
if stray:
|
||||
report.add_warning(f"ep={episode_index}: memory rows at {stray} not at any subtask boundary")
|
||||
|
||||
def _check_vqa_json(
|
||||
self,
|
||||
events: Iterable[dict[str, Any]],
|
||||
report: ValidationReport,
|
||||
episode_index: int,
|
||||
) -> None:
|
||||
for row in events:
|
||||
if row.get("style") != "vqa" or row.get("role") != "assistant":
|
||||
continue
|
||||
content = row.get("content")
|
||||
if content is None:
|
||||
report.add_error(
|
||||
f"ep={episode_index}: VQA assistant row at t={row.get('timestamp')} has null content"
|
||||
)
|
||||
continue
|
||||
try:
|
||||
payload = json.loads(content)
|
||||
except (TypeError, ValueError) as exc:
|
||||
report.add_error(
|
||||
f"ep={episode_index}: VQA assistant content not valid JSON at t={row.get('timestamp')}: {exc}"
|
||||
)
|
||||
continue
|
||||
shape = classify_vqa_answer(payload)
|
||||
if shape is None:
|
||||
report.add_error(
|
||||
f"ep={episode_index}: VQA assistant payload at t={row.get('timestamp')} does not match any known shape: keys={list(payload) if isinstance(payload, dict) else type(payload).__name__}"
|
||||
)
|
||||
@@ -1,599 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Shared Qwen-VL client.
|
||||
|
||||
The pipeline uses a single shared VLM across modules. vLLM is preferred when
|
||||
available (high throughput, JSON-guided decoding); transformers is the
|
||||
fallback. A ``stub`` backend is used for unit tests so fixtures never call
|
||||
into a real model.
|
||||
|
||||
The client speaks one method, :meth:`VlmClient.generate_json`, which:
|
||||
|
||||
- accepts a list of OpenAI/HF-style multimodal messages,
|
||||
- requests JSON output from the server,
|
||||
- batches requests transparently,
|
||||
- and reprompts once on a JSON parse failure with an inline correction
|
||||
message before raising.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import atexit
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import shlex
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import urllib.request
|
||||
from collections.abc import Callable, Sequence
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Protocol
|
||||
|
||||
from .config import VlmConfig
|
||||
|
||||
|
||||
class VlmClient(Protocol):
|
||||
"""Protocol every backend must implement."""
|
||||
|
||||
def generate_json(
|
||||
self,
|
||||
messages_batch: Sequence[Sequence[dict[str, Any]]],
|
||||
*,
|
||||
max_new_tokens: int | None = None,
|
||||
temperature: float | None = None,
|
||||
) -> list[Any]:
|
||||
"""Generate one JSON-decoded response per messages list."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class StubVlmClient:
|
||||
"""Deterministic stub used in unit tests.
|
||||
|
||||
A test passes a callable that maps the *last user message text* (or, if
|
||||
that is empty, the full message list) to a JSON-serializable response.
|
||||
"""
|
||||
|
||||
responder: Callable[[Sequence[dict[str, Any]]], Any]
|
||||
|
||||
def generate_json(
|
||||
self,
|
||||
messages_batch: Sequence[Sequence[dict[str, Any]]],
|
||||
*,
|
||||
max_new_tokens: int | None = None,
|
||||
temperature: float | None = None,
|
||||
) -> list[Any]:
|
||||
return [self.responder(list(messages)) for messages in messages_batch]
|
||||
|
||||
|
||||
def _strip_to_json(text: str) -> Any:
|
||||
text = text.strip()
|
||||
# Strip <think>...</think> blocks (Qwen3 Thinking style)
|
||||
while "<think>" in text and "</think>" in text:
|
||||
start = text.find("<think>")
|
||||
end = text.find("</think>", start) + len("</think>")
|
||||
text = (text[:start] + text[end:]).strip()
|
||||
# Strip ```json ... ``` fences from chat-tuned backbones
|
||||
if text.startswith("```"):
|
||||
first = text.find("\n")
|
||||
last = text.rfind("```")
|
||||
if first != -1 and last != -1 and last > first:
|
||||
text = text[first + 1 : last].strip()
|
||||
try:
|
||||
return json.loads(text)
|
||||
except (ValueError, json.JSONDecodeError):
|
||||
pass
|
||||
# Fall back to extracting the first balanced {...} block.
|
||||
obj_text = _extract_first_json_object(text)
|
||||
if obj_text is None:
|
||||
raise json.JSONDecodeError("No JSON object found", text, 0)
|
||||
return json.loads(obj_text)
|
||||
|
||||
|
||||
def _extract_first_json_object(text: str) -> str | None:
|
||||
"""Return the first balanced ``{...}`` substring, ignoring braces in
|
||||
string literals. Returns ``None`` if no balanced block is found."""
|
||||
start = text.find("{")
|
||||
if start < 0:
|
||||
return None
|
||||
depth = 0
|
||||
in_string = False
|
||||
escape = False
|
||||
for i in range(start, len(text)):
|
||||
ch = text[i]
|
||||
if escape:
|
||||
escape = False
|
||||
continue
|
||||
if ch == "\\":
|
||||
escape = True
|
||||
continue
|
||||
# Note: ``escape`` is always False here — the ``if escape`` branch
|
||||
# above already handled and reset it.
|
||||
if ch == '"':
|
||||
in_string = not in_string
|
||||
continue
|
||||
if in_string:
|
||||
continue
|
||||
if ch == "{":
|
||||
depth += 1
|
||||
elif ch == "}":
|
||||
depth -= 1
|
||||
if depth == 0:
|
||||
return text[start : i + 1]
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class _GenericTextClient:
|
||||
"""Wraps any text-generation callable in JSON-mode + one-retry semantics."""
|
||||
|
||||
generate_text: Callable[[Sequence[Sequence[dict[str, Any]]], int, float], list[str]]
|
||||
config: VlmConfig
|
||||
|
||||
def generate_json(
|
||||
self,
|
||||
messages_batch: Sequence[Sequence[dict[str, Any]]],
|
||||
*,
|
||||
max_new_tokens: int | None = None,
|
||||
temperature: float | None = None,
|
||||
) -> list[Any]:
|
||||
max_tok = max_new_tokens if max_new_tokens is not None else self.config.max_new_tokens
|
||||
temp = temperature if temperature is not None else self.config.temperature
|
||||
raw = self.generate_text(messages_batch, max_tok, temp)
|
||||
out: list[Any] = []
|
||||
for messages, text in zip(messages_batch, raw, strict=True):
|
||||
try:
|
||||
out.append(_strip_to_json(text))
|
||||
continue
|
||||
except (ValueError, json.JSONDecodeError):
|
||||
pass
|
||||
retry = list(messages) + [
|
||||
{"role": "assistant", "content": text},
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
"Your previous reply was not valid JSON. "
|
||||
"Reply with strictly valid JSON, no prose, no fences."
|
||||
),
|
||||
},
|
||||
]
|
||||
retry_text = self.generate_text([retry], max_tok, temp)[0]
|
||||
try:
|
||||
out.append(_strip_to_json(retry_text))
|
||||
except (ValueError, json.JSONDecodeError):
|
||||
# After retry: log preview and return None instead of crashing
|
||||
# the whole pipeline. Modules treat None as "skip".
|
||||
preview = retry_text.strip().replace("\n", " ")[:200]
|
||||
print(
|
||||
f"[vlm] WARNING: failed to parse JSON after retry; preview: {preview!r}",
|
||||
flush=True,
|
||||
)
|
||||
out.append(None)
|
||||
return out
|
||||
|
||||
|
||||
def make_vlm_client(config: VlmConfig) -> VlmClient:
|
||||
"""Build the shared VLM client.
|
||||
|
||||
Only the ``openai`` backend is supported for now. The shipped workflow
|
||||
is Hugging Face Jobs (``examples/annotations/run_hf_job.py``): it boots
|
||||
a vLLM server inside the ``vllm/vllm-openai`` image and the pipeline
|
||||
talks to it over the OpenAI-compatible API (``--vlm.backend=openai``,
|
||||
optionally auto-spawning the server via ``auto_serve`` /
|
||||
``serve_command``). The former in-process ``vllm`` / ``transformers``
|
||||
backends were removed to keep the support surface to the HF Jobs path.
|
||||
|
||||
For ``stub``, construct :class:`StubVlmClient` directly with a responder
|
||||
callable; it is rejected here to make accidental misuse obvious.
|
||||
"""
|
||||
if config.backend == "openai":
|
||||
return _make_openai_client(config)
|
||||
if config.backend == "stub":
|
||||
raise ValueError(
|
||||
"Use StubVlmClient(...) directly for the stub backend; make_vlm_client builds real clients."
|
||||
)
|
||||
if config.backend in {"vllm", "transformers"}:
|
||||
raise ValueError(
|
||||
f"backend={config.backend!r} (in-process local model) is not supported for now — "
|
||||
"only backend='openai' (the Hugging Face Jobs flow) is. Run the pipeline via "
|
||||
"examples/annotations/run_hf_job.py, which serves the model with vLLM in the "
|
||||
"vllm/vllm-openai image and talks to it over the OpenAI-compatible API."
|
||||
)
|
||||
raise ValueError(f"Unknown VLM backend: {config.backend!r}")
|
||||
|
||||
|
||||
def _make_openai_client(config: VlmConfig) -> VlmClient:
|
||||
"""Backend that talks to any OpenAI-compatible server.
|
||||
|
||||
Compatible with ``vllm serve``, ``transformers serve``,
|
||||
``ktransformers serve``, and hosted endpoints. By default the server
|
||||
is expected to be already running. Set ``auto_serve=True`` to have
|
||||
this client spawn one (default: ``transformers serve``), wait until
|
||||
it's ready, and tear it down on process exit.
|
||||
|
||||
Image blocks ``{"type":"image", "image":<PIL.Image>}`` are
|
||||
auto-converted to ``image_url`` data-URLs. Video blocks
|
||||
``{"type":"video", "video":[<PIL>...]}`` are forwarded as
|
||||
multi-frame ``video_url`` items where supported.
|
||||
"""
|
||||
try:
|
||||
from openai import OpenAI # type: ignore[import-not-found]
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"openai package is required for backend='openai'. Install with `pip install openai`."
|
||||
) from exc
|
||||
|
||||
api_base = config.api_base
|
||||
api_key = config.api_key
|
||||
auto_serve = config.auto_serve
|
||||
api_bases: list[str] = [api_base]
|
||||
|
||||
print(
|
||||
f"[lerobot-annotate] backend=openai model={config.model_id} "
|
||||
f"api_base={api_base} auto_serve={auto_serve}",
|
||||
flush=True,
|
||||
)
|
||||
if auto_serve:
|
||||
if config.parallel_servers > 1:
|
||||
print(
|
||||
f"[lerobot-annotate] spawning {config.parallel_servers} parallel servers",
|
||||
flush=True,
|
||||
)
|
||||
api_bases = _spawn_parallel_inference_servers(config)
|
||||
elif _server_is_up(api_base):
|
||||
print(f"[lerobot-annotate] reusing server already up at {api_base}", flush=True)
|
||||
else:
|
||||
print("[lerobot-annotate] no server reachable; spawning one", flush=True)
|
||||
api_base = _spawn_inference_server(config)
|
||||
api_bases = [api_base]
|
||||
print(f"[lerobot-annotate] server ready at {api_base}", flush=True)
|
||||
|
||||
clients = [OpenAI(base_url=base, api_key=api_key) for base in api_bases]
|
||||
# round-robin counter for parallel mode
|
||||
rr_counter = {"i": 0}
|
||||
|
||||
# ``mm_processor_kwargs`` is a vllm-specific extra; transformers serve
|
||||
# rejects it with HTTP 422. Send it only when explicitly opted in via
|
||||
# an env var (e.g. ``LEROBOT_OPENAI_SEND_MM_KWARGS=1`` for vllm).
|
||||
send_mm_kwargs = os.environ.get("LEROBOT_OPENAI_SEND_MM_KWARGS", "").lower() in {"1", "true", "yes"}
|
||||
|
||||
rr_lock = threading.Lock()
|
||||
|
||||
def _one_call(messages: Sequence[dict[str, Any]], max_tok: int, temp: float) -> str:
|
||||
api_messages, mm_kwargs = _to_openai_messages(messages)
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": config.model_id,
|
||||
"messages": api_messages,
|
||||
"max_tokens": max_tok,
|
||||
"temperature": temp,
|
||||
}
|
||||
extra_body: dict[str, Any] = {}
|
||||
if send_mm_kwargs and mm_kwargs:
|
||||
extra_body["mm_processor_kwargs"] = {**mm_kwargs, "do_sample_frames": True}
|
||||
if config.chat_template_kwargs:
|
||||
extra_body["chat_template_kwargs"] = config.chat_template_kwargs
|
||||
if extra_body:
|
||||
kwargs["extra_body"] = extra_body
|
||||
with rr_lock:
|
||||
chosen = clients[rr_counter["i"] % len(clients)]
|
||||
rr_counter["i"] += 1
|
||||
response = chosen.chat.completions.create(**kwargs)
|
||||
return response.choices[0].message.content or ""
|
||||
|
||||
def _gen(batch: Sequence[Sequence[dict[str, Any]]], max_tok: int, temp: float) -> list[str]:
|
||||
if len(batch) <= 1 or config.client_concurrency <= 1:
|
||||
return [_one_call(messages, max_tok, temp) for messages in batch]
|
||||
# Parallel fan-out — vllm batches these on the server side.
|
||||
max_workers = min(config.client_concurrency, len(batch))
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as pool:
|
||||
futures = [pool.submit(_one_call, messages, max_tok, temp) for messages in batch]
|
||||
return [f.result() for f in futures]
|
||||
|
||||
return _GenericTextClient(_gen, config)
|
||||
|
||||
|
||||
def _spawn_parallel_inference_servers(config: VlmConfig) -> list[str]:
|
||||
"""Spawn ``config.parallel_servers`` independent vllm replicas.
|
||||
|
||||
Each replica:
|
||||
- is pinned to a single GPU via ``CUDA_VISIBLE_DEVICES``
|
||||
- listens on ``serve_port + i``
|
||||
- is shut down via the same atexit hook as the single-server path
|
||||
|
||||
Returns the list of ``api_base`` URLs the client should round-robin
|
||||
across.
|
||||
"""
|
||||
n = config.parallel_servers
|
||||
api_bases: list[str] = []
|
||||
procs: list[subprocess.Popen] = []
|
||||
ready_events: list[threading.Event] = []
|
||||
# Multiple readiness signals — uvicorn's own banner is suppressed at
|
||||
# ``--uvicorn-log-level warning``, so we also accept vllm's own
|
||||
# "Starting vLLM API server" line and the route-listing line. The
|
||||
# HTTP probe below is the ultimate fallback.
|
||||
ready_markers = (
|
||||
"Uvicorn running",
|
||||
"Application startup complete",
|
||||
"Starting vLLM API server",
|
||||
"Available routes are",
|
||||
)
|
||||
# Single lock for all server-stream threads so multibyte chars from
|
||||
# different servers don't interleave and tear UTF-8 sequences.
|
||||
print_lock = threading.Lock()
|
||||
|
||||
base_cmd = config.serve_command or (
|
||||
f"vllm serve {shlex.quote(config.model_id)} "
|
||||
f"--tensor-parallel-size 1 "
|
||||
f"--max-model-len {config.max_model_len or 32768} "
|
||||
f"--uvicorn-log-level warning"
|
||||
)
|
||||
|
||||
num_gpus = config.num_gpus if config.num_gpus > 0 else n
|
||||
for i in range(n):
|
||||
port = config.serve_port + i
|
||||
gpu = i % num_gpus
|
||||
env = os.environ.copy()
|
||||
env["CUDA_VISIBLE_DEVICES"] = str(gpu)
|
||||
cmd = base_cmd.replace("{port}", str(port)) if "{port}" in base_cmd else f"{base_cmd} --port {port}"
|
||||
api_base = f"http://localhost:{port}/v1"
|
||||
api_bases.append(api_base)
|
||||
print(f"[server-{i}] launching on GPU {gpu} port {port}: {cmd}", flush=True)
|
||||
proc = subprocess.Popen(
|
||||
shlex.split(cmd),
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
text=True,
|
||||
bufsize=1,
|
||||
env=env,
|
||||
)
|
||||
procs.append(proc)
|
||||
ready = threading.Event()
|
||||
ready_events.append(ready)
|
||||
|
||||
def _stream(idx: int, p: subprocess.Popen, ev: threading.Event) -> None:
|
||||
# Read whole lines and emit each line atomically under the
|
||||
# shared print_lock so output from N servers stays readable.
|
||||
assert p.stdout is not None
|
||||
for line in iter(p.stdout.readline, ""):
|
||||
with print_lock:
|
||||
sys.stdout.write(f"[server-{idx}] {line}")
|
||||
if not line.endswith(("\n", "\r")):
|
||||
sys.stdout.write("\n")
|
||||
sys.stdout.flush()
|
||||
if any(m in line for m in ready_markers):
|
||||
ev.set()
|
||||
|
||||
threading.Thread(target=_stream, args=(i, proc, ready), daemon=True).start()
|
||||
|
||||
def _probe(idx: int, base: str, ev: threading.Event, p: subprocess.Popen) -> None:
|
||||
while not ev.is_set() and p.poll() is None:
|
||||
if _server_is_up(base):
|
||||
print(f"[server-{idx}] ready (http probe)", flush=True)
|
||||
ev.set()
|
||||
return
|
||||
time.sleep(2)
|
||||
|
||||
threading.Thread(target=_probe, args=(i, api_base, ready, proc), daemon=True).start()
|
||||
|
||||
def _shutdown() -> None:
|
||||
for i, p in enumerate(procs):
|
||||
if p.poll() is None:
|
||||
print(f"[server-{i}] stopping pid={p.pid}", flush=True)
|
||||
p.send_signal(signal.SIGINT)
|
||||
for p in procs:
|
||||
try:
|
||||
p.wait(timeout=15)
|
||||
except subprocess.TimeoutExpired:
|
||||
p.kill()
|
||||
p.wait(timeout=5)
|
||||
|
||||
atexit.register(_shutdown)
|
||||
|
||||
deadline = time.monotonic() + config.serve_ready_timeout_s
|
||||
while any(not ev.is_set() for ev in ready_events) and time.monotonic() < deadline:
|
||||
for i, p in enumerate(procs):
|
||||
if p.poll() is not None:
|
||||
raise RuntimeError(
|
||||
f"[server-{i}] inference server exited unexpectedly with rc={p.returncode}"
|
||||
)
|
||||
time.sleep(2)
|
||||
if any(not ev.is_set() for ev in ready_events):
|
||||
raise RuntimeError(f"[server] not all replicas became ready within {config.serve_ready_timeout_s}s")
|
||||
print(f"[lerobot-annotate] all {n} servers ready: {api_bases}", flush=True)
|
||||
return api_bases
|
||||
|
||||
|
||||
def _server_is_up(api_base: str) -> bool:
|
||||
"""Return True if ``api_base/models`` answers 200 within 2 seconds."""
|
||||
url = api_base.rstrip("/") + "/models"
|
||||
# ``api_base`` is the user-configured local-server URL we just spawned
|
||||
# or the user passed in via ``--vlm.api_base``; the bandit B310 warning
|
||||
# is for arbitrary user-controlled URLs with file:/ schemes which
|
||||
# cannot reach this code path.
|
||||
try:
|
||||
with urllib.request.urlopen(url, timeout=2) as resp: # noqa: S310 # nosec B310
|
||||
return resp.status == 200
|
||||
except Exception: # noqa: BLE001
|
||||
return False
|
||||
|
||||
|
||||
def _spawn_inference_server(config: VlmConfig) -> str:
|
||||
"""Spawn ``transformers serve`` (or ``serve_command``), wait until it
|
||||
accepts ``/v1/models``, and register a shutdown hook.
|
||||
|
||||
Streams the server's stdout/stderr to the parent terminal in
|
||||
real-time on a background thread so users can see model-load
|
||||
progress and errors as they happen.
|
||||
|
||||
Returns the full ``api_base`` URL the OpenAI client should use.
|
||||
"""
|
||||
cmd = config.serve_command
|
||||
if not cmd:
|
||||
cmd = (
|
||||
f"transformers serve {shlex.quote(config.model_id)} "
|
||||
f"--port {config.serve_port} --continuous-batching"
|
||||
)
|
||||
api_base = f"http://localhost:{config.serve_port}/v1"
|
||||
print(f"[server] launching: {cmd}", flush=True)
|
||||
proc = subprocess.Popen(
|
||||
shlex.split(cmd),
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
text=True,
|
||||
bufsize=1,
|
||||
)
|
||||
|
||||
# Watch the server output for the uvicorn readiness banner. This is
|
||||
# more reliable than polling /v1/models because transformers serve
|
||||
# rescans its cache on every model-list request, which can exceed
|
||||
# the urllib timeout and trigger an infinite probe loop.
|
||||
ready_event = threading.Event()
|
||||
# See _spawn_parallel_inference_servers for why we accept these.
|
||||
ready_markers = (
|
||||
"Uvicorn running",
|
||||
"Application startup complete",
|
||||
"Starting vLLM API server",
|
||||
"Available routes are",
|
||||
)
|
||||
|
||||
def _probe() -> None:
|
||||
while not ready_event.is_set() and proc.poll() is None:
|
||||
if _server_is_up(api_base):
|
||||
print("[server] ready (http probe)", flush=True)
|
||||
ready_event.set()
|
||||
return
|
||||
time.sleep(2)
|
||||
|
||||
threading.Thread(target=_probe, daemon=True).start()
|
||||
|
||||
def _stream_output() -> None:
|
||||
# Read raw chunks instead of iterating lines so tqdm progress
|
||||
# bars (which overwrite using \r) flush in real time.
|
||||
assert proc.stdout is not None
|
||||
buf = ""
|
||||
prefix_started = False
|
||||
while True:
|
||||
ch = proc.stdout.read(1)
|
||||
if ch == "":
|
||||
# process exited; flush any tail
|
||||
if buf:
|
||||
sys.stdout.write(buf)
|
||||
sys.stdout.flush()
|
||||
return
|
||||
if not prefix_started:
|
||||
sys.stdout.write("[server] ")
|
||||
prefix_started = True
|
||||
sys.stdout.write(ch)
|
||||
sys.stdout.flush()
|
||||
buf += ch
|
||||
if ch in ("\n", "\r"):
|
||||
if any(marker in buf for marker in ready_markers):
|
||||
ready_event.set()
|
||||
buf = ""
|
||||
prefix_started = False
|
||||
|
||||
threading.Thread(target=_stream_output, daemon=True).start()
|
||||
|
||||
def _shutdown() -> None:
|
||||
if proc.poll() is None:
|
||||
print(f"[server] stopping pid={proc.pid}", flush=True)
|
||||
proc.send_signal(signal.SIGINT)
|
||||
try:
|
||||
proc.wait(timeout=15)
|
||||
except subprocess.TimeoutExpired:
|
||||
proc.kill()
|
||||
proc.wait(timeout=5)
|
||||
|
||||
atexit.register(_shutdown)
|
||||
|
||||
deadline = time.monotonic() + config.serve_ready_timeout_s
|
||||
while time.monotonic() < deadline:
|
||||
if proc.poll() is not None:
|
||||
raise RuntimeError(
|
||||
f"[server] inference server exited unexpectedly with rc={proc.returncode}. "
|
||||
f"See [server] log lines above for the cause."
|
||||
)
|
||||
if ready_event.wait(timeout=2):
|
||||
return api_base
|
||||
proc.terminate()
|
||||
raise RuntimeError(f"[server] did not become ready within {config.serve_ready_timeout_s}s")
|
||||
|
||||
|
||||
def _to_openai_messages(
|
||||
messages: Sequence[dict[str, Any]],
|
||||
) -> tuple[list[dict[str, Any]], dict[str, Any]]:
|
||||
"""Convert internal messages to OpenAI chat format.
|
||||
|
||||
Returns ``(api_messages, mm_kwargs)``. Multimodal-processor kwargs
|
||||
(``fps`` from ``video_url`` blocks) are extracted out so the caller
|
||||
can pass them via ``extra_body.mm_processor_kwargs`` rather than
|
||||
inside the content blocks (which transformers serve rejects).
|
||||
|
||||
File-URL video blocks are inlined as base64 data URLs.
|
||||
"""
|
||||
out_messages: list[dict[str, Any]] = []
|
||||
mm_kwargs: dict[str, Any] = {}
|
||||
for message in messages:
|
||||
content = message.get("content")
|
||||
if not isinstance(content, list):
|
||||
out_messages.append({"role": message["role"], "content": content})
|
||||
continue
|
||||
out_blocks: list[dict[str, Any]] = []
|
||||
for block in content:
|
||||
block_type = block.get("type") if isinstance(block, dict) else None
|
||||
if block_type == "text":
|
||||
out_blocks.append({"type": "text", "text": block.get("text", "")})
|
||||
elif block_type == "image":
|
||||
out_blocks.append(
|
||||
{"type": "image_url", "image_url": {"url": _pil_to_data_url(block["image"])}}
|
||||
)
|
||||
elif block_type == "video":
|
||||
frames = block.get("video", [])
|
||||
for img in frames:
|
||||
out_blocks.append({"type": "image_url", "image_url": {"url": _pil_to_data_url(img)}})
|
||||
elif block_type == "video_url":
|
||||
video_url = dict(block["video_url"])
|
||||
url = video_url.get("url", "")
|
||||
if url.startswith("file://"):
|
||||
video_url["url"] = _file_to_data_url(url[len("file://") :])
|
||||
out_blocks.append({"type": "video_url", "video_url": video_url})
|
||||
fps = block.get("fps")
|
||||
if fps is not None:
|
||||
mm_kwargs["fps"] = fps
|
||||
else:
|
||||
out_blocks.append(block)
|
||||
out_messages.append({"role": message["role"], "content": out_blocks})
|
||||
return out_messages, mm_kwargs
|
||||
|
||||
|
||||
def _file_to_data_url(path: str) -> str:
|
||||
"""Read a local video file and return a base64 ``data:video/mp4`` URL."""
|
||||
with open(path, "rb") as f:
|
||||
b64 = base64.b64encode(f.read()).decode("ascii")
|
||||
return f"data:video/mp4;base64,{b64}"
|
||||
|
||||
|
||||
def _pil_to_data_url(image: Any) -> str:
|
||||
"""Encode a PIL.Image as a base64 data URL."""
|
||||
buf = io.BytesIO()
|
||||
image.save(buf, format="PNG")
|
||||
b64 = base64.b64encode(buf.getvalue()).decode("ascii")
|
||||
return f"data:image/png;base64,{b64}"
|
||||
@@ -1,341 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Final parquet rewrite.
|
||||
|
||||
For every episode the writer:
|
||||
|
||||
1. reads the staged module outputs,
|
||||
2. partitions them into a persistent slice (PERSISTENT_STYLES) and an event
|
||||
slice (EVENT_ONLY_STYLES + style=None tool-call atoms),
|
||||
3. sorts each slice deterministically,
|
||||
4. broadcasts the persistent slice across every frame in the episode,
|
||||
5. for each frame, materializes the sublist of event rows whose timestamp
|
||||
exactly equals that frame's timestamp,
|
||||
6. drops the legacy ``subtask_index`` column,
|
||||
7. writes the parquet shard back in place.
|
||||
|
||||
The writer does NOT add a dataset-level ``tools`` column. Tool *calls* are
|
||||
emitted per-row via the existing ``tool_calls`` field on the v3.1 row
|
||||
struct for every speech atom. The tool *schema* (the description
|
||||
of the ``say`` function and its parameters) is a fixed code constant —
|
||||
``SAY_TOOL_SCHEMA`` below — and downstream chat-template consumers import
|
||||
it directly rather than reading a redundant per-row column.
|
||||
|
||||
Invariants enforced here (and re-checked by the validator):
|
||||
|
||||
- per-episode persistent slice is byte-identical across every frame;
|
||||
- ``language_events`` rows on a frame all have ``timestamp == frame_ts``
|
||||
(timestamps come straight from the source parquet — never recomputed);
|
||||
- every row passes ``column_for_style(style)``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import pyarrow as pa
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
from lerobot.datasets.language import (
|
||||
EVENT_ONLY_STYLES,
|
||||
LANGUAGE_EVENTS,
|
||||
LANGUAGE_PERSISTENT,
|
||||
PERSISTENT_STYLES,
|
||||
column_for_style,
|
||||
validate_camera_field,
|
||||
)
|
||||
|
||||
from .reader import EpisodeRecord
|
||||
from .staging import EpisodeStaging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Tool schema constants live in lerobot.datasets.language — single
|
||||
# source of truth. Re-exported here so existing imports
|
||||
# (``from lerobot.annotations.steerable_pipeline.writer import SAY_TOOL_SCHEMA``)
|
||||
# keep working.
|
||||
from lerobot.datasets.language import DEFAULT_TOOLS, SAY_TOOL_SCHEMA # noqa: F401, E402
|
||||
|
||||
|
||||
def _row_persistent_sort_key(row: dict[str, Any]) -> tuple:
|
||||
return (float(row["timestamp"]), row.get("style") or "", row.get("role") or "")
|
||||
|
||||
|
||||
def _row_event_sort_key(row: dict[str, Any]) -> tuple:
|
||||
# events are bucketed per-frame, but within a frame we still want determinism
|
||||
return (
|
||||
row.get("style") or "",
|
||||
row.get("role") or "",
|
||||
row.get("camera") or "",
|
||||
)
|
||||
|
||||
|
||||
def _normalize_row(row: dict[str, Any], style: str | None, *, with_timestamp: bool) -> dict[str, Any]:
|
||||
"""Coerce a staged row into the language-column struct shape.
|
||||
|
||||
Key order matches ``PERSISTENT_ROW_FIELDS`` / ``EVENT_ROW_FIELDS`` — the
|
||||
writer infers the parquet struct schema from insertion order, so
|
||||
``timestamp`` (persistent rows only) sits between ``style`` and ``camera``.
|
||||
"""
|
||||
camera = row.get("camera")
|
||||
validate_camera_field(style, camera)
|
||||
out: dict[str, Any] = {
|
||||
"role": str(row["role"]),
|
||||
"content": None if row.get("content") is None else str(row["content"]),
|
||||
"style": style,
|
||||
}
|
||||
if with_timestamp:
|
||||
out["timestamp"] = float(row["timestamp"])
|
||||
out["camera"] = None if camera is None else str(camera)
|
||||
out["tool_calls"] = _normalize_tool_calls(row.get("tool_calls"))
|
||||
return out
|
||||
|
||||
|
||||
def _normalize_persistent_row(row: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Coerce a staged row into the persistent column's struct shape."""
|
||||
style = row.get("style")
|
||||
if style not in PERSISTENT_STYLES:
|
||||
raise ValueError(
|
||||
f"persistent slice contains row with non-persistent style {style!r}; "
|
||||
"row would be misrouted under column_for_style()"
|
||||
)
|
||||
if "timestamp" not in row:
|
||||
raise ValueError(f"persistent row missing timestamp: {row!r}")
|
||||
if "role" not in row:
|
||||
# Friendly error from the writer instead of a raw KeyError below;
|
||||
# the validator doesn't check ``role`` yet.
|
||||
raise ValueError(f"persistent row missing role: {row!r}")
|
||||
return _normalize_row(row, style, with_timestamp=True)
|
||||
|
||||
|
||||
def _normalize_event_row(row: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Coerce a staged row into the event column's struct shape (no timestamp)."""
|
||||
style = row.get("style")
|
||||
if style is not None and style not in EVENT_ONLY_STYLES:
|
||||
raise ValueError(
|
||||
f"event slice contains row with style {style!r}; expected None or one of {EVENT_ONLY_STYLES}"
|
||||
)
|
||||
if column_for_style(style) != LANGUAGE_EVENTS:
|
||||
raise ValueError(f"event row with style {style!r} would not route to language_events")
|
||||
if "role" not in row:
|
||||
raise ValueError(f"event row missing role: {row!r}")
|
||||
return _normalize_row(row, style, with_timestamp=False)
|
||||
|
||||
|
||||
def _normalize_tool_calls(value: Any) -> list[Any] | None:
|
||||
if value is None:
|
||||
return None
|
||||
if not isinstance(value, list):
|
||||
raise ValueError(f"tool_calls must be a list or None, got {type(value).__name__}")
|
||||
return list(value)
|
||||
|
||||
|
||||
def _validate_atom_invariants(row: dict[str, Any]) -> None:
|
||||
"""At-least-one of content/tool_calls; style=None implies tool_calls."""
|
||||
has_content = row.get("content") is not None
|
||||
has_tools = row.get("tool_calls") is not None
|
||||
if not (has_content or has_tools):
|
||||
raise ValueError(f"row has neither content nor tool_calls: {row!r}")
|
||||
if row.get("style") is None and not has_tools:
|
||||
raise ValueError(f"style=None requires tool_calls: {row!r}")
|
||||
|
||||
|
||||
def _validate_speech_atom(row: dict[str, Any]) -> None:
|
||||
"""Speech atoms: role=assistant, style=None, content=None, say tool call."""
|
||||
if row.get("style") is not None:
|
||||
return # not a speech atom
|
||||
if row.get("role") != "assistant":
|
||||
raise ValueError(f"speech atom must have role=assistant: {row!r}")
|
||||
if row.get("content") is not None:
|
||||
raise ValueError(f"speech atom must have content=null: {row!r}")
|
||||
tool_calls = row.get("tool_calls")
|
||||
if not tool_calls or not isinstance(tool_calls, list):
|
||||
raise ValueError(f"speech atom must have non-empty tool_calls list: {row!r}")
|
||||
first = tool_calls[0]
|
||||
if not isinstance(first, dict):
|
||||
raise ValueError(f"speech atom tool_calls[0] must be a dict: {row!r}")
|
||||
if first.get("type") != "function":
|
||||
raise ValueError(f"speech atom tool_calls[0].type must be 'function': {row!r}")
|
||||
fn = first.get("function") or {}
|
||||
if fn.get("name") != "say":
|
||||
raise ValueError(f"speech atom tool_calls[0].function.name must be 'say': {row!r}")
|
||||
args = fn.get("arguments") or {}
|
||||
if not isinstance(args, dict) or "text" not in args or not isinstance(args["text"], str):
|
||||
raise ValueError(f"speech atom must carry 'text' string in arguments: {row!r}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class LanguageColumnsWriter:
|
||||
"""Rewrite ``data/chunk-*/file-*.parquet`` with the two language columns."""
|
||||
|
||||
drop_existing_subtask_index: bool = True
|
||||
|
||||
def write_all(
|
||||
self,
|
||||
records: Sequence[EpisodeRecord],
|
||||
staging_dir: Path,
|
||||
root: Path,
|
||||
) -> list[Path]:
|
||||
episodes_by_path: dict[Path, list[EpisodeRecord]] = defaultdict(list)
|
||||
for record in records:
|
||||
episodes_by_path[record.data_path].append(record)
|
||||
|
||||
written: list[Path] = []
|
||||
for path, eps in episodes_by_path.items():
|
||||
self._rewrite_one(path, eps, staging_dir, root)
|
||||
written.append(path)
|
||||
return written
|
||||
|
||||
def _rewrite_one(
|
||||
self,
|
||||
path: Path,
|
||||
episodes: Sequence[EpisodeRecord],
|
||||
staging_dir: Path,
|
||||
root: Path,
|
||||
) -> None:
|
||||
table = pq.read_table(path)
|
||||
n_rows = table.num_rows
|
||||
|
||||
# Ensure we cover every episode in the file. Episodes that don't have
|
||||
# staging artifacts are passed through with empty annotation lists —
|
||||
# this keeps the writer idempotent and safe for partial reruns.
|
||||
staged_per_ep: dict[int, dict[str, list[dict[str, Any]]]] = {}
|
||||
for record in episodes:
|
||||
staging = EpisodeStaging(staging_dir, record.episode_index)
|
||||
staged_per_ep[record.episode_index] = staging.read_all()
|
||||
|
||||
persistent_by_ep: dict[int, list[dict[str, Any]]] = {}
|
||||
events_by_ep_ts: dict[int, dict[float, list[dict[str, Any]]]] = {}
|
||||
|
||||
for ep_index, ep_staged in staged_per_ep.items():
|
||||
persistent_rows: list[dict[str, Any]] = []
|
||||
event_rows: list[dict[str, Any]] = [] # carry timestamp until bucketed
|
||||
for _module_name, rows in ep_staged.items():
|
||||
for row in rows:
|
||||
style = row.get("style")
|
||||
if column_for_style(style) == LANGUAGE_PERSISTENT:
|
||||
persistent_rows.append(row)
|
||||
else:
|
||||
event_rows.append(row)
|
||||
|
||||
persistent_rows.sort(key=_row_persistent_sort_key)
|
||||
normalized_persistent = []
|
||||
for r in persistent_rows:
|
||||
_validate_atom_invariants(r)
|
||||
_validate_speech_atom(r)
|
||||
normalized_persistent.append(_normalize_persistent_row(r))
|
||||
persistent_by_ep[ep_index] = normalized_persistent
|
||||
|
||||
buckets: dict[float, list[dict[str, Any]]] = defaultdict(list)
|
||||
for r in event_rows:
|
||||
_validate_atom_invariants(r)
|
||||
_validate_speech_atom(r)
|
||||
ts = float(r["timestamp"])
|
||||
buckets[ts].append(_normalize_event_row(r))
|
||||
for ts in list(buckets.keys()):
|
||||
buckets[ts].sort(key=_row_event_sort_key)
|
||||
events_by_ep_ts[ep_index] = buckets
|
||||
|
||||
episode_col = (
|
||||
table.column("episode_index").to_pylist() if "episode_index" in table.column_names else None
|
||||
)
|
||||
ts_col = table.column("timestamp").to_pylist() if "timestamp" in table.column_names else None
|
||||
if episode_col is None or ts_col is None:
|
||||
raise ValueError(f"{path} is missing 'episode_index' or 'timestamp' — required by the writer.")
|
||||
|
||||
per_row_persistent: list[list[dict[str, Any]]] = []
|
||||
per_row_events: list[list[dict[str, Any]]] = []
|
||||
for i in range(n_rows):
|
||||
ep = episode_col[i]
|
||||
ts = float(ts_col[i])
|
||||
per_row_persistent.append(persistent_by_ep.get(ep, []))
|
||||
buckets = events_by_ep_ts.get(ep, {})
|
||||
per_row_events.append(buckets.get(ts, []))
|
||||
|
||||
new_table = self._materialize_table(
|
||||
table, per_row_persistent, per_row_events, drop_old=self.drop_existing_subtask_index
|
||||
)
|
||||
# Atomic replace: write to a sibling tmp path and rename so a crash
|
||||
# mid-write can't leave a half-written shard that ``pq.read_table``
|
||||
# would then fail to open. ``Path.replace`` is atomic on POSIX +
|
||||
# Windows when source and target sit on the same filesystem.
|
||||
tmp_path = path.with_suffix(path.suffix + ".tmp")
|
||||
pq.write_table(new_table, tmp_path)
|
||||
tmp_path.replace(path)
|
||||
|
||||
def _materialize_table(
|
||||
self,
|
||||
table: pa.Table,
|
||||
persistent: list[list[dict[str, Any]]],
|
||||
events: list[list[dict[str, Any]]],
|
||||
*,
|
||||
drop_old: bool,
|
||||
) -> pa.Table:
|
||||
cols = []
|
||||
names = []
|
||||
for name in table.column_names:
|
||||
if drop_old and name == "subtask_index":
|
||||
continue
|
||||
if name in (LANGUAGE_PERSISTENT, LANGUAGE_EVENTS):
|
||||
continue # we'll re-add canonical versions
|
||||
# Strip any legacy ``tools`` column previously emitted by older
|
||||
# writers — the schema no longer uses it (constant lives in
|
||||
# SAY_TOOL_SCHEMA / DEFAULT_TOOLS).
|
||||
if name == "tools":
|
||||
continue
|
||||
cols.append(table.column(name))
|
||||
names.append(name)
|
||||
|
||||
# We let pyarrow infer struct/list schema rather than passing the
|
||||
# canonical type from `lerobot.datasets.language` directly: that type
|
||||
# uses `pa.json_()` for the `tool_calls` element type, which
|
||||
# `pa.array(..., type=...)` cannot materialize from Python lists on
|
||||
# current pyarrow versions. The inferred schema round-trips through
|
||||
# parquet and `LeRobotDataset` correctly — `tests/datasets/test_language.py`
|
||||
# exercises the same flow.
|
||||
persistent_arr = pa.array(persistent)
|
||||
events_arr = pa.array(events)
|
||||
|
||||
cols.extend([persistent_arr, events_arr])
|
||||
names.extend([LANGUAGE_PERSISTENT, LANGUAGE_EVENTS])
|
||||
|
||||
return pa.Table.from_arrays(cols, names=names)
|
||||
|
||||
|
||||
def speech_atom(timestamp: float, text: str) -> dict[str, Any]:
|
||||
"""Build a canonical speech tool-call atom for the events column."""
|
||||
return {
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"style": None,
|
||||
"timestamp": float(timestamp),
|
||||
"camera": None,
|
||||
"tool_calls": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "say",
|
||||
"arguments": {"text": text},
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
@@ -1,30 +0,0 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Async inference server/client.
|
||||
|
||||
Requires: ``pip install 'lerobot[async]'``
|
||||
|
||||
Available modules (import directly)::
|
||||
|
||||
from lerobot.async_inference.policy_server import ...
|
||||
from lerobot.async_inference.robot_client import ...
|
||||
"""
|
||||
|
||||
from lerobot.utils.import_utils import require_package
|
||||
|
||||
require_package("grpcio", extra="async", import_name="grpc")
|
||||
|
||||
__all__: list[str] = []
|
||||
@@ -1,203 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.robots.config import RobotConfig
|
||||
|
||||
from .constants import (
|
||||
DEFAULT_FPS,
|
||||
DEFAULT_INFERENCE_LATENCY,
|
||||
DEFAULT_OBS_QUEUE_TIMEOUT,
|
||||
)
|
||||
|
||||
# Aggregate function registry for CLI usage
|
||||
AGGREGATE_FUNCTIONS = {
|
||||
"weighted_average": lambda old, new: 0.3 * old + 0.7 * new,
|
||||
"latest_only": lambda old, new: new,
|
||||
"average": lambda old, new: 0.5 * old + 0.5 * new,
|
||||
"conservative": lambda old, new: 0.7 * old + 0.3 * new,
|
||||
}
|
||||
|
||||
|
||||
def get_aggregate_function(name: str) -> Callable[[torch.Tensor, torch.Tensor], torch.Tensor]:
|
||||
"""Get aggregate function by name from registry."""
|
||||
if name not in AGGREGATE_FUNCTIONS:
|
||||
available = list(AGGREGATE_FUNCTIONS.keys())
|
||||
raise ValueError(f"Unknown aggregate function '{name}'. Available: {available}")
|
||||
return AGGREGATE_FUNCTIONS[name]
|
||||
|
||||
|
||||
@dataclass
|
||||
class PolicyServerConfig:
|
||||
"""Configuration for PolicyServer.
|
||||
|
||||
This class defines all configurable parameters for the PolicyServer,
|
||||
including networking settings and action chunking specifications.
|
||||
"""
|
||||
|
||||
# Networking configuration
|
||||
host: str = field(default="localhost", metadata={"help": "Host address to bind the server to"})
|
||||
port: int = field(default=8080, metadata={"help": "Port number to bind the server to"})
|
||||
|
||||
# Timing configuration
|
||||
fps: int = field(default=DEFAULT_FPS, metadata={"help": "Frames per second"})
|
||||
inference_latency: float = field(
|
||||
default=DEFAULT_INFERENCE_LATENCY, metadata={"help": "Target inference latency in seconds"}
|
||||
)
|
||||
|
||||
obs_queue_timeout: float = field(
|
||||
default=DEFAULT_OBS_QUEUE_TIMEOUT, metadata={"help": "Timeout for observation queue in seconds"}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
"""Validate configuration after initialization."""
|
||||
if self.port < 1 or self.port > 65535:
|
||||
raise ValueError(f"Port must be between 1 and 65535, got {self.port}")
|
||||
|
||||
if self.environment_dt <= 0:
|
||||
raise ValueError(f"environment_dt must be positive, got {self.environment_dt}")
|
||||
|
||||
if self.inference_latency < 0:
|
||||
raise ValueError(f"inference_latency must be non-negative, got {self.inference_latency}")
|
||||
|
||||
if self.obs_queue_timeout < 0:
|
||||
raise ValueError(f"obs_queue_timeout must be non-negative, got {self.obs_queue_timeout}")
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, config_dict: dict) -> "PolicyServerConfig":
|
||||
"""Create a PolicyServerConfig from a dictionary."""
|
||||
return cls(**config_dict)
|
||||
|
||||
@property
|
||||
def environment_dt(self) -> float:
|
||||
"""Environment time step, in seconds"""
|
||||
return 1 / self.fps
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert the configuration to a dictionary."""
|
||||
return {
|
||||
"host": self.host,
|
||||
"port": self.port,
|
||||
"fps": self.fps,
|
||||
"environment_dt": self.environment_dt,
|
||||
"inference_latency": self.inference_latency,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class RobotClientConfig:
|
||||
"""Configuration for RobotClient.
|
||||
|
||||
This class defines all configurable parameters for the RobotClient,
|
||||
including network connection, policy settings, and control behavior.
|
||||
"""
|
||||
|
||||
# Policy configuration
|
||||
policy_type: str = field(metadata={"help": "Type of policy to use"})
|
||||
pretrained_name_or_path: str = field(metadata={"help": "Pretrained model name or path"})
|
||||
|
||||
# Robot configuration (for CLI usage - robot instance will be created from this)
|
||||
robot: RobotConfig = field(metadata={"help": "Robot configuration"})
|
||||
|
||||
# Policies typically output K actions at max, but we can use less to avoid wasting bandwidth (as actions
|
||||
# would be aggregated on the client side anyway, depending on the value of `chunk_size_threshold`)
|
||||
actions_per_chunk: int = field(metadata={"help": "Number of actions per chunk"})
|
||||
|
||||
# Task instruction for the robot to execute (e.g., 'fold my tshirt')
|
||||
task: str = field(default="", metadata={"help": "Task instruction for the robot to execute"})
|
||||
|
||||
# Network configuration
|
||||
server_address: str = field(default="localhost:8080", metadata={"help": "Server address to connect to"})
|
||||
|
||||
# Device configuration
|
||||
policy_device: str = field(default="cpu", metadata={"help": "Device for policy inference"})
|
||||
client_device: str = field(
|
||||
default="cpu",
|
||||
metadata={
|
||||
"help": "Device to move actions to after receiving from server (e.g., for downstream planners)"
|
||||
},
|
||||
)
|
||||
|
||||
# Control behavior configuration
|
||||
chunk_size_threshold: float = field(default=0.5, metadata={"help": "Threshold for chunk size control"})
|
||||
fps: int = field(default=DEFAULT_FPS, metadata={"help": "Frames per second"})
|
||||
|
||||
# Aggregate function configuration (CLI-compatible)
|
||||
aggregate_fn_name: str = field(
|
||||
default="weighted_average",
|
||||
metadata={"help": f"Name of aggregate function to use. Options: {list(AGGREGATE_FUNCTIONS.keys())}"},
|
||||
)
|
||||
|
||||
# Debug configuration
|
||||
debug_visualize_queue_size: bool = field(
|
||||
default=False, metadata={"help": "Visualize the action queue size"}
|
||||
)
|
||||
|
||||
@property
|
||||
def environment_dt(self) -> float:
|
||||
"""Environment time step, in seconds"""
|
||||
return 1 / self.fps
|
||||
|
||||
def __post_init__(self):
|
||||
"""Validate configuration after initialization."""
|
||||
if not self.server_address:
|
||||
raise ValueError("server_address cannot be empty")
|
||||
|
||||
if not self.policy_type:
|
||||
raise ValueError("policy_type cannot be empty")
|
||||
|
||||
if not self.pretrained_name_or_path:
|
||||
raise ValueError("pretrained_name_or_path cannot be empty")
|
||||
|
||||
if not self.policy_device:
|
||||
raise ValueError("policy_device cannot be empty")
|
||||
|
||||
if not self.client_device:
|
||||
raise ValueError("client_device cannot be empty")
|
||||
|
||||
if self.chunk_size_threshold < 0 or self.chunk_size_threshold > 1:
|
||||
raise ValueError(f"chunk_size_threshold must be between 0 and 1, got {self.chunk_size_threshold}")
|
||||
|
||||
if self.fps <= 0:
|
||||
raise ValueError(f"fps must be positive, got {self.fps}")
|
||||
|
||||
if self.actions_per_chunk <= 0:
|
||||
raise ValueError(f"actions_per_chunk must be positive, got {self.actions_per_chunk}")
|
||||
|
||||
self.aggregate_fn = get_aggregate_function(self.aggregate_fn_name)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, config_dict: dict) -> "RobotClientConfig":
|
||||
"""Create a RobotClientConfig from a dictionary."""
|
||||
return cls(**config_dict)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert the configuration to a dictionary."""
|
||||
return {
|
||||
"server_address": self.server_address,
|
||||
"policy_type": self.policy_type,
|
||||
"pretrained_name_or_path": self.pretrained_name_or_path,
|
||||
"policy_device": self.policy_device,
|
||||
"client_device": self.client_device,
|
||||
"chunk_size_threshold": self.chunk_size_threshold,
|
||||
"fps": self.fps,
|
||||
"actions_per_chunk": self.actions_per_chunk,
|
||||
"task": self.task,
|
||||
"debug_visualize_queue_size": self.debug_visualize_queue_size,
|
||||
"aggregate_fn_name": self.aggregate_fn_name,
|
||||
}
|
||||
@@ -1,29 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Client side: The environment evolves with a time resolution equal to 1/fps"""
|
||||
|
||||
DEFAULT_FPS = 30
|
||||
|
||||
"""Server side: Running inference on (at most) 1/fps"""
|
||||
DEFAULT_INFERENCE_LATENCY = 1 / DEFAULT_FPS
|
||||
|
||||
"""Server side: Timeout for observation queue in seconds"""
|
||||
DEFAULT_OBS_QUEUE_TIMEOUT = 2
|
||||
|
||||
# All action chunking policies
|
||||
SUPPORTED_POLICIES = ["act", "smolvla", "diffusion", "tdmpc", "vqbet", "pi0", "pi05", "groot"]
|
||||
|
||||
# TODO: Add all other robots
|
||||
SUPPORTED_ROBOTS = ["so100_follower", "so101_follower", "bi_so_follower", "omx_follower"]
|
||||
@@ -1,297 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import logging.handlers
|
||||
import os
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs import PolicyFeature
|
||||
|
||||
# NOTE: Configs need to be loaded for the client to be able to instantiate the policy config
|
||||
from lerobot.policies import ( # noqa: F401
|
||||
ACTConfig,
|
||||
DiffusionConfig,
|
||||
PI0Config,
|
||||
PI05Config,
|
||||
SmolVLAConfig,
|
||||
VQBeTConfig,
|
||||
)
|
||||
from lerobot.robots.robot import Robot
|
||||
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE, OBS_STR
|
||||
from lerobot.utils.feature_utils import build_dataset_frame, hw_to_dataset_features
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
Action = torch.Tensor
|
||||
|
||||
# observation as received from the robot (can be numpy arrays, floats, etc.)
|
||||
RawObservation = dict[str, Any]
|
||||
|
||||
# observation as those recorded in LeRobot dataset (keys are different)
|
||||
LeRobotObservation = dict[str, torch.Tensor]
|
||||
|
||||
# observation, ready for policy inference (image keys resized)
|
||||
Observation = dict[str, torch.Tensor]
|
||||
|
||||
|
||||
def visualize_action_queue_size(action_queue_size: list[int]) -> None:
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
_, ax = plt.subplots()
|
||||
ax.set_title("Action Queue Size Over Time")
|
||||
ax.set_xlabel("Environment steps")
|
||||
ax.set_ylabel("Action Queue Size")
|
||||
ax.set_ylim(0, max(action_queue_size) * 1.1)
|
||||
ax.grid(True, alpha=0.3)
|
||||
ax.plot(range(len(action_queue_size)), action_queue_size)
|
||||
plt.show()
|
||||
|
||||
|
||||
def map_robot_keys_to_lerobot_features(robot: Robot) -> dict[str, dict]:
|
||||
return hw_to_dataset_features(robot.observation_features, OBS_STR, use_video=False)
|
||||
|
||||
|
||||
def is_image_key(k: str) -> bool:
|
||||
return k.startswith(OBS_IMAGES)
|
||||
|
||||
|
||||
def resize_robot_observation_image(image: torch.tensor, resize_dims: tuple[int, int, int]) -> torch.tensor:
|
||||
assert image.ndim == 3, f"Image must be (C, H, W)! Received {image.shape}"
|
||||
# (H, W, C) -> (C, H, W) for resizing from robot obsevation resolution to policy image resolution
|
||||
image = image.permute(2, 0, 1)
|
||||
dims = (resize_dims[1], resize_dims[2])
|
||||
# Add batch dimension for interpolate: (C, H, W) -> (1, C, H, W)
|
||||
image_batched = image.unsqueeze(0)
|
||||
# Interpolate and remove batch dimension: (1, C, H, W) -> (C, H, W)
|
||||
resized = torch.nn.functional.interpolate(image_batched, size=dims, mode="bilinear", align_corners=False)
|
||||
|
||||
return resized.squeeze(0)
|
||||
|
||||
|
||||
# TODO(Steven): Consider implementing a pipeline step for this
|
||||
def raw_observation_to_observation(
|
||||
raw_observation: RawObservation,
|
||||
lerobot_features: dict[str, dict],
|
||||
policy_image_features: dict[str, PolicyFeature],
|
||||
) -> Observation:
|
||||
observation = {}
|
||||
|
||||
observation = prepare_raw_observation(raw_observation, lerobot_features, policy_image_features)
|
||||
for k, v in observation.items():
|
||||
if isinstance(v, torch.Tensor): # VLAs present natural-language instructions in observations
|
||||
if "image" in k:
|
||||
# Policy expects images in shape (B, C, H, W)
|
||||
observation[k] = prepare_image(v).unsqueeze(0)
|
||||
else:
|
||||
observation[k] = v
|
||||
|
||||
return observation
|
||||
|
||||
|
||||
def prepare_image(image: torch.Tensor) -> torch.Tensor:
|
||||
"""Minimal preprocessing to turn int8 images to float32 in [0, 1], and create a memory-contiguous tensor"""
|
||||
image = image.type(torch.float32) / 255
|
||||
image = image.contiguous()
|
||||
|
||||
return image
|
||||
|
||||
|
||||
def extract_state_from_raw_observation(
|
||||
lerobot_obs: RawObservation,
|
||||
) -> torch.Tensor:
|
||||
"""Extract the state from a raw observation."""
|
||||
state = torch.tensor(lerobot_obs[OBS_STATE])
|
||||
|
||||
if state.ndim == 1:
|
||||
state = state.unsqueeze(0)
|
||||
|
||||
return state
|
||||
|
||||
|
||||
def extract_images_from_raw_observation(
|
||||
lerobot_obs: RawObservation,
|
||||
camera_key: str,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""Extract the images from a raw observation."""
|
||||
return torch.tensor(lerobot_obs[camera_key])
|
||||
|
||||
|
||||
def make_lerobot_observation(
|
||||
robot_obs: RawObservation,
|
||||
lerobot_features: dict[str, dict],
|
||||
) -> LeRobotObservation:
|
||||
"""Make a lerobot observation from a raw observation."""
|
||||
return build_dataset_frame(lerobot_features, robot_obs, prefix=OBS_STR)
|
||||
|
||||
|
||||
def prepare_raw_observation(
|
||||
robot_obs: RawObservation,
|
||||
lerobot_features: dict[str, dict],
|
||||
policy_image_features: dict[str, PolicyFeature],
|
||||
) -> Observation:
|
||||
"""Matches keys from the raw robot_obs dict to the keys expected by a given policy (passed as
|
||||
policy_image_features)."""
|
||||
# 1. {motor.pos1:value1, motor.pos2:value2, ..., laptop:np.ndarray} ->
|
||||
# -> {observation.state:[value1,value2,...], observation.images.laptop:np.ndarray}
|
||||
lerobot_obs = make_lerobot_observation(robot_obs, lerobot_features)
|
||||
|
||||
# 2. Greps all observation.images.<> keys
|
||||
image_keys = list(filter(is_image_key, lerobot_obs))
|
||||
# state's shape is expected as (B, state_dim)
|
||||
state_dict = {OBS_STATE: extract_state_from_raw_observation(lerobot_obs)}
|
||||
image_dict = {
|
||||
image_k: extract_images_from_raw_observation(lerobot_obs, image_k) for image_k in image_keys
|
||||
}
|
||||
|
||||
# Turns the image features to (C, H, W) with H, W matching the policy image features.
|
||||
# This reduces the resolution of the images
|
||||
image_dict = {
|
||||
key: resize_robot_observation_image(torch.tensor(lerobot_obs[key]), policy_image_features[key].shape)
|
||||
for key in image_keys
|
||||
}
|
||||
|
||||
if "task" in robot_obs:
|
||||
state_dict["task"] = robot_obs["task"]
|
||||
|
||||
return {**state_dict, **image_dict}
|
||||
|
||||
|
||||
def get_logger(name: str, log_to_file: bool = True) -> logging.Logger:
|
||||
"""
|
||||
Get a logger using the standardized logging setup from utils.py.
|
||||
|
||||
Args:
|
||||
name: Logger name (e.g., 'policy_server', 'robot_client')
|
||||
log_to_file: Whether to also log to a file
|
||||
|
||||
Returns:
|
||||
Configured logger instance
|
||||
"""
|
||||
# Create logs directory if logging to file
|
||||
if log_to_file:
|
||||
os.makedirs("logs", exist_ok=True)
|
||||
log_file = Path(f"logs/{name}_{int(time.time())}.log")
|
||||
else:
|
||||
log_file = None
|
||||
|
||||
# Initialize the standardized logging
|
||||
init_logging(log_file=log_file, display_pid=False)
|
||||
|
||||
# Return a named logger
|
||||
return logging.getLogger(name)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TimedData:
|
||||
"""A data object with timestamp and timestep information.
|
||||
|
||||
Args:
|
||||
timestamp: Unix timestamp relative to data's creation.
|
||||
data: The actual data to wrap a timestamp around.
|
||||
timestep: The timestep of the data.
|
||||
"""
|
||||
|
||||
timestamp: float
|
||||
timestep: int
|
||||
|
||||
def get_timestamp(self):
|
||||
return self.timestamp
|
||||
|
||||
def get_timestep(self):
|
||||
return self.timestep
|
||||
|
||||
|
||||
@dataclass
|
||||
class TimedAction(TimedData):
|
||||
action: Action
|
||||
|
||||
def get_action(self):
|
||||
return self.action
|
||||
|
||||
|
||||
@dataclass
|
||||
class TimedObservation(TimedData):
|
||||
observation: RawObservation
|
||||
must_go: bool = False
|
||||
|
||||
def get_observation(self):
|
||||
return self.observation
|
||||
|
||||
|
||||
@dataclass
|
||||
class FPSTracker:
|
||||
"""Utility class to track FPS metrics over time."""
|
||||
|
||||
target_fps: float
|
||||
first_timestamp: float = None
|
||||
total_obs_count: int = 0
|
||||
|
||||
def calculate_fps_metrics(self, current_timestamp: float) -> dict[str, float]:
|
||||
"""Calculate average FPS vs target"""
|
||||
self.total_obs_count += 1
|
||||
|
||||
# Initialize first observation time
|
||||
if self.first_timestamp is None:
|
||||
self.first_timestamp = current_timestamp
|
||||
|
||||
# Calculate overall average FPS (since start)
|
||||
total_duration = current_timestamp - self.first_timestamp
|
||||
avg_fps = (self.total_obs_count - 1) / total_duration if total_duration > 1e-6 else 0.0
|
||||
|
||||
return {"avg_fps": avg_fps, "target_fps": self.target_fps}
|
||||
|
||||
def reset(self):
|
||||
"""Reset the FPS tracker state"""
|
||||
self.first_timestamp = None
|
||||
self.total_obs_count = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class RemotePolicyConfig:
|
||||
policy_type: str
|
||||
pretrained_name_or_path: str
|
||||
lerobot_features: dict[str, PolicyFeature]
|
||||
actions_per_chunk: int
|
||||
device: str = "cpu"
|
||||
rename_map: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
|
||||
def _compare_observation_states(obs1_state: torch.Tensor, obs2_state: torch.Tensor, atol: float) -> bool:
|
||||
"""Check if two observation states are similar, under a tolerance threshold"""
|
||||
return bool(torch.linalg.norm(obs1_state - obs2_state) < atol)
|
||||
|
||||
|
||||
def observations_similar(
|
||||
obs1: TimedObservation, obs2: TimedObservation, lerobot_features: dict[str, dict], atol: float = 1
|
||||
) -> bool:
|
||||
"""Check if two observations are similar, under a tolerance threshold. Measures distance between
|
||||
observations as the difference in joint-space between the two observations.
|
||||
|
||||
NOTE(fracapuano): This is a very simple check, and it is enough for the current use case.
|
||||
An immediate next step is to use (fast) perceptual difference metrics comparing some camera views,
|
||||
to surpass this joint-space similarity check.
|
||||
"""
|
||||
obs1_state = extract_state_from_raw_observation(
|
||||
make_lerobot_observation(obs1.get_observation(), lerobot_features)
|
||||
)
|
||||
obs2_state = extract_state_from_raw_observation(
|
||||
make_lerobot_observation(obs2.get_observation(), lerobot_features)
|
||||
)
|
||||
|
||||
return _compare_observation_states(obs1_state, obs2_state, atol=atol)
|
||||
@@ -1,439 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Example:
|
||||
```shell
|
||||
python -m lerobot.async_inference.policy_server \
|
||||
--host=127.0.0.1 \
|
||||
--port=8080 \
|
||||
--fps=30 \
|
||||
--inference_latency=0.033 \
|
||||
--obs_queue_timeout=1
|
||||
```
|
||||
"""
|
||||
|
||||
import logging
|
||||
import pickle # nosec
|
||||
import threading
|
||||
import time
|
||||
from concurrent import futures
|
||||
from dataclasses import asdict
|
||||
from pprint import pformat
|
||||
from queue import Empty, Queue
|
||||
from typing import Any
|
||||
|
||||
import draccus
|
||||
import grpc
|
||||
import torch
|
||||
|
||||
from lerobot.policies import get_policy_class, make_pre_post_processors
|
||||
from lerobot.processor import PolicyProcessorPipeline
|
||||
from lerobot.transport import (
|
||||
services_pb2, # type: ignore
|
||||
services_pb2_grpc, # type: ignore
|
||||
)
|
||||
from lerobot.transport.utils import receive_bytes_in_chunks
|
||||
from lerobot.types import PolicyAction
|
||||
|
||||
from .configs import PolicyServerConfig
|
||||
from .constants import SUPPORTED_POLICIES
|
||||
from .helpers import (
|
||||
FPSTracker,
|
||||
Observation,
|
||||
RemotePolicyConfig,
|
||||
TimedAction,
|
||||
TimedObservation,
|
||||
get_logger,
|
||||
observations_similar,
|
||||
raw_observation_to_observation,
|
||||
)
|
||||
|
||||
|
||||
class PolicyServer(services_pb2_grpc.AsyncInferenceServicer):
|
||||
prefix = "policy_server"
|
||||
logger = get_logger(prefix)
|
||||
|
||||
def __init__(self, config: PolicyServerConfig):
|
||||
self.config = config
|
||||
self.shutdown_event = threading.Event()
|
||||
|
||||
# FPS measurement
|
||||
self.fps_tracker = FPSTracker(target_fps=config.fps)
|
||||
|
||||
self.observation_queue = Queue(maxsize=1)
|
||||
|
||||
self._predicted_timesteps_lock = threading.Lock()
|
||||
self._predicted_timesteps = set()
|
||||
|
||||
self.last_processed_obs = None
|
||||
|
||||
# Attributes will be set by SendPolicyInstructions
|
||||
self.device = None
|
||||
self.policy_type = None
|
||||
self.lerobot_features = None
|
||||
self.actions_per_chunk = None
|
||||
self.policy = None
|
||||
self.preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]] | None = None
|
||||
self.postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction] | None = None
|
||||
|
||||
@property
|
||||
def running(self):
|
||||
return not self.shutdown_event.is_set()
|
||||
|
||||
@property
|
||||
def policy_image_features(self):
|
||||
return self.policy.config.image_features
|
||||
|
||||
def _reset_server(self) -> None:
|
||||
"""Flushes server state when new client connects."""
|
||||
# only running inference on the latest observation received by the server
|
||||
self.shutdown_event.set()
|
||||
self.observation_queue = Queue(maxsize=1)
|
||||
|
||||
with self._predicted_timesteps_lock:
|
||||
self._predicted_timesteps = set()
|
||||
|
||||
def Ready(self, request, context): # noqa: N802
|
||||
client_id = context.peer()
|
||||
self.logger.info(f"Client {client_id} connected and ready")
|
||||
self._reset_server()
|
||||
self.shutdown_event.clear()
|
||||
|
||||
return services_pb2.Empty()
|
||||
|
||||
def SendPolicyInstructions(self, request, context): # noqa: N802
|
||||
"""Receive policy instructions from the robot client"""
|
||||
|
||||
if not self.running:
|
||||
self.logger.warning("Server is not running. Ignoring policy instructions.")
|
||||
return services_pb2.Empty()
|
||||
|
||||
client_id = context.peer()
|
||||
|
||||
policy_specs = pickle.loads(request.data) # nosec
|
||||
|
||||
if not isinstance(policy_specs, RemotePolicyConfig):
|
||||
raise TypeError(f"Policy specs must be a RemotePolicyConfig. Got {type(policy_specs)}")
|
||||
|
||||
if policy_specs.policy_type not in SUPPORTED_POLICIES:
|
||||
raise ValueError(
|
||||
f"Policy type {policy_specs.policy_type} not supported. "
|
||||
f"Supported policies: {SUPPORTED_POLICIES}"
|
||||
)
|
||||
|
||||
self.logger.info(
|
||||
f"Receiving policy instructions from {client_id} | "
|
||||
f"Policy type: {policy_specs.policy_type} | "
|
||||
f"Pretrained name or path: {policy_specs.pretrained_name_or_path} | "
|
||||
f"Actions per chunk: {policy_specs.actions_per_chunk} | "
|
||||
f"Device: {policy_specs.device}"
|
||||
)
|
||||
|
||||
self.device = policy_specs.device
|
||||
self.policy_type = policy_specs.policy_type # act, pi0, etc.
|
||||
self.lerobot_features = policy_specs.lerobot_features
|
||||
self.actions_per_chunk = policy_specs.actions_per_chunk
|
||||
|
||||
policy_class = get_policy_class(self.policy_type)
|
||||
|
||||
start = time.perf_counter()
|
||||
self.policy = policy_class.from_pretrained(policy_specs.pretrained_name_or_path)
|
||||
self.policy.to(self.device)
|
||||
|
||||
# Load preprocessor and postprocessor, overriding device to match requested device
|
||||
device_override = {"device": self.device}
|
||||
self.preprocessor, self.postprocessor = make_pre_post_processors(
|
||||
self.policy.config,
|
||||
pretrained_path=policy_specs.pretrained_name_or_path,
|
||||
preprocessor_overrides={
|
||||
"device_processor": device_override,
|
||||
"rename_observations_processor": {"rename_map": policy_specs.rename_map},
|
||||
},
|
||||
postprocessor_overrides={"device_processor": device_override},
|
||||
)
|
||||
|
||||
end = time.perf_counter()
|
||||
|
||||
self.logger.info(f"Time taken to put policy on {self.device}: {end - start:.4f} seconds")
|
||||
|
||||
return services_pb2.Empty()
|
||||
|
||||
def SendObservations(self, request_iterator, context): # noqa: N802
|
||||
"""Receive observations from the robot client"""
|
||||
client_id = context.peer()
|
||||
self.logger.debug(f"Receiving observations from {client_id}")
|
||||
|
||||
receive_time = time.time() # comparing timestamps so need time.time()
|
||||
start_deserialize = time.perf_counter()
|
||||
received_bytes = receive_bytes_in_chunks(
|
||||
request_iterator, None, self.shutdown_event, self.logger
|
||||
) # blocking call while looping over request_iterator
|
||||
timed_observation = pickle.loads(received_bytes) # nosec
|
||||
deserialize_time = time.perf_counter() - start_deserialize
|
||||
|
||||
self.logger.debug(f"Received observation #{timed_observation.get_timestep()}")
|
||||
|
||||
obs_timestep = timed_observation.get_timestep()
|
||||
obs_timestamp = timed_observation.get_timestamp()
|
||||
|
||||
# Calculate FPS metrics
|
||||
fps_metrics = self.fps_tracker.calculate_fps_metrics(obs_timestamp)
|
||||
|
||||
self.logger.debug(
|
||||
f"Received observation #{obs_timestep} | "
|
||||
f"Avg FPS: {fps_metrics['avg_fps']:.2f} | " # fps at which observations are received from client
|
||||
f"Target: {fps_metrics['target_fps']:.2f} | "
|
||||
f"One-way latency: {(receive_time - obs_timestamp) * 1000:.2f}ms"
|
||||
)
|
||||
|
||||
self.logger.debug(
|
||||
f"Server timestamp: {receive_time:.6f} | "
|
||||
f"Client timestamp: {obs_timestamp:.6f} | "
|
||||
f"Deserialization time: {deserialize_time:.6f}s"
|
||||
)
|
||||
|
||||
if not self._enqueue_observation(
|
||||
timed_observation # wrapping a RawObservation
|
||||
):
|
||||
self.logger.debug(f"Observation #{obs_timestep} has been filtered out")
|
||||
|
||||
return services_pb2.Empty()
|
||||
|
||||
def GetActions(self, request, context): # noqa: N802
|
||||
"""Returns actions to the robot client. Actions are sent as a single
|
||||
chunk, containing multiple actions."""
|
||||
client_id = context.peer()
|
||||
self.logger.debug(f"Client {client_id} connected for action streaming")
|
||||
|
||||
# Generate action based on the most recent observation and its timestep
|
||||
try:
|
||||
getactions_starts = time.perf_counter()
|
||||
obs = self.observation_queue.get(timeout=self.config.obs_queue_timeout)
|
||||
self.logger.info(
|
||||
f"Running inference for observation #{obs.get_timestep()} (must_go: {obs.must_go})"
|
||||
)
|
||||
|
||||
with self._predicted_timesteps_lock:
|
||||
self._predicted_timesteps.add(obs.get_timestep())
|
||||
|
||||
start_time = time.perf_counter()
|
||||
action_chunk = self._predict_action_chunk(obs)
|
||||
inference_time = time.perf_counter() - start_time
|
||||
|
||||
start_time = time.perf_counter()
|
||||
actions_bytes = pickle.dumps(action_chunk) # nosec
|
||||
serialize_time = time.perf_counter() - start_time
|
||||
|
||||
# Create and return the action chunk
|
||||
actions = services_pb2.Actions(data=actions_bytes)
|
||||
|
||||
self.logger.info(
|
||||
f"Action chunk #{obs.get_timestep()} generated | "
|
||||
f"Total time: {(inference_time + serialize_time) * 1000:.2f}ms"
|
||||
)
|
||||
|
||||
self.logger.debug(
|
||||
f"Action chunk #{obs.get_timestep()} generated | "
|
||||
f"Inference time: {inference_time:.2f}s |"
|
||||
f"Serialize time: {serialize_time:.2f}s |"
|
||||
f"Total time: {inference_time + serialize_time:.2f}s"
|
||||
)
|
||||
|
||||
time.sleep(
|
||||
max(0, self.config.inference_latency - max(0, time.perf_counter() - getactions_starts))
|
||||
) # sleep controls inference latency
|
||||
|
||||
return actions
|
||||
|
||||
except Empty: # no observation added to queue in obs_queue_timeout
|
||||
return services_pb2.Empty()
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in StreamActions: {e}")
|
||||
|
||||
return services_pb2.Empty()
|
||||
|
||||
def _obs_sanity_checks(self, obs: TimedObservation, previous_obs: TimedObservation) -> bool:
|
||||
"""Check if the observation is valid to be processed by the policy"""
|
||||
with self._predicted_timesteps_lock:
|
||||
predicted_timesteps = self._predicted_timesteps
|
||||
|
||||
if obs.get_timestep() in predicted_timesteps:
|
||||
self.logger.debug(f"Skipping observation #{obs.get_timestep()} - Timestep predicted already!")
|
||||
return False
|
||||
|
||||
elif observations_similar(obs, previous_obs, lerobot_features=self.lerobot_features):
|
||||
self.logger.debug(
|
||||
f"Skipping observation #{obs.get_timestep()} - Observation too similar to last obs predicted!"
|
||||
)
|
||||
return False
|
||||
|
||||
else:
|
||||
return True
|
||||
|
||||
def _enqueue_observation(self, obs: TimedObservation) -> bool:
|
||||
"""Enqueue an observation if it must go through processing, otherwise skip it.
|
||||
Observations not in queue are never run through the policy network"""
|
||||
|
||||
if (
|
||||
obs.must_go
|
||||
or self.last_processed_obs is None
|
||||
or self._obs_sanity_checks(obs, self.last_processed_obs)
|
||||
):
|
||||
last_obs = self.last_processed_obs.get_timestep() if self.last_processed_obs else "None"
|
||||
self.logger.debug(
|
||||
f"Enqueuing observation. Must go: {obs.must_go} | Last processed obs: {last_obs}"
|
||||
)
|
||||
|
||||
# If queue is full, get the old observation to make room
|
||||
if self.observation_queue.full():
|
||||
# pops from queue
|
||||
_ = self.observation_queue.get_nowait()
|
||||
self.logger.debug("Observation queue was full, removed oldest observation")
|
||||
|
||||
# Now put the new observation (never blocks as queue is non-full here)
|
||||
self.observation_queue.put(obs)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _time_action_chunk(self, t_0: float, action_chunk: list[torch.Tensor], i_0: int) -> list[TimedAction]:
|
||||
"""Turn a chunk of actions into a list of TimedAction instances,
|
||||
with the first action corresponding to t_0 and the rest corresponding to
|
||||
t_0 + i*environment_dt for i in range(len(action_chunk))
|
||||
"""
|
||||
return [
|
||||
TimedAction(timestamp=t_0 + i * self.config.environment_dt, timestep=i_0 + i, action=action)
|
||||
for i, action in enumerate(action_chunk)
|
||||
]
|
||||
|
||||
def _get_action_chunk(self, observation: dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
"""Get an action chunk from the policy. The chunk contains only"""
|
||||
chunk = self.policy.predict_action_chunk(observation)
|
||||
if chunk.ndim != 3:
|
||||
chunk = chunk.unsqueeze(0) # adding batch dimension, now shape is (B, chunk_size, action_dim)
|
||||
|
||||
return chunk[:, : self.actions_per_chunk, :]
|
||||
|
||||
def _predict_action_chunk(self, observation_t: TimedObservation) -> list[TimedAction]:
|
||||
"""Predict an action chunk based on an observation.
|
||||
|
||||
Pipeline:
|
||||
1. Convert raw observation to LeRobot format
|
||||
2. Apply preprocessor (tokenization, normalization, batching, device placement)
|
||||
3. Run policy inference to get action chunk
|
||||
4. Apply postprocessor (unnormalization, device movement)
|
||||
5. Convert to TimedAction list
|
||||
"""
|
||||
"""1. Prepare observation"""
|
||||
start_prepare = time.perf_counter()
|
||||
observation: Observation = raw_observation_to_observation(
|
||||
observation_t.get_observation(),
|
||||
self.lerobot_features,
|
||||
self.policy_image_features,
|
||||
)
|
||||
prepare_time = time.perf_counter() - start_prepare
|
||||
|
||||
"""2. Apply preprocessor"""
|
||||
start_preprocess = time.perf_counter()
|
||||
observation = self.preprocessor(observation)
|
||||
self.last_processed_obs: TimedObservation = observation_t
|
||||
preprocessing_time = time.perf_counter() - start_preprocess
|
||||
|
||||
"""3. Get action chunk"""
|
||||
start_inference = time.perf_counter()
|
||||
action_tensor = self._get_action_chunk(observation)
|
||||
inference_time = time.perf_counter() - start_inference
|
||||
self.logger.info(
|
||||
f"Preprocessing and inference took {inference_time:.4f}s, action shape: {action_tensor.shape}"
|
||||
)
|
||||
|
||||
"""4. Apply postprocessor"""
|
||||
# Apply postprocessor (handles unnormalization and device movement)
|
||||
# Postprocessor expects (B, action_dim) per action, but we have (B, chunk_size, action_dim)
|
||||
# So we process each action in the chunk individually
|
||||
start_postprocess = time.perf_counter()
|
||||
_, chunk_size, _ = action_tensor.shape
|
||||
|
||||
# Process each action in the chunk
|
||||
processed_actions = []
|
||||
for i in range(chunk_size):
|
||||
# Extract action at timestep i: (B, action_dim)
|
||||
single_action = action_tensor[:, i, :]
|
||||
processed_action = self.postprocessor(single_action)
|
||||
processed_actions.append(processed_action)
|
||||
|
||||
# Stack back to (B, chunk_size, action_dim), then remove batch dim
|
||||
action_tensor = torch.stack(processed_actions, dim=1).squeeze(0)
|
||||
self.logger.debug(f"Postprocessed action shape: {action_tensor.shape}")
|
||||
|
||||
action_tensor = action_tensor.detach().cpu()
|
||||
|
||||
"""5. Convert to TimedAction list"""
|
||||
action_chunk = self._time_action_chunk(
|
||||
observation_t.get_timestamp(), list(action_tensor), observation_t.get_timestep()
|
||||
)
|
||||
postprocess_stops = time.perf_counter()
|
||||
postprocessing_time = postprocess_stops - start_postprocess
|
||||
|
||||
self.logger.info(
|
||||
f"Observation {observation_t.get_timestep()} | "
|
||||
f"Total time: {1000 * (postprocess_stops - start_prepare):.2f}ms"
|
||||
)
|
||||
|
||||
self.logger.debug(
|
||||
f"Observation {observation_t.get_timestep()} | "
|
||||
f"Prepare time: {1000 * prepare_time:.2f}ms | "
|
||||
f"Preprocessing time: {1000 * preprocessing_time:.2f}ms | "
|
||||
f"Inference time: {1000 * inference_time:.2f}ms | "
|
||||
f"Postprocessing time: {1000 * postprocessing_time:.2f}ms | "
|
||||
f"Total time: {1000 * (postprocess_stops - start_prepare):.2f}ms"
|
||||
)
|
||||
|
||||
return action_chunk
|
||||
|
||||
def stop(self):
|
||||
"""Stop the server"""
|
||||
self._reset_server()
|
||||
self.logger.info("Server stopping...")
|
||||
|
||||
|
||||
@draccus.wrap()
|
||||
def serve(cfg: PolicyServerConfig):
|
||||
"""Start the PolicyServer with the given configuration.
|
||||
|
||||
Args:
|
||||
config: PolicyServerConfig instance. If None, uses default configuration.
|
||||
"""
|
||||
logging.info(pformat(asdict(cfg)))
|
||||
|
||||
# Create the server instance first
|
||||
policy_server = PolicyServer(cfg)
|
||||
|
||||
# Setup and start gRPC server
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=4))
|
||||
services_pb2_grpc.add_AsyncInferenceServicer_to_server(policy_server, server)
|
||||
server.add_insecure_port(f"{cfg.host}:{cfg.port}")
|
||||
|
||||
policy_server.logger.info(f"PolicyServer started on {cfg.host}:{cfg.port}")
|
||||
server.start()
|
||||
|
||||
server.wait_for_termination()
|
||||
|
||||
policy_server.logger.info("Server terminated")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
serve()
|
||||
@@ -1,517 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Example command:
|
||||
```shell
|
||||
python src/lerobot/async_inference/robot_client.py \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58760431541 \
|
||||
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 1920, height: 1080, fps: 30}}" \
|
||||
--robot.id=black \
|
||||
--task="dummy" \
|
||||
--server_address=127.0.0.1:8080 \
|
||||
--policy_type=act \
|
||||
--pretrained_name_or_path=user/model \
|
||||
--policy_device=mps \
|
||||
--client_device=cpu \
|
||||
--actions_per_chunk=50 \
|
||||
--chunk_size_threshold=0.5 \
|
||||
--aggregate_fn_name=weighted_average \
|
||||
--debug_visualize_queue_size=True
|
||||
```
|
||||
"""
|
||||
|
||||
import logging
|
||||
import pickle # nosec
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from dataclasses import asdict
|
||||
from pprint import pformat
|
||||
from queue import Queue
|
||||
from typing import Any
|
||||
|
||||
import draccus
|
||||
import grpc
|
||||
import torch
|
||||
|
||||
from lerobot.cameras.opencv import OpenCVCameraConfig # noqa: F401
|
||||
from lerobot.cameras.realsense import RealSenseCameraConfig # noqa: F401
|
||||
from lerobot.robots import ( # noqa: F401
|
||||
Robot,
|
||||
RobotConfig,
|
||||
bi_so_follower,
|
||||
koch_follower,
|
||||
make_robot_from_config,
|
||||
omx_follower,
|
||||
so_follower,
|
||||
)
|
||||
from lerobot.transport import (
|
||||
services_pb2, # type: ignore
|
||||
services_pb2_grpc, # type: ignore
|
||||
)
|
||||
from lerobot.transport.utils import grpc_channel_options, send_bytes_in_chunks
|
||||
from lerobot.utils.import_utils import register_third_party_plugins
|
||||
|
||||
from .configs import RobotClientConfig
|
||||
from .helpers import (
|
||||
Action,
|
||||
FPSTracker,
|
||||
Observation,
|
||||
RawObservation,
|
||||
RemotePolicyConfig,
|
||||
TimedAction,
|
||||
TimedObservation,
|
||||
get_logger,
|
||||
map_robot_keys_to_lerobot_features,
|
||||
visualize_action_queue_size,
|
||||
)
|
||||
|
||||
|
||||
class RobotClient:
|
||||
prefix = "robot_client"
|
||||
logger = get_logger(prefix)
|
||||
|
||||
def __init__(self, config: RobotClientConfig):
|
||||
"""Initialize RobotClient with unified configuration.
|
||||
|
||||
Args:
|
||||
config: RobotClientConfig containing all configuration parameters
|
||||
"""
|
||||
# Store configuration
|
||||
self.config = config
|
||||
self.robot = make_robot_from_config(config.robot)
|
||||
self.robot.connect()
|
||||
|
||||
lerobot_features = map_robot_keys_to_lerobot_features(self.robot)
|
||||
|
||||
# Use environment variable if server_address is not provided in config
|
||||
self.server_address = config.server_address
|
||||
|
||||
self.policy_config = RemotePolicyConfig(
|
||||
config.policy_type,
|
||||
config.pretrained_name_or_path,
|
||||
lerobot_features,
|
||||
config.actions_per_chunk,
|
||||
config.policy_device,
|
||||
)
|
||||
self.channel = grpc.insecure_channel(
|
||||
self.server_address, grpc_channel_options(initial_backoff=f"{config.environment_dt:.4f}s")
|
||||
)
|
||||
self.stub = services_pb2_grpc.AsyncInferenceStub(self.channel)
|
||||
self.logger.info(f"Initializing client to connect to server at {self.server_address}")
|
||||
|
||||
self.shutdown_event = threading.Event()
|
||||
|
||||
# Initialize client side variables
|
||||
self.latest_action_lock = threading.Lock()
|
||||
self.latest_action = -1
|
||||
self.action_chunk_size = -1
|
||||
|
||||
self._chunk_size_threshold = config.chunk_size_threshold
|
||||
|
||||
self.action_queue = Queue()
|
||||
self.action_queue_lock = threading.Lock() # Protect queue operations
|
||||
self.action_queue_size = []
|
||||
self.start_barrier = threading.Barrier(2) # 2 threads: action receiver, control loop
|
||||
|
||||
# FPS measurement
|
||||
self.fps_tracker = FPSTracker(target_fps=self.config.fps)
|
||||
|
||||
self.logger.info("Robot connected and ready")
|
||||
|
||||
# Use an event for thread-safe coordination
|
||||
self.must_go = threading.Event()
|
||||
self.must_go.set() # Initially set - observations qualify for direct processing
|
||||
|
||||
@property
|
||||
def running(self):
|
||||
return not self.shutdown_event.is_set()
|
||||
|
||||
def start(self):
|
||||
"""Start the robot client and connect to the policy server"""
|
||||
try:
|
||||
# client-server handshake
|
||||
start_time = time.perf_counter()
|
||||
self.stub.Ready(services_pb2.Empty())
|
||||
end_time = time.perf_counter()
|
||||
self.logger.debug(f"Connected to policy server in {end_time - start_time:.4f}s")
|
||||
|
||||
# send policy instructions
|
||||
policy_config_bytes = pickle.dumps(self.policy_config)
|
||||
policy_setup = services_pb2.PolicySetup(data=policy_config_bytes)
|
||||
|
||||
self.logger.info("Sending policy instructions to policy server")
|
||||
self.logger.debug(
|
||||
f"Policy type: {self.policy_config.policy_type} | "
|
||||
f"Pretrained name or path: {self.policy_config.pretrained_name_or_path} | "
|
||||
f"Device: {self.policy_config.device}"
|
||||
)
|
||||
|
||||
self.stub.SendPolicyInstructions(policy_setup)
|
||||
|
||||
self.shutdown_event.clear()
|
||||
|
||||
return True
|
||||
|
||||
except grpc.RpcError as e:
|
||||
self.logger.error(f"Failed to connect to policy server: {e}")
|
||||
return False
|
||||
|
||||
def stop(self):
|
||||
"""Stop the robot client"""
|
||||
self.shutdown_event.set()
|
||||
|
||||
self.robot.disconnect()
|
||||
self.logger.debug("Robot disconnected")
|
||||
|
||||
self.channel.close()
|
||||
self.logger.debug("Client stopped, channel closed")
|
||||
|
||||
def send_observation(
|
||||
self,
|
||||
obs: TimedObservation,
|
||||
) -> bool:
|
||||
"""Send observation to the policy server.
|
||||
Returns True if the observation was sent successfully, False otherwise."""
|
||||
if not self.running:
|
||||
raise RuntimeError("Client not running. Run RobotClient.start() before sending observations.")
|
||||
|
||||
if not isinstance(obs, TimedObservation):
|
||||
raise ValueError("Input observation needs to be a TimedObservation!")
|
||||
|
||||
start_time = time.perf_counter()
|
||||
observation_bytes = pickle.dumps(obs)
|
||||
serialize_time = time.perf_counter() - start_time
|
||||
self.logger.debug(f"Observation serialization time: {serialize_time:.6f}s")
|
||||
|
||||
try:
|
||||
observation_iterator = send_bytes_in_chunks(
|
||||
observation_bytes,
|
||||
services_pb2.Observation,
|
||||
log_prefix="[CLIENT] Observation",
|
||||
silent=True,
|
||||
)
|
||||
_ = self.stub.SendObservations(observation_iterator)
|
||||
obs_timestep = obs.get_timestep()
|
||||
self.logger.debug(f"Sent observation #{obs_timestep} | ")
|
||||
|
||||
return True
|
||||
|
||||
except grpc.RpcError as e:
|
||||
self.logger.error(f"Error sending observation #{obs.get_timestep()}: {e}")
|
||||
return False
|
||||
|
||||
def _inspect_action_queue(self):
|
||||
with self.action_queue_lock:
|
||||
queue_size = self.action_queue.qsize()
|
||||
timestamps = sorted([action.get_timestep() for action in self.action_queue.queue])
|
||||
self.logger.debug(f"Queue size: {queue_size}, Queue contents: {timestamps}")
|
||||
return queue_size, timestamps
|
||||
|
||||
def _aggregate_action_queues(
|
||||
self,
|
||||
incoming_actions: list[TimedAction],
|
||||
aggregate_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
|
||||
):
|
||||
"""Finds the same timestep actions in the queue and aggregates them using the aggregate_fn"""
|
||||
if aggregate_fn is None:
|
||||
# default aggregate function: take the latest action
|
||||
def aggregate_fn(x1, x2):
|
||||
return x2
|
||||
|
||||
future_action_queue = Queue()
|
||||
with self.action_queue_lock:
|
||||
internal_queue = self.action_queue.queue
|
||||
|
||||
current_action_queue = {action.get_timestep(): action.get_action() for action in internal_queue}
|
||||
|
||||
for new_action in incoming_actions:
|
||||
with self.latest_action_lock:
|
||||
latest_action = self.latest_action
|
||||
|
||||
# New action is older than the latest action in the queue, skip it
|
||||
if new_action.get_timestep() <= latest_action:
|
||||
continue
|
||||
|
||||
# If the new action's timestep is not in the current action queue, add it directly
|
||||
elif new_action.get_timestep() not in current_action_queue:
|
||||
future_action_queue.put(new_action)
|
||||
continue
|
||||
|
||||
# If the new action's timestep is in the current action queue, aggregate it
|
||||
# TODO: There is probably a way to do this with broadcasting of the two action tensors
|
||||
future_action_queue.put(
|
||||
TimedAction(
|
||||
timestamp=new_action.get_timestamp(),
|
||||
timestep=new_action.get_timestep(),
|
||||
action=aggregate_fn(
|
||||
current_action_queue[new_action.get_timestep()], new_action.get_action()
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
with self.action_queue_lock:
|
||||
self.action_queue = future_action_queue
|
||||
|
||||
def receive_actions(self, verbose: bool = False):
|
||||
"""Receive actions from the policy server"""
|
||||
# Wait at barrier for synchronized start
|
||||
self.start_barrier.wait()
|
||||
self.logger.info("Action receiving thread starting")
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
# Use StreamActions to get a stream of actions from the server
|
||||
actions_chunk = self.stub.GetActions(services_pb2.Empty())
|
||||
if len(actions_chunk.data) == 0:
|
||||
continue # received `Empty` from server, wait for next call
|
||||
|
||||
receive_time = time.time()
|
||||
|
||||
# Deserialize bytes back into list[TimedAction]
|
||||
deserialize_start = time.perf_counter()
|
||||
timed_actions = pickle.loads(actions_chunk.data) # nosec
|
||||
deserialize_time = time.perf_counter() - deserialize_start
|
||||
|
||||
# Log device type of received actions
|
||||
if len(timed_actions) > 0:
|
||||
received_device = timed_actions[0].get_action().device.type
|
||||
self.logger.debug(f"Received actions on device: {received_device}")
|
||||
|
||||
# Move actions to client_device (e.g., for downstream planners that need GPU)
|
||||
client_device = self.config.client_device
|
||||
if client_device != "cpu":
|
||||
for timed_action in timed_actions:
|
||||
if timed_action.get_action().device.type != client_device:
|
||||
timed_action.action = timed_action.get_action().to(client_device)
|
||||
self.logger.debug(f"Converted actions to device: {client_device}")
|
||||
else:
|
||||
self.logger.debug(f"Actions kept on device: {client_device}")
|
||||
|
||||
self.action_chunk_size = max(self.action_chunk_size, len(timed_actions))
|
||||
|
||||
# Calculate network latency if we have matching observations
|
||||
if len(timed_actions) > 0 and verbose:
|
||||
with self.latest_action_lock:
|
||||
latest_action = self.latest_action
|
||||
|
||||
self.logger.debug(f"Current latest action: {latest_action}")
|
||||
|
||||
# Get queue state before changes
|
||||
old_size, old_timesteps = self._inspect_action_queue()
|
||||
if not old_timesteps:
|
||||
old_timesteps = [latest_action] # queue was empty
|
||||
|
||||
# Log incoming actions
|
||||
incoming_timesteps = [a.get_timestep() for a in timed_actions]
|
||||
|
||||
first_action_timestep = timed_actions[0].get_timestep()
|
||||
server_to_client_latency = (receive_time - timed_actions[0].get_timestamp()) * 1000
|
||||
|
||||
self.logger.info(
|
||||
f"Received action chunk for step #{first_action_timestep} | "
|
||||
f"Latest action: #{latest_action} | "
|
||||
f"Incoming actions: {incoming_timesteps[0]}:{incoming_timesteps[-1]} | "
|
||||
f"Network latency (server->client): {server_to_client_latency:.2f}ms | "
|
||||
f"Deserialization time: {deserialize_time * 1000:.2f}ms"
|
||||
)
|
||||
|
||||
# Update action queue
|
||||
start_time = time.perf_counter()
|
||||
self._aggregate_action_queues(timed_actions, self.config.aggregate_fn)
|
||||
queue_update_time = time.perf_counter() - start_time
|
||||
|
||||
self.must_go.set() # after receiving actions, next empty queue triggers must-go processing!
|
||||
|
||||
if verbose:
|
||||
# Get queue state after changes
|
||||
new_size, new_timesteps = self._inspect_action_queue()
|
||||
|
||||
with self.latest_action_lock:
|
||||
latest_action = self.latest_action
|
||||
|
||||
self.logger.info(
|
||||
f"Latest action: {latest_action} | "
|
||||
f"Old action steps: {old_timesteps[0]}:{old_timesteps[-1]} | "
|
||||
f"Incoming action steps: {incoming_timesteps[0]}:{incoming_timesteps[-1]} | "
|
||||
f"Updated action steps: {new_timesteps[0]}:{new_timesteps[-1]}"
|
||||
)
|
||||
self.logger.debug(
|
||||
f"Queue update complete ({queue_update_time:.6f}s) | "
|
||||
f"Before: {old_size} items | "
|
||||
f"After: {new_size} items | "
|
||||
)
|
||||
|
||||
except grpc.RpcError as e:
|
||||
self.logger.error(f"Error receiving actions: {e}")
|
||||
|
||||
def actions_available(self):
|
||||
"""Check if there are actions available in the queue"""
|
||||
with self.action_queue_lock:
|
||||
return not self.action_queue.empty()
|
||||
|
||||
def _action_tensor_to_action_dict(self, action_tensor: torch.Tensor) -> dict[str, float]:
|
||||
action = {key: action_tensor[i].item() for i, key in enumerate(self.robot.action_features)}
|
||||
return action
|
||||
|
||||
def control_loop_action(self, verbose: bool = False) -> dict[str, Any]:
|
||||
"""Reading and performing actions in local queue"""
|
||||
|
||||
# Lock only for queue operations
|
||||
get_start = time.perf_counter()
|
||||
with self.action_queue_lock:
|
||||
self.action_queue_size.append(self.action_queue.qsize())
|
||||
# Get action from queue
|
||||
timed_action = self.action_queue.get_nowait()
|
||||
get_end = time.perf_counter() - get_start
|
||||
|
||||
_performed_action = self.robot.send_action(
|
||||
self._action_tensor_to_action_dict(timed_action.get_action())
|
||||
)
|
||||
with self.latest_action_lock:
|
||||
self.latest_action = timed_action.get_timestep()
|
||||
|
||||
if verbose:
|
||||
with self.action_queue_lock:
|
||||
current_queue_size = self.action_queue.qsize()
|
||||
|
||||
self.logger.debug(
|
||||
f"Ts={timed_action.get_timestamp()} | "
|
||||
f"Action #{timed_action.get_timestep()} performed | "
|
||||
f"Queue size: {current_queue_size}"
|
||||
)
|
||||
|
||||
self.logger.debug(
|
||||
f"Popping action from queue to perform took {get_end:.6f}s | Queue size: {current_queue_size}"
|
||||
)
|
||||
|
||||
return _performed_action
|
||||
|
||||
def _ready_to_send_observation(self):
|
||||
"""Flags when the client is ready to send an observation"""
|
||||
with self.action_queue_lock:
|
||||
return self.action_queue.qsize() / self.action_chunk_size <= self._chunk_size_threshold
|
||||
|
||||
def control_loop_observation(self, task: str, verbose: bool = False) -> RawObservation:
|
||||
try:
|
||||
# Get serialized observation bytes from the function
|
||||
start_time = time.perf_counter()
|
||||
|
||||
raw_observation: RawObservation = self.robot.get_observation()
|
||||
raw_observation["task"] = task
|
||||
|
||||
with self.latest_action_lock:
|
||||
latest_action = self.latest_action
|
||||
|
||||
observation = TimedObservation(
|
||||
timestamp=time.time(), # need time.time() to compare timestamps across client and server
|
||||
observation=raw_observation,
|
||||
timestep=max(latest_action, 0),
|
||||
)
|
||||
|
||||
obs_capture_time = time.perf_counter() - start_time
|
||||
|
||||
# If there are no actions left in the queue, the observation must go through processing!
|
||||
with self.action_queue_lock:
|
||||
observation.must_go = self.must_go.is_set() and self.action_queue.empty()
|
||||
current_queue_size = self.action_queue.qsize()
|
||||
|
||||
_ = self.send_observation(observation)
|
||||
|
||||
self.logger.debug(f"QUEUE SIZE: {current_queue_size} (Must go: {observation.must_go})")
|
||||
if observation.must_go:
|
||||
# must-go event will be set again after receiving actions
|
||||
self.must_go.clear()
|
||||
|
||||
if verbose:
|
||||
# Calculate comprehensive FPS metrics
|
||||
fps_metrics = self.fps_tracker.calculate_fps_metrics(observation.get_timestamp())
|
||||
|
||||
self.logger.info(
|
||||
f"Obs #{observation.get_timestep()} | "
|
||||
f"Avg FPS: {fps_metrics['avg_fps']:.2f} | "
|
||||
f"Target: {fps_metrics['target_fps']:.2f}"
|
||||
)
|
||||
|
||||
self.logger.debug(
|
||||
f"Ts={observation.get_timestamp():.6f} | Capturing observation took {obs_capture_time:.6f}s"
|
||||
)
|
||||
|
||||
return raw_observation
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in observation sender: {e}")
|
||||
|
||||
def control_loop(self, task: str, verbose: bool = False) -> tuple[Observation, Action]:
|
||||
"""Combined function for executing actions and streaming observations"""
|
||||
# Wait at barrier for synchronized start
|
||||
self.start_barrier.wait()
|
||||
self.logger.info("Control loop thread starting")
|
||||
|
||||
_performed_action = None
|
||||
_captured_observation = None
|
||||
|
||||
while self.running:
|
||||
control_loop_start = time.perf_counter()
|
||||
"""Control loop: (1) Performing actions, when available"""
|
||||
if self.actions_available():
|
||||
_performed_action = self.control_loop_action(verbose)
|
||||
|
||||
"""Control loop: (2) Streaming observations to the remote policy server"""
|
||||
if self._ready_to_send_observation():
|
||||
_captured_observation = self.control_loop_observation(task, verbose)
|
||||
|
||||
self.logger.debug(f"Control loop (ms): {(time.perf_counter() - control_loop_start) * 1000:.2f}")
|
||||
# Dynamically adjust sleep time to maintain the desired control frequency
|
||||
time.sleep(max(0, self.config.environment_dt - (time.perf_counter() - control_loop_start)))
|
||||
|
||||
return _captured_observation, _performed_action
|
||||
|
||||
|
||||
@draccus.wrap()
|
||||
def async_client(cfg: RobotClientConfig):
|
||||
logging.info(pformat(asdict(cfg)))
|
||||
|
||||
# TODO: Assert if checking robot support is still needed with the plugin system
|
||||
# if cfg.robot.type not in SUPPORTED_ROBOTS:
|
||||
# raise ValueError(f"Robot {cfg.robot.type} not yet supported!")
|
||||
|
||||
client = RobotClient(cfg)
|
||||
|
||||
if client.start():
|
||||
client.logger.info("Starting action receiver thread...")
|
||||
|
||||
# Create and start action receiver thread
|
||||
action_receiver_thread = threading.Thread(target=client.receive_actions, daemon=True)
|
||||
|
||||
# Start action receiver thread
|
||||
action_receiver_thread.start()
|
||||
|
||||
try:
|
||||
# The main thread runs the control loop
|
||||
client.control_loop(task=cfg.task)
|
||||
|
||||
finally:
|
||||
client.stop()
|
||||
action_receiver_thread.join()
|
||||
if cfg.debug_visualize_queue_size:
|
||||
visualize_action_queue_size(client.action_queue_size)
|
||||
client.logger.info("Client stopped")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
register_third_party_plugins()
|
||||
async_client() # run the client
|
||||
@@ -205,149 +205,3 @@ class WandBLogger:
|
||||
|
||||
wandb_video = self._wandb.Video(video_path, fps=self.env_fps, format="mp4")
|
||||
self._wandb.log({f"{mode}/video": wandb_video}, step=step)
|
||||
|
||||
def log_training_examples(
|
||||
self,
|
||||
batch: dict,
|
||||
step: int,
|
||||
*,
|
||||
camera_keys: list[str],
|
||||
n_samples: int = 4,
|
||||
policy=None,
|
||||
predict_actions: bool = False,
|
||||
mode: str = "train",
|
||||
) -> None:
|
||||
"""Push a ``wandb.Table`` of training-example rows for the current batch.
|
||||
|
||||
Each row is one batch element with:
|
||||
* one ``wandb.Image`` column per camera in ``camera_keys`` (CHW or
|
||||
HWC, uint8 or float in [0,1] — auto-detected),
|
||||
* any text fields present in the batch (``task`` / ``subtask`` /
|
||||
``memory`` / ``instruction``),
|
||||
* ground-truth action first/last frame (the action chunk's
|
||||
endpoints — gives a quick sense of trajectory direction),
|
||||
* if ``predict_actions=True`` and ``policy`` is supplied, the model's
|
||||
``predict_action_chunk`` first/last frame alongside.
|
||||
|
||||
This is opt-in via ``--wandb.log_examples_freq=N`` on the CLI; the
|
||||
training loop calls it once every N steps. Cheap to keep on: with
|
||||
N=4 samples and 3 cameras you upload 12 small PNGs per dump and (if
|
||||
enabled) run one extra inference forward pass.
|
||||
"""
|
||||
import logging # noqa: PLC0415
|
||||
import numpy as np # noqa: PLC0415
|
||||
import torch # noqa: PLC0415
|
||||
|
||||
if mode not in {"train", "eval"}:
|
||||
raise ValueError(mode)
|
||||
|
||||
# Batch size — first tensor-like value wins.
|
||||
bsz = next(
|
||||
(int(v.shape[0]) for v in batch.values() if hasattr(v, "shape") and v.ndim > 0),
|
||||
None,
|
||||
)
|
||||
if not bsz:
|
||||
return
|
||||
n = min(int(n_samples), bsz)
|
||||
|
||||
# Optional predicted-action forward pass on the first n samples.
|
||||
pred_actions: np.ndarray | None = None
|
||||
if predict_actions and policy is not None:
|
||||
was_training = policy.training
|
||||
try:
|
||||
policy.eval()
|
||||
sub_batch = {}
|
||||
for k, v in batch.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
sub_batch[k] = v[:n]
|
||||
elif isinstance(v, (list, tuple)):
|
||||
sub_batch[k] = list(v[:n])
|
||||
else:
|
||||
sub_batch[k] = v
|
||||
with torch.no_grad():
|
||||
pred = policy.predict_action_chunk(sub_batch)
|
||||
pred_actions = pred.detach().cpu().float().numpy()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logging.warning(
|
||||
"log_training_examples: predict_action_chunk failed (%s) — "
|
||||
"skipping predicted-action columns",
|
||||
exc,
|
||||
)
|
||||
pred_actions = None
|
||||
finally:
|
||||
if was_training:
|
||||
policy.train()
|
||||
|
||||
present_cameras = [c for c in camera_keys if c in batch]
|
||||
text_keys = [k for k in ("task", "subtask", "memory", "instruction") if k in batch]
|
||||
|
||||
columns = ["sample"]
|
||||
columns.extend(c.removeprefix("observation.images.") or c for c in present_cameras)
|
||||
columns.extend(text_keys)
|
||||
columns.append("gt_action_first")
|
||||
columns.append("gt_action_last")
|
||||
if pred_actions is not None:
|
||||
columns.append("pred_action_first")
|
||||
columns.append("pred_action_last")
|
||||
|
||||
table = self._wandb.Table(columns=columns)
|
||||
|
||||
def _to_uint8_hwc(t: torch.Tensor) -> np.ndarray:
|
||||
# Strip an outer time dim if present: (T, C, H, W) -> first frame.
|
||||
if t.ndim == 4:
|
||||
t = t[0]
|
||||
# CHW -> HWC.
|
||||
if t.ndim == 3 and t.shape[0] in (1, 3, 4) and t.shape[-1] not in (1, 3, 4):
|
||||
t = t.permute(1, 2, 0)
|
||||
arr = t.detach().cpu().float().numpy()
|
||||
if arr.size and float(arr.max()) <= 1.5:
|
||||
arr = arr * 255.0
|
||||
return np.clip(arr, 0, 255).astype(np.uint8)
|
||||
|
||||
def _action_endpoints(a: torch.Tensor) -> tuple[str, str]:
|
||||
arr = a.detach().cpu().float().numpy()
|
||||
if arr.ndim == 2: # (T, D)
|
||||
return (
|
||||
str(np.round(arr[0], 3).tolist()),
|
||||
str(np.round(arr[-1], 3).tolist()),
|
||||
)
|
||||
if arr.ndim == 1:
|
||||
rounded = np.round(arr, 3).tolist()
|
||||
return (str(rounded), str(rounded))
|
||||
return (str(arr.tolist()), str(arr.tolist()))
|
||||
|
||||
for i in range(n):
|
||||
row: list = [i]
|
||||
for cam in present_cameras:
|
||||
try:
|
||||
row.append(self._wandb.Image(_to_uint8_hwc(batch[cam][i])))
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logging.warning(
|
||||
"log_training_examples: camera %s sample %d failed (%s)",
|
||||
cam,
|
||||
i,
|
||||
exc,
|
||||
)
|
||||
row.append(None)
|
||||
for tk in text_keys:
|
||||
v = batch[tk]
|
||||
if isinstance(v, (list, tuple)):
|
||||
row.append(str(v[i]) if i < len(v) else "")
|
||||
else:
|
||||
row.append(str(v))
|
||||
action = batch.get("action")
|
||||
if isinstance(action, torch.Tensor) and action.ndim >= 1:
|
||||
first, last = _action_endpoints(action[i])
|
||||
row.append(first)
|
||||
row.append(last)
|
||||
else:
|
||||
row.append("")
|
||||
row.append("")
|
||||
if pred_actions is not None:
|
||||
p = torch.from_numpy(pred_actions[i])
|
||||
pfirst, plast = _action_endpoints(p)
|
||||
row.append(pfirst)
|
||||
row.append(plast)
|
||||
table.add_data(*row)
|
||||
|
||||
self._wandb.log({f"{mode}/examples": table}, step=step)
|
||||
|
||||
@@ -62,72 +62,6 @@ class WandBConfig:
|
||||
run_id: str | None = None
|
||||
mode: str | None = None # Allowed values: 'online', 'offline' 'disabled'. Defaults to 'online'
|
||||
add_tags: bool = True # If True, save configuration as tags in the WandB run.
|
||||
# Periodic training-example dump (independent of ``log_freq``). When > 0,
|
||||
# every ``log_examples_freq`` steps the trainer pushes a ``wandb.Table``
|
||||
# with one row per sampled batch element containing each camera view
|
||||
# (rendered as ``wandb.Image``), any text fields present in the batch
|
||||
# (``task`` / ``subtask`` / ``memory`` / ``instruction``), and the
|
||||
# ground-truth action chunk's first + last frames. Defaults to 5000 — set
|
||||
# to 0 to disable. Only fires when ``enable=True``, so runs without wandb
|
||||
# are unaffected.
|
||||
log_examples_freq: int = 5000
|
||||
# Number of batch elements to include in each example dump.
|
||||
log_examples_n: int = 4
|
||||
# If True (default), also run ``policy.predict_action_chunk`` on the logged
|
||||
# samples (in eval mode, no_grad) and add predicted vs ground-truth action
|
||||
# columns to the table. Costs one extra forward pass per dump — negligible
|
||||
# at the 5k-step default cadence. Set to ``False`` if your policy doesn't
|
||||
# implement ``predict_action_chunk`` or you want to skip the extra forward.
|
||||
log_examples_predict_actions: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class EMAConfig:
|
||||
"""Exponential Moving Average of trainable policy parameters.
|
||||
|
||||
Diffusion / flow-matching policies (Diffusion Policy, π0/π0.5,
|
||||
pi052) benefit substantially from averaging late-training
|
||||
parameter oscillations — see Chi et al. 2023 §V.D. The official
|
||||
JAX openpi trainer ships EMA with ``ema_decay=0.99`` (default) and
|
||||
``0.999`` for its pi05_libero config; the openpi PyTorch port
|
||||
explicitly lists EMA as unsupported, and LeRobot main inherited
|
||||
that gap. Enabling this flag plugs ema-pytorch
|
||||
(https://github.com/lucidrains/ema-pytorch) into the LeRobot
|
||||
training loop with a shadow ``nn.Module`` clone of the policy.
|
||||
|
||||
Cost: 1× model params in fp32 shadow (~13 GB for pi052's 3.3B
|
||||
params) + one elementwise update per training step (~1% step time).
|
||||
|
||||
Off by default (opt-in): EMA is only beneficial for flow-matching /
|
||||
diffusion policies (pi0/pi05/pi052), and the fp32 shadow copy is pure
|
||||
overhead for other policies (e.g. VLA-JEPA). Set ``--ema.enable=true``
|
||||
to turn it on (the pi05/pi052 training recipes do this). openpi (JAX)
|
||||
ships EMA on for every config; enable it explicitly to match that.
|
||||
"""
|
||||
|
||||
enable: bool = False
|
||||
# Target EMA decay β in θ_ema ← β·θ_ema + (1-β)·θ_live (passed to
|
||||
# ema-pytorch as ``beta``).
|
||||
# 0.999 — last ~1000 steps; pi05_libero default in openpi
|
||||
# 0.99 — last ~100 steps; openpi top-level default
|
||||
# 0.75 — very fast EMA (Diffusion Policy original setting)
|
||||
# 0.9999 — very slow EMA (long classification runs)
|
||||
decay: float = 0.99
|
||||
# Skip the first N calls to ``ema.update()``; during this window
|
||||
# the shadow is just a hard copy of the live weights (no averaging).
|
||||
# Lets early-training rapid changes settle before averaging begins.
|
||||
# Maps to ema-pytorch's ``update_after_step`` (NOT a smooth decay
|
||||
# ramp like older lerobot EMA implementations).
|
||||
warmup_steps: int = 0
|
||||
# When True, the periodic eval block uses the EMA shadow model
|
||||
# directly (``ema.ema_model``) instead of the live policy. Standard
|
||||
# practice for diffusion-style policies — eval scores are usually
|
||||
# 1–3% higher than the live policy at the same step.
|
||||
use_for_eval: bool = True
|
||||
# When True, the periodic wandb training-example dump uses the EMA
|
||||
# shadow for the optional predicted-action columns (so what you see
|
||||
# in W&B matches eval behavior).
|
||||
use_for_wandb_examples: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -147,16 +147,7 @@ class TrainingRecipe:
|
||||
return cls.from_dict(data)
|
||||
|
||||
def _validate_message_recipe(self) -> None:
|
||||
"""Ensure every templated binding is known and the recipe supervises something.
|
||||
|
||||
A recipe is valid if it has at least one of:
|
||||
|
||||
* a ``target: true`` assistant turn (drives text-CE supervision), or
|
||||
* a ``stream: low_level`` turn (drives flow / action supervision via
|
||||
``predict_actions=True``, even when no assistant turn is targeted —
|
||||
e.g. π0.5-style ``low_level_execution`` where the action expert
|
||||
conditions on a user-only ``${subtask}`` prompt).
|
||||
"""
|
||||
"""Ensure every templated binding is known and at least one turn is a target."""
|
||||
assert self.messages is not None
|
||||
known_bindings = set(DEFAULT_BINDINGS) | set(self.bindings or {}) | {"task"}
|
||||
|
||||
@@ -165,14 +156,8 @@ class TrainingRecipe:
|
||||
if missing:
|
||||
raise ValueError(f"MessageTurn references unknown binding(s): {sorted(missing)}")
|
||||
|
||||
has_target = any(turn.target for turn in self.messages)
|
||||
has_low_level = any(turn.stream == "low_level" for turn in self.messages)
|
||||
if not (has_target or has_low_level):
|
||||
raise ValueError(
|
||||
"Message recipes must contain at least one supervised turn — "
|
||||
"either ``target: true`` (text CE) or ``stream: low_level`` "
|
||||
"(flow/action loss)."
|
||||
)
|
||||
if not any(turn.target for turn in self.messages):
|
||||
raise ValueError("Message recipes must contain at least one target turn.")
|
||||
|
||||
def _validate_blend_recipe(self) -> None:
|
||||
"""Ensure each blend component is a non-empty, weighted message recipe."""
|
||||
|
||||
@@ -1,68 +0,0 @@
|
||||
# subtask_mem_vqa_speech — Hi-Robot blend + memory + spoken responses.
|
||||
#
|
||||
# Superset of subtasks_vqa.yaml. Keeps the core subtask + action + VQA
|
||||
# training, and adds two text-supervised tasks:
|
||||
#
|
||||
# high_level_subtask — predict the subtask from the task.
|
||||
# low_level_execution — flow loss with [images, subtask, state].
|
||||
# memory_update — compress progress into a memory note.
|
||||
# user_interjection_response — reply to a user interjection with a
|
||||
# spoken `say` tool call (no plan, no
|
||||
# subtask text — just the spoken reply).
|
||||
# ask_vqa_{top,wrist} — camera-grounded VQA.
|
||||
#
|
||||
# Plan is intentionally left out — memory is the only persistent
|
||||
# high-level state here, keeping the prompt short.
|
||||
#
|
||||
# Requires the dataset to carry `memory`, `interjection` and `say`-tool
|
||||
# annotations (the annotation pipeline's memory + interjection modules)
|
||||
# in addition to `subtask` and `vqa`. Sub-recipes whose `if_present`
|
||||
# bindings are missing simply don't render for that sample, so a
|
||||
# dataset without interjections still trains the rest of the blend.
|
||||
#
|
||||
# Tool-call note: the `say` tool call on the interjection-response turn
|
||||
# is flattened to a `<say>...</say>` text marker by the tokenizer step
|
||||
# (`_flatten_say_tool_calls`) so the LM head learns to emit exactly the
|
||||
# marker the runtime parses back (`_split_plan_and_say`).
|
||||
|
||||
blend:
|
||||
|
||||
high_level_subtask:
|
||||
weight: 0.30
|
||||
messages:
|
||||
- {role: user, content: "${task}", stream: high_level}
|
||||
- {role: assistant, content: "${subtask}", stream: high_level, target: true, if_present: subtask}
|
||||
|
||||
low_level_execution:
|
||||
weight: 0.55
|
||||
messages:
|
||||
# The action expert is conditioned on the SUBTASK — at inference
|
||||
# `HighLevelSubtaskFwd` generates it via the LM head and feeds it
|
||||
# here. `stream: low_level` flips `predict_actions=True` so the
|
||||
# flow loss fires; no text-CE target (subtask prediction is owned
|
||||
# by `high_level_subtask`).
|
||||
- {role: user, content: "${subtask}", stream: low_level, if_present: subtask}
|
||||
|
||||
memory_update:
|
||||
# At inference, `MemoryUpdateFwd` is triggered only on
|
||||
# `subtask_change` events (sparse). Training densely with
|
||||
# `active_at` — i.e. on every frame inside a subtask interval,
|
||||
# not just the boundary frame — supervises the same
|
||||
# (prior_memory, completed_subtask) → current_memory mapping
|
||||
# against varied observations within the interval. The model
|
||||
# learns a stateless transformation; the *when* to emit lives in
|
||||
# the inference trigger, not the model. Annotations only exist
|
||||
# for ~1% of frames as boundary events, so `emitted_at` would
|
||||
# waste 99% of the blend draws (and silently leak them into a
|
||||
# task-conditioned fallback); `active_at` lifts the renderable
|
||||
# rate to ~87% on this dataset.
|
||||
weight: 0.15
|
||||
bindings:
|
||||
prior_memory: "nth_prev(style=memory, offset=1)"
|
||||
current_memory: "active_at(t, style=memory)"
|
||||
completed_subtask: "nth_prev(style=subtask, offset=1)"
|
||||
messages:
|
||||
- {role: user, content: "${task}", stream: high_level}
|
||||
- {role: assistant, content: "Previous memory: ${prior_memory}", stream: high_level, if_present: prior_memory}
|
||||
- {role: user, content: "Completed subtask: ${completed_subtask}", stream: high_level, if_present: completed_subtask}
|
||||
- {role: assistant, content: "${current_memory}", stream: high_level, target: true, if_present: current_memory}
|
||||
@@ -1,99 +0,0 @@
|
||||
# subtask_mem_vqa_robocasa — Hi-Robot blend tuned for RoboCasa cameras.
|
||||
#
|
||||
# Same supervision as ``subtask_mem.yaml`` (subtask + memory) plus
|
||||
# camera-grounded VQA across the three RoboCasa camera keys produced
|
||||
# by ``slurm_build_robocasa_composite_seen.py``:
|
||||
#
|
||||
# observation.images.robot0_agentview_left (left scene view)
|
||||
# observation.images.robot0_agentview_right (right scene view)
|
||||
# observation.images.robot0_eye_in_hand (wrist)
|
||||
#
|
||||
# The annotation pipeline (``examples/annotations/run_hf_job.py``) emits
|
||||
# VQA per camera, so each anchor frame produces three (user, assistant)
|
||||
# rows tagged with their source camera. Each VQA sub-recipe consumes
|
||||
# the rows for one camera via ``camera=...`` resolver bindings.
|
||||
#
|
||||
# Spatial VQA targets (bbox / point) are rewritten from JSON to
|
||||
# PaliGemma ``<locDDDD>`` tokens by ``_messages_vqa_to_loc`` —
|
||||
# ``register_paligemma_loc_tokens`` already collapses them to single
|
||||
# detection-vocab ids so the LM head learns the pretrained pointing /
|
||||
# detection prior, not a 7-piece BPE salad.
|
||||
#
|
||||
# Interjections / spoken responses are intentionally absent — the
|
||||
# annotation job runs with ``--interjections.enabled=false``.
|
||||
|
||||
blend:
|
||||
|
||||
high_level_subtask:
|
||||
weight: 0.25
|
||||
messages:
|
||||
- {role: user, content: "${task}", stream: high_level}
|
||||
- {role: assistant, content: "${subtask}", stream: high_level, target: true, if_present: subtask}
|
||||
|
||||
low_level_execution:
|
||||
weight: 0.45
|
||||
messages:
|
||||
# Action expert is conditioned on the SUBTASK; at inference the
|
||||
# high-level loop generates it via the LM head and feeds it here.
|
||||
# ``stream: low_level`` flips ``predict_actions=True`` so the flow
|
||||
# loss fires; subtask CE is owned by ``high_level_subtask``.
|
||||
- {role: user, content: "${subtask}", stream: low_level, if_present: subtask}
|
||||
|
||||
memory_update:
|
||||
# Trained densely with ``active_at`` — every frame inside a subtask
|
||||
# interval — so the (prior_memory, completed_subtask) → current_memory
|
||||
# mapping is supervised against varied observations. The *when* to
|
||||
# emit lives in the inference trigger (subtask_change), not the
|
||||
# model. See ``subtask_mem.yaml`` for the long version of this note.
|
||||
weight: 0.15
|
||||
bindings:
|
||||
prior_memory: "nth_prev(style=memory, offset=1)"
|
||||
current_memory: "active_at(t, style=memory)"
|
||||
completed_subtask: "nth_prev(style=subtask, offset=1)"
|
||||
messages:
|
||||
- {role: user, content: "${task}", stream: high_level}
|
||||
- {role: assistant, content: "Previous memory: ${prior_memory}", stream: high_level, if_present: prior_memory}
|
||||
- {role: user, content: "Completed subtask: ${completed_subtask}", stream: high_level, if_present: completed_subtask}
|
||||
- {role: assistant, content: "${current_memory}", stream: high_level, target: true, if_present: current_memory}
|
||||
|
||||
ask_vqa_agentview_left:
|
||||
weight: 0.05
|
||||
bindings:
|
||||
vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.robot0_agentview_left)"
|
||||
vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.robot0_agentview_left)"
|
||||
messages:
|
||||
- role: user
|
||||
stream: high_level
|
||||
if_present: vqa_query
|
||||
content:
|
||||
- {type: image, feature: observation.images.robot0_agentview_left}
|
||||
- {type: text, text: "${vqa_query}"}
|
||||
- {role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa}
|
||||
|
||||
ask_vqa_agentview_right:
|
||||
weight: 0.05
|
||||
bindings:
|
||||
vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.robot0_agentview_right)"
|
||||
vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.robot0_agentview_right)"
|
||||
messages:
|
||||
- role: user
|
||||
stream: high_level
|
||||
if_present: vqa_query
|
||||
content:
|
||||
- {type: image, feature: observation.images.robot0_agentview_right}
|
||||
- {type: text, text: "${vqa_query}"}
|
||||
- {role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa}
|
||||
|
||||
ask_vqa_wrist:
|
||||
weight: 0.05
|
||||
bindings:
|
||||
vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.robot0_eye_in_hand)"
|
||||
vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.robot0_eye_in_hand)"
|
||||
messages:
|
||||
- role: user
|
||||
stream: high_level
|
||||
if_present: vqa_query
|
||||
content:
|
||||
- {type: image, feature: observation.images.robot0_eye_in_hand}
|
||||
- {type: text, text: "${vqa_query}"}
|
||||
- {role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa}
|
||||
@@ -1,114 +0,0 @@
|
||||
# subtask_mem_vqa_speech — Hi-Robot blend + memory + spoken responses.
|
||||
#
|
||||
# Superset of subtasks_vqa.yaml. Keeps the core subtask + action + VQA
|
||||
# training, and adds two text-supervised tasks:
|
||||
#
|
||||
# high_level_subtask — predict the subtask from the task.
|
||||
# low_level_execution — flow loss with [images, subtask, state].
|
||||
# memory_update — compress progress into a memory note.
|
||||
# user_interjection_response — reply to a user interjection with a
|
||||
# spoken `say` tool call (no plan, no
|
||||
# subtask text — just the spoken reply).
|
||||
# ask_vqa_{top,wrist} — camera-grounded VQA.
|
||||
#
|
||||
# Plan is intentionally left out — memory is the only persistent
|
||||
# high-level state here, keeping the prompt short.
|
||||
#
|
||||
# Requires the dataset to carry `memory`, `interjection` and `say`-tool
|
||||
# annotations (the annotation pipeline's memory + interjection modules)
|
||||
# in addition to `subtask` and `vqa`. Sub-recipes whose `if_present`
|
||||
# bindings are missing simply don't render for that sample, so a
|
||||
# dataset without interjections still trains the rest of the blend.
|
||||
#
|
||||
# Tool-call note: the `say` tool call on the interjection-response turn
|
||||
# is flattened to a `<say>...</say>` text marker by the tokenizer step
|
||||
# (`_flatten_say_tool_calls`) so the LM head learns to emit exactly the
|
||||
# marker the runtime parses back (`_split_plan_and_say`).
|
||||
|
||||
blend:
|
||||
|
||||
high_level_subtask:
|
||||
weight: 0.25
|
||||
messages:
|
||||
- {role: user, content: "${task}", stream: high_level}
|
||||
- {role: assistant, content: "${subtask}", stream: high_level, target: true, if_present: subtask}
|
||||
|
||||
low_level_execution:
|
||||
weight: 0.40
|
||||
messages:
|
||||
# The action expert is conditioned on the SUBTASK — at inference
|
||||
# `HighLevelSubtaskFwd` generates it via the LM head and feeds it
|
||||
# here. `stream: low_level` flips `predict_actions=True` so the
|
||||
# flow loss fires; no text-CE target (subtask prediction is owned
|
||||
# by `high_level_subtask`).
|
||||
- {role: user, content: "${subtask}", stream: low_level, if_present: subtask}
|
||||
|
||||
memory_update:
|
||||
# At inference, `MemoryUpdateFwd` is triggered only on
|
||||
# `subtask_change` events (sparse). Training densely with
|
||||
# `active_at` — i.e. on every frame inside a subtask interval,
|
||||
# not just the boundary frame — supervises the same
|
||||
# (prior_memory, completed_subtask) → current_memory mapping
|
||||
# against varied observations within the interval. The model
|
||||
# learns a stateless transformation; the *when* to emit lives in
|
||||
# the inference trigger, not the model. Annotations only exist
|
||||
# for ~1% of frames as boundary events, so `emitted_at` would
|
||||
# waste 99% of the blend draws (and silently leak them into the
|
||||
# task-conditioned fallback); `active_at` lifts the renderable
|
||||
# rate to ~87% on Hi-Robot-style datasets.
|
||||
weight: 0.10
|
||||
bindings:
|
||||
prior_memory: "nth_prev(style=memory, offset=1)"
|
||||
current_memory: "active_at(t, style=memory)"
|
||||
completed_subtask: "nth_prev(style=subtask, offset=1)"
|
||||
messages:
|
||||
- {role: user, content: "${task}", stream: high_level}
|
||||
- {role: assistant, content: "Previous memory: ${prior_memory}", stream: high_level, if_present: prior_memory}
|
||||
- {role: user, content: "Completed subtask: ${completed_subtask}", stream: high_level, if_present: completed_subtask}
|
||||
- {role: assistant, content: "${current_memory}", stream: high_level, target: true, if_present: current_memory}
|
||||
|
||||
user_interjection_response:
|
||||
weight: 0.10
|
||||
bindings:
|
||||
interjection: "emitted_at(t, style=interjection)"
|
||||
speech: "emitted_at(t, role=assistant, tool_name=say)"
|
||||
messages:
|
||||
- {role: user, content: "${task}", stream: high_level}
|
||||
- {role: user, content: "${interjection}", stream: high_level, if_present: interjection}
|
||||
# Spoken reply only: the assistant turn carries no text content,
|
||||
# just a `say` tool call (`tool_calls_from: speech`). The chat
|
||||
# tokenizer flattens it to a `<say>...</say>` marker, so the
|
||||
# supervised target trains the model to respond to an
|
||||
# interjection with a spoken acknowledgement.
|
||||
- {role: assistant, stream: high_level, target: true, if_present: speech, tool_calls_from: speech}
|
||||
|
||||
# VQA is view-dependent — each camera gets its own sub-recipe so the
|
||||
# resolver disambiguates via `camera=...`. Camera keys match
|
||||
# subtasks_vqa.yaml (`front` + `wrist`); adjust to your dataset.
|
||||
ask_vqa_top:
|
||||
weight: 0.075
|
||||
bindings:
|
||||
vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.front)"
|
||||
vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.front)"
|
||||
messages:
|
||||
- role: user
|
||||
stream: high_level
|
||||
if_present: vqa_query
|
||||
content:
|
||||
- {type: image, feature: observation.images.front}
|
||||
- {type: text, text: "${vqa_query}"}
|
||||
- {role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa}
|
||||
|
||||
ask_vqa_wrist:
|
||||
weight: 0.075
|
||||
bindings:
|
||||
vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.wrist)"
|
||||
vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.wrist)"
|
||||
messages:
|
||||
- role: user
|
||||
stream: high_level
|
||||
if_present: vqa_query
|
||||
content:
|
||||
- {type: image, feature: observation.images.wrist}
|
||||
- {type: text, text: "${vqa_query}"}
|
||||
- {role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa}
|
||||
@@ -1,61 +0,0 @@
|
||||
# subtasks_vqa — Hi-Robot blend for PI052 (PaliGemma backbone).
|
||||
#
|
||||
# Trains two things only: subtasks and VQA. Plan and memory are
|
||||
# intentionally left out — keeps the prompt short and the training
|
||||
# surface small. The fuller blend with memory + spoken replies is
|
||||
# ``subtask_mem_vqa_speech.yaml``.
|
||||
#
|
||||
# high_level_subtask — predict the subtask from the task.
|
||||
# low_level_execution — flow loss with [images, subtask, state].
|
||||
# ask_vqa_{top,wrist} — camera-grounded VQA.
|
||||
#
|
||||
# PI052's text tokenizer renders these messages as plain
|
||||
# ``Role: content`` text (PaliGemma is not chat-pretrained).
|
||||
|
||||
blend:
|
||||
|
||||
high_level_subtask:
|
||||
weight: 0.40
|
||||
messages:
|
||||
- {role: user, content: "${task}", stream: high_level}
|
||||
- {role: assistant, content: "${subtask}", stream: high_level, target: true, if_present: subtask}
|
||||
|
||||
low_level_execution:
|
||||
weight: 0.40
|
||||
messages:
|
||||
# The action expert is conditioned on the SUBTASK — at inference
|
||||
# the high-level loop (``HighLevelSubtaskFwd``) generates the
|
||||
# subtask via the LM head and feeds it here. The action expert's
|
||||
# prefix is [images, subtask, state]. ``stream: low_level`` flips
|
||||
# ``predict_actions=True`` so the flow loss fires; no text-CE
|
||||
# target here (subtask prediction is owned by
|
||||
# ``high_level_subtask``).
|
||||
- {role: user, content: "${subtask}", stream: low_level, if_present: subtask}
|
||||
|
||||
ask_vqa_top:
|
||||
weight: 0.10
|
||||
bindings:
|
||||
vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.front)"
|
||||
vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.front)"
|
||||
messages:
|
||||
- role: user
|
||||
stream: high_level
|
||||
if_present: vqa_query
|
||||
content:
|
||||
- {type: image, feature: observation.images.front}
|
||||
- {type: text, text: "${vqa_query}"}
|
||||
- {role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa}
|
||||
|
||||
ask_vqa_wrist:
|
||||
weight: 0.10
|
||||
bindings:
|
||||
vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.wrist)"
|
||||
vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.wrist)"
|
||||
messages:
|
||||
- role: user
|
||||
stream: high_level
|
||||
if_present: vqa_query
|
||||
content:
|
||||
- {type: image, feature: observation.images.wrist}
|
||||
- {type: text, text: "${vqa_query}"}
|
||||
- {role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa}
|
||||
@@ -30,7 +30,7 @@ from lerobot.utils.hub import HubMixin
|
||||
from lerobot.utils.sample_weighting import SampleWeightingConfig
|
||||
|
||||
from . import parser
|
||||
from .default import DatasetConfig, EMAConfig, EvalConfig, PeftConfig, WandBConfig
|
||||
from .default import DatasetConfig, EvalConfig, PeftConfig, WandBConfig
|
||||
from .policies import PreTrainedConfig
|
||||
from .rewards import RewardModelConfig
|
||||
|
||||
@@ -111,20 +111,9 @@ class TrainPipelineConfig(HubMixin):
|
||||
scheduler: LRSchedulerConfig | None = None
|
||||
eval: EvalConfig = field(default_factory=EvalConfig)
|
||||
wandb: WandBConfig = field(default_factory=WandBConfig)
|
||||
ema: EMAConfig = field(default_factory=EMAConfig)
|
||||
peft: PeftConfig | None = None
|
||||
|
||||
# VQA oversampling. When set (a fraction in (0, 1)), the training
|
||||
# dataloader uses a WeightedEpisodeAwareSampler that draws frames
|
||||
# carrying a `vqa` language annotation often enough that they make
|
||||
# up roughly this fraction of the training stream. VQA annotations
|
||||
# are typically sparse, so without this they are underrepresented.
|
||||
# `None` (default) keeps uniform episode-aware sampling.
|
||||
vqa_target_fraction: float | None = None
|
||||
|
||||
# Sample weighting configuration (e.g., for RA-BC training). Old
|
||||
# inline ``use_rabc`` / ``rabc_*`` params are migrated to this
|
||||
# field by ``_migrate_legacy_rabc_keys`` above.
|
||||
# Sample weighting configuration (e.g., for RA-BC training)
|
||||
sample_weighting: SampleWeightingConfig | None = None
|
||||
|
||||
# Rename map for the observation to override the image and state keys
|
||||
|
||||
@@ -35,6 +35,7 @@ from .dataset_tools import (
|
||||
remove_feature,
|
||||
split_dataset,
|
||||
)
|
||||
from .factory import make_dataset, resolve_delta_timestamps
|
||||
from .image_writer import safe_stop_image_writer
|
||||
from .io_utils import load_episodes, write_stats
|
||||
from .language import (
|
||||
@@ -49,24 +50,11 @@ from .lerobot_dataset import LeRobotDataset
|
||||
from .multi_dataset import MultiLeRobotDataset
|
||||
from .pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
|
||||
from .pyav_utils import check_video_encoder_parameters_pyav, detect_available_encoders_pyav
|
||||
from .sampler import EpisodeAwareSampler, WeightedEpisodeAwareSampler
|
||||
from .sampler import EpisodeAwareSampler
|
||||
from .streaming_dataset import StreamingLeRobotDataset
|
||||
from .utils import DEFAULT_EPISODES_PATH, create_lerobot_dataset_card
|
||||
from .video_utils import VideoEncodingManager
|
||||
|
||||
|
||||
def make_dataset(*args, **kwargs):
|
||||
from .factory import make_dataset as _make_dataset
|
||||
|
||||
return _make_dataset(*args, **kwargs)
|
||||
|
||||
|
||||
def resolve_delta_timestamps(*args, **kwargs):
|
||||
from .factory import resolve_delta_timestamps as _resolve_delta_timestamps
|
||||
|
||||
return _resolve_delta_timestamps(*args, **kwargs)
|
||||
|
||||
|
||||
# NOTE: Low-level I/O functions (cast_stats_to_numpy, get_parquet_file_size_in_mb, etc.)
|
||||
# and legacy migration constants are intentionally NOT re-exported here.
|
||||
# Import directly: ``from lerobot.datasets.io_utils import ...``
|
||||
@@ -77,7 +65,6 @@ __all__ = [
|
||||
"DEFAULT_QUANTILES",
|
||||
"EVENT_ONLY_STYLES",
|
||||
"EpisodeAwareSampler",
|
||||
"WeightedEpisodeAwareSampler",
|
||||
"LANGUAGE_EVENTS",
|
||||
"LANGUAGE_PERSISTENT",
|
||||
"LeRobotDataset",
|
||||
|
||||
@@ -126,53 +126,10 @@ class DatasetReader:
|
||||
def _load_hf_dataset(self) -> datasets.Dataset:
|
||||
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
|
||||
features = get_hf_features_from_features(self._meta.features)
|
||||
# Datasets annotated with the PR1 language columns may have been
|
||||
# written without registering those columns in ``meta/info.json``
|
||||
# (e.g. they predate ``CODEBASE_VERSION="v3.1"`` and were
|
||||
# back-filled by ``lerobot-annotate``). Probe a single parquet
|
||||
# shard and graft the column features on so the strict
|
||||
# ``Dataset.from_parquet`` cast doesn't fail with
|
||||
# ``column names don't match``.
|
||||
features = self._extend_features_with_language_columns(features)
|
||||
hf_dataset = load_nested_dataset(self.root / "data", features=features, episodes=self.episodes)
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
return hf_dataset
|
||||
|
||||
def _extend_features_with_language_columns(
|
||||
self, features: datasets.Features
|
||||
) -> datasets.Features:
|
||||
"""Add ``language_persistent`` / ``language_events`` to ``features``
|
||||
when the underlying parquet shards declare them but the metadata
|
||||
doesn't. No-op when neither column is present or both are
|
||||
already registered.
|
||||
"""
|
||||
# Find any one parquet to peek at; bail if there are none yet
|
||||
# (the dataset will fail later for an unrelated reason and we
|
||||
# want that error to surface as-is).
|
||||
try:
|
||||
sample = next((self.root / "data").glob("*/*.parquet"))
|
||||
except StopIteration:
|
||||
return features
|
||||
|
||||
from pyarrow import parquet as _pq # noqa: PLC0415
|
||||
|
||||
schema_names = set(_pq.read_schema(sample).names)
|
||||
from .language import ( # noqa: PLC0415
|
||||
LANGUAGE_EVENTS,
|
||||
LANGUAGE_PERSISTENT,
|
||||
language_events_column_feature,
|
||||
language_persistent_column_feature,
|
||||
)
|
||||
|
||||
extra: dict[str, object] = {}
|
||||
if LANGUAGE_PERSISTENT in schema_names and LANGUAGE_PERSISTENT not in features:
|
||||
extra[LANGUAGE_PERSISTENT] = language_persistent_column_feature()
|
||||
if LANGUAGE_EVENTS in schema_names and LANGUAGE_EVENTS not in features:
|
||||
extra[LANGUAGE_EVENTS] = language_events_column_feature()
|
||||
if not extra:
|
||||
return features
|
||||
return datasets.Features({**features, **extra})
|
||||
|
||||
def _check_cached_episodes_sufficient(self) -> bool:
|
||||
"""Check if the cached dataset contains all requested episodes and their video files."""
|
||||
if self.hf_dataset is None or len(self.hf_dataset) == 0:
|
||||
|
||||
@@ -170,29 +170,6 @@ def render_sample(
|
||||
"""
|
||||
persistent_rows = _normalize_rows(persistent or [])
|
||||
event_rows = _normalize_rows(events or [])
|
||||
|
||||
# VQA-priority routing. A ``vqa`` annotation is sparse and
|
||||
# view-dependent; the plain weighted blend would (a) waste a draw
|
||||
# whenever it picks an ``ask_vqa*`` sub-recipe for a frame that has
|
||||
# no VQA, and (b) silently drop a VQA-annotated frame whenever it
|
||||
# picks a non-VQA sub-recipe. So: if the blend has ``ask_vqa*``
|
||||
# sub-recipes and *this* frame carries one of their VQA bindings,
|
||||
# render VQA here regardless of the weighted draw. That makes VQA's
|
||||
# recipe-side training share equal the VQA-annotation density (the
|
||||
# maximum reachable without a dataset-level oversampling sampler).
|
||||
if recipe.blend is not None:
|
||||
vqa_rendered = _render_vqa_if_present(
|
||||
recipe,
|
||||
persistent=persistent_rows,
|
||||
events=event_rows,
|
||||
t=t,
|
||||
sample_idx=sample_idx,
|
||||
task=task,
|
||||
dataset_ctx=dataset_ctx,
|
||||
)
|
||||
if vqa_rendered is not None:
|
||||
return vqa_rendered
|
||||
|
||||
selected_recipe = _select_recipe(recipe, sample_idx)
|
||||
bindings = _resolve_bindings(
|
||||
selected_recipe,
|
||||
@@ -206,59 +183,6 @@ def render_sample(
|
||||
return _render_message_recipe(selected_recipe, bindings)
|
||||
|
||||
|
||||
def _render_vqa_if_present(
|
||||
recipe: TrainingRecipe,
|
||||
*,
|
||||
persistent: Sequence[LanguageRow],
|
||||
events: Sequence[LanguageRow],
|
||||
t: float,
|
||||
sample_idx: int,
|
||||
task: str | None,
|
||||
dataset_ctx: Any | None,
|
||||
) -> RenderedMessages | None:
|
||||
"""Render an ``ask_vqa*`` sub-recipe iff this frame carries a VQA
|
||||
annotation; otherwise return ``None`` so the caller falls back to the
|
||||
normal weighted blend.
|
||||
|
||||
When several VQA sub-recipes resolve (e.g. a frame annotated for more
|
||||
than one camera), one is chosen deterministically by relative weight.
|
||||
"""
|
||||
assert recipe.blend is not None
|
||||
renderable: list[tuple[float, RenderedMessages]] = []
|
||||
for name, component in recipe.blend.items():
|
||||
if not name.startswith("ask_vqa"):
|
||||
continue
|
||||
bindings = _resolve_bindings(
|
||||
component,
|
||||
persistent=persistent,
|
||||
events=events,
|
||||
t=t,
|
||||
sample_idx=sample_idx,
|
||||
task=task,
|
||||
dataset_ctx=dataset_ctx,
|
||||
)
|
||||
rendered = _render_message_recipe(component, bindings)
|
||||
if rendered is not None:
|
||||
renderable.append((float(component.weight or 0.0), rendered))
|
||||
|
||||
if not renderable:
|
||||
return None
|
||||
if len(renderable) == 1:
|
||||
return renderable[0][1]
|
||||
|
||||
# Multiple cameras have a VQA for this frame — deterministic pick by
|
||||
# relative weight (fall back to a uniform draw if all weights are 0).
|
||||
total = sum(w for w, _ in renderable) or float(len(renderable))
|
||||
digest = hashlib.blake2b(f"vqa:{sample_idx}".encode(), digest_size=8).digest()
|
||||
draw = int.from_bytes(digest, "big") / 2**64 * total
|
||||
cumulative = 0.0
|
||||
for w, rendered in renderable:
|
||||
cumulative += w or (total / len(renderable))
|
||||
if draw < cumulative:
|
||||
return rendered
|
||||
return renderable[-1][1]
|
||||
|
||||
|
||||
def _select_recipe(recipe: TrainingRecipe, sample_idx: int) -> TrainingRecipe:
|
||||
"""Pick a deterministic blend component for ``sample_idx`` (or return ``recipe``)."""
|
||||
if recipe.blend is None:
|
||||
@@ -422,15 +346,7 @@ def _render_message_recipe(
|
||||
if turn.target:
|
||||
target_indices.append(message_idx)
|
||||
|
||||
# A render is meaningful if it supervises *something*: either a
|
||||
# text-CE target turn, or a ``low_level`` stream turn (flow / action
|
||||
# supervision — e.g. the flow-only ``low_level_execution`` recipe,
|
||||
# ``user(${subtask})`` with ``stream: low_level`` and no target).
|
||||
# Without this, a flow-only recipe renders to ``None`` every time
|
||||
# the blend draws it → ``predict_actions`` is never True → the
|
||||
# action expert never receives a flow loss.
|
||||
has_low_level = any(stream == "low_level" for stream in streams)
|
||||
if not target_indices and not has_low_level:
|
||||
if not target_indices:
|
||||
return None
|
||||
|
||||
rendered = {
|
||||
@@ -487,10 +403,8 @@ def _validate_rendered(rendered: RenderedMessages) -> None:
|
||||
|
||||
if len(streams) != len(messages):
|
||||
raise ValueError("message_streams must be aligned with messages.")
|
||||
# Valid iff it supervises something: a text-CE target turn OR a
|
||||
# ``low_level`` stream turn (flow / action supervision).
|
||||
if not target_indices and not any(s == "low_level" for s in streams):
|
||||
raise ValueError("Rendered samples must contain a target message or a low_level-stream message.")
|
||||
if not target_indices:
|
||||
raise ValueError("Rendered samples must contain at least one target message.")
|
||||
for idx in target_indices:
|
||||
if idx < 0 or idx >= len(messages):
|
||||
raise ValueError(f"Target message index {idx} is out of bounds.")
|
||||
|
||||
@@ -30,6 +30,7 @@ class EpisodeAwareSampler:
|
||||
drop_n_first_frames: int = 0,
|
||||
drop_n_last_frames: int = 0,
|
||||
shuffle: bool = False,
|
||||
generator: torch.Generator | None = None,
|
||||
):
|
||||
"""Sampler that optionally incorporates episode boundary information.
|
||||
|
||||
@@ -41,6 +42,10 @@ class EpisodeAwareSampler:
|
||||
drop_n_first_frames: Number of frames to drop from the start of each episode.
|
||||
drop_n_last_frames: Number of frames to drop from the end of each episode.
|
||||
shuffle: Whether to shuffle the indices.
|
||||
generator: Generator used for shuffling. Exposing this attribute (even when None) lets
|
||||
`accelerate` register it as the synchronized RNG in distributed training, so
|
||||
every rank draws the same permutation and batch shards stay disjoint. When
|
||||
None, shuffling falls back to the global torch RNG.
|
||||
"""
|
||||
if drop_n_first_frames < 0:
|
||||
raise ValueError(f"drop_n_first_frames must be >= 0, got {drop_n_first_frames}")
|
||||
@@ -73,10 +78,11 @@ class EpisodeAwareSampler:
|
||||
|
||||
self.indices = indices
|
||||
self.shuffle = shuffle
|
||||
self.generator = generator
|
||||
|
||||
def __iter__(self) -> Iterator[int]:
|
||||
if self.shuffle:
|
||||
for i in torch.randperm(len(self.indices)):
|
||||
for i in torch.randperm(len(self.indices), generator=self.generator):
|
||||
yield self.indices[i]
|
||||
else:
|
||||
for i in self.indices:
|
||||
@@ -84,66 +90,3 @@ class EpisodeAwareSampler:
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.indices)
|
||||
|
||||
|
||||
class WeightedEpisodeAwareSampler(EpisodeAwareSampler):
|
||||
"""``EpisodeAwareSampler`` that draws frames *with replacement* in
|
||||
proportion to per-frame weights.
|
||||
|
||||
Used to oversample frames carrying a sparse annotation (e.g. a VQA
|
||||
question) so the policy sees them more often than their natural
|
||||
dataset density. One epoch still yields ``len(self.indices)``
|
||||
samples — the weights only change the *composition* of the stream,
|
||||
not its length. Each epoch re-draws, so the oversampled subset
|
||||
varies run to run.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset_from_indices: list[int],
|
||||
dataset_to_indices: list[int],
|
||||
frame_weights,
|
||||
*,
|
||||
episode_indices_to_use: list | None = None,
|
||||
drop_n_first_frames: int = 0,
|
||||
drop_n_last_frames: int = 0,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
dataset_from_indices: Episode start indices (see ``EpisodeAwareSampler``).
|
||||
dataset_to_indices: Episode end indices.
|
||||
frame_weights: 1-D sequence/tensor of non-negative weights, one per
|
||||
dataset frame (length == total dataset frames). Higher weight ⇒
|
||||
that frame is sampled more often.
|
||||
episode_indices_to_use / drop_n_first_frames / drop_n_last_frames:
|
||||
Same meaning as ``EpisodeAwareSampler`` — the episode-boundary
|
||||
frame filtering is applied first, then weighting is restricted
|
||||
to the surviving frames.
|
||||
"""
|
||||
super().__init__(
|
||||
dataset_from_indices,
|
||||
dataset_to_indices,
|
||||
episode_indices_to_use=episode_indices_to_use,
|
||||
drop_n_first_frames=drop_n_first_frames,
|
||||
drop_n_last_frames=drop_n_last_frames,
|
||||
shuffle=False,
|
||||
)
|
||||
weights = torch.as_tensor(frame_weights, dtype=torch.double).flatten()
|
||||
idx = torch.tensor(self.indices, dtype=torch.long)
|
||||
if weights.numel() <= int(idx.max()):
|
||||
raise ValueError(
|
||||
f"frame_weights has {weights.numel()} entries but the sampler "
|
||||
f"references frame index {int(idx.max())}."
|
||||
)
|
||||
selected = weights[idx]
|
||||
if not torch.isfinite(selected).all() or bool((selected < 0).any()):
|
||||
raise ValueError("frame_weights must be finite and non-negative.")
|
||||
if float(selected.sum()) <= 0.0:
|
||||
# All surviving frames have zero weight — fall back to uniform.
|
||||
selected = torch.ones_like(selected)
|
||||
self._weights = selected
|
||||
|
||||
def __iter__(self) -> Iterator[int]:
|
||||
picks = torch.multinomial(self._weights, num_samples=len(self.indices), replacement=True)
|
||||
for i in picks.tolist():
|
||||
yield self.indices[i]
|
||||
|
||||
@@ -366,24 +366,17 @@ def get_safe_version(repo_id: str, version: str | packaging.version.Version) ->
|
||||
hub_versions = get_repo_versions(repo_id)
|
||||
|
||||
if not hub_versions:
|
||||
msg = (
|
||||
f"Repo {repo_id!r} has no codebase-version tags. The dataset "
|
||||
f"either doesn't exist on the Hub yet, or it was uploaded "
|
||||
f"without a ``v3.x``-style tag. To tag an existing dataset run:\n"
|
||||
f" from huggingface_hub import HfApi\n"
|
||||
f" HfApi().create_tag({repo_id!r}, tag='v3.0', repo_type='dataset', exist_ok=True)"
|
||||
raise RevisionNotFoundError(
|
||||
f"""Your dataset must be tagged with a codebase version.
|
||||
Assuming _version_ is the codebase_version value in the info.json, you can run this:
|
||||
```python
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
hub_api = HfApi()
|
||||
hub_api.create_tag("{repo_id}", tag="_version_", repo_type="dataset")
|
||||
```
|
||||
"""
|
||||
)
|
||||
# ``RevisionNotFoundError`` extends ``HfHubHTTPError`` whose
|
||||
# ``__init__`` indexes ``response.headers`` unconditionally on
|
||||
# current ``huggingface_hub`` versions. Constructing it without
|
||||
# a real ``Response`` object crashes with either
|
||||
# ``TypeError: missing 1 required keyword-only argument`` (old
|
||||
# builds) or ``AttributeError: 'NoneType' object has no attribute
|
||||
# 'headers'`` (new builds). Skip that path entirely — this isn't
|
||||
# really an HTTP error, it's a configuration issue — and raise a
|
||||
# plain ``RuntimeError`` so the message actually reaches the
|
||||
# caller.
|
||||
raise RuntimeError(msg)
|
||||
|
||||
if target_version in hub_versions:
|
||||
return f"v{target_version}"
|
||||
|
||||
@@ -33,8 +33,8 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
# Dimensions for the flat action/state vectors used by the LeRobot wrapper.
|
||||
# These correspond to the PandaOmron robot in RoboCasa365.
|
||||
OBS_STATE_DIM = 16 # ee_pos_rel(3) + ee_quat_rel(4) + base_pos(3) + base_quat(4) + gripper_qpos(2)
|
||||
ACTION_DIM = 12 # ee_pos(3) + ee_rot(3) + gripper(1) + base_motion(4) + control_mode(1)
|
||||
OBS_STATE_DIM = 16 # base_pos(3) + base_quat(4) + ee_pos_rel(3) + ee_quat_rel(4) + gripper_qpos(2)
|
||||
ACTION_DIM = 12 # base_motion(4) + control_mode(1) + ee_pos(3) + ee_rot(3) + gripper(1)
|
||||
ACTION_LOW = -1.0
|
||||
ACTION_HIGH = 1.0
|
||||
|
||||
@@ -101,15 +101,14 @@ def _resolve_tasks(task: str) -> tuple[list[str], str | None]:
|
||||
def convert_action(flat_action: np.ndarray) -> dict[str, Any]:
|
||||
"""Split a flat (12,) action vector into a RoboCasa action dict.
|
||||
|
||||
Layout (openpi / robocasa.utils.env_utils.convert_action order):
|
||||
ee_pos(3) + ee_rot(3) + gripper(1) + base_motion(4) + control_mode(1)
|
||||
Layout: base_motion(4) + control_mode(1) + ee_pos(3) + ee_rot(3) + gripper(1)
|
||||
"""
|
||||
return {
|
||||
"action.end_effector_position": flat_action[0:3],
|
||||
"action.end_effector_rotation": flat_action[3:6],
|
||||
"action.gripper_close": flat_action[6:7],
|
||||
"action.base_motion": flat_action[7:11],
|
||||
"action.control_mode": flat_action[11:12],
|
||||
"action.base_motion": flat_action[0:4],
|
||||
"action.control_mode": flat_action[4:5],
|
||||
"action.end_effector_position": flat_action[5:8],
|
||||
"action.end_effector_rotation": flat_action[8:11],
|
||||
"action.gripper_close": flat_action[11:12],
|
||||
}
|
||||
|
||||
|
||||
@@ -231,14 +230,12 @@ class RoboCasaEnv(gym.Env):
|
||||
return {"pixels": images}
|
||||
|
||||
# `state.*` keys come from PandaOmronKeyConverter inside the wrapper.
|
||||
# openpi state order: ee first, then base, then gripper (matches the
|
||||
# openpi robocasa pipeline / examples/robocasa/main.py state layout).
|
||||
agent_pos = np.concatenate(
|
||||
[
|
||||
raw_obs.get("state.end_effector_position_relative", np.zeros(3)),
|
||||
raw_obs.get("state.end_effector_rotation_relative", np.zeros(4)),
|
||||
raw_obs.get("state.base_position", np.zeros(3)),
|
||||
raw_obs.get("state.base_rotation", np.zeros(4)),
|
||||
raw_obs.get("state.end_effector_position_relative", np.zeros(3)),
|
||||
raw_obs.get("state.end_effector_rotation_relative", np.zeros(4)),
|
||||
raw_obs.get("state.gripper_qpos", np.zeros(2)),
|
||||
],
|
||||
axis=-1,
|
||||
|
||||
@@ -104,8 +104,6 @@ class AdamWConfig(OptimizerConfig):
|
||||
eps: float = 1e-8
|
||||
weight_decay: float = 1e-2
|
||||
grad_clip_norm: float = 10.0
|
||||
foreach: bool | None = None
|
||||
fused: bool | None = None
|
||||
|
||||
def build(self, params: OptimizerParams) -> torch.optim.Optimizer:
|
||||
kwargs = asdict(self)
|
||||
|
||||
@@ -25,7 +25,6 @@ from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig as M
|
||||
from .pi0.configuration_pi0 import PI0Config as PI0Config
|
||||
from .pi0_fast.configuration_pi0_fast import PI0FastConfig as PI0FastConfig
|
||||
from .pi05.configuration_pi05 import PI05Config as PI05Config
|
||||
from .pi052.configuration_pi052 import PI052Config as PI052Config
|
||||
from .pretrained import PreTrainedPolicy as PreTrainedPolicy
|
||||
from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig
|
||||
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
|
||||
@@ -50,7 +49,6 @@ __all__ = [
|
||||
"PI0Config",
|
||||
"PI0FastConfig",
|
||||
"PI05Config",
|
||||
"PI052Config",
|
||||
"SmolVLAConfig",
|
||||
"TDMPCConfig",
|
||||
"VQBeTConfig",
|
||||
|
||||
@@ -63,79 +63,6 @@ from .wall_x.configuration_wall_x import WallXConfig
|
||||
from .xvla.configuration_xvla import XVLAConfig
|
||||
|
||||
|
||||
def _restore_pi052_pretrained_state(
|
||||
preprocessor: PolicyProcessorPipeline,
|
||||
postprocessor: PolicyProcessorPipeline,
|
||||
pretrained_path: str,
|
||||
) -> None:
|
||||
"""Transplant saved stateful blobs from a pi052 checkpoint into fresh pipelines.
|
||||
|
||||
pi052's preprocessor includes steps whose constructor args don't
|
||||
JSON-roundtrip (``RenderMessagesStep.recipe`` is a Python object,
|
||||
``ActionTokenizerProcessorStep.action_tokenizer_name`` is a
|
||||
fitted-tokenizer path that may not exist at eval time). We rebuild
|
||||
those pipelines fresh from ``config.recipe_path`` and then walk
|
||||
over the saved ``policy_{pre,post}processor.json`` files to find
|
||||
each step's ``state_file`` reference and load the bytes back into
|
||||
the corresponding fresh step. Today that's only the
|
||||
NormalizerProcessorStep / UnnormalizerProcessorStep (the action /
|
||||
state quantile stats), but the loop is generic so any future
|
||||
stateful step picks up its blob automatically.
|
||||
|
||||
Pairing is by ``registry_name`` AND position so a benign reorder
|
||||
on the saved side surfaces a warning rather than silently feeding
|
||||
the wrong tensors into the wrong step.
|
||||
"""
|
||||
import json # noqa: PLC0415
|
||||
import logging # noqa: PLC0415
|
||||
from pathlib import Path # noqa: PLC0415
|
||||
|
||||
from safetensors.torch import load_file # noqa: PLC0415
|
||||
|
||||
base = Path(pretrained_path)
|
||||
if not base.exists():
|
||||
return
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
for pipeline, config_filename in [
|
||||
(preprocessor, f"{POLICY_PREPROCESSOR_DEFAULT_NAME}.json"),
|
||||
(postprocessor, f"{POLICY_POSTPROCESSOR_DEFAULT_NAME}.json"),
|
||||
]:
|
||||
config_path = base / config_filename
|
||||
if not config_path.exists():
|
||||
continue
|
||||
saved = json.loads(config_path.read_text())
|
||||
|
||||
for idx, (saved_step, fresh_step) in enumerate(
|
||||
zip(saved.get("steps", []), pipeline.steps, strict=False)
|
||||
):
|
||||
state_file = saved_step.get("state_file")
|
||||
if not state_file:
|
||||
continue
|
||||
saved_name = saved_step.get("registry_name")
|
||||
fresh_name = getattr(type(fresh_step), "_registry_name", None)
|
||||
if saved_name and fresh_name and saved_name != fresh_name:
|
||||
log.warning(
|
||||
"PI052 state restore: %s step %d registry name mismatch "
|
||||
"(saved=%s, fresh=%s); skipping %s",
|
||||
config_filename, idx, saved_name, fresh_name, state_file,
|
||||
)
|
||||
continue
|
||||
state_path = base / state_file
|
||||
if not state_path.exists():
|
||||
log.warning(
|
||||
"PI052 state restore: %s missing at %s; %s left at fresh init",
|
||||
state_file, base, fresh_name,
|
||||
)
|
||||
continue
|
||||
fresh_step.load_state_dict(load_file(str(state_path)))
|
||||
log.info(
|
||||
"PI052 state restore: loaded %s into %s (step %d)",
|
||||
state_file, fresh_name, idx,
|
||||
)
|
||||
|
||||
|
||||
def _reconnect_relative_absolute_steps(
|
||||
preprocessor: PolicyProcessorPipeline, postprocessor: PolicyProcessorPipeline
|
||||
) -> None:
|
||||
@@ -203,10 +130,6 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
from .pi05.modeling_pi05 import PI05Policy
|
||||
|
||||
return PI05Policy
|
||||
elif name == "pi052":
|
||||
from .pi052.modeling_pi052 import PI052Policy
|
||||
|
||||
return PI052Policy
|
||||
elif name == "gaussian_actor":
|
||||
from .gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy
|
||||
|
||||
@@ -255,8 +178,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
|
||||
Args:
|
||||
policy_type: The type of the policy. Supported types include "tdmpc",
|
||||
"multi_task_dit", "diffusion", "act", "vqbet", "pi0", "pi05",
|
||||
"pi052", "gaussian_actor", "smolvla", "wall_x", "molmoact2".
|
||||
"multi_task_dit", "diffusion", "act", "vqbet", "pi0", "pi05", "gaussian_actor",
|
||||
"smolvla", "wall_x", "molmoact2".
|
||||
**kwargs: Keyword arguments to be passed to the configuration class constructor.
|
||||
|
||||
Returns:
|
||||
@@ -279,10 +202,6 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
return PI0Config(**kwargs)
|
||||
elif policy_type == "pi05":
|
||||
return PI05Config(**kwargs)
|
||||
elif policy_type == "pi052":
|
||||
from .pi052.configuration_pi052 import PI052Config
|
||||
|
||||
return PI052Config(**kwargs)
|
||||
elif policy_type == "gaussian_actor":
|
||||
return GaussianActorConfig(**kwargs)
|
||||
elif policy_type == "smolvla":
|
||||
@@ -327,12 +246,6 @@ class ProcessorConfigKwargs(TypedDict, total=False):
|
||||
preprocessor_overrides: dict[str, Any] | None
|
||||
postprocessor_overrides: dict[str, Any] | None
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None
|
||||
# Optional: HF Hub repo id of the dataset the policy is being
|
||||
# trained on. Used by policies that auto-fit pieces of their
|
||||
# preprocessing (e.g. pi052's FAST action tokenizer per
|
||||
# Pertsch et al. 2025 [64], π0.5 §III.C). When omitted, those
|
||||
# policies fall back to their universal pre-fitted tokenizers.
|
||||
dataset_repo_id: str | None
|
||||
dataset_meta: Any | None
|
||||
|
||||
|
||||
@@ -366,29 +279,6 @@ def make_pre_post_processors(
|
||||
NotImplementedError: If a processor factory is not implemented for the given
|
||||
policy configuration type.
|
||||
"""
|
||||
if pretrained_path and getattr(policy_cfg, "type", None) == "pi052":
|
||||
# pi052 pipelines don't roundtrip through the saved
|
||||
# ``policy_preprocessor.json``: ``RenderMessagesStep`` holds a
|
||||
# Python ``TrainingRecipe`` (not JSON-serializable; saved as
|
||||
# ``{}``) and ``ActionTokenizerProcessorStep`` saves a host-only
|
||||
# FAST tokenizer path. Generic ``from_pretrained`` then dies
|
||||
# with ``RenderMessagesStep.__init__() missing 1 required
|
||||
# positional argument: 'recipe'`` (job 22164494).
|
||||
#
|
||||
# Mirror ``lerobot_pi052_runtime``'s bootstrap: build pipelines
|
||||
# fresh from ``config.recipe_path`` and transplant the saved
|
||||
# stateful blobs (normalizer stats) from the checkpoint dir.
|
||||
from .pi052.processor_pi052 import make_pi052_pre_post_processors
|
||||
|
||||
preprocessor, postprocessor = make_pi052_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
dataset_repo_id=kwargs.get("dataset_repo_id"),
|
||||
)
|
||||
_restore_pi052_pretrained_state(preprocessor, postprocessor, pretrained_path)
|
||||
_reconnect_relative_absolute_steps(preprocessor, postprocessor)
|
||||
return preprocessor, postprocessor
|
||||
|
||||
if pretrained_path:
|
||||
# TODO(Steven): Temporary patch, implement correctly the processors for Gr00t
|
||||
if isinstance(policy_cfg, GrootConfig):
|
||||
@@ -483,22 +373,6 @@ def make_pre_post_processors(
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif policy_cfg.type == "pi052":
|
||||
# NOTE: PI052Config subclasses PI05Config, so this branch MUST
|
||||
# come before the PI05Config isinstance check below (otherwise
|
||||
# pi052 would silently pick up π0.5's processor).
|
||||
from .pi052.processor_pi052 import make_pi052_pre_post_processors
|
||||
|
||||
processors = make_pi052_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
# ``dataset_repo_id`` flows in via kwargs when FAST CE is
|
||||
# enabled — the train loop sets it from ``--dataset.repo_id``.
|
||||
# When ``None``, ``make_pi052_pre_post_processors`` skips
|
||||
# the auto-fit and uses the universal tokenizer.
|
||||
dataset_repo_id=kwargs.get("dataset_repo_id"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, PI05Config):
|
||||
from .pi05.processor_pi05 import make_pi05_pre_post_processors
|
||||
|
||||
|
||||
@@ -178,6 +178,7 @@ N_COLOR_CHANNELS = 3
|
||||
|
||||
|
||||
# config
|
||||
@strict
|
||||
class GR00TN15Config(PretrainedConfig):
|
||||
model_type = "gr00t_n1_5"
|
||||
|
||||
|
||||
@@ -1,42 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""π0.5 v2 — full reproduction of the π0.5 paper's hierarchical
|
||||
inference recipe on lerobot.
|
||||
|
||||
Extends :class:`lerobot.policies.pi05.PI05Policy` with:
|
||||
|
||||
* recipe-driven training (PR 1's :class:`RenderMessagesStep`),
|
||||
* PaliGemma ``lm_head`` cross-entropy on supervised subtask spans
|
||||
(the "high-level subtask prediction" of the paper, §IV.D),
|
||||
* AR text generation at inference (:meth:`PI052Policy.select_message`),
|
||||
* per-component prompt dropout (Pi 0.7 §V.E) for regularising the
|
||||
text head against missing context at inference.
|
||||
|
||||
See ``src/lerobot/configs/recipes/subtasks_vqa.yaml`` for the
|
||||
canonical training recipe and
|
||||
``examples/training/pi052_hirobot.slurm`` for the launcher.
|
||||
"""
|
||||
|
||||
from .configuration_pi052 import PI052Config
|
||||
from .modeling_pi052 import PI052Policy
|
||||
from .processor_pi052 import make_pi052_pre_post_processors
|
||||
from .text_processor_pi052 import PI052TextTokenizerStep
|
||||
|
||||
__all__ = [
|
||||
"PI052Config",
|
||||
"PI052Policy",
|
||||
"PI052TextTokenizerStep",
|
||||
"make_pi052_pre_post_processors",
|
||||
]
|
||||
@@ -1,235 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""π0.5 v2 (with text head) — reproduction of the π0.5 paper's
|
||||
hierarchical inference recipe.
|
||||
|
||||
Same architecture as the existing ``PI05Policy`` (PaliGemma 2B VLM +
|
||||
~300M Gemma action expert, joint training with FAST tokens during
|
||||
pre-train and flow matching during post-train), but with the
|
||||
PaliGemma ``lm_head`` re-enabled so the same model can be supervised
|
||||
to predict both:
|
||||
|
||||
* **subtask strings** at the high level (cross-entropy on the LM
|
||||
head), and
|
||||
* **action chunks** at the low level (flow matching on the
|
||||
action-expert tokens).
|
||||
|
||||
This is the dual-head co-training pattern from the paper:
|
||||
|
||||
L = H(x, f_θ_text) + α * ‖ω - a - f_θ_action(a_τ, o, ℓ)‖²
|
||||
|
||||
with α = 10.0 per § IV.D of arxiv:2504.16054. The π0.5 model splits
|
||||
inference into a text-prediction step followed by an action-prediction
|
||||
step, which the multi-rate ``PI052Runtime`` (in
|
||||
``lerobot.policies.pi052.inference``) drives at separate rates.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from lerobot.configs import PreTrainedConfig
|
||||
from lerobot.optim.optimizers import AdamWConfig
|
||||
|
||||
from ..pi05.configuration_pi05 import PI05Config
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("pi052")
|
||||
@dataclass
|
||||
class PI052Config(PI05Config):
|
||||
"""π0.5 with the PaliGemma LM head re-enabled for subtask prediction.
|
||||
|
||||
Recipe-driven dual-head training: the flow head supervises actions,
|
||||
the LM head supervises subtask / plan / memory / VQA text. The
|
||||
flow:text loss split is the milder 5:1 (see ``flow_loss_weight``).
|
||||
"""
|
||||
|
||||
# Recipe / language stack ---------------------------------------------
|
||||
recipe_path: str | None = "recipes/subtasks_vqa.yaml"
|
||||
"""Path (absolute or relative to ``src/lerobot/configs/``) to a
|
||||
``TrainingRecipe`` YAML. Defaults to the canonical Hi-Robot blend
|
||||
shipped alongside this policy. Set to ``None`` to disable recipe
|
||||
rendering and fall back to π0.5's single-task ``Task: ... Action:``
|
||||
prompt path (unannotated datasets keep working that way)."""
|
||||
|
||||
apply_chat_template: bool = False
|
||||
"""PaliGemma is *not* chat-pretrained — its tokenizer doesn't ship a
|
||||
chat template, so we don't apply one. The recipe renderer's output
|
||||
is concatenated as a plain prefix + assistant suffix instead,
|
||||
mirroring how the π0.5 paper's high-level inference samples text
|
||||
auto-regressively after the prefix."""
|
||||
|
||||
# Loss weights --------------------------------------------------------
|
||||
# Paper §IV.D uses α=10 between the flow and text terms, assuming
|
||||
# text is a rare auxiliary task. With the recipe stack the flow-only
|
||||
# `low_level` branch fires on a large share of samples, so α=10
|
||||
# swamps the LM head and collapses generation into degenerate
|
||||
# repetition. We use the milder 5:1 split here.
|
||||
text_loss_weight: float = 1.0
|
||||
"""Weight on the LM-head cross-entropy term. Set to ``0`` to disable
|
||||
text training entirely (reverts to flow-only / π0.5 behaviour)."""
|
||||
|
||||
flow_loss_weight: float = 5.0
|
||||
"""Weight on the action-expert flow-matching term. ``5.0`` — a milder
|
||||
flow:text split than the paper's α=10, since the flow-only
|
||||
``low_level`` recipe already gives the action expert frequent
|
||||
gradient. Lower it further if the LM head still underfits."""
|
||||
|
||||
# Backbone training ---------------------------------------------------
|
||||
unfreeze_lm_head: bool = True
|
||||
"""Whether to keep the PaliGemma ``lm_head`` unfrozen for fine-tuning.
|
||||
The existing ``PI05Policy`` zeroes / freezes the head on load
|
||||
because it never reads from it. Must be ``True`` for π0.5-style
|
||||
hierarchical inference."""
|
||||
|
||||
# Per-component prompt dropout (Pi0.7 §V.E) ---------------------------
|
||||
# Randomly drop non-target context messages so the LM head learns
|
||||
# to handle missing /
|
||||
# stale plan / memory at inference. Defaults to 0.0 so behaviour
|
||||
# is identical until explicitly enabled.
|
||||
plan_dropout_prob: float = 0.0
|
||||
memory_dropout_prob: float = 0.0
|
||||
subtask_dropout_prob: float = 0.0
|
||||
|
||||
# FAST discrete-action supervision — paper §III.B-C ------------------
|
||||
# When enabled, actions are *also* tokenised via the FAST tokenizer
|
||||
# ("physical-intelligence/fast") and supervised with cross-entropy
|
||||
# on the PaliGemma LM head — exactly as in the paper's pre-training
|
||||
# objective (Eq. 1 mixes FAST CE + flow MSE + subtask CE). The
|
||||
# ActionTokenizerProcessorStep is wired into the preprocessor
|
||||
# pipeline when this flag is set; the loss is computed in
|
||||
# PI052Policy.forward.
|
||||
enable_fast_action_loss: bool = True
|
||||
"""If True, tokenise actions with the FAST tokenizer and add a
|
||||
cross-entropy loss on the LM head. On by default to match the
|
||||
π0.5 paper's three-loss objective (text CE + FAST CE + flow MSE,
|
||||
§III.B-C Eq. 1). Set to False if you only want the
|
||||
post-training-style flow + text recipe."""
|
||||
|
||||
action_tokenizer_name: str = "physical-intelligence/fast"
|
||||
"""HF identifier for the FAST action tokenizer."""
|
||||
|
||||
max_action_tokens: int = 256
|
||||
"""Maximum number of FAST tokens per action chunk."""
|
||||
|
||||
fast_skip_tokens: int = 128
|
||||
"""Number of low-vocab tokens the FAST tokenizer skips to avoid
|
||||
collisions with PaliGemma's text vocabulary."""
|
||||
|
||||
fast_action_loss_weight: float = 1.0
|
||||
"""Weight on the FAST-action-token CE loss. Paper §III.C uses 1.0."""
|
||||
|
||||
auto_fit_fast_tokenizer: bool = False
|
||||
"""If True, the processor factory checks ``fast_tokenizer_cache_dir``
|
||||
for a previously-fitted tokenizer keyed on ``(dataset_repo_id,
|
||||
base_tokenizer_name, fit_samples)``. On cache miss, it loads
|
||||
``action_tokenizer_name`` as a base, samples
|
||||
``fast_tokenizer_fit_samples`` action chunks from the dataset, runs
|
||||
``.fit()``, saves the result, and uses *that* fitted path as the
|
||||
actual tokenizer. Pertsch et al. 2025 (FAST paper [64], π0.5 §III.C)
|
||||
explicitly recommend per-dataset fitting for best compression.
|
||||
|
||||
Off by default because the fit requires a separate pre-training
|
||||
pass over the dataset (~1-2 min on a medium dataset) and depends
|
||||
on the FAST tokenizer snapshot having a ``.fit()`` method. Opt in
|
||||
when you want paper-faithful compression; leave off to fall back
|
||||
on the universal ``physical-intelligence/fast`` codebook."""
|
||||
|
||||
fast_tokenizer_cache_dir: str = "~/.cache/lerobot/fast_tokenizers"
|
||||
"""Where fitted FAST tokenizers are stored. ``~`` expands."""
|
||||
|
||||
fast_tokenizer_fit_samples: int = 1024
|
||||
"""Number of action chunks to sample for the fit. The FAST paper uses
|
||||
a few thousand; 1024 is a reasonable default for medium datasets."""
|
||||
|
||||
# Knowledge insulation — paper §III.B --------------------------------
|
||||
# When enabled, gradients from the action expert's flow loss are
|
||||
# blocked from flowing back into the VLM's K/V projections. This
|
||||
# prevents the action loss from over-fitting the language backbone
|
||||
# to robot-specific features. Implemented in ``modeling_pi052`` as
|
||||
# a per-instance monkey-patch on ``paligemma_with_expert.forward``
|
||||
# that splits queries into VLM and action halves and ``.detach()``-s
|
||||
# the VLM K/V tensors used in the action-half's attention.
|
||||
knowledge_insulation: bool = False
|
||||
"""If True, route every transformer layer through the KI
|
||||
attention path that blocks action→VLM gradient flow on K/V."""
|
||||
|
||||
# Learning-rate defaults --------------------------------------------
|
||||
# pi052 inherits π0.5's openpi-validated optimizer config (peak LR
|
||||
# 2.5e-5, cosine→2.5e-6, 1k warmup, AdamW (0.9, 0.95), wd=0.01,
|
||||
# grad_clip=1.0). The only place pi052 needs to diverge from pi05
|
||||
# is the LM-head LR multiplier: pi05 has no text supervision so the
|
||||
# head doesn't get gradients; pi052 always has text supervision
|
||||
# (subtask / memory / VQA) via the recipe, and under KI the LM head
|
||||
# only sees gradients on ~30–45% of the batch (the text-CE mask
|
||||
# share of the recipe). Under aggressive cosine decay this is too
|
||||
# weak to keep the head pinned, so it drifts back toward PaliGemma's
|
||||
# pretrained ``<loc>`` first-token bias. 5x is the documented fix
|
||||
# (see ``PI05Config.lm_head_lr_scale`` docstring); the wiring is
|
||||
# already in ``PI05Policy.get_optim_params`` — it splits the LM head
|
||||
# + tied ``embed_tokens`` into their own param group while sharing
|
||||
# the same cosine lambda, so the 5x ratio is preserved across decay.
|
||||
lm_head_lr_scale: float = 5.0
|
||||
|
||||
# PaLM-style z-loss on text CE. Penalises the log-partition function
|
||||
# ``z = log Σ exp(logits)`` drifting away from zero — without it, large-
|
||||
# vocab models (PaliGemma is 257k) can let ``logsumexp`` grow unbounded
|
||||
# while CE stays low, because a uniform additive logit bias cancels in
|
||||
# softmax. PaLM appendix B / Chinchilla report z-loss is essential for
|
||||
# stable large-vocab CE; it especially helps under ``lm_head_lr_scale=
|
||||
# 5.0`` which amplifies drift risk on the LM head. ``1e-4`` is the
|
||||
# commonly cited weight; set 0 to disable entirely.
|
||||
text_ce_z_loss_weight: float = 1e-4
|
||||
|
||||
# Liger Triton kernels (rope + geglu + layer_norm) are now patched
|
||||
# unconditionally at model build time — see ``_enable_hf_kernels``
|
||||
# in ``modeling_pi052``. The patch is process-global, idempotent
|
||||
# and degrades gracefully if ``liger-kernel`` is missing. Measured
|
||||
# at -4.5% step time on H100 (bench job 22161421); peak memory
|
||||
# unchanged. ``fused_linear_cross_entropy`` ships separately via
|
||||
# ``_shifted_lin_ce`` / ``_fast_lin_ce``.
|
||||
use_hf_kernels: bool = True
|
||||
"""Deprecated. Liger HF kernels are patched unconditionally by
|
||||
``_enable_hf_kernels`` — this field is retained as a no-op for
|
||||
backward compatibility with checkpoints saved before commit
|
||||
d70c8104 (which still serialize ``use_hf_kernels: true`` into
|
||||
``config.json``). Loading those configs would otherwise raise
|
||||
``DecodingError: The fields use_hf_kernels are not valid for
|
||||
PI052Config`` (job 22164492). Remove in a future major bump."""
|
||||
|
||||
# Optimizer foreach/fused. pi052 carries these locally because the shared
|
||||
# PI05Config (kept identical to upstream main) does not define them; the
|
||||
# checkpoints we train serialize both keys into config.json, so they must
|
||||
# be valid PI052Config fields and flow into the AdamW preset below.
|
||||
optimizer_foreach: bool | None = False
|
||||
optimizer_fused: bool | None = True
|
||||
|
||||
def get_optimizer_preset(self) -> AdamWConfig:
|
||||
return AdamWConfig(
|
||||
lr=self.optimizer_lr,
|
||||
betas=self.optimizer_betas,
|
||||
eps=self.optimizer_eps,
|
||||
weight_decay=self.optimizer_weight_decay,
|
||||
grad_clip_norm=self.optimizer_grad_clip_norm,
|
||||
foreach=self.optimizer_foreach,
|
||||
fused=self.optimizer_fused,
|
||||
)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
super().__post_init__()
|
||||
# Backbone needs gradients flowing through the text head when
|
||||
# we're training it. Override the π0.5 default
|
||||
# (``train_expert_only=True``) unless the user explicitly opts
|
||||
# out of text training via ``text_loss_weight=0``.
|
||||
if self.text_loss_weight > 0 and self.unfreeze_lm_head:
|
||||
self.train_expert_only = False
|
||||
@@ -1,304 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Dataset-specific FAST action tokenizer fitting.
|
||||
|
||||
The published ``physical-intelligence/fast`` tokenizer is a *universal*
|
||||
codebook fitted on a heterogeneous mix of robot datasets. Per Pertsch
|
||||
et al. 2025 (the FAST paper, [64] in the π0.5 paper) and §III.C of
|
||||
π0.5 itself, the recommended practice is to **finetune the tokenizer on
|
||||
your specific dataset's action distribution** before training the
|
||||
policy — same way one would adapt a language tokenizer to a domain
|
||||
corpus. Without this finetune step, action sequences from your robot
|
||||
may require more tokens per chunk than necessary, lowering effective
|
||||
compression and slowing convergence of the action-CE loss.
|
||||
|
||||
This module provides a single utility, :func:`fit_fast_tokenizer`,
|
||||
that does the finetune. The training entry point invokes it
|
||||
automatically when the policy's ``enable_fast_action_loss`` and
|
||||
``auto_fit_fast_tokenizer`` flags are both ``True`` and no cached
|
||||
fitted tokenizer is found at ``fast_tokenizer_cache_dir``.
|
||||
|
||||
The fitted tokenizer is saved to
|
||||
``{cache_dir}/{dataset_hash}_{base_hash}/`` so successive training
|
||||
runs over the same dataset re-use it.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Marker file the cache-hit check looks for. ``ProcessorMixin.save_pretrained``
|
||||
# writes ``processor_config.json`` (NOT ``preprocessor_config.json`` —
|
||||
# that's the image / feature-extractor convention). Centralised here so
|
||||
# the cache-hit check and the rank-N readiness wait agree on the same
|
||||
# sentinel.
|
||||
_CACHE_SENTINEL = "processor_config.json"
|
||||
|
||||
|
||||
def _dataset_signature(
|
||||
dataset_repo_id: str,
|
||||
base_tokenizer_name: str,
|
||||
n_samples: int,
|
||||
chunk_size: int,
|
||||
) -> str:
|
||||
"""Deterministic short hash for naming the cache directory.
|
||||
|
||||
Keys on (dataset, base tokenizer, sample count, chunk size) so any
|
||||
of those changing re-runs the fit. ``chunk_size`` matters because
|
||||
the tokenizer is fit on chunks of that length.
|
||||
"""
|
||||
h = hashlib.sha256()
|
||||
h.update(dataset_repo_id.encode("utf-8"))
|
||||
h.update(b"\0")
|
||||
h.update(base_tokenizer_name.encode("utf-8"))
|
||||
h.update(b"\0")
|
||||
h.update(str(n_samples).encode("utf-8"))
|
||||
h.update(b"\0")
|
||||
h.update(str(chunk_size).encode("utf-8"))
|
||||
return h.hexdigest()[:16]
|
||||
|
||||
|
||||
def fit_fast_tokenizer(
|
||||
*,
|
||||
dataset_repo_id: str,
|
||||
cache_dir: str | Path,
|
||||
base_tokenizer_name: str = "physical-intelligence/fast",
|
||||
n_samples: int = 1024,
|
||||
chunk_size: int = 50,
|
||||
seed: int = 42,
|
||||
) -> str:
|
||||
"""Fit a FAST tokenizer on a LeRobot dataset's action distribution.
|
||||
|
||||
Args:
|
||||
dataset_repo_id: HF Hub repo id of the LeRobotDataset to fit on.
|
||||
cache_dir: Directory under which to save (and look up) fitted
|
||||
tokenizers. The actual save path is
|
||||
``{cache_dir}/{signature}``.
|
||||
base_tokenizer_name: HF identifier for the base FAST tokenizer
|
||||
to finetune from. ``physical-intelligence/fast`` is the
|
||||
universal one.
|
||||
n_samples: Number of action chunks to sample for the fit. The
|
||||
FAST paper uses a few thousand; ``1024`` is a good default
|
||||
for medium datasets.
|
||||
chunk_size: Length of each action chunk (matches
|
||||
``policy.chunk_size``). The FAST tokenizer is fit on
|
||||
sequences of this length.
|
||||
seed: RNG seed for sample selection.
|
||||
|
||||
Returns:
|
||||
The local path to the fitted tokenizer. Passed directly to
|
||||
``--policy.action_tokenizer_name`` for the training run.
|
||||
|
||||
Raises:
|
||||
ImportError: If the ``transformers`` library doesn't expose
|
||||
``AutoProcessor`` or the FAST tokenizer doesn't have a
|
||||
``.fit()`` method (then you're on an older FAST snapshot —
|
||||
update to the current published model).
|
||||
FileNotFoundError: If the dataset can't be loaded.
|
||||
"""
|
||||
cache_dir = Path(cache_dir)
|
||||
sig = _dataset_signature(dataset_repo_id, base_tokenizer_name, n_samples, chunk_size)
|
||||
out_dir = cache_dir / sig
|
||||
|
||||
if out_dir.exists() and (out_dir / _CACHE_SENTINEL).exists():
|
||||
logger.info(
|
||||
"FAST tokenizer cache hit: %s — re-using fitted tokenizer for "
|
||||
"dataset=%s base=%s n_samples=%d",
|
||||
out_dir, dataset_repo_id, base_tokenizer_name, n_samples,
|
||||
)
|
||||
return str(out_dir)
|
||||
|
||||
# DDP-safe fit: only the (local) main process actually fits + saves;
|
||||
# other ranks poll the cache sentinel until the leader is done.
|
||||
# Without this guard, all N ranks fit concurrently and race on
|
||||
# ``save_pretrained`` + ``AutoProcessor.from_pretrained`` (the latter
|
||||
# copies ``processing_action_tokenizer.py`` into ``HF_MODULES_CACHE``
|
||||
# and compiles a ``.pyc`` — concurrent writers occasionally produce
|
||||
# a stale / partial ``.pyc`` and the subsequent ``from .. import
|
||||
# UniversalActionProcessor`` raises ``AttributeError``.
|
||||
is_leader = (
|
||||
int(os.environ.get("RANK", "0")) == 0
|
||||
and int(os.environ.get("LOCAL_RANK", "0")) == 0
|
||||
)
|
||||
if not is_leader:
|
||||
timeout_s = 1800.0 # 30 min — covers ~1024-sample fits on cold caches
|
||||
start = time.monotonic()
|
||||
while not (out_dir / _CACHE_SENTINEL).exists():
|
||||
if time.monotonic() - start > timeout_s:
|
||||
raise RuntimeError(
|
||||
f"FAST tokenizer fit: non-leader rank timed out after "
|
||||
f"{timeout_s:.0f}s waiting for {out_dir / _CACHE_SENTINEL}. "
|
||||
"Leader rank likely crashed during the fit."
|
||||
)
|
||||
time.sleep(2.0)
|
||||
logger.info("FAST tokenizer ready (leader populated cache): %s", out_dir)
|
||||
return str(out_dir)
|
||||
|
||||
logger.info(
|
||||
"FAST tokenizer cache miss — fitting on dataset=%s "
|
||||
"base=%s n_samples=%d chunk_size=%d → %s",
|
||||
dataset_repo_id, base_tokenizer_name, n_samples, chunk_size, out_dir,
|
||||
)
|
||||
|
||||
from transformers import AutoProcessor # noqa: PLC0415
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset # noqa: PLC0415
|
||||
|
||||
# Stream a single episode's worth of action chunks at a time so
|
||||
# we don't blow memory on huge datasets. Random episode +
|
||||
# random start offset gives a reasonable spread.
|
||||
#
|
||||
# Actions are read straight from the underlying HF dataset's
|
||||
# ``action`` *column* — never via ``ds[i]``. ``ds[i]`` builds a full
|
||||
# training item (delta-timestamp expansion + video decode + image
|
||||
# transforms); a single bad video frame would then throw and, since
|
||||
# the failure was swallowed at debug level, silently starve the fit
|
||||
# of every chunk. The action column carries no video, so reading it
|
||||
# directly is both faster and immune to decode errors.
|
||||
rng = np.random.default_rng(seed)
|
||||
actions_buf: list[np.ndarray] = []
|
||||
|
||||
# Resolve the dataset's data parquet shards directly, sidestepping
|
||||
# ``LeRobotDataset(repo_id, episodes=[N])`` which on v3-format
|
||||
# datasets routes through HF datasets'' split lookup and raises
|
||||
# ``ValueError: Instruction "train" corresponds to no data!`` for
|
||||
# every episode (job 22182985 looped through 13,293 skipped episodes
|
||||
# for ~2.5 h before NCCL killed it). Reading the ``action`` column
|
||||
# straight from the parquet shards is also faster: each per-episode
|
||||
# ``LeRobotDataset`` instantiation re-parses every meta file.
|
||||
from huggingface_hub import snapshot_download # noqa: PLC0415
|
||||
import pyarrow as _pa # noqa: PLC0415
|
||||
import pyarrow.parquet as _pq # noqa: PLC0415
|
||||
|
||||
snap = Path(snapshot_download(repo_id=dataset_repo_id, repo_type="dataset"))
|
||||
data_files = sorted((snap / "data").glob("chunk-*/file-*.parquet"))
|
||||
if not data_files:
|
||||
raise RuntimeError(
|
||||
f"FAST fit: no ``data/chunk-*/file-*.parquet`` shards found under {snap!s}."
|
||||
)
|
||||
|
||||
# Read just the (episode_index, action) columns once across all
|
||||
# shards. This is the same pattern used elsewhere in the codebase
|
||||
# for whole-dataset audits and stays under ~2 GB even on 32 k-episode
|
||||
# / 29 M-frame datasets because the action column is a fixed-length
|
||||
# float vector.
|
||||
tables = [_pq.read_table(f, columns=["episode_index", "action"]) for f in data_files]
|
||||
table = _pa.concat_tables(tables)
|
||||
eps = table["episode_index"].to_numpy()
|
||||
acts_col = table["action"]
|
||||
# ``action`` may be a fixed-shape ListArray or a 2-D NumericArray;
|
||||
# ``to_numpy(zero_copy_only=False)`` produces an object array of
|
||||
# 1-D NumPy actions either way, which we stack into (N, D).
|
||||
try:
|
||||
acts = np.stack(acts_col.to_numpy(zero_copy_only=False)).astype(np.float32)
|
||||
except Exception: # noqa: BLE001
|
||||
# Fallback path for nested-list types: flatten via to_pylist().
|
||||
acts = np.asarray(acts_col.to_pylist(), dtype=np.float32)
|
||||
if acts.ndim != 2:
|
||||
raise RuntimeError(
|
||||
f"FAST fit: expected ``action`` rows to be 1-D vectors; got shape {acts.shape}."
|
||||
)
|
||||
|
||||
# Episode index → slice (start, stop) into ``acts`` along axis 0.
|
||||
# ``eps`` is monotonically increasing within each parquet shard but
|
||||
# we make no assumption across shards — sort once and group.
|
||||
order = np.argsort(eps, kind="stable")
|
||||
eps_sorted = eps[order]
|
||||
boundaries = np.searchsorted(eps_sorted, np.arange(int(eps_sorted.max()) + 2))
|
||||
ep_to_slice: dict[int, tuple[int, int]] = {
|
||||
int(ep): (int(boundaries[ep]), int(boundaries[ep + 1]))
|
||||
for ep in range(len(boundaries) - 1)
|
||||
if boundaries[ep] < boundaries[ep + 1]
|
||||
}
|
||||
num_episodes = len(ep_to_slice)
|
||||
# ``acts`` is in original (un-sorted-by-episode) row order; reorder
|
||||
# so per-episode slices are contiguous.
|
||||
acts = acts[order]
|
||||
|
||||
samples_per_episode = max(1, n_samples // max(num_episodes, 1))
|
||||
collected = 0
|
||||
eps_visited = 0
|
||||
short_episodes = 0
|
||||
ep_indices = list(ep_to_slice.keys())
|
||||
for ep_idx in rng.permutation(ep_indices):
|
||||
if collected >= n_samples:
|
||||
break
|
||||
start, stop = ep_to_slice[int(ep_idx)]
|
||||
ep_actions = acts[start:stop]
|
||||
if ep_actions.shape[0] < chunk_size:
|
||||
short_episodes += 1
|
||||
continue
|
||||
starts = rng.integers(0, ep_actions.shape[0] - chunk_size + 1, size=samples_per_episode)
|
||||
for s in starts:
|
||||
actions_buf.append(ep_actions[int(s) : int(s) + chunk_size])
|
||||
collected += 1
|
||||
if collected >= n_samples:
|
||||
break
|
||||
eps_visited += 1
|
||||
|
||||
if not actions_buf:
|
||||
raise RuntimeError(
|
||||
f"FAST fit collected zero action chunks from {dataset_repo_id!r}: "
|
||||
f"all {num_episodes} episodes were shorter than chunk_size="
|
||||
f"{chunk_size} ({short_episodes} too short) or had an unreadable "
|
||||
"``action`` column. Lower ``chunk_size`` to match your episode "
|
||||
"lengths."
|
||||
)
|
||||
|
||||
actions = np.stack(actions_buf, axis=0).astype(np.float32) # (N, H, D)
|
||||
logger.info(
|
||||
"FAST fit: collected %d chunks of shape %s from %d episodes",
|
||||
actions.shape[0], actions.shape[1:], eps_visited,
|
||||
)
|
||||
|
||||
# Quantile-normalise per dimension before fitting.
|
||||
#
|
||||
# The FAST tokenizer DCT-transforms actions, scales by ``scale`` and
|
||||
# rounds to integer tokens; the integer *range* must fit the
|
||||
# codebook (vocab_size, default 1024). Raw motor units (e.g. encoder
|
||||
# ticks) blow that range up — hence "Vocab size 1024 is too small".
|
||||
# More importantly, at training time ``ActionTokenizerProcessorStep``
|
||||
# runs *after* the QUANTILES ``NormalizerProcessorStep``, so it
|
||||
# encodes normalised actions. Fitting on raw actions would mismatch
|
||||
# that space. We replicate QUANTILES normalisation here (per-dim
|
||||
# [q01, q99] → [-1, 1], clipped) so the fit and the training-time
|
||||
# encode see the same distribution.
|
||||
flat = actions.reshape(-1, actions.shape[-1])
|
||||
q01 = np.quantile(flat, 0.01, axis=0)
|
||||
q99 = np.quantile(flat, 0.99, axis=0)
|
||||
span = np.where((q99 - q01) > 1e-6, q99 - q01, 1.0)
|
||||
actions = np.clip((actions - q01) / span * 2.0 - 1.0, -1.0, 1.0).astype(np.float32)
|
||||
|
||||
base = AutoProcessor.from_pretrained(base_tokenizer_name, trust_remote_code=True)
|
||||
if not hasattr(base, "fit"):
|
||||
raise ImportError(
|
||||
f"Base FAST tokenizer {base_tokenizer_name!r} has no ``.fit()`` "
|
||||
"method — your transformers / model snapshot is too old. Update "
|
||||
"to the current ``physical-intelligence/fast`` revision."
|
||||
)
|
||||
|
||||
fitted = base.fit(actions)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
fitted.save_pretrained(str(out_dir))
|
||||
logger.info("FAST fit: saved fitted tokenizer to %s", out_dir)
|
||||
return str(out_dir)
|
||||
@@ -1,73 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""PI052 inference / runtime orchestration.
|
||||
|
||||
Multi-rate runtime that mirrors the recipe-time training shape:
|
||||
|
||||
low_level_execution → LowLevelForward + DispatchAction (high Hz)
|
||||
high_level_subtask → HighLevelSubtaskFwd (~1 Hz)
|
||||
memory_update → MemoryUpdateFwd (event: subtask_change)
|
||||
user_interjection_response → UserInterjectionFwd (event: stdin)
|
||||
ask_vqa_* → AskVQAFwd (event: stdin question)
|
||||
speech tool calls → DispatchToolCalls (event: tool_call_pending)
|
||||
|
||||
The CLI ``lerobot-pi052-runtime`` builds a ``PI052Runtime`` and calls
|
||||
``run()``.
|
||||
"""
|
||||
|
||||
from .repl import StdinReader
|
||||
from .runtime import PI052Runtime
|
||||
from .runtime_state import initial_runtime_state, push_log, set_if_changed, take_event
|
||||
from .steps import (
|
||||
AskVQAFwd,
|
||||
DispatchAction,
|
||||
DispatchToolCalls,
|
||||
HighLevelSubtaskFwd,
|
||||
InferenceStep,
|
||||
LowLevelForward,
|
||||
MemoryUpdateFwd,
|
||||
UserInterjectionFwd,
|
||||
)
|
||||
from .triggers import EventTrigger, HzTrigger, Tick, TickClock, Trigger
|
||||
from .ui import make_state_panel, print_robot_lines, print_user_line
|
||||
|
||||
__all__ = [
|
||||
# runtime
|
||||
"PI052Runtime",
|
||||
"StdinReader",
|
||||
# state helpers
|
||||
"initial_runtime_state",
|
||||
"push_log",
|
||||
"set_if_changed",
|
||||
"take_event",
|
||||
# triggers
|
||||
"Trigger",
|
||||
"Tick",
|
||||
"TickClock",
|
||||
"HzTrigger",
|
||||
"EventTrigger",
|
||||
# steps
|
||||
"InferenceStep",
|
||||
"LowLevelForward",
|
||||
"DispatchAction",
|
||||
"HighLevelSubtaskFwd",
|
||||
"MemoryUpdateFwd",
|
||||
"UserInterjectionFwd",
|
||||
"AskVQAFwd",
|
||||
"DispatchToolCalls",
|
||||
# UI
|
||||
"make_state_panel",
|
||||
"print_robot_lines",
|
||||
"print_user_line",
|
||||
]
|
||||
@@ -1,105 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Stdin REPL event collector for the PI052 runtime.
|
||||
|
||||
Reads non-blocking stdin lines, classifies each one heuristically:
|
||||
|
||||
"stop" / "quit" / "exit" → state["stop"] = True
|
||||
"/action" / "/pause" → set state["mode"]
|
||||
ends with "?" → user_vqa_query event
|
||||
starts with "task:" or first line → set runtime task
|
||||
anything else → user_interjection event
|
||||
|
||||
Plugged into the runtime via ``event_collector=StdinReader().poll``.
|
||||
|
||||
Note: the shipped CLI (``lerobot-pi052-runtime``) drives stdin
|
||||
directly in its REPL / autonomous loops and does *not* wire this
|
||||
collector; it's kept as the documented embedding hook and for tests.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import select
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class StdinReader:
|
||||
"""Non-blocking stdin line collector for the runtime loop."""
|
||||
|
||||
prompt: str = "> "
|
||||
_seen_first_line: bool = field(default=False, init=False)
|
||||
_prompted: bool = field(default=False, init=False)
|
||||
|
||||
def poll(self, state: dict[str, Any]) -> None:
|
||||
"""Drain pending stdin lines into runtime events."""
|
||||
# Print the input prompt once on every fresh tick if we don't
|
||||
# already have a pending line; matches the expected REPL feel.
|
||||
if not self._prompted:
|
||||
print(self.prompt, end="", flush=True)
|
||||
self._prompted = True
|
||||
|
||||
# ``select`` with timeout=0 makes this non-blocking. Only works
|
||||
# for actual TTY / pipe stdins; CI / scripted runs hit EOF.
|
||||
try:
|
||||
ready, _, _ = select.select([sys.stdin], [], [], 0)
|
||||
except (ValueError, OSError):
|
||||
return
|
||||
if not ready:
|
||||
return
|
||||
|
||||
line = sys.stdin.readline()
|
||||
if not line: # EOF
|
||||
state["stop"] = True
|
||||
return
|
||||
line = line.strip()
|
||||
self._prompted = False # we'll re-prompt next tick
|
||||
if not line:
|
||||
return
|
||||
|
||||
lower = line.lower()
|
||||
if lower in {"stop", "quit", "exit"}:
|
||||
state["stop"] = True
|
||||
return
|
||||
|
||||
# Slash commands flip the run mode. ``/pause`` stops the action
|
||||
# loop (the action steps gate on ``state["mode"]``); ``/action``
|
||||
# resumes it.
|
||||
if lower.split(" ", 1)[0] in {"/action", "/act", "/run"}:
|
||||
state["mode"] = "action"
|
||||
return
|
||||
if lower in {"/pause", "/p"}:
|
||||
state["mode"] = "paused"
|
||||
queue = state.get("action_queue")
|
||||
if hasattr(queue, "clear"):
|
||||
queue.clear()
|
||||
return
|
||||
|
||||
# First non-control line sets the task if no task is active.
|
||||
if not state.get("task"):
|
||||
task = line[5:].strip() if lower.startswith("task:") else line
|
||||
state["task"] = task
|
||||
print(f"[pi052] Task: {task}", flush=True)
|
||||
self._seen_first_line = True
|
||||
return
|
||||
|
||||
# Question → VQA; statement → interjection.
|
||||
if lower.endswith("?"):
|
||||
state["recent_vqa_query"] = line
|
||||
state.setdefault("events_this_tick", []).append("user_vqa_query")
|
||||
else:
|
||||
state["recent_interjection"] = line
|
||||
state.setdefault("events_this_tick", []).append("user_interjection")
|
||||
@@ -1,205 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""PI052 runtime loop.
|
||||
|
||||
Threads the multi-rate inference pipeline together with a stdin REPL
|
||||
event collector, drives ticks through :class:`TickClock`, and prints
|
||||
state-change updates to the user.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections import deque
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable
|
||||
|
||||
from .runtime_state import initial_runtime_state, push_log
|
||||
from .steps import (
|
||||
AskVQAFwd,
|
||||
DispatchAction,
|
||||
DispatchToolCalls,
|
||||
HighLevelSubtaskFwd,
|
||||
InferenceStep,
|
||||
LowLevelForward,
|
||||
MemoryUpdateFwd,
|
||||
)
|
||||
from .triggers import EventTrigger, HzTrigger, TickClock
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PI052Runtime:
|
||||
"""Compose the inference pipeline and drive it tick-by-tick."""
|
||||
|
||||
policy: Any
|
||||
tools: dict[str, Any] = field(default_factory=dict)
|
||||
"""Name → tool-instance dict, e.g. ``{"say": SayTool(...)}``. Read
|
||||
from :func:`lerobot.tools.get_tools(meta)` when wiring the
|
||||
runtime."""
|
||||
observation_provider: Callable[[], dict | None] | None = None
|
||||
"""Closure returning the current preprocessed observation batch.
|
||||
``None`` for dry-run / language-only sessions."""
|
||||
robot_executor: Callable[[Any], None] | None = None
|
||||
"""Closure that takes one action chunk and forwards it to the
|
||||
robot. ``None`` for dry-run."""
|
||||
event_collector: Callable[[dict], None] | None = None
|
||||
"""Per-tick hook that polls external sources (stdin, network) and
|
||||
appends event names to ``state["events_this_tick"]``."""
|
||||
chunk_hz: float = 4.0
|
||||
ctrl_hz: float = 50.0
|
||||
high_level_hz: float = 1.0
|
||||
max_rate_hz: float = 50.0
|
||||
|
||||
pipeline: list[InferenceStep] = field(init=False)
|
||||
state: dict[str, Any] = field(init=False)
|
||||
_stop: bool = field(default=False, init=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# Subtask + memory + VQA configuration. Pipeline:
|
||||
#
|
||||
# HighLevelSubtaskFwd → generate the next subtask via the LM
|
||||
# head at ~``high_level_hz``; writes
|
||||
# ``current_subtask`` and emits
|
||||
# ``subtask_change`` on a transition.
|
||||
# MemoryUpdateFwd → on ``subtask_change``, refresh
|
||||
# ``current_memory`` from the
|
||||
# ``memory_update`` head.
|
||||
# AskVQAFwd → answer camera-grounded stdin questions.
|
||||
# LowLevelForward → action chunk conditioned on the
|
||||
# generated ``current_subtask``.
|
||||
# DispatchAction → drain the chunk to the robot.
|
||||
# DispatchToolCalls → fire any pending tool calls.
|
||||
#
|
||||
# Order matters: ``HighLevelSubtaskFwd`` must run before
|
||||
# ``MemoryUpdateFwd`` so the event is visible the same tick, and
|
||||
# both must run before ``LowLevelForward`` (which is gated on
|
||||
# "action queue empty") so the chunk consumes the freshest
|
||||
# subtask. ``UserInterjectionFwd`` is still importable but
|
||||
# disabled until plan generation is wired in.
|
||||
self.pipeline = [
|
||||
HighLevelSubtaskFwd(
|
||||
trigger=HzTrigger(self.high_level_hz),
|
||||
policy=self.policy,
|
||||
observation_provider=self.observation_provider,
|
||||
),
|
||||
# Listens for the ``subtask_change`` event raised by
|
||||
# ``HighLevelSubtaskFwd`` and refreshes ``current_memory``.
|
||||
MemoryUpdateFwd(
|
||||
trigger=EventTrigger("subtask_change"),
|
||||
policy=self.policy,
|
||||
observation_provider=self.observation_provider,
|
||||
),
|
||||
AskVQAFwd(
|
||||
policy=self.policy,
|
||||
observation_provider=self.observation_provider,
|
||||
),
|
||||
LowLevelForward(
|
||||
trigger=HzTrigger(self.chunk_hz),
|
||||
policy=self.policy,
|
||||
observation_provider=self.observation_provider,
|
||||
),
|
||||
DispatchAction(
|
||||
trigger=HzTrigger(self.ctrl_hz),
|
||||
robot_executor=self.robot_executor,
|
||||
),
|
||||
DispatchToolCalls(tools=self.tools),
|
||||
]
|
||||
self.state = initial_runtime_state()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Lifecycle
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def set_task(self, task: str) -> None:
|
||||
"""Set or replace the active task. Logged for the REPL."""
|
||||
self.state["task"] = task
|
||||
push_log(self.state, f"Task: {task}")
|
||||
|
||||
def stop(self) -> None:
|
||||
self._stop = True
|
||||
|
||||
def run(self, *, max_ticks: int | None = None) -> None:
|
||||
"""Main loop. Returns when ``stop()`` is called or after
|
||||
``max_ticks`` ticks (useful for tests / dry-run)."""
|
||||
clock = TickClock(max_rate_hz=self.max_rate_hz)
|
||||
while not self._stop:
|
||||
tick = clock.advance()
|
||||
self.state["_tick"] = tick
|
||||
self.state["events_this_tick"] = []
|
||||
self.state["log_lines"] = []
|
||||
|
||||
if self.event_collector is not None:
|
||||
self.event_collector(self.state)
|
||||
if self.state.get("stop"):
|
||||
self._stop = True
|
||||
break
|
||||
|
||||
for step in self.pipeline:
|
||||
self.state = step(self.state)
|
||||
|
||||
self._flush_logs()
|
||||
if max_ticks is not None and tick.index >= max_ticks:
|
||||
break
|
||||
|
||||
self._on_shutdown()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# REPL helper: drive one full pipeline pass and return its logs
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def step_once(self) -> list[str]:
|
||||
"""Run one tick of the pipeline and return the log lines.
|
||||
|
||||
Used by the interactive REPL: instead of a background thread,
|
||||
the CLI drives ticks synchronously after each user input. Logs
|
||||
are returned (not printed) so the caller can route them into
|
||||
the rich-Live chat scrollback.
|
||||
"""
|
||||
from .triggers import Tick # noqa: PLC0415
|
||||
|
||||
# Synthesize a tick. We don't need the real wall-clock pacing
|
||||
# here — the REPL drives the runtime, not vice versa — but
|
||||
# ``HzTrigger`` uses ``tick.monotonic_seconds`` to gate, so we
|
||||
# bump it generously so every Hz-triggered step considers
|
||||
# itself due.
|
||||
import time as _time # noqa: PLC0415
|
||||
|
||||
prev_index = self.state.get("_tick").index if isinstance(self.state.get("_tick"), Tick) else 0
|
||||
self.state["_tick"] = Tick(index=prev_index + 1, monotonic_seconds=_time.monotonic())
|
||||
self.state["log_lines"] = []
|
||||
# ``events_this_tick`` is set up by the caller before
|
||||
# ``step_once`` (the REPL pushes user-driven events first).
|
||||
self.state.setdefault("events_this_tick", [])
|
||||
|
||||
for step in self.pipeline:
|
||||
self.state = step(self.state)
|
||||
|
||||
return list(self.state.get("log_lines") or [])
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# I/O
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _flush_logs(self) -> None:
|
||||
for line in self.state.get("log_lines") or []:
|
||||
print(f"[pi052] {line}", flush=True)
|
||||
|
||||
def _on_shutdown(self) -> None:
|
||||
# Drain any queued action chunks safely.
|
||||
queue = self.state.get("action_queue")
|
||||
if isinstance(queue, deque):
|
||||
queue.clear()
|
||||
print("[pi052] runtime stopped", flush=True)
|
||||
@@ -1,95 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Runtime state passed between inference steps each tick.
|
||||
|
||||
The runtime threads a single dict through the pipeline; this module
|
||||
documents the shape and provides factories. We use a plain ``dict``
|
||||
rather than a frozen dataclass because steps freely add and remove
|
||||
keys (``events_this_tick``, ``messages_pending``, ``tool_calls_pending``,
|
||||
…) and dataclass field churn would just get in the way.
|
||||
|
||||
Stable keys (read by multiple steps):
|
||||
|
||||
task str the current top-level task
|
||||
current_plan str | None latest plan emitted by the planner
|
||||
current_subtask str | None latest subtask the policy is executing
|
||||
current_memory str | None latest compressed memory
|
||||
recent_interjection str | None most recent user interjection text (consumed)
|
||||
|
||||
action_queue collections.deque[Tensor] pending action chunks
|
||||
tool_calls_pending list[dict] parsed but not-yet-dispatched tool calls
|
||||
|
||||
events_this_tick list[str] triggers consumed this tick
|
||||
_tick Tick current tick (set by the loop)
|
||||
|
||||
mode str "action" (run the robot) | "paused"
|
||||
(action loop stopped — robot holds)
|
||||
|
||||
log_lines list[str] human-readable status lines printed each tick
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import deque
|
||||
from typing import Any
|
||||
|
||||
|
||||
def initial_runtime_state(task: str | None = None) -> dict[str, Any]:
|
||||
"""Build a fresh runtime state dict with sensible defaults."""
|
||||
return {
|
||||
"task": task,
|
||||
"current_plan": None,
|
||||
"current_subtask": None,
|
||||
"current_memory": None,
|
||||
"recent_interjection": None,
|
||||
"action_queue": deque(),
|
||||
"tool_calls_pending": [],
|
||||
"events_this_tick": [],
|
||||
"log_lines": [],
|
||||
"mode": "action",
|
||||
"stop": False,
|
||||
}
|
||||
|
||||
|
||||
def take_event(state: dict[str, Any], event_name: str) -> bool:
|
||||
"""Pop ``event_name`` from ``events_this_tick`` if present.
|
||||
|
||||
Steps that consume an event call this so the same event doesn't
|
||||
re-fire on a sibling step within the same tick.
|
||||
"""
|
||||
events: list[str] = state.get("events_this_tick") or []
|
||||
if event_name in events:
|
||||
events.remove(event_name)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def push_log(state: dict[str, Any], line: str) -> None:
|
||||
"""Append ``line`` to the per-tick log buffer; the runtime prints
|
||||
it at the end of the tick."""
|
||||
state.setdefault("log_lines", []).append(line)
|
||||
|
||||
|
||||
def set_if_changed(state: dict[str, Any], key: str, value: Any, label: str | None = None) -> bool:
|
||||
"""Update ``state[key]`` and log a diff line if the value changed.
|
||||
|
||||
Returns ``True`` if the value actually changed.
|
||||
"""
|
||||
prev = state.get(key)
|
||||
if prev == value:
|
||||
return False
|
||||
state[key] = value
|
||||
if label is not None:
|
||||
push_log(state, f" {label}: {value}")
|
||||
return True
|
||||
@@ -1,955 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference steps for the PI052 multi-rate runtime.
|
||||
|
||||
Each step is a tiny class with a ``trigger`` and an ``__call__(state)``;
|
||||
the runtime applies them in order each tick. When a step's trigger
|
||||
doesn't fire, the step is a no-op and the runtime moves on.
|
||||
|
||||
Stream-to-step mapping mirrors the ``subtasks_vqa.yaml`` recipe:
|
||||
|
||||
* ``LowLevelForward`` — calls ``policy.select_action`` for the
|
||||
action chunk; trained by
|
||||
``low_level_execution``
|
||||
* ``EnqueueChunk`` — pushes the chunk to ``action_queue``
|
||||
* ``DispatchAction`` — pops one action per control tick and
|
||||
forwards to the robot
|
||||
* ``HighLevelSubtaskFwd`` — calls ``policy.select_message`` for the
|
||||
next subtask; trained by
|
||||
``high_level_subtask``
|
||||
* ``MemoryUpdateFwd`` — fires on subtask boundary; trained by
|
||||
``memory_update``
|
||||
* ``UserInterjectionFwd`` — fires on stdin interjection; trained by
|
||||
``user_interjection_response``
|
||||
* ``AskVQAFwd`` — fires on stdin question; trained by
|
||||
``ask_vqa_*``
|
||||
* ``DispatchToolCalls`` — pops ``tool_calls_pending`` and calls
|
||||
the matching ``Tool`` instance
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from .runtime_state import push_log, set_if_changed, take_event
|
||||
from .triggers import EventTrigger, HzTrigger, Trigger
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Step base + runner
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class InferenceStep:
|
||||
"""A trigger-gated callable. Subclasses override :meth:`run`."""
|
||||
|
||||
trigger: Trigger
|
||||
|
||||
def __call__(self, state: dict[str, Any]) -> dict[str, Any]:
|
||||
if not self.trigger.should_fire(state["_tick"], state):
|
||||
return state
|
||||
return self.run(state) or state
|
||||
|
||||
def run(self, state: dict[str, Any]) -> dict[str, Any] | None: # pragma: no cover
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Low-level (action) path
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class LowLevelForward(InferenceStep):
|
||||
"""Run the policy's action head and produce one action chunk."""
|
||||
|
||||
policy: Any = None
|
||||
observation_provider: Any = None
|
||||
"""Callable ``() -> dict``: returns the current observation batch
|
||||
(already preprocessed). Typically wraps the robot's camera /
|
||||
proprio reads. ``None`` in dry-run mode → step skips."""
|
||||
|
||||
trigger: Trigger = field(default_factory=lambda: HzTrigger(hz=4.0))
|
||||
|
||||
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
|
||||
if self.policy is None or self.observation_provider is None:
|
||||
return None
|
||||
# ``/vlm`` mode pauses the whole action loop so the robot holds
|
||||
# position while the operator probes the VLM with VQA.
|
||||
if state.get("mode", "action") != "action":
|
||||
return None
|
||||
if not state.get("task"):
|
||||
return None
|
||||
|
||||
# PI052 produces *action chunks* (typically 50 steps via
|
||||
# flow-matching). Every step gets dispatched to the robot;
|
||||
# popping one per dispatch tick is essentially free. Only
|
||||
# generate a new chunk once the previous one has fully
|
||||
# drained — this is the canonical "sense → think → act"
|
||||
# loop. Refreshing while a chunk is still queued causes the
|
||||
# new chunk to "telescope" past the old one (planned from an
|
||||
# observation that's already 25+ steps stale by the time it
|
||||
# starts dispatching).
|
||||
queue = state.setdefault("action_queue", [])
|
||||
if len(queue) > 0:
|
||||
return None
|
||||
|
||||
observation = self.observation_provider()
|
||||
if observation is None:
|
||||
return None
|
||||
|
||||
# The action expert is conditioned on the SUBTASK generated by
|
||||
# the high-level loop (``HighLevelSubtaskFwd`` runs earlier in
|
||||
# the pipeline and writes ``current_subtask``). Matches the
|
||||
# training-time ``low_level_execution`` recipe — ``user(${subtask})``.
|
||||
# Falls back to the task string only on the very first frame,
|
||||
# before the high-level loop has produced a subtask.
|
||||
subtask = state.get("current_subtask") or state.get("task") or ""
|
||||
ctx = [{"role": "user", "content": subtask}]
|
||||
# ``add_generation_prompt=False`` to match the training-time
|
||||
# prefix shape: at training the action expert sees the rendered
|
||||
# user turn ending at ``<|im_end|>`` (no trailing
|
||||
# ``<|im_start|>assistant\n``). Passing True here would append
|
||||
# extra role-marker tokens the action expert never saw during
|
||||
# training.
|
||||
text_batch = _build_text_batch(self.policy, ctx, add_generation_prompt=False)
|
||||
from lerobot.utils.constants import ( # noqa: PLC0415
|
||||
OBS_LANGUAGE_ATTENTION_MASK,
|
||||
OBS_LANGUAGE_TOKENS,
|
||||
)
|
||||
|
||||
observation = dict(observation)
|
||||
observation[OBS_LANGUAGE_TOKENS] = text_batch["lang_tokens"]
|
||||
observation[OBS_LANGUAGE_ATTENTION_MASK] = text_batch["lang_masks"]
|
||||
|
||||
try:
|
||||
# ``predict_action_chunk`` returns the *full* chunk shape
|
||||
# ``(batch, n_action_steps, action_dim)``. Enqueue every
|
||||
# step so DispatchAction at ctrl_hz can drain them
|
||||
# smoothly until the next refresh.
|
||||
chunk = self.policy.predict_action_chunk(observation)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning(
|
||||
"predict_action_chunk failed: %s",
|
||||
exc,
|
||||
exc_info=logger.isEnabledFor(logging.DEBUG),
|
||||
)
|
||||
push_log(
|
||||
state,
|
||||
f" [warn] predict_action_chunk failed: "
|
||||
f"{type(exc).__name__}: {exc}",
|
||||
)
|
||||
return None
|
||||
|
||||
# ``chunk`` shape: ``(batch, n_action_steps, action_dim)``. Push
|
||||
# each step as a ``(1, action_dim)`` tensor so the existing
|
||||
# action executor's batch-squeeze logic works unchanged.
|
||||
if chunk.ndim == 3:
|
||||
chunk_iter = chunk[0] # ``(n_action_steps, action_dim)``
|
||||
elif chunk.ndim == 2:
|
||||
chunk_iter = chunk
|
||||
else:
|
||||
chunk_iter = chunk.unsqueeze(0)
|
||||
|
||||
for step in chunk_iter:
|
||||
queue.append(step.unsqueeze(0))
|
||||
state["last_chunk_size"] = int(chunk_iter.shape[0])
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class DispatchAction(InferenceStep):
|
||||
"""Pop one action per tick and hand it to the robot.
|
||||
|
||||
In dry-run mode (``robot_executor=None``) the step still pops the
|
||||
queue so it doesn't grow unbounded — the popped tensor is logged
|
||||
instead of executed.
|
||||
|
||||
Wall-clock catch-up: the action queue represents an open-loop
|
||||
trajectory at a fixed step rate (``trigger.hz`` ≈ ``ctrl_hz``).
|
||||
When the main loop stalls — e.g. an LLM call for the high-level
|
||||
subtask blocks for ~2 s on MPS — the dispatch trigger fires only
|
||||
once over that whole interval. Naively popping a single entry per
|
||||
fire makes the robot lag further and further behind the planned
|
||||
timeline, and a 50-step chunk would take ~125 s to drain instead
|
||||
of ~1.7 s. Track real elapsed time between dispatches and pop
|
||||
``round(elapsed * hz)`` entries, sending the most recent one. The
|
||||
skipped intermediate joint targets are stale anyway — the dynamixel
|
||||
will smooth toward the latest goal position.
|
||||
"""
|
||||
|
||||
robot_executor: Any = None
|
||||
trigger: Trigger = field(default_factory=lambda: HzTrigger(hz=50.0))
|
||||
_last_dispatch_t: float | None = field(default=None, init=False)
|
||||
|
||||
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
|
||||
import time as _time # noqa: PLC0415
|
||||
|
||||
# ``/vlm`` mode pauses dispatch — the robot holds its last
|
||||
# commanded position while the operator runs VQA.
|
||||
if state.get("mode", "action") != "action":
|
||||
self._last_dispatch_t = None
|
||||
return None
|
||||
|
||||
queue = state.get("action_queue")
|
||||
if not queue:
|
||||
# Reset wall-clock anchor when the queue is empty so the
|
||||
# next chunk doesn't see a huge fake "elapsed" window.
|
||||
self._last_dispatch_t = None
|
||||
return None
|
||||
|
||||
now = _time.monotonic()
|
||||
hz = getattr(self.trigger, "hz", 30.0)
|
||||
if self._last_dispatch_t is None or hz <= 0:
|
||||
n_to_pop = 1
|
||||
else:
|
||||
elapsed = now - self._last_dispatch_t
|
||||
# ``max(1, ...)`` so we always pop at least one when the
|
||||
# trigger fires; ``min(len(queue), ...)`` so we don't run
|
||||
# off the end of the chunk.
|
||||
n_to_pop = max(1, min(len(queue), int(round(elapsed * hz))))
|
||||
self._last_dispatch_t = now
|
||||
|
||||
# Drain ``n_to_pop`` stale entries, keep only the latest as the
|
||||
# action actually sent. The intermediate joint targets would
|
||||
# all be ~10–30 ms apart in chunk time — the robot can't track
|
||||
# them individually anyway when the host loop is slow.
|
||||
latest = None
|
||||
for _ in range(n_to_pop):
|
||||
if not queue:
|
||||
break
|
||||
latest = queue.popleft() if hasattr(queue, "popleft") else queue.pop(0)
|
||||
state["actions_dispatched"] = state.get("actions_dispatched", 0) + 1
|
||||
|
||||
if latest is not None and self.robot_executor is not None:
|
||||
self.robot_executor(latest)
|
||||
return None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# High-level (text) paths — all use policy.select_message
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
_LOC_TOKENIZER_CACHE: dict[str, Any] = {}
|
||||
|
||||
|
||||
def _get_loc_tokenizer(tok_name: str, auto_tokenizer_cls: Any, register_loc_fn: Any) -> Any:
|
||||
"""Return a loc-token-registered tokenizer, loading from disk only once.
|
||||
|
||||
``AutoTokenizer.from_pretrained`` + loc-token registration is expensive and
|
||||
the result is immutable, so cache per ``tok_name``.
|
||||
"""
|
||||
tokenizer = _LOC_TOKENIZER_CACHE.get(tok_name)
|
||||
if tokenizer is None:
|
||||
tokenizer = register_loc_fn(auto_tokenizer_cls.from_pretrained(tok_name))
|
||||
_LOC_TOKENIZER_CACHE[tok_name] = tokenizer
|
||||
return tokenizer
|
||||
|
||||
|
||||
def _build_text_batch(
|
||||
policy: Any,
|
||||
prompt_messages: list[dict[str, Any]],
|
||||
*,
|
||||
add_generation_prompt: bool = True,
|
||||
) -> dict[str, Any]:
|
||||
"""Tokenize chat messages into the batch ``select_message`` expects.
|
||||
|
||||
PI052's backbone (PaliGemma) ships no chat template, so we train on
|
||||
a plain role-prefixed concatenation built by
|
||||
``PI052TextTokenizerStep``. We reuse that exact formatter so the
|
||||
inference prefix matches training; ``add_generation_prompt`` appends
|
||||
the bare ``Assistant: `` header the LM head continues from.
|
||||
"""
|
||||
import torch # noqa: PLC0415
|
||||
from transformers import AutoTokenizer # noqa: PLC0415
|
||||
|
||||
from lerobot.policies.pi052.text_processor_pi052 import ( # noqa: PLC0415
|
||||
_flatten_say_tool_calls,
|
||||
_format_messages,
|
||||
_strip_blocks,
|
||||
register_paligemma_loc_tokens,
|
||||
)
|
||||
|
||||
tok_name = (
|
||||
getattr(policy.config, "tokenizer_name", None) or "google/paligemma-3b-pt-224"
|
||||
)
|
||||
# Register PaliGemma's <locDDDD> tokens so inference encoding /
|
||||
# decoding sees them as single vocab ids — must match training.
|
||||
# The tokenizer is read-only after registration, so cache it: rebuilding it
|
||||
# from disk on every call dominated eval runtime (this runs twice per env
|
||||
# per replan — subtask gen + action prompt).
|
||||
tokenizer = _get_loc_tokenizer(tok_name, AutoTokenizer, register_paligemma_loc_tokens)
|
||||
|
||||
messages = [_strip_blocks(_flatten_say_tool_calls(m)) for m in prompt_messages]
|
||||
prompt, _spans = _format_messages(messages)
|
||||
if add_generation_prompt:
|
||||
prompt = prompt + "Assistant: "
|
||||
|
||||
encoded = tokenizer(prompt, return_tensors="pt")
|
||||
ids = encoded["input_ids"]
|
||||
attn = encoded.get("attention_mask")
|
||||
if attn is None and tokenizer.pad_token_id is not None:
|
||||
attn = ids != tokenizer.pad_token_id
|
||||
if attn is not None and hasattr(attn, "dtype") and attn.dtype != torch.bool:
|
||||
attn = attn.bool()
|
||||
|
||||
# Move tokens onto the policy's device — otherwise prefix embedding
|
||||
# raises a device-mismatch on every forward (CPU tensor vs MPS / CUDA
|
||||
# model), which the caller's broad except would swallow silently.
|
||||
device = getattr(getattr(policy, "config", None), "device", None)
|
||||
if device is not None:
|
||||
try:
|
||||
ids = ids.to(device)
|
||||
if attn is not None and hasattr(attn, "to"):
|
||||
attn = attn.to(device)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.debug("could not move pi052 lang tokens to %s: %s", device, exc)
|
||||
return {"lang_tokens": ids, "lang_masks": attn, "tokenizer": tokenizer}
|
||||
|
||||
|
||||
def _strip_recipe_keys(m: dict[str, Any]) -> dict[str, Any]:
|
||||
new = dict(m)
|
||||
new.pop("stream", None)
|
||||
new.pop("target", None)
|
||||
return new
|
||||
|
||||
|
||||
@dataclass
|
||||
class HighLevelSubtaskFwd(InferenceStep):
|
||||
"""At ~1 Hz, ask the policy for the next subtask.
|
||||
|
||||
Mirrors the ``high_level_subtask`` recipe layout exactly:
|
||||
|
||||
user: "${task}\\nPlan: ${plan}\\nMemory: ${memory}"
|
||||
user: "Current subtask: ${subtask}" (if subtask present)
|
||||
↓ generate ↓
|
||||
assistant: <next subtask>
|
||||
"""
|
||||
|
||||
policy: Any = None
|
||||
observation_provider: Any = None
|
||||
"""Same shape as ``LowLevelForward.observation_provider``. When
|
||||
set, the resulting observation is merged into ``select_message``'s
|
||||
batch so text generation runs against real video + state."""
|
||||
|
||||
trigger: Trigger = field(default_factory=lambda: HzTrigger(hz=1.0))
|
||||
|
||||
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
|
||||
if self.policy is None or not state.get("task"):
|
||||
return None
|
||||
# ``/vlm`` mode pauses subtask generation along with the rest of
|
||||
# the action loop.
|
||||
if state.get("mode", "action") != "action":
|
||||
return None
|
||||
# Gate to chunk boundaries: only generate a fresh subtask when
|
||||
# the action queue is empty (i.e. right before LowLevelForward
|
||||
# refreshes the chunk). ``select_message`` takes ~2 s on MPS,
|
||||
# and running it every loop iteration starves DispatchAction
|
||||
# at ctrl_hz=30 — the queue drains at ~0.4 actions/sec instead
|
||||
# of 30/sec and the robot barely moves. Tying it to the same
|
||||
# "queue empty" condition as the chunk refresh produces a
|
||||
# clean sense → think → act cycle.
|
||||
#
|
||||
# Rearm the trigger when skipping so a low-hz schedule
|
||||
# (e.g. ``--high_level_hz=0.2`` = once per 5 s) doesn't lose
|
||||
# the slot: the trigger fires once on the timer but the brief
|
||||
# queue-empty window almost never coincides, so without rearm
|
||||
# HL would effectively never run.
|
||||
queue = state.get("action_queue") or []
|
||||
if len(queue) > 0:
|
||||
if hasattr(self.trigger, "rearm"):
|
||||
self.trigger.rearm()
|
||||
return None
|
||||
# Per-chunk-boundary throttle: at each "queue empty" moment we
|
||||
# increment a counter; subtask gen only fires once the counter
|
||||
# reaches ``subtask_chunks_per_gen``. Lets the operator run e.g.
|
||||
# 5 action chunks per subtask-gen so the LM head doesn't churn
|
||||
# every 1.7 s (a fresh subtask while the previous one is still
|
||||
# being executed is wasted compute *and* causes the action
|
||||
# expert's flow trajectory to be re-planned mid-grasp).
|
||||
chunks_per_gen = max(1, int(state.get("subtask_chunks_per_gen", 1) or 1))
|
||||
# Initialise so the first chunk boundary fires immediately
|
||||
# (counter starts at chunks_per_gen, decrements per skip,
|
||||
# generates and resets when it hits 0).
|
||||
if "_hl_chunks_until_gen" not in state:
|
||||
state["_hl_chunks_until_gen"] = 0
|
||||
if state["_hl_chunks_until_gen"] > 0:
|
||||
state["_hl_chunks_until_gen"] -= 1
|
||||
if hasattr(self.trigger, "rearm"):
|
||||
self.trigger.rearm()
|
||||
return None
|
||||
state["_hl_chunks_until_gen"] = chunks_per_gen - 1
|
||||
ctx = _msgs_for_subtask(state)
|
||||
observation = _maybe_observation(self.observation_provider)
|
||||
# Default: greedy argmax, no min_new_tokens, no special-token
|
||||
# suppression — matches training. Operator can override via
|
||||
# ``--text_min_new_tokens=N --text_temperature=T --text_top_p=P``
|
||||
# on the CLI; useful for under-trained checkpoints whose LM
|
||||
# head still favours EOS at position 0 (pre-trained chat
|
||||
# backbone's short-turn prior hasn't been fully overridden
|
||||
# by the fine-tuning supervision yet).
|
||||
msg = _generate_with_policy(
|
||||
self.policy,
|
||||
ctx,
|
||||
observation=observation,
|
||||
state=state,
|
||||
label="subtask gen",
|
||||
min_new_tokens=int(state.get("text_gen_min_new_tokens") or 0),
|
||||
temperature=float(state.get("text_gen_temperature") or 0.0),
|
||||
top_p=float(state.get("text_gen_top_p") or 1.0),
|
||||
# Subtasks never legitimately contain PaliGemma ``<loc>``
|
||||
# tokens — suppress them so a checkpoint whose LM head
|
||||
# has drifted toward the pretrained loc-prior falls back
|
||||
# to its (still-correct) text mass.
|
||||
suppress_loc_tokens=True,
|
||||
)
|
||||
# Diagnostics: surface what the model is *actually* producing
|
||||
# at chunk boundaries, even when the output gets rejected or
|
||||
# repeats. Memorisation collapse looks like "same accepted
|
||||
# subtask N times in a row" or "gibberish_count rising while
|
||||
# current_subtask is stuck". The state panel renders these.
|
||||
state["last_subtask_raw"] = msg or ""
|
||||
# Persistent empty completion is its own failure mode (model
|
||||
# immediately EOS-es from the chat-template generation
|
||||
# prompt) — surface it once every N occurrences so the
|
||||
# operator can distinguish "generation failing silently"
|
||||
# from "generating fine but filter rejecting".
|
||||
if not msg:
|
||||
empties = state.get("subtask_empty_count", 0) + 1
|
||||
state["subtask_empty_count"] = empties
|
||||
if empties == 1 or empties % 5 == 0:
|
||||
debug = getattr(self.policy, "_last_select_message_debug", "") or ""
|
||||
if debug:
|
||||
push_log(
|
||||
state,
|
||||
f" [info] subtask gen empty (×{empties}); {debug}",
|
||||
)
|
||||
else:
|
||||
push_log(
|
||||
state,
|
||||
f" [info] subtask gen returned empty (×{empties}) — "
|
||||
"no tokens generated (head EOS-ing before any "
|
||||
"non-special token).",
|
||||
)
|
||||
if msg and _looks_like_gibberish(msg):
|
||||
# Bump a counter so the operator can see the model is
|
||||
# struggling without spamming the log every tick. A first
|
||||
# rejection still logs once so the failure is visible.
|
||||
count = state.get("subtask_gibberish_count", 0) + 1
|
||||
state["subtask_gibberish_count"] = count
|
||||
if count == 1 or count % 30 == 0:
|
||||
push_log(
|
||||
state,
|
||||
f" [info] subtask gen rejected (gibberish ×{count}): {msg[:60]!r}",
|
||||
)
|
||||
return None
|
||||
if msg:
|
||||
prev_subtask = state.get("current_subtask")
|
||||
changed = set_if_changed(state, "current_subtask", msg, label="subtask")
|
||||
if changed:
|
||||
# Stash the just-completed subtask so ``MemoryUpdateFwd``
|
||||
# can drop it into its prompt as ``Completed subtask:``
|
||||
# — the recipe binds ``completed_subtask`` to
|
||||
# ``nth_prev(style=subtask, offset=1)``, i.e. the subtask
|
||||
# that was active *before* the change.
|
||||
if prev_subtask:
|
||||
state["prior_subtask"] = prev_subtask
|
||||
# Subtask change is a downstream trigger.
|
||||
state.setdefault("events_this_tick", []).append("subtask_change")
|
||||
state["subtask_repeat_count"] = 0
|
||||
else:
|
||||
# Same accepted string regenerated — memorisation tell.
|
||||
# Once this counter climbs past a few, you're seeing
|
||||
# the model unable to move past the current subtask
|
||||
# despite the chunk having drained (visual scene may
|
||||
# have changed but the LM is replaying training
|
||||
# tokens).
|
||||
state["subtask_repeat_count"] = (
|
||||
state.get("subtask_repeat_count", 0) + 1
|
||||
)
|
||||
# Silently skip empty completions — common when the model
|
||||
# warms up or generates only EOS; logging it every tick at
|
||||
# ctrl_hz is just noise.
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemoryUpdateFwd(InferenceStep):
|
||||
"""On subtask boundary, refresh the compressed memory.
|
||||
|
||||
Mirrors the ``memory_update`` recipe layout exactly:
|
||||
|
||||
user: "${task}"
|
||||
assistant: "Previous memory: ${prior_memory}" (if prior memory)
|
||||
user: "Completed subtask: ${completed_subtask}" (if subtask)
|
||||
↓ generate ↓
|
||||
assistant: <new memory>
|
||||
"""
|
||||
|
||||
policy: Any = None
|
||||
observation_provider: Any = None
|
||||
trigger: Trigger = field(default_factory=lambda: EventTrigger("subtask_change"))
|
||||
|
||||
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
|
||||
# Don't consume the event — multiple steps may want to react.
|
||||
if self.policy is None:
|
||||
return None
|
||||
ctx = _msgs_for_memory(state)
|
||||
observation = _maybe_observation(self.observation_provider)
|
||||
new_memory = _generate_with_policy(
|
||||
self.policy,
|
||||
ctx,
|
||||
observation=observation,
|
||||
state=state,
|
||||
label="memory gen",
|
||||
suppress_loc_tokens=True,
|
||||
)
|
||||
state["last_memory_raw"] = new_memory or ""
|
||||
if new_memory and _looks_like_gibberish(new_memory):
|
||||
count = state.get("memory_gibberish_count", 0) + 1
|
||||
state["memory_gibberish_count"] = count
|
||||
push_log(
|
||||
state,
|
||||
f" [info] memory gen rejected (gibberish ×{count}): {new_memory[:60]!r}",
|
||||
)
|
||||
return None
|
||||
if new_memory:
|
||||
set_if_changed(state, "current_memory", new_memory, label="memory")
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class UserInterjectionFwd(InferenceStep):
|
||||
"""On stdin interjection, refresh the plan + emit a paired ``say``.
|
||||
|
||||
Mirrors the ``user_interjection_response`` recipe layout exactly:
|
||||
|
||||
user: "${task}"
|
||||
assistant: "Previous plan:\\n${prior_plan}" (if prior plan)
|
||||
user: "${interjection}" (the new utterance)
|
||||
↓ generate ↓
|
||||
assistant: <plan + <say>...</say>>
|
||||
"""
|
||||
|
||||
policy: Any = None
|
||||
observation_provider: Any = None
|
||||
trigger: Trigger = field(default_factory=lambda: EventTrigger("user_interjection"))
|
||||
|
||||
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
|
||||
if self.policy is None or not take_event(state, "user_interjection"):
|
||||
return None
|
||||
ctx = _msgs_for_interjection(state)
|
||||
observation = _maybe_observation(self.observation_provider)
|
||||
out = _generate_with_policy(
|
||||
self.policy,
|
||||
ctx,
|
||||
observation=observation,
|
||||
state=state,
|
||||
label="plan/say gen",
|
||||
suppress_loc_tokens=True,
|
||||
)
|
||||
if not out:
|
||||
# Don't log every empty completion — happens repeatedly on
|
||||
# MPS during warm-up and floods the panel. The user can
|
||||
# re-trigger by typing again.
|
||||
return None
|
||||
if _looks_like_gibberish(out):
|
||||
count = state.get("plan_gibberish_count", 0) + 1
|
||||
state["plan_gibberish_count"] = count
|
||||
push_log(
|
||||
state,
|
||||
f" [info] plan/say gen rejected (gibberish ×{count}): {out[:60]!r}",
|
||||
)
|
||||
return None
|
||||
# Heuristic split: model is trained to emit one assistant turn
|
||||
# carrying both plan text AND a `say` tool call. Look for a
|
||||
# "<say>...</say>" or "say(...)" marker; fall back to whole
|
||||
# text → plan, no speech.
|
||||
plan_text, speech_text = _split_plan_and_say(out)
|
||||
if plan_text and _looks_like_gibberish(plan_text):
|
||||
plan_text = ""
|
||||
if plan_text:
|
||||
set_if_changed(state, "current_plan", plan_text, label="plan")
|
||||
if speech_text:
|
||||
push_log(state, f" speech: {speech_text}")
|
||||
state.setdefault("tool_calls_pending", []).append(
|
||||
{
|
||||
"type": "function",
|
||||
"function": {"name": "say", "arguments": {"text": speech_text}},
|
||||
}
|
||||
)
|
||||
state.setdefault("events_this_tick", []).append("tool_call_pending")
|
||||
# Mark interjection consumed.
|
||||
state["recent_interjection"] = None
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class AskVQAFwd(InferenceStep):
|
||||
"""On stdin question, answer a frame-grounded VQA.
|
||||
|
||||
Mirrors the ``ask_vqa_*`` recipe layout exactly: a single user
|
||||
turn carrying just the VQA question, plus the camera image block
|
||||
in training (we drop the image at inference because the dataset's
|
||||
image preprocessing doesn't match SmolVLM's vision tower input).
|
||||
|
||||
user: <question>
|
||||
↓ generate ↓
|
||||
assistant: <vqa answer>
|
||||
"""
|
||||
|
||||
policy: Any = None
|
||||
observation_provider: Any = None
|
||||
trigger: Trigger = field(default_factory=lambda: EventTrigger("user_vqa_query"))
|
||||
|
||||
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
|
||||
if self.policy is None or not take_event(state, "user_vqa_query"):
|
||||
return None
|
||||
question = state.get("recent_vqa_query")
|
||||
if not question:
|
||||
return None
|
||||
ctx = _msgs_for_vqa(question)
|
||||
observation = _maybe_observation(self.observation_provider)
|
||||
answer = _generate_with_policy(
|
||||
self.policy,
|
||||
ctx,
|
||||
observation=observation,
|
||||
state=state,
|
||||
label="vqa gen",
|
||||
)
|
||||
# VQA answers are intentionally JSON-like during training, so
|
||||
# ``_looks_like_gibberish`` would false-positive on them. Keep
|
||||
# the answer as-is — the VQA panel line lets the user judge.
|
||||
if answer:
|
||||
push_log(state, f" vqa: {answer}")
|
||||
state["recent_vqa_query"] = None
|
||||
return None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tool dispatch
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class DispatchToolCalls(InferenceStep):
|
||||
"""Pop ``tool_calls_pending`` and execute them via :data:`TOOL_REGISTRY`."""
|
||||
|
||||
tools: dict[str, Any] = field(default_factory=dict)
|
||||
trigger: Trigger = field(default_factory=lambda: EventTrigger("tool_call_pending"))
|
||||
|
||||
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
|
||||
take_event(state, "tool_call_pending")
|
||||
pending = state.get("tool_calls_pending") or []
|
||||
for call in pending:
|
||||
try:
|
||||
fn = (call or {}).get("function") or {}
|
||||
name = fn.get("name")
|
||||
args = fn.get("arguments") or {}
|
||||
tool = self.tools.get(name)
|
||||
if tool is None:
|
||||
push_log(state, f" [warn] tool {name!r} not registered — skipping call")
|
||||
continue
|
||||
tool.call(args)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
push_log(state, f" [error] tool dispatch failed: {exc}")
|
||||
state["tool_calls_pending"] = []
|
||||
return None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _looks_like_gibberish(text: str) -> bool:
|
||||
"""Heuristically detect generation that's clearly off the rails.
|
||||
|
||||
Memorised models can collapse to dominant-mode outputs when the
|
||||
prompt drifts even slightly from training distribution. Reject:
|
||||
|
||||
* empty / whitespace-only
|
||||
* too few alphabetic characters (mostly punctuation)
|
||||
* a single character repeated past the threshold
|
||||
* starts with ``":"`` and contains no letters
|
||||
* too few unique tokens — e.g. ``"the"``, ``"the the the"``,
|
||||
``"Ass\\n::\\nthe"`` (the collapse seen on real-robot frames
|
||||
where the model emits one or two memorised tokens repeatedly)
|
||||
* chat-template fragment leakage (``Assistant:``, ``User:``,
|
||||
``Ass\\n``)
|
||||
|
||||
Real subtasks look like ``"close the gripper to grasp the blue
|
||||
cube"`` — multiple unique alphabetic tokens, no role-marker
|
||||
fragments. Anything materially shorter than that is rejected.
|
||||
"""
|
||||
if not text or not text.strip():
|
||||
return True
|
||||
stripped = text.strip()
|
||||
alpha = sum(1 for c in stripped if c.isalpha())
|
||||
if alpha < max(3, len(stripped) // 8):
|
||||
return True
|
||||
if stripped.startswith('":') and stripped.count('"') > stripped.count(" "):
|
||||
return True
|
||||
# Single repeating char: e.g. ``""""""``.
|
||||
if len(set(stripped)) <= 2 and len(stripped) > 4:
|
||||
return True
|
||||
# Chat-template fragment leakage — the model emits ``Ass``,
|
||||
# ``Assistant:``, ``User:``, often with extra newlines/colons.
|
||||
# Reject if the cleaned text is mostly role-marker shards.
|
||||
cleaned = stripped.replace("\n", " ").replace(":", " ")
|
||||
for marker in ("Assistant", "User", "Ass "):
|
||||
if marker in cleaned and len(cleaned.split()) < 4:
|
||||
return True
|
||||
tokens = [t for t in cleaned.split() if any(c.isalpha() for c in t)]
|
||||
unique_alpha = {t.lower() for t in tokens}
|
||||
# Short degenerate output — model stuck on ``the`` or a couple of
|
||||
# memorised single-token continuations.
|
||||
if len(unique_alpha) < 3 and len(stripped) < 80:
|
||||
return True
|
||||
# Long repetition collapse — the LM head loops an n-gram for the
|
||||
# whole generation budget ("the arm the arm … the the the the").
|
||||
# Length-independent: many tokens but a tiny unique ratio. The
|
||||
# earlier ``< 80`` check missed these because the looped string
|
||||
# blows well past 80 chars.
|
||||
if len(tokens) >= 8 and len(unique_alpha) <= max(3, len(tokens) // 10):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _control_context_messages(
|
||||
state: dict[str, Any],
|
||||
*,
|
||||
include_completed: bool = False,
|
||||
extra_user: str | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Build a chat-template-ready prompt from current runtime state.
|
||||
|
||||
Mirrors what ``subtasks_vqa.yaml`` renders into ``${task}\nPlan:
|
||||
${plan}\nMemory: ${memory}`` for the high-level branches.
|
||||
"""
|
||||
# Always emit ``Plan: `` / ``Memory: `` labels — even with empty
|
||||
# values — to mirror the training-time recipe substitution.
|
||||
task = state.get("task") or ""
|
||||
plan = state.get("current_plan") or ""
|
||||
memory = state.get("current_memory") or ""
|
||||
parts = [task, f"Plan: {plan}", f"Memory: {memory}"]
|
||||
if include_completed and state.get("current_subtask"):
|
||||
parts.append(f"Completed subtask: {state['current_subtask']}")
|
||||
head = "\n".join(parts)
|
||||
msgs: list[dict[str, Any]] = [{"role": "user", "content": head}]
|
||||
if extra_user:
|
||||
msgs.append({"role": "user", "content": extra_user})
|
||||
return msgs
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Per-recipe prompt builders. Each one mirrors a single sub-recipe's
|
||||
# message layout in ``subtasks_vqa.yaml`` so the chat-templated
|
||||
# prompt at inference matches what the model saw during training.
|
||||
# Generic ``_control_context_messages`` is kept around as a fallback
|
||||
# for ad-hoc callers but the four high-level steps now use these.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _hirobot_user_head(state: dict[str, Any]) -> str:
|
||||
"""Build the ``task\\nPlan: …\\nMemory: …`` user content string.
|
||||
|
||||
Mirrors what the recipe renders at training time, where
|
||||
``language_render._substitute`` substitutes empty strings for
|
||||
missing ``${plan}`` / ``${memory}`` bindings — i.e. the
|
||||
``Plan: `` / ``Memory: `` prefix labels are *always* in the
|
||||
user turn, even when their values aren't set yet. Skipping them
|
||||
here (the previous behaviour) produced a different prompt shape
|
||||
on early frames before plan / memory are populated and on
|
||||
samples where the dataset has no plan / memory annotation.
|
||||
"""
|
||||
task = state.get("task") or ""
|
||||
plan = state.get("current_plan") or ""
|
||||
memory = state.get("current_memory") or ""
|
||||
return f"{task}\nPlan: {plan}\nMemory: {memory}"
|
||||
|
||||
|
||||
def _msgs_for_subtask(state: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
"""``high_level_subtask`` recipe layout — predict the subtask from the
|
||||
task. The v-current recipe's user turn is just ``${task}`` (plan and
|
||||
memory are not trained), so the inference prompt is the bare task —
|
||||
no ``Plan: `` / ``Memory: `` lines.
|
||||
"""
|
||||
return [{"role": "user", "content": state.get("task") or ""}]
|
||||
|
||||
|
||||
def _msgs_for_memory(state: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
"""Memory-update prompt — mirrors ``memory_update`` recipe layout.
|
||||
|
||||
Recipe layout (``subtask_mem.yaml``):
|
||||
|
||||
user: "${task}"
|
||||
assistant: "Previous memory: ${prior_memory}" (if_present prior)
|
||||
user: "Completed subtask: ${completed}" (if_present completed)
|
||||
assistant: → predicts new memory
|
||||
|
||||
Fired by ``MemoryUpdateFwd`` on a ``subtask_change`` event:
|
||||
``state['current_memory']`` is the memory the policy last emitted
|
||||
(= the ``prior_memory`` binding at training), and
|
||||
``state['prior_subtask']`` is the subtask that just got replaced
|
||||
(= the ``completed_subtask`` binding at training).
|
||||
"""
|
||||
msgs: list[dict[str, Any]] = [
|
||||
{"role": "user", "content": state.get("task") or ""},
|
||||
]
|
||||
prior_memory = state.get("current_memory")
|
||||
if prior_memory:
|
||||
msgs.append(
|
||||
{"role": "assistant", "content": f"Previous memory: {prior_memory}"}
|
||||
)
|
||||
completed_subtask = state.get("prior_subtask")
|
||||
if completed_subtask:
|
||||
msgs.append(
|
||||
{"role": "user", "content": f"Completed subtask: {completed_subtask}"}
|
||||
)
|
||||
return msgs
|
||||
|
||||
|
||||
def _msgs_for_interjection(state: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
"""``user_interjection_response`` recipe layout."""
|
||||
msgs: list[dict[str, Any]] = [
|
||||
{"role": "user", "content": state.get("task") or ""}
|
||||
]
|
||||
if state.get("current_plan"):
|
||||
msgs.append(
|
||||
{"role": "assistant", "content": f"Previous plan:\n{state['current_plan']}"}
|
||||
)
|
||||
interjection = state.get("recent_interjection")
|
||||
if interjection:
|
||||
msgs.append({"role": "user", "content": interjection})
|
||||
return msgs
|
||||
|
||||
|
||||
def _msgs_for_plan(state: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
"""``plan_generation`` recipe layout — bare task → plan.
|
||||
|
||||
The assistant turn is the generation target, so we only render
|
||||
the user turn at inference; the runtime appends the predicted
|
||||
plan after sampling.
|
||||
"""
|
||||
return [{"role": "user", "content": state.get("task") or ""}]
|
||||
|
||||
|
||||
def _msgs_for_vqa(question: str) -> list[dict[str, Any]]:
|
||||
"""``ask_vqa_*`` recipe layout (text-only at inference)."""
|
||||
return [{"role": "user", "content": question}]
|
||||
|
||||
|
||||
def _maybe_observation(provider: Any) -> dict | None:
|
||||
"""Pull one observation from ``provider`` if it's set, else ``None``.
|
||||
|
||||
Errors from the provider are logged at debug level and swallowed —
|
||||
text generation still runs (in text-only mode) so a flaky frame
|
||||
source doesn't kill the REPL.
|
||||
"""
|
||||
if provider is None:
|
||||
return None
|
||||
try:
|
||||
return provider()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.debug("observation_provider raised %s — falling back to text-only", exc)
|
||||
return None
|
||||
|
||||
|
||||
def _generate_with_policy(
|
||||
policy: Any,
|
||||
messages: list[dict[str, Any]],
|
||||
*,
|
||||
observation: dict | None = None,
|
||||
state: dict[str, Any] | None = None,
|
||||
label: str = "select_message",
|
||||
min_new_tokens: int = 0,
|
||||
temperature: float = 0.0,
|
||||
top_p: float = 1.0,
|
||||
suppress_loc_tokens: bool = False,
|
||||
) -> str:
|
||||
"""Drive ``policy.select_message`` with a chat batch (and optional obs).
|
||||
|
||||
When ``observation`` carries ``observation.images.*`` and
|
||||
``observation.state``, those are merged into the batch so
|
||||
``select_message`` runs the same VLM prefix the policy was trained
|
||||
on. Without an observation the runtime falls back to a text-only
|
||||
prompt — the text head still runs, but generations may drift from
|
||||
the training distribution.
|
||||
|
||||
Failures are surfaced both to the module logger (``warning``) and,
|
||||
when ``state`` is given, to the runtime's user-visible log via
|
||||
:func:`push_log`, so the REPL no longer "looks dead" when
|
||||
something goes wrong inside generation.
|
||||
"""
|
||||
if not hasattr(policy, "select_message"):
|
||||
if state is not None:
|
||||
push_log(state, f" [warn] policy has no select_message — skipping {label}")
|
||||
return ""
|
||||
text_batch = _build_text_batch(policy, messages)
|
||||
try:
|
||||
from lerobot.utils.constants import ( # noqa: PLC0415
|
||||
OBS_LANGUAGE_ATTENTION_MASK,
|
||||
OBS_LANGUAGE_TOKENS,
|
||||
)
|
||||
|
||||
batch: dict[str, Any] = {
|
||||
OBS_LANGUAGE_TOKENS: text_batch["lang_tokens"],
|
||||
OBS_LANGUAGE_ATTENTION_MASK: text_batch["lang_masks"],
|
||||
}
|
||||
if observation:
|
||||
for k, v in observation.items():
|
||||
if isinstance(k, str) and k.startswith("observation.") and k not in batch:
|
||||
batch[k] = v
|
||||
kwargs: dict[str, Any] = {
|
||||
"tokenizer": text_batch["tokenizer"],
|
||||
"min_new_tokens": min_new_tokens,
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
}
|
||||
kwargs["suppress_loc_tokens"] = suppress_loc_tokens
|
||||
return policy.select_message(batch, **kwargs)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("%s failed: %s", label, exc, exc_info=logger.isEnabledFor(logging.DEBUG))
|
||||
if state is not None:
|
||||
push_log(state, f" [warn] {label} failed: {type(exc).__name__}: {exc}")
|
||||
return ""
|
||||
|
||||
|
||||
_SAY_RE = re.compile(r"<\s*say\s*>(.*?)<\s*/\s*say\s*>", re.IGNORECASE | re.DOTALL)
|
||||
|
||||
|
||||
def _split_plan_and_say(text: str) -> tuple[str, str]:
|
||||
"""Pull a ``<say>...</say>`` snippet out of ``text``; remainder is plan.
|
||||
|
||||
The training-time tool-call serializer wraps ``say(text="…")`` in a
|
||||
deterministic textual marker so prefix-LM-style training learns to
|
||||
emit it. The runtime parses it back here. If no marker is present,
|
||||
the entire text is treated as plan with no speech.
|
||||
"""
|
||||
if not text:
|
||||
return "", ""
|
||||
match = _SAY_RE.search(text)
|
||||
if not match:
|
||||
return text.strip(), ""
|
||||
speech = match.group(1).strip().strip('"').strip("'")
|
||||
plan = (text[: match.start()] + text[match.end() :]).strip()
|
||||
return plan, speech
|
||||
@@ -1,134 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Trigger primitives for PI052's multi-rate inference runtime.
|
||||
|
||||
Mirrors the plan's Section "Runtime orchestration": each
|
||||
``InferenceStep`` is gated by a :class:`Trigger` that decides per tick
|
||||
whether the step fires. Two trigger flavours cover all the cadences
|
||||
the canonical recipe needs:
|
||||
|
||||
* :class:`HzTrigger` for periodic beats (action chunks at ~3-5 Hz,
|
||||
high-level subtask generation at ~1 Hz, action dispatch at ~50 Hz)
|
||||
* :class:`EventTrigger` for one-shot reactions (subtask boundary →
|
||||
memory update; user interjection → plan refresh; user VQA query →
|
||||
vqa answer; pending tool call → dispatcher)
|
||||
|
||||
Triggers are stateless except for ``HzTrigger``'s last-fire timestamp.
|
||||
The runtime stores the :class:`Tick` clock as ``state["_tick"]`` so
|
||||
every step shares a single time source.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Protocol
|
||||
|
||||
|
||||
@dataclass
|
||||
class Tick:
|
||||
"""Single tick from :class:`TickClock`. Carries time references the
|
||||
runtime steps consume to gate themselves."""
|
||||
|
||||
index: int
|
||||
"""Monotonic counter — increments by one per tick."""
|
||||
|
||||
monotonic_seconds: float
|
||||
"""``time.monotonic()`` at the start of this tick."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class TickClock:
|
||||
"""Drives the runtime loop at up to ``max_rate_hz``.
|
||||
|
||||
Sleeps just enough between :meth:`advance` calls to enforce the
|
||||
rate. With ``max_rate_hz=50`` the loop wakes ~every 20ms; the
|
||||
higher-level ``HzTrigger`` slices that timeline into sub-cadences.
|
||||
"""
|
||||
|
||||
max_rate_hz: float = 50.0
|
||||
_index: int = field(default=0, init=False)
|
||||
_last_seconds: float | None = field(default=None, init=False)
|
||||
|
||||
def advance(self) -> Tick:
|
||||
period = 1.0 / max(self.max_rate_hz, 0.1)
|
||||
now = time.monotonic()
|
||||
if self._last_seconds is not None:
|
||||
sleep_for = (self._last_seconds + period) - now
|
||||
if sleep_for > 0:
|
||||
time.sleep(sleep_for)
|
||||
now = time.monotonic()
|
||||
self._last_seconds = now
|
||||
self._index += 1
|
||||
return Tick(index=self._index, monotonic_seconds=now)
|
||||
|
||||
|
||||
class Trigger(Protocol):
|
||||
"""Decide whether the next ``InferenceStep`` should fire."""
|
||||
|
||||
def should_fire(self, tick: Tick, state: dict[str, Any]) -> bool: ...
|
||||
|
||||
|
||||
@dataclass
|
||||
class HzTrigger:
|
||||
"""Fire at most ``hz`` times per second.
|
||||
|
||||
A step that gates further (e.g. ``HighLevelSubtaskFwd`` skipping
|
||||
when the action queue is non-empty) and wants the trigger to
|
||||
retry next tick instead of waiting a full period can call
|
||||
:meth:`rearm` from inside ``run``. Without this, a low-hz trigger
|
||||
(e.g. ``hz=0.2`` = once per 5 s) almost never coincides with the
|
||||
brief queue-empty window and the step never fires at all.
|
||||
"""
|
||||
|
||||
hz: float
|
||||
_last_seconds: float | None = field(default=None, init=False)
|
||||
|
||||
def should_fire(self, tick: Tick, state: dict[str, Any]) -> bool:
|
||||
period = 1.0 / max(self.hz, 1e-6)
|
||||
if self._last_seconds is None or (tick.monotonic_seconds - self._last_seconds) >= period:
|
||||
self._last_seconds = tick.monotonic_seconds
|
||||
return True
|
||||
return False
|
||||
|
||||
def rearm(self) -> None:
|
||||
"""Mark the trigger as not having fired, so the next tick re-evaluates.
|
||||
|
||||
Used by a step that decided to skip after ``should_fire`` already
|
||||
committed the firing — keeps the cadence honest without losing
|
||||
the slot.
|
||||
"""
|
||||
self._last_seconds = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class EventTrigger:
|
||||
"""Fire when ``event_name`` is in ``state["events_this_tick"]``.
|
||||
|
||||
The runtime fills ``events_this_tick`` once per tick from:
|
||||
|
||||
* stdin / network input (``user_interjection``, ``user_vqa_query``,
|
||||
``stop``)
|
||||
* internal state transitions (``subtask_change``,
|
||||
``tool_call_pending``)
|
||||
|
||||
The list is consumed (cleared at the end of the tick) so events
|
||||
fire at most once.
|
||||
"""
|
||||
|
||||
event_name: str
|
||||
|
||||
def should_fire(self, tick: Tick, state: dict[str, Any]) -> bool:
|
||||
events: list[str] = state.get("events_this_tick") or []
|
||||
return self.event_name in events
|
||||
@@ -1,127 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Rich-based REPL layout for the PI052 runtime.
|
||||
|
||||
Two-zone terminal layout:
|
||||
|
||||
[chat scrollback — user messages / robot responses, scrolls naturally]
|
||||
|
||||
┌── State ──────────────────────────────────────────┐
|
||||
│ task please clean up the kitchen │
|
||||
│ subtask grasp the handle of the sponge │
|
||||
│ plan 1. grasp sponge 2. wipe 3. tidy │
|
||||
│ memory sponge picked up; counter still dirty │
|
||||
└───────────────────────────────────────────────────┘
|
||||
> _
|
||||
|
||||
The state panel re-renders on every state change. Chat lines are
|
||||
``console.print``'d above the live region so they accumulate naturally
|
||||
in scrollback. Implemented with :class:`rich.live.Live` plus
|
||||
:func:`rich.console.Console.input` for the prompt — when an input is
|
||||
pending, ``rich.Live`` auto-suspends so the input doesn't fight the
|
||||
panel for cursor position.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
try: # rich is optional; only required for the interactive REPL.
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.table import Table
|
||||
from rich.text import Text
|
||||
|
||||
_HAS_RICH = True
|
||||
except ImportError: # pragma: no cover
|
||||
_HAS_RICH = False
|
||||
Console = Any # type: ignore[assignment]
|
||||
Panel = Any # type: ignore[assignment]
|
||||
Table = Any # type: ignore[assignment]
|
||||
Text = Any # type: ignore[assignment]
|
||||
|
||||
|
||||
_STATE_KEYS = (
|
||||
("task", "task"),
|
||||
("current_subtask", "subtask"),
|
||||
("current_plan", "plan"),
|
||||
("current_memory", "memory"),
|
||||
)
|
||||
|
||||
|
||||
def make_state_panel(state: dict[str, Any]) -> Any:
|
||||
"""Render the persistent state panel for the live region.
|
||||
|
||||
Returns a :class:`rich.panel.Panel`. Caller passes it to
|
||||
``Live.update(panel)`` whenever the state changes.
|
||||
"""
|
||||
if not _HAS_RICH:
|
||||
raise RuntimeError(
|
||||
"rich is required for the interactive REPL. "
|
||||
"`pip install rich` (it's a transitive dep of lerobot)."
|
||||
)
|
||||
table = Table.grid(padding=(0, 2), expand=True)
|
||||
table.add_column(justify="right", style="dim", no_wrap=True, width=10)
|
||||
table.add_column(justify="left")
|
||||
for key, label in _STATE_KEYS:
|
||||
value = state.get(key)
|
||||
if value is None:
|
||||
rendered = Text("(not set)", style="dim italic")
|
||||
else:
|
||||
rendered = Text(str(value), style="bold")
|
||||
table.add_row(label, rendered)
|
||||
queue = state.get("action_queue")
|
||||
queue_len = len(queue) if hasattr(queue, "__len__") else 0
|
||||
pending = state.get("tool_calls_pending") or []
|
||||
footer = Text.assemble(
|
||||
("queued actions: ", "dim"),
|
||||
(str(queue_len), "bold cyan"),
|
||||
(" pending tool calls: ", "dim"),
|
||||
(str(len(pending)), "bold magenta"),
|
||||
)
|
||||
table.add_row("", footer)
|
||||
run_mode = state.get("mode", "action")
|
||||
mode_tag = (
|
||||
"[green]action[/]" if run_mode == "action" else "[yellow]paused[/]"
|
||||
)
|
||||
return Panel(
|
||||
table,
|
||||
title=f"[bold]PI052 state[/] · mode: {mode_tag}",
|
||||
border_style="cyan",
|
||||
)
|
||||
|
||||
|
||||
def print_user_line(console: Any, line: str) -> None:
|
||||
"""Append a user-typed line to the chat scrollback."""
|
||||
if not _HAS_RICH:
|
||||
print(f"you: {line}", flush=True)
|
||||
return
|
||||
console.print(f"[bold cyan]you:[/] {line}")
|
||||
|
||||
|
||||
def print_robot_lines(console: Any, lines: list[str]) -> None:
|
||||
"""Append robot/runtime log lines to the chat scrollback."""
|
||||
if not _HAS_RICH:
|
||||
for line in lines:
|
||||
print(f"robot: {line.lstrip()}", flush=True)
|
||||
return
|
||||
for line in lines:
|
||||
# The runtime uses leading whitespace + "label: text"; render
|
||||
# the label in green and the value in default for readability.
|
||||
stripped = line.lstrip()
|
||||
if ":" in stripped:
|
||||
label, _, value = stripped.partition(":")
|
||||
console.print(f"[bold green]robot[/] [dim]({label.strip()})[/] {value.strip()}")
|
||||
else:
|
||||
console.print(f"[bold green]robot:[/] {stripped}")
|
||||
@@ -1,423 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Interactive VQA for the PI052 runtime.
|
||||
|
||||
In ``/vlm`` mode a typed line is treated as a VQA question. This module
|
||||
runs the full interactive flow:
|
||||
|
||||
1. pull the current observation and list available cameras,
|
||||
2. ask the operator which camera to ground the question on,
|
||||
3. generate the answer with the VLM conditioned on that one camera,
|
||||
4. parse the JSON answer; if it carries a bounding box (``bbox``) or a
|
||||
point (``keypoint``), draw the overlay on the camera frame, save a
|
||||
PNG to ``./vqa_overlays/`` and auto-open it.
|
||||
|
||||
VQA answer schemas mirror the annotation pipeline's ``VQA_ANSWER_SHAPES``
|
||||
(see ``lerobot.annotations.steerable_pipeline.validator``):
|
||||
|
||||
* ``bbox`` — ``{"detections": [{"label", "bbox_format": "xyxy",
|
||||
"bbox": [x1, y1, x2, y2]}, ...]}``
|
||||
* ``keypoint`` — ``{"label", "point_format": "xy", "point": [x, y]}``
|
||||
* ``count`` / ``attribute`` / ``spatial`` — text-only, no overlay.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
import webbrowser
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from .runtime_state import push_log
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_IMAGE_PREFIX = "observation.images."
|
||||
|
||||
# PaliGemma detection / pointing vocabulary. PI052 trains spatial VQA
|
||||
# answers in this native ``<locNNNN>`` format (index in [0, 1023],
|
||||
# normalized to the image axis) instead of pixel-coordinate JSON, so the
|
||||
# answer string the runtime parses can be e.g.
|
||||
# ``<loc0512><loc0301> blue cube`` (point) or
|
||||
# ``<loc0100><loc0080><loc0400><loc0360> blue cube`` (box).
|
||||
_LOC_RE = re.compile(r"<loc(\d{1,4})>")
|
||||
|
||||
# Iteration order for shape matching — most specific keys first so an
|
||||
# answer is classified deterministically.
|
||||
_SHAPE_ORDER = ("bbox", "keypoint", "count", "attribute", "spatial")
|
||||
|
||||
_BBOX_COLOR = (255, 64, 64)
|
||||
_POINT_COLOR = (64, 220, 64)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Camera selection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def available_cameras(observation: dict | None) -> list[str]:
|
||||
"""Return the sorted ``observation.images.*`` keys present in ``observation``."""
|
||||
if not observation:
|
||||
return []
|
||||
return sorted(k for k in observation if isinstance(k, str) and k.startswith(_IMAGE_PREFIX))
|
||||
|
||||
|
||||
def camera_short_name(camera_key: str) -> str:
|
||||
"""Strip the ``observation.images.`` prefix for display."""
|
||||
return camera_key[len(_IMAGE_PREFIX) :] if camera_key.startswith(_IMAGE_PREFIX) else camera_key
|
||||
|
||||
|
||||
def prompt_camera_choice(
|
||||
cameras: list[str],
|
||||
*,
|
||||
input_fn: Any = input,
|
||||
print_fn: Any = print,
|
||||
) -> str | None:
|
||||
"""Ask the operator which camera frame to draw a VQA overlay on.
|
||||
|
||||
Accepts either the menu number or the (short or full) camera name.
|
||||
A single-camera setup auto-selects without prompting. Returns the
|
||||
chosen ``observation.images.*`` key, or ``None`` if the operator
|
||||
cancels / gives an invalid answer.
|
||||
"""
|
||||
if not cameras:
|
||||
return None
|
||||
if len(cameras) == 1:
|
||||
return cameras[0]
|
||||
print_fn("Draw the result on which camera?")
|
||||
for i, cam in enumerate(cameras, 1):
|
||||
print_fn(f" [{i}] {camera_short_name(cam)}")
|
||||
try:
|
||||
raw = str(input_fn("camera> ")).strip()
|
||||
except (EOFError, KeyboardInterrupt):
|
||||
return None
|
||||
if not raw:
|
||||
return cameras[0]
|
||||
if raw.isdigit():
|
||||
idx = int(raw) - 1
|
||||
return cameras[idx] if 0 <= idx < len(cameras) else None
|
||||
for cam in cameras:
|
||||
if raw == cam or raw == camera_short_name(cam):
|
||||
return cam
|
||||
return None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Answer parsing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _loc_to_norm(idx: int) -> float:
|
||||
"""PaliGemma ``<locNNNN>`` index → normalized [0, 1] axis coordinate."""
|
||||
return max(0.0, min(1023.0, float(idx))) / 1023.0
|
||||
|
||||
|
||||
def parse_loc_answer(answer: str) -> dict | None:
|
||||
"""Parse a PaliGemma ``<loc>``-format spatial VQA answer.
|
||||
|
||||
PI052 trains spatial answers in PaliGemma's native detection
|
||||
vocabulary, label-first: a point is ``<label> <locY><locX>``, a box
|
||||
is ``<label> <locY0><locX0><locY1><locX1>``, and multiple boxes are
|
||||
joined by `` ; `` (e.g. ``cube <loc..><loc..><loc..><loc..> ; box
|
||||
<loc..><loc..><loc..><loc..>``). Loc-first formats are also accepted
|
||||
— this parser strips loc tokens and treats the remainder as the
|
||||
label, so order is irrelevant. Coordinates come back *normalized*
|
||||
([0, 1]); the overlay denormalizes them against the chosen camera
|
||||
frame's pixel size.
|
||||
|
||||
Returns ``{"kind", "payload", "normalized": True}`` on success
|
||||
(``payload`` mirrors the JSON shapes so the overlay code is shared),
|
||||
or ``None`` when the answer carries no ``<loc>`` tokens.
|
||||
"""
|
||||
if not answer or "<loc" not in answer:
|
||||
return None
|
||||
segments = [seg for seg in answer.split(";") if "<loc" in seg]
|
||||
points: list[tuple[float, float, str]] = []
|
||||
boxes: list[tuple[float, float, float, float, str]] = []
|
||||
for seg in segments:
|
||||
locs = [int(m) for m in _LOC_RE.findall(seg)]
|
||||
label = _LOC_RE.sub("", seg).strip()
|
||||
if len(locs) == 2:
|
||||
y, x = (_loc_to_norm(v) for v in locs[:2])
|
||||
points.append((x, y, label))
|
||||
elif len(locs) >= 4:
|
||||
y1, x1, y2, x2 = (_loc_to_norm(v) for v in locs[:4])
|
||||
boxes.append((x1, y1, x2, y2, label))
|
||||
if boxes:
|
||||
detections = [
|
||||
{"label": lbl, "bbox_format": "xyxy", "bbox": [x1, y1, x2, y2]}
|
||||
for (x1, y1, x2, y2, lbl) in boxes
|
||||
]
|
||||
return {"kind": "bbox", "payload": {"detections": detections}, "normalized": True}
|
||||
if len(points) == 1:
|
||||
x, y, lbl = points[0]
|
||||
return {
|
||||
"kind": "keypoint",
|
||||
"payload": {"label": lbl, "point_format": "xy", "point": [x, y]},
|
||||
"normalized": True,
|
||||
}
|
||||
if points: # several bare points → treat as detections-as-points
|
||||
detections = [
|
||||
{"label": lbl, "bbox_format": "xyxy", "bbox": [x, y, x, y]} for (x, y, lbl) in points
|
||||
]
|
||||
return {"kind": "bbox", "payload": {"detections": detections}, "normalized": True}
|
||||
return None
|
||||
|
||||
|
||||
def parse_vqa_answer(answer: str) -> dict | None:
|
||||
"""Parse a VQA answer string into ``{"kind", "payload"}``.
|
||||
|
||||
``kind`` is one of the ``VQA_ANSWER_SHAPES`` names (``bbox``,
|
||||
``keypoint``, ``count``, ``attribute``, ``spatial``) or ``"unknown"``
|
||||
when the JSON doesn't match any known shape. PaliGemma ``<loc>``
|
||||
spatial answers are detected first (PI052 trains them in that native
|
||||
format). Returns ``None`` when the answer is neither ``<loc>`` text
|
||||
nor a parseable JSON object.
|
||||
"""
|
||||
if not answer or not answer.strip():
|
||||
return None
|
||||
loc_parsed = parse_loc_answer(answer)
|
||||
if loc_parsed is not None:
|
||||
return loc_parsed
|
||||
try:
|
||||
payload = json.loads(answer)
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
if not isinstance(payload, dict):
|
||||
return None
|
||||
|
||||
try:
|
||||
from lerobot.annotations.steerable_pipeline.validator import ( # noqa: PLC0415
|
||||
VQA_ANSWER_SHAPES,
|
||||
)
|
||||
|
||||
shapes = VQA_ANSWER_SHAPES
|
||||
except ImportError: # pragma: no cover - annotation extra not installed
|
||||
shapes = {
|
||||
"bbox": {"detections"},
|
||||
"keypoint": {"label", "point_format", "point"},
|
||||
"count": {"label", "count"},
|
||||
"attribute": {"label", "attribute", "value"},
|
||||
"spatial": {"subject", "relation", "object"},
|
||||
}
|
||||
|
||||
keys = set(payload)
|
||||
for kind in _SHAPE_ORDER:
|
||||
required = shapes.get(kind)
|
||||
if required and required <= keys:
|
||||
return {"kind": kind, "payload": payload}
|
||||
return {"kind": "unknown", "payload": payload}
|
||||
|
||||
|
||||
def answer_has_overlay(parsed: dict | None) -> bool:
|
||||
"""True iff ``parsed`` carries drawable spatial coordinates."""
|
||||
return bool(parsed) and parsed.get("kind") in ("bbox", "keypoint")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Overlay drawing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def observation_image_to_pil(image_tensor: Any) -> Any:
|
||||
"""Convert an ``observation.images.*`` tensor to a PIL RGB image.
|
||||
|
||||
The runtime observation stores images as ``(1, C, H, W)`` (or
|
||||
``(C, H, W)``) float tensors in ``[0, 1]``. Reuses
|
||||
``image_array_to_pil_image`` which handles the CHW→HWC transpose and
|
||||
the float→uint8 scaling.
|
||||
"""
|
||||
from lerobot.datasets.image_writer import image_array_to_pil_image # noqa: PLC0415
|
||||
|
||||
arr = image_tensor
|
||||
if hasattr(arr, "detach"):
|
||||
arr = arr.detach().cpu()
|
||||
if hasattr(arr, "numpy"):
|
||||
arr = arr.numpy()
|
||||
while arr.ndim > 3: # drop leading batch dim(s)
|
||||
arr = arr[0]
|
||||
return image_array_to_pil_image(arr).convert("RGB")
|
||||
|
||||
|
||||
def draw_vqa_overlay(image: Any, parsed: dict) -> Any:
|
||||
"""Draw ``bbox`` / ``keypoint`` answers onto a copy of ``image``.
|
||||
|
||||
Non-spatial answers (``count`` / ``attribute`` / ``spatial`` /
|
||||
``unknown``) are returned as an unmodified copy. When ``parsed`` has
|
||||
``normalized=True`` (PaliGemma ``<loc>`` answers) the [0, 1]
|
||||
coordinates are scaled to the image's pixel size.
|
||||
"""
|
||||
from PIL import ImageDraw # noqa: PLC0415
|
||||
|
||||
img = image.convert("RGB").copy()
|
||||
kind = parsed.get("kind")
|
||||
payload = parsed.get("payload") or {}
|
||||
draw = ImageDraw.Draw(img)
|
||||
w, h = img.size
|
||||
sx, sy = (w, h) if parsed.get("normalized") else (1, 1)
|
||||
|
||||
if kind == "bbox":
|
||||
for det in payload.get("detections") or []:
|
||||
if not isinstance(det, dict):
|
||||
continue
|
||||
box = det.get("bbox")
|
||||
if not (isinstance(box, list | tuple) and len(box) == 4):
|
||||
continue
|
||||
try:
|
||||
x1, y1, x2, y2 = (float(v) for v in box)
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
x1, x2 = x1 * sx, x2 * sx
|
||||
y1, y2 = y1 * sy, y2 * sy
|
||||
draw.rectangle([x1, y1, x2, y2], outline=_BBOX_COLOR, width=3)
|
||||
label = str(det.get("label", "")).strip()
|
||||
if label:
|
||||
draw.text((x1 + 3, max(0.0, y1 - 12)), label, fill=_BBOX_COLOR)
|
||||
elif kind == "keypoint":
|
||||
point = payload.get("point")
|
||||
if isinstance(point, list | tuple) and len(point) == 2:
|
||||
try:
|
||||
x, y = float(point[0]) * sx, float(point[1]) * sy
|
||||
except (TypeError, ValueError):
|
||||
return img
|
||||
r = 6
|
||||
draw.ellipse([x - r, y - r, x + r, y + r], outline=_POINT_COLOR, width=3)
|
||||
draw.line([x - 2 * r, y, x + 2 * r, y], fill=_POINT_COLOR, width=2)
|
||||
draw.line([x, y - 2 * r, x, y + 2 * r], fill=_POINT_COLOR, width=2)
|
||||
label = str(payload.get("label", "")).strip()
|
||||
if label:
|
||||
draw.text((x + r + 3, y - r), label, fill=_POINT_COLOR)
|
||||
return img
|
||||
|
||||
|
||||
def _open_file(path: Path) -> None:
|
||||
"""Best-effort open ``path`` in the OS default viewer."""
|
||||
try:
|
||||
if sys.platform == "darwin":
|
||||
subprocess.run(["open", str(path)], check=False)
|
||||
elif sys.platform.startswith("linux"):
|
||||
subprocess.run(["xdg-open", str(path)], check=False)
|
||||
elif os.name == "nt":
|
||||
os.startfile(str(path)) # type: ignore[attr-defined] # noqa: S606
|
||||
else: # pragma: no cover - exotic platform
|
||||
webbrowser.open(path.resolve().as_uri())
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.debug("could not auto-open %s: %s", path, exc)
|
||||
|
||||
|
||||
def save_and_open_overlay(image: Any, out_dir: str | Path = "./vqa_overlays") -> Path:
|
||||
"""Save ``image`` as a timestamped PNG under ``out_dir`` and auto-open it."""
|
||||
out = Path(out_dir)
|
||||
out.mkdir(parents=True, exist_ok=True)
|
||||
path = out / f"vqa_{int(time.time() * 1000)}.png"
|
||||
image.save(path)
|
||||
_open_file(path)
|
||||
return path
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Orchestrator
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def handle_vqa_query(
|
||||
*,
|
||||
policy: Any,
|
||||
observation_provider: Any,
|
||||
question: str,
|
||||
state: dict[str, Any],
|
||||
input_fn: Any = input,
|
||||
print_fn: Any = print,
|
||||
) -> None:
|
||||
"""Run one interactive VQA question end to end.
|
||||
|
||||
Called synchronously from the input layer while the runtime is in
|
||||
``/question`` mode (the action loop is gated off, so the policy is
|
||||
not in concurrent use). Progress is reported via both
|
||||
:func:`push_log` (REPL panel scrollback) and ``print_fn`` (direct
|
||||
stdout) — in autonomous question mode the panel redraw is suspended,
|
||||
so the direct print is what the operator actually sees.
|
||||
"""
|
||||
from .steps import _generate_with_policy, _msgs_for_vqa # noqa: PLC0415
|
||||
|
||||
def report(line: str) -> None:
|
||||
"""Surface a line both to the panel scrollback and to stdout."""
|
||||
push_log(state, line)
|
||||
try:
|
||||
print_fn(line)
|
||||
except Exception: # noqa: BLE001
|
||||
pass
|
||||
|
||||
if policy is None or not hasattr(policy, "select_message"):
|
||||
report(" [warn] vqa: policy has no select_message — skipping")
|
||||
return
|
||||
|
||||
observation: dict | None = None
|
||||
if observation_provider is not None:
|
||||
try:
|
||||
observation = observation_provider()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.debug("observation_provider raised %s", exc)
|
||||
|
||||
# Feed the FULL observation (every camera + state) to the VLM. The
|
||||
# ``ask_vqa_*`` recipes look single-camera, but the image *block* is
|
||||
# stripped before tokenization — the actual frames reach the model
|
||||
# via PI052's ``OBS_IMAGES_*`` channels, and ``embed_prefix``
|
||||
# consumes *all* ``config.image_features`` regardless of which
|
||||
# camera the sub-recipe was tagged for. So the model always sees
|
||||
# every camera; the operator never has to name one to ask.
|
||||
answer = _generate_with_policy(
|
||||
policy,
|
||||
_msgs_for_vqa(question),
|
||||
observation=observation,
|
||||
state=state,
|
||||
label="vqa gen",
|
||||
)
|
||||
if not answer:
|
||||
report(" [info] vqa gen returned empty")
|
||||
return
|
||||
report(f" vqa: {answer}")
|
||||
|
||||
parsed = parse_vqa_answer(answer)
|
||||
if not answer_has_overlay(parsed):
|
||||
if parsed is None:
|
||||
report(" [info] vqa answer is not JSON — no overlay")
|
||||
return
|
||||
|
||||
# The answer carries a bounding box / point. Its pixel coordinates
|
||||
# are camera-specific and the text answer doesn't say which camera,
|
||||
# so ask the operator *now* — only when there is actually something
|
||||
# to draw — which camera frame to render the overlay on.
|
||||
cameras = available_cameras(observation)
|
||||
if observation is None or not cameras:
|
||||
report(" [info] no camera image — cannot draw overlay")
|
||||
return
|
||||
chosen = prompt_camera_choice(cameras, input_fn=input_fn, print_fn=print_fn)
|
||||
if chosen is None:
|
||||
report(" [info] overlay skipped — no camera selected")
|
||||
return
|
||||
try:
|
||||
pil = observation_image_to_pil(observation[chosen])
|
||||
overlay = draw_vqa_overlay(pil, parsed)
|
||||
path = save_and_open_overlay(overlay)
|
||||
report(f" vqa overlay ({camera_short_name(chosen)}) saved: {path}")
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("vqa overlay failed: %s", exc, exc_info=logger.isEnabledFor(logging.DEBUG))
|
||||
report(f" [warn] vqa overlay failed: {type(exc).__name__}: {exc}")
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,198 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""π0.5 v2 pre/post-processor factory.
|
||||
|
||||
When ``config.recipe_path`` is set, the pre-processor pipeline becomes:
|
||||
|
||||
rename observations
|
||||
add batch dim
|
||||
relative-action prep (inherited from π0.5)
|
||||
NormalizerProcessorStep
|
||||
RenderMessagesStep — recipe → messages, target_message_indices,
|
||||
message_streams (PR 1 of the steerable
|
||||
stack)
|
||||
PI052TextTokenizerStep — messages → input_ids + label mask +
|
||||
predict_actions
|
||||
DeviceProcessorStep
|
||||
|
||||
When ``recipe_path`` is ``None`` we delegate to the plain π0.5 pipeline
|
||||
so unannotated datasets keep working.
|
||||
|
||||
Post-processor is unchanged from π0.5.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs.recipe import TrainingRecipe
|
||||
from lerobot.processor import (
|
||||
AbsoluteActionsProcessorStep,
|
||||
ActionTokenizerProcessorStep,
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyAction,
|
||||
PolicyProcessorPipeline,
|
||||
RelativeActionsProcessorStep,
|
||||
RenameObservationsProcessorStep,
|
||||
UnnormalizerProcessorStep,
|
||||
policy_action_to_transition,
|
||||
transition_to_policy_action,
|
||||
)
|
||||
# RenderMessagesStep is intentionally not re-exported from
|
||||
# ``lerobot.processor`` because it pulls in optional language-stack deps;
|
||||
# import it directly.
|
||||
from lerobot.processor.render_messages_processor import RenderMessagesStep
|
||||
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
|
||||
from ..pi05.processor_pi05 import make_pi05_pre_post_processors
|
||||
from .configuration_pi052 import PI052Config
|
||||
from .text_processor_pi052 import PI052TextTokenizerStep
|
||||
|
||||
|
||||
def make_pi052_pre_post_processors(
|
||||
config: PI052Config,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
dataset_repo_id: str | None = None,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
"""Build PI0.5-v2's pre/post-processor pipelines.
|
||||
|
||||
Falls through to π0.5's stock pipeline when ``recipe_path`` is unset.
|
||||
"""
|
||||
if not config.recipe_path:
|
||||
return make_pi05_pre_post_processors(config, dataset_stats=dataset_stats)
|
||||
|
||||
recipe = _load_recipe(config.recipe_path)
|
||||
|
||||
relative_step = RelativeActionsProcessorStep(
|
||||
enabled=config.use_relative_actions,
|
||||
exclude_joints=getattr(config, "relative_exclude_joints", []),
|
||||
action_names=getattr(config, "action_feature_names", None),
|
||||
)
|
||||
|
||||
input_steps = [
|
||||
RenameObservationsProcessorStep(rename_map={}),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
relative_step,
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
RenderMessagesStep(recipe=recipe),
|
||||
PI052TextTokenizerStep(
|
||||
tokenizer_name="google/paligemma-3b-pt-224",
|
||||
max_length=config.tokenizer_max_length,
|
||||
plan_dropout_prob=getattr(config, "plan_dropout_prob", 0.0),
|
||||
memory_dropout_prob=getattr(config, "memory_dropout_prob", 0.0),
|
||||
subtask_dropout_prob=getattr(config, "subtask_dropout_prob", 0.0),
|
||||
),
|
||||
]
|
||||
|
||||
# FAST tokenizer for discrete-action CE supervision (paper §III.C).
|
||||
# Only inserted when explicitly enabled — keeps the post-training-
|
||||
# style recipe (flow + text) as the default. When on, the step
|
||||
# writes ACTION_TOKENS / ACTION_TOKEN_MASK into
|
||||
# ``COMPLEMENTARY_DATA`` and the modeling forward picks them up.
|
||||
if getattr(config, "enable_fast_action_loss", False):
|
||||
# Per Pertsch et al. 2025 (FAST [64], π0.5 §III.C): fit the
|
||||
# tokenizer on this dataset's action distribution rather than
|
||||
# using the universal codebook off the shelf. We do this once
|
||||
# and cache to disk, keyed on (dataset, base, n_samples).
|
||||
action_tokenizer_path = config.action_tokenizer_name
|
||||
if (
|
||||
getattr(config, "auto_fit_fast_tokenizer", False)
|
||||
and dataset_repo_id is not None
|
||||
):
|
||||
from .fit_fast_tokenizer import fit_fast_tokenizer # noqa: PLC0415
|
||||
|
||||
cache_dir = Path(config.fast_tokenizer_cache_dir).expanduser()
|
||||
try:
|
||||
action_tokenizer_path = fit_fast_tokenizer(
|
||||
dataset_repo_id=dataset_repo_id,
|
||||
cache_dir=cache_dir,
|
||||
base_tokenizer_name=config.action_tokenizer_name,
|
||||
n_samples=config.fast_tokenizer_fit_samples,
|
||||
chunk_size=config.chunk_size,
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
import logging # noqa: PLC0415
|
||||
|
||||
logging.getLogger(__name__).warning(
|
||||
"FAST tokenizer fit failed (%s) — falling back to "
|
||||
"the universal base tokenizer %r. Train will still "
|
||||
"work but compression will be suboptimal.",
|
||||
exc, config.action_tokenizer_name,
|
||||
)
|
||||
|
||||
input_steps.append(
|
||||
ActionTokenizerProcessorStep(
|
||||
action_tokenizer_name=action_tokenizer_path,
|
||||
max_action_tokens=config.max_action_tokens,
|
||||
fast_skip_tokens=config.fast_skip_tokens,
|
||||
paligemma_tokenizer_name="google/paligemma-3b-pt-224",
|
||||
)
|
||||
)
|
||||
|
||||
input_steps.append(DeviceProcessorStep(device=config.device))
|
||||
|
||||
output_steps = [
|
||||
UnnormalizerProcessorStep(
|
||||
features=config.output_features,
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
AbsoluteActionsProcessorStep(
|
||||
enabled=config.use_relative_actions,
|
||||
relative_step=relative_step,
|
||||
),
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
]
|
||||
return (
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
|
||||
steps=input_steps,
|
||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
),
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction](
|
||||
steps=output_steps,
|
||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
to_transition=policy_action_to_transition,
|
||||
to_output=transition_to_policy_action,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _load_recipe(path_str: str) -> TrainingRecipe:
|
||||
"""Resolve ``path_str`` to a ``TrainingRecipe``.
|
||||
|
||||
Accepts an absolute path or a path relative to
|
||||
``src/lerobot/configs/``.
|
||||
"""
|
||||
p = Path(path_str)
|
||||
if not p.is_absolute() and not p.exists():
|
||||
from lerobot.configs import recipe as _recipe_module # noqa: PLC0415
|
||||
|
||||
configs_dir = Path(_recipe_module.__file__).resolve().parent
|
||||
candidate = configs_dir / path_str
|
||||
if candidate.exists():
|
||||
p = candidate
|
||||
return TrainingRecipe.from_yaml(p)
|
||||
@@ -1,641 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""π0.5 v2 text-tokenisation step.
|
||||
|
||||
PaliGemma is *not* chat-pretrained, so we can't lean on
|
||||
``tokenizer.apply_chat_template``. Instead we concatenate the rendered
|
||||
messages as plain text with simple ``User: ... Assistant: ...`` role
|
||||
delimiters — matching the prompt format π0.5 uses in the paper
|
||||
(``Task: ... State: ... Action: ...``).
|
||||
|
||||
Outputs:
|
||||
|
||||
* ``OBS_LANGUAGE_TOKENS`` / ``OBS_LANGUAGE_ATTENTION_MASK`` — the
|
||||
concatenated prompt tokenised by the PaliGemma tokenizer (the same
|
||||
one ``processor_pi05`` already uses).
|
||||
* ``text_labels`` — same shape as token ids, ``-100`` everywhere except
|
||||
positions belonging to messages whose index is in
|
||||
``target_message_indices``. ``modeling_pi052`` runs cross-entropy on
|
||||
those positions via the PaliGemma ``lm_head``.
|
||||
* ``predict_actions`` — bool tensor, ``True`` iff any of the rendered
|
||||
target messages has ``message_streams[i] == "low_level"``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.configs import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.processor.pipeline import ProcessorStep, ProcessorStepRegistry
|
||||
from lerobot.types import EnvTransition, TransitionKey
|
||||
from lerobot.utils.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS, OBS_STATE
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def discretize_state_str(state_row: Any) -> str:
|
||||
"""Discretize a single normalized state vector into 256 bins, space-joined.
|
||||
|
||||
Mirrors pi05's ``Pi05PrepareStateTokenizerProcessorStep`` (same bins /
|
||||
convention) so pi052's low-level action prompt carries proprioception in
|
||||
the exact format pi05 was trained on. Expects state already normalized by
|
||||
the upstream ``NormalizerProcessorStep``.
|
||||
"""
|
||||
arr = state_row.detach().cpu().numpy() if hasattr(state_row, "detach") else np.asarray(state_row)
|
||||
disc = np.digitize(arr, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1
|
||||
return " ".join(str(int(x)) for x in disc.reshape(-1).tolist())
|
||||
|
||||
|
||||
def _state_row_at(state_all: Any, pos: int) -> Any:
|
||||
"""Select the per-sample state row from a (possibly batched) state tensor."""
|
||||
if state_all is None:
|
||||
return None
|
||||
if hasattr(state_all, "ndim") and state_all.ndim >= 2:
|
||||
return state_all[pos]
|
||||
return state_all
|
||||
|
||||
|
||||
def _content_to_text(content: Any) -> str:
|
||||
"""Collapse a message's ``content`` (string or multimodal blocks) to text."""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
parts = [
|
||||
b["text"]
|
||||
for b in content
|
||||
if isinstance(b, dict) and b.get("type") == "text" and isinstance(b.get("text"), str)
|
||||
]
|
||||
return "\n".join(parts)
|
||||
return ""
|
||||
|
||||
|
||||
def _flatten_say_tool_calls(message: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Serialize assistant ``say`` tool calls into a ``<say>...</say>`` marker.
|
||||
|
||||
PaliGemma's flat text prompt has no notion of structured tool calls,
|
||||
and ``_format_messages`` only reads ``role`` / ``content`` — so
|
||||
without this a ``say`` tool call is dropped entirely and never
|
||||
supervised. Rewriting it into the content text as a ``<say>...</say>``
|
||||
marker lets the LM head learn to emit it; the runtime parses it back
|
||||
via ``_split_plan_and_say``. Messages without ``say`` tool calls are
|
||||
returned unchanged (the structured calls, if any, are still dropped).
|
||||
"""
|
||||
tool_calls = message.get("tool_calls")
|
||||
if not tool_calls:
|
||||
return message
|
||||
say_texts: list[str] = []
|
||||
for call in tool_calls:
|
||||
if not isinstance(call, dict):
|
||||
continue
|
||||
fn = call.get("function") or {}
|
||||
if fn.get("name") != "say":
|
||||
continue
|
||||
args = fn.get("arguments")
|
||||
if isinstance(args, str):
|
||||
try:
|
||||
import json # noqa: PLC0415
|
||||
|
||||
args = json.loads(args)
|
||||
except (ValueError, TypeError):
|
||||
args = {}
|
||||
text = args.get("text", "") if isinstance(args, dict) else ""
|
||||
if text:
|
||||
say_texts.append(str(text))
|
||||
new = dict(message)
|
||||
new.pop("tool_calls", None)
|
||||
if not say_texts:
|
||||
return new
|
||||
base = _content_to_text(new.get("content")).strip()
|
||||
marker = "".join(f"<say>{t}</say>" for t in say_texts)
|
||||
new["content"] = f"{base}\n{marker}" if base else marker
|
||||
return new
|
||||
|
||||
|
||||
def _strip_blocks(message: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Normalise a message's content to a plain string.
|
||||
|
||||
The recipe renderer can emit ``content`` as a string OR as a list
|
||||
of HF-style multimodal blocks (``{type: text, text: ...}``,
|
||||
``{type: image, feature: ...}``). PaliGemma's text tokenizer can
|
||||
only consume strings, so we flatten: drop image blocks (cameras
|
||||
flow through ``observation.images.*`` separately) and join text
|
||||
block texts.
|
||||
"""
|
||||
new = dict(message)
|
||||
new.pop("stream", None)
|
||||
new.pop("target", None)
|
||||
content = new.get("content")
|
||||
if content is None:
|
||||
new["content"] = ""
|
||||
elif isinstance(content, str):
|
||||
pass
|
||||
elif isinstance(content, list):
|
||||
parts: list[str] = []
|
||||
for block in content:
|
||||
if not isinstance(block, dict):
|
||||
continue
|
||||
if block.get("type") == "text":
|
||||
t = block.get("text", "")
|
||||
if isinstance(t, str):
|
||||
parts.append(t)
|
||||
new["content"] = "\n".join(parts)
|
||||
else:
|
||||
new["content"] = str(content)
|
||||
return new
|
||||
|
||||
|
||||
def _is_batched_messages(messages: Any) -> bool:
|
||||
return isinstance(messages, list) and bool(messages) and isinstance(messages[0], list)
|
||||
|
||||
|
||||
def _sample_indices(value: Any, batch_size: int) -> list[int | None]:
|
||||
if value is None:
|
||||
return [None] * batch_size
|
||||
if isinstance(value, torch.Tensor):
|
||||
if value.numel() == 1:
|
||||
return [int(value.item())] * batch_size
|
||||
values = value.reshape(-1).tolist()
|
||||
return [int(v) for v in values[:batch_size]]
|
||||
if isinstance(value, (list, tuple)):
|
||||
if len(value) == 1:
|
||||
return _sample_indices(value[0], batch_size)
|
||||
return [int(v.item() if hasattr(v, "item") else v) for v in value[:batch_size]]
|
||||
return [int(value)] * batch_size
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# VQA spatial answers → PaliGemma <loc> format (PI052 only)
|
||||
#
|
||||
# PaliGemma is pre-trained on detection / pointing with a ``<locNNNN>``
|
||||
# vocabulary (normalized [0, 1023]). The recipe's bbox / keypoint VQA
|
||||
# answers are stored as JSON in Qwen2.5-VL's grounding convention:
|
||||
# **0–1000 normalized coordinates**, NOT pixels. (Verified empirically
|
||||
# on the published datasets: x and y both span 0..1000 with ~30% of
|
||||
# values exceeding the camera's pixel dimensions — they're not pixels.)
|
||||
# Converting to ``<loc>`` is therefore camera-resolution-independent:
|
||||
# ``loc_idx = round(coord / 1000 * 1023)``. We do the conversion here —
|
||||
# not in the dataset — so the dataset keeps the raw JSON and stays
|
||||
# backbone-agnostic.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# The 0–1000 scale Qwen2.5-VL emits for grounding coordinates.
|
||||
_VQA_COORD_SCALE = 1000.0
|
||||
|
||||
|
||||
def register_paligemma_loc_tokens(tokenizer: Any) -> Any:
|
||||
"""Make PaliGemma's ``<locDDDD>`` ids match on raw text — single tokens.
|
||||
|
||||
PaliGemma reserves vocab ids [256000, 257023] for ``<locDDDD>``
|
||||
(detection / pointing) tokens, but the *stock* tokenizer does NOT
|
||||
match them when encoding raw text — it BPE-splits ``<loc0162>`` into
|
||||
7 pieces (``<``, ``loc``, ``0``, ``1``, ``6``, ``2``, ``>``). Training
|
||||
the LM head on a ``<loc>`` target then supervises those 7 generic
|
||||
BPE pieces instead of one detection-vocab id, the LM head learns to
|
||||
emit the *character sequence*, and those pieces' logits dominate
|
||||
other turns (the ``<loc>``-salad on subtasks). Registering the loc
|
||||
tokens once makes them tokenize as their single ids (256000+idx),
|
||||
leveraging PaliGemma's detection prior properly. Idempotent.
|
||||
"""
|
||||
if "<loc0000>" in getattr(tokenizer, "added_tokens_encoder", {}):
|
||||
return tokenizer
|
||||
tokenizer.add_tokens([f"<loc{i:04d}>" for i in range(1024)])
|
||||
return tokenizer
|
||||
|
||||
|
||||
def _loc_token(coord: float, scale: float = _VQA_COORD_SCALE) -> str:
|
||||
"""PaliGemma ``<locNNNN>`` for a coord on a ``[0, scale]`` axis."""
|
||||
idx = round(float(coord) / scale * 1023) if scale > 0 else 0
|
||||
return f"<loc{max(0, min(1023, idx)):04d}>"
|
||||
|
||||
|
||||
def _vqa_answer_to_loc(answer: dict[str, Any]) -> str | None:
|
||||
"""Convert a bbox / keypoint VQA answer dict to PaliGemma ``<loc>`` text.
|
||||
|
||||
Input coordinates are in Qwen2.5-VL's 0–1000 normalized space (see
|
||||
module-level note). y is emitted before x for each coordinate pair
|
||||
(PaliGemma convention), with the integer indices in [0, 1023].
|
||||
|
||||
**Format: label first, locs after.** PaliGemma's pretraining puts
|
||||
locs first (``<loc><loc> label``), but for our small-dataset VQA
|
||||
blend that turns the LM head into a loc-emission attractor at every
|
||||
``Assistant:`` position — VQA targets share their first supervised
|
||||
token with ~25% of all text samples, and the head collapses to
|
||||
emitting ``<loc>`` regardless of the prompt. Putting the label
|
||||
first (``label <locY><locX>``) means every text sample (subtask,
|
||||
memory, VQA, …) starts the supervised target with a real word,
|
||||
breaking the attractor. The model still learns the loc vocabulary
|
||||
for the *spatial* portion of the answer; it just can't fire it as
|
||||
the first generation step from a clean prompt.
|
||||
|
||||
Returns ``None`` for non-spatial answers (count / attribute /
|
||||
spatial-relation) — those keep their JSON form.
|
||||
"""
|
||||
point = answer.get("point")
|
||||
if isinstance(point, list | tuple) and len(point) == 2 and "point_format" in answer:
|
||||
try:
|
||||
x, y = float(point[0]), float(point[1])
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
label = str(answer.get("label", "")).strip()
|
||||
if not label:
|
||||
return None
|
||||
return f"{label} {_loc_token(y)}{_loc_token(x)}"
|
||||
|
||||
detections = answer.get("detections")
|
||||
if isinstance(detections, list) and detections:
|
||||
parts: list[str] = []
|
||||
for det in detections:
|
||||
if not isinstance(det, dict):
|
||||
continue
|
||||
box = det.get("bbox")
|
||||
if not (isinstance(box, list | tuple) and len(box) == 4):
|
||||
continue
|
||||
try:
|
||||
x1, y1, x2, y2 = (float(v) for v in box)
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
label = str(det.get("label", "")).strip()
|
||||
if not label:
|
||||
continue
|
||||
toks = (
|
||||
f"{_loc_token(y1)}{_loc_token(x1)}"
|
||||
f"{_loc_token(y2)}{_loc_token(x2)}"
|
||||
)
|
||||
parts.append(f"{label} {toks}")
|
||||
return " ; ".join(parts) if parts else None
|
||||
return None
|
||||
|
||||
|
||||
def _messages_vqa_to_loc(
|
||||
messages: list[dict[str, Any]],
|
||||
target_indices: list[int],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Rewrite bbox / keypoint VQA *target* answers from JSON to ``<loc>`` text.
|
||||
|
||||
Each target turn whose content parses as a spatial VQA answer is
|
||||
converted. Non-spatial answers and subtask / memory targets (plain
|
||||
text → not JSON) are left untouched. Camera-independent: VQA coords
|
||||
are 0–1000 normalized, so no observation lookup is needed.
|
||||
"""
|
||||
if not target_indices:
|
||||
return messages
|
||||
out = list(messages)
|
||||
for idx in target_indices:
|
||||
if not (0 <= idx < len(out)):
|
||||
continue
|
||||
content = out[idx].get("content")
|
||||
if not isinstance(content, str) or not content.strip():
|
||||
continue
|
||||
try:
|
||||
answer = json.loads(content)
|
||||
except (ValueError, TypeError):
|
||||
continue # subtask / memory targets are plain text — skip
|
||||
if not isinstance(answer, dict):
|
||||
continue
|
||||
loc_text = _vqa_answer_to_loc(answer)
|
||||
if loc_text is not None:
|
||||
out[idx] = {**out[idx], "content": loc_text}
|
||||
return out
|
||||
|
||||
|
||||
def _format_messages(
|
||||
messages: list[dict[str, Any]],
|
||||
target_indices: list[int] | None = None,
|
||||
eos_token: str | None = None,
|
||||
) -> tuple[str, list[tuple[int, int]]]:
|
||||
"""Concatenate messages into the π0.5-style flat prompt.
|
||||
|
||||
When both ``target_indices`` and ``eos_token`` are given, the EOS
|
||||
string is appended to each supervised target turn's content and the
|
||||
returned span covers it — so the label builder marks the EOS token
|
||||
as a supervised label. That teaches the LM head where the answer
|
||||
*ends*: without an EOS in the target span the model is never given a
|
||||
stop signal and rambles to ``max_length`` at inference. Inference
|
||||
callers omit both args (no EOS baked into the prompt — the model
|
||||
generates it and ``select_message`` stops on it).
|
||||
|
||||
Returns:
|
||||
prompt: the full text the tokenizer will consume.
|
||||
msg_spans: list of ``(char_start, char_end)`` covering each
|
||||
message's supervised payload (content, plus the
|
||||
appended EOS for target turns) within ``prompt``.
|
||||
"""
|
||||
targets = set(target_indices or [])
|
||||
parts: list[str] = []
|
||||
spans: list[tuple[int, int]] = []
|
||||
cursor = 0
|
||||
for i, m in enumerate(messages):
|
||||
role = m.get("role", "user")
|
||||
content = m.get("content", "") or ""
|
||||
# Role tag + newline. The model has to learn to emit the same
|
||||
# role tokens at generation time, which is fine for greedy
|
||||
# decoding because the chat template is implicit in the
|
||||
# supervised target span.
|
||||
header = f"{role.capitalize()}: "
|
||||
# A supervised target turn ends with EOS so the model learns to
|
||||
# terminate; the span below covers content + EOS. Non-target
|
||||
# turns (and inference) carry no EOS.
|
||||
body = content + eos_token if (eos_token and i in targets) else content
|
||||
# span covers the content (+ EOS) portion only — never the role
|
||||
# tag — so labels are computed over the supervised payload.
|
||||
full = header + body + "\n"
|
||||
start = cursor + len(header)
|
||||
end = start + len(body)
|
||||
parts.append(full)
|
||||
spans.append((start, end))
|
||||
cursor += len(full)
|
||||
return "".join(parts), spans
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="pi052_text_tokenizer")
|
||||
class PI052TextTokenizerStep(ProcessorStep):
|
||||
"""Render messages → token ids + label mask + predict_actions flag.
|
||||
|
||||
No chat template; concatenates messages as
|
||||
``User: ... \\nAssistant: ...`` text.
|
||||
"""
|
||||
|
||||
tokenizer_name: str = "google/paligemma-3b-pt-224"
|
||||
max_length: int = 200
|
||||
padding: str = "max_length"
|
||||
padding_side: str = "right"
|
||||
plan_dropout_prob: float = 0.0
|
||||
memory_dropout_prob: float = 0.0
|
||||
subtask_dropout_prob: float = 0.0
|
||||
interjection_dropout_prob: float = 0.0
|
||||
dropout_seed: int | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self._tokenizer: Any = None
|
||||
|
||||
def _ensure_tokenizer(self) -> Any:
|
||||
if self._tokenizer is not None:
|
||||
return self._tokenizer
|
||||
from transformers import AutoTokenizer # noqa: PLC0415
|
||||
|
||||
self._tokenizer = register_paligemma_loc_tokens(
|
||||
AutoTokenizer.from_pretrained(self.tokenizer_name)
|
||||
)
|
||||
return self._tokenizer
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Pipeline step
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition | None:
|
||||
transition = transition.copy()
|
||||
complementary = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) or {}
|
||||
messages = complementary.get("messages") or []
|
||||
|
||||
if not messages:
|
||||
# No recipe was rendered — caller will fall back to the
|
||||
# plain Pi0.5 prompt path. We pass the transition through
|
||||
# unmodified.
|
||||
return transition
|
||||
|
||||
tokenizer = self._ensure_tokenizer()
|
||||
# Normalized proprioceptive state (set by NormalizerProcessorStep, which
|
||||
# runs before this step). Injected into low-level action prompts so the
|
||||
# action expert sees proprioception, matching pi05's discretized State:.
|
||||
state_all = (transition.get(TransitionKey.OBSERVATION) or {}).get(OBS_STATE)
|
||||
# VQA coords are 0–1000 normalized (Qwen2.5-VL convention) — the
|
||||
# <loc> conversion is camera-resolution-independent and needs no
|
||||
# observation lookup here.
|
||||
if _is_batched_messages(messages):
|
||||
indices_iter = _sample_indices(complementary.get("index"), len(messages))
|
||||
encoded = [
|
||||
self._encode_messages(
|
||||
tokenizer,
|
||||
msg,
|
||||
list(streams),
|
||||
list(tgt_indices),
|
||||
complementary,
|
||||
sample_idx=int(s_idx) if s_idx is not None else None,
|
||||
state_row=_state_row_at(state_all, pos),
|
||||
)
|
||||
for pos, (msg, streams, tgt_indices, s_idx) in enumerate(
|
||||
zip(
|
||||
messages,
|
||||
complementary.get("message_streams") or [[] for _ in messages],
|
||||
complementary.get("target_message_indices") or [[] for _ in messages],
|
||||
indices_iter,
|
||||
strict=False,
|
||||
)
|
||||
)
|
||||
]
|
||||
else:
|
||||
sample_idx = _sample_indices(complementary.get("index"), 1)[0]
|
||||
encoded = [
|
||||
self._encode_messages(
|
||||
tokenizer,
|
||||
messages,
|
||||
list(complementary.get("message_streams") or []),
|
||||
list(complementary.get("target_message_indices") or []),
|
||||
complementary,
|
||||
sample_idx=sample_idx,
|
||||
state_row=_state_row_at(state_all, 0),
|
||||
)
|
||||
]
|
||||
|
||||
obs = dict(transition.get(TransitionKey.OBSERVATION) or {})
|
||||
obs[OBS_LANGUAGE_TOKENS] = torch.stack([ids for ids, _, _, _, _ in encoded])
|
||||
obs[OBS_LANGUAGE_ATTENTION_MASK] = torch.stack([attn for _, attn, _, _, _ in encoded])
|
||||
transition[TransitionKey.OBSERVATION] = obs
|
||||
|
||||
transition[TransitionKey.COMPLEMENTARY_DATA] = {
|
||||
**complementary,
|
||||
"text_labels": torch.stack([labels for _, _, labels, _, _ in encoded]),
|
||||
"predict_actions": torch.stack([pred for _, _, _, pred, _ in encoded]),
|
||||
}
|
||||
return transition
|
||||
|
||||
def _encode_messages(
|
||||
self,
|
||||
tokenizer: Any,
|
||||
messages: list[dict[str, Any]],
|
||||
message_streams: list[str | None],
|
||||
target_indices: list[int],
|
||||
complementary: dict[str, Any],
|
||||
sample_idx: int | None = None,
|
||||
state_row: Any = None,
|
||||
) -> tuple[Tensor, Tensor, Tensor, Tensor, str]:
|
||||
# Optional: drop non-target messages per the dropout config.
|
||||
# Keeps the supervised-target indices stable by re-mapping
|
||||
# after removal.
|
||||
if (
|
||||
self.plan_dropout_prob
|
||||
or self.memory_dropout_prob
|
||||
or self.subtask_dropout_prob
|
||||
or self.interjection_dropout_prob
|
||||
):
|
||||
messages, target_indices = self._apply_prompt_dropout(
|
||||
messages,
|
||||
target_indices,
|
||||
complementary,
|
||||
sample_idx=sample_idx,
|
||||
)
|
||||
|
||||
# Rewrite bbox / keypoint VQA target answers from JSON to
|
||||
# PaliGemma <loc> text. Coords are 0–1000 normalized so this is
|
||||
# camera-independent.
|
||||
messages = _messages_vqa_to_loc(messages, target_indices)
|
||||
|
||||
# Flatten ``say`` tool calls into ``<say>...</say>`` text before
|
||||
# stripping, so the spoken reply is actually tokenized and
|
||||
# supervised (PaliGemma's flat prompt has no structured calls).
|
||||
messages = [_strip_blocks(_flatten_say_tool_calls(m)) for m in messages]
|
||||
# Low-level (action-conditioning) samples get the discretized state
|
||||
# appended to their user message, mirroring pi05's
|
||||
# "..., State: {256-bin};" so the action expert sees proprioception.
|
||||
# Higher-level text streams (subtask/memory generation) stay state-free.
|
||||
if state_row is not None and any(s == "low_level" for s in message_streams):
|
||||
state_str = discretize_state_str(state_row)
|
||||
for m in reversed(messages):
|
||||
if m.get("role") == "user":
|
||||
base = _content_to_text(m.get("content", ""))
|
||||
m["content"] = f"{base}, State: {state_str};"
|
||||
break
|
||||
# Append EOS to supervised target turns so the LM head learns to
|
||||
# stop (the span covers it → it becomes a supervised label).
|
||||
prompt, spans = _format_messages(
|
||||
messages, target_indices, getattr(tokenizer, "eos_token", None)
|
||||
)
|
||||
|
||||
encoded = tokenizer(
|
||||
prompt,
|
||||
max_length=self.max_length,
|
||||
padding=self.padding,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
return_offsets_mapping=True,
|
||||
padding_side=self.padding_side,
|
||||
)
|
||||
|
||||
input_ids = encoded["input_ids"][0]
|
||||
attention_mask = encoded["attention_mask"][0].bool()
|
||||
offsets = encoded["offset_mapping"][0] # (seq, 2), char (start,end)
|
||||
|
||||
# Build label mask: -100 everywhere except over supervised
|
||||
# target message char ranges.
|
||||
labels = torch.full_like(input_ids, fill_value=-100)
|
||||
for idx in target_indices:
|
||||
if idx >= len(spans):
|
||||
continue
|
||||
char_start, char_end = spans[idx]
|
||||
for token_pos in range(input_ids.shape[0]):
|
||||
if not attention_mask[token_pos]:
|
||||
continue
|
||||
tok_start, tok_end = int(offsets[token_pos, 0]), int(offsets[token_pos, 1])
|
||||
if tok_end <= char_start or tok_start >= char_end:
|
||||
continue
|
||||
labels[token_pos] = input_ids[token_pos]
|
||||
|
||||
# Scan ALL message streams (not just targets): the
|
||||
# ``low_level_execution`` recipe drops ``target: true`` on
|
||||
# the assistant to avoid trivial copy-from-user text-CE; the
|
||||
# flow loss still needs to fire, gated by ``stream: low_level``.
|
||||
predict_actions = torch.tensor(
|
||||
bool(any(s == "low_level" for s in message_streams)),
|
||||
dtype=torch.bool,
|
||||
)
|
||||
return input_ids, attention_mask, labels, predict_actions, prompt
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Per-component prompt dropout (Pi0.7 §V.E)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _apply_prompt_dropout(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
target_indices: list[int],
|
||||
complementary: dict[str, Any],
|
||||
sample_idx: int | None = None,
|
||||
) -> tuple[list[dict[str, Any]], list[int]]:
|
||||
"""Drop messages classified as plan/memory/subtask context.
|
||||
|
||||
Targets are *never* dropped (they're the supervised payload).
|
||||
Re-maps target_indices to the new positions after drops.
|
||||
"""
|
||||
import random # noqa: PLC0415
|
||||
|
||||
seed = self.dropout_seed
|
||||
if seed is None:
|
||||
# Canonical row-index key set by ``BatchProcessor`` /
|
||||
# ``render_messages_processor``. Falling back to other
|
||||
# keys silently gave every sample seed=0 → identical
|
||||
# dropout pattern across the whole epoch.
|
||||
seed_src = sample_idx if sample_idx is not None else complementary.get("index", 0)
|
||||
try:
|
||||
if hasattr(seed_src, "item"):
|
||||
seed_src = seed_src.item()
|
||||
seed = int(seed_src)
|
||||
except (TypeError, ValueError):
|
||||
seed = 0
|
||||
rng = random.Random(seed)
|
||||
|
||||
keep_indices: list[int] = []
|
||||
for idx, msg in enumerate(messages):
|
||||
if idx in target_indices:
|
||||
keep_indices.append(idx)
|
||||
continue
|
||||
kind = _classify_for_dropout(msg)
|
||||
prob = {
|
||||
"plan": self.plan_dropout_prob,
|
||||
"memory": self.memory_dropout_prob,
|
||||
"subtask": self.subtask_dropout_prob,
|
||||
"interjection": self.interjection_dropout_prob,
|
||||
}.get(kind, 0.0)
|
||||
if prob > 0.0 and rng.random() < prob:
|
||||
continue
|
||||
keep_indices.append(idx)
|
||||
|
||||
# Build remap and apply
|
||||
new_messages = [messages[i] for i in keep_indices]
|
||||
old_to_new = {old: new for new, old in enumerate(keep_indices)}
|
||||
new_targets = [old_to_new[t] for t in target_indices if t in old_to_new]
|
||||
return new_messages, new_targets
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
return features
|
||||
|
||||
|
||||
def _classify_for_dropout(message: dict[str, Any]) -> str | None:
|
||||
"""Heuristic content-prefix classifier (plan / memory / subtask)."""
|
||||
content = message.get("content")
|
||||
if isinstance(content, list):
|
||||
text_parts = [b.get("text", "") for b in content if isinstance(b, dict) and b.get("type") == "text"]
|
||||
content = " ".join(text_parts)
|
||||
elif content is None:
|
||||
return None
|
||||
elif not isinstance(content, str):
|
||||
return None
|
||||
s = content.strip()
|
||||
if s.startswith("Plan:") or s.startswith("Previous plan"):
|
||||
return "plan"
|
||||
if s.startswith("Memory:") or s.startswith("Previous memory"):
|
||||
return "memory"
|
||||
if s.startswith("Current subtask") or s.startswith("Completed subtask"):
|
||||
return "subtask"
|
||||
return None
|
||||
@@ -14,28 +14,18 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from typing import TYPE_CHECKING, Literal
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
from torch.nn import functional as F # noqa: N812
|
||||
from torch import nn
|
||||
|
||||
from lerobot.utils.import_utils import _transformers_available
|
||||
|
||||
# Default PaliGemma SigLIP input resolution. Mirrors
|
||||
# ``pi05.configuration_pi05.DEFAULT_IMAGE_SIZE``; duplicated as a plain constant
|
||||
# to avoid importing the pi05 package here (which would create an import cycle:
|
||||
# pi_gemma -> pi05.__init__ -> modeling_pi05 -> pi_gemma).
|
||||
DEFAULT_IMAGE_SIZE = 224
|
||||
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers.cache_utils import DynamicCache
|
||||
from transformers.masking_utils import create_causal_mask
|
||||
from transformers.modeling_layers import GradientCheckpointingLayer
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
from transformers.models.auto import CONFIG_MAPPING
|
||||
from transformers.models.gemma import modeling_gemma
|
||||
from transformers.models.gemma.modeling_gemma import (
|
||||
GemmaAttention,
|
||||
GemmaConfig,
|
||||
@@ -59,8 +49,6 @@ else:
|
||||
GradientCheckpointingLayer = None
|
||||
BaseModelOutputWithPast = None
|
||||
create_causal_mask = None
|
||||
CONFIG_MAPPING = None
|
||||
modeling_gemma = None
|
||||
|
||||
|
||||
def _gated_residual(
|
||||
@@ -287,8 +275,6 @@ class PiGemmaModel(GemmaModel): # type: ignore[misc]
|
||||
# Convert to bfloat16 if the first layer uses bfloat16
|
||||
if len(self.layers) > 0 and self.layers[0].self_attn.q_proj.weight.dtype == torch.bfloat16:
|
||||
hidden_states = hidden_states.to(torch.bfloat16)
|
||||
if causal_mask is not None and torch.is_floating_point(causal_mask):
|
||||
causal_mask = causal_mask.to(dtype=hidden_states.dtype)
|
||||
|
||||
# create position embeddings to be shared across the decoder layers
|
||||
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||
@@ -381,374 +367,3 @@ __all__ = [
|
||||
"PaliGemmaModelWithPiGemma",
|
||||
"PaliGemmaForConditionalGenerationWithPiGemma",
|
||||
]
|
||||
|
||||
|
||||
# PI0.5 / PI052 dual-expert backbone: generic PaliGemma + Gemma action-expert
|
||||
# transformer machinery used by the pi052 policy. GemmaVariantConfig is openpi's
|
||||
# width/depth variant config (renamed from GemmaConfig to avoid clashing with
|
||||
# transformers' GemmaConfig).
|
||||
|
||||
def sdpa_attention_forward(
|
||||
module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None,
|
||||
scaling: float,
|
||||
dropout: float = 0.0,
|
||||
):
|
||||
"""Drop-in for ``modeling_gemma.eager_attention_forward`` using
|
||||
``torch.nn.functional.scaled_dot_product_attention``.
|
||||
|
||||
PyTorch SDPA picks the memory-efficient kernel for arbitrary additive
|
||||
bias masks (the FA backend only accepts causal/sliding-window). On
|
||||
H100 that is ~1.3-1.7x faster and uses ~30-40% less attention memory
|
||||
than the eager softmax(QK^T)+matmul path. Mirrors eager's signature
|
||||
and output shape (``(B, Lq, H, D)``) so call sites are unchanged.
|
||||
"""
|
||||
n_rep = module.num_key_value_groups
|
||||
if n_rep > 1:
|
||||
key = key.repeat_interleave(n_rep, dim=1)
|
||||
value = value.repeat_interleave(n_rep, dim=1)
|
||||
if attention_mask is not None and attention_mask.dtype != query.dtype:
|
||||
attention_mask = attention_mask.to(dtype=query.dtype)
|
||||
attn_output = F.scaled_dot_product_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=attention_mask,
|
||||
dropout_p=dropout if module.training else 0.0,
|
||||
is_causal=False,
|
||||
scale=scaling,
|
||||
)
|
||||
return attn_output.transpose(1, 2).contiguous(), None
|
||||
|
||||
|
||||
# Define the complete layer computation function for gradient checkpointing
|
||||
def compute_layer_complete(
|
||||
layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma, gemma_expert
|
||||
):
|
||||
models = [paligemma.model.language_model, gemma_expert.model]
|
||||
query_states = []
|
||||
key_states = []
|
||||
value_states = []
|
||||
gates = []
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
layer = models[i].layers[layer_idx]
|
||||
hidden_states, gate = layernorm_forward(layer.input_layernorm, hidden_states, adarms_cond[i])
|
||||
gates.append(gate)
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
|
||||
query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
query_states.append(query_state)
|
||||
key_states.append(key_state)
|
||||
value_states.append(value_state)
|
||||
# Concatenate and process attention
|
||||
query_states = torch.cat(query_states, dim=2)
|
||||
key_states = torch.cat(key_states, dim=2)
|
||||
value_states = torch.cat(value_states, dim=2)
|
||||
dummy_tensor = torch.zeros(
|
||||
query_states.shape[0],
|
||||
query_states.shape[2],
|
||||
query_states.shape[-1],
|
||||
device=query_states.device,
|
||||
dtype=query_states.dtype,
|
||||
)
|
||||
cos, sin = paligemma.model.language_model.rotary_emb(dummy_tensor, position_ids)
|
||||
query_states, key_states = modeling_gemma.apply_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin, unsqueeze_dim=1
|
||||
)
|
||||
batch_size = query_states.shape[0]
|
||||
scaling = paligemma.model.language_model.layers[layer_idx].self_attn.scaling
|
||||
att_output, _ = sdpa_attention_forward(
|
||||
paligemma.model.language_model.layers[layer_idx].self_attn,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
scaling,
|
||||
)
|
||||
# Get head_dim from the current layer, not from the model
|
||||
head_dim = paligemma.model.language_model.layers[layer_idx].self_attn.head_dim
|
||||
att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim)
|
||||
# Process layer outputs
|
||||
outputs_embeds = []
|
||||
start_pos = 0
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
layer = models[i].layers[layer_idx]
|
||||
end_pos = start_pos + hidden_states.shape[1]
|
||||
if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
|
||||
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
|
||||
out_emb = layer.self_attn.o_proj(att_output[:, start_pos:end_pos])
|
||||
# first residual
|
||||
out_emb = _gated_residual(hidden_states, out_emb, gates[i])
|
||||
after_first_residual = out_emb.clone()
|
||||
out_emb, gate = layernorm_forward(layer.post_attention_layernorm, out_emb, adarms_cond[i])
|
||||
# Convert to bfloat16 if the next layer (mlp) uses bfloat16
|
||||
if layer.mlp.up_proj.weight.dtype == torch.bfloat16:
|
||||
out_emb = out_emb.to(dtype=torch.bfloat16)
|
||||
out_emb = layer.mlp(out_emb)
|
||||
# second residual
|
||||
out_emb = _gated_residual(after_first_residual, out_emb, gate)
|
||||
outputs_embeds.append(out_emb)
|
||||
start_pos = end_pos
|
||||
return outputs_embeds
|
||||
|
||||
|
||||
class GemmaVariantConfig: # see openpi `gemma.py: Config`
|
||||
"""Configuration for Gemma model variants."""
|
||||
|
||||
def __init__(self, width, depth, mlp_dim, num_heads, num_kv_heads, head_dim):
|
||||
self.width = width
|
||||
self.depth = depth
|
||||
self.mlp_dim = mlp_dim
|
||||
self.num_heads = num_heads
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.head_dim = head_dim
|
||||
|
||||
|
||||
def get_gemma_config(variant: str) -> GemmaVariantConfig: # see openpi `gemma.py: get_config`
|
||||
"""Returns config for specified gemma variant."""
|
||||
if variant == "gemma_300m":
|
||||
return GemmaVariantConfig(
|
||||
width=1024,
|
||||
depth=18,
|
||||
mlp_dim=4096,
|
||||
num_heads=8,
|
||||
num_kv_heads=1,
|
||||
head_dim=256,
|
||||
)
|
||||
elif variant == "gemma_2b":
|
||||
return GemmaVariantConfig(
|
||||
width=2048,
|
||||
depth=18,
|
||||
mlp_dim=16_384,
|
||||
num_heads=8,
|
||||
num_kv_heads=1,
|
||||
head_dim=256,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown variant: {variant}")
|
||||
|
||||
|
||||
class PaliGemmaWithExpertModel(
|
||||
nn.Module
|
||||
): # see openpi `gemma_pytorch.py: PaliGemmaWithExpertModel` this class is almost a exact copy of PaliGemmaWithExpertModel in openpi
|
||||
"""PaliGemma model with action expert for PI05."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vlm_config,
|
||||
action_expert_config,
|
||||
use_adarms=None,
|
||||
precision: Literal["bfloat16", "float32"] = "bfloat16",
|
||||
image_size: int = DEFAULT_IMAGE_SIZE,
|
||||
freeze_vision_encoder: bool = False,
|
||||
train_expert_only: bool = False,
|
||||
):
|
||||
if use_adarms is None:
|
||||
use_adarms = [False, False]
|
||||
super().__init__()
|
||||
self.freeze_vision_encoder = freeze_vision_encoder
|
||||
self.train_expert_only = train_expert_only
|
||||
|
||||
vlm_config_hf = CONFIG_MAPPING["paligemma"]()
|
||||
vlm_config_hf._vocab_size = 257152 # noqa: SLF001
|
||||
vlm_config_hf.image_token_index = 257152
|
||||
vlm_config_hf.text_config.hidden_size = vlm_config.width
|
||||
vlm_config_hf.text_config.intermediate_size = vlm_config.mlp_dim
|
||||
vlm_config_hf.text_config.num_attention_heads = vlm_config.num_heads
|
||||
vlm_config_hf.text_config.head_dim = vlm_config.head_dim
|
||||
vlm_config_hf.text_config.num_hidden_layers = vlm_config.depth
|
||||
vlm_config_hf.text_config.num_key_value_heads = vlm_config.num_kv_heads
|
||||
vlm_config_hf.text_config.hidden_activation = "gelu_pytorch_tanh"
|
||||
vlm_config_hf.text_config.dtype = "float32"
|
||||
vlm_config_hf.text_config.vocab_size = 257152
|
||||
vlm_config_hf.text_config.use_adarms = use_adarms[0]
|
||||
vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None
|
||||
vlm_config_hf.vision_config.image_size = image_size
|
||||
vlm_config_hf.vision_config.intermediate_size = 4304
|
||||
vlm_config_hf.vision_config.projection_dim = 2048
|
||||
vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast"
|
||||
vlm_config_hf.vision_config.dtype = "float32"
|
||||
|
||||
action_expert_config_hf = CONFIG_MAPPING["gemma"](
|
||||
head_dim=action_expert_config.head_dim,
|
||||
hidden_size=action_expert_config.width,
|
||||
intermediate_size=action_expert_config.mlp_dim,
|
||||
num_attention_heads=action_expert_config.num_heads,
|
||||
num_hidden_layers=action_expert_config.depth,
|
||||
num_key_value_heads=action_expert_config.num_kv_heads,
|
||||
vocab_size=257152,
|
||||
hidden_activation="gelu_pytorch_tanh",
|
||||
dtype="float32",
|
||||
use_adarms=use_adarms[1],
|
||||
adarms_cond_dim=action_expert_config.width if use_adarms[1] else None,
|
||||
)
|
||||
|
||||
self.paligemma = PaliGemmaForConditionalGenerationWithPiGemma(config=vlm_config_hf)
|
||||
self.gemma_expert = PiGemmaForCausalLM(config=action_expert_config_hf)
|
||||
self.gemma_expert.model.embed_tokens = None
|
||||
|
||||
self.to_bfloat16_for_selected_params(precision)
|
||||
self._set_requires_grad()
|
||||
|
||||
def to_bfloat16_for_selected_params(self, precision: Literal["bfloat16", "float32"] = "bfloat16"):
|
||||
if precision == "bfloat16":
|
||||
self.to(dtype=torch.bfloat16)
|
||||
elif precision == "float32":
|
||||
self.to(dtype=torch.float32)
|
||||
return
|
||||
else:
|
||||
raise ValueError(f"Invalid precision: {precision}")
|
||||
|
||||
# Keep full vision path in float32 so we never toggle (toggle causes optimizer
|
||||
# "same dtype" error). Saves memory vs full float32; more memory than only 3 params.
|
||||
params_to_keep_float32 = [
|
||||
"vision_tower",
|
||||
"multi_modal_projector",
|
||||
"lm_head",
|
||||
"input_layernorm",
|
||||
"post_attention_layernorm",
|
||||
"model.norm",
|
||||
]
|
||||
|
||||
for name, param in self.named_parameters():
|
||||
if any(selector in name for selector in params_to_keep_float32):
|
||||
param.data = param.data.to(dtype=torch.float32)
|
||||
|
||||
def _set_requires_grad(self):
|
||||
if self.freeze_vision_encoder:
|
||||
self.paligemma.model.vision_tower.eval()
|
||||
for param in self.paligemma.model.vision_tower.parameters():
|
||||
param.requires_grad = False
|
||||
if self.train_expert_only:
|
||||
self.paligemma.eval()
|
||||
for param in self.paligemma.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def train(self, mode: bool = True):
|
||||
super().train(mode)
|
||||
if self.freeze_vision_encoder:
|
||||
self.paligemma.model.vision_tower.eval()
|
||||
if self.train_expert_only:
|
||||
self.paligemma.eval()
|
||||
|
||||
def embed_image(self, image: torch.Tensor):
|
||||
# Vision tower and multi_modal_projector are kept in float32 (params_to_keep_float32).
|
||||
out_dtype = image.dtype
|
||||
if image.dtype != torch.float32:
|
||||
image = image.to(torch.float32)
|
||||
image_outputs = self.paligemma.model.get_image_features(image)
|
||||
# OpenPI / big_vision convention: image (soft) tokens are NOT scaled by the
|
||||
# Gemma embedder normalizer (sqrt(hidden_size)) — only text tokens are. lerobot/pi05_base
|
||||
# was trained in this regime, so scaling image features here over-scales them ~45x and
|
||||
# breaks the pretrained vision-language alignment. Keep image features un-normalized.
|
||||
features = image_outputs.pooler_output
|
||||
if features.dtype != out_dtype:
|
||||
features = features.to(out_dtype)
|
||||
return features
|
||||
|
||||
def embed_language_tokens(self, tokens: torch.Tensor):
|
||||
return self.paligemma.model.language_model.embed_tokens(tokens)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
position_ids: torch.LongTensor | None = None,
|
||||
past_key_values: list[torch.FloatTensor] | None = None,
|
||||
inputs_embeds: list[torch.FloatTensor] | None = None,
|
||||
use_cache: bool | None = None,
|
||||
adarms_cond: list[torch.Tensor] | None = None,
|
||||
):
|
||||
if adarms_cond is None:
|
||||
adarms_cond = [None, None]
|
||||
if inputs_embeds[1] is None:
|
||||
prefix_output = self.paligemma.model.language_model.forward(
|
||||
inputs_embeds=inputs_embeds[0],
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
adarms_cond=adarms_cond[0] if adarms_cond is not None else None,
|
||||
)
|
||||
prefix_past_key_values = prefix_output.past_key_values
|
||||
prefix_output = prefix_output.last_hidden_state
|
||||
suffix_output = None
|
||||
elif inputs_embeds[0] is None:
|
||||
suffix_output = self.gemma_expert.model.forward(
|
||||
inputs_embeds=inputs_embeds[1],
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
adarms_cond=adarms_cond[1] if adarms_cond is not None else None,
|
||||
)
|
||||
suffix_output = suffix_output.last_hidden_state
|
||||
prefix_output = None
|
||||
prefix_past_key_values = None
|
||||
else:
|
||||
models = [self.paligemma.model.language_model, self.gemma_expert.model]
|
||||
num_layers = self.paligemma.config.text_config.num_hidden_layers
|
||||
|
||||
# Check if gradient checkpointing is enabled for any of the models
|
||||
use_gradient_checkpointing = (
|
||||
hasattr(self.gemma_expert.model, "gradient_checkpointing")
|
||||
and self.gemma_expert.model.gradient_checkpointing
|
||||
and self.training
|
||||
) or (hasattr(self, "gradient_checkpointing") and self.gradient_checkpointing and self.training)
|
||||
|
||||
# Process all layers with gradient checkpointing if enabled
|
||||
for layer_idx in range(num_layers):
|
||||
if use_gradient_checkpointing:
|
||||
inputs_embeds = torch.utils.checkpoint.checkpoint(
|
||||
compute_layer_complete,
|
||||
layer_idx,
|
||||
inputs_embeds,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
adarms_cond,
|
||||
use_reentrant=False,
|
||||
preserve_rng_state=False,
|
||||
paligemma=self.paligemma,
|
||||
gemma_expert=self.gemma_expert,
|
||||
)
|
||||
else:
|
||||
inputs_embeds = compute_layer_complete(
|
||||
layer_idx,
|
||||
inputs_embeds,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
adarms_cond,
|
||||
paligemma=self.paligemma,
|
||||
gemma_expert=self.gemma_expert,
|
||||
)
|
||||
|
||||
# final norm
|
||||
def compute_final_norms(inputs_embeds, adarms_cond):
|
||||
outputs_embeds = []
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
out_emb, _ = layernorm_forward(models[i].norm, hidden_states, adarms_cond[i])
|
||||
outputs_embeds.append(out_emb)
|
||||
return outputs_embeds
|
||||
|
||||
# Apply gradient checkpointing to final norm if enabled
|
||||
if use_gradient_checkpointing:
|
||||
outputs_embeds = torch.utils.checkpoint.checkpoint(
|
||||
compute_final_norms,
|
||||
inputs_embeds,
|
||||
adarms_cond,
|
||||
use_reentrant=False,
|
||||
preserve_rng_state=False,
|
||||
)
|
||||
else:
|
||||
outputs_embeds = compute_final_norms(inputs_embeds, adarms_cond)
|
||||
|
||||
prefix_output = outputs_embeds[0]
|
||||
suffix_output = outputs_embeds[1]
|
||||
prefix_past_key_values = None
|
||||
|
||||
return [prefix_output, suffix_output], prefix_past_key_values
|
||||
|
||||
|
||||
@@ -0,0 +1,286 @@
|
||||
# Remote Inference Architecture
|
||||
|
||||
How `lerobot-policy-server` and `lerobot-rollout --inference.type=remote` decouple GPU-bound policy inference from high-frequency robot control over Zenoh.
|
||||
|
||||
This document explains the **internals** — the wire protocol, threading models, state machines, and safety invariants. For the user-facing guide (CLI quickstarts, deployment), see [`docs/source/remote_inference.mdx`](../../../docs/source/remote_inference.mdx).
|
||||
|
||||
## 1. The problem and the shape of the solution
|
||||
|
||||
Running a large policy (Pi0-class, ~150 ms inference) inside a 33 ms control loop doesn't work, and putting a GPU next to every robot doesn't scale. LeRobot already solved the _local_ version of this problem: `RTCInferenceEngine` runs inference in a background **thread** that fills a thread-safe `ActionQueue`, while the control loop pops one action per tick.
|
||||
|
||||
Remote inference is **that same architecture with the thread boundary replaced by a network boundary**:
|
||||
|
||||
```
|
||||
local RTC: control loop ──ActionQueue── inference thread (same process, same GPU)
|
||||
remote: control loop ──ActionQueue── network worker ══zenoh══ policy server (GPU, elsewhere)
|
||||
```
|
||||
|
||||
Three design commitments follow from this:
|
||||
|
||||
- **The client is a backend, not a CLI.** `RemoteInferenceEngine` plugs into the existing `InferenceEngine` seam (`rollout/inference/base.py`), so every rollout strategy (base, sentry, highlight, dagger, episodic) gets network inference — including dataset recording, pause/resume, and safe teardown — without changing a line.
|
||||
- **The client is weightless.** No policy weights, no policy processors on the edge. `--policy.path` resolves to a config-only `PreTrainedConfig` used for pre-flight validation and action ordering.
|
||||
- **The server is stateless per request.** All chunk state (RTC prefixes, latency tracking, delay computation) lives client-side in the existing `ActionQueue`/`LatencyTracker`. The client ships prefixes + a delay hint with every observation, so a server crash loses zero control state.
|
||||
|
||||
## 2. Component map
|
||||
|
||||
```mermaid
|
||||
flowchart LR
|
||||
subgraph EDGE["Edge (per robot, weightless)"]
|
||||
R[Robot HW] --> S["Rollout strategy<br/>(sentry / dagger / ...)"]
|
||||
S -->|"notify_observation()"| E[RemoteInferenceEngine]
|
||||
E -->|"get_action()"| S
|
||||
S -->|actions| R
|
||||
E --- AQ[("ActionQueue<br/>(chunk buffer)")]
|
||||
end
|
||||
|
||||
subgraph NET["Transport"]
|
||||
Z["zenohd router(s)<br/>(robots dial out, mTLS + ACL)"]
|
||||
end
|
||||
|
||||
subgraph GPU["GPU pod (one model · one device · one process)"]
|
||||
PS[PolicyServer]
|
||||
PS --- SR["SessionRegistry<br/>(per-client mailboxes + pipelines)"]
|
||||
PS --- W["Inference worker<br/>(1 thread, owns GPU)"]
|
||||
W --- P["PreTrainedPolicy<br/>(pre-warmed)"]
|
||||
end
|
||||
|
||||
E <-->|"obs ↑ / chunks ↓ (pub/sub)<br/>status · session · reset (queryables)"| Z
|
||||
Z <--> PS
|
||||
```
|
||||
|
||||
One server process = one pre-warmed `(model, revision, dtype, device)` serving up to `max_sessions` robots. Scaling out = more pods; clients rejected with the current load retry another replica.
|
||||
|
||||
## 3. Where the network cut goes
|
||||
|
||||
The local RTC pipeline is split at the cheapest, most hardware-coupled point. Everything policy-coupled (resize, normalize, tokenize) runs server-side with the **canonical training-time processors**, so serve-time preprocessing is byte-identical to train-time:
|
||||
|
||||
```
|
||||
robot obs (processed dict)
|
||||
→ build_dataset_frame(...) CLIENT cheap, hardware-coupled
|
||||
→ rename_map applied to keys CLIENT wire format = canonical policy keys
|
||||
══════════════════════ network (msgpack + JPEG) ══════════════════════
|
||||
→ prepare_observation_for_inference(...) SERVER tensors, batch dim, device
|
||||
→ per-session preprocessor(...) SERVER stateful within the request
|
||||
→ policy.predict_action_chunk(obs, delay, prefix) SERVER pure for allowlisted policies
|
||||
→ per-session postprocessor(...) SERVER reads state cached at preprocess
|
||||
══════════════════════ network ══════════════════════
|
||||
→ ActionQueue.merge(original, processed, delay, idx_before) CLIENT
|
||||
```
|
||||
|
||||
The reply carries **both** the model-space (`chunk_model`) and robot-space (`chunk_robot`) chunks because `ActionQueue.merge` needs both, and the next request's relative-action prefix re-anchoring needs the robot-space tail.
|
||||
|
||||
## 4. Wire protocol
|
||||
|
||||
### 4.1 Key-expression schema (Zenoh)
|
||||
|
||||
```
|
||||
@lerobot/<model_slug>/<revision>/<task_slug>/<client_uuid>/obs client → server pub/sub
|
||||
@lerobot/<model_slug>/<revision>/<task_slug>/<client_uuid>/action server → client pub/sub
|
||||
@lerobot/<model_slug>/<revision>/<task_slug>/status queryable (capabilities)
|
||||
@lerobot/<model_slug>/<revision>/<task_slug>/session queryable (open / close)
|
||||
@lerobot/<model_slug>/<revision>/<task_slug>/<client_uuid>/reset queryable (episode boundary)
|
||||
@lerobot/<model_slug>/<revision>/<task_slug>/<client_uuid>/alive liveliness token (client)
|
||||
@lerobot/<model_slug>/<revision>/<task_slug>/server/alive liveliness token (server)
|
||||
```
|
||||
|
||||
`@lerobot` is a **verbatim chunk**: wildcards never match it, so third-party `**` subscribers on a shared router cannot scrape the tree. User-supplied segments are sanitized (`sanitize_key_segment`), and the server subscribes with single-depth wildcards only (`.../*/obs`, never `**`).
|
||||
|
||||
Data plane = pub/sub (a late chunk is still usable; a timed-out query reply is not). Control plane = queryables with explicit timeouts (the rmw_zenoh pattern). QoS (`zenoh_utils.py`): actions are `RELIABLE + congestion DROP + express + INTERACTIVE_HIGH` — **never BLOCK**, so one dead robot uplink can never stall the server's publish path; a dropped chunk is recoverable because the client buffer keeps the robot moving.
|
||||
|
||||
### 4.2 Messages
|
||||
|
||||
Every data-plane message carries a **packed little-endian attachment header** (27 bytes, parsed without touching the body):
|
||||
|
||||
| field | type | meaning |
|
||||
| ---------------- | ---- | --------------------------------------------------------- |
|
||||
| `schema_version` | u16 | negotiated at session open; additive-only body evolution |
|
||||
| `msg_type` | u8 | OBS / CHUNK / EVENT |
|
||||
| `seq_id` | u64 | per-session monotonic; echoed in the chunk |
|
||||
| `episode_id` | u32 | bumped by `reset()` |
|
||||
| `client_mono_ns` | i64 | client monotonic clock — **opaque to the server, echoed** |
|
||||
| `session_epoch` | u32 | bumped per (re)connect; stale-epoch chunks dropped |
|
||||
|
||||
Bodies are msgpack (`codec.py`): tensors as raw little-endian bytes + dtype + shape, images JPEG (RGB convention enforced inside the codec; `jpeg_quality=0` = raw). No pickle anywhere — nothing on the wire can carry code.
|
||||
|
||||
**Clock iron rule:** wall-clock instants never cross machines. The client computes RTT from its own monotonic clock via the echoed `client_mono_ns`; the server reports only **durations** (`queue_wait_ms`, `inference_ms`).
|
||||
|
||||
### 4.3 Session lifecycle
|
||||
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
participant C as RemoteInferenceEngine
|
||||
participant S as PolicyServer
|
||||
|
||||
C->>S: GET status (timeout 2s)
|
||||
S-->>C: capabilities (model, action_names, cameras, chunk_size, supports_rtc, ...)
|
||||
C->>S: GET session {op: open, action_names, cameras, state_dim, fps, rtc, task}
|
||||
Note over S: validate (hard: action name ORDER,<br/>cameras, state_dim, schema, capacity)
|
||||
S-->>C: SessionAck {session_id, warnings, rtc_execution_horizon, ...}
|
||||
Note over C,S: both declare liveliness tokens
|
||||
|
||||
loop self-clocked by buffer_time_s (one-in-flight)
|
||||
C->>S: PUB obs {state, images, delay_steps, prefix_model, prefix_robot} + header
|
||||
Note over S: latest-only mailbox → worker →<br/>preprocess → predict_action_chunk → postprocess
|
||||
S-->>C: PUB chunk {chunk_model, chunk_robot, durations} + echoed header
|
||||
Note over C: validate (episode, epoch) → ActionQueue.merge(..., idx_before)
|
||||
end
|
||||
|
||||
C->>S: GET reset {episode_id} (episode boundary, acked)
|
||||
C->>S: GET session {op: close} (graceful stop)
|
||||
```
|
||||
|
||||
The **action-name order check is a hard reject**: it is the contract that maps chunk columns to motors. A mismatch means wrong-joint commands, so the session never opens.
|
||||
|
||||
## 5. The client: `RemoteInferenceEngine`
|
||||
|
||||
File: `src/lerobot/rollout/inference/remote.py`, registered as `--inference.type=remote` (`RemoteInferenceConfig` in `factory.py`).
|
||||
|
||||
### 5.1 Threading model
|
||||
|
||||
| thread | role |
|
||||
| ---------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| main (strategy loop) | `notify_observation()` → latest-only slot; `get_action()` → `ActionQueue.get()` + staleness check + fallback. **Never any I/O.** |
|
||||
| network worker (1) | gate on `buffer_time_s` → snapshot `(seq, episode, epoch)` then `idx_before` + RTC prefixes → publish obs → await chunk (timeout) → revalidate → merge. Owns the state machine and reconnects. |
|
||||
| zenoh callback threads | deposit-only: chunk → bounded queue; server liveliness → event. |
|
||||
|
||||
**One-in-flight is a correctness requirement, not a tuning choice.** `merge(..., idx_before)` validates against the consumption index snapshotted at send time; two in-flight requests would carry conflicting snapshots and corrupt both RTC-replace and append modes. The worker therefore publishes one observation, waits for its chunk (or timeout), then sends the next. A late chunk is accepted only if it answers the latest outstanding `seq_id` _and_ the current `(episode, epoch)`.
|
||||
|
||||
### 5.2 The request cycle
|
||||
|
||||
```
|
||||
queue playback ≤ buffer_time_s? (self-clocking: ~1–4 Hz, not the 30 Hz control rate)
|
||||
├─ snapshot (seq, episode, epoch)
|
||||
├─ snapshot idx_before, prefix_model = queue.get_left_over()[:H],
|
||||
│ prefix_robot = queue.get_processed_left_over()[:H]
|
||||
├─ revalidate (episode, epoch) unchanged ← a reset racing the snapshot skips the cycle
|
||||
├─ delay_steps = ceil(LatencyTracker.max() / dt)
|
||||
├─ publish obs + header
|
||||
├─ await chunk (request_timeout_s)
|
||||
├─ revalidate (episode, epoch) under _anchor_lock ← a stale chunk can never survive a reset
|
||||
└─ merge(chunk_model, chunk_robot, ceil(measured_latency/dt), idx_before); update anchor
|
||||
```
|
||||
|
||||
Because the `LatencyTracker` samples are full network-inclusive cycle times, RTT compensation falls out for free — the same `delay`-trimming machinery local RTC uses absorbs network latency as just more delay.
|
||||
|
||||
### 5.3 Fail-safe state machine
|
||||
|
||||
```mermaid
|
||||
stateDiagram-v2
|
||||
[*] --> CONNECTING
|
||||
CONNECTING --> STREAMING: first merge
|
||||
STREAMING --> DEGRADED: request timeouts,<br/>queue still has actions
|
||||
DEGRADED --> STREAMING: merge
|
||||
DEGRADED --> STALLED: queue empty or<br/>max_action_age_s hit
|
||||
STALLED --> RECONNECTING: timeout streak /<br/>server liveliness drop
|
||||
DEGRADED --> RECONNECTING: timeout streak /<br/>server liveliness drop
|
||||
RECONNECTING --> STREAMING: re-handshake OK<br/>(epoch++)
|
||||
RECONNECTING --> DEAD: offline > max_offline_s,<br/>capability/model mismatch
|
||||
DEAD --> [*]: failed=True → shutdown_event<br/>→ strategy teardown
|
||||
```
|
||||
|
||||
- **DEGRADED**: the chunk buffer _is_ the fault tolerance — 1–3 s of buffered actions makes network blips and clean server drains invisible to the robot.
|
||||
- **Staleness bound** (`max_action_age_s`): `get_action` refuses any action whose source observation is too old, bounding open-loop execution after a stall. Then the **fallback ladder** applies: `hold` (return `None`; the robot holds), `repeat_last`, or `zero` (the safe stop for velocity-controlled robots).
|
||||
- **Watchdog layering**: per-request timeout (catches a _hung-but-connected_ server) → server liveliness token (catches a dead server/router) → staleness bound (the robot-side invariant that holds regardless of why data stopped).
|
||||
- **DEAD** is reserved for hard failures: offline beyond `max_offline_s` with no successful merge (a server that handshakes but never delivers chunks still runs out of budget), or a contract violation on reconnect (model/revision changed, RTC capability flipped — never execute wrong-model chunks). It triggers the exact mechanism local RTC uses: `failed=True` + the global `shutdown_event`, so the existing teardown (return-to-initial-pose) runs unchanged.
|
||||
- **Pause/resume** (DAgger): `pause()` stops publishing; the queue stays intact. A pause during an outage freezes the offline budget so a human correction can never be aborted by `max_offline_s`.
|
||||
|
||||
### 5.4 Episode boundaries
|
||||
|
||||
`reset()` (control thread) atomically — under the same lock the merge path takes — clears the `ActionQueue`, nulls the staleness anchor, bumps `episode_id`, and invalidates the observation slot (the previous episode's final frame must not seed the new one). The worker sends an acked `reset` query, and the next observation header carries the new `episode_id` anyway — so a lost ack costs nothing (the server is stateless per request).
|
||||
|
||||
## 6. The server: `PolicyServer`
|
||||
|
||||
Files: `src/lerobot/policy_server/`. Entry point: `lerobot-policy-server --manifest server.yaml` (draccus dataclasses in `manifest.py`).
|
||||
|
||||
### 6.1 Concurrency model
|
||||
|
||||
zenoh-python is thread-based (no asyncio); callbacks must be deposit-only:
|
||||
|
||||
```
|
||||
zenoh subscriber (.../*/obs) inference worker (1 thread, owns GPU)
|
||||
deposit-only callback: loop:
|
||||
session.deposit(header, body) ──► scheduler picks next session with pending obs
|
||||
(per-client latest-only mailbox) decode → episode-boundary check
|
||||
preprocess → predict_action_chunk(delay, prefix)
|
||||
control queryables (status / postprocess → encode
|
||||
session / reset): validate, publisher.put(.../<uuid>/action)
|
||||
mutate registry, reply inline
|
||||
|
||||
liveliness subscriber (.../*/alive): mark sessions for GC on token DELETE
|
||||
```
|
||||
|
||||
- **Latest-only mailboxes**: the newest observation wins; superseded requests are counted and reported in the next reply (`superseded_seqs`), so drops are visible client-side. The client decides _when_ to request; the server never second-guesses observation content.
|
||||
- **Single inference worker** + round-robin over ready sessions: every ready session gets exactly one inference per cycle — starvation is structurally impossible. Overload degrades into longer cycle times → larger (but correct) client `delay_steps` → eventually the client staleness bound trips and the robot holds. Safe by construction.
|
||||
- The `Scheduler` seam (`scheduler.py`) exists so cross-session micro-batching can land later without redesign (blocked today on `predict_action_chunk` taking a _scalar_ `inference_delay`).
|
||||
- `_inference_lock` serializes the worker's predict path against episode resets arriving on queryable threads (in exclusive mode a `policy.reset()` mid-predict would corrupt the in-flight request).
|
||||
|
||||
### 6.2 Multi-tenancy: engineered, not assumed
|
||||
|
||||
Sharing one policy instance across sessions is only safe when `predict_action_chunk` touches no cross-request instance state. That property is **verified per family and encoded as a registry** (`validation.py`) — never inferred:
|
||||
|
||||
| class | policies | mode | why |
|
||||
| --------------- | ------------------------------------------------- | ----------- | ----------------------------------------------------------------------------------- |
|
||||
| chunk-stateless | `act`, `pi0`, `pi05`, `smolvla` (`n_obs_steps=1`) | `shared` | chunk call is pure (smolvla overwrites its 1-deep queue with the request's own obs) |
|
||||
| chunk-stateful | `diffusion` (and `smolvla` with `n_obs_steps>1`) | `exclusive` | chunk call reads `select_action`-fed `_queues` → server populates them per request |
|
||||
| no chunk API | `sac`, `tdmpc`, ... (no `predict_action_chunk`) | refused | nothing to serve |
|
||||
| unverified | any other chunk-API policy | `exclusive` | a manifest can force `exclusive`, but never `shared` for an unverified policy |
|
||||
|
||||
The real multi-tenancy hazard is **processor state**, not just policy purity: `RelativeActionsProcessorStep` caches `_last_state` at preprocess and the postprocessor reads it back. The server therefore builds a **fresh pre/post pipeline pair per session** — two robots at different joint positions can never cross-contaminate each other's action conversions. `policy.reset()` is **never** called in shared mode (it is global to the shared instance).
|
||||
|
||||
### 6.3 Statelessness and the RTC prefix
|
||||
|
||||
The server holds no cross-request control state. Each observation ships everything inference needs:
|
||||
|
||||
- `inference_delay_steps` — computed client-side from network-inclusive latency.
|
||||
- `prefix_model` — the unexecuted tail of the previous chunk in model space (feeds `prev_chunk_left_over`).
|
||||
- `prefix_robot` — the same tail in robot space. For relative-action policies the server **re-anchors** it against the state cached by _this request's_ preprocess (`reanchor_relative_rtc_prefix`, mirroring `rtc.py`), so the prefix is expressed relative to where the robot actually is now.
|
||||
|
||||
Consequences: reconnects are trivial, horizontal scaling is trivial, and a `kill -9` on the server loses nothing the client can't re-send.
|
||||
|
||||
### 6.4 Episode and reconnect hygiene
|
||||
|
||||
- Fresh sessions start at the `episode_id = -1` sentinel: the **first** observation of any session always triggers the boundary branch (pipelines reset; exclusive policies `reset()`), so a mid-episode reconnect can never inherit stale state.
|
||||
- Session replacement is identity-checked (`SessionRegistry.remove(expected=...)`): a GC sweep that snapshotted an old session can never tear down its just-handshaked replacement.
|
||||
- Liveliness GC double-checks with an explicit liveliness `get` before closing: the token key is per-client (not per-epoch), so a _late_ DELETE from a previous incarnation must not kill the live session.
|
||||
- Drain (`SIGTERM`): drop the liveliness token first (clients ride their buffers), finish the in-flight inference, undeclare the control surface, then close. Clients reconnect to another replica invisibly.
|
||||
|
||||
## 7. Latency budget (why the transport is never the bottleneck)
|
||||
|
||||
| stage | LAN | WAN (50 ms RTT) |
|
||||
| ------------------------------ | ------------- | --------------- |
|
||||
| JPEG encode + serialize (edge) | 2–9 ms | 2–9 ms |
|
||||
| uplink | ~2 ms | ~54 ms |
|
||||
| decode + canonical preprocess | 4–10 ms | 4–10 ms |
|
||||
| **inference** | **15–150 ms** | **15–150 ms** |
|
||||
| postprocess + downlink + merge | ~2 ms | ~27 ms |
|
||||
|
||||
Inference dominates (60–85% on LAN). At 30 fps a WAN deployment lands `delay_steps ≈ 4–8`, comfortably inside RTC execution horizons: WAN degrades smoothness parameters, never correctness. Requests are self-clocked by `buffer_time_s` to ~1–4 Hz per robot, so 300 robots cost ~0.3–10 Mbps each.
|
||||
|
||||
Capacity per GPU: `N_max ≈ 0.8 / (request_rate × inference_time)` → ~40 ACT-class or ~5 Pi0-class clients; `max_sessions` enforces it at session open (rejected clients receive the current load and retry another replica).
|
||||
|
||||
## 8. Observability & reproducibility
|
||||
|
||||
The contract is **fully logged + replayable**, not "deterministic" (no seed controls hardware or network jitter):
|
||||
|
||||
- **Client = source of truth**: recording strategies persist observations + executed actions as usual; the engine tracks `(session_id, seq_id, episode_id)` and per-cycle stats.
|
||||
- **Server**: one JSON audit line per request on the `lerobot.policy_server.audit` logger — `{session_id, client_uuid, seq_id, episode_id, queue_wait_ms, inference_ms, superseded, outcome}` — plus `/healthz` and Prometheus-style `/metrics`, and an optional bounded raw request/response capture (`debug.capture_dir`) for byte-exact offline replay.
|
||||
- Every hop shares `(session_id, seq_id)`, so joining a robot-side stutter to a server-side cause is mechanical.
|
||||
|
||||
## 9. File map
|
||||
|
||||
| path | contents |
|
||||
| ---------------------------------- | ---------------------------------------------------------------------------------------------------- |
|
||||
| `policy_server/schema.py` | wire messages, packed header, key-expression schema + sanitizer |
|
||||
| `policy_server/codec.py` | msgpack bodies, tensor codec (LE bytes), JPEG image codec (RGB convention) |
|
||||
| `policy_server/manifest.py` | draccus config: model, zenoh endpoints/TLS, serving mode, capacity, RTC, health |
|
||||
| `policy_server/validation.py` | serving-mode registry + session-open capability matrix |
|
||||
| `policy_server/session.py` | per-client `Session` (pipelines, latest-only mailbox, stats) + identity-safe registry |
|
||||
| `policy_server/scheduler.py` | `Scheduler` seam; `RoundRobinScheduler` |
|
||||
| `policy_server/zenoh_utils.py` | config builder, QoS profiles, lazy import with install hint |
|
||||
| `policy_server/server.py` | `PolicyServer`: zenoh surface, inference worker, GC, warmup, drain, health/metrics |
|
||||
| `rollout/inference/remote.py` | `RemoteInferenceEngine` (the edge client) |
|
||||
| `rollout/inference/factory.py` | `RemoteInferenceConfig`, `FallbackMode`, factory dispatch |
|
||||
| `scripts/lerobot_policy_server.py` | console entry point (`--manifest` → draccus `--config_path`) |
|
||||
| `tests/policy_server/` | codec/schema/validation/scheduler/session units, server logic, zenoh loopback + chaos, golden parity |
|
||||
|
||||
The golden parity test (`tests/policy_server/test_golden_parity.py`) is the standing contract: the remote request path (encode → decode → `run_inference_request` → encode → decode → merge) must produce **byte-identical** action queues to the local RTC compute path on identical inputs.
|
||||
@@ -0,0 +1,53 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Multi-client GPU policy serving over Zenoh (``lerobot-policy-server``).
|
||||
|
||||
The wire schema (:mod:`.schema`) and codecs (:mod:`.codec`) are shared
|
||||
with the edge-side :class:`~lerobot.rollout.inference.remote.RemoteInferenceEngine`.
|
||||
Heavy/optional imports (msgpack, zenoh, torch server) are deferred so the
|
||||
schema stays importable without the ``async`` extra.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from .manifest import (
|
||||
DebugSpec,
|
||||
ModelSpec,
|
||||
PolicyServerManifest,
|
||||
ZenohSpec,
|
||||
)
|
||||
from .schema import SCHEMA_VERSION, MsgHeader, service_prefix
|
||||
|
||||
__all__ = [
|
||||
"SCHEMA_VERSION",
|
||||
"DebugSpec",
|
||||
"ModelSpec",
|
||||
"MsgHeader",
|
||||
"PolicyServer",
|
||||
"PolicyServerManifest",
|
||||
"ZenohSpec",
|
||||
"codec",
|
||||
"service_prefix",
|
||||
]
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
import importlib
|
||||
|
||||
if name == "PolicyServer":
|
||||
return importlib.import_module(".server", __name__).PolicyServer
|
||||
if name == "codec":
|
||||
return importlib.import_module(".codec", __name__)
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
@@ -0,0 +1,262 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""MessagePack codecs for the remote-inference wire schema.
|
||||
|
||||
Encoding rules:
|
||||
- Tensors are raw little-endian bytes + dtype + shape (msgpack's ``bin``
|
||||
type), so decoding is a zero-parse ``np.frombuffer``.
|
||||
- Images are JPEG by default (``jpeg_quality=0`` sends raw bytes). The
|
||||
in-memory convention on both ends is **RGB** uint8 HWC; the OpenCV
|
||||
BGR↔RGB conversion happens inside this module only.
|
||||
- Decoders are tolerant: unknown keys are ignored, missing optional keys
|
||||
take dataclass defaults — schema evolution is additive-only.
|
||||
- No pickle anywhere: nothing in this codec can carry code.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
import msgpack
|
||||
except ImportError as e: # pragma: no cover
|
||||
raise ImportError(
|
||||
"Remote inference requires the 'async' extra: pip install 'lerobot[async]' (eclipse-zenoh + msgpack)"
|
||||
) from e
|
||||
|
||||
from .schema import (
|
||||
IMAGE_CODEC_JPEG,
|
||||
IMAGE_CODEC_RAW,
|
||||
ActionChunkMsg,
|
||||
ObservationMsg,
|
||||
ResetAckMsg,
|
||||
ResetMsg,
|
||||
SessionAckMsg,
|
||||
SessionCloseMsg,
|
||||
SessionOpenMsg,
|
||||
StatusMsg,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tensor codec
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _to_little_endian(arr: np.ndarray) -> np.ndarray:
|
||||
if arr.dtype.byteorder == ">":
|
||||
arr = arr.astype(arr.dtype.newbyteorder("<"))
|
||||
return np.ascontiguousarray(arr)
|
||||
|
||||
|
||||
def encode_tensor(arr: np.ndarray | None) -> dict[str, Any] | None:
|
||||
"""Encode an ndarray as raw little-endian bytes + dtype + shape."""
|
||||
if arr is None:
|
||||
return None
|
||||
arr = np.asarray(arr)
|
||||
# Record the shape before ascontiguousarray, which promotes 0-d to 1-d.
|
||||
shape = list(arr.shape)
|
||||
arr = _to_little_endian(arr)
|
||||
return {"dtype": arr.dtype.str, "shape": shape, "data": arr.tobytes()}
|
||||
|
||||
|
||||
def decode_tensor(obj: dict[str, Any] | None) -> np.ndarray | None:
|
||||
if obj is None:
|
||||
return None
|
||||
dtype = np.dtype(obj["dtype"])
|
||||
if dtype.hasobject:
|
||||
raise ValueError(f"Refusing object dtype {dtype} on the wire")
|
||||
arr = np.frombuffer(obj["data"], dtype=dtype).reshape(obj["shape"])
|
||||
# frombuffer returns a read-only view; copy so downstream torch.from_numpy works.
|
||||
return arr.copy()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Image codec (RGB uint8 HWC on both ends)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def encode_image(img: np.ndarray, jpeg_quality: int = 90) -> dict[str, Any]:
|
||||
"""Encode an RGB uint8 HWC image; ``jpeg_quality=0`` keeps it raw."""
|
||||
img = np.asarray(img)
|
||||
if img.dtype != np.uint8 or img.ndim != 3 or img.shape[2] != 3:
|
||||
raise ValueError(f"Expected uint8 HWC RGB image, got dtype={img.dtype} shape={img.shape}")
|
||||
if jpeg_quality <= 0:
|
||||
return {"codec": IMAGE_CODEC_RAW, "shape": list(img.shape), "data": _to_little_endian(img).tobytes()}
|
||||
ok, buf = cv2.imencode(
|
||||
".jpg", cv2.cvtColor(img, cv2.COLOR_RGB2BGR), [int(cv2.IMWRITE_JPEG_QUALITY), int(jpeg_quality)]
|
||||
)
|
||||
if not ok:
|
||||
raise ValueError("JPEG encoding failed")
|
||||
return {"codec": IMAGE_CODEC_JPEG, "data": buf.tobytes()}
|
||||
|
||||
|
||||
def decode_image(obj: dict[str, Any]) -> np.ndarray:
|
||||
"""Decode to an RGB uint8 HWC image."""
|
||||
codec = obj.get("codec", IMAGE_CODEC_JPEG)
|
||||
if codec == IMAGE_CODEC_RAW:
|
||||
return np.frombuffer(obj["data"], dtype=np.uint8).reshape(obj["shape"]).copy()
|
||||
if codec == IMAGE_CODEC_JPEG:
|
||||
bgr = cv2.imdecode(np.frombuffer(obj["data"], dtype=np.uint8), cv2.IMREAD_COLOR)
|
||||
if bgr is None:
|
||||
raise ValueError("JPEG decoding failed")
|
||||
return cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
|
||||
raise ValueError(f"Unknown image codec: {codec!r}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# msgpack helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _packb(obj: dict[str, Any]) -> bytes:
|
||||
return msgpack.packb(obj, use_bin_type=True)
|
||||
|
||||
|
||||
def _unpackb(data: bytes) -> dict[str, Any]:
|
||||
return msgpack.unpackb(data, raw=False)
|
||||
|
||||
|
||||
def decode_raw(data: bytes) -> dict[str, Any]:
|
||||
"""Decode a body to a plain dict (e.g. to peek a control-plane ``op``)."""
|
||||
return _unpackb(data)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Data-plane messages
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def encode_observation(msg: ObservationMsg) -> bytes:
|
||||
return _packb(
|
||||
{
|
||||
"state": encode_tensor(msg.state),
|
||||
"images": {name: encode_image(img, msg.jpeg_quality) for name, img in msg.images.items()},
|
||||
"task": msg.task,
|
||||
"inference_delay_steps": int(msg.inference_delay_steps),
|
||||
"prefix_model": encode_tensor(msg.prefix_model),
|
||||
"prefix_robot": encode_tensor(msg.prefix_robot),
|
||||
"episode_start": bool(msg.episode_start),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def decode_observation(data: bytes) -> ObservationMsg:
|
||||
obj = _unpackb(data)
|
||||
return ObservationMsg(
|
||||
state=decode_tensor(obj.get("state")),
|
||||
images={name: decode_image(img) for name, img in obj.get("images", {}).items()},
|
||||
task=obj.get("task", ""),
|
||||
inference_delay_steps=obj.get("inference_delay_steps", 0),
|
||||
prefix_model=decode_tensor(obj.get("prefix_model")),
|
||||
prefix_robot=decode_tensor(obj.get("prefix_robot")),
|
||||
episode_start=obj.get("episode_start", False),
|
||||
)
|
||||
|
||||
|
||||
def encode_action_chunk(msg: ActionChunkMsg) -> bytes:
|
||||
return _packb(
|
||||
{
|
||||
"seq_id_echo": int(msg.seq_id_echo),
|
||||
"client_mono_ns_echo": int(msg.client_mono_ns_echo),
|
||||
"episode_id_echo": int(msg.episode_id_echo),
|
||||
"chunk_model": encode_tensor(msg.chunk_model),
|
||||
"chunk_robot": encode_tensor(msg.chunk_robot),
|
||||
"queue_wait_ms": float(msg.queue_wait_ms),
|
||||
"inference_ms": float(msg.inference_ms),
|
||||
"superseded_seqs": int(msg.superseded_seqs),
|
||||
"server_load": float(msg.server_load),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def decode_action_chunk(data: bytes) -> ActionChunkMsg:
|
||||
obj = _unpackb(data)
|
||||
return ActionChunkMsg(
|
||||
seq_id_echo=obj.get("seq_id_echo", 0),
|
||||
client_mono_ns_echo=obj.get("client_mono_ns_echo", 0),
|
||||
episode_id_echo=obj.get("episode_id_echo", 0),
|
||||
chunk_model=decode_tensor(obj.get("chunk_model")),
|
||||
chunk_robot=decode_tensor(obj.get("chunk_robot")),
|
||||
queue_wait_ms=obj.get("queue_wait_ms", 0.0),
|
||||
inference_ms=obj.get("inference_ms", 0.0),
|
||||
superseded_seqs=obj.get("superseded_seqs", 0),
|
||||
server_load=obj.get("server_load", 0.0),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Control-plane messages (flat scalar/list/dict fields → generic codec)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _encode_flat(msg: Any) -> bytes:
|
||||
return _packb(dict(vars(msg).items()))
|
||||
|
||||
|
||||
def _decode_flat(cls: type, data: bytes) -> Any:
|
||||
obj = _unpackb(data)
|
||||
known = set(cls.__dataclass_fields__)
|
||||
return cls(**{k: v for k, v in obj.items() if k in known})
|
||||
|
||||
|
||||
def encode_session_open(msg: SessionOpenMsg) -> bytes:
|
||||
return _encode_flat(msg)
|
||||
|
||||
|
||||
def decode_session_open(data: bytes) -> SessionOpenMsg:
|
||||
return _decode_flat(SessionOpenMsg, data)
|
||||
|
||||
|
||||
def encode_session_ack(msg: SessionAckMsg) -> bytes:
|
||||
return _encode_flat(msg)
|
||||
|
||||
|
||||
def decode_session_ack(data: bytes) -> SessionAckMsg:
|
||||
return _decode_flat(SessionAckMsg, data)
|
||||
|
||||
|
||||
def encode_status(msg: StatusMsg) -> bytes:
|
||||
return _encode_flat(msg)
|
||||
|
||||
|
||||
def decode_status(data: bytes) -> StatusMsg:
|
||||
return _decode_flat(StatusMsg, data)
|
||||
|
||||
|
||||
def encode_reset(msg: ResetMsg) -> bytes:
|
||||
return _encode_flat(msg)
|
||||
|
||||
|
||||
def decode_reset(data: bytes) -> ResetMsg:
|
||||
return _decode_flat(ResetMsg, data)
|
||||
|
||||
|
||||
def encode_reset_ack(msg: ResetAckMsg) -> bytes:
|
||||
return _encode_flat(msg)
|
||||
|
||||
|
||||
def decode_reset_ack(data: bytes) -> ResetAckMsg:
|
||||
return _decode_flat(ResetAckMsg, data)
|
||||
|
||||
|
||||
def encode_session_close(msg: SessionCloseMsg) -> bytes:
|
||||
return _encode_flat(msg)
|
||||
|
||||
|
||||
def decode_session_close(data: bytes) -> SessionCloseMsg:
|
||||
return _decode_flat(SessionCloseMsg, data)
|
||||
@@ -0,0 +1,139 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Policy-server manifest: one process = one (model, revision, dtype, device) on one GPU.
|
||||
|
||||
Loaded from YAML via ``lerobot-policy-server --manifest server.yaml`` (or
|
||||
individual ``--model.repo_or_path=...`` CLI overrides through draccus).
|
||||
Dynamic model loading is deliberately unsupported: pre-warmed processes
|
||||
keep capacity planning honest and keep code-carrying payloads off the wire.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SERVING_MODE_AUTO = "auto"
|
||||
SERVING_MODE_SHARED = "shared"
|
||||
SERVING_MODE_EXCLUSIVE = "exclusive"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelSpec:
|
||||
"""Which policy this process serves, and where it runs."""
|
||||
|
||||
repo_or_path: str = ""
|
||||
revision: str = "main"
|
||||
# Optional torch dtype cast applied after load (e.g. "bfloat16").
|
||||
dtype: str | None = None
|
||||
device: str = "cuda"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ZenohSpec:
|
||||
"""Transport endpoints and security.
|
||||
|
||||
Robots and servers both *dial out* to a ``zenohd`` router in
|
||||
production (``mode=client``). ``mode=peer`` + ``listen_endpoints``
|
||||
supports router-less LAN and loopback test deployments. Multicast
|
||||
scouting is always disabled: fleet discovery is configuration, not
|
||||
protocol magic.
|
||||
"""
|
||||
|
||||
mode: str = "client" # "client" (via router) | "peer" (direct)
|
||||
connect_endpoints: list[str] = field(default_factory=list)
|
||||
listen_endpoints: list[str] = field(default_factory=list)
|
||||
# mTLS material (PEM paths). All three are required for TLS endpoints.
|
||||
tls_root_ca_certificate: str | None = None
|
||||
tls_connect_certificate: str | None = None
|
||||
tls_connect_private_key: str | None = None
|
||||
# Escape hatch: raw JSON5 merged into the zenoh config last.
|
||||
extra_config_json5: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class DebugSpec:
|
||||
"""Optional bounded request/response capture for offline replay."""
|
||||
|
||||
capture_dir: str | None = None
|
||||
capture_max: int = 256
|
||||
|
||||
|
||||
@dataclass
|
||||
class PolicyServerManifest:
|
||||
"""Top-level config for ``lerobot-policy-server``."""
|
||||
|
||||
model: ModelSpec = field(default_factory=ModelSpec)
|
||||
zenoh: ZenohSpec = field(default_factory=ZenohSpec)
|
||||
|
||||
# The task namespace this service is published under. When
|
||||
# ``pin_task`` is true, session opens with a different task string
|
||||
# are rejected; otherwise VLA clients may set the task per session.
|
||||
default_task: str = ""
|
||||
pin_task: bool = False
|
||||
# Optional override for the <task_slug> key segment (defaults to a
|
||||
# slug of ``default_task``).
|
||||
service_name: str = ""
|
||||
|
||||
# "auto" resolves from the policy classification (shared for
|
||||
# chunk-stateless policies, exclusive otherwise). "exclusive" can be
|
||||
# forced; "shared" cannot override a chunk-stateful classification.
|
||||
serving_mode: str = SERVING_MODE_AUTO
|
||||
max_sessions: int = 5
|
||||
warmup_inferences: int = 2
|
||||
|
||||
# FPS contract: warn on mismatch unless strict.
|
||||
trained_fps: float = 30.0
|
||||
strict_fps: bool = False
|
||||
|
||||
# RTC behaviour for this server process (global to the shared policy:
|
||||
# ``init_rtc_processor`` mutates the policy instance, so it is a
|
||||
# per-process decision, not per-session).
|
||||
rtc: RTCConfig = field(default_factory=RTCConfig)
|
||||
|
||||
# Sessions with no liveliness token and no traffic for this long are
|
||||
# garbage-collected (belt-and-braces behind liveliness GC).
|
||||
session_idle_timeout_s: float = 300.0
|
||||
|
||||
# HTTP health + Prometheus metrics port; 0 disables the endpoint.
|
||||
health_port: int = 9100
|
||||
|
||||
debug: DebugSpec = field(default_factory=DebugSpec)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if not self.model.repo_or_path:
|
||||
raise ValueError("--model.repo_or_path is required (the policy this server serves)")
|
||||
if self.serving_mode not in (SERVING_MODE_AUTO, SERVING_MODE_SHARED, SERVING_MODE_EXCLUSIVE):
|
||||
raise ValueError(f"serving_mode must be one of auto|shared|exclusive, got {self.serving_mode!r}")
|
||||
if self.max_sessions < 1:
|
||||
raise ValueError(f"max_sessions must be >= 1, got {self.max_sessions}")
|
||||
if self.zenoh.mode not in ("client", "peer"):
|
||||
raise ValueError(f"zenoh.mode must be 'client' or 'peer', got {self.zenoh.mode!r}")
|
||||
if self.zenoh.mode == "client" and not self.zenoh.connect_endpoints:
|
||||
raise ValueError("zenoh.connect_endpoints is required in client mode (router address)")
|
||||
tls_fields = (
|
||||
self.zenoh.tls_root_ca_certificate,
|
||||
self.zenoh.tls_connect_certificate,
|
||||
self.zenoh.tls_connect_private_key,
|
||||
)
|
||||
if any(tls_fields) and not all(tls_fields):
|
||||
raise ValueError(
|
||||
"TLS requires all of tls_root_ca_certificate, tls_connect_certificate, "
|
||||
"tls_connect_private_key"
|
||||
)
|
||||
@@ -0,0 +1,58 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Scheduling seam between the session registry and the inference worker.
|
||||
|
||||
The v1 scheduler is strict round-robin over sessions with a pending
|
||||
observation: every ready session gets exactly one inference per cycle,
|
||||
so starvation is structurally impossible. The seam exists so that
|
||||
cross-session micro-batching can land later without redesign (blocked
|
||||
today on ``predict_action_chunk`` taking a *scalar* ``inference_delay``).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
|
||||
from .session import Session
|
||||
|
||||
|
||||
class Scheduler(abc.ABC):
|
||||
"""Pick which ready session(s) the worker serves next."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def select(self, ready: list[Session]) -> list[Session]:
|
||||
"""Return the sessions to serve this cycle (subset of ``ready``)."""
|
||||
|
||||
|
||||
class RoundRobinScheduler(Scheduler):
|
||||
"""Serve one session per cycle, fairly, in client_uuid order."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._last_served: str | None = None
|
||||
|
||||
def select(self, ready: list[Session]) -> list[Session]:
|
||||
if not ready:
|
||||
return []
|
||||
ring = sorted(ready, key=lambda s: s.client_uuid)
|
||||
if self._last_served is not None:
|
||||
for i, session in enumerate(ring):
|
||||
if session.client_uuid > self._last_served:
|
||||
ring = ring[i:] + ring[:i]
|
||||
break
|
||||
else:
|
||||
pass # wrap: everyone is <= last served, keep sorted order
|
||||
chosen = ring[0]
|
||||
self._last_served = chosen.client_uuid
|
||||
return [chosen]
|
||||
@@ -0,0 +1,340 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Wire schema for remote policy inference.
|
||||
|
||||
Message dataclasses, the packed attachment header, and the Zenoh
|
||||
key-expression layout shared by the policy server and the remote
|
||||
inference engine. This module is transport-free (no zenoh import) so
|
||||
codecs and validation can be unit-tested without the optional extra.
|
||||
|
||||
Schema discipline: bodies are MessagePack maps decoded tolerantly
|
||||
(unknown keys ignored, missing optional keys defaulted) so evolution is
|
||||
additive-only. Any change to the attachment layout requires a
|
||||
``SCHEMA_VERSION`` bump; versions are negotiated at session open.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import struct
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import numpy as np
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Versioning
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
SCHEMA_VERSION = 1
|
||||
# Oldest schema version this build can still serve.
|
||||
MIN_SUPPORTED_SCHEMA_VERSION = 1
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Attachment header (fixed layout, packed little-endian)
|
||||
#
|
||||
# Parsed without touching the msgpack body so routing, correlation and
|
||||
# supersession decisions never pay deserialization costs. The
|
||||
# ``client_mono_ns`` field is a client-monotonic timestamp that is
|
||||
# OPAQUE to the server: it is echoed back verbatim so the client can
|
||||
# compute round-trip times on its own clock. Wall-clock instants never
|
||||
# cross machines (the clock iron rule).
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_HEADER_STRUCT = struct.Struct("<HBQIqI") # schema_version, msg_type, seq_id, episode_id, mono_ns, epoch
|
||||
|
||||
MSG_TYPE_OBS = 1
|
||||
MSG_TYPE_CHUNK = 2
|
||||
MSG_TYPE_EVENT = 3
|
||||
|
||||
|
||||
@dataclass
|
||||
class MsgHeader:
|
||||
"""Packed per-message header carried in the Zenoh attachment."""
|
||||
|
||||
schema_version: int = SCHEMA_VERSION
|
||||
msg_type: int = MSG_TYPE_OBS
|
||||
seq_id: int = 0
|
||||
episode_id: int = 0
|
||||
client_mono_ns: int = 0
|
||||
session_epoch: int = 0
|
||||
|
||||
def pack(self) -> bytes:
|
||||
return _HEADER_STRUCT.pack(
|
||||
self.schema_version,
|
||||
self.msg_type,
|
||||
self.seq_id,
|
||||
self.episode_id,
|
||||
self.client_mono_ns,
|
||||
self.session_epoch,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def unpack(cls, data: bytes) -> MsgHeader:
|
||||
if len(data) != _HEADER_STRUCT.size:
|
||||
raise ValueError(f"Bad header length: {len(data)} (expected {_HEADER_STRUCT.size})")
|
||||
version, msg_type, seq_id, episode_id, mono_ns, epoch = _HEADER_STRUCT.unpack(data)
|
||||
return cls(
|
||||
schema_version=version,
|
||||
msg_type=msg_type,
|
||||
seq_id=seq_id,
|
||||
episode_id=episode_id,
|
||||
client_mono_ns=mono_ns,
|
||||
session_epoch=epoch,
|
||||
)
|
||||
|
||||
|
||||
HEADER_SIZE = _HEADER_STRUCT.size
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Message bodies
|
||||
#
|
||||
# ``np.ndarray`` fields travel as raw little-endian bytes + dtype + shape
|
||||
# (see codec.py). Images travel JPEG-compressed by default.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
IMAGE_CODEC_JPEG = "jpeg"
|
||||
IMAGE_CODEC_RAW = "raw"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ObservationMsg:
|
||||
"""Client → server: one inference request (data plane)."""
|
||||
|
||||
state: np.ndarray | None = None # float32 [state_dim]
|
||||
images: dict[str, np.ndarray] = field(default_factory=dict) # name -> uint8 HWC RGB
|
||||
task: str = ""
|
||||
inference_delay_steps: int = 0
|
||||
# RTC prefixes: the unexecuted tail of the previous chunk, in model
|
||||
# space (original) and robot space (postprocessed). Both are needed
|
||||
# because the server re-anchors relative-action prefixes against the
|
||||
# current state and the client's ActionQueue.merge needs both chunks.
|
||||
prefix_model: np.ndarray | None = None # float32 [T, action_dim]
|
||||
prefix_robot: np.ndarray | None = None # float32 [T, action_dim]
|
||||
episode_start: bool = False
|
||||
# JPEG quality the images were encoded with; 0 means raw.
|
||||
jpeg_quality: int = 90
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActionChunkMsg:
|
||||
"""Server → client: one action chunk (data plane)."""
|
||||
|
||||
seq_id_echo: int = 0
|
||||
client_mono_ns_echo: int = 0
|
||||
episode_id_echo: int = 0
|
||||
chunk_model: np.ndarray | None = None # float32 [H, action_dim] (pre-postprocessor)
|
||||
chunk_robot: np.ndarray | None = None # float32 [H, action_dim] (postprocessed)
|
||||
# Durations only — measured on the server's monotonic clock, never
|
||||
# compared against client time (the clock iron rule).
|
||||
queue_wait_ms: float = 0.0
|
||||
inference_ms: float = 0.0
|
||||
# Observations from this client that were superseded (overwritten in
|
||||
# the latest-only mailbox) since the previous reply — makes drops visible.
|
||||
superseded_seqs: int = 0
|
||||
server_load: float = 0.0 # active_sessions / max_sessions
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionOpenMsg:
|
||||
"""Client → server (control plane): open and validate a session."""
|
||||
|
||||
op: str = "open"
|
||||
client_uuid: str = ""
|
||||
robot_type: str = ""
|
||||
policy_type: str = ""
|
||||
fps: float = 0.0
|
||||
# Hard sync-safety contract: must equal the server's action feature
|
||||
# names *and order* — this maps chunk columns to motors.
|
||||
action_names: list[str] = field(default_factory=list)
|
||||
camera_names: list[str] = field(default_factory=list) # canonical keys (post-rename)
|
||||
state_dim: int = 0
|
||||
schema_version: int = SCHEMA_VERSION
|
||||
rtc_enabled: bool = False
|
||||
task: str = ""
|
||||
tags: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionAckMsg:
|
||||
"""Server → client (control plane): session accept/reject + capabilities."""
|
||||
|
||||
accepted: bool = False
|
||||
error: str = ""
|
||||
warnings: list[str] = field(default_factory=list)
|
||||
session_id: str = ""
|
||||
model_repo: str = ""
|
||||
model_revision: str = ""
|
||||
policy_type: str = ""
|
||||
action_names: list[str] = field(default_factory=list)
|
||||
expected_cameras: list[str] = field(default_factory=list)
|
||||
state_dim: int = 0
|
||||
chunk_size: int = 0
|
||||
trained_fps: float = 0.0
|
||||
supports_rtc: bool = False
|
||||
rtc_execution_horizon: int = 0
|
||||
serving_mode: str = ""
|
||||
warmed_up: bool = False
|
||||
schema_version: int = SCHEMA_VERSION
|
||||
server_load: float = 0.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class StatusMsg:
|
||||
"""Server → client (control plane): pre-flight capability snapshot."""
|
||||
|
||||
model_repo: str = ""
|
||||
model_revision: str = ""
|
||||
policy_type: str = ""
|
||||
action_names: list[str] = field(default_factory=list)
|
||||
expected_cameras: list[str] = field(default_factory=list)
|
||||
state_dim: int = 0
|
||||
chunk_size: int = 0
|
||||
trained_fps: float = 0.0
|
||||
supports_rtc: bool = False
|
||||
rtc_execution_horizon: int = 0
|
||||
serving_mode: str = ""
|
||||
warmed_up: bool = False
|
||||
min_schema_version: int = MIN_SUPPORTED_SCHEMA_VERSION
|
||||
max_schema_version: int = SCHEMA_VERSION
|
||||
active_sessions: int = 0
|
||||
max_sessions: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResetMsg:
|
||||
"""Client → server (control plane): episode boundary (acknowledged)."""
|
||||
|
||||
client_uuid: str = ""
|
||||
episode_id: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResetAckMsg:
|
||||
"""Server → client: reset acknowledgement."""
|
||||
|
||||
ok: bool = True
|
||||
error: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionCloseMsg:
|
||||
"""Client → server (control plane): graceful session close."""
|
||||
|
||||
op: str = "close"
|
||||
client_uuid: str = ""
|
||||
session_id: str = ""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Key-expression schema
|
||||
#
|
||||
# @lerobot/<service>/<client_uuid>/obs client → server (pub/sub)
|
||||
# @lerobot/<service>/<client_uuid>/action server → client (pub/sub)
|
||||
# @lerobot/<service>/status queryable (capabilities)
|
||||
# @lerobot/<service>/session queryable (open/close)
|
||||
# @lerobot/<service>/<client_uuid>/reset queryable (episode boundary)
|
||||
# @lerobot/<service>/<client_uuid>/alive liveliness token (client)
|
||||
# @lerobot/<service>/server/alive liveliness token (server)
|
||||
#
|
||||
# where <service> = <model_slug>/<revision_slug>/<task_slug>. The task
|
||||
# segment is a *namespace label* derived from the server's default task
|
||||
# (or an explicit service name) — the actual inference task string
|
||||
# travels in the session/observation messages.
|
||||
#
|
||||
# ``@lerobot`` is a verbatim chunk: it is only matched by an identical
|
||||
# chunk, so third-party ``**`` subscribers on a shared router can never
|
||||
# scrape this tree.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
KEY_ROOT = "@lerobot"
|
||||
|
||||
# Conservative allowlist for user-supplied key segments. Everything
|
||||
# else (including '/', '*', '$', '?', '#', whitespace) is folded to '-'.
|
||||
_SEGMENT_SANITIZE_RE = re.compile(r"[^A-Za-z0-9_.\-]+")
|
||||
|
||||
# Reserved final chunks of the key tree; a client UUID must never
|
||||
# collide with them.
|
||||
RESERVED_SEGMENTS = frozenset({"status", "session", "server", "obs", "action", "reset", "alive"})
|
||||
|
||||
|
||||
def sanitize_key_segment(segment: str) -> str:
|
||||
"""Fold an arbitrary string into a single safe Zenoh key chunk."""
|
||||
slug = _SEGMENT_SANITIZE_RE.sub("-", segment.strip()).strip("-.")
|
||||
if not slug:
|
||||
raise ValueError(f"Key segment {segment!r} sanitizes to an empty chunk")
|
||||
if slug in RESERVED_SEGMENTS:
|
||||
raise ValueError(f"Key segment {segment!r} collides with reserved chunk {slug!r}")
|
||||
return slug
|
||||
|
||||
|
||||
def service_prefix(model_id: str, revision: str, task: str) -> str:
|
||||
"""Build the shared key prefix for one served (model, revision, task) triple."""
|
||||
return "/".join(
|
||||
(
|
||||
KEY_ROOT,
|
||||
sanitize_key_segment(model_id),
|
||||
sanitize_key_segment(revision or "main"),
|
||||
sanitize_key_segment(task or "default"),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def obs_key(prefix: str, client_uuid: str) -> str:
|
||||
return f"{prefix}/{sanitize_key_segment(client_uuid)}/obs"
|
||||
|
||||
|
||||
def action_key(prefix: str, client_uuid: str) -> str:
|
||||
return f"{prefix}/{sanitize_key_segment(client_uuid)}/action"
|
||||
|
||||
|
||||
def reset_key(prefix: str, client_uuid: str) -> str:
|
||||
return f"{prefix}/{sanitize_key_segment(client_uuid)}/reset"
|
||||
|
||||
|
||||
def client_alive_key(prefix: str, client_uuid: str) -> str:
|
||||
return f"{prefix}/{sanitize_key_segment(client_uuid)}/alive"
|
||||
|
||||
|
||||
def status_key(prefix: str) -> str:
|
||||
return f"{prefix}/status"
|
||||
|
||||
|
||||
def session_key(prefix: str) -> str:
|
||||
return f"{prefix}/session"
|
||||
|
||||
|
||||
def server_alive_key(prefix: str) -> str:
|
||||
return f"{prefix}/server/alive"
|
||||
|
||||
|
||||
# Single-depth wildcards only — '**' would also match status/session/alive.
|
||||
def obs_wildcard(prefix: str) -> str:
|
||||
return f"{prefix}/*/obs"
|
||||
|
||||
|
||||
def reset_wildcard(prefix: str) -> str:
|
||||
return f"{prefix}/*/reset"
|
||||
|
||||
|
||||
def client_alive_wildcard(prefix: str) -> str:
|
||||
return f"{prefix}/*/alive"
|
||||
|
||||
|
||||
def client_uuid_from_key(key: str) -> str:
|
||||
"""Extract the client UUID chunk from an obs/reset/alive key."""
|
||||
parts = key.split("/")
|
||||
if len(parts) < 2:
|
||||
raise ValueError(f"Key {key!r} has no client chunk")
|
||||
return parts[-2]
|
||||
@@ -0,0 +1,934 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""``lerobot-policy-server``: multi-client GPU inference over Zenoh.
|
||||
|
||||
One process serves one pre-warmed (model, revision, dtype, device) to up
|
||||
to ``max_sessions`` robot clients. The process is **stateless per
|
||||
request**: clients ship RTC prefixes and a delay hint with every
|
||||
observation, so a server crash loses zero control state and reconnects
|
||||
are trivial.
|
||||
|
||||
Concurrency model (pure threads — zenoh-python has no asyncio API):
|
||||
|
||||
zenoh subscriber (.../*/obs) inference worker (1 thread, owns GPU)
|
||||
deposit-only callback: loop:
|
||||
session.deposit(header, body) ──► pick next session with pending obs (RR)
|
||||
(per-client latest-only slot) decode → per-session preprocess
|
||||
predict_action_chunk(delay, prefix)
|
||||
control queryables (status/session/ per-session postprocess → encode
|
||||
reset): validate, mutate session publisher.put(.../<uuid>/action)
|
||||
registry, reply inline
|
||||
|
||||
The single worker thread serializes GPU access; newest-wins mailboxes
|
||||
mean overload degrades into longer cycle times (larger but correct
|
||||
client delays), never into queue buildup.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import http.server
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
import uuid as uuid_module
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.configs import FeatureType
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.policies.factory import get_policy_class, make_pre_post_processors
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.rtc.relative import reanchor_relative_rtc_prefix
|
||||
from lerobot.policies.utils import populate_queues, prepare_observation_for_inference
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
|
||||
|
||||
from . import codec
|
||||
from .manifest import PolicyServerManifest
|
||||
from .scheduler import RoundRobinScheduler, Scheduler
|
||||
from .schema import (
|
||||
SCHEMA_VERSION,
|
||||
ActionChunkMsg,
|
||||
MsgHeader,
|
||||
ObservationMsg,
|
||||
ResetAckMsg,
|
||||
SessionAckMsg,
|
||||
SessionCloseMsg,
|
||||
SessionOpenMsg,
|
||||
StatusMsg,
|
||||
action_key,
|
||||
client_alive_key,
|
||||
client_alive_wildcard,
|
||||
client_uuid_from_key,
|
||||
obs_wildcard,
|
||||
reset_wildcard,
|
||||
server_alive_key,
|
||||
service_prefix,
|
||||
session_key,
|
||||
status_key,
|
||||
)
|
||||
from .session import Session, SessionRegistry
|
||||
from .validation import (
|
||||
PolicyClassification,
|
||||
classify_policy,
|
||||
resolve_serving_mode,
|
||||
validate_session_open,
|
||||
)
|
||||
from .zenoh_utils import action_publisher_qos, build_zenoh_config, import_zenoh
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
audit_logger = logging.getLogger("lerobot.policy_server.audit")
|
||||
|
||||
# Grace period after a client liveliness token drops before its session
|
||||
# is garbage-collected (rides out router blips and reconnects).
|
||||
_LIVELINESS_GC_GRACE_S = 5.0
|
||||
# Worker idle wait between work-event checks (also paces the GC sweep).
|
||||
_WORKER_IDLE_WAIT_S = 0.05
|
||||
|
||||
|
||||
def _normalize_prev_actions_length(prev_actions: torch.Tensor, target_steps: int) -> torch.Tensor:
|
||||
"""Pad or truncate RTC prefix actions to a fixed length (mirrors rtc.py)."""
|
||||
if prev_actions.ndim != 2:
|
||||
raise ValueError(f"Expected 2D [T, A] tensor, got shape={tuple(prev_actions.shape)}")
|
||||
steps, action_dim = prev_actions.shape
|
||||
if steps == target_steps:
|
||||
return prev_actions
|
||||
if steps > target_steps:
|
||||
return prev_actions[:target_steps]
|
||||
padded = torch.zeros((target_steps, action_dim), dtype=prev_actions.dtype, device=prev_actions.device)
|
||||
padded[:steps] = prev_actions
|
||||
return padded
|
||||
|
||||
|
||||
class PolicyServer:
|
||||
"""Zenoh policy server: control-plane queryables + data-plane pub/sub + one GPU worker."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
manifest: PolicyServerManifest,
|
||||
*,
|
||||
policy: PreTrainedPolicy | None = None,
|
||||
policy_cfg: PreTrainedConfig | None = None,
|
||||
processor_factory: Callable[[], tuple[Any, Any]] | None = None,
|
||||
classification: PolicyClassification | None = None,
|
||||
scheduler: Scheduler | None = None,
|
||||
) -> None:
|
||||
"""``policy``/``policy_cfg``/``processor_factory``/``classification``
|
||||
are injection points for tests; production loads everything from
|
||||
the manifest via :meth:`load_policy`.
|
||||
"""
|
||||
self._manifest = manifest
|
||||
self._device = torch.device(manifest.model.device)
|
||||
self._policy = policy
|
||||
self._policy_cfg = policy_cfg
|
||||
self._processor_factory = processor_factory
|
||||
self._classification = classification
|
||||
self._scheduler = scheduler or RoundRobinScheduler()
|
||||
|
||||
self._serving_mode: str = ""
|
||||
self._max_sessions: int = manifest.max_sessions
|
||||
self._rtc_active = False
|
||||
self._warmed_up = False
|
||||
|
||||
self.registry = SessionRegistry()
|
||||
self._registry_lock = threading.Lock() # serializes open/close/GC decisions
|
||||
# Serializes inference against episode resets: in exclusive mode a
|
||||
# reset (policy.reset(), pipeline reset) arriving on a queryable
|
||||
# thread mid-predict would corrupt the in-flight request's state.
|
||||
self._inference_lock = threading.Lock()
|
||||
|
||||
self._zenoh = None
|
||||
self._declarations: list[Any] = []
|
||||
self._alive_token = None
|
||||
|
||||
self._work = threading.Event()
|
||||
self._shutdown = threading.Event()
|
||||
self._worker: threading.Thread | None = None
|
||||
self._health_server: http.server.ThreadingHTTPServer | None = None
|
||||
|
||||
self._unknown_clients_warned: set[str] = set()
|
||||
self._capture_count = 0
|
||||
|
||||
self.metrics: dict[str, float] = {
|
||||
"requests_total": 0,
|
||||
"errors_total": 0,
|
||||
"superseded_total": 0,
|
||||
"dropped_unknown_client_total": 0,
|
||||
"sessions_opened_total": 0,
|
||||
"sessions_closed_total": 0,
|
||||
}
|
||||
self._metrics_lock = threading.Lock()
|
||||
|
||||
task_slug_source = manifest.service_name or manifest.default_task or "default"
|
||||
self.prefix = service_prefix(manifest.model.repo_or_path, manifest.model.revision, task_slug_source)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Loading & warmup
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def load_policy(self) -> None:
|
||||
"""Load config + weights, apply RTC settings, classify, warm up."""
|
||||
manifest = self._manifest
|
||||
if self._policy is None:
|
||||
logger.info(
|
||||
"Loading policy from '%s' (revision=%s)...",
|
||||
manifest.model.repo_or_path,
|
||||
manifest.model.revision,
|
||||
)
|
||||
policy_cfg = PreTrainedConfig.from_pretrained(manifest.model.repo_or_path)
|
||||
policy_cfg.pretrained_path = manifest.model.repo_or_path
|
||||
policy_class = get_policy_class(policy_cfg.type)
|
||||
policy = policy_class.from_pretrained(manifest.model.repo_or_path, config=policy_cfg)
|
||||
self._policy = policy
|
||||
self._policy_cfg = policy_cfg
|
||||
elif self._policy_cfg is None:
|
||||
self._policy_cfg = self._policy.config
|
||||
|
||||
if self._classification is None:
|
||||
self._classification = classify_policy(self._policy)
|
||||
logger.info("Policy classification: %s", self._classification.reason)
|
||||
|
||||
self._serving_mode, self._max_sessions = resolve_serving_mode(self._classification, manifest)
|
||||
logger.info("Serving mode: %s (max_sessions=%d)", self._serving_mode, self._max_sessions)
|
||||
|
||||
# RTC is a per-process decision: init_rtc_processor mutates the
|
||||
# shared policy instance.
|
||||
self._rtc_active = (
|
||||
manifest.rtc.enabled
|
||||
and self._classification.supports_rtc
|
||||
and hasattr(self._policy.config, "rtc_config")
|
||||
)
|
||||
if self._rtc_active:
|
||||
self._policy.config.rtc_config = manifest.rtc
|
||||
if hasattr(self._policy, "init_rtc_processor"):
|
||||
self._policy.init_rtc_processor()
|
||||
logger.info("RTC active (execution_horizon=%d)", manifest.rtc.execution_horizon)
|
||||
|
||||
if manifest.model.dtype:
|
||||
self._policy = self._policy.to(getattr(torch, manifest.model.dtype))
|
||||
self._policy = self._policy.to(self._device)
|
||||
self._policy.eval()
|
||||
|
||||
if not self.action_names:
|
||||
logger.warning(
|
||||
"Policy config has no action_feature_names: the action-order contract "
|
||||
"cannot be enforced at session open. Clients are trusted to match training order."
|
||||
)
|
||||
|
||||
if manifest.warmup_inferences > 0:
|
||||
self._warmup(manifest.warmup_inferences)
|
||||
self._warmed_up = True
|
||||
|
||||
def make_session_processors(self) -> tuple[Any, Any]:
|
||||
"""Build a fresh per-session pre/post pipeline pair.
|
||||
|
||||
The rename step is forced to identity: clients apply their
|
||||
rename map before encoding, so the wire format is canonical
|
||||
policy-feature keys across heterogeneous robots.
|
||||
"""
|
||||
if self._processor_factory is not None:
|
||||
return self._processor_factory()
|
||||
return make_pre_post_processors(
|
||||
policy_cfg=self._policy_cfg,
|
||||
pretrained_path=self._policy_cfg.pretrained_path,
|
||||
preprocessor_overrides={
|
||||
"device_processor": {"device": str(self._device)},
|
||||
"rename_observations_processor": {"rename_map": {}},
|
||||
},
|
||||
)
|
||||
|
||||
def _warmup(self, n: int) -> None:
|
||||
"""Run dummy forwards through the full request path (covers compile/caches)."""
|
||||
logger.info("Warmup: %d inferences...", n)
|
||||
obs = self._synthetic_observation()
|
||||
preprocessor, postprocessor = self.make_session_processors()
|
||||
session = Session(
|
||||
session_id="warmup",
|
||||
client_uuid="warmup",
|
||||
task=self._manifest.default_task,
|
||||
robot_type="",
|
||||
rtc_enabled=self._rtc_active,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
)
|
||||
reply = self.run_inference_request(session, MsgHeader(), obs)
|
||||
if self._rtc_active and reply.chunk_model is not None and n > 1:
|
||||
# Exercise the prefix-conditioned path too so its compile/cache
|
||||
# cost isn't paid by the first real RTC request.
|
||||
action_dim = reply.chunk_model.shape[-1]
|
||||
horizon = self._manifest.rtc.execution_horizon
|
||||
obs.prefix_model = np.zeros((horizon, action_dim), dtype=np.float32)
|
||||
obs.prefix_robot = np.zeros((horizon, action_dim), dtype=np.float32)
|
||||
obs.inference_delay_steps = 1
|
||||
for _ in range(n - 1):
|
||||
self.run_inference_request(session, MsgHeader(), obs)
|
||||
session.close()
|
||||
# Stateful policies must not carry warmup observations into real sessions.
|
||||
if self._serving_mode == "exclusive":
|
||||
self._policy.reset()
|
||||
logger.info("Warmup complete")
|
||||
|
||||
def _synthetic_observation(self) -> ObservationMsg:
|
||||
cfg = self._policy_cfg
|
||||
state_dim = self.state_dim or 1
|
||||
images = {}
|
||||
for key, feature in cfg.input_features.items():
|
||||
if feature.type == FeatureType.VISUAL:
|
||||
channels, height, width = feature.shape
|
||||
images[key] = np.zeros((height, width, channels), dtype=np.uint8)
|
||||
return ObservationMsg(
|
||||
state=np.zeros(state_dim, dtype=np.float32),
|
||||
images=images,
|
||||
task=self._manifest.default_task,
|
||||
jpeg_quality=0,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Capabilities
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@property
|
||||
def action_names(self) -> list[str]:
|
||||
names = getattr(self._policy_cfg, "action_feature_names", None)
|
||||
return list(names) if names else []
|
||||
|
||||
@property
|
||||
def state_dim(self) -> int:
|
||||
cfg = self._policy_cfg
|
||||
for key, feature in getattr(cfg, "input_features", {}).items():
|
||||
if key == OBS_STATE or feature.type == FeatureType.STATE:
|
||||
return int(feature.shape[0])
|
||||
return 0
|
||||
|
||||
@property
|
||||
def chunk_size(self) -> int:
|
||||
cfg = self._policy_cfg
|
||||
for attr in ("chunk_size", "n_action_steps", "horizon"):
|
||||
value = getattr(cfg, attr, None)
|
||||
if value:
|
||||
return int(value)
|
||||
return 0
|
||||
|
||||
def status_snapshot(self) -> StatusMsg:
|
||||
cfg = self._policy_cfg
|
||||
expected_cameras = [
|
||||
key
|
||||
for key, feature in getattr(cfg, "input_features", {}).items()
|
||||
if feature.type == FeatureType.VISUAL
|
||||
]
|
||||
return StatusMsg(
|
||||
model_repo=self._manifest.model.repo_or_path,
|
||||
model_revision=self._manifest.model.revision,
|
||||
policy_type=getattr(cfg, "type", "") or getattr(self._policy, "name", ""),
|
||||
action_names=self.action_names,
|
||||
expected_cameras=expected_cameras,
|
||||
state_dim=self.state_dim,
|
||||
chunk_size=self.chunk_size,
|
||||
trained_fps=self._manifest.trained_fps,
|
||||
supports_rtc=self._rtc_active,
|
||||
rtc_execution_horizon=self._manifest.rtc.execution_horizon if self._rtc_active else 0,
|
||||
serving_mode=self._serving_mode,
|
||||
warmed_up=self._warmed_up,
|
||||
active_sessions=len(self.registry),
|
||||
max_sessions=self._max_sessions,
|
||||
)
|
||||
|
||||
@property
|
||||
def server_load(self) -> float:
|
||||
return len(self.registry) / max(1, self._max_sessions)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# The per-request inference path (pure: no zenoh — parity-testable)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def run_inference_request(
|
||||
self, session: Session, header: MsgHeader, obs: ObservationMsg
|
||||
) -> ActionChunkMsg:
|
||||
"""Mirror of the local RTC loop's compute step (rtc.py), minus the queue merge."""
|
||||
t0 = time.perf_counter()
|
||||
|
||||
obs_np: dict[str, np.ndarray] = {}
|
||||
if obs.state is not None:
|
||||
obs_np[OBS_STATE] = np.asarray(obs.state, dtype=np.float32)
|
||||
for name, img in obs.images.items():
|
||||
obs_np[name] = img
|
||||
|
||||
task = obs.task or session.task or self._manifest.default_task
|
||||
batch = prepare_observation_for_inference(obs_np, self._device, task, session.robot_type)
|
||||
batch["task"] = [task]
|
||||
|
||||
preprocessed = session.preprocessor(batch)
|
||||
|
||||
use_rtc = self._rtc_active and session.rtc_enabled
|
||||
if use_rtc:
|
||||
delay = max(0, int(obs.inference_delay_steps))
|
||||
prev_actions: torch.Tensor | None = None
|
||||
if obs.prefix_model is not None and obs.prefix_model.size:
|
||||
prev_actions = torch.from_numpy(np.ascontiguousarray(obs.prefix_model)).to(self._device)
|
||||
|
||||
if prev_actions is not None and session.relative_step is not None:
|
||||
# Re-anchor the absolute leftover tail against the state
|
||||
# cached by THIS request's preprocess (mirrors rtc.py).
|
||||
raw_state = session.relative_step.get_cached_state()
|
||||
prefix_robot = obs.prefix_robot
|
||||
if raw_state is not None and prefix_robot is not None and prefix_robot.size:
|
||||
prev_actions = reanchor_relative_rtc_prefix(
|
||||
prev_actions_absolute=torch.from_numpy(np.ascontiguousarray(prefix_robot)),
|
||||
current_state=raw_state,
|
||||
relative_step=session.relative_step,
|
||||
normalizer_step=session.normalizer_step,
|
||||
policy_device=self._device,
|
||||
)
|
||||
|
||||
if prev_actions is not None:
|
||||
prev_actions = _normalize_prev_actions_length(
|
||||
prev_actions, target_steps=self._manifest.rtc.execution_horizon
|
||||
)
|
||||
|
||||
actions = self._policy.predict_action_chunk(
|
||||
preprocessed, inference_delay=delay, prev_chunk_left_over=prev_actions
|
||||
)
|
||||
else:
|
||||
if self._classification is not None and self._classification.needs_queue_population:
|
||||
preprocessed = self._populate_select_queues(preprocessed)
|
||||
actions = self._policy.predict_action_chunk(preprocessed)
|
||||
|
||||
original = actions.squeeze(0).clone()
|
||||
processed = session.postprocessor(actions).squeeze(0)
|
||||
inference_ms = (time.perf_counter() - t0) * 1e3
|
||||
|
||||
session.stats.requests += 1
|
||||
session.stats.last_inference_ms = inference_ms
|
||||
superseded = session.take_superseded()
|
||||
|
||||
return ActionChunkMsg(
|
||||
seq_id_echo=header.seq_id,
|
||||
client_mono_ns_echo=header.client_mono_ns,
|
||||
episode_id_echo=header.episode_id,
|
||||
chunk_model=original.detach().to("cpu", torch.float32).numpy(),
|
||||
chunk_robot=processed.detach().to("cpu", torch.float32).numpy(),
|
||||
inference_ms=inference_ms,
|
||||
superseded_seqs=superseded,
|
||||
server_load=self.server_load,
|
||||
)
|
||||
|
||||
def _populate_select_queues(self, batch: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Exclusive-mode shim for select_action-fed policies (diffusion family).
|
||||
|
||||
Mirrors ``DiffusionPolicy.select_action``: stack camera features
|
||||
into OBS_IMAGES, then populate the policy's observation queues so
|
||||
``predict_action_chunk`` sees the same history it would locally.
|
||||
"""
|
||||
policy = self._policy
|
||||
batch = {k: v for k, v in batch.items() if k != ACTION}
|
||||
if getattr(policy.config, "image_features", None):
|
||||
batch[OBS_IMAGES] = torch.stack([batch[key] for key in policy.config.image_features], dim=-4)
|
||||
policy._queues = populate_queues(policy._queues, batch)
|
||||
return batch
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Zenoh wiring
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def start(self) -> None:
|
||||
"""Open zenoh, declare the service surface, start worker + health threads."""
|
||||
if self._policy is None or not self._warmed_up:
|
||||
self.load_policy()
|
||||
|
||||
zenoh = import_zenoh()
|
||||
spec = self._manifest.zenoh
|
||||
self._zenoh = zenoh.open(
|
||||
build_zenoh_config(
|
||||
mode=spec.mode,
|
||||
connect_endpoints=spec.connect_endpoints,
|
||||
listen_endpoints=spec.listen_endpoints,
|
||||
tls_root_ca_certificate=spec.tls_root_ca_certificate,
|
||||
tls_connect_certificate=spec.tls_connect_certificate,
|
||||
tls_connect_private_key=spec.tls_connect_private_key,
|
||||
extra_config_json5=spec.extra_config_json5,
|
||||
)
|
||||
)
|
||||
handlers = zenoh.handlers
|
||||
|
||||
# Data plane: wildcard subscriber, deposit-only callback.
|
||||
self._declarations.append(
|
||||
self._zenoh.declare_subscriber(obs_wildcard(self.prefix), handlers.Callback(self._on_obs))
|
||||
)
|
||||
# Control plane: queryables reply inline (low rate).
|
||||
self._declarations.append(
|
||||
self._zenoh.declare_queryable(status_key(self.prefix), handlers.Callback(self._on_status_query))
|
||||
)
|
||||
self._declarations.append(
|
||||
self._zenoh.declare_queryable(session_key(self.prefix), handlers.Callback(self._on_session_query))
|
||||
)
|
||||
self._declarations.append(
|
||||
self._zenoh.declare_queryable(
|
||||
reset_wildcard(self.prefix), handlers.Callback(self._on_reset_query)
|
||||
)
|
||||
)
|
||||
# Presence: watch client tokens; publish our own.
|
||||
self._declarations.append(
|
||||
self._zenoh.liveliness().declare_subscriber(
|
||||
client_alive_wildcard(self.prefix), handlers.Callback(self._on_liveliness), history=True
|
||||
)
|
||||
)
|
||||
self._alive_token = self._zenoh.liveliness().declare_token(server_alive_key(self.prefix))
|
||||
|
||||
self._shutdown.clear()
|
||||
self._worker = threading.Thread(target=self._worker_loop, daemon=True, name="PolicyServerWorker")
|
||||
self._worker.start()
|
||||
|
||||
if self._manifest.health_port:
|
||||
self._start_health_server(self._manifest.health_port)
|
||||
|
||||
logger.info(
|
||||
"Policy server up: prefix=%s mode=%s max_sessions=%d rtc=%s",
|
||||
self.prefix,
|
||||
self._serving_mode,
|
||||
self._max_sessions,
|
||||
self._rtc_active,
|
||||
)
|
||||
|
||||
def serve_forever(self) -> None:
|
||||
try:
|
||||
while not self._shutdown.is_set():
|
||||
self._shutdown.wait(timeout=0.5)
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Interrupted — draining")
|
||||
finally:
|
||||
self.stop()
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Drain: drop the liveliness token first (clients ride their buffers
|
||||
through the drain), finish the in-flight inference, then close."""
|
||||
if self._alive_token is not None:
|
||||
with contextlib.suppress(Exception):
|
||||
self._alive_token.undeclare()
|
||||
self._alive_token = None
|
||||
|
||||
self._shutdown.set()
|
||||
self._work.set()
|
||||
if self._worker is not None and self._worker.is_alive():
|
||||
self._worker.join(timeout=10.0)
|
||||
if self._worker.is_alive():
|
||||
logger.warning("Inference worker did not join within 10s")
|
||||
self._worker = None
|
||||
|
||||
# Undeclare the control/data surface BEFORE closing sessions so a
|
||||
# late session open cannot be accepted by a server that has
|
||||
# already drained its worker.
|
||||
for declaration in self._declarations:
|
||||
with contextlib.suppress(Exception):
|
||||
declaration.undeclare()
|
||||
self._declarations.clear()
|
||||
|
||||
for session in self.registry.snapshot():
|
||||
self._close_session(session, reason="server shutdown")
|
||||
|
||||
if self._zenoh is not None:
|
||||
with contextlib.suppress(Exception):
|
||||
self._zenoh.close()
|
||||
self._zenoh = None
|
||||
|
||||
if self._health_server is not None:
|
||||
self._health_server.shutdown()
|
||||
self._health_server = None
|
||||
logger.info("Policy server stopped")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Zenoh callbacks (deposit-only on the data plane)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _on_obs(self, sample: Any) -> None:
|
||||
try:
|
||||
attachment = sample.attachment
|
||||
if attachment is None:
|
||||
return
|
||||
header = MsgHeader.unpack(attachment.to_bytes())
|
||||
if header.schema_version != SCHEMA_VERSION:
|
||||
return
|
||||
client_uuid = client_uuid_from_key(str(sample.key_expr))
|
||||
session = self.registry.get(client_uuid)
|
||||
if session is None:
|
||||
self._bump("dropped_unknown_client_total")
|
||||
# Bounded: garbage publishers must not grow this set (or
|
||||
# the log) without limit.
|
||||
if (
|
||||
client_uuid not in self._unknown_clients_warned
|
||||
and len(self._unknown_clients_warned) < 256
|
||||
):
|
||||
self._unknown_clients_warned.add(client_uuid)
|
||||
logger.warning(
|
||||
"Observation from unknown client '%s' (no session) — dropping", client_uuid
|
||||
)
|
||||
return
|
||||
session.deposit(header, sample.payload.to_bytes())
|
||||
self._work.set()
|
||||
except Exception as e: # noqa: BLE001 — a malformed sample must never kill the subscriber
|
||||
logger.error("obs callback error: %s", e)
|
||||
|
||||
def _on_liveliness(self, sample: Any) -> None:
|
||||
try:
|
||||
import zenoh
|
||||
|
||||
client_uuid = client_uuid_from_key(str(sample.key_expr))
|
||||
session = self.registry.get(client_uuid)
|
||||
if session is None:
|
||||
return
|
||||
if sample.kind == zenoh.SampleKind.DELETE:
|
||||
session.alive = False
|
||||
session.token_dropped_mono = time.monotonic()
|
||||
logger.info(
|
||||
"Client '%s' liveliness dropped — GC in %.0fs", client_uuid, _LIVELINESS_GC_GRACE_S
|
||||
)
|
||||
else:
|
||||
session.alive = True
|
||||
session.token_dropped_mono = None
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.error("liveliness callback error: %s", e)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Control-plane queryables
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _on_status_query(self, query: Any) -> None:
|
||||
try:
|
||||
query.reply(status_key(self.prefix), codec.encode_status(self.status_snapshot()))
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.error("status query error: %s", e)
|
||||
|
||||
def _on_session_query(self, query: Any) -> None:
|
||||
try:
|
||||
payload = query.payload.to_bytes() if query.payload is not None else b""
|
||||
op = codec.decode_raw(payload).get("op", "open") if payload else "open"
|
||||
if op == "close":
|
||||
self._handle_session_close(codec.decode_session_close(payload))
|
||||
query.reply(session_key(self.prefix), codec.encode_reset_ack(ResetAckMsg(ok=True)))
|
||||
return
|
||||
ack = self._handle_session_open(codec.decode_session_open(payload))
|
||||
query.reply(session_key(self.prefix), codec.encode_session_ack(ack))
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.error("session query error: %s\n%s", e, traceback.format_exc())
|
||||
with contextlib.suppress(Exception):
|
||||
query.reply(
|
||||
session_key(self.prefix),
|
||||
codec.encode_session_ack(SessionAckMsg(accepted=False, error=f"server error: {e}")),
|
||||
)
|
||||
|
||||
def _on_reset_query(self, query: Any) -> None:
|
||||
try:
|
||||
payload = query.payload.to_bytes() if query.payload is not None else b""
|
||||
msg = codec.decode_reset(payload)
|
||||
session = self.registry.get(msg.client_uuid)
|
||||
if session is None:
|
||||
ack = ResetAckMsg(ok=False, error=f"unknown client '{msg.client_uuid}'")
|
||||
else:
|
||||
# Serialize with the worker: resetting pipelines/policy
|
||||
# mid-predict would corrupt the in-flight request.
|
||||
with self._inference_lock:
|
||||
session.reset_episode(msg.episode_id)
|
||||
if self._serving_mode == "exclusive":
|
||||
# Safe: max_sessions=1, the policy belongs to this client.
|
||||
self._policy.reset()
|
||||
ack = ResetAckMsg(ok=True)
|
||||
query.reply(str(query.key_expr), codec.encode_reset_ack(ack))
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.error("reset query error: %s", e)
|
||||
|
||||
def _handle_session_open(self, msg: SessionOpenMsg) -> SessionAckMsg:
|
||||
capabilities = self.status_snapshot()
|
||||
with self._registry_lock:
|
||||
# A re-handshake from a known client replaces its session and
|
||||
# does not count against capacity.
|
||||
existing = self.registry.get(msg.client_uuid)
|
||||
active = len(self.registry) - (1 if existing else 0)
|
||||
result = validate_session_open(msg, capabilities, self._manifest, active)
|
||||
if not result.ok:
|
||||
logger.warning("Session rejected for '%s': %s", msg.client_uuid, result.error)
|
||||
return SessionAckMsg(accepted=False, error=result.error, server_load=self.server_load)
|
||||
|
||||
preprocessor, postprocessor = self.make_session_processors()
|
||||
session = Session(
|
||||
session_id=uuid_module.uuid4().hex,
|
||||
client_uuid=msg.client_uuid,
|
||||
task=msg.task or self._manifest.default_task,
|
||||
robot_type=msg.robot_type,
|
||||
rtc_enabled=msg.rtc_enabled and not result.rtc_downgraded,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
action_publisher=self._declare_action_publisher(msg.client_uuid),
|
||||
)
|
||||
if session.relative_step is not None and session.relative_step.action_names is None:
|
||||
session.relative_step.action_names = self.action_names or list(msg.action_names)
|
||||
# Sentinel: the FIRST observation of a fresh session always
|
||||
# triggers the episode-boundary branch in _serve_one, so a
|
||||
# mid-episode reconnect can never inherit stale state.
|
||||
session.episode_id = -1
|
||||
displaced = self.registry.add(session)
|
||||
if displaced is not None:
|
||||
displaced.close()
|
||||
self._bump("sessions_closed_total")
|
||||
logger.info("Client '%s' re-handshake: previous session replaced", msg.client_uuid)
|
||||
if self._serving_mode == "exclusive":
|
||||
# A new exclusive session must start from fresh policy state.
|
||||
with self._inference_lock:
|
||||
self._policy.reset()
|
||||
self._bump("sessions_opened_total")
|
||||
self._unknown_clients_warned.discard(msg.client_uuid)
|
||||
logger.info(
|
||||
"Session opened: client=%s session=%s task=%r rtc=%s (%d/%d)",
|
||||
msg.client_uuid,
|
||||
session.session_id,
|
||||
session.task,
|
||||
session.rtc_enabled,
|
||||
len(self.registry),
|
||||
self._max_sessions,
|
||||
)
|
||||
return SessionAckMsg(
|
||||
accepted=True,
|
||||
warnings=result.warnings,
|
||||
session_id=session.session_id,
|
||||
model_repo=capabilities.model_repo,
|
||||
model_revision=capabilities.model_revision,
|
||||
policy_type=capabilities.policy_type,
|
||||
action_names=capabilities.action_names,
|
||||
expected_cameras=capabilities.expected_cameras,
|
||||
state_dim=capabilities.state_dim,
|
||||
chunk_size=capabilities.chunk_size,
|
||||
trained_fps=capabilities.trained_fps,
|
||||
supports_rtc=capabilities.supports_rtc and session.rtc_enabled,
|
||||
rtc_execution_horizon=capabilities.rtc_execution_horizon,
|
||||
serving_mode=capabilities.serving_mode,
|
||||
warmed_up=capabilities.warmed_up,
|
||||
server_load=self.server_load,
|
||||
)
|
||||
|
||||
def _declare_action_publisher(self, client_uuid: str) -> Any:
|
||||
if self._zenoh is None: # pure-logic tests run without transport
|
||||
return None
|
||||
zenoh = import_zenoh()
|
||||
return self._zenoh.declare_publisher(
|
||||
action_key(self.prefix, client_uuid), **action_publisher_qos(zenoh)
|
||||
)
|
||||
|
||||
def _handle_session_close(self, msg: SessionCloseMsg) -> None:
|
||||
session = self.registry.get(msg.client_uuid)
|
||||
if session is not None and (not msg.session_id or msg.session_id == session.session_id):
|
||||
self._close_session(session, reason="client close")
|
||||
|
||||
def _close_session(self, session: Session, reason: str) -> None:
|
||||
# Identity-checked removal: never tear down a same-uuid session
|
||||
# that replaced this one via a re-handshake.
|
||||
removed = self.registry.remove(session.client_uuid, expected=session)
|
||||
if removed is not None:
|
||||
removed.close()
|
||||
self._bump("sessions_closed_total")
|
||||
logger.info("Session closed: client=%s (%s)", session.client_uuid, reason)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Inference worker
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _worker_loop(self) -> None:
|
||||
last_gc = time.monotonic()
|
||||
while not self._shutdown.is_set():
|
||||
ready = [s for s in self.registry.snapshot() if s.has_pending()]
|
||||
if not ready:
|
||||
self._work.wait(timeout=_WORKER_IDLE_WAIT_S)
|
||||
self._work.clear()
|
||||
else:
|
||||
for session in self._scheduler.select(ready):
|
||||
self._serve_one(session)
|
||||
|
||||
now = time.monotonic()
|
||||
if now - last_gc > 1.0:
|
||||
last_gc = now
|
||||
self._gc_sessions(now)
|
||||
|
||||
def _serve_one(self, session: Session) -> None:
|
||||
item = session.take()
|
||||
if item is None:
|
||||
return
|
||||
queue_wait_ms = (time.monotonic() - item.recv_mono) * 1e3
|
||||
outcome = "ok"
|
||||
try:
|
||||
obs = codec.decode_observation(item.payload)
|
||||
self._capture("req", item.payload)
|
||||
|
||||
with self._inference_lock:
|
||||
# Belt-and-braces episode ordering: the first observation of
|
||||
# an episode also announces the boundary (one-in-flight makes
|
||||
# the reset query race-free, but a lost ack must not desync
|
||||
# us; fresh sessions start at the -1 sentinel so their first
|
||||
# request always lands here).
|
||||
if obs.episode_start or item.header.episode_id != session.episode_id:
|
||||
session.preprocessor.reset()
|
||||
session.postprocessor.reset()
|
||||
session.episode_id = item.header.episode_id
|
||||
if self._serving_mode == "exclusive":
|
||||
self._policy.reset()
|
||||
|
||||
reply = self.run_inference_request(session, item.header, obs)
|
||||
reply.queue_wait_ms = queue_wait_ms
|
||||
session.stats.last_queue_wait_ms = queue_wait_ms
|
||||
|
||||
body = codec.encode_action_chunk(reply)
|
||||
self._capture("rep", body)
|
||||
# Local ref: a re-handshake can null session.action_publisher
|
||||
# between the check and the put.
|
||||
publisher = session.action_publisher
|
||||
if publisher is not None:
|
||||
reply_header = MsgHeader(
|
||||
schema_version=SCHEMA_VERSION,
|
||||
msg_type=2, # MSG_TYPE_CHUNK
|
||||
seq_id=item.header.seq_id,
|
||||
episode_id=item.header.episode_id,
|
||||
client_mono_ns=item.header.client_mono_ns,
|
||||
session_epoch=item.header.session_epoch,
|
||||
)
|
||||
publisher.put(body, attachment=reply_header.pack())
|
||||
self._bump("requests_total")
|
||||
self._bump("superseded_total", reply.superseded_seqs)
|
||||
except Exception as e: # noqa: BLE001 — one bad request must not kill the worker
|
||||
outcome = f"error: {e}"
|
||||
session.stats.errors += 1
|
||||
self._bump("errors_total")
|
||||
logger.error(
|
||||
"Inference error for client '%s' seq=%d: %s\n%s",
|
||||
session.client_uuid,
|
||||
item.header.seq_id,
|
||||
e,
|
||||
traceback.format_exc(),
|
||||
)
|
||||
finally:
|
||||
audit_logger.info(
|
||||
json.dumps(
|
||||
{
|
||||
"session_id": session.session_id,
|
||||
"client_uuid": session.client_uuid,
|
||||
"seq_id": item.header.seq_id,
|
||||
"episode_id": item.header.episode_id,
|
||||
"queue_wait_ms": round(queue_wait_ms, 3),
|
||||
"inference_ms": round(session.stats.last_inference_ms, 3),
|
||||
"superseded": session.stats.superseded,
|
||||
"outcome": outcome,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
def _gc_sessions(self, now: float) -> None:
|
||||
for session in self.registry.snapshot():
|
||||
if (
|
||||
session.token_dropped_mono is not None
|
||||
and now - session.token_dropped_mono > _LIVELINESS_GC_GRACE_S
|
||||
):
|
||||
if self._client_token_alive(session.client_uuid):
|
||||
# The DELETE was a late echo of a previous incarnation
|
||||
# (the token key is per client, not per epoch) — the
|
||||
# client re-declared and is alive.
|
||||
session.token_dropped_mono = None
|
||||
session.alive = True
|
||||
continue
|
||||
self._close_session(session, reason="liveliness token dropped")
|
||||
elif now - session.last_seen_mono > self._manifest.session_idle_timeout_s:
|
||||
self._close_session(session, reason="idle timeout")
|
||||
|
||||
def _client_token_alive(self, client_uuid: str) -> bool:
|
||||
"""Confirm a client's liveliness token via an explicit get (GC double-check)."""
|
||||
if self._zenoh is None:
|
||||
return False
|
||||
try:
|
||||
zenoh = import_zenoh()
|
||||
replies = self._zenoh.liveliness().get(
|
||||
client_alive_key(self.prefix, client_uuid),
|
||||
handler=zenoh.handlers.FifoChannel(4),
|
||||
timeout=0.5,
|
||||
)
|
||||
deadline = time.monotonic() + 1.0
|
||||
while time.monotonic() < deadline:
|
||||
try:
|
||||
reply = replies.try_recv()
|
||||
except Exception: # channel closed: no token found # noqa: BLE001
|
||||
return False
|
||||
if reply is None:
|
||||
time.sleep(0.01)
|
||||
continue
|
||||
if reply.ok is not None:
|
||||
return True
|
||||
return False
|
||||
except Exception: # noqa: BLE001 — treat transport trouble as "not alive"
|
||||
return False
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Misc
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _bump(self, key: str, amount: float = 1) -> None:
|
||||
with self._metrics_lock:
|
||||
self.metrics[key] = self.metrics.get(key, 0) + amount
|
||||
|
||||
def _capture(self, kind: str, data: bytes) -> None:
|
||||
capture_dir = self._manifest.debug.capture_dir
|
||||
if not capture_dir:
|
||||
return
|
||||
try:
|
||||
directory = Path(capture_dir)
|
||||
directory.mkdir(parents=True, exist_ok=True)
|
||||
index = self._capture_count % max(1, self._manifest.debug.capture_max)
|
||||
(directory / f"{kind}_{index:05d}.bin").write_bytes(data)
|
||||
if kind == "rep":
|
||||
self._capture_count += 1
|
||||
except OSError as e:
|
||||
logger.warning("debug capture failed: %s", e)
|
||||
|
||||
def _start_health_server(self, port: int) -> None:
|
||||
server_ref = self
|
||||
|
||||
class Handler(http.server.BaseHTTPRequestHandler):
|
||||
def do_GET(self) -> None: # noqa: N802 — http.server API
|
||||
if self.path == "/healthz":
|
||||
worker = server_ref._worker # local ref: stop() may null it mid-read
|
||||
healthy = worker is not None and worker.is_alive()
|
||||
self.send_response(200 if healthy else 503)
|
||||
self.end_headers()
|
||||
self.wfile.write(b"ok" if healthy else b"worker dead")
|
||||
elif self.path == "/metrics":
|
||||
with server_ref._metrics_lock:
|
||||
counters = dict(server_ref.metrics)
|
||||
counters["active_sessions"] = len(server_ref.registry)
|
||||
counters["server_load"] = server_ref.server_load
|
||||
body = "".join(
|
||||
f"lerobot_policy_server_{name} {value}\n" for name, value in sorted(counters.items())
|
||||
).encode()
|
||||
self.send_response(200)
|
||||
self.send_header("Content-Type", "text/plain; version=0.0.4")
|
||||
self.end_headers()
|
||||
self.wfile.write(body)
|
||||
else:
|
||||
self.send_response(404)
|
||||
self.end_headers()
|
||||
|
||||
def log_message(self, *args: Any) -> None: # silence per-request logging
|
||||
pass
|
||||
|
||||
self._health_server = http.server.ThreadingHTTPServer(("0.0.0.0", port), Handler) # nosec B104
|
||||
threading.Thread(target=self._health_server.serve_forever, daemon=True, name="HealthHTTP").start()
|
||||
logger.info("Health/metrics on :%d (/healthz, /metrics)", port)
|
||||
@@ -0,0 +1,203 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Per-client session state and the latest-only observation mailbox.
|
||||
|
||||
The server holds **no cross-request control state**: RTC prefixes and
|
||||
delay hints arrive with every observation. What a session does hold:
|
||||
|
||||
- Per-session processor pipeline instances. Mandatory:
|
||||
``RelativeActionsProcessorStep`` caches ``_last_state`` at preprocess
|
||||
and the postprocessor reads it back — a pipeline shared across clients
|
||||
would be a race.
|
||||
- A one-slot mailbox: the newest observation wins; superseded requests
|
||||
are counted so drops stay visible to the client.
|
||||
- Counters for the audit log and ``/metrics``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from threading import Lock
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from lerobot.processor import (
|
||||
NormalizerProcessorStep,
|
||||
PolicyProcessorPipeline,
|
||||
RelativeActionsProcessorStep,
|
||||
)
|
||||
|
||||
from .schema import MsgHeader
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class MailboxItem:
|
||||
header: MsgHeader
|
||||
payload: bytes
|
||||
recv_mono: float # server-local monotonic deposit time (for queue_wait_ms)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionStats:
|
||||
requests: int = 0
|
||||
errors: int = 0
|
||||
superseded: int = 0 # observations overwritten before inference (lifetime)
|
||||
superseded_since_reply: int = 0
|
||||
last_inference_ms: float = 0.0
|
||||
last_queue_wait_ms: float = 0.0
|
||||
|
||||
|
||||
class Session:
|
||||
"""One connected robot client."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_id: str,
|
||||
client_uuid: str,
|
||||
task: str,
|
||||
robot_type: str,
|
||||
rtc_enabled: bool,
|
||||
preprocessor: PolicyProcessorPipeline,
|
||||
postprocessor: PolicyProcessorPipeline,
|
||||
action_publisher: Any = None, # zenoh.Publisher (Any: zenoh optional at import)
|
||||
) -> None:
|
||||
self.session_id = session_id
|
||||
self.client_uuid = client_uuid
|
||||
self.task = task
|
||||
self.robot_type = robot_type
|
||||
self.rtc_enabled = rtc_enabled
|
||||
self.preprocessor = preprocessor
|
||||
self.postprocessor = postprocessor
|
||||
self.action_publisher = action_publisher
|
||||
|
||||
self.episode_id = 0
|
||||
self.stats = SessionStats()
|
||||
self.alive = True
|
||||
self.last_seen_mono = time.monotonic()
|
||||
# Set when the client's liveliness token drops; GC after grace.
|
||||
self.token_dropped_mono: float | None = None
|
||||
|
||||
# Processor introspection for relative-action prefix re-anchoring
|
||||
# (mirrors RTCInferenceEngine.__init__).
|
||||
self.relative_step = next(
|
||||
(s for s in preprocessor.steps if isinstance(s, RelativeActionsProcessorStep) and s.enabled),
|
||||
None,
|
||||
)
|
||||
self.normalizer_step = next(
|
||||
(s for s in preprocessor.steps if isinstance(s, NormalizerProcessorStep)),
|
||||
None,
|
||||
)
|
||||
|
||||
self._mailbox: MailboxItem | None = None
|
||||
self._mailbox_lock = Lock()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Mailbox (deposit-only callbacks write, the inference worker reads)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def deposit(self, header: MsgHeader, payload: bytes) -> None:
|
||||
"""Latest-only deposit; counts superseded observations."""
|
||||
item = MailboxItem(header=header, payload=payload, recv_mono=time.monotonic())
|
||||
with self._mailbox_lock:
|
||||
if self._mailbox is not None:
|
||||
self.stats.superseded += 1
|
||||
self.stats.superseded_since_reply += 1
|
||||
self._mailbox = item
|
||||
self.alive = True
|
||||
self.token_dropped_mono = None
|
||||
self.last_seen_mono = item.recv_mono
|
||||
|
||||
def take(self) -> MailboxItem | None:
|
||||
with self._mailbox_lock:
|
||||
item, self._mailbox = self._mailbox, None
|
||||
return item
|
||||
|
||||
def take_superseded(self) -> int:
|
||||
"""Atomically read-and-reset the per-reply supersession counter."""
|
||||
with self._mailbox_lock:
|
||||
count = self.stats.superseded_since_reply
|
||||
self.stats.superseded_since_reply = 0
|
||||
return count
|
||||
|
||||
def has_pending(self) -> bool:
|
||||
with self._mailbox_lock:
|
||||
return self._mailbox is not None
|
||||
|
||||
def clear_mailbox(self) -> None:
|
||||
with self._mailbox_lock:
|
||||
self._mailbox = None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Episode boundary
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def reset_episode(self, episode_id: int | None = None) -> None:
|
||||
"""Clear per-episode state. The shared policy is NOT touched here."""
|
||||
self.clear_mailbox()
|
||||
self.preprocessor.reset()
|
||||
self.postprocessor.reset()
|
||||
self.episode_id = episode_id if episode_id is not None else self.episode_id + 1
|
||||
|
||||
def close(self) -> None:
|
||||
self.clear_mailbox()
|
||||
publisher = self.action_publisher
|
||||
self.action_publisher = None
|
||||
if publisher is not None:
|
||||
# Already-closed transport is fine on teardown.
|
||||
with contextlib.suppress(Exception):
|
||||
publisher.undeclare()
|
||||
|
||||
|
||||
class SessionRegistry:
|
||||
"""Thread-safe map of client_uuid → Session."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._sessions: dict[str, Session] = {}
|
||||
self._lock = Lock()
|
||||
|
||||
def add(self, session: Session) -> Session | None:
|
||||
"""Register, returning a displaced same-client session (caller closes it)."""
|
||||
with self._lock:
|
||||
old = self._sessions.get(session.client_uuid)
|
||||
self._sessions[session.client_uuid] = session
|
||||
return old
|
||||
|
||||
def get(self, client_uuid: str) -> Session | None:
|
||||
with self._lock:
|
||||
return self._sessions.get(client_uuid)
|
||||
|
||||
def remove(self, client_uuid: str, expected: Session | None = None) -> Session | None:
|
||||
"""Remove by uuid; with ``expected``, only if it is still that exact session.
|
||||
|
||||
The identity check stops a GC sweep that snapshotted an old
|
||||
session from tearing down its just-handshaked replacement.
|
||||
"""
|
||||
with self._lock:
|
||||
current = self._sessions.get(client_uuid)
|
||||
if current is None or (expected is not None and current is not expected):
|
||||
return None
|
||||
return self._sessions.pop(client_uuid)
|
||||
|
||||
def snapshot(self) -> list[Session]:
|
||||
with self._lock:
|
||||
return list(self._sessions.values())
|
||||
|
||||
def __len__(self) -> int:
|
||||
with self._lock:
|
||||
return len(self._sessions)
|
||||
@@ -0,0 +1,265 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Serving-mode classification and session capability validation.
|
||||
|
||||
Multi-tenancy is engineered, not assumed: sharing one policy instance
|
||||
across sessions is only safe when ``predict_action_chunk`` touches no
|
||||
instance state. That property has been verified per policy family and
|
||||
is encoded here as an explicit registry — never inferred.
|
||||
|
||||
- ``act``/``pi0``/``pi05``: chunk-stateless (verified in-tree).
|
||||
- ``smolvla``: populates its ``_queues`` *inside* ``predict_action_chunk``;
|
||||
with ``n_obs_steps == 1`` the queue is overwritten with the request's
|
||||
own observation before being read, so sharing is safe. With history
|
||||
(``n_obs_steps > 1``) requests would read other sessions' frames →
|
||||
exclusive.
|
||||
- ``diffusion``: ``predict_action_chunk`` reads ``_queues`` that only
|
||||
``select_action`` populates → exclusive, with the server populating
|
||||
the observation queues per request (mirroring ``select_action``).
|
||||
- Policies without a ``predict_action_chunk`` override are refused.
|
||||
- Unverified chunk-API policies default to exclusive; ``shared`` cannot
|
||||
be forced for them (the roadmap upstreams a
|
||||
``supports_stateless_chunking`` attribute to policy classes).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
|
||||
from .manifest import (
|
||||
SERVING_MODE_EXCLUSIVE,
|
||||
SERVING_MODE_SHARED,
|
||||
PolicyServerManifest,
|
||||
)
|
||||
from .schema import SessionOpenMsg, StatusMsg
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ServingClass(Enum):
|
||||
SHARED = "shared"
|
||||
EXCLUSIVE = "exclusive"
|
||||
REFUSED = "refused"
|
||||
|
||||
|
||||
# Verified chunk-stateless families (predict_action_chunk touches no
|
||||
# cross-request instance state).
|
||||
VERIFIED_CHUNK_STATELESS: frozenset[str] = frozenset({"act", "pi0", "pi05"})
|
||||
|
||||
# Families whose predict_action_chunk reads select_action-fed queues:
|
||||
# the server must populate the observation queues per request.
|
||||
QUEUE_POPULATED_IN_SELECT: frozenset[str] = frozenset({"diffusion"})
|
||||
|
||||
# Families whose predict_action_chunk accepts the RTC kwargs
|
||||
# (inference_delay / prev_chunk_left_over) — see each family's
|
||||
# ActionSelectKwargs TypedDict.
|
||||
RTC_CAPABLE: frozenset[str] = frozenset({"pi0", "pi05", "smolvla"})
|
||||
|
||||
|
||||
@dataclass
|
||||
class PolicyClassification:
|
||||
serving_class: ServingClass
|
||||
supports_rtc: bool
|
||||
needs_queue_population: bool
|
||||
reason: str
|
||||
|
||||
|
||||
def _has_chunk_api(policy: PreTrainedPolicy) -> bool:
|
||||
method = getattr(type(policy), "predict_action_chunk", None)
|
||||
return method is not None and method is not PreTrainedPolicy.predict_action_chunk
|
||||
|
||||
|
||||
def classify_policy(policy: PreTrainedPolicy) -> PolicyClassification:
|
||||
"""Classify a loaded policy into a serving class. Registry-driven, never inferred."""
|
||||
name = getattr(policy, "name", type(policy).__name__)
|
||||
|
||||
if not _has_chunk_api(policy):
|
||||
return PolicyClassification(
|
||||
ServingClass.REFUSED,
|
||||
supports_rtc=False,
|
||||
needs_queue_population=False,
|
||||
reason=f"policy '{name}' does not implement predict_action_chunk",
|
||||
)
|
||||
|
||||
supports_rtc = name in RTC_CAPABLE
|
||||
|
||||
if name in VERIFIED_CHUNK_STATELESS:
|
||||
return PolicyClassification(
|
||||
ServingClass.SHARED, supports_rtc, False, f"'{name}' is verified chunk-stateless"
|
||||
)
|
||||
|
||||
if name == "smolvla":
|
||||
n_obs_steps = getattr(policy.config, "n_obs_steps", 1)
|
||||
if n_obs_steps == 1:
|
||||
return PolicyClassification(
|
||||
ServingClass.SHARED,
|
||||
supports_rtc,
|
||||
False,
|
||||
"'smolvla' with n_obs_steps=1 overwrites its queues per request",
|
||||
)
|
||||
return PolicyClassification(
|
||||
ServingClass.EXCLUSIVE,
|
||||
supports_rtc,
|
||||
False,
|
||||
f"'smolvla' with n_obs_steps={n_obs_steps} keeps observation history across requests",
|
||||
)
|
||||
|
||||
if name in QUEUE_POPULATED_IN_SELECT:
|
||||
return PolicyClassification(
|
||||
ServingClass.EXCLUSIVE,
|
||||
supports_rtc,
|
||||
True,
|
||||
f"'{name}' predict_action_chunk reads select_action-fed queues",
|
||||
)
|
||||
|
||||
return PolicyClassification(
|
||||
ServingClass.EXCLUSIVE,
|
||||
supports_rtc,
|
||||
False,
|
||||
f"'{name}' has a chunk API but is not in the verified chunk-stateless registry",
|
||||
)
|
||||
|
||||
|
||||
def resolve_serving_mode(
|
||||
classification: PolicyClassification, manifest: PolicyServerManifest
|
||||
) -> tuple[str, int]:
|
||||
"""Resolve the final (serving_mode, max_sessions) from classification + manifest.
|
||||
|
||||
The manifest may force ``exclusive`` but can never force ``shared``
|
||||
for a policy that is not verified chunk-stateless.
|
||||
"""
|
||||
if classification.serving_class is ServingClass.REFUSED:
|
||||
raise ValueError(f"Refusing to serve this policy: {classification.reason}")
|
||||
|
||||
if manifest.serving_mode == SERVING_MODE_SHARED:
|
||||
if classification.serving_class is not ServingClass.SHARED:
|
||||
raise ValueError(
|
||||
f"serving_mode=shared is unsafe for this policy: {classification.reason}. "
|
||||
"Use serving_mode=exclusive (or auto)."
|
||||
)
|
||||
mode = SERVING_MODE_SHARED
|
||||
elif manifest.serving_mode == SERVING_MODE_EXCLUSIVE:
|
||||
mode = SERVING_MODE_EXCLUSIVE
|
||||
else: # auto
|
||||
mode = (
|
||||
SERVING_MODE_SHARED
|
||||
if classification.serving_class is ServingClass.SHARED
|
||||
else SERVING_MODE_EXCLUSIVE
|
||||
)
|
||||
|
||||
max_sessions = manifest.max_sessions
|
||||
if mode == SERVING_MODE_EXCLUSIVE and max_sessions != 1:
|
||||
logger.warning(
|
||||
"serving_mode=exclusive forces max_sessions=1 (manifest had %d)", manifest.max_sessions
|
||||
)
|
||||
max_sessions = 1
|
||||
return mode, max_sessions
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Session-open validation (fail fast, fail loud)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValidationResult:
|
||||
error: str = "" # non-empty → hard reject
|
||||
warnings: list[str] = field(default_factory=list)
|
||||
# RTC requested but unsupported → downgrade to plain chunk-append.
|
||||
rtc_downgraded: bool = False
|
||||
|
||||
@property
|
||||
def ok(self) -> bool:
|
||||
return not self.error
|
||||
|
||||
|
||||
def validate_session_open(
|
||||
msg: SessionOpenMsg,
|
||||
capabilities: StatusMsg,
|
||||
manifest: PolicyServerManifest,
|
||||
active_sessions: int,
|
||||
) -> ValidationResult:
|
||||
"""Apply the capability matrix from the design doc (§8.4)."""
|
||||
result = ValidationResult()
|
||||
|
||||
# Schema version: client must be within the server's supported range.
|
||||
if not (capabilities.min_schema_version <= msg.schema_version <= capabilities.max_schema_version):
|
||||
result.error = (
|
||||
f"schema_version {msg.schema_version} outside supported range "
|
||||
f"[{capabilities.min_schema_version}, {capabilities.max_schema_version}]"
|
||||
)
|
||||
return result
|
||||
|
||||
# Capacity: reject with current load so the client can retry another replica.
|
||||
if active_sessions >= capabilities.max_sessions:
|
||||
result.error = f"server full: {active_sessions}/{capabilities.max_sessions} sessions active"
|
||||
return result
|
||||
|
||||
# Action names AND order: the hard sync-safety contract mapping
|
||||
# chunk columns to motors.
|
||||
if capabilities.action_names and msg.action_names != capabilities.action_names:
|
||||
result.error = (
|
||||
"action feature names/order mismatch — refusing to map chunk columns to motors.\n"
|
||||
f" server: {capabilities.action_names}\n"
|
||||
f" client: {msg.action_names}"
|
||||
)
|
||||
return result
|
||||
|
||||
# State dim.
|
||||
if capabilities.state_dim and msg.state_dim and msg.state_dim != capabilities.state_dim:
|
||||
result.error = f"state dim mismatch: server={capabilities.state_dim}, client={msg.state_dim}"
|
||||
return result
|
||||
|
||||
# Camera names: the client set must cover the policy's visual features.
|
||||
missing = set(capabilities.expected_cameras) - set(msg.camera_names)
|
||||
if missing:
|
||||
result.error = (
|
||||
f"missing camera features {sorted(missing)} "
|
||||
f"(client provides {sorted(msg.camera_names)}; resolution may differ — names may not)"
|
||||
)
|
||||
return result
|
||||
|
||||
# Task pinning.
|
||||
if manifest.pin_task and msg.task and msg.task != manifest.default_task:
|
||||
result.error = f"task is pinned to {manifest.default_task!r} on this server, got {msg.task!r}"
|
||||
return result
|
||||
|
||||
# fps: warn unless strict.
|
||||
if capabilities.trained_fps and msg.fps and abs(msg.fps - capabilities.trained_fps) > 1e-6:
|
||||
fps_msg = f"client fps={msg.fps:g} != trained fps={capabilities.trained_fps:g}"
|
||||
if manifest.strict_fps:
|
||||
result.error = fps_msg + " (strict_fps=true)"
|
||||
return result
|
||||
result.warnings.append(fps_msg)
|
||||
|
||||
# Policy type sanity (informational mismatch is a warning, not fatal:
|
||||
# the action/state/camera contracts above are the binding ones).
|
||||
if msg.policy_type and capabilities.policy_type and msg.policy_type != capabilities.policy_type:
|
||||
result.warnings.append(
|
||||
f"client expected policy_type={msg.policy_type!r}, server runs {capabilities.policy_type!r}"
|
||||
)
|
||||
|
||||
# RTC: requested but unsupported → serve plain chunks, client appends.
|
||||
if msg.rtc_enabled and not capabilities.supports_rtc:
|
||||
result.rtc_downgraded = True
|
||||
result.warnings.append(
|
||||
"RTC requested but this server/policy does not support it — downgrading to chunk-append"
|
||||
)
|
||||
|
||||
return result
|
||||
@@ -0,0 +1,101 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Zenoh session construction shared by the policy server and the remote engine.
|
||||
|
||||
Verified against eclipse-zenoh 1.9 (thread-based; no asyncio API).
|
||||
Multicast scouting is always disabled — fleet "discovery" is static
|
||||
endpoint configuration plus liveliness tokens, never protocol magic.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_ZENOH_IMPORT_HINT = (
|
||||
"Remote inference requires the 'async' extra: pip install 'lerobot[async]' (eclipse-zenoh + msgpack)"
|
||||
)
|
||||
|
||||
|
||||
def import_zenoh():
|
||||
"""Import zenoh lazily with an actionable error message."""
|
||||
try:
|
||||
import zenoh
|
||||
except ImportError as e:
|
||||
raise ImportError(_ZENOH_IMPORT_HINT) from e
|
||||
return zenoh
|
||||
|
||||
|
||||
def build_zenoh_config(
|
||||
*,
|
||||
mode: str = "client",
|
||||
connect_endpoints: list[str] | None = None,
|
||||
listen_endpoints: list[str] | None = None,
|
||||
tls_root_ca_certificate: str | None = None,
|
||||
tls_connect_certificate: str | None = None,
|
||||
tls_connect_private_key: str | None = None,
|
||||
extra_config_json5: str | None = None,
|
||||
):
|
||||
"""Build a zenoh.Config (values are JSON5 strings — note the inner quoting)."""
|
||||
zenoh = import_zenoh()
|
||||
cfg = zenoh.Config()
|
||||
cfg.insert_json5("mode", json.dumps(mode))
|
||||
cfg.insert_json5("scouting/multicast/enabled", "false")
|
||||
if connect_endpoints:
|
||||
cfg.insert_json5("connect/endpoints", json.dumps(list(connect_endpoints)))
|
||||
if listen_endpoints:
|
||||
cfg.insert_json5("listen/endpoints", json.dumps(list(listen_endpoints)))
|
||||
if tls_root_ca_certificate:
|
||||
cfg.insert_json5("transport/link/tls/root_ca_certificate", json.dumps(tls_root_ca_certificate))
|
||||
if tls_connect_certificate:
|
||||
cfg.insert_json5("transport/link/tls/connect_certificate", json.dumps(tls_connect_certificate))
|
||||
if tls_connect_private_key:
|
||||
cfg.insert_json5("transport/link/tls/connect_private_key", json.dumps(tls_connect_private_key))
|
||||
if extra_config_json5:
|
||||
merged = json.loads(extra_config_json5)
|
||||
for key, value in merged.items():
|
||||
cfg.insert_json5(key, json.dumps(value))
|
||||
return cfg
|
||||
|
||||
|
||||
def action_publisher_qos(zenoh) -> dict:
|
||||
"""QoS for the action topic: RELIABLE + congestion DROP (never BLOCK) + express.
|
||||
|
||||
DROP so one dead robot uplink can never stall the server's publish
|
||||
path; a dropped chunk is recoverable by design — the client's action
|
||||
buffer keeps the robot moving and the next chunk replaces it.
|
||||
"""
|
||||
return {
|
||||
"reliability": zenoh.Reliability.RELIABLE,
|
||||
"congestion_control": zenoh.CongestionControl.DROP,
|
||||
"express": True,
|
||||
"priority": zenoh.Priority.INTERACTIVE_HIGH,
|
||||
}
|
||||
|
||||
|
||||
def obs_publisher_qos(zenoh) -> dict:
|
||||
"""QoS for the observation topic: best-effort drop, default priority.
|
||||
|
||||
Intentional drop already happened at the client's one-slot holder;
|
||||
if the uplink stalls, dropping a frame protects the control loop.
|
||||
"""
|
||||
return {
|
||||
"reliability": zenoh.Reliability.BEST_EFFORT,
|
||||
"congestion_control": zenoh.CongestionControl.DROP,
|
||||
"express": False,
|
||||
"priority": zenoh.Priority.DATA,
|
||||
}
|
||||
@@ -175,6 +175,9 @@ class AddBatchDimensionComplementaryDataStep(ComplementaryDataProcessorStep):
|
||||
if isinstance(task_index_value, Tensor) and task_index_value.dim() == 0:
|
||||
complementary_data["task_index"] = task_index_value.unsqueeze(0)
|
||||
|
||||
complementary_data.pop("language_persistent", None)
|
||||
complementary_data.pop("language_events", None)
|
||||
|
||||
if "messages" in complementary_data:
|
||||
messages = complementary_data["messages"]
|
||||
if isinstance(messages, list) and (not messages or isinstance(messages[0], dict)):
|
||||
|
||||
@@ -32,7 +32,6 @@ from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable, Iterable, Sequence
|
||||
@@ -281,6 +280,11 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
|
||||
before_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False)
|
||||
after_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False)
|
||||
_serialized_state_filenames: tuple[str | None, ...] | None = field(
|
||||
default=None,
|
||||
init=False,
|
||||
repr=False,
|
||||
)
|
||||
|
||||
def __call__(self, data: TInput) -> TOutput:
|
||||
"""Processes input data through the full pipeline.
|
||||
@@ -338,30 +342,108 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
transition = processor_step(transition)
|
||||
yield transition
|
||||
|
||||
def _save_pretrained(self, save_directory: Path, **kwargs):
|
||||
"""Internal method to comply with `HubMixin`'s saving mechanism.
|
||||
def _get_sanitized_name(self) -> str:
|
||||
"""Return a filename-safe version of the pipeline name.
|
||||
|
||||
This method does the actual saving work and is called by HubMixin.save_pretrained.
|
||||
Returns:
|
||||
The lower-cased pipeline name with non-alphanumeric characters replaced by underscores.
|
||||
"""
|
||||
config_filename = kwargs.pop("config_filename", None)
|
||||
return re.sub(r"[^a-zA-Z0-9_]", "_", self.name.lower())
|
||||
|
||||
# Sanitize the pipeline name to create a valid filename prefix.
|
||||
sanitized_name = re.sub(r"[^a-zA-Z0-9_]", "_", self.name.lower())
|
||||
@staticmethod
|
||||
def _get_state_filename(
|
||||
*,
|
||||
step_index: int,
|
||||
registry_name: str | None,
|
||||
sanitized_name: str,
|
||||
) -> str:
|
||||
"""Return the safetensors filename for one stateful processor step.
|
||||
|
||||
if config_filename is None:
|
||||
config_filename = f"{sanitized_name}.json"
|
||||
Args:
|
||||
step_index: The index of the processor step in this pipeline.
|
||||
registry_name: The registered processor step name, if available.
|
||||
sanitized_name: The filename-safe pipeline name.
|
||||
|
||||
config: dict[str, Any] = {
|
||||
Returns:
|
||||
The state filename used by the existing disk serialization format.
|
||||
"""
|
||||
if registry_name:
|
||||
return f"{sanitized_name}_step_{step_index}_{registry_name}.safetensors"
|
||||
|
||||
return f"{sanitized_name}_step_{step_index}.safetensors"
|
||||
|
||||
@staticmethod
|
||||
def _get_state_key(state_filename: str) -> str:
|
||||
"""Return the in-memory state key for a serialized state filename.
|
||||
|
||||
Args:
|
||||
state_filename: The `.safetensors` filename from the serialized config.
|
||||
|
||||
Returns:
|
||||
The state key used by the in-memory pipeline state dictionary.
|
||||
"""
|
||||
return state_filename.removesuffix(".safetensors")
|
||||
|
||||
@staticmethod
|
||||
def _get_state_filenames_from_config(loaded_config: dict[str, Any]) -> tuple[str | None, ...]:
|
||||
"""Return serialized state filenames in step order.
|
||||
|
||||
Args:
|
||||
loaded_config: A validated processor pipeline config.
|
||||
|
||||
Returns:
|
||||
A tuple containing each step's serialized state filename, or None for stateless steps.
|
||||
"""
|
||||
return tuple(step_entry.get("state_file") for step_entry in loaded_config["steps"])
|
||||
|
||||
def _get_state_filenames_for_loading(self) -> tuple[str | None, ...]:
|
||||
"""Return expected state filenames in step order for `load_state_dict()`.
|
||||
|
||||
Returns:
|
||||
The preserved serialized state filenames when available, otherwise filenames derived from
|
||||
current non-empty step state.
|
||||
"""
|
||||
if self._serialized_state_filenames is not None and len(self._serialized_state_filenames) == len(
|
||||
self.steps
|
||||
):
|
||||
return self._serialized_state_filenames
|
||||
|
||||
sanitized_name = self._get_sanitized_name()
|
||||
state_filenames: list[str | None] = []
|
||||
|
||||
for step_index, processor_step in enumerate(self.steps):
|
||||
step_state_dict = processor_step.state_dict()
|
||||
if not step_state_dict:
|
||||
state_filenames.append(None)
|
||||
continue
|
||||
|
||||
registry_name = getattr(processor_step.__class__, "_registry_name", None)
|
||||
state_filenames.append(
|
||||
self._get_state_filename(
|
||||
step_index=step_index,
|
||||
registry_name=registry_name,
|
||||
sanitized_name=sanitized_name,
|
||||
)
|
||||
)
|
||||
|
||||
return tuple(state_filenames)
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""Return the JSON-serializable pipeline configuration.
|
||||
|
||||
Returns:
|
||||
A dictionary with the same content that `save_pretrained()` writes as JSON.
|
||||
"""
|
||||
sanitized_name = self._get_sanitized_name()
|
||||
pipeline_config: dict[str, Any] = {
|
||||
"name": self.name,
|
||||
"steps": [],
|
||||
}
|
||||
|
||||
# Iterate through each step to build its configuration entry.
|
||||
for step_index, processor_step in enumerate(self.steps):
|
||||
registry_name = getattr(processor_step.__class__, "_registry_name", None)
|
||||
|
||||
step_entry: dict[str, Any] = {}
|
||||
# Prefer registry name for portability, otherwise fall back to full class path.
|
||||
|
||||
if registry_name:
|
||||
step_entry["registry_name"] = registry_name
|
||||
else:
|
||||
@@ -369,31 +451,110 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
f"{processor_step.__class__.__module__}.{processor_step.__class__.__name__}"
|
||||
)
|
||||
|
||||
# Save step configuration if `get_config` is implemented.
|
||||
if hasattr(processor_step, "get_config"):
|
||||
step_entry["config"] = processor_step.get_config()
|
||||
step_entry["config"] = processor_step.get_config()
|
||||
|
||||
# Save step state if `state_dict` is implemented and returns a non-empty dict.
|
||||
if hasattr(processor_step, "state_dict"):
|
||||
state = processor_step.state_dict()
|
||||
if state:
|
||||
# Clone tensors to avoid modifying the original state.
|
||||
cloned_state = {key: tensor.clone() for key, tensor in state.items()}
|
||||
step_state_dict = processor_step.state_dict()
|
||||
if step_state_dict:
|
||||
step_entry["state_file"] = self._get_state_filename(
|
||||
step_index=step_index,
|
||||
registry_name=registry_name,
|
||||
sanitized_name=sanitized_name,
|
||||
)
|
||||
|
||||
# Create a unique filename for the state file.
|
||||
if registry_name:
|
||||
state_filename = f"{sanitized_name}_step_{step_index}_{registry_name}.safetensors"
|
||||
else:
|
||||
state_filename = f"{sanitized_name}_step_{step_index}.safetensors"
|
||||
pipeline_config["steps"].append(step_entry)
|
||||
|
||||
save_file(cloned_state, os.path.join(str(save_directory), state_filename))
|
||||
step_entry["state_file"] = state_filename
|
||||
return pipeline_config
|
||||
|
||||
config["steps"].append(step_entry)
|
||||
def state_dict(self) -> dict[str, dict[str, torch.Tensor]]:
|
||||
"""Return pipeline state tensors grouped by state key.
|
||||
|
||||
# Write the main configuration JSON file.
|
||||
with open(os.path.join(str(save_directory), config_filename), "w") as file_pointer:
|
||||
json.dump(config, file_pointer, indent=2)
|
||||
Returns:
|
||||
A dictionary mapping suffixless state keys to cloned step state dictionaries.
|
||||
"""
|
||||
sanitized_name = self._get_sanitized_name()
|
||||
pipeline_state_dict: dict[str, dict[str, torch.Tensor]] = {}
|
||||
|
||||
for step_index, processor_step in enumerate(self.steps):
|
||||
step_state_dict = processor_step.state_dict()
|
||||
if not step_state_dict:
|
||||
continue
|
||||
|
||||
registry_name = getattr(processor_step.__class__, "_registry_name", None)
|
||||
state_filename = self._get_state_filename(
|
||||
step_index=step_index,
|
||||
registry_name=registry_name,
|
||||
sanitized_name=sanitized_name,
|
||||
)
|
||||
state_key = self._get_state_key(state_filename)
|
||||
pipeline_state_dict[state_key] = {
|
||||
tensor_name: tensor.clone() for tensor_name, tensor in step_state_dict.items()
|
||||
}
|
||||
|
||||
return pipeline_state_dict
|
||||
|
||||
def load_state_dict(
|
||||
self,
|
||||
state_dict: dict[str, dict[str, torch.Tensor]],
|
||||
) -> None:
|
||||
"""Load pipeline state tensors into the existing steps.
|
||||
|
||||
Args:
|
||||
state_dict: A dictionary mapping suffixless state keys to step state dictionaries.
|
||||
|
||||
Raises:
|
||||
KeyError: If loading finds missing expected state or unexpected extra state.
|
||||
"""
|
||||
expected_state_filenames = self._get_state_filenames_for_loading()
|
||||
used_state_keys: set[str] = set()
|
||||
|
||||
for step_index, (processor_step, state_filename) in enumerate(
|
||||
zip(self.steps, expected_state_filenames, strict=True)
|
||||
):
|
||||
if state_filename is None:
|
||||
continue
|
||||
|
||||
state_key = self._get_state_key(state_filename)
|
||||
if state_key not in state_dict:
|
||||
raise KeyError(
|
||||
f"Missing state key '{state_key}' for processor step {step_index}. "
|
||||
f"Available state keys: {sorted(state_dict.keys())}"
|
||||
)
|
||||
|
||||
processor_step.load_state_dict(state_dict[state_key])
|
||||
used_state_keys.add(state_key)
|
||||
|
||||
unexpected_state_keys = set(state_dict) - used_state_keys
|
||||
if unexpected_state_keys:
|
||||
expected_state_key_set = {
|
||||
self._get_state_key(state_filename)
|
||||
for state_filename in expected_state_filenames
|
||||
if state_filename is not None
|
||||
}
|
||||
raise KeyError(
|
||||
f"Unexpected processor state keys: {sorted(unexpected_state_keys)}. "
|
||||
f"Expected state keys: {sorted(expected_state_key_set)}"
|
||||
)
|
||||
|
||||
def _save_pretrained(self, save_directory: Path, **kwargs) -> None:
|
||||
"""Internal method to comply with `HubMixin`'s saving mechanism.
|
||||
|
||||
This method does the actual saving work and is called by HubMixin.save_pretrained.
|
||||
"""
|
||||
config_filename = kwargs.pop("config_filename", None)
|
||||
sanitized_name = self._get_sanitized_name()
|
||||
|
||||
if config_filename is None:
|
||||
config_filename = f"{sanitized_name}.json"
|
||||
|
||||
pipeline_config = self.get_config()
|
||||
pipeline_state_dict = self.state_dict()
|
||||
|
||||
for state_key, step_state_dict in pipeline_state_dict.items():
|
||||
state_filename = f"{state_key}.safetensors"
|
||||
save_file(step_state_dict, save_directory / state_filename)
|
||||
|
||||
with open(save_directory / config_filename, "w") as file_pointer:
|
||||
json.dump(pipeline_config, file_pointer, indent=2)
|
||||
|
||||
def save_pretrained(
|
||||
self,
|
||||
@@ -577,12 +738,54 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
cls._validate_overrides_used(validated_overrides, loaded_config)
|
||||
|
||||
# 5. Construct and return the final pipeline instance
|
||||
return cls(
|
||||
pipeline = cls(
|
||||
steps=steps,
|
||||
name=loaded_config.get("name", "DataProcessorPipeline"),
|
||||
to_transition=to_transition or cast(Callable[[TInput], EnvTransition], batch_to_transition),
|
||||
to_output=to_output or cast(Callable[[EnvTransition], TOutput], transition_to_batch),
|
||||
)
|
||||
pipeline._serialized_state_filenames = cls._get_state_filenames_from_config(loaded_config)
|
||||
return pipeline
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls,
|
||||
config: dict[str, Any],
|
||||
*,
|
||||
state_dict: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
overrides: dict[str, Any] | None = None,
|
||||
to_transition: Callable[[TInput], EnvTransition] | None = None,
|
||||
to_output: Callable[[EnvTransition], TOutput] | None = None,
|
||||
) -> DataProcessorPipeline[TInput, TOutput]:
|
||||
"""Build a pipeline from an in-memory config and optional state tensors.
|
||||
|
||||
Args:
|
||||
config: A config dictionary with the same structure as the saved processor JSON.
|
||||
state_dict: Optional in-memory pipeline state grouped by suffixless state key.
|
||||
overrides: Optional constructor overrides keyed by registry name or class name.
|
||||
to_transition: Optional converter from input data to `EnvTransition`.
|
||||
to_output: Optional converter from `EnvTransition` to output data.
|
||||
|
||||
Returns:
|
||||
A processor pipeline built from the config and optional state.
|
||||
"""
|
||||
cls._validate_loaded_config("<in-memory config>", config, "<in-memory config>")
|
||||
|
||||
steps, remaining_override_keys = cls._build_steps_from_config(config, overrides or {})
|
||||
cls._validate_overrides_used(remaining_override_keys, config)
|
||||
|
||||
pipeline = cls(
|
||||
steps=steps,
|
||||
name=config.get("name", "DataProcessorPipeline"),
|
||||
to_transition=to_transition or cast(Callable[[TInput], EnvTransition], batch_to_transition),
|
||||
to_output=to_output or cast(Callable[[EnvTransition], TOutput], transition_to_batch),
|
||||
)
|
||||
pipeline._serialized_state_filenames = cls._get_state_filenames_from_config(config)
|
||||
|
||||
if state_dict is not None:
|
||||
pipeline.load_state_dict(state_dict)
|
||||
|
||||
return pipeline
|
||||
|
||||
@classmethod
|
||||
def _load_config(
|
||||
@@ -666,9 +869,7 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
) from e
|
||||
|
||||
@classmethod
|
||||
def _validate_loaded_config(
|
||||
cls, model_id: str, loaded_config: dict[str, Any], config_filename: str
|
||||
) -> None:
|
||||
def _validate_loaded_config(cls, model_id: str, loaded_config: Any, config_filename: str) -> None:
|
||||
"""Validate that a config was loaded and is a valid processor config.
|
||||
|
||||
This method validates processor config format with intelligent migration detection:
|
||||
@@ -688,7 +889,7 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
|
||||
Args:
|
||||
model_id: The model identifier (used for migration detection)
|
||||
loaded_config: The loaded config dictionary (guaranteed non-None)
|
||||
loaded_config: The loaded config value to validate (may be non-dict)
|
||||
config_filename: The config filename that was loaded (for error messages)
|
||||
|
||||
Raises:
|
||||
@@ -702,9 +903,14 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
model_id,
|
||||
f"Config file '{config_filename}' is not a valid processor configuration",
|
||||
)
|
||||
loaded_config_description = (
|
||||
list(loaded_config.keys())
|
||||
if isinstance(loaded_config, dict)
|
||||
else type(loaded_config).__name__
|
||||
)
|
||||
raise ValueError(
|
||||
f"Config file '{config_filename}' is not a valid processor configuration. "
|
||||
f"Expected a config with 'steps' field, but got: {list(loaded_config.keys())}"
|
||||
f"Expected a config with 'steps' field, but got: {loaded_config_description}"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -766,26 +972,41 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
ImportError: If a step class cannot be imported or found in registry
|
||||
ValueError: If a step cannot be instantiated with its configuration
|
||||
"""
|
||||
steps: list[ProcessorStep] = []
|
||||
override_keys = set(overrides.keys())
|
||||
steps, remaining_override_keys = cls._build_steps_from_config(loaded_config, overrides)
|
||||
|
||||
for step_entry in loaded_config["steps"]:
|
||||
# 1. Get step class and key
|
||||
step_class, step_key = cls._resolve_step_class(step_entry)
|
||||
|
||||
# 2. Instantiate step with overrides
|
||||
step_instance = cls._instantiate_step(step_entry, step_class, step_key, overrides)
|
||||
|
||||
# 3. Load step state if available
|
||||
for step_instance, step_entry in zip(steps, loaded_config["steps"], strict=True):
|
||||
cls._load_step_state(step_instance, step_entry, model_id, base_path, hub_download_kwargs)
|
||||
|
||||
# 4. Track used overrides
|
||||
if step_key in override_keys:
|
||||
override_keys.discard(step_key)
|
||||
return steps, remaining_override_keys
|
||||
|
||||
steps.append(step_instance)
|
||||
@classmethod
|
||||
def _build_steps_from_config(
|
||||
cls,
|
||||
loaded_config: dict[str, Any],
|
||||
overrides: dict[str, Any],
|
||||
) -> tuple[list[ProcessorStep], set[str]]:
|
||||
"""Build processor steps from config without loading tensor state.
|
||||
|
||||
return steps, override_keys
|
||||
Args:
|
||||
loaded_config: The loaded processor configuration.
|
||||
overrides: User-provided constructor overrides keyed by step key.
|
||||
|
||||
Returns:
|
||||
A tuple containing instantiated steps and override keys that did not match a step.
|
||||
"""
|
||||
processor_steps: list[ProcessorStep] = []
|
||||
remaining_override_keys = set(overrides.keys())
|
||||
|
||||
for step_entry in loaded_config["steps"]:
|
||||
step_class, step_key = cls._resolve_step_class(step_entry)
|
||||
processor_step = cls._instantiate_step(step_entry, step_class, step_key, overrides)
|
||||
|
||||
if step_key in remaining_override_keys:
|
||||
remaining_override_keys.discard(step_key)
|
||||
|
||||
processor_steps.append(processor_step)
|
||||
|
||||
return processor_steps, remaining_override_keys
|
||||
|
||||
@classmethod
|
||||
def _resolve_step_class(cls, step_entry: dict[str, Any]) -> tuple[type[ProcessorStep], str]:
|
||||
@@ -1096,7 +1317,7 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def _is_processor_config(cls, config: dict) -> bool:
|
||||
def _is_processor_config(cls, config: Any) -> bool:
|
||||
"""Check if config follows DataProcessorPipeline format.
|
||||
|
||||
This method validates the processor configuration structure:
|
||||
@@ -1147,6 +1368,9 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
Returns:
|
||||
True if config follows valid DataProcessorPipeline format, False otherwise
|
||||
"""
|
||||
if not isinstance(config, dict):
|
||||
return False
|
||||
|
||||
# Must have a "steps" field with a list of step configurations
|
||||
if not isinstance(config.get("steps"), list):
|
||||
return False
|
||||
|
||||
@@ -50,17 +50,7 @@ class RenderMessagesStep(ProcessorStep):
|
||||
events = complementary_data.get(LANGUAGE_EVENTS) or []
|
||||
|
||||
if not persistent and not events:
|
||||
rendered = _fallback_low_level_render(complementary_data.get("task"))
|
||||
if rendered is None:
|
||||
return transition
|
||||
new_transition = transition.copy()
|
||||
new_complementary_data = dict(new_transition.get(TransitionKey.COMPLEMENTARY_DATA) or {})
|
||||
new_complementary_data.update(rendered)
|
||||
new_transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary_data
|
||||
return new_transition
|
||||
|
||||
if _is_batched_language(persistent) or _is_batched_language(events):
|
||||
return self._call_batch(transition, complementary_data, persistent, events)
|
||||
return transition
|
||||
|
||||
timestamp = complementary_data.get("timestamp")
|
||||
if timestamp is None:
|
||||
@@ -77,147 +67,18 @@ class RenderMessagesStep(ProcessorStep):
|
||||
dataset_ctx=self.dataset_ctx,
|
||||
)
|
||||
if rendered is None:
|
||||
rendered = _fallback_low_level_render(complementary_data.get("task"))
|
||||
if rendered is None:
|
||||
return None
|
||||
return None
|
||||
|
||||
new_transition = transition.copy()
|
||||
new_complementary_data = dict(new_transition.get(TransitionKey.COMPLEMENTARY_DATA) or {})
|
||||
new_complementary_data = dict(complementary_data)
|
||||
new_complementary_data.pop(LANGUAGE_PERSISTENT, None)
|
||||
new_complementary_data.pop(LANGUAGE_EVENTS, None)
|
||||
new_complementary_data.update(rendered)
|
||||
new_transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary_data
|
||||
return new_transition
|
||||
|
||||
def _call_batch(
|
||||
self,
|
||||
transition: EnvTransition,
|
||||
complementary_data: dict[str, Any],
|
||||
persistent_batch: list,
|
||||
events_batch: list,
|
||||
) -> EnvTransition | None:
|
||||
timestamp = complementary_data.get("timestamp")
|
||||
if timestamp is None:
|
||||
raise KeyError("RenderMessagesStep requires sample timestamp in complementary data.")
|
||||
|
||||
batch_size = max(len(persistent_batch), len(events_batch))
|
||||
messages: list[list[dict[str, Any]]] = []
|
||||
message_streams: list[list[str | None]] = []
|
||||
target_message_indices: list[list[int]] = []
|
||||
keep_indices: list[int] = []
|
||||
|
||||
for i in range(batch_size):
|
||||
rendered = render_sample(
|
||||
recipe=self.recipe,
|
||||
persistent=persistent_batch[i] if i < len(persistent_batch) else [],
|
||||
events=events_batch[i] if i < len(events_batch) else [],
|
||||
t=_batch_value(timestamp, i),
|
||||
sample_idx=int(_batch_value(complementary_data.get("index", 0), i)),
|
||||
task=_batch_value(complementary_data.get("task"), i),
|
||||
dataset_ctx=self.dataset_ctx,
|
||||
)
|
||||
if rendered is None:
|
||||
rendered = _fallback_low_level_render(_batch_value(complementary_data.get("task"), i))
|
||||
if rendered is None:
|
||||
continue
|
||||
keep_indices.append(i)
|
||||
messages.append(rendered["messages"])
|
||||
message_streams.append(rendered["message_streams"])
|
||||
target_message_indices.append(rendered["target_message_indices"])
|
||||
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
new_transition = (
|
||||
_select_batch_indices(transition, keep_indices)
|
||||
if len(keep_indices) != batch_size
|
||||
else transition.copy()
|
||||
)
|
||||
new_complementary_data = dict(new_transition.get(TransitionKey.COMPLEMENTARY_DATA) or {})
|
||||
new_complementary_data.pop(LANGUAGE_PERSISTENT, None)
|
||||
new_complementary_data.pop(LANGUAGE_EVENTS, None)
|
||||
new_complementary_data["messages"] = messages
|
||||
new_complementary_data["message_streams"] = message_streams
|
||||
new_complementary_data["target_message_indices"] = target_message_indices
|
||||
new_transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary_data
|
||||
return new_transition
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
"""Pass features through unchanged; rendering only touches complementary data."""
|
||||
return features
|
||||
|
||||
|
||||
def _scalar(value: Any) -> float | int:
|
||||
"""Unwrap a tensor/array/single-element list into a Python scalar."""
|
||||
if hasattr(value, "item"):
|
||||
return value.item()
|
||||
if isinstance(value, list):
|
||||
if len(value) != 1:
|
||||
raise ValueError(f"Expected a scalar, got list of length {len(value)}: {value!r}")
|
||||
return _scalar(value[0])
|
||||
return value
|
||||
|
||||
|
||||
def _is_batched_language(value: Any) -> bool:
|
||||
return isinstance(value, list) and bool(value) and isinstance(value[0], list)
|
||||
|
||||
|
||||
def _batch_value(value: Any, index: int) -> Any:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, list):
|
||||
return value[index]
|
||||
if hasattr(value, "ndim") and getattr(value, "ndim") > 0:
|
||||
return _scalar(value[index])
|
||||
return _scalar(value)
|
||||
|
||||
|
||||
def _select_batch_indices(transition: EnvTransition, indices: list[int]) -> EnvTransition:
|
||||
selected = transition.copy()
|
||||
for key in (TransitionKey.OBSERVATION, TransitionKey.COMPLEMENTARY_DATA):
|
||||
data = selected.get(key)
|
||||
if isinstance(data, dict):
|
||||
selected[key] = {k: _select_value(v, indices) for k, v in data.items()}
|
||||
action = selected.get(TransitionKey.ACTION)
|
||||
if action is not None:
|
||||
selected[TransitionKey.ACTION] = _select_value(action, indices)
|
||||
return selected
|
||||
|
||||
|
||||
def _select_value(value: Any, indices: list[int]) -> Any:
|
||||
if isinstance(value, list) and len(value) >= len(indices):
|
||||
return [value[i] for i in indices]
|
||||
if hasattr(value, "index_select") and hasattr(value, "new_tensor") and getattr(value, "ndim", 0) > 0:
|
||||
return value.index_select(0, value.new_tensor(indices).long())
|
||||
return value
|
||||
|
||||
|
||||
def _fallback_low_level_render(task: Any) -> dict[str, Any] | None:
|
||||
"""Keep action-only samples trainable when no recipe branch matches."""
|
||||
if hasattr(task, "item"):
|
||||
task = task.item()
|
||||
if isinstance(task, list):
|
||||
messages = []
|
||||
message_streams = []
|
||||
target_message_indices = []
|
||||
for t in task:
|
||||
rendered = _fallback_low_level_render(t)
|
||||
if rendered is None:
|
||||
return None
|
||||
messages.append(rendered["messages"])
|
||||
message_streams.append(rendered["message_streams"])
|
||||
target_message_indices.append(rendered["target_message_indices"])
|
||||
return {
|
||||
"messages": messages,
|
||||
"message_streams": message_streams,
|
||||
"target_message_indices": target_message_indices,
|
||||
}
|
||||
if not isinstance(task, str) or not task:
|
||||
return None
|
||||
return {
|
||||
"messages": [{"role": "user", "content": task}],
|
||||
"message_streams": ["low_level"],
|
||||
"target_message_indices": [],
|
||||
}
|
||||
|
||||
@@ -32,7 +32,6 @@ import torch
|
||||
from lerobot.configs import FeatureType, PipelineFeatureType, PolicyFeature
|
||||
from lerobot.types import EnvTransition, RobotObservation, TransitionKey
|
||||
from lerobot.utils.constants import (
|
||||
ACTION_CODE_TOKEN_MASK,
|
||||
ACTION_TOKEN_MASK,
|
||||
ACTION_TOKENS,
|
||||
OBS_LANGUAGE_ATTENTION_MASK,
|
||||
@@ -413,15 +412,14 @@ class ActionTokenizerProcessorStep(ActionProcessorStep):
|
||||
# During inference, no action is available, skip tokenization
|
||||
return new_transition
|
||||
|
||||
# Tokenize and get masks for the full formatted sequence and the discrete action codes.
|
||||
tokens, mask, code_mask = self._tokenize_action(action)
|
||||
# Tokenize and get both tokens and mask
|
||||
tokens, mask = self._tokenize_action(action)
|
||||
|
||||
# Store mask in complementary data
|
||||
complementary_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||
if complementary_data is None:
|
||||
complementary_data = {}
|
||||
complementary_data[ACTION_TOKEN_MASK] = mask
|
||||
complementary_data[ACTION_CODE_TOKEN_MASK] = code_mask
|
||||
complementary_data[ACTION_TOKENS] = tokens
|
||||
new_transition[TransitionKey.COMPLEMENTARY_DATA] = complementary_data
|
||||
return new_transition
|
||||
@@ -432,7 +430,7 @@ class ActionTokenizerProcessorStep(ActionProcessorStep):
|
||||
"""
|
||||
return self._paligemma_tokenizer.vocab_size - 1 - self.fast_skip_tokens - tokens
|
||||
|
||||
def _tokenize_action(self, action: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
def _tokenize_action(self, action: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Tokenizes the action tensor and creates a mask.
|
||||
|
||||
@@ -461,7 +459,6 @@ class ActionTokenizerProcessorStep(ActionProcessorStep):
|
||||
# The fast tokenizer expects action data and returns token IDs
|
||||
tokens_list = []
|
||||
masks_list = []
|
||||
code_masks_list = []
|
||||
|
||||
for i in range(batch_size):
|
||||
# Tokenize single action (move to CPU first as tokenizer uses scipy which requires numpy)
|
||||
@@ -479,26 +476,19 @@ class ActionTokenizerProcessorStep(ActionProcessorStep):
|
||||
if tokens.dim() > 1:
|
||||
tokens = tokens.flatten()
|
||||
|
||||
action_code_tokens = self._act_tokens_to_paligemma_tokens(tokens)
|
||||
bos_id = self._paligemma_tokenizer.bos_token_id
|
||||
prompt_tokens = torch.tensor(
|
||||
self._paligemma_tokenizer.encode("Action: ", add_special_tokens=False),
|
||||
device=action.device,
|
||||
)
|
||||
end_tokens = torch.tensor(self._paligemma_tokenizer.encode("|"), device=action.device)
|
||||
|
||||
code_start = 1 + len(prompt_tokens)
|
||||
code_end = code_start + len(action_code_tokens)
|
||||
# add bos
|
||||
tokens = torch.cat(
|
||||
[
|
||||
torch.tensor([bos_id], device=action.device),
|
||||
prompt_tokens,
|
||||
action_code_tokens,
|
||||
end_tokens,
|
||||
torch.tensor(
|
||||
self._paligemma_tokenizer.encode("Action: ", add_special_tokens=False),
|
||||
device=action.device,
|
||||
),
|
||||
self._act_tokens_to_paligemma_tokens(tokens),
|
||||
torch.tensor(self._paligemma_tokenizer.encode("|"), device=action.device),
|
||||
]
|
||||
)
|
||||
code_mask = torch.zeros(len(tokens), dtype=torch.bool, device=action.device)
|
||||
code_mask[code_start:code_end] = True
|
||||
|
||||
# Truncate or pad to max_action_tokens
|
||||
if len(tokens) > self.max_action_tokens:
|
||||
@@ -507,49 +497,44 @@ class ActionTokenizerProcessorStep(ActionProcessorStep):
|
||||
"Consider increasing the `max_action_tokens` in your model config if this happens frequently."
|
||||
)
|
||||
tokens = tokens[: self.max_action_tokens]
|
||||
code_mask = code_mask[: self.max_action_tokens]
|
||||
mask = torch.ones(self.max_action_tokens, dtype=torch.bool, device=action.device)
|
||||
else:
|
||||
pad_len = self.max_action_tokens - len(tokens)
|
||||
mask = torch.cat(
|
||||
[
|
||||
torch.ones(len(tokens), dtype=torch.bool, device=action.device),
|
||||
torch.zeros(pad_len, dtype=torch.bool, device=action.device),
|
||||
torch.zeros(
|
||||
self.max_action_tokens - len(tokens), dtype=torch.bool, device=action.device
|
||||
),
|
||||
]
|
||||
)
|
||||
code_mask = torch.nn.functional.pad(code_mask, (0, pad_len), value=False)
|
||||
# Pad tokens with zeros
|
||||
tokens = torch.nn.functional.pad(tokens, (0, pad_len), value=0)
|
||||
tokens = torch.nn.functional.pad(tokens, (0, self.max_action_tokens - len(tokens)), value=0)
|
||||
|
||||
tokens_list.append(tokens)
|
||||
masks_list.append(mask)
|
||||
code_masks_list.append(code_mask)
|
||||
|
||||
# Stack into batched tensors
|
||||
tokens_batch = torch.stack(tokens_list, dim=0) # (B, max_action_tokens)
|
||||
masks_batch = torch.stack(masks_list, dim=0) # (B, max_action_tokens)
|
||||
code_masks_batch = torch.stack(code_masks_list, dim=0) # (B, max_action_tokens)
|
||||
|
||||
# Remove batch dimension if input was single sample
|
||||
if single_sample:
|
||||
tokens_batch = tokens_batch.squeeze(0)
|
||||
masks_batch = masks_batch.squeeze(0)
|
||||
code_masks_batch = code_masks_batch.squeeze(0)
|
||||
|
||||
# Move to the same device as the input
|
||||
if device is not None:
|
||||
tokens_batch = tokens_batch.to(device)
|
||||
masks_batch = masks_batch.to(device)
|
||||
code_masks_batch = code_masks_batch.to(device)
|
||||
|
||||
return tokens_batch, masks_batch, code_masks_batch
|
||||
return tokens_batch, masks_batch
|
||||
|
||||
def action(self, action: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
This method is not used since we override __call__.
|
||||
Required by ActionProcessorStep ABC.
|
||||
"""
|
||||
tokens, _, _ = self._tokenize_action(action)
|
||||
tokens, _ = self._tokenize_action(action)
|
||||
return tokens
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
|
||||
@@ -21,8 +21,6 @@ from lerobot.utils.import_utils import make_device_from_device_class
|
||||
from .config import RobotConfig
|
||||
from .robot import Robot
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def make_robot_from_config(config: RobotConfig) -> Robot:
|
||||
# TODO(Steven): Consider just using the make_device_from_device_class for all types
|
||||
@@ -120,7 +118,7 @@ def ensure_safe_goal_position(
|
||||
}
|
||||
|
||||
if warnings_dict:
|
||||
logger.warning(
|
||||
logging.warning(
|
||||
"Relative goal position magnitude had to be clamped to be safe.\n"
|
||||
f"{pformat(warnings_dict, indent=4)}"
|
||||
)
|
||||
|
||||
@@ -39,8 +39,10 @@ from .context import (
|
||||
build_rollout_context,
|
||||
)
|
||||
from .inference import (
|
||||
FallbackMode,
|
||||
InferenceEngine,
|
||||
InferenceEngineConfig,
|
||||
RemoteInferenceConfig,
|
||||
RTCInferenceConfig,
|
||||
RTCInferenceEngine,
|
||||
SyncInferenceConfig,
|
||||
@@ -70,12 +72,14 @@ __all__ = [
|
||||
"HighlightStrategyConfig",
|
||||
"EpisodicStrategy",
|
||||
"EpisodicStrategyConfig",
|
||||
"FallbackMode",
|
||||
"InferenceEngine",
|
||||
"InferenceEngineConfig",
|
||||
"PolicyContext",
|
||||
"ProcessorContext",
|
||||
"RTCInferenceConfig",
|
||||
"RTCInferenceEngine",
|
||||
"RemoteInferenceConfig",
|
||||
"RolloutConfig",
|
||||
"RolloutContext",
|
||||
"RolloutStrategy",
|
||||
|
||||
@@ -51,6 +51,7 @@ from lerobot.utils.feature_utils import combine_feature_dicts, hw_to_dataset_fea
|
||||
from .configs import BaseStrategyConfig, DAggerStrategyConfig, RolloutConfig
|
||||
from .inference import (
|
||||
InferenceEngine,
|
||||
RemoteInferenceConfig,
|
||||
RTCInferenceConfig,
|
||||
SyncInferenceConfig,
|
||||
create_inference_engine,
|
||||
@@ -113,11 +114,17 @@ class HardwareContext:
|
||||
|
||||
@dataclass
|
||||
class PolicyContext:
|
||||
"""Loaded policy and its inference engine."""
|
||||
"""Loaded policy and its inference engine.
|
||||
|
||||
policy: PreTrainedPolicy
|
||||
preprocessor: PolicyProcessorPipeline
|
||||
postprocessor: PolicyProcessorPipeline
|
||||
``policy``/``preprocessor``/``postprocessor`` are ``None`` for the
|
||||
weightless remote backend (``--inference.type=remote``): inference
|
||||
runs on a ``lerobot-policy-server`` and strategies only ever consume
|
||||
``inference``.
|
||||
"""
|
||||
|
||||
policy: PreTrainedPolicy | None
|
||||
preprocessor: PolicyProcessorPipeline | None
|
||||
postprocessor: PolicyProcessorPipeline | None
|
||||
inference: InferenceEngine
|
||||
|
||||
|
||||
@@ -172,54 +179,66 @@ def build_rollout_context(
|
||||
fails fast without touching the robot.
|
||||
"""
|
||||
is_rtc = isinstance(cfg.inference, RTCInferenceConfig)
|
||||
is_remote = isinstance(cfg.inference, RemoteInferenceConfig)
|
||||
|
||||
# --- 1. Policy (heavy I/O, but no hardware yet) -------------------
|
||||
logger.info("Loading policy from '%s'...", cfg.policy.pretrained_path)
|
||||
# Remote inference keeps the edge weightless: the config-only
|
||||
# PreTrainedConfig (already loaded by RolloutConfig.__post_init__,
|
||||
# no weight download) is all the client needs for pre-flight
|
||||
# validation and action ordering.
|
||||
policy_config = cfg.policy
|
||||
policy_class = get_policy_class(policy_config.type)
|
||||
|
||||
if hasattr(policy_config, "compile_model"):
|
||||
policy_config.compile_model = cfg.use_torch_compile
|
||||
|
||||
if policy_config.type == "vqbet" and cfg.device == "mps":
|
||||
raise NotImplementedError(
|
||||
"Current implementation of VQBeT does not support `mps` backend. "
|
||||
"Please use `cpu` or `cuda` backend."
|
||||
policy = None
|
||||
if is_remote:
|
||||
logger.info(
|
||||
"Remote inference: weightless client for '%s' (no weights downloaded)",
|
||||
cfg.policy.pretrained_path,
|
||||
)
|
||||
|
||||
if policy_config.use_peft:
|
||||
from peft import PeftConfig, PeftModel
|
||||
|
||||
peft_path = policy_config.pretrained_path
|
||||
peft_config = PeftConfig.from_pretrained(peft_path)
|
||||
policy = policy_class.from_pretrained(
|
||||
pretrained_name_or_path=peft_config.base_model_name_or_path, config=policy_config
|
||||
)
|
||||
policy = PeftModel.from_pretrained(policy, peft_path, config=peft_config)
|
||||
else:
|
||||
policy = policy_class.from_pretrained(policy_config.pretrained_path, config=policy_config)
|
||||
logger.info("Loading policy from '%s'...", cfg.policy.pretrained_path)
|
||||
policy_class = get_policy_class(policy_config.type)
|
||||
|
||||
if is_rtc:
|
||||
policy.config.rtc_config = cfg.inference.rtc
|
||||
if hasattr(policy, "init_rtc_processor"):
|
||||
policy.init_rtc_processor()
|
||||
if hasattr(policy_config, "compile_model"):
|
||||
policy_config.compile_model = cfg.use_torch_compile
|
||||
|
||||
policy = policy.to(cfg.device)
|
||||
policy.eval()
|
||||
logger.info("Policy loaded: type=%s, device=%s", policy_config.type, cfg.device)
|
||||
if policy_config.type == "vqbet" and cfg.device == "mps":
|
||||
raise NotImplementedError(
|
||||
"Current implementation of VQBeT does not support `mps` backend. "
|
||||
"Please use `cpu` or `cuda` backend."
|
||||
)
|
||||
|
||||
if cfg.use_torch_compile and policy.type not in ("pi0", "pi05"):
|
||||
try:
|
||||
if hasattr(torch, "compile"):
|
||||
compile_kwargs = {
|
||||
"backend": cfg.torch_compile_backend,
|
||||
"mode": cfg.torch_compile_mode,
|
||||
"options": {"triton.cudagraphs": False},
|
||||
}
|
||||
policy.predict_action_chunk = torch.compile(policy.predict_action_chunk, **compile_kwargs)
|
||||
logger.info("torch.compile applied to predict_action_chunk")
|
||||
except Exception as e:
|
||||
logger.warning("Failed to apply torch.compile: %s", e)
|
||||
if policy_config.use_peft:
|
||||
from peft import PeftConfig, PeftModel
|
||||
|
||||
peft_path = policy_config.pretrained_path
|
||||
peft_config = PeftConfig.from_pretrained(peft_path)
|
||||
policy = policy_class.from_pretrained(
|
||||
pretrained_name_or_path=peft_config.base_model_name_or_path, config=policy_config
|
||||
)
|
||||
policy = PeftModel.from_pretrained(policy, peft_path, config=peft_config)
|
||||
else:
|
||||
policy = policy_class.from_pretrained(policy_config.pretrained_path, config=policy_config)
|
||||
|
||||
if is_rtc:
|
||||
policy.config.rtc_config = cfg.inference.rtc
|
||||
if hasattr(policy, "init_rtc_processor"):
|
||||
policy.init_rtc_processor()
|
||||
|
||||
policy = policy.to(cfg.device)
|
||||
policy.eval()
|
||||
logger.info("Policy loaded: type=%s, device=%s", policy_config.type, cfg.device)
|
||||
|
||||
if cfg.use_torch_compile and policy.type not in ("pi0", "pi05"):
|
||||
try:
|
||||
if hasattr(torch, "compile"):
|
||||
compile_kwargs = {
|
||||
"backend": cfg.torch_compile_backend,
|
||||
"mode": cfg.torch_compile_mode,
|
||||
"options": {"triton.cudagraphs": False},
|
||||
}
|
||||
policy.predict_action_chunk = torch.compile(policy.predict_action_chunk, **compile_kwargs)
|
||||
logger.info("torch.compile applied to predict_action_chunk")
|
||||
except Exception as e:
|
||||
logger.warning("Failed to apply torch.compile: %s", e)
|
||||
|
||||
# --- 2. Robot-side processors (user-supplied or defaults) --------
|
||||
if (
|
||||
@@ -378,31 +397,36 @@ def build_rollout_context(
|
||||
logger.info("Dataset ready: %s (%d existing episodes)", dataset.repo_id, dataset.num_episodes)
|
||||
|
||||
# --- 6. Policy pre/post processors (needs dataset stats if any) ---
|
||||
dataset_stats = None
|
||||
if dataset is not None:
|
||||
dataset_stats = rename_stats(
|
||||
dataset.meta.stats,
|
||||
cfg.rename_map,
|
||||
# Remote inference runs the policy processors server-side (per
|
||||
# session); the edge ships canonical dataset-format observations.
|
||||
preprocessor = None
|
||||
postprocessor = None
|
||||
if not is_remote:
|
||||
dataset_stats = None
|
||||
if dataset is not None:
|
||||
dataset_stats = rename_stats(
|
||||
dataset.meta.stats,
|
||||
cfg.rename_map,
|
||||
)
|
||||
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=policy_config,
|
||||
pretrained_path=cfg.policy.pretrained_path,
|
||||
dataset_stats=dataset_stats,
|
||||
preprocessor_overrides={
|
||||
"device_processor": {"device": cfg.device},
|
||||
"rename_observations_processor": {"rename_map": cfg.rename_map},
|
||||
},
|
||||
)
|
||||
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=policy_config,
|
||||
pretrained_path=cfg.policy.pretrained_path,
|
||||
dataset_stats=dataset_stats,
|
||||
preprocessor_overrides={
|
||||
"device_processor": {"device": cfg.device},
|
||||
"rename_observations_processor": {"rename_map": cfg.rename_map},
|
||||
},
|
||||
)
|
||||
|
||||
if isinstance(cfg.inference, SyncInferenceConfig) and any(
|
||||
isinstance(step, RelativeActionsProcessorStep) and step.enabled
|
||||
for step in getattr(preprocessor, "steps", ())
|
||||
):
|
||||
raise NotImplementedError(
|
||||
"SyncInferenceEngine does not support policies with relative actions for now."
|
||||
"Use --inference.type=rtc or remove relative action processor steps from the policy pipeline."
|
||||
)
|
||||
if isinstance(cfg.inference, SyncInferenceConfig) and any(
|
||||
isinstance(step, RelativeActionsProcessorStep) and step.enabled
|
||||
for step in getattr(preprocessor, "steps", ())
|
||||
):
|
||||
raise NotImplementedError(
|
||||
"SyncInferenceEngine does not support policies with relative actions for now."
|
||||
"Use --inference.type=rtc or remove relative action processor steps from the policy pipeline."
|
||||
)
|
||||
|
||||
# --- 7. Inference strategy (needs policy + pre/post + hardware) --
|
||||
logger.info(
|
||||
@@ -425,6 +449,8 @@ def build_rollout_context(
|
||||
use_torch_compile=cfg.use_torch_compile,
|
||||
compile_warmup_inferences=cfg.compile_warmup_inferences,
|
||||
shutdown_event=shutdown_event,
|
||||
policy_config=policy_config,
|
||||
rename_map=cfg.rename_map,
|
||||
)
|
||||
|
||||
# --- 8. Assemble ---------------------------------------------------
|
||||
|
||||
@@ -14,13 +14,18 @@
|
||||
|
||||
"""Inference engine package — backend-agnostic action production.
|
||||
|
||||
Concrete backends (``sync``, ``rtc``, ...) expose the same small interface so
|
||||
rollout strategies never branch on which backend is in use.
|
||||
Concrete backends (``sync``, ``rtc``, ``remote``, ...) expose the same
|
||||
small interface so rollout strategies never branch on which backend is
|
||||
in use.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from .base import InferenceEngine
|
||||
from .factory import (
|
||||
FallbackMode,
|
||||
InferenceEngineConfig,
|
||||
RemoteInferenceConfig,
|
||||
RTCInferenceConfig,
|
||||
SyncInferenceConfig,
|
||||
create_inference_engine,
|
||||
@@ -29,11 +34,23 @@ from .rtc import RTCInferenceEngine
|
||||
from .sync import SyncInferenceEngine
|
||||
|
||||
__all__ = [
|
||||
"FallbackMode",
|
||||
"InferenceEngine",
|
||||
"InferenceEngineConfig",
|
||||
"RTCInferenceConfig",
|
||||
"RTCInferenceEngine",
|
||||
"RemoteInferenceConfig",
|
||||
"RemoteInferenceEngine",
|
||||
"SyncInferenceConfig",
|
||||
"SyncInferenceEngine",
|
||||
"create_inference_engine",
|
||||
]
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
# Lazy: RemoteInferenceEngine pulls in msgpack/zenoh ('async' extra).
|
||||
if name == "RemoteInferenceEngine":
|
||||
from .remote import RemoteInferenceEngine
|
||||
|
||||
return RemoteInferenceEngine
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
|
||||
@@ -14,9 +14,9 @@
|
||||
|
||||
"""Inference engine configs and factory.
|
||||
|
||||
Selection is explicit via ``--inference.type=sync|rtc``. Adding a new
|
||||
backend requires registering its config subclass and dispatching it in
|
||||
:func:`create_inference_engine`.
|
||||
Selection is explicit via ``--inference.type=sync|rtc|remote``. Adding a
|
||||
new backend requires registering its config subclass and dispatching it
|
||||
in :func:`create_inference_engine`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -24,10 +24,12 @@ from __future__ import annotations
|
||||
import abc
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from enum import StrEnum
|
||||
from threading import Event
|
||||
|
||||
import draccus
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.processor import PolicyProcessorPipeline
|
||||
@@ -74,6 +76,73 @@ class RTCInferenceConfig(InferenceEngineConfig):
|
||||
queue_threshold: int = 30
|
||||
|
||||
|
||||
class FallbackMode(StrEnum):
|
||||
"""What ``get_action`` returns when the remote queue runs dry (STALLED)."""
|
||||
|
||||
HOLD = "hold" # return None: the robot holds its last commanded position
|
||||
REPEAT_LAST = "repeat_last" # re-send the last executed action
|
||||
ZERO = "zero" # explicit zero command (required for velocity-controlled robots)
|
||||
|
||||
|
||||
@InferenceEngineConfig.register_subclass("remote")
|
||||
@dataclass
|
||||
class RemoteInferenceConfig(InferenceEngineConfig):
|
||||
"""Network inference against a ``lerobot-policy-server`` over Zenoh.
|
||||
|
||||
The edge stays weightless: ``--policy.path`` resolves to a
|
||||
config-only ``PreTrainedConfig`` (no weight download) used for
|
||||
pre-flight validation and action ordering. Requires the ``async``
|
||||
extra (``pip install 'lerobot[async]'``).
|
||||
"""
|
||||
|
||||
# Transport: robots dial out to a zenoh router (NAT-friendly).
|
||||
connect_endpoint: str = "tcp/localhost:7447"
|
||||
# "client" via a zenohd router (production) | "peer" direct (LAN/tests).
|
||||
zenoh_mode: str = "client"
|
||||
tls_ca: str | None = None
|
||||
tls_cert: str | None = None
|
||||
tls_key: str | None = None
|
||||
|
||||
# Service addressing: which (model, revision, task) key tree to dial.
|
||||
# service_model_id defaults to --policy.path; service_task to the
|
||||
# rollout task. These must match the server manifest's namespace.
|
||||
service_model_id: str = ""
|
||||
service_revision: str = "main"
|
||||
service_task: str = ""
|
||||
|
||||
# Identity: "" → a fresh uuid4 per run. Set a stable ID per robot for
|
||||
# fleet-wide log correlation and per-robot router ACLs.
|
||||
client_uuid: str = ""
|
||||
|
||||
# Observation encoding: JPEG quality (0 = raw, LAN/debug only).
|
||||
jpeg_quality: int = 90
|
||||
|
||||
# Self-clocking: request the next chunk when the local queue holds
|
||||
# less than this many seconds of playback.
|
||||
buffer_time_s: float = 0.5
|
||||
|
||||
# Safety: never execute an action whose source observation is older
|
||||
# than this (bounds open-loop execution after a network stall).
|
||||
max_action_age_s: float = 3.0
|
||||
# Fallback when the queue runs dry (see FallbackMode).
|
||||
fallback: FallbackMode = FallbackMode.HOLD
|
||||
|
||||
# Watchdogs & reconnection.
|
||||
degraded_after_s: float = 1.0
|
||||
request_timeout_s: float = 5.0
|
||||
handshake_timeout_s: float = 2.0
|
||||
reconnect_initial_backoff_s: float = 0.5
|
||||
reconnect_max_backoff_s: float = 10.0
|
||||
max_offline_s: float = 60.0
|
||||
|
||||
# RTC settings (enabled → replace-merge with prefix conditioning when
|
||||
# the server supports it; otherwise downgraded to chunk-append).
|
||||
rtc: RTCConfig = field(default_factory=RTCConfig)
|
||||
|
||||
# Free-form labels forwarded in the session handshake (telemetry only).
|
||||
tags: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Factory
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -82,9 +151,9 @@ class RTCInferenceConfig(InferenceEngineConfig):
|
||||
def create_inference_engine(
|
||||
config: InferenceEngineConfig,
|
||||
*,
|
||||
policy: PreTrainedPolicy,
|
||||
preprocessor: PolicyProcessorPipeline,
|
||||
postprocessor: PolicyProcessorPipeline,
|
||||
policy: PreTrainedPolicy | None,
|
||||
preprocessor: PolicyProcessorPipeline | None,
|
||||
postprocessor: PolicyProcessorPipeline | None,
|
||||
robot_wrapper: ThreadSafeRobot,
|
||||
hw_features: dict,
|
||||
dataset_features: dict,
|
||||
@@ -95,10 +164,19 @@ def create_inference_engine(
|
||||
use_torch_compile: bool = False,
|
||||
compile_warmup_inferences: int = 2,
|
||||
shutdown_event: Event | None = None,
|
||||
policy_config: PreTrainedConfig | None = None,
|
||||
rename_map: dict[str, str] | None = None,
|
||||
) -> InferenceEngine:
|
||||
"""Instantiate the appropriate inference engine from a config object."""
|
||||
"""Instantiate the appropriate inference engine from a config object.
|
||||
|
||||
``policy``/``preprocessor``/``postprocessor`` are required for the
|
||||
local backends (``sync``, ``rtc``) and must be ``None``-free there;
|
||||
the ``remote`` backend is weightless and needs only ``policy_config``.
|
||||
"""
|
||||
logger.info("Creating inference engine: %s", config.type)
|
||||
if isinstance(config, SyncInferenceConfig):
|
||||
if policy is None or preprocessor is None or postprocessor is None:
|
||||
raise ValueError("sync inference requires a loaded policy and processors")
|
||||
return SyncInferenceEngine(
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
@@ -110,6 +188,8 @@ def create_inference_engine(
|
||||
robot_type=robot_wrapper.robot_type,
|
||||
)
|
||||
if isinstance(config, RTCInferenceConfig):
|
||||
if policy is None or preprocessor is None or postprocessor is None:
|
||||
raise ValueError("rtc inference requires a loaded policy and processors")
|
||||
return RTCInferenceEngine(
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
@@ -125,4 +205,25 @@ def create_inference_engine(
|
||||
rtc_queue_threshold=config.queue_threshold,
|
||||
shutdown_event=shutdown_event,
|
||||
)
|
||||
if isinstance(config, RemoteInferenceConfig):
|
||||
if policy_config is None:
|
||||
raise ValueError("remote inference requires policy_config (from config-only --policy.path)")
|
||||
if use_torch_compile:
|
||||
logger.warning("--use_torch_compile is ignored with remote inference (server-side concern)")
|
||||
if device not in (None, "cpu"):
|
||||
logger.warning("--device=%s is ignored with remote inference (server-side concern)", device)
|
||||
# Lazy import: eclipse-zenoh/msgpack live behind the 'async' extra.
|
||||
from .remote import RemoteInferenceEngine
|
||||
|
||||
return RemoteInferenceEngine(
|
||||
config=config,
|
||||
policy_config=policy_config,
|
||||
hw_features=hw_features,
|
||||
ordered_action_keys=ordered_action_keys,
|
||||
task=task,
|
||||
fps=fps,
|
||||
robot_type=robot_wrapper.robot_type,
|
||||
rename_map=rename_map,
|
||||
shutdown_event=shutdown_event,
|
||||
)
|
||||
raise ValueError(f"Unknown inference engine type: {type(config).__name__}")
|
||||
|
||||
@@ -0,0 +1,851 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Remote inference engine: network-decoupled policy inference over Zenoh.
|
||||
|
||||
The same architecture as :class:`RTCInferenceEngine` with the thread
|
||||
boundary replaced by a network boundary. The edge stays **weightless**
|
||||
(no policy weights, no policy processors); a ``lerobot-policy-server``
|
||||
runs the heavy half. All chunk state — leftover prefixes, latency
|
||||
tracking, delay computation — lives client-side in the existing
|
||||
``ActionQueue``/``LatencyTracker`` machinery, so the server is stateless
|
||||
per request and a server crash loses zero control state.
|
||||
|
||||
Threading model:
|
||||
- **Main thread** (strategy loop): ``notify_observation`` writes a
|
||||
latest-only slot; ``get_action`` pops the local queue and applies the
|
||||
staleness bound + fallback ladder. Never any I/O.
|
||||
- **Network worker** (one daemon thread): self-clocked by
|
||||
``buffer_time_s``, publishes one observation, awaits its chunk (or
|
||||
timeout), merges, repeats. One-in-flight is a *correctness*
|
||||
requirement: ``idx_before``/prefix snapshots must serialize with
|
||||
merges.
|
||||
- **Zenoh threads**: deposit-only callbacks (chunk → bounded queue,
|
||||
liveliness → event).
|
||||
|
||||
Clock iron rule: wall-clock instants never cross machines. The header's
|
||||
``client_mono_ns`` is opaque to the server and echoed back; the server
|
||||
reports only durations.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
import math
|
||||
import queue as queue_module
|
||||
import time
|
||||
import traceback
|
||||
import uuid as uuid_module
|
||||
from threading import Event, Lock, Thread
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.policies.rtc import ActionQueue, LatencyTracker
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.policy_server import codec
|
||||
from lerobot.policy_server.schema import (
|
||||
MSG_TYPE_OBS,
|
||||
SCHEMA_VERSION,
|
||||
MsgHeader,
|
||||
ObservationMsg,
|
||||
ResetMsg,
|
||||
SessionAckMsg,
|
||||
SessionCloseMsg,
|
||||
SessionOpenMsg,
|
||||
action_key,
|
||||
client_alive_key,
|
||||
obs_key,
|
||||
reset_key,
|
||||
sanitize_key_segment,
|
||||
server_alive_key,
|
||||
service_prefix,
|
||||
session_key,
|
||||
status_key,
|
||||
)
|
||||
from lerobot.policy_server.zenoh_utils import build_zenoh_config, import_zenoh, obs_publisher_qos
|
||||
from lerobot.utils.constants import OBS_STATE, OBS_STR
|
||||
from lerobot.utils.feature_utils import build_dataset_frame
|
||||
|
||||
from .base import InferenceEngine
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
|
||||
from .factory import RemoteInferenceConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_IDLE_SLEEP_S = 0.01
|
||||
_MAX_CONSECUTIVE_WORKER_ERRORS = 10
|
||||
|
||||
|
||||
class ClientState:
|
||||
"""Fail-safe state machine states (see the design doc §9.2)."""
|
||||
|
||||
CONNECTING = "CONNECTING"
|
||||
STREAMING = "STREAMING"
|
||||
DEGRADED = "DEGRADED"
|
||||
STALLED = "STALLED"
|
||||
RECONNECTING = "RECONNECTING"
|
||||
DEAD = "DEAD"
|
||||
|
||||
|
||||
class RemoteInferenceEngine(InferenceEngine):
|
||||
"""``--inference.type=remote``: weightless edge client of a policy server."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: RemoteInferenceConfig,
|
||||
policy_config: PreTrainedConfig,
|
||||
hw_features: dict,
|
||||
ordered_action_keys: list[str],
|
||||
task: str,
|
||||
fps: float,
|
||||
robot_type: str,
|
||||
rename_map: dict[str, str] | None = None,
|
||||
shutdown_event: Event | None = None,
|
||||
) -> None:
|
||||
self._config = config
|
||||
self._policy_config = policy_config
|
||||
self._hw_features = hw_features
|
||||
self._ordered_action_keys = list(ordered_action_keys)
|
||||
self._task = task
|
||||
self._fps = float(fps)
|
||||
self._dt = 1.0 / self._fps
|
||||
self._robot_type = robot_type
|
||||
self._rename_map = dict(rename_map or {})
|
||||
self._global_shutdown_event = shutdown_event
|
||||
|
||||
self._client_uuid = sanitize_key_segment(config.client_uuid or uuid_module.uuid4().hex)
|
||||
model_id = config.service_model_id or getattr(policy_config, "pretrained_path", "") or "model"
|
||||
self._prefix = service_prefix(model_id, config.service_revision, config.service_task or task)
|
||||
|
||||
# Latest-only observation slot (identical to rtc.py's _obs_holder).
|
||||
self._obs_holder: dict[str, Any] = {"obs": None}
|
||||
self._obs_lock = Lock()
|
||||
|
||||
self._action_queue: ActionQueue | None = None
|
||||
self._latency_tracker = LatencyTracker()
|
||||
self._effective_rtc: RTCConfig = config.rtc
|
||||
|
||||
# Replies deposited by the zenoh callback, consumed by the worker.
|
||||
self._reply_queue: queue_module.Queue[tuple[MsgHeader, bytes]] = queue_module.Queue(maxsize=4)
|
||||
|
||||
self._zenoh = None
|
||||
self._obs_publisher = None
|
||||
self._declarations: list[Any] = []
|
||||
self._alive_token = None
|
||||
self._server_alive = Event()
|
||||
|
||||
self._worker: Thread | None = None
|
||||
self._stop_event = Event()
|
||||
self._active = Event()
|
||||
self._dead = Event()
|
||||
self._session_ack: SessionAckMsg | None = None
|
||||
|
||||
self.state = ClientState.CONNECTING
|
||||
self._state_lock = Lock()
|
||||
self._seq_id = 0
|
||||
self._epoch = 0
|
||||
self._episode_id = 0
|
||||
self._pending_reset = False
|
||||
|
||||
# Staleness bookkeeping: client-monotonic send time of the
|
||||
# observation that produced the current queue contents.
|
||||
# _anchor_lock serializes {merge + anchor update} (worker),
|
||||
# {staleness clear} (control thread), and {reset clear} so a
|
||||
# stale chunk can never merge into a freshly-reset queue and the
|
||||
# safety path can never clear a just-merged one.
|
||||
self._anchor_lock = Lock()
|
||||
self._chunk_anchor_mono: float | None = None
|
||||
self._last_chunk_mono: float | None = None
|
||||
self._offline_since_mono: float | None = None
|
||||
self._last_action: torch.Tensor | None = None
|
||||
|
||||
self.stats: dict[str, float] = {
|
||||
"requests": 0,
|
||||
"timeouts": 0,
|
||||
"merges": 0,
|
||||
"stale_drops": 0,
|
||||
"fallback_ticks": 0,
|
||||
"reconnects": 0,
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Lifecycle
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@property
|
||||
def ready(self) -> bool:
|
||||
"""Session opened, capabilities validated, server warmed up."""
|
||||
ack = self._session_ack
|
||||
return ack is not None and ack.warmed_up and not self._dead.is_set()
|
||||
|
||||
@property
|
||||
def failed(self) -> bool:
|
||||
return self._dead.is_set()
|
||||
|
||||
@property
|
||||
def action_queue(self) -> ActionQueue | None:
|
||||
return self._action_queue
|
||||
|
||||
def start(self) -> None:
|
||||
"""Open transport, handshake, start the network worker.
|
||||
|
||||
Raises on initial connection/validation failure so a bad
|
||||
deployment aborts before the robot moves (reconnect logic only
|
||||
guards established sessions).
|
||||
"""
|
||||
zenoh = import_zenoh()
|
||||
cfg = self._config
|
||||
self._zenoh = zenoh.open(
|
||||
build_zenoh_config(
|
||||
mode=cfg.zenoh_mode,
|
||||
connect_endpoints=[cfg.connect_endpoint] if cfg.connect_endpoint else None,
|
||||
tls_root_ca_certificate=cfg.tls_ca,
|
||||
tls_connect_certificate=cfg.tls_cert,
|
||||
tls_connect_private_key=cfg.tls_key,
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
ack = self._handshake(initial=True)
|
||||
except Exception:
|
||||
# Fail fast without leaking the transport session.
|
||||
with contextlib.suppress(Exception):
|
||||
self._zenoh.close()
|
||||
self._zenoh = None
|
||||
raise
|
||||
self._configure_from_ack(ack)
|
||||
|
||||
handlers = zenoh.handlers
|
||||
self._declarations.append(
|
||||
self._zenoh.declare_subscriber(
|
||||
action_key(self._prefix, self._client_uuid), handlers.Callback(self._on_chunk)
|
||||
)
|
||||
)
|
||||
self._obs_publisher = self._zenoh.declare_publisher(
|
||||
obs_key(self._prefix, self._client_uuid), **obs_publisher_qos(zenoh)
|
||||
)
|
||||
self._declarations.append(self._obs_publisher)
|
||||
self._server_alive.set()
|
||||
self._declarations.append(
|
||||
self._zenoh.liveliness().declare_subscriber(
|
||||
server_alive_key(self._prefix), handlers.Callback(self._on_server_liveliness), history=True
|
||||
)
|
||||
)
|
||||
self._alive_token = self._zenoh.liveliness().declare_token(
|
||||
client_alive_key(self._prefix, self._client_uuid)
|
||||
)
|
||||
|
||||
self._stop_event.clear()
|
||||
self._dead.clear()
|
||||
self._active.set()
|
||||
self._worker = Thread(target=self._worker_loop, daemon=True, name="RemoteInference")
|
||||
self._worker.start()
|
||||
logger.info(
|
||||
"Remote inference started: prefix=%s client=%s rtc=%s",
|
||||
self._prefix,
|
||||
self._client_uuid,
|
||||
self._effective_rtc.enabled,
|
||||
)
|
||||
|
||||
def stop(self) -> None:
|
||||
logger.info("Stopping remote inference engine...")
|
||||
self._stop_event.set()
|
||||
self._active.clear()
|
||||
if self._worker is not None and self._worker.is_alive():
|
||||
# Worst case the worker is mid-handshake inside _enter_reconnect.
|
||||
join_timeout = max(3.0, self._config.handshake_timeout_s + self._config.request_timeout_s + 2.0)
|
||||
self._worker.join(timeout=join_timeout)
|
||||
if self._worker.is_alive():
|
||||
logger.warning("Remote inference worker did not join")
|
||||
self._worker = None
|
||||
|
||||
if self._zenoh is not None:
|
||||
# Best-effort graceful close; the server also GCs on liveliness drop.
|
||||
with contextlib.suppress(Exception):
|
||||
self._control_query(
|
||||
session_key(self._prefix),
|
||||
codec.encode_session_close(
|
||||
SessionCloseMsg(
|
||||
client_uuid=self._client_uuid,
|
||||
session_id=self._session_ack.session_id if self._session_ack else "",
|
||||
)
|
||||
),
|
||||
timeout=1.0,
|
||||
)
|
||||
if self._alive_token is not None:
|
||||
with contextlib.suppress(Exception):
|
||||
self._alive_token.undeclare()
|
||||
self._alive_token = None
|
||||
for declaration in self._declarations:
|
||||
with contextlib.suppress(Exception):
|
||||
declaration.undeclare()
|
||||
self._declarations.clear()
|
||||
self._obs_publisher = None
|
||||
with contextlib.suppress(Exception):
|
||||
self._zenoh.close()
|
||||
self._zenoh = None
|
||||
logger.info("Remote inference engine stopped")
|
||||
|
||||
def pause(self) -> None:
|
||||
"""Stop publishing observations; the local queue stays intact."""
|
||||
logger.info("Pausing remote inference (publishing stops, queue intact)")
|
||||
self._active.clear()
|
||||
|
||||
def resume(self) -> None:
|
||||
logger.info("Resuming remote inference")
|
||||
self._active.set()
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Episode boundary: clear local chunk state, notify the server.
|
||||
|
||||
The acked reset query runs on the worker thread (never I/O on the
|
||||
control thread); thanks to per-request server statelessness a
|
||||
lost ack only costs a warning — the next observation announces
|
||||
the new episode in its header anyway.
|
||||
"""
|
||||
logger.info("Resetting remote inference state (queue + episode)")
|
||||
with self._anchor_lock:
|
||||
if self._action_queue is not None:
|
||||
self._action_queue.clear()
|
||||
self._chunk_anchor_mono = None
|
||||
with self._state_lock:
|
||||
self._episode_id += 1
|
||||
self._pending_reset = True
|
||||
with self._obs_lock:
|
||||
# The previous episode's final frame must not seed the new
|
||||
# episode's first request.
|
||||
self._obs_holder["obs"] = None
|
||||
self._last_action = None
|
||||
# LatencyTracker intentionally survives reset: latency is
|
||||
# episode-invariant (parity with local RTC).
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Action production (main thread — never any I/O)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def notify_observation(self, obs: dict) -> None:
|
||||
with self._obs_lock:
|
||||
self._obs_holder["obs"] = obs
|
||||
|
||||
def get_action(self, obs_frame: dict | None) -> torch.Tensor | None:
|
||||
queue = self._action_queue
|
||||
if queue is None:
|
||||
return None
|
||||
|
||||
# Staleness bound (sync safety): never execute an action whose
|
||||
# source observation is older than max_action_age_s. The lock
|
||||
# makes the check-and-clear atomic with the worker's merge.
|
||||
with self._anchor_lock:
|
||||
anchor = self._chunk_anchor_mono
|
||||
if (
|
||||
anchor is not None
|
||||
and queue.qsize() > 0
|
||||
and time.monotonic() - anchor > self._config.max_action_age_s
|
||||
):
|
||||
logger.warning(
|
||||
"Dropping %d stale actions (older than %.1fs) — applying fallback",
|
||||
queue.qsize(),
|
||||
self._config.max_action_age_s,
|
||||
)
|
||||
self.stats["stale_drops"] += 1
|
||||
queue.clear()
|
||||
self._chunk_anchor_mono = None
|
||||
|
||||
action = queue.get()
|
||||
if action is not None:
|
||||
self._last_action = action
|
||||
return action
|
||||
|
||||
self._set_state(ClientState.STALLED if self.state == ClientState.DEGRADED else self.state)
|
||||
return self._fallback_action()
|
||||
|
||||
def _fallback_action(self) -> torch.Tensor | None:
|
||||
from .factory import FallbackMode
|
||||
|
||||
mode = self._config.fallback
|
||||
if mode == FallbackMode.REPEAT_LAST and self._last_action is not None:
|
||||
self.stats["fallback_ticks"] += 1
|
||||
return self._last_action.clone()
|
||||
if mode == FallbackMode.ZERO:
|
||||
# For velocity-controlled robots "send nothing" means "keep
|
||||
# last velocity" — an explicit zero command is the safe stop.
|
||||
self.stats["fallback_ticks"] += 1
|
||||
return torch.zeros(len(self._ordered_action_keys))
|
||||
return None # HOLD: send_next_action tolerates None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Handshake & control plane
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _handshake(self, initial: bool) -> SessionAckMsg:
|
||||
"""status (pre-flight) + session open; raises on rejection."""
|
||||
cfg = self._config
|
||||
status_data = self._control_query(status_key(self._prefix), b"", timeout=cfg.handshake_timeout_s)
|
||||
if status_data is None:
|
||||
raise ConnectionError(
|
||||
f"No policy server answered status query at {status_key(self._prefix)!r} "
|
||||
f"via {cfg.connect_endpoint!r} (timeout {cfg.handshake_timeout_s}s)"
|
||||
)
|
||||
status = codec.decode_status(status_data)
|
||||
logger.info(
|
||||
"Server status: model=%s@%s policy=%s sessions=%d/%d warmed_up=%s",
|
||||
status.model_repo,
|
||||
status.model_revision,
|
||||
status.policy_type,
|
||||
status.active_sessions,
|
||||
status.max_sessions,
|
||||
status.warmed_up,
|
||||
)
|
||||
|
||||
open_msg = SessionOpenMsg(
|
||||
client_uuid=self._client_uuid,
|
||||
robot_type=self._robot_type,
|
||||
policy_type=getattr(self._policy_config, "type", ""),
|
||||
fps=self._fps,
|
||||
action_names=self._ordered_action_keys,
|
||||
camera_names=self._wire_camera_names(),
|
||||
state_dim=self._state_dim(),
|
||||
schema_version=SCHEMA_VERSION,
|
||||
rtc_enabled=cfg.rtc.enabled,
|
||||
task=self._task,
|
||||
tags=cfg.tags,
|
||||
)
|
||||
ack_data = self._control_query(
|
||||
session_key(self._prefix), codec.encode_session_open(open_msg), timeout=cfg.request_timeout_s
|
||||
)
|
||||
if ack_data is None:
|
||||
raise ConnectionError("Session open query timed out")
|
||||
ack = codec.decode_session_ack(ack_data)
|
||||
if not ack.accepted:
|
||||
raise ConnectionError(f"Policy server rejected the session: {ack.error}")
|
||||
for warning in ack.warnings:
|
||||
logger.warning("Server warning: %s", warning)
|
||||
|
||||
# Hard sync-safety contract: chunk columns map to motors by order.
|
||||
if ack.action_names and ack.action_names != self._ordered_action_keys:
|
||||
raise ValueError(
|
||||
"Action name/order mismatch between server policy and this robot.\n"
|
||||
f" server: {ack.action_names}\n client: {self._ordered_action_keys}"
|
||||
)
|
||||
if not initial and self._session_ack is not None:
|
||||
previous = self._session_ack
|
||||
if (ack.model_repo, ack.model_revision) != (previous.model_repo, previous.model_revision):
|
||||
raise ValueError(
|
||||
f"Server model changed across reconnect "
|
||||
f"({previous.model_repo}@{previous.model_revision} → "
|
||||
f"{ack.model_repo}@{ack.model_revision}) — refusing to execute wrong-model chunks"
|
||||
)
|
||||
return ack
|
||||
|
||||
def _configure_from_ack(self, ack: SessionAckMsg) -> None:
|
||||
rtc_requested = self._config.rtc.enabled
|
||||
rtc_effective = rtc_requested and ack.supports_rtc
|
||||
if rtc_requested and not rtc_effective:
|
||||
logger.warning("RTC downgraded to chunk-append (server does not support RTC)")
|
||||
if self._action_queue is not None and self._action_queue.cfg.enabled != rtc_effective:
|
||||
# The queue's merge semantics (replace vs append) were fixed at
|
||||
# session start; a server whose RTC capability changed across a
|
||||
# reconnect would corrupt them.
|
||||
raise ValueError(
|
||||
"Server RTC capability changed across reconnect "
|
||||
f"(queue merge mode {'replace' if self._action_queue.cfg.enabled else 'append'} "
|
||||
f"vs server RTC={rtc_effective}) — refusing to continue"
|
||||
)
|
||||
self._effective_rtc = RTCConfig(
|
||||
enabled=rtc_effective,
|
||||
prefix_attention_schedule=self._config.rtc.prefix_attention_schedule,
|
||||
max_guidance_weight=self._config.rtc.max_guidance_weight,
|
||||
execution_horizon=ack.rtc_execution_horizon or self._config.rtc.execution_horizon,
|
||||
debug=self._config.rtc.debug,
|
||||
debug_maxlen=self._config.rtc.debug_maxlen,
|
||||
)
|
||||
if self._action_queue is None:
|
||||
self._action_queue = ActionQueue(self._effective_rtc)
|
||||
self._session_ack = ack
|
||||
|
||||
def _control_query(self, key: str, payload: bytes, timeout: float) -> bytes | None:
|
||||
"""One request/reply on the control plane; None on timeout/no-server."""
|
||||
zenoh = import_zenoh()
|
||||
try:
|
||||
replies = self._zenoh.get(
|
||||
key,
|
||||
handler=zenoh.handlers.FifoChannel(4),
|
||||
payload=payload,
|
||||
timeout=timeout,
|
||||
)
|
||||
deadline = time.monotonic() + timeout + 0.5
|
||||
while time.monotonic() < deadline:
|
||||
try:
|
||||
reply = replies.try_recv()
|
||||
except Exception: # zenoh.ZError: channel closed (no queryable / finished)
|
||||
return None
|
||||
if reply is None:
|
||||
time.sleep(0.005)
|
||||
continue
|
||||
if reply.ok is not None:
|
||||
return reply.ok.payload.to_bytes()
|
||||
return None # Reply.err (e.g. b"Timeout")
|
||||
return None
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.warning("Control query %s failed: %s", key, e)
|
||||
return None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Zenoh callbacks (deposit-only)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _on_chunk(self, sample: Any) -> None:
|
||||
try:
|
||||
attachment = sample.attachment
|
||||
if attachment is None:
|
||||
return
|
||||
header = MsgHeader.unpack(attachment.to_bytes())
|
||||
item = (header, sample.payload.to_bytes())
|
||||
try:
|
||||
self._reply_queue.put_nowait(item)
|
||||
except queue_module.Full:
|
||||
# Drop oldest, keep newest.
|
||||
with contextlib.suppress(queue_module.Empty):
|
||||
self._reply_queue.get_nowait()
|
||||
with contextlib.suppress(queue_module.Full):
|
||||
self._reply_queue.put_nowait(item)
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.error("chunk callback error: %s", e)
|
||||
|
||||
def _on_server_liveliness(self, sample: Any) -> None:
|
||||
try:
|
||||
import zenoh
|
||||
|
||||
if sample.kind == zenoh.SampleKind.DELETE:
|
||||
logger.warning("Server liveliness token dropped")
|
||||
self._server_alive.clear()
|
||||
else:
|
||||
self._server_alive.set()
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.error("liveliness callback error: %s", e)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Network worker
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _worker_loop(self) -> None:
|
||||
consecutive_errors = 0
|
||||
try:
|
||||
while not self._stop_event.is_set():
|
||||
if not self._active.is_set():
|
||||
time.sleep(_IDLE_SLEEP_S)
|
||||
continue
|
||||
try:
|
||||
self._maybe_send_reset()
|
||||
|
||||
if not self._server_alive.is_set():
|
||||
self._enter_reconnect("server liveliness dropped")
|
||||
continue
|
||||
|
||||
queue = self._action_queue
|
||||
if queue is not None and queue.qsize() * self._dt > self._config.buffer_time_s:
|
||||
time.sleep(_IDLE_SLEEP_S)
|
||||
continue
|
||||
|
||||
with self._obs_lock:
|
||||
obs = self._obs_holder.get("obs")
|
||||
if obs is None:
|
||||
time.sleep(_IDLE_SLEEP_S)
|
||||
continue
|
||||
|
||||
self._request_cycle(obs)
|
||||
consecutive_errors = 0
|
||||
except ConnectionError as e:
|
||||
# Raised by reconnect on hard contract violations.
|
||||
raise e
|
||||
except Exception as e: # noqa: BLE001 — transient worker errors retry
|
||||
consecutive_errors += 1
|
||||
logger.error(
|
||||
"Remote inference worker error (%d/%d): %s",
|
||||
consecutive_errors,
|
||||
_MAX_CONSECUTIVE_WORKER_ERRORS,
|
||||
e,
|
||||
)
|
||||
logger.debug(traceback.format_exc())
|
||||
if consecutive_errors >= _MAX_CONSECUTIVE_WORKER_ERRORS:
|
||||
raise
|
||||
time.sleep(0.5)
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.error("Fatal error in remote inference worker: %s", e)
|
||||
logger.error(traceback.format_exc())
|
||||
self._go_dead(str(e))
|
||||
|
||||
def _request_cycle(self, obs: dict) -> None:
|
||||
"""Publish one observation and merge its chunk (one-in-flight)."""
|
||||
cfg = self._config
|
||||
queue = self._action_queue
|
||||
|
||||
obs_frame = build_dataset_frame(self._hw_features, obs, prefix=OBS_STR)
|
||||
if self._rename_map:
|
||||
obs_frame = {self._rename_map.get(k, k): v for k, v in obs_frame.items()}
|
||||
|
||||
state = obs_frame.pop(OBS_STATE, None)
|
||||
images = {k: v for k, v in obs_frame.items() if isinstance(v, np.ndarray) and v.ndim == 3}
|
||||
|
||||
with self._state_lock:
|
||||
self._seq_id += 1
|
||||
seq_id = self._seq_id
|
||||
episode_id = self._episode_id
|
||||
epoch = self._epoch
|
||||
|
||||
# Snapshot RTC state (must precede the publish; merge validates
|
||||
# against idx_before).
|
||||
idx_before = queue.get_action_index()
|
||||
prefix_model: np.ndarray | None = None
|
||||
prefix_robot: np.ndarray | None = None
|
||||
delay_steps = 0
|
||||
if self._effective_rtc.enabled:
|
||||
horizon = self._effective_rtc.execution_horizon
|
||||
left_over = queue.get_left_over()
|
||||
if left_over is not None and left_over.numel():
|
||||
prefix_model = left_over[:horizon].to(torch.float32).numpy()
|
||||
processed_left_over = queue.get_processed_left_over()
|
||||
if processed_left_over is not None and processed_left_over.numel():
|
||||
prefix_robot = processed_left_over[:horizon].to(torch.float32).numpy()
|
||||
max_latency = self._latency_tracker.max() if len(self._latency_tracker) else 0.0
|
||||
delay_steps = math.ceil(max_latency / self._dt) if max_latency else 0
|
||||
|
||||
# A reset/reconnect between the counter snapshot and the prefix
|
||||
# snapshot would pair a new episode id with old-episode prefixes
|
||||
# — skip the cycle instead.
|
||||
with self._state_lock:
|
||||
if (self._episode_id, self._epoch) != (episode_id, epoch):
|
||||
return
|
||||
|
||||
header = MsgHeader(
|
||||
schema_version=SCHEMA_VERSION,
|
||||
msg_type=MSG_TYPE_OBS,
|
||||
seq_id=seq_id,
|
||||
episode_id=episode_id,
|
||||
client_mono_ns=time.monotonic_ns(),
|
||||
session_epoch=epoch,
|
||||
)
|
||||
msg = ObservationMsg(
|
||||
state=state,
|
||||
images=images,
|
||||
task=self._task,
|
||||
inference_delay_steps=delay_steps,
|
||||
prefix_model=prefix_model,
|
||||
prefix_robot=prefix_robot,
|
||||
episode_start=(queue.qsize() == 0 and idx_before == 0 and self._chunk_anchor_mono is None),
|
||||
jpeg_quality=cfg.jpeg_quality,
|
||||
)
|
||||
|
||||
t_send = time.perf_counter()
|
||||
self._obs_publisher.put(codec.encode_observation(msg), attachment=header.pack())
|
||||
self.stats["requests"] += 1
|
||||
|
||||
reply = self._await_chunk(seq_id, episode_id, epoch, timeout=cfg.request_timeout_s)
|
||||
if reply is None:
|
||||
self.stats["timeouts"] += 1
|
||||
self._on_request_timeout()
|
||||
return
|
||||
|
||||
chunk = codec.decode_action_chunk(reply)
|
||||
if chunk.chunk_model is None or chunk.chunk_robot is None:
|
||||
# A persistently malformed server must still trip the
|
||||
# degradation ladder, not stall in nominal state.
|
||||
logger.warning("Chunk for seq=%d had empty tensors — dropping", seq_id)
|
||||
self.stats["timeouts"] += 1
|
||||
self._on_request_timeout()
|
||||
return
|
||||
|
||||
latency = time.perf_counter() - t_send
|
||||
real_delay = math.ceil(latency / self._dt)
|
||||
with self._anchor_lock:
|
||||
# reset() takes the same lock before clearing: either the
|
||||
# reset fully precedes this merge (episode changed → drop the
|
||||
# stale chunk) or the merge completes first (and the reset
|
||||
# then clears it) — a stale chunk can never survive a reset.
|
||||
with self._state_lock:
|
||||
if (self._episode_id, self._epoch) != (episode_id, epoch):
|
||||
logger.debug("Dropping chunk seq=%d: episode/epoch changed mid-flight", seq_id)
|
||||
return
|
||||
queue.merge(
|
||||
torch.from_numpy(np.ascontiguousarray(chunk.chunk_model)),
|
||||
torch.from_numpy(np.ascontiguousarray(chunk.chunk_robot)),
|
||||
real_delay,
|
||||
idx_before,
|
||||
)
|
||||
self._chunk_anchor_mono = time.monotonic() - latency # ≈ when the source obs was sent
|
||||
self._latency_tracker.add(latency)
|
||||
self._last_chunk_mono = time.monotonic()
|
||||
self._offline_since_mono = None
|
||||
self.stats["merges"] += 1
|
||||
self._set_state(ClientState.STREAMING)
|
||||
logger.debug(
|
||||
"merge: seq=%d latency=%.0fms delay=%d queue=%d server(inf=%.0fms wait=%.0fms load=%.2f)",
|
||||
seq_id,
|
||||
latency * 1e3,
|
||||
real_delay,
|
||||
queue.qsize(),
|
||||
chunk.inference_ms,
|
||||
chunk.queue_wait_ms,
|
||||
chunk.server_load,
|
||||
)
|
||||
|
||||
def _await_chunk(self, seq_id: int, episode_id: int, epoch: int, timeout: float) -> bytes | None:
|
||||
"""Wait for the chunk answering the latest outstanding request.
|
||||
|
||||
Stale replies (older seq/episode/epoch) are dropped — under
|
||||
one-in-flight a late chunk can only ever answer an older request.
|
||||
"""
|
||||
deadline = time.monotonic() + timeout
|
||||
while not self._stop_event.is_set():
|
||||
remaining = deadline - time.monotonic()
|
||||
if remaining <= 0:
|
||||
return None
|
||||
try:
|
||||
header, payload = self._reply_queue.get(timeout=min(remaining, 0.1))
|
||||
except queue_module.Empty:
|
||||
continue
|
||||
if header.session_epoch != epoch or header.episode_id != episode_id:
|
||||
continue # stale epoch/episode (reset or reconnect happened)
|
||||
if header.seq_id != seq_id:
|
||||
continue # late reply to a superseded request
|
||||
return payload
|
||||
return None
|
||||
|
||||
def _maybe_send_reset(self) -> None:
|
||||
with self._state_lock:
|
||||
pending, episode_id = self._pending_reset, self._episode_id
|
||||
self._pending_reset = False
|
||||
if pending and self._zenoh is not None:
|
||||
ack_data = self._control_query(
|
||||
reset_key(self._prefix, self._client_uuid),
|
||||
codec.encode_reset(ResetMsg(client_uuid=self._client_uuid, episode_id=episode_id)),
|
||||
timeout=1.0,
|
||||
)
|
||||
if ack_data is None:
|
||||
# Harmless: the server is stateless per request and the next
|
||||
# observation header announces the new episode anyway.
|
||||
logger.warning("Reset ack not received (continuing — header carries the episode bump)")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Degradation / reconnect / death
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _on_request_timeout(self) -> None:
|
||||
if self._stop_event.is_set():
|
||||
# _await_chunk aborted by a normal stop(), not by the network.
|
||||
return
|
||||
now = time.monotonic()
|
||||
if self._offline_since_mono is None:
|
||||
self._offline_since_mono = now
|
||||
offline_for = now - self._offline_since_mono
|
||||
|
||||
queue = self._action_queue
|
||||
if queue is not None and queue.qsize() > 0:
|
||||
self._set_state(ClientState.DEGRADED)
|
||||
else:
|
||||
self._set_state(ClientState.STALLED)
|
||||
|
||||
last = self._last_chunk_mono or 0.0
|
||||
if (now - last if last else offline_for) >= self._config.degraded_after_s:
|
||||
logger.warning(
|
||||
"No chunk for %.1fs (queue=%d) — %s",
|
||||
offline_for,
|
||||
queue.qsize() if queue else 0,
|
||||
self.state,
|
||||
)
|
||||
if offline_for > self._config.max_offline_s:
|
||||
self._go_dead(f"offline for {offline_for:.0f}s (> max_offline_s)")
|
||||
return
|
||||
if not self._server_alive.is_set() or offline_for >= 2 * self._config.request_timeout_s:
|
||||
self._enter_reconnect(f"request timeouts for {offline_for:.0f}s")
|
||||
|
||||
def _enter_reconnect(self, reason: str) -> None:
|
||||
"""Backoff + re-handshake loop. Hard contract violations → DEAD."""
|
||||
self._set_state(ClientState.RECONNECTING)
|
||||
logger.warning("Reconnecting: %s", reason)
|
||||
if self._offline_since_mono is None:
|
||||
self._offline_since_mono = time.monotonic()
|
||||
backoff = self._config.reconnect_initial_backoff_s
|
||||
while not self._stop_event.is_set():
|
||||
if not self._active.is_set():
|
||||
# Paused (e.g. DAgger human correction): keep trying to
|
||||
# reconnect, but a pause must never burn the offline budget
|
||||
# into a mid-correction shutdown.
|
||||
self._offline_since_mono = time.monotonic()
|
||||
offline_for = time.monotonic() - self._offline_since_mono
|
||||
if offline_for > self._config.max_offline_s:
|
||||
self._go_dead(f"offline for {offline_for:.0f}s (> max_offline_s)")
|
||||
return
|
||||
self._stop_event.wait(timeout=backoff)
|
||||
if self._stop_event.is_set():
|
||||
return
|
||||
backoff = min(backoff * 2, self._config.reconnect_max_backoff_s)
|
||||
try:
|
||||
with self._state_lock:
|
||||
self._epoch += 1
|
||||
ack = self._handshake(initial=False)
|
||||
self._configure_from_ack(ack)
|
||||
except ValueError as e:
|
||||
# Capability/schema/model mismatch: never execute wrong-model chunks.
|
||||
self._go_dead(str(e))
|
||||
return
|
||||
except Exception as e: # noqa: BLE001 — server still down, keep trying
|
||||
logger.info("Re-handshake failed (%s) — retrying in %.1fs", e, backoff)
|
||||
continue
|
||||
# A successful handshake is proof of life even if the liveliness
|
||||
# PUT was missed or hasn't been delivered yet.
|
||||
self._server_alive.set()
|
||||
# The offline budget is only reset by the next successful merge:
|
||||
# a server that handshakes but never delivers chunks must still
|
||||
# run out of budget and go DEAD.
|
||||
self.stats["reconnects"] += 1
|
||||
self._set_state(ClientState.STREAMING)
|
||||
logger.info("Reconnected (epoch=%d, session=%s)", self._epoch, ack.session_id)
|
||||
return
|
||||
|
||||
def _go_dead(self, reason: str) -> None:
|
||||
if self._dead.is_set():
|
||||
return
|
||||
logger.error("Remote inference DEAD: %s", reason)
|
||||
self._set_state(ClientState.DEAD)
|
||||
self._dead.set()
|
||||
if self._global_shutdown_event is not None:
|
||||
self._global_shutdown_event.set()
|
||||
|
||||
def _set_state(self, new_state: str) -> None:
|
||||
if new_state != self.state:
|
||||
logger.info("Client state: %s → %s", self.state, new_state)
|
||||
self.state = new_state
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Feature helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _wire_camera_names(self) -> list[str]:
|
||||
names = [
|
||||
key for key, feature in self._hw_features.items() if feature.get("dtype") in ("image", "video")
|
||||
]
|
||||
return [self._rename_map.get(name, name) for name in names]
|
||||
|
||||
def _state_dim(self) -> int:
|
||||
state_feature = self._hw_features.get(OBS_STATE)
|
||||
if state_feature and state_feature.get("names"):
|
||||
return len(state_feature["names"])
|
||||
return 0
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user