mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-16 07:49:48 +00:00
Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| b968020ec4 | |||
| fc019d3902 |
@@ -65,9 +65,6 @@ repos:
|
||||
name: Format Markdown with Prettier
|
||||
types_or: [markdown, mdx]
|
||||
args: [--prose-wrap=preserve]
|
||||
# Jinja2 model-card templates use a .md extension but contain {% ... %} /
|
||||
# {{ ... }} tags that prettier's Markdown formatter mangles (e.g. table loops).
|
||||
exclude: ^src/lerobot/templates/.*\.md$
|
||||
|
||||
##### Security #####
|
||||
- repo: https://github.com/gitleaks/gitleaks
|
||||
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 445 KiB |
@@ -178,9 +178,3 @@ test-smolvla-ete-eval:
|
||||
--env.episode_length=5 \
|
||||
--eval.n_episodes=1 \
|
||||
--eval.batch_size=1
|
||||
|
||||
# E2E annotation pipeline smoke test against a tiny in-memory fixture
|
||||
# dataset. Opt-in (not part of `make test-end-to-end`) and uses a stub VLM
|
||||
# backend, so it does not require a real model checkpoint or GPU.
|
||||
annotation-e2e:
|
||||
uv run python -m tests.annotations.run_e2e_smoke
|
||||
|
||||
@@ -0,0 +1,417 @@
|
||||
# Decoupled VLA Inference & Edge Control: System Design Proposal
|
||||
|
||||
## 1. Executive Summary
|
||||
|
||||
This document proposes a production-grade system for decoupling GPU-bound VLA (Vision-Language-Action) policy inference from high-frequency, CPU-bound robot control in LeRobot. The system adopts a **Model-as-a-Service (MaaS)** paradigm using **Zenoh** as the sole transport protocol, enabling multiple edge devices to be served by centralized GPU servers with minimal latency and high reliability.
|
||||
|
||||
An initial prototype exists in `src/lerobot/async_inference/` (gRPC-based, single-client). This proposal defines the target architecture, identifies gaps between the prototype and production requirements, documents known bugs, and establishes the design for the new system.
|
||||
|
||||
---
|
||||
|
||||
## 2. Motivation
|
||||
|
||||
LeRobot's standard control loop runs policy inference and robot I/O in the same process. This works for lightweight policies on local GPUs, but breaks down when:
|
||||
|
||||
- **The policy is too large for edge hardware** (e.g., Pi0 at ~3B parameters requires a dedicated GPU).
|
||||
- **Multiple robots need the same policy** (redundant GPU allocation per robot).
|
||||
- **Inference latency exceeds the control deadline** (e.g., 200ms inference on a 33ms control loop at 30 FPS).
|
||||
|
||||
Decoupling inference from control solves all three: the edge device runs a tight I/O loop on a CPU, while a GPU server handles inference for one or more clients.
|
||||
|
||||
---
|
||||
|
||||
## 3. Core Architectural Principles
|
||||
|
||||
### 3.1 Model-as-a-Service (MaaS)
|
||||
|
||||
Servers initialize models **once at startup** from a configuration manifest. Edge devices do **not** trigger dynamic model loading — they route to pre-warmed servers and validate compatibility via a status endpoint.
|
||||
|
||||
### 3.2 Multi-Tenant & Stateless Inference
|
||||
|
||||
A single GPU server handles multiple edge devices executing the same task. The server is stateless per inference call — `predict_action_chunk()` is a pure function with no side effects on the model. Client isolation is achieved through per-client observation slots and Zenoh key-expression routing.
|
||||
|
||||
> **Invariant**: `predict_action_chunk()` must remain a pure function (no mutation of `self`) for all supported policies. This is what enables safe multi-tenant sharing of a single model instance. This invariant must be documented and tested.
|
||||
|
||||
### 3.3 Zenoh as primary Transport
|
||||
|
||||
The system uses Zenoh's pub/sub model, replacing the current gRPC implementation. Zenoh provides:
|
||||
|
||||
- **Hierarchical key expressions** for routing (natural fit for the cluster/experiment/model/task topology).
|
||||
- **Built-in discovery** (no external service discovery needed).
|
||||
- **Non-blocking publish** for observations (fire-and-forget with best-effort QoS).
|
||||
- **Reliable delivery** configurable per-topic (required for action chunks).
|
||||
- **Shared-memory transport** for same-machine deployments (zero-copy) (if available).
|
||||
|
||||
### 3.4 Local Edge CPU
|
||||
|
||||
Edge devices rely on standard CPUs for sensor polling, image compression, payload serialization, motor control, and data logging. No edge-GPU dependency.
|
||||
|
||||
---
|
||||
|
||||
## 4. System Topology
|
||||
|
||||

|
||||
|
||||
- **Cluster**: A set of GPU machines. Identified by `cluster_uuid`.
|
||||
- **Experiment**: A logical grouping of servers and clients. Identified by `experiment_tag`.
|
||||
- **Server**: One model + one task, pre-warmed. Serves N clients for that model/task combination.
|
||||
- **Client**: One robot, one task. Publishes observations, subscribes to actions.
|
||||
|
||||
The number of clients a single server can handle is a **user decision** based on model inference time and acceptable latency.
|
||||
|
||||
---
|
||||
|
||||
## 5. Component Specifications
|
||||
|
||||
### 5.1 The Edge Device (Client)
|
||||
|
||||
**Responsibilities:**
|
||||
|
||||
1. **Observation capture**: Read sensors (cameras, motors) at the control loop frequency.
|
||||
2. **Image compression**: JPEG-encode RGB images before transmission.
|
||||
3. **Observation publishing**: Non-blocking Zenoh put to the observation topic.
|
||||
4. **Action subscription**: Zenoh callback receives action chunks, deposits into local buffer.
|
||||
5. **Action execution**: Pop actions from buffer, send to robot at control frequency.
|
||||
6. **Action blending**: When a new action chunk overlaps with the current buffer, blend via configurable aggregation function (weighted average, latest-only, etc.).
|
||||
7. **Latency compensation**: Calculate one-way latency from RTT, discard expired initial steps of incoming action chunks.
|
||||
8. **Fail-safe**: If action buffer empties, logs a warning.
|
||||
9. **Data logging**: Record raw observations and executed actions to local `LeRobotDataset` storage for deferred upload.
|
||||
|
||||
**Threading model:**
|
||||
|
||||
- **Control loop thread** (main): Capture observation → deposit in outbox → pop action from buffer → send to robot → sleep to maintain frequency.
|
||||
- **Zenoh action callback** (Zenoh-managed): Receives action chunks, processes RTT, trims stale steps, deposits into action buffer.
|
||||
- **Observation publisher thread**: Drains the outbox, compresses images, serializes, publishes via Zenoh.
|
||||
|
||||
> **Design note**: The current prototype blocks on `send_observation` inside the control loop (BUG-1, see Section 9). The new design decouples observation publishing from the control loop entirely, using a separate thread and Zenoh's non-blocking put.
|
||||
|
||||
### 5.2 The Inference Server (GPU Pod)
|
||||
|
||||
**Responsibilities:**
|
||||
|
||||
1. **Model pre-warming**: Load model and processor pipelines at startup from config manifest (including expected clients & policy parameters).
|
||||
2. **Status publishing**: Expose model capabilities (policy type, expected camera names, resolutions, action dimensions) via Zenoh queryable.
|
||||
3. **Observation subscription**: Subscribe to observation topics for all clients of this model/task. Maintain per-client observation slots (newest-only semantics).
|
||||
4. **Inference**: Single inference thread processes observations sequentially (round-robin across clients). Calls `policy.predict_action_chunk()`.
|
||||
5. **Action publishing**: Publish action chunks to per-client action topics with reliable QoS.
|
||||
|
||||
> **Thread safety**: PyTorch's `model.forward()` is not guaranteed thread-safe. Inference will be sequential, latency is mostly about the capabilities of the server to serve multiple requests.
|
||||
|
||||
---
|
||||
|
||||
## 6. Zenoh Routing & Key Expressions
|
||||
|
||||
### 6.1 Key Expression Schema
|
||||
|
||||
```
|
||||
[cluster_uuid] / [experiment_tag] / [model_id] / [model_version] / [application_tag] / [client_uuid] / [topic]
|
||||
```
|
||||
|
||||
**Example key expressions:**
|
||||
|
||||
| Key Expression | Direction | Purpose |
|
||||
| ------------------------------------------------ | ----------------- | ---------------------------------- |
|
||||
| `jupiter/fabio2/pi0/v1/cookie/robot_a4b9/obs` | Client → Server | Observation payload |
|
||||
| `jupiter/fabio2/pi0/v1/cookie/robot_a4b9/action` | Server → Client | Action chunk |
|
||||
| `jupiter/fabio2/pi0/v1/cookie/*/obs` | Server subscribes | All observations for pi0/v1/cookie |
|
||||
| `jupiter/fabio2/pi0/v1/cookie/status` | Server publishes | Model capabilities (queryable) |
|
||||
|
||||
### 6.2 QoS Configuration
|
||||
|
||||
| Topic | Reliability | Rationale |
|
||||
| -------- | ----------- | -------------------------------------------------------------------- |
|
||||
| `obs` | Best-effort | Dropping stale observations is expected behavior. |
|
||||
| `action` | Reliable | Every action chunk must be delivered; loss causes action starvation. |
|
||||
| `status` | Reliable | Client needs accurate capability info before starting. |
|
||||
|
||||
### 6.3 Discovery Flow
|
||||
|
||||
0. Server goes up with the static configuration.
|
||||
1. Client constructs its target key prefix: `cluster/experiment/model/version/task/`.
|
||||
2. Client queries `cluster/experiment/model/version/task/status` (Zenoh queryable).
|
||||
3. Server responds with its capabilities (expected camera names, image resolutions, action dimensions, model metadata).
|
||||
4. Client validates its own configuration against server capabilities.
|
||||
5. On match: client starts publishing observations and subscribing to actions.
|
||||
6. On mismatch: client logs an error and refuses to start.
|
||||
|
||||
No dynamic client discovery for now.
|
||||
|
||||
---
|
||||
|
||||
## 7. Message Schema
|
||||
|
||||
### 7.1 Observation Payload (Client → Server)
|
||||
|
||||
| Field | Type | Purpose |
|
||||
| ------------- | ------------------ | ----------------------------------------------------------- |
|
||||
| `seq_id` | `uint64` | Incrementing ID for causality tracking and RTT computation. |
|
||||
| `client_uuid` | `string` | Identifies the sending client. |
|
||||
| `state` | `bytes` | Proprioceptive state vector (`numpy.tobytes()`). |
|
||||
| `images` | `dict[str, bytes]` | JPEG-compressed camera images, keyed by camera name. |
|
||||
| `task` | `string` | Natural-language task instruction (for VLA conditioning). |
|
||||
|
||||
### 7.2 Action Payload (Server → Client)
|
||||
|
||||
| Field | Type | Purpose |
|
||||
| -------------------- | --------- | --------------------------------------------------------------- |
|
||||
| `response_to_seq_id` | `uint64` | Echoes the observation `seq_id` this action corresponds to. |
|
||||
| `inference_time_ms` | `float32` | Server-side compute duration (for edge RTT math). |
|
||||
| `actions` | `bytes` | Action chunk as numpy array bytes (`(chunk_size, action_dim)`). |
|
||||
|
||||
### 7.3 Status Payload (Server, Queryable)
|
||||
|
||||
| Field | Type | Purpose |
|
||||
| ----------------------- | ------------------- | ------------------------------------------ |
|
||||
| `model_id` | `string` | Policy identifier (e.g., `pi0`). |
|
||||
| `model_version` | `string` | Model version or checkpoint path. |
|
||||
| `expected_cameras` | `dict[str, (H, W)]` | Expected camera names and shapes. |
|
||||
| `action_dim` | `int` | Dimensionality of the action space. |
|
||||
| `max_actions_per_chunk` | `int` | Maximum chunk size the model supports. |
|
||||
| `observation_features` | `dict` | Full feature specification for validation. |
|
||||
|
||||
### 7.4 Serialization Format
|
||||
|
||||
**MessagePack** for all structured metadata (compact, fast, cross-language). Image payloads are raw JPEG bytes embedded in the MessagePack structure. State vectors use `numpy.tobytes()` with shape/dtype metadata for zero-copy reconstruction.
|
||||
|
||||
**No pickle.** The current prototype uses `pickle.dumps`/`pickle.loads` throughout, which allows arbitrary code execution. This is replaced entirely.
|
||||
|
||||
---
|
||||
|
||||
## 8. Latency Compensation
|
||||
|
||||
### 8.1 RTT Calculation
|
||||
|
||||
The edge device tracks in-flight observations:
|
||||
|
||||
```python
|
||||
in_flight: dict[int, float] = {} # seq_id -> time.perf_counter() at send
|
||||
|
||||
# On send:
|
||||
in_flight[seq_id] = time.perf_counter()
|
||||
|
||||
# On receive action chunk:
|
||||
rtt = time.perf_counter() - in_flight[response_to_seq_id]
|
||||
# delete older keys than the one received
|
||||
```
|
||||
|
||||
> **Important**: Delete only the exact `response_to_seq_id` key from `in_flight`, not all keys `<= response_to_seq_id`. With Zenoh's best-effort transport, messages can arrive out of order. Clearing earlier keys would make their RTT unmeasurable.
|
||||
|
||||
### 8.2 Stale Action Trimming
|
||||
|
||||
When an action chunk arrives, the edge calculates how many initial steps have already expired:
|
||||
|
||||
```python
|
||||
expired_steps = int(rtt / environment_dt)
|
||||
valid_actions = action_chunk[expired_steps:]
|
||||
```
|
||||
|
||||
The valid actions are then blended into the action buffer using the configured aggregation function.
|
||||
|
||||
### 8.3 Edge Cases
|
||||
|
||||
| Scenario | Behavior |
|
||||
| -------------------------------------- | -------------------------------------------------------------------------------------- |
|
||||
| **First observation** (no RTT history) | Apply all action steps without trimming. |
|
||||
| **Dropped observations** | Server infers on next received observation. No special handling needed. |
|
||||
| **Dropped action chunks** | Edge continues executing current buffer. If buffer empties, warn & hold last position. |
|
||||
| **Server crash** | Edge exhausts buffer, holds position, warns & re-validates via status query. |
|
||||
|
||||
> **Assumption**: All currently supported robots are position-controlled (SO100, SO101, OMX). For velocity-controlled robots, the fail-safe must send zero-velocity instead of holding position. This should be configurable per-robot.
|
||||
|
||||
---
|
||||
|
||||
## 9. Known Bugs in Current Prototype
|
||||
|
||||
These issues exist in `src/lerobot/async_inference/` and must be addressed in the new implementation.
|
||||
|
||||
### BUG-1: `send_observation` Blocks the Control Loop (Critical)
|
||||
|
||||
**Location**: `robot_client.py:207`
|
||||
|
||||
`self.stub.SendObservations(observation_iterator)` is a synchronous gRPC call inside the 33ms control loop. For multi-camera observations (several MB after pickle), this consumes 10-20ms on the network, leaving no headroom for sensor capture and motor commands. The robot stutters.
|
||||
|
||||
**Resolution in new design**: Observation publishing is moved to a dedicated thread. Zenoh's `session.put()` is non-blocking by default. The control loop only deposits observations into a local outbox.
|
||||
|
||||
### BUG-2: Race Condition in Action Queue Aggregation (Correctness)
|
||||
|
||||
**Location**: `robot_client.py:236-267`
|
||||
|
||||
The lock on `self.action_queue` is acquired to read `internal_queue = self.action_queue.queue` (a reference to the internal deque), then **released** at line 238. The aggregation logic iterates over this reference outside the lock. Meanwhile, the control loop thread can `get_nowait()` from the same queue, mutating the deque during iteration. At line 267, the entire queue is replaced, but actions popped between 238-267 are silently lost.
|
||||
|
||||
**Fix**: Either hold the lock for the entire aggregation, or `list(self.action_queue.queue)` to copy contents before releasing.
|
||||
|
||||
### BUG-3: No RPC Deadlines (Reliability)
|
||||
|
||||
**Location**: `robot_client.py:278`
|
||||
|
||||
`GetActions` blocks indefinitely if the server hangs (GPU OOM, deadlock). The retry policy handles `UNAVAILABLE` but not a hung connection.
|
||||
|
||||
**Resolution in new design**: The polling `GetActions` pattern is replaced by Zenoh subscription callbacks. The client needs a watchdog timer or check when action queue is empty: if no actions are received for `T` seconds, trigger re-validation via the status service.
|
||||
|
||||
### BUG-4: Similarity Check Ignores Images (Correctness for VLAs)
|
||||
|
||||
**Location**: `helpers.py:280-297`
|
||||
|
||||
`observations_similar()` + `must_go` is a workaround for current architecure limitations to avoid filling up the server queue the first seconds of the task & the robot remaining idle.
|
||||
|
||||
**Resolution in new design**: the server always processes the latest observation per client in its inference loop, and doesn't need similarity gating at all. The client can always push.
|
||||
|
||||
---
|
||||
|
||||
## 10. Gaps Between Prototype and Target Architecture
|
||||
|
||||
### 10.1 Critical (Must Address)
|
||||
|
||||
| # | Gap | Current State | Target State |
|
||||
| --- | ------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| G1 | **Single-client server** | One `observation_queue(maxsize=1)`, one `last_processed_obs`, one `_predicted_timesteps`. `_reset_server()` flushes all state on any new connection. | Per-client state (`ClientState` dataclass) keyed by `client_uuid`. Zenoh key-expression routing provides client isolation. |
|
||||
| G2 | **Dynamic model loading** | Client sends `RemotePolicyConfig` → server calls `from_pretrained()` on demand. | Server loads models at startup from config manifest. `SendPolicyInstructions` RPC eliminated. Client validates via status query. |
|
||||
| G3 | **gRPC transport** | Entire `transport/` directory: proto definitions, generated stubs, chunking utils. 4 RPCs: `Ready`, `SendPolicyInstructions`, `SendObservations`, `GetActions`. | Zenoh pub/sub. Client publishes obs, subscribes to actions. Server subscribes to obs, publishes actions. Dispatching via key expressions. |
|
||||
| G4 | **Pickle serialization** | `pickle.dumps`/`pickle.loads` throughout (arbitrary code execution risk, `# nosec` suppression). | MessagePack for structured metadata + raw JPEG bytes for images + `numpy.tobytes()` for state vectors. |
|
||||
|
||||
### 10.2 Important
|
||||
|
||||
| # | Gap | Current State | Target State |
|
||||
| --- | -------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | -------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| G5 | **No RTT/latency compensation** | No `seq_id`, no `response_to_seq_id`, no `inference_time_ms`. Timestamps use `time.time()` (unreliable across machines). | Edge-local `perf_counter` + echoed `seq_id` + server inference duration. Stale action step trimming. |
|
||||
| G6 | **No hierarchical routing** | Direct gRPC channel to `host:port`. | Zenoh key expressions: `cluster/experiment/model/version/task/client/topic`. |
|
||||
| G7 | **No data logging** | `control_loop` has access to obs and actions but doesn't persist them. | Edge records via `LeRobotDataset` (`build_dataset_frame` + `dataset.add_frame`). |
|
||||
| G8 | **No authentication** | `grpc.insecure_channel`. | Zenoh TLS + access control lists on key expressions. |
|
||||
| G9 | **ProcessorPipeline divergence** | Server reimplements observation prep in `helpers.py` (custom `resize_robot_observation_image` with `F.interpolate` bilinear). Diverges from standard `RobotProcessorPipeline`. | Use the standard `RobotProcessorPipeline` + `build_dataset_frame` to ensure behavioral equivalence between record and async inference. |
|
||||
|
||||
### 10.3 Nice-to-Have
|
||||
|
||||
| # | Gap | Current State | Target State |
|
||||
| --- | ------------------------------------- | --------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| G11 | **No status/discovery service** | Bare `Ready()` ping. | Zenoh queryable at `cluster/exp/model/version/task/status`. |
|
||||
| G12 | **No monitoring** | `FPSTracker` + `logging.debug`. | Structured metrics via Zenoh telemetry topics. Wildcard subscriptions for centralized monitoring. |
|
||||
| G13 | **No entry points** | Module-level `__main__`. | `lerobot-policy-server` and `lerobot-robot-client` console scripts in `pyproject.toml`. |
|
||||
| G14 | **Ratio-based observation threshold** | `chunk_size_threshold` (0-1 ratio of queue fill). Scales oddly with different `actions_per_chunk` values. | Absolute time threshold: `buffer_time_s` calibrated to observed RTT. Send observation when `queue_size * environment_dt < buffer_time_s`. |
|
||||
|
||||
---
|
||||
|
||||
## 11. Design Decisions & Rationale
|
||||
|
||||
### 11.1 Why Zenoh Over gRPC
|
||||
|
||||
| Aspect | Zenoh | gRPC |
|
||||
| ------------------------- | -------------------------------------------------------------------------- | ---------------------------------------------------------------------------------- |
|
||||
| Communication model | Pub/sub — natural fit for "client publishes obs, server publishes actions" | Request/response — requires polling (`GetActions` loop) or bidirectional streaming |
|
||||
| Multi-tenant routing | Hierarchical key expressions provide built-in per-client topic isolation | Requires manual per-client channel/stream management |
|
||||
| Discovery | Built-in discovery | Requires external service (mDNS, Consul, etc.) |
|
||||
| Observation publishing | Non-blocking put (fire-and-forget) — resolves BUG-1 automatically | Synchronous stream-unary call — blocks the control loop |
|
||||
| Same-machine optimization | Shared-memory transport (zero-copy) | Loopback TCP |
|
||||
| Telemetry | Wildcard subscriptions (`+/+/+/+/+/metrics`) | Requires separate monitoring infrastructure |
|
||||
|
||||
**Tradeoffs of going Zenoh-only:**
|
||||
|
||||
- Smaller community, less tooling for monitoring/tracing vs. gRPC's mature ecosystem.
|
||||
- No built-in schema enforcement (Zenoh sends raw bytes) — serialization correctness is entirely on us.
|
||||
- Default QoS is best-effort (like UDP). Must explicitly configure reliable delivery for action chunks.
|
||||
- `zenoh-python` bindings are less battle-tested than `grpcio`. Needs integration testing under network stress.
|
||||
|
||||
### 11.2 Why Single Inference Thread (Not Batching)
|
||||
|
||||
True GPU batching across clients requires collecting observations from multiple clients and running a single forward pass. This is difficult because:
|
||||
|
||||
- Clients send observations at different times — waiting to batch adds latency.
|
||||
- Different clients may have slightly different image resolutions.
|
||||
- Error in one client's observation shouldn't affect others.
|
||||
|
||||
**Decision**: Start with sequential processing (single inference thread, round-robin across clients). Profile GPU utilization.
|
||||
|
||||
### 11.4 Why MessagePack (Not Protobuf, Not FlatBuffers)
|
||||
|
||||
- **Protobuf**: Strong schema enforcement but heavier toolchain (proto compilation, generated code). Since we're dropping gRPC, the protobuf dependency becomes unnecessary overhead.
|
||||
- **MessagePack**: Fast, compact, schema-less (enforced by application), excellent Python support (`msgpack` package), good for nested dicts with mixed types. Natural fit for observation/action payloads.
|
||||
|
||||
Images are embedded as raw JPEG bytes within the MessagePack structure. State vectors use `numpy.tobytes()` with shape/dtype metadata for zero-copy reconstruction.
|
||||
|
||||
### 11.5 Action Aggregation Strategy
|
||||
|
||||
When a new action chunk overlaps with the existing buffer, the overlapping timesteps must be blended. The current prototype supports configurable aggregation functions:
|
||||
|
||||
| Function | Formula | Character |
|
||||
| ------------------ | ----------------------- | ------------------------------------------ |
|
||||
| `weighted_average` | `0.3 * old + 0.7 * new` | Smooth transitions, favors new predictions |
|
||||
| `latest_only` | `new` | Most responsive, can cause discontinuities |
|
||||
| `average` | `0.5 * old + 0.5 * new` | Equal weight |
|
||||
| `conservative` | `0.7 * old + 0.3 * new` | Smooth, slow to adapt |
|
||||
|
||||
Ultimately, this should be the user's decision. Default to `weighted_average`. The goal of async is not to do temporal ensembling, but to provide a solution when we want to decouple inference and execution.
|
||||
|
||||
---
|
||||
|
||||
## 12. Configuration
|
||||
|
||||
### 12.1 Server Configuration (Manifest)
|
||||
|
||||
Servers are configured via a YAML manifest that declares which models to pre-warm & clients to serve:
|
||||
|
||||
```yaml
|
||||
cluster_uuid: jupiter
|
||||
experiment_tag: fabio2
|
||||
server:
|
||||
- model_id: pi0
|
||||
model_version: v1
|
||||
pretrained_path: lerobot/pi0-cookie-v1
|
||||
application_tag: cookie
|
||||
device: cuda:0
|
||||
fps: 30
|
||||
endpoint: tcp/192.168.1.50:7447
|
||||
clients:
|
||||
- client_uuid: cookie-worker-4269
|
||||
```
|
||||
|
||||
### 12.2 Client Configuration
|
||||
|
||||
Clients are configured via draccus dataclass (CLI-compatible):
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class AsyncClientConfig:
|
||||
# Zenoh routing
|
||||
cluster_uuid: str
|
||||
experiment_tag: str
|
||||
model_id: str
|
||||
model_version: str
|
||||
application_tag: str
|
||||
client_uuid: str
|
||||
endpoint: str
|
||||
|
||||
# Robot
|
||||
robot: RobotConfig
|
||||
|
||||
# Control
|
||||
fps: int = 30
|
||||
actions_per_chunk: int = 50
|
||||
aggregate_fn_name: str = "weighted_average"
|
||||
jpeg_quality: int = 90
|
||||
|
||||
# Fail-safe
|
||||
max_empty_cycles_before_warning: int = 10
|
||||
|
||||
# Datset recording
|
||||
dataset_repo_id: str | None = None # None = no logging
|
||||
|
||||
# Task
|
||||
task: str = ""
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 14. Data Logging Integration
|
||||
|
||||
The client records observations and executed actions into a local `LeRobotDataset` for deferred upload to the training dataset:
|
||||
|
||||
```python
|
||||
# In control_loop, after executing an action:
|
||||
if self.dataset is not None:
|
||||
frame = build_dataset_frame(
|
||||
self.dataset.features,
|
||||
processed_observation,
|
||||
prefix=OBS_STR,
|
||||
)
|
||||
frame["action"] = executed_action_tensor
|
||||
self.dataset.add_frame(frame)
|
||||
```
|
||||
@@ -0,0 +1,498 @@
|
||||
# Decoupled VLA Inference & Edge Control v2: Async Network Inference for `lerobot-rollout`
|
||||
|
||||
> **Status**: supersedes the v1 proposal in full. v1 was written against the standalone `src/lerobot/async_inference/` prototype, before `lerobot-rollout` existed. This revision re-grounds the design in the current codebase, keeps v1's decisions that survived contact with it (marked **KEPT** throughout), reverses the ones that didn't, and adds the safety, multi-tenancy, and operations specifications v1 lacked.
|
||||
|
||||
## 1. Executive Summary
|
||||
|
||||
This document specifies a production-grade system for decoupling GPU-bound policy inference from high-frequency robot control, targeting power users running **hundreds of robots** against centralized GPU clusters. The system keeps v1's **Model-as-a-Service (MaaS)** paradigm and **Zenoh** transport, but changes the integration architecture fundamentally:
|
||||
|
||||
- **The client is not a standalone CLI.** It is `--inference.type=remote`, a new `InferenceEngine` backend inside `lerobot-rollout` (`src/lerobot/rollout/inference/`). Every rollout strategy (base, sentry, highlight, dagger, episodic) gets network inference for free — including dataset recording, DAgger pause/resume, Rerun visualization, and safe teardown.
|
||||
- **The client is weightless.** No policy weights, no policy processors on the edge. `--policy.path` resolves to a config-only `PreTrainedConfig` (no weight download) used for pre-flight validation and action ordering.
|
||||
- **The server is stateless per request.** All RTC chunk state (leftover prefixes, latency tracking, delay computation) lives client-side in the existing `ActionQueue`/`LatencyTracker` machinery — the client ships prefixes + a delay hint with each observation. A server crash loses zero control state; reconnects and horizontal scaling are trivial.
|
||||
- **Multi-tenancy is engineered, not assumed.** The real hazards are stateful processor pipelines and episode-scoped policy state — not `predict_action_chunk` purity (which holds for ACT/Pi0/Pi0.5/SmolVLA but _not_ diffusion). The server uses per-session processor instances, a chunk-stateless allowlist, and an exclusive serving mode for policies that need it.
|
||||
- **The legacy module dies.** `src/lerobot/async_inference/` (~1,900 lines, pickle-over-gRPC, single-client, four confirmed bugs) is deleted in the same PR that lands the new backend. No deprecation cycle: the module is experimental, its CLI undocumented in the main flow, and every config field has a mapped successor (§13.4).
|
||||
|
||||
---
|
||||
|
||||
## 2. Motivation (unchanged from v1) — **KEPT**
|
||||
|
||||
LeRobot's standard control loop runs policy inference and robot I/O in the same process. This breaks down when:
|
||||
|
||||
- **The policy is too large for edge hardware** (Pi0-class models need a dedicated GPU).
|
||||
- **Multiple robots need the same policy** (redundant GPU allocation per robot).
|
||||
- **Inference latency exceeds the control deadline** (e.g. 150 ms inference on a 33 ms control tick).
|
||||
|
||||
Decoupling solves all three: the edge runs a tight CPU loop; a GPU server performs inference for N clients.
|
||||
|
||||
What changed since v1: the _local_ version of this decoupling already shipped. `RTCInferenceEngine` (`src/lerobot/rollout/inference/rtc.py`) runs inference in a background thread against a thread-safe `ActionQueue` with latency-aware chunk merging. **The network system is that same architecture with the thread boundary replaced by a network boundary.** This is the design's central simplification: reuse, don't reinvent.
|
||||
|
||||
---
|
||||
|
||||
## 3. Gap Analysis: v1 Proposal vs. Modern Codebase
|
||||
|
||||
| Topic | v1 assumed | Modern reality | Verdict |
|
||||
| ----------------------------------------- | --------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------- | --------------------------------------- |
|
||||
| Client architecture | Standalone robot-client CLI (§5.1 of v1) | `InferenceEngine` ABC seam in `lerobot-rollout` (`rollout/inference/base.py`); strategies are backend-agnostic | **Superseded** — backend, not CLI |
|
||||
| Chunk blending | Configurable aggregation zoo (`weighted_average`, …) | `ActionQueue` replace-with-delay-trim (RTC) / append (non-RTC) (`policies/rtc/action_queue.py:147-217`) | **Superseded** — drop blending entirely |
|
||||
| Latency compensation | Hand-rolled RTT trim (`expired_steps = int(rtt/dt)`, v1 §8.2) | `ActionQueue.merge(..., real_delay, idx_before)` + `LatencyTracker` already do this, validated | **Superseded** |
|
||||
| Multi-tenancy invariant | "`predict_action_chunk()` pure ⇒ safe to share" | Processor state + episode-scoped policy state are the real hazards (§7) | **Incomplete** — fixed in §8.3 |
|
||||
| Data logging | Client-side `build_dataset_frame` + `add_frame` sketch (v1 §14) | Recording strategies (sentry/episodic/dagger) already log obs + executed actions | **Superseded** — free via rollout |
|
||||
| MaaS pre-warm, no dynamic loading | ✓ | Still right; legacy `SendPolicyInstructions` is a pickle/RCE + capacity-planning disaster | **KEPT** |
|
||||
| JPEG observation compression | ✓ | Still right (§10.1) | **KEPT** |
|
||||
| Status/capability validation before start | ✓ (Zenoh queryable) | Still right; extended into a hard sync-safety contract (§8.4) | **KEPT, extended** |
|
||||
| Time-based send threshold (v1 G14) | ✓ | Adopted as `buffer_time_s` | **KEPT** |
|
||||
| Zenoh pub/sub data plane | ✓ | Confirmed; QoS corrected (§6.3), control plane moved to queryables, liveliness added | **KEPT, hardened** |
|
||||
| MessagePack serialization | ✓ | Endorsed (zenoh's `ext` serializer cannot encode numpy); must be version-gated (§10.4) | **KEPT, with schema discipline** |
|
||||
| QoS table (v1 §6.2) | "obs best-effort, actions reliable" | Conflates transport reliability with congestion control; BLOCK on actions is dangerous | **Revised** (§6.3) |
|
||||
| Bugs BUG-1…BUG-4, gaps G1…G14 | Listed as work items | Every one resolved _structurally_ by this design (§13.5 mapping) | **Resolved by design** |
|
||||
|
||||
---
|
||||
|
||||
## 4. Critical Pushbacks on v1
|
||||
|
||||
Each pushback: claim → evidence → consequence for this design.
|
||||
|
||||
**P1 — A standalone client duplicates `lerobot-rollout`.**
|
||||
v1 §5.1 assigns the client: observation capture, action execution at frequency, fail-safe, data logging. Every one of those is already owned by rollout strategies and `send_next_action` (`rollout/strategies/core.py:269-304`), which tolerates `None` actions, runs the interpolator, and routes through the canonical robot processors. A standalone client re-implements loop timing, recording, DAgger UX, Rerun, and teardown safety — and then drifts. _Consequence_: the client is `RemoteInferenceEngine`, registered as `--inference.type=remote` next to `sync` and `rtc`.
|
||||
|
||||
**P2 — The aggregation-function zoo fabricates actions no policy predicted.**
|
||||
`0.3*old + 0.7*new` produces hybrid actions that exist in no policy's output distribution; the logged action becomes unexplainable (bad for the reproducibility story) and the implementation hosted a real lock-release race (BUG-2, `async_inference/robot_client.py:236-267`). RTC's prefix-conditioned chunk generation is the principled mechanism for smooth chunk transitions; plain append covers non-RTC chunking. _Consequence_: `ActionQueue` replace/append are the only two merge semantics. The zoo is deleted.
|
||||
|
||||
**P3 — "predict_action_chunk pure ⇒ multi-tenant safe" is incomplete.**
|
||||
Verified in-tree: (a) `RelativeActionsProcessorStep` caches `_last_state` at preprocess (`processor/relative_action_processor.py:131`) and the postprocessor reads it back (`:189`) — a shared pipeline across clients is a race; (b) `DiffusionPolicy.predict_action_chunk` reads `self._queues`, which only `select_action` populates (`policies/diffusion/modeling_diffusion.py:90-108`) — it is **not** chunk-stateless; (c) SAC/SARM have no `predict_action_chunk` at all. _Consequence_: per-session processor instances (mandatory), a chunk-stateless allowlist, `serving_mode: exclusive` for diffusion-family, refusal at startup for SAC/SARM, and `policy.reset()` is **never** called in shared mode (§8.3).
|
||||
|
||||
**P4 — v1 re-derives latency compensation that already exists, on top of broken clocks.**
|
||||
v1 §8 specifies an in-flight RTT dict and manual stale-step trimming. `ActionQueue.merge(original, processed, real_delay, idx_before)` already trims `real_delay` stale steps and cross-validates against actions consumed in flight (`action_queue.py:219-246`). Worse, the legacy code compares wall clocks across machines (`robot_client.py:420` stamps `time.time()` "to compare timestamps across client and server"; `policy_server.py:178` compares it) — NTP skew is the same order as the latencies being measured. _Consequence_: the **monotonic iron rule** (§11): instants never cross machines; client timestamps are opaque echoed tokens; servers report only durations. `delay_steps = ceil((rtt + inference)/dt)` is computed client-side from client-local `perf_counter` samples and shipped per request.
|
||||
|
||||
**P5 — One-in-flight per client is a correctness requirement, not a tuning choice.**
|
||||
At send time the client snapshots `idx_before = queue.get_action_index()` and the leftover prefixes; `merge` validates against them. Two in-flight requests carry conflicting snapshots — the second merge corrupts both RTC replace mode and append mode. The local RTC thread is also strictly one-inference-at-a-time; one-in-flight preserves exact parity. _Consequence_: the worker publishes one observation, waits for its chunk (or timeout), then sends the next. v1 §8.1's out-of-order in-flight dict is dead weight; a late chunk is accepted only if it answers the _latest_ outstanding `seq_id`, otherwise dropped.
|
||||
|
||||
**P6 — v1's QoS table conflates transport reliability with congestion behavior.**
|
||||
"Reliable delivery for actions" sounds right but the dangerous knob is congestion control: a publisher configured `BLOCK` on the action topic can stall the **server's** publish path on one robot's dead uplink (Zenoh blocks up to `wait_before_close`, then may close the transport). A dropped action chunk is _recoverable by design_ — the client's queue keeps the robot moving and the next chunk replaces it. _Consequence_ (§6.3): actions = `reliability=RELIABLE` (hop-level) + `congestion_control=DROP` + `express=True` + `priority=INTERACTIVE_HIGH`; observations = `DROP` + `DATA`. If WAN loss proves material, upgrade the action topic to Zenoh Advanced Pub/Sub (cache + recovery, zenoh ≥ 1.5) rather than BLOCK.
|
||||
|
||||
**P7 — Schema-less MessagePack invites silent version drift across a 300-robot fleet.**
|
||||
msgpack stays (zenoh's `ext` serializer cannot encode numpy/dataclasses, and the team's choice stands), but naked msgpack dicts across heterogeneous fleet versions fail at runtime, on the robot. _Consequence_ (§10.4): a packed little-endian **attachment header** (`schema_version`, `seq_id`, `episode_id`, `client_mono_ns` — the rmw_zenoh pattern) so routing/correlation never deserializes the body; `schema_version` negotiated at the session handshake; additive-only evolution; golden codec tests. Protobuf-over-ZBytes is the documented fallback if drift bites in practice.
|
||||
|
||||
**P8 — "Deterministic rollout reproducibility" is unattainable on real robots.**
|
||||
No seed controls hardware, sensor noise, or network jitter; RTC's latency-driven trimming is inherently timing-dependent. _Consequence_: the contract is **fully logged + replayable** (§12): recording strategies already persist observations and executed actions; the remote engine adds `(session_id, seq_id, episode_id)` provenance so client datasets join server audit logs mechanically.
|
||||
|
||||
**P9 — v1 has no safety specification.**
|
||||
"Log a warning when the buffer empties" is not a fail-safe for a 300-robot fleet. _Consequence_ (§9): a staleness bound (`max_action_age_s` — never execute an action older than X relative to its source observation), an explicit fallback ladder (`hold` / `repeat_last` / `zero` — zero-command required for future velocity-controlled robots), and a DEAD state that triggers the existing strategy shutdown path (return-to-initial-pose, disconnect) via the same `shutdown_event` mechanism RTC uses (`rtc.py:359-360`).
|
||||
|
||||
**P10 — Capacity must be formula-driven, not "a user decision".**
|
||||
v1 §4 says clients-per-server "is a user decision". With `t` = server time per request, `r` = per-client request rate, `H` = RTC execution horizon, `dt` = control period:
|
||||
`N_max = min( 0.8 / (r·t), (H·dt/2 − RTT_net) / t )`
|
||||
→ ACT @ 20 ms, 1 Hz: ~40 clients/GPU. Pi0 @ 150 ms, 1 Hz: ~5 clients/GPU. 300 robots on Pi0 ≈ 60 GPU pods. _Consequence_: the manifest carries `max_sessions`; the server rejects session opens beyond it (with current load in the reply) so clients retry another replica. Micro-batching is deferred — blocked on a real API issue (`predict_action_chunk` takes a _scalar_ `inference_delay`; batched clients have different delays) — behind a `Scheduler` seam so it can land later without redesign (§8.5).
|
||||
|
||||
**P11 — Discovery ≠ multicast.**
|
||||
Zenoh's multicast scouting does not cross WAN, NAT, or most k8s CNIs. _Consequence_: multicast scouting disabled; clients use static `connect.endpoints` (DNS name of the router) + gossip; presence and liveness come from Zenoh **liveliness tokens** (§6.4), not discovery. "Discovery" for a robot fleet is configuration.
|
||||
|
||||
---
|
||||
|
||||
## 5. System Topology
|
||||
|
||||

|
||||
_(Diagram unchanged from v1 — the topology survives; transport/QoS/session details in it are superseded by §6.)_
|
||||
|
||||
- **Router tier**: one or more `zenohd` routers (k8s Deployment + Service, TLS on 7447). Robots **dial out** to the router (NAT-friendly: labs only need outbound 7447/443). GPU servers join as peers via cluster DNS.
|
||||
- **Server**: one process = one `(model_repo, revision, dtype, device)` on one GPU, pre-warmed from a YAML manifest (**KEPT** from v1, amended: `pin_task: bool` — VLA prompts may vary per session unless pinned).
|
||||
- **Client**: one robot running `lerobot-rollout --inference.type=remote`. Weightless: config-only policy metadata.
|
||||
- **Identity**: `client_uuid` per robot; `session_id` per connection epoch; both in every log line on both sides.
|
||||
|
||||
---
|
||||
|
||||
## 6. Zenoh Design
|
||||
|
||||
All Zenoh claims below were verified against zenoh / zenoh-python 1.x (eclipse-zenoh 1.9.0). Pin: `eclipse-zenoh>=1.9,<2.0`; keep `zenohd` on the same minor as the Python binding. Wheels cover manylinux x86_64/aarch64/armv7l/armv6l + macOS — Raspberry Pi edge clients are covered.
|
||||
|
||||
### 6.1 Key-expression schema
|
||||
|
||||
```
|
||||
@lerobot/<model_id>/<revision>/<task_slug>/<client_uuid>/obs client → server
|
||||
@lerobot/<model_id>/<revision>/<task_slug>/<client_uuid>/action server → client
|
||||
@lerobot/<model_id>/<revision>/<task_slug>/status queryable (capabilities)
|
||||
@lerobot/<model_id>/<revision>/<task_slug>/session queryable (open/validate)
|
||||
@lerobot/<model_id>/<revision>/<task_slug>/<client_uuid>/reset queryable (episode boundary)
|
||||
@lerobot/<model_id>/<revision>/<task_slug>/<client_uuid>/alive liveliness token (client)
|
||||
@lerobot/<model_id>/<revision>/<task_slug>/server/alive liveliness token (server)
|
||||
```
|
||||
|
||||
Rules (hard, enforced by a `sanitize_keyexpr()` helper):
|
||||
|
||||
- Root at the **verbatim chunk** `@lerobot` — verbatim chunks are only matched by identical chunks, so third-party `**` subscribers on a shared router can never scrape the tree.
|
||||
- Sanitize every user-supplied segment (model ids, task strings, uuids): non-empty, no `* $ ? # /`, no leading/trailing/double `/`. A task string containing `/` must be slugified before it becomes a key chunk.
|
||||
- Server subscribes with a **single-depth** wildcard (`.../*/obs`) — never `**` (it would also match `status`, `alive`, …).
|
||||
- v1's `cluster/experiment` prefix segments are dropped from the key schema; they return as free-form `tags` metadata in the session handshake (telemetry/labeling, not routing). Routing topology belongs to deployment (which router you dial), not to key depth.
|
||||
|
||||
### 6.2 Data plane vs. control plane (the rmw_zenoh split)
|
||||
|
||||
- **Data plane = pub/sub** (KEPT from v1): observations up, action chunks down, correlated by `seq_id` in **attachments** (§10.4). Pub/sub rather than query-per-inference because: a timed-out query's late reply is _dropped by the transport_ (wasted inference), whereas a late pub/sub chunk is still mergeable if it answers the latest outstanding seq; and pub/sub leaves room for server-initiated messages (drain notices). The one-in-flight discipline (P5) is enforced in the client worker, not by the transport.
|
||||
- **Control plane = queryables** (request/reply with explicit timeouts; the pattern rmw*zenoh uses for ROS 2 services): `status` (pre-flight capability fetch, 2 s timeout), `session` (open/validate → ack with capabilities + `session_id`), `reset` (episode boundary — \_acknowledged*, so episodic strategies know the server-side episode state is clean). Always pass an explicit `timeout` to `session.get()` — the config default is 10 s, far too long for our watchdogs.
|
||||
- **Episode ordering**: under one-in-flight there is no obs/reset race window in the data plane, but as belt-and-braces the first observation of each episode also carries `episode_start=True` + the new `episode_id` in its header.
|
||||
|
||||
### 6.3 QoS (revised from v1 §6.2 — see P6)
|
||||
|
||||
| Topic | reliability | congestion_control | express | priority | Why |
|
||||
| ------------------ | ----------- | ---------------------- | -------- | ---------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
|
||||
| `obs` | default | **DROP** | false | DATA | Intentional drop already happened at the client's one-slot holder; if the uplink stalls, dropping a frame protects the control loop. |
|
||||
| `action` | RELIABLE | **DROP** (never BLOCK) | **true** | INTERACTIVE_HIGH | Hop-level reliability over TCP; express skips batching for the small (4–50 KB) latency-critical payload; DROP so one dead robot uplink can never stall the server's publish path. Chunk loss is recoverable: the client buffer rides through it. |
|
||||
| control queryables | RELIABLE | default | — | — | Correctness over latency; explicit timeouts bound them. |
|
||||
|
||||
Upgrade path if WAN chunk loss proves material: `AdvancedPublisher`/`AdvancedSubscriber` (zenoh ≥ 1.5) with a small cache + heartbeat-based recovery **on the action topic only**. Hop-by-hop RELIABLE is not end-to-end reliability — Zenoh has no broker persistence; a disconnected subscriber's data is gone. The design assumes this (client state machine, §9).
|
||||
|
||||
### 6.4 Liveliness (presence + watchdogs)
|
||||
|
||||
- Client declares a liveliness token on `.../<client_uuid>/alive`. The server liveliness-subscribes with `history=True`: token appear → ensure session state; token drop → GC the session (mailbox, processor instances) after a grace period.
|
||||
- Server declares `.../server/alive`. The client liveliness-subscribes: on drop → treat as RECONNECTING (§9), hold/fallback per config, re-run the `status`/`session` handshake when the token reappears.
|
||||
- Tune the transport lease down from its default so ungraceful-death detection is seconds, not tens of seconds (verify the default in the pinned version; it is config `transport/link/tx/lease`).
|
||||
- Liveliness cannot detect a _hung-but-connected_ server. The client's per-request timeout (`request_timeout_s`) is the authoritative watchdog — this is the structural fix for legacy BUG-3 (no deadlines on `GetActions`).
|
||||
|
||||
### 6.5 Threading constraints (zenoh-python facts that shape both processes)
|
||||
|
||||
- **No asyncio API** in zenoh-python — both client and server are thread-based. This matches the existing RTC engine pattern exactly.
|
||||
- Each callback-based subscriber spawns a dedicated Python thread; **blocking Zenoh calls inside callbacks are disallowed**. Callbacks must be deposit-only (write a slot, set an event, return).
|
||||
- Channel handlers (`FifoChannel`, `RingChannel`) are Rust-side; `try_recv()` polls without spawning Python threads. `RingChannel(1)` is native latest-only semantics.
|
||||
- No zero-copy path for our payloads (SHM API is `@_unstable` and same-host-only; `ZBytes` copy behavior undocumented). At ~200 KB × a few Hz per robot, one memcpy is irrelevant.
|
||||
|
||||
### 6.6 Router deployment
|
||||
|
||||
- `zenohd` official image as a k8s Deployment (1–N replicas; routers mesh and reroute around failures) behind a `LoadBalancer`/`NodePort` Service exposing TLS 7447. No official Helm chart exists — roll-your-own manifests.
|
||||
- `scouting.multicast.enabled: false`; `scouting.gossip.enabled: true`; clients/servers use static `connect.endpoints`.
|
||||
- **Auth**: mTLS per robot (`transport.link.tls` with `enable_mtls`) + router **ACL** keyed on `cert_common_names`: a robot's cert may only `put` to `@lerobot/**/<its-uuid>/obs` and receive on `.../<its-uuid>/action`. Caveat (flagged): ACL config reloads require a router restart — plan cert/ACL changes as rolling router restarts.
|
||||
- Security review input: the third-party Zenoh protocol security analysis (Census Labs, 2025) should be read before exposing 7447 publicly.
|
||||
|
||||
---
|
||||
|
||||
## 7. The Statelessness Boundary (the load-bearing section)
|
||||
|
||||
**Where the network cut goes.** The local RTC pipeline is:
|
||||
|
||||
```
|
||||
obs (robot-processed dict)
|
||||
→ build_dataset_frame(hw_features, obs, "observation") CLIENT (cheap, hardware-coupled)
|
||||
─────────────────────────── network ───────────────────────────
|
||||
→ prepare_observation_for_inference(...) SERVER (policy-coupled, heavy)
|
||||
→ per-session preprocessor(...) SERVER (stateful within the request)
|
||||
→ policy.predict_action_chunk(obs, inference_delay, prefix) SERVER (pure for allowlisted policies)
|
||||
→ per-session postprocessor(...) SERVER (reads state cached at preprocess)
|
||||
─────────────────────────── network ───────────────────────────
|
||||
→ ActionQueue.merge(original, processed, real_delay, idx_before) CLIENT
|
||||
```
|
||||
|
||||
Three consequences:
|
||||
|
||||
1. **The server needs no cross-request state.** `RelativeActionsProcessorStep` writes `_last_state` at preprocess and the postprocessor reads it back _within the same request_. Per-session pipeline instances + one-request-at-a-time-per-session give correctness with zero persistent state.
|
||||
2. **RTC state stays client-side**, exactly where `RTCInferenceEngine` already keeps it. Each request ships: `inference_delay_steps = ceil(L_max/dt)` (from the client `LatencyTracker`, whose samples are full network-inclusive cycle times — RTT compensation falls out for free), `prefix_model = queue.get_left_over()[:H]`, and `prefix_robot = queue.get_processed_left_over()[:H]` (needed for server-side relative-prefix re-anchoring, mirroring `rtc.py:287-305`). The response returns **both** the model-space and robot-space chunks because `merge` needs both. ≤ `execution_horizon × action_dim` float32 each — a few hundred bytes.
|
||||
3. **G9 dies structurally.** No bespoke client resize (`F.interpolate` in legacy `helpers.py`), no client-side normalization. Clients ship native camera resolution; the server's canonical processor path does everything — serve-time preprocessing is byte-identical to train-time.
|
||||
|
||||
**What the server _does_ hold** (and what it means):
|
||||
|
||||
- Per-session processor instances (cheap; normalization stat tensors shared read-only).
|
||||
- Per-session episode counter + stats. Episode reset = reset the session's pipelines, clear its mailbox. **`policy.reset()` is never called in shared mode** — it is global to the shared policy instance and unnecessary for chunk-pure policies (ACT's ensembler and Pi0/SmolVLA's queues live in `select_action`, not `predict_action_chunk` — verified).
|
||||
- Policies that are _not_ chunk-pure get `serving_mode: exclusive` (§8.3).
|
||||
|
||||
---
|
||||
|
||||
## 8. The Inference Server: `lerobot-policy-server`
|
||||
|
||||
New package `src/lerobot/policy_server/`; console script `lerobot-policy-server --manifest manifest.yaml`.
|
||||
|
||||
### 8.1 Process model — **KEPT** from v1, amended
|
||||
|
||||
One process = one model+task on one GPU, loaded and warmed at startup (`warmup_inferences` dummy forwards; covers torch.compile). Multi-GPU nodes run N processes (`CUDA_VISIBLE_DEVICES` pinning). Dynamic model loading (`SendPolicyInstructions`) is **rejected**: pickle/RCE surface, arbitrary-download surface, and it destroys capacity planning. Amendment: `pin_task: false` (default) lets VLA clients set the task per session; `pin_task: true` rejects mismatched tasks at session open.
|
||||
|
||||
### 8.2 Concurrency (pure threads — no asyncio in zenoh-python)
|
||||
|
||||
```
|
||||
zenoh subscriber (.../*/obs) inference worker (1 thread, owns GPU)
|
||||
deposit-only callback: loop:
|
||||
slots[client_uuid] = sample ──► pick next session with pending obs (RR ring)
|
||||
(per-client latest-only) decode JPEG → per-session preprocess
|
||||
predict_action_chunk(delay, prefix)
|
||||
control queryables (status/session/ per-session postprocess → encode
|
||||
reset): validate, mutate session publisher.put(.../<uuid>/action)
|
||||
registry, reply (publishing from the worker thread is fine)
|
||||
```
|
||||
|
||||
- **Per-client latest-only mailbox**: a wildcard subscriber with a deposit-only callback writing per-client slots (scales to dynamic fleets), or — when the manifest enumerates clients — one `RingChannel(1)` subscriber per client polled via `try_recv()`. Either way: newest observation wins; a superseded request is counted (`superseded_seqs` in the next response) so drops are visible. This deletes legacy BUG-4 (`observations_similar` + `must_go`) by construction — the **client** decides when to request; the server never second-guesses observation content.
|
||||
- **Single inference worker**: torch releases the GIL inside `forward`, callbacks stay responsive. Strict round-robin over sessions with pending observations: each gets exactly one inference per cycle; starvation is structurally impossible. Overload degrades into longer cycle times → larger (but correct) client `delay_steps` → eventually the client staleness bound trips and the robot holds — safe by construction.
|
||||
|
||||
### 8.3 Chunk-stateless allowlist and serving modes
|
||||
|
||||
At startup the server classifies the loaded policy:
|
||||
|
||||
| Class | Policies (verified) | Mode |
|
||||
| --------------- | ------------------------------------------------------------------------------------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| chunk-stateless | ACT, Pi0, Pi0.5, SmolVLA (and any policy whose `predict_action_chunk` touches no instance state) | `shared`: N sessions, per-session pipelines, `policy.reset()` never called |
|
||||
| chunk-stateful | Diffusion family (`predict_action_chunk` reads `select_action`-fed `self._queues`) | `exclusive`: `max_sessions=1` enforced; episode reset additionally calls `policy.reset()`; second session open → rejected with a self-explanatory error |
|
||||
| no chunk API | SAC, SARM | refused at startup |
|
||||
|
||||
Implemented as a registry in `policy_server/validation.py`; the cleaner follow-up is a `supports_stateless_chunking` class attribute on `PreTrainedPolicy` (needs a pass over policy families — roadmap §14).
|
||||
|
||||
### 8.4 Session open & capability validation (fail fast, fail loud)
|
||||
|
||||
`session` queryable payload: `client_uuid`, `policy_type`, `fps`, feature summary (post-rename observation feature names + shapes, ordered action keys), `schema_version`, RTC intent, `tags`. Checks:
|
||||
|
||||
| Check | Rule | On mismatch |
|
||||
| -------------------------- | --------------------------------------------------------------- | ---------------------------------------------------------------------------------- |
|
||||
| Action names **and order** | must equal server's `action_feature_names` exactly | **hard reject** — this is the sync-safety contract mapping chunk columns to motors |
|
||||
| Camera names | client set must cover `policy.config.input_features` image keys | hard reject |
|
||||
| Resolution | any H×W accepted (server resizes canonically) | warn if aspect ratio differs from training |
|
||||
| State dim | flattened dim must match | hard reject |
|
||||
| `schema_version` | client within server's supported range | hard reject |
|
||||
| fps | vs. manifest `trained_fps` | warn (reject only when `strict_fps: true`) |
|
||||
| Task | when `pin_task: true`, must equal `default_task` | reject |
|
||||
| RTC | client RTC requires policy RTC kwargs support | downgrade to append mode + warning |
|
||||
| Capacity | `active_sessions < max_sessions` | reject with current load → client retries another replica |
|
||||
|
||||
Reply: `session_id`, model info (repo, revision — consider a checkpoint hash, §15), `action_feature_names`, `chunk_size`, `trained_fps`, `supports_rtc`, `serving_mode`, `warmed_up`, `schema_version`, warnings. **rename_map is applied client-side** so the wire format is canonical policy-feature keys across heterogeneous robots (also a prerequisite for future batching).
|
||||
|
||||
### 8.5 Scheduler seam (micro-batching later, not in v1)
|
||||
|
||||
The worker calls a `Scheduler.select(ready: list[Session]) -> list[Session]`; v1 ships `RoundRobin` (`return ready[:1]`). Cross-session batching is blocked on the policy API (`inference_delay` is scalar; batched clients have different delays/prefixes) — when that lands, a `MicroBatch` scheduler groups same-shape sessions. The seam costs nothing now and prevents a redesign later.
|
||||
|
||||
### 8.6 Manifest
|
||||
|
||||
```yaml
|
||||
model:
|
||||
{
|
||||
repo_or_path: lerobot/pi0_towels,
|
||||
revision: main,
|
||||
dtype: bfloat16,
|
||||
device: cuda,
|
||||
}
|
||||
default_task: "fold the towel"
|
||||
pin_task: false
|
||||
serving_mode: shared # forced to exclusive for chunk-stateful policies
|
||||
max_sessions: 5 # from the §P10 formula: Pi0 @150ms, 1 Hz refresh
|
||||
warmup_inferences: 2
|
||||
strict_fps: false
|
||||
zenoh:
|
||||
connect_endpoints: ["tls/router.gpu-cluster.internal:7447"]
|
||||
tls:
|
||||
{
|
||||
connect_certificate: ...,
|
||||
connect_private_key: ...,
|
||||
root_ca_certificate: ...,
|
||||
}
|
||||
health_port: 9100 # HTTP health + Prometheus metrics
|
||||
debug: { capture_dir: null, capture_max: 256 }
|
||||
```
|
||||
|
||||
Draccus dataclass in `policy_server/manifest.py`; YAML via `--manifest`, individual overrides via CLI.
|
||||
|
||||
---
|
||||
|
||||
## 9. The Edge Client: `RemoteInferenceEngine`
|
||||
|
||||
New file `src/lerobot/rollout/inference/remote.py`, registered `@InferenceEngineConfig.register_subclass("remote")`.
|
||||
|
||||
### 9.1 Threading model
|
||||
|
||||
| Thread | Role |
|
||||
| -------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| Main (strategy loop) | `notify_observation(obs)` → lock-protected latest-only slot (identical to `rtc.py` `_obs_holder`). `get_action()` → `ActionQueue.get()` + staleness check. **Never any I/O.** Structurally fixes legacy BUG-1 (blocking send inside the 33 ms loop). |
|
||||
| Network worker (1 daemon thread) | Cycle: wait until `queue_remaining·dt ≤ buffer_time_s` and active → snapshot `idx_before`, prefixes, `delay_steps = ceil(L_max/dt)` → encode (JPEG q=`jpeg_quality`) → `publisher.put(obs, attachment=header)` → await chunk on the action subscriber channel (timeout `request_timeout_s`) → `merge(original, processed, ceil(L/dt), idx_before)` → `latency_tracker.add(L)`. Owns the state machine, reconnects, and control queries. One-in-flight (P5). |
|
||||
| Zenoh action subscriber | `FifoChannel(2)` handler drained by the worker (no Python callback thread on the hot path); liveliness subscriber callback is deposit-only (sets an event). |
|
||||
|
||||
Reused unchanged: `ActionQueue` (`policies/rtc/action_queue.py`), `LatencyTracker`, `ActionInterpolator` (lives in strategies — `interpolation_multiplier` works with remote for free). Deleted concepts: aggregation zoo, `observations_similar`, `must_go`, `TimedObservation`/`TimedAction` pickles.
|
||||
|
||||
### 9.2 Fail-safe state machine
|
||||
|
||||
```
|
||||
ok no chunk for degraded_after_s
|
||||
CONNECTING ─────► STREAMING ───────────────────────────────► DEGRADED
|
||||
│ ▲ ▲ │ queue empty OR max_action_age_s hit │
|
||||
│ │ backoff, │ └───────────────────────────────────► STALLED ◄──┘
|
||||
│ │ re-handshake │ first successful merge │
|
||||
│ └─ RECONNECTING ◄── timeout streak / server liveliness drop ◄─┘
|
||||
│ │ offline > max_offline_s, capability/schema mismatch, auth failure
|
||||
└──────► DEAD (failed=True → shutdown_event → strategy teardown: return-to-initial-pose)
|
||||
```
|
||||
|
||||
- **DEGRADED**: requests failing but the queue still holds actions — the robot keeps executing; chunks _are_ the fault-tolerance buffer (1–3 s of coverage makes blips and clean server drains invisible).
|
||||
- **STALLED**: queue empty or staleness bound hit → apply `fallback`: `hold` (`get_action` → `None`; `send_next_action` already tolerates it), `repeat_last`, or `zero` (required for velocity-controlled robots, where "send nothing" means "keep last velocity").
|
||||
- **Staleness bound** (sync safety): every merge records `(chunk_start_index, t_send)`; `get_action` refuses any action whose source observation is older than `max_action_age_s` (default 3.0 s ≈ 90 steps @ 30 fps). Bounds open-loop execution after a network stall.
|
||||
- **DEAD**: only after `max_offline_s` (default 60 s) or a hard contract violation (capability/schema mismatch on reconnect — e.g. the server restarted with a different model; never execute wrong-model chunks). Uses the exact mechanism RTC uses (`failed=True` + global `shutdown_event`) so existing teardown runs unchanged.
|
||||
- **Watchdog layering**: per-request timeout (hung server — the BUG-3 fix) → server liveliness token (dead server/router) → staleness bound (the robot-side invariant that holds regardless of why data stopped).
|
||||
- **Pause/resume (DAgger)**: `pause()` stops the worker publishing (slot keeps refreshing, ignored); queue intact — parity with `RTCInferenceEngine.pause`. DAgger's existing `interpolator.reset(); engine.reset(); engine.resume()` sequence works unchanged.
|
||||
- **`reset()` (episode boundary)**: clear `ActionQueue` + staleness bookkeeping, bump `episode_id`, fire the acked `reset` query (1 s timeout, failure logged — the server has nothing it _must_ do thanks to per-request statelessness), flag `episode_start` on the next observation. `LatencyTracker` intentionally survives reset (latency is episode-invariant; parity with local RTC).
|
||||
- **`ready`** = session opened ∧ capabilities validated ∧ server `warmed_up`. First-chunk gating is implicit (`get_action` → `None` until the first merge).
|
||||
|
||||
### 9.3 Weightless client — exact integration changes
|
||||
|
||||
- `rollout/context.py`: `PolicyContext.{policy, preprocessor, postprocessor}` become `| None`. For remote configs, skip step 1 (weight load / PEFT / `.to(device)` / torch.compile / `init_rtc_processor`) and step 6 (`make_pre_post_processors`). Verified safe: strategies only consume `ctx.policy.inference`. Keep steps 2–5 (robot processors, hardware, features, dataset) — they are robot-derived. Keep the visual pre-flight check (`context.py:309-324`): `--policy.path` already loads config-only (`rollout/configs.py:324-328`, no weight download) and failing before dialing the server is free. `use_torch_compile` / explicit `--device` → warn-and-ignore for remote.
|
||||
- `rollout/inference/factory.py`: signature loosens to `policy: PreTrainedPolicy | None` (+ `policy_config: PreTrainedConfig`); `sync`/`rtc` branches guard `policy is None`; the `remote` branch lazy-imports (`eclipse-zenoh` stays an optional extra).
|
||||
- The authoritative validation moves to session open (§8.4); the local check becomes a fast-fail convenience.
|
||||
|
||||
### 9.4 Config
|
||||
|
||||
```python
|
||||
@InferenceEngineConfig.register_subclass("remote")
|
||||
@dataclass
|
||||
class RemoteInferenceConfig(InferenceEngineConfig):
|
||||
connect_endpoint: str = "tls/localhost:7447" # zenoh router endpoint
|
||||
tls_cert: str | None = None; tls_key: str | None = None; tls_ca: str | None = None
|
||||
client_uuid: str = "" # "" → uuid4 at start()
|
||||
jpeg_quality: int = 90 # 0 = raw (LAN/debug)
|
||||
buffer_time_s: float = 0.5 # send next obs when queue playback ≤ this (v1 G14) — KEPT
|
||||
max_action_age_s: float = 3.0 # staleness bound (safety)
|
||||
degraded_after_s: float = 1.0
|
||||
request_timeout_s: float = 5.0
|
||||
reconnect_initial_backoff_s: float = 0.5
|
||||
reconnect_max_backoff_s: float = 10.0
|
||||
max_offline_s: float = 60.0
|
||||
fallback: FallbackBehavior = FallbackBehavior.HOLD # hold | repeat_last | zero
|
||||
rtc: RTCConfig = field(default_factory=RTCConfig) # enabled → replace mode; horizon caps prefix
|
||||
tags: dict[str, str] = field(default_factory=dict) # ex-cluster/experiment labels
|
||||
```
|
||||
|
||||
```bash
|
||||
# Remote RTC + sentry recording (the reproducibility path)
|
||||
lerobot-rollout \
|
||||
--strategy.type=sentry \
|
||||
--policy.path=lerobot/pi0_towels \ # config-only: no weights downloaded
|
||||
--inference.type=remote \
|
||||
--inference.connect_endpoint=tls/router.gpu-cluster.internal:7447 \
|
||||
--inference.rtc.execution_horizon=10 \
|
||||
--robot.type=so100_follower --robot.port=/dev/ttyACM0 \
|
||||
--robot.cameras="{front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||
--dataset.repo_id=user/rollout_fleet_a --dataset.single_task="fold the towel"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 10. Wire Schema
|
||||
|
||||
### 10.1 Payload anatomy & rates — **KEPT** (JPEG) with numbers
|
||||
|
||||
Upstream per request: joints (24–128 B) + JPEG frames (480p q90 ≈ 40–90 KB each; 720p ≈ 110–230 KB) + RTC prefixes (≤ a few KB) → 60–450 KB depending on cameras. Downstream: `2 × chunk_size × action_dim × 4 B` + metadata → 3–50 KB. Effective request rate is self-clocked by `buffer_time_s` to ~1–4 Hz per robot (not the 30 Hz control rate). 300 robots ≈ 0.3–10 Mbps each — the wire is never the bottleneck; bandwidth budgeting is about camera count/resolution, and each GPU pod only ever sees its own ≤ `max_sessions` clients. Zenoh fragments >64 KiB payloads transparently; multi-MB messages are fine.
|
||||
|
||||
### 10.2 Attachment header (fixed-layout, packed little-endian — parsed without touching the body)
|
||||
|
||||
| Field | Type | Notes |
|
||||
| ---------------- | ---- | -------------------------------------------------------------- |
|
||||
| `schema_version` | u16 | negotiated at session open |
|
||||
| `msg_type` | u8 | OBS / CHUNK / EVENT |
|
||||
| `seq_id` | u64 | per-session monotonic; echoed in the chunk |
|
||||
| `episode_id` | u32 | bumped by `reset()` |
|
||||
| `client_mono_ns` | i64 | client `monotonic_ns()`; **opaque to the server, echoed back** |
|
||||
| `session_epoch` | u32 | bumped per (re)connect; stale-epoch chunks dropped |
|
||||
|
||||
### 10.3 msgpack bodies
|
||||
|
||||
**ObservationMsg** (client → server): `state: {names_ref, data: f32 LE bytes}`, `images: {name: {codec: jpeg|raw, bytes, (h,w,c) if raw}}`, `task: str`, `inference_delay_steps: int`, `prefix_model: tensor?`, `prefix_robot: tensor?` (tensors = raw LE bytes + dtype + shape), `episode_start: bool`.
|
||||
**ActionChunkMsg** (server → client): `seq_id_echo`, `client_mono_ns_echo`, `chunk_model: tensor`, `chunk_robot: tensor`, `queue_wait_ms: f32`, `inference_ms: f32`, `superseded_seqs: u32`, `server_load: f32`.
|
||||
**Status / SessionOpen / SessionAck / ResetMsg**: as specified in §8.4.
|
||||
|
||||
### 10.4 Schema discipline (P7)
|
||||
|
||||
`schema_version` gates at handshake; evolution is additive-only (new optional msgpack keys; unknown keys ignored); attachment layout changes require a version bump; golden codec round-trip tests (tensor exactness, JPEG RGB-channel-order regression — a silent BGR swap poisons every VLA in the fleet) are part of the test suite. **No pickle anywhere** — KEPT from v1 and now structural: nothing in the schema can carry code.
|
||||
|
||||
---
|
||||
|
||||
## 11. Latency Budget & the Clock Iron Rule
|
||||
|
||||
| Stage | LAN | WAN (50 ms RTT) |
|
||||
| ------------------------------ | --------------- | --------------- |
|
||||
| JPEG encode ×3 (edge CPU) | 2–9 ms | 2–9 ms |
|
||||
| Serialize | <1 ms | <1 ms |
|
||||
| Uplink (tx + ½RTT) | ~2 ms | ~54 ms |
|
||||
| Server queue wait | 0 → 1×inference | 0 → 1×inference |
|
||||
| Decode + canonical preprocess | 4–10 ms | 4–10 ms |
|
||||
| **Inference** | **15–150 ms** | **15–150 ms** |
|
||||
| Postprocess + downlink + merge | ~2 ms | ~27 ms |
|
||||
| **Total (Pi0-class)** | **~110–175 ms** | **~190–250 ms** |
|
||||
|
||||
Inference is 60–85 % of end-to-end on LAN; the entire transport+serialization stack is <10 ms. WAN adds propagation + uplink bandwidth — identical under any transport. At 30 fps this lands `delay_steps` ≈ 4–8, comfortably inside RTC execution horizons: WAN degrades smoothness parameters, never correctness. _This table is the standing answer to transport-performance bikeshedding._
|
||||
|
||||
**Clock iron rule** (P4): wall-clock instants never cross machines. Client stamps `monotonic_ns`, the server echoes it opaquely; `RTT = now − echo`. The server reports only **durations** (`queue_wait_ms`, `inference_ms`) measured on its own monotonic clock; `network_time = RTT − queue_wait − inference` for diagnostics. The schema has no field in which a foreign wall-clock instant can be compared — the legacy `time.time()` bug is unrepresentable.
|
||||
|
||||
---
|
||||
|
||||
## 12. Reproducibility & Audit (P8)
|
||||
|
||||
The contract is **fully logged + replayable**, not "deterministic":
|
||||
|
||||
- **Client = source of truth.** Recording strategies already persist observations + executed actions to `LeRobotDataset`. The remote engine logs, per executed action, the `(session_id, seq_id, episode_id)` of its source chunk plus the echoed `queue_wait_ms`/`inference_ms` (dataset-extras columns are a follow-up; client logs in v1).
|
||||
- **Server audit line per request** (structured JSON): `{ts, session_id, client_uuid, seq_id, episode_id, queue_wait_ms, inference_ms, chunk_range, superseded_seqs, outcome}`.
|
||||
- **Optional bounded capture**: `debug.capture_dir` writes a ring of request/response pairs (safetensors) for byte-exact offline replay through the same server pipeline.
|
||||
- **Runbook — "robot #217 stuttered at 14:03"**: (1) Grafana `session_staleness{client="217"}` — spike ⇒ server side, flat ⇒ client/network. (2) Server side: audit lines — `queue_wait_ms` rising across _all_ sessions ⇒ overloaded replica (check `active_sessions` vs `max_sessions`); `superseded_seqs` streak on 217 only ⇒ that client over-requesting; `outcome=error` ⇒ adjacent stack trace. (3) Client side: state-machine transitions + reconnects in the client log; dataset rows show which seq's chunk was executing and where `None` ticks occurred. Every hop shares `(session_id, seq_id)` — the join is mechanical.
|
||||
|
||||
---
|
||||
|
||||
## 13. Integration & Migration Plan
|
||||
|
||||
### 13.1 New
|
||||
|
||||
| Path | Content |
|
||||
| --------------------------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `src/lerobot/policy_server/{__init__,schema,codec,manifest,session,scheduler,validation,server}.py` | wire schema constants, msgpack/attachment codecs, manifest dataclasses, `Session` + mailbox, `Scheduler` seam, capability rules + chunk-stateless registry, zenoh servicer + inference worker + drain + HTTP health/metrics |
|
||||
| `src/lerobot/rollout/inference/remote.py` | `RemoteInferenceEngine` (~600 lines; mirrors `rtc.py` structure) |
|
||||
| `src/lerobot/scripts/lerobot_policy_server.py` + `[project.scripts]` entry | thin `main()` |
|
||||
| `docker/Dockerfile.policy-server` | CUDA runtime base + uv; manifest via ConfigMap |
|
||||
| `docs/source/remote_inference.mdx` (+ `_toctree.yml`) | replaces `async.mdx` |
|
||||
|
||||
### 13.2 Modified
|
||||
|
||||
`rollout/inference/factory.py` (config + Optional-typed signature + lazy import) · `rollout/context.py` (weightless branch) · `rollout/inference/__init__.py` · `scripts/lerobot_rollout.py` docstring · `pyproject.toml`: `[async]` extra becomes `eclipse-zenoh>=1.9,<2.0` + `msgpack` (grpcio/matplotlib leave it; grpcio remains under `[hilserl]`/`dev` for the RL stack).
|
||||
|
||||
### 13.3 Removed — same landing PR
|
||||
|
||||
`src/lerobot/async_inference/` · `tests/async_inference/` · `docs/source/async.mdx` + its `_toctree.yml` entry · the `AsyncInference` service + `Observation`/`Actions`/`PolicySetup` messages from `src/lerobot/transport/services.proto` (regenerate pb2; **`LearnerService` untouched** — `transport/` is shared with HIL-SERL (`src/lerobot/rl/`); the RL test suite gates this change).
|
||||
|
||||
### 13.4 Legacy config → successor mapping
|
||||
|
||||
| Legacy (`RobotClientConfig`/`PolicyServerConfig`) | Successor |
|
||||
| ------------------------------------------------- | ---------------------------------------------------------- |
|
||||
| `server_address` | `--inference.connect_endpoint` (zenoh router) |
|
||||
| `policy_type`, `pretrained_name_or_path` | `--policy.path` (config-only) + server manifest |
|
||||
| `chunk_size_threshold` (0–1 ratio) | `--inference.buffer_time_s` (seconds) |
|
||||
| `actions_per_chunk` | server manifest (validated at session open) |
|
||||
| `aggregate_fn_name` + `AGGREGATE_FUNCTIONS` | **dropped** — `ActionQueue` replace/append |
|
||||
| `policy_device`, `client_device` | **dropped** — server concern / chunks arrive CPU f32 |
|
||||
| `debug_visualize_queue_size` | **dropped** — Rerun (`--display_data`) + engine stats |
|
||||
| `PolicyServerConfig.{host,port}` | manifest `zenoh.connect_endpoints` |
|
||||
| `inference_latency`, `obs_queue_timeout` | **dropped** — latency client-measured; no server obs queue |
|
||||
| `SendPolicyInstructions` | **dropped** — MaaS manifest + session validation |
|
||||
| `observations_similar` / `must_go` | **dropped** — latest-only slots + client send gate |
|
||||
| pickle envelopes | **dropped** — msgpack + attachment headers |
|
||||
|
||||
### 13.5 Legacy bugs/gaps → structural resolution
|
||||
|
||||
BUG-1 → worker thread owns all I/O. BUG-2 → aggregation deleted; `ActionQueue` is internally locked. BUG-3 → per-request timeout + liveliness. BUG-4 → client-side send gating; server newest-wins. G1 → per-session registry. G2 → manifest. G4 → msgpack+attachments. G5 → monotonic echo + `delay_steps`. G7 → recording strategies. G8 → mTLS + ACL. G9 → server-side canonical processors. G11 → `status` queryable. G12 → Prometheus + audit logs. G13 → `lerobot-policy-server` console script. G14 → `buffer_time_s`.
|
||||
|
||||
### 13.6 Tests
|
||||
|
||||
- **Unit**: codec round-trips (tensor exact; JPEG RGB-order regression), capability-validation matrix (§8.4 as parametrized cases), scheduler fairness + newest-wins supersession (mock policy with configurable sleep), manifest parsing, key-expr sanitization.
|
||||
- **Loopback integration** (CPU, fast CI): client+server in one process over zenoh peer-to-peer (or a localhost `zenohd` started by the fixture), tiny-ACT, fake 2-camera robot, N=8 concurrent sessions. The headline regression: two sessions with different joint states must not cross-contaminate `RelativeActionsProcessorStep` postprocessing — the test that proves the multi-tenancy claim.
|
||||
- **Chaos**: kill the server mid-episode → client returns `None`, never raises into the control loop, `failed` stays False within `max_offline_s`, resumes on restart; `docker kill zenohd` → liveliness flap → safe state → re-handshake (explicitly tests re-declaration behavior, flagged unverified upstream); SIGTERM drain → in-flight chunk completes, clients reconnect invisibly.
|
||||
- **Golden parity**: remote RTC vs local `RTCInferenceEngine` on identical observation sequences → byte-identical merged queues (the re-anchoring contract test). Gate for any real-robot remote-RTC use.
|
||||
|
||||
---
|
||||
|
||||
## 14. Roadmap
|
||||
|
||||
1. **PR1 — schema & codecs** (no torch deps): `policy_server/{schema,codec,manifest}.py`, key-expr sanitizer, golden codec tests.
|
||||
2. **PR2 — server core**: session registry, scheduler, validation/allowlist, inference worker with mock policy, loopback harness.
|
||||
3. **PR3 — client engine**: `RemoteInferenceEngine`, factory/context weightless integration, loopback integration + chaos + golden-parity tests.
|
||||
4. **PR4 — ops & docs**: Dockerfile, health/metrics, drain, ACL examples, `remote_inference.mdx`, rollout docstring.
|
||||
5. **Landing PR — legacy deletion**: remove `async_inference/` + tests + docs + proto service (RL suite gates), `[async]` extra swap.
|
||||
6. **Pre-release field validation**: one real robot on a lossy network (watchdog default tuning); JPEG q90 vs raw A/B on one policy (train/serve shift).
|
||||
7. **Future**: micro-batching (needs per-sample `inference_delay` across policy families), client-side downscale-to-policy-resolution (config-only shapes make it possible), Advanced Pub/Sub on the action topic, per-robot quotas, dataset provenance columns, `supports_stateless_chunking` attribute upstreamed to policy classes.
|
||||
|
||||
---
|
||||
|
||||
## 15. Open Risks
|
||||
|
||||
| Risk | Mitigation / decision needed |
|
||||
| ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| Re-anchoring parity (server-side relative-prefix re-anchor vs `rtc.py`) | Golden parity test (§13.6) is a hard gate before robot use; likely failure mode is normalizer dtype/device drift |
|
||||
| First-chunk over-trim when idle: `merge` trims `ceil(L/dt)` even when nothing was consumed (queue empty at episode start) — wasteful at network latencies (600 ms ⇒ 18 steps) | Proposed clamp `real_delay = min(real_delay, last_index - idx_before)` touches the shared `ActionQueue` used by local RTC — needs sign-off + regression tests |
|
||||
| JPEG train/serve distribution shift | Unmeasured; A/B before locking q90 default (roadmap §14.6) |
|
||||
| Watchdog defaults untuned (`request_timeout_s=5`, `degraded_after_s=1`, `max_action_age_s=3`) | Field validation on wired and Wi-Fi; consider named profiles |
|
||||
| Capability check can pass while semantics differ (different finetune, different normalization stats, identical feature names) | Add checkpoint hash/revision pinning to SessionAck — decide in PR2 |
|
||||
| zenoh-python long-session maturity: re-declaration after router restart partially verified; SHM unstable; no asyncio | Chaos tests own this; thread-based design avoids the asyncio gap entirely |
|
||||
| Router ACL reload requires restart | Operational runbook: cert/ACL changes = rolling router restart |
|
||||
| `fallback=zero` has no consumer until velocity actions land in rollout (only `.pos` features routed today) | Validate the enum against robot capabilities when velocity support lands |
|
||||
| Per-client mailbox memory under fleet-scale wildcard subscription | One decoded-obs slot per client is small; add an LRU GC tied to liveliness drops |
|
||||
@@ -0,0 +1,82 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# This Dockerfile builds a GPU inference pod for `lerobot-policy-server`
|
||||
# (remote inference over Zenoh). It starts from an NVIDIA CUDA base image;
|
||||
# the cu128 PyTorch wheels bundle their own CUDA runtime (driver floor 570.86,
|
||||
# see pyproject.toml [tool.uv]).
|
||||
|
||||
# docker build -f docker/Dockerfile.policy-server -t lerobot-policy-server .
|
||||
# docker run --gpus all -v ./server.yaml:/etc/lerobot/server.yaml lerobot-policy-server
|
||||
#
|
||||
# Extra policy-family dependencies (e.g. pi0/smolvla need transformers) can be
|
||||
# added at build time:
|
||||
# docker build -f docker/Dockerfile.policy-server \
|
||||
# --build-arg LEROBOT_EXTRAS="async pi0" -t lerobot-policy-server .
|
||||
|
||||
# Configure the base image (same CUDA family as Dockerfile.internal)
|
||||
ARG CUDA_VERSION=12.8.1
|
||||
ARG OS_VERSION=24.04
|
||||
FROM nvidia/cuda:${CUDA_VERSION}-base-ubuntu${OS_VERSION}
|
||||
|
||||
# Define Python version and lerobot extras arguments
|
||||
ARG PYTHON_VERSION=3.12
|
||||
ARG LEROBOT_EXTRAS="async"
|
||||
|
||||
# Configure environment variables
|
||||
ENV DEBIAN_FRONTEND=noninteractive \
|
||||
PATH=/lerobot/.venv/bin:$PATH
|
||||
|
||||
# Install system dependencies and uv (as root).
|
||||
# Kept lean: no hardware/teleop libraries — this image only serves policies.
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
git curl ca-certificates libglib2.0-0 ffmpeg \
|
||||
&& curl -LsSf https://astral.sh/uv/install.sh | sh \
|
||||
&& mv /root/.local/bin/uv /usr/local/bin/uv \
|
||||
&& useradd --create-home --shell /bin/bash user_lerobot \
|
||||
&& apt-get clean && rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Create application directory and set permissions
|
||||
WORKDIR /lerobot
|
||||
RUN chown -R user_lerobot:user_lerobot /lerobot
|
||||
|
||||
# Switch to the non-root user
|
||||
USER user_lerobot
|
||||
|
||||
# Model checkpoints are cached under HF_HOME — mount it as a volume
|
||||
# (or a PVC in Kubernetes) so warm restarts skip the Hub download.
|
||||
ENV HOME=/home/user_lerobot \
|
||||
HF_HOME=/home/user_lerobot/.cache/huggingface \
|
||||
HF_LEROBOT_HOME=/home/user_lerobot/.cache/huggingface/lerobot \
|
||||
TORCH_HOME=/home/user_lerobot/.cache/torch \
|
||||
TRITON_CACHE_DIR=/home/user_lerobot/.cache/triton
|
||||
|
||||
# Create the virtual environment (Python provisioned by uv)
|
||||
RUN uv venv --python ${PYTHON_VERSION}
|
||||
|
||||
# Install lerobot from the build context with the async extra
|
||||
# (eclipse-zenoh + msgpack — see pyproject.toml [project.optional-dependencies])
|
||||
COPY --chown=user_lerobot:user_lerobot setup.py pyproject.toml uv.lock README.md MANIFEST.in ./
|
||||
COPY --chown=user_lerobot:user_lerobot src/ src/
|
||||
|
||||
RUN uv sync --locked --no-cache $(printf -- '--extra %s ' ${LEROBOT_EXTRAS})
|
||||
|
||||
# HTTP health + Prometheus metrics (manifest `health_port`, 0 disables)
|
||||
EXPOSE 9100
|
||||
|
||||
# The manifest is typically mounted as a ConfigMap (Kubernetes) or a bind
|
||||
# mount (docker run -v) at /etc/lerobot/server.yaml; any field can also be
|
||||
# overridden on the command line, e.g. --model.repo_or_path=lerobot/pi0_towels
|
||||
ENTRYPOINT ["lerobot-policy-server"]
|
||||
CMD ["--manifest", "/etc/lerobot/server.yaml"]
|
||||
@@ -45,8 +45,6 @@
|
||||
title: Language Columns and Recipes
|
||||
- local: tools
|
||||
title: Tools
|
||||
- local: annotation_pipeline
|
||||
title: Annotation Pipeline
|
||||
- local: video_encoding_parameters
|
||||
title: Video encoding parameters
|
||||
- local: streaming_video_encoding
|
||||
@@ -89,8 +87,8 @@
|
||||
- sections:
|
||||
- local: inference
|
||||
title: Policy Deployment (lerobot-rollout)
|
||||
- local: async
|
||||
title: Use Async Inference
|
||||
- local: remote_inference
|
||||
title: Remote Inference (lerobot-policy-server)
|
||||
- local: rtc
|
||||
title: Real-Time Chunking (RTC)
|
||||
title: "Inference"
|
||||
|
||||
@@ -1,291 +0,0 @@
|
||||
# Annotation Pipeline
|
||||
|
||||
`lerobot-annotate` watches each episode's video with a vision-language
|
||||
model (VLM) and writes natural-language annotations back into your
|
||||
dataset. It fills the two language columns from the
|
||||
[Language Columns and Recipes](./language_and_recipes) page —
|
||||
`language_persistent` and `language_events` — straight into
|
||||
`data/chunk-*/file-*.parquet`.
|
||||
|
||||
In short: point it at a LeRobot dataset, and it adds subtasks, plans,
|
||||
memory, interjections, speech, and visual Q&A that a policy can be
|
||||
trained on.
|
||||
|
||||
## How it fits together
|
||||
|
||||
```text
|
||||
your dataset lerobot-annotate
|
||||
(LeRobot v3.1)
|
||||
│
|
||||
▼
|
||||
┌─────────────────────────────────────────────────────┐
|
||||
│ read episodes │
|
||||
└──────────────────────────┬──────────────────────────┘
|
||||
│
|
||||
┌────────────────────┼────────────────────┐
|
||||
▼ ▼ ▼
|
||||
┌──────────┐ ┌───────────────┐ ┌──────────┐ one shared Qwen-VL
|
||||
│ plan │ │ interjections │ │ vqa │ ◀── server (vLLM, OpenAI
|
||||
└────┬─────┘ └───────┬───────┘ └────┬─────┘ API) drives all three
|
||||
└────────────────────┼─────────────────────┘
|
||||
│ each module stages raw JSONL
|
||||
▼ into .annotate_staging/
|
||||
┌─────────────────┐
|
||||
│ validator │ ◀── checks everything
|
||||
└────────┬────────┘
|
||||
▼
|
||||
┌─────────────────┐
|
||||
│ writer │
|
||||
└────────┬────────┘
|
||||
▼
|
||||
data/chunk-*/file-*.parquet
|
||||
(+ meta/info.json tools)
|
||||
```
|
||||
|
||||
Three modules (`plan`, `interjections`, `vqa`) all talk to **one** shared
|
||||
VLM. Each module stages its output to disk, a validator checks it, and a
|
||||
single writer rewrites the dataset shards in place.
|
||||
|
||||
## What the pipeline produces
|
||||
|
||||
Each module emits a few kinds of annotation ("styles"), routed to one of
|
||||
the two language columns:
|
||||
|
||||
| Style / atom | Column | Module |
|
||||
| ------------------------------------------- | --------------------- | --------------- |
|
||||
| `subtask` (Pi0.7-style "how, not what") | `language_persistent` | `plan` |
|
||||
| `plan` (initial + refresh on interjection) | `language_persistent` | `plan` |
|
||||
| `memory` (MEM-style compression) | `language_persistent` | `plan` |
|
||||
| `task_aug` (rephrasings of the task) | `language_persistent` | `plan` |
|
||||
| `interjection` | `language_events` | `interjections` |
|
||||
| speech tool-call atom (`style=null`, `say`) | `language_events` | `interjections` |
|
||||
| `vqa` (user / assistant pair) | `language_events` | `vqa` |
|
||||
|
||||
### How subtasks are generated
|
||||
|
||||
The `plan` module doesn't ask the VLM for subtasks in one shot. Instead
|
||||
it uses a two-step **describe → segment** flow:
|
||||
|
||||
1. **Describe** — the VLM narrates only what it actually sees in the
|
||||
chosen camera (no guessing about the task).
|
||||
2. **Segment** — that description is fed back in, and the VLM splits the
|
||||
episode into consecutive atomic subtasks.
|
||||
|
||||
Both passes see the episode as **timestamped contact sheets** — frames
|
||||
sampled at `frames_per_second` (0.5s by default) and packed into JPEG
|
||||
grids with each frame's time burned into its corner, so the VLM cites
|
||||
exact boundary times directly. This is far cheaper in vision tokens than
|
||||
one image per frame, so the sampling can stay dense; episodes longer than
|
||||
`max_frames_per_prompt` are split into windows at the same density and
|
||||
merged. Both prompts also carry a causal **event-boundary** definition (a
|
||||
new event starts when an object becomes held / is released / reaches a new
|
||||
location / a lid changes state / contents move) to sharpen where cuts land.
|
||||
|
||||
The resulting spans are then stitched into a gap-free, full-episode
|
||||
cover, so **every frame has exactly one active subtask**. See
|
||||
[`run_hf_job.py`](https://github.com/huggingface/lerobot/blob/main/examples/annotations/run_hf_job.py)
|
||||
for the production settings (single camera, timestamped contact sheets,
|
||||
auto-windowed subtask generation).
|
||||
|
||||
### Tools
|
||||
|
||||
The writer does **not** add a `tools` column to the parquet. The tool
|
||||
catalog lives in `meta/info.json["tools"]` instead (see [Tools](./tools)).
|
||||
After every run, the pipeline makes sure the canonical `say` schema is in
|
||||
that list, keeping any tools you declared beforehand.
|
||||
|
||||
Want to add your own tool? Edit `meta/info.json["tools"]` directly — the
|
||||
pipeline preserves whatever is already there. That makes the tool visible
|
||||
to the chat template, so the model can learn to _generate_ the call. The
|
||||
runtime layer that actually _executes_ a generated call (the `Tool`
|
||||
protocol / `TOOL_REGISTRY` under `src/lerobot/tools/`) is not part of
|
||||
this PR — the [Tools](./tools) doc marks those pieces as
|
||||
not-yet-implemented.
|
||||
|
||||
## Running on Hugging Face Jobs
|
||||
|
||||
Annotation runs on [Hugging Face Jobs](https://huggingface.co/docs/hub/en/jobs).
|
||||
The repo ships a launcher script you copy and tweak for your dataset:
|
||||
|
||||
```bash
|
||||
HF_TOKEN=hf_... uv run python examples/annotations/run_hf_job.py
|
||||
```
|
||||
|
||||
[`run_hf_job.py`](https://github.com/huggingface/lerobot/blob/main/examples/annotations/run_hf_job.py)
|
||||
starts a single-GPU `h200` job (bump it to `h200x4` for big datasets)
|
||||
that:
|
||||
|
||||
1. installs `lerobot` (from `main`) plus the annotation extras,
|
||||
2. boots one vLLM server per GPU (using the `vllm/vllm-openai` image) and
|
||||
drives it over the OpenAI-compatible API,
|
||||
3. runs the `plan` / `interjections` / `vqa` modules across the dataset
|
||||
with `lerobot-annotate`,
|
||||
4. with `--push_to_hub=true`, uploads the result to `--new_repo_id` (or
|
||||
back to `--repo_id` in place if you leave that unset).
|
||||
|
||||
To use a different dataset, model, or hub repo, edit the `CMD` block in
|
||||
the script. Every flag there maps directly to a `lerobot-annotate` flag
|
||||
(run `lerobot-annotate --help` for the full list).
|
||||
|
||||
## Key options
|
||||
|
||||
These are the flags you'll reach for most often. Run
|
||||
`lerobot-annotate --help` for everything else; the defaults are tuned for
|
||||
short manipulation episodes.
|
||||
|
||||
### Dataset in / out
|
||||
|
||||
| Flag | Default | What it does |
|
||||
| ----------------- | ------- | ----------------------------------------------------------------------- |
|
||||
| `--repo_id` | — | Hub dataset to annotate (downloaded if `--root` unset). |
|
||||
| `--root` | — | Annotate a local dataset directory instead. |
|
||||
| `--new_repo_id` | — | Push the result to a new repo (leaves the source repo untouched). |
|
||||
| `--push_to_hub` | `false` | Upload after annotating (to `--new_repo_id`, else back to `--repo_id`). |
|
||||
| `--only_episodes` | all | Annotate just these episode indices (handy for a test run). |
|
||||
| `--seed` | `1729` | Seeds the RNGs that pick interjection timestamps + VQA question types. |
|
||||
|
||||
### Which modules run
|
||||
|
||||
Every module is on by default and can be toggled independently (set to
|
||||
`false` to skip it, e.g. to iterate on one module at a time):
|
||||
|
||||
| Flag | Default | Turns off |
|
||||
| ------------------------- | ------- | ----------------------------------- |
|
||||
| `--plan.enabled` | `true` | subtasks + plan + memory + task_aug |
|
||||
| `--interjections.enabled` | `true` | interjections + speech atoms |
|
||||
| `--vqa.enabled` | `true` | the VQA pairs |
|
||||
|
||||
### The VLM (`--vlm.*`)
|
||||
|
||||
| Flag | Default | What it does |
|
||||
| -------------------------- | ------------------ | ----------------------------------------------------------------------------------- |
|
||||
| `--vlm.model_id` | `Qwen/Qwen3.6-27B` | The model to serve and prompt. |
|
||||
| `--vlm.camera_key` | first `images.*` | Which camera every prompt is grounded on. |
|
||||
| `--vlm.serve_command` | auto | The exact `vllm serve …` command (set TP size, GPU memory, `--max-model-len` here). |
|
||||
| `--vlm.parallel_servers` | `1` | Independent servers for round-robin routing (one per GPU). |
|
||||
| `--vlm.num_gpus` | `0` | GPUs per server (`0` = one each). |
|
||||
| `--vlm.client_concurrency` | `16` | In-flight requests across all servers. |
|
||||
| `--vlm.max_new_tokens` | `512` | Generation cap per call. |
|
||||
| `--vlm.temperature` | `0.2` | Sampling temperature. |
|
||||
|
||||
### Subtasks / plan / memory (`--plan.*`)
|
||||
|
||||
| Flag | Default | What it does |
|
||||
| ------------------------------- | ---------- | ------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `--plan.frames_per_second` | `2.0` | Frame sampling rate for the contact sheets (`2.0` = one frame every 0.5s). |
|
||||
| `--plan.max_frames_per_prompt` | `60` | Frame budget per VLM call. Episodes whose sampling exceeds this are auto-windowed at the same density, then stitched. |
|
||||
| `--plan.contact_sheet_columns` | `5` | Columns per contact-sheet grid (`contact_sheet_frames_per_sheet` tiles, time row-major). |
|
||||
| `--plan.plan_max_steps` | `8` | Upper bound on subtasks per episode. |
|
||||
| `--plan.subtask_describe_first` | `true` | Run the describe→segment grounding pass (best subtask quality; +1 call/episode). |
|
||||
| `--plan.emit_plan` | `true` | Emit the numbered `plan` rows (`false` = subtasks + memory only). |
|
||||
| `--plan.emit_memory` | `true` | Emit the `memory` rows (`false` = subtasks + plan only); symmetric to `emit_plan`. |
|
||||
| `--plan.n_task_rephrasings` | `10` | How many `task_aug` rephrasings to emit (`0` disables). |
|
||||
| `--plan.derive_task_from_video` | `if_short` | Use the dataset task as-is (`off`), only when it's missing/short (`if_short`), or always re-derive from video (`always`). |
|
||||
|
||||
### Interjections + VQA
|
||||
|
||||
| Flag | Default | What it does |
|
||||
| ----------------------------------------------- | ------- | ---------------------------------------------------------- |
|
||||
| `--interjections.max_interjections_per_episode` | `3` | Cap on interjection/speech pairs per episode. |
|
||||
| `--vqa.vqa_emission_hz` | `1.0` | How often VQA pairs are emitted. |
|
||||
| `--vqa.restrict_to_default_camera` | `false` | Ground VQA only on `--vlm.camera_key` (else every camera). |
|
||||
| `--executor.episode_parallelism` | `16` | Episodes processed concurrently within each phase. |
|
||||
|
||||
## Contributing new modules
|
||||
|
||||
The pipeline is built to grow, and **contributions are very welcome** —
|
||||
a brand-new module (say, trajectory traces or affordances), a new prompt
|
||||
template, a smarter grounding flow, or quality fixes to the existing
|
||||
`plan` / `interjections` / `vqa` modules.
|
||||
|
||||
Every module lives under
|
||||
`src/lerobot/annotations/steerable_pipeline/modules/`, shares the VLM
|
||||
client and the keyframe cache, writes its raw output to the staging
|
||||
tree, and plugs into the executor as its own phase. Got an idea? Open an
|
||||
issue or PR on [the repo](https://github.com/huggingface/lerobot).
|
||||
|
||||
## How recipes consume the output
|
||||
|
||||
The annotations are meant to be read by recipes (see
|
||||
[Language Columns and Recipes](./language_and_recipes)). Typically:
|
||||
|
||||
- low-level / high-level / memory-update branches read
|
||||
`subtask` / `plan` / `memory` from `language_persistent`.
|
||||
- an interjection-response branch reads `interjection` events plus the
|
||||
paired speech atom (merged into one assistant turn via `tool_calls_from`)
|
||||
and the matching `plan` refresh at the same timestamp.
|
||||
- a VQA branch reads the `(vqa, user)` and `(vqa, assistant)` pairs from
|
||||
`language_events`.
|
||||
|
||||
## Why state and events are split
|
||||
|
||||
Two ideas shape the design:
|
||||
|
||||
1. **Persistent state vs. exact events.** Persistent rows (`subtask`,
|
||||
`plan`, `memory`) apply to the whole episode and answer "what's true
|
||||
right now?". Event rows (`interjection`, `vqa`, speech) appear only on
|
||||
the one frame whose timestamp matches. Timestamps are copied straight
|
||||
from the source parquet — never recomputed in floating point.
|
||||
2. **One VLM pass.** All three modules share a single VLM client (the
|
||||
OpenAI-compatible client talking to the job's vLLM server), so you pay
|
||||
for one model load per dataset, not three.
|
||||
|
||||
## Re-running a single module
|
||||
|
||||
Each module stages its raw output to
|
||||
`<root>/.annotate_staging/episode_{N:06d}/<module>.jsonl`. This makes
|
||||
prompt iteration cheap: re-running one module overwrites only its own
|
||||
JSONL, then the writer recomposes the final parquet. Disable modules you
|
||||
don't want with `--plan.enabled=false` (and likewise
|
||||
`--interjections.enabled` / `--vqa.enabled`) to test one at a time.
|
||||
|
||||
## What the validator checks
|
||||
|
||||
Before the writer runs, `StagingValidator` confirms:
|
||||
|
||||
- every event row lands exactly on a real frame timestamp;
|
||||
- no speech / interjection pairs are left orphaned;
|
||||
- `plan` is refreshed at every interjection timestamp;
|
||||
- `memory` rows fall on subtask boundaries (a warning, not an error);
|
||||
- each VQA assistant `content` is valid JSON in one of the
|
||||
bbox / keypoint / count / attribute / spatial shapes;
|
||||
- every row goes to the column chosen by `column_for_style(style)`.
|
||||
|
||||
Any error aborts the writer. Pass `--skip_validation=true` to override
|
||||
while debugging.
|
||||
|
||||
## Where each module's ideas come from
|
||||
|
||||
- **`plan` — subtasks.** Hi Robot ([Shi 2025](https://arxiv.org/abs/2502.19417))
|
||||
for atom granularity ("pick up one piece of lettuce", "place bowl to
|
||||
box"); Pi0.7 ([Physical Intelligence 2025](https://pi.website/pi07))
|
||||
for "how, not what" detail.
|
||||
- **`plan` — memory.** MEM ([Torne 2026](https://arxiv.org/abs/2603.03596)):
|
||||
keep only the minimal relevant information — preserve outcomes, drop
|
||||
specific attributes.
|
||||
- **`interjections`.** Hi Robot's scenario taxonomy: negative task,
|
||||
situated correction, specific constraint, preference. Speech is a
|
||||
tool-call-only atom
|
||||
(`tool_calls=[{type:function, function:{name:"say", arguments:{text:...}}}]`).
|
||||
- **`vqa`.** ECoT ([Zawalski 2024](https://arxiv.org/abs/2407.08693)) for
|
||||
grounded features (pixel bounding boxes `[x_min, y_min, x_max, y_max]`,
|
||||
keypoints) and Steerable VLA Policies
|
||||
([Zhao 2025](https://arxiv.org/abs/2509.07626)) for multi-abstraction
|
||||
grounding. Pi0.7 also grounds answers across abstraction levels.
|
||||
|
||||
When improving a module, tweak its prompt template in
|
||||
`src/lerobot/annotations/steerable_pipeline/prompts/` rather than
|
||||
rewriting from scratch.
|
||||
|
||||
## Roughly how much it costs
|
||||
|
||||
Per episode, the pipeline makes about `max_steps` plan calls,
|
||||
`max_interjections_per_episode` interjection calls, and
|
||||
`vqa_emission_hz × episode_seconds` VQA calls. With the defaults (8
|
||||
subtasks, 1 interjection, 1 Hz × 3 pairs) on a 30-second episode, that's
|
||||
~50 VLM calls.
|
||||
|
||||
Storage stays small: `language_persistent` is at most tens of KB per
|
||||
episode (parquet dictionary-encodes the one entry that repeats across
|
||||
frames), and `language_events` is empty on most frames — its size scales
|
||||
with the number of emissions, not `num_frames × num_emissions`.
|
||||
@@ -1,313 +0,0 @@
|
||||
# Asynchronous Inference
|
||||
|
||||
With our [SmolVLA](https://huggingface.co/papers/2506.01844) we introduced a new way to run inference on real-world robots, **decoupling action prediction from action execution**.
|
||||
In this tutorial, we'll show how to use asynchronous inference (_async inference_) using a finetuned version of SmolVLA, and all the policies supported by LeRobot.
|
||||
**Try async inference with all the policies** supported by LeRobot!
|
||||
|
||||
**What you'll learn:**
|
||||
|
||||
1. Why asynchronous inference matters and how it compares to, more traditional, sequential inference.
|
||||
2. How to spin-up a `PolicyServer` and connect a `RobotClient` from the same machine, and even over the network.
|
||||
3. How to tune key parameters (`actions_per_chunk`, `chunk_size_threshold`) for your robot and policy.
|
||||
|
||||
If you get stuck, hop into our [Discord community](https://discord.gg/s3KuuzsPFb)!
|
||||
|
||||
In a nutshell: with _async inference_, your robot keeps acting while the policy server is already busy computing the next chunk of actions---eliminating "wait-for-inference" lags and unlocking smoother, more reactive behaviours.
|
||||
This is fundamentally different from synchronous inference (sync), where the robot stays idle while the policy computes the next chunk of actions.
|
||||
|
||||
---
|
||||
|
||||
## Getting started with async inference
|
||||
|
||||
You can read more information on asynchronous inference in our [blogpost](https://huggingface.co/blog/async-robot-inference). This guide is designed to help you quickly set up and run asynchronous inference in your environment.
|
||||
|
||||
First, install `lerobot` with the `async` tag, to install the extra dependencies required to run async inference.
|
||||
|
||||
```shell
|
||||
pip install -e ".[async]"
|
||||
```
|
||||
|
||||
Then, spin up a policy server (in one terminal, or in a separate machine) specifying the host address and port for the client to connect to.
|
||||
You can spin up a policy server running:
|
||||
|
||||
```shell
|
||||
python -m lerobot.async_inference.policy_server \
|
||||
--host=127.0.0.1 \
|
||||
--port=8080
|
||||
```
|
||||
|
||||
This will start a policy server listening on `127.0.0.1:8080` (`localhost`, port 8080). At this stage, the policy server is empty, as all information related to which policy to run and with which parameters are specified during the first handshake with the client. Spin up a client with:
|
||||
|
||||
```shell
|
||||
python -m lerobot.async_inference.robot_client \
|
||||
--server_address=127.0.0.1:8080 \ # SERVER: the host address and port of the policy server
|
||||
--robot.type=so100_follower \ # ROBOT: your robot type
|
||||
--robot.port=/dev/tty.usbmodem585A0076841 \ # ROBOT: your robot port
|
||||
--robot.id=follower_so100 \ # ROBOT: your robot id, to load calibration file
|
||||
--robot.cameras="{ laptop: {type: opencv, index_or_path: 0, width: 1920, height: 1080, fps: 30}, phone: {type: opencv, index_or_path: 0, width: 1920, height: 1080, fps: 30}}" \ # POLICY: the cameras used to acquire frames, with keys matching the keys expected by the policy
|
||||
--task="dummy" \ # POLICY: The task to run the policy on (`Fold my t-shirt`). Not necessarily defined for all policies, such as `act`
|
||||
--policy_type=your_policy_type \ # POLICY: the type of policy to run (smolvla, act, etc)
|
||||
--pretrained_name_or_path=user/model \ # POLICY: the model name/path on server to the checkpoint to run (e.g., lerobot/smolvla_base)
|
||||
--policy_device=mps \ # POLICY: the device to run the policy on, on the server (cuda, mps, xpu, cpu)
|
||||
--actions_per_chunk=50 \ # POLICY: the number of actions to output at once
|
||||
--chunk_size_threshold=0.5 \ # CLIENT: the threshold for the chunk size before sending a new observation to the server
|
||||
--aggregate_fn_name=weighted_average \ # CLIENT: the function to aggregate actions on overlapping portions
|
||||
--debug_visualize_queue_size=True # CLIENT: whether to visualize the queue size at runtime
|
||||
```
|
||||
|
||||
In summary, you need to specify instructions for:
|
||||
|
||||
- `SERVER`: the address and port of the policy server
|
||||
- `ROBOT`: the type of robot to connect to, the port to connect to, and the local `id` of the robot
|
||||
- `POLICY`: the type of policy to run, and the model name/path on server to the checkpoint to run. You also need to specify which device should the sever be using, and how many actions to output at once (capped at the policy max actions value).
|
||||
- `CLIENT`: the threshold for the chunk size before sending a new observation to the server, and the function to aggregate actions on overlapping portions. Optionally, you can also visualize the queue size at runtime, to help you tune the `CLIENT` parameters.
|
||||
|
||||
Importantly,
|
||||
|
||||
- `actions_per_chunk` and `chunk_size_threshold` are key parameters to tune for your setup.
|
||||
- `aggregate_fn_name` is the function to aggregate actions on overlapping portions. You can either add a new one to a registry of functions, or add your own in `robot_client.py` (see [here](NOTE:addlinktoLOC))
|
||||
- `debug_visualize_queue_size` is a useful tool to tune the `CLIENT` parameters.
|
||||
|
||||
## Done! You should see your robot moving around by now 😉
|
||||
|
||||
## Async vs. synchronous inference
|
||||
|
||||
Synchronous inference relies on interleaving action chunk prediction and action execution. This inherently results in _idle frames_, frames where the robot awaits idle the policy's output: a new action chunk.
|
||||
In turn, inference is plagued by evident real-time lags, where the robot simply stops acting due to the lack of available actions.
|
||||
With robotics models increasing in size, this problem risks becoming only more severe.
|
||||
|
||||
<p align="center">
|
||||
<img
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/async-inference/sync.png"
|
||||
width="80%"
|
||||
></img>
|
||||
</p>
|
||||
<p align="center">
|
||||
<i>Synchronous inference</i> makes the robot idle while the policy is
|
||||
computing the next chunk of actions.
|
||||
</p>
|
||||
|
||||
To overcome this, we design async inference, a paradigm where action planning and execution are decoupled, resulting in (1) higher adaptability and, most importantly, (2) no idle frames.
|
||||
Crucially, with async inference, the next action chunk is computed _before_ the current one is exhausted, resulting in no idleness.
|
||||
Higher adaptability is ensured by aggregating the different action chunks on overlapping portions, obtaining an up-to-date plan and a tighter control loop.
|
||||
|
||||
<p align="center">
|
||||
<img
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/async-inference/async.png"
|
||||
width="80%"
|
||||
></img>
|
||||
</p>
|
||||
<p align="center">
|
||||
<i>Asynchronous inference</i> results in no idleness because the next chunk is
|
||||
computed before the current chunk is exhausted.
|
||||
</p>
|
||||
|
||||
---
|
||||
|
||||
## Start the Policy Server
|
||||
|
||||
Policy servers are wrappers around a `PreTrainedPolicy` interfacing them with observations coming from a robot client.
|
||||
Policy servers are initialized as empty containers which are populated with the requested policy specified in the initial handshake between the robot client and the policy server.
|
||||
As such, spinning up a policy server is as easy as specifying the host address and port. If you're running the policy server on the same machine as the robot client, you can use `localhost` as the host address.
|
||||
|
||||
<hfoptions id="start_policy_server">
|
||||
<hfoption id="Command">
|
||||
```bash
|
||||
python -m lerobot.async_inference.policy_server \
|
||||
--host=127.0.0.1 \
|
||||
--port=8080
|
||||
```
|
||||
</hfoption>
|
||||
<hfoption id="API example">
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
from lerobot.async_inference.configs import PolicyServerConfig
|
||||
from lerobot.async_inference.policy_server import serve
|
||||
|
||||
config = PolicyServerConfig(
|
||||
host="localhost",
|
||||
port=8080,
|
||||
)
|
||||
serve(config)
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
This listens on `localhost:8080` for an incoming connection from the associated`RobotClient`, which will communicate which policy to run during the first client-server handshake.
|
||||
|
||||
---
|
||||
|
||||
## Launch the Robot Client
|
||||
|
||||
`RobotClient` is a wrapper around a `Robot` instance, which `RobotClient` connects to the (possibly remote) `PolicyServer`.
|
||||
The `RobotClient` streams observations to the `PolicyServer`, and receives action chunks obtained running inference on the server (which we assume to have better computational resources than the robot controller).
|
||||
|
||||
<hfoptions id="start_robot_client">
|
||||
<hfoption id="Command">
|
||||
```bash
|
||||
python -m lerobot.async_inference.robot_client \
|
||||
--server_address=127.0.0.1:8080 \ # SERVER: the host address and port of the policy server
|
||||
--robot.type=so100_follower \ # ROBOT: your robot type
|
||||
--robot.port=/dev/tty.usbmodem585A0076841 \ # ROBOT: your robot port
|
||||
--robot.id=follower_so100 \ # ROBOT: your robot id, to load calibration file
|
||||
--robot.cameras="{ laptop: {type: opencv, index_or_path: 0, width: 1920, height: 1080, fps: 30}, phone: {type: opencv, index_or_path: 0, width: 1920, height: 1080, fps: 30}}" \ # POLICY: the cameras used to acquire frames, with keys matching the keys expected by the policy
|
||||
--task="dummy" \ # POLICY: The task to run the policy on (`Fold my t-shirt`). Not necessarily defined for all policies, such as `act`
|
||||
--policy_type=your_policy_type \ # POLICY: the type of policy to run (smolvla, act, etc)
|
||||
--pretrained_name_or_path=user/model \ # POLICY: the model name/path on server to the checkpoint to run (e.g., lerobot/smolvla_base)
|
||||
--policy_device=mps \ # POLICY: the device to run the policy on, on the server
|
||||
--actions_per_chunk=50 \ # POLICY: the number of actions to output at once
|
||||
--chunk_size_threshold=0.5 \ # CLIENT: the threshold for the chunk size before sending a new observation to the server
|
||||
--aggregate_fn_name=weighted_average \ # CLIENT: the function to aggregate actions on overlapping portions
|
||||
--debug_visualize_queue_size=True # CLIENT: whether to visualize the queue size at runtime
|
||||
```
|
||||
</hfoption>
|
||||
<hfoption id="API example">
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
import threading
|
||||
from lerobot.robots.so_follower import SO100FollowerConfig
|
||||
from lerobot.cameras.opencv import OpenCVCameraConfig
|
||||
from lerobot.async_inference.configs import RobotClientConfig
|
||||
from lerobot.async_inference.robot_client import RobotClient
|
||||
from lerobot.async_inference.helpers import visualize_action_queue_size
|
||||
|
||||
# 1. Create the robot instance
|
||||
"""Check out the cameras available in your setup by running `python lerobot/find_cameras.py`"""
|
||||
# these cameras must match the ones expected by the policy
|
||||
# check the config.json on the Hub for the policy you are using
|
||||
camera_cfg = {
|
||||
"top": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
|
||||
"side": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30)
|
||||
}
|
||||
|
||||
robot_cfg = SO100FollowerConfig(
|
||||
port="/dev/tty.usbmodem585A0076841",
|
||||
id="follower_so100",
|
||||
cameras=camera_cfg
|
||||
)
|
||||
|
||||
# 3. Create client configuration
|
||||
client_cfg = RobotClientConfig(
|
||||
robot=robot_cfg,
|
||||
server_address="localhost:8080",
|
||||
policy_device="mps",
|
||||
client_device="cpu",
|
||||
policy_type="smolvla",
|
||||
pretrained_name_or_path="<user>/smolvla_async",
|
||||
chunk_size_threshold=0.5,
|
||||
actions_per_chunk=50, # make sure this is less than the max actions of the policy
|
||||
)
|
||||
|
||||
# 4. Create and start client
|
||||
client = RobotClient(client_cfg)
|
||||
|
||||
# 5. Specify the task
|
||||
task = "Don't do anything, stay still"
|
||||
|
||||
if client.start():
|
||||
# Start action receiver thread
|
||||
action_receiver_thread = threading.Thread(target=client.receive_actions, daemon=True)
|
||||
action_receiver_thread.start()
|
||||
|
||||
try:
|
||||
# Run the control loop
|
||||
client.control_loop(task)
|
||||
except KeyboardInterrupt:
|
||||
client.stop()
|
||||
action_receiver_thread.join()
|
||||
# (Optionally) plot the action queue size
|
||||
visualize_action_queue_size(client.action_queue_size)
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
The following two parameters are key in every setup:
|
||||
|
||||
<table>
|
||||
<thead>
|
||||
<tr>
|
||||
<th>Hyperparameter</th>
|
||||
<th>Default</th>
|
||||
<th>What it does</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
<tr>
|
||||
<td>
|
||||
<code>actions_per_chunk</code>
|
||||
</td>
|
||||
<td>50</td>
|
||||
<td>
|
||||
How many actions the policy outputs at once. Typical values: 10-50.
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>
|
||||
<code>chunk_size_threshold</code>
|
||||
</td>
|
||||
<td>0.7</td>
|
||||
<td>
|
||||
When the queue is ≤ 50% full, the client sends a fresh observation.
|
||||
Value in [0, 1].
|
||||
</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
|
||||
<Tip>
|
||||
Different values of `actions_per_chunk` and `chunk_size_threshold` do result
|
||||
in different behaviours.
|
||||
</Tip>
|
||||
|
||||
On the one hand, increasing the value of `actions_per_chunk` will result in reducing the likelihood of ending up with no actions to execute, as more actions will be available when the new chunk is computed.
|
||||
However, larger values of `actions_per_chunk` might also result in less precise actions, due to the compounding errors consequent to predicting actions over longer timespans.
|
||||
|
||||
On the other hand, increasing the value of `chunk_size_threshold` will result in sending out to the `PolicyServer` observations for inference more often, resulting in a larger number of updates action chunks, overlapping on significant portions. This results in high adaptability, in the limit predicting one action chunk for each observation, which is in turn only marginally consumed while a new one is produced.
|
||||
This option does also put more pressure on the inference pipeline, as a consequence of the many requests. Conversely, values of `chunk_size_threshold` close to 0.0 collapse to the synchronous edge case, whereby new observations are only sent out whenever the current chunk is exhausted.
|
||||
|
||||
We found the default values of `actions_per_chunk` and `chunk_size_threshold` to work well in the experiments we developed for the [SmolVLA paper](https://huggingface.co/papers/2506.01844), but recommend experimenting with different values to find the best fit for your setup.
|
||||
|
||||
### Tuning async inference for your setup
|
||||
|
||||
1. **Choose your computational resources carefully.** [PI0](https://huggingface.co/lerobot/pi0) occupies 14GB of memory at inference time, while [SmolVLA](https://huggingface.co/lerobot/smolvla_base) requires only ~2GB. You should identify the best computational resource for your use case keeping in mind smaller policies require less computational resources. The combination of policy and device used (CPU-intensive, using MPS, or the number of CUDA cores on a given NVIDIA GPU) directly impacts the average inference latency you should expect.
|
||||
2. **Adjust your `fps` based on inference latency.** While the server generates a new action chunk, the client is not idle and is stepping through its current action queue. If the two processes happen at fundamentally different speeds, the client might end up with an empty queue. As such, you should reduce your fps if you consistently run out of actions in queue.
|
||||
3. **Adjust `chunk_size_threshold`**.
|
||||
- Values closer to `0.0` result in almost sequential behavior. Values closer to `1.0` → send observation every step (more bandwidth, relies on good world-model).
|
||||
- We found values around 0.5-0.6 to work well. If you want to tweak this, spin up a `RobotClient` setting the `--debug_visualize_queue_size` to `True`. This will plot the action queue size evolution at runtime, and you can use it to find the value of `chunk_size_threshold` that works best for your setup.
|
||||
|
||||
<p align="center">
|
||||
<img
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/async-inference/queues.png"
|
||||
width="80%"
|
||||
></img>
|
||||
</p>
|
||||
<p align="center">
|
||||
<i>
|
||||
The action queue size is plotted at runtime when the
|
||||
`--debug_visualize_queue_size` flag is passed, for various levels of
|
||||
`chunk_size_threshold` (`g` in the SmolVLA paper).
|
||||
</i>
|
||||
</p>
|
||||
|
||||
---
|
||||
|
||||
## Conclusion
|
||||
|
||||
Asynchronous inference represents a significant advancement in real-time robotics control, addressing the fundamental challenge of inference latency that has long plagued robotics applications. Through this tutorial, you've learned how to implement a complete async inference pipeline that eliminates idle frames and enables smoother, more reactive robot behaviors.
|
||||
|
||||
**Key Takeaways:**
|
||||
|
||||
- **Paradigm Shift**: Async inference decouples action prediction from execution, allowing robots to continue acting while new action chunks are computed in parallel
|
||||
- **Performance Benefits**: Eliminates "wait-for-inference" lags that are inherent in synchronous approaches, becoming increasingly important as policy models grow larger
|
||||
- **Flexible Architecture**: The server-client design enables distributed computing, where inference can run on powerful remote hardware while maintaining real-time robot control
|
||||
- **Tunable Parameters**: Success depends on properly configuring `actions_per_chunk` and `chunk_size_threshold` for your specific hardware, policy, and task requirements
|
||||
- **Universal Compatibility**: Works with all LeRobot-supported policies, from lightweight ACT models to vision-language models like SmolVLA
|
||||
|
||||
Start experimenting with the default parameters, monitor your action queue sizes, and iteratively refine your setup to achieve optimal performance for your specific use case.
|
||||
If you want to discuss this further, hop into our [Discord community](https://discord.gg/s3KuuzsPFb), or open an issue on our [GitHub repository](https://github.com/huggingface/lerobot/issues).
|
||||
@@ -0,0 +1,250 @@
|
||||
# Remote Inference (lerobot-policy-server)
|
||||
|
||||
Remote inference decouples GPU policy inference from robot control. A `lerobot-policy-server` process runs the policy on a GPU machine; the robot runs `lerobot-rollout --inference.type=remote` as a **weightless edge client** — no policy weights, no GPU, no policy processors on the robot. One GPU server can serve several robots at once, and the remote backend works with every rollout strategy (`base`, `sentry`, `highlight`, `dagger`, `episodic`).
|
||||
|
||||
Use remote inference when:
|
||||
|
||||
- The policy is too large or too slow for the machine attached to the robot (e.g. Pi0/Pi0.5 on a Raspberry Pi or laptop edge).
|
||||
- You want one GPU to serve a fleet of robots running the same policy.
|
||||
- You want to update or restart the inference side without touching the robots.
|
||||
|
||||
<Tip>
|
||||
|
||||
Remote inference requires the `async` extra on **both** sides: `pip install 'lerobot[async]'` (installs `eclipse-zenoh` and `msgpack`). The server additionally needs the extras of the policy it serves (e.g. `lerobot[pi]`, `lerobot[smolvla]`).
|
||||
|
||||
</Tip>
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
robot (edge, weightless) GPU machine
|
||||
┌───────────────────────────┐ ┌────────────────────────────┐
|
||||
│ lerobot-rollout │ │ lerobot-policy-server │
|
||||
│ --inference.type=remote │ zenoh │ one process = one │
|
||||
│ │ router │ (model, revision, GPU) │
|
||||
│ control loop @ fps │ ┌────────┐ │ │
|
||||
│ └─ pops local action ◄──┼───┤ zenohd ├─────┼─► inference worker thread │
|
||||
│ buffer (chunks) │ └────────┘ │ (round-robin over │
|
||||
│ │ observations ► │ client sessions) │
|
||||
│ network worker thread ───┼──► ◄ action │ │
|
||||
│ (publishes obs, merges │ chunks │ stateless per request │
|
||||
│ chunks into buffer) │ │ │
|
||||
└───────────────────────────┘ └────────────────────────────┘
|
||||
```
|
||||
|
||||
The client keeps a local **action buffer** filled with chunks of future actions, so the control loop never blocks on the network: short network blips are absorbed by the buffer and the robot keeps moving. The client self-clocks — it requests a new chunk whenever the buffer holds less than `--inference.buffer_time_s` seconds of playback.
|
||||
|
||||
The server is **stateless per request**: clients ship their RTC prefixes and a delay hint with every observation, so a server crash or restart loses zero control state and reconnects are trivial. In production both robots and servers _dial out_ to a `zenohd` router (NAT-friendly: nothing on the robot network needs an open inbound port).
|
||||
|
||||
## Quickstart on a LAN (peer mode, no router)
|
||||
|
||||
For a quick test on one network you can skip the router: the server listens directly and the robot connects to it.
|
||||
|
||||
On the GPU machine:
|
||||
|
||||
```bash
|
||||
lerobot-policy-server \
|
||||
--model.repo_or_path=${HF_USER}/my_pi0_policy \
|
||||
--default_task="pick up the cube" \
|
||||
--zenoh.mode=peer \
|
||||
--zenoh.listen_endpoints='["tcp/0.0.0.0:7447"]'
|
||||
```
|
||||
|
||||
Wait for `Policy server up: ...` (the model is downloaded, loaded, and warmed up first).
|
||||
|
||||
On the robot machine (replace `192.168.1.42` with the GPU machine's IP):
|
||||
|
||||
```bash
|
||||
lerobot-rollout \
|
||||
--strategy.type=base \
|
||||
--policy.path=${HF_USER}/my_pi0_policy \
|
||||
--inference.type=remote \
|
||||
--inference.zenoh_mode=peer \
|
||||
--inference.connect_endpoint=tcp/192.168.1.42:7447 \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/ttyACM0 \
|
||||
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||
--task="pick up the cube" \
|
||||
--duration=60
|
||||
```
|
||||
|
||||
`--policy.path` on the client resolves to a config-only download (no weights): it is used for pre-flight validation and action ordering, and doubles as the default service address. The client's `--policy.path` and `--task` must match the server's `--model.repo_or_path` and `--default_task` — that pair is the namespace the service is published under (see [Troubleshooting](#troubleshooting)).
|
||||
|
||||
## Production deployment (router)
|
||||
|
||||
In production, run a [zenoh router](https://zenoh.io/docs/getting-started/installation/) (`zenohd`) somewhere both sides can reach, and have robots and servers dial out to it:
|
||||
|
||||
```bash
|
||||
zenohd # listens on tcp/0.0.0.0:7447 by default
|
||||
```
|
||||
|
||||
Configure the server with a YAML manifest:
|
||||
|
||||
```yaml
|
||||
# server.yaml
|
||||
model:
|
||||
repo_or_path: lerobot/pi0_towels
|
||||
revision: main
|
||||
dtype: bfloat16 # optional cast after load
|
||||
device: cuda
|
||||
default_task: "fold the towel"
|
||||
serving_mode: auto # shared for verified chunk-stateless policies, exclusive otherwise
|
||||
max_sessions: 5
|
||||
warmup_inferences: 2
|
||||
trained_fps: 30.0
|
||||
rtc:
|
||||
enabled: true
|
||||
execution_horizon: 10
|
||||
max_guidance_weight: 10.0
|
||||
health_port: 9100 # /healthz + /metrics; 0 disables
|
||||
zenoh:
|
||||
mode: client
|
||||
connect_endpoints: ["tcp/router.gpu-cluster.internal:7447"]
|
||||
```
|
||||
|
||||
```bash
|
||||
lerobot-policy-server --manifest server.yaml
|
||||
```
|
||||
|
||||
Everything in the manifest can also be set directly on the CLI (`--model.repo_or_path=...`, `--max_sessions=...`, etc.). One process serves exactly one `(model, revision, dtype, device)` — to serve two models, or one model on two GPUs, run two processes. Dynamic model loading is deliberately unsupported: pre-warmed processes keep capacity planning honest.
|
||||
|
||||
On the robot, only the endpoint changes (the default `--inference.zenoh_mode=client` is already router mode):
|
||||
|
||||
```bash
|
||||
lerobot-rollout \
|
||||
--strategy.type=base \
|
||||
--policy.path=lerobot/pi0_towels \
|
||||
--inference.type=remote \
|
||||
--inference.connect_endpoint=tcp/router.gpu-cluster.internal:7447 \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/ttyACM0 \
|
||||
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||
--task="fold the towel" \
|
||||
--duration=600
|
||||
```
|
||||
|
||||
### TLS / mTLS
|
||||
|
||||
For traffic that leaves a trusted network, terminate TLS at the router and give both sides client certificates (all three PEM paths are required together):
|
||||
|
||||
```yaml
|
||||
# server.yaml (zenoh section)
|
||||
zenoh:
|
||||
mode: client
|
||||
connect_endpoints: ["tls/router.gpu-cluster.internal:7447"]
|
||||
tls_root_ca_certificate: /etc/lerobot/ca.pem
|
||||
tls_connect_certificate: /etc/lerobot/server.pem
|
||||
tls_connect_private_key: /etc/lerobot/server.key
|
||||
```
|
||||
|
||||
On the robot the equivalent flags are `--inference.tls_ca`, `--inference.tls_cert`, and `--inference.tls_key`, with `--inference.connect_endpoint=tls/...`.
|
||||
|
||||
<Tip>
|
||||
|
||||
Multicast scouting is always disabled: discovery is configuration, not protocol magic. If nothing connects, check the endpoints — there is no fallback discovery mechanism.
|
||||
|
||||
</Tip>
|
||||
|
||||
## RTC over the network
|
||||
|
||||
The remote engine reuses the [Real-Time Chunking](./rtc) machinery: the client keeps the chunk leftover and latency tracking locally and ships an action prefix plus a delay hint with every observation; the server runs prefix-conditioned chunk generation. This gives the same smooth chunk-to-chunk transitions as local RTC, with network latency folded into the delay computation.
|
||||
|
||||
RTC is enabled by default on both sides (`rtc.enabled: true`). Tune it from the client:
|
||||
|
||||
```bash
|
||||
lerobot-rollout \
|
||||
... \
|
||||
--inference.type=remote \
|
||||
--inference.rtc.execution_horizon=10 \
|
||||
--inference.rtc.max_guidance_weight=10.0
|
||||
```
|
||||
|
||||
If the server or its policy does not support RTC (only `pi0`, `pi05`, and `smolvla` are RTC-capable, and the server manifest must have `rtc.enabled: true`), the session is **downgraded to plain chunk-append** and the client logs:
|
||||
|
||||
```
|
||||
RTC downgraded to chunk-append (server does not support RTC)
|
||||
```
|
||||
|
||||
The robot still runs — chunks are simply appended to the buffer without prefix blending, which can produce visible seams between chunks on slow policies.
|
||||
|
||||
## Fail-safe behavior
|
||||
|
||||
The client runs a fail-safe state machine (`CONNECTING → STREAMING → DEGRADED → STALLED → RECONNECTING → DEAD`). A bad initial deployment fails fast: `lerobot-rollout` aborts before the robot moves if the handshake or validation fails. Once streaming, faults degrade in stages:
|
||||
|
||||
| Condition | Behavior |
|
||||
| -------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| Short network blip / late chunk | The robot rides its action buffer; state goes `DEGRADED` after `--inference.degraded_after_s` (default 1.0 s) without a fresh chunk |
|
||||
| Buffered actions older than `max_action_age_s` | Stale actions are dropped (never executed); default `--inference.max_action_age_s=3.0` |
|
||||
| Buffer runs dry (`STALLED`) | Fallback per `--inference.fallback`: `hold` (default — robot holds its last commanded position), `repeat_last`, or `zero` |
|
||||
| Server liveliness lost / repeated request timeouts | `RECONNECTING`: re-handshake with exponential backoff (`reconnect_initial_backoff_s=0.5` doubling up to `reconnect_max_backoff_s=10.0`) |
|
||||
| Reconnected server runs a different model/revision | Hard refusal (`DEAD`) — the client never executes wrong-model chunks |
|
||||
| Offline longer than `max_offline_s` (default 60 s) | `DEAD`: the engine signals the rollout's shutdown event for a clean stop |
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
`--inference.fallback=zero` is required for velocity-controlled robots: for them "send nothing" means "keep the last velocity", so an explicit zero command is the only safe stop. For position-controlled arms the default `hold` is safe.
|
||||
|
||||
</Tip>
|
||||
|
||||
Server restarts are equally graceful: on SIGTERM the server drops its liveliness token first (clients ride their buffers through the drain), finishes the in-flight inference, and exits. Clients reconnect when the replacement comes up.
|
||||
|
||||
## Serving multiple robots
|
||||
|
||||
`max_sessions` caps concurrent clients per server process. A single inference worker thread serializes GPU access and round-robins over sessions with a pending observation; per-client newest-wins mailboxes mean overload degrades into longer cycle times (larger but correct client-side delays), never into queue buildup.
|
||||
|
||||
A rough capacity estimate, keeping ~20% headroom:
|
||||
|
||||
```
|
||||
N_robots ≈ 0.8 / (rate × inference_time)
|
||||
```
|
||||
|
||||
where `rate` is each robot's chunk-request rate in Hz (how often the client's buffer dips below `buffer_time_s`) and `inference_time` is the server's seconds per chunk. For example, at 100 ms per chunk and ~2 chunk requests per second per robot: `N ≈ 0.8 / (2 × 0.1) = 4` robots.
|
||||
|
||||
The actual serving mode is classified per policy family, never inferred:
|
||||
|
||||
- **shared** — verified chunk-stateless policies (`act`, `pi0`, `pi05`, and `smolvla` with `n_obs_steps=1`) serve up to `max_sessions` clients from one policy instance.
|
||||
- **exclusive** — stateful families (diffusion-family policies, `smolvla` with observation history, and any unverified policy) are forced to `max_sessions=1`. Run one server process per robot for these.
|
||||
|
||||
`serving_mode: auto` (the default) resolves this automatically; you may force `exclusive`, but `shared` can never override a stateful classification.
|
||||
|
||||
## Observability
|
||||
|
||||
With `health_port` set (default 9100), the server exposes:
|
||||
|
||||
- `GET /healthz` — `200 ok` while the inference worker is alive, `503` otherwise. Wire this to your orchestrator's liveness probe.
|
||||
- `GET /metrics` — Prometheus text format: `lerobot_policy_server_requests_total`, `errors_total`, `superseded_total`, `dropped_unknown_client_total`, `sessions_opened_total`, `sessions_closed_total`, `active_sessions`, `server_load`.
|
||||
|
||||
Every inference request also emits one structured audit line on the `lerobot.policy_server.audit` logger:
|
||||
|
||||
```json
|
||||
{
|
||||
"session_id": "9f2c...",
|
||||
"client_uuid": "robot-07",
|
||||
"seq_id": 412,
|
||||
"episode_id": 3,
|
||||
"queue_wait_ms": 1.8,
|
||||
"inference_ms": 93.2,
|
||||
"superseded": 0,
|
||||
"outcome": "ok"
|
||||
}
|
||||
```
|
||||
|
||||
`(session_id, seq_id)` correlates a server-side audit line with the client's request. Set a stable `--inference.client_uuid` per robot (instead of the default fresh UUID per run) for fleet-wide log correlation, and use `--inference.tags` to forward free-form labels in the handshake.
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
**`No policy server answered status query at '@lerobot/...'`**
|
||||
|
||||
The client found no server under the key it dialed. Either the endpoint is wrong (check `--inference.connect_endpoint`, the router, and firewalls), or the **service namespace** does not match. The namespace is the `(model_id, revision, task)` triple: on the client it comes from `--inference.service_model_id` (default: `--policy.path`), `--inference.service_revision` (default: `main`), and `--inference.service_task` (default: the rollout `--task`); on the server from `model.repo_or_path`, `model.revision`, and `service_name` (default: a slug of `default_task`). A robot task string that differs from the server's `default_task` is the most common cause — fix the task, or pin the namespace explicitly with `--inference.service_task` on the client / `service_name` in the manifest.
|
||||
|
||||
**`Action name/order mismatch between server policy and this robot`**
|
||||
|
||||
The hard sync-safety contract: chunk columns map to motors **by order**, so the robot's ordered action keys must exactly equal the policy's `action_feature_names`. This fires when the robot type, motor naming, or rename map differs from the training setup. Use the same robot type (and rename map) the policy was trained with.
|
||||
|
||||
**`RTC requested but this server/policy does not support it — downgrading to chunk-append`**
|
||||
|
||||
Informational, not fatal. Enable RTC in the server manifest (`rtc.enabled: true`) and make sure the policy family is RTC-capable (`pi0`, `pi05`, `smolvla`). Otherwise, expect chunk-append behavior (see [RTC over the network](#rtc-over-the-network)).
|
||||
|
||||
**`server full: N/N sessions active`**
|
||||
|
||||
The session-open was rejected at capacity. Raise `max_sessions` (shared mode only), or point the robot at another server replica — the rejection includes the current load so orchestration can retry elsewhere.
|
||||
+9
-9
@@ -151,18 +151,18 @@ lerobot-rollout \
|
||||
--device=cuda
|
||||
```
|
||||
|
||||
## How It Differs from the Async Inference in LeRobot
|
||||
## How It Relates to Remote Inference
|
||||
|
||||
Both RTC and [async inference](./async) improve real-time robot control, but they solve different problems.
|
||||
Both RTC and [remote inference](./remote_inference) improve real-time robot control, but they solve different problems.
|
||||
|
||||
| Aspect | Async Inference | RTC |
|
||||
| ------------- | -------------------------------------------------------------------------- | --------------------------------------------------- |
|
||||
| **Problem** | Idle frames while waiting for inference | Discontinuities between action chunks |
|
||||
| **Solution** | Decouple prediction from execution | Guide new chunks to continue smoothly from previous |
|
||||
| **Benefit** | No waiting, continuous action | Smooth transitions, natural motion |
|
||||
| **Best Used** | Async inference is best used with large models with high inference latency | Flow-matching based policies |
|
||||
| Aspect | Remote Inference | RTC |
|
||||
| ------------- | ------------------------------------------------------------------------ | --------------------------------------------------- |
|
||||
| **Problem** | The policy is too large (or too slow) for the edge machine | Discontinuities between action chunks |
|
||||
| **Solution** | Run inference on a GPU server; the robot executes buffered action chunks | Guide new chunks to continue smoothly from previous |
|
||||
| **Benefit** | Weightless edge clients, one GPU serves many robots | Smooth transitions, natural motion |
|
||||
| **Best Used** | Large models with high inference latency, robot fleets | Flow-matching based policies |
|
||||
|
||||
**Use both together** for maximum smoothness and reactivity!
|
||||
**Use both together** (`--inference.type=remote` with `--inference.rtc.execution_horizon=...`) for maximum smoothness and reactivity: the remote engine reuses RTC's chunk-merging machinery client-side while the server runs prefix-conditioned chunk generation.
|
||||
|
||||
## Advanced: Debug Tracking
|
||||
|
||||
|
||||
@@ -1,77 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Launch ``lerobot-annotate`` on a Hugging Face job (vllm + Qwen3.6-27B VLM).
|
||||
|
||||
Spawns one single-GPU ``h200`` job that:
|
||||
|
||||
1. installs ``lerobot`` from ``main`` plus the annotation extras,
|
||||
2. boots one vllm server with Qwen3.6-27B (dense VLM),
|
||||
3. runs the plan / interjections / vqa modules across the dataset
|
||||
in free-form mode (each episode generates its own subtasks +
|
||||
memory),
|
||||
4. uploads the annotated dataset to ``--new_repo_id`` (when set)
|
||||
or back to ``--repo_id``.
|
||||
|
||||
Usage:
|
||||
|
||||
HF_TOKEN=hf_... uv run python examples/annotations/run_hf_job.py
|
||||
|
||||
Adjust ``CMD`` (dataset, model, hub repo) and ``flavor`` below for your
|
||||
run. For larger datasets, scale to ``h200x4`` and raise
|
||||
``--vlm.parallel_servers`` / ``--vlm.num_gpus`` to match.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from huggingface_hub import get_token, run_job
|
||||
|
||||
token = os.environ.get("HF_TOKEN") or get_token()
|
||||
if not token:
|
||||
raise RuntimeError("No HF token. Run `huggingface-cli login` or `export HF_TOKEN=hf_...`")
|
||||
|
||||
CMD = (
|
||||
"apt-get update -qq && apt-get install -y -qq git ffmpeg && "
|
||||
"pip install --no-deps "
|
||||
"'lerobot @ git+https://github.com/huggingface/lerobot.git@main' && "
|
||||
"pip install --upgrade-strategy only-if-needed "
|
||||
"datasets pyarrow av jsonlines draccus gymnasium torchcodec mergedeep pyyaml-include toml typing-inspect "
|
||||
"openai && "
|
||||
"export VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS=0 && "
|
||||
"export VLLM_VIDEO_BACKEND=pyav && "
|
||||
"lerobot-annotate "
|
||||
"--repo_id=pepijn223/robocasa_pretrain_human300_v4 "
|
||||
"--new_repo_id=pepijn223/robocasa_pretrain_human300_v4_annotated "
|
||||
"--push_to_hub=true "
|
||||
"--vlm.backend=openai "
|
||||
"--vlm.model_id=Qwen/Qwen3.6-27B "
|
||||
"--vlm.num_gpus=1 "
|
||||
'--vlm.serve_command="vllm serve Qwen/Qwen3.6-27B '
|
||||
"--tensor-parallel-size 1 --max-model-len 32768 "
|
||||
'--gpu-memory-utilization 0.8 --uvicorn-log-level warning --port {port}" '
|
||||
"--vlm.serve_ready_timeout_s=1800 "
|
||||
# Qwen3.6 ships with thinking on; annotation wants plain JSON answers.
|
||||
"--vlm.chat_template_kwargs='{\"enable_thinking\": false}'"
|
||||
)
|
||||
|
||||
job = run_job(
|
||||
image="vllm/vllm-openai:latest",
|
||||
command=["bash", "-c", CMD],
|
||||
flavor="h200",
|
||||
secrets={"HF_TOKEN": token},
|
||||
timeout="2h",
|
||||
)
|
||||
print(f"Job URL: {job.url}")
|
||||
print(f"Job ID: {job.id}")
|
||||
@@ -0,0 +1,115 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Example manifest for `lerobot-policy-server --manifest server.yaml`.
|
||||
#
|
||||
# One process = one (model, revision, dtype, device) on one GPU. Dynamic
|
||||
# model loading is deliberately unsupported: pre-warmed processes keep
|
||||
# capacity planning honest. Every field below can also be overridden on
|
||||
# the command line via draccus, e.g. --model.repo_or_path=... or
|
||||
# --zenoh.connect_endpoints='["tcp/other-router:7447"]'.
|
||||
#
|
||||
# Field names mirror the dataclasses in src/lerobot/policy_server/manifest.py.
|
||||
|
||||
# --- Which policy this process serves, and where it runs ------------------
|
||||
model:
|
||||
# Hub repo id (org/name) or a local checkpoint directory. Required.
|
||||
repo_or_path: lerobot/pi0_towels
|
||||
# Hub revision: branch, tag, or commit sha.
|
||||
revision: main
|
||||
# Optional torch dtype cast applied after load (e.g. "bfloat16",
|
||||
# "float16"). null keeps the checkpoint's native dtype.
|
||||
dtype: bfloat16
|
||||
# Inference device, e.g. "cuda", "cuda:1", "cpu".
|
||||
device: cuda
|
||||
|
||||
# --- Task namespace --------------------------------------------------------
|
||||
# The task this service is published under. VLA clients may override the
|
||||
# task per session unless `pin_task` is true, in which case session opens
|
||||
# with a different task string are rejected.
|
||||
default_task: "fold the towel"
|
||||
pin_task: false
|
||||
# Optional override for the <task_slug> key segment of the Zenoh prefix
|
||||
# (defaults to a slug of `default_task`).
|
||||
service_name: ""
|
||||
|
||||
# --- Serving mode & capacity ------------------------------------------------
|
||||
# "auto" resolves from the policy classification: shared for verified
|
||||
# chunk-stateless policies (act/pi0/pi05, smolvla with n_obs_steps=1),
|
||||
# exclusive otherwise. Chunk-stateful policies — e.g. diffusion, whose
|
||||
# predict_action_chunk reads select_action-fed queues — are always forced
|
||||
# to "exclusive" (max_sessions=1); "shared" cannot override that.
|
||||
serving_mode: auto
|
||||
|
||||
# Capacity rule-of-thumb: with t = server seconds per inference, r = each
|
||||
# client's request rate (self-clocked to ~1-4 Hz, not the control rate),
|
||||
# H = RTC execution horizon, and dt = control period:
|
||||
# max_sessions ~= min( 0.8 / (r*t), (H*dt/2 - network RTT) / t )
|
||||
# e.g. ACT @ 20 ms, 1 Hz refresh -> ~40 clients/GPU; Pi0 @ 150 ms -> ~5.
|
||||
# Session opens beyond this are rejected with the current load in the
|
||||
# reply, so clients retry another replica.
|
||||
max_sessions: 5
|
||||
|
||||
# Dummy inferences run at startup so the first real request does not pay
|
||||
# for CUDA graph/kernel warmup.
|
||||
warmup_inferences: 2
|
||||
|
||||
# --- FPS contract -----------------------------------------------------------
|
||||
# Control rate the policy was trained at. Clients reporting a different
|
||||
# fps get a warning — or a hard reject when `strict_fps` is true.
|
||||
trained_fps: 30.0
|
||||
strict_fps: false
|
||||
|
||||
# --- Real Time Chunking (RTC) -----------------------------------------------
|
||||
# Global to this process: init_rtc_processor mutates the policy instance,
|
||||
# so RTC is a per-process decision, not per-session. Only rtc-capable
|
||||
# families (pi0/pi05/smolvla) honor it; others are downgraded to plain
|
||||
# chunk-append at session open.
|
||||
rtc:
|
||||
enabled: true
|
||||
# Number of actions executed from each chunk before the next chunk is
|
||||
# blended in (the H in the capacity formula above).
|
||||
execution_horizon: 10
|
||||
|
||||
# --- Housekeeping ------------------------------------------------------------
|
||||
# Sessions with no liveliness token and no traffic for this long are
|
||||
# garbage-collected (belt-and-braces behind liveliness GC).
|
||||
session_idle_timeout_s: 300.0
|
||||
|
||||
# --- Transport ----------------------------------------------------------------
|
||||
# Robots and servers both *dial out* to a zenohd router in production
|
||||
# (mode: client). mode: peer + listen_endpoints supports router-less LAN
|
||||
# and loopback test deployments. Multicast scouting is always disabled:
|
||||
# fleet discovery is configuration, not protocol magic.
|
||||
zenoh:
|
||||
mode: client
|
||||
connect_endpoints:
|
||||
- tcp/router.gpu-cluster.internal:7447
|
||||
listen_endpoints: []
|
||||
# mTLS material (PEM paths). All three are required for tls/ endpoints;
|
||||
# leave them null for plain tcp/ inside a trusted network.
|
||||
# tls_root_ca_certificate: /etc/lerobot/tls/ca.pem
|
||||
# tls_connect_certificate: /etc/lerobot/tls/server.pem
|
||||
# tls_connect_private_key: /etc/lerobot/tls/server.key
|
||||
# Escape hatch: raw JSON5 merged into the zenoh config last.
|
||||
# extra_config_json5: '{transport: {link: {tx: {queue: {size: {data: 4}}}}}}'
|
||||
|
||||
# --- Observability -------------------------------------------------------------
|
||||
# HTTP health + Prometheus metrics port; 0 disables the endpoint.
|
||||
health_port: 9100
|
||||
|
||||
# Optional bounded request/response capture for offline replay.
|
||||
debug:
|
||||
capture_dir: null
|
||||
capture_max: 256
|
||||
@@ -1,17 +0,0 @@
|
||||
from lerobot.async_inference.configs import PolicyServerConfig
|
||||
from lerobot.async_inference.policy_server import serve
|
||||
|
||||
|
||||
def main():
|
||||
host = ... # something like "127.0.0.1" if you're exposing to localhost
|
||||
port = ... # something like 8080
|
||||
|
||||
config = PolicyServerConfig(
|
||||
host=host,
|
||||
port=port,
|
||||
)
|
||||
serve(config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,62 +0,0 @@
|
||||
import threading
|
||||
|
||||
from lerobot.async_inference.configs import RobotClientConfig
|
||||
from lerobot.async_inference.helpers import visualize_action_queue_size
|
||||
from lerobot.async_inference.robot_client import RobotClient
|
||||
from lerobot.cameras.opencv import OpenCVCameraConfig
|
||||
from lerobot.robots.so_follower import SO100FollowerConfig
|
||||
|
||||
|
||||
def main():
|
||||
# these cameras must match the ones expected by the policy - find your cameras with lerobot-find-cameras
|
||||
# check the config.json on the Hub for the policy you are using to see the expected camera specs
|
||||
camera_cfg = {
|
||||
"up": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
|
||||
"side": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30),
|
||||
}
|
||||
|
||||
# # find ports using lerobot-find-port
|
||||
follower_port = ... # something like "/dev/tty.usbmodem58760431631"
|
||||
|
||||
# # the robot ids are used the load the right calibration files
|
||||
follower_id = ... # something like "follower_so100"
|
||||
|
||||
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id, cameras=camera_cfg)
|
||||
|
||||
server_address = ... # something like "127.0.0.1:8080" if using localhost
|
||||
|
||||
# 3. Create client configuration
|
||||
client_cfg = RobotClientConfig(
|
||||
robot=robot_cfg,
|
||||
server_address=server_address,
|
||||
policy_device="mps",
|
||||
client_device="cpu",
|
||||
policy_type="act",
|
||||
pretrained_name_or_path="<user>/robot_learning_tutorial_act",
|
||||
chunk_size_threshold=0.5, # g
|
||||
actions_per_chunk=50, # make sure this is less than the max actions of the policy
|
||||
)
|
||||
|
||||
# 4. Create and start client
|
||||
client = RobotClient(client_cfg)
|
||||
|
||||
# 5. Provide a textual description of the task
|
||||
task = ...
|
||||
|
||||
if client.start():
|
||||
# Start action receiver thread
|
||||
action_receiver_thread = threading.Thread(target=client.receive_actions, daemon=True)
|
||||
action_receiver_thread.start()
|
||||
|
||||
try:
|
||||
# Run the control loop
|
||||
client.control_loop(task)
|
||||
except KeyboardInterrupt:
|
||||
client.stop()
|
||||
action_receiver_thread.join()
|
||||
# (Optionally) plot the action queue size
|
||||
visualize_action_queue_size(client.action_queue_size)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
+5
-23
@@ -226,24 +226,11 @@ hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.14,<0.
|
||||
vla_jepa = ["lerobot[transformers-dep]", "lerobot[diffusers-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||
|
||||
# Features
|
||||
async = ["lerobot[grpcio-dep]", "lerobot[matplotlib-dep]"]
|
||||
# Remote inference over Zenoh: lerobot-policy-server + lerobot-rollout --inference.type=remote.
|
||||
# Keep zenohd routers on the same minor version as the Python binding.
|
||||
async = ["eclipse-zenoh>=1.9,<2.0", "msgpack>=1.0.0,<2.0.0"]
|
||||
peft = ["lerobot[transformers-dep]", "lerobot[peft-dep]"]
|
||||
|
||||
# Annotation pipeline (lerobot-annotate). The only backend is ``openai``,
|
||||
# which talks to any OpenAI-compatible server (``vllm serve`` /
|
||||
# ``transformers serve`` / hosted). Distributed runs use Hugging Face Jobs
|
||||
# (see examples/annotations/run_hf_job.py).
|
||||
annotations = [
|
||||
"lerobot[dataset]",
|
||||
"lerobot[transformers-dep]",
|
||||
"openai>=1.40,<2.0",
|
||||
# ``vllm`` is intentionally NOT a hard dep: it pins an older torch, and
|
||||
# uv's single unified lock would then cap ``torch`` for every extra
|
||||
# (e.g. forcing 2.8 while ``torchcodec`` in [dataset] needs 2.11 -> ABI
|
||||
# break in CI). The HF Jobs image (``vllm/vllm-openai``) provides vLLM;
|
||||
# install it locally only if you run your own ``vllm serve``.
|
||||
]
|
||||
|
||||
# Development
|
||||
dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools>=1.73.1,<2.0.0", "mypy>=1.19.1", "ruff>=0.14.1", "lerobot[notebook]"]
|
||||
notebook = ["jupyter>=1.0.0,<2.0.0", "ipykernel>=6.0.0,<7.0.0"]
|
||||
@@ -338,8 +325,8 @@ lerobot-find-joint-limits="lerobot.scripts.lerobot_find_joint_limits:main"
|
||||
lerobot-imgtransform-viz="lerobot.scripts.lerobot_imgtransform_viz:main"
|
||||
lerobot-edit-dataset="lerobot.scripts.lerobot_edit_dataset:main"
|
||||
lerobot-setup-can="lerobot.scripts.lerobot_setup_can:main"
|
||||
lerobot-annotate="lerobot.scripts.lerobot_annotate:main"
|
||||
lerobot-rollout="lerobot.scripts.lerobot_rollout:main"
|
||||
lerobot-policy-server="lerobot.scripts.lerobot_policy_server:main"
|
||||
|
||||
# ---------------- Tool Configurations ----------------
|
||||
|
||||
@@ -357,7 +344,7 @@ torch = [{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" }]
|
||||
torchvision = [{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" }]
|
||||
|
||||
[tool.setuptools.package-data]
|
||||
lerobot = ["envs/*.json", "annotations/steerable_pipeline/prompts/*.txt"]
|
||||
lerobot = ["envs/*.json"]
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["src"]
|
||||
@@ -532,11 +519,6 @@ ignore_errors = false
|
||||
# module = "lerobot.rl.*"
|
||||
# ignore_errors = false
|
||||
|
||||
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.async_inference.*"
|
||||
# ignore_errors = false
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "lerobot.transport.*"
|
||||
ignore_errors = false
|
||||
|
||||
@@ -1,15 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
@@ -1,36 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Steerable annotation pipeline producing ``language_persistent`` and
|
||||
``language_events`` columns for LeRobot datasets.
|
||||
|
||||
The pipeline is decomposed into three independently runnable modules whose
|
||||
outputs are staged per-episode before a final parquet rewrite:
|
||||
|
||||
- :mod:`.modules.plan_subtasks_memory` (the ``plan`` module) — persistent styles
|
||||
- :mod:`.modules.interjections_and_speech` (the ``interjections`` module) — event styles + speech
|
||||
- :mod:`.modules.general_vqa` (the ``vqa`` module) — event-style VQA pairs
|
||||
"""
|
||||
|
||||
from .config import AnnotationPipelineConfig
|
||||
from .validator import StagingValidator, ValidationReport
|
||||
from .writer import LanguageColumnsWriter
|
||||
|
||||
__all__ = [
|
||||
"AnnotationPipelineConfig",
|
||||
"LanguageColumnsWriter",
|
||||
"StagingValidator",
|
||||
"ValidationReport",
|
||||
]
|
||||
@@ -1,211 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlanConfig:
|
||||
"""``plan`` module: subtasks + plan + memory + task augmentation."""
|
||||
|
||||
enabled: bool = True
|
||||
|
||||
# ``task_aug`` rephrasings at t=0 (renderer rotates ${task} among them); 0 disables.
|
||||
n_task_rephrasings: int = 10
|
||||
|
||||
# Derive the task from video instead of episode_task: off / if_short / always.
|
||||
# Affects prompts only; ``meta/tasks.parquet`` is untouched.
|
||||
derive_task_from_video: str = "if_short"
|
||||
derive_task_min_words: int = 3
|
||||
|
||||
# --- Frame input: timestamped contact sheets (always on) ---------------
|
||||
# The subtask describe/segment passes ALWAYS render the episode as
|
||||
# macrodata/refiner-style contact sheets: sampled frames packed into JPEG
|
||||
# grids with each frame's timestamp burned into its corner, so the VLM
|
||||
# cites the exact source time of a boundary directly. This is far cheaper
|
||||
# in vision tokens than one image per frame (≈2× faster subtask generation
|
||||
# in practice), which is why the sampling is dense by default.
|
||||
#
|
||||
# ``frames_per_second`` is the sampling rate: 2.0 = one frame every 0.5s.
|
||||
frames_per_second: float = 2.0
|
||||
# Frame budget per VLM call (= columns × rows × sheets). When a whole
|
||||
# episode sampled at ``frames_per_second`` exceeds this, the episode is
|
||||
# AUTOMATICALLY split into consecutive windows of
|
||||
# ``max_frames_per_prompt`` frames each (one describe→segment call per
|
||||
# window, still at the full ``frames_per_second`` density), and the
|
||||
# per-window spans are merged + stitched into one contiguous cover. So an
|
||||
# episode of any length is always covered at the full sampling density.
|
||||
max_frames_per_prompt: int = 60
|
||||
contact_sheet_columns: int = 5
|
||||
contact_sheet_frames_per_sheet: int = 20
|
||||
contact_sheet_frame_width: int = 224
|
||||
contact_sheet_quality: int = 84
|
||||
|
||||
min_subtask_seconds: float = 1.5
|
||||
plan_max_steps: int = 8
|
||||
|
||||
# Narrate-only grounding pass before segmenting — best defense against subtasks
|
||||
# invented from the task text (+1 VLM call/episode).
|
||||
subtask_describe_first: bool = True
|
||||
|
||||
# Emit ``style="plan"`` rows at each boundary; False = subtasks + memory only.
|
||||
emit_plan: bool = True
|
||||
|
||||
# Emit ``style="memory"`` rows at each boundary; False = subtasks (+ plan) only.
|
||||
# Symmetric counterpart of ``emit_plan``.
|
||||
emit_memory: bool = True
|
||||
|
||||
# (subtask spans are always stitched to a contiguous full-episode cover; not configurable.)
|
||||
|
||||
# Optional EgoMimic-style 5-axis task augmentation; replaces n_task_rephrasings.
|
||||
task_aug_axes: TaskAugAxesConfig = field(default_factory=lambda: TaskAugAxesConfig())
|
||||
|
||||
|
||||
@dataclass
|
||||
class TaskAugAxesConfig:
|
||||
"""5-axis t=0 task augmentation (EgoMimic-style): synonym / omit_arm /
|
||||
omit_orientation / omit_grasp_method / combined. Replaces n_task_rephrasings
|
||||
when enabled; each variant becomes a ``task_aug`` row. Axes with nothing to
|
||||
omit emit fewer entries. Defaults (3+3+2+2+2) match EgoMimic."""
|
||||
|
||||
enabled: bool = False
|
||||
|
||||
synonym_paraphrase: int = 3
|
||||
omit_arm: int = 3
|
||||
omit_orientation: int = 2
|
||||
omit_grasp_method: int = 2
|
||||
combined_omissions: int = 2
|
||||
|
||||
|
||||
@dataclass
|
||||
class InterjectionsConfig:
|
||||
"""``interjections`` module: interjections + paired speech."""
|
||||
|
||||
enabled: bool = True
|
||||
|
||||
# Each emits a paired (interjection, speech) row + a plan refresh at that ts.
|
||||
max_interjections_per_episode: int = 3
|
||||
interjection_min_t: float = 2.0
|
||||
|
||||
# Frame window centered on the timestamp so the VLM sees motion, not one frame.
|
||||
interjection_window_seconds: float = 2.0
|
||||
interjection_window_frames: int = 4
|
||||
|
||||
|
||||
@dataclass
|
||||
class VqaConfig:
|
||||
"""``vqa`` module: general VQA."""
|
||||
|
||||
enabled: bool = True
|
||||
vqa_emission_hz: float = 1.0
|
||||
K: int = 1
|
||||
"""Consecutive frames per emission tick. The VLM grounds on the FIRST frame,
|
||||
so K>1 smears stale labels onto moved frames. Default 1 (no smear)."""
|
||||
question_types: tuple[str, ...] = ("bbox", "keypoint", "count", "attribute", "spatial")
|
||||
|
||||
# True: ground VQA only on --vlm.camera_key (default: every camera).
|
||||
restrict_to_default_camera: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class VlmConfig:
|
||||
"""Shared Qwen-VL client configuration."""
|
||||
|
||||
# Only ``openai`` (OpenAI-compatible vLLM server, auto-spawned when
|
||||
# auto_serve=True); ``stub`` is for tests.
|
||||
backend: str = "openai"
|
||||
model_id: str = "Qwen/Qwen3.6-27B"
|
||||
|
||||
# OpenAI-compatible endpoint; ``EMPTY`` key works for local servers.
|
||||
api_base: str = "http://localhost:8000/v1"
|
||||
api_key: str = "EMPTY"
|
||||
|
||||
# Spawn a server if none answers api_base; False = fail fast on a remote.
|
||||
auto_serve: bool = True
|
||||
serve_port: int = 8000
|
||||
# Override the auto-serve command; ``{port}`` substituted per replica.
|
||||
serve_command: str | None = None
|
||||
|
||||
# Independent servers for round-robin routing (one per GPU). num_gpus=0 = one each.
|
||||
parallel_servers: int = 1
|
||||
num_gpus: int = 0
|
||||
client_concurrency: int = 16
|
||||
serve_ready_timeout_s: float = 600.0
|
||||
|
||||
max_new_tokens: int = 512
|
||||
temperature: float = 0.2
|
||||
|
||||
# Auto-serve context length (None → 32768); other vLLM flags go in serve_command.
|
||||
max_model_len: int | None = None
|
||||
|
||||
# Camera for keyframes; None → first ``observation.images.*`` key.
|
||||
camera_key: str | None = None
|
||||
# Forwarded as extra_body.chat_template_kwargs (e.g. {"enable_thinking": false}).
|
||||
chat_template_kwargs: dict[str, Any] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExecutorConfig:
|
||||
"""Executor settings (intra-process episode concurrency; distribution via HF Jobs)."""
|
||||
|
||||
# Episodes processed concurrently per phase; main knob for saturating the servers.
|
||||
episode_parallelism: int = 16
|
||||
|
||||
|
||||
@dataclass
|
||||
class AnnotationPipelineConfig:
|
||||
"""Top-level config for ``lerobot-annotate`` (rewrites data shards in place)."""
|
||||
|
||||
# Hub dataset: download source when ``root`` unset; push target when push_to_hub
|
||||
# is on and ``new_repo_id`` unset.
|
||||
repo_id: str | None = None
|
||||
|
||||
# Separate push target (matches the LeRobot edit tools). Unset → push in place.
|
||||
new_repo_id: str | None = None
|
||||
|
||||
root: Path | None = None
|
||||
|
||||
# Defaults to ``<root>/.annotate_staging/``.
|
||||
staging_dir: Path | None = None
|
||||
|
||||
seed: int = 1729
|
||||
|
||||
plan: PlanConfig = field(default_factory=PlanConfig)
|
||||
interjections: InterjectionsConfig = field(default_factory=InterjectionsConfig)
|
||||
vqa: VqaConfig = field(default_factory=VqaConfig)
|
||||
|
||||
vlm: VlmConfig = field(default_factory=VlmConfig)
|
||||
executor: ExecutorConfig = field(default_factory=ExecutorConfig)
|
||||
|
||||
skip_validation: bool = False
|
||||
only_episodes: tuple[int, ...] | None = None
|
||||
|
||||
# Keyframe decode backend forwarded to ``decode_video_frames``. None →
|
||||
# library default (torchcodec when available, else PyAV). Or pin
|
||||
# ``"torchcodec"`` / ``"pyav"`` explicitly.
|
||||
video_backend: str | None = None
|
||||
|
||||
# Upload to the Hub (new_repo_id if set, else repo_id; one must be set).
|
||||
push_to_hub: bool = False
|
||||
push_private: bool = False
|
||||
push_commit_message: str | None = None
|
||||
|
||||
def resolved_staging_dir(self, root: Path) -> Path:
|
||||
return self.staging_dir if self.staging_dir is not None else root / ".annotate_staging"
|
||||
@@ -1,253 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""In-process executor that runs the annotation phases.
|
||||
|
||||
The executor runs **six phases** in dependency order:
|
||||
|
||||
phase 1: ``plan`` module (plan + subtasks + memory)
|
||||
phase 2: ``interjections`` module (interjections + speech)
|
||||
phase 3: ``plan`` plan-update pass — re-runs plan emission at every
|
||||
interjection timestamp produced by phase 2
|
||||
phase 4: ``vqa`` module (VQA)
|
||||
phase 5: validator
|
||||
phase 6: writer
|
||||
|
||||
Phase 3 is why the ``plan`` module must be re-entered after the
|
||||
``interjections`` module — to refresh ``plan`` rows at interjection
|
||||
timestamps.
|
||||
|
||||
Distributed execution is provided by Hugging Face Jobs (see
|
||||
``examples/annotations/run_hf_job.py``); the runner inside the job
|
||||
invokes ``lerobot-annotate`` which uses this in-process executor.
|
||||
Episode-level concurrency is controlled by
|
||||
``ExecutorConfig.episode_parallelism``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from .config import AnnotationPipelineConfig
|
||||
from .reader import EpisodeRecord, iter_episodes
|
||||
from .staging import EpisodeStaging
|
||||
from .validator import StagingValidator
|
||||
from .writer import LanguageColumnsWriter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PhaseResult:
|
||||
"""Summary of one pipeline phase across all episodes."""
|
||||
|
||||
name: str
|
||||
episodes_processed: int
|
||||
episodes_skipped: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelineRunSummary:
|
||||
"""Aggregated result returned by :meth:`Executor.run`."""
|
||||
|
||||
phases: list[PhaseResult]
|
||||
written_paths: list[Path]
|
||||
validation_report: Any # ValidationReport, kept Any to avoid import cycle
|
||||
|
||||
|
||||
@dataclass
|
||||
class Executor:
|
||||
"""Run all six phases over a dataset root in-process.
|
||||
|
||||
Episode-level concurrency comes from ``ExecutorConfig.episode_parallelism``
|
||||
(a thread pool); cluster-level concurrency comes from running this
|
||||
executor inside a Hugging Face Job. Tests construct the executor
|
||||
directly with stub modules.
|
||||
"""
|
||||
|
||||
config: AnnotationPipelineConfig
|
||||
plan: Any # PlanSubtasksMemoryModule
|
||||
interjections: Any # InterjectionsAndSpeechModule
|
||||
vqa: Any # GeneralVqaModule
|
||||
writer: LanguageColumnsWriter
|
||||
validator: StagingValidator
|
||||
|
||||
def run(self, root: Path) -> PipelineRunSummary:
|
||||
records = list(iter_episodes(root, only_episodes=self.config.only_episodes))
|
||||
n = len(records)
|
||||
if n == 0:
|
||||
raise ValueError(f"No episodes found under {root}/data/")
|
||||
|
||||
print(f"[annotate] {n} episodes total", flush=True)
|
||||
|
||||
staging_dir = self.config.resolved_staging_dir(root)
|
||||
staging_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
phases: list[PhaseResult] = []
|
||||
|
||||
# Phase 1: ``plan`` module (plan + subtasks + memory)
|
||||
phases.append(self._run_module_phase("plan", records, staging_dir, self.plan))
|
||||
# Phase 2: ``interjections`` module (interjections + speech). It
|
||||
# reads the ``plan`` module's subtask rows from the same staging
|
||||
# tree to ground the interjection prompt in the correct local subtask.
|
||||
phases.append(self._run_module_phase("interjections", records, staging_dir, self.interjections))
|
||||
# Phase 3: ``plan`` plan-update pass at interjection timestamps.
|
||||
phases.append(self._run_plan_update_phase(records, staging_dir))
|
||||
# Phase 4: ``vqa`` module (VQA)
|
||||
phases.append(self._run_module_phase("vqa", records, staging_dir, self.vqa))
|
||||
|
||||
print("[annotate] running validator...", flush=True)
|
||||
report = self.validator.validate(records, staging_dir)
|
||||
if not report.ok and not self.config.skip_validation:
|
||||
raise RuntimeError(f"Staging validation failed: {report.summary()}")
|
||||
print(f"[annotate] validator: {report.summary()}", flush=True)
|
||||
|
||||
print(f"[annotate] writing parquet shards into {root}/data/...", flush=True)
|
||||
written = self.writer.write_all(records, staging_dir, root)
|
||||
print(f"[annotate] wrote {len(written)} shard(s); pipeline complete", flush=True)
|
||||
|
||||
# Keep meta/info.json aligned with the parquet schema we just wrote.
|
||||
# Idempotent and additive: existing user metadata is preserved.
|
||||
self._ensure_annotation_metadata_in_info(root)
|
||||
|
||||
return PipelineRunSummary(phases=phases, written_paths=written, validation_report=report)
|
||||
|
||||
@staticmethod
|
||||
def _ensure_annotation_metadata_in_info(root: Path) -> None:
|
||||
"""Write language features and canonical tools to ``meta/info.json``.
|
||||
|
||||
``LanguageColumnsWriter`` adds ``language_persistent`` and
|
||||
``language_events`` to parquet shards. The metadata must advertise
|
||||
those columns too, otherwise non-streaming ``LeRobotDataset`` loads
|
||||
cast against the old schema and fail on the extra parquet columns.
|
||||
"""
|
||||
from lerobot.datasets.io_utils import load_info, write_info # noqa: PLC0415
|
||||
from lerobot.datasets.language import SAY_TOOL_SCHEMA, language_feature_info # noqa: PLC0415
|
||||
|
||||
info_path = root / "meta" / "info.json"
|
||||
if not info_path.exists():
|
||||
return
|
||||
try:
|
||||
info = load_info(root)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
print(f"[annotate] could not read {info_path}: {exc}", flush=True)
|
||||
return
|
||||
|
||||
changed = False
|
||||
|
||||
merged_features = {**info.features, **language_feature_info()}
|
||||
if merged_features != info.features:
|
||||
info.features = merged_features
|
||||
changed = True
|
||||
|
||||
existing = info.tools or []
|
||||
names = {(t.get("function") or {}).get("name") for t in existing if isinstance(t, dict)}
|
||||
if SAY_TOOL_SCHEMA["function"]["name"] not in names:
|
||||
info.tools = [*existing, SAY_TOOL_SCHEMA]
|
||||
changed = True
|
||||
|
||||
if changed:
|
||||
write_info(info, root)
|
||||
print(
|
||||
"[annotate] meta/info.json: "
|
||||
f"language_features={list(language_feature_info())}, "
|
||||
f"tools={[t['function']['name'] for t in (info.tools or [])]}",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
def _run_module_phase(
|
||||
self,
|
||||
name: str,
|
||||
records: list[EpisodeRecord],
|
||||
staging_dir: Path,
|
||||
module: Any,
|
||||
) -> PhaseResult:
|
||||
if not module.enabled:
|
||||
print(f"[annotate] phase={name} skipped (module disabled)", flush=True)
|
||||
return PhaseResult(name=name, episodes_processed=0, episodes_skipped=len(records))
|
||||
n = len(records)
|
||||
parallelism = max(1, min(self.config.executor.episode_parallelism, n))
|
||||
print(
|
||||
f"[annotate] phase={name} starting on {n} episode(s) (parallelism={parallelism})",
|
||||
flush=True,
|
||||
)
|
||||
t0 = time.time()
|
||||
|
||||
def _do(idx_record: tuple[int, EpisodeRecord]) -> tuple[int, int, float]:
|
||||
i, record = idx_record
|
||||
ep_start = time.time()
|
||||
staging = EpisodeStaging(staging_dir, record.episode_index)
|
||||
module.run_episode(record, staging)
|
||||
return i, record.episode_index, time.time() - ep_start
|
||||
|
||||
processed = 0
|
||||
if parallelism == 1:
|
||||
for i, record in enumerate(records, 1):
|
||||
_, ep_idx, elapsed = _do((i, record))
|
||||
processed += 1
|
||||
print(
|
||||
f"[annotate] {name} episode {i}/{n} (idx={ep_idx}) done in {elapsed:.1f}s",
|
||||
flush=True,
|
||||
)
|
||||
else:
|
||||
with ThreadPoolExecutor(max_workers=parallelism) as pool:
|
||||
futures = [pool.submit(_do, (i, r)) for i, r in enumerate(records, 1)]
|
||||
for fut in as_completed(futures):
|
||||
i, ep_idx, elapsed = fut.result()
|
||||
processed += 1
|
||||
print(
|
||||
f"[annotate] {name} episode {processed}/{n} "
|
||||
f"(idx={ep_idx}, submit_order={i}) done in {elapsed:.1f}s",
|
||||
flush=True,
|
||||
)
|
||||
total = time.time() - t0
|
||||
print(f"[annotate] phase={name} complete: {processed}/{n} in {total:.1f}s", flush=True)
|
||||
return PhaseResult(name=name, episodes_processed=processed, episodes_skipped=0)
|
||||
|
||||
def _run_plan_update_phase( # noqa: PLR0915
|
||||
self, records: list[EpisodeRecord], staging_dir: Path
|
||||
) -> PhaseResult:
|
||||
"""Re-emit ``plan`` rows at each timestamp the ``interjections`` module produced.
|
||||
|
||||
The ``plan`` module owns the prompt; the ``interjections`` module
|
||||
produced the timestamps. This phase therefore calls back into the
|
||||
``plan`` module with the interjection timestamps so its existing
|
||||
prompt path is reused.
|
||||
"""
|
||||
if not self.plan.enabled or not self.interjections.enabled:
|
||||
return PhaseResult(name="plan_update", episodes_processed=0, episodes_skipped=len(records))
|
||||
processed = 0
|
||||
for record in records:
|
||||
staging = EpisodeStaging(staging_dir, record.episode_index)
|
||||
interjection_rows = [
|
||||
row for row in staging.read("interjections") if row.get("style") == "interjection"
|
||||
]
|
||||
interjection_times = [float(row["timestamp"]) for row in interjection_rows]
|
||||
interjection_texts = [str(row.get("content") or "") for row in interjection_rows]
|
||||
if interjection_times:
|
||||
self.plan.run_plan_updates(record, staging, interjection_times, interjection_texts)
|
||||
processed += 1
|
||||
# Episodes without any interjections are skipped (no plan refresh
|
||||
# needed); count them so the summary's processed+skipped == total.
|
||||
return PhaseResult(
|
||||
name="plan_update",
|
||||
episodes_processed=processed,
|
||||
episodes_skipped=len(records) - processed,
|
||||
)
|
||||
@@ -1,481 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Keyframe extraction for the annotation pipeline.
|
||||
|
||||
Modules attach decoded camera frames to their VLM prompts so the model can
|
||||
ground subtask decomposition, interjection scenarios, and VQA in actual
|
||||
visual content. The pipeline shares one provider across modules and one
|
||||
episode at a time, with a small per-episode cache so multiple modules
|
||||
querying the same timestamp pay decode cost once.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import logging
|
||||
import math
|
||||
import threading
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Protocol
|
||||
|
||||
import PIL.Image
|
||||
import torch
|
||||
|
||||
from lerobot.configs.video import VideoEncoderConfig
|
||||
from lerobot.datasets.video_utils import decode_video_frames, reencode_video
|
||||
|
||||
from .reader import EpisodeRecord, snap_to_frame
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FrameProvider(Protocol):
|
||||
"""Decodes camera frames at episode-relative timestamps."""
|
||||
|
||||
@property
|
||||
def camera_keys(self) -> list[str]:
|
||||
"""All ``observation.images.*`` feature keys this provider can decode."""
|
||||
|
||||
def frames_at(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
timestamps: list[float],
|
||||
camera_key: str | None = None,
|
||||
) -> list[Any]:
|
||||
"""Return one decoded frame per timestamp from ``camera_key`` (or default).
|
||||
|
||||
Frames are ``torch.Tensor`` (``C, H, W`` uint8) — the shape
|
||||
:func:`lerobot.datasets.video_utils.decode_video_frames` returns.
|
||||
:func:`to_image_blocks` converts them to PIL only at the VLM-message
|
||||
boundary.
|
||||
|
||||
Empty list if the camera is unavailable. ``camera_key=None`` falls back
|
||||
to the provider's default camera so existing single-camera callers
|
||||
(the ``plan`` and ``interjections`` modules) keep working unchanged.
|
||||
"""
|
||||
|
||||
def video_for_episode(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
max_frames: int,
|
||||
camera_key: str | None = None,
|
||||
) -> list[Any]:
|
||||
"""Return up to ``max_frames`` decoded frames covering the whole episode.
|
||||
|
||||
Sampling is uniform across the episode duration. Frames are
|
||||
``torch.Tensor`` (``C, H, W`` uint8); :func:`to_video_block` wraps
|
||||
them into one ``{"type":"video", "video":<list>}`` block for a
|
||||
Qwen-VL-compatible model that pools temporally itself. Empty list if
|
||||
no camera available.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class _NullProvider:
|
||||
"""No-op provider used when the dataset has no video keys or in tests."""
|
||||
|
||||
@property
|
||||
def camera_keys(self) -> list[str]:
|
||||
return []
|
||||
|
||||
def frames_at(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
timestamps: list[float],
|
||||
camera_key: str | None = None,
|
||||
) -> list[Any]:
|
||||
return []
|
||||
|
||||
def video_for_episode(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
max_frames: int,
|
||||
camera_key: str | None = None,
|
||||
) -> list[Any]:
|
||||
return []
|
||||
|
||||
|
||||
def null_provider() -> FrameProvider:
|
||||
return _NullProvider()
|
||||
|
||||
|
||||
@dataclass
|
||||
class VideoFrameProvider:
|
||||
"""Decodes frames from the dataset's ``observation.images.*`` streams.
|
||||
|
||||
By default the *first* camera key is used for the ``plan`` module
|
||||
(subtask decomposition) and the ``interjections`` module (interjection
|
||||
scenarios) — those prompts care about *what is happening*, not which
|
||||
angle. The ``vqa`` module instead iterates over every camera in
|
||||
:attr:`camera_keys` so each frame's
|
||||
grounded answer (bbox/keypoint/...) is tagged with the camera it was
|
||||
grounded against.
|
||||
|
||||
``camera_key`` overrides the default-camera choice but does not restrict
|
||||
:attr:`camera_keys`. Pass ``camera_key`` explicitly to ``frames_at`` /
|
||||
``video_for_episode`` to read a non-default stream.
|
||||
|
||||
Caches up to ``cache_size`` decoded frames per process to keep
|
||||
co-timestamped ``interjections`` + ``plan`` plan-update calls cheap.
|
||||
"""
|
||||
|
||||
root: Path
|
||||
camera_key: str | None = None
|
||||
tolerance_s: float = 1e-2
|
||||
cache_size: int = 256
|
||||
# Keyframe decode backend forwarded to
|
||||
# :func:`lerobot.datasets.video_utils.decode_video_frames`. ``None``
|
||||
# uses the library default (torchcodec when available, else PyAV).
|
||||
video_backend: str | None = None
|
||||
_meta: Any = field(default=None, init=False, repr=False)
|
||||
_cache: dict = field(default_factory=dict, init=False, repr=False)
|
||||
_camera_keys: list[str] = field(default_factory=list, init=False, repr=False)
|
||||
# Pipeline runs the three module phases under a ThreadPoolExecutor (see
|
||||
# ``ExecutorConfig.episode_parallelism``); guard the dict cache and the
|
||||
# one-shot warn flag against concurrent updates from worker threads.
|
||||
_lock: threading.Lock = field(default_factory=threading.Lock, init=False, repr=False)
|
||||
# Serializes decode_video_frames calls: torchcodec hands out one
|
||||
# ``VideoDecoder`` per file from a process-wide cache, and the decoder
|
||||
# is not safe to drive from multiple threads at once.
|
||||
_decode_lock: threading.Lock = field(default_factory=threading.Lock, init=False, repr=False)
|
||||
_warned_decode_fail: bool = field(default=False, init=False, repr=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata # noqa: PLC0415
|
||||
|
||||
self._meta = LeRobotDatasetMetadata(repo_id="local", root=self.root)
|
||||
# Only ``video_keys`` are decodable here: the clip/decode paths read
|
||||
# ``videos/<key>/from_timestamp`` from episode metadata, which exists
|
||||
# only for video-stored cameras. Image-stored cameras (also in
|
||||
# ``camera_keys``) would KeyError, so restrict the list — and the
|
||||
# default — to video keys.
|
||||
keys = list(self._meta.video_keys)
|
||||
# Last-resort fallback: if metadata didn't surface any video keys but
|
||||
# the caller explicitly named a camera (``--vlm.camera_key=...``),
|
||||
# trust them — the key is by definition known to exist on the dataset.
|
||||
if not keys and self.camera_key:
|
||||
keys = [self.camera_key]
|
||||
self._camera_keys = keys
|
||||
if self.camera_key is None:
|
||||
self.camera_key = keys[0] if keys else None
|
||||
|
||||
@property
|
||||
def camera_keys(self) -> list[str]:
|
||||
"""All ``observation.images.*`` keys available on this dataset."""
|
||||
return list(self._camera_keys)
|
||||
|
||||
def frames_at(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
timestamps: list[float],
|
||||
camera_key: str | None = None,
|
||||
) -> list[Any]:
|
||||
target = camera_key if camera_key is not None else self.camera_key
|
||||
if not timestamps or target is None:
|
||||
return []
|
||||
# Snap each request to the nearest real frame timestamp: callers
|
||||
# sample uniform grids whose points land mid-frame, and
|
||||
# ``decode_video_frames`` rejects queries farther than
|
||||
# ``tolerance_s`` from a decodable frame. Snapping also dedupes
|
||||
# repeat queries through the cache.
|
||||
if record.frame_timestamps:
|
||||
timestamps = [snap_to_frame(float(ts), record.frame_timestamps) for ts in timestamps]
|
||||
|
||||
out: list[Any] = []
|
||||
misses: list[float] = []
|
||||
miss_indices: list[int] = []
|
||||
with self._lock:
|
||||
for i, ts in enumerate(timestamps):
|
||||
key = (record.episode_index, target, round(float(ts), 6))
|
||||
cached = self._cache.get(key)
|
||||
if cached is not None:
|
||||
out.append(cached)
|
||||
else:
|
||||
out.append(None)
|
||||
misses.append(float(ts))
|
||||
miss_indices.append(i)
|
||||
|
||||
if misses:
|
||||
decoded = self._decode(record.episode_index, misses, target)
|
||||
# ``_decode`` returns exactly one frame per requested timestamp,
|
||||
# or an empty list if decoding failed wholesale. A partial list
|
||||
# would mean a frame/timestamp misalignment, so only pair them up
|
||||
# when the counts match (``strict=True`` then guards regressions).
|
||||
if len(decoded) == len(miss_indices):
|
||||
with self._lock:
|
||||
for i, frame in zip(miss_indices, decoded, strict=True):
|
||||
out[i] = frame
|
||||
key = (record.episode_index, target, round(float(timestamps[i]), 6))
|
||||
if len(self._cache) >= self.cache_size:
|
||||
self._cache.pop(next(iter(self._cache)))
|
||||
self._cache[key] = frame
|
||||
# filter out any None left over from decode failures
|
||||
return [frame for frame in out if frame is not None]
|
||||
|
||||
def video_for_episode(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
max_frames: int,
|
||||
camera_key: str | None = None,
|
||||
) -> list[Any]:
|
||||
"""Return up to ``max_frames`` frames uniformly sampled across the episode.
|
||||
|
||||
The whole episode duration is covered; the model picks subtask
|
||||
boundaries from the temporal pooling it does internally. Frames are
|
||||
``torch.Tensor`` (see :meth:`frames_at`).
|
||||
"""
|
||||
target = camera_key if camera_key is not None else self.camera_key
|
||||
if max_frames <= 0 or target is None or not record.frame_timestamps:
|
||||
return []
|
||||
n_frames = min(max_frames, len(record.frame_timestamps))
|
||||
if n_frames == len(record.frame_timestamps):
|
||||
timestamps = list(record.frame_timestamps)
|
||||
else:
|
||||
t0 = record.frame_timestamps[0]
|
||||
t_last = record.frame_timestamps[-1]
|
||||
if t_last <= t0:
|
||||
timestamps = [float(t0)] * n_frames
|
||||
else:
|
||||
step = (t_last - t0) / (n_frames - 1) if n_frames > 1 else 0.0
|
||||
timestamps = [float(t0 + i * step) for i in range(n_frames)]
|
||||
return self.frames_at(record, timestamps, camera_key=target)
|
||||
|
||||
def episode_clip_path(self, record: EpisodeRecord, cache_dir: Path) -> Path | None:
|
||||
"""Extract the episode's subclip to ``cache_dir/ep_{idx:06d}.mp4``.
|
||||
|
||||
Returns ``None`` if the dataset has no video tracks or extraction
|
||||
failed. Skips re-extract when the cached clip already exists.
|
||||
Re-encodes to H.264 via
|
||||
:func:`lerobot.datasets.video_utils.reencode_video` so the resulting
|
||||
mp4 is decodable by every downstream video processor — stream-copy
|
||||
would inherit the source codec (often AV1 in modern LeRobot
|
||||
datasets), which vllm's libav build cannot decode.
|
||||
"""
|
||||
if self.camera_key is None:
|
||||
return None
|
||||
cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
out_path = cache_dir / f"ep_{record.episode_index:06d}.mp4"
|
||||
if out_path.exists() and out_path.stat().st_size > 0:
|
||||
return out_path
|
||||
ep = self._meta.episodes[record.episode_index]
|
||||
from_timestamp = float(ep[f"videos/{self.camera_key}/from_timestamp"])
|
||||
to_timestamp = float(ep[f"videos/{self.camera_key}/to_timestamp"])
|
||||
src = self.root / self._meta.get_video_file_path(record.episode_index, self.camera_key)
|
||||
encoder = VideoEncoderConfig(vcodec="h264", pix_fmt="yuv420p", g=None, crf=23, preset="ultrafast")
|
||||
try:
|
||||
reencode_video(
|
||||
src,
|
||||
out_path,
|
||||
camera_encoder=encoder,
|
||||
overwrite=True,
|
||||
start_time_s=from_timestamp,
|
||||
end_time_s=to_timestamp,
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"clip extraction failed for episode %s (%s)", record.episode_index, src, exc_info=True
|
||||
)
|
||||
return None
|
||||
return out_path if out_path.exists() and out_path.stat().st_size > 0 else None
|
||||
|
||||
def _decode(self, episode_index: int, timestamps: list[float], camera_key: str) -> list[Any]:
|
||||
"""Decode ``timestamps`` from the episode's video as ``(C, H, W)`` tensors.
|
||||
|
||||
Delegates to :func:`lerobot.datasets.video_utils.decode_video_frames`
|
||||
(torchcodec when available, PyAV otherwise; ``video_backend`` pins
|
||||
one explicitly). Returns one frame per requested timestamp, or ``[]``
|
||||
if decoding failed — callers treat ``[]`` as "no frames available".
|
||||
"""
|
||||
ep = self._meta.episodes[episode_index]
|
||||
from_timestamp = ep[f"videos/{camera_key}/from_timestamp"]
|
||||
shifted = [from_timestamp + ts for ts in timestamps]
|
||||
video_path = self.root / self._meta.get_video_file_path(episode_index, camera_key)
|
||||
|
||||
try:
|
||||
# The module phases decode under a ThreadPoolExecutor (see
|
||||
# ``ExecutorConfig.episode_parallelism``) but torchcodec's cached
|
||||
# per-file decoder is single-threaded, so serialize decodes on a
|
||||
# dedicated lock. Frame extraction is a small fraction of episode
|
||||
# wall time (VLM calls dominate), so the contention is cheap.
|
||||
with self._decode_lock:
|
||||
# Stacked ``(N, C, H, W)`` uint8 tensor; one row per timestamp.
|
||||
decoded = decode_video_frames(
|
||||
video_path, shifted, self.tolerance_s, backend=self.video_backend, return_uint8=True
|
||||
)
|
||||
return list(decoded)
|
||||
except Exception as exc:
|
||||
# Log loudly the first time so a silent vqa-module no-op (every
|
||||
# prompt skipped because frames_at returned []) is debuggable from
|
||||
# the job log instead of post-hoc parquet inspection. Subsequent
|
||||
# failures stay quiet.
|
||||
with self._lock:
|
||||
already_warned = self._warned_decode_fail
|
||||
if not already_warned:
|
||||
self._warned_decode_fail = True
|
||||
if not already_warned:
|
||||
logger.warning(
|
||||
"VideoFrameProvider._decode failed for episode=%s camera=%s video_path=%s backend=%s: %s",
|
||||
episode_index,
|
||||
camera_key,
|
||||
video_path,
|
||||
self.video_backend,
|
||||
exc,
|
||||
exc_info=exc,
|
||||
)
|
||||
return []
|
||||
|
||||
|
||||
def make_frame_provider(
|
||||
root: Path, camera_key: str | None = None, video_backend: str | None = None
|
||||
) -> FrameProvider:
|
||||
"""Build a :class:`VideoFrameProvider` if videos are present, else null."""
|
||||
try:
|
||||
provider = VideoFrameProvider(root=root, camera_key=camera_key, video_backend=video_backend)
|
||||
except Exception:
|
||||
return null_provider()
|
||||
if provider.camera_key is None:
|
||||
return null_provider()
|
||||
return provider
|
||||
|
||||
|
||||
def _frame_to_pil(frame: Any) -> Any:
|
||||
"""Materialise a decoded frame as a ``PIL.Image`` for the VLM message.
|
||||
|
||||
Frames flow through the provider as ``torch.Tensor`` (``C, H, W`` uint8,
|
||||
straight from :func:`decode_video_frames`); PIL is only created here, at
|
||||
the VLM-message boundary, because the chat backends expect PIL images /
|
||||
data URLs. Non-tensor inputs (e.g. test stubs) pass through untouched.
|
||||
"""
|
||||
if not isinstance(frame, torch.Tensor):
|
||||
return frame
|
||||
array = frame.detach().cpu()
|
||||
if array.ndim == 3 and array.shape[0] in (1, 3):
|
||||
array = array.permute(1, 2, 0) # (C, H, W) -> (H, W, C)
|
||||
if array.shape[-1] == 1:
|
||||
array = array.squeeze(-1)
|
||||
return PIL.Image.fromarray(array.to(torch.uint8).numpy())
|
||||
|
||||
|
||||
def to_image_blocks(frames: list[Any]) -> list[dict[str, Any]]:
|
||||
"""Convert decoded frames to Qwen-VL-compatible image content blocks."""
|
||||
return [{"type": "image", "image": _frame_to_pil(frame)} for frame in frames]
|
||||
|
||||
|
||||
def to_video_block(frames: list[Any]) -> list[dict[str, Any]]:
|
||||
"""Wrap a list of decoded frames as one Qwen-VL video block.
|
||||
|
||||
Returns ``[]`` when the list is empty, so the caller can splat the result
|
||||
into a content array without a separate emptiness check.
|
||||
"""
|
||||
if not frames:
|
||||
return []
|
||||
return [{"type": "video", "video": [_frame_to_pil(frame) for frame in frames]}]
|
||||
|
||||
|
||||
def to_video_url_block(url: str | None, fps: float = 2.0) -> list[dict[str, Any]]:
|
||||
"""Wrap a video file URL as one ``video_url`` block.
|
||||
|
||||
Used by the ``openai`` backend (transformers serve / vllm serve /
|
||||
ktransformers serve), where the server handles frame sampling.
|
||||
Returns ``[]`` when ``url`` is ``None`` so the caller can splat.
|
||||
"""
|
||||
if not url:
|
||||
return []
|
||||
return [{"type": "video_url", "video_url": {"url": url}, "fps": fps}]
|
||||
|
||||
|
||||
def _draw_timestamp_badge(image: PIL.Image.Image, timestamp: float) -> PIL.Image.Image:
|
||||
"""Burn ``timestamp`` (seconds) into the top-left corner of ``image``.
|
||||
|
||||
A solid black badge with white text, so a VLM reading a contact sheet can
|
||||
cite the exact source time of each tile (e.g. ``012.50s``) directly,
|
||||
instead of the caller having to map tile position back to time. Mirrors
|
||||
the macrodata/refiner contact-sheet convention.
|
||||
"""
|
||||
from PIL import ImageDraw, ImageFont
|
||||
|
||||
result = image.copy()
|
||||
draw = ImageDraw.Draw(result)
|
||||
font = ImageFont.load_default()
|
||||
label = f"{timestamp:06.2f}s"
|
||||
left, top, right, bottom = draw.textbbox((0, 0), label, font=font)
|
||||
text_w, text_h = right - left, bottom - top
|
||||
pad = max(3, round(min(image.width, image.height) * 0.018))
|
||||
draw.rectangle((0, 0, text_w + pad * 2, text_h + pad * 2), fill=(0, 0, 0))
|
||||
draw.text((pad - left, pad - top), label, fill=(255, 255, 255), font=font)
|
||||
return result
|
||||
|
||||
|
||||
def to_contact_sheet_blocks(
|
||||
frames: Sequence[Any],
|
||||
timestamps: Sequence[float],
|
||||
*,
|
||||
columns: int = 5,
|
||||
frames_per_sheet: int = 20,
|
||||
frame_width: int = 224,
|
||||
quality: int = 84,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Pack decoded frames into timestamped JPEG contact-sheet image blocks.
|
||||
|
||||
Each frame is resized to ``frame_width`` wide, stamped with its
|
||||
episode-relative timestamp, and tiled row-major into grids of
|
||||
``frames_per_sheet`` (``columns`` wide). One ``{"type":"image", ...}``
|
||||
block is returned per grid; many frames collapse into a few images, so a
|
||||
long episode's temporal coverage stays dense at a fraction of the vision
|
||||
tokens N separate frames would cost. ``frames`` and ``timestamps`` must be
|
||||
aligned and equal length. Returns ``[]`` for empty input.
|
||||
"""
|
||||
from PIL import Image
|
||||
|
||||
if not frames:
|
||||
return []
|
||||
columns = max(1, columns)
|
||||
frames_per_sheet = max(1, frames_per_sheet)
|
||||
rows_per_sheet = math.ceil(frames_per_sheet / columns)
|
||||
|
||||
tiles: list[PIL.Image.Image] = []
|
||||
for ts, frame in zip(timestamps, frames, strict=False):
|
||||
img = _frame_to_pil(frame)
|
||||
if not isinstance(img, PIL.Image.Image):
|
||||
continue
|
||||
img = img.convert("RGB")
|
||||
if img.width != frame_width:
|
||||
height = max(1, round(img.height * frame_width / img.width))
|
||||
img = img.resize((frame_width, height), resample=Image.Resampling.BILINEAR)
|
||||
tiles.append(_draw_timestamp_badge(img, float(ts)))
|
||||
if not tiles:
|
||||
return []
|
||||
|
||||
blocks: list[dict[str, Any]] = []
|
||||
for start in range(0, len(tiles), frames_per_sheet):
|
||||
chunk = tiles[start : start + frames_per_sheet]
|
||||
cell_w = max(tile.width for tile in chunk)
|
||||
cell_h = max(tile.height for tile in chunk)
|
||||
sheet = Image.new("RGB", (cell_w * columns, cell_h * rows_per_sheet), color=(0, 0, 0))
|
||||
for i, tile in enumerate(chunk):
|
||||
x = (i % columns) * cell_w
|
||||
y = (i // columns) * cell_h
|
||||
sheet.paste(tile, (x, y))
|
||||
# JPEG round-trip at ``quality`` to match the refiner convention and
|
||||
# shrink the wire payload; vision-token count is set by resolution, so
|
||||
# the real saving is the grid packing, not the codec.
|
||||
buf = io.BytesIO()
|
||||
sheet.save(buf, format="JPEG", quality=quality)
|
||||
buf.seek(0)
|
||||
blocks.append({"type": "image", "image": Image.open(buf).convert("RGB")})
|
||||
return blocks
|
||||
@@ -1,25 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .general_vqa import GeneralVqaModule
|
||||
from .interjections_and_speech import InterjectionsAndSpeechModule
|
||||
from .plan_subtasks_memory import PlanSubtasksMemoryModule
|
||||
|
||||
__all__ = [
|
||||
"GeneralVqaModule",
|
||||
"InterjectionsAndSpeechModule",
|
||||
"PlanSubtasksMemoryModule",
|
||||
]
|
||||
@@ -1,248 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""``vqa`` module: general VQA at a timed cadence.
|
||||
|
||||
Every ``1/hz`` seconds an emission tick fires; each tick anchors ``K``
|
||||
consecutive frames, and every anchored frame gets its own VQA pair. Each
|
||||
pair is grounded on that single anchor frame — there is no per-pair frame
|
||||
window. For datasets with multiple cameras, every anchored frame produces
|
||||
one ``(vqa, user)`` + ``(vqa, assistant)`` pair *per camera*: each pair is
|
||||
generated against that camera's frame and stamped with the matching
|
||||
``camera`` field on the emitted rows. The resolver disambiguates via
|
||||
``camera=...``; recipes that consume VQA do so through one sub-recipe
|
||||
per camera (see ``recipes/pi05_hirobot.yaml``).
|
||||
|
||||
Within a single (frame, camera) we still emit at most one ``(vqa, user)``
|
||||
and one ``(vqa, assistant)`` row, so the resolver contract stays scalar.
|
||||
|
||||
Question types covered (per the plan's ``vqa`` table): bbox, keypoint,
|
||||
count, attribute, spatial. The assistant's ``content`` is a JSON string
|
||||
whose schema depends on the question type. Malformed JSON triggers one
|
||||
retry inside :meth:`VlmClient.generate_json`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from ..config import VqaConfig
|
||||
from ..frames import FrameProvider, null_provider, to_image_blocks
|
||||
from ..prompts import load as load_prompt
|
||||
from ..reader import EpisodeRecord
|
||||
from ..staging import EpisodeStaging
|
||||
from ..validator import classify_vqa_answer
|
||||
from ..vlm_client import VlmClient
|
||||
|
||||
|
||||
def _emission_anchor_indices(frame_timestamps: Sequence[float], hz: float, k: int) -> list[int]:
|
||||
"""Return the relative frame indices to anchor VQA emissions to.
|
||||
|
||||
For each emission tick (every ``1/hz`` seconds), we anchor ``k``
|
||||
consecutive frames starting at the tick. Ticks fall on the nearest
|
||||
available source frame timestamp.
|
||||
"""
|
||||
if hz <= 0 or k <= 0 or not frame_timestamps:
|
||||
return []
|
||||
t0 = frame_timestamps[0]
|
||||
t_last = frame_timestamps[-1]
|
||||
period = 1.0 / hz
|
||||
indices: list[int] = []
|
||||
t = t0
|
||||
while t <= t_last + 1e-9:
|
||||
# find the index of the nearest frame to t
|
||||
nearest_i = min(range(len(frame_timestamps)), key=lambda i: abs(frame_timestamps[i] - t))
|
||||
for offset in range(k):
|
||||
j = nearest_i + offset
|
||||
if j >= len(frame_timestamps):
|
||||
break
|
||||
if not indices or indices[-1] != j:
|
||||
indices.append(j)
|
||||
t += period
|
||||
# dedupe while preserving order
|
||||
seen: set[int] = set()
|
||||
deduped: list[int] = []
|
||||
for i in indices:
|
||||
if i in seen:
|
||||
continue
|
||||
seen.add(i)
|
||||
deduped.append(i)
|
||||
return deduped
|
||||
|
||||
|
||||
@dataclass
|
||||
class GeneralVqaModule:
|
||||
"""Emit grounded VQA pairs at a timed cadence."""
|
||||
|
||||
vlm: VlmClient
|
||||
config: VqaConfig
|
||||
seed: int = 1729
|
||||
frame_provider: FrameProvider = field(default_factory=null_provider)
|
||||
_warned_no_camera: bool = field(default=False, init=False, repr=False)
|
||||
|
||||
@property
|
||||
def enabled(self) -> bool:
|
||||
return self.config.enabled
|
||||
|
||||
def run_episode(self, record: EpisodeRecord, staging: EpisodeStaging) -> None:
|
||||
if not record.frame_timestamps:
|
||||
staging.write("vqa", [])
|
||||
return
|
||||
rng = random.Random(f"{self.seed}:{record.episode_index}:vqa")
|
||||
anchor_idx = _emission_anchor_indices(
|
||||
record.frame_timestamps, self.config.vqa_emission_hz, self.config.K
|
||||
)
|
||||
cameras = self._target_cameras()
|
||||
if not cameras:
|
||||
# No camera available — emit nothing rather than producing
|
||||
# untagged rows that would fail validation. Surface a loud one-
|
||||
# time warning so this is never silently a no-op.
|
||||
if not self._warned_no_camera:
|
||||
logging.getLogger(__name__).warning(
|
||||
"vqa module found no cameras on the frame provider — "
|
||||
"every episode will emit zero VQA rows. Check that the "
|
||||
"dataset declares observation.images.* features in "
|
||||
"meta/info.json; passing --vlm.camera_key=<key> at the "
|
||||
"CLI now also seeds the cameras list as a fallback."
|
||||
)
|
||||
self._warned_no_camera = True
|
||||
staging.write("vqa", [])
|
||||
return
|
||||
|
||||
# Build all messages first (one per (frame, camera)), then issue them
|
||||
# as a single batched generate_json call so the client can fan them
|
||||
# out concurrently.
|
||||
per_call: list[tuple[float, str, str, list[dict[str, Any]]]] = []
|
||||
for idx in anchor_idx:
|
||||
ts = float(record.frame_timestamps[idx])
|
||||
qtype = rng.choice(self.config.question_types)
|
||||
for camera in cameras:
|
||||
messages = self._build_messages(record, qtype, ts, camera)
|
||||
# Skip cameras that decoded to zero frames at this ts: no point
|
||||
# asking the VLM to ground a bbox without an image.
|
||||
if not _has_image_block(messages):
|
||||
continue
|
||||
per_call.append((ts, camera, qtype, messages))
|
||||
|
||||
if not per_call:
|
||||
staging.write("vqa", [])
|
||||
return
|
||||
|
||||
results = self.vlm.generate_json([m for _, _, _, m in per_call])
|
||||
|
||||
rows: list[dict[str, Any]] = []
|
||||
for (ts, camera, _qtype, _messages), result in zip(per_call, results, strict=True):
|
||||
qa = self._postprocess(result)
|
||||
if qa is None:
|
||||
continue
|
||||
question, answer = qa
|
||||
rows.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": question,
|
||||
"style": "vqa",
|
||||
"timestamp": ts,
|
||||
"camera": camera,
|
||||
"tool_calls": None,
|
||||
}
|
||||
)
|
||||
rows.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": json.dumps(answer, sort_keys=True),
|
||||
"style": "vqa",
|
||||
"timestamp": ts,
|
||||
"camera": camera,
|
||||
"tool_calls": None,
|
||||
}
|
||||
)
|
||||
staging.write("vqa", rows)
|
||||
|
||||
def _target_cameras(self) -> list[str]:
|
||||
"""Return the cameras the ``vqa`` module should iterate per anchored frame.
|
||||
|
||||
Defaults to every camera the provider exposes. Datasets with no
|
||||
cameras (or test/null providers) yield an empty list, which makes
|
||||
``run_episode`` a no-op.
|
||||
|
||||
When ``config.restrict_to_default_camera`` is set, VQA grounds on
|
||||
only the provider's default camera (the single ``--vlm.camera_key``
|
||||
stream), matching the plan / interjection modules so the whole
|
||||
pipeline focuses on one view.
|
||||
"""
|
||||
all_cameras = list(getattr(self.frame_provider, "camera_keys", []) or [])
|
||||
if getattr(self.config, "restrict_to_default_camera", False):
|
||||
default = getattr(self.frame_provider, "camera_key", None)
|
||||
if default and default in all_cameras:
|
||||
return [default]
|
||||
# ``restrict_to_default_camera`` is set but the configured default
|
||||
# isn't one the provider exposes. Returning it anyway would make
|
||||
# ``_decode`` raise a KeyError deep in frame extraction, so warn and
|
||||
# fall through to every available camera instead.
|
||||
if default:
|
||||
logging.getLogger(__name__).warning(
|
||||
"restrict_to_default_camera is set but camera_key=%r is not in the "
|
||||
"provider's cameras %s; grounding VQA on all available cameras instead.",
|
||||
default,
|
||||
all_cameras,
|
||||
)
|
||||
return all_cameras
|
||||
|
||||
def _build_messages(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
question_type: str,
|
||||
frame_timestamp: float,
|
||||
camera_key: str,
|
||||
) -> list[dict[str, Any]]:
|
||||
prompt = load_prompt("vqa").format(
|
||||
episode_task=record.episode_task,
|
||||
question_type=question_type,
|
||||
)
|
||||
images = self.frame_provider.frames_at(record, [frame_timestamp], camera_key=camera_key)
|
||||
content = [*to_image_blocks(images), {"type": "text", "text": prompt}]
|
||||
return [{"role": "user", "content": content}]
|
||||
|
||||
def _postprocess(self, result: Any) -> tuple[str, dict[str, Any]] | None:
|
||||
if not isinstance(result, dict):
|
||||
return None
|
||||
question = result.get("question")
|
||||
answer = result.get("answer")
|
||||
if not isinstance(question, str) or not question.strip():
|
||||
return None
|
||||
if not isinstance(answer, dict):
|
||||
return None
|
||||
# The validator will enforce shape; here we just sanity-check that the
|
||||
# answer matches *some* known shape so we can drop garbage early.
|
||||
if classify_vqa_answer(answer) is None:
|
||||
return None
|
||||
return question.strip(), answer
|
||||
|
||||
|
||||
def _has_image_block(messages: list[dict[str, Any]]) -> bool:
|
||||
"""Return True if any user content block is a populated image block."""
|
||||
for msg in messages:
|
||||
content = msg.get("content")
|
||||
if not isinstance(content, list):
|
||||
continue
|
||||
for block in content:
|
||||
if isinstance(block, dict) and block.get("type") == "image":
|
||||
return True
|
||||
return False
|
||||
@@ -1,211 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""``interjections`` module: interjections + paired speech (EVENT styles + speech atoms).
|
||||
|
||||
Two sub-passes:
|
||||
|
||||
1. At ``t=0``, emit ONLY a speech tool-call atom (acknowledgement of the
|
||||
canonical task). No interjection row — the canonical task is already the
|
||||
user utterance from ``meta/tasks.parquet``.
|
||||
|
||||
2. For mid-episode interruptions, emit a co-timestamped pair:
|
||||
{role:user, style:interjection, content:<text>}
|
||||
speech atom (role:assistant, style:None, tool_calls=[say(...)])
|
||||
Both rows go in ``language_events`` at the same timestamp.
|
||||
|
||||
The ``plan`` module's :meth:`run_plan_updates` reuses this module's
|
||||
interjection timestamps to refresh the ``plan`` row at the same instant.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import random
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from ..config import InterjectionsConfig
|
||||
from ..frames import FrameProvider, null_provider, to_image_blocks
|
||||
from ..prompts import load as load_prompt
|
||||
from ..reader import EpisodeRecord, reconstruct_subtask_spans, snap_to_frame
|
||||
from ..staging import EpisodeStaging
|
||||
from ..vlm_client import VlmClient
|
||||
from ..writer import speech_atom
|
||||
|
||||
|
||||
@dataclass
|
||||
class InterjectionsAndSpeechModule:
|
||||
"""Generate task-start speech and mid-episode interjection/speech pairs."""
|
||||
|
||||
vlm: VlmClient
|
||||
config: InterjectionsConfig
|
||||
seed: int = 1729
|
||||
frame_provider: FrameProvider = field(default_factory=null_provider)
|
||||
|
||||
@property
|
||||
def enabled(self) -> bool:
|
||||
return self.config.enabled
|
||||
|
||||
def run_episode(self, record: EpisodeRecord, staging: EpisodeStaging) -> None:
|
||||
rows: list[dict[str, Any]] = []
|
||||
if record.frame_timestamps:
|
||||
t0 = float(record.frame_timestamps[0])
|
||||
initial = self._initial_speech(record)
|
||||
if initial:
|
||||
rows.append(speech_atom(t0, initial))
|
||||
# Pull the ``plan`` module's subtask spans for this episode so the
|
||||
# interjection prompt can ground itself in the actual current
|
||||
# subtask at each chosen timestamp. The ``plan`` module ran first.
|
||||
episode_end_t = float(record.frame_timestamps[-1]) if record.frame_timestamps else None
|
||||
subtask_spans = reconstruct_subtask_spans(staging.read("plan"), episode_end_t=episode_end_t)
|
||||
rows.extend(self._mid_episode_interjections(record, subtask_spans))
|
||||
staging.write("interjections", rows)
|
||||
|
||||
@staticmethod
|
||||
def _subtask_at(spans: Sequence[dict[str, Any]], t: float) -> str | None:
|
||||
current: str | None = None
|
||||
for span in spans:
|
||||
if float(span["start"]) <= t:
|
||||
current = span.get("text")
|
||||
else:
|
||||
break
|
||||
return current
|
||||
|
||||
def _initial_speech(self, record: EpisodeRecord) -> str | None:
|
||||
prompt = load_prompt("interjections_initial_speech").format(
|
||||
episode_task=record.episode_task,
|
||||
)
|
||||
messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}]
|
||||
result = self.vlm.generate_json([messages])[0]
|
||||
if isinstance(result, dict) and isinstance(result.get("text"), str):
|
||||
text = result["text"].strip()
|
||||
if text:
|
||||
return text
|
||||
return None
|
||||
|
||||
def _mid_episode_interjections(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
subtask_spans: Sequence[dict[str, Any]],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Generate interjections aligned with the actual demo trajectory.
|
||||
|
||||
Teleop data is frozen — the robot already executed every step in
|
||||
the video. A *counterfactual* interjection like "actually skip
|
||||
the wipe" contradicts what then happens in the video, which is
|
||||
what qwen36moe-10/11 surfaced as low-quality interjections.
|
||||
|
||||
Instead, anchor every interjection at a subtask boundary and
|
||||
write it as a natural user request for the *upcoming* subtask.
|
||||
The robot's visible next behavior IS the interjection's effect,
|
||||
so the training signal stays consistent: interjection text →
|
||||
plan refresh → action stream all line up.
|
||||
"""
|
||||
if self.config.max_interjections_per_episode <= 0:
|
||||
return []
|
||||
if len(subtask_spans) < 2:
|
||||
# Need at least one transition (subtask 0 → subtask 1).
|
||||
return []
|
||||
# Deterministic per-episode RNG so reruns are stable across SLURM jobs.
|
||||
rng = random.Random(f"{self.seed}:{record.episode_index}:interjection")
|
||||
|
||||
# Boundaries: the start time of every subtask except the first
|
||||
# (which is just t0 and is covered by the initial-task speech atom).
|
||||
boundaries: list[tuple[float, str, str]] = []
|
||||
for i in range(1, len(subtask_spans)):
|
||||
ts = float(subtask_spans[i]["start"])
|
||||
if ts < self.config.interjection_min_t:
|
||||
continue
|
||||
prev_text = (subtask_spans[i - 1].get("text") or "").strip()
|
||||
next_text = (subtask_spans[i].get("text") or "").strip()
|
||||
if not next_text:
|
||||
continue
|
||||
boundaries.append((ts, prev_text, next_text))
|
||||
if not boundaries:
|
||||
return []
|
||||
|
||||
n = min(self.config.max_interjections_per_episode, len(boundaries))
|
||||
chosen = sorted(rng.sample(boundaries, n), key=lambda b: b[0])
|
||||
|
||||
out: list[dict[str, Any]] = []
|
||||
for t, prev_subtask, next_subtask in chosen:
|
||||
t_snap = snap_to_frame(t, record.frame_timestamps)
|
||||
# Window straddles the boundary so the VLM sees the end of the
|
||||
# previous subtask and the start of the next one — same
|
||||
# conditioning the policy will see at training time.
|
||||
window_ts = self._window_timestamps(t_snap, record.frame_timestamps)
|
||||
prompt = load_prompt("interjections_interjection").format(
|
||||
episode_task=record.episode_task,
|
||||
prev_subtask=prev_subtask or "(starting from initial state)",
|
||||
next_subtask=next_subtask,
|
||||
timestamp=t_snap,
|
||||
window_seconds=self.config.interjection_window_seconds,
|
||||
)
|
||||
images = self.frame_provider.frames_at(record, window_ts)
|
||||
content = [*to_image_blocks(images), {"type": "text", "text": prompt}]
|
||||
messages = [{"role": "user", "content": content}]
|
||||
result = self.vlm.generate_json([messages])[0]
|
||||
if not isinstance(result, dict):
|
||||
continue
|
||||
interjection_text = result.get("interjection")
|
||||
speech_text = result.get("speech")
|
||||
if not isinstance(interjection_text, str) or not interjection_text.strip():
|
||||
continue
|
||||
if not isinstance(speech_text, str) or not speech_text.strip():
|
||||
continue
|
||||
out.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": interjection_text.strip(),
|
||||
"style": "interjection",
|
||||
"timestamp": t_snap,
|
||||
"tool_calls": None,
|
||||
}
|
||||
)
|
||||
out.append(speech_atom(t_snap, speech_text.strip()))
|
||||
return out
|
||||
|
||||
def _window_timestamps(self, t_anchor: float, frame_timestamps: Sequence[float]) -> list[float]:
|
||||
"""Return a small set of frame timestamps centered on ``t_anchor``.
|
||||
|
||||
The window straddles the subtask boundary the interjection sits
|
||||
on: roughly half the frames cover the end of the previous
|
||||
subtask, half cover the start of the next one. The VLM therefore
|
||||
sees BOTH what just finished AND what's about to start, which is
|
||||
the conditioning we need to write a natural "now please do X"
|
||||
request that matches the visible upcoming behavior.
|
||||
"""
|
||||
if not frame_timestamps:
|
||||
return [t_anchor]
|
||||
n = max(1, int(self.config.interjection_window_frames))
|
||||
if n == 1:
|
||||
return [t_anchor]
|
||||
window = float(self.config.interjection_window_seconds)
|
||||
step = window / max(1, n - 1)
|
||||
# Center the window on the anchor so half lands before, half after.
|
||||
start_offset = -window / 2.0
|
||||
targets = [t_anchor + start_offset + step * i for i in range(n)]
|
||||
first_ts = float(frame_timestamps[0])
|
||||
last_ts = float(frame_timestamps[-1])
|
||||
snapped: list[float] = []
|
||||
seen: set[float] = set()
|
||||
for tgt in targets:
|
||||
clamped = min(last_ts, max(first_ts, tgt))
|
||||
t = snap_to_frame(clamped, frame_timestamps)
|
||||
if t not in seen:
|
||||
seen.add(t)
|
||||
snapped.append(t)
|
||||
return snapped or [t_anchor]
|
||||
@@ -1,780 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""``plan`` module: subtask decomposition + plan + memory (PERSISTENT styles)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from ..config import PlanConfig
|
||||
from ..frames import (
|
||||
FrameProvider,
|
||||
null_provider,
|
||||
to_contact_sheet_blocks,
|
||||
)
|
||||
from ..prompts import load as load_prompt
|
||||
from ..reader import EpisodeRecord, reconstruct_subtask_spans, snap_to_frame
|
||||
from ..staging import EpisodeStaging
|
||||
from ..vlm_client import VlmClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Prepended to every describe / segment prompt so the VLM knows the images are
|
||||
# timestamped contact-sheet grids, not a single video, and reads the burned-in
|
||||
# per-tile timestamp when choosing boundaries.
|
||||
def _contact_sheet_preamble(columns: int) -> str:
|
||||
return (
|
||||
"CONTACT SHEETS — how to read the images below:\n"
|
||||
f"- Each image is a grid of sampled video frames, {columns} per row, "
|
||||
"with time running left-to-right then top-to-bottom (row-major).\n"
|
||||
"- Each frame has its timestamp burned into the top-left corner, e.g. "
|
||||
'"012.50s". Use that printed timestamp (not the tile position) when you '
|
||||
"choose start/end times; boundaries should land on or near a printed "
|
||||
"timestamp.\n"
|
||||
"- Frames continue across grids: an action may span the end of one sheet "
|
||||
"and the start of the next, so do not place a boundary just because a new "
|
||||
"image begins.\n\n"
|
||||
)
|
||||
|
||||
|
||||
# Appended to every describe (and segment) prompt. A visual, causal definition
|
||||
# of where one event ends and the next begins — adapted from macrodata/refiner —
|
||||
# to sharpen cut points while the existing prompt keeps owning the imperative
|
||||
# phrasing.
|
||||
_CAUSAL_BOUNDARY_RULES = (
|
||||
"EVENT BOUNDARIES — where one event ends and the next begins:\n"
|
||||
"- Start a new event whenever the world state changes: an object becomes "
|
||||
"held (the gripper closes on it), an object is released (the gripper opens "
|
||||
"and it stays put), an object reaches a new location, a lid/door/drawer "
|
||||
"changes open/closed state, a tool starts or stops affecting a surface, or "
|
||||
"contents visibly move (e.g. poured).\n"
|
||||
"- If a single action changes the same state gradually and continuously, "
|
||||
"keep it as ONE event — do not split it.\n"
|
||||
"- If the same action repeats on different objects or target locations, "
|
||||
"treat each repetition as a separate event.\n"
|
||||
"- Do NOT create boundaries for idle time, camera motion, hesitation, or "
|
||||
"tiny hand adjustments."
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlanSubtasksMemoryModule:
|
||||
"""Generate subtask spans, plan, and memory rows.
|
||||
|
||||
All output is persistent (lives in ``language_persistent``):
|
||||
|
||||
- ``subtask`` rows: one per span, stamped at the span's *start* timestamp
|
||||
(snapped to an exact frame).
|
||||
- ``plan`` rows: emitted at ``t=0``; refreshed at every interjection
|
||||
timestamp via :meth:`run_plan_updates` (called by the executor after
|
||||
the ``interjections`` module completes).
|
||||
- ``memory`` rows: emitted at each subtask boundary (= subtask start
|
||||
timestamp from the second subtask onward).
|
||||
"""
|
||||
|
||||
vlm: VlmClient
|
||||
config: PlanConfig
|
||||
frame_provider: FrameProvider = field(default_factory=null_provider)
|
||||
|
||||
@property
|
||||
def enabled(self) -> bool:
|
||||
return self.config.enabled
|
||||
|
||||
def run_episode(self, record: EpisodeRecord, staging: EpisodeStaging) -> None:
|
||||
rows: list[dict[str, Any]] = []
|
||||
# Task driving every plan-module prompt: canonical episode_task, or a
|
||||
# video-derived one when it's empty/placeholder (see derive_task_*).
|
||||
effective_task = self._resolve_effective_task(record)
|
||||
# task_aug rows at t=0: phrasings the renderer rotates ${task} through.
|
||||
# Either the structured 5-axis taxonomy (task_aug_axes.enabled) or
|
||||
# free-form n_task_rephrasings; the effective task is always emitted
|
||||
# first so the rotation covers the source-of-truth phrasing.
|
||||
t0 = float(record.frame_timestamps[0]) if record.frame_timestamps else 0.0
|
||||
variants: list[str] | None = None
|
||||
if self.config.task_aug_axes.enabled and effective_task:
|
||||
variants = self._generate_task_aug_by_axes(effective_task, self.config.task_aug_axes)
|
||||
elif self.config.n_task_rephrasings > 0 and effective_task:
|
||||
variants = self._generate_task_rephrasings(effective_task, n=self.config.n_task_rephrasings)
|
||||
if variants is not None:
|
||||
rows.extend(self._task_aug_rows([effective_task, *variants], t0))
|
||||
|
||||
subtask_spans = self._generate_subtasks(record, task=effective_task)
|
||||
|
||||
# subtask rows
|
||||
for span in subtask_spans:
|
||||
rows.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": span["text"],
|
||||
"style": "subtask",
|
||||
"timestamp": snap_to_frame(span["start"], record.frame_timestamps),
|
||||
"tool_calls": None,
|
||||
}
|
||||
)
|
||||
# Plan rows at every subtask boundary (incl. t=0). The plan is a
|
||||
# numbered list of still-todo subtasks, so re-emitting at each
|
||||
# boundary makes it shrink as work progresses — ${plan} at frame t is
|
||||
# exactly what's left to do.
|
||||
if self.config.emit_plan:
|
||||
for span in subtask_spans:
|
||||
boundary_t = snap_to_frame(span["start"], record.frame_timestamps)
|
||||
plan_text = self._generate_plan(
|
||||
record, subtask_spans, refresh_t=boundary_t, task=effective_task
|
||||
)
|
||||
if plan_text is not None:
|
||||
rows.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": plan_text,
|
||||
"style": "plan",
|
||||
"timestamp": float(boundary_t),
|
||||
"tool_calls": None,
|
||||
}
|
||||
)
|
||||
# memory rows at every subtask boundary except the very first start;
|
||||
# skipped entirely when ``emit_memory`` is False (subtasks-only / plan-only).
|
||||
prior_memory = ""
|
||||
memory_boundaries = enumerate(subtask_spans[1:], start=1) if self.config.emit_memory else []
|
||||
for i, span in memory_boundaries:
|
||||
completed = subtask_spans[i - 1]["text"]
|
||||
remaining = [s["text"] for s in subtask_spans[i:]]
|
||||
mem_text = self._generate_memory(record, prior_memory, completed, remaining, task=effective_task)
|
||||
if mem_text:
|
||||
ts = snap_to_frame(span["start"], record.frame_timestamps)
|
||||
rows.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": mem_text,
|
||||
"style": "memory",
|
||||
"timestamp": ts,
|
||||
"tool_calls": None,
|
||||
}
|
||||
)
|
||||
prior_memory = mem_text
|
||||
staging.write("plan", rows)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Task derivation + rephrasings
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
_PLACEHOLDER_TASKS: frozenset[str] = frozenset(
|
||||
{
|
||||
"debug",
|
||||
"test",
|
||||
"tbd",
|
||||
"todo",
|
||||
"n/a",
|
||||
"na",
|
||||
"untitled",
|
||||
"unnamed",
|
||||
"default",
|
||||
"placeholder",
|
||||
}
|
||||
)
|
||||
|
||||
def _resolve_effective_task(self, record: EpisodeRecord) -> str:
|
||||
"""Decide which task string drives the ``plan`` module for this episode.
|
||||
|
||||
Returns the user-supplied ``record.episode_task`` unless
|
||||
``derive_task_from_video`` says otherwise (see config docstring).
|
||||
Falls back gracefully to the canonical task if video derivation
|
||||
fails.
|
||||
"""
|
||||
canonical = (record.episode_task or "").strip()
|
||||
mode = (self.config.derive_task_from_video or "off").strip().lower()
|
||||
if mode == "always":
|
||||
derived = self._derive_task_from_video(record)
|
||||
return derived or canonical
|
||||
if mode == "if_short" and self._task_seems_bad(canonical):
|
||||
derived = self._derive_task_from_video(record)
|
||||
if derived:
|
||||
return derived
|
||||
return canonical
|
||||
|
||||
def _task_seems_bad(self, task: str) -> bool:
|
||||
if not task:
|
||||
return True
|
||||
if len(task.split()) < int(self.config.derive_task_min_words):
|
||||
return True
|
||||
return task.lower() in self._PLACEHOLDER_TASKS
|
||||
|
||||
@staticmethod
|
||||
def _task_aug_rows(phrasings: Sequence[str], t0: float) -> list[dict[str, Any]]:
|
||||
"""Build deduplicated ``task_aug`` rows (role=user) at ``t0``."""
|
||||
seen: set[str] = set()
|
||||
rows: list[dict[str, Any]] = []
|
||||
for phrasing in phrasings:
|
||||
key = phrasing.strip()
|
||||
if not key or key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
rows.append(
|
||||
{"role": "user", "content": key, "style": "task_aug", "timestamp": t0, "tool_calls": None}
|
||||
)
|
||||
return rows
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# VLM call helpers — every plan-module prompt follows the same shape:
|
||||
# build messages → single VLM call → pull a named field.
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _vlm_field(self, messages: list[dict[str, Any]], field: str) -> Any:
|
||||
"""Run a single VLM call and return ``result[field]`` or ``None``.
|
||||
|
||||
Centralizes the ``vlm.generate_json([m])[0]`` + ``isinstance(dict)``
|
||||
dance every prompt-call site needs.
|
||||
"""
|
||||
result = self.vlm.generate_json([messages])[0]
|
||||
if isinstance(result, dict):
|
||||
return result.get(field)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _text_message(text: str) -> list[dict[str, Any]]:
|
||||
"""One-shot text-only user message wrapped for ``generate_json``."""
|
||||
return [{"role": "user", "content": [{"type": "text", "text": text}]}]
|
||||
|
||||
def _video_message(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
prompt: str,
|
||||
window: tuple[float, float] | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""User message combining the (optionally windowed) contact sheets with ``prompt``.
|
||||
|
||||
The prompt is always prefixed with a short explanation of how to read
|
||||
the timestamped grids, so the model treats them as one ordered
|
||||
sequence of frames rather than unrelated images.
|
||||
"""
|
||||
prompt = _contact_sheet_preamble(self.config.contact_sheet_columns) + prompt
|
||||
content = [*self._episode_video_block(record, window=window), {"type": "text", "text": prompt}]
|
||||
return [{"role": "user", "content": content}]
|
||||
|
||||
def _derive_task_from_video(self, record: EpisodeRecord) -> str | None:
|
||||
"""Ask the VLM "what is this video about" with no task hint at all."""
|
||||
text = self._vlm_field(self._video_message(record, load_prompt("plan_video_task")), "task")
|
||||
return text.strip() if isinstance(text, str) and text.strip() else None
|
||||
|
||||
def _generate_task_rephrasings(self, base_task: str, *, n: int) -> list[str]:
|
||||
"""Generate ``n`` text-only paraphrases of ``base_task``."""
|
||||
if n <= 0 or not base_task:
|
||||
return []
|
||||
prompt = load_prompt("plan_task_rephrasings").format(base_task=base_task, n=n)
|
||||
raw = self._vlm_field(self._text_message(prompt), "rephrasings")
|
||||
if not isinstance(raw, list):
|
||||
return []
|
||||
out = [item.strip().strip('"').strip("'") for item in raw if isinstance(item, str)]
|
||||
return [s for s in out if s][:n]
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Structured 5-axis task augmentation (EgoMimic-style taxonomy)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _generate_task_aug_by_axes(self, base_task: str, axes_cfg: Any) -> list[str]:
|
||||
"""One VLM call → variants along the 5-axis taxonomy.
|
||||
|
||||
Variants from all axes are flattened into a single list (the
|
||||
downstream pipeline doesn't need to know about the per-axis
|
||||
bucketing — every variant becomes a ``task_aug`` row). Order
|
||||
is preserved for reproducibility: synonym_paraphrase first,
|
||||
then omit_arm, then omit_orientation, then omit_grasp_method,
|
||||
then combined_omissions.
|
||||
"""
|
||||
if not base_task:
|
||||
return []
|
||||
prompt = load_prompt("plan_task_aug_axes").format(
|
||||
base_task=base_task,
|
||||
n_synonym=axes_cfg.synonym_paraphrase,
|
||||
n_omit_arm=axes_cfg.omit_arm,
|
||||
n_omit_orientation=axes_cfg.omit_orientation,
|
||||
n_omit_grasp_method=axes_cfg.omit_grasp_method,
|
||||
n_combined=axes_cfg.combined_omissions,
|
||||
)
|
||||
result = self.vlm.generate_json([self._text_message(prompt)])[0]
|
||||
if not isinstance(result, dict):
|
||||
return []
|
||||
ordered_axes = (
|
||||
"synonym_paraphrase",
|
||||
"omit_arm",
|
||||
"omit_orientation",
|
||||
"omit_grasp_method",
|
||||
"combined_omissions",
|
||||
)
|
||||
flat: list[str] = []
|
||||
seen: set[str] = set()
|
||||
for axis in ordered_axes:
|
||||
entries = result.get(axis)
|
||||
if not isinstance(entries, list):
|
||||
continue
|
||||
for item in entries:
|
||||
if not isinstance(item, str):
|
||||
continue
|
||||
key = item.strip().strip('"').strip("'")
|
||||
if not key or key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
flat.append(key)
|
||||
return flat
|
||||
|
||||
def _episode_video_block(
|
||||
self, record: EpisodeRecord, window: tuple[float, float] | None = None
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Timestamped contact sheets for the describe / segmentation prompts.
|
||||
|
||||
Always renders the (optionally windowed) episode as contact sheets:
|
||||
frames sampled at ``frames_per_second`` and packed into timestamped
|
||||
JPEG grids. ``max_frames_per_prompt`` caps the frame count; whole
|
||||
episodes that exceed it are windowed upstream in
|
||||
:meth:`_generate_subtasks` so each call stays within budget while the
|
||||
full episode keeps its sampling density.
|
||||
|
||||
When ``window=(w0, w1)`` is given the badges are WINDOW-RELATIVE
|
||||
(``ts - w0``) to match the window-relative time frame the
|
||||
segmentation prompt works in (spans are offset back to absolute time
|
||||
afterwards).
|
||||
"""
|
||||
if not record.frame_timestamps:
|
||||
return []
|
||||
if window is not None:
|
||||
w0, w1 = float(window[0]), float(window[1])
|
||||
dur = max(0.0, w1 - w0)
|
||||
n = max(1, int(round(dur * self.config.frames_per_second)) + 1)
|
||||
n = min(n, self.config.max_frames_per_prompt)
|
||||
if n <= 1 or dur <= 0.0:
|
||||
timestamps = [0.5 * (w0 + w1)]
|
||||
else:
|
||||
step = dur / (n - 1)
|
||||
timestamps = [w0 + i * step for i in range(n)]
|
||||
frames = self.frame_provider.frames_at(record, timestamps)
|
||||
rel = [ts - w0 for ts in timestamps[: len(frames)]]
|
||||
return self._contact_sheet_blocks(frames, rel)
|
||||
episode_duration = record.frame_timestamps[-1] - record.frame_timestamps[0]
|
||||
n = max(1, int(round(episode_duration * self.config.frames_per_second)) + 1)
|
||||
n = min(n, self.config.max_frames_per_prompt)
|
||||
timestamps = self._uniform_episode_timestamps(record, n)
|
||||
frames = self.frame_provider.frames_at(record, timestamps)
|
||||
return self._contact_sheet_blocks(frames, timestamps[: len(frames)])
|
||||
|
||||
@staticmethod
|
||||
def _uniform_episode_timestamps(record: EpisodeRecord, n: int) -> list[float]:
|
||||
"""``n`` episode-relative timestamps spanning ``[t0, t_last]`` uniformly."""
|
||||
ts = record.frame_timestamps
|
||||
if n >= len(ts):
|
||||
return [float(t) for t in ts]
|
||||
t0, t_last = float(ts[0]), float(ts[-1])
|
||||
if t_last <= t0 or n <= 1:
|
||||
return [t0] * max(1, n)
|
||||
step = (t_last - t0) / (n - 1)
|
||||
return [t0 + i * step for i in range(n)]
|
||||
|
||||
def _contact_sheet_blocks(self, frames: list[Any], timestamps: list[float]) -> list[dict[str, Any]]:
|
||||
"""Build timestamped contact-sheet image blocks from decoded frames."""
|
||||
return to_contact_sheet_blocks(
|
||||
frames,
|
||||
timestamps,
|
||||
columns=self.config.contact_sheet_columns,
|
||||
frames_per_sheet=self.config.contact_sheet_frames_per_sheet,
|
||||
frame_width=self.config.contact_sheet_frame_width,
|
||||
quality=self.config.contact_sheet_quality,
|
||||
)
|
||||
|
||||
def run_plan_updates(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
staging: EpisodeStaging,
|
||||
interjection_times: Sequence[float],
|
||||
interjection_texts: Sequence[str] | None = None,
|
||||
) -> None:
|
||||
"""Append additional ``plan`` rows at every interjection timestamp.
|
||||
|
||||
Plans refresh ONLY on user interjections (event-driven). The
|
||||
interjection text is forwarded into the prompt so the refreshed plan
|
||||
reflects the user's correction.
|
||||
"""
|
||||
if not self.config.emit_plan:
|
||||
return
|
||||
existing = staging.read("plan")
|
||||
# Pass the last frame timestamp so the final span is closed (else its
|
||||
# end == start, zero duration, and a refresh inside it is missed).
|
||||
episode_end_t = float(record.frame_timestamps[-1]) if record.frame_timestamps else None
|
||||
spans = reconstruct_subtask_spans(existing, episode_end_t=episode_end_t)
|
||||
already_planned: set[float] = {float(r["timestamp"]) for r in existing if r.get("style") == "plan"}
|
||||
new_rows = list(existing)
|
||||
|
||||
texts: list[str | None] = (
|
||||
[None] * len(interjection_times)
|
||||
if interjection_texts is None
|
||||
else [str(t) if t else None for t in interjection_texts]
|
||||
)
|
||||
for raw_t, inter_text in zip(interjection_times, texts, strict=True):
|
||||
t = snap_to_frame(raw_t, record.frame_timestamps)
|
||||
if t in already_planned:
|
||||
continue
|
||||
already_planned.add(t)
|
||||
plan_text = self._generate_plan(record, spans, refresh_t=t, interjection=inter_text)
|
||||
if plan_text is not None:
|
||||
new_rows.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": plan_text,
|
||||
"style": "plan",
|
||||
"timestamp": t,
|
||||
"tool_calls": None,
|
||||
}
|
||||
)
|
||||
staging.write("plan", new_rows)
|
||||
|
||||
def _generate_subtasks(self, record: EpisodeRecord, *, task: str | None = None) -> list[dict[str, Any]]:
|
||||
"""Generate subtask spans, optionally via a multi-call quality chain.
|
||||
|
||||
Single call (default): watch video → emit subtask JSON.
|
||||
|
||||
Multi-call (opt-in, higher quality, more VLM calls):
|
||||
1. ``subtask_describe_first`` — a grounding pass that narrates
|
||||
ONLY what is visible (no JSON commitment to subtasks yet);
|
||||
its description is injected into the segmentation prompt so
|
||||
the model segments its own grounded observations instead of
|
||||
pattern-matching the task text.
|
||||
2. segmentation — emit subtask JSON (as before).
|
||||
"""
|
||||
if record.row_count == 0 or not record.frame_timestamps:
|
||||
return []
|
||||
episode_duration = record.frame_timestamps[-1] - record.frame_timestamps[0]
|
||||
effective_task = task if task is not None else record.episode_task
|
||||
|
||||
# ---- Auto-windowing (keeps the full sampling density) --------
|
||||
# Contact sheets are cheap, but a whole long episode sampled at
|
||||
# ``frames_per_second`` can still exceed ``max_frames_per_prompt``.
|
||||
# When it does, split into consecutive windows of exactly that many
|
||||
# frames (one describe→segment call each, still at the full sampling
|
||||
# density), then merge + stitch — so an episode of any length is
|
||||
# covered at full density rather than subsampled into one sparse call.
|
||||
fps = max(1e-6, float(self.config.frames_per_second))
|
||||
n_whole = int(round(episode_duration * fps)) + 1
|
||||
if n_whole > self.config.max_frames_per_prompt:
|
||||
window_s = self.config.max_frames_per_prompt / fps
|
||||
return self._generate_subtasks_windowed(record, effective_task, window_s)
|
||||
|
||||
# ---- Pass 1 (optional): grounding description ----------------
|
||||
observation_block = ""
|
||||
if getattr(self.config, "subtask_describe_first", False):
|
||||
description = self._describe_episode(record, effective_task)
|
||||
if description:
|
||||
observation_block = (
|
||||
"You watched this video and described, chronologically, "
|
||||
"ONLY what the robot actually does:\n"
|
||||
f'"""{description}"""\n\n'
|
||||
"Segment THAT grounded description (cross-checked against "
|
||||
"the video) into atomic subtasks. Do not introduce any "
|
||||
"action that is not in your description above.\n\n"
|
||||
)
|
||||
|
||||
# ---- Pass 2: segmentation ------------------------------------
|
||||
prompt = self._with_causal_rules(
|
||||
load_prompt("plan_subtasks").format(
|
||||
episode_task=effective_task,
|
||||
min_subtask_seconds=self.config.min_subtask_seconds,
|
||||
max_steps=self.config.plan_max_steps,
|
||||
episode_duration=f"{episode_duration:.3f}",
|
||||
observation_block=observation_block,
|
||||
)
|
||||
)
|
||||
spans = self._vlm_field(self._video_message(record, prompt), "subtasks")
|
||||
cleaned = self._clean_spans(spans, record)
|
||||
if not cleaned:
|
||||
return []
|
||||
|
||||
# ---- Full-episode coverage stitch ----------------------------
|
||||
# The VLM can start after t0 or leave gaps, so frames fall through
|
||||
# with no active subtask. Always stitch into a contiguous
|
||||
# [t0, t_last] cover.
|
||||
cleaned = self._stitch_full_coverage(cleaned, record)
|
||||
|
||||
return cleaned
|
||||
|
||||
def _generate_subtasks_windowed(
|
||||
self, record: EpisodeRecord, task: str, window_s: float
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Subtask generation in fixed-length windows at constant fps.
|
||||
|
||||
Splits ``[t0, t_last]`` into consecutive windows of ``window_s``
|
||||
seconds, runs the describe -> segment chain on each window's own
|
||||
frames (sampled at ``frames_per_second``), offsets
|
||||
each window's spans back to absolute episode time, then merges +
|
||||
stitches into a contiguous whole-episode cover.
|
||||
"""
|
||||
t0 = float(record.frame_timestamps[0])
|
||||
t_last = float(record.frame_timestamps[-1])
|
||||
all_spans: list[dict[str, Any]] = []
|
||||
w0 = t0
|
||||
n_windows = 0
|
||||
while w0 < t_last - 1e-6:
|
||||
w1 = min(w0 + window_s, t_last)
|
||||
all_spans.extend(self._subtasks_for_window(record, task, w0, w1))
|
||||
n_windows += 1
|
||||
w0 = w1
|
||||
logger.info(
|
||||
"episode %d: windowed subtask gen over %d window(s) of %.1fs -> %d raw spans",
|
||||
record.episode_index,
|
||||
n_windows,
|
||||
window_s,
|
||||
len(all_spans),
|
||||
)
|
||||
# Merge across windows: clamp to the absolute episode, sort, and
|
||||
# frame-snap to distinct starts (handles any boundary collisions).
|
||||
cleaned = self._clean_spans(all_spans, record)
|
||||
if not cleaned:
|
||||
return []
|
||||
return self._stitch_full_coverage(cleaned, record)
|
||||
|
||||
def _subtasks_for_window(
|
||||
self, record: EpisodeRecord, task: str, w0: float, w1: float
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Run describe -> segment on one ``[w0, w1]`` window.
|
||||
|
||||
The model works in window-RELATIVE time ``[0, L]`` (it perceives
|
||||
the window as a clip starting at 0); spans are offset back to
|
||||
absolute ``[w0, w1]`` before returning.
|
||||
"""
|
||||
window = (w0, w1)
|
||||
win_len = max(0.0, w1 - w0)
|
||||
|
||||
observation_block = ""
|
||||
if getattr(self.config, "subtask_describe_first", False):
|
||||
description = self._describe_episode(record, task, window=window)
|
||||
if description:
|
||||
observation_block = (
|
||||
"You watched this video clip and described, chronologically, "
|
||||
"ONLY what the robot actually does:\n"
|
||||
f'"""{description}"""\n\n'
|
||||
"Segment THAT grounded description (cross-checked against "
|
||||
"the clip) into atomic subtasks. Do not introduce any "
|
||||
"action that is not in your description above.\n\n"
|
||||
)
|
||||
|
||||
prompt = self._with_causal_rules(
|
||||
load_prompt("plan_subtasks").format(
|
||||
episode_task=task,
|
||||
min_subtask_seconds=self.config.min_subtask_seconds,
|
||||
max_steps=self.config.plan_max_steps,
|
||||
episode_duration=f"{win_len:.3f}",
|
||||
observation_block=observation_block,
|
||||
)
|
||||
)
|
||||
spans = self._vlm_field(self._video_message(record, prompt, window=window), "subtasks")
|
||||
# Window-relative clamp; no frame-snap dedupe yet (done on the
|
||||
# merged absolute set).
|
||||
cleaned = self._clean_spans(spans, record, bounds=(0.0, win_len), dedupe=False)
|
||||
if not cleaned:
|
||||
return []
|
||||
|
||||
# Offset window-relative spans back to absolute episode time.
|
||||
for s in cleaned:
|
||||
s["start"] = w0 + float(s["start"])
|
||||
s["end"] = w0 + float(s["end"])
|
||||
return cleaned
|
||||
|
||||
def _stitch_full_coverage(
|
||||
self, spans: list[dict[str, Any]], record: EpisodeRecord
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Make subtask spans tile the full episode with no gaps.
|
||||
|
||||
* The first subtask starts at the episode's first frame ``t0``
|
||||
(any idle / approach before the first labelled action is folded
|
||||
into it), so every early frame has an active subtask.
|
||||
* Each subtask's ``end`` is snapped to the next subtask's
|
||||
``start`` (gaps between spans are closed), and the final
|
||||
subtask's ``end`` extends to the last frame ``t_last``.
|
||||
|
||||
Starts are otherwise left as the (already frame-snapped, distinct)
|
||||
values the VLM produced — only the FIRST start is pulled
|
||||
back to ``t0``, which can't collide with a later span because it
|
||||
was already the earliest. Purely deterministic; runs after the
|
||||
VLM passes.
|
||||
"""
|
||||
if not spans or not record.frame_timestamps:
|
||||
return spans
|
||||
t0 = float(record.frame_timestamps[0])
|
||||
t_last = float(record.frame_timestamps[-1])
|
||||
spans = sorted(spans, key=lambda s: float(s["start"]))
|
||||
spans[0]["start"] = t0
|
||||
for i in range(len(spans) - 1):
|
||||
spans[i]["end"] = float(spans[i + 1]["start"])
|
||||
spans[-1]["end"] = t_last
|
||||
for s in spans:
|
||||
if float(s["end"]) < float(s["start"]):
|
||||
s["end"] = float(s["start"])
|
||||
return spans
|
||||
|
||||
@staticmethod
|
||||
def _with_causal_rules(prompt: str) -> str:
|
||||
"""Append the causal event-boundary rules to a describe/segment prompt."""
|
||||
return f"{prompt}\n\n{_CAUSAL_BOUNDARY_RULES}"
|
||||
|
||||
def _clean_spans(
|
||||
self,
|
||||
spans: Any,
|
||||
record: EpisodeRecord,
|
||||
bounds: tuple[float, float] | None = None,
|
||||
dedupe: bool = True,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Clamp / sort / (optionally) dedupe raw VLM subtask spans into valid rows.
|
||||
|
||||
``bounds`` overrides the clamp range — pass the window's
|
||||
``(w_lo, w_hi)`` when cleaning window-relative spans, or leave
|
||||
``None`` to clamp to the whole episode ``[t0, t_last]``.
|
||||
``dedupe`` runs the frame-snap distinct-start step; skip it for
|
||||
window-relative spans (frame snapping is done once on the merged,
|
||||
absolute-time set).
|
||||
"""
|
||||
if not spans:
|
||||
return []
|
||||
if bounds is not None:
|
||||
lo, hi = float(bounds[0]), float(bounds[1])
|
||||
else:
|
||||
lo = record.frame_timestamps[0]
|
||||
hi = record.frame_timestamps[-1]
|
||||
cleaned: list[dict[str, Any]] = []
|
||||
for span in spans:
|
||||
try:
|
||||
start = float(span["start"])
|
||||
end = float(span["end"])
|
||||
text = str(span["text"]).strip()
|
||||
except (KeyError, ValueError, TypeError):
|
||||
continue
|
||||
start = max(lo, min(start, hi))
|
||||
end = max(lo, min(end, hi))
|
||||
if end < start:
|
||||
start, end = end, start
|
||||
if not text:
|
||||
continue
|
||||
cleaned.append({"text": text, "start": start, "end": end})
|
||||
cleaned.sort(key=lambda s: s["start"])
|
||||
if dedupe:
|
||||
return self._dedupe_starts_to_distinct_frames(cleaned, record)
|
||||
return cleaned
|
||||
|
||||
def _describe_episode(
|
||||
self, record: EpisodeRecord, task: str, window: tuple[float, float] | None = None
|
||||
) -> str:
|
||||
"""Grounding pass: free-form chronological description of the (windowed) video."""
|
||||
prompt = self._with_causal_rules(load_prompt("plan_subtask_describe").format(episode_task=task))
|
||||
text = self._vlm_field(self._video_message(record, prompt, window=window), "description")
|
||||
return text.strip() if isinstance(text, str) and text.strip() else ""
|
||||
|
||||
@staticmethod
|
||||
def _dedupe_starts_to_distinct_frames(
|
||||
spans: list[dict[str, Any]], record: EpisodeRecord
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Bump same-frame subtask starts onto distinct frames.
|
||||
|
||||
Two consecutive VLM spans whose ``start`` rounds to the same
|
||||
source frame (after :func:`snap_to_frame`) would otherwise emit
|
||||
two ``style=subtask`` rows at the identical persistent
|
||||
timestamp. The training-time renderer's ``active_at(t,
|
||||
style=subtask)`` resolver can't disambiguate that and raises
|
||||
``Ambiguous resolver for style='subtask'``.
|
||||
|
||||
Walk the (sorted-by-start) spans, snap each to its frame, and
|
||||
if the snapped frame is already taken push the span onto the
|
||||
next unused frame so both subtasks survive on distinct
|
||||
timestamps. If the episode ends before a free frame is found,
|
||||
the trailing span is dropped with a warning — better than
|
||||
poisoning the render.
|
||||
"""
|
||||
if not spans:
|
||||
return spans
|
||||
frames = record.frame_timestamps
|
||||
if not frames:
|
||||
return spans
|
||||
used: set[float] = set()
|
||||
out: list[dict[str, Any]] = []
|
||||
for span in spans:
|
||||
ts = snap_to_frame(span["start"], frames)
|
||||
if ts in used:
|
||||
next_ts = next((f for f in frames if f > ts and f not in used), None)
|
||||
if next_ts is None:
|
||||
logger.warning(
|
||||
"episode %d: subtask %r snapped to occupied frame "
|
||||
"%.3f and no free later frame exists — dropping",
|
||||
record.episode_index,
|
||||
span.get("text"),
|
||||
ts,
|
||||
)
|
||||
continue
|
||||
ts = next_ts
|
||||
used.add(ts)
|
||||
new_span = {**span, "start": ts}
|
||||
if float(new_span.get("end", ts)) < ts:
|
||||
new_span["end"] = ts
|
||||
out.append(new_span)
|
||||
return out
|
||||
|
||||
def _generate_plan(
|
||||
self,
|
||||
record: EpisodeRecord, # noqa: ARG002 (kept for signature stability)
|
||||
subtask_spans: Sequence[dict[str, Any]],
|
||||
*,
|
||||
refresh_t: float | None = None,
|
||||
interjection: str | None = None, # noqa: ARG002
|
||||
task: str | None = None, # noqa: ARG002
|
||||
) -> str | None:
|
||||
"""Deterministic plan = numbered list of *still-todo* subtasks.
|
||||
|
||||
No VLM call: a plain numbered list keeps the plan aligned with the
|
||||
upcoming subtasks (the old VLM "compact hierarchical plan" prompt
|
||||
cost a round-trip per episode/refresh and could diverge).
|
||||
|
||||
1. <subtask 1>
|
||||
2. <subtask 2>
|
||||
|
||||
On a refresh at ``refresh_t`` (from ``run_plan_updates`` on
|
||||
interjections, and ``run_episode`` at each boundary), only subtasks
|
||||
starting at or after ``refresh_t`` are included — so it always
|
||||
describes what's left.
|
||||
"""
|
||||
if not subtask_spans:
|
||||
return None
|
||||
remaining = [
|
||||
s for s in subtask_spans if refresh_t is None or float(s.get("start", 0.0)) >= float(refresh_t)
|
||||
]
|
||||
if not remaining:
|
||||
# Past the last subtask boundary on a late refresh — nothing
|
||||
# left to plan; emit None so the caller skips the row.
|
||||
return None
|
||||
return "\n".join(f"{i}. {span.get('text', '').strip()}" for i, span in enumerate(remaining, start=1))
|
||||
|
||||
def _generate_memory(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
prior_memory: str,
|
||||
completed: str,
|
||||
remaining: Sequence[str],
|
||||
*,
|
||||
task: str | None = None,
|
||||
) -> str:
|
||||
prompt = load_prompt("plan_memory").format(
|
||||
episode_task=(task if task is not None else record.episode_task),
|
||||
prior_memory=prior_memory or "(none)",
|
||||
completed_subtask=completed,
|
||||
remaining_subtasks=", ".join(remaining) if remaining else "(none)",
|
||||
)
|
||||
memory = self._vlm_field(self._text_message(prompt), "memory")
|
||||
return memory.strip() if isinstance(memory, str) else ""
|
||||
@@ -1,33 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Prompt templates loaded as plain text.
|
||||
|
||||
One file per use site. Templates use ``str.format(**vars)`` substitution; we
|
||||
intentionally avoid jinja2 here so the templates remain inspectable in
|
||||
plain editors and roundtrip cleanly through ``ruff format``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
_DIR = Path(__file__).parent
|
||||
|
||||
|
||||
def load(name: str) -> str:
|
||||
"""Read prompt template ``name.txt`` from the ``prompts/`` directory."""
|
||||
path = _DIR / f"{name}.txt"
|
||||
return path.read_text(encoding="utf-8")
|
||||
@@ -1,12 +0,0 @@
|
||||
The user just asked the robot: "{episode_task}".
|
||||
|
||||
Generate a short verbal acknowledgement the robot would speak back before
|
||||
beginning the task. Style: compact, confident, friendly.
|
||||
|
||||
Examples (Hi Robot, Shi 2025): "Sure, I won't put cheese on it.",
|
||||
"OK, starting with the sponge.", "Got it.".
|
||||
|
||||
Prefer very short replies: "Got it.", "On it.", "OK."
|
||||
|
||||
Output strictly valid JSON:
|
||||
{{ "text": "<the spoken acknowledgement>" }}
|
||||
@@ -1,46 +0,0 @@
|
||||
You are generating training data for a Hi Robot-style hierarchical
|
||||
robot policy. The robot in this demonstration has ALREADY executed
|
||||
every step shown in the video — we cannot retroactively change the
|
||||
action stream. To keep training data consistent with the video, the
|
||||
"interjection" must align with what the robot is *about to do next* in
|
||||
the demonstration, framed as a natural mid-task user request.
|
||||
|
||||
The episode's overall task: "{episode_task}".
|
||||
|
||||
The images above show roughly {window_seconds:.1f} seconds straddling a
|
||||
subtask boundary in the demonstration:
|
||||
|
||||
- Subtask the robot just finished: "{prev_subtask}"
|
||||
- Subtask the robot is about to start: "{next_subtask}"
|
||||
- Time into episode: {timestamp:.2f}s
|
||||
|
||||
Write ONE compact interjection the user would naturally say at this
|
||||
moment to prompt / confirm / encourage the robot to do "{next_subtask}".
|
||||
Keep it like a mid-task coaching cue, not a full instruction paragraph.
|
||||
Also write the robot's compact verbal acknowledgement.
|
||||
|
||||
Hard rules:
|
||||
|
||||
- The interjection MUST be consistent with the next subtask. The user
|
||||
cannot ask for something different from what the robot then does in
|
||||
the video. If you're tempted to say "actually skip X" or "do Y
|
||||
instead", DO NOT — those would contradict the demonstration.
|
||||
- The interjection must reference an object, location, or action that
|
||||
is plausible given the visible scene and the next subtask text.
|
||||
- One short phrase or sentence each. Conversational, not robotic.
|
||||
- Prefer direct cues: "{next_subtask}, please."; "Now {next_subtask}."
|
||||
- Keep robot speech very short: "OK.", "On it.", "Doing that."
|
||||
|
||||
Style examples (vary the phrasing — don't reuse these verbatim):
|
||||
- "Now go ahead and {next_subtask}."
|
||||
- "Great, can you {next_subtask} next?"
|
||||
- "{next_subtask}, please."
|
||||
- "Before you continue, please {next_subtask}."
|
||||
- "Looking good — {next_subtask} now."
|
||||
- "Okay, {next_subtask}."
|
||||
|
||||
Output strictly valid JSON:
|
||||
{{
|
||||
"interjection": "<short cue from the user, asking for the next subtask>",
|
||||
"speech": "<short robot acknowledgement>"
|
||||
}}
|
||||
@@ -1,36 +0,0 @@
|
||||
You are updating the robot's compressed semantic memory at the boundary of
|
||||
a completed subtask.
|
||||
|
||||
Reference (verbatim from MEM, Torne 2026):
|
||||
"Remove or compress information in the language memory whenever
|
||||
appropriate. Keep ONLY the minimal set of relevant information for future
|
||||
task execution. Specific object attributes (colors, precise quantities of
|
||||
each item) get discarded when their details won't affect subsequent
|
||||
actions. Functional outcomes (where items went, how many) are preserved."
|
||||
|
||||
Episode task: "{episode_task}"
|
||||
Previous memory: {prior_memory}
|
||||
Just-completed subtask: "{completed_subtask}"
|
||||
Remaining subtasks (for relevance judgement only): {remaining_subtasks}
|
||||
|
||||
Write the memory as a short FIRST-PERSON, PAST-TENSE narrative of what the
|
||||
robot has accomplished so far — the running story it would tell itself.
|
||||
|
||||
Authoring rules:
|
||||
- First person, past tense. Every sentence starts with "I": "I picked
|
||||
up...", "I opened...", "I moved to...".
|
||||
- One or two short sentences. Extend the previous memory with the
|
||||
just-completed subtask; do not rewrite it from scratch.
|
||||
- Keep WHAT happened (functional outcomes — where items went, how many),
|
||||
drop HOW (grasp details, motions).
|
||||
- Compress completed steps and drop object attributes (colors, exact
|
||||
counts) once they no longer affect the remaining subtasks.
|
||||
|
||||
Example (MEM, Torne 2026):
|
||||
Before: "I prepared the pot and got the potatoes, milk, and butter. I
|
||||
moved to the drawer."
|
||||
After: "I prepared the pot and got the ingredients. I opened the
|
||||
drawer with the masher."
|
||||
|
||||
Output strictly valid JSON:
|
||||
{{ "memory": "<one or two short first-person past-tense sentences>" }}
|
||||
@@ -1,27 +0,0 @@
|
||||
You are watching a teleoperated robot demonstration from a single
|
||||
camera. The user asked the robot to: "{episode_task}"
|
||||
|
||||
This is an OBSERVATION pass. Watch the entire clip and describe, in
|
||||
chronological order, ONLY what the robot physically does — the concrete
|
||||
motions, approaches, contacts, grasps, releases, and relocations you can
|
||||
actually SEE in the frames.
|
||||
|
||||
Hard rules:
|
||||
- Describe only motion visible in the video. Do NOT use the task
|
||||
instruction to guess steps that aren't shown. The instruction is the
|
||||
goal; the video is ground truth.
|
||||
- Do NOT segment into named subtasks yet and do NOT output JSON beyond
|
||||
the single field below. Just narrate what happens.
|
||||
- Give an approximate timestamp (in seconds) for each distinct event,
|
||||
e.g. "0.0-1.4s: the base drives forward toward the stove".
|
||||
- Do NOT invent objects, grasps, destinations, or steps. If the robot
|
||||
only does one thing (e.g. it just navigates and the clip ends), say
|
||||
exactly that and nothing more.
|
||||
- Be concrete and literal. "the gripper closes on the mug" — not "the
|
||||
robot prepares to make coffee".
|
||||
|
||||
Output strictly valid JSON:
|
||||
|
||||
{{
|
||||
"description": "<chronological, timestamped description of ONLY what is visible>"
|
||||
}}
|
||||
@@ -1,112 +0,0 @@
|
||||
You are labeling a teleoperated robot demonstration.
|
||||
|
||||
The user originally asked: "{episode_task}"
|
||||
|
||||
You are shown the entire demonstration as a single video. Watch the
|
||||
whole clip, then segment it into a list of consecutive atomic subtasks
|
||||
the robot performs.
|
||||
|
||||
{observation_block}GROUNDING — read this first, it overrides everything below:
|
||||
- Label ONLY what the robot actually does in the video. Every subtask
|
||||
you emit must correspond to motion you can SEE in specific frames.
|
||||
- Do NOT invent, anticipate, or pad. If the robot only does one thing
|
||||
(e.g. it just navigates to a location and the clip ends), emit
|
||||
EXACTLY ONE subtask. Many demonstrations are a single atomic skill.
|
||||
- ``max_steps`` below is a hard CEILING, not a target. Emitting fewer
|
||||
subtasks than the ceiling is not just allowed, it is expected for
|
||||
short / atomic demonstrations. One correct subtask is far better
|
||||
than several invented ones.
|
||||
- If the video does not clearly show the action implied by the task,
|
||||
describe what you actually see — do NOT fabricate the task's steps
|
||||
from the instruction text. The instruction tells you the goal; the
|
||||
VIDEO is the ground truth for what happened.
|
||||
|
||||
Authoring rules — Hi Robot atom granularity, pi0.7-style short prompts:
|
||||
|
||||
- Each subtask = one COMPOSITE atomic skill the low-level policy can
|
||||
execute end-to-end. A "skill" bundles its own approach motion with
|
||||
its terminal action — do NOT split the approach off as its own
|
||||
subtask. The whole-arm policy already learns to reach as part of
|
||||
every manipulation primitive.
|
||||
- Write each subtask as an IMPERATIVE COMMAND, starting with one of
|
||||
these verbs (extend only when none fits):
|
||||
pick up <obj> — approach + grasp + lift in one subtask
|
||||
put <obj> on/in <loc> — transport + release in one subtask
|
||||
place <obj> on/in <loc> — synonym of "put"; pick one and stay consistent
|
||||
push <obj> — contact + linear shove
|
||||
pull <obj> — contact + linear retract
|
||||
turn <knob/dial/handle> — rotary actuation
|
||||
press <button> — single-press contact
|
||||
open <drawer/door/lid> — full open motion
|
||||
close <drawer/door/lid> — full close motion
|
||||
pour <src> into <dst> — tilt + flow
|
||||
insert <obj> into <slot>— alignment + push-fit
|
||||
go to <loc> — ONLY when no grasp / actuation follows
|
||||
(e.g. a pure relocation between phases).
|
||||
If the next subtask grasps something at
|
||||
that location, drop "go to ..." and just
|
||||
write "pick up ..." instead.
|
||||
- Forbidden ultra-fine splits — the VLM is NOT allowed to emit these
|
||||
as standalone subtasks; fold them into the parent composite:
|
||||
"move to X" → fold into "pick up X" (or whatever follows)
|
||||
"reach for X" → fold into "pick up X"
|
||||
"grasp X" → fold into "pick up X"
|
||||
"lift X" → fold into "pick up X" (or "put X on Y" if it's
|
||||
the transport phase of a place)
|
||||
"release X" → fold into "put X on Y" (or "place X in Y")
|
||||
- Keep it SHORT — a verb phrase, not a sentence. Drop articles
|
||||
("the", "a") and adverbs ("carefully", "slowly"). Add a "how"
|
||||
detail (which hand, which grasp point) ONLY when it is needed to
|
||||
disambiguate. Every subtask must begin with one of the verbs
|
||||
above (no leading nouns, no "then", no "first").
|
||||
- NEVER use third person. Never write "the robot", "the arm", "the
|
||||
gripper moves", "it picks up" — the robot is implied. Command it,
|
||||
do not describe it.
|
||||
- Use the exact object nouns from the task above. If the task says
|
||||
"cube", every subtask says "cube" — never switch to "block". If it
|
||||
says "box", never switch to "bin"/"container". Keep vocabulary
|
||||
consistent across the whole episode.
|
||||
- Good: "pick up blue cube", "put blue cube in box", "open drawer",
|
||||
"turn red knob", "press start button", "go to sink".
|
||||
- Bad: "move to blue cube" (approach as its own subtask — forbidden,
|
||||
must be folded into "pick up blue cube"); "the robot arm moves
|
||||
towards the blue cube" (third person, too long); "carefully pick
|
||||
up the cube" (adverb, article); "release the yellow block"
|
||||
("block" when the task said "cube", and "release" must be folded
|
||||
into a "put"/"place" subtask).
|
||||
- Subtasks are non-overlapping and cover the full episode in order.
|
||||
Choose the cut points yourself based on what you see in the video
|
||||
(gripper open/close events, contact, regrasps, transitions).
|
||||
- Each subtask spans at least {min_subtask_seconds} seconds. If a
|
||||
candidate span would be shorter, merge it into its neighbour
|
||||
rather than emitting it.
|
||||
- Do not exceed {max_steps} subtasks total. Fewer, larger composites
|
||||
are preferred over many micro-steps.
|
||||
- Every subtask's [start_time, end_time] must lie within
|
||||
[0.0, {episode_duration}] seconds.
|
||||
|
||||
SPECIAL CASES — verb disambiguation (each rule is narrowly visual and
|
||||
fires ONLY on the spatial situation it names; it must not change how you
|
||||
label any other situation):
|
||||
- STACK vs PUT: if an object is placed ON TOP OF another specific object
|
||||
(not on a flat table / shelf / counter), use "stack ... on ...", not
|
||||
"put". "stack blue book on green book", NOT "put blue book on table".
|
||||
- INSERT vs PUT: if an object goes INTO a fitted slot / hole / socket /
|
||||
receptacle (push-fit), use "insert ... into ...", not "put".
|
||||
- RETRIEVE/PICK-UP vs PUT (direction): watch the gripper. If it CLOSES
|
||||
on the object and the object moves WITH the hand, it is "pick up" /
|
||||
"retrieve" (object leaves its location). If the gripper OPENS and the
|
||||
object stays where the hand left it, it is "put" / "place" (object
|
||||
arrives at a location). Decide by which way the object moves, not by
|
||||
where the hand ends up.
|
||||
- POUR vs PUT: only use "pour" when the source is tilted and contents
|
||||
flow out; moving a full container without tilting is "put"/"place".
|
||||
|
||||
Output strictly valid JSON of shape:
|
||||
|
||||
{{
|
||||
"subtasks": [
|
||||
{{"text": "<short imperative verb phrase>", "start": <float>, "end": <float>}},
|
||||
...
|
||||
]
|
||||
}}
|
||||
@@ -1,67 +0,0 @@
|
||||
You are generating structured augmentations of a robot task instruction
|
||||
for training a language-conditioned policy. Unlike free-form rephrasing,
|
||||
your variants follow a NAMED 5-axis taxonomy — each axis omits or varies
|
||||
a specific element of the task while preserving its meaning.
|
||||
|
||||
Original task: "{base_task}"
|
||||
|
||||
Produce variants along five named axes. Each axis has a target count.
|
||||
The whole batch should expose the policy to maximum linguistic diversity
|
||||
WITHOUT changing what the robot is supposed to do.
|
||||
|
||||
Axes and target counts:
|
||||
|
||||
synonym_paraphrase ({n_synonym}):
|
||||
Different wording / verbs / sentence structure. ALL information
|
||||
from the original task is preserved — same object, same arm
|
||||
specification if present, same orientation if present, same grasp
|
||||
if present.
|
||||
|
||||
omit_arm ({n_omit_arm}):
|
||||
Drop the left/right/both arm specification from the task. Skip
|
||||
entirely (emit 0 entries) if the original task does NOT mention an
|
||||
arm. Do not invent an arm specification just to omit it.
|
||||
|
||||
omit_orientation ({n_omit_orientation}):
|
||||
Drop orientation cues (upright, sideways, facing the user,
|
||||
long-edge-first, etc.). Skip entirely if no orientation cue is
|
||||
present in the original task.
|
||||
|
||||
omit_grasp_method ({n_omit_grasp_method}):
|
||||
Drop the grip / grasp method specification (pinch, wrap, hold by
|
||||
the rim, etc.). Skip entirely if no grasp method is mentioned.
|
||||
|
||||
combined_omissions ({n_combined}):
|
||||
Combine TWO of the above omissions simultaneously (e.g. drop both
|
||||
arm and orientation). Skip entirely if fewer than two of (arm,
|
||||
orientation, grasp_method) appear in the original task.
|
||||
|
||||
Hard rules:
|
||||
- Each variant MUST preserve the core action, the target object, AND
|
||||
the goal / destination. Do not change which object is involved, where
|
||||
it goes, or the high-level action. "Navigate to the stove" may become
|
||||
"go to the stove" or "head over to the stove" — it must NEVER become
|
||||
"wander around the kitchen", "explore the room", or anything that
|
||||
drops or generalises the stove destination. If you cannot vary the
|
||||
wording without changing the goal, emit fewer variants.
|
||||
- Only the FIVE listed elements (wording, arm, orientation, grasp
|
||||
method, or a combination) may be varied or omitted. The verb's
|
||||
meaning, the object, and the destination are fixed.
|
||||
- Each variant is plain prose, no markdown, no quotes, no list numbers.
|
||||
- Each variant must be DISTINCT from every other variant in the entire
|
||||
output, both within and across axes. Near-duplicates are not allowed.
|
||||
- If an axis cannot reach its target count because the original task
|
||||
lacks the omittable element, emit fewer entries — do NOT pad the
|
||||
axis with paraphrases that belong to a different axis.
|
||||
- Variants should not all start with verbs — vary sentence structure
|
||||
(some imperative, some polite request, some question).
|
||||
|
||||
Output strictly valid JSON of shape:
|
||||
|
||||
{{
|
||||
"synonym_paraphrase": ["<v1>", "<v2>", ...],
|
||||
"omit_arm": ["<v1>", "<v2>", ...],
|
||||
"omit_orientation": ["<v1>", ...],
|
||||
"omit_grasp_method": ["<v1>", ...],
|
||||
"combined_omissions": ["<v1>", ...]
|
||||
}}
|
||||
@@ -1,32 +0,0 @@
|
||||
You are generating training data for a Hi Robot-style policy. We need
|
||||
{n} alternative phrasings of the same robot task so the policy sees
|
||||
diverse user prompts during training instead of the same canonical
|
||||
string repeated every frame.
|
||||
|
||||
Original task:
|
||||
"{base_task}"
|
||||
|
||||
Generate exactly {n} alternative phrasings of the same task. Vary:
|
||||
|
||||
- formality (casual / polite / curt)
|
||||
- verbosity (mostly short imperative; occasional polite request)
|
||||
- word choice (synonyms, different verbs)
|
||||
- sentence structure (imperative / question / suggestion)
|
||||
|
||||
Hard rules:
|
||||
- Each phrasing MUST preserve the exact meaning of the original task.
|
||||
Do not change which object is involved, the destination, or the
|
||||
action. Do not add extra steps. Do not invent new objects.
|
||||
- Each phrasing must be a short phrase or sentence, plain prose, no
|
||||
markdown, no quotes, no list numbers.
|
||||
- Phrasings must be distinct — no near-duplicates.
|
||||
- Output exactly {n} entries.
|
||||
|
||||
Output strictly valid JSON:
|
||||
{{
|
||||
"rephrasings": [
|
||||
"<phrasing 1>",
|
||||
"<phrasing 2>",
|
||||
...
|
||||
]
|
||||
}}
|
||||
@@ -1,17 +0,0 @@
|
||||
The video above shows a robot manipulation episode in full. Look at
|
||||
the entire video and describe in ONE concise sentence what the robot
|
||||
is doing.
|
||||
|
||||
Rules:
|
||||
- One sentence, in natural English, like a user instruction.
|
||||
- Capture the goal of the demonstration, not low-level motions.
|
||||
Example: "place the yellow cube into the red bin" — not "move the
|
||||
end-effector down 5cm and close the gripper".
|
||||
- 4 to 15 words. Plain prose, no markdown, no bullets, no quotes.
|
||||
- Do not invent objects or actions that aren't visible.
|
||||
- Do not output anything other than the JSON object below.
|
||||
|
||||
Output strictly valid JSON:
|
||||
{{
|
||||
"task": "<single concise sentence describing what the robot does in this video>"
|
||||
}}
|
||||
@@ -1,32 +0,0 @@
|
||||
You are generating a frame-grounded visual question/answer pair for
|
||||
chain-of-thought training. Reference: ECoT (Zawalski 2024) and Steerable
|
||||
Policies — both train policies on grounded features such as bounding box
|
||||
pixel coordinates, keypoints, counts, attributes, and spatial relations.
|
||||
|
||||
The frame shows a robot working on: "{episode_task}".
|
||||
|
||||
Question types and the EXACT answer JSON shape required for each:
|
||||
|
||||
bbox => {{"detections": [{{"label": "<obj>", "bbox_format": "xyxy",
|
||||
"bbox": [x1, y1, x2, y2]}}, ...]}}
|
||||
bbox is in pixel coordinates (x_min, y_min, x_max, y_max).
|
||||
ECoT example: "a white cup [124, 25, 176, 113]".
|
||||
|
||||
keypoint => {{"label": "<point>", "point_format": "xy",
|
||||
"point": [x, y]}}
|
||||
|
||||
count => {{"label": "<obj>", "count": <int>,
|
||||
"note": "<optional short note>"}}
|
||||
|
||||
attribute => {{"label": "<obj>", "attribute": "<color|shape|state|...>",
|
||||
"value": "<observed value>"}}
|
||||
|
||||
spatial => {{"subject": "<obj>", "relation": "<left_of|right_of|on|in|"
|
||||
"above|below|near>", "object": "<obj>"}}
|
||||
|
||||
Generate a question of type "{question_type}". Output strictly valid JSON:
|
||||
|
||||
{{
|
||||
"question": "<short, frame-grounded question>",
|
||||
"answer": <object whose shape matches the schema above>
|
||||
}}
|
||||
@@ -1,216 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Datatrove-shaped reader.
|
||||
|
||||
The reader walks ``data/chunk-*/file-*.parquet`` and yields one record per
|
||||
episode containing:
|
||||
|
||||
- ``episode_index``: int
|
||||
- ``frame_timestamps``: tuple[float, ...]
|
||||
- ``frame_indices``: tuple[int, ...]
|
||||
- ``episode_task``: str (canonical task from ``meta/tasks.parquet``)
|
||||
- ``data_path``: pathlib.Path of the source parquet shard
|
||||
- ``frames_df``: pandas.DataFrame slice for the episode (only loaded on demand)
|
||||
|
||||
This shape lets each module operate per-episode without loading all parquet
|
||||
rows into memory at once.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator, Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
from lerobot.datasets.io_utils import load_tasks
|
||||
from lerobot.datasets.utils import DEFAULT_TASKS_PATH
|
||||
|
||||
|
||||
@dataclass
|
||||
class EpisodeRecord:
|
||||
"""Per-episode record yielded by the reader."""
|
||||
|
||||
episode_index: int
|
||||
episode_task: str
|
||||
frame_timestamps: tuple[float, ...]
|
||||
frame_indices: tuple[int, ...]
|
||||
data_path: Path
|
||||
row_offset: int # row offset within the parquet file where this episode starts
|
||||
row_count: int # number of rows for this episode
|
||||
|
||||
# Memoized parquet slice — populated on first ``frames_df()`` call so
|
||||
# repeat queries from different modules don't re-read the whole shard.
|
||||
_frames_df_cache: Any = field(default=None, init=False, repr=False, compare=False)
|
||||
|
||||
def frames_df(self): # type: ignore[no-untyped-def]
|
||||
"""Lazy-load the pandas slice for this episode (memoized)."""
|
||||
if self._frames_df_cache is None:
|
||||
import pandas as pd # noqa: PLC0415 - deferred for optional dataset extra
|
||||
|
||||
table = pq.read_table(self.data_path)
|
||||
df: pd.DataFrame = table.to_pandas()
|
||||
self._frames_df_cache = df.iloc[self.row_offset : self.row_offset + self.row_count].reset_index(
|
||||
drop=True
|
||||
)
|
||||
return self._frames_df_cache
|
||||
|
||||
|
||||
def reconstruct_subtask_spans(
|
||||
rows: Sequence[dict[str, Any]],
|
||||
*,
|
||||
episode_end_t: float | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Turn ``style="subtask"`` rows into ``{text, start, end}`` spans.
|
||||
|
||||
Each span's ``end`` is the next span's ``start``. The final span's
|
||||
``end`` defaults to its own ``start`` (zero-duration) — pass
|
||||
``episode_end_t`` to extend it to the episode's last frame instead,
|
||||
which is what downstream consumers (memory, interjection boundary
|
||||
selection) expect.
|
||||
|
||||
Used by the ``plan`` module (plan-update pass) and the
|
||||
``interjections`` module (interjection anchoring), which both need the
|
||||
same span shape.
|
||||
"""
|
||||
sorted_rows = sorted(
|
||||
(r for r in rows if r.get("style") == "subtask"),
|
||||
key=lambda r: float(r["timestamp"]),
|
||||
)
|
||||
spans: list[dict[str, Any]] = []
|
||||
for r in sorted_rows:
|
||||
t = float(r["timestamp"])
|
||||
if spans:
|
||||
spans[-1]["end"] = t
|
||||
spans.append({"text": r.get("content") or "", "start": t, "end": t})
|
||||
if spans and episode_end_t is not None and float(episode_end_t) > spans[-1]["start"]:
|
||||
spans[-1]["end"] = float(episode_end_t)
|
||||
return spans
|
||||
|
||||
|
||||
def snap_to_frame(t: float, frame_timestamps: Sequence[float]) -> float:
|
||||
"""Snap an arbitrary float to the nearest exact source frame timestamp.
|
||||
|
||||
Modules use this when emitting event-style rows so the row's
|
||||
timestamp matches a real parquet frame: event rows must land on an
|
||||
exact frame, otherwise the per-frame event lookup the writer does
|
||||
would never match them.
|
||||
"""
|
||||
if not frame_timestamps:
|
||||
return float(t)
|
||||
nearest = min(frame_timestamps, key=lambda f: abs(f - t))
|
||||
return float(nearest)
|
||||
|
||||
|
||||
def _load_tasks_lookup(root: Path) -> dict[int, str]:
|
||||
"""Map ``task_index -> task`` from ``meta/tasks.parquet``.
|
||||
|
||||
Returns an empty dict when the file is absent — the task description is
|
||||
derived later from the video if needed. Reuses the library-level
|
||||
:func:`lerobot.datasets.io_utils.load_tasks`, which returns the tasks
|
||||
frame indexed by task string with a ``task_index`` column.
|
||||
"""
|
||||
if not (root / DEFAULT_TASKS_PATH).exists():
|
||||
return {}
|
||||
tasks = load_tasks(root)
|
||||
return {int(idx): str(task) for task, idx in zip(tasks.index, tasks["task_index"], strict=True)}
|
||||
|
||||
|
||||
def iter_episodes(root: Path, *, only_episodes: tuple[int, ...] | None = None) -> Iterator[EpisodeRecord]:
|
||||
"""Yield :class:`EpisodeRecord` for every episode under ``root/data/``.
|
||||
|
||||
Episodes are yielded in ascending ``episode_index`` order. The reader does
|
||||
not assume a specific chunk/file layout: it scans every ``*.parquet``
|
||||
under ``data/`` and groups by ``episode_index``.
|
||||
"""
|
||||
tasks = _load_tasks_lookup(root)
|
||||
data_dir = root / "data"
|
||||
parquet_files = sorted(data_dir.rglob("*.parquet"))
|
||||
|
||||
only_set = set(only_episodes) if only_episodes is not None else None
|
||||
|
||||
for path in parquet_files:
|
||||
yield from _iter_one_path(path, tasks, only_set)
|
||||
|
||||
|
||||
def _iter_one_path(path: Path, tasks: dict[int, str], only_set: set[int] | None) -> Iterator[EpisodeRecord]:
|
||||
table = pq.read_table(path)
|
||||
names = table.column_names
|
||||
if "episode_index" not in names:
|
||||
return
|
||||
episode_col = table.column("episode_index").to_pylist()
|
||||
timestamp_col = (
|
||||
table.column("timestamp").to_pylist() if "timestamp" in names else [0.0] * len(episode_col)
|
||||
)
|
||||
frame_col = (
|
||||
table.column("frame_index").to_pylist() if "frame_index" in names else list(range(len(episode_col)))
|
||||
)
|
||||
task_col = table.column("task_index").to_pylist() if "task_index" in names else None
|
||||
|
||||
def _build(
|
||||
ep: int,
|
||||
start: int,
|
||||
end: int,
|
||||
task_idx: int | None,
|
||||
ts_buf: list[float],
|
||||
fi_buf: list[int],
|
||||
) -> EpisodeRecord | None:
|
||||
if only_set is not None and ep not in only_set:
|
||||
return None
|
||||
task = tasks.get(task_idx, "") if task_idx is not None else ""
|
||||
return EpisodeRecord(
|
||||
episode_index=ep,
|
||||
episode_task=task,
|
||||
frame_timestamps=tuple(ts_buf),
|
||||
frame_indices=tuple(fi_buf),
|
||||
data_path=path,
|
||||
row_offset=start,
|
||||
row_count=end - start,
|
||||
)
|
||||
|
||||
cur_ep: int | None = None
|
||||
start_offset = 0
|
||||
ts_buf: list[float] = []
|
||||
fi_buf: list[int] = []
|
||||
cur_task_idx: int | None = None
|
||||
|
||||
for i, ep in enumerate(episode_col):
|
||||
if cur_ep is None:
|
||||
cur_ep = ep
|
||||
start_offset = i
|
||||
ts_buf = [timestamp_col[i]]
|
||||
fi_buf = [frame_col[i]]
|
||||
cur_task_idx = task_col[i] if task_col is not None else None
|
||||
continue
|
||||
if ep != cur_ep:
|
||||
rec = _build(cur_ep, start_offset, i, cur_task_idx, ts_buf, fi_buf)
|
||||
if rec is not None:
|
||||
yield rec
|
||||
cur_ep = ep
|
||||
start_offset = i
|
||||
ts_buf = [timestamp_col[i]]
|
||||
fi_buf = [frame_col[i]]
|
||||
cur_task_idx = task_col[i] if task_col is not None else None
|
||||
else:
|
||||
ts_buf.append(timestamp_col[i])
|
||||
fi_buf.append(frame_col[i])
|
||||
|
||||
if cur_ep is not None:
|
||||
rec = _build(cur_ep, start_offset, len(episode_col), cur_task_idx, ts_buf, fi_buf)
|
||||
if rec is not None:
|
||||
yield rec
|
||||
@@ -1,92 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Per-episode staging.
|
||||
|
||||
Each module writes its raw output as a JSONL file under
|
||||
``<staging_dir>/episode_{ep:06d}/<module>.jsonl``. The writer reads back this
|
||||
staging tree and partitions rows into the two language columns.
|
||||
|
||||
JSONL is preferred over parquet here because the staging artifact is meant to
|
||||
be human-inspectable, easy to diff between prompt iterations, and trivially
|
||||
appended to. The final dataset format is parquet; staging is just an
|
||||
intermediate.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
ModuleName = str
|
||||
|
||||
_MODULES: tuple[ModuleName, ...] = (
|
||||
"plan",
|
||||
"interjections",
|
||||
"vqa",
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EpisodeStaging:
|
||||
"""Filesystem layout for a single episode's staged module outputs."""
|
||||
|
||||
root: Path
|
||||
episode_index: int
|
||||
|
||||
@property
|
||||
def episode_dir(self) -> Path:
|
||||
return self.root / f"episode_{self.episode_index:06d}"
|
||||
|
||||
def path_for(self, module: ModuleName) -> Path:
|
||||
if module not in _MODULES:
|
||||
raise ValueError(f"Unknown module {module!r}; expected one of {_MODULES}")
|
||||
return self.episode_dir / f"{module}.jsonl"
|
||||
|
||||
def write(self, module: ModuleName, rows: Iterable[dict[str, Any]]) -> Path:
|
||||
path = self.path_for(module)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
# Atomic replace: a crash mid-write would otherwise leave a
|
||||
# half-written JSONL file that ``read()`` would then fail to
|
||||
# parse. Write to a sibling .tmp and rename so the target path
|
||||
# only ever points at a complete file.
|
||||
tmp_path = path.with_suffix(path.suffix + ".tmp")
|
||||
with tmp_path.open("w", encoding="utf-8") as f:
|
||||
for row in rows:
|
||||
f.write(json.dumps(row, ensure_ascii=False, sort_keys=True))
|
||||
f.write("\n")
|
||||
tmp_path.replace(path)
|
||||
return path
|
||||
|
||||
def read(self, module: ModuleName) -> list[dict[str, Any]]:
|
||||
path = self.path_for(module)
|
||||
if not path.exists():
|
||||
return []
|
||||
out: list[dict[str, Any]] = []
|
||||
with path.open(encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line:
|
||||
out.append(json.loads(line))
|
||||
return out
|
||||
|
||||
def read_all(self) -> dict[ModuleName, list[dict[str, Any]]]:
|
||||
return {m: self.read(m) for m in _MODULES}
|
||||
|
||||
def has(self, module: ModuleName) -> bool:
|
||||
return self.path_for(module).exists()
|
||||
@@ -1,332 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Pre-write validation against staged outputs.
|
||||
|
||||
Runs after all three modules have written their per-episode artifacts but
|
||||
*before* the writer rewrites parquet shards. The validator never touches
|
||||
parquet; it only inspects the staging tree and the source frame timestamps
|
||||
exposed by :class:`EpisodeRecord`.
|
||||
|
||||
Checks (per the plan's "Intermediate staging and validation" section):
|
||||
|
||||
- exact timestamp alignment against source frame timestamps
|
||||
- no orphan speech / interjection pairs
|
||||
- plan / memory emission consistency (events have a paired persistent row)
|
||||
- VQA assistant ``content`` is valid JSON (one of bbox / keypoint / count /
|
||||
attribute / spatial)
|
||||
- every row maps to its correct column under :func:`column_for_style`
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Iterable, Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from lerobot.datasets.language import (
|
||||
LANGUAGE_EVENTS,
|
||||
LANGUAGE_PERSISTENT,
|
||||
column_for_style,
|
||||
is_view_dependent_style,
|
||||
validate_camera_field,
|
||||
)
|
||||
|
||||
from .reader import EpisodeRecord
|
||||
from .staging import EpisodeStaging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValidationReport:
|
||||
"""Outcome of one validation pass across all episodes."""
|
||||
|
||||
errors: list[str] = field(default_factory=list)
|
||||
warnings: list[str] = field(default_factory=list)
|
||||
episodes_checked: int = 0
|
||||
|
||||
@property
|
||||
def ok(self) -> bool:
|
||||
return not self.errors
|
||||
|
||||
def add_error(self, message: str) -> None:
|
||||
self.errors.append(message)
|
||||
|
||||
def add_warning(self, message: str) -> None:
|
||||
self.warnings.append(message)
|
||||
|
||||
def summary(self) -> str:
|
||||
return f"checked={self.episodes_checked} errors={len(self.errors)} warnings={len(self.warnings)}"
|
||||
|
||||
|
||||
VQA_ANSWER_SHAPES: dict[str, set[str]] = {
|
||||
"bbox": {"detections"},
|
||||
"keypoint": {"label", "point_format", "point"},
|
||||
"count": {"label", "count"},
|
||||
"attribute": {"label", "attribute", "value"},
|
||||
"spatial": {"subject", "relation", "object"},
|
||||
}
|
||||
|
||||
|
||||
def classify_vqa_answer(payload: Any) -> str | None:
|
||||
"""Best-effort classification of a VQA answer payload to a question type."""
|
||||
if not isinstance(payload, dict):
|
||||
return None
|
||||
keys = set(payload.keys())
|
||||
for kind, required in VQA_ANSWER_SHAPES.items():
|
||||
if required.issubset(keys):
|
||||
return kind
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class StagingValidator:
|
||||
"""Walks the staging tree and produces a :class:`ValidationReport`."""
|
||||
|
||||
timestamp_atol: float = 0.0 # exact-match by default
|
||||
dataset_camera_keys: tuple[str, ...] | None = None
|
||||
"""Known ``observation.images.*`` keys on the dataset. When set, the
|
||||
validator additionally enforces that every view-dependent row's
|
||||
``camera`` field references one of these keys. Pass ``None`` (default)
|
||||
to skip that cross-check (e.g. in unit tests with no real dataset)."""
|
||||
|
||||
def validate(
|
||||
self,
|
||||
records: Sequence[EpisodeRecord],
|
||||
staging_dir: Path,
|
||||
) -> ValidationReport:
|
||||
report = ValidationReport()
|
||||
for record in records:
|
||||
self._validate_episode(record, staging_dir, report)
|
||||
report.episodes_checked += 1
|
||||
return report
|
||||
|
||||
def _validate_episode(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
staging_dir: Path,
|
||||
report: ValidationReport,
|
||||
) -> None:
|
||||
staging = EpisodeStaging(staging_dir, record.episode_index)
|
||||
staged = staging.read_all()
|
||||
all_rows: list[dict[str, Any]] = []
|
||||
for module_name, rows in staged.items():
|
||||
for row in rows:
|
||||
row = {**row, "_module": module_name}
|
||||
all_rows.append(row)
|
||||
|
||||
frame_ts = set(record.frame_timestamps)
|
||||
|
||||
events: list[dict[str, Any]] = []
|
||||
persistent: list[dict[str, Any]] = []
|
||||
for row in all_rows:
|
||||
self._check_column_routing(row, report, record.episode_index)
|
||||
self._check_camera_field(row, report, record.episode_index, self.dataset_camera_keys)
|
||||
# ``_check_column_routing`` already recorded any unknown-style error;
|
||||
# don't let the same ``column_for_style`` lookup raise here uncaught.
|
||||
try:
|
||||
column = column_for_style(row.get("style"))
|
||||
except ValueError:
|
||||
continue
|
||||
if column == LANGUAGE_PERSISTENT:
|
||||
persistent.append(row)
|
||||
else:
|
||||
events.append(row)
|
||||
|
||||
for row in events:
|
||||
self._check_event_timestamp_alignment(row, frame_ts, report, record.episode_index)
|
||||
|
||||
self._check_speech_interjection_pairs(events, report, record.episode_index)
|
||||
self._check_plan_memory_consistency(persistent, events, report, record.episode_index)
|
||||
self._check_vqa_json(events, report, record.episode_index)
|
||||
self._check_vqa_uniqueness_per_frame_camera(events, report, record.episode_index)
|
||||
|
||||
def _check_camera_field(
|
||||
self,
|
||||
row: dict[str, Any],
|
||||
report: ValidationReport,
|
||||
episode_index: int,
|
||||
dataset_camera_keys: Sequence[str] | None,
|
||||
) -> None:
|
||||
"""Enforce the camera invariant + that the key matches the dataset's cameras."""
|
||||
style = row.get("style")
|
||||
camera = row.get("camera")
|
||||
try:
|
||||
validate_camera_field(style, camera)
|
||||
except ValueError as exc:
|
||||
report.add_error(f"ep={episode_index} module={row.get('_module')}: {exc}")
|
||||
return
|
||||
if is_view_dependent_style(style) and dataset_camera_keys and camera not in dataset_camera_keys:
|
||||
report.add_error(
|
||||
f"ep={episode_index} module={row.get('_module')}: camera {camera!r} on style "
|
||||
f"{style!r} is not one of the dataset's video keys {sorted(dataset_camera_keys)!r}"
|
||||
)
|
||||
|
||||
def _check_vqa_uniqueness_per_frame_camera(
|
||||
self,
|
||||
events: Iterable[dict[str, Any]],
|
||||
report: ValidationReport,
|
||||
episode_index: int,
|
||||
) -> None:
|
||||
"""Ensure at most one (vqa, user) and one (vqa, assistant) per (t, camera)."""
|
||||
counts: dict[tuple[float, str, str], int] = {}
|
||||
for row in events:
|
||||
if row.get("style") != "vqa":
|
||||
continue
|
||||
ts = row.get("timestamp")
|
||||
camera = row.get("camera")
|
||||
role = row.get("role")
|
||||
if ts is None or camera is None or role is None:
|
||||
continue # other validators flag these
|
||||
key = (float(ts), str(camera), str(role))
|
||||
counts[key] = counts.get(key, 0) + 1
|
||||
for (ts, camera, role), n in counts.items():
|
||||
if n > 1:
|
||||
report.add_error(
|
||||
f"ep={episode_index}: {n} duplicate vqa rows at t={ts} "
|
||||
f"camera={camera!r} role={role!r}; expected at most one per (t, camera, role)"
|
||||
)
|
||||
|
||||
def _check_column_routing(
|
||||
self,
|
||||
row: dict[str, Any],
|
||||
report: ValidationReport,
|
||||
episode_index: int,
|
||||
) -> None:
|
||||
style = row.get("style")
|
||||
module = row.get("_module")
|
||||
try:
|
||||
target_col = column_for_style(style)
|
||||
except ValueError:
|
||||
report.add_error(f"ep={episode_index} module={module}: unknown style {style!r}")
|
||||
return
|
||||
if module == "plan" and target_col != LANGUAGE_PERSISTENT:
|
||||
report.add_error(
|
||||
f"ep={episode_index} module=plan emitted style {style!r} that routes to {target_col} (must be persistent)"
|
||||
)
|
||||
if module in {"interjections", "vqa"} and target_col != LANGUAGE_EVENTS:
|
||||
report.add_error(
|
||||
f"ep={episode_index} module={module} emitted style {style!r} that routes to {target_col} (must be events)"
|
||||
)
|
||||
|
||||
def _check_event_timestamp_alignment(
|
||||
self,
|
||||
row: dict[str, Any],
|
||||
frame_ts: set[float],
|
||||
report: ValidationReport,
|
||||
episode_index: int,
|
||||
) -> None:
|
||||
ts = row.get("timestamp")
|
||||
if ts is None:
|
||||
report.add_error(f"ep={episode_index}: event row missing timestamp: {row!r}")
|
||||
return
|
||||
if self.timestamp_atol == 0.0:
|
||||
if float(ts) not in frame_ts:
|
||||
report.add_error(
|
||||
f"ep={episode_index}: event row timestamp {ts!r} does not match any source frame timestamp"
|
||||
)
|
||||
else:
|
||||
if not any(abs(float(ts) - f) <= self.timestamp_atol for f in frame_ts):
|
||||
report.add_error(
|
||||
f"ep={episode_index}: event row timestamp {ts!r} not within {self.timestamp_atol}s of any frame"
|
||||
)
|
||||
|
||||
def _check_speech_interjection_pairs(
|
||||
self,
|
||||
events: Iterable[dict[str, Any]],
|
||||
report: ValidationReport,
|
||||
episode_index: int,
|
||||
) -> None:
|
||||
speech_ts: dict[float, int] = {}
|
||||
interjection_ts: dict[float, int] = {}
|
||||
for row in events:
|
||||
ts = row.get("timestamp")
|
||||
if ts is None:
|
||||
continue
|
||||
ts_f = float(ts)
|
||||
if row.get("style") is None and row.get("role") == "assistant":
|
||||
speech_ts[ts_f] = speech_ts.get(ts_f, 0) + 1
|
||||
if row.get("style") == "interjection":
|
||||
interjection_ts[ts_f] = interjection_ts.get(ts_f, 0) + 1
|
||||
|
||||
for ts in interjection_ts:
|
||||
if ts not in speech_ts:
|
||||
report.add_error(f"ep={episode_index}: interjection at t={ts} has no paired speech atom")
|
||||
|
||||
def _check_plan_memory_consistency(
|
||||
self,
|
||||
persistent: Sequence[dict[str, Any]],
|
||||
events: Sequence[dict[str, Any]],
|
||||
report: ValidationReport,
|
||||
episode_index: int,
|
||||
) -> None:
|
||||
plan_ts = sorted({float(r["timestamp"]) for r in persistent if r.get("style") == "plan"})
|
||||
memory_ts = sorted({float(r["timestamp"]) for r in persistent if r.get("style") == "memory"})
|
||||
subtask_ts = sorted({float(r["timestamp"]) for r in persistent if r.get("style") == "subtask"})
|
||||
interjection_ts = sorted(
|
||||
{
|
||||
float(r["timestamp"])
|
||||
for r in events
|
||||
if r.get("style") == "interjection" and r.get("timestamp") is not None
|
||||
}
|
||||
)
|
||||
|
||||
if persistent and not plan_ts:
|
||||
report.add_warning(f"ep={episode_index}: persistent rows present but no plan emitted")
|
||||
# every interjection should have a same-timestamp plan refresh
|
||||
for ts in interjection_ts:
|
||||
if ts not in set(plan_ts):
|
||||
report.add_error(
|
||||
f"ep={episode_index}: interjection at t={ts} has no co-timestamped plan update"
|
||||
)
|
||||
# memory should be emitted at subtask boundaries (subset relation)
|
||||
if memory_ts and subtask_ts:
|
||||
mem_set = set(memory_ts)
|
||||
sub_set = set(subtask_ts)
|
||||
stray = sorted(mem_set - sub_set)
|
||||
if stray:
|
||||
report.add_warning(f"ep={episode_index}: memory rows at {stray} not at any subtask boundary")
|
||||
|
||||
def _check_vqa_json(
|
||||
self,
|
||||
events: Iterable[dict[str, Any]],
|
||||
report: ValidationReport,
|
||||
episode_index: int,
|
||||
) -> None:
|
||||
for row in events:
|
||||
if row.get("style") != "vqa" or row.get("role") != "assistant":
|
||||
continue
|
||||
content = row.get("content")
|
||||
if content is None:
|
||||
report.add_error(
|
||||
f"ep={episode_index}: VQA assistant row at t={row.get('timestamp')} has null content"
|
||||
)
|
||||
continue
|
||||
try:
|
||||
payload = json.loads(content)
|
||||
except (TypeError, ValueError) as exc:
|
||||
report.add_error(
|
||||
f"ep={episode_index}: VQA assistant content not valid JSON at t={row.get('timestamp')}: {exc}"
|
||||
)
|
||||
continue
|
||||
shape = classify_vqa_answer(payload)
|
||||
if shape is None:
|
||||
report.add_error(
|
||||
f"ep={episode_index}: VQA assistant payload at t={row.get('timestamp')} does not match any known shape: keys={list(payload) if isinstance(payload, dict) else type(payload).__name__}"
|
||||
)
|
||||
@@ -1,617 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Shared Qwen-VL client.
|
||||
|
||||
The pipeline uses a single shared VLM across modules. vLLM is preferred when
|
||||
available (high throughput, JSON-guided decoding); transformers is the
|
||||
fallback. A ``stub`` backend is used for unit tests so fixtures never call
|
||||
into a real model.
|
||||
|
||||
The client speaks one method, :meth:`VlmClient.generate_json`, which:
|
||||
|
||||
- accepts a list of OpenAI/HF-style multimodal messages,
|
||||
- requests JSON output from the server,
|
||||
- batches requests transparently,
|
||||
- and reprompts once on a JSON parse failure with an inline correction
|
||||
message before raising.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import atexit
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import shlex
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import urllib.request
|
||||
from collections.abc import Callable, Sequence
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Protocol
|
||||
|
||||
from .config import VlmConfig
|
||||
|
||||
|
||||
class VlmClient(Protocol):
|
||||
"""Protocol every backend must implement."""
|
||||
|
||||
def generate_json(
|
||||
self,
|
||||
messages_batch: Sequence[Sequence[dict[str, Any]]],
|
||||
*,
|
||||
max_new_tokens: int | None = None,
|
||||
temperature: float | None = None,
|
||||
) -> list[Any]:
|
||||
"""Generate one JSON-decoded response per messages list."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class StubVlmClient:
|
||||
"""Deterministic stub used in unit tests.
|
||||
|
||||
A test passes a callable that maps the *last user message text* (or, if
|
||||
that is empty, the full message list) to a JSON-serializable response.
|
||||
"""
|
||||
|
||||
responder: Callable[[Sequence[dict[str, Any]]], Any]
|
||||
|
||||
def generate_json(
|
||||
self,
|
||||
messages_batch: Sequence[Sequence[dict[str, Any]]],
|
||||
*,
|
||||
max_new_tokens: int | None = None,
|
||||
temperature: float | None = None,
|
||||
) -> list[Any]:
|
||||
return [self.responder(list(messages)) for messages in messages_batch]
|
||||
|
||||
|
||||
def _strip_to_json(text: str) -> Any:
|
||||
text = text.strip()
|
||||
# Strip <think>...</think> blocks (Qwen3 Thinking style)
|
||||
while "<think>" in text and "</think>" in text:
|
||||
start = text.find("<think>")
|
||||
end = text.find("</think>", start) + len("</think>")
|
||||
text = (text[:start] + text[end:]).strip()
|
||||
# Strip ```json ... ``` fences from chat-tuned backbones
|
||||
if text.startswith("```"):
|
||||
first = text.find("\n")
|
||||
last = text.rfind("```")
|
||||
if first != -1 and last != -1 and last > first:
|
||||
text = text[first + 1 : last].strip()
|
||||
try:
|
||||
return json.loads(text)
|
||||
except (ValueError, json.JSONDecodeError):
|
||||
pass
|
||||
# Fall back to extracting the first balanced {...} block.
|
||||
obj_text = _extract_first_json_object(text)
|
||||
if obj_text is None:
|
||||
raise json.JSONDecodeError("No JSON object found", text, 0)
|
||||
return json.loads(obj_text)
|
||||
|
||||
|
||||
def _extract_first_json_object(text: str) -> str | None:
|
||||
"""Return the first balanced ``{...}`` substring, ignoring braces in
|
||||
string literals. Returns ``None`` if no balanced block is found."""
|
||||
start = text.find("{")
|
||||
if start < 0:
|
||||
return None
|
||||
depth = 0
|
||||
in_string = False
|
||||
escape = False
|
||||
for i in range(start, len(text)):
|
||||
ch = text[i]
|
||||
if escape:
|
||||
escape = False
|
||||
continue
|
||||
if ch == "\\":
|
||||
escape = True
|
||||
continue
|
||||
# Note: ``escape`` is always False here — the ``if escape`` branch
|
||||
# above already handled and reset it.
|
||||
if ch == '"':
|
||||
in_string = not in_string
|
||||
continue
|
||||
if in_string:
|
||||
continue
|
||||
if ch == "{":
|
||||
depth += 1
|
||||
elif ch == "}":
|
||||
depth -= 1
|
||||
if depth == 0:
|
||||
return text[start : i + 1]
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class _GenericTextClient:
|
||||
"""Wraps any text-generation callable in JSON-mode + one-retry semantics."""
|
||||
|
||||
generate_text: Callable[[Sequence[Sequence[dict[str, Any]]], int, float], list[str]]
|
||||
config: VlmConfig
|
||||
|
||||
def generate_json(
|
||||
self,
|
||||
messages_batch: Sequence[Sequence[dict[str, Any]]],
|
||||
*,
|
||||
max_new_tokens: int | None = None,
|
||||
temperature: float | None = None,
|
||||
) -> list[Any]:
|
||||
max_tok = max_new_tokens if max_new_tokens is not None else self.config.max_new_tokens
|
||||
temp = temperature if temperature is not None else self.config.temperature
|
||||
raw = self.generate_text(messages_batch, max_tok, temp)
|
||||
out: list[Any] = []
|
||||
for messages, text in zip(messages_batch, raw, strict=True):
|
||||
try:
|
||||
out.append(_strip_to_json(text))
|
||||
continue
|
||||
except (ValueError, json.JSONDecodeError):
|
||||
pass
|
||||
retry = list(messages) + [
|
||||
{"role": "assistant", "content": text},
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
"Your previous reply was not valid JSON. "
|
||||
"Reply with strictly valid JSON, no prose, no fences."
|
||||
),
|
||||
},
|
||||
]
|
||||
retry_text = self.generate_text([retry], max_tok, temp)[0]
|
||||
try:
|
||||
out.append(_strip_to_json(retry_text))
|
||||
except (ValueError, json.JSONDecodeError):
|
||||
# After retry: log preview and return None instead of crashing
|
||||
# the whole pipeline. Modules treat None as "skip".
|
||||
preview = retry_text.strip().replace("\n", " ")[:200]
|
||||
print(
|
||||
f"[vlm] WARNING: failed to parse JSON after retry; preview: {preview!r}",
|
||||
flush=True,
|
||||
)
|
||||
out.append(None)
|
||||
return out
|
||||
|
||||
|
||||
def make_vlm_client(config: VlmConfig) -> VlmClient:
|
||||
"""Build the shared VLM client.
|
||||
|
||||
Only the ``openai`` backend is supported for now. The shipped workflow
|
||||
is Hugging Face Jobs (``examples/annotations/run_hf_job.py``): it boots
|
||||
a vLLM server inside the ``vllm/vllm-openai`` image and the pipeline
|
||||
talks to it over the OpenAI-compatible API (``--vlm.backend=openai``,
|
||||
optionally auto-spawning the server via ``auto_serve`` /
|
||||
``serve_command``). The former in-process ``vllm`` / ``transformers``
|
||||
backends were removed to keep the support surface to the HF Jobs path.
|
||||
|
||||
For ``stub``, construct :class:`StubVlmClient` directly with a responder
|
||||
callable; it is rejected here to make accidental misuse obvious.
|
||||
"""
|
||||
if config.backend == "openai":
|
||||
return _make_openai_client(config)
|
||||
if config.backend == "stub":
|
||||
raise ValueError(
|
||||
"Use StubVlmClient(...) directly for the stub backend; make_vlm_client builds real clients."
|
||||
)
|
||||
if config.backend in {"vllm", "transformers"}:
|
||||
raise ValueError(
|
||||
f"backend={config.backend!r} (in-process local model) is not supported for now — "
|
||||
"only backend='openai' (the Hugging Face Jobs flow) is. Run the pipeline via "
|
||||
"examples/annotations/run_hf_job.py, which serves the model with vLLM in the "
|
||||
"vllm/vllm-openai image and talks to it over the OpenAI-compatible API."
|
||||
)
|
||||
raise ValueError(f"Unknown VLM backend: {config.backend!r}")
|
||||
|
||||
|
||||
def _make_openai_client(config: VlmConfig) -> VlmClient:
|
||||
"""Backend that talks to any OpenAI-compatible server.
|
||||
|
||||
Compatible with ``vllm serve``, ``transformers serve``,
|
||||
``ktransformers serve``, and hosted endpoints. By default the server
|
||||
is expected to be already running. Set ``auto_serve=True`` to have
|
||||
this client spawn one (default: ``transformers serve``), wait until
|
||||
it's ready, and tear it down on process exit.
|
||||
|
||||
Image blocks ``{"type":"image", "image":<PIL.Image>}`` are
|
||||
auto-converted to ``image_url`` data-URLs. Video blocks
|
||||
``{"type":"video", "video":[<PIL>...]}`` are forwarded as
|
||||
multi-frame ``video_url`` items where supported.
|
||||
"""
|
||||
try:
|
||||
from openai import OpenAI # type: ignore[import-not-found]
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"openai package is required for backend='openai'. Install with `pip install openai`."
|
||||
) from exc
|
||||
|
||||
api_base = config.api_base
|
||||
api_key = config.api_key
|
||||
auto_serve = config.auto_serve
|
||||
api_bases: list[str] = [api_base]
|
||||
|
||||
print(
|
||||
f"[lerobot-annotate] backend=openai model={config.model_id} "
|
||||
f"api_base={api_base} auto_serve={auto_serve}",
|
||||
flush=True,
|
||||
)
|
||||
if auto_serve:
|
||||
if config.parallel_servers > 1:
|
||||
print(
|
||||
f"[lerobot-annotate] spawning {config.parallel_servers} parallel servers",
|
||||
flush=True,
|
||||
)
|
||||
api_bases = _spawn_parallel_inference_servers(config)
|
||||
elif _server_is_up(api_base):
|
||||
print(f"[lerobot-annotate] reusing server already up at {api_base}", flush=True)
|
||||
else:
|
||||
print("[lerobot-annotate] no server reachable; spawning one", flush=True)
|
||||
api_base = _spawn_inference_server(config)
|
||||
api_bases = [api_base]
|
||||
print(f"[lerobot-annotate] server ready at {api_base}", flush=True)
|
||||
|
||||
clients = [OpenAI(base_url=base, api_key=api_key) for base in api_bases]
|
||||
# round-robin counter for parallel mode
|
||||
rr_counter = {"i": 0}
|
||||
|
||||
# ``mm_processor_kwargs`` is a vllm-specific extra; transformers serve
|
||||
# rejects it with HTTP 422. Send it only when explicitly opted in via
|
||||
# an env var (e.g. ``LEROBOT_OPENAI_SEND_MM_KWARGS=1`` for vllm).
|
||||
send_mm_kwargs = os.environ.get("LEROBOT_OPENAI_SEND_MM_KWARGS", "").lower() in {"1", "true", "yes"}
|
||||
|
||||
rr_lock = threading.Lock()
|
||||
|
||||
def _one_call(messages: Sequence[dict[str, Any]], max_tok: int, temp: float) -> str:
|
||||
api_messages, mm_kwargs = _to_openai_messages(messages)
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": config.model_id,
|
||||
"messages": api_messages,
|
||||
"max_tokens": max_tok,
|
||||
"temperature": temp,
|
||||
}
|
||||
extra_body: dict[str, Any] = {}
|
||||
if send_mm_kwargs and mm_kwargs:
|
||||
extra_body["mm_processor_kwargs"] = {**mm_kwargs, "do_sample_frames": True}
|
||||
if config.chat_template_kwargs:
|
||||
extra_body["chat_template_kwargs"] = config.chat_template_kwargs
|
||||
if extra_body:
|
||||
kwargs["extra_body"] = extra_body
|
||||
with rr_lock:
|
||||
chosen = clients[rr_counter["i"] % len(clients)]
|
||||
rr_counter["i"] += 1
|
||||
response = chosen.chat.completions.create(**kwargs)
|
||||
return response.choices[0].message.content or ""
|
||||
|
||||
def _gen(batch: Sequence[Sequence[dict[str, Any]]], max_tok: int, temp: float) -> list[str]:
|
||||
if len(batch) <= 1 or config.client_concurrency <= 1:
|
||||
return [_one_call(messages, max_tok, temp) for messages in batch]
|
||||
# Parallel fan-out — vllm batches these on the server side.
|
||||
max_workers = min(config.client_concurrency, len(batch))
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as pool:
|
||||
futures = [pool.submit(_one_call, messages, max_tok, temp) for messages in batch]
|
||||
return [f.result() for f in futures]
|
||||
|
||||
return _GenericTextClient(_gen, config)
|
||||
|
||||
|
||||
def _bind_serve_port(cmd: str, port: int) -> str:
|
||||
"""Bind a serve command to ``port``: substitute a ``{port}`` placeholder
|
||||
if present, else append ``--port`` when the command omits it (leaving an
|
||||
explicit ``--port`` untouched). Shared by the single- and parallel-server
|
||||
paths so a serve_command never reaches the server with a literal
|
||||
``{port}``."""
|
||||
if "{port}" in cmd:
|
||||
return cmd.replace("{port}", str(port))
|
||||
if "--port" not in cmd:
|
||||
return f"{cmd} --port {port}"
|
||||
return cmd
|
||||
|
||||
|
||||
def _spawn_parallel_inference_servers(config: VlmConfig) -> list[str]:
|
||||
"""Spawn ``config.parallel_servers`` independent vllm replicas.
|
||||
|
||||
Each replica:
|
||||
- is pinned to a single GPU via ``CUDA_VISIBLE_DEVICES``
|
||||
- listens on ``serve_port + i``
|
||||
- is shut down via the same atexit hook as the single-server path
|
||||
|
||||
Returns the list of ``api_base`` URLs the client should round-robin
|
||||
across.
|
||||
"""
|
||||
n = config.parallel_servers
|
||||
api_bases: list[str] = []
|
||||
procs: list[subprocess.Popen] = []
|
||||
ready_events: list[threading.Event] = []
|
||||
# Multiple readiness signals — uvicorn's own banner is suppressed at
|
||||
# ``--uvicorn-log-level warning``, so we also accept vllm's own
|
||||
# "Starting vLLM API server" line and the route-listing line. The
|
||||
# HTTP probe below is the ultimate fallback.
|
||||
ready_markers = (
|
||||
"Uvicorn running",
|
||||
"Application startup complete",
|
||||
"Starting vLLM API server",
|
||||
"Available routes are",
|
||||
)
|
||||
# Single lock for all server-stream threads so multibyte chars from
|
||||
# different servers don't interleave and tear UTF-8 sequences.
|
||||
print_lock = threading.Lock()
|
||||
|
||||
base_cmd = config.serve_command or (
|
||||
f"vllm serve {shlex.quote(config.model_id)} "
|
||||
f"--tensor-parallel-size 1 "
|
||||
f"--max-model-len {config.max_model_len or 32768} "
|
||||
f"--uvicorn-log-level warning"
|
||||
)
|
||||
|
||||
num_gpus = config.num_gpus if config.num_gpus > 0 else n
|
||||
for i in range(n):
|
||||
port = config.serve_port + i
|
||||
gpu = i % num_gpus
|
||||
env = os.environ.copy()
|
||||
env["CUDA_VISIBLE_DEVICES"] = str(gpu)
|
||||
cmd = _bind_serve_port(base_cmd, port)
|
||||
api_base = f"http://localhost:{port}/v1"
|
||||
api_bases.append(api_base)
|
||||
print(f"[server-{i}] launching on GPU {gpu} port {port}: {cmd}", flush=True)
|
||||
proc = subprocess.Popen(
|
||||
shlex.split(cmd),
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
text=True,
|
||||
bufsize=1,
|
||||
env=env,
|
||||
)
|
||||
procs.append(proc)
|
||||
ready = threading.Event()
|
||||
ready_events.append(ready)
|
||||
|
||||
def _stream(idx: int, p: subprocess.Popen, ev: threading.Event) -> None:
|
||||
# Read whole lines and emit each line atomically under the
|
||||
# shared print_lock so output from N servers stays readable.
|
||||
assert p.stdout is not None
|
||||
for line in iter(p.stdout.readline, ""):
|
||||
with print_lock:
|
||||
sys.stdout.write(f"[server-{idx}] {line}")
|
||||
if not line.endswith(("\n", "\r")):
|
||||
sys.stdout.write("\n")
|
||||
sys.stdout.flush()
|
||||
if any(m in line for m in ready_markers):
|
||||
ev.set()
|
||||
|
||||
threading.Thread(target=_stream, args=(i, proc, ready), daemon=True).start()
|
||||
|
||||
def _probe(idx: int, base: str, ev: threading.Event, p: subprocess.Popen) -> None:
|
||||
while not ev.is_set() and p.poll() is None:
|
||||
if _server_is_up(base):
|
||||
print(f"[server-{idx}] ready (http probe)", flush=True)
|
||||
ev.set()
|
||||
return
|
||||
time.sleep(2)
|
||||
|
||||
threading.Thread(target=_probe, args=(i, api_base, ready, proc), daemon=True).start()
|
||||
|
||||
def _shutdown() -> None:
|
||||
for i, p in enumerate(procs):
|
||||
if p.poll() is None:
|
||||
print(f"[server-{i}] stopping pid={p.pid}", flush=True)
|
||||
p.send_signal(signal.SIGINT)
|
||||
for p in procs:
|
||||
try:
|
||||
p.wait(timeout=15)
|
||||
except subprocess.TimeoutExpired:
|
||||
p.kill()
|
||||
p.wait(timeout=5)
|
||||
|
||||
atexit.register(_shutdown)
|
||||
|
||||
deadline = time.monotonic() + config.serve_ready_timeout_s
|
||||
while any(not ev.is_set() for ev in ready_events) and time.monotonic() < deadline:
|
||||
for i, p in enumerate(procs):
|
||||
if p.poll() is not None:
|
||||
raise RuntimeError(
|
||||
f"[server-{i}] inference server exited unexpectedly with rc={p.returncode}"
|
||||
)
|
||||
time.sleep(2)
|
||||
if any(not ev.is_set() for ev in ready_events):
|
||||
raise RuntimeError(f"[server] not all replicas became ready within {config.serve_ready_timeout_s}s")
|
||||
print(f"[lerobot-annotate] all {n} servers ready: {api_bases}", flush=True)
|
||||
return api_bases
|
||||
|
||||
|
||||
def _server_is_up(api_base: str) -> bool:
|
||||
"""Return True if ``api_base/models`` answers 200 within 2 seconds."""
|
||||
url = api_base.rstrip("/") + "/models"
|
||||
# ``api_base`` is the user-configured local-server URL we just spawned
|
||||
# or the user passed in via ``--vlm.api_base``; the bandit B310 warning
|
||||
# is for arbitrary user-controlled URLs with file:/ schemes which
|
||||
# cannot reach this code path.
|
||||
try:
|
||||
with urllib.request.urlopen(url, timeout=2) as resp: # noqa: S310 # nosec B310
|
||||
return resp.status == 200
|
||||
except Exception: # noqa: BLE001
|
||||
return False
|
||||
|
||||
|
||||
def _spawn_inference_server(config: VlmConfig) -> str:
|
||||
"""Spawn ``transformers serve`` (or ``serve_command``), wait until it
|
||||
accepts ``/v1/models``, and register a shutdown hook.
|
||||
|
||||
Streams the server's stdout/stderr to the parent terminal in
|
||||
real-time on a background thread so users can see model-load
|
||||
progress and errors as they happen.
|
||||
|
||||
Returns the full ``api_base`` URL the OpenAI client should use.
|
||||
"""
|
||||
cmd = config.serve_command
|
||||
if not cmd:
|
||||
cmd = (
|
||||
f"transformers serve {shlex.quote(config.model_id)} "
|
||||
f"--port {config.serve_port} --continuous-batching"
|
||||
)
|
||||
# Bind the single server to ``serve_port`` (what ``api_base`` below
|
||||
# targets): substitute a literal ``{port}`` placeholder, else append
|
||||
# ``--port``. Without this a serve_command carrying ``{port}`` would
|
||||
# reach the server unsubstituted and fail to parse.
|
||||
cmd = _bind_serve_port(cmd, config.serve_port)
|
||||
api_base = f"http://localhost:{config.serve_port}/v1"
|
||||
print(f"[server] launching: {cmd}", flush=True)
|
||||
proc = subprocess.Popen(
|
||||
shlex.split(cmd),
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
text=True,
|
||||
bufsize=1,
|
||||
)
|
||||
|
||||
# Watch the server output for the uvicorn readiness banner. This is
|
||||
# more reliable than polling /v1/models because transformers serve
|
||||
# rescans its cache on every model-list request, which can exceed
|
||||
# the urllib timeout and trigger an infinite probe loop.
|
||||
ready_event = threading.Event()
|
||||
# See _spawn_parallel_inference_servers for why we accept these.
|
||||
ready_markers = (
|
||||
"Uvicorn running",
|
||||
"Application startup complete",
|
||||
"Starting vLLM API server",
|
||||
"Available routes are",
|
||||
)
|
||||
|
||||
def _probe() -> None:
|
||||
while not ready_event.is_set() and proc.poll() is None:
|
||||
if _server_is_up(api_base):
|
||||
print("[server] ready (http probe)", flush=True)
|
||||
ready_event.set()
|
||||
return
|
||||
time.sleep(2)
|
||||
|
||||
threading.Thread(target=_probe, daemon=True).start()
|
||||
|
||||
def _stream_output() -> None:
|
||||
# Read raw chunks instead of iterating lines so tqdm progress
|
||||
# bars (which overwrite using \r) flush in real time.
|
||||
assert proc.stdout is not None
|
||||
buf = ""
|
||||
prefix_started = False
|
||||
while True:
|
||||
ch = proc.stdout.read(1)
|
||||
if ch == "":
|
||||
# process exited; flush any tail
|
||||
if buf:
|
||||
sys.stdout.write(buf)
|
||||
sys.stdout.flush()
|
||||
return
|
||||
if not prefix_started:
|
||||
sys.stdout.write("[server] ")
|
||||
prefix_started = True
|
||||
sys.stdout.write(ch)
|
||||
sys.stdout.flush()
|
||||
buf += ch
|
||||
if ch in ("\n", "\r"):
|
||||
if any(marker in buf for marker in ready_markers):
|
||||
ready_event.set()
|
||||
buf = ""
|
||||
prefix_started = False
|
||||
|
||||
threading.Thread(target=_stream_output, daemon=True).start()
|
||||
|
||||
def _shutdown() -> None:
|
||||
if proc.poll() is None:
|
||||
print(f"[server] stopping pid={proc.pid}", flush=True)
|
||||
proc.send_signal(signal.SIGINT)
|
||||
try:
|
||||
proc.wait(timeout=15)
|
||||
except subprocess.TimeoutExpired:
|
||||
proc.kill()
|
||||
proc.wait(timeout=5)
|
||||
|
||||
atexit.register(_shutdown)
|
||||
|
||||
deadline = time.monotonic() + config.serve_ready_timeout_s
|
||||
while time.monotonic() < deadline:
|
||||
if proc.poll() is not None:
|
||||
raise RuntimeError(
|
||||
f"[server] inference server exited unexpectedly with rc={proc.returncode}. "
|
||||
f"See [server] log lines above for the cause."
|
||||
)
|
||||
if ready_event.wait(timeout=2):
|
||||
return api_base
|
||||
proc.terminate()
|
||||
raise RuntimeError(f"[server] did not become ready within {config.serve_ready_timeout_s}s")
|
||||
|
||||
|
||||
def _to_openai_messages(
|
||||
messages: Sequence[dict[str, Any]],
|
||||
) -> tuple[list[dict[str, Any]], dict[str, Any]]:
|
||||
"""Convert internal messages to OpenAI chat format.
|
||||
|
||||
Returns ``(api_messages, mm_kwargs)``. Multimodal-processor kwargs
|
||||
(``fps`` from ``video_url`` blocks) are extracted out so the caller
|
||||
can pass them via ``extra_body.mm_processor_kwargs`` rather than
|
||||
inside the content blocks (which transformers serve rejects).
|
||||
|
||||
File-URL video blocks are inlined as base64 data URLs.
|
||||
"""
|
||||
out_messages: list[dict[str, Any]] = []
|
||||
mm_kwargs: dict[str, Any] = {}
|
||||
for message in messages:
|
||||
content = message.get("content")
|
||||
if not isinstance(content, list):
|
||||
out_messages.append({"role": message["role"], "content": content})
|
||||
continue
|
||||
out_blocks: list[dict[str, Any]] = []
|
||||
for block in content:
|
||||
block_type = block.get("type") if isinstance(block, dict) else None
|
||||
if block_type == "text":
|
||||
out_blocks.append({"type": "text", "text": block.get("text", "")})
|
||||
elif block_type == "image":
|
||||
out_blocks.append(
|
||||
{"type": "image_url", "image_url": {"url": _pil_to_data_url(block["image"])}}
|
||||
)
|
||||
elif block_type == "video":
|
||||
frames = block.get("video", [])
|
||||
for img in frames:
|
||||
out_blocks.append({"type": "image_url", "image_url": {"url": _pil_to_data_url(img)}})
|
||||
elif block_type == "video_url":
|
||||
video_url = dict(block["video_url"])
|
||||
url = video_url.get("url", "")
|
||||
if url.startswith("file://"):
|
||||
video_url["url"] = _file_to_data_url(url[len("file://") :])
|
||||
out_blocks.append({"type": "video_url", "video_url": video_url})
|
||||
fps = block.get("fps")
|
||||
if fps is not None:
|
||||
mm_kwargs["fps"] = fps
|
||||
else:
|
||||
out_blocks.append(block)
|
||||
out_messages.append({"role": message["role"], "content": out_blocks})
|
||||
return out_messages, mm_kwargs
|
||||
|
||||
|
||||
def _file_to_data_url(path: str) -> str:
|
||||
"""Read a local video file and return a base64 ``data:video/mp4`` URL."""
|
||||
with open(path, "rb") as f:
|
||||
b64 = base64.b64encode(f.read()).decode("ascii")
|
||||
return f"data:video/mp4;base64,{b64}"
|
||||
|
||||
|
||||
def _pil_to_data_url(image: Any) -> str:
|
||||
"""Encode a PIL.Image as a base64 data URL."""
|
||||
buf = io.BytesIO()
|
||||
image.save(buf, format="PNG")
|
||||
b64 = base64.b64encode(buf.getvalue()).decode("ascii")
|
||||
return f"data:image/png;base64,{b64}"
|
||||
@@ -1,341 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Final parquet rewrite.
|
||||
|
||||
For every episode the writer:
|
||||
|
||||
1. reads the staged module outputs,
|
||||
2. partitions them into a persistent slice (PERSISTENT_STYLES) and an event
|
||||
slice (EVENT_ONLY_STYLES + style=None tool-call atoms),
|
||||
3. sorts each slice deterministically,
|
||||
4. broadcasts the persistent slice across every frame in the episode,
|
||||
5. for each frame, materializes the sublist of event rows whose timestamp
|
||||
exactly equals that frame's timestamp,
|
||||
6. drops the legacy ``subtask_index`` column,
|
||||
7. writes the parquet shard back in place.
|
||||
|
||||
The writer does NOT add a dataset-level ``tools`` column. Tool *calls* are
|
||||
emitted per-row via the existing ``tool_calls`` field on the v3.1 row
|
||||
struct for every speech atom. The tool *schema* (the description
|
||||
of the ``say`` function and its parameters) is a fixed code constant —
|
||||
``SAY_TOOL_SCHEMA`` below — and downstream chat-template consumers import
|
||||
it directly rather than reading a redundant per-row column.
|
||||
|
||||
Invariants enforced here (and re-checked by the validator):
|
||||
|
||||
- per-episode persistent slice is byte-identical across every frame;
|
||||
- ``language_events`` rows on a frame all have ``timestamp == frame_ts``
|
||||
(timestamps come straight from the source parquet — never recomputed);
|
||||
- every row passes ``column_for_style(style)``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import pyarrow as pa
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
from lerobot.datasets.language import (
|
||||
EVENT_ONLY_STYLES,
|
||||
LANGUAGE_EVENTS,
|
||||
LANGUAGE_PERSISTENT,
|
||||
PERSISTENT_STYLES,
|
||||
column_for_style,
|
||||
validate_camera_field,
|
||||
)
|
||||
|
||||
from .reader import EpisodeRecord
|
||||
from .staging import EpisodeStaging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Tool schema constants live in lerobot.datasets.language — single
|
||||
# source of truth. Re-exported here so existing imports
|
||||
# (``from lerobot.annotations.steerable_pipeline.writer import SAY_TOOL_SCHEMA``)
|
||||
# keep working.
|
||||
from lerobot.datasets.language import DEFAULT_TOOLS, SAY_TOOL_SCHEMA # noqa: F401, E402
|
||||
|
||||
|
||||
def _row_persistent_sort_key(row: dict[str, Any]) -> tuple:
|
||||
return (float(row["timestamp"]), row.get("style") or "", row.get("role") or "")
|
||||
|
||||
|
||||
def _row_event_sort_key(row: dict[str, Any]) -> tuple:
|
||||
# events are bucketed per-frame, but within a frame we still want determinism
|
||||
return (
|
||||
row.get("style") or "",
|
||||
row.get("role") or "",
|
||||
row.get("camera") or "",
|
||||
)
|
||||
|
||||
|
||||
def _normalize_row(row: dict[str, Any], style: str | None, *, with_timestamp: bool) -> dict[str, Any]:
|
||||
"""Coerce a staged row into the language-column struct shape.
|
||||
|
||||
Key order matches ``PERSISTENT_ROW_FIELDS`` / ``EVENT_ROW_FIELDS`` — the
|
||||
writer infers the parquet struct schema from insertion order, so
|
||||
``timestamp`` (persistent rows only) sits between ``style`` and ``camera``.
|
||||
"""
|
||||
camera = row.get("camera")
|
||||
validate_camera_field(style, camera)
|
||||
out: dict[str, Any] = {
|
||||
"role": str(row["role"]),
|
||||
"content": None if row.get("content") is None else str(row["content"]),
|
||||
"style": style,
|
||||
}
|
||||
if with_timestamp:
|
||||
out["timestamp"] = float(row["timestamp"])
|
||||
out["camera"] = None if camera is None else str(camera)
|
||||
out["tool_calls"] = _normalize_tool_calls(row.get("tool_calls"))
|
||||
return out
|
||||
|
||||
|
||||
def _normalize_persistent_row(row: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Coerce a staged row into the persistent column's struct shape."""
|
||||
style = row.get("style")
|
||||
if style not in PERSISTENT_STYLES:
|
||||
raise ValueError(
|
||||
f"persistent slice contains row with non-persistent style {style!r}; "
|
||||
"row would be misrouted under column_for_style()"
|
||||
)
|
||||
if "timestamp" not in row:
|
||||
raise ValueError(f"persistent row missing timestamp: {row!r}")
|
||||
if "role" not in row:
|
||||
# Friendly error from the writer instead of a raw KeyError below;
|
||||
# the validator doesn't check ``role`` yet.
|
||||
raise ValueError(f"persistent row missing role: {row!r}")
|
||||
return _normalize_row(row, style, with_timestamp=True)
|
||||
|
||||
|
||||
def _normalize_event_row(row: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Coerce a staged row into the event column's struct shape (no timestamp)."""
|
||||
style = row.get("style")
|
||||
if style is not None and style not in EVENT_ONLY_STYLES:
|
||||
raise ValueError(
|
||||
f"event slice contains row with style {style!r}; expected None or one of {EVENT_ONLY_STYLES}"
|
||||
)
|
||||
if column_for_style(style) != LANGUAGE_EVENTS:
|
||||
raise ValueError(f"event row with style {style!r} would not route to language_events")
|
||||
if "role" not in row:
|
||||
raise ValueError(f"event row missing role: {row!r}")
|
||||
return _normalize_row(row, style, with_timestamp=False)
|
||||
|
||||
|
||||
def _normalize_tool_calls(value: Any) -> list[Any] | None:
|
||||
if value is None:
|
||||
return None
|
||||
if not isinstance(value, list):
|
||||
raise ValueError(f"tool_calls must be a list or None, got {type(value).__name__}")
|
||||
return list(value)
|
||||
|
||||
|
||||
def _validate_atom_invariants(row: dict[str, Any]) -> None:
|
||||
"""At-least-one of content/tool_calls; style=None implies tool_calls."""
|
||||
has_content = row.get("content") is not None
|
||||
has_tools = row.get("tool_calls") is not None
|
||||
if not (has_content or has_tools):
|
||||
raise ValueError(f"row has neither content nor tool_calls: {row!r}")
|
||||
if row.get("style") is None and not has_tools:
|
||||
raise ValueError(f"style=None requires tool_calls: {row!r}")
|
||||
|
||||
|
||||
def _validate_speech_atom(row: dict[str, Any]) -> None:
|
||||
"""Speech atoms: role=assistant, style=None, content=None, say tool call."""
|
||||
if row.get("style") is not None:
|
||||
return # not a speech atom
|
||||
if row.get("role") != "assistant":
|
||||
raise ValueError(f"speech atom must have role=assistant: {row!r}")
|
||||
if row.get("content") is not None:
|
||||
raise ValueError(f"speech atom must have content=null: {row!r}")
|
||||
tool_calls = row.get("tool_calls")
|
||||
if not tool_calls or not isinstance(tool_calls, list):
|
||||
raise ValueError(f"speech atom must have non-empty tool_calls list: {row!r}")
|
||||
first = tool_calls[0]
|
||||
if not isinstance(first, dict):
|
||||
raise ValueError(f"speech atom tool_calls[0] must be a dict: {row!r}")
|
||||
if first.get("type") != "function":
|
||||
raise ValueError(f"speech atom tool_calls[0].type must be 'function': {row!r}")
|
||||
fn = first.get("function") or {}
|
||||
if fn.get("name") != "say":
|
||||
raise ValueError(f"speech atom tool_calls[0].function.name must be 'say': {row!r}")
|
||||
args = fn.get("arguments") or {}
|
||||
if not isinstance(args, dict) or "text" not in args or not isinstance(args["text"], str):
|
||||
raise ValueError(f"speech atom must carry 'text' string in arguments: {row!r}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class LanguageColumnsWriter:
|
||||
"""Rewrite ``data/chunk-*/file-*.parquet`` with the two language columns."""
|
||||
|
||||
drop_existing_subtask_index: bool = True
|
||||
|
||||
def write_all(
|
||||
self,
|
||||
records: Sequence[EpisodeRecord],
|
||||
staging_dir: Path,
|
||||
root: Path,
|
||||
) -> list[Path]:
|
||||
episodes_by_path: dict[Path, list[EpisodeRecord]] = defaultdict(list)
|
||||
for record in records:
|
||||
episodes_by_path[record.data_path].append(record)
|
||||
|
||||
written: list[Path] = []
|
||||
for path, eps in episodes_by_path.items():
|
||||
self._rewrite_one(path, eps, staging_dir, root)
|
||||
written.append(path)
|
||||
return written
|
||||
|
||||
def _rewrite_one(
|
||||
self,
|
||||
path: Path,
|
||||
episodes: Sequence[EpisodeRecord],
|
||||
staging_dir: Path,
|
||||
root: Path,
|
||||
) -> None:
|
||||
table = pq.read_table(path)
|
||||
n_rows = table.num_rows
|
||||
|
||||
# Ensure we cover every episode in the file. Episodes that don't have
|
||||
# staging artifacts are passed through with empty annotation lists —
|
||||
# this keeps the writer idempotent and safe for partial reruns.
|
||||
staged_per_ep: dict[int, dict[str, list[dict[str, Any]]]] = {}
|
||||
for record in episodes:
|
||||
staging = EpisodeStaging(staging_dir, record.episode_index)
|
||||
staged_per_ep[record.episode_index] = staging.read_all()
|
||||
|
||||
persistent_by_ep: dict[int, list[dict[str, Any]]] = {}
|
||||
events_by_ep_ts: dict[int, dict[float, list[dict[str, Any]]]] = {}
|
||||
|
||||
for ep_index, ep_staged in staged_per_ep.items():
|
||||
persistent_rows: list[dict[str, Any]] = []
|
||||
event_rows: list[dict[str, Any]] = [] # carry timestamp until bucketed
|
||||
for _module_name, rows in ep_staged.items():
|
||||
for row in rows:
|
||||
style = row.get("style")
|
||||
if column_for_style(style) == LANGUAGE_PERSISTENT:
|
||||
persistent_rows.append(row)
|
||||
else:
|
||||
event_rows.append(row)
|
||||
|
||||
persistent_rows.sort(key=_row_persistent_sort_key)
|
||||
normalized_persistent = []
|
||||
for r in persistent_rows:
|
||||
_validate_atom_invariants(r)
|
||||
_validate_speech_atom(r)
|
||||
normalized_persistent.append(_normalize_persistent_row(r))
|
||||
persistent_by_ep[ep_index] = normalized_persistent
|
||||
|
||||
buckets: dict[float, list[dict[str, Any]]] = defaultdict(list)
|
||||
for r in event_rows:
|
||||
_validate_atom_invariants(r)
|
||||
_validate_speech_atom(r)
|
||||
ts = float(r["timestamp"])
|
||||
buckets[ts].append(_normalize_event_row(r))
|
||||
for ts in list(buckets.keys()):
|
||||
buckets[ts].sort(key=_row_event_sort_key)
|
||||
events_by_ep_ts[ep_index] = buckets
|
||||
|
||||
episode_col = (
|
||||
table.column("episode_index").to_pylist() if "episode_index" in table.column_names else None
|
||||
)
|
||||
ts_col = table.column("timestamp").to_pylist() if "timestamp" in table.column_names else None
|
||||
if episode_col is None or ts_col is None:
|
||||
raise ValueError(f"{path} is missing 'episode_index' or 'timestamp' — required by the writer.")
|
||||
|
||||
per_row_persistent: list[list[dict[str, Any]]] = []
|
||||
per_row_events: list[list[dict[str, Any]]] = []
|
||||
for i in range(n_rows):
|
||||
ep = episode_col[i]
|
||||
ts = float(ts_col[i])
|
||||
per_row_persistent.append(persistent_by_ep.get(ep, []))
|
||||
buckets = events_by_ep_ts.get(ep, {})
|
||||
per_row_events.append(buckets.get(ts, []))
|
||||
|
||||
new_table = self._materialize_table(
|
||||
table, per_row_persistent, per_row_events, drop_old=self.drop_existing_subtask_index
|
||||
)
|
||||
# Atomic replace: write to a sibling tmp path and rename so a crash
|
||||
# mid-write can't leave a half-written shard that ``pq.read_table``
|
||||
# would then fail to open. ``Path.replace`` is atomic on POSIX +
|
||||
# Windows when source and target sit on the same filesystem.
|
||||
tmp_path = path.with_suffix(path.suffix + ".tmp")
|
||||
pq.write_table(new_table, tmp_path)
|
||||
tmp_path.replace(path)
|
||||
|
||||
def _materialize_table(
|
||||
self,
|
||||
table: pa.Table,
|
||||
persistent: list[list[dict[str, Any]]],
|
||||
events: list[list[dict[str, Any]]],
|
||||
*,
|
||||
drop_old: bool,
|
||||
) -> pa.Table:
|
||||
cols = []
|
||||
names = []
|
||||
for name in table.column_names:
|
||||
if drop_old and name == "subtask_index":
|
||||
continue
|
||||
if name in (LANGUAGE_PERSISTENT, LANGUAGE_EVENTS):
|
||||
continue # we'll re-add canonical versions
|
||||
# Strip any legacy ``tools`` column previously emitted by older
|
||||
# writers — the schema no longer uses it (constant lives in
|
||||
# SAY_TOOL_SCHEMA / DEFAULT_TOOLS).
|
||||
if name == "tools":
|
||||
continue
|
||||
cols.append(table.column(name))
|
||||
names.append(name)
|
||||
|
||||
# We let pyarrow infer struct/list schema rather than passing the
|
||||
# canonical type from `lerobot.datasets.language` directly: that type
|
||||
# uses `pa.json_()` for the `tool_calls` element type, which
|
||||
# `pa.array(..., type=...)` cannot materialize from Python lists on
|
||||
# current pyarrow versions. The inferred schema round-trips through
|
||||
# parquet and `LeRobotDataset` correctly — `tests/datasets/test_language.py`
|
||||
# exercises the same flow.
|
||||
persistent_arr = pa.array(persistent)
|
||||
events_arr = pa.array(events)
|
||||
|
||||
cols.extend([persistent_arr, events_arr])
|
||||
names.extend([LANGUAGE_PERSISTENT, LANGUAGE_EVENTS])
|
||||
|
||||
return pa.Table.from_arrays(cols, names=names)
|
||||
|
||||
|
||||
def speech_atom(timestamp: float, text: str) -> dict[str, Any]:
|
||||
"""Build a canonical speech tool-call atom for the events column."""
|
||||
return {
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"style": None,
|
||||
"timestamp": float(timestamp),
|
||||
"camera": None,
|
||||
"tool_calls": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "say",
|
||||
"arguments": {"text": text},
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
@@ -1,30 +0,0 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Async inference server/client.
|
||||
|
||||
Requires: ``pip install 'lerobot[async]'``
|
||||
|
||||
Available modules (import directly)::
|
||||
|
||||
from lerobot.async_inference.policy_server import ...
|
||||
from lerobot.async_inference.robot_client import ...
|
||||
"""
|
||||
|
||||
from lerobot.utils.import_utils import require_package
|
||||
|
||||
require_package("grpcio", extra="async", import_name="grpc")
|
||||
|
||||
__all__: list[str] = []
|
||||
@@ -1,203 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.robots.config import RobotConfig
|
||||
|
||||
from .constants import (
|
||||
DEFAULT_FPS,
|
||||
DEFAULT_INFERENCE_LATENCY,
|
||||
DEFAULT_OBS_QUEUE_TIMEOUT,
|
||||
)
|
||||
|
||||
# Aggregate function registry for CLI usage
|
||||
AGGREGATE_FUNCTIONS = {
|
||||
"weighted_average": lambda old, new: 0.3 * old + 0.7 * new,
|
||||
"latest_only": lambda old, new: new,
|
||||
"average": lambda old, new: 0.5 * old + 0.5 * new,
|
||||
"conservative": lambda old, new: 0.7 * old + 0.3 * new,
|
||||
}
|
||||
|
||||
|
||||
def get_aggregate_function(name: str) -> Callable[[torch.Tensor, torch.Tensor], torch.Tensor]:
|
||||
"""Get aggregate function by name from registry."""
|
||||
if name not in AGGREGATE_FUNCTIONS:
|
||||
available = list(AGGREGATE_FUNCTIONS.keys())
|
||||
raise ValueError(f"Unknown aggregate function '{name}'. Available: {available}")
|
||||
return AGGREGATE_FUNCTIONS[name]
|
||||
|
||||
|
||||
@dataclass
|
||||
class PolicyServerConfig:
|
||||
"""Configuration for PolicyServer.
|
||||
|
||||
This class defines all configurable parameters for the PolicyServer,
|
||||
including networking settings and action chunking specifications.
|
||||
"""
|
||||
|
||||
# Networking configuration
|
||||
host: str = field(default="localhost", metadata={"help": "Host address to bind the server to"})
|
||||
port: int = field(default=8080, metadata={"help": "Port number to bind the server to"})
|
||||
|
||||
# Timing configuration
|
||||
fps: int = field(default=DEFAULT_FPS, metadata={"help": "Frames per second"})
|
||||
inference_latency: float = field(
|
||||
default=DEFAULT_INFERENCE_LATENCY, metadata={"help": "Target inference latency in seconds"}
|
||||
)
|
||||
|
||||
obs_queue_timeout: float = field(
|
||||
default=DEFAULT_OBS_QUEUE_TIMEOUT, metadata={"help": "Timeout for observation queue in seconds"}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
"""Validate configuration after initialization."""
|
||||
if self.port < 1 or self.port > 65535:
|
||||
raise ValueError(f"Port must be between 1 and 65535, got {self.port}")
|
||||
|
||||
if self.environment_dt <= 0:
|
||||
raise ValueError(f"environment_dt must be positive, got {self.environment_dt}")
|
||||
|
||||
if self.inference_latency < 0:
|
||||
raise ValueError(f"inference_latency must be non-negative, got {self.inference_latency}")
|
||||
|
||||
if self.obs_queue_timeout < 0:
|
||||
raise ValueError(f"obs_queue_timeout must be non-negative, got {self.obs_queue_timeout}")
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, config_dict: dict) -> "PolicyServerConfig":
|
||||
"""Create a PolicyServerConfig from a dictionary."""
|
||||
return cls(**config_dict)
|
||||
|
||||
@property
|
||||
def environment_dt(self) -> float:
|
||||
"""Environment time step, in seconds"""
|
||||
return 1 / self.fps
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert the configuration to a dictionary."""
|
||||
return {
|
||||
"host": self.host,
|
||||
"port": self.port,
|
||||
"fps": self.fps,
|
||||
"environment_dt": self.environment_dt,
|
||||
"inference_latency": self.inference_latency,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class RobotClientConfig:
|
||||
"""Configuration for RobotClient.
|
||||
|
||||
This class defines all configurable parameters for the RobotClient,
|
||||
including network connection, policy settings, and control behavior.
|
||||
"""
|
||||
|
||||
# Policy configuration
|
||||
policy_type: str = field(metadata={"help": "Type of policy to use"})
|
||||
pretrained_name_or_path: str = field(metadata={"help": "Pretrained model name or path"})
|
||||
|
||||
# Robot configuration (for CLI usage - robot instance will be created from this)
|
||||
robot: RobotConfig = field(metadata={"help": "Robot configuration"})
|
||||
|
||||
# Policies typically output K actions at max, but we can use less to avoid wasting bandwidth (as actions
|
||||
# would be aggregated on the client side anyway, depending on the value of `chunk_size_threshold`)
|
||||
actions_per_chunk: int = field(metadata={"help": "Number of actions per chunk"})
|
||||
|
||||
# Task instruction for the robot to execute (e.g., 'fold my tshirt')
|
||||
task: str = field(default="", metadata={"help": "Task instruction for the robot to execute"})
|
||||
|
||||
# Network configuration
|
||||
server_address: str = field(default="localhost:8080", metadata={"help": "Server address to connect to"})
|
||||
|
||||
# Device configuration
|
||||
policy_device: str = field(default="cpu", metadata={"help": "Device for policy inference"})
|
||||
client_device: str = field(
|
||||
default="cpu",
|
||||
metadata={
|
||||
"help": "Device to move actions to after receiving from server (e.g., for downstream planners)"
|
||||
},
|
||||
)
|
||||
|
||||
# Control behavior configuration
|
||||
chunk_size_threshold: float = field(default=0.5, metadata={"help": "Threshold for chunk size control"})
|
||||
fps: int = field(default=DEFAULT_FPS, metadata={"help": "Frames per second"})
|
||||
|
||||
# Aggregate function configuration (CLI-compatible)
|
||||
aggregate_fn_name: str = field(
|
||||
default="weighted_average",
|
||||
metadata={"help": f"Name of aggregate function to use. Options: {list(AGGREGATE_FUNCTIONS.keys())}"},
|
||||
)
|
||||
|
||||
# Debug configuration
|
||||
debug_visualize_queue_size: bool = field(
|
||||
default=False, metadata={"help": "Visualize the action queue size"}
|
||||
)
|
||||
|
||||
@property
|
||||
def environment_dt(self) -> float:
|
||||
"""Environment time step, in seconds"""
|
||||
return 1 / self.fps
|
||||
|
||||
def __post_init__(self):
|
||||
"""Validate configuration after initialization."""
|
||||
if not self.server_address:
|
||||
raise ValueError("server_address cannot be empty")
|
||||
|
||||
if not self.policy_type:
|
||||
raise ValueError("policy_type cannot be empty")
|
||||
|
||||
if not self.pretrained_name_or_path:
|
||||
raise ValueError("pretrained_name_or_path cannot be empty")
|
||||
|
||||
if not self.policy_device:
|
||||
raise ValueError("policy_device cannot be empty")
|
||||
|
||||
if not self.client_device:
|
||||
raise ValueError("client_device cannot be empty")
|
||||
|
||||
if self.chunk_size_threshold < 0 or self.chunk_size_threshold > 1:
|
||||
raise ValueError(f"chunk_size_threshold must be between 0 and 1, got {self.chunk_size_threshold}")
|
||||
|
||||
if self.fps <= 0:
|
||||
raise ValueError(f"fps must be positive, got {self.fps}")
|
||||
|
||||
if self.actions_per_chunk <= 0:
|
||||
raise ValueError(f"actions_per_chunk must be positive, got {self.actions_per_chunk}")
|
||||
|
||||
self.aggregate_fn = get_aggregate_function(self.aggregate_fn_name)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, config_dict: dict) -> "RobotClientConfig":
|
||||
"""Create a RobotClientConfig from a dictionary."""
|
||||
return cls(**config_dict)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert the configuration to a dictionary."""
|
||||
return {
|
||||
"server_address": self.server_address,
|
||||
"policy_type": self.policy_type,
|
||||
"pretrained_name_or_path": self.pretrained_name_or_path,
|
||||
"policy_device": self.policy_device,
|
||||
"client_device": self.client_device,
|
||||
"chunk_size_threshold": self.chunk_size_threshold,
|
||||
"fps": self.fps,
|
||||
"actions_per_chunk": self.actions_per_chunk,
|
||||
"task": self.task,
|
||||
"debug_visualize_queue_size": self.debug_visualize_queue_size,
|
||||
"aggregate_fn_name": self.aggregate_fn_name,
|
||||
}
|
||||
@@ -1,29 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Client side: The environment evolves with a time resolution equal to 1/fps"""
|
||||
|
||||
DEFAULT_FPS = 30
|
||||
|
||||
"""Server side: Running inference on (at most) 1/fps"""
|
||||
DEFAULT_INFERENCE_LATENCY = 1 / DEFAULT_FPS
|
||||
|
||||
"""Server side: Timeout for observation queue in seconds"""
|
||||
DEFAULT_OBS_QUEUE_TIMEOUT = 2
|
||||
|
||||
# All action chunking policies
|
||||
SUPPORTED_POLICIES = ["act", "smolvla", "diffusion", "tdmpc", "vqbet", "pi0", "pi05", "groot"]
|
||||
|
||||
# TODO: Add all other robots
|
||||
SUPPORTED_ROBOTS = ["so100_follower", "so101_follower", "bi_so_follower", "omx_follower"]
|
||||
@@ -1,297 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import logging.handlers
|
||||
import os
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs import PolicyFeature
|
||||
|
||||
# NOTE: Configs need to be loaded for the client to be able to instantiate the policy config
|
||||
from lerobot.policies import ( # noqa: F401
|
||||
ACTConfig,
|
||||
DiffusionConfig,
|
||||
PI0Config,
|
||||
PI05Config,
|
||||
SmolVLAConfig,
|
||||
VQBeTConfig,
|
||||
)
|
||||
from lerobot.robots.robot import Robot
|
||||
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE, OBS_STR
|
||||
from lerobot.utils.feature_utils import build_dataset_frame, hw_to_dataset_features
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
Action = torch.Tensor
|
||||
|
||||
# observation as received from the robot (can be numpy arrays, floats, etc.)
|
||||
RawObservation = dict[str, Any]
|
||||
|
||||
# observation as those recorded in LeRobot dataset (keys are different)
|
||||
LeRobotObservation = dict[str, torch.Tensor]
|
||||
|
||||
# observation, ready for policy inference (image keys resized)
|
||||
Observation = dict[str, torch.Tensor]
|
||||
|
||||
|
||||
def visualize_action_queue_size(action_queue_size: list[int]) -> None:
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
_, ax = plt.subplots()
|
||||
ax.set_title("Action Queue Size Over Time")
|
||||
ax.set_xlabel("Environment steps")
|
||||
ax.set_ylabel("Action Queue Size")
|
||||
ax.set_ylim(0, max(action_queue_size) * 1.1)
|
||||
ax.grid(True, alpha=0.3)
|
||||
ax.plot(range(len(action_queue_size)), action_queue_size)
|
||||
plt.show()
|
||||
|
||||
|
||||
def map_robot_keys_to_lerobot_features(robot: Robot) -> dict[str, dict]:
|
||||
return hw_to_dataset_features(robot.observation_features, OBS_STR, use_video=False)
|
||||
|
||||
|
||||
def is_image_key(k: str) -> bool:
|
||||
return k.startswith(OBS_IMAGES)
|
||||
|
||||
|
||||
def resize_robot_observation_image(image: torch.tensor, resize_dims: tuple[int, int, int]) -> torch.tensor:
|
||||
assert image.ndim == 3, f"Image must be (C, H, W)! Received {image.shape}"
|
||||
# (H, W, C) -> (C, H, W) for resizing from robot obsevation resolution to policy image resolution
|
||||
image = image.permute(2, 0, 1)
|
||||
dims = (resize_dims[1], resize_dims[2])
|
||||
# Add batch dimension for interpolate: (C, H, W) -> (1, C, H, W)
|
||||
image_batched = image.unsqueeze(0)
|
||||
# Interpolate and remove batch dimension: (1, C, H, W) -> (C, H, W)
|
||||
resized = torch.nn.functional.interpolate(image_batched, size=dims, mode="bilinear", align_corners=False)
|
||||
|
||||
return resized.squeeze(0)
|
||||
|
||||
|
||||
# TODO(Steven): Consider implementing a pipeline step for this
|
||||
def raw_observation_to_observation(
|
||||
raw_observation: RawObservation,
|
||||
lerobot_features: dict[str, dict],
|
||||
policy_image_features: dict[str, PolicyFeature],
|
||||
) -> Observation:
|
||||
observation = {}
|
||||
|
||||
observation = prepare_raw_observation(raw_observation, lerobot_features, policy_image_features)
|
||||
for k, v in observation.items():
|
||||
if isinstance(v, torch.Tensor): # VLAs present natural-language instructions in observations
|
||||
if "image" in k:
|
||||
# Policy expects images in shape (B, C, H, W)
|
||||
observation[k] = prepare_image(v).unsqueeze(0)
|
||||
else:
|
||||
observation[k] = v
|
||||
|
||||
return observation
|
||||
|
||||
|
||||
def prepare_image(image: torch.Tensor) -> torch.Tensor:
|
||||
"""Minimal preprocessing to turn int8 images to float32 in [0, 1], and create a memory-contiguous tensor"""
|
||||
image = image.type(torch.float32) / 255
|
||||
image = image.contiguous()
|
||||
|
||||
return image
|
||||
|
||||
|
||||
def extract_state_from_raw_observation(
|
||||
lerobot_obs: RawObservation,
|
||||
) -> torch.Tensor:
|
||||
"""Extract the state from a raw observation."""
|
||||
state = torch.tensor(lerobot_obs[OBS_STATE])
|
||||
|
||||
if state.ndim == 1:
|
||||
state = state.unsqueeze(0)
|
||||
|
||||
return state
|
||||
|
||||
|
||||
def extract_images_from_raw_observation(
|
||||
lerobot_obs: RawObservation,
|
||||
camera_key: str,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""Extract the images from a raw observation."""
|
||||
return torch.tensor(lerobot_obs[camera_key])
|
||||
|
||||
|
||||
def make_lerobot_observation(
|
||||
robot_obs: RawObservation,
|
||||
lerobot_features: dict[str, dict],
|
||||
) -> LeRobotObservation:
|
||||
"""Make a lerobot observation from a raw observation."""
|
||||
return build_dataset_frame(lerobot_features, robot_obs, prefix=OBS_STR)
|
||||
|
||||
|
||||
def prepare_raw_observation(
|
||||
robot_obs: RawObservation,
|
||||
lerobot_features: dict[str, dict],
|
||||
policy_image_features: dict[str, PolicyFeature],
|
||||
) -> Observation:
|
||||
"""Matches keys from the raw robot_obs dict to the keys expected by a given policy (passed as
|
||||
policy_image_features)."""
|
||||
# 1. {motor.pos1:value1, motor.pos2:value2, ..., laptop:np.ndarray} ->
|
||||
# -> {observation.state:[value1,value2,...], observation.images.laptop:np.ndarray}
|
||||
lerobot_obs = make_lerobot_observation(robot_obs, lerobot_features)
|
||||
|
||||
# 2. Greps all observation.images.<> keys
|
||||
image_keys = list(filter(is_image_key, lerobot_obs))
|
||||
# state's shape is expected as (B, state_dim)
|
||||
state_dict = {OBS_STATE: extract_state_from_raw_observation(lerobot_obs)}
|
||||
image_dict = {
|
||||
image_k: extract_images_from_raw_observation(lerobot_obs, image_k) for image_k in image_keys
|
||||
}
|
||||
|
||||
# Turns the image features to (C, H, W) with H, W matching the policy image features.
|
||||
# This reduces the resolution of the images
|
||||
image_dict = {
|
||||
key: resize_robot_observation_image(torch.tensor(lerobot_obs[key]), policy_image_features[key].shape)
|
||||
for key in image_keys
|
||||
}
|
||||
|
||||
if "task" in robot_obs:
|
||||
state_dict["task"] = robot_obs["task"]
|
||||
|
||||
return {**state_dict, **image_dict}
|
||||
|
||||
|
||||
def get_logger(name: str, log_to_file: bool = True) -> logging.Logger:
|
||||
"""
|
||||
Get a logger using the standardized logging setup from utils.py.
|
||||
|
||||
Args:
|
||||
name: Logger name (e.g., 'policy_server', 'robot_client')
|
||||
log_to_file: Whether to also log to a file
|
||||
|
||||
Returns:
|
||||
Configured logger instance
|
||||
"""
|
||||
# Create logs directory if logging to file
|
||||
if log_to_file:
|
||||
os.makedirs("logs", exist_ok=True)
|
||||
log_file = Path(f"logs/{name}_{int(time.time())}.log")
|
||||
else:
|
||||
log_file = None
|
||||
|
||||
# Initialize the standardized logging
|
||||
init_logging(log_file=log_file, display_pid=False)
|
||||
|
||||
# Return a named logger
|
||||
return logging.getLogger(name)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TimedData:
|
||||
"""A data object with timestamp and timestep information.
|
||||
|
||||
Args:
|
||||
timestamp: Unix timestamp relative to data's creation.
|
||||
data: The actual data to wrap a timestamp around.
|
||||
timestep: The timestep of the data.
|
||||
"""
|
||||
|
||||
timestamp: float
|
||||
timestep: int
|
||||
|
||||
def get_timestamp(self):
|
||||
return self.timestamp
|
||||
|
||||
def get_timestep(self):
|
||||
return self.timestep
|
||||
|
||||
|
||||
@dataclass
|
||||
class TimedAction(TimedData):
|
||||
action: Action
|
||||
|
||||
def get_action(self):
|
||||
return self.action
|
||||
|
||||
|
||||
@dataclass
|
||||
class TimedObservation(TimedData):
|
||||
observation: RawObservation
|
||||
must_go: bool = False
|
||||
|
||||
def get_observation(self):
|
||||
return self.observation
|
||||
|
||||
|
||||
@dataclass
|
||||
class FPSTracker:
|
||||
"""Utility class to track FPS metrics over time."""
|
||||
|
||||
target_fps: float
|
||||
first_timestamp: float = None
|
||||
total_obs_count: int = 0
|
||||
|
||||
def calculate_fps_metrics(self, current_timestamp: float) -> dict[str, float]:
|
||||
"""Calculate average FPS vs target"""
|
||||
self.total_obs_count += 1
|
||||
|
||||
# Initialize first observation time
|
||||
if self.first_timestamp is None:
|
||||
self.first_timestamp = current_timestamp
|
||||
|
||||
# Calculate overall average FPS (since start)
|
||||
total_duration = current_timestamp - self.first_timestamp
|
||||
avg_fps = (self.total_obs_count - 1) / total_duration if total_duration > 1e-6 else 0.0
|
||||
|
||||
return {"avg_fps": avg_fps, "target_fps": self.target_fps}
|
||||
|
||||
def reset(self):
|
||||
"""Reset the FPS tracker state"""
|
||||
self.first_timestamp = None
|
||||
self.total_obs_count = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class RemotePolicyConfig:
|
||||
policy_type: str
|
||||
pretrained_name_or_path: str
|
||||
lerobot_features: dict[str, PolicyFeature]
|
||||
actions_per_chunk: int
|
||||
device: str = "cpu"
|
||||
rename_map: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
|
||||
def _compare_observation_states(obs1_state: torch.Tensor, obs2_state: torch.Tensor, atol: float) -> bool:
|
||||
"""Check if two observation states are similar, under a tolerance threshold"""
|
||||
return bool(torch.linalg.norm(obs1_state - obs2_state) < atol)
|
||||
|
||||
|
||||
def observations_similar(
|
||||
obs1: TimedObservation, obs2: TimedObservation, lerobot_features: dict[str, dict], atol: float = 1
|
||||
) -> bool:
|
||||
"""Check if two observations are similar, under a tolerance threshold. Measures distance between
|
||||
observations as the difference in joint-space between the two observations.
|
||||
|
||||
NOTE(fracapuano): This is a very simple check, and it is enough for the current use case.
|
||||
An immediate next step is to use (fast) perceptual difference metrics comparing some camera views,
|
||||
to surpass this joint-space similarity check.
|
||||
"""
|
||||
obs1_state = extract_state_from_raw_observation(
|
||||
make_lerobot_observation(obs1.get_observation(), lerobot_features)
|
||||
)
|
||||
obs2_state = extract_state_from_raw_observation(
|
||||
make_lerobot_observation(obs2.get_observation(), lerobot_features)
|
||||
)
|
||||
|
||||
return _compare_observation_states(obs1_state, obs2_state, atol=atol)
|
||||
@@ -1,439 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Example:
|
||||
```shell
|
||||
python -m lerobot.async_inference.policy_server \
|
||||
--host=127.0.0.1 \
|
||||
--port=8080 \
|
||||
--fps=30 \
|
||||
--inference_latency=0.033 \
|
||||
--obs_queue_timeout=1
|
||||
```
|
||||
"""
|
||||
|
||||
import logging
|
||||
import pickle # nosec
|
||||
import threading
|
||||
import time
|
||||
from concurrent import futures
|
||||
from dataclasses import asdict
|
||||
from pprint import pformat
|
||||
from queue import Empty, Queue
|
||||
from typing import Any
|
||||
|
||||
import draccus
|
||||
import grpc
|
||||
import torch
|
||||
|
||||
from lerobot.policies import get_policy_class, make_pre_post_processors
|
||||
from lerobot.processor import PolicyProcessorPipeline
|
||||
from lerobot.transport import (
|
||||
services_pb2, # type: ignore
|
||||
services_pb2_grpc, # type: ignore
|
||||
)
|
||||
from lerobot.transport.utils import receive_bytes_in_chunks
|
||||
from lerobot.types import PolicyAction
|
||||
|
||||
from .configs import PolicyServerConfig
|
||||
from .constants import SUPPORTED_POLICIES
|
||||
from .helpers import (
|
||||
FPSTracker,
|
||||
Observation,
|
||||
RemotePolicyConfig,
|
||||
TimedAction,
|
||||
TimedObservation,
|
||||
get_logger,
|
||||
observations_similar,
|
||||
raw_observation_to_observation,
|
||||
)
|
||||
|
||||
|
||||
class PolicyServer(services_pb2_grpc.AsyncInferenceServicer):
|
||||
prefix = "policy_server"
|
||||
logger = get_logger(prefix)
|
||||
|
||||
def __init__(self, config: PolicyServerConfig):
|
||||
self.config = config
|
||||
self.shutdown_event = threading.Event()
|
||||
|
||||
# FPS measurement
|
||||
self.fps_tracker = FPSTracker(target_fps=config.fps)
|
||||
|
||||
self.observation_queue = Queue(maxsize=1)
|
||||
|
||||
self._predicted_timesteps_lock = threading.Lock()
|
||||
self._predicted_timesteps = set()
|
||||
|
||||
self.last_processed_obs = None
|
||||
|
||||
# Attributes will be set by SendPolicyInstructions
|
||||
self.device = None
|
||||
self.policy_type = None
|
||||
self.lerobot_features = None
|
||||
self.actions_per_chunk = None
|
||||
self.policy = None
|
||||
self.preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]] | None = None
|
||||
self.postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction] | None = None
|
||||
|
||||
@property
|
||||
def running(self):
|
||||
return not self.shutdown_event.is_set()
|
||||
|
||||
@property
|
||||
def policy_image_features(self):
|
||||
return self.policy.config.image_features
|
||||
|
||||
def _reset_server(self) -> None:
|
||||
"""Flushes server state when new client connects."""
|
||||
# only running inference on the latest observation received by the server
|
||||
self.shutdown_event.set()
|
||||
self.observation_queue = Queue(maxsize=1)
|
||||
|
||||
with self._predicted_timesteps_lock:
|
||||
self._predicted_timesteps = set()
|
||||
|
||||
def Ready(self, request, context): # noqa: N802
|
||||
client_id = context.peer()
|
||||
self.logger.info(f"Client {client_id} connected and ready")
|
||||
self._reset_server()
|
||||
self.shutdown_event.clear()
|
||||
|
||||
return services_pb2.Empty()
|
||||
|
||||
def SendPolicyInstructions(self, request, context): # noqa: N802
|
||||
"""Receive policy instructions from the robot client"""
|
||||
|
||||
if not self.running:
|
||||
self.logger.warning("Server is not running. Ignoring policy instructions.")
|
||||
return services_pb2.Empty()
|
||||
|
||||
client_id = context.peer()
|
||||
|
||||
policy_specs = pickle.loads(request.data) # nosec
|
||||
|
||||
if not isinstance(policy_specs, RemotePolicyConfig):
|
||||
raise TypeError(f"Policy specs must be a RemotePolicyConfig. Got {type(policy_specs)}")
|
||||
|
||||
if policy_specs.policy_type not in SUPPORTED_POLICIES:
|
||||
raise ValueError(
|
||||
f"Policy type {policy_specs.policy_type} not supported. "
|
||||
f"Supported policies: {SUPPORTED_POLICIES}"
|
||||
)
|
||||
|
||||
self.logger.info(
|
||||
f"Receiving policy instructions from {client_id} | "
|
||||
f"Policy type: {policy_specs.policy_type} | "
|
||||
f"Pretrained name or path: {policy_specs.pretrained_name_or_path} | "
|
||||
f"Actions per chunk: {policy_specs.actions_per_chunk} | "
|
||||
f"Device: {policy_specs.device}"
|
||||
)
|
||||
|
||||
self.device = policy_specs.device
|
||||
self.policy_type = policy_specs.policy_type # act, pi0, etc.
|
||||
self.lerobot_features = policy_specs.lerobot_features
|
||||
self.actions_per_chunk = policy_specs.actions_per_chunk
|
||||
|
||||
policy_class = get_policy_class(self.policy_type)
|
||||
|
||||
start = time.perf_counter()
|
||||
self.policy = policy_class.from_pretrained(policy_specs.pretrained_name_or_path)
|
||||
self.policy.to(self.device)
|
||||
|
||||
# Load preprocessor and postprocessor, overriding device to match requested device
|
||||
device_override = {"device": self.device}
|
||||
self.preprocessor, self.postprocessor = make_pre_post_processors(
|
||||
self.policy.config,
|
||||
pretrained_path=policy_specs.pretrained_name_or_path,
|
||||
preprocessor_overrides={
|
||||
"device_processor": device_override,
|
||||
"rename_observations_processor": {"rename_map": policy_specs.rename_map},
|
||||
},
|
||||
postprocessor_overrides={"device_processor": device_override},
|
||||
)
|
||||
|
||||
end = time.perf_counter()
|
||||
|
||||
self.logger.info(f"Time taken to put policy on {self.device}: {end - start:.4f} seconds")
|
||||
|
||||
return services_pb2.Empty()
|
||||
|
||||
def SendObservations(self, request_iterator, context): # noqa: N802
|
||||
"""Receive observations from the robot client"""
|
||||
client_id = context.peer()
|
||||
self.logger.debug(f"Receiving observations from {client_id}")
|
||||
|
||||
receive_time = time.time() # comparing timestamps so need time.time()
|
||||
start_deserialize = time.perf_counter()
|
||||
received_bytes = receive_bytes_in_chunks(
|
||||
request_iterator, None, self.shutdown_event, self.logger
|
||||
) # blocking call while looping over request_iterator
|
||||
timed_observation = pickle.loads(received_bytes) # nosec
|
||||
deserialize_time = time.perf_counter() - start_deserialize
|
||||
|
||||
self.logger.debug(f"Received observation #{timed_observation.get_timestep()}")
|
||||
|
||||
obs_timestep = timed_observation.get_timestep()
|
||||
obs_timestamp = timed_observation.get_timestamp()
|
||||
|
||||
# Calculate FPS metrics
|
||||
fps_metrics = self.fps_tracker.calculate_fps_metrics(obs_timestamp)
|
||||
|
||||
self.logger.debug(
|
||||
f"Received observation #{obs_timestep} | "
|
||||
f"Avg FPS: {fps_metrics['avg_fps']:.2f} | " # fps at which observations are received from client
|
||||
f"Target: {fps_metrics['target_fps']:.2f} | "
|
||||
f"One-way latency: {(receive_time - obs_timestamp) * 1000:.2f}ms"
|
||||
)
|
||||
|
||||
self.logger.debug(
|
||||
f"Server timestamp: {receive_time:.6f} | "
|
||||
f"Client timestamp: {obs_timestamp:.6f} | "
|
||||
f"Deserialization time: {deserialize_time:.6f}s"
|
||||
)
|
||||
|
||||
if not self._enqueue_observation(
|
||||
timed_observation # wrapping a RawObservation
|
||||
):
|
||||
self.logger.debug(f"Observation #{obs_timestep} has been filtered out")
|
||||
|
||||
return services_pb2.Empty()
|
||||
|
||||
def GetActions(self, request, context): # noqa: N802
|
||||
"""Returns actions to the robot client. Actions are sent as a single
|
||||
chunk, containing multiple actions."""
|
||||
client_id = context.peer()
|
||||
self.logger.debug(f"Client {client_id} connected for action streaming")
|
||||
|
||||
# Generate action based on the most recent observation and its timestep
|
||||
try:
|
||||
getactions_starts = time.perf_counter()
|
||||
obs = self.observation_queue.get(timeout=self.config.obs_queue_timeout)
|
||||
self.logger.info(
|
||||
f"Running inference for observation #{obs.get_timestep()} (must_go: {obs.must_go})"
|
||||
)
|
||||
|
||||
with self._predicted_timesteps_lock:
|
||||
self._predicted_timesteps.add(obs.get_timestep())
|
||||
|
||||
start_time = time.perf_counter()
|
||||
action_chunk = self._predict_action_chunk(obs)
|
||||
inference_time = time.perf_counter() - start_time
|
||||
|
||||
start_time = time.perf_counter()
|
||||
actions_bytes = pickle.dumps(action_chunk) # nosec
|
||||
serialize_time = time.perf_counter() - start_time
|
||||
|
||||
# Create and return the action chunk
|
||||
actions = services_pb2.Actions(data=actions_bytes)
|
||||
|
||||
self.logger.info(
|
||||
f"Action chunk #{obs.get_timestep()} generated | "
|
||||
f"Total time: {(inference_time + serialize_time) * 1000:.2f}ms"
|
||||
)
|
||||
|
||||
self.logger.debug(
|
||||
f"Action chunk #{obs.get_timestep()} generated | "
|
||||
f"Inference time: {inference_time:.2f}s |"
|
||||
f"Serialize time: {serialize_time:.2f}s |"
|
||||
f"Total time: {inference_time + serialize_time:.2f}s"
|
||||
)
|
||||
|
||||
time.sleep(
|
||||
max(0, self.config.inference_latency - max(0, time.perf_counter() - getactions_starts))
|
||||
) # sleep controls inference latency
|
||||
|
||||
return actions
|
||||
|
||||
except Empty: # no observation added to queue in obs_queue_timeout
|
||||
return services_pb2.Empty()
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in StreamActions: {e}")
|
||||
|
||||
return services_pb2.Empty()
|
||||
|
||||
def _obs_sanity_checks(self, obs: TimedObservation, previous_obs: TimedObservation) -> bool:
|
||||
"""Check if the observation is valid to be processed by the policy"""
|
||||
with self._predicted_timesteps_lock:
|
||||
predicted_timesteps = self._predicted_timesteps
|
||||
|
||||
if obs.get_timestep() in predicted_timesteps:
|
||||
self.logger.debug(f"Skipping observation #{obs.get_timestep()} - Timestep predicted already!")
|
||||
return False
|
||||
|
||||
elif observations_similar(obs, previous_obs, lerobot_features=self.lerobot_features):
|
||||
self.logger.debug(
|
||||
f"Skipping observation #{obs.get_timestep()} - Observation too similar to last obs predicted!"
|
||||
)
|
||||
return False
|
||||
|
||||
else:
|
||||
return True
|
||||
|
||||
def _enqueue_observation(self, obs: TimedObservation) -> bool:
|
||||
"""Enqueue an observation if it must go through processing, otherwise skip it.
|
||||
Observations not in queue are never run through the policy network"""
|
||||
|
||||
if (
|
||||
obs.must_go
|
||||
or self.last_processed_obs is None
|
||||
or self._obs_sanity_checks(obs, self.last_processed_obs)
|
||||
):
|
||||
last_obs = self.last_processed_obs.get_timestep() if self.last_processed_obs else "None"
|
||||
self.logger.debug(
|
||||
f"Enqueuing observation. Must go: {obs.must_go} | Last processed obs: {last_obs}"
|
||||
)
|
||||
|
||||
# If queue is full, get the old observation to make room
|
||||
if self.observation_queue.full():
|
||||
# pops from queue
|
||||
_ = self.observation_queue.get_nowait()
|
||||
self.logger.debug("Observation queue was full, removed oldest observation")
|
||||
|
||||
# Now put the new observation (never blocks as queue is non-full here)
|
||||
self.observation_queue.put(obs)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _time_action_chunk(self, t_0: float, action_chunk: list[torch.Tensor], i_0: int) -> list[TimedAction]:
|
||||
"""Turn a chunk of actions into a list of TimedAction instances,
|
||||
with the first action corresponding to t_0 and the rest corresponding to
|
||||
t_0 + i*environment_dt for i in range(len(action_chunk))
|
||||
"""
|
||||
return [
|
||||
TimedAction(timestamp=t_0 + i * self.config.environment_dt, timestep=i_0 + i, action=action)
|
||||
for i, action in enumerate(action_chunk)
|
||||
]
|
||||
|
||||
def _get_action_chunk(self, observation: dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
"""Get an action chunk from the policy. The chunk contains only"""
|
||||
chunk = self.policy.predict_action_chunk(observation)
|
||||
if chunk.ndim != 3:
|
||||
chunk = chunk.unsqueeze(0) # adding batch dimension, now shape is (B, chunk_size, action_dim)
|
||||
|
||||
return chunk[:, : self.actions_per_chunk, :]
|
||||
|
||||
def _predict_action_chunk(self, observation_t: TimedObservation) -> list[TimedAction]:
|
||||
"""Predict an action chunk based on an observation.
|
||||
|
||||
Pipeline:
|
||||
1. Convert raw observation to LeRobot format
|
||||
2. Apply preprocessor (tokenization, normalization, batching, device placement)
|
||||
3. Run policy inference to get action chunk
|
||||
4. Apply postprocessor (unnormalization, device movement)
|
||||
5. Convert to TimedAction list
|
||||
"""
|
||||
"""1. Prepare observation"""
|
||||
start_prepare = time.perf_counter()
|
||||
observation: Observation = raw_observation_to_observation(
|
||||
observation_t.get_observation(),
|
||||
self.lerobot_features,
|
||||
self.policy_image_features,
|
||||
)
|
||||
prepare_time = time.perf_counter() - start_prepare
|
||||
|
||||
"""2. Apply preprocessor"""
|
||||
start_preprocess = time.perf_counter()
|
||||
observation = self.preprocessor(observation)
|
||||
self.last_processed_obs: TimedObservation = observation_t
|
||||
preprocessing_time = time.perf_counter() - start_preprocess
|
||||
|
||||
"""3. Get action chunk"""
|
||||
start_inference = time.perf_counter()
|
||||
action_tensor = self._get_action_chunk(observation)
|
||||
inference_time = time.perf_counter() - start_inference
|
||||
self.logger.info(
|
||||
f"Preprocessing and inference took {inference_time:.4f}s, action shape: {action_tensor.shape}"
|
||||
)
|
||||
|
||||
"""4. Apply postprocessor"""
|
||||
# Apply postprocessor (handles unnormalization and device movement)
|
||||
# Postprocessor expects (B, action_dim) per action, but we have (B, chunk_size, action_dim)
|
||||
# So we process each action in the chunk individually
|
||||
start_postprocess = time.perf_counter()
|
||||
_, chunk_size, _ = action_tensor.shape
|
||||
|
||||
# Process each action in the chunk
|
||||
processed_actions = []
|
||||
for i in range(chunk_size):
|
||||
# Extract action at timestep i: (B, action_dim)
|
||||
single_action = action_tensor[:, i, :]
|
||||
processed_action = self.postprocessor(single_action)
|
||||
processed_actions.append(processed_action)
|
||||
|
||||
# Stack back to (B, chunk_size, action_dim), then remove batch dim
|
||||
action_tensor = torch.stack(processed_actions, dim=1).squeeze(0)
|
||||
self.logger.debug(f"Postprocessed action shape: {action_tensor.shape}")
|
||||
|
||||
action_tensor = action_tensor.detach().cpu()
|
||||
|
||||
"""5. Convert to TimedAction list"""
|
||||
action_chunk = self._time_action_chunk(
|
||||
observation_t.get_timestamp(), list(action_tensor), observation_t.get_timestep()
|
||||
)
|
||||
postprocess_stops = time.perf_counter()
|
||||
postprocessing_time = postprocess_stops - start_postprocess
|
||||
|
||||
self.logger.info(
|
||||
f"Observation {observation_t.get_timestep()} | "
|
||||
f"Total time: {1000 * (postprocess_stops - start_prepare):.2f}ms"
|
||||
)
|
||||
|
||||
self.logger.debug(
|
||||
f"Observation {observation_t.get_timestep()} | "
|
||||
f"Prepare time: {1000 * prepare_time:.2f}ms | "
|
||||
f"Preprocessing time: {1000 * preprocessing_time:.2f}ms | "
|
||||
f"Inference time: {1000 * inference_time:.2f}ms | "
|
||||
f"Postprocessing time: {1000 * postprocessing_time:.2f}ms | "
|
||||
f"Total time: {1000 * (postprocess_stops - start_prepare):.2f}ms"
|
||||
)
|
||||
|
||||
return action_chunk
|
||||
|
||||
def stop(self):
|
||||
"""Stop the server"""
|
||||
self._reset_server()
|
||||
self.logger.info("Server stopping...")
|
||||
|
||||
|
||||
@draccus.wrap()
|
||||
def serve(cfg: PolicyServerConfig):
|
||||
"""Start the PolicyServer with the given configuration.
|
||||
|
||||
Args:
|
||||
config: PolicyServerConfig instance. If None, uses default configuration.
|
||||
"""
|
||||
logging.info(pformat(asdict(cfg)))
|
||||
|
||||
# Create the server instance first
|
||||
policy_server = PolicyServer(cfg)
|
||||
|
||||
# Setup and start gRPC server
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=4))
|
||||
services_pb2_grpc.add_AsyncInferenceServicer_to_server(policy_server, server)
|
||||
server.add_insecure_port(f"{cfg.host}:{cfg.port}")
|
||||
|
||||
policy_server.logger.info(f"PolicyServer started on {cfg.host}:{cfg.port}")
|
||||
server.start()
|
||||
|
||||
server.wait_for_termination()
|
||||
|
||||
policy_server.logger.info("Server terminated")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
serve()
|
||||
@@ -1,517 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Example command:
|
||||
```shell
|
||||
python src/lerobot/async_inference/robot_client.py \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58760431541 \
|
||||
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 1920, height: 1080, fps: 30}}" \
|
||||
--robot.id=black \
|
||||
--task="dummy" \
|
||||
--server_address=127.0.0.1:8080 \
|
||||
--policy_type=act \
|
||||
--pretrained_name_or_path=user/model \
|
||||
--policy_device=mps \
|
||||
--client_device=cpu \
|
||||
--actions_per_chunk=50 \
|
||||
--chunk_size_threshold=0.5 \
|
||||
--aggregate_fn_name=weighted_average \
|
||||
--debug_visualize_queue_size=True
|
||||
```
|
||||
"""
|
||||
|
||||
import logging
|
||||
import pickle # nosec
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from dataclasses import asdict
|
||||
from pprint import pformat
|
||||
from queue import Queue
|
||||
from typing import Any
|
||||
|
||||
import draccus
|
||||
import grpc
|
||||
import torch
|
||||
|
||||
from lerobot.cameras.opencv import OpenCVCameraConfig # noqa: F401
|
||||
from lerobot.cameras.realsense import RealSenseCameraConfig # noqa: F401
|
||||
from lerobot.robots import ( # noqa: F401
|
||||
Robot,
|
||||
RobotConfig,
|
||||
bi_so_follower,
|
||||
koch_follower,
|
||||
make_robot_from_config,
|
||||
omx_follower,
|
||||
so_follower,
|
||||
)
|
||||
from lerobot.transport import (
|
||||
services_pb2, # type: ignore
|
||||
services_pb2_grpc, # type: ignore
|
||||
)
|
||||
from lerobot.transport.utils import grpc_channel_options, send_bytes_in_chunks
|
||||
from lerobot.utils.import_utils import register_third_party_plugins
|
||||
|
||||
from .configs import RobotClientConfig
|
||||
from .helpers import (
|
||||
Action,
|
||||
FPSTracker,
|
||||
Observation,
|
||||
RawObservation,
|
||||
RemotePolicyConfig,
|
||||
TimedAction,
|
||||
TimedObservation,
|
||||
get_logger,
|
||||
map_robot_keys_to_lerobot_features,
|
||||
visualize_action_queue_size,
|
||||
)
|
||||
|
||||
|
||||
class RobotClient:
|
||||
prefix = "robot_client"
|
||||
logger = get_logger(prefix)
|
||||
|
||||
def __init__(self, config: RobotClientConfig):
|
||||
"""Initialize RobotClient with unified configuration.
|
||||
|
||||
Args:
|
||||
config: RobotClientConfig containing all configuration parameters
|
||||
"""
|
||||
# Store configuration
|
||||
self.config = config
|
||||
self.robot = make_robot_from_config(config.robot)
|
||||
self.robot.connect()
|
||||
|
||||
lerobot_features = map_robot_keys_to_lerobot_features(self.robot)
|
||||
|
||||
# Use environment variable if server_address is not provided in config
|
||||
self.server_address = config.server_address
|
||||
|
||||
self.policy_config = RemotePolicyConfig(
|
||||
config.policy_type,
|
||||
config.pretrained_name_or_path,
|
||||
lerobot_features,
|
||||
config.actions_per_chunk,
|
||||
config.policy_device,
|
||||
)
|
||||
self.channel = grpc.insecure_channel(
|
||||
self.server_address, grpc_channel_options(initial_backoff=f"{config.environment_dt:.4f}s")
|
||||
)
|
||||
self.stub = services_pb2_grpc.AsyncInferenceStub(self.channel)
|
||||
self.logger.info(f"Initializing client to connect to server at {self.server_address}")
|
||||
|
||||
self.shutdown_event = threading.Event()
|
||||
|
||||
# Initialize client side variables
|
||||
self.latest_action_lock = threading.Lock()
|
||||
self.latest_action = -1
|
||||
self.action_chunk_size = -1
|
||||
|
||||
self._chunk_size_threshold = config.chunk_size_threshold
|
||||
|
||||
self.action_queue = Queue()
|
||||
self.action_queue_lock = threading.Lock() # Protect queue operations
|
||||
self.action_queue_size = []
|
||||
self.start_barrier = threading.Barrier(2) # 2 threads: action receiver, control loop
|
||||
|
||||
# FPS measurement
|
||||
self.fps_tracker = FPSTracker(target_fps=self.config.fps)
|
||||
|
||||
self.logger.info("Robot connected and ready")
|
||||
|
||||
# Use an event for thread-safe coordination
|
||||
self.must_go = threading.Event()
|
||||
self.must_go.set() # Initially set - observations qualify for direct processing
|
||||
|
||||
@property
|
||||
def running(self):
|
||||
return not self.shutdown_event.is_set()
|
||||
|
||||
def start(self):
|
||||
"""Start the robot client and connect to the policy server"""
|
||||
try:
|
||||
# client-server handshake
|
||||
start_time = time.perf_counter()
|
||||
self.stub.Ready(services_pb2.Empty())
|
||||
end_time = time.perf_counter()
|
||||
self.logger.debug(f"Connected to policy server in {end_time - start_time:.4f}s")
|
||||
|
||||
# send policy instructions
|
||||
policy_config_bytes = pickle.dumps(self.policy_config)
|
||||
policy_setup = services_pb2.PolicySetup(data=policy_config_bytes)
|
||||
|
||||
self.logger.info("Sending policy instructions to policy server")
|
||||
self.logger.debug(
|
||||
f"Policy type: {self.policy_config.policy_type} | "
|
||||
f"Pretrained name or path: {self.policy_config.pretrained_name_or_path} | "
|
||||
f"Device: {self.policy_config.device}"
|
||||
)
|
||||
|
||||
self.stub.SendPolicyInstructions(policy_setup)
|
||||
|
||||
self.shutdown_event.clear()
|
||||
|
||||
return True
|
||||
|
||||
except grpc.RpcError as e:
|
||||
self.logger.error(f"Failed to connect to policy server: {e}")
|
||||
return False
|
||||
|
||||
def stop(self):
|
||||
"""Stop the robot client"""
|
||||
self.shutdown_event.set()
|
||||
|
||||
self.robot.disconnect()
|
||||
self.logger.debug("Robot disconnected")
|
||||
|
||||
self.channel.close()
|
||||
self.logger.debug("Client stopped, channel closed")
|
||||
|
||||
def send_observation(
|
||||
self,
|
||||
obs: TimedObservation,
|
||||
) -> bool:
|
||||
"""Send observation to the policy server.
|
||||
Returns True if the observation was sent successfully, False otherwise."""
|
||||
if not self.running:
|
||||
raise RuntimeError("Client not running. Run RobotClient.start() before sending observations.")
|
||||
|
||||
if not isinstance(obs, TimedObservation):
|
||||
raise ValueError("Input observation needs to be a TimedObservation!")
|
||||
|
||||
start_time = time.perf_counter()
|
||||
observation_bytes = pickle.dumps(obs)
|
||||
serialize_time = time.perf_counter() - start_time
|
||||
self.logger.debug(f"Observation serialization time: {serialize_time:.6f}s")
|
||||
|
||||
try:
|
||||
observation_iterator = send_bytes_in_chunks(
|
||||
observation_bytes,
|
||||
services_pb2.Observation,
|
||||
log_prefix="[CLIENT] Observation",
|
||||
silent=True,
|
||||
)
|
||||
_ = self.stub.SendObservations(observation_iterator)
|
||||
obs_timestep = obs.get_timestep()
|
||||
self.logger.debug(f"Sent observation #{obs_timestep} | ")
|
||||
|
||||
return True
|
||||
|
||||
except grpc.RpcError as e:
|
||||
self.logger.error(f"Error sending observation #{obs.get_timestep()}: {e}")
|
||||
return False
|
||||
|
||||
def _inspect_action_queue(self):
|
||||
with self.action_queue_lock:
|
||||
queue_size = self.action_queue.qsize()
|
||||
timestamps = sorted([action.get_timestep() for action in self.action_queue.queue])
|
||||
self.logger.debug(f"Queue size: {queue_size}, Queue contents: {timestamps}")
|
||||
return queue_size, timestamps
|
||||
|
||||
def _aggregate_action_queues(
|
||||
self,
|
||||
incoming_actions: list[TimedAction],
|
||||
aggregate_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
|
||||
):
|
||||
"""Finds the same timestep actions in the queue and aggregates them using the aggregate_fn"""
|
||||
if aggregate_fn is None:
|
||||
# default aggregate function: take the latest action
|
||||
def aggregate_fn(x1, x2):
|
||||
return x2
|
||||
|
||||
future_action_queue = Queue()
|
||||
with self.action_queue_lock:
|
||||
internal_queue = self.action_queue.queue
|
||||
|
||||
current_action_queue = {action.get_timestep(): action.get_action() for action in internal_queue}
|
||||
|
||||
for new_action in incoming_actions:
|
||||
with self.latest_action_lock:
|
||||
latest_action = self.latest_action
|
||||
|
||||
# New action is older than the latest action in the queue, skip it
|
||||
if new_action.get_timestep() <= latest_action:
|
||||
continue
|
||||
|
||||
# If the new action's timestep is not in the current action queue, add it directly
|
||||
elif new_action.get_timestep() not in current_action_queue:
|
||||
future_action_queue.put(new_action)
|
||||
continue
|
||||
|
||||
# If the new action's timestep is in the current action queue, aggregate it
|
||||
# TODO: There is probably a way to do this with broadcasting of the two action tensors
|
||||
future_action_queue.put(
|
||||
TimedAction(
|
||||
timestamp=new_action.get_timestamp(),
|
||||
timestep=new_action.get_timestep(),
|
||||
action=aggregate_fn(
|
||||
current_action_queue[new_action.get_timestep()], new_action.get_action()
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
with self.action_queue_lock:
|
||||
self.action_queue = future_action_queue
|
||||
|
||||
def receive_actions(self, verbose: bool = False):
|
||||
"""Receive actions from the policy server"""
|
||||
# Wait at barrier for synchronized start
|
||||
self.start_barrier.wait()
|
||||
self.logger.info("Action receiving thread starting")
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
# Use StreamActions to get a stream of actions from the server
|
||||
actions_chunk = self.stub.GetActions(services_pb2.Empty())
|
||||
if len(actions_chunk.data) == 0:
|
||||
continue # received `Empty` from server, wait for next call
|
||||
|
||||
receive_time = time.time()
|
||||
|
||||
# Deserialize bytes back into list[TimedAction]
|
||||
deserialize_start = time.perf_counter()
|
||||
timed_actions = pickle.loads(actions_chunk.data) # nosec
|
||||
deserialize_time = time.perf_counter() - deserialize_start
|
||||
|
||||
# Log device type of received actions
|
||||
if len(timed_actions) > 0:
|
||||
received_device = timed_actions[0].get_action().device.type
|
||||
self.logger.debug(f"Received actions on device: {received_device}")
|
||||
|
||||
# Move actions to client_device (e.g., for downstream planners that need GPU)
|
||||
client_device = self.config.client_device
|
||||
if client_device != "cpu":
|
||||
for timed_action in timed_actions:
|
||||
if timed_action.get_action().device.type != client_device:
|
||||
timed_action.action = timed_action.get_action().to(client_device)
|
||||
self.logger.debug(f"Converted actions to device: {client_device}")
|
||||
else:
|
||||
self.logger.debug(f"Actions kept on device: {client_device}")
|
||||
|
||||
self.action_chunk_size = max(self.action_chunk_size, len(timed_actions))
|
||||
|
||||
# Calculate network latency if we have matching observations
|
||||
if len(timed_actions) > 0 and verbose:
|
||||
with self.latest_action_lock:
|
||||
latest_action = self.latest_action
|
||||
|
||||
self.logger.debug(f"Current latest action: {latest_action}")
|
||||
|
||||
# Get queue state before changes
|
||||
old_size, old_timesteps = self._inspect_action_queue()
|
||||
if not old_timesteps:
|
||||
old_timesteps = [latest_action] # queue was empty
|
||||
|
||||
# Log incoming actions
|
||||
incoming_timesteps = [a.get_timestep() for a in timed_actions]
|
||||
|
||||
first_action_timestep = timed_actions[0].get_timestep()
|
||||
server_to_client_latency = (receive_time - timed_actions[0].get_timestamp()) * 1000
|
||||
|
||||
self.logger.info(
|
||||
f"Received action chunk for step #{first_action_timestep} | "
|
||||
f"Latest action: #{latest_action} | "
|
||||
f"Incoming actions: {incoming_timesteps[0]}:{incoming_timesteps[-1]} | "
|
||||
f"Network latency (server->client): {server_to_client_latency:.2f}ms | "
|
||||
f"Deserialization time: {deserialize_time * 1000:.2f}ms"
|
||||
)
|
||||
|
||||
# Update action queue
|
||||
start_time = time.perf_counter()
|
||||
self._aggregate_action_queues(timed_actions, self.config.aggregate_fn)
|
||||
queue_update_time = time.perf_counter() - start_time
|
||||
|
||||
self.must_go.set() # after receiving actions, next empty queue triggers must-go processing!
|
||||
|
||||
if verbose:
|
||||
# Get queue state after changes
|
||||
new_size, new_timesteps = self._inspect_action_queue()
|
||||
|
||||
with self.latest_action_lock:
|
||||
latest_action = self.latest_action
|
||||
|
||||
self.logger.info(
|
||||
f"Latest action: {latest_action} | "
|
||||
f"Old action steps: {old_timesteps[0]}:{old_timesteps[-1]} | "
|
||||
f"Incoming action steps: {incoming_timesteps[0]}:{incoming_timesteps[-1]} | "
|
||||
f"Updated action steps: {new_timesteps[0]}:{new_timesteps[-1]}"
|
||||
)
|
||||
self.logger.debug(
|
||||
f"Queue update complete ({queue_update_time:.6f}s) | "
|
||||
f"Before: {old_size} items | "
|
||||
f"After: {new_size} items | "
|
||||
)
|
||||
|
||||
except grpc.RpcError as e:
|
||||
self.logger.error(f"Error receiving actions: {e}")
|
||||
|
||||
def actions_available(self):
|
||||
"""Check if there are actions available in the queue"""
|
||||
with self.action_queue_lock:
|
||||
return not self.action_queue.empty()
|
||||
|
||||
def _action_tensor_to_action_dict(self, action_tensor: torch.Tensor) -> dict[str, float]:
|
||||
action = {key: action_tensor[i].item() for i, key in enumerate(self.robot.action_features)}
|
||||
return action
|
||||
|
||||
def control_loop_action(self, verbose: bool = False) -> dict[str, Any]:
|
||||
"""Reading and performing actions in local queue"""
|
||||
|
||||
# Lock only for queue operations
|
||||
get_start = time.perf_counter()
|
||||
with self.action_queue_lock:
|
||||
self.action_queue_size.append(self.action_queue.qsize())
|
||||
# Get action from queue
|
||||
timed_action = self.action_queue.get_nowait()
|
||||
get_end = time.perf_counter() - get_start
|
||||
|
||||
_performed_action = self.robot.send_action(
|
||||
self._action_tensor_to_action_dict(timed_action.get_action())
|
||||
)
|
||||
with self.latest_action_lock:
|
||||
self.latest_action = timed_action.get_timestep()
|
||||
|
||||
if verbose:
|
||||
with self.action_queue_lock:
|
||||
current_queue_size = self.action_queue.qsize()
|
||||
|
||||
self.logger.debug(
|
||||
f"Ts={timed_action.get_timestamp()} | "
|
||||
f"Action #{timed_action.get_timestep()} performed | "
|
||||
f"Queue size: {current_queue_size}"
|
||||
)
|
||||
|
||||
self.logger.debug(
|
||||
f"Popping action from queue to perform took {get_end:.6f}s | Queue size: {current_queue_size}"
|
||||
)
|
||||
|
||||
return _performed_action
|
||||
|
||||
def _ready_to_send_observation(self):
|
||||
"""Flags when the client is ready to send an observation"""
|
||||
with self.action_queue_lock:
|
||||
return self.action_queue.qsize() / self.action_chunk_size <= self._chunk_size_threshold
|
||||
|
||||
def control_loop_observation(self, task: str, verbose: bool = False) -> RawObservation:
|
||||
try:
|
||||
# Get serialized observation bytes from the function
|
||||
start_time = time.perf_counter()
|
||||
|
||||
raw_observation: RawObservation = self.robot.get_observation()
|
||||
raw_observation["task"] = task
|
||||
|
||||
with self.latest_action_lock:
|
||||
latest_action = self.latest_action
|
||||
|
||||
observation = TimedObservation(
|
||||
timestamp=time.time(), # need time.time() to compare timestamps across client and server
|
||||
observation=raw_observation,
|
||||
timestep=max(latest_action, 0),
|
||||
)
|
||||
|
||||
obs_capture_time = time.perf_counter() - start_time
|
||||
|
||||
# If there are no actions left in the queue, the observation must go through processing!
|
||||
with self.action_queue_lock:
|
||||
observation.must_go = self.must_go.is_set() and self.action_queue.empty()
|
||||
current_queue_size = self.action_queue.qsize()
|
||||
|
||||
_ = self.send_observation(observation)
|
||||
|
||||
self.logger.debug(f"QUEUE SIZE: {current_queue_size} (Must go: {observation.must_go})")
|
||||
if observation.must_go:
|
||||
# must-go event will be set again after receiving actions
|
||||
self.must_go.clear()
|
||||
|
||||
if verbose:
|
||||
# Calculate comprehensive FPS metrics
|
||||
fps_metrics = self.fps_tracker.calculate_fps_metrics(observation.get_timestamp())
|
||||
|
||||
self.logger.info(
|
||||
f"Obs #{observation.get_timestep()} | "
|
||||
f"Avg FPS: {fps_metrics['avg_fps']:.2f} | "
|
||||
f"Target: {fps_metrics['target_fps']:.2f}"
|
||||
)
|
||||
|
||||
self.logger.debug(
|
||||
f"Ts={observation.get_timestamp():.6f} | Capturing observation took {obs_capture_time:.6f}s"
|
||||
)
|
||||
|
||||
return raw_observation
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in observation sender: {e}")
|
||||
|
||||
def control_loop(self, task: str, verbose: bool = False) -> tuple[Observation, Action]:
|
||||
"""Combined function for executing actions and streaming observations"""
|
||||
# Wait at barrier for synchronized start
|
||||
self.start_barrier.wait()
|
||||
self.logger.info("Control loop thread starting")
|
||||
|
||||
_performed_action = None
|
||||
_captured_observation = None
|
||||
|
||||
while self.running:
|
||||
control_loop_start = time.perf_counter()
|
||||
"""Control loop: (1) Performing actions, when available"""
|
||||
if self.actions_available():
|
||||
_performed_action = self.control_loop_action(verbose)
|
||||
|
||||
"""Control loop: (2) Streaming observations to the remote policy server"""
|
||||
if self._ready_to_send_observation():
|
||||
_captured_observation = self.control_loop_observation(task, verbose)
|
||||
|
||||
self.logger.debug(f"Control loop (ms): {(time.perf_counter() - control_loop_start) * 1000:.2f}")
|
||||
# Dynamically adjust sleep time to maintain the desired control frequency
|
||||
time.sleep(max(0, self.config.environment_dt - (time.perf_counter() - control_loop_start)))
|
||||
|
||||
return _captured_observation, _performed_action
|
||||
|
||||
|
||||
@draccus.wrap()
|
||||
def async_client(cfg: RobotClientConfig):
|
||||
logging.info(pformat(asdict(cfg)))
|
||||
|
||||
# TODO: Assert if checking robot support is still needed with the plugin system
|
||||
# if cfg.robot.type not in SUPPORTED_ROBOTS:
|
||||
# raise ValueError(f"Robot {cfg.robot.type} not yet supported!")
|
||||
|
||||
client = RobotClient(cfg)
|
||||
|
||||
if client.start():
|
||||
client.logger.info("Starting action receiver thread...")
|
||||
|
||||
# Create and start action receiver thread
|
||||
action_receiver_thread = threading.Thread(target=client.receive_actions, daemon=True)
|
||||
|
||||
# Start action receiver thread
|
||||
action_receiver_thread.start()
|
||||
|
||||
try:
|
||||
# The main thread runs the control loop
|
||||
client.control_loop(task=cfg.task)
|
||||
|
||||
finally:
|
||||
client.stop()
|
||||
action_receiver_thread.join()
|
||||
if cfg.debug_visualize_queue_size:
|
||||
visualize_action_queue_size(client.action_queue_size)
|
||||
client.logger.info("Client stopped")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
register_third_party_plugins()
|
||||
async_client() # run the client
|
||||
@@ -49,19 +49,8 @@ def get_step_checkpoint_dir(output_dir: Path, total_steps: int, step: int) -> Pa
|
||||
return output_dir / CHECKPOINTS_DIR / step_identifier
|
||||
|
||||
|
||||
def save_training_step(
|
||||
step: int, save_dir: Path, num_processes: int | None = None, batch_size: int | None = None
|
||||
) -> None:
|
||||
state: dict = {"step": step}
|
||||
# num_processes and batch_size are recorded so a resumed run can detect a changed world size or
|
||||
# batch size: the sampler's resume offset is computed from the (num_processes, batch_size) that
|
||||
# produced `step`, since both scale how many sampler positions a step consumes (see
|
||||
# compute_sampler_state).
|
||||
if num_processes is not None:
|
||||
state["num_processes"] = num_processes
|
||||
if batch_size is not None:
|
||||
state["batch_size"] = batch_size
|
||||
write_json(state, save_dir / TRAINING_STEP)
|
||||
def save_training_step(step: int, save_dir: Path) -> None:
|
||||
write_json({"step": step}, save_dir / TRAINING_STEP)
|
||||
|
||||
|
||||
def load_training_step(save_dir: Path) -> int:
|
||||
@@ -69,16 +58,6 @@ def load_training_step(save_dir: Path) -> int:
|
||||
return training_step["step"]
|
||||
|
||||
|
||||
def load_training_num_processes(checkpoint_dir: Path) -> int | None:
|
||||
"""World size recorded at checkpoint time, or None for checkpoints written before it was stored."""
|
||||
return load_json(checkpoint_dir / TRAINING_STATE_DIR / TRAINING_STEP).get("num_processes")
|
||||
|
||||
|
||||
def load_training_batch_size(checkpoint_dir: Path) -> int | None:
|
||||
"""Per-process batch size recorded at checkpoint time, or None for older checkpoints."""
|
||||
return load_json(checkpoint_dir / TRAINING_STATE_DIR / TRAINING_STEP).get("batch_size")
|
||||
|
||||
|
||||
def update_last_checkpoint(checkpoint_dir: Path) -> Path:
|
||||
last_checkpoint_dir = checkpoint_dir.parent / LAST_CHECKPOINT_LINK
|
||||
if last_checkpoint_dir.is_symlink():
|
||||
@@ -96,8 +75,6 @@ def save_checkpoint(
|
||||
scheduler: LRScheduler | None = None,
|
||||
preprocessor: PolicyProcessorPipeline | None = None,
|
||||
postprocessor: PolicyProcessorPipeline | None = None,
|
||||
num_processes: int | None = None,
|
||||
batch_size: int | None = None,
|
||||
) -> None:
|
||||
"""This function creates the following directory structure:
|
||||
|
||||
@@ -123,10 +100,6 @@ def save_checkpoint(
|
||||
scheduler (LRScheduler | None, optional): The scheduler to save the state from. Defaults to None.
|
||||
preprocessor: The preprocessor/pipeline to save. Defaults to None.
|
||||
postprocessor: The postprocessor/pipeline to save. Defaults to None.
|
||||
num_processes (int | None, optional): Distributed world size to record for sample-exact
|
||||
resume. Defaults to None (not recorded).
|
||||
batch_size (int | None, optional): Per-process batch size to record for sample-exact
|
||||
resume. Defaults to None (not recorded).
|
||||
"""
|
||||
pretrained_dir = checkpoint_dir / PRETRAINED_MODEL_DIR
|
||||
policy.save_pretrained(pretrained_dir)
|
||||
@@ -139,9 +112,7 @@ def save_checkpoint(
|
||||
preprocessor.save_pretrained(pretrained_dir)
|
||||
if postprocessor is not None:
|
||||
postprocessor.save_pretrained(pretrained_dir)
|
||||
save_training_state(
|
||||
checkpoint_dir, step, optimizer, scheduler, num_processes=num_processes, batch_size=batch_size
|
||||
)
|
||||
save_training_state(checkpoint_dir, step, optimizer, scheduler)
|
||||
|
||||
|
||||
def save_training_state(
|
||||
@@ -149,8 +120,6 @@ def save_training_state(
|
||||
train_step: int,
|
||||
optimizer: Optimizer | None = None,
|
||||
scheduler: LRScheduler | None = None,
|
||||
num_processes: int | None = None,
|
||||
batch_size: int | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Saves the training step, optimizer state, scheduler state, and rng state.
|
||||
@@ -162,12 +131,10 @@ def save_training_state(
|
||||
Defaults to None.
|
||||
scheduler (LRScheduler | None, optional): The scheduler from which to save the state_dict.
|
||||
Defaults to None.
|
||||
num_processes (int | None, optional): Distributed world size to record. Defaults to None.
|
||||
batch_size (int | None, optional): Per-process batch size to record. Defaults to None.
|
||||
"""
|
||||
save_dir = checkpoint_dir / TRAINING_STATE_DIR
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
save_training_step(train_step, save_dir, num_processes=num_processes, batch_size=batch_size)
|
||||
save_training_step(train_step, save_dir)
|
||||
save_rng_state(save_dir)
|
||||
if optimizer is not None:
|
||||
save_optimizer_state(optimizer, save_dir)
|
||||
|
||||
@@ -79,8 +79,6 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: igno
|
||||
# Either the repo ID of a model hosted on the Hub or a path to a directory containing weights
|
||||
# saved using `Policy.save_pretrained`. If not provided, the policy is initialized from scratch.
|
||||
pretrained_path: Path | None = None
|
||||
# Optional Hub revision (commit hash, branch, or tag) to pin the pretrained model version.
|
||||
pretrained_revision: str | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if not self.device or not is_torch_device_available(self.device):
|
||||
|
||||
@@ -56,8 +56,6 @@ class RewardModelConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
|
||||
device: str | None = None
|
||||
|
||||
pretrained_path: str | None = None
|
||||
# Optional Hub revision (commit hash, branch, or tag) to pin the pretrained reward model version.
|
||||
pretrained_revision: str | None = None
|
||||
|
||||
push_to_hub: bool = False
|
||||
repo_id: str | None = None
|
||||
|
||||
@@ -50,7 +50,7 @@ from .lerobot_dataset import LeRobotDataset
|
||||
from .multi_dataset import MultiLeRobotDataset
|
||||
from .pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
|
||||
from .pyav_utils import check_video_encoder_parameters_pyav, detect_available_encoders_pyav
|
||||
from .sampler import EpisodeAwareSampler, compute_sampler_state
|
||||
from .sampler import EpisodeAwareSampler
|
||||
from .streaming_dataset import StreamingLeRobotDataset
|
||||
from .utils import DEFAULT_EPISODES_PATH, create_lerobot_dataset_card
|
||||
from .video_utils import VideoEncodingManager
|
||||
@@ -82,7 +82,6 @@ __all__ = [
|
||||
"aggregate_stats",
|
||||
"convert_image_to_video_dataset",
|
||||
"create_initial_features",
|
||||
"compute_sampler_state",
|
||||
"create_lerobot_dataset_card",
|
||||
"column_for_style",
|
||||
"delete_episodes",
|
||||
|
||||
@@ -286,8 +286,6 @@ def aggregate_datasets(
|
||||
data_files_size_in_mb: int | None = None,
|
||||
video_files_size_in_mb: int | None = None,
|
||||
chunk_size: int | None = None,
|
||||
concatenate_videos: bool = True,
|
||||
concatenate_data: bool = True,
|
||||
):
|
||||
"""Aggregates multiple LeRobot datasets into a single unified dataset.
|
||||
|
||||
@@ -305,8 +303,6 @@ def aggregate_datasets(
|
||||
data_files_size_in_mb: Maximum size for data files in MB (defaults to DEFAULT_DATA_FILE_SIZE_IN_MB)
|
||||
video_files_size_in_mb: Maximum size for video files in MB (defaults to DEFAULT_VIDEO_FILE_SIZE_IN_MB)
|
||||
chunk_size: Maximum number of files per chunk (defaults to DEFAULT_CHUNK_SIZE)
|
||||
concatenate_videos: When False, keep one mp4 per source file instead of packing into shards.
|
||||
concatenate_data: When False, keep one parquet per source file instead of packing into shards.
|
||||
"""
|
||||
logging.info("Start aggregate_datasets")
|
||||
|
||||
@@ -355,12 +351,8 @@ def aggregate_datasets(
|
||||
dst_meta.episodes = {}
|
||||
|
||||
for src_meta in tqdm.tqdm(all_metadata, desc="Copy data and videos"):
|
||||
videos_idx = aggregate_videos(
|
||||
src_meta, dst_meta, videos_idx, video_files_size_in_mb, chunk_size, concatenate_videos
|
||||
)
|
||||
data_idx = aggregate_data(
|
||||
src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_size, concatenate_data
|
||||
)
|
||||
videos_idx = aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chunk_size)
|
||||
data_idx = aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_size)
|
||||
|
||||
meta_idx = aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx)
|
||||
|
||||
@@ -375,9 +367,7 @@ def aggregate_datasets(
|
||||
logging.info("Aggregation complete.")
|
||||
|
||||
|
||||
def aggregate_videos(
|
||||
src_meta, dst_meta, videos_idx, video_files_size_in_mb, chunk_size, concatenate_videos=True
|
||||
):
|
||||
def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chunk_size):
|
||||
"""Aggregates video chunks from a source dataset into the destination dataset.
|
||||
|
||||
Handles video file concatenation and rotation based on file size limits.
|
||||
@@ -389,7 +379,6 @@ def aggregate_videos(
|
||||
videos_idx: Dictionary tracking video chunk and file indices.
|
||||
video_files_size_in_mb: Maximum size for video files in MB (defaults to DEFAULT_VIDEO_FILE_SIZE_IN_MB)
|
||||
chunk_size: Maximum number of files per chunk (defaults to DEFAULT_CHUNK_SIZE)
|
||||
concatenate_videos: When False, keep one mp4 per source file instead of packing into shards.
|
||||
Returns:
|
||||
dict: Updated videos_idx with current chunk and file indices.
|
||||
"""
|
||||
@@ -450,7 +439,7 @@ def aggregate_videos(
|
||||
src_size = get_file_size_in_mb(src_path)
|
||||
dst_size = get_file_size_in_mb(dst_path)
|
||||
|
||||
if not concatenate_videos or dst_size + src_size >= video_files_size_in_mb:
|
||||
if dst_size + src_size >= video_files_size_in_mb:
|
||||
# Rotate to a new file - offset is 0
|
||||
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, chunk_size)
|
||||
dst_key = (chunk_idx, file_idx)
|
||||
@@ -488,7 +477,7 @@ def aggregate_videos(
|
||||
return videos_idx
|
||||
|
||||
|
||||
def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_size, concatenate_data=True):
|
||||
def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_size):
|
||||
"""Aggregates data chunks from a source dataset into the destination dataset.
|
||||
|
||||
Reads source data files, updates indices to match the aggregated dataset,
|
||||
@@ -504,7 +493,6 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
|
||||
data_idx: Dictionary tracking data chunk and file indices.
|
||||
data_files_size_in_mb: Maximum size for data files in MB.
|
||||
chunk_size: Maximum number of files per chunk.
|
||||
concatenate_data: When False, keep one parquet per source file instead of packing into shards.
|
||||
|
||||
Returns:
|
||||
dict: Updated data_idx with current chunk and file indices.
|
||||
@@ -550,7 +538,6 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
|
||||
contains_images=contains_images,
|
||||
aggr_root=dst_meta.root,
|
||||
hf_features=hf_features,
|
||||
concatenate=concatenate_data,
|
||||
)
|
||||
|
||||
# Record the mapping from source to actual destination
|
||||
@@ -627,7 +614,6 @@ def append_or_create_parquet_file(
|
||||
contains_images: bool = False,
|
||||
aggr_root: Path = None,
|
||||
hf_features: datasets.Features | None = None,
|
||||
concatenate: bool = True,
|
||||
) -> tuple[dict[str, int], tuple[int, int]]:
|
||||
"""Appends data to an existing parquet file or creates a new one based on size constraints.
|
||||
|
||||
@@ -644,7 +630,6 @@ def append_or_create_parquet_file(
|
||||
contains_images: Whether the data contains images requiring special handling.
|
||||
aggr_root: Root path for the aggregated dataset.
|
||||
hf_features: Optional HuggingFace Features schema for proper image typing.
|
||||
concatenate: When False, always rotate to a new file instead of appending to the current one.
|
||||
|
||||
Returns:
|
||||
tuple: (updated_idx, (dst_chunk, dst_file)) where updated_idx is the index dict
|
||||
@@ -664,7 +649,7 @@ def append_or_create_parquet_file(
|
||||
src_size = get_parquet_file_size_in_mb(src_path)
|
||||
dst_size = get_parquet_file_size_in_mb(dst_path)
|
||||
|
||||
if not concatenate or dst_size + src_size >= max_mb:
|
||||
if dst_size + src_size >= max_mb:
|
||||
idx["chunk"], idx["file"] = update_chunk_file_indices(idx["chunk"], idx["file"], chunk_size)
|
||||
dst_chunk, dst_file = idx["chunk"], idx["file"]
|
||||
new_path = aggr_root / default_path.format(chunk_index=dst_chunk, file_index=dst_file)
|
||||
|
||||
@@ -59,8 +59,6 @@ class RunningQuantileStats:
|
||||
batch: An array where all dimensions except the last are batch dimensions.
|
||||
"""
|
||||
batch = batch.reshape(-1, batch.shape[-1])
|
||||
# Promote integer and low-precision inputs before computing squared statistics.
|
||||
batch = batch.astype(np.result_type(batch.dtype, np.float32), copy=False)
|
||||
num_elements, vector_length = batch.shape
|
||||
|
||||
if self._count == 0:
|
||||
|
||||
@@ -261,8 +261,6 @@ def merge_datasets(
|
||||
datasets: list[LeRobotDataset],
|
||||
output_repo_id: str,
|
||||
output_dir: str | Path | None = None,
|
||||
concatenate_videos: bool = True,
|
||||
concatenate_data: bool = True,
|
||||
) -> LeRobotDataset:
|
||||
"""Merge multiple LeRobotDatasets into a single dataset.
|
||||
|
||||
@@ -272,8 +270,6 @@ def merge_datasets(
|
||||
datasets: List of LeRobotDatasets to merge.
|
||||
output_repo_id: Merged dataset identifier.
|
||||
output_dir: Root directory where the merged dataset will be stored. If not specified, defaults to $HF_LEROBOT_HOME/output_repo_id.
|
||||
concatenate_videos: When False, keep one mp4 per source file instead of packing into shards.
|
||||
concatenate_data: When False, keep one parquet per source file instead of packing into shards.
|
||||
"""
|
||||
if not datasets:
|
||||
raise ValueError("No datasets to merge")
|
||||
@@ -288,8 +284,6 @@ def merge_datasets(
|
||||
aggr_repo_id=output_repo_id,
|
||||
roots=roots,
|
||||
aggr_root=output_dir,
|
||||
concatenate_videos=concatenate_videos,
|
||||
concatenate_data=concatenate_data,
|
||||
)
|
||||
|
||||
merged_dataset = LeRobotDataset(
|
||||
|
||||
+38
-122
@@ -14,36 +14,14 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import math
|
||||
from collections.abc import Iterator
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EpisodeAwareSampler:
|
||||
"""Sampler over episode frames that stores only per-episode boundaries.
|
||||
|
||||
Logical positions map to frame indices on the fly (O(num_episodes) construction memory)
|
||||
instead of materializing a Python list of every frame index.
|
||||
|
||||
Each epoch is shuffled with a `torch.randperm` seeded from `(seed, epoch)`, so the data order
|
||||
is a pure function of `(seed, epoch)`: it reproduces on every rank without synchronizing the
|
||||
global RNG (no `generator` to sync across distributed ranks), and `state_dict` /
|
||||
`load_state_dict` resume a run sample-exactly by regenerating the epoch's permutation and
|
||||
continuing from the saved offset. Each call to `__iter__` advances the epoch. During a
|
||||
resumed epoch, `__len__` still reports the full length.
|
||||
|
||||
Epoch advancement: `__iter__` eagerly advances the epoch, and `set_epoch` / `load_state_dict`
|
||||
set it explicitly. Within a single run callers should rely on exactly one of these mechanisms,
|
||||
not both: advancing the epoch by hand *and* letting `__iter__` auto-advance over the same
|
||||
iterations would skip or repeat epochs. The training loop drives it purely through `__iter__`
|
||||
(via `cycle`); `set_epoch` / `load_state_dict` are used only to (re)position before iteration
|
||||
starts (e.g. on resume or in tests).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset_from_indices: list[int],
|
||||
@@ -52,125 +30,63 @@ class EpisodeAwareSampler:
|
||||
drop_n_first_frames: int = 0,
|
||||
drop_n_last_frames: int = 0,
|
||||
shuffle: bool = False,
|
||||
seed: int = 0,
|
||||
generator: torch.Generator | None = None,
|
||||
):
|
||||
"""
|
||||
"""Sampler that optionally incorporates episode boundary information.
|
||||
|
||||
Args:
|
||||
dataset_from_indices: Start index of each episode in the dataset.
|
||||
dataset_to_indices: End index of each episode in the dataset.
|
||||
episode_indices_to_use: Episode indices to use; None means all.
|
||||
drop_n_first_frames: Frames to drop from the start of each episode.
|
||||
drop_n_last_frames: Frames to drop from the end of each episode.
|
||||
dataset_from_indices: List of indices containing the start of each episode in the dataset.
|
||||
dataset_to_indices: List of indices containing the end of each episode in the dataset.
|
||||
episode_indices_to_use: List of episode indices to use. If None, all episodes are used.
|
||||
Assumes that episodes are indexed from 0 to N-1.
|
||||
drop_n_first_frames: Number of frames to drop from the start of each episode.
|
||||
drop_n_last_frames: Number of frames to drop from the end of each episode.
|
||||
shuffle: Whether to shuffle the indices.
|
||||
seed: Seed the permutation is derived from (together with the epoch).
|
||||
generator: Generator used for shuffling. Exposing this attribute (even when None) lets
|
||||
`accelerate` register it as the synchronized RNG in distributed training, so
|
||||
every rank draws the same permutation and batch shards stay disjoint. When
|
||||
None, shuffling falls back to the global torch RNG.
|
||||
"""
|
||||
if drop_n_first_frames < 0:
|
||||
raise ValueError(f"drop_n_first_frames must be >= 0, got {drop_n_first_frames}")
|
||||
if drop_n_last_frames < 0:
|
||||
raise ValueError(f"drop_n_last_frames must be >= 0, got {drop_n_last_frames}")
|
||||
|
||||
from_indices = np.asarray(dataset_from_indices, dtype=np.int64)
|
||||
to_indices = np.asarray(dataset_to_indices, dtype=np.int64)
|
||||
if from_indices.shape != to_indices.shape:
|
||||
raise ValueError(
|
||||
f"dataset_from_indices and dataset_to_indices must have the same length, "
|
||||
f"got {len(from_indices)} and {len(to_indices)}"
|
||||
)
|
||||
indices = []
|
||||
for episode_idx, (start_index, end_index) in enumerate(
|
||||
zip(dataset_from_indices, dataset_to_indices, strict=True)
|
||||
):
|
||||
if episode_indices_to_use is None or episode_idx in episode_indices_to_use:
|
||||
ep_length = end_index - start_index
|
||||
if drop_n_first_frames + drop_n_last_frames >= ep_length:
|
||||
logger.warning(
|
||||
"Episode %d has %d frames but drop_n_first_frames=%d and "
|
||||
"drop_n_last_frames=%d removes all frames. Skipping.",
|
||||
episode_idx,
|
||||
ep_length,
|
||||
drop_n_first_frames,
|
||||
drop_n_last_frames,
|
||||
)
|
||||
continue
|
||||
indices.extend(range(start_index + drop_n_first_frames, end_index - drop_n_last_frames))
|
||||
|
||||
used = np.ones(len(from_indices), dtype=bool)
|
||||
if episode_indices_to_use is not None:
|
||||
used = np.zeros(len(from_indices), dtype=bool)
|
||||
used[np.asarray(episode_indices_to_use, dtype=np.int64)] = True
|
||||
|
||||
starts = from_indices + drop_n_first_frames
|
||||
lengths = to_indices - drop_n_last_frames - starts
|
||||
for episode_idx in np.flatnonzero(used & (lengths <= 0)):
|
||||
logger.warning(
|
||||
"Episode %d has %d frames but drop_n_first_frames=%d and "
|
||||
"drop_n_last_frames=%d removes all frames. Skipping.",
|
||||
episode_idx,
|
||||
to_indices[episode_idx] - from_indices[episode_idx],
|
||||
drop_n_first_frames,
|
||||
drop_n_last_frames,
|
||||
)
|
||||
used &= lengths > 0
|
||||
if not used.any():
|
||||
if not indices:
|
||||
raise ValueError(
|
||||
"No valid frames remain after applying drop_n_first_frames and drop_n_last_frames. "
|
||||
"All episodes were either filtered out or had too few frames."
|
||||
)
|
||||
|
||||
self._starts = starts[used]
|
||||
self._cum_lengths = np.cumsum(lengths[used])
|
||||
self._num_frames = int(self._cum_lengths[-1])
|
||||
self.indices = indices
|
||||
self.shuffle = shuffle
|
||||
self.seed = seed
|
||||
self._epoch = 0
|
||||
self._start_index = 0
|
||||
|
||||
@property
|
||||
def indices(self) -> list[int]:
|
||||
"""Materialized frame indices in unshuffled order; O(num_frames), introspection only."""
|
||||
return [self._frame_index(k) for k in range(self._num_frames)]
|
||||
|
||||
def set_epoch(self, epoch: int) -> None:
|
||||
self._epoch = epoch
|
||||
|
||||
def state_dict(self) -> dict:
|
||||
return {"epoch": self._epoch, "start_index": self._start_index}
|
||||
|
||||
def load_state_dict(self, state: dict) -> None:
|
||||
self._epoch = state["epoch"]
|
||||
self._start_index = state["start_index"]
|
||||
|
||||
def _epoch_generator(self, epoch: int) -> torch.Generator:
|
||||
# Derive a per-epoch seed from (seed, epoch) so the permutation is a pure function of both
|
||||
# and reproduces identically on every rank without touching the global RNG.
|
||||
epoch_seed = int(np.random.SeedSequence([self.seed, epoch]).generate_state(1, dtype=np.uint64)[0])
|
||||
return torch.Generator().manual_seed(epoch_seed)
|
||||
|
||||
def _frame_index(self, position: int) -> int:
|
||||
episode = int(np.searchsorted(self._cum_lengths, position, side="right"))
|
||||
position_in_episode = position - (int(self._cum_lengths[episode - 1]) if episode > 0 else 0)
|
||||
return int(self._starts[episode]) + position_in_episode
|
||||
self.generator = generator
|
||||
|
||||
def __iter__(self) -> Iterator[int]:
|
||||
# Advance epoch state eagerly, not on first consumption of the generator.
|
||||
epoch, start = self._epoch, self._start_index
|
||||
self._epoch += 1
|
||||
self._start_index = 0
|
||||
return self._iter_epoch(epoch, start)
|
||||
|
||||
def _iter_epoch(self, epoch: int, start: int) -> Iterator[int]:
|
||||
if self.shuffle:
|
||||
order = torch.randperm(self._num_frames, generator=self._epoch_generator(epoch))
|
||||
for k in range(start, self._num_frames):
|
||||
yield self._frame_index(int(order[k]))
|
||||
for i in torch.randperm(len(self.indices), generator=self.generator):
|
||||
yield self.indices[i]
|
||||
else:
|
||||
for k in range(start, self._num_frames):
|
||||
yield self._frame_index(k)
|
||||
for i in self.indices:
|
||||
yield i
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self._num_frames
|
||||
|
||||
|
||||
def compute_sampler_state(step: int, num_frames: int, batch_size: int, num_processes: int) -> dict:
|
||||
"""Map an optimization step to an `EpisodeAwareSampler` state for sample-exact resume.
|
||||
|
||||
Under accelerate's batch sharding, one step consumes `batch_size * num_processes` sampler
|
||||
positions and each rank sees `ceil(ceil(num_frames / batch_size) / num_processes)` batches
|
||||
per epoch (`even_batches` padding included). The start index provably stays below
|
||||
`num_frames`; the `min` is defensive.
|
||||
|
||||
Assumptions (resume is only sample-exact when they hold):
|
||||
- `num_processes` and `batch_size` match the run that wrote the checkpoint. Both scale how
|
||||
many positions a step consumes, so the epoch/offset are wrong if either changed. The
|
||||
caller passes the checkpoint's `num_processes` and `batch_size` and warns on a mismatch.
|
||||
- accelerate uses `even_batches=True` (its default). The `ceil(... / num_processes)` term
|
||||
mirrors that padding; with `even_batches=False` the per-epoch batch count differs and
|
||||
the boundary is off.
|
||||
"""
|
||||
batches_per_epoch = math.ceil(math.ceil(num_frames / batch_size) / num_processes)
|
||||
epoch, batches_into_epoch = divmod(step, batches_per_epoch)
|
||||
start_index = min(batches_into_epoch * batch_size * num_processes, num_frames)
|
||||
return {"epoch": epoch, "start_index": start_index}
|
||||
return len(self.indices)
|
||||
|
||||
@@ -481,10 +481,8 @@ def reencode_video(
|
||||
encoder_threads: int | None = None,
|
||||
log_level: int | None = av.logging.WARNING,
|
||||
overwrite: bool = False,
|
||||
start_time_s: float | None = None,
|
||||
end_time_s: float | None = None,
|
||||
) -> None:
|
||||
"""Re-encode a video file, optionally trimming it to ``[start_time_s, end_time_s)``.
|
||||
"""Re-encode a video file using the given encoder configuration.
|
||||
|
||||
Args:
|
||||
input_video_path: Existing video file to read.
|
||||
@@ -493,17 +491,10 @@ def reencode_video(
|
||||
encoder_threads: Optional thread count forwarded to :meth:`VideoEncoderConfig.get_codec_options`.
|
||||
log_level: libav log level while encoding, or ``None`` to leave logging unchanged. Defaults to WARNING.
|
||||
overwrite: When ``False`` and ``output_video_path`` already exists, skip and log a warning.
|
||||
start_time_s: When set, trim the output to start at this timestamp (seconds).
|
||||
end_time_s: When set, trim the output to end at this timestamp (seconds, exclusive).
|
||||
"""
|
||||
|
||||
camera_encoder = camera_encoder or camera_encoder_defaults()
|
||||
|
||||
if (start_time_s is not None and start_time_s < 0) or (end_time_s is not None and end_time_s < 0):
|
||||
raise ValueError(f"Trim times must be non-negative, got start={start_time_s}, end={end_time_s}.")
|
||||
if start_time_s is not None and end_time_s is not None and end_time_s <= start_time_s:
|
||||
raise ValueError(f"end_time_s ({end_time_s}) must be greater than start_time_s ({start_time_s}).")
|
||||
|
||||
output_video_path = Path(output_video_path)
|
||||
|
||||
if output_video_path.exists() and not overwrite:
|
||||
@@ -535,10 +526,6 @@ def reencode_video(
|
||||
width = int(in_stream.width)
|
||||
height = int(in_stream.height)
|
||||
|
||||
# Seek to the keyframe at or before start_time_s to avoid reading from the start.
|
||||
if start_time_s is not None:
|
||||
src.seek(int(start_time_s * av.time_base), backward=True)
|
||||
|
||||
with av.open(
|
||||
tmp_output_video_path,
|
||||
mode="w",
|
||||
@@ -552,14 +539,7 @@ def reencode_video(
|
||||
out_stream.height = height
|
||||
|
||||
for frame in src.decode(in_stream):
|
||||
frame_time_s = frame.time
|
||||
if start_time_s is not None and frame_time_s < start_time_s:
|
||||
continue
|
||||
if end_time_s is not None and frame_time_s >= end_time_s:
|
||||
break
|
||||
frame = frame.reformat(width=width, height=height, format=pix_fmt)
|
||||
if start_time_s is not None:
|
||||
frame.pts = None # reset timestamps so the trimmed output starts at t=0
|
||||
packet = out_stream.encode(frame)
|
||||
if packet:
|
||||
dst.mux(packet)
|
||||
|
||||
@@ -252,7 +252,6 @@ class ProcessorConfigKwargs(TypedDict, total=False):
|
||||
def make_pre_post_processors(
|
||||
policy_cfg: PreTrainedConfig,
|
||||
pretrained_path: str | None = None,
|
||||
pretrained_revision: str | None = None,
|
||||
**kwargs: Unpack[ProcessorConfigKwargs],
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
@@ -310,7 +309,6 @@ def make_pre_post_processors(
|
||||
overrides=kwargs.get("preprocessor_overrides", {}),
|
||||
to_transition=batch_to_transition,
|
||||
to_output=transition_to_batch,
|
||||
revision=pretrained_revision,
|
||||
)
|
||||
postprocessor = PolicyProcessorPipeline.from_pretrained(
|
||||
pretrained_model_name_or_path=pretrained_path,
|
||||
@@ -320,7 +318,6 @@ def make_pre_post_processors(
|
||||
overrides=kwargs.get("postprocessor_overrides", {}),
|
||||
to_transition=policy_action_to_transition,
|
||||
to_output=transition_to_policy_action,
|
||||
revision=pretrained_revision,
|
||||
)
|
||||
_reconnect_relative_absolute_steps(preprocessor, postprocessor)
|
||||
return preprocessor, postprocessor
|
||||
@@ -560,7 +557,6 @@ def make_policy(
|
||||
# Load a pretrained policy and override the config if needed (for example, if there are inference-time
|
||||
# hyperparameters that we want to vary).
|
||||
kwargs["pretrained_name_or_path"] = cfg.pretrained_path
|
||||
kwargs["revision"] = cfg.pretrained_revision
|
||||
policy = policy_cls.from_pretrained(**kwargs)
|
||||
elif cfg.pretrained_path and cfg.use_peft:
|
||||
# Load a pretrained PEFT model on top of the policy. The pretrained path points to the folder/repo
|
||||
|
||||
@@ -29,7 +29,6 @@ from huggingface_hub.errors import HfHubHTTPError
|
||||
from safetensors.torch import load_model as load_model_as_safetensor, save_model as save_model_as_safetensor
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.__version__ import __version__
|
||||
from lerobot.configs import PreTrainedConfig
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.utils.hub import HubMixin
|
||||
@@ -39,67 +38,6 @@ from .utils import log_model_loading_keys
|
||||
T = TypeVar("T", bound="PreTrainedPolicy")
|
||||
|
||||
|
||||
def _build_card_context(
|
||||
cfg: TrainPipelineConfig | None,
|
||||
dataset_repo_id: str | None,
|
||||
input_features: dict | None,
|
||||
output_features: dict | None,
|
||||
) -> dict:
|
||||
"""Collect optional data for the model-card template.
|
||||
|
||||
Returns plain values only (no Markdown) — the template in
|
||||
``lerobot/templates/lerobot_modelcard_template.md`` decides how and whether to show
|
||||
each one. Everything is best-effort: anything unavailable is left empty/None and the
|
||||
template simply skips that section, so this never breaks a Hub push.
|
||||
"""
|
||||
context = {
|
||||
"training": None,
|
||||
"input_features": input_features or {},
|
||||
"output_features": output_features or {},
|
||||
"dataset": None,
|
||||
"robot_type": None,
|
||||
"cameras": [],
|
||||
}
|
||||
|
||||
if cfg is not None:
|
||||
optimizer = getattr(cfg, "optimizer", None)
|
||||
context["training"] = {
|
||||
"steps": cfg.steps,
|
||||
"batch_size": cfg.batch_size,
|
||||
"seed": cfg.seed,
|
||||
"optimizer": getattr(optimizer, "type", None) if optimizer else None,
|
||||
"lr": getattr(optimizer, "lr", None) if optimizer else None,
|
||||
"lerobot_version": __version__,
|
||||
}
|
||||
|
||||
if dataset_repo_id:
|
||||
dataset_cfg = getattr(cfg, "dataset", None)
|
||||
try:
|
||||
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
|
||||
|
||||
meta = LeRobotDatasetMetadata(
|
||||
dataset_repo_id,
|
||||
root=getattr(dataset_cfg, "root", None),
|
||||
revision=getattr(dataset_cfg, "revision", None),
|
||||
)
|
||||
context["dataset"] = {
|
||||
"repo_id": dataset_repo_id,
|
||||
"episodes": meta.total_episodes,
|
||||
"frames": meta.total_frames,
|
||||
"fps": meta.fps,
|
||||
"tasks": [str(task) for task in meta.tasks.index],
|
||||
}
|
||||
context["robot_type"] = meta.robot_type
|
||||
context["cameras"] = [key.split(".")[-1] for key in meta.camera_keys]
|
||||
except Exception as e: # noqa: BLE001 — dataset details are optional, never fail the push
|
||||
logging.warning(
|
||||
f"Could not load dataset metadata for '{dataset_repo_id}'; those sections will be "
|
||||
f"omitted from the model card. ({e})"
|
||||
)
|
||||
|
||||
return context
|
||||
|
||||
|
||||
class ActionSelectKwargs(TypedDict, total=False):
|
||||
noise: Tensor | None
|
||||
|
||||
@@ -290,7 +228,7 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
||||
self.save_pretrained(saved_path) # Calls _save_pretrained and stores model tensors
|
||||
|
||||
card = self.generate_model_card(
|
||||
cfg.dataset.repo_id, self.config.type, self.config.license, self.config.tags, cfg=cfg
|
||||
cfg.dataset.repo_id, self.config.type, self.config.license, self.config.tags
|
||||
)
|
||||
card.save(str(saved_path / "README.md"))
|
||||
|
||||
@@ -308,20 +246,9 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
||||
logging.info(f"Model pushed to {commit_info.repo_url.url}")
|
||||
|
||||
def generate_model_card(
|
||||
self,
|
||||
dataset_repo_id: str,
|
||||
model_type: str,
|
||||
license: str | None,
|
||||
tags: list[str] | None,
|
||||
cfg: TrainPipelineConfig | None = None,
|
||||
self, dataset_repo_id: str, model_type: str, license: str | None, tags: list[str] | None
|
||||
) -> ModelCard:
|
||||
base_model_mapping = {
|
||||
"smolvla": "lerobot/smolvla_base",
|
||||
"pi0": "lerobot/pi0_base",
|
||||
"pi05": "lerobot/pi05_base",
|
||||
"pi0_fast": "lerobot/pi0fast-base",
|
||||
"xvla": "lerobot/xvla-base",
|
||||
}
|
||||
base_model = "lerobot/smolvla_base" if model_type == "smolvla" else None # Set a base model
|
||||
|
||||
card_data = ModelCardData(
|
||||
license=license or "apache-2.0",
|
||||
@@ -330,20 +257,13 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
||||
tags=list(set(tags or []).union({"robotics", "lerobot", model_type})),
|
||||
model_name=model_type,
|
||||
datasets=dataset_repo_id,
|
||||
base_model=base_model_mapping.get(model_type),
|
||||
base_model=base_model,
|
||||
)
|
||||
|
||||
context = _build_card_context(
|
||||
cfg, dataset_repo_id, self.config.input_features, self.config.output_features
|
||||
)
|
||||
# Used by the template to pre-fill commands and the "Fine-tuned from" line.
|
||||
context["policy_repo_id"] = getattr(self.config, "repo_id", None)
|
||||
context["base_model"] = base_model_mapping.get(model_type)
|
||||
|
||||
template_card = (
|
||||
files("lerobot.templates").joinpath("lerobot_modelcard_template.md").read_text(encoding="utf-8")
|
||||
)
|
||||
card = ModelCard.from_template(card_data, template_str=template_card, **context)
|
||||
card = ModelCard.from_template(card_data, template_str=template_card)
|
||||
card.validate()
|
||||
return card
|
||||
|
||||
|
||||
@@ -0,0 +1,286 @@
|
||||
# Remote Inference Architecture
|
||||
|
||||
How `lerobot-policy-server` and `lerobot-rollout --inference.type=remote` decouple GPU-bound policy inference from high-frequency robot control over Zenoh.
|
||||
|
||||
This document explains the **internals** — the wire protocol, threading models, state machines, and safety invariants. For the user-facing guide (CLI quickstarts, deployment), see [`docs/source/remote_inference.mdx`](../../../docs/source/remote_inference.mdx).
|
||||
|
||||
## 1. The problem and the shape of the solution
|
||||
|
||||
Running a large policy (Pi0-class, ~150 ms inference) inside a 33 ms control loop doesn't work, and putting a GPU next to every robot doesn't scale. LeRobot already solved the _local_ version of this problem: `RTCInferenceEngine` runs inference in a background **thread** that fills a thread-safe `ActionQueue`, while the control loop pops one action per tick.
|
||||
|
||||
Remote inference is **that same architecture with the thread boundary replaced by a network boundary**:
|
||||
|
||||
```
|
||||
local RTC: control loop ──ActionQueue── inference thread (same process, same GPU)
|
||||
remote: control loop ──ActionQueue── network worker ══zenoh══ policy server (GPU, elsewhere)
|
||||
```
|
||||
|
||||
Three design commitments follow from this:
|
||||
|
||||
- **The client is a backend, not a CLI.** `RemoteInferenceEngine` plugs into the existing `InferenceEngine` seam (`rollout/inference/base.py`), so every rollout strategy (base, sentry, highlight, dagger, episodic) gets network inference — including dataset recording, pause/resume, and safe teardown — without changing a line.
|
||||
- **The client is weightless.** No policy weights, no policy processors on the edge. `--policy.path` resolves to a config-only `PreTrainedConfig` used for pre-flight validation and action ordering.
|
||||
- **The server is stateless per request.** All chunk state (RTC prefixes, latency tracking, delay computation) lives client-side in the existing `ActionQueue`/`LatencyTracker`. The client ships prefixes + a delay hint with every observation, so a server crash loses zero control state.
|
||||
|
||||
## 2. Component map
|
||||
|
||||
```mermaid
|
||||
flowchart LR
|
||||
subgraph EDGE["Edge (per robot, weightless)"]
|
||||
R[Robot HW] --> S["Rollout strategy<br/>(sentry / dagger / ...)"]
|
||||
S -->|"notify_observation()"| E[RemoteInferenceEngine]
|
||||
E -->|"get_action()"| S
|
||||
S -->|actions| R
|
||||
E --- AQ[("ActionQueue<br/>(chunk buffer)")]
|
||||
end
|
||||
|
||||
subgraph NET["Transport"]
|
||||
Z["zenohd router(s)<br/>(robots dial out, mTLS + ACL)"]
|
||||
end
|
||||
|
||||
subgraph GPU["GPU pod (one model · one device · one process)"]
|
||||
PS[PolicyServer]
|
||||
PS --- SR["SessionRegistry<br/>(per-client mailboxes + pipelines)"]
|
||||
PS --- W["Inference worker<br/>(1 thread, owns GPU)"]
|
||||
W --- P["PreTrainedPolicy<br/>(pre-warmed)"]
|
||||
end
|
||||
|
||||
E <-->|"obs ↑ / chunks ↓ (pub/sub)<br/>status · session · reset (queryables)"| Z
|
||||
Z <--> PS
|
||||
```
|
||||
|
||||
One server process = one pre-warmed `(model, revision, dtype, device)` serving up to `max_sessions` robots. Scaling out = more pods; clients rejected with the current load retry another replica.
|
||||
|
||||
## 3. Where the network cut goes
|
||||
|
||||
The local RTC pipeline is split at the cheapest, most hardware-coupled point. Everything policy-coupled (resize, normalize, tokenize) runs server-side with the **canonical training-time processors**, so serve-time preprocessing is byte-identical to train-time:
|
||||
|
||||
```
|
||||
robot obs (processed dict)
|
||||
→ build_dataset_frame(...) CLIENT cheap, hardware-coupled
|
||||
→ rename_map applied to keys CLIENT wire format = canonical policy keys
|
||||
══════════════════════ network (msgpack + JPEG) ══════════════════════
|
||||
→ prepare_observation_for_inference(...) SERVER tensors, batch dim, device
|
||||
→ per-session preprocessor(...) SERVER stateful within the request
|
||||
→ policy.predict_action_chunk(obs, delay, prefix) SERVER pure for allowlisted policies
|
||||
→ per-session postprocessor(...) SERVER reads state cached at preprocess
|
||||
══════════════════════ network ══════════════════════
|
||||
→ ActionQueue.merge(original, processed, delay, idx_before) CLIENT
|
||||
```
|
||||
|
||||
The reply carries **both** the model-space (`chunk_model`) and robot-space (`chunk_robot`) chunks because `ActionQueue.merge` needs both, and the next request's relative-action prefix re-anchoring needs the robot-space tail.
|
||||
|
||||
## 4. Wire protocol
|
||||
|
||||
### 4.1 Key-expression schema (Zenoh)
|
||||
|
||||
```
|
||||
@lerobot/<model_slug>/<revision>/<task_slug>/<client_uuid>/obs client → server pub/sub
|
||||
@lerobot/<model_slug>/<revision>/<task_slug>/<client_uuid>/action server → client pub/sub
|
||||
@lerobot/<model_slug>/<revision>/<task_slug>/status queryable (capabilities)
|
||||
@lerobot/<model_slug>/<revision>/<task_slug>/session queryable (open / close)
|
||||
@lerobot/<model_slug>/<revision>/<task_slug>/<client_uuid>/reset queryable (episode boundary)
|
||||
@lerobot/<model_slug>/<revision>/<task_slug>/<client_uuid>/alive liveliness token (client)
|
||||
@lerobot/<model_slug>/<revision>/<task_slug>/server/alive liveliness token (server)
|
||||
```
|
||||
|
||||
`@lerobot` is a **verbatim chunk**: wildcards never match it, so third-party `**` subscribers on a shared router cannot scrape the tree. User-supplied segments are sanitized (`sanitize_key_segment`), and the server subscribes with single-depth wildcards only (`.../*/obs`, never `**`).
|
||||
|
||||
Data plane = pub/sub (a late chunk is still usable; a timed-out query reply is not). Control plane = queryables with explicit timeouts (the rmw_zenoh pattern). QoS (`zenoh_utils.py`): actions are `RELIABLE + congestion DROP + express + INTERACTIVE_HIGH` — **never BLOCK**, so one dead robot uplink can never stall the server's publish path; a dropped chunk is recoverable because the client buffer keeps the robot moving.
|
||||
|
||||
### 4.2 Messages
|
||||
|
||||
Every data-plane message carries a **packed little-endian attachment header** (27 bytes, parsed without touching the body):
|
||||
|
||||
| field | type | meaning |
|
||||
| ---------------- | ---- | --------------------------------------------------------- |
|
||||
| `schema_version` | u16 | negotiated at session open; additive-only body evolution |
|
||||
| `msg_type` | u8 | OBS / CHUNK / EVENT |
|
||||
| `seq_id` | u64 | per-session monotonic; echoed in the chunk |
|
||||
| `episode_id` | u32 | bumped by `reset()` |
|
||||
| `client_mono_ns` | i64 | client monotonic clock — **opaque to the server, echoed** |
|
||||
| `session_epoch` | u32 | bumped per (re)connect; stale-epoch chunks dropped |
|
||||
|
||||
Bodies are msgpack (`codec.py`): tensors as raw little-endian bytes + dtype + shape, images JPEG (RGB convention enforced inside the codec; `jpeg_quality=0` = raw). No pickle anywhere — nothing on the wire can carry code.
|
||||
|
||||
**Clock iron rule:** wall-clock instants never cross machines. The client computes RTT from its own monotonic clock via the echoed `client_mono_ns`; the server reports only **durations** (`queue_wait_ms`, `inference_ms`).
|
||||
|
||||
### 4.3 Session lifecycle
|
||||
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
participant C as RemoteInferenceEngine
|
||||
participant S as PolicyServer
|
||||
|
||||
C->>S: GET status (timeout 2s)
|
||||
S-->>C: capabilities (model, action_names, cameras, chunk_size, supports_rtc, ...)
|
||||
C->>S: GET session {op: open, action_names, cameras, state_dim, fps, rtc, task}
|
||||
Note over S: validate (hard: action name ORDER,<br/>cameras, state_dim, schema, capacity)
|
||||
S-->>C: SessionAck {session_id, warnings, rtc_execution_horizon, ...}
|
||||
Note over C,S: both declare liveliness tokens
|
||||
|
||||
loop self-clocked by buffer_time_s (one-in-flight)
|
||||
C->>S: PUB obs {state, images, delay_steps, prefix_model, prefix_robot} + header
|
||||
Note over S: latest-only mailbox → worker →<br/>preprocess → predict_action_chunk → postprocess
|
||||
S-->>C: PUB chunk {chunk_model, chunk_robot, durations} + echoed header
|
||||
Note over C: validate (episode, epoch) → ActionQueue.merge(..., idx_before)
|
||||
end
|
||||
|
||||
C->>S: GET reset {episode_id} (episode boundary, acked)
|
||||
C->>S: GET session {op: close} (graceful stop)
|
||||
```
|
||||
|
||||
The **action-name order check is a hard reject**: it is the contract that maps chunk columns to motors. A mismatch means wrong-joint commands, so the session never opens.
|
||||
|
||||
## 5. The client: `RemoteInferenceEngine`
|
||||
|
||||
File: `src/lerobot/rollout/inference/remote.py`, registered as `--inference.type=remote` (`RemoteInferenceConfig` in `factory.py`).
|
||||
|
||||
### 5.1 Threading model
|
||||
|
||||
| thread | role |
|
||||
| ---------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| main (strategy loop) | `notify_observation()` → latest-only slot; `get_action()` → `ActionQueue.get()` + staleness check + fallback. **Never any I/O.** |
|
||||
| network worker (1) | gate on `buffer_time_s` → snapshot `(seq, episode, epoch)` then `idx_before` + RTC prefixes → publish obs → await chunk (timeout) → revalidate → merge. Owns the state machine and reconnects. |
|
||||
| zenoh callback threads | deposit-only: chunk → bounded queue; server liveliness → event. |
|
||||
|
||||
**One-in-flight is a correctness requirement, not a tuning choice.** `merge(..., idx_before)` validates against the consumption index snapshotted at send time; two in-flight requests would carry conflicting snapshots and corrupt both RTC-replace and append modes. The worker therefore publishes one observation, waits for its chunk (or timeout), then sends the next. A late chunk is accepted only if it answers the latest outstanding `seq_id` _and_ the current `(episode, epoch)`.
|
||||
|
||||
### 5.2 The request cycle
|
||||
|
||||
```
|
||||
queue playback ≤ buffer_time_s? (self-clocking: ~1–4 Hz, not the 30 Hz control rate)
|
||||
├─ snapshot (seq, episode, epoch)
|
||||
├─ snapshot idx_before, prefix_model = queue.get_left_over()[:H],
|
||||
│ prefix_robot = queue.get_processed_left_over()[:H]
|
||||
├─ revalidate (episode, epoch) unchanged ← a reset racing the snapshot skips the cycle
|
||||
├─ delay_steps = ceil(LatencyTracker.max() / dt)
|
||||
├─ publish obs + header
|
||||
├─ await chunk (request_timeout_s)
|
||||
├─ revalidate (episode, epoch) under _anchor_lock ← a stale chunk can never survive a reset
|
||||
└─ merge(chunk_model, chunk_robot, ceil(measured_latency/dt), idx_before); update anchor
|
||||
```
|
||||
|
||||
Because the `LatencyTracker` samples are full network-inclusive cycle times, RTT compensation falls out for free — the same `delay`-trimming machinery local RTC uses absorbs network latency as just more delay.
|
||||
|
||||
### 5.3 Fail-safe state machine
|
||||
|
||||
```mermaid
|
||||
stateDiagram-v2
|
||||
[*] --> CONNECTING
|
||||
CONNECTING --> STREAMING: first merge
|
||||
STREAMING --> DEGRADED: request timeouts,<br/>queue still has actions
|
||||
DEGRADED --> STREAMING: merge
|
||||
DEGRADED --> STALLED: queue empty or<br/>max_action_age_s hit
|
||||
STALLED --> RECONNECTING: timeout streak /<br/>server liveliness drop
|
||||
DEGRADED --> RECONNECTING: timeout streak /<br/>server liveliness drop
|
||||
RECONNECTING --> STREAMING: re-handshake OK<br/>(epoch++)
|
||||
RECONNECTING --> DEAD: offline > max_offline_s,<br/>capability/model mismatch
|
||||
DEAD --> [*]: failed=True → shutdown_event<br/>→ strategy teardown
|
||||
```
|
||||
|
||||
- **DEGRADED**: the chunk buffer _is_ the fault tolerance — 1–3 s of buffered actions makes network blips and clean server drains invisible to the robot.
|
||||
- **Staleness bound** (`max_action_age_s`): `get_action` refuses any action whose source observation is too old, bounding open-loop execution after a stall. Then the **fallback ladder** applies: `hold` (return `None`; the robot holds), `repeat_last`, or `zero` (the safe stop for velocity-controlled robots).
|
||||
- **Watchdog layering**: per-request timeout (catches a _hung-but-connected_ server) → server liveliness token (catches a dead server/router) → staleness bound (the robot-side invariant that holds regardless of why data stopped).
|
||||
- **DEAD** is reserved for hard failures: offline beyond `max_offline_s` with no successful merge (a server that handshakes but never delivers chunks still runs out of budget), or a contract violation on reconnect (model/revision changed, RTC capability flipped — never execute wrong-model chunks). It triggers the exact mechanism local RTC uses: `failed=True` + the global `shutdown_event`, so the existing teardown (return-to-initial-pose) runs unchanged.
|
||||
- **Pause/resume** (DAgger): `pause()` stops publishing; the queue stays intact. A pause during an outage freezes the offline budget so a human correction can never be aborted by `max_offline_s`.
|
||||
|
||||
### 5.4 Episode boundaries
|
||||
|
||||
`reset()` (control thread) atomically — under the same lock the merge path takes — clears the `ActionQueue`, nulls the staleness anchor, bumps `episode_id`, and invalidates the observation slot (the previous episode's final frame must not seed the new one). The worker sends an acked `reset` query, and the next observation header carries the new `episode_id` anyway — so a lost ack costs nothing (the server is stateless per request).
|
||||
|
||||
## 6. The server: `PolicyServer`
|
||||
|
||||
Files: `src/lerobot/policy_server/`. Entry point: `lerobot-policy-server --manifest server.yaml` (draccus dataclasses in `manifest.py`).
|
||||
|
||||
### 6.1 Concurrency model
|
||||
|
||||
zenoh-python is thread-based (no asyncio); callbacks must be deposit-only:
|
||||
|
||||
```
|
||||
zenoh subscriber (.../*/obs) inference worker (1 thread, owns GPU)
|
||||
deposit-only callback: loop:
|
||||
session.deposit(header, body) ──► scheduler picks next session with pending obs
|
||||
(per-client latest-only mailbox) decode → episode-boundary check
|
||||
preprocess → predict_action_chunk(delay, prefix)
|
||||
control queryables (status / postprocess → encode
|
||||
session / reset): validate, publisher.put(.../<uuid>/action)
|
||||
mutate registry, reply inline
|
||||
|
||||
liveliness subscriber (.../*/alive): mark sessions for GC on token DELETE
|
||||
```
|
||||
|
||||
- **Latest-only mailboxes**: the newest observation wins; superseded requests are counted and reported in the next reply (`superseded_seqs`), so drops are visible client-side. The client decides _when_ to request; the server never second-guesses observation content.
|
||||
- **Single inference worker** + round-robin over ready sessions: every ready session gets exactly one inference per cycle — starvation is structurally impossible. Overload degrades into longer cycle times → larger (but correct) client `delay_steps` → eventually the client staleness bound trips and the robot holds. Safe by construction.
|
||||
- The `Scheduler` seam (`scheduler.py`) exists so cross-session micro-batching can land later without redesign (blocked today on `predict_action_chunk` taking a _scalar_ `inference_delay`).
|
||||
- `_inference_lock` serializes the worker's predict path against episode resets arriving on queryable threads (in exclusive mode a `policy.reset()` mid-predict would corrupt the in-flight request).
|
||||
|
||||
### 6.2 Multi-tenancy: engineered, not assumed
|
||||
|
||||
Sharing one policy instance across sessions is only safe when `predict_action_chunk` touches no cross-request instance state. That property is **verified per family and encoded as a registry** (`validation.py`) — never inferred:
|
||||
|
||||
| class | policies | mode | why |
|
||||
| --------------- | ------------------------------------------------- | ----------- | ----------------------------------------------------------------------------------- |
|
||||
| chunk-stateless | `act`, `pi0`, `pi05`, `smolvla` (`n_obs_steps=1`) | `shared` | chunk call is pure (smolvla overwrites its 1-deep queue with the request's own obs) |
|
||||
| chunk-stateful | `diffusion` (and `smolvla` with `n_obs_steps>1`) | `exclusive` | chunk call reads `select_action`-fed `_queues` → server populates them per request |
|
||||
| no chunk API | `sac`, `tdmpc`, ... (no `predict_action_chunk`) | refused | nothing to serve |
|
||||
| unverified | any other chunk-API policy | `exclusive` | a manifest can force `exclusive`, but never `shared` for an unverified policy |
|
||||
|
||||
The real multi-tenancy hazard is **processor state**, not just policy purity: `RelativeActionsProcessorStep` caches `_last_state` at preprocess and the postprocessor reads it back. The server therefore builds a **fresh pre/post pipeline pair per session** — two robots at different joint positions can never cross-contaminate each other's action conversions. `policy.reset()` is **never** called in shared mode (it is global to the shared instance).
|
||||
|
||||
### 6.3 Statelessness and the RTC prefix
|
||||
|
||||
The server holds no cross-request control state. Each observation ships everything inference needs:
|
||||
|
||||
- `inference_delay_steps` — computed client-side from network-inclusive latency.
|
||||
- `prefix_model` — the unexecuted tail of the previous chunk in model space (feeds `prev_chunk_left_over`).
|
||||
- `prefix_robot` — the same tail in robot space. For relative-action policies the server **re-anchors** it against the state cached by _this request's_ preprocess (`reanchor_relative_rtc_prefix`, mirroring `rtc.py`), so the prefix is expressed relative to where the robot actually is now.
|
||||
|
||||
Consequences: reconnects are trivial, horizontal scaling is trivial, and a `kill -9` on the server loses nothing the client can't re-send.
|
||||
|
||||
### 6.4 Episode and reconnect hygiene
|
||||
|
||||
- Fresh sessions start at the `episode_id = -1` sentinel: the **first** observation of any session always triggers the boundary branch (pipelines reset; exclusive policies `reset()`), so a mid-episode reconnect can never inherit stale state.
|
||||
- Session replacement is identity-checked (`SessionRegistry.remove(expected=...)`): a GC sweep that snapshotted an old session can never tear down its just-handshaked replacement.
|
||||
- Liveliness GC double-checks with an explicit liveliness `get` before closing: the token key is per-client (not per-epoch), so a _late_ DELETE from a previous incarnation must not kill the live session.
|
||||
- Drain (`SIGTERM`): drop the liveliness token first (clients ride their buffers), finish the in-flight inference, undeclare the control surface, then close. Clients reconnect to another replica invisibly.
|
||||
|
||||
## 7. Latency budget (why the transport is never the bottleneck)
|
||||
|
||||
| stage | LAN | WAN (50 ms RTT) |
|
||||
| ------------------------------ | ------------- | --------------- |
|
||||
| JPEG encode + serialize (edge) | 2–9 ms | 2–9 ms |
|
||||
| uplink | ~2 ms | ~54 ms |
|
||||
| decode + canonical preprocess | 4–10 ms | 4–10 ms |
|
||||
| **inference** | **15–150 ms** | **15–150 ms** |
|
||||
| postprocess + downlink + merge | ~2 ms | ~27 ms |
|
||||
|
||||
Inference dominates (60–85% on LAN). At 30 fps a WAN deployment lands `delay_steps ≈ 4–8`, comfortably inside RTC execution horizons: WAN degrades smoothness parameters, never correctness. Requests are self-clocked by `buffer_time_s` to ~1–4 Hz per robot, so 300 robots cost ~0.3–10 Mbps each.
|
||||
|
||||
Capacity per GPU: `N_max ≈ 0.8 / (request_rate × inference_time)` → ~40 ACT-class or ~5 Pi0-class clients; `max_sessions` enforces it at session open (rejected clients receive the current load and retry another replica).
|
||||
|
||||
## 8. Observability & reproducibility
|
||||
|
||||
The contract is **fully logged + replayable**, not "deterministic" (no seed controls hardware or network jitter):
|
||||
|
||||
- **Client = source of truth**: recording strategies persist observations + executed actions as usual; the engine tracks `(session_id, seq_id, episode_id)` and per-cycle stats.
|
||||
- **Server**: one JSON audit line per request on the `lerobot.policy_server.audit` logger — `{session_id, client_uuid, seq_id, episode_id, queue_wait_ms, inference_ms, superseded, outcome}` — plus `/healthz` and Prometheus-style `/metrics`, and an optional bounded raw request/response capture (`debug.capture_dir`) for byte-exact offline replay.
|
||||
- Every hop shares `(session_id, seq_id)`, so joining a robot-side stutter to a server-side cause is mechanical.
|
||||
|
||||
## 9. File map
|
||||
|
||||
| path | contents |
|
||||
| ---------------------------------- | ---------------------------------------------------------------------------------------------------- |
|
||||
| `policy_server/schema.py` | wire messages, packed header, key-expression schema + sanitizer |
|
||||
| `policy_server/codec.py` | msgpack bodies, tensor codec (LE bytes), JPEG image codec (RGB convention) |
|
||||
| `policy_server/manifest.py` | draccus config: model, zenoh endpoints/TLS, serving mode, capacity, RTC, health |
|
||||
| `policy_server/validation.py` | serving-mode registry + session-open capability matrix |
|
||||
| `policy_server/session.py` | per-client `Session` (pipelines, latest-only mailbox, stats) + identity-safe registry |
|
||||
| `policy_server/scheduler.py` | `Scheduler` seam; `RoundRobinScheduler` |
|
||||
| `policy_server/zenoh_utils.py` | config builder, QoS profiles, lazy import with install hint |
|
||||
| `policy_server/server.py` | `PolicyServer`: zenoh surface, inference worker, GC, warmup, drain, health/metrics |
|
||||
| `rollout/inference/remote.py` | `RemoteInferenceEngine` (the edge client) |
|
||||
| `rollout/inference/factory.py` | `RemoteInferenceConfig`, `FallbackMode`, factory dispatch |
|
||||
| `scripts/lerobot_policy_server.py` | console entry point (`--manifest` → draccus `--config_path`) |
|
||||
| `tests/policy_server/` | codec/schema/validation/scheduler/session units, server logic, zenoh loopback + chaos, golden parity |
|
||||
|
||||
The golden parity test (`tests/policy_server/test_golden_parity.py`) is the standing contract: the remote request path (encode → decode → `run_inference_request` → encode → decode → merge) must produce **byte-identical** action queues to the local RTC compute path on identical inputs.
|
||||
@@ -0,0 +1,53 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Multi-client GPU policy serving over Zenoh (``lerobot-policy-server``).
|
||||
|
||||
The wire schema (:mod:`.schema`) and codecs (:mod:`.codec`) are shared
|
||||
with the edge-side :class:`~lerobot.rollout.inference.remote.RemoteInferenceEngine`.
|
||||
Heavy/optional imports (msgpack, zenoh, torch server) are deferred so the
|
||||
schema stays importable without the ``async`` extra.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from .manifest import (
|
||||
DebugSpec,
|
||||
ModelSpec,
|
||||
PolicyServerManifest,
|
||||
ZenohSpec,
|
||||
)
|
||||
from .schema import SCHEMA_VERSION, MsgHeader, service_prefix
|
||||
|
||||
__all__ = [
|
||||
"SCHEMA_VERSION",
|
||||
"DebugSpec",
|
||||
"ModelSpec",
|
||||
"MsgHeader",
|
||||
"PolicyServer",
|
||||
"PolicyServerManifest",
|
||||
"ZenohSpec",
|
||||
"codec",
|
||||
"service_prefix",
|
||||
]
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
import importlib
|
||||
|
||||
if name == "PolicyServer":
|
||||
return importlib.import_module(".server", __name__).PolicyServer
|
||||
if name == "codec":
|
||||
return importlib.import_module(".codec", __name__)
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
@@ -0,0 +1,262 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""MessagePack codecs for the remote-inference wire schema.
|
||||
|
||||
Encoding rules:
|
||||
- Tensors are raw little-endian bytes + dtype + shape (msgpack's ``bin``
|
||||
type), so decoding is a zero-parse ``np.frombuffer``.
|
||||
- Images are JPEG by default (``jpeg_quality=0`` sends raw bytes). The
|
||||
in-memory convention on both ends is **RGB** uint8 HWC; the OpenCV
|
||||
BGR↔RGB conversion happens inside this module only.
|
||||
- Decoders are tolerant: unknown keys are ignored, missing optional keys
|
||||
take dataclass defaults — schema evolution is additive-only.
|
||||
- No pickle anywhere: nothing in this codec can carry code.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
import msgpack
|
||||
except ImportError as e: # pragma: no cover
|
||||
raise ImportError(
|
||||
"Remote inference requires the 'async' extra: pip install 'lerobot[async]' (eclipse-zenoh + msgpack)"
|
||||
) from e
|
||||
|
||||
from .schema import (
|
||||
IMAGE_CODEC_JPEG,
|
||||
IMAGE_CODEC_RAW,
|
||||
ActionChunkMsg,
|
||||
ObservationMsg,
|
||||
ResetAckMsg,
|
||||
ResetMsg,
|
||||
SessionAckMsg,
|
||||
SessionCloseMsg,
|
||||
SessionOpenMsg,
|
||||
StatusMsg,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tensor codec
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _to_little_endian(arr: np.ndarray) -> np.ndarray:
|
||||
if arr.dtype.byteorder == ">":
|
||||
arr = arr.astype(arr.dtype.newbyteorder("<"))
|
||||
return np.ascontiguousarray(arr)
|
||||
|
||||
|
||||
def encode_tensor(arr: np.ndarray | None) -> dict[str, Any] | None:
|
||||
"""Encode an ndarray as raw little-endian bytes + dtype + shape."""
|
||||
if arr is None:
|
||||
return None
|
||||
arr = np.asarray(arr)
|
||||
# Record the shape before ascontiguousarray, which promotes 0-d to 1-d.
|
||||
shape = list(arr.shape)
|
||||
arr = _to_little_endian(arr)
|
||||
return {"dtype": arr.dtype.str, "shape": shape, "data": arr.tobytes()}
|
||||
|
||||
|
||||
def decode_tensor(obj: dict[str, Any] | None) -> np.ndarray | None:
|
||||
if obj is None:
|
||||
return None
|
||||
dtype = np.dtype(obj["dtype"])
|
||||
if dtype.hasobject:
|
||||
raise ValueError(f"Refusing object dtype {dtype} on the wire")
|
||||
arr = np.frombuffer(obj["data"], dtype=dtype).reshape(obj["shape"])
|
||||
# frombuffer returns a read-only view; copy so downstream torch.from_numpy works.
|
||||
return arr.copy()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Image codec (RGB uint8 HWC on both ends)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def encode_image(img: np.ndarray, jpeg_quality: int = 90) -> dict[str, Any]:
|
||||
"""Encode an RGB uint8 HWC image; ``jpeg_quality=0`` keeps it raw."""
|
||||
img = np.asarray(img)
|
||||
if img.dtype != np.uint8 or img.ndim != 3 or img.shape[2] != 3:
|
||||
raise ValueError(f"Expected uint8 HWC RGB image, got dtype={img.dtype} shape={img.shape}")
|
||||
if jpeg_quality <= 0:
|
||||
return {"codec": IMAGE_CODEC_RAW, "shape": list(img.shape), "data": _to_little_endian(img).tobytes()}
|
||||
ok, buf = cv2.imencode(
|
||||
".jpg", cv2.cvtColor(img, cv2.COLOR_RGB2BGR), [int(cv2.IMWRITE_JPEG_QUALITY), int(jpeg_quality)]
|
||||
)
|
||||
if not ok:
|
||||
raise ValueError("JPEG encoding failed")
|
||||
return {"codec": IMAGE_CODEC_JPEG, "data": buf.tobytes()}
|
||||
|
||||
|
||||
def decode_image(obj: dict[str, Any]) -> np.ndarray:
|
||||
"""Decode to an RGB uint8 HWC image."""
|
||||
codec = obj.get("codec", IMAGE_CODEC_JPEG)
|
||||
if codec == IMAGE_CODEC_RAW:
|
||||
return np.frombuffer(obj["data"], dtype=np.uint8).reshape(obj["shape"]).copy()
|
||||
if codec == IMAGE_CODEC_JPEG:
|
||||
bgr = cv2.imdecode(np.frombuffer(obj["data"], dtype=np.uint8), cv2.IMREAD_COLOR)
|
||||
if bgr is None:
|
||||
raise ValueError("JPEG decoding failed")
|
||||
return cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
|
||||
raise ValueError(f"Unknown image codec: {codec!r}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# msgpack helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _packb(obj: dict[str, Any]) -> bytes:
|
||||
return msgpack.packb(obj, use_bin_type=True)
|
||||
|
||||
|
||||
def _unpackb(data: bytes) -> dict[str, Any]:
|
||||
return msgpack.unpackb(data, raw=False)
|
||||
|
||||
|
||||
def decode_raw(data: bytes) -> dict[str, Any]:
|
||||
"""Decode a body to a plain dict (e.g. to peek a control-plane ``op``)."""
|
||||
return _unpackb(data)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Data-plane messages
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def encode_observation(msg: ObservationMsg) -> bytes:
|
||||
return _packb(
|
||||
{
|
||||
"state": encode_tensor(msg.state),
|
||||
"images": {name: encode_image(img, msg.jpeg_quality) for name, img in msg.images.items()},
|
||||
"task": msg.task,
|
||||
"inference_delay_steps": int(msg.inference_delay_steps),
|
||||
"prefix_model": encode_tensor(msg.prefix_model),
|
||||
"prefix_robot": encode_tensor(msg.prefix_robot),
|
||||
"episode_start": bool(msg.episode_start),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def decode_observation(data: bytes) -> ObservationMsg:
|
||||
obj = _unpackb(data)
|
||||
return ObservationMsg(
|
||||
state=decode_tensor(obj.get("state")),
|
||||
images={name: decode_image(img) for name, img in obj.get("images", {}).items()},
|
||||
task=obj.get("task", ""),
|
||||
inference_delay_steps=obj.get("inference_delay_steps", 0),
|
||||
prefix_model=decode_tensor(obj.get("prefix_model")),
|
||||
prefix_robot=decode_tensor(obj.get("prefix_robot")),
|
||||
episode_start=obj.get("episode_start", False),
|
||||
)
|
||||
|
||||
|
||||
def encode_action_chunk(msg: ActionChunkMsg) -> bytes:
|
||||
return _packb(
|
||||
{
|
||||
"seq_id_echo": int(msg.seq_id_echo),
|
||||
"client_mono_ns_echo": int(msg.client_mono_ns_echo),
|
||||
"episode_id_echo": int(msg.episode_id_echo),
|
||||
"chunk_model": encode_tensor(msg.chunk_model),
|
||||
"chunk_robot": encode_tensor(msg.chunk_robot),
|
||||
"queue_wait_ms": float(msg.queue_wait_ms),
|
||||
"inference_ms": float(msg.inference_ms),
|
||||
"superseded_seqs": int(msg.superseded_seqs),
|
||||
"server_load": float(msg.server_load),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def decode_action_chunk(data: bytes) -> ActionChunkMsg:
|
||||
obj = _unpackb(data)
|
||||
return ActionChunkMsg(
|
||||
seq_id_echo=obj.get("seq_id_echo", 0),
|
||||
client_mono_ns_echo=obj.get("client_mono_ns_echo", 0),
|
||||
episode_id_echo=obj.get("episode_id_echo", 0),
|
||||
chunk_model=decode_tensor(obj.get("chunk_model")),
|
||||
chunk_robot=decode_tensor(obj.get("chunk_robot")),
|
||||
queue_wait_ms=obj.get("queue_wait_ms", 0.0),
|
||||
inference_ms=obj.get("inference_ms", 0.0),
|
||||
superseded_seqs=obj.get("superseded_seqs", 0),
|
||||
server_load=obj.get("server_load", 0.0),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Control-plane messages (flat scalar/list/dict fields → generic codec)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _encode_flat(msg: Any) -> bytes:
|
||||
return _packb(dict(vars(msg).items()))
|
||||
|
||||
|
||||
def _decode_flat(cls: type, data: bytes) -> Any:
|
||||
obj = _unpackb(data)
|
||||
known = set(cls.__dataclass_fields__)
|
||||
return cls(**{k: v for k, v in obj.items() if k in known})
|
||||
|
||||
|
||||
def encode_session_open(msg: SessionOpenMsg) -> bytes:
|
||||
return _encode_flat(msg)
|
||||
|
||||
|
||||
def decode_session_open(data: bytes) -> SessionOpenMsg:
|
||||
return _decode_flat(SessionOpenMsg, data)
|
||||
|
||||
|
||||
def encode_session_ack(msg: SessionAckMsg) -> bytes:
|
||||
return _encode_flat(msg)
|
||||
|
||||
|
||||
def decode_session_ack(data: bytes) -> SessionAckMsg:
|
||||
return _decode_flat(SessionAckMsg, data)
|
||||
|
||||
|
||||
def encode_status(msg: StatusMsg) -> bytes:
|
||||
return _encode_flat(msg)
|
||||
|
||||
|
||||
def decode_status(data: bytes) -> StatusMsg:
|
||||
return _decode_flat(StatusMsg, data)
|
||||
|
||||
|
||||
def encode_reset(msg: ResetMsg) -> bytes:
|
||||
return _encode_flat(msg)
|
||||
|
||||
|
||||
def decode_reset(data: bytes) -> ResetMsg:
|
||||
return _decode_flat(ResetMsg, data)
|
||||
|
||||
|
||||
def encode_reset_ack(msg: ResetAckMsg) -> bytes:
|
||||
return _encode_flat(msg)
|
||||
|
||||
|
||||
def decode_reset_ack(data: bytes) -> ResetAckMsg:
|
||||
return _decode_flat(ResetAckMsg, data)
|
||||
|
||||
|
||||
def encode_session_close(msg: SessionCloseMsg) -> bytes:
|
||||
return _encode_flat(msg)
|
||||
|
||||
|
||||
def decode_session_close(data: bytes) -> SessionCloseMsg:
|
||||
return _decode_flat(SessionCloseMsg, data)
|
||||
@@ -0,0 +1,139 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Policy-server manifest: one process = one (model, revision, dtype, device) on one GPU.
|
||||
|
||||
Loaded from YAML via ``lerobot-policy-server --manifest server.yaml`` (or
|
||||
individual ``--model.repo_or_path=...`` CLI overrides through draccus).
|
||||
Dynamic model loading is deliberately unsupported: pre-warmed processes
|
||||
keep capacity planning honest and keep code-carrying payloads off the wire.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SERVING_MODE_AUTO = "auto"
|
||||
SERVING_MODE_SHARED = "shared"
|
||||
SERVING_MODE_EXCLUSIVE = "exclusive"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelSpec:
|
||||
"""Which policy this process serves, and where it runs."""
|
||||
|
||||
repo_or_path: str = ""
|
||||
revision: str = "main"
|
||||
# Optional torch dtype cast applied after load (e.g. "bfloat16").
|
||||
dtype: str | None = None
|
||||
device: str = "cuda"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ZenohSpec:
|
||||
"""Transport endpoints and security.
|
||||
|
||||
Robots and servers both *dial out* to a ``zenohd`` router in
|
||||
production (``mode=client``). ``mode=peer`` + ``listen_endpoints``
|
||||
supports router-less LAN and loopback test deployments. Multicast
|
||||
scouting is always disabled: fleet discovery is configuration, not
|
||||
protocol magic.
|
||||
"""
|
||||
|
||||
mode: str = "client" # "client" (via router) | "peer" (direct)
|
||||
connect_endpoints: list[str] = field(default_factory=list)
|
||||
listen_endpoints: list[str] = field(default_factory=list)
|
||||
# mTLS material (PEM paths). All three are required for TLS endpoints.
|
||||
tls_root_ca_certificate: str | None = None
|
||||
tls_connect_certificate: str | None = None
|
||||
tls_connect_private_key: str | None = None
|
||||
# Escape hatch: raw JSON5 merged into the zenoh config last.
|
||||
extra_config_json5: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class DebugSpec:
|
||||
"""Optional bounded request/response capture for offline replay."""
|
||||
|
||||
capture_dir: str | None = None
|
||||
capture_max: int = 256
|
||||
|
||||
|
||||
@dataclass
|
||||
class PolicyServerManifest:
|
||||
"""Top-level config for ``lerobot-policy-server``."""
|
||||
|
||||
model: ModelSpec = field(default_factory=ModelSpec)
|
||||
zenoh: ZenohSpec = field(default_factory=ZenohSpec)
|
||||
|
||||
# The task namespace this service is published under. When
|
||||
# ``pin_task`` is true, session opens with a different task string
|
||||
# are rejected; otherwise VLA clients may set the task per session.
|
||||
default_task: str = ""
|
||||
pin_task: bool = False
|
||||
# Optional override for the <task_slug> key segment (defaults to a
|
||||
# slug of ``default_task``).
|
||||
service_name: str = ""
|
||||
|
||||
# "auto" resolves from the policy classification (shared for
|
||||
# chunk-stateless policies, exclusive otherwise). "exclusive" can be
|
||||
# forced; "shared" cannot override a chunk-stateful classification.
|
||||
serving_mode: str = SERVING_MODE_AUTO
|
||||
max_sessions: int = 5
|
||||
warmup_inferences: int = 2
|
||||
|
||||
# FPS contract: warn on mismatch unless strict.
|
||||
trained_fps: float = 30.0
|
||||
strict_fps: bool = False
|
||||
|
||||
# RTC behaviour for this server process (global to the shared policy:
|
||||
# ``init_rtc_processor`` mutates the policy instance, so it is a
|
||||
# per-process decision, not per-session).
|
||||
rtc: RTCConfig = field(default_factory=RTCConfig)
|
||||
|
||||
# Sessions with no liveliness token and no traffic for this long are
|
||||
# garbage-collected (belt-and-braces behind liveliness GC).
|
||||
session_idle_timeout_s: float = 300.0
|
||||
|
||||
# HTTP health + Prometheus metrics port; 0 disables the endpoint.
|
||||
health_port: int = 9100
|
||||
|
||||
debug: DebugSpec = field(default_factory=DebugSpec)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if not self.model.repo_or_path:
|
||||
raise ValueError("--model.repo_or_path is required (the policy this server serves)")
|
||||
if self.serving_mode not in (SERVING_MODE_AUTO, SERVING_MODE_SHARED, SERVING_MODE_EXCLUSIVE):
|
||||
raise ValueError(f"serving_mode must be one of auto|shared|exclusive, got {self.serving_mode!r}")
|
||||
if self.max_sessions < 1:
|
||||
raise ValueError(f"max_sessions must be >= 1, got {self.max_sessions}")
|
||||
if self.zenoh.mode not in ("client", "peer"):
|
||||
raise ValueError(f"zenoh.mode must be 'client' or 'peer', got {self.zenoh.mode!r}")
|
||||
if self.zenoh.mode == "client" and not self.zenoh.connect_endpoints:
|
||||
raise ValueError("zenoh.connect_endpoints is required in client mode (router address)")
|
||||
tls_fields = (
|
||||
self.zenoh.tls_root_ca_certificate,
|
||||
self.zenoh.tls_connect_certificate,
|
||||
self.zenoh.tls_connect_private_key,
|
||||
)
|
||||
if any(tls_fields) and not all(tls_fields):
|
||||
raise ValueError(
|
||||
"TLS requires all of tls_root_ca_certificate, tls_connect_certificate, "
|
||||
"tls_connect_private_key"
|
||||
)
|
||||
@@ -0,0 +1,58 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Scheduling seam between the session registry and the inference worker.
|
||||
|
||||
The v1 scheduler is strict round-robin over sessions with a pending
|
||||
observation: every ready session gets exactly one inference per cycle,
|
||||
so starvation is structurally impossible. The seam exists so that
|
||||
cross-session micro-batching can land later without redesign (blocked
|
||||
today on ``predict_action_chunk`` taking a *scalar* ``inference_delay``).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
|
||||
from .session import Session
|
||||
|
||||
|
||||
class Scheduler(abc.ABC):
|
||||
"""Pick which ready session(s) the worker serves next."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def select(self, ready: list[Session]) -> list[Session]:
|
||||
"""Return the sessions to serve this cycle (subset of ``ready``)."""
|
||||
|
||||
|
||||
class RoundRobinScheduler(Scheduler):
|
||||
"""Serve one session per cycle, fairly, in client_uuid order."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._last_served: str | None = None
|
||||
|
||||
def select(self, ready: list[Session]) -> list[Session]:
|
||||
if not ready:
|
||||
return []
|
||||
ring = sorted(ready, key=lambda s: s.client_uuid)
|
||||
if self._last_served is not None:
|
||||
for i, session in enumerate(ring):
|
||||
if session.client_uuid > self._last_served:
|
||||
ring = ring[i:] + ring[:i]
|
||||
break
|
||||
else:
|
||||
pass # wrap: everyone is <= last served, keep sorted order
|
||||
chosen = ring[0]
|
||||
self._last_served = chosen.client_uuid
|
||||
return [chosen]
|
||||
@@ -0,0 +1,340 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Wire schema for remote policy inference.
|
||||
|
||||
Message dataclasses, the packed attachment header, and the Zenoh
|
||||
key-expression layout shared by the policy server and the remote
|
||||
inference engine. This module is transport-free (no zenoh import) so
|
||||
codecs and validation can be unit-tested without the optional extra.
|
||||
|
||||
Schema discipline: bodies are MessagePack maps decoded tolerantly
|
||||
(unknown keys ignored, missing optional keys defaulted) so evolution is
|
||||
additive-only. Any change to the attachment layout requires a
|
||||
``SCHEMA_VERSION`` bump; versions are negotiated at session open.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import struct
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import numpy as np
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Versioning
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
SCHEMA_VERSION = 1
|
||||
# Oldest schema version this build can still serve.
|
||||
MIN_SUPPORTED_SCHEMA_VERSION = 1
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Attachment header (fixed layout, packed little-endian)
|
||||
#
|
||||
# Parsed without touching the msgpack body so routing, correlation and
|
||||
# supersession decisions never pay deserialization costs. The
|
||||
# ``client_mono_ns`` field is a client-monotonic timestamp that is
|
||||
# OPAQUE to the server: it is echoed back verbatim so the client can
|
||||
# compute round-trip times on its own clock. Wall-clock instants never
|
||||
# cross machines (the clock iron rule).
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_HEADER_STRUCT = struct.Struct("<HBQIqI") # schema_version, msg_type, seq_id, episode_id, mono_ns, epoch
|
||||
|
||||
MSG_TYPE_OBS = 1
|
||||
MSG_TYPE_CHUNK = 2
|
||||
MSG_TYPE_EVENT = 3
|
||||
|
||||
|
||||
@dataclass
|
||||
class MsgHeader:
|
||||
"""Packed per-message header carried in the Zenoh attachment."""
|
||||
|
||||
schema_version: int = SCHEMA_VERSION
|
||||
msg_type: int = MSG_TYPE_OBS
|
||||
seq_id: int = 0
|
||||
episode_id: int = 0
|
||||
client_mono_ns: int = 0
|
||||
session_epoch: int = 0
|
||||
|
||||
def pack(self) -> bytes:
|
||||
return _HEADER_STRUCT.pack(
|
||||
self.schema_version,
|
||||
self.msg_type,
|
||||
self.seq_id,
|
||||
self.episode_id,
|
||||
self.client_mono_ns,
|
||||
self.session_epoch,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def unpack(cls, data: bytes) -> MsgHeader:
|
||||
if len(data) != _HEADER_STRUCT.size:
|
||||
raise ValueError(f"Bad header length: {len(data)} (expected {_HEADER_STRUCT.size})")
|
||||
version, msg_type, seq_id, episode_id, mono_ns, epoch = _HEADER_STRUCT.unpack(data)
|
||||
return cls(
|
||||
schema_version=version,
|
||||
msg_type=msg_type,
|
||||
seq_id=seq_id,
|
||||
episode_id=episode_id,
|
||||
client_mono_ns=mono_ns,
|
||||
session_epoch=epoch,
|
||||
)
|
||||
|
||||
|
||||
HEADER_SIZE = _HEADER_STRUCT.size
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Message bodies
|
||||
#
|
||||
# ``np.ndarray`` fields travel as raw little-endian bytes + dtype + shape
|
||||
# (see codec.py). Images travel JPEG-compressed by default.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
IMAGE_CODEC_JPEG = "jpeg"
|
||||
IMAGE_CODEC_RAW = "raw"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ObservationMsg:
|
||||
"""Client → server: one inference request (data plane)."""
|
||||
|
||||
state: np.ndarray | None = None # float32 [state_dim]
|
||||
images: dict[str, np.ndarray] = field(default_factory=dict) # name -> uint8 HWC RGB
|
||||
task: str = ""
|
||||
inference_delay_steps: int = 0
|
||||
# RTC prefixes: the unexecuted tail of the previous chunk, in model
|
||||
# space (original) and robot space (postprocessed). Both are needed
|
||||
# because the server re-anchors relative-action prefixes against the
|
||||
# current state and the client's ActionQueue.merge needs both chunks.
|
||||
prefix_model: np.ndarray | None = None # float32 [T, action_dim]
|
||||
prefix_robot: np.ndarray | None = None # float32 [T, action_dim]
|
||||
episode_start: bool = False
|
||||
# JPEG quality the images were encoded with; 0 means raw.
|
||||
jpeg_quality: int = 90
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActionChunkMsg:
|
||||
"""Server → client: one action chunk (data plane)."""
|
||||
|
||||
seq_id_echo: int = 0
|
||||
client_mono_ns_echo: int = 0
|
||||
episode_id_echo: int = 0
|
||||
chunk_model: np.ndarray | None = None # float32 [H, action_dim] (pre-postprocessor)
|
||||
chunk_robot: np.ndarray | None = None # float32 [H, action_dim] (postprocessed)
|
||||
# Durations only — measured on the server's monotonic clock, never
|
||||
# compared against client time (the clock iron rule).
|
||||
queue_wait_ms: float = 0.0
|
||||
inference_ms: float = 0.0
|
||||
# Observations from this client that were superseded (overwritten in
|
||||
# the latest-only mailbox) since the previous reply — makes drops visible.
|
||||
superseded_seqs: int = 0
|
||||
server_load: float = 0.0 # active_sessions / max_sessions
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionOpenMsg:
|
||||
"""Client → server (control plane): open and validate a session."""
|
||||
|
||||
op: str = "open"
|
||||
client_uuid: str = ""
|
||||
robot_type: str = ""
|
||||
policy_type: str = ""
|
||||
fps: float = 0.0
|
||||
# Hard sync-safety contract: must equal the server's action feature
|
||||
# names *and order* — this maps chunk columns to motors.
|
||||
action_names: list[str] = field(default_factory=list)
|
||||
camera_names: list[str] = field(default_factory=list) # canonical keys (post-rename)
|
||||
state_dim: int = 0
|
||||
schema_version: int = SCHEMA_VERSION
|
||||
rtc_enabled: bool = False
|
||||
task: str = ""
|
||||
tags: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionAckMsg:
|
||||
"""Server → client (control plane): session accept/reject + capabilities."""
|
||||
|
||||
accepted: bool = False
|
||||
error: str = ""
|
||||
warnings: list[str] = field(default_factory=list)
|
||||
session_id: str = ""
|
||||
model_repo: str = ""
|
||||
model_revision: str = ""
|
||||
policy_type: str = ""
|
||||
action_names: list[str] = field(default_factory=list)
|
||||
expected_cameras: list[str] = field(default_factory=list)
|
||||
state_dim: int = 0
|
||||
chunk_size: int = 0
|
||||
trained_fps: float = 0.0
|
||||
supports_rtc: bool = False
|
||||
rtc_execution_horizon: int = 0
|
||||
serving_mode: str = ""
|
||||
warmed_up: bool = False
|
||||
schema_version: int = SCHEMA_VERSION
|
||||
server_load: float = 0.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class StatusMsg:
|
||||
"""Server → client (control plane): pre-flight capability snapshot."""
|
||||
|
||||
model_repo: str = ""
|
||||
model_revision: str = ""
|
||||
policy_type: str = ""
|
||||
action_names: list[str] = field(default_factory=list)
|
||||
expected_cameras: list[str] = field(default_factory=list)
|
||||
state_dim: int = 0
|
||||
chunk_size: int = 0
|
||||
trained_fps: float = 0.0
|
||||
supports_rtc: bool = False
|
||||
rtc_execution_horizon: int = 0
|
||||
serving_mode: str = ""
|
||||
warmed_up: bool = False
|
||||
min_schema_version: int = MIN_SUPPORTED_SCHEMA_VERSION
|
||||
max_schema_version: int = SCHEMA_VERSION
|
||||
active_sessions: int = 0
|
||||
max_sessions: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResetMsg:
|
||||
"""Client → server (control plane): episode boundary (acknowledged)."""
|
||||
|
||||
client_uuid: str = ""
|
||||
episode_id: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResetAckMsg:
|
||||
"""Server → client: reset acknowledgement."""
|
||||
|
||||
ok: bool = True
|
||||
error: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionCloseMsg:
|
||||
"""Client → server (control plane): graceful session close."""
|
||||
|
||||
op: str = "close"
|
||||
client_uuid: str = ""
|
||||
session_id: str = ""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Key-expression schema
|
||||
#
|
||||
# @lerobot/<service>/<client_uuid>/obs client → server (pub/sub)
|
||||
# @lerobot/<service>/<client_uuid>/action server → client (pub/sub)
|
||||
# @lerobot/<service>/status queryable (capabilities)
|
||||
# @lerobot/<service>/session queryable (open/close)
|
||||
# @lerobot/<service>/<client_uuid>/reset queryable (episode boundary)
|
||||
# @lerobot/<service>/<client_uuid>/alive liveliness token (client)
|
||||
# @lerobot/<service>/server/alive liveliness token (server)
|
||||
#
|
||||
# where <service> = <model_slug>/<revision_slug>/<task_slug>. The task
|
||||
# segment is a *namespace label* derived from the server's default task
|
||||
# (or an explicit service name) — the actual inference task string
|
||||
# travels in the session/observation messages.
|
||||
#
|
||||
# ``@lerobot`` is a verbatim chunk: it is only matched by an identical
|
||||
# chunk, so third-party ``**`` subscribers on a shared router can never
|
||||
# scrape this tree.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
KEY_ROOT = "@lerobot"
|
||||
|
||||
# Conservative allowlist for user-supplied key segments. Everything
|
||||
# else (including '/', '*', '$', '?', '#', whitespace) is folded to '-'.
|
||||
_SEGMENT_SANITIZE_RE = re.compile(r"[^A-Za-z0-9_.\-]+")
|
||||
|
||||
# Reserved final chunks of the key tree; a client UUID must never
|
||||
# collide with them.
|
||||
RESERVED_SEGMENTS = frozenset({"status", "session", "server", "obs", "action", "reset", "alive"})
|
||||
|
||||
|
||||
def sanitize_key_segment(segment: str) -> str:
|
||||
"""Fold an arbitrary string into a single safe Zenoh key chunk."""
|
||||
slug = _SEGMENT_SANITIZE_RE.sub("-", segment.strip()).strip("-.")
|
||||
if not slug:
|
||||
raise ValueError(f"Key segment {segment!r} sanitizes to an empty chunk")
|
||||
if slug in RESERVED_SEGMENTS:
|
||||
raise ValueError(f"Key segment {segment!r} collides with reserved chunk {slug!r}")
|
||||
return slug
|
||||
|
||||
|
||||
def service_prefix(model_id: str, revision: str, task: str) -> str:
|
||||
"""Build the shared key prefix for one served (model, revision, task) triple."""
|
||||
return "/".join(
|
||||
(
|
||||
KEY_ROOT,
|
||||
sanitize_key_segment(model_id),
|
||||
sanitize_key_segment(revision or "main"),
|
||||
sanitize_key_segment(task or "default"),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def obs_key(prefix: str, client_uuid: str) -> str:
|
||||
return f"{prefix}/{sanitize_key_segment(client_uuid)}/obs"
|
||||
|
||||
|
||||
def action_key(prefix: str, client_uuid: str) -> str:
|
||||
return f"{prefix}/{sanitize_key_segment(client_uuid)}/action"
|
||||
|
||||
|
||||
def reset_key(prefix: str, client_uuid: str) -> str:
|
||||
return f"{prefix}/{sanitize_key_segment(client_uuid)}/reset"
|
||||
|
||||
|
||||
def client_alive_key(prefix: str, client_uuid: str) -> str:
|
||||
return f"{prefix}/{sanitize_key_segment(client_uuid)}/alive"
|
||||
|
||||
|
||||
def status_key(prefix: str) -> str:
|
||||
return f"{prefix}/status"
|
||||
|
||||
|
||||
def session_key(prefix: str) -> str:
|
||||
return f"{prefix}/session"
|
||||
|
||||
|
||||
def server_alive_key(prefix: str) -> str:
|
||||
return f"{prefix}/server/alive"
|
||||
|
||||
|
||||
# Single-depth wildcards only — '**' would also match status/session/alive.
|
||||
def obs_wildcard(prefix: str) -> str:
|
||||
return f"{prefix}/*/obs"
|
||||
|
||||
|
||||
def reset_wildcard(prefix: str) -> str:
|
||||
return f"{prefix}/*/reset"
|
||||
|
||||
|
||||
def client_alive_wildcard(prefix: str) -> str:
|
||||
return f"{prefix}/*/alive"
|
||||
|
||||
|
||||
def client_uuid_from_key(key: str) -> str:
|
||||
"""Extract the client UUID chunk from an obs/reset/alive key."""
|
||||
parts = key.split("/")
|
||||
if len(parts) < 2:
|
||||
raise ValueError(f"Key {key!r} has no client chunk")
|
||||
return parts[-2]
|
||||
@@ -0,0 +1,934 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""``lerobot-policy-server``: multi-client GPU inference over Zenoh.
|
||||
|
||||
One process serves one pre-warmed (model, revision, dtype, device) to up
|
||||
to ``max_sessions`` robot clients. The process is **stateless per
|
||||
request**: clients ship RTC prefixes and a delay hint with every
|
||||
observation, so a server crash loses zero control state and reconnects
|
||||
are trivial.
|
||||
|
||||
Concurrency model (pure threads — zenoh-python has no asyncio API):
|
||||
|
||||
zenoh subscriber (.../*/obs) inference worker (1 thread, owns GPU)
|
||||
deposit-only callback: loop:
|
||||
session.deposit(header, body) ──► pick next session with pending obs (RR)
|
||||
(per-client latest-only slot) decode → per-session preprocess
|
||||
predict_action_chunk(delay, prefix)
|
||||
control queryables (status/session/ per-session postprocess → encode
|
||||
reset): validate, mutate session publisher.put(.../<uuid>/action)
|
||||
registry, reply inline
|
||||
|
||||
The single worker thread serializes GPU access; newest-wins mailboxes
|
||||
mean overload degrades into longer cycle times (larger but correct
|
||||
client delays), never into queue buildup.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import http.server
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
import uuid as uuid_module
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.configs import FeatureType
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.policies.factory import get_policy_class, make_pre_post_processors
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.rtc.relative import reanchor_relative_rtc_prefix
|
||||
from lerobot.policies.utils import populate_queues, prepare_observation_for_inference
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
|
||||
|
||||
from . import codec
|
||||
from .manifest import PolicyServerManifest
|
||||
from .scheduler import RoundRobinScheduler, Scheduler
|
||||
from .schema import (
|
||||
SCHEMA_VERSION,
|
||||
ActionChunkMsg,
|
||||
MsgHeader,
|
||||
ObservationMsg,
|
||||
ResetAckMsg,
|
||||
SessionAckMsg,
|
||||
SessionCloseMsg,
|
||||
SessionOpenMsg,
|
||||
StatusMsg,
|
||||
action_key,
|
||||
client_alive_key,
|
||||
client_alive_wildcard,
|
||||
client_uuid_from_key,
|
||||
obs_wildcard,
|
||||
reset_wildcard,
|
||||
server_alive_key,
|
||||
service_prefix,
|
||||
session_key,
|
||||
status_key,
|
||||
)
|
||||
from .session import Session, SessionRegistry
|
||||
from .validation import (
|
||||
PolicyClassification,
|
||||
classify_policy,
|
||||
resolve_serving_mode,
|
||||
validate_session_open,
|
||||
)
|
||||
from .zenoh_utils import action_publisher_qos, build_zenoh_config, import_zenoh
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
audit_logger = logging.getLogger("lerobot.policy_server.audit")
|
||||
|
||||
# Grace period after a client liveliness token drops before its session
|
||||
# is garbage-collected (rides out router blips and reconnects).
|
||||
_LIVELINESS_GC_GRACE_S = 5.0
|
||||
# Worker idle wait between work-event checks (also paces the GC sweep).
|
||||
_WORKER_IDLE_WAIT_S = 0.05
|
||||
|
||||
|
||||
def _normalize_prev_actions_length(prev_actions: torch.Tensor, target_steps: int) -> torch.Tensor:
|
||||
"""Pad or truncate RTC prefix actions to a fixed length (mirrors rtc.py)."""
|
||||
if prev_actions.ndim != 2:
|
||||
raise ValueError(f"Expected 2D [T, A] tensor, got shape={tuple(prev_actions.shape)}")
|
||||
steps, action_dim = prev_actions.shape
|
||||
if steps == target_steps:
|
||||
return prev_actions
|
||||
if steps > target_steps:
|
||||
return prev_actions[:target_steps]
|
||||
padded = torch.zeros((target_steps, action_dim), dtype=prev_actions.dtype, device=prev_actions.device)
|
||||
padded[:steps] = prev_actions
|
||||
return padded
|
||||
|
||||
|
||||
class PolicyServer:
|
||||
"""Zenoh policy server: control-plane queryables + data-plane pub/sub + one GPU worker."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
manifest: PolicyServerManifest,
|
||||
*,
|
||||
policy: PreTrainedPolicy | None = None,
|
||||
policy_cfg: PreTrainedConfig | None = None,
|
||||
processor_factory: Callable[[], tuple[Any, Any]] | None = None,
|
||||
classification: PolicyClassification | None = None,
|
||||
scheduler: Scheduler | None = None,
|
||||
) -> None:
|
||||
"""``policy``/``policy_cfg``/``processor_factory``/``classification``
|
||||
are injection points for tests; production loads everything from
|
||||
the manifest via :meth:`load_policy`.
|
||||
"""
|
||||
self._manifest = manifest
|
||||
self._device = torch.device(manifest.model.device)
|
||||
self._policy = policy
|
||||
self._policy_cfg = policy_cfg
|
||||
self._processor_factory = processor_factory
|
||||
self._classification = classification
|
||||
self._scheduler = scheduler or RoundRobinScheduler()
|
||||
|
||||
self._serving_mode: str = ""
|
||||
self._max_sessions: int = manifest.max_sessions
|
||||
self._rtc_active = False
|
||||
self._warmed_up = False
|
||||
|
||||
self.registry = SessionRegistry()
|
||||
self._registry_lock = threading.Lock() # serializes open/close/GC decisions
|
||||
# Serializes inference against episode resets: in exclusive mode a
|
||||
# reset (policy.reset(), pipeline reset) arriving on a queryable
|
||||
# thread mid-predict would corrupt the in-flight request's state.
|
||||
self._inference_lock = threading.Lock()
|
||||
|
||||
self._zenoh = None
|
||||
self._declarations: list[Any] = []
|
||||
self._alive_token = None
|
||||
|
||||
self._work = threading.Event()
|
||||
self._shutdown = threading.Event()
|
||||
self._worker: threading.Thread | None = None
|
||||
self._health_server: http.server.ThreadingHTTPServer | None = None
|
||||
|
||||
self._unknown_clients_warned: set[str] = set()
|
||||
self._capture_count = 0
|
||||
|
||||
self.metrics: dict[str, float] = {
|
||||
"requests_total": 0,
|
||||
"errors_total": 0,
|
||||
"superseded_total": 0,
|
||||
"dropped_unknown_client_total": 0,
|
||||
"sessions_opened_total": 0,
|
||||
"sessions_closed_total": 0,
|
||||
}
|
||||
self._metrics_lock = threading.Lock()
|
||||
|
||||
task_slug_source = manifest.service_name or manifest.default_task or "default"
|
||||
self.prefix = service_prefix(manifest.model.repo_or_path, manifest.model.revision, task_slug_source)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Loading & warmup
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def load_policy(self) -> None:
|
||||
"""Load config + weights, apply RTC settings, classify, warm up."""
|
||||
manifest = self._manifest
|
||||
if self._policy is None:
|
||||
logger.info(
|
||||
"Loading policy from '%s' (revision=%s)...",
|
||||
manifest.model.repo_or_path,
|
||||
manifest.model.revision,
|
||||
)
|
||||
policy_cfg = PreTrainedConfig.from_pretrained(manifest.model.repo_or_path)
|
||||
policy_cfg.pretrained_path = manifest.model.repo_or_path
|
||||
policy_class = get_policy_class(policy_cfg.type)
|
||||
policy = policy_class.from_pretrained(manifest.model.repo_or_path, config=policy_cfg)
|
||||
self._policy = policy
|
||||
self._policy_cfg = policy_cfg
|
||||
elif self._policy_cfg is None:
|
||||
self._policy_cfg = self._policy.config
|
||||
|
||||
if self._classification is None:
|
||||
self._classification = classify_policy(self._policy)
|
||||
logger.info("Policy classification: %s", self._classification.reason)
|
||||
|
||||
self._serving_mode, self._max_sessions = resolve_serving_mode(self._classification, manifest)
|
||||
logger.info("Serving mode: %s (max_sessions=%d)", self._serving_mode, self._max_sessions)
|
||||
|
||||
# RTC is a per-process decision: init_rtc_processor mutates the
|
||||
# shared policy instance.
|
||||
self._rtc_active = (
|
||||
manifest.rtc.enabled
|
||||
and self._classification.supports_rtc
|
||||
and hasattr(self._policy.config, "rtc_config")
|
||||
)
|
||||
if self._rtc_active:
|
||||
self._policy.config.rtc_config = manifest.rtc
|
||||
if hasattr(self._policy, "init_rtc_processor"):
|
||||
self._policy.init_rtc_processor()
|
||||
logger.info("RTC active (execution_horizon=%d)", manifest.rtc.execution_horizon)
|
||||
|
||||
if manifest.model.dtype:
|
||||
self._policy = self._policy.to(getattr(torch, manifest.model.dtype))
|
||||
self._policy = self._policy.to(self._device)
|
||||
self._policy.eval()
|
||||
|
||||
if not self.action_names:
|
||||
logger.warning(
|
||||
"Policy config has no action_feature_names: the action-order contract "
|
||||
"cannot be enforced at session open. Clients are trusted to match training order."
|
||||
)
|
||||
|
||||
if manifest.warmup_inferences > 0:
|
||||
self._warmup(manifest.warmup_inferences)
|
||||
self._warmed_up = True
|
||||
|
||||
def make_session_processors(self) -> tuple[Any, Any]:
|
||||
"""Build a fresh per-session pre/post pipeline pair.
|
||||
|
||||
The rename step is forced to identity: clients apply their
|
||||
rename map before encoding, so the wire format is canonical
|
||||
policy-feature keys across heterogeneous robots.
|
||||
"""
|
||||
if self._processor_factory is not None:
|
||||
return self._processor_factory()
|
||||
return make_pre_post_processors(
|
||||
policy_cfg=self._policy_cfg,
|
||||
pretrained_path=self._policy_cfg.pretrained_path,
|
||||
preprocessor_overrides={
|
||||
"device_processor": {"device": str(self._device)},
|
||||
"rename_observations_processor": {"rename_map": {}},
|
||||
},
|
||||
)
|
||||
|
||||
def _warmup(self, n: int) -> None:
|
||||
"""Run dummy forwards through the full request path (covers compile/caches)."""
|
||||
logger.info("Warmup: %d inferences...", n)
|
||||
obs = self._synthetic_observation()
|
||||
preprocessor, postprocessor = self.make_session_processors()
|
||||
session = Session(
|
||||
session_id="warmup",
|
||||
client_uuid="warmup",
|
||||
task=self._manifest.default_task,
|
||||
robot_type="",
|
||||
rtc_enabled=self._rtc_active,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
)
|
||||
reply = self.run_inference_request(session, MsgHeader(), obs)
|
||||
if self._rtc_active and reply.chunk_model is not None and n > 1:
|
||||
# Exercise the prefix-conditioned path too so its compile/cache
|
||||
# cost isn't paid by the first real RTC request.
|
||||
action_dim = reply.chunk_model.shape[-1]
|
||||
horizon = self._manifest.rtc.execution_horizon
|
||||
obs.prefix_model = np.zeros((horizon, action_dim), dtype=np.float32)
|
||||
obs.prefix_robot = np.zeros((horizon, action_dim), dtype=np.float32)
|
||||
obs.inference_delay_steps = 1
|
||||
for _ in range(n - 1):
|
||||
self.run_inference_request(session, MsgHeader(), obs)
|
||||
session.close()
|
||||
# Stateful policies must not carry warmup observations into real sessions.
|
||||
if self._serving_mode == "exclusive":
|
||||
self._policy.reset()
|
||||
logger.info("Warmup complete")
|
||||
|
||||
def _synthetic_observation(self) -> ObservationMsg:
|
||||
cfg = self._policy_cfg
|
||||
state_dim = self.state_dim or 1
|
||||
images = {}
|
||||
for key, feature in cfg.input_features.items():
|
||||
if feature.type == FeatureType.VISUAL:
|
||||
channels, height, width = feature.shape
|
||||
images[key] = np.zeros((height, width, channels), dtype=np.uint8)
|
||||
return ObservationMsg(
|
||||
state=np.zeros(state_dim, dtype=np.float32),
|
||||
images=images,
|
||||
task=self._manifest.default_task,
|
||||
jpeg_quality=0,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Capabilities
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@property
|
||||
def action_names(self) -> list[str]:
|
||||
names = getattr(self._policy_cfg, "action_feature_names", None)
|
||||
return list(names) if names else []
|
||||
|
||||
@property
|
||||
def state_dim(self) -> int:
|
||||
cfg = self._policy_cfg
|
||||
for key, feature in getattr(cfg, "input_features", {}).items():
|
||||
if key == OBS_STATE or feature.type == FeatureType.STATE:
|
||||
return int(feature.shape[0])
|
||||
return 0
|
||||
|
||||
@property
|
||||
def chunk_size(self) -> int:
|
||||
cfg = self._policy_cfg
|
||||
for attr in ("chunk_size", "n_action_steps", "horizon"):
|
||||
value = getattr(cfg, attr, None)
|
||||
if value:
|
||||
return int(value)
|
||||
return 0
|
||||
|
||||
def status_snapshot(self) -> StatusMsg:
|
||||
cfg = self._policy_cfg
|
||||
expected_cameras = [
|
||||
key
|
||||
for key, feature in getattr(cfg, "input_features", {}).items()
|
||||
if feature.type == FeatureType.VISUAL
|
||||
]
|
||||
return StatusMsg(
|
||||
model_repo=self._manifest.model.repo_or_path,
|
||||
model_revision=self._manifest.model.revision,
|
||||
policy_type=getattr(cfg, "type", "") or getattr(self._policy, "name", ""),
|
||||
action_names=self.action_names,
|
||||
expected_cameras=expected_cameras,
|
||||
state_dim=self.state_dim,
|
||||
chunk_size=self.chunk_size,
|
||||
trained_fps=self._manifest.trained_fps,
|
||||
supports_rtc=self._rtc_active,
|
||||
rtc_execution_horizon=self._manifest.rtc.execution_horizon if self._rtc_active else 0,
|
||||
serving_mode=self._serving_mode,
|
||||
warmed_up=self._warmed_up,
|
||||
active_sessions=len(self.registry),
|
||||
max_sessions=self._max_sessions,
|
||||
)
|
||||
|
||||
@property
|
||||
def server_load(self) -> float:
|
||||
return len(self.registry) / max(1, self._max_sessions)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# The per-request inference path (pure: no zenoh — parity-testable)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def run_inference_request(
|
||||
self, session: Session, header: MsgHeader, obs: ObservationMsg
|
||||
) -> ActionChunkMsg:
|
||||
"""Mirror of the local RTC loop's compute step (rtc.py), minus the queue merge."""
|
||||
t0 = time.perf_counter()
|
||||
|
||||
obs_np: dict[str, np.ndarray] = {}
|
||||
if obs.state is not None:
|
||||
obs_np[OBS_STATE] = np.asarray(obs.state, dtype=np.float32)
|
||||
for name, img in obs.images.items():
|
||||
obs_np[name] = img
|
||||
|
||||
task = obs.task or session.task or self._manifest.default_task
|
||||
batch = prepare_observation_for_inference(obs_np, self._device, task, session.robot_type)
|
||||
batch["task"] = [task]
|
||||
|
||||
preprocessed = session.preprocessor(batch)
|
||||
|
||||
use_rtc = self._rtc_active and session.rtc_enabled
|
||||
if use_rtc:
|
||||
delay = max(0, int(obs.inference_delay_steps))
|
||||
prev_actions: torch.Tensor | None = None
|
||||
if obs.prefix_model is not None and obs.prefix_model.size:
|
||||
prev_actions = torch.from_numpy(np.ascontiguousarray(obs.prefix_model)).to(self._device)
|
||||
|
||||
if prev_actions is not None and session.relative_step is not None:
|
||||
# Re-anchor the absolute leftover tail against the state
|
||||
# cached by THIS request's preprocess (mirrors rtc.py).
|
||||
raw_state = session.relative_step.get_cached_state()
|
||||
prefix_robot = obs.prefix_robot
|
||||
if raw_state is not None and prefix_robot is not None and prefix_robot.size:
|
||||
prev_actions = reanchor_relative_rtc_prefix(
|
||||
prev_actions_absolute=torch.from_numpy(np.ascontiguousarray(prefix_robot)),
|
||||
current_state=raw_state,
|
||||
relative_step=session.relative_step,
|
||||
normalizer_step=session.normalizer_step,
|
||||
policy_device=self._device,
|
||||
)
|
||||
|
||||
if prev_actions is not None:
|
||||
prev_actions = _normalize_prev_actions_length(
|
||||
prev_actions, target_steps=self._manifest.rtc.execution_horizon
|
||||
)
|
||||
|
||||
actions = self._policy.predict_action_chunk(
|
||||
preprocessed, inference_delay=delay, prev_chunk_left_over=prev_actions
|
||||
)
|
||||
else:
|
||||
if self._classification is not None and self._classification.needs_queue_population:
|
||||
preprocessed = self._populate_select_queues(preprocessed)
|
||||
actions = self._policy.predict_action_chunk(preprocessed)
|
||||
|
||||
original = actions.squeeze(0).clone()
|
||||
processed = session.postprocessor(actions).squeeze(0)
|
||||
inference_ms = (time.perf_counter() - t0) * 1e3
|
||||
|
||||
session.stats.requests += 1
|
||||
session.stats.last_inference_ms = inference_ms
|
||||
superseded = session.take_superseded()
|
||||
|
||||
return ActionChunkMsg(
|
||||
seq_id_echo=header.seq_id,
|
||||
client_mono_ns_echo=header.client_mono_ns,
|
||||
episode_id_echo=header.episode_id,
|
||||
chunk_model=original.detach().to("cpu", torch.float32).numpy(),
|
||||
chunk_robot=processed.detach().to("cpu", torch.float32).numpy(),
|
||||
inference_ms=inference_ms,
|
||||
superseded_seqs=superseded,
|
||||
server_load=self.server_load,
|
||||
)
|
||||
|
||||
def _populate_select_queues(self, batch: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Exclusive-mode shim for select_action-fed policies (diffusion family).
|
||||
|
||||
Mirrors ``DiffusionPolicy.select_action``: stack camera features
|
||||
into OBS_IMAGES, then populate the policy's observation queues so
|
||||
``predict_action_chunk`` sees the same history it would locally.
|
||||
"""
|
||||
policy = self._policy
|
||||
batch = {k: v for k, v in batch.items() if k != ACTION}
|
||||
if getattr(policy.config, "image_features", None):
|
||||
batch[OBS_IMAGES] = torch.stack([batch[key] for key in policy.config.image_features], dim=-4)
|
||||
policy._queues = populate_queues(policy._queues, batch)
|
||||
return batch
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Zenoh wiring
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def start(self) -> None:
|
||||
"""Open zenoh, declare the service surface, start worker + health threads."""
|
||||
if self._policy is None or not self._warmed_up:
|
||||
self.load_policy()
|
||||
|
||||
zenoh = import_zenoh()
|
||||
spec = self._manifest.zenoh
|
||||
self._zenoh = zenoh.open(
|
||||
build_zenoh_config(
|
||||
mode=spec.mode,
|
||||
connect_endpoints=spec.connect_endpoints,
|
||||
listen_endpoints=spec.listen_endpoints,
|
||||
tls_root_ca_certificate=spec.tls_root_ca_certificate,
|
||||
tls_connect_certificate=spec.tls_connect_certificate,
|
||||
tls_connect_private_key=spec.tls_connect_private_key,
|
||||
extra_config_json5=spec.extra_config_json5,
|
||||
)
|
||||
)
|
||||
handlers = zenoh.handlers
|
||||
|
||||
# Data plane: wildcard subscriber, deposit-only callback.
|
||||
self._declarations.append(
|
||||
self._zenoh.declare_subscriber(obs_wildcard(self.prefix), handlers.Callback(self._on_obs))
|
||||
)
|
||||
# Control plane: queryables reply inline (low rate).
|
||||
self._declarations.append(
|
||||
self._zenoh.declare_queryable(status_key(self.prefix), handlers.Callback(self._on_status_query))
|
||||
)
|
||||
self._declarations.append(
|
||||
self._zenoh.declare_queryable(session_key(self.prefix), handlers.Callback(self._on_session_query))
|
||||
)
|
||||
self._declarations.append(
|
||||
self._zenoh.declare_queryable(
|
||||
reset_wildcard(self.prefix), handlers.Callback(self._on_reset_query)
|
||||
)
|
||||
)
|
||||
# Presence: watch client tokens; publish our own.
|
||||
self._declarations.append(
|
||||
self._zenoh.liveliness().declare_subscriber(
|
||||
client_alive_wildcard(self.prefix), handlers.Callback(self._on_liveliness), history=True
|
||||
)
|
||||
)
|
||||
self._alive_token = self._zenoh.liveliness().declare_token(server_alive_key(self.prefix))
|
||||
|
||||
self._shutdown.clear()
|
||||
self._worker = threading.Thread(target=self._worker_loop, daemon=True, name="PolicyServerWorker")
|
||||
self._worker.start()
|
||||
|
||||
if self._manifest.health_port:
|
||||
self._start_health_server(self._manifest.health_port)
|
||||
|
||||
logger.info(
|
||||
"Policy server up: prefix=%s mode=%s max_sessions=%d rtc=%s",
|
||||
self.prefix,
|
||||
self._serving_mode,
|
||||
self._max_sessions,
|
||||
self._rtc_active,
|
||||
)
|
||||
|
||||
def serve_forever(self) -> None:
|
||||
try:
|
||||
while not self._shutdown.is_set():
|
||||
self._shutdown.wait(timeout=0.5)
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Interrupted — draining")
|
||||
finally:
|
||||
self.stop()
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Drain: drop the liveliness token first (clients ride their buffers
|
||||
through the drain), finish the in-flight inference, then close."""
|
||||
if self._alive_token is not None:
|
||||
with contextlib.suppress(Exception):
|
||||
self._alive_token.undeclare()
|
||||
self._alive_token = None
|
||||
|
||||
self._shutdown.set()
|
||||
self._work.set()
|
||||
if self._worker is not None and self._worker.is_alive():
|
||||
self._worker.join(timeout=10.0)
|
||||
if self._worker.is_alive():
|
||||
logger.warning("Inference worker did not join within 10s")
|
||||
self._worker = None
|
||||
|
||||
# Undeclare the control/data surface BEFORE closing sessions so a
|
||||
# late session open cannot be accepted by a server that has
|
||||
# already drained its worker.
|
||||
for declaration in self._declarations:
|
||||
with contextlib.suppress(Exception):
|
||||
declaration.undeclare()
|
||||
self._declarations.clear()
|
||||
|
||||
for session in self.registry.snapshot():
|
||||
self._close_session(session, reason="server shutdown")
|
||||
|
||||
if self._zenoh is not None:
|
||||
with contextlib.suppress(Exception):
|
||||
self._zenoh.close()
|
||||
self._zenoh = None
|
||||
|
||||
if self._health_server is not None:
|
||||
self._health_server.shutdown()
|
||||
self._health_server = None
|
||||
logger.info("Policy server stopped")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Zenoh callbacks (deposit-only on the data plane)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _on_obs(self, sample: Any) -> None:
|
||||
try:
|
||||
attachment = sample.attachment
|
||||
if attachment is None:
|
||||
return
|
||||
header = MsgHeader.unpack(attachment.to_bytes())
|
||||
if header.schema_version != SCHEMA_VERSION:
|
||||
return
|
||||
client_uuid = client_uuid_from_key(str(sample.key_expr))
|
||||
session = self.registry.get(client_uuid)
|
||||
if session is None:
|
||||
self._bump("dropped_unknown_client_total")
|
||||
# Bounded: garbage publishers must not grow this set (or
|
||||
# the log) without limit.
|
||||
if (
|
||||
client_uuid not in self._unknown_clients_warned
|
||||
and len(self._unknown_clients_warned) < 256
|
||||
):
|
||||
self._unknown_clients_warned.add(client_uuid)
|
||||
logger.warning(
|
||||
"Observation from unknown client '%s' (no session) — dropping", client_uuid
|
||||
)
|
||||
return
|
||||
session.deposit(header, sample.payload.to_bytes())
|
||||
self._work.set()
|
||||
except Exception as e: # noqa: BLE001 — a malformed sample must never kill the subscriber
|
||||
logger.error("obs callback error: %s", e)
|
||||
|
||||
def _on_liveliness(self, sample: Any) -> None:
|
||||
try:
|
||||
import zenoh
|
||||
|
||||
client_uuid = client_uuid_from_key(str(sample.key_expr))
|
||||
session = self.registry.get(client_uuid)
|
||||
if session is None:
|
||||
return
|
||||
if sample.kind == zenoh.SampleKind.DELETE:
|
||||
session.alive = False
|
||||
session.token_dropped_mono = time.monotonic()
|
||||
logger.info(
|
||||
"Client '%s' liveliness dropped — GC in %.0fs", client_uuid, _LIVELINESS_GC_GRACE_S
|
||||
)
|
||||
else:
|
||||
session.alive = True
|
||||
session.token_dropped_mono = None
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.error("liveliness callback error: %s", e)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Control-plane queryables
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _on_status_query(self, query: Any) -> None:
|
||||
try:
|
||||
query.reply(status_key(self.prefix), codec.encode_status(self.status_snapshot()))
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.error("status query error: %s", e)
|
||||
|
||||
def _on_session_query(self, query: Any) -> None:
|
||||
try:
|
||||
payload = query.payload.to_bytes() if query.payload is not None else b""
|
||||
op = codec.decode_raw(payload).get("op", "open") if payload else "open"
|
||||
if op == "close":
|
||||
self._handle_session_close(codec.decode_session_close(payload))
|
||||
query.reply(session_key(self.prefix), codec.encode_reset_ack(ResetAckMsg(ok=True)))
|
||||
return
|
||||
ack = self._handle_session_open(codec.decode_session_open(payload))
|
||||
query.reply(session_key(self.prefix), codec.encode_session_ack(ack))
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.error("session query error: %s\n%s", e, traceback.format_exc())
|
||||
with contextlib.suppress(Exception):
|
||||
query.reply(
|
||||
session_key(self.prefix),
|
||||
codec.encode_session_ack(SessionAckMsg(accepted=False, error=f"server error: {e}")),
|
||||
)
|
||||
|
||||
def _on_reset_query(self, query: Any) -> None:
|
||||
try:
|
||||
payload = query.payload.to_bytes() if query.payload is not None else b""
|
||||
msg = codec.decode_reset(payload)
|
||||
session = self.registry.get(msg.client_uuid)
|
||||
if session is None:
|
||||
ack = ResetAckMsg(ok=False, error=f"unknown client '{msg.client_uuid}'")
|
||||
else:
|
||||
# Serialize with the worker: resetting pipelines/policy
|
||||
# mid-predict would corrupt the in-flight request.
|
||||
with self._inference_lock:
|
||||
session.reset_episode(msg.episode_id)
|
||||
if self._serving_mode == "exclusive":
|
||||
# Safe: max_sessions=1, the policy belongs to this client.
|
||||
self._policy.reset()
|
||||
ack = ResetAckMsg(ok=True)
|
||||
query.reply(str(query.key_expr), codec.encode_reset_ack(ack))
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.error("reset query error: %s", e)
|
||||
|
||||
def _handle_session_open(self, msg: SessionOpenMsg) -> SessionAckMsg:
|
||||
capabilities = self.status_snapshot()
|
||||
with self._registry_lock:
|
||||
# A re-handshake from a known client replaces its session and
|
||||
# does not count against capacity.
|
||||
existing = self.registry.get(msg.client_uuid)
|
||||
active = len(self.registry) - (1 if existing else 0)
|
||||
result = validate_session_open(msg, capabilities, self._manifest, active)
|
||||
if not result.ok:
|
||||
logger.warning("Session rejected for '%s': %s", msg.client_uuid, result.error)
|
||||
return SessionAckMsg(accepted=False, error=result.error, server_load=self.server_load)
|
||||
|
||||
preprocessor, postprocessor = self.make_session_processors()
|
||||
session = Session(
|
||||
session_id=uuid_module.uuid4().hex,
|
||||
client_uuid=msg.client_uuid,
|
||||
task=msg.task or self._manifest.default_task,
|
||||
robot_type=msg.robot_type,
|
||||
rtc_enabled=msg.rtc_enabled and not result.rtc_downgraded,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
action_publisher=self._declare_action_publisher(msg.client_uuid),
|
||||
)
|
||||
if session.relative_step is not None and session.relative_step.action_names is None:
|
||||
session.relative_step.action_names = self.action_names or list(msg.action_names)
|
||||
# Sentinel: the FIRST observation of a fresh session always
|
||||
# triggers the episode-boundary branch in _serve_one, so a
|
||||
# mid-episode reconnect can never inherit stale state.
|
||||
session.episode_id = -1
|
||||
displaced = self.registry.add(session)
|
||||
if displaced is not None:
|
||||
displaced.close()
|
||||
self._bump("sessions_closed_total")
|
||||
logger.info("Client '%s' re-handshake: previous session replaced", msg.client_uuid)
|
||||
if self._serving_mode == "exclusive":
|
||||
# A new exclusive session must start from fresh policy state.
|
||||
with self._inference_lock:
|
||||
self._policy.reset()
|
||||
self._bump("sessions_opened_total")
|
||||
self._unknown_clients_warned.discard(msg.client_uuid)
|
||||
logger.info(
|
||||
"Session opened: client=%s session=%s task=%r rtc=%s (%d/%d)",
|
||||
msg.client_uuid,
|
||||
session.session_id,
|
||||
session.task,
|
||||
session.rtc_enabled,
|
||||
len(self.registry),
|
||||
self._max_sessions,
|
||||
)
|
||||
return SessionAckMsg(
|
||||
accepted=True,
|
||||
warnings=result.warnings,
|
||||
session_id=session.session_id,
|
||||
model_repo=capabilities.model_repo,
|
||||
model_revision=capabilities.model_revision,
|
||||
policy_type=capabilities.policy_type,
|
||||
action_names=capabilities.action_names,
|
||||
expected_cameras=capabilities.expected_cameras,
|
||||
state_dim=capabilities.state_dim,
|
||||
chunk_size=capabilities.chunk_size,
|
||||
trained_fps=capabilities.trained_fps,
|
||||
supports_rtc=capabilities.supports_rtc and session.rtc_enabled,
|
||||
rtc_execution_horizon=capabilities.rtc_execution_horizon,
|
||||
serving_mode=capabilities.serving_mode,
|
||||
warmed_up=capabilities.warmed_up,
|
||||
server_load=self.server_load,
|
||||
)
|
||||
|
||||
def _declare_action_publisher(self, client_uuid: str) -> Any:
|
||||
if self._zenoh is None: # pure-logic tests run without transport
|
||||
return None
|
||||
zenoh = import_zenoh()
|
||||
return self._zenoh.declare_publisher(
|
||||
action_key(self.prefix, client_uuid), **action_publisher_qos(zenoh)
|
||||
)
|
||||
|
||||
def _handle_session_close(self, msg: SessionCloseMsg) -> None:
|
||||
session = self.registry.get(msg.client_uuid)
|
||||
if session is not None and (not msg.session_id or msg.session_id == session.session_id):
|
||||
self._close_session(session, reason="client close")
|
||||
|
||||
def _close_session(self, session: Session, reason: str) -> None:
|
||||
# Identity-checked removal: never tear down a same-uuid session
|
||||
# that replaced this one via a re-handshake.
|
||||
removed = self.registry.remove(session.client_uuid, expected=session)
|
||||
if removed is not None:
|
||||
removed.close()
|
||||
self._bump("sessions_closed_total")
|
||||
logger.info("Session closed: client=%s (%s)", session.client_uuid, reason)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Inference worker
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _worker_loop(self) -> None:
|
||||
last_gc = time.monotonic()
|
||||
while not self._shutdown.is_set():
|
||||
ready = [s for s in self.registry.snapshot() if s.has_pending()]
|
||||
if not ready:
|
||||
self._work.wait(timeout=_WORKER_IDLE_WAIT_S)
|
||||
self._work.clear()
|
||||
else:
|
||||
for session in self._scheduler.select(ready):
|
||||
self._serve_one(session)
|
||||
|
||||
now = time.monotonic()
|
||||
if now - last_gc > 1.0:
|
||||
last_gc = now
|
||||
self._gc_sessions(now)
|
||||
|
||||
def _serve_one(self, session: Session) -> None:
|
||||
item = session.take()
|
||||
if item is None:
|
||||
return
|
||||
queue_wait_ms = (time.monotonic() - item.recv_mono) * 1e3
|
||||
outcome = "ok"
|
||||
try:
|
||||
obs = codec.decode_observation(item.payload)
|
||||
self._capture("req", item.payload)
|
||||
|
||||
with self._inference_lock:
|
||||
# Belt-and-braces episode ordering: the first observation of
|
||||
# an episode also announces the boundary (one-in-flight makes
|
||||
# the reset query race-free, but a lost ack must not desync
|
||||
# us; fresh sessions start at the -1 sentinel so their first
|
||||
# request always lands here).
|
||||
if obs.episode_start or item.header.episode_id != session.episode_id:
|
||||
session.preprocessor.reset()
|
||||
session.postprocessor.reset()
|
||||
session.episode_id = item.header.episode_id
|
||||
if self._serving_mode == "exclusive":
|
||||
self._policy.reset()
|
||||
|
||||
reply = self.run_inference_request(session, item.header, obs)
|
||||
reply.queue_wait_ms = queue_wait_ms
|
||||
session.stats.last_queue_wait_ms = queue_wait_ms
|
||||
|
||||
body = codec.encode_action_chunk(reply)
|
||||
self._capture("rep", body)
|
||||
# Local ref: a re-handshake can null session.action_publisher
|
||||
# between the check and the put.
|
||||
publisher = session.action_publisher
|
||||
if publisher is not None:
|
||||
reply_header = MsgHeader(
|
||||
schema_version=SCHEMA_VERSION,
|
||||
msg_type=2, # MSG_TYPE_CHUNK
|
||||
seq_id=item.header.seq_id,
|
||||
episode_id=item.header.episode_id,
|
||||
client_mono_ns=item.header.client_mono_ns,
|
||||
session_epoch=item.header.session_epoch,
|
||||
)
|
||||
publisher.put(body, attachment=reply_header.pack())
|
||||
self._bump("requests_total")
|
||||
self._bump("superseded_total", reply.superseded_seqs)
|
||||
except Exception as e: # noqa: BLE001 — one bad request must not kill the worker
|
||||
outcome = f"error: {e}"
|
||||
session.stats.errors += 1
|
||||
self._bump("errors_total")
|
||||
logger.error(
|
||||
"Inference error for client '%s' seq=%d: %s\n%s",
|
||||
session.client_uuid,
|
||||
item.header.seq_id,
|
||||
e,
|
||||
traceback.format_exc(),
|
||||
)
|
||||
finally:
|
||||
audit_logger.info(
|
||||
json.dumps(
|
||||
{
|
||||
"session_id": session.session_id,
|
||||
"client_uuid": session.client_uuid,
|
||||
"seq_id": item.header.seq_id,
|
||||
"episode_id": item.header.episode_id,
|
||||
"queue_wait_ms": round(queue_wait_ms, 3),
|
||||
"inference_ms": round(session.stats.last_inference_ms, 3),
|
||||
"superseded": session.stats.superseded,
|
||||
"outcome": outcome,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
def _gc_sessions(self, now: float) -> None:
|
||||
for session in self.registry.snapshot():
|
||||
if (
|
||||
session.token_dropped_mono is not None
|
||||
and now - session.token_dropped_mono > _LIVELINESS_GC_GRACE_S
|
||||
):
|
||||
if self._client_token_alive(session.client_uuid):
|
||||
# The DELETE was a late echo of a previous incarnation
|
||||
# (the token key is per client, not per epoch) — the
|
||||
# client re-declared and is alive.
|
||||
session.token_dropped_mono = None
|
||||
session.alive = True
|
||||
continue
|
||||
self._close_session(session, reason="liveliness token dropped")
|
||||
elif now - session.last_seen_mono > self._manifest.session_idle_timeout_s:
|
||||
self._close_session(session, reason="idle timeout")
|
||||
|
||||
def _client_token_alive(self, client_uuid: str) -> bool:
|
||||
"""Confirm a client's liveliness token via an explicit get (GC double-check)."""
|
||||
if self._zenoh is None:
|
||||
return False
|
||||
try:
|
||||
zenoh = import_zenoh()
|
||||
replies = self._zenoh.liveliness().get(
|
||||
client_alive_key(self.prefix, client_uuid),
|
||||
handler=zenoh.handlers.FifoChannel(4),
|
||||
timeout=0.5,
|
||||
)
|
||||
deadline = time.monotonic() + 1.0
|
||||
while time.monotonic() < deadline:
|
||||
try:
|
||||
reply = replies.try_recv()
|
||||
except Exception: # channel closed: no token found # noqa: BLE001
|
||||
return False
|
||||
if reply is None:
|
||||
time.sleep(0.01)
|
||||
continue
|
||||
if reply.ok is not None:
|
||||
return True
|
||||
return False
|
||||
except Exception: # noqa: BLE001 — treat transport trouble as "not alive"
|
||||
return False
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Misc
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _bump(self, key: str, amount: float = 1) -> None:
|
||||
with self._metrics_lock:
|
||||
self.metrics[key] = self.metrics.get(key, 0) + amount
|
||||
|
||||
def _capture(self, kind: str, data: bytes) -> None:
|
||||
capture_dir = self._manifest.debug.capture_dir
|
||||
if not capture_dir:
|
||||
return
|
||||
try:
|
||||
directory = Path(capture_dir)
|
||||
directory.mkdir(parents=True, exist_ok=True)
|
||||
index = self._capture_count % max(1, self._manifest.debug.capture_max)
|
||||
(directory / f"{kind}_{index:05d}.bin").write_bytes(data)
|
||||
if kind == "rep":
|
||||
self._capture_count += 1
|
||||
except OSError as e:
|
||||
logger.warning("debug capture failed: %s", e)
|
||||
|
||||
def _start_health_server(self, port: int) -> None:
|
||||
server_ref = self
|
||||
|
||||
class Handler(http.server.BaseHTTPRequestHandler):
|
||||
def do_GET(self) -> None: # noqa: N802 — http.server API
|
||||
if self.path == "/healthz":
|
||||
worker = server_ref._worker # local ref: stop() may null it mid-read
|
||||
healthy = worker is not None and worker.is_alive()
|
||||
self.send_response(200 if healthy else 503)
|
||||
self.end_headers()
|
||||
self.wfile.write(b"ok" if healthy else b"worker dead")
|
||||
elif self.path == "/metrics":
|
||||
with server_ref._metrics_lock:
|
||||
counters = dict(server_ref.metrics)
|
||||
counters["active_sessions"] = len(server_ref.registry)
|
||||
counters["server_load"] = server_ref.server_load
|
||||
body = "".join(
|
||||
f"lerobot_policy_server_{name} {value}\n" for name, value in sorted(counters.items())
|
||||
).encode()
|
||||
self.send_response(200)
|
||||
self.send_header("Content-Type", "text/plain; version=0.0.4")
|
||||
self.end_headers()
|
||||
self.wfile.write(body)
|
||||
else:
|
||||
self.send_response(404)
|
||||
self.end_headers()
|
||||
|
||||
def log_message(self, *args: Any) -> None: # silence per-request logging
|
||||
pass
|
||||
|
||||
self._health_server = http.server.ThreadingHTTPServer(("0.0.0.0", port), Handler) # nosec B104
|
||||
threading.Thread(target=self._health_server.serve_forever, daemon=True, name="HealthHTTP").start()
|
||||
logger.info("Health/metrics on :%d (/healthz, /metrics)", port)
|
||||
@@ -0,0 +1,203 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Per-client session state and the latest-only observation mailbox.
|
||||
|
||||
The server holds **no cross-request control state**: RTC prefixes and
|
||||
delay hints arrive with every observation. What a session does hold:
|
||||
|
||||
- Per-session processor pipeline instances. Mandatory:
|
||||
``RelativeActionsProcessorStep`` caches ``_last_state`` at preprocess
|
||||
and the postprocessor reads it back — a pipeline shared across clients
|
||||
would be a race.
|
||||
- A one-slot mailbox: the newest observation wins; superseded requests
|
||||
are counted so drops stay visible to the client.
|
||||
- Counters for the audit log and ``/metrics``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from threading import Lock
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from lerobot.processor import (
|
||||
NormalizerProcessorStep,
|
||||
PolicyProcessorPipeline,
|
||||
RelativeActionsProcessorStep,
|
||||
)
|
||||
|
||||
from .schema import MsgHeader
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class MailboxItem:
|
||||
header: MsgHeader
|
||||
payload: bytes
|
||||
recv_mono: float # server-local monotonic deposit time (for queue_wait_ms)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionStats:
|
||||
requests: int = 0
|
||||
errors: int = 0
|
||||
superseded: int = 0 # observations overwritten before inference (lifetime)
|
||||
superseded_since_reply: int = 0
|
||||
last_inference_ms: float = 0.0
|
||||
last_queue_wait_ms: float = 0.0
|
||||
|
||||
|
||||
class Session:
|
||||
"""One connected robot client."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_id: str,
|
||||
client_uuid: str,
|
||||
task: str,
|
||||
robot_type: str,
|
||||
rtc_enabled: bool,
|
||||
preprocessor: PolicyProcessorPipeline,
|
||||
postprocessor: PolicyProcessorPipeline,
|
||||
action_publisher: Any = None, # zenoh.Publisher (Any: zenoh optional at import)
|
||||
) -> None:
|
||||
self.session_id = session_id
|
||||
self.client_uuid = client_uuid
|
||||
self.task = task
|
||||
self.robot_type = robot_type
|
||||
self.rtc_enabled = rtc_enabled
|
||||
self.preprocessor = preprocessor
|
||||
self.postprocessor = postprocessor
|
||||
self.action_publisher = action_publisher
|
||||
|
||||
self.episode_id = 0
|
||||
self.stats = SessionStats()
|
||||
self.alive = True
|
||||
self.last_seen_mono = time.monotonic()
|
||||
# Set when the client's liveliness token drops; GC after grace.
|
||||
self.token_dropped_mono: float | None = None
|
||||
|
||||
# Processor introspection for relative-action prefix re-anchoring
|
||||
# (mirrors RTCInferenceEngine.__init__).
|
||||
self.relative_step = next(
|
||||
(s for s in preprocessor.steps if isinstance(s, RelativeActionsProcessorStep) and s.enabled),
|
||||
None,
|
||||
)
|
||||
self.normalizer_step = next(
|
||||
(s for s in preprocessor.steps if isinstance(s, NormalizerProcessorStep)),
|
||||
None,
|
||||
)
|
||||
|
||||
self._mailbox: MailboxItem | None = None
|
||||
self._mailbox_lock = Lock()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Mailbox (deposit-only callbacks write, the inference worker reads)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def deposit(self, header: MsgHeader, payload: bytes) -> None:
|
||||
"""Latest-only deposit; counts superseded observations."""
|
||||
item = MailboxItem(header=header, payload=payload, recv_mono=time.monotonic())
|
||||
with self._mailbox_lock:
|
||||
if self._mailbox is not None:
|
||||
self.stats.superseded += 1
|
||||
self.stats.superseded_since_reply += 1
|
||||
self._mailbox = item
|
||||
self.alive = True
|
||||
self.token_dropped_mono = None
|
||||
self.last_seen_mono = item.recv_mono
|
||||
|
||||
def take(self) -> MailboxItem | None:
|
||||
with self._mailbox_lock:
|
||||
item, self._mailbox = self._mailbox, None
|
||||
return item
|
||||
|
||||
def take_superseded(self) -> int:
|
||||
"""Atomically read-and-reset the per-reply supersession counter."""
|
||||
with self._mailbox_lock:
|
||||
count = self.stats.superseded_since_reply
|
||||
self.stats.superseded_since_reply = 0
|
||||
return count
|
||||
|
||||
def has_pending(self) -> bool:
|
||||
with self._mailbox_lock:
|
||||
return self._mailbox is not None
|
||||
|
||||
def clear_mailbox(self) -> None:
|
||||
with self._mailbox_lock:
|
||||
self._mailbox = None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Episode boundary
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def reset_episode(self, episode_id: int | None = None) -> None:
|
||||
"""Clear per-episode state. The shared policy is NOT touched here."""
|
||||
self.clear_mailbox()
|
||||
self.preprocessor.reset()
|
||||
self.postprocessor.reset()
|
||||
self.episode_id = episode_id if episode_id is not None else self.episode_id + 1
|
||||
|
||||
def close(self) -> None:
|
||||
self.clear_mailbox()
|
||||
publisher = self.action_publisher
|
||||
self.action_publisher = None
|
||||
if publisher is not None:
|
||||
# Already-closed transport is fine on teardown.
|
||||
with contextlib.suppress(Exception):
|
||||
publisher.undeclare()
|
||||
|
||||
|
||||
class SessionRegistry:
|
||||
"""Thread-safe map of client_uuid → Session."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._sessions: dict[str, Session] = {}
|
||||
self._lock = Lock()
|
||||
|
||||
def add(self, session: Session) -> Session | None:
|
||||
"""Register, returning a displaced same-client session (caller closes it)."""
|
||||
with self._lock:
|
||||
old = self._sessions.get(session.client_uuid)
|
||||
self._sessions[session.client_uuid] = session
|
||||
return old
|
||||
|
||||
def get(self, client_uuid: str) -> Session | None:
|
||||
with self._lock:
|
||||
return self._sessions.get(client_uuid)
|
||||
|
||||
def remove(self, client_uuid: str, expected: Session | None = None) -> Session | None:
|
||||
"""Remove by uuid; with ``expected``, only if it is still that exact session.
|
||||
|
||||
The identity check stops a GC sweep that snapshotted an old
|
||||
session from tearing down its just-handshaked replacement.
|
||||
"""
|
||||
with self._lock:
|
||||
current = self._sessions.get(client_uuid)
|
||||
if current is None or (expected is not None and current is not expected):
|
||||
return None
|
||||
return self._sessions.pop(client_uuid)
|
||||
|
||||
def snapshot(self) -> list[Session]:
|
||||
with self._lock:
|
||||
return list(self._sessions.values())
|
||||
|
||||
def __len__(self) -> int:
|
||||
with self._lock:
|
||||
return len(self._sessions)
|
||||
@@ -0,0 +1,265 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Serving-mode classification and session capability validation.
|
||||
|
||||
Multi-tenancy is engineered, not assumed: sharing one policy instance
|
||||
across sessions is only safe when ``predict_action_chunk`` touches no
|
||||
instance state. That property has been verified per policy family and
|
||||
is encoded here as an explicit registry — never inferred.
|
||||
|
||||
- ``act``/``pi0``/``pi05``: chunk-stateless (verified in-tree).
|
||||
- ``smolvla``: populates its ``_queues`` *inside* ``predict_action_chunk``;
|
||||
with ``n_obs_steps == 1`` the queue is overwritten with the request's
|
||||
own observation before being read, so sharing is safe. With history
|
||||
(``n_obs_steps > 1``) requests would read other sessions' frames →
|
||||
exclusive.
|
||||
- ``diffusion``: ``predict_action_chunk`` reads ``_queues`` that only
|
||||
``select_action`` populates → exclusive, with the server populating
|
||||
the observation queues per request (mirroring ``select_action``).
|
||||
- Policies without a ``predict_action_chunk`` override are refused.
|
||||
- Unverified chunk-API policies default to exclusive; ``shared`` cannot
|
||||
be forced for them (the roadmap upstreams a
|
||||
``supports_stateless_chunking`` attribute to policy classes).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
|
||||
from .manifest import (
|
||||
SERVING_MODE_EXCLUSIVE,
|
||||
SERVING_MODE_SHARED,
|
||||
PolicyServerManifest,
|
||||
)
|
||||
from .schema import SessionOpenMsg, StatusMsg
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ServingClass(Enum):
|
||||
SHARED = "shared"
|
||||
EXCLUSIVE = "exclusive"
|
||||
REFUSED = "refused"
|
||||
|
||||
|
||||
# Verified chunk-stateless families (predict_action_chunk touches no
|
||||
# cross-request instance state).
|
||||
VERIFIED_CHUNK_STATELESS: frozenset[str] = frozenset({"act", "pi0", "pi05"})
|
||||
|
||||
# Families whose predict_action_chunk reads select_action-fed queues:
|
||||
# the server must populate the observation queues per request.
|
||||
QUEUE_POPULATED_IN_SELECT: frozenset[str] = frozenset({"diffusion"})
|
||||
|
||||
# Families whose predict_action_chunk accepts the RTC kwargs
|
||||
# (inference_delay / prev_chunk_left_over) — see each family's
|
||||
# ActionSelectKwargs TypedDict.
|
||||
RTC_CAPABLE: frozenset[str] = frozenset({"pi0", "pi05", "smolvla"})
|
||||
|
||||
|
||||
@dataclass
|
||||
class PolicyClassification:
|
||||
serving_class: ServingClass
|
||||
supports_rtc: bool
|
||||
needs_queue_population: bool
|
||||
reason: str
|
||||
|
||||
|
||||
def _has_chunk_api(policy: PreTrainedPolicy) -> bool:
|
||||
method = getattr(type(policy), "predict_action_chunk", None)
|
||||
return method is not None and method is not PreTrainedPolicy.predict_action_chunk
|
||||
|
||||
|
||||
def classify_policy(policy: PreTrainedPolicy) -> PolicyClassification:
|
||||
"""Classify a loaded policy into a serving class. Registry-driven, never inferred."""
|
||||
name = getattr(policy, "name", type(policy).__name__)
|
||||
|
||||
if not _has_chunk_api(policy):
|
||||
return PolicyClassification(
|
||||
ServingClass.REFUSED,
|
||||
supports_rtc=False,
|
||||
needs_queue_population=False,
|
||||
reason=f"policy '{name}' does not implement predict_action_chunk",
|
||||
)
|
||||
|
||||
supports_rtc = name in RTC_CAPABLE
|
||||
|
||||
if name in VERIFIED_CHUNK_STATELESS:
|
||||
return PolicyClassification(
|
||||
ServingClass.SHARED, supports_rtc, False, f"'{name}' is verified chunk-stateless"
|
||||
)
|
||||
|
||||
if name == "smolvla":
|
||||
n_obs_steps = getattr(policy.config, "n_obs_steps", 1)
|
||||
if n_obs_steps == 1:
|
||||
return PolicyClassification(
|
||||
ServingClass.SHARED,
|
||||
supports_rtc,
|
||||
False,
|
||||
"'smolvla' with n_obs_steps=1 overwrites its queues per request",
|
||||
)
|
||||
return PolicyClassification(
|
||||
ServingClass.EXCLUSIVE,
|
||||
supports_rtc,
|
||||
False,
|
||||
f"'smolvla' with n_obs_steps={n_obs_steps} keeps observation history across requests",
|
||||
)
|
||||
|
||||
if name in QUEUE_POPULATED_IN_SELECT:
|
||||
return PolicyClassification(
|
||||
ServingClass.EXCLUSIVE,
|
||||
supports_rtc,
|
||||
True,
|
||||
f"'{name}' predict_action_chunk reads select_action-fed queues",
|
||||
)
|
||||
|
||||
return PolicyClassification(
|
||||
ServingClass.EXCLUSIVE,
|
||||
supports_rtc,
|
||||
False,
|
||||
f"'{name}' has a chunk API but is not in the verified chunk-stateless registry",
|
||||
)
|
||||
|
||||
|
||||
def resolve_serving_mode(
|
||||
classification: PolicyClassification, manifest: PolicyServerManifest
|
||||
) -> tuple[str, int]:
|
||||
"""Resolve the final (serving_mode, max_sessions) from classification + manifest.
|
||||
|
||||
The manifest may force ``exclusive`` but can never force ``shared``
|
||||
for a policy that is not verified chunk-stateless.
|
||||
"""
|
||||
if classification.serving_class is ServingClass.REFUSED:
|
||||
raise ValueError(f"Refusing to serve this policy: {classification.reason}")
|
||||
|
||||
if manifest.serving_mode == SERVING_MODE_SHARED:
|
||||
if classification.serving_class is not ServingClass.SHARED:
|
||||
raise ValueError(
|
||||
f"serving_mode=shared is unsafe for this policy: {classification.reason}. "
|
||||
"Use serving_mode=exclusive (or auto)."
|
||||
)
|
||||
mode = SERVING_MODE_SHARED
|
||||
elif manifest.serving_mode == SERVING_MODE_EXCLUSIVE:
|
||||
mode = SERVING_MODE_EXCLUSIVE
|
||||
else: # auto
|
||||
mode = (
|
||||
SERVING_MODE_SHARED
|
||||
if classification.serving_class is ServingClass.SHARED
|
||||
else SERVING_MODE_EXCLUSIVE
|
||||
)
|
||||
|
||||
max_sessions = manifest.max_sessions
|
||||
if mode == SERVING_MODE_EXCLUSIVE and max_sessions != 1:
|
||||
logger.warning(
|
||||
"serving_mode=exclusive forces max_sessions=1 (manifest had %d)", manifest.max_sessions
|
||||
)
|
||||
max_sessions = 1
|
||||
return mode, max_sessions
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Session-open validation (fail fast, fail loud)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValidationResult:
|
||||
error: str = "" # non-empty → hard reject
|
||||
warnings: list[str] = field(default_factory=list)
|
||||
# RTC requested but unsupported → downgrade to plain chunk-append.
|
||||
rtc_downgraded: bool = False
|
||||
|
||||
@property
|
||||
def ok(self) -> bool:
|
||||
return not self.error
|
||||
|
||||
|
||||
def validate_session_open(
|
||||
msg: SessionOpenMsg,
|
||||
capabilities: StatusMsg,
|
||||
manifest: PolicyServerManifest,
|
||||
active_sessions: int,
|
||||
) -> ValidationResult:
|
||||
"""Apply the capability matrix from the design doc (§8.4)."""
|
||||
result = ValidationResult()
|
||||
|
||||
# Schema version: client must be within the server's supported range.
|
||||
if not (capabilities.min_schema_version <= msg.schema_version <= capabilities.max_schema_version):
|
||||
result.error = (
|
||||
f"schema_version {msg.schema_version} outside supported range "
|
||||
f"[{capabilities.min_schema_version}, {capabilities.max_schema_version}]"
|
||||
)
|
||||
return result
|
||||
|
||||
# Capacity: reject with current load so the client can retry another replica.
|
||||
if active_sessions >= capabilities.max_sessions:
|
||||
result.error = f"server full: {active_sessions}/{capabilities.max_sessions} sessions active"
|
||||
return result
|
||||
|
||||
# Action names AND order: the hard sync-safety contract mapping
|
||||
# chunk columns to motors.
|
||||
if capabilities.action_names and msg.action_names != capabilities.action_names:
|
||||
result.error = (
|
||||
"action feature names/order mismatch — refusing to map chunk columns to motors.\n"
|
||||
f" server: {capabilities.action_names}\n"
|
||||
f" client: {msg.action_names}"
|
||||
)
|
||||
return result
|
||||
|
||||
# State dim.
|
||||
if capabilities.state_dim and msg.state_dim and msg.state_dim != capabilities.state_dim:
|
||||
result.error = f"state dim mismatch: server={capabilities.state_dim}, client={msg.state_dim}"
|
||||
return result
|
||||
|
||||
# Camera names: the client set must cover the policy's visual features.
|
||||
missing = set(capabilities.expected_cameras) - set(msg.camera_names)
|
||||
if missing:
|
||||
result.error = (
|
||||
f"missing camera features {sorted(missing)} "
|
||||
f"(client provides {sorted(msg.camera_names)}; resolution may differ — names may not)"
|
||||
)
|
||||
return result
|
||||
|
||||
# Task pinning.
|
||||
if manifest.pin_task and msg.task and msg.task != manifest.default_task:
|
||||
result.error = f"task is pinned to {manifest.default_task!r} on this server, got {msg.task!r}"
|
||||
return result
|
||||
|
||||
# fps: warn unless strict.
|
||||
if capabilities.trained_fps and msg.fps and abs(msg.fps - capabilities.trained_fps) > 1e-6:
|
||||
fps_msg = f"client fps={msg.fps:g} != trained fps={capabilities.trained_fps:g}"
|
||||
if manifest.strict_fps:
|
||||
result.error = fps_msg + " (strict_fps=true)"
|
||||
return result
|
||||
result.warnings.append(fps_msg)
|
||||
|
||||
# Policy type sanity (informational mismatch is a warning, not fatal:
|
||||
# the action/state/camera contracts above are the binding ones).
|
||||
if msg.policy_type and capabilities.policy_type and msg.policy_type != capabilities.policy_type:
|
||||
result.warnings.append(
|
||||
f"client expected policy_type={msg.policy_type!r}, server runs {capabilities.policy_type!r}"
|
||||
)
|
||||
|
||||
# RTC: requested but unsupported → serve plain chunks, client appends.
|
||||
if msg.rtc_enabled and not capabilities.supports_rtc:
|
||||
result.rtc_downgraded = True
|
||||
result.warnings.append(
|
||||
"RTC requested but this server/policy does not support it — downgrading to chunk-append"
|
||||
)
|
||||
|
||||
return result
|
||||
@@ -0,0 +1,101 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Zenoh session construction shared by the policy server and the remote engine.
|
||||
|
||||
Verified against eclipse-zenoh 1.9 (thread-based; no asyncio API).
|
||||
Multicast scouting is always disabled — fleet "discovery" is static
|
||||
endpoint configuration plus liveliness tokens, never protocol magic.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_ZENOH_IMPORT_HINT = (
|
||||
"Remote inference requires the 'async' extra: pip install 'lerobot[async]' (eclipse-zenoh + msgpack)"
|
||||
)
|
||||
|
||||
|
||||
def import_zenoh():
|
||||
"""Import zenoh lazily with an actionable error message."""
|
||||
try:
|
||||
import zenoh
|
||||
except ImportError as e:
|
||||
raise ImportError(_ZENOH_IMPORT_HINT) from e
|
||||
return zenoh
|
||||
|
||||
|
||||
def build_zenoh_config(
|
||||
*,
|
||||
mode: str = "client",
|
||||
connect_endpoints: list[str] | None = None,
|
||||
listen_endpoints: list[str] | None = None,
|
||||
tls_root_ca_certificate: str | None = None,
|
||||
tls_connect_certificate: str | None = None,
|
||||
tls_connect_private_key: str | None = None,
|
||||
extra_config_json5: str | None = None,
|
||||
):
|
||||
"""Build a zenoh.Config (values are JSON5 strings — note the inner quoting)."""
|
||||
zenoh = import_zenoh()
|
||||
cfg = zenoh.Config()
|
||||
cfg.insert_json5("mode", json.dumps(mode))
|
||||
cfg.insert_json5("scouting/multicast/enabled", "false")
|
||||
if connect_endpoints:
|
||||
cfg.insert_json5("connect/endpoints", json.dumps(list(connect_endpoints)))
|
||||
if listen_endpoints:
|
||||
cfg.insert_json5("listen/endpoints", json.dumps(list(listen_endpoints)))
|
||||
if tls_root_ca_certificate:
|
||||
cfg.insert_json5("transport/link/tls/root_ca_certificate", json.dumps(tls_root_ca_certificate))
|
||||
if tls_connect_certificate:
|
||||
cfg.insert_json5("transport/link/tls/connect_certificate", json.dumps(tls_connect_certificate))
|
||||
if tls_connect_private_key:
|
||||
cfg.insert_json5("transport/link/tls/connect_private_key", json.dumps(tls_connect_private_key))
|
||||
if extra_config_json5:
|
||||
merged = json.loads(extra_config_json5)
|
||||
for key, value in merged.items():
|
||||
cfg.insert_json5(key, json.dumps(value))
|
||||
return cfg
|
||||
|
||||
|
||||
def action_publisher_qos(zenoh) -> dict:
|
||||
"""QoS for the action topic: RELIABLE + congestion DROP (never BLOCK) + express.
|
||||
|
||||
DROP so one dead robot uplink can never stall the server's publish
|
||||
path; a dropped chunk is recoverable by design — the client's action
|
||||
buffer keeps the robot moving and the next chunk replaces it.
|
||||
"""
|
||||
return {
|
||||
"reliability": zenoh.Reliability.RELIABLE,
|
||||
"congestion_control": zenoh.CongestionControl.DROP,
|
||||
"express": True,
|
||||
"priority": zenoh.Priority.INTERACTIVE_HIGH,
|
||||
}
|
||||
|
||||
|
||||
def obs_publisher_qos(zenoh) -> dict:
|
||||
"""QoS for the observation topic: best-effort drop, default priority.
|
||||
|
||||
Intentional drop already happened at the client's one-slot holder;
|
||||
if the uplink stalls, dropping a frame protects the control loop.
|
||||
"""
|
||||
return {
|
||||
"reliability": zenoh.Reliability.BEST_EFFORT,
|
||||
"congestion_control": zenoh.CongestionControl.DROP,
|
||||
"express": False,
|
||||
"priority": zenoh.Priority.DATA,
|
||||
}
|
||||
@@ -124,7 +124,6 @@ def make_reward_model(cfg: RewardModelConfig, **kwargs) -> PreTrainedRewardModel
|
||||
|
||||
if cfg.pretrained_path:
|
||||
kwargs["pretrained_name_or_path"] = cfg.pretrained_path
|
||||
kwargs["revision"] = cfg.pretrained_revision
|
||||
reward_model = reward_cls.from_pretrained(**kwargs)
|
||||
else:
|
||||
reward_model = reward_cls(**kwargs)
|
||||
|
||||
@@ -39,8 +39,10 @@ from .context import (
|
||||
build_rollout_context,
|
||||
)
|
||||
from .inference import (
|
||||
FallbackMode,
|
||||
InferenceEngine,
|
||||
InferenceEngineConfig,
|
||||
RemoteInferenceConfig,
|
||||
RTCInferenceConfig,
|
||||
RTCInferenceEngine,
|
||||
SyncInferenceConfig,
|
||||
@@ -70,12 +72,14 @@ __all__ = [
|
||||
"HighlightStrategyConfig",
|
||||
"EpisodicStrategy",
|
||||
"EpisodicStrategyConfig",
|
||||
"FallbackMode",
|
||||
"InferenceEngine",
|
||||
"InferenceEngineConfig",
|
||||
"PolicyContext",
|
||||
"ProcessorContext",
|
||||
"RTCInferenceConfig",
|
||||
"RTCInferenceEngine",
|
||||
"RemoteInferenceConfig",
|
||||
"RolloutConfig",
|
||||
"RolloutContext",
|
||||
"RolloutStrategy",
|
||||
|
||||
@@ -51,6 +51,7 @@ from lerobot.utils.feature_utils import combine_feature_dicts, hw_to_dataset_fea
|
||||
from .configs import BaseStrategyConfig, DAggerStrategyConfig, RolloutConfig
|
||||
from .inference import (
|
||||
InferenceEngine,
|
||||
RemoteInferenceConfig,
|
||||
RTCInferenceConfig,
|
||||
SyncInferenceConfig,
|
||||
create_inference_engine,
|
||||
@@ -113,11 +114,17 @@ class HardwareContext:
|
||||
|
||||
@dataclass
|
||||
class PolicyContext:
|
||||
"""Loaded policy and its inference engine."""
|
||||
"""Loaded policy and its inference engine.
|
||||
|
||||
policy: PreTrainedPolicy
|
||||
preprocessor: PolicyProcessorPipeline
|
||||
postprocessor: PolicyProcessorPipeline
|
||||
``policy``/``preprocessor``/``postprocessor`` are ``None`` for the
|
||||
weightless remote backend (``--inference.type=remote``): inference
|
||||
runs on a ``lerobot-policy-server`` and strategies only ever consume
|
||||
``inference``.
|
||||
"""
|
||||
|
||||
policy: PreTrainedPolicy | None
|
||||
preprocessor: PolicyProcessorPipeline | None
|
||||
postprocessor: PolicyProcessorPipeline | None
|
||||
inference: InferenceEngine
|
||||
|
||||
|
||||
@@ -172,54 +179,66 @@ def build_rollout_context(
|
||||
fails fast without touching the robot.
|
||||
"""
|
||||
is_rtc = isinstance(cfg.inference, RTCInferenceConfig)
|
||||
is_remote = isinstance(cfg.inference, RemoteInferenceConfig)
|
||||
|
||||
# --- 1. Policy (heavy I/O, but no hardware yet) -------------------
|
||||
logger.info("Loading policy from '%s'...", cfg.policy.pretrained_path)
|
||||
# Remote inference keeps the edge weightless: the config-only
|
||||
# PreTrainedConfig (already loaded by RolloutConfig.__post_init__,
|
||||
# no weight download) is all the client needs for pre-flight
|
||||
# validation and action ordering.
|
||||
policy_config = cfg.policy
|
||||
policy_class = get_policy_class(policy_config.type)
|
||||
|
||||
if hasattr(policy_config, "compile_model"):
|
||||
policy_config.compile_model = cfg.use_torch_compile
|
||||
|
||||
if policy_config.type == "vqbet" and cfg.device == "mps":
|
||||
raise NotImplementedError(
|
||||
"Current implementation of VQBeT does not support `mps` backend. "
|
||||
"Please use `cpu` or `cuda` backend."
|
||||
policy = None
|
||||
if is_remote:
|
||||
logger.info(
|
||||
"Remote inference: weightless client for '%s' (no weights downloaded)",
|
||||
cfg.policy.pretrained_path,
|
||||
)
|
||||
|
||||
if policy_config.use_peft:
|
||||
from peft import PeftConfig, PeftModel
|
||||
|
||||
peft_path = policy_config.pretrained_path
|
||||
peft_config = PeftConfig.from_pretrained(peft_path)
|
||||
policy = policy_class.from_pretrained(
|
||||
pretrained_name_or_path=peft_config.base_model_name_or_path, config=policy_config
|
||||
)
|
||||
policy = PeftModel.from_pretrained(policy, peft_path, config=peft_config)
|
||||
else:
|
||||
policy = policy_class.from_pretrained(policy_config.pretrained_path, config=policy_config)
|
||||
logger.info("Loading policy from '%s'...", cfg.policy.pretrained_path)
|
||||
policy_class = get_policy_class(policy_config.type)
|
||||
|
||||
if is_rtc:
|
||||
policy.config.rtc_config = cfg.inference.rtc
|
||||
if hasattr(policy, "init_rtc_processor"):
|
||||
policy.init_rtc_processor()
|
||||
if hasattr(policy_config, "compile_model"):
|
||||
policy_config.compile_model = cfg.use_torch_compile
|
||||
|
||||
policy = policy.to(cfg.device)
|
||||
policy.eval()
|
||||
logger.info("Policy loaded: type=%s, device=%s", policy_config.type, cfg.device)
|
||||
if policy_config.type == "vqbet" and cfg.device == "mps":
|
||||
raise NotImplementedError(
|
||||
"Current implementation of VQBeT does not support `mps` backend. "
|
||||
"Please use `cpu` or `cuda` backend."
|
||||
)
|
||||
|
||||
if cfg.use_torch_compile and policy.type not in ("pi0", "pi05"):
|
||||
try:
|
||||
if hasattr(torch, "compile"):
|
||||
compile_kwargs = {
|
||||
"backend": cfg.torch_compile_backend,
|
||||
"mode": cfg.torch_compile_mode,
|
||||
"options": {"triton.cudagraphs": False},
|
||||
}
|
||||
policy.predict_action_chunk = torch.compile(policy.predict_action_chunk, **compile_kwargs)
|
||||
logger.info("torch.compile applied to predict_action_chunk")
|
||||
except Exception as e:
|
||||
logger.warning("Failed to apply torch.compile: %s", e)
|
||||
if policy_config.use_peft:
|
||||
from peft import PeftConfig, PeftModel
|
||||
|
||||
peft_path = policy_config.pretrained_path
|
||||
peft_config = PeftConfig.from_pretrained(peft_path)
|
||||
policy = policy_class.from_pretrained(
|
||||
pretrained_name_or_path=peft_config.base_model_name_or_path, config=policy_config
|
||||
)
|
||||
policy = PeftModel.from_pretrained(policy, peft_path, config=peft_config)
|
||||
else:
|
||||
policy = policy_class.from_pretrained(policy_config.pretrained_path, config=policy_config)
|
||||
|
||||
if is_rtc:
|
||||
policy.config.rtc_config = cfg.inference.rtc
|
||||
if hasattr(policy, "init_rtc_processor"):
|
||||
policy.init_rtc_processor()
|
||||
|
||||
policy = policy.to(cfg.device)
|
||||
policy.eval()
|
||||
logger.info("Policy loaded: type=%s, device=%s", policy_config.type, cfg.device)
|
||||
|
||||
if cfg.use_torch_compile and policy.type not in ("pi0", "pi05"):
|
||||
try:
|
||||
if hasattr(torch, "compile"):
|
||||
compile_kwargs = {
|
||||
"backend": cfg.torch_compile_backend,
|
||||
"mode": cfg.torch_compile_mode,
|
||||
"options": {"triton.cudagraphs": False},
|
||||
}
|
||||
policy.predict_action_chunk = torch.compile(policy.predict_action_chunk, **compile_kwargs)
|
||||
logger.info("torch.compile applied to predict_action_chunk")
|
||||
except Exception as e:
|
||||
logger.warning("Failed to apply torch.compile: %s", e)
|
||||
|
||||
# --- 2. Robot-side processors (user-supplied or defaults) --------
|
||||
if (
|
||||
@@ -378,31 +397,36 @@ def build_rollout_context(
|
||||
logger.info("Dataset ready: %s (%d existing episodes)", dataset.repo_id, dataset.num_episodes)
|
||||
|
||||
# --- 6. Policy pre/post processors (needs dataset stats if any) ---
|
||||
dataset_stats = None
|
||||
if dataset is not None:
|
||||
dataset_stats = rename_stats(
|
||||
dataset.meta.stats,
|
||||
cfg.rename_map,
|
||||
# Remote inference runs the policy processors server-side (per
|
||||
# session); the edge ships canonical dataset-format observations.
|
||||
preprocessor = None
|
||||
postprocessor = None
|
||||
if not is_remote:
|
||||
dataset_stats = None
|
||||
if dataset is not None:
|
||||
dataset_stats = rename_stats(
|
||||
dataset.meta.stats,
|
||||
cfg.rename_map,
|
||||
)
|
||||
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=policy_config,
|
||||
pretrained_path=cfg.policy.pretrained_path,
|
||||
dataset_stats=dataset_stats,
|
||||
preprocessor_overrides={
|
||||
"device_processor": {"device": cfg.device},
|
||||
"rename_observations_processor": {"rename_map": cfg.rename_map},
|
||||
},
|
||||
)
|
||||
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=policy_config,
|
||||
pretrained_path=cfg.policy.pretrained_path,
|
||||
dataset_stats=dataset_stats,
|
||||
preprocessor_overrides={
|
||||
"device_processor": {"device": cfg.device},
|
||||
"rename_observations_processor": {"rename_map": cfg.rename_map},
|
||||
},
|
||||
)
|
||||
|
||||
if isinstance(cfg.inference, SyncInferenceConfig) and any(
|
||||
isinstance(step, RelativeActionsProcessorStep) and step.enabled
|
||||
for step in getattr(preprocessor, "steps", ())
|
||||
):
|
||||
raise NotImplementedError(
|
||||
"SyncInferenceEngine does not support policies with relative actions for now."
|
||||
"Use --inference.type=rtc or remove relative action processor steps from the policy pipeline."
|
||||
)
|
||||
if isinstance(cfg.inference, SyncInferenceConfig) and any(
|
||||
isinstance(step, RelativeActionsProcessorStep) and step.enabled
|
||||
for step in getattr(preprocessor, "steps", ())
|
||||
):
|
||||
raise NotImplementedError(
|
||||
"SyncInferenceEngine does not support policies with relative actions for now."
|
||||
"Use --inference.type=rtc or remove relative action processor steps from the policy pipeline."
|
||||
)
|
||||
|
||||
# --- 7. Inference strategy (needs policy + pre/post + hardware) --
|
||||
logger.info(
|
||||
@@ -425,6 +449,8 @@ def build_rollout_context(
|
||||
use_torch_compile=cfg.use_torch_compile,
|
||||
compile_warmup_inferences=cfg.compile_warmup_inferences,
|
||||
shutdown_event=shutdown_event,
|
||||
policy_config=policy_config,
|
||||
rename_map=cfg.rename_map,
|
||||
)
|
||||
|
||||
# --- 8. Assemble ---------------------------------------------------
|
||||
|
||||
@@ -14,13 +14,18 @@
|
||||
|
||||
"""Inference engine package — backend-agnostic action production.
|
||||
|
||||
Concrete backends (``sync``, ``rtc``, ...) expose the same small interface so
|
||||
rollout strategies never branch on which backend is in use.
|
||||
Concrete backends (``sync``, ``rtc``, ``remote``, ...) expose the same
|
||||
small interface so rollout strategies never branch on which backend is
|
||||
in use.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from .base import InferenceEngine
|
||||
from .factory import (
|
||||
FallbackMode,
|
||||
InferenceEngineConfig,
|
||||
RemoteInferenceConfig,
|
||||
RTCInferenceConfig,
|
||||
SyncInferenceConfig,
|
||||
create_inference_engine,
|
||||
@@ -29,11 +34,23 @@ from .rtc import RTCInferenceEngine
|
||||
from .sync import SyncInferenceEngine
|
||||
|
||||
__all__ = [
|
||||
"FallbackMode",
|
||||
"InferenceEngine",
|
||||
"InferenceEngineConfig",
|
||||
"RTCInferenceConfig",
|
||||
"RTCInferenceEngine",
|
||||
"RemoteInferenceConfig",
|
||||
"RemoteInferenceEngine",
|
||||
"SyncInferenceConfig",
|
||||
"SyncInferenceEngine",
|
||||
"create_inference_engine",
|
||||
]
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
# Lazy: RemoteInferenceEngine pulls in msgpack/zenoh ('async' extra).
|
||||
if name == "RemoteInferenceEngine":
|
||||
from .remote import RemoteInferenceEngine
|
||||
|
||||
return RemoteInferenceEngine
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
|
||||
@@ -14,9 +14,9 @@
|
||||
|
||||
"""Inference engine configs and factory.
|
||||
|
||||
Selection is explicit via ``--inference.type=sync|rtc``. Adding a new
|
||||
backend requires registering its config subclass and dispatching it in
|
||||
:func:`create_inference_engine`.
|
||||
Selection is explicit via ``--inference.type=sync|rtc|remote``. Adding a
|
||||
new backend requires registering its config subclass and dispatching it
|
||||
in :func:`create_inference_engine`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -24,10 +24,12 @@ from __future__ import annotations
|
||||
import abc
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from enum import StrEnum
|
||||
from threading import Event
|
||||
|
||||
import draccus
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.processor import PolicyProcessorPipeline
|
||||
@@ -74,6 +76,73 @@ class RTCInferenceConfig(InferenceEngineConfig):
|
||||
queue_threshold: int = 30
|
||||
|
||||
|
||||
class FallbackMode(StrEnum):
|
||||
"""What ``get_action`` returns when the remote queue runs dry (STALLED)."""
|
||||
|
||||
HOLD = "hold" # return None: the robot holds its last commanded position
|
||||
REPEAT_LAST = "repeat_last" # re-send the last executed action
|
||||
ZERO = "zero" # explicit zero command (required for velocity-controlled robots)
|
||||
|
||||
|
||||
@InferenceEngineConfig.register_subclass("remote")
|
||||
@dataclass
|
||||
class RemoteInferenceConfig(InferenceEngineConfig):
|
||||
"""Network inference against a ``lerobot-policy-server`` over Zenoh.
|
||||
|
||||
The edge stays weightless: ``--policy.path`` resolves to a
|
||||
config-only ``PreTrainedConfig`` (no weight download) used for
|
||||
pre-flight validation and action ordering. Requires the ``async``
|
||||
extra (``pip install 'lerobot[async]'``).
|
||||
"""
|
||||
|
||||
# Transport: robots dial out to a zenoh router (NAT-friendly).
|
||||
connect_endpoint: str = "tcp/localhost:7447"
|
||||
# "client" via a zenohd router (production) | "peer" direct (LAN/tests).
|
||||
zenoh_mode: str = "client"
|
||||
tls_ca: str | None = None
|
||||
tls_cert: str | None = None
|
||||
tls_key: str | None = None
|
||||
|
||||
# Service addressing: which (model, revision, task) key tree to dial.
|
||||
# service_model_id defaults to --policy.path; service_task to the
|
||||
# rollout task. These must match the server manifest's namespace.
|
||||
service_model_id: str = ""
|
||||
service_revision: str = "main"
|
||||
service_task: str = ""
|
||||
|
||||
# Identity: "" → a fresh uuid4 per run. Set a stable ID per robot for
|
||||
# fleet-wide log correlation and per-robot router ACLs.
|
||||
client_uuid: str = ""
|
||||
|
||||
# Observation encoding: JPEG quality (0 = raw, LAN/debug only).
|
||||
jpeg_quality: int = 90
|
||||
|
||||
# Self-clocking: request the next chunk when the local queue holds
|
||||
# less than this many seconds of playback.
|
||||
buffer_time_s: float = 0.5
|
||||
|
||||
# Safety: never execute an action whose source observation is older
|
||||
# than this (bounds open-loop execution after a network stall).
|
||||
max_action_age_s: float = 3.0
|
||||
# Fallback when the queue runs dry (see FallbackMode).
|
||||
fallback: FallbackMode = FallbackMode.HOLD
|
||||
|
||||
# Watchdogs & reconnection.
|
||||
degraded_after_s: float = 1.0
|
||||
request_timeout_s: float = 5.0
|
||||
handshake_timeout_s: float = 2.0
|
||||
reconnect_initial_backoff_s: float = 0.5
|
||||
reconnect_max_backoff_s: float = 10.0
|
||||
max_offline_s: float = 60.0
|
||||
|
||||
# RTC settings (enabled → replace-merge with prefix conditioning when
|
||||
# the server supports it; otherwise downgraded to chunk-append).
|
||||
rtc: RTCConfig = field(default_factory=RTCConfig)
|
||||
|
||||
# Free-form labels forwarded in the session handshake (telemetry only).
|
||||
tags: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Factory
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -82,9 +151,9 @@ class RTCInferenceConfig(InferenceEngineConfig):
|
||||
def create_inference_engine(
|
||||
config: InferenceEngineConfig,
|
||||
*,
|
||||
policy: PreTrainedPolicy,
|
||||
preprocessor: PolicyProcessorPipeline,
|
||||
postprocessor: PolicyProcessorPipeline,
|
||||
policy: PreTrainedPolicy | None,
|
||||
preprocessor: PolicyProcessorPipeline | None,
|
||||
postprocessor: PolicyProcessorPipeline | None,
|
||||
robot_wrapper: ThreadSafeRobot,
|
||||
hw_features: dict,
|
||||
dataset_features: dict,
|
||||
@@ -95,10 +164,19 @@ def create_inference_engine(
|
||||
use_torch_compile: bool = False,
|
||||
compile_warmup_inferences: int = 2,
|
||||
shutdown_event: Event | None = None,
|
||||
policy_config: PreTrainedConfig | None = None,
|
||||
rename_map: dict[str, str] | None = None,
|
||||
) -> InferenceEngine:
|
||||
"""Instantiate the appropriate inference engine from a config object."""
|
||||
"""Instantiate the appropriate inference engine from a config object.
|
||||
|
||||
``policy``/``preprocessor``/``postprocessor`` are required for the
|
||||
local backends (``sync``, ``rtc``) and must be ``None``-free there;
|
||||
the ``remote`` backend is weightless and needs only ``policy_config``.
|
||||
"""
|
||||
logger.info("Creating inference engine: %s", config.type)
|
||||
if isinstance(config, SyncInferenceConfig):
|
||||
if policy is None or preprocessor is None or postprocessor is None:
|
||||
raise ValueError("sync inference requires a loaded policy and processors")
|
||||
return SyncInferenceEngine(
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
@@ -110,6 +188,8 @@ def create_inference_engine(
|
||||
robot_type=robot_wrapper.robot_type,
|
||||
)
|
||||
if isinstance(config, RTCInferenceConfig):
|
||||
if policy is None or preprocessor is None or postprocessor is None:
|
||||
raise ValueError("rtc inference requires a loaded policy and processors")
|
||||
return RTCInferenceEngine(
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
@@ -125,4 +205,25 @@ def create_inference_engine(
|
||||
rtc_queue_threshold=config.queue_threshold,
|
||||
shutdown_event=shutdown_event,
|
||||
)
|
||||
if isinstance(config, RemoteInferenceConfig):
|
||||
if policy_config is None:
|
||||
raise ValueError("remote inference requires policy_config (from config-only --policy.path)")
|
||||
if use_torch_compile:
|
||||
logger.warning("--use_torch_compile is ignored with remote inference (server-side concern)")
|
||||
if device not in (None, "cpu"):
|
||||
logger.warning("--device=%s is ignored with remote inference (server-side concern)", device)
|
||||
# Lazy import: eclipse-zenoh/msgpack live behind the 'async' extra.
|
||||
from .remote import RemoteInferenceEngine
|
||||
|
||||
return RemoteInferenceEngine(
|
||||
config=config,
|
||||
policy_config=policy_config,
|
||||
hw_features=hw_features,
|
||||
ordered_action_keys=ordered_action_keys,
|
||||
task=task,
|
||||
fps=fps,
|
||||
robot_type=robot_wrapper.robot_type,
|
||||
rename_map=rename_map,
|
||||
shutdown_event=shutdown_event,
|
||||
)
|
||||
raise ValueError(f"Unknown inference engine type: {type(config).__name__}")
|
||||
|
||||
@@ -0,0 +1,851 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Remote inference engine: network-decoupled policy inference over Zenoh.
|
||||
|
||||
The same architecture as :class:`RTCInferenceEngine` with the thread
|
||||
boundary replaced by a network boundary. The edge stays **weightless**
|
||||
(no policy weights, no policy processors); a ``lerobot-policy-server``
|
||||
runs the heavy half. All chunk state — leftover prefixes, latency
|
||||
tracking, delay computation — lives client-side in the existing
|
||||
``ActionQueue``/``LatencyTracker`` machinery, so the server is stateless
|
||||
per request and a server crash loses zero control state.
|
||||
|
||||
Threading model:
|
||||
- **Main thread** (strategy loop): ``notify_observation`` writes a
|
||||
latest-only slot; ``get_action`` pops the local queue and applies the
|
||||
staleness bound + fallback ladder. Never any I/O.
|
||||
- **Network worker** (one daemon thread): self-clocked by
|
||||
``buffer_time_s``, publishes one observation, awaits its chunk (or
|
||||
timeout), merges, repeats. One-in-flight is a *correctness*
|
||||
requirement: ``idx_before``/prefix snapshots must serialize with
|
||||
merges.
|
||||
- **Zenoh threads**: deposit-only callbacks (chunk → bounded queue,
|
||||
liveliness → event).
|
||||
|
||||
Clock iron rule: wall-clock instants never cross machines. The header's
|
||||
``client_mono_ns`` is opaque to the server and echoed back; the server
|
||||
reports only durations.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
import math
|
||||
import queue as queue_module
|
||||
import time
|
||||
import traceback
|
||||
import uuid as uuid_module
|
||||
from threading import Event, Lock, Thread
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.policies.rtc import ActionQueue, LatencyTracker
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.policy_server import codec
|
||||
from lerobot.policy_server.schema import (
|
||||
MSG_TYPE_OBS,
|
||||
SCHEMA_VERSION,
|
||||
MsgHeader,
|
||||
ObservationMsg,
|
||||
ResetMsg,
|
||||
SessionAckMsg,
|
||||
SessionCloseMsg,
|
||||
SessionOpenMsg,
|
||||
action_key,
|
||||
client_alive_key,
|
||||
obs_key,
|
||||
reset_key,
|
||||
sanitize_key_segment,
|
||||
server_alive_key,
|
||||
service_prefix,
|
||||
session_key,
|
||||
status_key,
|
||||
)
|
||||
from lerobot.policy_server.zenoh_utils import build_zenoh_config, import_zenoh, obs_publisher_qos
|
||||
from lerobot.utils.constants import OBS_STATE, OBS_STR
|
||||
from lerobot.utils.feature_utils import build_dataset_frame
|
||||
|
||||
from .base import InferenceEngine
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
|
||||
from .factory import RemoteInferenceConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_IDLE_SLEEP_S = 0.01
|
||||
_MAX_CONSECUTIVE_WORKER_ERRORS = 10
|
||||
|
||||
|
||||
class ClientState:
|
||||
"""Fail-safe state machine states (see the design doc §9.2)."""
|
||||
|
||||
CONNECTING = "CONNECTING"
|
||||
STREAMING = "STREAMING"
|
||||
DEGRADED = "DEGRADED"
|
||||
STALLED = "STALLED"
|
||||
RECONNECTING = "RECONNECTING"
|
||||
DEAD = "DEAD"
|
||||
|
||||
|
||||
class RemoteInferenceEngine(InferenceEngine):
|
||||
"""``--inference.type=remote``: weightless edge client of a policy server."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: RemoteInferenceConfig,
|
||||
policy_config: PreTrainedConfig,
|
||||
hw_features: dict,
|
||||
ordered_action_keys: list[str],
|
||||
task: str,
|
||||
fps: float,
|
||||
robot_type: str,
|
||||
rename_map: dict[str, str] | None = None,
|
||||
shutdown_event: Event | None = None,
|
||||
) -> None:
|
||||
self._config = config
|
||||
self._policy_config = policy_config
|
||||
self._hw_features = hw_features
|
||||
self._ordered_action_keys = list(ordered_action_keys)
|
||||
self._task = task
|
||||
self._fps = float(fps)
|
||||
self._dt = 1.0 / self._fps
|
||||
self._robot_type = robot_type
|
||||
self._rename_map = dict(rename_map or {})
|
||||
self._global_shutdown_event = shutdown_event
|
||||
|
||||
self._client_uuid = sanitize_key_segment(config.client_uuid or uuid_module.uuid4().hex)
|
||||
model_id = config.service_model_id or getattr(policy_config, "pretrained_path", "") or "model"
|
||||
self._prefix = service_prefix(model_id, config.service_revision, config.service_task or task)
|
||||
|
||||
# Latest-only observation slot (identical to rtc.py's _obs_holder).
|
||||
self._obs_holder: dict[str, Any] = {"obs": None}
|
||||
self._obs_lock = Lock()
|
||||
|
||||
self._action_queue: ActionQueue | None = None
|
||||
self._latency_tracker = LatencyTracker()
|
||||
self._effective_rtc: RTCConfig = config.rtc
|
||||
|
||||
# Replies deposited by the zenoh callback, consumed by the worker.
|
||||
self._reply_queue: queue_module.Queue[tuple[MsgHeader, bytes]] = queue_module.Queue(maxsize=4)
|
||||
|
||||
self._zenoh = None
|
||||
self._obs_publisher = None
|
||||
self._declarations: list[Any] = []
|
||||
self._alive_token = None
|
||||
self._server_alive = Event()
|
||||
|
||||
self._worker: Thread | None = None
|
||||
self._stop_event = Event()
|
||||
self._active = Event()
|
||||
self._dead = Event()
|
||||
self._session_ack: SessionAckMsg | None = None
|
||||
|
||||
self.state = ClientState.CONNECTING
|
||||
self._state_lock = Lock()
|
||||
self._seq_id = 0
|
||||
self._epoch = 0
|
||||
self._episode_id = 0
|
||||
self._pending_reset = False
|
||||
|
||||
# Staleness bookkeeping: client-monotonic send time of the
|
||||
# observation that produced the current queue contents.
|
||||
# _anchor_lock serializes {merge + anchor update} (worker),
|
||||
# {staleness clear} (control thread), and {reset clear} so a
|
||||
# stale chunk can never merge into a freshly-reset queue and the
|
||||
# safety path can never clear a just-merged one.
|
||||
self._anchor_lock = Lock()
|
||||
self._chunk_anchor_mono: float | None = None
|
||||
self._last_chunk_mono: float | None = None
|
||||
self._offline_since_mono: float | None = None
|
||||
self._last_action: torch.Tensor | None = None
|
||||
|
||||
self.stats: dict[str, float] = {
|
||||
"requests": 0,
|
||||
"timeouts": 0,
|
||||
"merges": 0,
|
||||
"stale_drops": 0,
|
||||
"fallback_ticks": 0,
|
||||
"reconnects": 0,
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Lifecycle
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@property
|
||||
def ready(self) -> bool:
|
||||
"""Session opened, capabilities validated, server warmed up."""
|
||||
ack = self._session_ack
|
||||
return ack is not None and ack.warmed_up and not self._dead.is_set()
|
||||
|
||||
@property
|
||||
def failed(self) -> bool:
|
||||
return self._dead.is_set()
|
||||
|
||||
@property
|
||||
def action_queue(self) -> ActionQueue | None:
|
||||
return self._action_queue
|
||||
|
||||
def start(self) -> None:
|
||||
"""Open transport, handshake, start the network worker.
|
||||
|
||||
Raises on initial connection/validation failure so a bad
|
||||
deployment aborts before the robot moves (reconnect logic only
|
||||
guards established sessions).
|
||||
"""
|
||||
zenoh = import_zenoh()
|
||||
cfg = self._config
|
||||
self._zenoh = zenoh.open(
|
||||
build_zenoh_config(
|
||||
mode=cfg.zenoh_mode,
|
||||
connect_endpoints=[cfg.connect_endpoint] if cfg.connect_endpoint else None,
|
||||
tls_root_ca_certificate=cfg.tls_ca,
|
||||
tls_connect_certificate=cfg.tls_cert,
|
||||
tls_connect_private_key=cfg.tls_key,
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
ack = self._handshake(initial=True)
|
||||
except Exception:
|
||||
# Fail fast without leaking the transport session.
|
||||
with contextlib.suppress(Exception):
|
||||
self._zenoh.close()
|
||||
self._zenoh = None
|
||||
raise
|
||||
self._configure_from_ack(ack)
|
||||
|
||||
handlers = zenoh.handlers
|
||||
self._declarations.append(
|
||||
self._zenoh.declare_subscriber(
|
||||
action_key(self._prefix, self._client_uuid), handlers.Callback(self._on_chunk)
|
||||
)
|
||||
)
|
||||
self._obs_publisher = self._zenoh.declare_publisher(
|
||||
obs_key(self._prefix, self._client_uuid), **obs_publisher_qos(zenoh)
|
||||
)
|
||||
self._declarations.append(self._obs_publisher)
|
||||
self._server_alive.set()
|
||||
self._declarations.append(
|
||||
self._zenoh.liveliness().declare_subscriber(
|
||||
server_alive_key(self._prefix), handlers.Callback(self._on_server_liveliness), history=True
|
||||
)
|
||||
)
|
||||
self._alive_token = self._zenoh.liveliness().declare_token(
|
||||
client_alive_key(self._prefix, self._client_uuid)
|
||||
)
|
||||
|
||||
self._stop_event.clear()
|
||||
self._dead.clear()
|
||||
self._active.set()
|
||||
self._worker = Thread(target=self._worker_loop, daemon=True, name="RemoteInference")
|
||||
self._worker.start()
|
||||
logger.info(
|
||||
"Remote inference started: prefix=%s client=%s rtc=%s",
|
||||
self._prefix,
|
||||
self._client_uuid,
|
||||
self._effective_rtc.enabled,
|
||||
)
|
||||
|
||||
def stop(self) -> None:
|
||||
logger.info("Stopping remote inference engine...")
|
||||
self._stop_event.set()
|
||||
self._active.clear()
|
||||
if self._worker is not None and self._worker.is_alive():
|
||||
# Worst case the worker is mid-handshake inside _enter_reconnect.
|
||||
join_timeout = max(3.0, self._config.handshake_timeout_s + self._config.request_timeout_s + 2.0)
|
||||
self._worker.join(timeout=join_timeout)
|
||||
if self._worker.is_alive():
|
||||
logger.warning("Remote inference worker did not join")
|
||||
self._worker = None
|
||||
|
||||
if self._zenoh is not None:
|
||||
# Best-effort graceful close; the server also GCs on liveliness drop.
|
||||
with contextlib.suppress(Exception):
|
||||
self._control_query(
|
||||
session_key(self._prefix),
|
||||
codec.encode_session_close(
|
||||
SessionCloseMsg(
|
||||
client_uuid=self._client_uuid,
|
||||
session_id=self._session_ack.session_id if self._session_ack else "",
|
||||
)
|
||||
),
|
||||
timeout=1.0,
|
||||
)
|
||||
if self._alive_token is not None:
|
||||
with contextlib.suppress(Exception):
|
||||
self._alive_token.undeclare()
|
||||
self._alive_token = None
|
||||
for declaration in self._declarations:
|
||||
with contextlib.suppress(Exception):
|
||||
declaration.undeclare()
|
||||
self._declarations.clear()
|
||||
self._obs_publisher = None
|
||||
with contextlib.suppress(Exception):
|
||||
self._zenoh.close()
|
||||
self._zenoh = None
|
||||
logger.info("Remote inference engine stopped")
|
||||
|
||||
def pause(self) -> None:
|
||||
"""Stop publishing observations; the local queue stays intact."""
|
||||
logger.info("Pausing remote inference (publishing stops, queue intact)")
|
||||
self._active.clear()
|
||||
|
||||
def resume(self) -> None:
|
||||
logger.info("Resuming remote inference")
|
||||
self._active.set()
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Episode boundary: clear local chunk state, notify the server.
|
||||
|
||||
The acked reset query runs on the worker thread (never I/O on the
|
||||
control thread); thanks to per-request server statelessness a
|
||||
lost ack only costs a warning — the next observation announces
|
||||
the new episode in its header anyway.
|
||||
"""
|
||||
logger.info("Resetting remote inference state (queue + episode)")
|
||||
with self._anchor_lock:
|
||||
if self._action_queue is not None:
|
||||
self._action_queue.clear()
|
||||
self._chunk_anchor_mono = None
|
||||
with self._state_lock:
|
||||
self._episode_id += 1
|
||||
self._pending_reset = True
|
||||
with self._obs_lock:
|
||||
# The previous episode's final frame must not seed the new
|
||||
# episode's first request.
|
||||
self._obs_holder["obs"] = None
|
||||
self._last_action = None
|
||||
# LatencyTracker intentionally survives reset: latency is
|
||||
# episode-invariant (parity with local RTC).
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Action production (main thread — never any I/O)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def notify_observation(self, obs: dict) -> None:
|
||||
with self._obs_lock:
|
||||
self._obs_holder["obs"] = obs
|
||||
|
||||
def get_action(self, obs_frame: dict | None) -> torch.Tensor | None:
|
||||
queue = self._action_queue
|
||||
if queue is None:
|
||||
return None
|
||||
|
||||
# Staleness bound (sync safety): never execute an action whose
|
||||
# source observation is older than max_action_age_s. The lock
|
||||
# makes the check-and-clear atomic with the worker's merge.
|
||||
with self._anchor_lock:
|
||||
anchor = self._chunk_anchor_mono
|
||||
if (
|
||||
anchor is not None
|
||||
and queue.qsize() > 0
|
||||
and time.monotonic() - anchor > self._config.max_action_age_s
|
||||
):
|
||||
logger.warning(
|
||||
"Dropping %d stale actions (older than %.1fs) — applying fallback",
|
||||
queue.qsize(),
|
||||
self._config.max_action_age_s,
|
||||
)
|
||||
self.stats["stale_drops"] += 1
|
||||
queue.clear()
|
||||
self._chunk_anchor_mono = None
|
||||
|
||||
action = queue.get()
|
||||
if action is not None:
|
||||
self._last_action = action
|
||||
return action
|
||||
|
||||
self._set_state(ClientState.STALLED if self.state == ClientState.DEGRADED else self.state)
|
||||
return self._fallback_action()
|
||||
|
||||
def _fallback_action(self) -> torch.Tensor | None:
|
||||
from .factory import FallbackMode
|
||||
|
||||
mode = self._config.fallback
|
||||
if mode == FallbackMode.REPEAT_LAST and self._last_action is not None:
|
||||
self.stats["fallback_ticks"] += 1
|
||||
return self._last_action.clone()
|
||||
if mode == FallbackMode.ZERO:
|
||||
# For velocity-controlled robots "send nothing" means "keep
|
||||
# last velocity" — an explicit zero command is the safe stop.
|
||||
self.stats["fallback_ticks"] += 1
|
||||
return torch.zeros(len(self._ordered_action_keys))
|
||||
return None # HOLD: send_next_action tolerates None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Handshake & control plane
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _handshake(self, initial: bool) -> SessionAckMsg:
|
||||
"""status (pre-flight) + session open; raises on rejection."""
|
||||
cfg = self._config
|
||||
status_data = self._control_query(status_key(self._prefix), b"", timeout=cfg.handshake_timeout_s)
|
||||
if status_data is None:
|
||||
raise ConnectionError(
|
||||
f"No policy server answered status query at {status_key(self._prefix)!r} "
|
||||
f"via {cfg.connect_endpoint!r} (timeout {cfg.handshake_timeout_s}s)"
|
||||
)
|
||||
status = codec.decode_status(status_data)
|
||||
logger.info(
|
||||
"Server status: model=%s@%s policy=%s sessions=%d/%d warmed_up=%s",
|
||||
status.model_repo,
|
||||
status.model_revision,
|
||||
status.policy_type,
|
||||
status.active_sessions,
|
||||
status.max_sessions,
|
||||
status.warmed_up,
|
||||
)
|
||||
|
||||
open_msg = SessionOpenMsg(
|
||||
client_uuid=self._client_uuid,
|
||||
robot_type=self._robot_type,
|
||||
policy_type=getattr(self._policy_config, "type", ""),
|
||||
fps=self._fps,
|
||||
action_names=self._ordered_action_keys,
|
||||
camera_names=self._wire_camera_names(),
|
||||
state_dim=self._state_dim(),
|
||||
schema_version=SCHEMA_VERSION,
|
||||
rtc_enabled=cfg.rtc.enabled,
|
||||
task=self._task,
|
||||
tags=cfg.tags,
|
||||
)
|
||||
ack_data = self._control_query(
|
||||
session_key(self._prefix), codec.encode_session_open(open_msg), timeout=cfg.request_timeout_s
|
||||
)
|
||||
if ack_data is None:
|
||||
raise ConnectionError("Session open query timed out")
|
||||
ack = codec.decode_session_ack(ack_data)
|
||||
if not ack.accepted:
|
||||
raise ConnectionError(f"Policy server rejected the session: {ack.error}")
|
||||
for warning in ack.warnings:
|
||||
logger.warning("Server warning: %s", warning)
|
||||
|
||||
# Hard sync-safety contract: chunk columns map to motors by order.
|
||||
if ack.action_names and ack.action_names != self._ordered_action_keys:
|
||||
raise ValueError(
|
||||
"Action name/order mismatch between server policy and this robot.\n"
|
||||
f" server: {ack.action_names}\n client: {self._ordered_action_keys}"
|
||||
)
|
||||
if not initial and self._session_ack is not None:
|
||||
previous = self._session_ack
|
||||
if (ack.model_repo, ack.model_revision) != (previous.model_repo, previous.model_revision):
|
||||
raise ValueError(
|
||||
f"Server model changed across reconnect "
|
||||
f"({previous.model_repo}@{previous.model_revision} → "
|
||||
f"{ack.model_repo}@{ack.model_revision}) — refusing to execute wrong-model chunks"
|
||||
)
|
||||
return ack
|
||||
|
||||
def _configure_from_ack(self, ack: SessionAckMsg) -> None:
|
||||
rtc_requested = self._config.rtc.enabled
|
||||
rtc_effective = rtc_requested and ack.supports_rtc
|
||||
if rtc_requested and not rtc_effective:
|
||||
logger.warning("RTC downgraded to chunk-append (server does not support RTC)")
|
||||
if self._action_queue is not None and self._action_queue.cfg.enabled != rtc_effective:
|
||||
# The queue's merge semantics (replace vs append) were fixed at
|
||||
# session start; a server whose RTC capability changed across a
|
||||
# reconnect would corrupt them.
|
||||
raise ValueError(
|
||||
"Server RTC capability changed across reconnect "
|
||||
f"(queue merge mode {'replace' if self._action_queue.cfg.enabled else 'append'} "
|
||||
f"vs server RTC={rtc_effective}) — refusing to continue"
|
||||
)
|
||||
self._effective_rtc = RTCConfig(
|
||||
enabled=rtc_effective,
|
||||
prefix_attention_schedule=self._config.rtc.prefix_attention_schedule,
|
||||
max_guidance_weight=self._config.rtc.max_guidance_weight,
|
||||
execution_horizon=ack.rtc_execution_horizon or self._config.rtc.execution_horizon,
|
||||
debug=self._config.rtc.debug,
|
||||
debug_maxlen=self._config.rtc.debug_maxlen,
|
||||
)
|
||||
if self._action_queue is None:
|
||||
self._action_queue = ActionQueue(self._effective_rtc)
|
||||
self._session_ack = ack
|
||||
|
||||
def _control_query(self, key: str, payload: bytes, timeout: float) -> bytes | None:
|
||||
"""One request/reply on the control plane; None on timeout/no-server."""
|
||||
zenoh = import_zenoh()
|
||||
try:
|
||||
replies = self._zenoh.get(
|
||||
key,
|
||||
handler=zenoh.handlers.FifoChannel(4),
|
||||
payload=payload,
|
||||
timeout=timeout,
|
||||
)
|
||||
deadline = time.monotonic() + timeout + 0.5
|
||||
while time.monotonic() < deadline:
|
||||
try:
|
||||
reply = replies.try_recv()
|
||||
except Exception: # zenoh.ZError: channel closed (no queryable / finished)
|
||||
return None
|
||||
if reply is None:
|
||||
time.sleep(0.005)
|
||||
continue
|
||||
if reply.ok is not None:
|
||||
return reply.ok.payload.to_bytes()
|
||||
return None # Reply.err (e.g. b"Timeout")
|
||||
return None
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.warning("Control query %s failed: %s", key, e)
|
||||
return None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Zenoh callbacks (deposit-only)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _on_chunk(self, sample: Any) -> None:
|
||||
try:
|
||||
attachment = sample.attachment
|
||||
if attachment is None:
|
||||
return
|
||||
header = MsgHeader.unpack(attachment.to_bytes())
|
||||
item = (header, sample.payload.to_bytes())
|
||||
try:
|
||||
self._reply_queue.put_nowait(item)
|
||||
except queue_module.Full:
|
||||
# Drop oldest, keep newest.
|
||||
with contextlib.suppress(queue_module.Empty):
|
||||
self._reply_queue.get_nowait()
|
||||
with contextlib.suppress(queue_module.Full):
|
||||
self._reply_queue.put_nowait(item)
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.error("chunk callback error: %s", e)
|
||||
|
||||
def _on_server_liveliness(self, sample: Any) -> None:
|
||||
try:
|
||||
import zenoh
|
||||
|
||||
if sample.kind == zenoh.SampleKind.DELETE:
|
||||
logger.warning("Server liveliness token dropped")
|
||||
self._server_alive.clear()
|
||||
else:
|
||||
self._server_alive.set()
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.error("liveliness callback error: %s", e)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Network worker
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _worker_loop(self) -> None:
|
||||
consecutive_errors = 0
|
||||
try:
|
||||
while not self._stop_event.is_set():
|
||||
if not self._active.is_set():
|
||||
time.sleep(_IDLE_SLEEP_S)
|
||||
continue
|
||||
try:
|
||||
self._maybe_send_reset()
|
||||
|
||||
if not self._server_alive.is_set():
|
||||
self._enter_reconnect("server liveliness dropped")
|
||||
continue
|
||||
|
||||
queue = self._action_queue
|
||||
if queue is not None and queue.qsize() * self._dt > self._config.buffer_time_s:
|
||||
time.sleep(_IDLE_SLEEP_S)
|
||||
continue
|
||||
|
||||
with self._obs_lock:
|
||||
obs = self._obs_holder.get("obs")
|
||||
if obs is None:
|
||||
time.sleep(_IDLE_SLEEP_S)
|
||||
continue
|
||||
|
||||
self._request_cycle(obs)
|
||||
consecutive_errors = 0
|
||||
except ConnectionError as e:
|
||||
# Raised by reconnect on hard contract violations.
|
||||
raise e
|
||||
except Exception as e: # noqa: BLE001 — transient worker errors retry
|
||||
consecutive_errors += 1
|
||||
logger.error(
|
||||
"Remote inference worker error (%d/%d): %s",
|
||||
consecutive_errors,
|
||||
_MAX_CONSECUTIVE_WORKER_ERRORS,
|
||||
e,
|
||||
)
|
||||
logger.debug(traceback.format_exc())
|
||||
if consecutive_errors >= _MAX_CONSECUTIVE_WORKER_ERRORS:
|
||||
raise
|
||||
time.sleep(0.5)
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.error("Fatal error in remote inference worker: %s", e)
|
||||
logger.error(traceback.format_exc())
|
||||
self._go_dead(str(e))
|
||||
|
||||
def _request_cycle(self, obs: dict) -> None:
|
||||
"""Publish one observation and merge its chunk (one-in-flight)."""
|
||||
cfg = self._config
|
||||
queue = self._action_queue
|
||||
|
||||
obs_frame = build_dataset_frame(self._hw_features, obs, prefix=OBS_STR)
|
||||
if self._rename_map:
|
||||
obs_frame = {self._rename_map.get(k, k): v for k, v in obs_frame.items()}
|
||||
|
||||
state = obs_frame.pop(OBS_STATE, None)
|
||||
images = {k: v for k, v in obs_frame.items() if isinstance(v, np.ndarray) and v.ndim == 3}
|
||||
|
||||
with self._state_lock:
|
||||
self._seq_id += 1
|
||||
seq_id = self._seq_id
|
||||
episode_id = self._episode_id
|
||||
epoch = self._epoch
|
||||
|
||||
# Snapshot RTC state (must precede the publish; merge validates
|
||||
# against idx_before).
|
||||
idx_before = queue.get_action_index()
|
||||
prefix_model: np.ndarray | None = None
|
||||
prefix_robot: np.ndarray | None = None
|
||||
delay_steps = 0
|
||||
if self._effective_rtc.enabled:
|
||||
horizon = self._effective_rtc.execution_horizon
|
||||
left_over = queue.get_left_over()
|
||||
if left_over is not None and left_over.numel():
|
||||
prefix_model = left_over[:horizon].to(torch.float32).numpy()
|
||||
processed_left_over = queue.get_processed_left_over()
|
||||
if processed_left_over is not None and processed_left_over.numel():
|
||||
prefix_robot = processed_left_over[:horizon].to(torch.float32).numpy()
|
||||
max_latency = self._latency_tracker.max() if len(self._latency_tracker) else 0.0
|
||||
delay_steps = math.ceil(max_latency / self._dt) if max_latency else 0
|
||||
|
||||
# A reset/reconnect between the counter snapshot and the prefix
|
||||
# snapshot would pair a new episode id with old-episode prefixes
|
||||
# — skip the cycle instead.
|
||||
with self._state_lock:
|
||||
if (self._episode_id, self._epoch) != (episode_id, epoch):
|
||||
return
|
||||
|
||||
header = MsgHeader(
|
||||
schema_version=SCHEMA_VERSION,
|
||||
msg_type=MSG_TYPE_OBS,
|
||||
seq_id=seq_id,
|
||||
episode_id=episode_id,
|
||||
client_mono_ns=time.monotonic_ns(),
|
||||
session_epoch=epoch,
|
||||
)
|
||||
msg = ObservationMsg(
|
||||
state=state,
|
||||
images=images,
|
||||
task=self._task,
|
||||
inference_delay_steps=delay_steps,
|
||||
prefix_model=prefix_model,
|
||||
prefix_robot=prefix_robot,
|
||||
episode_start=(queue.qsize() == 0 and idx_before == 0 and self._chunk_anchor_mono is None),
|
||||
jpeg_quality=cfg.jpeg_quality,
|
||||
)
|
||||
|
||||
t_send = time.perf_counter()
|
||||
self._obs_publisher.put(codec.encode_observation(msg), attachment=header.pack())
|
||||
self.stats["requests"] += 1
|
||||
|
||||
reply = self._await_chunk(seq_id, episode_id, epoch, timeout=cfg.request_timeout_s)
|
||||
if reply is None:
|
||||
self.stats["timeouts"] += 1
|
||||
self._on_request_timeout()
|
||||
return
|
||||
|
||||
chunk = codec.decode_action_chunk(reply)
|
||||
if chunk.chunk_model is None or chunk.chunk_robot is None:
|
||||
# A persistently malformed server must still trip the
|
||||
# degradation ladder, not stall in nominal state.
|
||||
logger.warning("Chunk for seq=%d had empty tensors — dropping", seq_id)
|
||||
self.stats["timeouts"] += 1
|
||||
self._on_request_timeout()
|
||||
return
|
||||
|
||||
latency = time.perf_counter() - t_send
|
||||
real_delay = math.ceil(latency / self._dt)
|
||||
with self._anchor_lock:
|
||||
# reset() takes the same lock before clearing: either the
|
||||
# reset fully precedes this merge (episode changed → drop the
|
||||
# stale chunk) or the merge completes first (and the reset
|
||||
# then clears it) — a stale chunk can never survive a reset.
|
||||
with self._state_lock:
|
||||
if (self._episode_id, self._epoch) != (episode_id, epoch):
|
||||
logger.debug("Dropping chunk seq=%d: episode/epoch changed mid-flight", seq_id)
|
||||
return
|
||||
queue.merge(
|
||||
torch.from_numpy(np.ascontiguousarray(chunk.chunk_model)),
|
||||
torch.from_numpy(np.ascontiguousarray(chunk.chunk_robot)),
|
||||
real_delay,
|
||||
idx_before,
|
||||
)
|
||||
self._chunk_anchor_mono = time.monotonic() - latency # ≈ when the source obs was sent
|
||||
self._latency_tracker.add(latency)
|
||||
self._last_chunk_mono = time.monotonic()
|
||||
self._offline_since_mono = None
|
||||
self.stats["merges"] += 1
|
||||
self._set_state(ClientState.STREAMING)
|
||||
logger.debug(
|
||||
"merge: seq=%d latency=%.0fms delay=%d queue=%d server(inf=%.0fms wait=%.0fms load=%.2f)",
|
||||
seq_id,
|
||||
latency * 1e3,
|
||||
real_delay,
|
||||
queue.qsize(),
|
||||
chunk.inference_ms,
|
||||
chunk.queue_wait_ms,
|
||||
chunk.server_load,
|
||||
)
|
||||
|
||||
def _await_chunk(self, seq_id: int, episode_id: int, epoch: int, timeout: float) -> bytes | None:
|
||||
"""Wait for the chunk answering the latest outstanding request.
|
||||
|
||||
Stale replies (older seq/episode/epoch) are dropped — under
|
||||
one-in-flight a late chunk can only ever answer an older request.
|
||||
"""
|
||||
deadline = time.monotonic() + timeout
|
||||
while not self._stop_event.is_set():
|
||||
remaining = deadline - time.monotonic()
|
||||
if remaining <= 0:
|
||||
return None
|
||||
try:
|
||||
header, payload = self._reply_queue.get(timeout=min(remaining, 0.1))
|
||||
except queue_module.Empty:
|
||||
continue
|
||||
if header.session_epoch != epoch or header.episode_id != episode_id:
|
||||
continue # stale epoch/episode (reset or reconnect happened)
|
||||
if header.seq_id != seq_id:
|
||||
continue # late reply to a superseded request
|
||||
return payload
|
||||
return None
|
||||
|
||||
def _maybe_send_reset(self) -> None:
|
||||
with self._state_lock:
|
||||
pending, episode_id = self._pending_reset, self._episode_id
|
||||
self._pending_reset = False
|
||||
if pending and self._zenoh is not None:
|
||||
ack_data = self._control_query(
|
||||
reset_key(self._prefix, self._client_uuid),
|
||||
codec.encode_reset(ResetMsg(client_uuid=self._client_uuid, episode_id=episode_id)),
|
||||
timeout=1.0,
|
||||
)
|
||||
if ack_data is None:
|
||||
# Harmless: the server is stateless per request and the next
|
||||
# observation header announces the new episode anyway.
|
||||
logger.warning("Reset ack not received (continuing — header carries the episode bump)")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Degradation / reconnect / death
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _on_request_timeout(self) -> None:
|
||||
if self._stop_event.is_set():
|
||||
# _await_chunk aborted by a normal stop(), not by the network.
|
||||
return
|
||||
now = time.monotonic()
|
||||
if self._offline_since_mono is None:
|
||||
self._offline_since_mono = now
|
||||
offline_for = now - self._offline_since_mono
|
||||
|
||||
queue = self._action_queue
|
||||
if queue is not None and queue.qsize() > 0:
|
||||
self._set_state(ClientState.DEGRADED)
|
||||
else:
|
||||
self._set_state(ClientState.STALLED)
|
||||
|
||||
last = self._last_chunk_mono or 0.0
|
||||
if (now - last if last else offline_for) >= self._config.degraded_after_s:
|
||||
logger.warning(
|
||||
"No chunk for %.1fs (queue=%d) — %s",
|
||||
offline_for,
|
||||
queue.qsize() if queue else 0,
|
||||
self.state,
|
||||
)
|
||||
if offline_for > self._config.max_offline_s:
|
||||
self._go_dead(f"offline for {offline_for:.0f}s (> max_offline_s)")
|
||||
return
|
||||
if not self._server_alive.is_set() or offline_for >= 2 * self._config.request_timeout_s:
|
||||
self._enter_reconnect(f"request timeouts for {offline_for:.0f}s")
|
||||
|
||||
def _enter_reconnect(self, reason: str) -> None:
|
||||
"""Backoff + re-handshake loop. Hard contract violations → DEAD."""
|
||||
self._set_state(ClientState.RECONNECTING)
|
||||
logger.warning("Reconnecting: %s", reason)
|
||||
if self._offline_since_mono is None:
|
||||
self._offline_since_mono = time.monotonic()
|
||||
backoff = self._config.reconnect_initial_backoff_s
|
||||
while not self._stop_event.is_set():
|
||||
if not self._active.is_set():
|
||||
# Paused (e.g. DAgger human correction): keep trying to
|
||||
# reconnect, but a pause must never burn the offline budget
|
||||
# into a mid-correction shutdown.
|
||||
self._offline_since_mono = time.monotonic()
|
||||
offline_for = time.monotonic() - self._offline_since_mono
|
||||
if offline_for > self._config.max_offline_s:
|
||||
self._go_dead(f"offline for {offline_for:.0f}s (> max_offline_s)")
|
||||
return
|
||||
self._stop_event.wait(timeout=backoff)
|
||||
if self._stop_event.is_set():
|
||||
return
|
||||
backoff = min(backoff * 2, self._config.reconnect_max_backoff_s)
|
||||
try:
|
||||
with self._state_lock:
|
||||
self._epoch += 1
|
||||
ack = self._handshake(initial=False)
|
||||
self._configure_from_ack(ack)
|
||||
except ValueError as e:
|
||||
# Capability/schema/model mismatch: never execute wrong-model chunks.
|
||||
self._go_dead(str(e))
|
||||
return
|
||||
except Exception as e: # noqa: BLE001 — server still down, keep trying
|
||||
logger.info("Re-handshake failed (%s) — retrying in %.1fs", e, backoff)
|
||||
continue
|
||||
# A successful handshake is proof of life even if the liveliness
|
||||
# PUT was missed or hasn't been delivered yet.
|
||||
self._server_alive.set()
|
||||
# The offline budget is only reset by the next successful merge:
|
||||
# a server that handshakes but never delivers chunks must still
|
||||
# run out of budget and go DEAD.
|
||||
self.stats["reconnects"] += 1
|
||||
self._set_state(ClientState.STREAMING)
|
||||
logger.info("Reconnected (epoch=%d, session=%s)", self._epoch, ack.session_id)
|
||||
return
|
||||
|
||||
def _go_dead(self, reason: str) -> None:
|
||||
if self._dead.is_set():
|
||||
return
|
||||
logger.error("Remote inference DEAD: %s", reason)
|
||||
self._set_state(ClientState.DEAD)
|
||||
self._dead.set()
|
||||
if self._global_shutdown_event is not None:
|
||||
self._global_shutdown_event.set()
|
||||
|
||||
def _set_state(self, new_state: str) -> None:
|
||||
if new_state != self.state:
|
||||
logger.info("Client state: %s → %s", self.state, new_state)
|
||||
self.state = new_state
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Feature helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _wire_camera_names(self) -> list[str]:
|
||||
names = [
|
||||
key for key, feature in self._hw_features.items() if feature.get("dtype") in ("image", "video")
|
||||
]
|
||||
return [self._rename_map.get(name, name) for name in names]
|
||||
|
||||
def _state_dim(self) -> int:
|
||||
state_feature = self._hw_features.get(OBS_STATE)
|
||||
if state_feature and state_feature.get("names"):
|
||||
return len(state_feature["names"])
|
||||
return 0
|
||||
@@ -1,206 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""``lerobot-annotate`` — populate ``language_persistent`` and
|
||||
``language_events`` columns on a LeRobot dataset.
|
||||
|
||||
Annotations live directly in ``data/chunk-*/file-*.parquet``.
|
||||
|
||||
Example:
|
||||
|
||||
uv run lerobot-annotate \\
|
||||
--root=/path/to/dataset \\
|
||||
--vlm.model_id=Qwen/Qwen2.5-VL-7B-Instruct
|
||||
|
||||
For distributed runs, see ``examples/annotations/run_hf_job.py``.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from lerobot.annotations.steerable_pipeline.config import AnnotationPipelineConfig
|
||||
from lerobot.annotations.steerable_pipeline.executor import Executor
|
||||
from lerobot.annotations.steerable_pipeline.frames import make_frame_provider
|
||||
from lerobot.annotations.steerable_pipeline.modules import (
|
||||
GeneralVqaModule,
|
||||
InterjectionsAndSpeechModule,
|
||||
PlanSubtasksMemoryModule,
|
||||
)
|
||||
from lerobot.annotations.steerable_pipeline.validator import StagingValidator
|
||||
from lerobot.annotations.steerable_pipeline.vlm_client import make_vlm_client
|
||||
from lerobot.annotations.steerable_pipeline.writer import LanguageColumnsWriter
|
||||
from lerobot.configs import parser
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _resolve_root(cfg: AnnotationPipelineConfig) -> Path:
|
||||
if cfg.root is not None:
|
||||
return Path(cfg.root)
|
||||
if cfg.repo_id is not None:
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
return Path(snapshot_download(repo_id=cfg.repo_id, repo_type="dataset"))
|
||||
raise ValueError("Either --root or --repo_id must be provided.")
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def annotate(cfg: AnnotationPipelineConfig) -> None:
|
||||
"""Run the steerable annotation pipeline against a dataset."""
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
|
||||
root = _resolve_root(cfg)
|
||||
logger.info("annotate: root=%s", root)
|
||||
|
||||
vlm = make_vlm_client(cfg.vlm)
|
||||
frame_provider = make_frame_provider(root, camera_key=cfg.vlm.camera_key, video_backend=cfg.video_backend)
|
||||
# Surface the resolved cameras up front so a silent vqa-module no-op
|
||||
# is obvious in job output rather than discovered post-hoc by counting
|
||||
# parquet rows.
|
||||
cam_keys = list(getattr(frame_provider, "camera_keys", []) or [])
|
||||
logger.info(
|
||||
"annotate: frame_provider default camera=%r, all cameras=%s",
|
||||
getattr(frame_provider, "camera_key", None),
|
||||
cam_keys,
|
||||
)
|
||||
if cfg.vqa.enabled and not cam_keys:
|
||||
logger.warning(
|
||||
"annotate: the vqa module is enabled but no cameras were "
|
||||
"resolved — it will produce zero VQA rows. Check "
|
||||
"meta/info.json for observation.images.* features, or pass "
|
||||
"--vlm.camera_key=<key> to seed the cameras list."
|
||||
)
|
||||
plan = PlanSubtasksMemoryModule(vlm=vlm, config=cfg.plan, frame_provider=frame_provider)
|
||||
interjections = InterjectionsAndSpeechModule(
|
||||
vlm=vlm, config=cfg.interjections, seed=cfg.seed, frame_provider=frame_provider
|
||||
)
|
||||
vqa = GeneralVqaModule(vlm=vlm, config=cfg.vqa, seed=cfg.seed, frame_provider=frame_provider)
|
||||
writer = LanguageColumnsWriter()
|
||||
validator = StagingValidator(
|
||||
dataset_camera_keys=tuple(getattr(frame_provider, "camera_keys", []) or []) or None,
|
||||
)
|
||||
|
||||
executor = Executor(
|
||||
config=cfg,
|
||||
plan=plan,
|
||||
interjections=interjections,
|
||||
vqa=vqa,
|
||||
writer=writer,
|
||||
validator=validator,
|
||||
)
|
||||
summary = executor.run(root)
|
||||
logger.info("annotate: wrote %d shard(s)", len(summary.written_paths))
|
||||
for phase in summary.phases:
|
||||
logger.info(
|
||||
"annotate: phase=%s processed=%d skipped=%d",
|
||||
phase.name,
|
||||
phase.episodes_processed,
|
||||
phase.episodes_skipped,
|
||||
)
|
||||
if summary.validation_report.warnings:
|
||||
for w in summary.validation_report.warnings:
|
||||
logger.warning(w)
|
||||
|
||||
if cfg.push_to_hub:
|
||||
if cfg.repo_id is None and cfg.new_repo_id is None:
|
||||
raise ValueError(
|
||||
"--push_to_hub requires --repo_id or --new_repo_id (the dataset repo to push to)."
|
||||
)
|
||||
_push_to_hub(root, cfg)
|
||||
|
||||
|
||||
def _push_to_hub(root: Path, cfg: AnnotationPipelineConfig) -> None:
|
||||
"""Upload the annotated dataset directory to the Hub.
|
||||
|
||||
Pushes to ``cfg.new_repo_id`` when set, otherwise back to ``cfg.repo_id``.
|
||||
"""
|
||||
from huggingface_hub import HfApi # noqa: PLC0415
|
||||
|
||||
repo_id = cfg.new_repo_id or cfg.repo_id
|
||||
commit_message = cfg.push_commit_message or "Add steerable annotations (lerobot-annotate)"
|
||||
api = HfApi()
|
||||
print(f"[lerobot-annotate] creating/locating dataset repo {repo_id}...", flush=True)
|
||||
api.create_repo(
|
||||
repo_id=repo_id,
|
||||
repo_type="dataset",
|
||||
private=cfg.push_private,
|
||||
exist_ok=True,
|
||||
)
|
||||
print(f"[lerobot-annotate] uploading {root} -> {repo_id}...", flush=True)
|
||||
commit_info = api.upload_folder(
|
||||
folder_path=str(root),
|
||||
repo_id=repo_id,
|
||||
repo_type="dataset",
|
||||
commit_message=commit_message,
|
||||
ignore_patterns=[".annotate_staging/**", "**/.DS_Store"],
|
||||
)
|
||||
print(f"[lerobot-annotate] uploaded to https://huggingface.co/datasets/{repo_id}", flush=True)
|
||||
|
||||
# Tag the upload with the codebase version. ``LeRobotDatasetMetadata``
|
||||
# resolves the dataset revision via ``get_safe_version`` which scans
|
||||
# for tags like ``v3.0``; without a tag it raises
|
||||
# ``RevisionNotFoundError``. Read the version straight from the
|
||||
# dataset's own ``meta/info.json`` so we tag whatever the writer
|
||||
# actually wrote (no accidental drift if the codebase floor moves).
|
||||
from lerobot.datasets.dataset_metadata import CODEBASE_VERSION # noqa: PLC0415
|
||||
|
||||
info_path = root / "meta" / "info.json"
|
||||
version_tag = CODEBASE_VERSION
|
||||
if info_path.exists():
|
||||
try:
|
||||
from lerobot.utils.io_utils import load_json # noqa: PLC0415
|
||||
|
||||
info = load_json(info_path)
|
||||
ds_version = info.get("codebase_version")
|
||||
if isinstance(ds_version, str) and ds_version.startswith("v"):
|
||||
version_tag = ds_version
|
||||
except Exception as exc: # noqa: BLE001
|
||||
print(
|
||||
f"[lerobot-annotate] could not read codebase_version from info.json ({exc}); falling back to {version_tag}",
|
||||
flush=True,
|
||||
)
|
||||
revision = getattr(commit_info, "oid", None)
|
||||
tag_kwargs = {
|
||||
"repo_id": repo_id,
|
||||
"tag": version_tag,
|
||||
"repo_type": "dataset",
|
||||
}
|
||||
if revision is not None:
|
||||
tag_kwargs["revision"] = revision
|
||||
|
||||
try:
|
||||
from contextlib import suppress # noqa: PLC0415
|
||||
|
||||
from huggingface_hub.errors import RevisionNotFoundError # noqa: PLC0415
|
||||
|
||||
with suppress(RevisionNotFoundError):
|
||||
api.delete_tag(repo_id, tag=version_tag, repo_type="dataset")
|
||||
api.create_tag(**tag_kwargs)
|
||||
print(f"[lerobot-annotate] tagged {repo_id} as {version_tag}", flush=True)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
print(
|
||||
f"[lerobot-annotate] WARNING: could not create tag {version_tag!r} on {repo_id}: {exc}. "
|
||||
"Dataset is uploaded but ``LeRobotDataset`` won't be able to load it until it's tagged. "
|
||||
"Run: from huggingface_hub import HfApi; "
|
||||
f"HfApi().create_tag({repo_id!r}, tag={version_tag!r}, repo_type='dataset', exist_ok=True)",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
annotate()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -94,14 +94,6 @@ Merge multiple datasets from a list of local dataset paths:
|
||||
--operation.repo_ids "['pusht_train', 'pusht_val']" \
|
||||
--operation.roots "['/path/to/pusht_train', '/path/to/pusht_val']"
|
||||
|
||||
Merge multiple datasets while keeping one file per source file (no video/data stitching):
|
||||
lerobot-edit-dataset \
|
||||
--new_repo_id lerobot/pusht_merged \
|
||||
--operation.type merge \
|
||||
--operation.repo_ids "['lerobot/pusht_train', 'lerobot/pusht_val']" \
|
||||
--operation.concatenate_videos false \
|
||||
--operation.concatenate_data false
|
||||
|
||||
Remove camera feature:
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
@@ -265,9 +257,6 @@ class SplitConfig(OperationConfig):
|
||||
class MergeConfig(OperationConfig):
|
||||
repo_ids: list[str] | None = None
|
||||
roots: list[str] | None = None
|
||||
# When False, keep one file per source file instead of packing into shards.
|
||||
concatenate_videos: bool = True
|
||||
concatenate_data: bool = True
|
||||
|
||||
|
||||
@OperationConfig.register_subclass("remove_feature")
|
||||
@@ -472,8 +461,6 @@ def handle_merge(cfg: EditDatasetConfig) -> None:
|
||||
datasets,
|
||||
output_repo_id=cfg.new_repo_id,
|
||||
output_dir=output_dir,
|
||||
concatenate_videos=cfg.operation.concatenate_videos,
|
||||
concatenate_data=cfg.operation.concatenate_data,
|
||||
)
|
||||
|
||||
logging.info(f"Merged dataset saved to {output_dir}")
|
||||
|
||||
@@ -0,0 +1,93 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Serve a pretrained policy to remote ``lerobot-rollout`` clients over Zenoh.
|
||||
|
||||
One process = one pre-warmed (model, revision, dtype, device) on one GPU.
|
||||
Robots connect with ``lerobot-rollout --inference.type=remote``.
|
||||
|
||||
Usage examples
|
||||
--------------
|
||||
|
||||
Serve a model from a YAML manifest::
|
||||
|
||||
lerobot-policy-server --manifest server.yaml
|
||||
|
||||
Minimal manifest::
|
||||
|
||||
model:
|
||||
repo_or_path: lerobot/pi0_towels
|
||||
device: cuda
|
||||
default_task: "fold the towel"
|
||||
max_sessions: 5
|
||||
zenoh:
|
||||
mode: client
|
||||
connect_endpoints: ["tcp/router.gpu-cluster.internal:7447"]
|
||||
|
||||
Everything in the manifest can also be set directly on the CLI::
|
||||
|
||||
lerobot-policy-server \\
|
||||
--model.repo_or_path=lerobot/pi0_towels \\
|
||||
--default_task="fold the towel" \\
|
||||
--zenoh.mode=peer --zenoh.listen_endpoints='["tcp/0.0.0.0:7447"]'
|
||||
|
||||
SIGTERM/SIGINT drains gracefully: the server drops its liveliness token
|
||||
(clients ride their action buffers through the drain), finishes the
|
||||
in-flight inference, and exits.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import signal
|
||||
import sys
|
||||
|
||||
from lerobot.configs import parser
|
||||
from lerobot.policy_server.manifest import PolicyServerManifest
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def policy_server(manifest: PolicyServerManifest):
|
||||
init_logging()
|
||||
from lerobot.policy_server.server import PolicyServer
|
||||
|
||||
server = PolicyServer(manifest)
|
||||
server.load_policy()
|
||||
|
||||
def _drain(signum, frame): # noqa: ARG001
|
||||
logger.info("Signal %s received — draining", signum)
|
||||
server.stop()
|
||||
sys.exit(0)
|
||||
|
||||
signal.signal(signal.SIGTERM, _drain)
|
||||
server.start()
|
||||
server.serve_forever()
|
||||
|
||||
|
||||
def main():
|
||||
# `--manifest path.yaml` is sugar for draccus's `--config_path`.
|
||||
sys.argv = [
|
||||
arg.replace("--manifest=", "--config_path=") if arg.startswith("--manifest=") else arg
|
||||
for arg in sys.argv
|
||||
]
|
||||
if "--manifest" in sys.argv:
|
||||
sys.argv[sys.argv.index("--manifest")] = "--config_path"
|
||||
policy_server()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -31,6 +31,7 @@ Inference backends
|
||||
------------------
|
||||
--inference.type=sync One policy call per control tick (default)
|
||||
--inference.type=rtc Real-Time Chunking for slow VLA models
|
||||
--inference.type=remote Network inference via lerobot-policy-server (weightless edge)
|
||||
|
||||
Usage examples
|
||||
--------------
|
||||
@@ -145,6 +146,19 @@ Usage examples
|
||||
--dataset.camera_encoder.vcodec=h264 \\
|
||||
--dataset.camera_encoder.preset=fast \\
|
||||
--dataset.camera_encoder.extra_options={"tune": "film", "profile:v": "high", "bf": 2}
|
||||
|
||||
# Sentry mode — remote inference against a lerobot-policy-server (weightless edge)
|
||||
lerobot-rollout \\
|
||||
--strategy.type=sentry \\
|
||||
--strategy.upload_every_n_episodes=5 \\
|
||||
--policy.path=lerobot/pi0_base \\
|
||||
--inference.type=remote \\
|
||||
--inference.connect_endpoint=tcp/router.gpu-cluster.internal:7447 \\
|
||||
--inference.rtc.execution_horizon=10 \\
|
||||
--robot.type=so100_follower \\
|
||||
--robot.port=/dev/ttyACM0 \\
|
||||
--dataset.repo_id=user/rollout_sentry_data \\
|
||||
--dataset.single_task="patrol" --duration=3600
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
@@ -36,8 +36,6 @@ from tqdm import tqdm
|
||||
from lerobot.common.train_utils import (
|
||||
get_step_checkpoint_dir,
|
||||
get_step_identifier,
|
||||
load_training_batch_size,
|
||||
load_training_num_processes,
|
||||
load_training_state,
|
||||
save_checkpoint,
|
||||
update_last_checkpoint,
|
||||
@@ -45,7 +43,7 @@ from lerobot.common.train_utils import (
|
||||
from lerobot.common.wandb_utils import WandBLogger
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.datasets import EpisodeAwareSampler, compute_sampler_state, make_dataset
|
||||
from lerobot.datasets import EpisodeAwareSampler, make_dataset
|
||||
from lerobot.envs import close_envs, make_env, make_env_pre_post_processors
|
||||
from lerobot.optim.factory import make_optimizer_and_scheduler
|
||||
from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors
|
||||
@@ -239,17 +237,18 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
# Dataset loading synchronization: the global main process downloads once to the shared
|
||||
# dataset root, then a barrier lets every other rank read the already-populated copy.
|
||||
# LeRobotDataset skips its snapshot_download when try_load() succeeds, so no rank re-downloads.
|
||||
if is_main_process:
|
||||
logging.info("Creating dataset")
|
||||
# Dataset loading synchronization: each node's local main process downloads first to avoid
|
||||
# race conditions (the global main process only exists on node 0, so gating on it would let
|
||||
# all ranks of the other nodes download and build the Arrow cache concurrently).
|
||||
if accelerator.is_local_main_process:
|
||||
if is_main_process:
|
||||
logging.info("Creating dataset")
|
||||
dataset = make_dataset(cfg)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# Other ranks read from the shared copy populated by the main process.
|
||||
if not is_main_process:
|
||||
# Now all other processes can safely load the dataset from the local cache
|
||||
if not accelerator.is_local_main_process:
|
||||
dataset = make_dataset(cfg)
|
||||
|
||||
# Create environment used for evaluating checkpoints during training on simulation data.
|
||||
@@ -345,7 +344,6 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg.policy,
|
||||
pretrained_path=processor_pretrained_path,
|
||||
pretrained_revision=getattr(cfg.policy, "pretrained_revision", None),
|
||||
**processor_kwargs,
|
||||
)
|
||||
|
||||
@@ -394,47 +392,22 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
|
||||
|
||||
# create dataloader for offline training
|
||||
if not cfg.dataset.streaming:
|
||||
# All non-streaming (map-style) datasets use EpisodeAwareSampler.
|
||||
# The order is a pure function of (seed, epoch), so every rank independently produces the
|
||||
# same permutation. accelerate then shards it disjointly across ranks via BatchSamplerShard
|
||||
# without needing a `generator` attribute to synchronize an RNG, and resume is sample-exact.
|
||||
if hasattr(active_cfg, "drop_n_last_frames"):
|
||||
shuffle = False
|
||||
# A dedicated generator (rather than the global torch RNG) lets accelerator.prepare
|
||||
# synchronize the shuffle permutation across ranks, keeping batch shards disjoint even
|
||||
# when ranks consume the global RNG asymmetrically (e.g. eval on the main process only).
|
||||
sampler_generator = torch.Generator()
|
||||
if cfg.seed is not None:
|
||||
sampler_generator.manual_seed(cfg.seed)
|
||||
sampler = EpisodeAwareSampler(
|
||||
dataset.meta.episodes["dataset_from_index"],
|
||||
dataset.meta.episodes["dataset_to_index"],
|
||||
episode_indices_to_use=dataset.episodes,
|
||||
drop_n_last_frames=getattr(active_cfg, "drop_n_last_frames", 0),
|
||||
drop_n_last_frames=active_cfg.drop_n_last_frames,
|
||||
shuffle=True,
|
||||
seed=cfg.seed if cfg.seed is not None else 0,
|
||||
generator=sampler_generator,
|
||||
)
|
||||
if cfg.resume and step > 0:
|
||||
# The resume offset depends on the (num_processes, batch_size) that produced `step`, so
|
||||
# use the values recorded in the checkpoint (falling back to the current ones for older
|
||||
# ckpts that did not store them).
|
||||
saved_num_processes = load_training_num_processes(cfg.checkpoint_path)
|
||||
saved_batch_size = load_training_batch_size(cfg.checkpoint_path)
|
||||
ckpt_num_processes = saved_num_processes or accelerator.num_processes
|
||||
ckpt_batch_size = saved_batch_size or cfg.batch_size
|
||||
if is_main_process and saved_num_processes not in (None, accelerator.num_processes):
|
||||
logging.warning(
|
||||
f"Resuming with num_processes={accelerator.num_processes} but the checkpoint was "
|
||||
f"written with num_processes={saved_num_processes}. The data order resumes at the "
|
||||
"right epoch/offset, but per-rank sample-exactness requires the same world size."
|
||||
)
|
||||
if is_main_process and saved_batch_size not in (None, cfg.batch_size):
|
||||
logging.warning(
|
||||
f"Resuming with batch_size={cfg.batch_size} but the checkpoint was written with "
|
||||
f"batch_size={saved_batch_size}. The data order resumes at the right epoch/offset, "
|
||||
"but per-rank sample-exactness requires the same batch size."
|
||||
)
|
||||
sampler_state = compute_sampler_state(step, len(sampler), ckpt_batch_size, ckpt_num_processes)
|
||||
sampler.load_state_dict(sampler_state)
|
||||
if is_main_process:
|
||||
logging.info(
|
||||
f"Resuming data order at epoch {sampler_state['epoch']}, "
|
||||
f"sample {sampler_state['start_index']}"
|
||||
)
|
||||
else:
|
||||
shuffle = True
|
||||
sampler = None
|
||||
@@ -571,8 +544,6 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
scheduler=lr_scheduler,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
num_processes=accelerator.num_processes,
|
||||
batch_size=cfg.batch_size,
|
||||
)
|
||||
update_last_checkpoint(checkpoint_dir)
|
||||
if wandb_logger:
|
||||
|
||||
@@ -13,213 +13,77 @@
|
||||
[SmolVLA](https://huggingface.co/papers/2506.01844) is a compact, efficient vision-language-action model that achieves competitive performance at reduced computational costs and can be deployed on consumer-grade hardware.
|
||||
{% elif model_name == "act" %}
|
||||
[Action Chunking with Transformers (ACT)](https://huggingface.co/papers/2304.13705) is an imitation-learning method that predicts short action chunks instead of single steps. It learns from teleoperated data and often achieves high success rates.
|
||||
{% elif model_name == "tdmpc" %}
|
||||
[TD-MPC](https://huggingface.co/papers/2203.04955) combines model-free and model-based approaches to improve sample efficiency and performance in continuous control tasks by using a learned latent dynamics model and terminal value function.
|
||||
{% elif model_name == "diffusion" %}
|
||||
[Diffusion Policy](https://huggingface.co/papers/2303.04137) treats visuomotor control as a generative diffusion process, producing smooth, multi-step action trajectories that excel at contact-rich manipulation.
|
||||
{% elif model_name == "vqbet" %}
|
||||
[VQ-BET](https://huggingface.co/papers/2403.03181) combines vector-quantised action tokens with Behaviour Transformers to discretise control and achieve data-efficient imitation across diverse skills.
|
||||
{% elif model_name == "pi0" %}
|
||||
[π₀ (Pi0)](https://www.physicalintelligence.company/blog/pi0) is a general-purpose robot foundation model from Physical Intelligence: a generalist Vision-Language-Action policy that understands visual inputs, interprets natural language instructions, and controls a variety of different robots across diverse tasks. The LeRobot implementation is adapted from their open-source OpenPI repository.
|
||||
**π₀ (Pi0)**
|
||||
|
||||
π₀ is a Vision-Language-Action model for general robot control, from Physical Intelligence. The LeRobot implementation is adapted from their open source OpenPI repository.
|
||||
|
||||
**Model Overview**
|
||||
|
||||
π₀ represents a breakthrough in robotics as the first general-purpose robot foundation model developed by Physical Intelligence. Unlike traditional robots that are narrow specialists programmed for repetitive motions, π₀ is designed to be a generalist policy that can understand visual inputs, interpret natural language instructions, and control a variety of different robots across diverse tasks.
|
||||
|
||||
For more details, see the [Physical Intelligence π₀ blog post](https://www.physicalintelligence.company/blog/pi0).
|
||||
{% elif model_name == "pi05" %}
|
||||
[π₀.₅ (Pi05)](https://www.physicalintelligence.company/blog/pi05) is a Vision-Language-Action model from Physical Intelligence designed for open-world generalization: it evolves π₀ to generalize to entirely new environments and situations that were never seen during training. The LeRobot implementation is adapted from their open-source OpenPI repository.
|
||||
{% elif model_name == "molmoact2" %}
|
||||
[MolmoAct2](https://allenai.org/blog/molmoact2) is an open robotics foundation model from the Allen Institute for AI (Ai2) that maps camera images and language instructions to robot action chunks. The LeRobot implementation supports training and evaluation of the regular MolmoAct2 model.
|
||||
{% elif model_name == "vla_jepa" %}
|
||||
[VLA-JEPA](https://arxiv.org/abs/2602.10098) is a Vision-Language-Action model that combines a Qwen3-VL language backbone with a self-supervised video world model (V-JEPA2) and a flow-matching DiT action head.
|
||||
**π₀.₅ (Pi05) Policy**
|
||||
|
||||
π₀.₅ is a Vision-Language-Action model with open-world generalization, from Physical Intelligence. The LeRobot implementation is adapted from their open source OpenPI repository.
|
||||
|
||||
**Model Overview**
|
||||
|
||||
π₀.₅ represents a significant evolution from π₀, developed by Physical Intelligence to address a big challenge in robotics: open-world generalization. While robots can perform impressive tasks in controlled environments, π₀.₅ is designed to generalize to entirely new environments and situations that were never seen during training.
|
||||
|
||||
For more details, see the [Physical Intelligence π₀.₅ blog post](https://www.physicalintelligence.company/blog/pi05).
|
||||
{% elif model_name == "gaussian_actor" %}
|
||||
This is a Gaussian Actor policy (Gaussian policy with a tanh squash) — the policy-side component used by [Soft Actor-Critic (SAC)](https://huggingface.co/papers/1801.01290) and related maximum-entropy continuous-control algorithms.
|
||||
{% elif model_name == "pi0_fast" %}
|
||||
[π₀-FAST (Pi0-FAST)](https://www.physicalintelligence.company/research/fast) is a Vision-Language-Action model for general robot control, from Physical Intelligence. It models continuous robot actions with autoregressive next-token prediction using FAST (Frequency-space Action Sequence Tokenization), training up to 5x faster than diffusion-based π₀.
|
||||
{% elif model_name == "eo1" %}
|
||||
[EO-1](https://huggingface.co/papers/2508.21112) is a Vision-Language-Action model for general robot control. It pairs a Qwen2.5-VL backbone for vision-language understanding with a continuous flow-matching action head that denoises action chunks.
|
||||
{% elif model_name == "groot" %}
|
||||
[GR00T N1.5](https://github.com/NVIDIA/Isaac-GR00T) is an open, cross-embodiment foundation model from NVIDIA for generalized humanoid robot reasoning and skills. It takes language and images as input and uses a flow-matching action transformer to predict actions conditioned on vision, language, and proprioception.
|
||||
{% elif model_name == "multi_task_dit" %}
|
||||
[Multi-Task Diffusion Transformer (DiT)](https://huggingface.co/papers/2507.05331) extends Diffusion Policy with a large Diffusion Transformer and text + vision conditioning for multi-task robot learning. It supports both diffusion and flow-matching objectives and reaches high dexterity with only ~450M parameters.
|
||||
{% elif model_name == "wall_x" %}
|
||||
[WALL-OSS](https://huggingface.co/papers/2509.11766) is an open-source foundation model for embodied intelligence from XSquare Robot. Built on Qwen2.5-VL, it uses a tightly-coupled multimodal architecture with flow matching to unify semantic reasoning and high-frequency action generation for cross-embodiment control.
|
||||
{% elif model_name == "xvla" %}
|
||||
[X-VLA](https://huggingface.co/papers/2510.10274) is a soft-prompted, flow-matching Vision-Language-Action framework that treats each robot or hardware setup as a "task" encoded with a small set of learnable Soft Prompt embeddings, letting a single model reconcile diverse robot morphologies, sensors, and action spaces.
|
||||
{% else %}
|
||||
This is a **{{ model_name }}** policy trained with [LeRobot](https://github.com/huggingface/lerobot).
|
||||
_Model type not recognized — please update this template._
|
||||
{% endif %}
|
||||
{% set diagrams = {
|
||||
"smolvla": "https://cdn-uploads.huggingface.co/production/uploads/640e21ef3c82bd463ee5a76d/aooU0a3DMtYmy_1IWMaIM.png",
|
||||
"pi0": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/lerobot-pi0%20(1).png",
|
||||
"pi0_fast": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/lerobot-pifast.png",
|
||||
"eo1": "https://huggingface.co/datasets/HaomingSong/lerobot-documentation-images/resolve/main/lerobot/eo_pipeline.png",
|
||||
"groot": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/lerobot-groot-paper1%20(1).png",
|
||||
"wall_x": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/walloss-lerobot-paper.png",
|
||||
"xvla": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/xvla-architecture.png"
|
||||
} %}
|
||||
{% if diagrams.get(model_name) %}
|
||||
<p align="center">
|
||||
<img src="{{ diagrams[model_name] }}" alt="{{ model_name }} architecture" width="85%"/>
|
||||
</p>
|
||||
{% endif %}
|
||||
|
||||
<!-- A short demo is worth more than any description! Record a GIF/video of the policy
|
||||
running on your robot, upload it to this repo, and embed it here:
|
||||
<p align="center">
|
||||
<img src="https://huggingface.co/<hf_user>/<policy_repo_id>/resolve/main/demo.gif" width="60%"/>
|
||||
</p>
|
||||
-->
|
||||
|
||||
This policy has been trained and pushed to the Hub using [LeRobot](https://github.com/huggingface/lerobot).
|
||||
{% set policy_docs = {
|
||||
"act": "act",
|
||||
"smolvla": "smolvla",
|
||||
"pi0": "pi0",
|
||||
"pi0_fast": "pi0fast",
|
||||
"pi05": "pi05",
|
||||
"molmoact2": "molmoact2",
|
||||
"vla_jepa": "vla_jepa",
|
||||
"eo1": "eo1",
|
||||
"groot": "groot",
|
||||
"xvla": "xvla",
|
||||
"multi_task_dit": "multi_task_dit",
|
||||
"wall_x": "walloss"
|
||||
} %}
|
||||
{% if policy_docs.get(model_name) %}Learn how to train and run it in the [LeRobot {{ model_name }} guide](https://huggingface.co/docs/lerobot/main/en/{{ policy_docs[model_name] }}), or browse the [full documentation](https://huggingface.co/docs/lerobot/index).
|
||||
{% else %}See the [full LeRobot documentation](https://huggingface.co/docs/lerobot/index).
|
||||
{% endif %}
|
||||
See the full documentation at [LeRobot Docs](https://huggingface.co/docs/lerobot/index).
|
||||
|
||||
---
|
||||
|
||||
## How to Get Started with the Model
|
||||
|
||||
For a complete walkthrough, see the [training guide](https://huggingface.co/docs/lerobot/il_robots#train-a-policy).
|
||||
Below is the short version on how to train and run inference/eval:
|
||||
|
||||
### Train from scratch
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--dataset.repo_id=${HF_USER}/<dataset> \
|
||||
--policy.type=act \
|
||||
--output_dir=outputs/train/<desired_policy_repo_id> \
|
||||
--job_name=lerobot_training \
|
||||
--policy.device=cuda \
|
||||
--policy.repo_id=${HF_USER}/<desired_policy_repo_id>
|
||||
--wandb.enable=true
|
||||
```
|
||||
|
||||
_Writes checkpoints to `outputs/train/<desired_policy_repo_id>/checkpoints/`._
|
||||
|
||||
### Evaluate the policy/run inference
|
||||
|
||||
```bash
|
||||
lerobot-record \
|
||||
--robot.type=so100_follower \
|
||||
--dataset.repo_id=<hf_user>/eval_<dataset> \
|
||||
--policy.path=<hf_user>/<desired_policy_repo_id> \
|
||||
--episodes=10
|
||||
```
|
||||
|
||||
Prefix the dataset repo with **eval\_** and supply `--policy.path` pointing to a local or hub checkpoint.
|
||||
|
||||
---
|
||||
|
||||
## Model Details
|
||||
|
||||
- **License:** {{ license | default("\[More Information Needed]", true) }}
|
||||
{% if base_model %}- **Fine-tuned from:** [{{ base_model }}](https://huggingface.co/{{ base_model }})
|
||||
{% endif %}{% if robot_type %}- **Robot type:** `{{ robot_type }}`
|
||||
{% endif %}{% if cameras %}- **Cameras:** {% for camera in cameras %}`{{ camera }}`{% if not loop.last %}, {% endif %}{% endfor %}
|
||||
{% endif %}
|
||||
{% if input_features or output_features %}
|
||||
## Inputs & Outputs
|
||||
|
||||
The policy consumes these observation features and produces these action features.
|
||||
{% if input_features %}
|
||||
**Inputs**
|
||||
|
||||
| Feature | Type | Shape |
|
||||
| --- | --- | --- |
|
||||
{% for name, feature in input_features.items() %}| `{{ name }}` | {{ feature.type.value }} | `{{ feature.shape }}` |
|
||||
{% endfor %}{% endif %}{% if output_features %}
|
||||
**Outputs**
|
||||
|
||||
| Feature | Type | Shape |
|
||||
| --- | --- | --- |
|
||||
{% for name, feature in output_features.items() %}| `{{ name }}` | {{ feature.type.value }} | `{{ feature.shape }}` |
|
||||
{% endfor %}{% endif %}{% endif %}
|
||||
{% if dataset %}
|
||||
## Training Dataset
|
||||
|
||||
- **Repository:** [{{ dataset.repo_id }}](https://huggingface.co/datasets/{{ dataset.repo_id }})
|
||||
- **Episodes:** {{ dataset.episodes }}
|
||||
- **Frames:** {{ dataset.frames }}
|
||||
- **Frame rate:** {{ dataset.fps }} FPS
|
||||
{% if dataset.tasks %}- **Task(s):** {% for task in dataset.tasks %}"{{ task }}"{% if not loop.last %}, {% endif %}{% endfor %}
|
||||
{% endif %}
|
||||
<a class="flex" href="https://huggingface.co/spaces/lerobot/visualize_dataset?path={{ dataset.repo_id }}">
|
||||
<img class="block dark:hidden" src="https://huggingface.co/datasets/huggingface/badges/resolve/main/visualize-this-dataset-xl.svg"/>
|
||||
<img class="hidden dark:block" src="https://huggingface.co/datasets/huggingface/badges/resolve/main/visualize-this-dataset-xl-dark.svg"/>
|
||||
</a>
|
||||
{% endif %}
|
||||
{% if training %}
|
||||
## Training Configuration
|
||||
|
||||
| Setting | Value |
|
||||
| --- | --- |
|
||||
| Training steps | {{ training.steps }} |
|
||||
| Batch size | {{ training.batch_size }} |
|
||||
{% if training.optimizer %}| Optimizer | {{ training.optimizer }} |
|
||||
{% endif %}{% if training.lr %}| Learning rate | {{ training.lr }} |
|
||||
{% endif %}{% if training.seed is not none %}| Seed | {{ training.seed }} |
|
||||
{% endif %}| LeRobot version | {{ training.lerobot_version }} |
|
||||
{% endif %}
|
||||
---
|
||||
|
||||
## How to Get Started with the Model
|
||||
|
||||
New to LeRobot? These guides cover the full workflow:
|
||||
|
||||
- **[Install LeRobot](https://huggingface.co/docs/lerobot/main/en/installation)** — set up the `lerobot` package.
|
||||
- **[Hardware setup](https://huggingface.co/docs/lerobot/main/en/hardware_guide)** — assemble, wire, and calibrate your robot and cameras.
|
||||
- **[Record data & train a policy](https://huggingface.co/docs/lerobot/en/il_robots)** — the end-to-end imitation-learning walkthrough.
|
||||
- **[CLI cheat-sheet](https://huggingface.co/docs/lerobot/main/en/cheat-sheet)** — quick reference for the `lerobot-*` commands.
|
||||
|
||||
The short version to run and train this policy:
|
||||
|
||||
### Run the policy on your robot
|
||||
|
||||
```bash
|
||||
lerobot-rollout \
|
||||
--strategy.type=base \
|
||||
--robot.type={{ robot_type | default("<your_robot_type>", true) }} \
|
||||
--robot.port=<your_robot_port> \
|
||||
--robot.cameras="{ <camera_1>: {type: opencv, index_or_path: <index_or_path>, width: 640, height: 480, fps: 30}, <camera_2>: {type: opencv, index_or_path: <index_or_path>, width: 640, height: 480, fps: 30}}" \
|
||||
--policy.path={{ policy_repo_id | default("<hf_user>/<policy_repo_id>", true) }} \
|
||||
--task="{% if dataset and dataset.tasks %}{{ dataset.tasks[0] }}{% else %}<your_task_description>{% endif %}" \
|
||||
--duration=60
|
||||
```
|
||||
|
||||
Replace the remaining `<...>` placeholders with your own values: `--robot.port` and the camera names/indices are specific to your machine, and the camera names must match the observation keys this policy was trained on.
|
||||
|
||||
When `--strategy.type=base` is used the script doesn't record the episodes. Skipping duration will make the policy run indefinitely. For more information look at [rollout documentation](https://huggingface.co/docs/lerobot/main/en/inference).
|
||||
|
||||
{% if base_model %}### Train your own policy
|
||||
|
||||
This policy type is usually fine-tuned from the pretrained base model [{{ base_model }}](https://huggingface.co/{{ base_model }}):
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--dataset.repo_id=${HF_USER}/<dataset> \
|
||||
--policy.path={{ base_model }} \
|
||||
--output_dir=outputs/train/<policy_repo_id> \
|
||||
--job_name=lerobot_training \
|
||||
--policy.device=cuda \
|
||||
--policy.repo_id=${HF_USER}/<policy_repo_id> \
|
||||
--wandb.enable=true
|
||||
```
|
||||
{% else %}### Train your own policy
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--dataset.repo_id=${HF_USER}/<dataset> \
|
||||
--policy.type={{ model_name }} \
|
||||
--output_dir=outputs/train/<policy_repo_id> \
|
||||
--job_name=lerobot_training \
|
||||
--policy.device=cuda \
|
||||
--policy.repo_id=${HF_USER}/<policy_repo_id> \
|
||||
--wandb.enable=true
|
||||
```
|
||||
{% endif %}
|
||||
_Writes checkpoints to `outputs/train/<policy_repo_id>/checkpoints/`._
|
||||
|
||||
---
|
||||
|
||||
## Evaluation
|
||||
|
||||
<!-- Report real-robot results here: run the policy several times per task and count the
|
||||
successes. Delete the "No evaluation results" line and fill in this table instead:
|
||||
|
||||
| Task | Trials | Successes | Success rate |
|
||||
| ---- | ------ | --------- | ------------ |
|
||||
| pick the lego brick | 10 | 8 | 80% |
|
||||
|
||||
Also worth noting: anything that affects difficulty (new object positions, lighting,
|
||||
distractors, a different robot of the same type, ...).
|
||||
-->
|
||||
|
||||
_No evaluation results have been provided for this policy yet._
|
||||
|
||||
---
|
||||
|
||||
## Citation
|
||||
|
||||
If you use this policy, please cite the method linked in the description above, along with LeRobot:
|
||||
|
||||
```bibtex
|
||||
@misc{cadene2024lerobot,
|
||||
author = {Cadene, Remi and Alibert, Simon and Soare, Alexander and Gallouedec, Quentin and Zouitine, Adil and Palma, Steven and Kooijmans, Pepijn and Aractingi, Michel and Shukor, Mustafa and Aubakirova, Dana and Russi, Martino and Capuano, Francesco and Pascal, Caroline and Choghari, Jade and Moss, Jess and Wolf, Thomas},
|
||||
title = {LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch},
|
||||
howpublished = "\url{https://github.com/huggingface/lerobot}",
|
||||
year = {2024}
|
||||
}
|
||||
```
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
gRPC transport layer for async inference.
|
||||
gRPC transport layer for the HIL-SERL RL stack (actor ↔ learner).
|
||||
|
||||
Requires: ``pip install 'lerobot[grpcio-dep]'``
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.python -m grpc_tools.protoc -I src --python_out=src --grpc_python_out=src src/lerobot/transport/services.proto
|
||||
// limitations under the License.
|
||||
|
||||
// To generate a classes for transport part (services_pb2.py and services_pb2_grpc.py) use the following command:
|
||||
//
|
||||
@@ -33,17 +33,6 @@ service LearnerService {
|
||||
rpc Ready(Empty) returns (Empty);
|
||||
}
|
||||
|
||||
// AsyncInference: from Robot perspective
|
||||
// Robot send observations to & executes action received from a remote Policy server
|
||||
service AsyncInference {
|
||||
// Robot -> Policy to share observations with a remote inference server
|
||||
// Policy -> Robot to share actions predicted for given observations
|
||||
rpc SendObservations(stream Observation) returns (Empty);
|
||||
rpc GetActions(Empty) returns (Actions);
|
||||
rpc SendPolicyInstructions(PolicySetup) returns (Empty);
|
||||
rpc Ready(Empty) returns (Empty);
|
||||
}
|
||||
|
||||
enum TransferState {
|
||||
TRANSFER_UNKNOWN = 0;
|
||||
TRANSFER_BEGIN = 1;
|
||||
@@ -67,21 +56,4 @@ message InteractionMessage {
|
||||
bytes data = 2;
|
||||
}
|
||||
|
||||
// Messages
|
||||
message Observation {
|
||||
// sent by Robot, to remote Policy
|
||||
TransferState transfer_state = 1; // Observations can be streamed exceeding 4MB of size
|
||||
bytes data = 2;
|
||||
}
|
||||
|
||||
message Actions {
|
||||
// sent by remote Policy, to Robot
|
||||
bytes data = 1;
|
||||
}
|
||||
|
||||
message PolicySetup {
|
||||
// sent by Robot to remote server, to init Policy
|
||||
bytes data = 1;
|
||||
}
|
||||
|
||||
message Empty {}
|
||||
|
||||
@@ -23,31 +23,23 @@ _sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n lerobot/transport/services.proto\x12\ttransport\"L\n\nTransition\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"L\n\nParameters\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"T\n\x12InteractionMessage\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"M\n\x0bObservation\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"\x17\n\x07\x41\x63tions\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\"\x1b\n\x0bPolicySetup\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\"\x07\n\x05\x45mpty*`\n\rTransferState\x12\x14\n\x10TRANSFER_UNKNOWN\x10\x00\x12\x12\n\x0eTRANSFER_BEGIN\x10\x01\x12\x13\n\x0fTRANSFER_MIDDLE\x10\x02\x12\x10\n\x0cTRANSFER_END\x10\x03\x32\x81\x02\n\x0eLearnerService\x12=\n\x10StreamParameters\x12\x10.transport.Empty\x1a\x15.transport.Parameters0\x01\x12<\n\x0fSendTransitions\x12\x15.transport.Transition\x1a\x10.transport.Empty(\x01\x12\x45\n\x10SendInteractions\x12\x1d.transport.InteractionMessage\x1a\x10.transport.Empty(\x01\x12+\n\x05Ready\x12\x10.transport.Empty\x1a\x10.transport.Empty2\xf5\x01\n\x0e\x41syncInference\x12>\n\x10SendObservations\x12\x16.transport.Observation\x1a\x10.transport.Empty(\x01\x12\x32\n\nGetActions\x12\x10.transport.Empty\x1a\x12.transport.Actions\x12\x42\n\x16SendPolicyInstructions\x12\x16.transport.PolicySetup\x1a\x10.transport.Empty\x12+\n\x05Ready\x12\x10.transport.Empty\x1a\x10.transport.Emptyb\x06proto3')
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n lerobot/transport/services.proto\x12\ttransport\"L\n\nTransition\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"L\n\nParameters\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"T\n\x12InteractionMessage\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"\x07\n\x05\x45mpty*`\n\rTransferState\x12\x14\n\x10TRANSFER_UNKNOWN\x10\x00\x12\x12\n\x0eTRANSFER_BEGIN\x10\x01\x12\x13\n\x0fTRANSFER_MIDDLE\x10\x02\x12\x10\n\x0cTRANSFER_END\x10\x03\x32\x81\x02\n\x0eLearnerService\x12=\n\x10StreamParameters\x12\x10.transport.Empty\x1a\x15.transport.Parameters0\x01\x12<\n\x0fSendTransitions\x12\x15.transport.Transition\x1a\x10.transport.Empty(\x01\x12\x45\n\x10SendInteractions\x12\x1d.transport.InteractionMessage\x1a\x10.transport.Empty(\x01\x12+\n\x05Ready\x12\x10.transport.Empty\x1a\x10.transport.Emptyb\x06proto3')
|
||||
|
||||
_globals = globals()
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'lerobot.transport.services_pb2', _globals)
|
||||
if not _descriptor._USE_C_DESCRIPTORS:
|
||||
DESCRIPTOR._loaded_options = None
|
||||
_globals['_TRANSFERSTATE']._serialized_start=431
|
||||
_globals['_TRANSFERSTATE']._serialized_end=527
|
||||
_globals['_TRANSFERSTATE']._serialized_start=298
|
||||
_globals['_TRANSFERSTATE']._serialized_end=394
|
||||
_globals['_TRANSITION']._serialized_start=47
|
||||
_globals['_TRANSITION']._serialized_end=123
|
||||
_globals['_PARAMETERS']._serialized_start=125
|
||||
_globals['_PARAMETERS']._serialized_end=201
|
||||
_globals['_INTERACTIONMESSAGE']._serialized_start=203
|
||||
_globals['_INTERACTIONMESSAGE']._serialized_end=287
|
||||
_globals['_OBSERVATION']._serialized_start=289
|
||||
_globals['_OBSERVATION']._serialized_end=366
|
||||
_globals['_ACTIONS']._serialized_start=368
|
||||
_globals['_ACTIONS']._serialized_end=391
|
||||
_globals['_POLICYSETUP']._serialized_start=393
|
||||
_globals['_POLICYSETUP']._serialized_end=420
|
||||
_globals['_EMPTY']._serialized_start=422
|
||||
_globals['_EMPTY']._serialized_end=429
|
||||
_globals['_LEARNERSERVICE']._serialized_start=530
|
||||
_globals['_LEARNERSERVICE']._serialized_end=787
|
||||
_globals['_ASYNCINFERENCE']._serialized_start=790
|
||||
_globals['_ASYNCINFERENCE']._serialized_end=1035
|
||||
_globals['_EMPTY']._serialized_start=289
|
||||
_globals['_EMPTY']._serialized_end=296
|
||||
_globals['_LEARNERSERVICE']._serialized_start=397
|
||||
_globals['_LEARNERSERVICE']._serialized_end=654
|
||||
# @@protoc_insertion_point(module_scope)
|
||||
|
||||
@@ -231,212 +231,3 @@ class LearnerService:
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
|
||||
class AsyncInferenceStub:
|
||||
"""AsyncInference: from Robot perspective
|
||||
Robot send observations to & executes action received from a remote Policy server
|
||||
"""
|
||||
|
||||
def __init__(self, channel):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
channel: A grpc.Channel.
|
||||
"""
|
||||
self.SendObservations = channel.stream_unary(
|
||||
'/transport.AsyncInference/SendObservations',
|
||||
request_serializer=lerobot_dot_transport_dot_services__pb2.Observation.SerializeToString,
|
||||
response_deserializer=lerobot_dot_transport_dot_services__pb2.Empty.FromString,
|
||||
_registered_method=True)
|
||||
self.GetActions = channel.unary_unary(
|
||||
'/transport.AsyncInference/GetActions',
|
||||
request_serializer=lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString,
|
||||
response_deserializer=lerobot_dot_transport_dot_services__pb2.Actions.FromString,
|
||||
_registered_method=True)
|
||||
self.SendPolicyInstructions = channel.unary_unary(
|
||||
'/transport.AsyncInference/SendPolicyInstructions',
|
||||
request_serializer=lerobot_dot_transport_dot_services__pb2.PolicySetup.SerializeToString,
|
||||
response_deserializer=lerobot_dot_transport_dot_services__pb2.Empty.FromString,
|
||||
_registered_method=True)
|
||||
self.Ready = channel.unary_unary(
|
||||
'/transport.AsyncInference/Ready',
|
||||
request_serializer=lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString,
|
||||
response_deserializer=lerobot_dot_transport_dot_services__pb2.Empty.FromString,
|
||||
_registered_method=True)
|
||||
|
||||
|
||||
class AsyncInferenceServicer:
|
||||
"""AsyncInference: from Robot perspective
|
||||
Robot send observations to & executes action received from a remote Policy server
|
||||
"""
|
||||
|
||||
def SendObservations(self, request_iterator, context):
|
||||
"""Robot -> Policy to share observations with a remote inference server
|
||||
Policy -> Robot to share actions predicted for given observations
|
||||
"""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def GetActions(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def SendPolicyInstructions(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def Ready(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
|
||||
def add_AsyncInferenceServicer_to_server(servicer, server):
|
||||
rpc_method_handlers = {
|
||||
'SendObservations': grpc.stream_unary_rpc_method_handler(
|
||||
servicer.SendObservations,
|
||||
request_deserializer=lerobot_dot_transport_dot_services__pb2.Observation.FromString,
|
||||
response_serializer=lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString,
|
||||
),
|
||||
'GetActions': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.GetActions,
|
||||
request_deserializer=lerobot_dot_transport_dot_services__pb2.Empty.FromString,
|
||||
response_serializer=lerobot_dot_transport_dot_services__pb2.Actions.SerializeToString,
|
||||
),
|
||||
'SendPolicyInstructions': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.SendPolicyInstructions,
|
||||
request_deserializer=lerobot_dot_transport_dot_services__pb2.PolicySetup.FromString,
|
||||
response_serializer=lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString,
|
||||
),
|
||||
'Ready': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.Ready,
|
||||
request_deserializer=lerobot_dot_transport_dot_services__pb2.Empty.FromString,
|
||||
response_serializer=lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString,
|
||||
),
|
||||
}
|
||||
generic_handler = grpc.method_handlers_generic_handler(
|
||||
'transport.AsyncInference', rpc_method_handlers)
|
||||
server.add_generic_rpc_handlers((generic_handler,))
|
||||
server.add_registered_method_handlers('transport.AsyncInference', rpc_method_handlers)
|
||||
|
||||
|
||||
# This class is part of an EXPERIMENTAL API.
|
||||
class AsyncInference:
|
||||
"""AsyncInference: from Robot perspective
|
||||
Robot send observations to & executes action received from a remote Policy server
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def SendObservations(request_iterator,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.stream_unary(
|
||||
request_iterator,
|
||||
target,
|
||||
'/transport.AsyncInference/SendObservations',
|
||||
lerobot_dot_transport_dot_services__pb2.Observation.SerializeToString,
|
||||
lerobot_dot_transport_dot_services__pb2.Empty.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def GetActions(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/transport.AsyncInference/GetActions',
|
||||
lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString,
|
||||
lerobot_dot_transport_dot_services__pb2.Actions.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def SendPolicyInstructions(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/transport.AsyncInference/SendPolicyInstructions',
|
||||
lerobot_dot_transport_dot_services__pb2.PolicySetup.SerializeToString,
|
||||
lerobot_dot_transport_dot_services__pb2.Empty.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def Ready(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/transport.AsyncInference/Ready',
|
||||
lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString,
|
||||
lerobot_dot_transport_dot_services__pb2.Empty.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@@ -1,58 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Helpers shared across annotation-pipeline tests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from lerobot.annotations.steerable_pipeline.vlm_client import StubVlmClient
|
||||
|
||||
|
||||
def make_canned_responder(
|
||||
responses_by_marker: dict[str, Any],
|
||||
default: Any = None,
|
||||
) -> StubVlmClient:
|
||||
"""Return a stub that picks a response by inspecting the user prompt.
|
||||
|
||||
For each call the responder examines the last user-message text and
|
||||
returns the response keyed by the first marker substring it contains.
|
||||
Falls back to ``default`` if no marker matches.
|
||||
"""
|
||||
|
||||
def responder(messages: list[dict[str, Any]]) -> Any:
|
||||
last_user_text = ""
|
||||
for message in messages:
|
||||
if message.get("role") != "user":
|
||||
continue
|
||||
content = message.get("content")
|
||||
if isinstance(content, str):
|
||||
last_user_text = content
|
||||
elif isinstance(content, list):
|
||||
for block in content:
|
||||
if isinstance(block, dict) and block.get("type") == "text":
|
||||
last_user_text = block.get("text", "")
|
||||
for marker, response in responses_by_marker.items():
|
||||
if marker in last_user_text:
|
||||
return response
|
||||
return default
|
||||
|
||||
return StubVlmClient(responder=responder)
|
||||
|
||||
|
||||
def encode_vqa_answer(payload: dict[str, Any]) -> str:
|
||||
return json.dumps(payload, sort_keys=True)
|
||||
@@ -1,58 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Shared fixtures for annotation-pipeline tests.
|
||||
|
||||
The on-disk dataset builder lives with the other dataset factories in
|
||||
``tests/fixtures/dataset_factories.py`` (:func:`build_annotation_dataset`);
|
||||
these fixtures only wire it into pytest.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
# ``build_annotation_dataset`` pulls in ``lerobot.datasets`` (HF ``datasets``
|
||||
# + ``pandas``, only in the ``dataset`` extra), so it's imported lazily inside
|
||||
# each fixture — this conftest stays importable without that extra. The test
|
||||
# modules ``pytest.importorskip("datasets")`` so they skip rather than error.
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fixture_dataset_root(tmp_path: Path) -> Path:
|
||||
"""A tiny dataset with two episodes, 12 frames each at 10 fps."""
|
||||
from tests.fixtures.dataset_factories import build_annotation_dataset
|
||||
|
||||
return build_annotation_dataset(
|
||||
tmp_path / "ds",
|
||||
episode_specs=[
|
||||
(0, 12, "Could you tidy the kitchen please?"),
|
||||
(1, 12, "Please clean up the kitchen"),
|
||||
],
|
||||
fps=10,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def single_episode_root(tmp_path: Path) -> Path:
|
||||
from tests.fixtures.dataset_factories import build_annotation_dataset
|
||||
|
||||
return build_annotation_dataset(
|
||||
tmp_path / "ds_one",
|
||||
episode_specs=[(0, 30, "Pour water from the bottle into the cup.")],
|
||||
fps=10,
|
||||
)
|
||||
@@ -1,116 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Opt-in E2E smoke run for ``make annotation-e2e``.
|
||||
|
||||
Builds the shared annotation fixture (:func:`build_annotation_dataset`),
|
||||
runs the full annotation pipeline against it with a stub VLM, and prints a
|
||||
short report. This is intentionally not a pytest test — it exercises the
|
||||
CLI plumbing — but it reuses the same on-disk dataset builder as the pytest
|
||||
fixtures so there is no duplicated fixture code.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
from lerobot.annotations.steerable_pipeline.config import AnnotationPipelineConfig
|
||||
from lerobot.annotations.steerable_pipeline.executor import Executor
|
||||
from lerobot.annotations.steerable_pipeline.modules import (
|
||||
GeneralVqaModule,
|
||||
InterjectionsAndSpeechModule,
|
||||
PlanSubtasksMemoryModule,
|
||||
)
|
||||
from lerobot.annotations.steerable_pipeline.validator import StagingValidator
|
||||
from lerobot.annotations.steerable_pipeline.vlm_client import StubVlmClient
|
||||
from lerobot.annotations.steerable_pipeline.writer import LanguageColumnsWriter
|
||||
from tests.fixtures.dataset_factories import build_annotation_dataset
|
||||
|
||||
|
||||
def _stub_responder(messages):
|
||||
text = ""
|
||||
for m in messages:
|
||||
if m.get("role") == "user":
|
||||
content = m.get("content")
|
||||
if isinstance(content, list):
|
||||
for block in content:
|
||||
if isinstance(block, dict) and block.get("type") == "text":
|
||||
text = block.get("text", "")
|
||||
elif isinstance(content, str):
|
||||
text = content
|
||||
if "atomic subtasks" in text:
|
||||
return {
|
||||
"subtasks": [
|
||||
{"text": "grasp the bottle", "start": 0.0, "end": 1.0},
|
||||
{"text": "pour into the cup", "start": 1.0, "end": 2.0},
|
||||
{"text": "place the bottle down", "start": 2.0, "end": 3.0},
|
||||
]
|
||||
}
|
||||
if "compressed semantic memory" in text:
|
||||
return {"memory": "poured once"}
|
||||
if "acknowledgement the robot" in text:
|
||||
return {"text": "Sure."}
|
||||
if "compact interjection" in text:
|
||||
return {"interjection": "use less water", "speech": "Using less water."}
|
||||
if "frame-grounded visual question" in text:
|
||||
return {"question": "How many cups?", "answer": {"label": "cup", "count": 1}}
|
||||
return None
|
||||
|
||||
|
||||
def main() -> int:
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
root = build_annotation_dataset(
|
||||
Path(tmp) / "ds",
|
||||
episode_specs=[(0, 30, "Pour water into the cup.")],
|
||||
fps=10,
|
||||
)
|
||||
vlm = StubVlmClient(responder=_stub_responder)
|
||||
cfg = AnnotationPipelineConfig()
|
||||
executor = Executor(
|
||||
config=cfg,
|
||||
plan=PlanSubtasksMemoryModule(vlm=vlm, config=cfg.plan),
|
||||
interjections=InterjectionsAndSpeechModule(vlm=vlm, config=cfg.interjections, seed=cfg.seed),
|
||||
vqa=GeneralVqaModule(vlm=vlm, config=cfg.vqa, seed=cfg.seed),
|
||||
writer=LanguageColumnsWriter(),
|
||||
validator=StagingValidator(),
|
||||
)
|
||||
summary = executor.run(root)
|
||||
print(f"phases={[(p.name, p.episodes_processed) for p in summary.phases]}")
|
||||
print(f"validation: {summary.validation_report.summary()}")
|
||||
print(f"shards rewritten: {len(summary.written_paths)}")
|
||||
|
||||
# Assert the interjection code path actually fired — otherwise a stale
|
||||
# canned-VLM marker would silently produce zero interjections and this
|
||||
# smoke run would still "pass" by only printing.
|
||||
import pyarrow.parquet as pq # noqa: PLC0415
|
||||
|
||||
events = [
|
||||
r
|
||||
for shard in summary.written_paths
|
||||
for ev in pq.read_table(shard).column("language_events").to_pylist()
|
||||
for r in ev
|
||||
]
|
||||
n_interjections = sum(1 for r in events if r.get("style") == "interjection")
|
||||
n_speech = sum(1 for r in events if r.get("style") is None and r.get("role") == "assistant")
|
||||
print(f"interjections={n_interjections} speech_atoms={n_speech}")
|
||||
assert n_interjections > 0, "no interjection rows produced — check the interjection prompt marker"
|
||||
assert n_speech > 0, "no speech tool-call atoms produced — check the speech prompt marker"
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
@@ -1,246 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Unit tests for :class:`VideoFrameProvider` method bindings.
|
||||
|
||||
These were prompted by a real regression: ``video_for_episode`` was once
|
||||
indented one level too deep so it ended up nested *inside* a module-level
|
||||
helper (after that function's ``return`` statement) — silently dead code
|
||||
that meant production runs with ``use_video_url=False`` would
|
||||
``AttributeError`` on ``self.frame_provider.video_for_episode(...)``. The
|
||||
existing module tests didn't catch it because they exercise stub providers.
|
||||
|
||||
The tests below assert on the class itself (not on an instance), so a
|
||||
future reindent regression flips them to red without needing a real
|
||||
LeRobot dataset on disk.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import shutil
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
|
||||
from lerobot.annotations.steerable_pipeline.frames import VideoFrameProvider # noqa: E402
|
||||
|
||||
|
||||
class _FakeMeta:
|
||||
"""Minimal metadata stub exposing ``video_keys`` / ``camera_keys``."""
|
||||
|
||||
def __init__(self, video_keys: list[str], image_keys: list[str], video_path: Path | None = None) -> None:
|
||||
self.video_keys = video_keys
|
||||
self.camera_keys = [*video_keys, *image_keys]
|
||||
self._video_path = video_path
|
||||
self.episodes = {0: {f"videos/{key}/from_timestamp": 0.0 for key in video_keys}}
|
||||
|
||||
def get_video_file_path(self, episode_index: int, camera_key: str) -> Path:
|
||||
return self._video_path
|
||||
|
||||
|
||||
def test_default_camera_key_skips_image_only_cameras(tmp_path: Path, monkeypatch) -> None:
|
||||
"""The default camera must be a *video* key — image-stored cameras have no
|
||||
``videos/<key>/from_timestamp`` and would KeyError in the clip/decode path.
|
||||
|
||||
Regression: a dataset whose first ``camera_keys`` entry was an image-stored
|
||||
camera (e.g. ``observation.images.wrist``) crashed at clip extraction.
|
||||
"""
|
||||
fake = _FakeMeta(
|
||||
video_keys=["observation.images.robot0_agentview_right"],
|
||||
image_keys=["observation.images.wrist"],
|
||||
)
|
||||
import lerobot.datasets.dataset_metadata as meta_mod
|
||||
|
||||
monkeypatch.setattr(meta_mod, "LeRobotDatasetMetadata", lambda *a, **k: fake, raising=True)
|
||||
provider = VideoFrameProvider(root=tmp_path)
|
||||
assert provider.camera_key == "observation.images.robot0_agentview_right"
|
||||
assert "observation.images.wrist" not in provider.camera_keys
|
||||
|
||||
|
||||
def test_video_for_episode_is_a_method_of_videoframeprovider():
|
||||
"""``video_for_episode`` must be a bound method, not nested dead code."""
|
||||
assert callable(getattr(VideoFrameProvider, "video_for_episode", None))
|
||||
|
||||
|
||||
def test_episode_clip_path_is_a_method_of_videoframeprovider():
|
||||
"""``episode_clip_path`` is now a method (was a free function reaching
|
||||
into ``provider._meta`` from outside the class)."""
|
||||
assert callable(getattr(VideoFrameProvider, "episode_clip_path", None))
|
||||
|
||||
|
||||
def test_videoframeprovider_has_a_lock_for_concurrent_use():
|
||||
"""A ``ThreadPoolExecutor`` runs the plan / interjections / vqa phases
|
||||
concurrently; the cache + warn-flag accesses must be guarded.
|
||||
"""
|
||||
import threading
|
||||
|
||||
# Fresh-instance check via a minimal fake to avoid touching the hub.
|
||||
# The lock is declared with ``init=False`` and has a default factory,
|
||||
# so a constructed instance must own a real ``threading.Lock``.
|
||||
lock_field = next(
|
||||
(f for f in VideoFrameProvider.__dataclass_fields__.values() if f.name == "_lock"),
|
||||
None,
|
||||
)
|
||||
assert lock_field is not None
|
||||
assert lock_field.default_factory is threading.Lock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_video(tmp_path: Path) -> Path:
|
||||
"""A 3 s 10 fps test-pattern mp4, written with ffmpeg."""
|
||||
if shutil.which("ffmpeg") is None:
|
||||
pytest.skip("ffmpeg not available")
|
||||
out = tmp_path / "sample.mp4"
|
||||
subprocess.run(
|
||||
[
|
||||
"ffmpeg",
|
||||
"-y",
|
||||
"-f",
|
||||
"lavfi",
|
||||
"-i",
|
||||
"testsrc=duration=3:size=160x120:rate=10",
|
||||
"-pix_fmt",
|
||||
"yuv420p",
|
||||
str(out),
|
||||
],
|
||||
check=True,
|
||||
capture_output=True,
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
def _provider_for_video(tmp_path: Path, video: Path, monkeypatch) -> VideoFrameProvider:
|
||||
"""A provider whose single camera resolves to ``video`` via fake metadata."""
|
||||
fake = _FakeMeta(video_keys=["observation.images.cam"], image_keys=[], video_path=video)
|
||||
import lerobot.datasets.dataset_metadata as meta_mod
|
||||
|
||||
monkeypatch.setattr(meta_mod, "LeRobotDatasetMetadata", lambda *a, **k: fake, raising=True)
|
||||
return VideoFrameProvider(root=tmp_path, tolerance_s=0.2)
|
||||
|
||||
|
||||
def test_decode_returns_one_uint8_frame_per_timestamp(
|
||||
sample_video: Path, tmp_path: Path, monkeypatch
|
||||
) -> None:
|
||||
"""``_decode`` routes through ``decode_video_frames`` (torchcodec when
|
||||
available, PyAV otherwise) — no subprocess fallback.
|
||||
"""
|
||||
provider = _provider_for_video(tmp_path, sample_video, monkeypatch)
|
||||
timestamps = [0.0, 1.0, 2.5]
|
||||
frames = provider._decode(0, timestamps, "observation.images.cam")
|
||||
|
||||
assert len(frames) == len(timestamps)
|
||||
for frame in frames:
|
||||
assert isinstance(frame, torch.Tensor)
|
||||
assert frame.dtype == torch.uint8
|
||||
assert frame.shape == (3, 120, 160)
|
||||
|
||||
|
||||
def test_frames_at_snaps_mid_frame_grid_to_real_frames(
|
||||
sample_video: Path, tmp_path: Path, monkeypatch
|
||||
) -> None:
|
||||
"""Uniform sampling grids land mid-frame; ``frames_at`` must snap them to
|
||||
real frame timestamps before decoding.
|
||||
|
||||
Regression: ``decode_video_frames`` rejects queries farther than
|
||||
``tolerance_s`` (default 10 ms) from a decodable frame, so un-snapped
|
||||
mid-frame queries raised ``FrameTimestampError`` wholesale and the plan
|
||||
module silently lost its contact sheets for most episodes.
|
||||
"""
|
||||
from types import SimpleNamespace
|
||||
|
||||
fake = _FakeMeta(video_keys=["observation.images.cam"], image_keys=[], video_path=sample_video)
|
||||
import lerobot.datasets.dataset_metadata as meta_mod
|
||||
|
||||
monkeypatch.setattr(meta_mod, "LeRobotDatasetMetadata", lambda *a, **k: fake, raising=True)
|
||||
provider = VideoFrameProvider(root=tmp_path) # default 10 ms tolerance
|
||||
# 10 fps fixture -> frames at 0.0, 0.1, ...; queries sit mid-frame.
|
||||
record = SimpleNamespace(episode_index=0, frame_timestamps=[i / 10 for i in range(30)])
|
||||
|
||||
frames = provider.frames_at(record, [0.149, 1.234, 2.04], camera_key="observation.images.cam")
|
||||
|
||||
assert len(frames) == 3
|
||||
for frame in frames:
|
||||
assert isinstance(frame, torch.Tensor)
|
||||
assert frame.shape == (3, 120, 160)
|
||||
|
||||
|
||||
def test_decode_returns_empty_list_on_missing_file(tmp_path: Path, monkeypatch) -> None:
|
||||
"""A missing video is a recoverable no-frames condition, never a crash."""
|
||||
provider = _provider_for_video(tmp_path, tmp_path / "does_not_exist.mp4", monkeypatch)
|
||||
assert provider._decode(0, [0.0], "observation.images.cam") == []
|
||||
|
||||
|
||||
def test_episode_clip_path_trims_via_reencode_video(tmp_path: Path, monkeypatch) -> None:
|
||||
"""Clip extraction delegates to ``video_utils.reencode_video`` with the
|
||||
episode's ``[from_timestamp, to_timestamp)`` trim window — no subprocess.
|
||||
"""
|
||||
from types import SimpleNamespace
|
||||
|
||||
import lerobot.annotations.steerable_pipeline.frames as frames_mod
|
||||
|
||||
src = tmp_path / "src.mp4"
|
||||
src.write_bytes(b"src")
|
||||
fake = _FakeMeta(video_keys=["observation.images.cam"], image_keys=[], video_path=src)
|
||||
fake.episodes[0]["videos/observation.images.cam/from_timestamp"] = 1.5
|
||||
fake.episodes[0]["videos/observation.images.cam/to_timestamp"] = 4.0
|
||||
import lerobot.datasets.dataset_metadata as meta_mod
|
||||
|
||||
monkeypatch.setattr(meta_mod, "LeRobotDatasetMetadata", lambda *a, **k: fake, raising=True)
|
||||
|
||||
captured = {}
|
||||
|
||||
def fake_reencode(
|
||||
input_video_path,
|
||||
output_video_path,
|
||||
camera_encoder=None,
|
||||
overwrite=False,
|
||||
start_time_s=None,
|
||||
end_time_s=None,
|
||||
):
|
||||
captured.update(
|
||||
src=Path(input_video_path),
|
||||
encoder=camera_encoder,
|
||||
start_time_s=start_time_s,
|
||||
end_time_s=end_time_s,
|
||||
)
|
||||
Path(output_video_path).write_bytes(b"clip")
|
||||
|
||||
monkeypatch.setattr(frames_mod, "reencode_video", fake_reencode, raising=True)
|
||||
provider = VideoFrameProvider(root=tmp_path)
|
||||
record = SimpleNamespace(episode_index=0, frame_timestamps=[0.0, 1.0])
|
||||
|
||||
out = provider.episode_clip_path(record, tmp_path / "clips")
|
||||
|
||||
assert out == tmp_path / "clips" / "ep_000000.mp4"
|
||||
assert captured["src"] == src
|
||||
assert captured["start_time_s"] == 1.5
|
||||
assert captured["end_time_s"] == 4.0
|
||||
# H.264 so the clip is decodable by vllm's libav build (sources are often AV1).
|
||||
assert captured["encoder"].vcodec == "h264"
|
||||
|
||||
|
||||
def test_videoframeprovider_serializes_decodes_with_a_lock() -> None:
|
||||
"""torchcodec's cached per-file decoder is single-threaded; the provider
|
||||
must own a dedicated lock that ``_decode`` holds around the decoder call.
|
||||
"""
|
||||
import threading
|
||||
|
||||
lock_field = VideoFrameProvider.__dataclass_fields__.get("_decode_lock")
|
||||
assert lock_field is not None
|
||||
assert lock_field.default_factory is threading.Lock
|
||||
@@ -1,390 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Module 1/2/3 unit tests with stubbed VLMs."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import PIL.Image
|
||||
import pytest
|
||||
|
||||
# ``lerobot.annotations`` imports pull in ``lerobot.datasets`` (-> the HF
|
||||
# ``datasets`` library), which only ships under the ``dataset`` extra. Skip
|
||||
# this module in tiers without it instead of erroring at import.
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
pytest.importorskip("pandas", reason="pandas is required (install lerobot[dataset])")
|
||||
|
||||
from lerobot.annotations.steerable_pipeline.config import ( # noqa: E402
|
||||
InterjectionsConfig,
|
||||
PlanConfig,
|
||||
VqaConfig,
|
||||
)
|
||||
from lerobot.annotations.steerable_pipeline.modules import ( # noqa: E402
|
||||
GeneralVqaModule,
|
||||
InterjectionsAndSpeechModule,
|
||||
PlanSubtasksMemoryModule,
|
||||
)
|
||||
from lerobot.annotations.steerable_pipeline.reader import iter_episodes # noqa: E402
|
||||
from lerobot.annotations.steerable_pipeline.staging import EpisodeStaging # noqa: E402
|
||||
from lerobot.annotations.steerable_pipeline.vlm_client import StubVlmClient # noqa: E402
|
||||
|
||||
from ._helpers import make_canned_responder # noqa: E402
|
||||
|
||||
|
||||
@dataclass
|
||||
class _StubFrameProvider:
|
||||
"""Returns one sentinel object per requested timestamp."""
|
||||
|
||||
# A real (tiny) PIL image so the contact-sheet builder, which resizes and
|
||||
# tiles frames, has something to draw. VQA still passes it through by
|
||||
# identity via ``to_image_blocks``.
|
||||
sentinel: Any = field(default_factory=lambda: PIL.Image.new("RGB", (32, 24)))
|
||||
cameras: tuple[str, ...] = ("observation.images.top",)
|
||||
calls: list[tuple[int, tuple[float, ...], str | None]] = field(default_factory=list)
|
||||
video_calls: list[tuple[int, int, str | None]] = field(default_factory=list)
|
||||
|
||||
@property
|
||||
def camera_keys(self) -> list[str]:
|
||||
return list(self.cameras)
|
||||
|
||||
def frames_at(self, record, timestamps, camera_key=None):
|
||||
self.calls.append((record.episode_index, tuple(timestamps), camera_key))
|
||||
return [self.sentinel] * len(timestamps)
|
||||
|
||||
def video_for_episode(self, record, max_frames, camera_key=None):
|
||||
self.video_calls.append((record.episode_index, max_frames, camera_key))
|
||||
n = min(max_frames, len(record.frame_timestamps))
|
||||
return [self.sentinel] * n
|
||||
|
||||
|
||||
def _spy_responder(captured: list[list[dict[str, Any]]], reply: Any):
|
||||
def responder(messages):
|
||||
captured.append(list(messages))
|
||||
return reply
|
||||
|
||||
return StubVlmClient(responder=responder)
|
||||
|
||||
|
||||
def test_module1_plan_memory_subtask_smoke(fixture_dataset_root: Path, tmp_path: Path) -> None:
|
||||
vlm = make_canned_responder(
|
||||
{
|
||||
"atomic subtasks": {
|
||||
"subtasks": [
|
||||
{"text": "grasp the handle of the sponge", "start": 0.0, "end": 0.4},
|
||||
{"text": "wipe the counter from left to right", "start": 0.4, "end": 0.8},
|
||||
{"text": "place the sponge into the sink", "start": 0.8, "end": 1.1},
|
||||
]
|
||||
},
|
||||
"compressed semantic memory": {"memory": "wiped the counter once"},
|
||||
},
|
||||
)
|
||||
module = PlanSubtasksMemoryModule(vlm=vlm, config=PlanConfig())
|
||||
record = next(iter_episodes(fixture_dataset_root))
|
||||
staging = EpisodeStaging(tmp_path / "stage", record.episode_index)
|
||||
module.run_episode(record, staging)
|
||||
rows = staging.read("plan")
|
||||
|
||||
styles = {r["style"] for r in rows}
|
||||
assert {"subtask", "plan", "memory"}.issubset(styles)
|
||||
# subtask timestamps must be exact frame timestamps
|
||||
frame_set = set(record.frame_timestamps)
|
||||
for row in rows:
|
||||
assert row["timestamp"] in frame_set
|
||||
# one plan row per subtask boundary; the first lands at t0 and each
|
||||
# plan is the deterministic numbered list of still-todo subtasks
|
||||
plan_rows = sorted((r for r in rows if r["style"] == "plan"), key=lambda r: r["timestamp"])
|
||||
subtask_rows = [r for r in rows if r["style"] == "subtask"]
|
||||
assert len(plan_rows) == len(subtask_rows)
|
||||
assert plan_rows[0]["timestamp"] == record.frame_timestamps[0]
|
||||
# the t0 plan enumerates all subtasks; later plans shrink
|
||||
assert plan_rows[0]["content"].startswith("1. ")
|
||||
assert len(plan_rows[0]["content"].splitlines()) == len(subtask_rows)
|
||||
assert len(plan_rows[-1]["content"].splitlines()) == 1
|
||||
|
||||
|
||||
def test_module1_emit_memory_false_skips_memory_keeps_subtasks_and_plan(
|
||||
fixture_dataset_root: Path, tmp_path: Path
|
||||
) -> None:
|
||||
"""``emit_memory=False`` drops ``memory`` rows (and their VLM calls) while
|
||||
leaving subtask + plan generation intact — symmetric to ``emit_plan``."""
|
||||
vlm = make_canned_responder(
|
||||
{
|
||||
"atomic subtasks": {
|
||||
"subtasks": [
|
||||
{"text": "grasp the handle of the sponge", "start": 0.0, "end": 0.4},
|
||||
{"text": "wipe the counter from left to right", "start": 0.4, "end": 0.8},
|
||||
{"text": "place the sponge into the sink", "start": 0.8, "end": 1.1},
|
||||
]
|
||||
},
|
||||
"compressed semantic memory": {"memory": "wiped the counter once"},
|
||||
},
|
||||
)
|
||||
module = PlanSubtasksMemoryModule(vlm=vlm, config=PlanConfig(emit_memory=False))
|
||||
record = next(iter_episodes(fixture_dataset_root))
|
||||
staging = EpisodeStaging(tmp_path / "stage", record.episode_index)
|
||||
module.run_episode(record, staging)
|
||||
rows = staging.read("plan")
|
||||
|
||||
styles = {r["style"] for r in rows}
|
||||
assert "memory" not in styles
|
||||
assert {"subtask", "plan"}.issubset(styles)
|
||||
|
||||
|
||||
def test_module2_at_t0_emits_speech_only_no_interjection(fixture_dataset_root: Path, tmp_path: Path) -> None:
|
||||
vlm = make_canned_responder(
|
||||
{"acknowledgement the robot": {"text": "Sure, on it."}},
|
||||
)
|
||||
module = InterjectionsAndSpeechModule(
|
||||
vlm=vlm,
|
||||
config=InterjectionsConfig(max_interjections_per_episode=0),
|
||||
)
|
||||
record = next(iter_episodes(fixture_dataset_root))
|
||||
staging = EpisodeStaging(tmp_path / "stage", record.episode_index)
|
||||
module.run_episode(record, staging)
|
||||
rows = staging.read("interjections")
|
||||
assert len(rows) == 1
|
||||
only = rows[0]
|
||||
assert only["role"] == "assistant"
|
||||
assert only["style"] is None
|
||||
assert only["content"] is None
|
||||
assert only["timestamp"] == record.frame_timestamps[0]
|
||||
assert only["tool_calls"][0]["function"]["name"] == "say"
|
||||
|
||||
|
||||
def test_module2_mid_episode_emits_paired_interjection_and_speech(
|
||||
fixture_dataset_root: Path, tmp_path: Path
|
||||
) -> None:
|
||||
"""Module 2 anchors interjections on Module 1's subtask boundaries.
|
||||
|
||||
The executor runs Module 1 first, then Module 2 reads the subtask
|
||||
rows back from the same staging tree (see
|
||||
``_mid_episode_interjections``). Reproduce that contract here by
|
||||
seeding the staging with two subtask rows so a single ``0 → 1``
|
||||
boundary exists for Module 2 to anchor on.
|
||||
"""
|
||||
vlm = make_canned_responder(
|
||||
{
|
||||
"acknowledgement the robot": {"text": "OK."},
|
||||
# Marker matches the distinctive line of
|
||||
# ``interjections_interjection.txt`` ("Write ONE compact
|
||||
# interjection ..."). Keep this in sync with that prompt's
|
||||
# wording — the canned responder matches on substring.
|
||||
"Write ONE compact interjection": {
|
||||
"interjection": "now wipe the counter please",
|
||||
"speech": "On it.",
|
||||
},
|
||||
},
|
||||
)
|
||||
module = InterjectionsAndSpeechModule(
|
||||
vlm=vlm,
|
||||
config=InterjectionsConfig(max_interjections_per_episode=1, interjection_min_t=0.2),
|
||||
seed=7,
|
||||
)
|
||||
record = next(iter_episodes(fixture_dataset_root))
|
||||
staging = EpisodeStaging(tmp_path / "stage", record.episode_index)
|
||||
# Seed Module 1's subtask staging so Module 2 has a boundary to
|
||||
# anchor on (it bails with zero rows when no spans exist — the
|
||||
# production executor guarantees Module 1 ran first).
|
||||
boundary_ts = float(record.frame_timestamps[len(record.frame_timestamps) // 2])
|
||||
staging.write(
|
||||
"plan",
|
||||
[
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "grasp the sponge",
|
||||
"style": "subtask",
|
||||
"timestamp": float(record.frame_timestamps[0]),
|
||||
"tool_calls": None,
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "wipe the counter",
|
||||
"style": "subtask",
|
||||
"timestamp": boundary_ts,
|
||||
"tool_calls": None,
|
||||
},
|
||||
],
|
||||
)
|
||||
module.run_episode(record, staging)
|
||||
rows = staging.read("interjections")
|
||||
|
||||
interjections = [r for r in rows if r["style"] == "interjection"]
|
||||
speeches = [r for r in rows if r["style"] is None and r["role"] == "assistant"]
|
||||
assert len(interjections) == 1
|
||||
assert len(speeches) >= 2 # initial t=0 + one paired with the interjection
|
||||
inter_t = interjections[0]["timestamp"]
|
||||
assert any(abs(s["timestamp"] - inter_t) < 1e-9 for s in speeches)
|
||||
|
||||
|
||||
def test_module3_vqa_unique_per_frame_and_camera(single_episode_root: Path, tmp_path: Path) -> None:
|
||||
payload = {
|
||||
"question": "How many cups?",
|
||||
"answer": {"label": "cup", "count": 2, "note": "white & blue"},
|
||||
}
|
||||
vlm = make_canned_responder({"frame-grounded visual question": payload})
|
||||
module = GeneralVqaModule(
|
||||
vlm=vlm,
|
||||
config=VqaConfig(vqa_emission_hz=1.0, K=3),
|
||||
seed=1,
|
||||
frame_provider=_StubFrameProvider(cameras=("observation.images.top", "observation.images.wrist")),
|
||||
)
|
||||
record = next(iter_episodes(single_episode_root))
|
||||
staging = EpisodeStaging(tmp_path / "stage", record.episode_index)
|
||||
module.run_episode(record, staging)
|
||||
rows = staging.read("vqa")
|
||||
# every vqa row must carry a camera tag and one of the configured cameras
|
||||
for r in rows:
|
||||
assert r["style"] == "vqa"
|
||||
assert r.get("camera") in {"observation.images.top", "observation.images.wrist"}
|
||||
# at most one (vqa, user) and one (vqa, assistant) per (timestamp, camera)
|
||||
user_keys = [(r["timestamp"], r["camera"]) for r in rows if r["role"] == "user" and r["style"] == "vqa"]
|
||||
assistant_keys = [
|
||||
(r["timestamp"], r["camera"]) for r in rows if r["role"] == "assistant" and r["style"] == "vqa"
|
||||
]
|
||||
assert len(user_keys) == len(set(user_keys))
|
||||
assert len(assistant_keys) == len(set(assistant_keys))
|
||||
# both cameras must be represented
|
||||
assert {c for _, c in user_keys} == {"observation.images.top", "observation.images.wrist"}
|
||||
# every emitted timestamp must be an exact source frame timestamp
|
||||
frame_set = set(record.frame_timestamps)
|
||||
for ts, _ in user_keys + assistant_keys:
|
||||
assert ts in frame_set
|
||||
|
||||
|
||||
def test_module1_attaches_contact_sheets_to_subtask_prompt(
|
||||
fixture_dataset_root: Path, tmp_path: Path
|
||||
) -> None:
|
||||
"""Module 1 sends timestamped contact-sheet image blocks (not a raw video block)."""
|
||||
captured: list[list[dict[str, Any]]] = []
|
||||
payload = {
|
||||
"subtasks": [
|
||||
{"text": "grasp the handle of the sponge", "start": 0.0, "end": 0.5},
|
||||
{"text": "wipe the counter", "start": 0.5, "end": 1.1},
|
||||
]
|
||||
}
|
||||
memory_payload = {"memory": "wiped once"}
|
||||
|
||||
def responder(messages):
|
||||
captured.append(list(messages))
|
||||
text = ""
|
||||
for m in messages:
|
||||
for block in m.get("content", []):
|
||||
if isinstance(block, dict) and block.get("type") == "text":
|
||||
text = block.get("text", "")
|
||||
if "compressed semantic memory" in text:
|
||||
return memory_payload
|
||||
return payload
|
||||
|
||||
provider = _StubFrameProvider()
|
||||
module = PlanSubtasksMemoryModule(
|
||||
vlm=StubVlmClient(responder=responder),
|
||||
# Disable the rephrasings sub-prompt so the test's only video-bearing
|
||||
# call is the subtask one — keeps the assertions below focused on
|
||||
# ``_generate_subtasks`` rather than fighting the order of unrelated
|
||||
# text-only Module-1 sub-prompts.
|
||||
config=PlanConfig(frames_per_second=2.0, max_frames_per_prompt=60, n_task_rephrasings=0),
|
||||
frame_provider=provider,
|
||||
)
|
||||
record = next(iter_episodes(fixture_dataset_root))
|
||||
staging = EpisodeStaging(tmp_path / "stage", record.episode_index)
|
||||
module.run_episode(record, staging)
|
||||
|
||||
# Find the call carrying the subtask prompt rather than blindly taking
|
||||
# captured[0] — Module 1 issues several sub-prompts and their order is
|
||||
# not part of the contract.
|
||||
assert captured, "no VLM calls made"
|
||||
|
||||
def _prompt_text(messages):
|
||||
for m in messages:
|
||||
for block in m.get("content", []):
|
||||
if isinstance(block, dict) and block.get("type") == "text":
|
||||
return block.get("text", "")
|
||||
return ""
|
||||
|
||||
subtask_calls = [m for m in captured if "atomic subtasks" in _prompt_text(m)]
|
||||
assert len(subtask_calls) == 1, "expected exactly one subtask-prompt VLM call"
|
||||
content = subtask_calls[0][0]["content"]
|
||||
video_blocks = [b for b in content if isinstance(b, dict) and b.get("type") == "video"]
|
||||
image_blocks = [b for b in content if isinstance(b, dict) and b.get("type") == "image"]
|
||||
text_blocks = [b for b in content if isinstance(b, dict) and b.get("type") == "text"]
|
||||
assert video_blocks == [], "contact-sheet mode must not emit a raw video block"
|
||||
assert len(image_blocks) >= 1, f"expected >=1 contact-sheet image block, got {content}"
|
||||
assert all(isinstance(b["image"], PIL.Image.Image) for b in image_blocks)
|
||||
assert len(text_blocks) == 1
|
||||
# the prompt is prefixed with the contact-sheet reading instructions
|
||||
assert text_blocks[0]["text"].startswith("CONTACT SHEETS")
|
||||
# frames were decoded for this episode at episode-relative timestamps
|
||||
assert provider.calls and provider.calls[0][0] == record.episode_index
|
||||
|
||||
|
||||
def test_module3_attaches_frame_image_block_to_prompt(single_episode_root: Path, tmp_path: Path) -> None:
|
||||
"""Each VQA prompt must carry a single image block at the emission frame."""
|
||||
captured: list[list[dict[str, Any]]] = []
|
||||
payload = {
|
||||
"question": "How many cups?",
|
||||
"answer": {"label": "cup", "count": 1},
|
||||
}
|
||||
provider = _StubFrameProvider()
|
||||
module = GeneralVqaModule(
|
||||
vlm=_spy_responder(captured, payload),
|
||||
config=VqaConfig(vqa_emission_hz=1.0, K=1),
|
||||
seed=0,
|
||||
frame_provider=provider,
|
||||
)
|
||||
record = next(iter_episodes(single_episode_root))
|
||||
staging = EpisodeStaging(tmp_path / "stage", record.episode_index)
|
||||
module.run_episode(record, staging)
|
||||
|
||||
assert captured, "no VLM calls made"
|
||||
for messages in captured:
|
||||
content = messages[0]["content"]
|
||||
image_blocks = [b for b in content if isinstance(b, dict) and b.get("type") == "image"]
|
||||
text_blocks = [b for b in content if isinstance(b, dict) and b.get("type") == "text"]
|
||||
assert len(image_blocks) == 1, f"expected 1 image block per VQA prompt, got {content}"
|
||||
assert image_blocks[0]["image"] is provider.sentinel
|
||||
assert len(text_blocks) == 1
|
||||
# provider was called once per emission per camera with the exact emission timestamp
|
||||
for ep_idx, ts_tuple, camera in provider.calls:
|
||||
assert ep_idx == record.episode_index
|
||||
assert len(ts_tuple) == 1
|
||||
assert ts_tuple[0] in record.frame_timestamps
|
||||
assert camera in provider.cameras
|
||||
|
||||
|
||||
def test_module3_assistant_content_is_valid_json(single_episode_root: Path, tmp_path: Path) -> None:
|
||||
payload = {
|
||||
"question": "Where is the cup?",
|
||||
"answer": {"detections": [{"label": "cup", "bbox_format": "xyxy", "bbox": [10, 20, 50, 80]}]},
|
||||
}
|
||||
vlm = make_canned_responder({"frame-grounded visual question": payload})
|
||||
module = GeneralVqaModule(
|
||||
vlm=vlm,
|
||||
config=VqaConfig(vqa_emission_hz=1.0, K=2),
|
||||
seed=2,
|
||||
frame_provider=_StubFrameProvider(),
|
||||
)
|
||||
record = next(iter_episodes(single_episode_root))
|
||||
staging = EpisodeStaging(tmp_path / "stage", record.episode_index)
|
||||
module.run_episode(record, staging)
|
||||
rows = staging.read("vqa")
|
||||
for row in rows:
|
||||
if row["role"] == "assistant" and row["style"] == "vqa":
|
||||
decoded = json.loads(row["content"])
|
||||
assert "detections" in decoded
|
||||
@@ -1,183 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""End-to-end smoke: pipeline output → canonical recipe rendering."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
# ``pyarrow`` and the ``lerobot.datasets`` chain (-> the HF ``datasets``
|
||||
# library) only ship under the ``dataset`` extra. Skip this module in
|
||||
# tiers without it instead of erroring at import.
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
pytest.importorskip("pandas", reason="pandas is required (install lerobot[dataset])")
|
||||
|
||||
import pyarrow.parquet as pq # noqa: E402
|
||||
|
||||
from lerobot.annotations.steerable_pipeline.config import ( # noqa: E402
|
||||
AnnotationPipelineConfig,
|
||||
InterjectionsConfig,
|
||||
PlanConfig,
|
||||
VqaConfig,
|
||||
)
|
||||
from lerobot.annotations.steerable_pipeline.executor import Executor # noqa: E402
|
||||
from lerobot.annotations.steerable_pipeline.modules import ( # noqa: E402
|
||||
GeneralVqaModule,
|
||||
InterjectionsAndSpeechModule,
|
||||
PlanSubtasksMemoryModule,
|
||||
)
|
||||
from lerobot.annotations.steerable_pipeline.validator import StagingValidator # noqa: E402
|
||||
from lerobot.annotations.steerable_pipeline.writer import LanguageColumnsWriter # noqa: E402
|
||||
from lerobot.configs.recipe import MessageTurn, TrainingRecipe # noqa: E402
|
||||
from lerobot.datasets.language_render import render_sample # noqa: E402
|
||||
|
||||
from ._helpers import make_canned_responder # noqa: E402
|
||||
|
||||
|
||||
def _build_style_blend_recipe() -> TrainingRecipe:
|
||||
"""Inline blend recipe that consumes every style this pipeline produces.
|
||||
|
||||
The language schema/DSL work used to ship
|
||||
``src/lerobot/configs/recipes/pi05_hirobot.yaml`` as a canonical
|
||||
example, but that file was dropped during review. The contract this
|
||||
test guards is "the recipe DSL can render non-empty messages from
|
||||
pipeline output", which doesn't require a specific YAML — so we build
|
||||
the equivalent blend in code.
|
||||
"""
|
||||
return TrainingRecipe(
|
||||
blend={
|
||||
"low_level_execution": TrainingRecipe(
|
||||
weight=0.35,
|
||||
messages=[
|
||||
MessageTurn(
|
||||
role="user",
|
||||
content="${task}\nPlan: ${plan}\nMemory: ${memory}",
|
||||
stream="high_level",
|
||||
),
|
||||
MessageTurn(role="assistant", content="${subtask}", stream="low_level", target=True),
|
||||
],
|
||||
),
|
||||
"user_interjection_response": TrainingRecipe(
|
||||
weight=0.16,
|
||||
bindings={
|
||||
"speech": "emitted_at(t, role=assistant, tool_name=say)",
|
||||
"interjection": "emitted_at(t, style=interjection)",
|
||||
},
|
||||
messages=[
|
||||
MessageTurn(role="user", content="${task}", stream="high_level"),
|
||||
MessageTurn(
|
||||
role="user",
|
||||
content="${interjection}",
|
||||
stream="high_level",
|
||||
if_present="interjection",
|
||||
),
|
||||
MessageTurn(
|
||||
role="assistant",
|
||||
content="${plan}",
|
||||
stream="high_level",
|
||||
target=True,
|
||||
if_present="plan",
|
||||
tool_calls_from="speech",
|
||||
),
|
||||
],
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _build_executor() -> Executor:
|
||||
vlm = make_canned_responder(
|
||||
{
|
||||
"atomic subtasks": {
|
||||
"subtasks": [
|
||||
{"text": "grasp the bottle", "start": 0.0, "end": 0.5},
|
||||
{"text": "pour into the cup", "start": 0.5, "end": 1.0},
|
||||
{"text": "place the bottle down", "start": 1.0, "end": 1.5},
|
||||
]
|
||||
},
|
||||
"compressed semantic memory": {"memory": "poured once"},
|
||||
"acknowledgement the robot": {"text": "Sure."},
|
||||
"compact interjection": {
|
||||
"interjection": "use less water",
|
||||
"speech": "Using less water.",
|
||||
},
|
||||
"frame-grounded visual question": {
|
||||
"question": "How many cups?",
|
||||
"answer": {"label": "cup", "count": 1},
|
||||
},
|
||||
},
|
||||
)
|
||||
config = AnnotationPipelineConfig(
|
||||
plan=PlanConfig(),
|
||||
interjections=InterjectionsConfig(max_interjections_per_episode=1, interjection_min_t=0.5),
|
||||
vqa=VqaConfig(vqa_emission_hz=1.0, K=2),
|
||||
)
|
||||
return Executor(
|
||||
config=config,
|
||||
plan=PlanSubtasksMemoryModule(vlm=vlm, config=config.plan),
|
||||
interjections=InterjectionsAndSpeechModule(vlm=vlm, config=config.interjections, seed=config.seed),
|
||||
vqa=GeneralVqaModule(vlm=vlm, config=config.vqa, seed=config.seed),
|
||||
writer=LanguageColumnsWriter(),
|
||||
validator=StagingValidator(),
|
||||
)
|
||||
|
||||
|
||||
def test_canonical_recipe_renders_nonempty_from_pipeline_output(
|
||||
single_episode_root: Path,
|
||||
) -> None:
|
||||
executor = _build_executor()
|
||||
summary = executor.run(single_episode_root)
|
||||
# validator may emit warnings but no errors for the synthetic fixture
|
||||
assert summary.validation_report.ok, summary.validation_report.summary()
|
||||
|
||||
table = pq.read_table(single_episode_root / "data" / "chunk-000" / "file-000.parquet")
|
||||
persistent_lists = table.column("language_persistent").to_pylist()
|
||||
events_lists = table.column("language_events").to_pylist()
|
||||
timestamps = table.column("timestamp").to_pylist()
|
||||
|
||||
recipe = _build_style_blend_recipe()
|
||||
|
||||
rendered_any = False
|
||||
for ts, persistent, events in zip(timestamps, persistent_lists, events_lists, strict=True):
|
||||
result = render_sample(
|
||||
recipe=recipe,
|
||||
persistent=persistent,
|
||||
events=events,
|
||||
t=float(ts),
|
||||
sample_idx=0,
|
||||
dataset_ctx={"task": "Pour water from the bottle into the cup."},
|
||||
)
|
||||
if result is None:
|
||||
continue
|
||||
if result["messages"]:
|
||||
rendered_any = True
|
||||
assert result["target_message_indices"]
|
||||
break
|
||||
assert rendered_any, "recipe rendered no messages from pipeline output"
|
||||
|
||||
# Sanity: speech atom appears in events column intact
|
||||
flat_events = [r for ev in events_lists for r in ev]
|
||||
speech_rows = [r for r in flat_events if r.get("style") is None and r.get("role") == "assistant"]
|
||||
assert speech_rows
|
||||
say = speech_rows[0]["tool_calls"][0]
|
||||
assert say["function"]["name"] == "say"
|
||||
assert isinstance(say["function"]["arguments"]["text"], str)
|
||||
# The pipeline does not write a ``tools`` column — the say schema lives
|
||||
# as a constant (``SAY_TOOL_SCHEMA``) so the language row struct is the
|
||||
# single source of truth for the v3.1 schema.
|
||||
assert "tools" not in table.column_names
|
||||
@@ -1,133 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Validator behavior tests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
# ``lerobot.annotations`` imports pull in ``lerobot.datasets`` (-> the HF
|
||||
# ``datasets`` library), which only ships under the ``dataset`` extra. Skip
|
||||
# this module in tiers without it instead of erroring at import.
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
pytest.importorskip("pandas", reason="pandas is required (install lerobot[dataset])")
|
||||
|
||||
from lerobot.annotations.steerable_pipeline.reader import iter_episodes # noqa: E402
|
||||
from lerobot.annotations.steerable_pipeline.staging import EpisodeStaging # noqa: E402
|
||||
from lerobot.annotations.steerable_pipeline.validator import StagingValidator # noqa: E402
|
||||
from lerobot.annotations.steerable_pipeline.writer import speech_atom # noqa: E402
|
||||
|
||||
|
||||
def _validate(root: Path, staging_dir: Path):
|
||||
records = list(iter_episodes(root))
|
||||
return StagingValidator().validate(records, staging_dir)
|
||||
|
||||
|
||||
def test_validator_catches_misaligned_timestamps(fixture_dataset_root: Path, tmp_path: Path) -> None:
|
||||
staging_dir = tmp_path / "stage"
|
||||
EpisodeStaging(staging_dir, 0).write(
|
||||
"vqa",
|
||||
[
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": json.dumps({"label": "cup", "count": 2}, sort_keys=True),
|
||||
"style": "vqa",
|
||||
"timestamp": 9.999, # not on any 10 fps frame
|
||||
"tool_calls": None,
|
||||
}
|
||||
],
|
||||
)
|
||||
report = _validate(fixture_dataset_root, staging_dir)
|
||||
assert not report.ok
|
||||
assert any("does not match any source frame timestamp" in e for e in report.errors)
|
||||
|
||||
|
||||
def test_validator_catches_orphan_speech(fixture_dataset_root: Path, tmp_path: Path) -> None:
|
||||
staging_dir = tmp_path / "stage"
|
||||
EpisodeStaging(staging_dir, 0).write(
|
||||
"interjections",
|
||||
[
|
||||
speech_atom(0.0, "Got it."),
|
||||
# interjection at 0.3s with NO paired speech
|
||||
{
|
||||
"role": "user",
|
||||
"content": "skip it",
|
||||
"style": "interjection",
|
||||
"timestamp": 0.3,
|
||||
"tool_calls": None,
|
||||
},
|
||||
],
|
||||
)
|
||||
report = _validate(fixture_dataset_root, staging_dir)
|
||||
assert not report.ok
|
||||
assert any("paired speech" in e for e in report.errors)
|
||||
|
||||
|
||||
def test_validator_catches_inconsistent_plan_memory(fixture_dataset_root: Path, tmp_path: Path) -> None:
|
||||
staging_dir = tmp_path / "stage"
|
||||
EpisodeStaging(staging_dir, 0).write(
|
||||
"plan",
|
||||
[
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "1. do x",
|
||||
"style": "plan",
|
||||
"timestamp": 0.0,
|
||||
"tool_calls": None,
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "do x",
|
||||
"style": "subtask",
|
||||
"timestamp": 0.0,
|
||||
"tool_calls": None,
|
||||
},
|
||||
],
|
||||
)
|
||||
EpisodeStaging(staging_dir, 0).write(
|
||||
"interjections",
|
||||
[
|
||||
speech_atom(0.0, "Got it."),
|
||||
speech_atom(0.4, "Replanning."),
|
||||
{
|
||||
"role": "user",
|
||||
"content": "replan",
|
||||
"style": "interjection",
|
||||
"timestamp": 0.4,
|
||||
"tool_calls": None,
|
||||
},
|
||||
],
|
||||
)
|
||||
report = _validate(fixture_dataset_root, staging_dir)
|
||||
# missing co-timestamped plan refresh at 0.4s → error
|
||||
assert not report.ok
|
||||
assert any("co-timestamped plan update" in e for e in report.errors)
|
||||
|
||||
|
||||
def test_validator_catches_wrong_column(fixture_dataset_root: Path, tmp_path: Path) -> None:
|
||||
staging_dir = tmp_path / "stage"
|
||||
EpisodeStaging(staging_dir, 0).write(
|
||||
"plan",
|
||||
[
|
||||
{"role": "user", "content": "where?", "style": "vqa", "timestamp": 0.0, "tool_calls": None},
|
||||
],
|
||||
)
|
||||
report = _validate(fixture_dataset_root, staging_dir)
|
||||
assert not report.ok
|
||||
assert any("plan emitted style 'vqa'" in e or "must be persistent" in e for e in report.errors)
|
||||
@@ -1,41 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Unit tests for ``vlm_client`` helpers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
|
||||
from lerobot.annotations.steerable_pipeline.vlm_client import _bind_serve_port # noqa: E402
|
||||
|
||||
|
||||
def test_bind_serve_port_substitutes_placeholder() -> None:
|
||||
# The {port} placeholder is replaced everywhere it appears, regardless of
|
||||
# parallel vs single server — the bug was the single-server path passing
|
||||
# it through unsubstituted.
|
||||
cmd = "vllm serve M --max-model-len 32768 --port {port}"
|
||||
assert _bind_serve_port(cmd, 8000) == "vllm serve M --max-model-len 32768 --port 8000"
|
||||
|
||||
|
||||
def test_bind_serve_port_appends_when_missing() -> None:
|
||||
assert _bind_serve_port("vllm serve M", 8001) == "vllm serve M --port 8001"
|
||||
|
||||
|
||||
def test_bind_serve_port_leaves_explicit_port_untouched() -> None:
|
||||
cmd = "vllm serve M --port 9000"
|
||||
assert _bind_serve_port(cmd, 8000) == cmd
|
||||
@@ -1,357 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Writer correctness tests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
# ``pyarrow`` and the ``lerobot.annotations`` -> ``lerobot.datasets`` chain
|
||||
# (-> the HF ``datasets`` library) only ship under the ``dataset`` extra.
|
||||
# Skip this module in tiers without it instead of erroring at import.
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
pytest.importorskip("pandas", reason="pandas is required (install lerobot[dataset])")
|
||||
|
||||
import pyarrow.parquet as pq # noqa: E402
|
||||
|
||||
from lerobot.annotations.steerable_pipeline.reader import iter_episodes # noqa: E402
|
||||
from lerobot.annotations.steerable_pipeline.staging import EpisodeStaging # noqa: E402
|
||||
from lerobot.annotations.steerable_pipeline.writer import ( # noqa: E402
|
||||
LanguageColumnsWriter,
|
||||
speech_atom,
|
||||
)
|
||||
|
||||
|
||||
def _stage_episode(
|
||||
staging_dir: Path,
|
||||
episode_index: int,
|
||||
*,
|
||||
plan: list[dict] | None = None,
|
||||
interjections: list[dict] | None = None,
|
||||
vqa: list[dict] | None = None,
|
||||
) -> None:
|
||||
staging = EpisodeStaging(staging_dir, episode_index)
|
||||
if plan is not None:
|
||||
staging.write("plan", plan)
|
||||
if interjections is not None:
|
||||
staging.write("interjections", interjections)
|
||||
if vqa is not None:
|
||||
staging.write("vqa", vqa)
|
||||
|
||||
|
||||
def test_writer_persistence_identity(fixture_dataset_root: Path, tmp_path: Path) -> None:
|
||||
"""Every frame in an episode has a byte-identical persistent list."""
|
||||
staging_dir = tmp_path / "stage"
|
||||
_stage_episode(
|
||||
staging_dir,
|
||||
0,
|
||||
plan=[
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "grasp the sponge",
|
||||
"style": "subtask",
|
||||
"timestamp": 0.0,
|
||||
"tool_calls": None,
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "1. wipe\n2. dry",
|
||||
"style": "plan",
|
||||
"timestamp": 0.0,
|
||||
"tool_calls": None,
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "wiped the counter",
|
||||
"style": "memory",
|
||||
"timestamp": 0.5,
|
||||
"tool_calls": None,
|
||||
},
|
||||
],
|
||||
)
|
||||
records = list(iter_episodes(fixture_dataset_root))
|
||||
LanguageColumnsWriter().write_all(records, staging_dir, fixture_dataset_root)
|
||||
|
||||
table = pq.read_table(fixture_dataset_root / "data" / "chunk-000" / "file-000.parquet")
|
||||
persistent = table.column("language_persistent").to_pylist()
|
||||
first = persistent[0]
|
||||
assert first # non-empty
|
||||
for row in persistent:
|
||||
assert row == first, "persistent slice must be byte-identical across all frames"
|
||||
|
||||
|
||||
def test_writer_events_exact_timestamp(fixture_dataset_root: Path, tmp_path: Path) -> None:
|
||||
staging_dir = tmp_path / "stage"
|
||||
_stage_episode(
|
||||
staging_dir,
|
||||
0,
|
||||
interjections=[
|
||||
speech_atom(0.0, "Got it."),
|
||||
{
|
||||
"role": "user",
|
||||
"content": "skip the dishes",
|
||||
"style": "interjection",
|
||||
"timestamp": 0.5,
|
||||
"tool_calls": None,
|
||||
},
|
||||
speech_atom(0.5, "Skipping the dishes."),
|
||||
],
|
||||
)
|
||||
records = list(iter_episodes(fixture_dataset_root))
|
||||
LanguageColumnsWriter().write_all(records, staging_dir, fixture_dataset_root)
|
||||
|
||||
table = pq.read_table(fixture_dataset_root / "data" / "chunk-000" / "file-000.parquet")
|
||||
timestamps = table.column("timestamp").to_pylist()
|
||||
events = table.column("language_events").to_pylist()
|
||||
for ts, ev in zip(timestamps, events, strict=True):
|
||||
if abs(ts - 0.0) < 1e-9:
|
||||
assert any(r["role"] == "assistant" and r.get("style") is None for r in ev), ev
|
||||
elif abs(ts - 0.5) < 1e-9:
|
||||
assert any(r.get("style") == "interjection" for r in ev), ev
|
||||
assert any(r.get("style") is None for r in ev), ev
|
||||
else:
|
||||
assert ev == []
|
||||
|
||||
|
||||
def test_writer_column_routing(fixture_dataset_root: Path, tmp_path: Path) -> None:
|
||||
staging_dir = tmp_path / "stage"
|
||||
_stage_episode(
|
||||
staging_dir,
|
||||
0,
|
||||
plan=[
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "do X",
|
||||
"style": "subtask",
|
||||
"timestamp": 0.0,
|
||||
"tool_calls": None,
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "1. do X",
|
||||
"style": "plan",
|
||||
"timestamp": 0.0,
|
||||
"tool_calls": None,
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "did X",
|
||||
"style": "memory",
|
||||
"timestamp": 0.3,
|
||||
"tool_calls": None,
|
||||
},
|
||||
],
|
||||
interjections=[
|
||||
speech_atom(0.0, "OK"),
|
||||
{
|
||||
"role": "user",
|
||||
"content": "wait",
|
||||
"style": "interjection",
|
||||
"timestamp": 0.2,
|
||||
"tool_calls": None,
|
||||
},
|
||||
speech_atom(0.2, "Waiting"),
|
||||
],
|
||||
vqa=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "where is the cup?",
|
||||
"style": "vqa",
|
||||
"timestamp": 0.4,
|
||||
"camera": "observation.images.front",
|
||||
"tool_calls": None,
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": json.dumps(
|
||||
{"detections": [{"label": "cup", "bbox_format": "xyxy", "bbox": [1, 2, 3, 4]}]},
|
||||
sort_keys=True,
|
||||
),
|
||||
"style": "vqa",
|
||||
"timestamp": 0.4,
|
||||
"camera": "observation.images.front",
|
||||
"tool_calls": None,
|
||||
},
|
||||
],
|
||||
)
|
||||
records = list(iter_episodes(fixture_dataset_root))
|
||||
LanguageColumnsWriter().write_all(records, staging_dir, fixture_dataset_root)
|
||||
table = pq.read_table(fixture_dataset_root / "data" / "chunk-000" / "file-000.parquet")
|
||||
|
||||
persistent = table.column("language_persistent").to_pylist()[0]
|
||||
persistent_styles = {r["style"] for r in persistent}
|
||||
assert persistent_styles == {"subtask", "plan", "memory"}
|
||||
|
||||
all_events = [r for ev in table.column("language_events").to_pylist() for r in ev]
|
||||
event_styles = {r.get("style") for r in all_events}
|
||||
assert event_styles == {None, "interjection", "vqa"}
|
||||
|
||||
|
||||
def test_writer_drops_subtask_index_idempotent(fixture_dataset_root: Path, tmp_path: Path) -> None:
|
||||
staging_dir = tmp_path / "stage"
|
||||
_stage_episode(
|
||||
staging_dir,
|
||||
0,
|
||||
plan=[
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "do X",
|
||||
"style": "subtask",
|
||||
"timestamp": 0.0,
|
||||
"tool_calls": None,
|
||||
},
|
||||
],
|
||||
)
|
||||
records = list(iter_episodes(fixture_dataset_root))
|
||||
writer = LanguageColumnsWriter()
|
||||
writer.write_all(records, staging_dir, fixture_dataset_root)
|
||||
|
||||
path = fixture_dataset_root / "data" / "chunk-000" / "file-000.parquet"
|
||||
table_a = pq.read_table(path)
|
||||
assert "subtask_index" not in table_a.column_names
|
||||
assert "language_persistent" in table_a.column_names
|
||||
assert "language_events" in table_a.column_names
|
||||
# The writer no longer emits a dataset-level ``tools`` column; the
|
||||
# ``say`` tool schema lives as a code constant (``SAY_TOOL_SCHEMA``)
|
||||
# so the parquet stays small and the pipeline doesn't extend the schema.
|
||||
assert "tools" not in table_a.column_names
|
||||
|
||||
# second pass — must produce identical bytes for the language columns
|
||||
records_again = list(iter_episodes(fixture_dataset_root))
|
||||
writer.write_all(records_again, staging_dir, fixture_dataset_root)
|
||||
table_b = pq.read_table(path)
|
||||
assert (
|
||||
table_a.column("language_persistent").to_pylist() == table_b.column("language_persistent").to_pylist()
|
||||
)
|
||||
assert table_a.column("language_events").to_pylist() == table_b.column("language_events").to_pylist()
|
||||
|
||||
|
||||
def test_writer_normalize_rejects_misrouted_persistent_style() -> None:
|
||||
"""``_normalize_persistent_row`` must reject any non-persistent style."""
|
||||
from lerobot.annotations.steerable_pipeline.writer import _normalize_persistent_row
|
||||
|
||||
with pytest.raises(ValueError, match="non-persistent style"):
|
||||
_normalize_persistent_row(
|
||||
{"role": "assistant", "content": "oops", "style": "vqa", "timestamp": 0.0, "tool_calls": None}
|
||||
)
|
||||
|
||||
|
||||
def test_writer_normalize_rejects_misrouted_event_style() -> None:
|
||||
"""``_normalize_event_row`` must reject any persistent style."""
|
||||
from lerobot.annotations.steerable_pipeline.writer import _normalize_event_row
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
_normalize_event_row({"role": "assistant", "content": "oops", "style": "subtask", "tool_calls": None})
|
||||
|
||||
|
||||
def test_say_tool_schema_constant_is_well_formed() -> None:
|
||||
"""``SAY_TOOL_SCHEMA`` (and ``DEFAULT_TOOLS``) replace the parquet
|
||||
``tools`` column — chat-template consumers import them directly.
|
||||
"""
|
||||
from lerobot.annotations.steerable_pipeline.writer import (
|
||||
DEFAULT_TOOLS,
|
||||
SAY_TOOL_SCHEMA,
|
||||
)
|
||||
|
||||
assert DEFAULT_TOOLS == [SAY_TOOL_SCHEMA]
|
||||
assert SAY_TOOL_SCHEMA["function"]["name"] == "say"
|
||||
params = SAY_TOOL_SCHEMA["function"]["parameters"]
|
||||
assert params["properties"]["text"]["type"] == "string"
|
||||
assert params["required"] == ["text"]
|
||||
|
||||
|
||||
def test_writer_does_not_add_tools_column(fixture_dataset_root: Path, tmp_path: Path) -> None:
|
||||
"""Re-running on a parquet that already has a legacy ``tools`` column
|
||||
must drop it cleanly so reruns converge to the v3.1 schema.
|
||||
"""
|
||||
staging_dir = tmp_path / "stage"
|
||||
_stage_episode(
|
||||
staging_dir,
|
||||
0,
|
||||
plan=[
|
||||
{"role": "assistant", "content": "x", "style": "subtask", "timestamp": 0.0, "tool_calls": None}
|
||||
],
|
||||
)
|
||||
records = list(iter_episodes(fixture_dataset_root))
|
||||
LanguageColumnsWriter().write_all(records, staging_dir, fixture_dataset_root)
|
||||
table = pq.read_table(fixture_dataset_root / "data" / "chunk-000" / "file-000.parquet")
|
||||
assert "tools" not in table.column_names
|
||||
|
||||
|
||||
def test_annotation_metadata_sync_allows_non_streaming_load(
|
||||
fixture_dataset_root: Path, tmp_path: Path
|
||||
) -> None:
|
||||
"""Annotated parquet columns must be declared in ``meta/info.json``.
|
||||
|
||||
``LeRobotDataset`` loads non-streaming datasets by casting parquet
|
||||
against metadata-derived HF features. If the annotation writer adds
|
||||
language columns but metadata stays stale, that cast fails with a column
|
||||
mismatch.
|
||||
"""
|
||||
from lerobot.annotations.steerable_pipeline.executor import Executor
|
||||
from lerobot.datasets.feature_utils import get_hf_features_from_features
|
||||
from lerobot.datasets.io_utils import load_info, load_nested_dataset
|
||||
from lerobot.datasets.language import LANGUAGE_EVENTS, LANGUAGE_PERSISTENT, language_feature_info
|
||||
|
||||
info_path = fixture_dataset_root / "meta" / "info.json"
|
||||
info = json.loads(info_path.read_text())
|
||||
info["features"] = {
|
||||
"episode_index": {"dtype": "int64", "shape": (1,), "names": None},
|
||||
"frame_index": {"dtype": "int64", "shape": (1,), "names": None},
|
||||
"timestamp": {"dtype": "float32", "shape": (1,), "names": None},
|
||||
"task_index": {"dtype": "int64", "shape": (1,), "names": None},
|
||||
}
|
||||
info_path.write_text(json.dumps(info, indent=2))
|
||||
|
||||
staging_dir = tmp_path / "stage"
|
||||
_stage_episode(
|
||||
staging_dir,
|
||||
0,
|
||||
plan=[
|
||||
{"role": "assistant", "content": "do X", "style": "subtask", "timestamp": 0.0, "tool_calls": None}
|
||||
],
|
||||
)
|
||||
records = list(iter_episodes(fixture_dataset_root))
|
||||
LanguageColumnsWriter().write_all(records, staging_dir, fixture_dataset_root)
|
||||
|
||||
Executor._ensure_annotation_metadata_in_info(fixture_dataset_root)
|
||||
|
||||
synced = load_info(fixture_dataset_root)
|
||||
for key, feature in language_feature_info().items():
|
||||
assert synced["features"][key] == feature
|
||||
|
||||
hf_features = get_hf_features_from_features(synced["features"])
|
||||
dataset = load_nested_dataset(fixture_dataset_root / "data", features=hf_features)
|
||||
|
||||
assert LANGUAGE_PERSISTENT in dataset.column_names
|
||||
assert LANGUAGE_EVENTS in dataset.column_names
|
||||
assert len(dataset) == 24
|
||||
|
||||
|
||||
def test_speech_atom_shape_matches_plan_spec() -> None:
|
||||
atom = speech_atom(2.5, "I'm cleaning up!")
|
||||
assert atom["role"] == "assistant"
|
||||
assert atom["style"] is None
|
||||
assert atom["content"] is None
|
||||
assert atom["timestamp"] == 2.5
|
||||
assert isinstance(atom["tool_calls"], list)
|
||||
call = atom["tool_calls"][0]
|
||||
assert call["type"] == "function"
|
||||
assert call["function"]["name"] == "say"
|
||||
assert call["function"]["arguments"]["text"] == "I'm cleaning up!"
|
||||
@@ -1,187 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""End-to-end test of the asynchronous inference stack (client ↔ server).
|
||||
|
||||
This test spins up a lightweight gRPC `PolicyServer` instance with a stubbed
|
||||
policy network and launches a `RobotClient` that uses a `MockRobot`. The goal
|
||||
is to exercise the full communication loop:
|
||||
|
||||
1. Client sends policy specification → Server
|
||||
2. Client streams observations → Server
|
||||
3. Server streams action chunks → Client
|
||||
4. Client executes received actions
|
||||
|
||||
The test succeeds if at least one action is executed and the server records at
|
||||
least one predicted timestep - demonstrating that the gRPC round-trip works
|
||||
end-to-end using real (but lightweight) protocol messages.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
from concurrent import futures
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
# Skip entire module if required deps are not available
|
||||
pytest.importorskip("grpc")
|
||||
pytest.importorskip("serial", reason="pyserial is required (install lerobot[hardware])")
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# End-to-end test
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_async_inference_e2e(monkeypatch):
|
||||
"""Tests the full asynchronous inference pipeline."""
|
||||
# Import grpc-dependent modules inside the test function
|
||||
import grpc
|
||||
|
||||
from lerobot.async_inference.configs import PolicyServerConfig, RobotClientConfig
|
||||
from lerobot.async_inference.helpers import map_robot_keys_to_lerobot_features
|
||||
from lerobot.async_inference.policy_server import PolicyServer
|
||||
from lerobot.async_inference.robot_client import RobotClient
|
||||
from lerobot.robots.utils import make_robot_from_config
|
||||
from lerobot.transport import (
|
||||
services_pb2, # type: ignore
|
||||
services_pb2_grpc, # type: ignore
|
||||
)
|
||||
from tests.mocks.mock_robot import MockRobotConfig
|
||||
|
||||
# Create a stub policy similar to test_policy_server.py
|
||||
class MockPolicy:
|
||||
"""A minimal mock for an actual policy, returning zeros."""
|
||||
|
||||
class _Config:
|
||||
robot_type = "dummy_robot"
|
||||
|
||||
@property
|
||||
def image_features(self):
|
||||
"""Empty image features since this test doesn't use images."""
|
||||
return {}
|
||||
|
||||
def __init__(self):
|
||||
self.config = self._Config()
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
def model(self, batch):
|
||||
# Return a chunk of 20 dummy actions.
|
||||
batch_size = len(batch["robot_type"])
|
||||
return torch.zeros(batch_size, 20, 6)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 1. Create PolicyServer instance with mock policy
|
||||
# ------------------------------------------------------------------
|
||||
policy_server_config = PolicyServerConfig(host="localhost", port=9999)
|
||||
policy_server = PolicyServer(policy_server_config)
|
||||
# Replace the real policy with our fast, deterministic stub.
|
||||
policy_server.policy = MockPolicy()
|
||||
policy_server.actions_per_chunk = 20
|
||||
policy_server.device = "cpu"
|
||||
# NOTE(Steven): Smelly tests as the Server is a state machine being partially mocked. Adding these processors as a quick fix.
|
||||
policy_server.preprocessor = lambda obs: obs
|
||||
policy_server.postprocessor = lambda tensor: tensor
|
||||
|
||||
# Set up robot config and features
|
||||
robot_config = MockRobotConfig()
|
||||
mock_robot = make_robot_from_config(robot_config)
|
||||
|
||||
lerobot_features = map_robot_keys_to_lerobot_features(mock_robot)
|
||||
policy_server.lerobot_features = lerobot_features
|
||||
|
||||
# Force server to produce deterministic action chunks in test mode
|
||||
policy_server.policy_type = "act"
|
||||
|
||||
def _fake_get_action_chunk(_self, _obs, _type="test"):
|
||||
action_dim = 6
|
||||
batch_size = 1
|
||||
actions_per_chunk = policy_server.actions_per_chunk
|
||||
|
||||
return torch.zeros(batch_size, actions_per_chunk, action_dim)
|
||||
|
||||
monkeypatch.setattr(PolicyServer, "_get_action_chunk", _fake_get_action_chunk, raising=True)
|
||||
|
||||
# Bypass potentially heavy model loading inside SendPolicyInstructions
|
||||
def _fake_send_policy_instructions(self, request, context): # noqa: N802
|
||||
return services_pb2.Empty()
|
||||
|
||||
monkeypatch.setattr(PolicyServer, "SendPolicyInstructions", _fake_send_policy_instructions, raising=True)
|
||||
|
||||
# Build gRPC server running a PolicyServer
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=1, thread_name_prefix="policy_server"))
|
||||
services_pb2_grpc.add_AsyncInferenceServicer_to_server(policy_server, server)
|
||||
|
||||
# Use the host/port specified in the fixture's config
|
||||
server_address = f"{policy_server.config.host}:{policy_server.config.port}"
|
||||
server.add_insecure_port(server_address)
|
||||
server.start()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 2. Create a RobotClient around the MockRobot
|
||||
# ------------------------------------------------------------------
|
||||
client_config = RobotClientConfig(
|
||||
server_address=server_address,
|
||||
robot=robot_config,
|
||||
chunk_size_threshold=0.0,
|
||||
policy_type="test",
|
||||
pretrained_name_or_path="test",
|
||||
actions_per_chunk=20,
|
||||
)
|
||||
|
||||
client = RobotClient(client_config)
|
||||
assert client.start(), "Client failed initial handshake with the server"
|
||||
|
||||
# Track action chunks received and verify device type
|
||||
action_chunks_received = {"count": 0, "actions_on_cpu": True}
|
||||
original_aggregate = client._aggregate_action_queues
|
||||
|
||||
def counting_aggregate(*args, **kwargs):
|
||||
action_chunks_received["count"] += 1
|
||||
# Check that all received actions are on CPU
|
||||
if args:
|
||||
for timed_action in args[0]: # args[0] is the list of TimedAction
|
||||
action_tensor = timed_action.get_action()
|
||||
if action_tensor.device.type != "cpu":
|
||||
action_chunks_received["actions_on_cpu"] = False
|
||||
return original_aggregate(*args, **kwargs)
|
||||
|
||||
monkeypatch.setattr(client, "_aggregate_action_queues", counting_aggregate)
|
||||
|
||||
# Start client threads
|
||||
action_thread = threading.Thread(target=client.receive_actions, daemon=True)
|
||||
control_thread = threading.Thread(target=client.control_loop, args=({"task": ""}), daemon=True)
|
||||
action_thread.start()
|
||||
control_thread.start()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 3. System exchanges a few messages
|
||||
# ------------------------------------------------------------------
|
||||
# Wait for 5 seconds
|
||||
server.wait_for_termination(timeout=5)
|
||||
|
||||
assert action_chunks_received["count"] > 0, "Client did not receive any action chunks"
|
||||
assert len(policy_server._predicted_timesteps) > 0, "Server did not record any predicted timesteps"
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 4. Stop the system
|
||||
# ------------------------------------------------------------------
|
||||
client.stop()
|
||||
action_thread.join()
|
||||
control_thread.join()
|
||||
policy_server.stop()
|
||||
server.stop(grace=None)
|
||||
@@ -1,454 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
import pickle
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
pytest.importorskip("grpc")
|
||||
|
||||
import numpy as np # noqa: E402
|
||||
import torch # noqa: E402
|
||||
|
||||
from lerobot.async_inference.helpers import ( # noqa: E402
|
||||
FPSTracker,
|
||||
TimedAction,
|
||||
TimedObservation,
|
||||
observations_similar,
|
||||
prepare_image,
|
||||
prepare_raw_observation,
|
||||
raw_observation_to_observation,
|
||||
resize_robot_observation_image,
|
||||
)
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
# FPSTracker
|
||||
# ---------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_fps_tracker_first_observation():
|
||||
"""First observation should initialize timestamp and return 0 FPS."""
|
||||
tracker = FPSTracker(target_fps=30.0)
|
||||
timestamp = 1000.0
|
||||
|
||||
metrics = tracker.calculate_fps_metrics(timestamp)
|
||||
|
||||
assert tracker.first_timestamp == timestamp
|
||||
assert tracker.total_obs_count == 1
|
||||
assert metrics["avg_fps"] == 0.0
|
||||
assert metrics["target_fps"] == 30.0
|
||||
|
||||
|
||||
def test_fps_tracker_single_interval():
|
||||
"""Two observations 1 second apart should give 1 FPS."""
|
||||
tracker = FPSTracker(target_fps=30.0)
|
||||
|
||||
# First observation at t=0
|
||||
metrics1 = tracker.calculate_fps_metrics(0.0)
|
||||
assert metrics1["avg_fps"] == 0.0
|
||||
|
||||
# Second observation at t=1 (1 second later)
|
||||
metrics2 = tracker.calculate_fps_metrics(1.0)
|
||||
expected_fps = 1.0 # (2-1) observations / 1.0 seconds = 1 FPS
|
||||
assert math.isclose(metrics2["avg_fps"], expected_fps, rel_tol=1e-6)
|
||||
|
||||
|
||||
def test_fps_tracker_multiple_intervals():
|
||||
"""Multiple observations should calculate correct average FPS."""
|
||||
tracker = FPSTracker(target_fps=30.0)
|
||||
|
||||
# Simulate 5 observations over 2 seconds (should be 2 FPS average)
|
||||
timestamps = [0.0, 0.5, 1.0, 1.5, 2.0]
|
||||
|
||||
for i, ts in enumerate(timestamps):
|
||||
metrics = tracker.calculate_fps_metrics(ts)
|
||||
|
||||
if i == 0:
|
||||
assert metrics["avg_fps"] == 0.0
|
||||
elif i == len(timestamps) - 1:
|
||||
# After 5 observations over 2 seconds: (5-1)/2 = 2 FPS
|
||||
expected_fps = 2.0
|
||||
assert math.isclose(metrics["avg_fps"], expected_fps, rel_tol=1e-6)
|
||||
|
||||
|
||||
def test_fps_tracker_irregular_intervals():
|
||||
"""FPS calculation should work with irregular time intervals."""
|
||||
tracker = FPSTracker(target_fps=30.0)
|
||||
|
||||
# Irregular timestamps: 0, 0.1, 0.5, 2.0, 3.0 seconds
|
||||
timestamps = [0.0, 0.1, 0.5, 2.0, 3.0]
|
||||
|
||||
for ts in timestamps:
|
||||
metrics = tracker.calculate_fps_metrics(ts)
|
||||
|
||||
# 5 observations over 3 seconds: (5-1)/3 = 1.333... FPS
|
||||
expected_fps = 4.0 / 3.0
|
||||
assert math.isclose(metrics["avg_fps"], expected_fps, rel_tol=1e-6)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
# TimedData helpers
|
||||
# ---------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_timed_action_getters():
|
||||
"""TimedAction stores & returns timestamp, action tensor and timestep."""
|
||||
ts = time.time()
|
||||
action = torch.arange(10)
|
||||
ta = TimedAction(timestamp=ts, action=action, timestep=0)
|
||||
|
||||
assert math.isclose(ta.get_timestamp(), ts, rel_tol=0, abs_tol=1e-6)
|
||||
torch.testing.assert_close(ta.get_action(), action)
|
||||
assert ta.get_timestep() == 0
|
||||
|
||||
|
||||
def test_timed_observation_getters():
|
||||
"""TimedObservation stores & returns timestamp, dict and timestep."""
|
||||
ts = time.time()
|
||||
obs_dict = {OBS_STATE: torch.ones(6)}
|
||||
to = TimedObservation(timestamp=ts, observation=obs_dict, timestep=0)
|
||||
|
||||
assert math.isclose(to.get_timestamp(), ts, rel_tol=0, abs_tol=1e-6)
|
||||
assert to.get_observation() is obs_dict
|
||||
assert to.get_timestep() == 0
|
||||
|
||||
|
||||
def test_timed_data_deserialization_data_getters():
|
||||
"""TimedAction / TimedObservation survive a round-trip through ``pickle``.
|
||||
|
||||
The async-inference stack uses ``pickle.dumps`` to move these objects across
|
||||
the gRPC boundary (see RobotClient.send_observation and PolicyServer.StreamActions).
|
||||
This test ensures that the payload keeps its content intact after
|
||||
the (de)serialization round-trip.
|
||||
"""
|
||||
ts = time.time()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# TimedAction
|
||||
# ------------------------------------------------------------------
|
||||
original_action = torch.randn(6)
|
||||
ta_in = TimedAction(timestamp=ts, action=original_action, timestep=13)
|
||||
|
||||
# Serialize → bytes → deserialize
|
||||
ta_bytes = pickle.dumps(ta_in) # nosec
|
||||
ta_out: TimedAction = pickle.loads(ta_bytes) # nosec B301
|
||||
|
||||
# Identity & content checks
|
||||
assert math.isclose(ta_out.get_timestamp(), ts, rel_tol=0, abs_tol=1e-6)
|
||||
assert ta_out.get_timestep() == 13
|
||||
torch.testing.assert_close(ta_out.get_action(), original_action)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# TimedObservation
|
||||
# ------------------------------------------------------------------
|
||||
obs_dict = {OBS_STATE: torch.arange(4).float()}
|
||||
to_in = TimedObservation(timestamp=ts, observation=obs_dict, timestep=7, must_go=True)
|
||||
|
||||
to_bytes = pickle.dumps(to_in) # nosec
|
||||
to_out: TimedObservation = pickle.loads(to_bytes) # nosec B301
|
||||
|
||||
assert math.isclose(to_out.get_timestamp(), ts, rel_tol=0, abs_tol=1e-6)
|
||||
assert to_out.get_timestep() == 7
|
||||
assert to_out.must_go is True
|
||||
assert to_out.get_observation().keys() == obs_dict.keys()
|
||||
torch.testing.assert_close(to_out.get_observation()[OBS_STATE], obs_dict[OBS_STATE])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
# observations_similar()
|
||||
# ---------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_obs(state: torch.Tensor) -> TimedObservation:
|
||||
"""Create a TimedObservation with raw robot observation format."""
|
||||
return TimedObservation(
|
||||
timestamp=time.time(),
|
||||
observation={
|
||||
"shoulder": state[0].item() if len(state) > 0 else 0.0,
|
||||
"elbow": state[1].item() if len(state) > 1 else 0.0,
|
||||
"wrist": state[2].item() if len(state) > 2 else 0.0,
|
||||
"gripper": state[3].item() if len(state) > 3 else 0.0,
|
||||
},
|
||||
timestep=0,
|
||||
)
|
||||
|
||||
|
||||
def test_observations_similar_true():
|
||||
"""Distance below atol → observations considered similar."""
|
||||
# Create mock lerobot features for the similarity check
|
||||
lerobot_features = {
|
||||
OBS_STATE: {
|
||||
"dtype": "float32",
|
||||
"shape": [4],
|
||||
"names": ["shoulder", "elbow", "wrist", "gripper"],
|
||||
}
|
||||
}
|
||||
|
||||
obs1 = _make_obs(torch.zeros(4))
|
||||
obs2 = _make_obs(0.5 * torch.ones(4))
|
||||
assert observations_similar(obs1, obs2, lerobot_features, atol=2.0)
|
||||
|
||||
obs3 = _make_obs(2.0 * torch.ones(4))
|
||||
assert not observations_similar(obs1, obs3, lerobot_features, atol=2.0)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
# raw_observation_to_observation and helpers
|
||||
# ---------------------------------------------------------------------
|
||||
|
||||
|
||||
def _create_mock_robot_observation():
|
||||
"""Create a mock robot observation with motor positions and camera images."""
|
||||
return {
|
||||
"shoulder": 1.0,
|
||||
"elbow": 2.0,
|
||||
"wrist": 3.0,
|
||||
"gripper": 0.5,
|
||||
"laptop": np.random.randint(0, 256, size=(480, 640, 3), dtype=np.uint8),
|
||||
"phone": np.random.randint(0, 256, size=(480, 640, 3), dtype=np.uint8),
|
||||
}
|
||||
|
||||
|
||||
def _create_mock_lerobot_features():
|
||||
"""Create mock lerobot features mapping similar to what hw_to_dataset_features returns."""
|
||||
return {
|
||||
OBS_STATE: {
|
||||
"dtype": "float32",
|
||||
"shape": [4],
|
||||
"names": ["shoulder", "elbow", "wrist", "gripper"],
|
||||
},
|
||||
f"{OBS_IMAGES}.laptop": {
|
||||
"dtype": "image",
|
||||
"shape": [480, 640, 3],
|
||||
"names": ["height", "width", "channels"],
|
||||
},
|
||||
f"{OBS_IMAGES}.phone": {
|
||||
"dtype": "image",
|
||||
"shape": [480, 640, 3],
|
||||
"names": ["height", "width", "channels"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _create_mock_policy_image_features():
|
||||
"""Create mock policy image features with different resolutions."""
|
||||
return {
|
||||
f"{OBS_IMAGES}.laptop": PolicyFeature(
|
||||
type=FeatureType.VISUAL,
|
||||
shape=(3, 224, 224), # Policy expects smaller resolution
|
||||
),
|
||||
f"{OBS_IMAGES}.phone": PolicyFeature(
|
||||
type=FeatureType.VISUAL,
|
||||
shape=(3, 160, 160), # Different resolution for second camera
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def test_prepare_image():
|
||||
"""Test image preprocessing: int8 → float32, normalization to [0,1]."""
|
||||
# Create mock int8 image data
|
||||
image_int8 = torch.randint(0, 256, size=(3, 224, 224), dtype=torch.uint8)
|
||||
|
||||
processed = prepare_image(image_int8)
|
||||
|
||||
# Check dtype conversion
|
||||
assert processed.dtype == torch.float32
|
||||
|
||||
# Check normalization range
|
||||
assert processed.min() >= 0.0
|
||||
assert processed.max() <= 1.0
|
||||
|
||||
# Check that values are scaled correctly (255 → 1.0, 0 → 0.0)
|
||||
if image_int8.max() == 255:
|
||||
assert torch.isclose(processed.max(), torch.tensor(1.0), atol=1e-6)
|
||||
if image_int8.min() == 0:
|
||||
assert torch.isclose(processed.min(), torch.tensor(0.0), atol=1e-6)
|
||||
|
||||
# Check memory contiguity
|
||||
assert processed.is_contiguous()
|
||||
|
||||
|
||||
def test_resize_robot_observation_image():
|
||||
"""Test image resizing from robot resolution to policy resolution."""
|
||||
# Create mock image: (H=480, W=640, C=3)
|
||||
original_image = torch.randint(0, 256, size=(480, 640, 3), dtype=torch.uint8)
|
||||
target_shape = (3, 224, 224) # (C, H, W)
|
||||
|
||||
resized = resize_robot_observation_image(original_image, target_shape)
|
||||
|
||||
# Check output shape matches target
|
||||
assert resized.shape == target_shape
|
||||
|
||||
# Check that original image had different dimensions
|
||||
assert original_image.shape != resized.shape
|
||||
|
||||
# Check that resizing preserves value range
|
||||
assert resized.min() >= 0
|
||||
assert resized.max() <= 255
|
||||
|
||||
|
||||
def test_prepare_raw_observation():
|
||||
"""Test the preparation of raw robot observation to lerobot format."""
|
||||
robot_obs = _create_mock_robot_observation()
|
||||
lerobot_features = _create_mock_lerobot_features()
|
||||
policy_image_features = _create_mock_policy_image_features()
|
||||
|
||||
prepared = prepare_raw_observation(robot_obs, lerobot_features, policy_image_features)
|
||||
|
||||
# Check that state is properly extracted and batched
|
||||
assert OBS_STATE in prepared
|
||||
state = prepared[OBS_STATE]
|
||||
assert isinstance(state, torch.Tensor)
|
||||
assert state.shape == (1, 4) # Batched state
|
||||
|
||||
# Check that images are processed and resized
|
||||
assert f"{OBS_IMAGES}.laptop" in prepared
|
||||
assert f"{OBS_IMAGES}.phone" in prepared
|
||||
|
||||
laptop_img = prepared[f"{OBS_IMAGES}.laptop"]
|
||||
phone_img = prepared[f"{OBS_IMAGES}.phone"]
|
||||
|
||||
# Check image shapes match policy requirements
|
||||
assert laptop_img.shape == policy_image_features[f"{OBS_IMAGES}.laptop"].shape
|
||||
assert phone_img.shape == policy_image_features[f"{OBS_IMAGES}.phone"].shape
|
||||
|
||||
# Check that images are tensors
|
||||
assert isinstance(laptop_img, torch.Tensor)
|
||||
assert isinstance(phone_img, torch.Tensor)
|
||||
|
||||
|
||||
def test_raw_observation_to_observation_basic():
|
||||
"""Test the main raw_observation_to_observation function."""
|
||||
robot_obs = _create_mock_robot_observation()
|
||||
lerobot_features = _create_mock_lerobot_features()
|
||||
policy_image_features = _create_mock_policy_image_features()
|
||||
|
||||
observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features)
|
||||
|
||||
# Check that all expected keys are present
|
||||
assert OBS_STATE in observation
|
||||
assert f"{OBS_IMAGES}.laptop" in observation
|
||||
assert f"{OBS_IMAGES}.phone" in observation
|
||||
|
||||
# Check state processing
|
||||
state = observation[OBS_STATE]
|
||||
assert isinstance(state, torch.Tensor)
|
||||
assert state.shape == (1, 4) # Batched
|
||||
|
||||
# Check image processing
|
||||
laptop_img = observation[f"{OBS_IMAGES}.laptop"]
|
||||
phone_img = observation[f"{OBS_IMAGES}.phone"]
|
||||
|
||||
# Images should have batch dimension: (B, C, H, W)
|
||||
assert laptop_img.shape == (1, 3, 224, 224)
|
||||
assert phone_img.shape == (1, 3, 160, 160)
|
||||
|
||||
# Check image dtype and range (should be float32 in [0, 1])
|
||||
assert laptop_img.dtype == torch.float32
|
||||
assert phone_img.dtype == torch.float32
|
||||
assert laptop_img.min() >= 0.0 and laptop_img.max() <= 1.0
|
||||
assert phone_img.min() >= 0.0 and phone_img.max() <= 1.0
|
||||
|
||||
|
||||
def test_raw_observation_to_observation_with_non_tensor_data():
|
||||
"""Test that non-tensor data (like task strings) is preserved."""
|
||||
robot_obs = _create_mock_robot_observation()
|
||||
robot_obs["task"] = "pick up the red cube" # Add string instruction
|
||||
|
||||
lerobot_features = _create_mock_lerobot_features()
|
||||
policy_image_features = _create_mock_policy_image_features()
|
||||
|
||||
observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features)
|
||||
|
||||
# Check that task string is preserved
|
||||
assert "task" in observation
|
||||
assert observation["task"] == "pick up the red cube"
|
||||
assert isinstance(observation["task"], str)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_raw_observation_to_observation_device_handling():
|
||||
"""Test that tensors are created (device placement is handled by preprocessor)."""
|
||||
robot_obs = _create_mock_robot_observation()
|
||||
lerobot_features = _create_mock_lerobot_features()
|
||||
policy_image_features = _create_mock_policy_image_features()
|
||||
|
||||
observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features)
|
||||
|
||||
# Check that all expected keys produce tensors (device placement handled by preprocessor later)
|
||||
for key, value in observation.items():
|
||||
if isinstance(value, torch.Tensor):
|
||||
assert value.device.type in ["cpu", "cuda", "mps", "xpu"], f"Tensor {key} on unexpected device"
|
||||
|
||||
|
||||
def test_raw_observation_to_observation_deterministic():
|
||||
"""Test that the function produces consistent results for the same input."""
|
||||
robot_obs = _create_mock_robot_observation()
|
||||
lerobot_features = _create_mock_lerobot_features()
|
||||
policy_image_features = _create_mock_policy_image_features()
|
||||
|
||||
# Run twice with same input
|
||||
obs1 = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features)
|
||||
obs2 = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features)
|
||||
|
||||
# Results should be identical
|
||||
assert set(obs1.keys()) == set(obs2.keys())
|
||||
|
||||
for key in obs1:
|
||||
if isinstance(obs1[key], torch.Tensor):
|
||||
torch.testing.assert_close(obs1[key], obs2[key])
|
||||
else:
|
||||
assert obs1[key] == obs2[key]
|
||||
|
||||
|
||||
def test_image_processing_pipeline_preserves_content():
|
||||
"""Test that the image processing pipeline preserves recognizable patterns."""
|
||||
# Create an image with a specific pattern
|
||||
original_img = np.zeros((100, 100, 3), dtype=np.uint8)
|
||||
original_img[25:75, 25:75, :] = 255 # White square in center
|
||||
|
||||
robot_obs = {"shoulder": 1.0, "elbow": 1.0, "wrist": 1.0, "gripper": 1.0, "laptop": original_img}
|
||||
lerobot_features = {
|
||||
OBS_STATE: {
|
||||
"dtype": "float32",
|
||||
"shape": [4],
|
||||
"names": ["shoulder", "elbow", "wrist", "gripper"],
|
||||
},
|
||||
f"{OBS_IMAGES}.laptop": {
|
||||
"dtype": "image",
|
||||
"shape": [100, 100, 3],
|
||||
"names": ["height", "width", "channels"],
|
||||
},
|
||||
}
|
||||
policy_image_features = {
|
||||
f"{OBS_IMAGES}.laptop": PolicyFeature(
|
||||
type=FeatureType.VISUAL,
|
||||
shape=(3, 50, 50), # Downsamples from 100x100
|
||||
)
|
||||
}
|
||||
|
||||
observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features)
|
||||
|
||||
processed_img = observation[f"{OBS_IMAGES}.laptop"].squeeze(0) # Remove batch dim
|
||||
|
||||
# Check that the center region has higher values than corners
|
||||
# Due to bilinear interpolation, exact values will change but pattern should remain
|
||||
center_val = processed_img[:, 25, 25].mean() # Center of 50x50 image
|
||||
corner_val = processed_img[:, 5, 5].mean() # Corner
|
||||
|
||||
assert center_val > corner_val, "Image processing should preserve recognizable patterns"
|
||||
@@ -1,219 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Unit-tests for the `PolicyServer` core logic.
|
||||
Monkey-patch the `policy` attribute with a stub so that no real model inference is performed.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.utils.constants import OBS_STATE
|
||||
from tests.utils import skip_if_package_missing
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Test fixtures
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class MockPolicy:
|
||||
"""A minimal mock for an actual policy, returning zeros.
|
||||
Refer to tests/policies for tests of the individual policies supported."""
|
||||
|
||||
class _Config:
|
||||
robot_type = "dummy_robot"
|
||||
|
||||
@property
|
||||
def image_features(self) -> dict[str, PolicyFeature]:
|
||||
"""Empty image features since this test doesn't use images."""
|
||||
return {}
|
||||
|
||||
def predict_action_chunk(self, observation: dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
"""Return a chunk of 20 dummy actions."""
|
||||
batch_size = len(observation[OBS_STATE])
|
||||
return torch.zeros(batch_size, 20, 6)
|
||||
|
||||
def __init__(self):
|
||||
self.config = self._Config()
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
# The server calls `policy.to(device)`. This stub ignores it.
|
||||
return self
|
||||
|
||||
def model(self, batch: dict) -> torch.Tensor:
|
||||
# Return a chunk of 20 dummy actions.
|
||||
batch_size = len(batch["robot_type"])
|
||||
return torch.zeros(batch_size, 20, 6)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@skip_if_package_missing("grpcio", "grpc")
|
||||
def policy_server():
|
||||
"""Fresh `PolicyServer` instance with a stubbed-out policy model."""
|
||||
# Import only when the test actually runs (after decorator check)
|
||||
from lerobot.async_inference.configs import PolicyServerConfig
|
||||
from lerobot.async_inference.policy_server import PolicyServer
|
||||
|
||||
test_config = PolicyServerConfig(host="localhost", port=9999)
|
||||
server = PolicyServer(test_config)
|
||||
# Replace the real policy with our fast, deterministic stub.
|
||||
server.policy = MockPolicy()
|
||||
server.actions_per_chunk = 20
|
||||
server.device = "cpu"
|
||||
|
||||
# Add mock lerobot_features that the observation similarity functions need
|
||||
server.lerobot_features = {
|
||||
OBS_STATE: {
|
||||
"dtype": "float32",
|
||||
"shape": [6],
|
||||
"names": ["joint1", "joint2", "joint3", "joint4", "joint5", "joint6"],
|
||||
}
|
||||
}
|
||||
|
||||
return server
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Helper utilities for tests
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_obs(state: torch.Tensor, timestep: int = 0, must_go: bool = False):
|
||||
"""Create a TimedObservation with a given state vector."""
|
||||
# Import only when needed
|
||||
from lerobot.async_inference.helpers import TimedObservation
|
||||
|
||||
return TimedObservation(
|
||||
observation={
|
||||
"joint1": state[0].item() if len(state) > 0 else 0.0,
|
||||
"joint2": state[1].item() if len(state) > 1 else 0.0,
|
||||
"joint3": state[2].item() if len(state) > 2 else 0.0,
|
||||
"joint4": state[3].item() if len(state) > 3 else 0.0,
|
||||
"joint5": state[4].item() if len(state) > 4 else 0.0,
|
||||
"joint6": state[5].item() if len(state) > 5 else 0.0,
|
||||
},
|
||||
timestamp=time.time(),
|
||||
timestep=timestep,
|
||||
must_go=must_go,
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Tests
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_time_action_chunk(policy_server):
|
||||
"""Verify that `_time_action_chunk` assigns correct timestamps and timesteps."""
|
||||
start_ts = time.time()
|
||||
start_t = 10
|
||||
# A chunk of 3 action tensors.
|
||||
action_tensors = [torch.randn(6) for _ in range(3)]
|
||||
|
||||
timed_actions = policy_server._time_action_chunk(start_ts, action_tensors, start_t)
|
||||
|
||||
assert len(timed_actions) == 3
|
||||
# Check timesteps
|
||||
assert [ta.get_timestep() for ta in timed_actions] == [10, 11, 12]
|
||||
# Check timestamps
|
||||
expected_timestamps = [
|
||||
start_ts,
|
||||
start_ts + policy_server.config.environment_dt,
|
||||
start_ts + 2 * policy_server.config.environment_dt,
|
||||
]
|
||||
for ta, expected_ts in zip(timed_actions, expected_timestamps, strict=True):
|
||||
assert abs(ta.get_timestamp() - expected_ts) < 1e-6
|
||||
|
||||
|
||||
def test_maybe_enqueue_observation_must_go(policy_server):
|
||||
"""An observation with `must_go=True` is always enqueued."""
|
||||
obs = _make_obs(torch.zeros(6), must_go=True)
|
||||
assert policy_server._enqueue_observation(obs) is True
|
||||
assert policy_server.observation_queue.qsize() == 1
|
||||
assert policy_server.observation_queue.get_nowait() is obs
|
||||
|
||||
|
||||
def test_maybe_enqueue_observation_dissimilar(policy_server):
|
||||
"""A dissimilar observation (not `must_go`) is enqueued."""
|
||||
# Set a last predicted observation.
|
||||
policy_server.last_processed_obs = _make_obs(torch.zeros(6))
|
||||
# Create a new, dissimilar observation.
|
||||
new_obs = _make_obs(torch.ones(6) * 5) # High norm difference
|
||||
|
||||
assert policy_server._enqueue_observation(new_obs) is True
|
||||
assert policy_server.observation_queue.qsize() == 1
|
||||
|
||||
|
||||
def test_maybe_enqueue_observation_is_skipped(policy_server):
|
||||
"""A similar observation (not `must_go`) is skipped."""
|
||||
# Set a last predicted observation.
|
||||
policy_server.last_processed_obs = _make_obs(torch.zeros(6))
|
||||
# Create a new, very similar observation.
|
||||
new_obs = _make_obs(torch.zeros(6) + 1e-4)
|
||||
|
||||
assert policy_server._enqueue_observation(new_obs) is False
|
||||
assert policy_server.observation_queue.empty() is True
|
||||
|
||||
|
||||
def test_obs_sanity_checks(policy_server):
|
||||
"""Unit-test the private `_obs_sanity_checks` helper."""
|
||||
prev = _make_obs(torch.zeros(6), timestep=0)
|
||||
|
||||
# Case 1 – timestep already predicted
|
||||
policy_server._predicted_timesteps.add(1)
|
||||
obs_same_ts = _make_obs(torch.ones(6), timestep=1)
|
||||
assert policy_server._obs_sanity_checks(obs_same_ts, prev) is False
|
||||
|
||||
# Case 2 – observation too similar
|
||||
policy_server._predicted_timesteps.clear()
|
||||
obs_similar = _make_obs(torch.zeros(6) + 1e-4, timestep=2)
|
||||
assert policy_server._obs_sanity_checks(obs_similar, prev) is False
|
||||
|
||||
# Case 3 – genuinely new & dissimilar observation passes
|
||||
obs_ok = _make_obs(torch.ones(6) * 5, timestep=3)
|
||||
assert policy_server._obs_sanity_checks(obs_ok, prev) is True
|
||||
|
||||
|
||||
def test_predict_action_chunk(monkeypatch, policy_server):
|
||||
"""End-to-end test of `_predict_action_chunk` with a stubbed _get_action_chunk."""
|
||||
# Import only when needed
|
||||
from lerobot.async_inference.policy_server import PolicyServer
|
||||
|
||||
# Force server to act-style policy; patch method to return deterministic tensor
|
||||
policy_server.policy_type = "act"
|
||||
# NOTE(Steven): Smelly tests as the Server is a state machine being partially mocked. Adding these processors as a quick fix.
|
||||
policy_server.preprocessor = lambda obs: obs
|
||||
policy_server.postprocessor = lambda tensor: tensor
|
||||
action_dim = 6
|
||||
batch_size = 1
|
||||
actions_per_chunk = policy_server.actions_per_chunk
|
||||
|
||||
def _fake_get_action_chunk(_self, _obs, _type="act"):
|
||||
return torch.zeros(batch_size, actions_per_chunk, action_dim)
|
||||
|
||||
monkeypatch.setattr(PolicyServer, "_get_action_chunk", _fake_get_action_chunk, raising=True)
|
||||
|
||||
obs = _make_obs(torch.zeros(6), timestep=5)
|
||||
timed_actions = policy_server._predict_action_chunk(obs)
|
||||
|
||||
assert len(timed_actions) == actions_per_chunk
|
||||
assert [ta.get_timestep() for ta in timed_actions] == list(range(5, 5 + actions_per_chunk))
|
||||
|
||||
for i, ta in enumerate(timed_actions):
|
||||
expected_ts = obs.get_timestamp() + i * policy_server.config.environment_dt
|
||||
assert abs(ta.get_timestamp() - expected_ts) < 1e-6
|
||||
@@ -1,271 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Unit-tests for the `RobotClient` action-queue logic (pure Python, no gRPC).
|
||||
|
||||
We monkey-patch `lerobot.robots.utils.make_robot_from_config` so that
|
||||
no real hardware is accessed. Only the queue-update mechanism is verified.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from queue import Queue
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
# Skip entire module if required deps are not available
|
||||
pytest.importorskip("grpc")
|
||||
pytest.importorskip("serial", reason="pyserial is required (install lerobot[hardware])")
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Test fixtures
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def robot_client():
|
||||
"""Fresh `RobotClient` instance for each test case (no threads started).
|
||||
Uses DummyRobot."""
|
||||
# Import only when the test actually runs (after decorator check)
|
||||
from lerobot.async_inference.configs import RobotClientConfig
|
||||
from lerobot.async_inference.robot_client import RobotClient
|
||||
from tests.mocks.mock_robot import MockRobotConfig
|
||||
|
||||
test_config = MockRobotConfig()
|
||||
|
||||
# gRPC channel is not actually used in tests, so using a dummy address
|
||||
test_config = RobotClientConfig(
|
||||
robot=test_config,
|
||||
server_address="localhost:9999",
|
||||
policy_type="test",
|
||||
pretrained_name_or_path="test",
|
||||
actions_per_chunk=20,
|
||||
)
|
||||
|
||||
client = RobotClient(test_config)
|
||||
|
||||
# Initialize attributes that are normally set in start() method
|
||||
client.chunks_received = 0
|
||||
client.available_actions_size = []
|
||||
|
||||
yield client
|
||||
|
||||
if client.robot.is_connected:
|
||||
client.stop()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Helper utilities for tests
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_actions(start_ts: float, start_t: int, count: int):
|
||||
"""Generate `count` consecutive TimedAction objects starting at timestep `start_t`."""
|
||||
from lerobot.async_inference.helpers import TimedAction
|
||||
|
||||
fps = 30 # emulates most common frame-rate
|
||||
actions = []
|
||||
for i in range(count):
|
||||
timestep = start_t + i
|
||||
timestamp = start_ts + i * (1 / fps)
|
||||
action_tensor = torch.full((6,), timestep, dtype=torch.float32)
|
||||
actions.append(TimedAction(action=action_tensor, timestep=timestep, timestamp=timestamp))
|
||||
return actions
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Tests
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_update_action_queue_discards_stale(robot_client):
|
||||
"""`_update_action_queue` must drop actions with `timestep` <= `latest_action`."""
|
||||
|
||||
# Pretend we already executed up to action #4
|
||||
robot_client.latest_action = 4
|
||||
|
||||
# Incoming chunk contains timesteps 3..7 -> expect 5,6,7 kept.
|
||||
incoming = _make_actions(start_ts=time.time(), start_t=3, count=5) # 3,4,5,6,7
|
||||
|
||||
robot_client._aggregate_action_queues(incoming)
|
||||
|
||||
# Extract timesteps from queue
|
||||
resulting_timesteps = [a.get_timestep() for a in robot_client.action_queue.queue]
|
||||
|
||||
assert resulting_timesteps == [5, 6, 7]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"weight_old, weight_new",
|
||||
[
|
||||
(1.0, 0.0),
|
||||
(0.0, 1.0),
|
||||
(0.5, 0.5),
|
||||
(0.2, 0.8),
|
||||
(0.8, 0.2),
|
||||
(0.1, 0.9),
|
||||
(0.9, 0.1),
|
||||
],
|
||||
)
|
||||
def test_aggregate_action_queues_combines_actions_in_overlap(
|
||||
robot_client, weight_old: float, weight_new: float
|
||||
):
|
||||
"""`_aggregate_action_queues` must combine actions on overlapping timesteps according
|
||||
to the provided aggregate_fn, here tested with multiple coefficients."""
|
||||
from lerobot.async_inference.helpers import TimedAction
|
||||
|
||||
robot_client.chunks_received = 0
|
||||
|
||||
# Pretend we already executed up to action #4, and queue contains actions for timesteps 5..6
|
||||
robot_client.latest_action = 4
|
||||
current_actions = _make_actions(
|
||||
start_ts=time.time(), start_t=5, count=2
|
||||
) # actions are [torch.ones(6), torch.ones(6), ...]
|
||||
current_actions = [
|
||||
TimedAction(action=10 * a.get_action(), timestep=a.get_timestep(), timestamp=a.get_timestamp())
|
||||
for a in current_actions
|
||||
]
|
||||
|
||||
for a in current_actions:
|
||||
robot_client.action_queue.put(a)
|
||||
|
||||
# Incoming chunk contains timesteps 3..7 -> expect 5,6,7 kept.
|
||||
incoming = _make_actions(start_ts=time.time(), start_t=3, count=5) # 3,4,5,6,7
|
||||
|
||||
overlap_timesteps = [5, 6] # properly tested in test_aggregate_action_queues_discards_stale
|
||||
nonoverlap_timesteps = [7]
|
||||
|
||||
robot_client._aggregate_action_queues(
|
||||
incoming, aggregate_fn=lambda x1, x2: weight_old * x1 + weight_new * x2
|
||||
)
|
||||
|
||||
queue_overlap_actions = []
|
||||
queue_non_overlap_actions = []
|
||||
for a in robot_client.action_queue.queue:
|
||||
if a.get_timestep() in overlap_timesteps:
|
||||
queue_overlap_actions.append(a)
|
||||
elif a.get_timestep() in nonoverlap_timesteps:
|
||||
queue_non_overlap_actions.append(a)
|
||||
|
||||
queue_overlap_actions = sorted(queue_overlap_actions, key=lambda x: x.get_timestep())
|
||||
queue_non_overlap_actions = sorted(queue_non_overlap_actions, key=lambda x: x.get_timestep())
|
||||
|
||||
assert torch.allclose(
|
||||
queue_overlap_actions[0].get_action(),
|
||||
weight_old * current_actions[0].get_action() + weight_new * incoming[-3].get_action(),
|
||||
)
|
||||
assert torch.allclose(
|
||||
queue_overlap_actions[1].get_action(),
|
||||
weight_old * current_actions[1].get_action() + weight_new * incoming[-2].get_action(),
|
||||
)
|
||||
assert torch.allclose(queue_non_overlap_actions[0].get_action(), incoming[-1].get_action())
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"chunk_size, queue_len, expected",
|
||||
[
|
||||
(20, 12, False), # 12 / 20 = 0.6 > g=0.5 threshold, not ready to send
|
||||
(20, 8, True), # 8 / 20 = 0.4 <= g=0.5, ready to send
|
||||
(10, 5, True),
|
||||
(10, 6, False),
|
||||
],
|
||||
)
|
||||
def test_ready_to_send_observation(robot_client, chunk_size: int, queue_len: int, expected: bool):
|
||||
"""Validate `_ready_to_send_observation` ratio logic for various sizes."""
|
||||
|
||||
robot_client.action_chunk_size = chunk_size
|
||||
|
||||
# Clear any existing actions then fill with `queue_len` dummy entries ----
|
||||
robot_client.action_queue = Queue()
|
||||
|
||||
dummy_actions = _make_actions(start_ts=time.time(), start_t=0, count=queue_len)
|
||||
for act in dummy_actions:
|
||||
robot_client.action_queue.put(act)
|
||||
|
||||
assert robot_client._ready_to_send_observation() is expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"g_threshold, expected",
|
||||
[
|
||||
# The condition is `queue_size / chunk_size <= g`.
|
||||
# Here, ratio = 6 / 10 = 0.6.
|
||||
(0.0, False), # 0.6 <= 0.0 is False
|
||||
(0.1, False),
|
||||
(0.2, False),
|
||||
(0.3, False),
|
||||
(0.4, False),
|
||||
(0.5, False),
|
||||
(0.6, True), # 0.6 <= 0.6 is True
|
||||
(0.7, True),
|
||||
(0.8, True),
|
||||
(0.9, True),
|
||||
(1.0, True),
|
||||
],
|
||||
)
|
||||
def test_ready_to_send_observation_with_varying_threshold(robot_client, g_threshold: float, expected: bool):
|
||||
"""Validate `_ready_to_send_observation` with fixed sizes and varying `g`."""
|
||||
# Fixed sizes for this test: ratio = 6 / 10 = 0.6
|
||||
chunk_size = 10
|
||||
queue_len = 6
|
||||
|
||||
robot_client.action_chunk_size = chunk_size
|
||||
# This is the parameter we are testing
|
||||
robot_client._chunk_size_threshold = g_threshold
|
||||
|
||||
# Fill queue with dummy actions
|
||||
robot_client.action_queue = Queue()
|
||||
dummy_actions = _make_actions(start_ts=time.time(), start_t=0, count=queue_len)
|
||||
for act in dummy_actions:
|
||||
robot_client.action_queue.put(act)
|
||||
|
||||
assert robot_client._ready_to_send_observation() is expected
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Regression test: robot type registry populated by robot_client imports
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_robot_client_registers_builtin_robot_types():
|
||||
"""Importing robot_client must populate RobotConfig's ChoiceRegistry.
|
||||
|
||||
This is a regression test for a bug introduced in #2425, where removing
|
||||
robot module imports from robot_client.py caused RobotConfig's registry to
|
||||
be empty, breaking CLI argument parsing with:
|
||||
error: argument --robot.type: invalid choice: 'so101_follower' (choose from )
|
||||
|
||||
Robot types are registered via @RobotConfig.register_subclass() decorators
|
||||
at import time, so all supported modules must be explicitly imported.
|
||||
"""
|
||||
import lerobot.async_inference.robot_client # noqa: F401
|
||||
from lerobot.robots.config import RobotConfig
|
||||
|
||||
known_choices = RobotConfig.get_known_choices()
|
||||
|
||||
expected_robot_types = [
|
||||
"so100_follower",
|
||||
"so101_follower",
|
||||
"koch_follower",
|
||||
"omx_follower",
|
||||
"bi_so_follower",
|
||||
]
|
||||
for robot_type in expected_robot_types:
|
||||
assert robot_type in known_choices, (
|
||||
f"Robot type '{robot_type}' is not registered in RobotConfig's ChoiceRegistry. "
|
||||
f"Ensure the corresponding module is imported in robot_client.py. "
|
||||
f"Known choices: {sorted(known_choices)}"
|
||||
)
|
||||
@@ -289,52 +289,6 @@ def test_aggregate_datasets(tmp_path, lerobot_dataset_factory):
|
||||
assert_dataset_iteration_works(aggr_ds)
|
||||
|
||||
|
||||
def test_aggregate_datasets_without_concatenation(tmp_path, lerobot_dataset_factory):
|
||||
"""With concatenation disabled, each source file is kept as its own destination file."""
|
||||
ds_0 = lerobot_dataset_factory(
|
||||
root=tmp_path / "no_stitch_0",
|
||||
repo_id=f"{DUMMY_REPO_ID}_no_stitch_0",
|
||||
total_episodes=3,
|
||||
total_frames=60,
|
||||
)
|
||||
ds_1 = lerobot_dataset_factory(
|
||||
root=tmp_path / "no_stitch_1",
|
||||
repo_id=f"{DUMMY_REPO_ID}_no_stitch_1",
|
||||
total_episodes=4,
|
||||
total_frames=80,
|
||||
)
|
||||
|
||||
aggr_root = tmp_path / "no_stitch_aggr"
|
||||
aggregate_datasets(
|
||||
repo_ids=[ds_0.repo_id, ds_1.repo_id],
|
||||
roots=[ds_0.root, ds_1.root],
|
||||
aggr_repo_id=f"{DUMMY_REPO_ID}_no_stitch_aggr",
|
||||
aggr_root=aggr_root,
|
||||
concatenate_videos=False,
|
||||
concatenate_data=False,
|
||||
)
|
||||
|
||||
with (
|
||||
patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download,
|
||||
):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
mock_snapshot_download.return_value = str(aggr_root)
|
||||
aggr_ds = LeRobotDataset(f"{DUMMY_REPO_ID}_no_stitch_aggr", root=aggr_root)
|
||||
|
||||
assert_episode_and_frame_counts(
|
||||
aggr_ds, ds_0.num_episodes + ds_1.num_episodes, ds_0.num_frames + ds_1.num_frames
|
||||
)
|
||||
assert_dataset_iteration_works(aggr_ds)
|
||||
assert_video_timestamps_within_bounds(aggr_ds)
|
||||
|
||||
# Two single-file sources stay as two files each, instead of being packed together.
|
||||
assert len(list((aggr_root / "data").rglob("*.parquet"))) == 2
|
||||
assert aggr_ds.meta.video_keys, "Test fixture should produce at least one video feature"
|
||||
for key in aggr_ds.meta.video_keys:
|
||||
assert len(list((aggr_root / "videos" / key).rglob("*.mp4"))) == 2
|
||||
|
||||
|
||||
@pytest.mark.parametrize("mutation", ["mismatched_value", "missing_key"])
|
||||
def test_aggregate_incomplete_video_encoder_info_warns_and_nuls_encoders(
|
||||
tmp_path, lerobot_dataset_factory, caplog, mutation
|
||||
|
||||
@@ -83,29 +83,6 @@ def test_get_feature_stats_images():
|
||||
assert stats["min"].shape == stats["max"].shape == stats["mean"].shape == stats["std"].shape
|
||||
|
||||
|
||||
def test_get_feature_stats_uint8_images_preserves_std():
|
||||
data = np.array(
|
||||
[
|
||||
[
|
||||
[[0, 64], [128, 255]],
|
||||
[[255, 128], [64, 0]],
|
||||
[[32, 96], [160, 224]],
|
||||
],
|
||||
[
|
||||
[[16, 80], [144, 240]],
|
||||
[[240, 144], [80, 16]],
|
||||
[[48, 112], [176, 208]],
|
||||
],
|
||||
],
|
||||
dtype=np.uint8,
|
||||
)
|
||||
|
||||
stats = get_feature_stats(data, axis=(0, 2, 3), keepdims=True)
|
||||
|
||||
expected_std = data.transpose(0, 2, 3, 1).reshape(-1, 3).std(axis=0).reshape(1, 3, 1, 1)
|
||||
np.testing.assert_allclose(stats["std"], expected_std)
|
||||
|
||||
|
||||
def test_get_feature_stats_axis_0_keepdims(sample_array):
|
||||
expected = {
|
||||
"min": np.array([[1, 2, 3]]),
|
||||
|
||||
@@ -114,17 +114,28 @@ def test_shuffle():
|
||||
assert set(sampler) == {0, 1, 2, 3, 4, 5}
|
||||
|
||||
|
||||
def test_shuffle_is_reproducible_across_instances():
|
||||
# The order is a pure function of (seed, epoch), so two fresh samplers (e.g. two ranks)
|
||||
# produce the same permutation without any generator synchronization.
|
||||
sampler_a = EpisodeAwareSampler([0], [6], shuffle=True, seed=42)
|
||||
sampler_b = EpisodeAwareSampler([0], [6], shuffle=True, seed=42)
|
||||
epoch_0 = list(sampler_a)
|
||||
assert list(sampler_b) == epoch_0
|
||||
def test_shuffle_with_generator_is_deterministic():
|
||||
# Two samplers shuffling with same-seed generators must yield identical permutations.
|
||||
# This is what keeps batch shards disjoint across ranks in distributed training, where
|
||||
# accelerate synchronizes the sampler's generator state instead of the global torch RNG.
|
||||
sampler_a = EpisodeAwareSampler([0], [6], shuffle=True, generator=torch.Generator().manual_seed(42))
|
||||
sampler_b = EpisodeAwareSampler([0], [6], shuffle=True, generator=torch.Generator().manual_seed(42))
|
||||
assert list(sampler_a) == list(sampler_b)
|
||||
|
||||
# Desyncing the global RNG must not affect the permutation.
|
||||
sampler_c = EpisodeAwareSampler([0], [6], shuffle=True, seed=42)
|
||||
sampler_c = EpisodeAwareSampler([0], [6], shuffle=True, generator=torch.Generator().manual_seed(42))
|
||||
order_before = list(sampler_c)
|
||||
sampler_c.generator.manual_seed(42)
|
||||
torch.randperm(1000) # consume global RNG, as rank-asymmetric code (e.g. eval) would
|
||||
assert list(sampler_c) == epoch_0
|
||||
assert list(sampler_c) == order_before
|
||||
|
||||
|
||||
def test_generator_attribute_defaults_to_none():
|
||||
# accelerate detects synchronizable samplers via `hasattr(sampler, "generator")`,
|
||||
# so the attribute must exist even when no generator is passed.
|
||||
sampler = EpisodeAwareSampler([0], [6], shuffle=True)
|
||||
assert sampler.generator is None
|
||||
assert set(sampler) == {0, 1, 2, 3, 4, 5}
|
||||
|
||||
|
||||
def test_negative_drop_first_frames_raises():
|
||||
@@ -150,87 +161,3 @@ def test_partial_episode_drop_warns(caplog):
|
||||
# Episode 0 is skipped (1 frame, drop 1), Episode 1 keeps frames 2-5
|
||||
assert sampler.indices == [2, 3, 4, 5]
|
||||
assert "Episode 0" in caplog.text
|
||||
|
||||
|
||||
# --- seeded (seed, epoch) shuffling, resume, and state ---
|
||||
|
||||
from lerobot.datasets.sampler import compute_sampler_state # noqa: E402
|
||||
|
||||
EPISODE_BOUNDS = ([0, 2, 3], [2, 3, 6]) # episodes of 2, 1 and 3 frames
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_frames", [1, 2, 3, 37, 64, 100])
|
||||
def test_deterministic_sampler_shuffle_is_permutation(num_frames):
|
||||
for seed in (0, 1, 1234):
|
||||
sampler = EpisodeAwareSampler([0], [num_frames], shuffle=True, seed=seed)
|
||||
assert sorted(sampler) == list(range(num_frames))
|
||||
|
||||
|
||||
def test_deterministic_sampler_epochs_reproduce_and_differ():
|
||||
sampler_a = EpisodeAwareSampler([0], [100], shuffle=True, seed=42)
|
||||
sampler_b = EpisodeAwareSampler([0], [100], shuffle=True, seed=42)
|
||||
epoch_0 = list(sampler_a)
|
||||
assert list(sampler_b) == epoch_0 # same (seed, epoch) -> same order on any process
|
||||
epoch_1 = list(sampler_a) # __iter__ auto-advances the epoch
|
||||
assert epoch_1 != epoch_0
|
||||
assert sorted(epoch_1) == sorted(epoch_0)
|
||||
sampler_a.set_epoch(0)
|
||||
assert list(sampler_a) == epoch_0
|
||||
assert list(EpisodeAwareSampler([0], [100], shuffle=True, seed=7)) != epoch_0
|
||||
|
||||
|
||||
def test_deterministic_sampler_resume_mid_epoch():
|
||||
reference = EpisodeAwareSampler(*EPISODE_BOUNDS, shuffle=True, seed=42)
|
||||
epoch_0 = list(reference)
|
||||
epoch_1 = list(reference)
|
||||
for start in (0, 1, 4, len(epoch_0)):
|
||||
resumed = EpisodeAwareSampler(*EPISODE_BOUNDS, shuffle=True, seed=42)
|
||||
resumed.load_state_dict({"epoch": 0, "start_index": start})
|
||||
assert list(resumed) == epoch_0[start:]
|
||||
# the resumed sampler continues into the same epoch 1 as the uninterrupted one
|
||||
assert list(resumed) == epoch_1
|
||||
|
||||
|
||||
def test_deterministic_sampler_construction_stores_only_boundaries():
|
||||
# Construction is O(num_episodes), not O(num_frames): a million-frame single episode
|
||||
# instantiates from just its boundaries without materializing a per-frame index list.
|
||||
num_frames = 1_000_000
|
||||
sampler = EpisodeAwareSampler([0], [num_frames], shuffle=True, seed=0)
|
||||
assert len(sampler) == num_frames
|
||||
assert sampler._starts.shape == (1,) and sampler._cum_lengths.shape == (1,)
|
||||
|
||||
|
||||
def test_deterministic_sampler_resume_is_exact_at_scale():
|
||||
# Seeded randperm makes resume sample-exact at non-trivial sizes: regenerating the epoch's
|
||||
# permutation and slicing from the saved offset reproduces the remaining order exactly.
|
||||
num_frames = 100_000
|
||||
reference = EpisodeAwareSampler([0], [num_frames], shuffle=True, seed=0)
|
||||
epoch_0 = list(reference)
|
||||
assert sorted(epoch_0) == list(range(num_frames))
|
||||
start = num_frames - 5
|
||||
resumed = EpisodeAwareSampler([0], [num_frames], shuffle=True, seed=0)
|
||||
resumed.load_state_dict({"epoch": 0, "start_index": start})
|
||||
assert list(resumed) == epoch_0[start:]
|
||||
|
||||
|
||||
def test_compute_sampler_state():
|
||||
# 100 frames, batch 10, 2 ranks -> 10 underlying batches, 5 per rank per epoch.
|
||||
assert compute_sampler_state(step=0, num_frames=100, batch_size=10, num_processes=2) == {
|
||||
"epoch": 0,
|
||||
"start_index": 0,
|
||||
}
|
||||
# step 7 -> epoch 1, 2 per-rank batches in = 2 * 10 * 2 = 40 samples in
|
||||
assert compute_sampler_state(step=7, num_frames=100, batch_size=10, num_processes=2) == {
|
||||
"epoch": 1,
|
||||
"start_index": 40,
|
||||
}
|
||||
# uneven epoch: 95 frames -> 10 underlying batches (last short), still 5 per rank
|
||||
assert compute_sampler_state(step=12, num_frames=95, batch_size=10, num_processes=2) == {
|
||||
"epoch": 2,
|
||||
"start_index": 40,
|
||||
}
|
||||
# uneven sharding: 105 frames -> 11 underlying batches, 6 per rank (even_batches pads)
|
||||
assert compute_sampler_state(step=11, num_frames=105, batch_size=10, num_processes=2) == {
|
||||
"epoch": 1,
|
||||
"start_index": 100,
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user