mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-16 15:57:03 +00:00
Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| bec7d668a6 |
@@ -0,0 +1,11 @@
|
||||
version: 2
|
||||
updates:
|
||||
- package-ecosystem: "github-actions"
|
||||
directory: "/"
|
||||
schedule:
|
||||
interval: "weekly"
|
||||
cooldown:
|
||||
default-days: 7
|
||||
groups:
|
||||
actions:
|
||||
patterns: ["*"]
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 445 KiB |
@@ -58,7 +58,7 @@ action = model.select_action(obs)
|
||||
robot.send_action(action)
|
||||
```
|
||||
|
||||
**Supported Hardware:** SO100, LeKiwi, Koch, HopeJR, OMX, EarthRover, Reachy2, Gamepads, Keyboards, Phones, OpenARM, Unitree G1, reBot B601.
|
||||
**Supported Hardware:** SO100, LeKiwi, Koch, HopeJR, OMX, EarthRover, Reachy2, Gamepads, Keyboards, Phones, OpenARM, Unitree G1.
|
||||
|
||||
While these devices are natively integrated into the LeRobot codebase, the library is designed to be extensible. You can easily implement the Robot interface to utilize LeRobot's data collection, training, and visualization tools for your own custom robot.
|
||||
|
||||
@@ -101,13 +101,11 @@ lerobot-train \
|
||||
--dataset.repo_id=lerobot/aloha_mobile_cabinet
|
||||
```
|
||||
|
||||
| Category | Models |
|
||||
| -------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| **Imitation Learning** | [ACT](./docs/source/policy_act_README.md), [Diffusion](./docs/source/policy_diffusion_README.md), [VQ-BeT](./docs/source/policy_vqbet_README.md), [Multitask DiT Policy](./docs/source/policy_multi_task_dit_README.md) |
|
||||
| **Reinforcement Learning** | [HIL-SERL](./docs/source/hilserl.mdx), [TDMPC](./docs/source/policy_tdmpc_README.md) & QC-FQL (coming soon) |
|
||||
| **VLAs Models** | [Pi0](./docs/source/pi0.mdx), [Pi0Fast](./docs/source/pi0fast.mdx), [Pi0.5](./docs/source/pi05.mdx), [GR00T N1.5](./docs/source/policy_groot_README.md), [SmolVLA](./docs/source/policy_smolvla_README.md), [XVLA](./docs/source/xvla.mdx), [EO-1](./docs/source/eo1.mdx), [MolmoAct2](./docs/source/molmoact2.mdx), [WALL-OSS](./docs/source/walloss.mdx) |
|
||||
| **World Models** | [VLA-JEPA](./docs/source/vla_jepa.mdx) (more coming soon) |
|
||||
| **Reward Models** | [SARM](./docs/source/sarm.mdx), [TOPReward](./docs/source/topreward.mdx), [Robometer](./docs/source/robometer.mdx) |
|
||||
| Category | Models |
|
||||
| -------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| **Imitation Learning** | [ACT](./docs/source/policy_act_README.md), [Diffusion](./docs/source/policy_diffusion_README.md), [VQ-BeT](./docs/source/policy_vqbet_README.md), [Multitask DiT Policy](./docs/source/policy_multi_task_dit_README.md) |
|
||||
| **Reinforcement Learning** | [HIL-SERL](./docs/source/hilserl.mdx), [TDMPC](./docs/source/policy_tdmpc_README.md) & QC-FQL (coming soon) |
|
||||
| **VLAs Models** | [Pi0Fast](./docs/source/pi0fast.mdx), [Pi0.5](./docs/source/pi05.mdx), [GR00T N1.5](./docs/source/policy_groot_README.md), [SmolVLA](./docs/source/policy_smolvla_README.md), [XVLA](./docs/source/xvla.mdx) |
|
||||
|
||||
Similarly to the hardware, you can easily implement your own policy & leverage LeRobot's data collection, training, and visualization tools, and share your model to the HF Hub
|
||||
|
||||
@@ -135,7 +133,6 @@ Learn how to implement your own simulation environment or benchmark and distribu
|
||||
- **[Discord](https://discord.gg/q8Dzzpym3f):** Join the `LeRobot` server to discuss with the community.
|
||||
- **[X](https://x.com/LeRobotHF):** Follow us on X to stay up-to-date with the latest developments.
|
||||
- **[Robot Learning Tutorial](https://huggingface.co/spaces/lerobot/robot-learning-tutorial):** A free, hands-on course to learn robot learning using LeRobot.
|
||||
- **[T-Shirt Folding Experiment](https://huggingface.co/spaces/lerobot/robot-folding):** An end-to-end demonstration of folding t-shirts with LeRobot.
|
||||
|
||||
## Citation
|
||||
|
||||
@@ -143,7 +140,7 @@ If you use LeRobot in your project, please cite the GitHub repository to acknowl
|
||||
|
||||
```bibtex
|
||||
@misc{cadene2024lerobot,
|
||||
author = {Cadene, Remi and Alibert, Simon and Soare, Alexander and Gallouedec, Quentin and Zouitine, Adil and Palma, Steven and Kooijmans, Pepijn and Aractingi, Michel and Shukor, Mustafa and Aubakirova, Dana and Russi, Martino and Capuano, Francesco and Pascal, Caroline and Choghari, Jade and Meftah, Khalil and Ellerbach, Maxime and Moss, Jess and Wolf, Thomas},
|
||||
author = {Cadene, Remi and Alibert, Simon and Soare, Alexander and Gallouedec, Quentin and Zouitine, Adil and Palma, Steven and Kooijmans, Pepijn and Aractingi, Michel and Shukor, Mustafa and Aubakirova, Dana and Russi, Martino and Capuano, Francesco and Pascal, Caroline and Choghari, Jade and Moss, Jess and Wolf, Thomas},
|
||||
title = {LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch},
|
||||
howpublished = "\url{https://github.com/huggingface/lerobot}",
|
||||
year = {2024}
|
||||
|
||||
@@ -1,417 +0,0 @@
|
||||
# Decoupled VLA Inference & Edge Control: System Design Proposal
|
||||
|
||||
## 1. Executive Summary
|
||||
|
||||
This document proposes a production-grade system for decoupling GPU-bound VLA (Vision-Language-Action) policy inference from high-frequency, CPU-bound robot control in LeRobot. The system adopts a **Model-as-a-Service (MaaS)** paradigm using **Zenoh** as the sole transport protocol, enabling multiple edge devices to be served by centralized GPU servers with minimal latency and high reliability.
|
||||
|
||||
An initial prototype exists in `src/lerobot/async_inference/` (gRPC-based, single-client). This proposal defines the target architecture, identifies gaps between the prototype and production requirements, documents known bugs, and establishes the design for the new system.
|
||||
|
||||
---
|
||||
|
||||
## 2. Motivation
|
||||
|
||||
LeRobot's standard control loop runs policy inference and robot I/O in the same process. This works for lightweight policies on local GPUs, but breaks down when:
|
||||
|
||||
- **The policy is too large for edge hardware** (e.g., Pi0 at ~3B parameters requires a dedicated GPU).
|
||||
- **Multiple robots need the same policy** (redundant GPU allocation per robot).
|
||||
- **Inference latency exceeds the control deadline** (e.g., 200ms inference on a 33ms control loop at 30 FPS).
|
||||
|
||||
Decoupling inference from control solves all three: the edge device runs a tight I/O loop on a CPU, while a GPU server handles inference for one or more clients.
|
||||
|
||||
---
|
||||
|
||||
## 3. Core Architectural Principles
|
||||
|
||||
### 3.1 Model-as-a-Service (MaaS)
|
||||
|
||||
Servers initialize models **once at startup** from a configuration manifest. Edge devices do **not** trigger dynamic model loading — they route to pre-warmed servers and validate compatibility via a status endpoint.
|
||||
|
||||
### 3.2 Multi-Tenant & Stateless Inference
|
||||
|
||||
A single GPU server handles multiple edge devices executing the same task. The server is stateless per inference call — `predict_action_chunk()` is a pure function with no side effects on the model. Client isolation is achieved through per-client observation slots and Zenoh key-expression routing.
|
||||
|
||||
> **Invariant**: `predict_action_chunk()` must remain a pure function (no mutation of `self`) for all supported policies. This is what enables safe multi-tenant sharing of a single model instance. This invariant must be documented and tested.
|
||||
|
||||
### 3.3 Zenoh as primary Transport
|
||||
|
||||
The system uses Zenoh's pub/sub model, replacing the current gRPC implementation. Zenoh provides:
|
||||
|
||||
- **Hierarchical key expressions** for routing (natural fit for the cluster/experiment/model/task topology).
|
||||
- **Built-in discovery** (no external service discovery needed).
|
||||
- **Non-blocking publish** for observations (fire-and-forget with best-effort QoS).
|
||||
- **Reliable delivery** configurable per-topic (required for action chunks).
|
||||
- **Shared-memory transport** for same-machine deployments (zero-copy) (if available).
|
||||
|
||||
### 3.4 Local Edge CPU
|
||||
|
||||
Edge devices rely on standard CPUs for sensor polling, image compression, payload serialization, motor control, and data logging. No edge-GPU dependency.
|
||||
|
||||
---
|
||||
|
||||
## 4. System Topology
|
||||
|
||||

|
||||
|
||||
- **Cluster**: A set of GPU machines. Identified by `cluster_uuid`.
|
||||
- **Experiment**: A logical grouping of servers and clients. Identified by `experiment_tag`.
|
||||
- **Server**: One model + one task, pre-warmed. Serves N clients for that model/task combination.
|
||||
- **Client**: One robot, one task. Publishes observations, subscribes to actions.
|
||||
|
||||
The number of clients a single server can handle is a **user decision** based on model inference time and acceptable latency.
|
||||
|
||||
---
|
||||
|
||||
## 5. Component Specifications
|
||||
|
||||
### 5.1 The Edge Device (Client)
|
||||
|
||||
**Responsibilities:**
|
||||
|
||||
1. **Observation capture**: Read sensors (cameras, motors) at the control loop frequency.
|
||||
2. **Image compression**: JPEG-encode RGB images before transmission.
|
||||
3. **Observation publishing**: Non-blocking Zenoh put to the observation topic.
|
||||
4. **Action subscription**: Zenoh callback receives action chunks, deposits into local buffer.
|
||||
5. **Action execution**: Pop actions from buffer, send to robot at control frequency.
|
||||
6. **Action blending**: When a new action chunk overlaps with the current buffer, blend via configurable aggregation function (weighted average, latest-only, etc.).
|
||||
7. **Latency compensation**: Calculate one-way latency from RTT, discard expired initial steps of incoming action chunks.
|
||||
8. **Fail-safe**: If action buffer empties, logs a warning.
|
||||
9. **Data logging**: Record raw observations and executed actions to local `LeRobotDataset` storage for deferred upload.
|
||||
|
||||
**Threading model:**
|
||||
|
||||
- **Control loop thread** (main): Capture observation → deposit in outbox → pop action from buffer → send to robot → sleep to maintain frequency.
|
||||
- **Zenoh action callback** (Zenoh-managed): Receives action chunks, processes RTT, trims stale steps, deposits into action buffer.
|
||||
- **Observation publisher thread**: Drains the outbox, compresses images, serializes, publishes via Zenoh.
|
||||
|
||||
> **Design note**: The current prototype blocks on `send_observation` inside the control loop (BUG-1, see Section 9). The new design decouples observation publishing from the control loop entirely, using a separate thread and Zenoh's non-blocking put.
|
||||
|
||||
### 5.2 The Inference Server (GPU Pod)
|
||||
|
||||
**Responsibilities:**
|
||||
|
||||
1. **Model pre-warming**: Load model and processor pipelines at startup from config manifest (including expected clients & policy parameters).
|
||||
2. **Status publishing**: Expose model capabilities (policy type, expected camera names, resolutions, action dimensions) via Zenoh queryable.
|
||||
3. **Observation subscription**: Subscribe to observation topics for all clients of this model/task. Maintain per-client observation slots (newest-only semantics).
|
||||
4. **Inference**: Single inference thread processes observations sequentially (round-robin across clients). Calls `policy.predict_action_chunk()`.
|
||||
5. **Action publishing**: Publish action chunks to per-client action topics with reliable QoS.
|
||||
|
||||
> **Thread safety**: PyTorch's `model.forward()` is not guaranteed thread-safe. Inference will be sequential, latency is mostly about the capabilities of the server to serve multiple requests.
|
||||
|
||||
---
|
||||
|
||||
## 6. Zenoh Routing & Key Expressions
|
||||
|
||||
### 6.1 Key Expression Schema
|
||||
|
||||
```
|
||||
[cluster_uuid] / [experiment_tag] / [model_id] / [model_version] / [application_tag] / [client_uuid] / [topic]
|
||||
```
|
||||
|
||||
**Example key expressions:**
|
||||
|
||||
| Key Expression | Direction | Purpose |
|
||||
| ------------------------------------------------ | ----------------- | ---------------------------------- |
|
||||
| `jupiter/fabio2/pi0/v1/cookie/robot_a4b9/obs` | Client → Server | Observation payload |
|
||||
| `jupiter/fabio2/pi0/v1/cookie/robot_a4b9/action` | Server → Client | Action chunk |
|
||||
| `jupiter/fabio2/pi0/v1/cookie/*/obs` | Server subscribes | All observations for pi0/v1/cookie |
|
||||
| `jupiter/fabio2/pi0/v1/cookie/status` | Server publishes | Model capabilities (queryable) |
|
||||
|
||||
### 6.2 QoS Configuration
|
||||
|
||||
| Topic | Reliability | Rationale |
|
||||
| -------- | ----------- | -------------------------------------------------------------------- |
|
||||
| `obs` | Best-effort | Dropping stale observations is expected behavior. |
|
||||
| `action` | Reliable | Every action chunk must be delivered; loss causes action starvation. |
|
||||
| `status` | Reliable | Client needs accurate capability info before starting. |
|
||||
|
||||
### 6.3 Discovery Flow
|
||||
|
||||
0. Server goes up with the static configuration.
|
||||
1. Client constructs its target key prefix: `cluster/experiment/model/version/task/`.
|
||||
2. Client queries `cluster/experiment/model/version/task/status` (Zenoh queryable).
|
||||
3. Server responds with its capabilities (expected camera names, image resolutions, action dimensions, model metadata).
|
||||
4. Client validates its own configuration against server capabilities.
|
||||
5. On match: client starts publishing observations and subscribing to actions.
|
||||
6. On mismatch: client logs an error and refuses to start.
|
||||
|
||||
No dynamic client discovery for now.
|
||||
|
||||
---
|
||||
|
||||
## 7. Message Schema
|
||||
|
||||
### 7.1 Observation Payload (Client → Server)
|
||||
|
||||
| Field | Type | Purpose |
|
||||
| ------------- | ------------------ | ----------------------------------------------------------- |
|
||||
| `seq_id` | `uint64` | Incrementing ID for causality tracking and RTT computation. |
|
||||
| `client_uuid` | `string` | Identifies the sending client. |
|
||||
| `state` | `bytes` | Proprioceptive state vector (`numpy.tobytes()`). |
|
||||
| `images` | `dict[str, bytes]` | JPEG-compressed camera images, keyed by camera name. |
|
||||
| `task` | `string` | Natural-language task instruction (for VLA conditioning). |
|
||||
|
||||
### 7.2 Action Payload (Server → Client)
|
||||
|
||||
| Field | Type | Purpose |
|
||||
| -------------------- | --------- | --------------------------------------------------------------- |
|
||||
| `response_to_seq_id` | `uint64` | Echoes the observation `seq_id` this action corresponds to. |
|
||||
| `inference_time_ms` | `float32` | Server-side compute duration (for edge RTT math). |
|
||||
| `actions` | `bytes` | Action chunk as numpy array bytes (`(chunk_size, action_dim)`). |
|
||||
|
||||
### 7.3 Status Payload (Server, Queryable)
|
||||
|
||||
| Field | Type | Purpose |
|
||||
| ----------------------- | ------------------- | ------------------------------------------ |
|
||||
| `model_id` | `string` | Policy identifier (e.g., `pi0`). |
|
||||
| `model_version` | `string` | Model version or checkpoint path. |
|
||||
| `expected_cameras` | `dict[str, (H, W)]` | Expected camera names and shapes. |
|
||||
| `action_dim` | `int` | Dimensionality of the action space. |
|
||||
| `max_actions_per_chunk` | `int` | Maximum chunk size the model supports. |
|
||||
| `observation_features` | `dict` | Full feature specification for validation. |
|
||||
|
||||
### 7.4 Serialization Format
|
||||
|
||||
**MessagePack** for all structured metadata (compact, fast, cross-language). Image payloads are raw JPEG bytes embedded in the MessagePack structure. State vectors use `numpy.tobytes()` with shape/dtype metadata for zero-copy reconstruction.
|
||||
|
||||
**No pickle.** The current prototype uses `pickle.dumps`/`pickle.loads` throughout, which allows arbitrary code execution. This is replaced entirely.
|
||||
|
||||
---
|
||||
|
||||
## 8. Latency Compensation
|
||||
|
||||
### 8.1 RTT Calculation
|
||||
|
||||
The edge device tracks in-flight observations:
|
||||
|
||||
```python
|
||||
in_flight: dict[int, float] = {} # seq_id -> time.perf_counter() at send
|
||||
|
||||
# On send:
|
||||
in_flight[seq_id] = time.perf_counter()
|
||||
|
||||
# On receive action chunk:
|
||||
rtt = time.perf_counter() - in_flight[response_to_seq_id]
|
||||
# delete older keys than the one received
|
||||
```
|
||||
|
||||
> **Important**: Delete only the exact `response_to_seq_id` key from `in_flight`, not all keys `<= response_to_seq_id`. With Zenoh's best-effort transport, messages can arrive out of order. Clearing earlier keys would make their RTT unmeasurable.
|
||||
|
||||
### 8.2 Stale Action Trimming
|
||||
|
||||
When an action chunk arrives, the edge calculates how many initial steps have already expired:
|
||||
|
||||
```python
|
||||
expired_steps = int(rtt / environment_dt)
|
||||
valid_actions = action_chunk[expired_steps:]
|
||||
```
|
||||
|
||||
The valid actions are then blended into the action buffer using the configured aggregation function.
|
||||
|
||||
### 8.3 Edge Cases
|
||||
|
||||
| Scenario | Behavior |
|
||||
| -------------------------------------- | -------------------------------------------------------------------------------------- |
|
||||
| **First observation** (no RTT history) | Apply all action steps without trimming. |
|
||||
| **Dropped observations** | Server infers on next received observation. No special handling needed. |
|
||||
| **Dropped action chunks** | Edge continues executing current buffer. If buffer empties, warn & hold last position. |
|
||||
| **Server crash** | Edge exhausts buffer, holds position, warns & re-validates via status query. |
|
||||
|
||||
> **Assumption**: All currently supported robots are position-controlled (SO100, SO101, OMX). For velocity-controlled robots, the fail-safe must send zero-velocity instead of holding position. This should be configurable per-robot.
|
||||
|
||||
---
|
||||
|
||||
## 9. Known Bugs in Current Prototype
|
||||
|
||||
These issues exist in `src/lerobot/async_inference/` and must be addressed in the new implementation.
|
||||
|
||||
### BUG-1: `send_observation` Blocks the Control Loop (Critical)
|
||||
|
||||
**Location**: `robot_client.py:207`
|
||||
|
||||
`self.stub.SendObservations(observation_iterator)` is a synchronous gRPC call inside the 33ms control loop. For multi-camera observations (several MB after pickle), this consumes 10-20ms on the network, leaving no headroom for sensor capture and motor commands. The robot stutters.
|
||||
|
||||
**Resolution in new design**: Observation publishing is moved to a dedicated thread. Zenoh's `session.put()` is non-blocking by default. The control loop only deposits observations into a local outbox.
|
||||
|
||||
### BUG-2: Race Condition in Action Queue Aggregation (Correctness)
|
||||
|
||||
**Location**: `robot_client.py:236-267`
|
||||
|
||||
The lock on `self.action_queue` is acquired to read `internal_queue = self.action_queue.queue` (a reference to the internal deque), then **released** at line 238. The aggregation logic iterates over this reference outside the lock. Meanwhile, the control loop thread can `get_nowait()` from the same queue, mutating the deque during iteration. At line 267, the entire queue is replaced, but actions popped between 238-267 are silently lost.
|
||||
|
||||
**Fix**: Either hold the lock for the entire aggregation, or `list(self.action_queue.queue)` to copy contents before releasing.
|
||||
|
||||
### BUG-3: No RPC Deadlines (Reliability)
|
||||
|
||||
**Location**: `robot_client.py:278`
|
||||
|
||||
`GetActions` blocks indefinitely if the server hangs (GPU OOM, deadlock). The retry policy handles `UNAVAILABLE` but not a hung connection.
|
||||
|
||||
**Resolution in new design**: The polling `GetActions` pattern is replaced by Zenoh subscription callbacks. The client needs a watchdog timer or check when action queue is empty: if no actions are received for `T` seconds, trigger re-validation via the status service.
|
||||
|
||||
### BUG-4: Similarity Check Ignores Images (Correctness for VLAs)
|
||||
|
||||
**Location**: `helpers.py:280-297`
|
||||
|
||||
`observations_similar()` + `must_go` is a workaround for current architecure limitations to avoid filling up the server queue the first seconds of the task & the robot remaining idle.
|
||||
|
||||
**Resolution in new design**: the server always processes the latest observation per client in its inference loop, and doesn't need similarity gating at all. The client can always push.
|
||||
|
||||
---
|
||||
|
||||
## 10. Gaps Between Prototype and Target Architecture
|
||||
|
||||
### 10.1 Critical (Must Address)
|
||||
|
||||
| # | Gap | Current State | Target State |
|
||||
| --- | ------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| G1 | **Single-client server** | One `observation_queue(maxsize=1)`, one `last_processed_obs`, one `_predicted_timesteps`. `_reset_server()` flushes all state on any new connection. | Per-client state (`ClientState` dataclass) keyed by `client_uuid`. Zenoh key-expression routing provides client isolation. |
|
||||
| G2 | **Dynamic model loading** | Client sends `RemotePolicyConfig` → server calls `from_pretrained()` on demand. | Server loads models at startup from config manifest. `SendPolicyInstructions` RPC eliminated. Client validates via status query. |
|
||||
| G3 | **gRPC transport** | Entire `transport/` directory: proto definitions, generated stubs, chunking utils. 4 RPCs: `Ready`, `SendPolicyInstructions`, `SendObservations`, `GetActions`. | Zenoh pub/sub. Client publishes obs, subscribes to actions. Server subscribes to obs, publishes actions. Dispatching via key expressions. |
|
||||
| G4 | **Pickle serialization** | `pickle.dumps`/`pickle.loads` throughout (arbitrary code execution risk, `# nosec` suppression). | MessagePack for structured metadata + raw JPEG bytes for images + `numpy.tobytes()` for state vectors. |
|
||||
|
||||
### 10.2 Important
|
||||
|
||||
| # | Gap | Current State | Target State |
|
||||
| --- | -------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | -------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| G5 | **No RTT/latency compensation** | No `seq_id`, no `response_to_seq_id`, no `inference_time_ms`. Timestamps use `time.time()` (unreliable across machines). | Edge-local `perf_counter` + echoed `seq_id` + server inference duration. Stale action step trimming. |
|
||||
| G6 | **No hierarchical routing** | Direct gRPC channel to `host:port`. | Zenoh key expressions: `cluster/experiment/model/version/task/client/topic`. |
|
||||
| G7 | **No data logging** | `control_loop` has access to obs and actions but doesn't persist them. | Edge records via `LeRobotDataset` (`build_dataset_frame` + `dataset.add_frame`). |
|
||||
| G8 | **No authentication** | `grpc.insecure_channel`. | Zenoh TLS + access control lists on key expressions. |
|
||||
| G9 | **ProcessorPipeline divergence** | Server reimplements observation prep in `helpers.py` (custom `resize_robot_observation_image` with `F.interpolate` bilinear). Diverges from standard `RobotProcessorPipeline`. | Use the standard `RobotProcessorPipeline` + `build_dataset_frame` to ensure behavioral equivalence between record and async inference. |
|
||||
|
||||
### 10.3 Nice-to-Have
|
||||
|
||||
| # | Gap | Current State | Target State |
|
||||
| --- | ------------------------------------- | --------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| G11 | **No status/discovery service** | Bare `Ready()` ping. | Zenoh queryable at `cluster/exp/model/version/task/status`. |
|
||||
| G12 | **No monitoring** | `FPSTracker` + `logging.debug`. | Structured metrics via Zenoh telemetry topics. Wildcard subscriptions for centralized monitoring. |
|
||||
| G13 | **No entry points** | Module-level `__main__`. | `lerobot-policy-server` and `lerobot-robot-client` console scripts in `pyproject.toml`. |
|
||||
| G14 | **Ratio-based observation threshold** | `chunk_size_threshold` (0-1 ratio of queue fill). Scales oddly with different `actions_per_chunk` values. | Absolute time threshold: `buffer_time_s` calibrated to observed RTT. Send observation when `queue_size * environment_dt < buffer_time_s`. |
|
||||
|
||||
---
|
||||
|
||||
## 11. Design Decisions & Rationale
|
||||
|
||||
### 11.1 Why Zenoh Over gRPC
|
||||
|
||||
| Aspect | Zenoh | gRPC |
|
||||
| ------------------------- | -------------------------------------------------------------------------- | ---------------------------------------------------------------------------------- |
|
||||
| Communication model | Pub/sub — natural fit for "client publishes obs, server publishes actions" | Request/response — requires polling (`GetActions` loop) or bidirectional streaming |
|
||||
| Multi-tenant routing | Hierarchical key expressions provide built-in per-client topic isolation | Requires manual per-client channel/stream management |
|
||||
| Discovery | Built-in discovery | Requires external service (mDNS, Consul, etc.) |
|
||||
| Observation publishing | Non-blocking put (fire-and-forget) — resolves BUG-1 automatically | Synchronous stream-unary call — blocks the control loop |
|
||||
| Same-machine optimization | Shared-memory transport (zero-copy) | Loopback TCP |
|
||||
| Telemetry | Wildcard subscriptions (`+/+/+/+/+/metrics`) | Requires separate monitoring infrastructure |
|
||||
|
||||
**Tradeoffs of going Zenoh-only:**
|
||||
|
||||
- Smaller community, less tooling for monitoring/tracing vs. gRPC's mature ecosystem.
|
||||
- No built-in schema enforcement (Zenoh sends raw bytes) — serialization correctness is entirely on us.
|
||||
- Default QoS is best-effort (like UDP). Must explicitly configure reliable delivery for action chunks.
|
||||
- `zenoh-python` bindings are less battle-tested than `grpcio`. Needs integration testing under network stress.
|
||||
|
||||
### 11.2 Why Single Inference Thread (Not Batching)
|
||||
|
||||
True GPU batching across clients requires collecting observations from multiple clients and running a single forward pass. This is difficult because:
|
||||
|
||||
- Clients send observations at different times — waiting to batch adds latency.
|
||||
- Different clients may have slightly different image resolutions.
|
||||
- Error in one client's observation shouldn't affect others.
|
||||
|
||||
**Decision**: Start with sequential processing (single inference thread, round-robin across clients). Profile GPU utilization.
|
||||
|
||||
### 11.4 Why MessagePack (Not Protobuf, Not FlatBuffers)
|
||||
|
||||
- **Protobuf**: Strong schema enforcement but heavier toolchain (proto compilation, generated code). Since we're dropping gRPC, the protobuf dependency becomes unnecessary overhead.
|
||||
- **MessagePack**: Fast, compact, schema-less (enforced by application), excellent Python support (`msgpack` package), good for nested dicts with mixed types. Natural fit for observation/action payloads.
|
||||
|
||||
Images are embedded as raw JPEG bytes within the MessagePack structure. State vectors use `numpy.tobytes()` with shape/dtype metadata for zero-copy reconstruction.
|
||||
|
||||
### 11.5 Action Aggregation Strategy
|
||||
|
||||
When a new action chunk overlaps with the existing buffer, the overlapping timesteps must be blended. The current prototype supports configurable aggregation functions:
|
||||
|
||||
| Function | Formula | Character |
|
||||
| ------------------ | ----------------------- | ------------------------------------------ |
|
||||
| `weighted_average` | `0.3 * old + 0.7 * new` | Smooth transitions, favors new predictions |
|
||||
| `latest_only` | `new` | Most responsive, can cause discontinuities |
|
||||
| `average` | `0.5 * old + 0.5 * new` | Equal weight |
|
||||
| `conservative` | `0.7 * old + 0.3 * new` | Smooth, slow to adapt |
|
||||
|
||||
Ultimately, this should be the user's decision. Default to `weighted_average`. The goal of async is not to do temporal ensembling, but to provide a solution when we want to decouple inference and execution.
|
||||
|
||||
---
|
||||
|
||||
## 12. Configuration
|
||||
|
||||
### 12.1 Server Configuration (Manifest)
|
||||
|
||||
Servers are configured via a YAML manifest that declares which models to pre-warm & clients to serve:
|
||||
|
||||
```yaml
|
||||
cluster_uuid: jupiter
|
||||
experiment_tag: fabio2
|
||||
server:
|
||||
- model_id: pi0
|
||||
model_version: v1
|
||||
pretrained_path: lerobot/pi0-cookie-v1
|
||||
application_tag: cookie
|
||||
device: cuda:0
|
||||
fps: 30
|
||||
endpoint: tcp/192.168.1.50:7447
|
||||
clients:
|
||||
- client_uuid: cookie-worker-4269
|
||||
```
|
||||
|
||||
### 12.2 Client Configuration
|
||||
|
||||
Clients are configured via draccus dataclass (CLI-compatible):
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class AsyncClientConfig:
|
||||
# Zenoh routing
|
||||
cluster_uuid: str
|
||||
experiment_tag: str
|
||||
model_id: str
|
||||
model_version: str
|
||||
application_tag: str
|
||||
client_uuid: str
|
||||
endpoint: str
|
||||
|
||||
# Robot
|
||||
robot: RobotConfig
|
||||
|
||||
# Control
|
||||
fps: int = 30
|
||||
actions_per_chunk: int = 50
|
||||
aggregate_fn_name: str = "weighted_average"
|
||||
jpeg_quality: int = 90
|
||||
|
||||
# Fail-safe
|
||||
max_empty_cycles_before_warning: int = 10
|
||||
|
||||
# Datset recording
|
||||
dataset_repo_id: str | None = None # None = no logging
|
||||
|
||||
# Task
|
||||
task: str = ""
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 14. Data Logging Integration
|
||||
|
||||
The client records observations and executed actions into a local `LeRobotDataset` for deferred upload to the training dataset:
|
||||
|
||||
```python
|
||||
# In control_loop, after executing an action:
|
||||
if self.dataset is not None:
|
||||
frame = build_dataset_frame(
|
||||
self.dataset.features,
|
||||
processed_observation,
|
||||
prefix=OBS_STR,
|
||||
)
|
||||
frame["action"] = executed_action_tensor
|
||||
self.dataset.add_frame(frame)
|
||||
```
|
||||
@@ -1,498 +0,0 @@
|
||||
# Decoupled VLA Inference & Edge Control v2: Async Network Inference for `lerobot-rollout`
|
||||
|
||||
> **Status**: supersedes the v1 proposal in full. v1 was written against the standalone `src/lerobot/async_inference/` prototype, before `lerobot-rollout` existed. This revision re-grounds the design in the current codebase, keeps v1's decisions that survived contact with it (marked **KEPT** throughout), reverses the ones that didn't, and adds the safety, multi-tenancy, and operations specifications v1 lacked.
|
||||
|
||||
## 1. Executive Summary
|
||||
|
||||
This document specifies a production-grade system for decoupling GPU-bound policy inference from high-frequency robot control, targeting power users running **hundreds of robots** against centralized GPU clusters. The system keeps v1's **Model-as-a-Service (MaaS)** paradigm and **Zenoh** transport, but changes the integration architecture fundamentally:
|
||||
|
||||
- **The client is not a standalone CLI.** It is `--inference.type=remote`, a new `InferenceEngine` backend inside `lerobot-rollout` (`src/lerobot/rollout/inference/`). Every rollout strategy (base, sentry, highlight, dagger, episodic) gets network inference for free — including dataset recording, DAgger pause/resume, Rerun visualization, and safe teardown.
|
||||
- **The client is weightless.** No policy weights, no policy processors on the edge. `--policy.path` resolves to a config-only `PreTrainedConfig` (no weight download) used for pre-flight validation and action ordering.
|
||||
- **The server is stateless per request.** All RTC chunk state (leftover prefixes, latency tracking, delay computation) lives client-side in the existing `ActionQueue`/`LatencyTracker` machinery — the client ships prefixes + a delay hint with each observation. A server crash loses zero control state; reconnects and horizontal scaling are trivial.
|
||||
- **Multi-tenancy is engineered, not assumed.** The real hazards are stateful processor pipelines and episode-scoped policy state — not `predict_action_chunk` purity (which holds for ACT/Pi0/Pi0.5/SmolVLA but _not_ diffusion). The server uses per-session processor instances, a chunk-stateless allowlist, and an exclusive serving mode for policies that need it.
|
||||
- **The legacy module dies.** `src/lerobot/async_inference/` (~1,900 lines, pickle-over-gRPC, single-client, four confirmed bugs) is deleted in the same PR that lands the new backend. No deprecation cycle: the module is experimental, its CLI undocumented in the main flow, and every config field has a mapped successor (§13.4).
|
||||
|
||||
---
|
||||
|
||||
## 2. Motivation (unchanged from v1) — **KEPT**
|
||||
|
||||
LeRobot's standard control loop runs policy inference and robot I/O in the same process. This breaks down when:
|
||||
|
||||
- **The policy is too large for edge hardware** (Pi0-class models need a dedicated GPU).
|
||||
- **Multiple robots need the same policy** (redundant GPU allocation per robot).
|
||||
- **Inference latency exceeds the control deadline** (e.g. 150 ms inference on a 33 ms control tick).
|
||||
|
||||
Decoupling solves all three: the edge runs a tight CPU loop; a GPU server performs inference for N clients.
|
||||
|
||||
What changed since v1: the _local_ version of this decoupling already shipped. `RTCInferenceEngine` (`src/lerobot/rollout/inference/rtc.py`) runs inference in a background thread against a thread-safe `ActionQueue` with latency-aware chunk merging. **The network system is that same architecture with the thread boundary replaced by a network boundary.** This is the design's central simplification: reuse, don't reinvent.
|
||||
|
||||
---
|
||||
|
||||
## 3. Gap Analysis: v1 Proposal vs. Modern Codebase
|
||||
|
||||
| Topic | v1 assumed | Modern reality | Verdict |
|
||||
| ----------------------------------------- | --------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------- | --------------------------------------- |
|
||||
| Client architecture | Standalone robot-client CLI (§5.1 of v1) | `InferenceEngine` ABC seam in `lerobot-rollout` (`rollout/inference/base.py`); strategies are backend-agnostic | **Superseded** — backend, not CLI |
|
||||
| Chunk blending | Configurable aggregation zoo (`weighted_average`, …) | `ActionQueue` replace-with-delay-trim (RTC) / append (non-RTC) (`policies/rtc/action_queue.py:147-217`) | **Superseded** — drop blending entirely |
|
||||
| Latency compensation | Hand-rolled RTT trim (`expired_steps = int(rtt/dt)`, v1 §8.2) | `ActionQueue.merge(..., real_delay, idx_before)` + `LatencyTracker` already do this, validated | **Superseded** |
|
||||
| Multi-tenancy invariant | "`predict_action_chunk()` pure ⇒ safe to share" | Processor state + episode-scoped policy state are the real hazards (§7) | **Incomplete** — fixed in §8.3 |
|
||||
| Data logging | Client-side `build_dataset_frame` + `add_frame` sketch (v1 §14) | Recording strategies (sentry/episodic/dagger) already log obs + executed actions | **Superseded** — free via rollout |
|
||||
| MaaS pre-warm, no dynamic loading | ✓ | Still right; legacy `SendPolicyInstructions` is a pickle/RCE + capacity-planning disaster | **KEPT** |
|
||||
| JPEG observation compression | ✓ | Still right (§10.1) | **KEPT** |
|
||||
| Status/capability validation before start | ✓ (Zenoh queryable) | Still right; extended into a hard sync-safety contract (§8.4) | **KEPT, extended** |
|
||||
| Time-based send threshold (v1 G14) | ✓ | Adopted as `buffer_time_s` | **KEPT** |
|
||||
| Zenoh pub/sub data plane | ✓ | Confirmed; QoS corrected (§6.3), control plane moved to queryables, liveliness added | **KEPT, hardened** |
|
||||
| MessagePack serialization | ✓ | Endorsed (zenoh's `ext` serializer cannot encode numpy); must be version-gated (§10.4) | **KEPT, with schema discipline** |
|
||||
| QoS table (v1 §6.2) | "obs best-effort, actions reliable" | Conflates transport reliability with congestion control; BLOCK on actions is dangerous | **Revised** (§6.3) |
|
||||
| Bugs BUG-1…BUG-4, gaps G1…G14 | Listed as work items | Every one resolved _structurally_ by this design (§13.5 mapping) | **Resolved by design** |
|
||||
|
||||
---
|
||||
|
||||
## 4. Critical Pushbacks on v1
|
||||
|
||||
Each pushback: claim → evidence → consequence for this design.
|
||||
|
||||
**P1 — A standalone client duplicates `lerobot-rollout`.**
|
||||
v1 §5.1 assigns the client: observation capture, action execution at frequency, fail-safe, data logging. Every one of those is already owned by rollout strategies and `send_next_action` (`rollout/strategies/core.py:269-304`), which tolerates `None` actions, runs the interpolator, and routes through the canonical robot processors. A standalone client re-implements loop timing, recording, DAgger UX, Rerun, and teardown safety — and then drifts. _Consequence_: the client is `RemoteInferenceEngine`, registered as `--inference.type=remote` next to `sync` and `rtc`.
|
||||
|
||||
**P2 — The aggregation-function zoo fabricates actions no policy predicted.**
|
||||
`0.3*old + 0.7*new` produces hybrid actions that exist in no policy's output distribution; the logged action becomes unexplainable (bad for the reproducibility story) and the implementation hosted a real lock-release race (BUG-2, `async_inference/robot_client.py:236-267`). RTC's prefix-conditioned chunk generation is the principled mechanism for smooth chunk transitions; plain append covers non-RTC chunking. _Consequence_: `ActionQueue` replace/append are the only two merge semantics. The zoo is deleted.
|
||||
|
||||
**P3 — "predict_action_chunk pure ⇒ multi-tenant safe" is incomplete.**
|
||||
Verified in-tree: (a) `RelativeActionsProcessorStep` caches `_last_state` at preprocess (`processor/relative_action_processor.py:131`) and the postprocessor reads it back (`:189`) — a shared pipeline across clients is a race; (b) `DiffusionPolicy.predict_action_chunk` reads `self._queues`, which only `select_action` populates (`policies/diffusion/modeling_diffusion.py:90-108`) — it is **not** chunk-stateless; (c) SAC/SARM have no `predict_action_chunk` at all. _Consequence_: per-session processor instances (mandatory), a chunk-stateless allowlist, `serving_mode: exclusive` for diffusion-family, refusal at startup for SAC/SARM, and `policy.reset()` is **never** called in shared mode (§8.3).
|
||||
|
||||
**P4 — v1 re-derives latency compensation that already exists, on top of broken clocks.**
|
||||
v1 §8 specifies an in-flight RTT dict and manual stale-step trimming. `ActionQueue.merge(original, processed, real_delay, idx_before)` already trims `real_delay` stale steps and cross-validates against actions consumed in flight (`action_queue.py:219-246`). Worse, the legacy code compares wall clocks across machines (`robot_client.py:420` stamps `time.time()` "to compare timestamps across client and server"; `policy_server.py:178` compares it) — NTP skew is the same order as the latencies being measured. _Consequence_: the **monotonic iron rule** (§11): instants never cross machines; client timestamps are opaque echoed tokens; servers report only durations. `delay_steps = ceil((rtt + inference)/dt)` is computed client-side from client-local `perf_counter` samples and shipped per request.
|
||||
|
||||
**P5 — One-in-flight per client is a correctness requirement, not a tuning choice.**
|
||||
At send time the client snapshots `idx_before = queue.get_action_index()` and the leftover prefixes; `merge` validates against them. Two in-flight requests carry conflicting snapshots — the second merge corrupts both RTC replace mode and append mode. The local RTC thread is also strictly one-inference-at-a-time; one-in-flight preserves exact parity. _Consequence_: the worker publishes one observation, waits for its chunk (or timeout), then sends the next. v1 §8.1's out-of-order in-flight dict is dead weight; a late chunk is accepted only if it answers the _latest_ outstanding `seq_id`, otherwise dropped.
|
||||
|
||||
**P6 — v1's QoS table conflates transport reliability with congestion behavior.**
|
||||
"Reliable delivery for actions" sounds right but the dangerous knob is congestion control: a publisher configured `BLOCK` on the action topic can stall the **server's** publish path on one robot's dead uplink (Zenoh blocks up to `wait_before_close`, then may close the transport). A dropped action chunk is _recoverable by design_ — the client's queue keeps the robot moving and the next chunk replaces it. _Consequence_ (§6.3): actions = `reliability=RELIABLE` (hop-level) + `congestion_control=DROP` + `express=True` + `priority=INTERACTIVE_HIGH`; observations = `DROP` + `DATA`. If WAN loss proves material, upgrade the action topic to Zenoh Advanced Pub/Sub (cache + recovery, zenoh ≥ 1.5) rather than BLOCK.
|
||||
|
||||
**P7 — Schema-less MessagePack invites silent version drift across a 300-robot fleet.**
|
||||
msgpack stays (zenoh's `ext` serializer cannot encode numpy/dataclasses, and the team's choice stands), but naked msgpack dicts across heterogeneous fleet versions fail at runtime, on the robot. _Consequence_ (§10.4): a packed little-endian **attachment header** (`schema_version`, `seq_id`, `episode_id`, `client_mono_ns` — the rmw_zenoh pattern) so routing/correlation never deserializes the body; `schema_version` negotiated at the session handshake; additive-only evolution; golden codec tests. Protobuf-over-ZBytes is the documented fallback if drift bites in practice.
|
||||
|
||||
**P8 — "Deterministic rollout reproducibility" is unattainable on real robots.**
|
||||
No seed controls hardware, sensor noise, or network jitter; RTC's latency-driven trimming is inherently timing-dependent. _Consequence_: the contract is **fully logged + replayable** (§12): recording strategies already persist observations and executed actions; the remote engine adds `(session_id, seq_id, episode_id)` provenance so client datasets join server audit logs mechanically.
|
||||
|
||||
**P9 — v1 has no safety specification.**
|
||||
"Log a warning when the buffer empties" is not a fail-safe for a 300-robot fleet. _Consequence_ (§9): a staleness bound (`max_action_age_s` — never execute an action older than X relative to its source observation), an explicit fallback ladder (`hold` / `repeat_last` / `zero` — zero-command required for future velocity-controlled robots), and a DEAD state that triggers the existing strategy shutdown path (return-to-initial-pose, disconnect) via the same `shutdown_event` mechanism RTC uses (`rtc.py:359-360`).
|
||||
|
||||
**P10 — Capacity must be formula-driven, not "a user decision".**
|
||||
v1 §4 says clients-per-server "is a user decision". With `t` = server time per request, `r` = per-client request rate, `H` = RTC execution horizon, `dt` = control period:
|
||||
`N_max = min( 0.8 / (r·t), (H·dt/2 − RTT_net) / t )`
|
||||
→ ACT @ 20 ms, 1 Hz: ~40 clients/GPU. Pi0 @ 150 ms, 1 Hz: ~5 clients/GPU. 300 robots on Pi0 ≈ 60 GPU pods. _Consequence_: the manifest carries `max_sessions`; the server rejects session opens beyond it (with current load in the reply) so clients retry another replica. Micro-batching is deferred — blocked on a real API issue (`predict_action_chunk` takes a _scalar_ `inference_delay`; batched clients have different delays) — behind a `Scheduler` seam so it can land later without redesign (§8.5).
|
||||
|
||||
**P11 — Discovery ≠ multicast.**
|
||||
Zenoh's multicast scouting does not cross WAN, NAT, or most k8s CNIs. _Consequence_: multicast scouting disabled; clients use static `connect.endpoints` (DNS name of the router) + gossip; presence and liveness come from Zenoh **liveliness tokens** (§6.4), not discovery. "Discovery" for a robot fleet is configuration.
|
||||
|
||||
---
|
||||
|
||||
## 5. System Topology
|
||||
|
||||

|
||||
_(Diagram unchanged from v1 — the topology survives; transport/QoS/session details in it are superseded by §6.)_
|
||||
|
||||
- **Router tier**: one or more `zenohd` routers (k8s Deployment + Service, TLS on 7447). Robots **dial out** to the router (NAT-friendly: labs only need outbound 7447/443). GPU servers join as peers via cluster DNS.
|
||||
- **Server**: one process = one `(model_repo, revision, dtype, device)` on one GPU, pre-warmed from a YAML manifest (**KEPT** from v1, amended: `pin_task: bool` — VLA prompts may vary per session unless pinned).
|
||||
- **Client**: one robot running `lerobot-rollout --inference.type=remote`. Weightless: config-only policy metadata.
|
||||
- **Identity**: `client_uuid` per robot; `session_id` per connection epoch; both in every log line on both sides.
|
||||
|
||||
---
|
||||
|
||||
## 6. Zenoh Design
|
||||
|
||||
All Zenoh claims below were verified against zenoh / zenoh-python 1.x (eclipse-zenoh 1.9.0). Pin: `eclipse-zenoh>=1.9,<2.0`; keep `zenohd` on the same minor as the Python binding. Wheels cover manylinux x86_64/aarch64/armv7l/armv6l + macOS — Raspberry Pi edge clients are covered.
|
||||
|
||||
### 6.1 Key-expression schema
|
||||
|
||||
```
|
||||
@lerobot/<model_id>/<revision>/<task_slug>/<client_uuid>/obs client → server
|
||||
@lerobot/<model_id>/<revision>/<task_slug>/<client_uuid>/action server → client
|
||||
@lerobot/<model_id>/<revision>/<task_slug>/status queryable (capabilities)
|
||||
@lerobot/<model_id>/<revision>/<task_slug>/session queryable (open/validate)
|
||||
@lerobot/<model_id>/<revision>/<task_slug>/<client_uuid>/reset queryable (episode boundary)
|
||||
@lerobot/<model_id>/<revision>/<task_slug>/<client_uuid>/alive liveliness token (client)
|
||||
@lerobot/<model_id>/<revision>/<task_slug>/server/alive liveliness token (server)
|
||||
```
|
||||
|
||||
Rules (hard, enforced by a `sanitize_keyexpr()` helper):
|
||||
|
||||
- Root at the **verbatim chunk** `@lerobot` — verbatim chunks are only matched by identical chunks, so third-party `**` subscribers on a shared router can never scrape the tree.
|
||||
- Sanitize every user-supplied segment (model ids, task strings, uuids): non-empty, no `* $ ? # /`, no leading/trailing/double `/`. A task string containing `/` must be slugified before it becomes a key chunk.
|
||||
- Server subscribes with a **single-depth** wildcard (`.../*/obs`) — never `**` (it would also match `status`, `alive`, …).
|
||||
- v1's `cluster/experiment` prefix segments are dropped from the key schema; they return as free-form `tags` metadata in the session handshake (telemetry/labeling, not routing). Routing topology belongs to deployment (which router you dial), not to key depth.
|
||||
|
||||
### 6.2 Data plane vs. control plane (the rmw_zenoh split)
|
||||
|
||||
- **Data plane = pub/sub** (KEPT from v1): observations up, action chunks down, correlated by `seq_id` in **attachments** (§10.4). Pub/sub rather than query-per-inference because: a timed-out query's late reply is _dropped by the transport_ (wasted inference), whereas a late pub/sub chunk is still mergeable if it answers the latest outstanding seq; and pub/sub leaves room for server-initiated messages (drain notices). The one-in-flight discipline (P5) is enforced in the client worker, not by the transport.
|
||||
- **Control plane = queryables** (request/reply with explicit timeouts; the pattern rmw*zenoh uses for ROS 2 services): `status` (pre-flight capability fetch, 2 s timeout), `session` (open/validate → ack with capabilities + `session_id`), `reset` (episode boundary — \_acknowledged*, so episodic strategies know the server-side episode state is clean). Always pass an explicit `timeout` to `session.get()` — the config default is 10 s, far too long for our watchdogs.
|
||||
- **Episode ordering**: under one-in-flight there is no obs/reset race window in the data plane, but as belt-and-braces the first observation of each episode also carries `episode_start=True` + the new `episode_id` in its header.
|
||||
|
||||
### 6.3 QoS (revised from v1 §6.2 — see P6)
|
||||
|
||||
| Topic | reliability | congestion_control | express | priority | Why |
|
||||
| ------------------ | ----------- | ---------------------- | -------- | ---------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
|
||||
| `obs` | default | **DROP** | false | DATA | Intentional drop already happened at the client's one-slot holder; if the uplink stalls, dropping a frame protects the control loop. |
|
||||
| `action` | RELIABLE | **DROP** (never BLOCK) | **true** | INTERACTIVE_HIGH | Hop-level reliability over TCP; express skips batching for the small (4–50 KB) latency-critical payload; DROP so one dead robot uplink can never stall the server's publish path. Chunk loss is recoverable: the client buffer rides through it. |
|
||||
| control queryables | RELIABLE | default | — | — | Correctness over latency; explicit timeouts bound them. |
|
||||
|
||||
Upgrade path if WAN chunk loss proves material: `AdvancedPublisher`/`AdvancedSubscriber` (zenoh ≥ 1.5) with a small cache + heartbeat-based recovery **on the action topic only**. Hop-by-hop RELIABLE is not end-to-end reliability — Zenoh has no broker persistence; a disconnected subscriber's data is gone. The design assumes this (client state machine, §9).
|
||||
|
||||
### 6.4 Liveliness (presence + watchdogs)
|
||||
|
||||
- Client declares a liveliness token on `.../<client_uuid>/alive`. The server liveliness-subscribes with `history=True`: token appear → ensure session state; token drop → GC the session (mailbox, processor instances) after a grace period.
|
||||
- Server declares `.../server/alive`. The client liveliness-subscribes: on drop → treat as RECONNECTING (§9), hold/fallback per config, re-run the `status`/`session` handshake when the token reappears.
|
||||
- Tune the transport lease down from its default so ungraceful-death detection is seconds, not tens of seconds (verify the default in the pinned version; it is config `transport/link/tx/lease`).
|
||||
- Liveliness cannot detect a _hung-but-connected_ server. The client's per-request timeout (`request_timeout_s`) is the authoritative watchdog — this is the structural fix for legacy BUG-3 (no deadlines on `GetActions`).
|
||||
|
||||
### 6.5 Threading constraints (zenoh-python facts that shape both processes)
|
||||
|
||||
- **No asyncio API** in zenoh-python — both client and server are thread-based. This matches the existing RTC engine pattern exactly.
|
||||
- Each callback-based subscriber spawns a dedicated Python thread; **blocking Zenoh calls inside callbacks are disallowed**. Callbacks must be deposit-only (write a slot, set an event, return).
|
||||
- Channel handlers (`FifoChannel`, `RingChannel`) are Rust-side; `try_recv()` polls without spawning Python threads. `RingChannel(1)` is native latest-only semantics.
|
||||
- No zero-copy path for our payloads (SHM API is `@_unstable` and same-host-only; `ZBytes` copy behavior undocumented). At ~200 KB × a few Hz per robot, one memcpy is irrelevant.
|
||||
|
||||
### 6.6 Router deployment
|
||||
|
||||
- `zenohd` official image as a k8s Deployment (1–N replicas; routers mesh and reroute around failures) behind a `LoadBalancer`/`NodePort` Service exposing TLS 7447. No official Helm chart exists — roll-your-own manifests.
|
||||
- `scouting.multicast.enabled: false`; `scouting.gossip.enabled: true`; clients/servers use static `connect.endpoints`.
|
||||
- **Auth**: mTLS per robot (`transport.link.tls` with `enable_mtls`) + router **ACL** keyed on `cert_common_names`: a robot's cert may only `put` to `@lerobot/**/<its-uuid>/obs` and receive on `.../<its-uuid>/action`. Caveat (flagged): ACL config reloads require a router restart — plan cert/ACL changes as rolling router restarts.
|
||||
- Security review input: the third-party Zenoh protocol security analysis (Census Labs, 2025) should be read before exposing 7447 publicly.
|
||||
|
||||
---
|
||||
|
||||
## 7. The Statelessness Boundary (the load-bearing section)
|
||||
|
||||
**Where the network cut goes.** The local RTC pipeline is:
|
||||
|
||||
```
|
||||
obs (robot-processed dict)
|
||||
→ build_dataset_frame(hw_features, obs, "observation") CLIENT (cheap, hardware-coupled)
|
||||
─────────────────────────── network ───────────────────────────
|
||||
→ prepare_observation_for_inference(...) SERVER (policy-coupled, heavy)
|
||||
→ per-session preprocessor(...) SERVER (stateful within the request)
|
||||
→ policy.predict_action_chunk(obs, inference_delay, prefix) SERVER (pure for allowlisted policies)
|
||||
→ per-session postprocessor(...) SERVER (reads state cached at preprocess)
|
||||
─────────────────────────── network ───────────────────────────
|
||||
→ ActionQueue.merge(original, processed, real_delay, idx_before) CLIENT
|
||||
```
|
||||
|
||||
Three consequences:
|
||||
|
||||
1. **The server needs no cross-request state.** `RelativeActionsProcessorStep` writes `_last_state` at preprocess and the postprocessor reads it back _within the same request_. Per-session pipeline instances + one-request-at-a-time-per-session give correctness with zero persistent state.
|
||||
2. **RTC state stays client-side**, exactly where `RTCInferenceEngine` already keeps it. Each request ships: `inference_delay_steps = ceil(L_max/dt)` (from the client `LatencyTracker`, whose samples are full network-inclusive cycle times — RTT compensation falls out for free), `prefix_model = queue.get_left_over()[:H]`, and `prefix_robot = queue.get_processed_left_over()[:H]` (needed for server-side relative-prefix re-anchoring, mirroring `rtc.py:287-305`). The response returns **both** the model-space and robot-space chunks because `merge` needs both. ≤ `execution_horizon × action_dim` float32 each — a few hundred bytes.
|
||||
3. **G9 dies structurally.** No bespoke client resize (`F.interpolate` in legacy `helpers.py`), no client-side normalization. Clients ship native camera resolution; the server's canonical processor path does everything — serve-time preprocessing is byte-identical to train-time.
|
||||
|
||||
**What the server _does_ hold** (and what it means):
|
||||
|
||||
- Per-session processor instances (cheap; normalization stat tensors shared read-only).
|
||||
- Per-session episode counter + stats. Episode reset = reset the session's pipelines, clear its mailbox. **`policy.reset()` is never called in shared mode** — it is global to the shared policy instance and unnecessary for chunk-pure policies (ACT's ensembler and Pi0/SmolVLA's queues live in `select_action`, not `predict_action_chunk` — verified).
|
||||
- Policies that are _not_ chunk-pure get `serving_mode: exclusive` (§8.3).
|
||||
|
||||
---
|
||||
|
||||
## 8. The Inference Server: `lerobot-policy-server`
|
||||
|
||||
New package `src/lerobot/policy_server/`; console script `lerobot-policy-server --manifest manifest.yaml`.
|
||||
|
||||
### 8.1 Process model — **KEPT** from v1, amended
|
||||
|
||||
One process = one model+task on one GPU, loaded and warmed at startup (`warmup_inferences` dummy forwards; covers torch.compile). Multi-GPU nodes run N processes (`CUDA_VISIBLE_DEVICES` pinning). Dynamic model loading (`SendPolicyInstructions`) is **rejected**: pickle/RCE surface, arbitrary-download surface, and it destroys capacity planning. Amendment: `pin_task: false` (default) lets VLA clients set the task per session; `pin_task: true` rejects mismatched tasks at session open.
|
||||
|
||||
### 8.2 Concurrency (pure threads — no asyncio in zenoh-python)
|
||||
|
||||
```
|
||||
zenoh subscriber (.../*/obs) inference worker (1 thread, owns GPU)
|
||||
deposit-only callback: loop:
|
||||
slots[client_uuid] = sample ──► pick next session with pending obs (RR ring)
|
||||
(per-client latest-only) decode JPEG → per-session preprocess
|
||||
predict_action_chunk(delay, prefix)
|
||||
control queryables (status/session/ per-session postprocess → encode
|
||||
reset): validate, mutate session publisher.put(.../<uuid>/action)
|
||||
registry, reply (publishing from the worker thread is fine)
|
||||
```
|
||||
|
||||
- **Per-client latest-only mailbox**: a wildcard subscriber with a deposit-only callback writing per-client slots (scales to dynamic fleets), or — when the manifest enumerates clients — one `RingChannel(1)` subscriber per client polled via `try_recv()`. Either way: newest observation wins; a superseded request is counted (`superseded_seqs` in the next response) so drops are visible. This deletes legacy BUG-4 (`observations_similar` + `must_go`) by construction — the **client** decides when to request; the server never second-guesses observation content.
|
||||
- **Single inference worker**: torch releases the GIL inside `forward`, callbacks stay responsive. Strict round-robin over sessions with pending observations: each gets exactly one inference per cycle; starvation is structurally impossible. Overload degrades into longer cycle times → larger (but correct) client `delay_steps` → eventually the client staleness bound trips and the robot holds — safe by construction.
|
||||
|
||||
### 8.3 Chunk-stateless allowlist and serving modes
|
||||
|
||||
At startup the server classifies the loaded policy:
|
||||
|
||||
| Class | Policies (verified) | Mode |
|
||||
| --------------- | ------------------------------------------------------------------------------------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| chunk-stateless | ACT, Pi0, Pi0.5, SmolVLA (and any policy whose `predict_action_chunk` touches no instance state) | `shared`: N sessions, per-session pipelines, `policy.reset()` never called |
|
||||
| chunk-stateful | Diffusion family (`predict_action_chunk` reads `select_action`-fed `self._queues`) | `exclusive`: `max_sessions=1` enforced; episode reset additionally calls `policy.reset()`; second session open → rejected with a self-explanatory error |
|
||||
| no chunk API | SAC, SARM | refused at startup |
|
||||
|
||||
Implemented as a registry in `policy_server/validation.py`; the cleaner follow-up is a `supports_stateless_chunking` class attribute on `PreTrainedPolicy` (needs a pass over policy families — roadmap §14).
|
||||
|
||||
### 8.4 Session open & capability validation (fail fast, fail loud)
|
||||
|
||||
`session` queryable payload: `client_uuid`, `policy_type`, `fps`, feature summary (post-rename observation feature names + shapes, ordered action keys), `schema_version`, RTC intent, `tags`. Checks:
|
||||
|
||||
| Check | Rule | On mismatch |
|
||||
| -------------------------- | --------------------------------------------------------------- | ---------------------------------------------------------------------------------- |
|
||||
| Action names **and order** | must equal server's `action_feature_names` exactly | **hard reject** — this is the sync-safety contract mapping chunk columns to motors |
|
||||
| Camera names | client set must cover `policy.config.input_features` image keys | hard reject |
|
||||
| Resolution | any H×W accepted (server resizes canonically) | warn if aspect ratio differs from training |
|
||||
| State dim | flattened dim must match | hard reject |
|
||||
| `schema_version` | client within server's supported range | hard reject |
|
||||
| fps | vs. manifest `trained_fps` | warn (reject only when `strict_fps: true`) |
|
||||
| Task | when `pin_task: true`, must equal `default_task` | reject |
|
||||
| RTC | client RTC requires policy RTC kwargs support | downgrade to append mode + warning |
|
||||
| Capacity | `active_sessions < max_sessions` | reject with current load → client retries another replica |
|
||||
|
||||
Reply: `session_id`, model info (repo, revision — consider a checkpoint hash, §15), `action_feature_names`, `chunk_size`, `trained_fps`, `supports_rtc`, `serving_mode`, `warmed_up`, `schema_version`, warnings. **rename_map is applied client-side** so the wire format is canonical policy-feature keys across heterogeneous robots (also a prerequisite for future batching).
|
||||
|
||||
### 8.5 Scheduler seam (micro-batching later, not in v1)
|
||||
|
||||
The worker calls a `Scheduler.select(ready: list[Session]) -> list[Session]`; v1 ships `RoundRobin` (`return ready[:1]`). Cross-session batching is blocked on the policy API (`inference_delay` is scalar; batched clients have different delays/prefixes) — when that lands, a `MicroBatch` scheduler groups same-shape sessions. The seam costs nothing now and prevents a redesign later.
|
||||
|
||||
### 8.6 Manifest
|
||||
|
||||
```yaml
|
||||
model:
|
||||
{
|
||||
repo_or_path: lerobot/pi0_towels,
|
||||
revision: main,
|
||||
dtype: bfloat16,
|
||||
device: cuda,
|
||||
}
|
||||
default_task: "fold the towel"
|
||||
pin_task: false
|
||||
serving_mode: shared # forced to exclusive for chunk-stateful policies
|
||||
max_sessions: 5 # from the §P10 formula: Pi0 @150ms, 1 Hz refresh
|
||||
warmup_inferences: 2
|
||||
strict_fps: false
|
||||
zenoh:
|
||||
connect_endpoints: ["tls/router.gpu-cluster.internal:7447"]
|
||||
tls:
|
||||
{
|
||||
connect_certificate: ...,
|
||||
connect_private_key: ...,
|
||||
root_ca_certificate: ...,
|
||||
}
|
||||
health_port: 9100 # HTTP health + Prometheus metrics
|
||||
debug: { capture_dir: null, capture_max: 256 }
|
||||
```
|
||||
|
||||
Draccus dataclass in `policy_server/manifest.py`; YAML via `--manifest`, individual overrides via CLI.
|
||||
|
||||
---
|
||||
|
||||
## 9. The Edge Client: `RemoteInferenceEngine`
|
||||
|
||||
New file `src/lerobot/rollout/inference/remote.py`, registered `@InferenceEngineConfig.register_subclass("remote")`.
|
||||
|
||||
### 9.1 Threading model
|
||||
|
||||
| Thread | Role |
|
||||
| -------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| Main (strategy loop) | `notify_observation(obs)` → lock-protected latest-only slot (identical to `rtc.py` `_obs_holder`). `get_action()` → `ActionQueue.get()` + staleness check. **Never any I/O.** Structurally fixes legacy BUG-1 (blocking send inside the 33 ms loop). |
|
||||
| Network worker (1 daemon thread) | Cycle: wait until `queue_remaining·dt ≤ buffer_time_s` and active → snapshot `idx_before`, prefixes, `delay_steps = ceil(L_max/dt)` → encode (JPEG q=`jpeg_quality`) → `publisher.put(obs, attachment=header)` → await chunk on the action subscriber channel (timeout `request_timeout_s`) → `merge(original, processed, ceil(L/dt), idx_before)` → `latency_tracker.add(L)`. Owns the state machine, reconnects, and control queries. One-in-flight (P5). |
|
||||
| Zenoh action subscriber | `FifoChannel(2)` handler drained by the worker (no Python callback thread on the hot path); liveliness subscriber callback is deposit-only (sets an event). |
|
||||
|
||||
Reused unchanged: `ActionQueue` (`policies/rtc/action_queue.py`), `LatencyTracker`, `ActionInterpolator` (lives in strategies — `interpolation_multiplier` works with remote for free). Deleted concepts: aggregation zoo, `observations_similar`, `must_go`, `TimedObservation`/`TimedAction` pickles.
|
||||
|
||||
### 9.2 Fail-safe state machine
|
||||
|
||||
```
|
||||
ok no chunk for degraded_after_s
|
||||
CONNECTING ─────► STREAMING ───────────────────────────────► DEGRADED
|
||||
│ ▲ ▲ │ queue empty OR max_action_age_s hit │
|
||||
│ │ backoff, │ └───────────────────────────────────► STALLED ◄──┘
|
||||
│ │ re-handshake │ first successful merge │
|
||||
│ └─ RECONNECTING ◄── timeout streak / server liveliness drop ◄─┘
|
||||
│ │ offline > max_offline_s, capability/schema mismatch, auth failure
|
||||
└──────► DEAD (failed=True → shutdown_event → strategy teardown: return-to-initial-pose)
|
||||
```
|
||||
|
||||
- **DEGRADED**: requests failing but the queue still holds actions — the robot keeps executing; chunks _are_ the fault-tolerance buffer (1–3 s of coverage makes blips and clean server drains invisible).
|
||||
- **STALLED**: queue empty or staleness bound hit → apply `fallback`: `hold` (`get_action` → `None`; `send_next_action` already tolerates it), `repeat_last`, or `zero` (required for velocity-controlled robots, where "send nothing" means "keep last velocity").
|
||||
- **Staleness bound** (sync safety): every merge records `(chunk_start_index, t_send)`; `get_action` refuses any action whose source observation is older than `max_action_age_s` (default 3.0 s ≈ 90 steps @ 30 fps). Bounds open-loop execution after a network stall.
|
||||
- **DEAD**: only after `max_offline_s` (default 60 s) or a hard contract violation (capability/schema mismatch on reconnect — e.g. the server restarted with a different model; never execute wrong-model chunks). Uses the exact mechanism RTC uses (`failed=True` + global `shutdown_event`) so existing teardown runs unchanged.
|
||||
- **Watchdog layering**: per-request timeout (hung server — the BUG-3 fix) → server liveliness token (dead server/router) → staleness bound (the robot-side invariant that holds regardless of why data stopped).
|
||||
- **Pause/resume (DAgger)**: `pause()` stops the worker publishing (slot keeps refreshing, ignored); queue intact — parity with `RTCInferenceEngine.pause`. DAgger's existing `interpolator.reset(); engine.reset(); engine.resume()` sequence works unchanged.
|
||||
- **`reset()` (episode boundary)**: clear `ActionQueue` + staleness bookkeeping, bump `episode_id`, fire the acked `reset` query (1 s timeout, failure logged — the server has nothing it _must_ do thanks to per-request statelessness), flag `episode_start` on the next observation. `LatencyTracker` intentionally survives reset (latency is episode-invariant; parity with local RTC).
|
||||
- **`ready`** = session opened ∧ capabilities validated ∧ server `warmed_up`. First-chunk gating is implicit (`get_action` → `None` until the first merge).
|
||||
|
||||
### 9.3 Weightless client — exact integration changes
|
||||
|
||||
- `rollout/context.py`: `PolicyContext.{policy, preprocessor, postprocessor}` become `| None`. For remote configs, skip step 1 (weight load / PEFT / `.to(device)` / torch.compile / `init_rtc_processor`) and step 6 (`make_pre_post_processors`). Verified safe: strategies only consume `ctx.policy.inference`. Keep steps 2–5 (robot processors, hardware, features, dataset) — they are robot-derived. Keep the visual pre-flight check (`context.py:309-324`): `--policy.path` already loads config-only (`rollout/configs.py:324-328`, no weight download) and failing before dialing the server is free. `use_torch_compile` / explicit `--device` → warn-and-ignore for remote.
|
||||
- `rollout/inference/factory.py`: signature loosens to `policy: PreTrainedPolicy | None` (+ `policy_config: PreTrainedConfig`); `sync`/`rtc` branches guard `policy is None`; the `remote` branch lazy-imports (`eclipse-zenoh` stays an optional extra).
|
||||
- The authoritative validation moves to session open (§8.4); the local check becomes a fast-fail convenience.
|
||||
|
||||
### 9.4 Config
|
||||
|
||||
```python
|
||||
@InferenceEngineConfig.register_subclass("remote")
|
||||
@dataclass
|
||||
class RemoteInferenceConfig(InferenceEngineConfig):
|
||||
connect_endpoint: str = "tls/localhost:7447" # zenoh router endpoint
|
||||
tls_cert: str | None = None; tls_key: str | None = None; tls_ca: str | None = None
|
||||
client_uuid: str = "" # "" → uuid4 at start()
|
||||
jpeg_quality: int = 90 # 0 = raw (LAN/debug)
|
||||
buffer_time_s: float = 0.5 # send next obs when queue playback ≤ this (v1 G14) — KEPT
|
||||
max_action_age_s: float = 3.0 # staleness bound (safety)
|
||||
degraded_after_s: float = 1.0
|
||||
request_timeout_s: float = 5.0
|
||||
reconnect_initial_backoff_s: float = 0.5
|
||||
reconnect_max_backoff_s: float = 10.0
|
||||
max_offline_s: float = 60.0
|
||||
fallback: FallbackBehavior = FallbackBehavior.HOLD # hold | repeat_last | zero
|
||||
rtc: RTCConfig = field(default_factory=RTCConfig) # enabled → replace mode; horizon caps prefix
|
||||
tags: dict[str, str] = field(default_factory=dict) # ex-cluster/experiment labels
|
||||
```
|
||||
|
||||
```bash
|
||||
# Remote RTC + sentry recording (the reproducibility path)
|
||||
lerobot-rollout \
|
||||
--strategy.type=sentry \
|
||||
--policy.path=lerobot/pi0_towels \ # config-only: no weights downloaded
|
||||
--inference.type=remote \
|
||||
--inference.connect_endpoint=tls/router.gpu-cluster.internal:7447 \
|
||||
--inference.rtc.execution_horizon=10 \
|
||||
--robot.type=so100_follower --robot.port=/dev/ttyACM0 \
|
||||
--robot.cameras="{front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||
--dataset.repo_id=user/rollout_fleet_a --dataset.single_task="fold the towel"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 10. Wire Schema
|
||||
|
||||
### 10.1 Payload anatomy & rates — **KEPT** (JPEG) with numbers
|
||||
|
||||
Upstream per request: joints (24–128 B) + JPEG frames (480p q90 ≈ 40–90 KB each; 720p ≈ 110–230 KB) + RTC prefixes (≤ a few KB) → 60–450 KB depending on cameras. Downstream: `2 × chunk_size × action_dim × 4 B` + metadata → 3–50 KB. Effective request rate is self-clocked by `buffer_time_s` to ~1–4 Hz per robot (not the 30 Hz control rate). 300 robots ≈ 0.3–10 Mbps each — the wire is never the bottleneck; bandwidth budgeting is about camera count/resolution, and each GPU pod only ever sees its own ≤ `max_sessions` clients. Zenoh fragments >64 KiB payloads transparently; multi-MB messages are fine.
|
||||
|
||||
### 10.2 Attachment header (fixed-layout, packed little-endian — parsed without touching the body)
|
||||
|
||||
| Field | Type | Notes |
|
||||
| ---------------- | ---- | -------------------------------------------------------------- |
|
||||
| `schema_version` | u16 | negotiated at session open |
|
||||
| `msg_type` | u8 | OBS / CHUNK / EVENT |
|
||||
| `seq_id` | u64 | per-session monotonic; echoed in the chunk |
|
||||
| `episode_id` | u32 | bumped by `reset()` |
|
||||
| `client_mono_ns` | i64 | client `monotonic_ns()`; **opaque to the server, echoed back** |
|
||||
| `session_epoch` | u32 | bumped per (re)connect; stale-epoch chunks dropped |
|
||||
|
||||
### 10.3 msgpack bodies
|
||||
|
||||
**ObservationMsg** (client → server): `state: {names_ref, data: f32 LE bytes}`, `images: {name: {codec: jpeg|raw, bytes, (h,w,c) if raw}}`, `task: str`, `inference_delay_steps: int`, `prefix_model: tensor?`, `prefix_robot: tensor?` (tensors = raw LE bytes + dtype + shape), `episode_start: bool`.
|
||||
**ActionChunkMsg** (server → client): `seq_id_echo`, `client_mono_ns_echo`, `chunk_model: tensor`, `chunk_robot: tensor`, `queue_wait_ms: f32`, `inference_ms: f32`, `superseded_seqs: u32`, `server_load: f32`.
|
||||
**Status / SessionOpen / SessionAck / ResetMsg**: as specified in §8.4.
|
||||
|
||||
### 10.4 Schema discipline (P7)
|
||||
|
||||
`schema_version` gates at handshake; evolution is additive-only (new optional msgpack keys; unknown keys ignored); attachment layout changes require a version bump; golden codec round-trip tests (tensor exactness, JPEG RGB-channel-order regression — a silent BGR swap poisons every VLA in the fleet) are part of the test suite. **No pickle anywhere** — KEPT from v1 and now structural: nothing in the schema can carry code.
|
||||
|
||||
---
|
||||
|
||||
## 11. Latency Budget & the Clock Iron Rule
|
||||
|
||||
| Stage | LAN | WAN (50 ms RTT) |
|
||||
| ------------------------------ | --------------- | --------------- |
|
||||
| JPEG encode ×3 (edge CPU) | 2–9 ms | 2–9 ms |
|
||||
| Serialize | <1 ms | <1 ms |
|
||||
| Uplink (tx + ½RTT) | ~2 ms | ~54 ms |
|
||||
| Server queue wait | 0 → 1×inference | 0 → 1×inference |
|
||||
| Decode + canonical preprocess | 4–10 ms | 4–10 ms |
|
||||
| **Inference** | **15–150 ms** | **15–150 ms** |
|
||||
| Postprocess + downlink + merge | ~2 ms | ~27 ms |
|
||||
| **Total (Pi0-class)** | **~110–175 ms** | **~190–250 ms** |
|
||||
|
||||
Inference is 60–85 % of end-to-end on LAN; the entire transport+serialization stack is <10 ms. WAN adds propagation + uplink bandwidth — identical under any transport. At 30 fps this lands `delay_steps` ≈ 4–8, comfortably inside RTC execution horizons: WAN degrades smoothness parameters, never correctness. _This table is the standing answer to transport-performance bikeshedding._
|
||||
|
||||
**Clock iron rule** (P4): wall-clock instants never cross machines. Client stamps `monotonic_ns`, the server echoes it opaquely; `RTT = now − echo`. The server reports only **durations** (`queue_wait_ms`, `inference_ms`) measured on its own monotonic clock; `network_time = RTT − queue_wait − inference` for diagnostics. The schema has no field in which a foreign wall-clock instant can be compared — the legacy `time.time()` bug is unrepresentable.
|
||||
|
||||
---
|
||||
|
||||
## 12. Reproducibility & Audit (P8)
|
||||
|
||||
The contract is **fully logged + replayable**, not "deterministic":
|
||||
|
||||
- **Client = source of truth.** Recording strategies already persist observations + executed actions to `LeRobotDataset`. The remote engine logs, per executed action, the `(session_id, seq_id, episode_id)` of its source chunk plus the echoed `queue_wait_ms`/`inference_ms` (dataset-extras columns are a follow-up; client logs in v1).
|
||||
- **Server audit line per request** (structured JSON): `{ts, session_id, client_uuid, seq_id, episode_id, queue_wait_ms, inference_ms, chunk_range, superseded_seqs, outcome}`.
|
||||
- **Optional bounded capture**: `debug.capture_dir` writes a ring of request/response pairs (safetensors) for byte-exact offline replay through the same server pipeline.
|
||||
- **Runbook — "robot #217 stuttered at 14:03"**: (1) Grafana `session_staleness{client="217"}` — spike ⇒ server side, flat ⇒ client/network. (2) Server side: audit lines — `queue_wait_ms` rising across _all_ sessions ⇒ overloaded replica (check `active_sessions` vs `max_sessions`); `superseded_seqs` streak on 217 only ⇒ that client over-requesting; `outcome=error` ⇒ adjacent stack trace. (3) Client side: state-machine transitions + reconnects in the client log; dataset rows show which seq's chunk was executing and where `None` ticks occurred. Every hop shares `(session_id, seq_id)` — the join is mechanical.
|
||||
|
||||
---
|
||||
|
||||
## 13. Integration & Migration Plan
|
||||
|
||||
### 13.1 New
|
||||
|
||||
| Path | Content |
|
||||
| --------------------------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `src/lerobot/policy_server/{__init__,schema,codec,manifest,session,scheduler,validation,server}.py` | wire schema constants, msgpack/attachment codecs, manifest dataclasses, `Session` + mailbox, `Scheduler` seam, capability rules + chunk-stateless registry, zenoh servicer + inference worker + drain + HTTP health/metrics |
|
||||
| `src/lerobot/rollout/inference/remote.py` | `RemoteInferenceEngine` (~600 lines; mirrors `rtc.py` structure) |
|
||||
| `src/lerobot/scripts/lerobot_policy_server.py` + `[project.scripts]` entry | thin `main()` |
|
||||
| `docker/Dockerfile.policy-server` | CUDA runtime base + uv; manifest via ConfigMap |
|
||||
| `docs/source/remote_inference.mdx` (+ `_toctree.yml`) | replaces `async.mdx` |
|
||||
|
||||
### 13.2 Modified
|
||||
|
||||
`rollout/inference/factory.py` (config + Optional-typed signature + lazy import) · `rollout/context.py` (weightless branch) · `rollout/inference/__init__.py` · `scripts/lerobot_rollout.py` docstring · `pyproject.toml`: `[async]` extra becomes `eclipse-zenoh>=1.9,<2.0` + `msgpack` (grpcio/matplotlib leave it; grpcio remains under `[hilserl]`/`dev` for the RL stack).
|
||||
|
||||
### 13.3 Removed — same landing PR
|
||||
|
||||
`src/lerobot/async_inference/` · `tests/async_inference/` · `docs/source/async.mdx` + its `_toctree.yml` entry · the `AsyncInference` service + `Observation`/`Actions`/`PolicySetup` messages from `src/lerobot/transport/services.proto` (regenerate pb2; **`LearnerService` untouched** — `transport/` is shared with HIL-SERL (`src/lerobot/rl/`); the RL test suite gates this change).
|
||||
|
||||
### 13.4 Legacy config → successor mapping
|
||||
|
||||
| Legacy (`RobotClientConfig`/`PolicyServerConfig`) | Successor |
|
||||
| ------------------------------------------------- | ---------------------------------------------------------- |
|
||||
| `server_address` | `--inference.connect_endpoint` (zenoh router) |
|
||||
| `policy_type`, `pretrained_name_or_path` | `--policy.path` (config-only) + server manifest |
|
||||
| `chunk_size_threshold` (0–1 ratio) | `--inference.buffer_time_s` (seconds) |
|
||||
| `actions_per_chunk` | server manifest (validated at session open) |
|
||||
| `aggregate_fn_name` + `AGGREGATE_FUNCTIONS` | **dropped** — `ActionQueue` replace/append |
|
||||
| `policy_device`, `client_device` | **dropped** — server concern / chunks arrive CPU f32 |
|
||||
| `debug_visualize_queue_size` | **dropped** — Rerun (`--display_data`) + engine stats |
|
||||
| `PolicyServerConfig.{host,port}` | manifest `zenoh.connect_endpoints` |
|
||||
| `inference_latency`, `obs_queue_timeout` | **dropped** — latency client-measured; no server obs queue |
|
||||
| `SendPolicyInstructions` | **dropped** — MaaS manifest + session validation |
|
||||
| `observations_similar` / `must_go` | **dropped** — latest-only slots + client send gate |
|
||||
| pickle envelopes | **dropped** — msgpack + attachment headers |
|
||||
|
||||
### 13.5 Legacy bugs/gaps → structural resolution
|
||||
|
||||
BUG-1 → worker thread owns all I/O. BUG-2 → aggregation deleted; `ActionQueue` is internally locked. BUG-3 → per-request timeout + liveliness. BUG-4 → client-side send gating; server newest-wins. G1 → per-session registry. G2 → manifest. G4 → msgpack+attachments. G5 → monotonic echo + `delay_steps`. G7 → recording strategies. G8 → mTLS + ACL. G9 → server-side canonical processors. G11 → `status` queryable. G12 → Prometheus + audit logs. G13 → `lerobot-policy-server` console script. G14 → `buffer_time_s`.
|
||||
|
||||
### 13.6 Tests
|
||||
|
||||
- **Unit**: codec round-trips (tensor exact; JPEG RGB-order regression), capability-validation matrix (§8.4 as parametrized cases), scheduler fairness + newest-wins supersession (mock policy with configurable sleep), manifest parsing, key-expr sanitization.
|
||||
- **Loopback integration** (CPU, fast CI): client+server in one process over zenoh peer-to-peer (or a localhost `zenohd` started by the fixture), tiny-ACT, fake 2-camera robot, N=8 concurrent sessions. The headline regression: two sessions with different joint states must not cross-contaminate `RelativeActionsProcessorStep` postprocessing — the test that proves the multi-tenancy claim.
|
||||
- **Chaos**: kill the server mid-episode → client returns `None`, never raises into the control loop, `failed` stays False within `max_offline_s`, resumes on restart; `docker kill zenohd` → liveliness flap → safe state → re-handshake (explicitly tests re-declaration behavior, flagged unverified upstream); SIGTERM drain → in-flight chunk completes, clients reconnect invisibly.
|
||||
- **Golden parity**: remote RTC vs local `RTCInferenceEngine` on identical observation sequences → byte-identical merged queues (the re-anchoring contract test). Gate for any real-robot remote-RTC use.
|
||||
|
||||
---
|
||||
|
||||
## 14. Roadmap
|
||||
|
||||
1. **PR1 — schema & codecs** (no torch deps): `policy_server/{schema,codec,manifest}.py`, key-expr sanitizer, golden codec tests.
|
||||
2. **PR2 — server core**: session registry, scheduler, validation/allowlist, inference worker with mock policy, loopback harness.
|
||||
3. **PR3 — client engine**: `RemoteInferenceEngine`, factory/context weightless integration, loopback integration + chaos + golden-parity tests.
|
||||
4. **PR4 — ops & docs**: Dockerfile, health/metrics, drain, ACL examples, `remote_inference.mdx`, rollout docstring.
|
||||
5. **Landing PR — legacy deletion**: remove `async_inference/` + tests + docs + proto service (RL suite gates), `[async]` extra swap.
|
||||
6. **Pre-release field validation**: one real robot on a lossy network (watchdog default tuning); JPEG q90 vs raw A/B on one policy (train/serve shift).
|
||||
7. **Future**: micro-batching (needs per-sample `inference_delay` across policy families), client-side downscale-to-policy-resolution (config-only shapes make it possible), Advanced Pub/Sub on the action topic, per-robot quotas, dataset provenance columns, `supports_stateless_chunking` attribute upstreamed to policy classes.
|
||||
|
||||
---
|
||||
|
||||
## 15. Open Risks
|
||||
|
||||
| Risk | Mitigation / decision needed |
|
||||
| ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| Re-anchoring parity (server-side relative-prefix re-anchor vs `rtc.py`) | Golden parity test (§13.6) is a hard gate before robot use; likely failure mode is normalizer dtype/device drift |
|
||||
| First-chunk over-trim when idle: `merge` trims `ceil(L/dt)` even when nothing was consumed (queue empty at episode start) — wasteful at network latencies (600 ms ⇒ 18 steps) | Proposed clamp `real_delay = min(real_delay, last_index - idx_before)` touches the shared `ActionQueue` used by local RTC — needs sign-off + regression tests |
|
||||
| JPEG train/serve distribution shift | Unmeasured; A/B before locking q90 default (roadmap §14.6) |
|
||||
| Watchdog defaults untuned (`request_timeout_s=5`, `degraded_after_s=1`, `max_action_age_s=3`) | Field validation on wired and Wi-Fi; consider named profiles |
|
||||
| Capability check can pass while semantics differ (different finetune, different normalization stats, identical feature names) | Add checkpoint hash/revision pinning to SessionAck — decide in PR2 |
|
||||
| zenoh-python long-session maturity: re-declaration after router restart partially verified; SHM unstable; no asyncio | Chaos tests own this; thread-based design avoids the asyncio gap entirely |
|
||||
| Router ACL reload requires restart | Operational runbook: cert/ACL changes = rolling router restart |
|
||||
| `fallback=zero` has no consumer until velocity actions land in rollout (only `.pos` features routed today) | Validate the enum against robot capabilities when velocity support lands |
|
||||
| Per-client mailbox memory under fleet-scale wildcard subscription | One decoded-obs slot per client is small; add an LRU GC tied to liveliness drops |
|
||||
@@ -1,82 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# This Dockerfile builds a GPU inference pod for `lerobot-policy-server`
|
||||
# (remote inference over Zenoh). It starts from an NVIDIA CUDA base image;
|
||||
# the cu128 PyTorch wheels bundle their own CUDA runtime (driver floor 570.86,
|
||||
# see pyproject.toml [tool.uv]).
|
||||
|
||||
# docker build -f docker/Dockerfile.policy-server -t lerobot-policy-server .
|
||||
# docker run --gpus all -v ./server.yaml:/etc/lerobot/server.yaml lerobot-policy-server
|
||||
#
|
||||
# Extra policy-family dependencies (e.g. pi0/smolvla need transformers) can be
|
||||
# added at build time:
|
||||
# docker build -f docker/Dockerfile.policy-server \
|
||||
# --build-arg LEROBOT_EXTRAS="async pi0" -t lerobot-policy-server .
|
||||
|
||||
# Configure the base image (same CUDA family as Dockerfile.internal)
|
||||
ARG CUDA_VERSION=12.8.1
|
||||
ARG OS_VERSION=24.04
|
||||
FROM nvidia/cuda:${CUDA_VERSION}-base-ubuntu${OS_VERSION}
|
||||
|
||||
# Define Python version and lerobot extras arguments
|
||||
ARG PYTHON_VERSION=3.12
|
||||
ARG LEROBOT_EXTRAS="async"
|
||||
|
||||
# Configure environment variables
|
||||
ENV DEBIAN_FRONTEND=noninteractive \
|
||||
PATH=/lerobot/.venv/bin:$PATH
|
||||
|
||||
# Install system dependencies and uv (as root).
|
||||
# Kept lean: no hardware/teleop libraries — this image only serves policies.
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
git curl ca-certificates libglib2.0-0 ffmpeg \
|
||||
&& curl -LsSf https://astral.sh/uv/install.sh | sh \
|
||||
&& mv /root/.local/bin/uv /usr/local/bin/uv \
|
||||
&& useradd --create-home --shell /bin/bash user_lerobot \
|
||||
&& apt-get clean && rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Create application directory and set permissions
|
||||
WORKDIR /lerobot
|
||||
RUN chown -R user_lerobot:user_lerobot /lerobot
|
||||
|
||||
# Switch to the non-root user
|
||||
USER user_lerobot
|
||||
|
||||
# Model checkpoints are cached under HF_HOME — mount it as a volume
|
||||
# (or a PVC in Kubernetes) so warm restarts skip the Hub download.
|
||||
ENV HOME=/home/user_lerobot \
|
||||
HF_HOME=/home/user_lerobot/.cache/huggingface \
|
||||
HF_LEROBOT_HOME=/home/user_lerobot/.cache/huggingface/lerobot \
|
||||
TORCH_HOME=/home/user_lerobot/.cache/torch \
|
||||
TRITON_CACHE_DIR=/home/user_lerobot/.cache/triton
|
||||
|
||||
# Create the virtual environment (Python provisioned by uv)
|
||||
RUN uv venv --python ${PYTHON_VERSION}
|
||||
|
||||
# Install lerobot from the build context with the async extra
|
||||
# (eclipse-zenoh + msgpack — see pyproject.toml [project.optional-dependencies])
|
||||
COPY --chown=user_lerobot:user_lerobot setup.py pyproject.toml uv.lock README.md MANIFEST.in ./
|
||||
COPY --chown=user_lerobot:user_lerobot src/ src/
|
||||
|
||||
RUN uv sync --locked --no-cache $(printf -- '--extra %s ' ${LEROBOT_EXTRAS})
|
||||
|
||||
# HTTP health + Prometheus metrics (manifest `health_port`, 0 disables)
|
||||
EXPOSE 9100
|
||||
|
||||
# The manifest is typically mounted as a ConfigMap (Kubernetes) or a bind
|
||||
# mount (docker run -v) at /etc/lerobot/server.yaml; any field can also be
|
||||
# overridden on the command line, e.g. --model.repo_or_path=lerobot/pi0_towels
|
||||
ENTRYPOINT ["lerobot-policy-server"]
|
||||
CMD ["--manifest", "/etc/lerobot/server.yaml"]
|
||||
@@ -9,8 +9,6 @@
|
||||
- sections:
|
||||
- local: il_robots
|
||||
title: Imitation Learning for Robots
|
||||
- local: lelab
|
||||
title: LeLab - Lerobot GUI
|
||||
- local: bring_your_own_policies
|
||||
title: Adding a Policy
|
||||
- local: integrate_hardware
|
||||
@@ -61,10 +59,6 @@
|
||||
title: π₀-FAST (Pi0Fast)
|
||||
- local: pi05
|
||||
title: π₀.₅ (Pi05)
|
||||
- local: molmoact2
|
||||
title: MolmoAct2
|
||||
- local: vla_jepa
|
||||
title: VLA-JEPA
|
||||
- local: eo1
|
||||
title: EO-1
|
||||
- local: groot
|
||||
@@ -79,16 +73,12 @@
|
||||
- sections:
|
||||
- local: sarm
|
||||
title: SARM
|
||||
- local: robometer
|
||||
title: ROBOMETER
|
||||
- local: topreward
|
||||
title: TOPReward
|
||||
title: "Reward Models"
|
||||
- sections:
|
||||
- local: inference
|
||||
title: Policy Deployment (lerobot-rollout)
|
||||
- local: remote_inference
|
||||
title: Remote Inference (lerobot-policy-server)
|
||||
- local: async
|
||||
title: Use Async Inference
|
||||
- local: rtc
|
||||
title: Real-Time Chunking (RTC)
|
||||
title: "Inference"
|
||||
|
||||
@@ -0,0 +1,313 @@
|
||||
# Asynchronous Inference
|
||||
|
||||
With our [SmolVLA](https://huggingface.co/papers/2506.01844) we introduced a new way to run inference on real-world robots, **decoupling action prediction from action execution**.
|
||||
In this tutorial, we'll show how to use asynchronous inference (_async inference_) using a finetuned version of SmolVLA, and all the policies supported by LeRobot.
|
||||
**Try async inference with all the policies** supported by LeRobot!
|
||||
|
||||
**What you'll learn:**
|
||||
|
||||
1. Why asynchronous inference matters and how it compares to, more traditional, sequential inference.
|
||||
2. How to spin-up a `PolicyServer` and connect a `RobotClient` from the same machine, and even over the network.
|
||||
3. How to tune key parameters (`actions_per_chunk`, `chunk_size_threshold`) for your robot and policy.
|
||||
|
||||
If you get stuck, hop into our [Discord community](https://discord.gg/s3KuuzsPFb)!
|
||||
|
||||
In a nutshell: with _async inference_, your robot keeps acting while the policy server is already busy computing the next chunk of actions---eliminating "wait-for-inference" lags and unlocking smoother, more reactive behaviours.
|
||||
This is fundamentally different from synchronous inference (sync), where the robot stays idle while the policy computes the next chunk of actions.
|
||||
|
||||
---
|
||||
|
||||
## Getting started with async inference
|
||||
|
||||
You can read more information on asynchronous inference in our [blogpost](https://huggingface.co/blog/async-robot-inference). This guide is designed to help you quickly set up and run asynchronous inference in your environment.
|
||||
|
||||
First, install `lerobot` with the `async` tag, to install the extra dependencies required to run async inference.
|
||||
|
||||
```shell
|
||||
pip install -e ".[async]"
|
||||
```
|
||||
|
||||
Then, spin up a policy server (in one terminal, or in a separate machine) specifying the host address and port for the client to connect to.
|
||||
You can spin up a policy server running:
|
||||
|
||||
```shell
|
||||
python -m lerobot.async_inference.policy_server \
|
||||
--host=127.0.0.1 \
|
||||
--port=8080
|
||||
```
|
||||
|
||||
This will start a policy server listening on `127.0.0.1:8080` (`localhost`, port 8080). At this stage, the policy server is empty, as all information related to which policy to run and with which parameters are specified during the first handshake with the client. Spin up a client with:
|
||||
|
||||
```shell
|
||||
python -m lerobot.async_inference.robot_client \
|
||||
--server_address=127.0.0.1:8080 \ # SERVER: the host address and port of the policy server
|
||||
--robot.type=so100_follower \ # ROBOT: your robot type
|
||||
--robot.port=/dev/tty.usbmodem585A0076841 \ # ROBOT: your robot port
|
||||
--robot.id=follower_so100 \ # ROBOT: your robot id, to load calibration file
|
||||
--robot.cameras="{ laptop: {type: opencv, index_or_path: 0, width: 1920, height: 1080, fps: 30}, phone: {type: opencv, index_or_path: 0, width: 1920, height: 1080, fps: 30}}" \ # POLICY: the cameras used to acquire frames, with keys matching the keys expected by the policy
|
||||
--task="dummy" \ # POLICY: The task to run the policy on (`Fold my t-shirt`). Not necessarily defined for all policies, such as `act`
|
||||
--policy_type=your_policy_type \ # POLICY: the type of policy to run (smolvla, act, etc)
|
||||
--pretrained_name_or_path=user/model \ # POLICY: the model name/path on server to the checkpoint to run (e.g., lerobot/smolvla_base)
|
||||
--policy_device=mps \ # POLICY: the device to run the policy on, on the server (cuda, mps, xpu, cpu)
|
||||
--actions_per_chunk=50 \ # POLICY: the number of actions to output at once
|
||||
--chunk_size_threshold=0.5 \ # CLIENT: the threshold for the chunk size before sending a new observation to the server
|
||||
--aggregate_fn_name=weighted_average \ # CLIENT: the function to aggregate actions on overlapping portions
|
||||
--debug_visualize_queue_size=True # CLIENT: whether to visualize the queue size at runtime
|
||||
```
|
||||
|
||||
In summary, you need to specify instructions for:
|
||||
|
||||
- `SERVER`: the address and port of the policy server
|
||||
- `ROBOT`: the type of robot to connect to, the port to connect to, and the local `id` of the robot
|
||||
- `POLICY`: the type of policy to run, and the model name/path on server to the checkpoint to run. You also need to specify which device should the sever be using, and how many actions to output at once (capped at the policy max actions value).
|
||||
- `CLIENT`: the threshold for the chunk size before sending a new observation to the server, and the function to aggregate actions on overlapping portions. Optionally, you can also visualize the queue size at runtime, to help you tune the `CLIENT` parameters.
|
||||
|
||||
Importantly,
|
||||
|
||||
- `actions_per_chunk` and `chunk_size_threshold` are key parameters to tune for your setup.
|
||||
- `aggregate_fn_name` is the function to aggregate actions on overlapping portions. You can either add a new one to a registry of functions, or add your own in `robot_client.py` (see [here](NOTE:addlinktoLOC))
|
||||
- `debug_visualize_queue_size` is a useful tool to tune the `CLIENT` parameters.
|
||||
|
||||
## Done! You should see your robot moving around by now 😉
|
||||
|
||||
## Async vs. synchronous inference
|
||||
|
||||
Synchronous inference relies on interleaving action chunk prediction and action execution. This inherently results in _idle frames_, frames where the robot awaits idle the policy's output: a new action chunk.
|
||||
In turn, inference is plagued by evident real-time lags, where the robot simply stops acting due to the lack of available actions.
|
||||
With robotics models increasing in size, this problem risks becoming only more severe.
|
||||
|
||||
<p align="center">
|
||||
<img
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/async-inference/sync.png"
|
||||
width="80%"
|
||||
></img>
|
||||
</p>
|
||||
<p align="center">
|
||||
<i>Synchronous inference</i> makes the robot idle while the policy is
|
||||
computing the next chunk of actions.
|
||||
</p>
|
||||
|
||||
To overcome this, we design async inference, a paradigm where action planning and execution are decoupled, resulting in (1) higher adaptability and, most importantly, (2) no idle frames.
|
||||
Crucially, with async inference, the next action chunk is computed _before_ the current one is exhausted, resulting in no idleness.
|
||||
Higher adaptability is ensured by aggregating the different action chunks on overlapping portions, obtaining an up-to-date plan and a tighter control loop.
|
||||
|
||||
<p align="center">
|
||||
<img
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/async-inference/async.png"
|
||||
width="80%"
|
||||
></img>
|
||||
</p>
|
||||
<p align="center">
|
||||
<i>Asynchronous inference</i> results in no idleness because the next chunk is
|
||||
computed before the current chunk is exhausted.
|
||||
</p>
|
||||
|
||||
---
|
||||
|
||||
## Start the Policy Server
|
||||
|
||||
Policy servers are wrappers around a `PreTrainedPolicy` interfacing them with observations coming from a robot client.
|
||||
Policy servers are initialized as empty containers which are populated with the requested policy specified in the initial handshake between the robot client and the policy server.
|
||||
As such, spinning up a policy server is as easy as specifying the host address and port. If you're running the policy server on the same machine as the robot client, you can use `localhost` as the host address.
|
||||
|
||||
<hfoptions id="start_policy_server">
|
||||
<hfoption id="Command">
|
||||
```bash
|
||||
python -m lerobot.async_inference.policy_server \
|
||||
--host=127.0.0.1 \
|
||||
--port=8080
|
||||
```
|
||||
</hfoption>
|
||||
<hfoption id="API example">
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
from lerobot.async_inference.configs import PolicyServerConfig
|
||||
from lerobot.async_inference.policy_server import serve
|
||||
|
||||
config = PolicyServerConfig(
|
||||
host="localhost",
|
||||
port=8080,
|
||||
)
|
||||
serve(config)
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
This listens on `localhost:8080` for an incoming connection from the associated`RobotClient`, which will communicate which policy to run during the first client-server handshake.
|
||||
|
||||
---
|
||||
|
||||
## Launch the Robot Client
|
||||
|
||||
`RobotClient` is a wrapper around a `Robot` instance, which `RobotClient` connects to the (possibly remote) `PolicyServer`.
|
||||
The `RobotClient` streams observations to the `PolicyServer`, and receives action chunks obtained running inference on the server (which we assume to have better computational resources than the robot controller).
|
||||
|
||||
<hfoptions id="start_robot_client">
|
||||
<hfoption id="Command">
|
||||
```bash
|
||||
python -m lerobot.async_inference.robot_client \
|
||||
--server_address=127.0.0.1:8080 \ # SERVER: the host address and port of the policy server
|
||||
--robot.type=so100_follower \ # ROBOT: your robot type
|
||||
--robot.port=/dev/tty.usbmodem585A0076841 \ # ROBOT: your robot port
|
||||
--robot.id=follower_so100 \ # ROBOT: your robot id, to load calibration file
|
||||
--robot.cameras="{ laptop: {type: opencv, index_or_path: 0, width: 1920, height: 1080, fps: 30}, phone: {type: opencv, index_or_path: 0, width: 1920, height: 1080, fps: 30}}" \ # POLICY: the cameras used to acquire frames, with keys matching the keys expected by the policy
|
||||
--task="dummy" \ # POLICY: The task to run the policy on (`Fold my t-shirt`). Not necessarily defined for all policies, such as `act`
|
||||
--policy_type=your_policy_type \ # POLICY: the type of policy to run (smolvla, act, etc)
|
||||
--pretrained_name_or_path=user/model \ # POLICY: the model name/path on server to the checkpoint to run (e.g., lerobot/smolvla_base)
|
||||
--policy_device=mps \ # POLICY: the device to run the policy on, on the server
|
||||
--actions_per_chunk=50 \ # POLICY: the number of actions to output at once
|
||||
--chunk_size_threshold=0.5 \ # CLIENT: the threshold for the chunk size before sending a new observation to the server
|
||||
--aggregate_fn_name=weighted_average \ # CLIENT: the function to aggregate actions on overlapping portions
|
||||
--debug_visualize_queue_size=True # CLIENT: whether to visualize the queue size at runtime
|
||||
```
|
||||
</hfoption>
|
||||
<hfoption id="API example">
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
import threading
|
||||
from lerobot.robots.so_follower import SO100FollowerConfig
|
||||
from lerobot.cameras.opencv import OpenCVCameraConfig
|
||||
from lerobot.async_inference.configs import RobotClientConfig
|
||||
from lerobot.async_inference.robot_client import RobotClient
|
||||
from lerobot.async_inference.helpers import visualize_action_queue_size
|
||||
|
||||
# 1. Create the robot instance
|
||||
"""Check out the cameras available in your setup by running `python lerobot/find_cameras.py`"""
|
||||
# these cameras must match the ones expected by the policy
|
||||
# check the config.json on the Hub for the policy you are using
|
||||
camera_cfg = {
|
||||
"top": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
|
||||
"side": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30)
|
||||
}
|
||||
|
||||
robot_cfg = SO100FollowerConfig(
|
||||
port="/dev/tty.usbmodem585A0076841",
|
||||
id="follower_so100",
|
||||
cameras=camera_cfg
|
||||
)
|
||||
|
||||
# 3. Create client configuration
|
||||
client_cfg = RobotClientConfig(
|
||||
robot=robot_cfg,
|
||||
server_address="localhost:8080",
|
||||
policy_device="mps",
|
||||
client_device="cpu",
|
||||
policy_type="smolvla",
|
||||
pretrained_name_or_path="<user>/smolvla_async",
|
||||
chunk_size_threshold=0.5,
|
||||
actions_per_chunk=50, # make sure this is less than the max actions of the policy
|
||||
)
|
||||
|
||||
# 4. Create and start client
|
||||
client = RobotClient(client_cfg)
|
||||
|
||||
# 5. Specify the task
|
||||
task = "Don't do anything, stay still"
|
||||
|
||||
if client.start():
|
||||
# Start action receiver thread
|
||||
action_receiver_thread = threading.Thread(target=client.receive_actions, daemon=True)
|
||||
action_receiver_thread.start()
|
||||
|
||||
try:
|
||||
# Run the control loop
|
||||
client.control_loop(task)
|
||||
except KeyboardInterrupt:
|
||||
client.stop()
|
||||
action_receiver_thread.join()
|
||||
# (Optionally) plot the action queue size
|
||||
visualize_action_queue_size(client.action_queue_size)
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
The following two parameters are key in every setup:
|
||||
|
||||
<table>
|
||||
<thead>
|
||||
<tr>
|
||||
<th>Hyperparameter</th>
|
||||
<th>Default</th>
|
||||
<th>What it does</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
<tr>
|
||||
<td>
|
||||
<code>actions_per_chunk</code>
|
||||
</td>
|
||||
<td>50</td>
|
||||
<td>
|
||||
How many actions the policy outputs at once. Typical values: 10-50.
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>
|
||||
<code>chunk_size_threshold</code>
|
||||
</td>
|
||||
<td>0.7</td>
|
||||
<td>
|
||||
When the queue is ≤ 50% full, the client sends a fresh observation.
|
||||
Value in [0, 1].
|
||||
</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
|
||||
<Tip>
|
||||
Different values of `actions_per_chunk` and `chunk_size_threshold` do result
|
||||
in different behaviours.
|
||||
</Tip>
|
||||
|
||||
On the one hand, increasing the value of `actions_per_chunk` will result in reducing the likelihood of ending up with no actions to execute, as more actions will be available when the new chunk is computed.
|
||||
However, larger values of `actions_per_chunk` might also result in less precise actions, due to the compounding errors consequent to predicting actions over longer timespans.
|
||||
|
||||
On the other hand, increasing the value of `chunk_size_threshold` will result in sending out to the `PolicyServer` observations for inference more often, resulting in a larger number of updates action chunks, overlapping on significant portions. This results in high adaptability, in the limit predicting one action chunk for each observation, which is in turn only marginally consumed while a new one is produced.
|
||||
This option does also put more pressure on the inference pipeline, as a consequence of the many requests. Conversely, values of `chunk_size_threshold` close to 0.0 collapse to the synchronous edge case, whereby new observations are only sent out whenever the current chunk is exhausted.
|
||||
|
||||
We found the default values of `actions_per_chunk` and `chunk_size_threshold` to work well in the experiments we developed for the [SmolVLA paper](https://huggingface.co/papers/2506.01844), but recommend experimenting with different values to find the best fit for your setup.
|
||||
|
||||
### Tuning async inference for your setup
|
||||
|
||||
1. **Choose your computational resources carefully.** [PI0](https://huggingface.co/lerobot/pi0) occupies 14GB of memory at inference time, while [SmolVLA](https://huggingface.co/lerobot/smolvla_base) requires only ~2GB. You should identify the best computational resource for your use case keeping in mind smaller policies require less computational resources. The combination of policy and device used (CPU-intensive, using MPS, or the number of CUDA cores on a given NVIDIA GPU) directly impacts the average inference latency you should expect.
|
||||
2. **Adjust your `fps` based on inference latency.** While the server generates a new action chunk, the client is not idle and is stepping through its current action queue. If the two processes happen at fundamentally different speeds, the client might end up with an empty queue. As such, you should reduce your fps if you consistently run out of actions in queue.
|
||||
3. **Adjust `chunk_size_threshold`**.
|
||||
- Values closer to `0.0` result in almost sequential behavior. Values closer to `1.0` → send observation every step (more bandwidth, relies on good world-model).
|
||||
- We found values around 0.5-0.6 to work well. If you want to tweak this, spin up a `RobotClient` setting the `--debug_visualize_queue_size` to `True`. This will plot the action queue size evolution at runtime, and you can use it to find the value of `chunk_size_threshold` that works best for your setup.
|
||||
|
||||
<p align="center">
|
||||
<img
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/async-inference/queues.png"
|
||||
width="80%"
|
||||
></img>
|
||||
</p>
|
||||
<p align="center">
|
||||
<i>
|
||||
The action queue size is plotted at runtime when the
|
||||
`--debug_visualize_queue_size` flag is passed, for various levels of
|
||||
`chunk_size_threshold` (`g` in the SmolVLA paper).
|
||||
</i>
|
||||
</p>
|
||||
|
||||
---
|
||||
|
||||
## Conclusion
|
||||
|
||||
Asynchronous inference represents a significant advancement in real-time robotics control, addressing the fundamental challenge of inference latency that has long plagued robotics applications. Through this tutorial, you've learned how to implement a complete async inference pipeline that eliminates idle frames and enables smoother, more reactive robot behaviors.
|
||||
|
||||
**Key Takeaways:**
|
||||
|
||||
- **Paradigm Shift**: Async inference decouples action prediction from execution, allowing robots to continue acting while new action chunks are computed in parallel
|
||||
- **Performance Benefits**: Eliminates "wait-for-inference" lags that are inherent in synchronous approaches, becoming increasingly important as policy models grow larger
|
||||
- **Flexible Architecture**: The server-client design enables distributed computing, where inference can run on powerful remote hardware while maintaining real-time robot control
|
||||
- **Tunable Parameters**: Success depends on properly configuring `actions_per_chunk` and `chunk_size_threshold` for your specific hardware, policy, and task requirements
|
||||
- **Universal Compatibility**: Works with all LeRobot-supported policies, from lightweight ACT models to vision-language models like SmolVLA
|
||||
|
||||
Start experimenting with the default parameters, monitor your action queue sizes, and iteratively refine your setup to achieve optimal performance for your specific use case.
|
||||
If you want to discuss this further, hop into our [Discord community](https://discord.gg/s3KuuzsPFb), or open an issue on our [GitHub repository](https://github.com/huggingface/lerobot/issues).
|
||||
@@ -647,6 +647,5 @@ The `--strategy.type` flag selects the execution mode:
|
||||
- `sentry`: Continuous recording with auto-upload (useful for large-scale evaluation)
|
||||
- `highlight`: Ring buffer recording with keystroke save (useful for capturing interesting events)
|
||||
- `dagger`: Human-in-the-loop data collection (see [HIL Data Collection](./hil_data_collection))
|
||||
- `episodic`: Episode-oriented policy recording with reset phases between episodes
|
||||
|
||||
All strategies support `--inference.type=rtc` for smooth execution with slow VLA models (Pi0, Pi0.5, SmolVLA).
|
||||
|
||||
@@ -157,44 +157,6 @@ Foot pedal input is also supported via `--strategy.input_device=pedal`. Configur
|
||||
| `--strategy.input_device` | Input device: `keyboard` or `pedal` (default: keyboard) |
|
||||
| `--teleop.type` | **Required.** Teleoperator type |
|
||||
|
||||
### Episodic (`--strategy.type=episodic`)
|
||||
|
||||
Episode-oriented recording that mirrors the behavior of `lerobot-record`. The policy drives the robot for each episode; an optional teleoperator can drive the robot during the reset phase between episodes.
|
||||
|
||||
```bash
|
||||
lerobot-rollout \
|
||||
--strategy.type=episodic \
|
||||
--policy.path=${HF_USER}/my_policy \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/ttyACM0 \
|
||||
--teleop.type=so100_leader \
|
||||
--teleop.port=/dev/ttyACM1 \
|
||||
--dataset.repo_id=${HF_USER}/my_eval_data \
|
||||
--dataset.num_episodes=20 \
|
||||
--dataset.episode_time_s=30 \
|
||||
--dataset.reset_time_s=10 \
|
||||
--dataset.single_task="Pick up the red cube"
|
||||
```
|
||||
|
||||
Teleop is optional — if omitted the robot holds its position during the reset phase.
|
||||
|
||||
**Keyboard controls:**
|
||||
|
||||
| Key | Action |
|
||||
| ----------- | -------------------------------- |
|
||||
| `→` (right) | End the current episode early |
|
||||
| `←` (left) | Discard episode and re-record it |
|
||||
| `ESC` | Stop the recording session |
|
||||
|
||||
| Flag | Description |
|
||||
| ----------------------------------------------- | -------------------------------------------------------------------------- |
|
||||
| `--dataset.num_episodes` | Number of episodes to record |
|
||||
| `--dataset.episode_time_s` | Duration of each recording episode in seconds |
|
||||
| `--dataset.reset_time_s` | Duration of the reset phase between episodes in seconds |
|
||||
| `--teleop.type` | Optional. Teleoperator to drive the robot during resets |
|
||||
| `--strategy.reset_to_initial_position` | Whether to reset the robot to its initial position between episodes |
|
||||
| `--strategy.smooth_leader_to_follower_handover` | Whether to turn on or off the leader -> follower smooth handover behavior. |
|
||||
|
||||
---
|
||||
|
||||
## Inference Backends
|
||||
|
||||
@@ -1,29 +0,0 @@
|
||||
# LeLab - LeRobot Guide
|
||||
|
||||
LeLab is a graphical user interface built on top of the LeRobot library, designed to make robotics accessible without needing to memorize CLI commands. From a single app you can configure your robot, teleoperate it, collect datasets, train policies locally or on cloud GPUs via HF Jobs, and deploy trained models back onto your robot. It's the easiest way to go from an unboxed SO-101 to a working policy, and a great companion for anyone learning the LeRobot workflow. Source code and issues live on GitHub: [huggingface/leLab](https://github.com/huggingface/leLab).
|
||||
|
||||
> [!TIP]
|
||||
> For now LeLab is compatible only with SO-ARM101
|
||||
|
||||
<Youtube id="VqyKUuW9V1g" />
|
||||
|
||||
### Installation
|
||||
|
||||
Requires [`uv`](https://docs.astral.sh/uv/getting-started/installation/). Install and launch in one command:
|
||||
|
||||
```
|
||||
uv tool install git+https://github.com/huggingface/leLab.git && lelab
|
||||
```
|
||||
|
||||
After install, run `lelab` from your terminal anytime to start the app.
|
||||
|
||||
### Features
|
||||
|
||||
- **Add robots** — Select arm type (leader/follower), calibrate each joint from the middle position, and attach cameras.
|
||||
- **Teleoperation** — Control the follower arm with the leader and see a live 3D visualization of the arms.
|
||||
- **Dataset recording** — Define a task description, number of episodes, and episode/reset durations. Press spacebar to advance between episodes. 30+ episodes recommended.
|
||||
- **Local training** — Train a policy directly on your own machine with a selected dataset, policy type, batch size, and step count.
|
||||
- **Cloud training with HF Jobs** — Train on powerful GPUs via [HF Jobs](https://huggingface.co/docs/huggingface_hub/en/guides/jobs) with transparent pricing. Run `hf auth login` first. See the [Compute HW Guide](hardware_guide) for hardware/batch size tips.
|
||||
- **Training visualization** — Watch progress live in the app, with checkpoints saved automatically.
|
||||
- **Run trained policies** — Pick any model from your jobs list and run inference on your robot with one click.
|
||||
- **Use community datasets** — Provide any Hugging Face dataset ID to train on datasets you didn't record yourself.
|
||||
@@ -275,7 +275,7 @@ A converter aggregates per‑episode files into larger shards and writes episode
|
||||
pip install "https://github.com/huggingface/lerobot/archive/33cad37054c2b594ceba57463e8f11ee374fa93c.zip"
|
||||
|
||||
# Convert an existing v2.1 dataset hosted on the Hub:
|
||||
python -m lerobot.scripts.convert_dataset_v21_to_v30 --repo-id=<HF_USER/DATASET_ID>
|
||||
python -m lerobot.datasets.v30.convert_dataset_v21_to_v30 --repo-id=<HF_USER/DATASET_ID>
|
||||
```
|
||||
|
||||
**What it does**
|
||||
|
||||
@@ -1,433 +0,0 @@
|
||||
# MolmoAct2 Policy
|
||||
|
||||
MolmoAct2 is the LeRobot policy implementation of
|
||||
[MolmoAct2](https://allenai.org/blog/molmoact2), ported into the LeRobot
|
||||
training, evaluation, checkpointing, and dataset interfaces for easier use with
|
||||
LeRobot datasets.
|
||||
|
||||
This implementation currently supports training and evaluation for the regular
|
||||
MolmoAct2 model. MolmoAct2-Think, which supports adaptive depth reasoning, is
|
||||
not included in this LeRobot policy yet and is coming soon.
|
||||
|
||||
For the original MolmoAct2 training code used for the experiments reported in
|
||||
the paper, see [allenai/molmoact2](https://github.com/allenai/molmoact2).
|
||||
|
||||
## Installation Requirements
|
||||
|
||||
Install LeRobot with the MolmoAct2 optional dependencies:
|
||||
|
||||
```bash
|
||||
pip install -e ".[molmoact2]"
|
||||
```
|
||||
|
||||
To run the models in this repository, you need an NVIDIA GPU. The measurements
|
||||
below were taken on a single NVIDIA H100 80GB with bf16 model loading, LIBERO with two RGB cameras. MolmoAct2 rows use `chunk_size=10`, action dim 7
|
||||
padded to `expected_max_action_dim=32`, and `num_flow_timesteps=8`. Training measurements use
|
||||
`gradient_checkpointing=true` and include the forward pass, backward pass,
|
||||
gradient clipping, optimizer step, and optimizer state allocation. Values are
|
||||
peak GPU memory sampled with `nvidia-smi`. Leave a few GiB of headroom for
|
||||
dataloader workers, CUDA context, and fragmentation.
|
||||
|
||||
Multi-GPU training through `accelerate` increases throughput and global batch
|
||||
size, but this LeRobot port does not currently expose the original MolmoAct2
|
||||
`fsdp_devices` model-parallel training path. The current training script has
|
||||
not been tested for multi-node training.
|
||||
|
||||
| Mode | Peak Memory, bs=8 | Peak Memory, bs=16 | Peak Memory, bs=32 |
|
||||
| ------------------------------------------------ | ----------------: | -----------------: | -----------------: |
|
||||
| Inference, continuous, CUDA graph enabled (bs=1) | 12.1 GiB | - | - |
|
||||
| Fine-tuning, action expert only, continuous | 16.5 GiB | 18.3 GiB | 21.4 GiB |
|
||||
| Fine-tuning, LoRA VLM, both action modes | 20.2 GiB | 26.8 GiB | 41.3 GiB |
|
||||
| Fine-tuning, full model, both action modes | 48.3 GiB | 49.8 GiB | 60.1 GiB |
|
||||
|
||||
The repo has been tested with Ubuntu 22.04.
|
||||
|
||||
## Usage
|
||||
|
||||
To use MolmoAct2 in a LeRobot training config, set:
|
||||
|
||||
```python
|
||||
policy.type=molmoact2
|
||||
```
|
||||
|
||||
## Training
|
||||
|
||||
MolmoAct2 can be fine-tuned from either the released MolmoAct2 Hugging Face
|
||||
checkpoint format or from a checkpoint already saved by LeRobot. Both routes use
|
||||
the same LeRobot training loop, dataset transforms, checkpoint saving, and
|
||||
logging. The difference is only how the initial policy weights and processor
|
||||
state are loaded.
|
||||
|
||||
### Training With Original MolmoAct2 Weight
|
||||
|
||||
Use `policy.checkpoint_path` when starting from a released MolmoAct2 checkpoint,
|
||||
for example `allenai/MolmoAct2` or `allenai/MolmoAct2-LIBERO`. LeRobot will load
|
||||
the original HF model files, then build its own policy processor from the
|
||||
dataset metadata and the policy options below.
|
||||
|
||||
The command below shows full fine-tuning on the merged LIBERO dataset. It uses
|
||||
bf16 model loading, 8 flow timesteps, LeRobot dataset statistics, image
|
||||
augmentation, and LeRobot's checkpointing/logging path.
|
||||
|
||||
```bash
|
||||
accelerate launch \
|
||||
--num_processes=8 \
|
||||
--mixed_precision=bf16 \
|
||||
-m lerobot.scripts.lerobot_train \
|
||||
--dataset.repo_id=allenai/MolmoAct2-LIBERO-Dataset \
|
||||
--dataset.root=/path/to/lerobot/data/allenai/MolmoAct2-LIBERO-Dataset \
|
||||
--dataset.video_backend=pyav \
|
||||
--dataset.image_transforms.enable=true \
|
||||
--policy.type=molmoact2 \
|
||||
--policy.checkpoint_path=allenai/MolmoAct2-LIBERO \
|
||||
--policy.device=cuda \
|
||||
--policy.action_mode=both \
|
||||
--policy.chunk_size=10 \
|
||||
--policy.n_action_steps=10 \
|
||||
--policy.setup_type="single franka robotic arm in libero" \
|
||||
--policy.control_mode="delta end-effector pose" \
|
||||
--policy.image_keys='["observation.images.image","observation.images.wrist_image"]' \
|
||||
--policy.model_dtype=bfloat16 \
|
||||
--policy.num_flow_timesteps=8 \
|
||||
--policy.gradient_checkpointing=true \
|
||||
--policy.freeze_embedding=true \
|
||||
--policy.normalize_gripper=false \
|
||||
--policy.enable_knowledge_insulation=false \
|
||||
--policy.push_to_hub=false \
|
||||
--wandb.enable=true \
|
||||
--wandb.entity=<wandb_entity> \
|
||||
--wandb.project=<wandb_project> \
|
||||
--job_name=<job_name> \
|
||||
--output_dir=outputs/<job_name> \
|
||||
--steps=10000 \
|
||||
--batch_size=32 \
|
||||
--num_workers=4 \
|
||||
--log_freq=20 \
|
||||
--eval_freq=-1 \
|
||||
--save_checkpoint=true \
|
||||
--save_freq=2000
|
||||
```
|
||||
|
||||
### Training With LeRobot MolmoAct2 Weight
|
||||
|
||||
Use `policy.path` when starting from a MolmoAct2 checkpoint that was saved by
|
||||
LeRobot, either from a local `pretrained_model` directory or from the Hub. This
|
||||
restores the saved LeRobot policy config, model weights, processor, and
|
||||
normalization statistics. You can still override training-time options such as
|
||||
`batch_size`, `steps`, LoRA flags, or `policy.action_mode`.
|
||||
|
||||
```bash
|
||||
accelerate launch \
|
||||
--num_processes=8 \
|
||||
--mixed_precision=bf16 \
|
||||
-m lerobot.scripts.lerobot_train \
|
||||
--dataset.repo_id=allenai/MolmoAct2-LIBERO-Dataset \
|
||||
--dataset.root=/path/to/lerobot/data/allenai/MolmoAct2-LIBERO-Dataset \
|
||||
--dataset.video_backend=pyav \
|
||||
--dataset.image_transforms.enable=true \
|
||||
--policy.path=/path/to/pretrained_model \
|
||||
--policy.device=cuda \
|
||||
--policy.action_mode=both \
|
||||
--policy.chunk_size=10 \
|
||||
--policy.n_action_steps=10 \
|
||||
--policy.model_dtype=bfloat16 \
|
||||
--policy.num_flow_timesteps=8 \
|
||||
--policy.gradient_checkpointing=true \
|
||||
--wandb.enable=true \
|
||||
--wandb.entity=<wandb_entity> \
|
||||
--wandb.project=<wandb_project> \
|
||||
--job_name=<job_name> \
|
||||
--output_dir=outputs/<job_name> \
|
||||
--steps=10000 \
|
||||
--batch_size=32 \
|
||||
--num_workers=4 \
|
||||
--log_freq=20 \
|
||||
--eval_freq=-1 \
|
||||
--save_checkpoint=true \
|
||||
--save_freq=2000
|
||||
```
|
||||
|
||||
### Common Practices
|
||||
|
||||
For fine-tuning on a comparatively small dataset, such as a single LIBERO suite
|
||||
or a real-world dataset with less than 200 demonstrations, a global batch size of
|
||||
16 to 32 is a good starting point. In these settings, `policy.enable_lora_vlm=true` or `policy.train_action_expert_only=true` is also a practical choice. In both
|
||||
cases, we intentionally keep the action expert fully trainable, which we found
|
||||
to be crucial for model performance. For larger fine-tuning datasets, larger
|
||||
global batch sizes and full fine-tuning are usually preferred.
|
||||
|
||||
### Common Policy Options
|
||||
|
||||
- `policy.checkpoint_path`: original MolmoAct2 HF checkpoint to initialize from.
|
||||
Use this for released MolmoAct2 weights.
|
||||
- `policy.path`: LeRobot checkpoint to initialize from. Use this for checkpoints
|
||||
created by LeRobot training.
|
||||
- `policy.action_mode`: training target, one of `continuous`, `discrete`, or
|
||||
`both`. `both` trains the flow-matching action expert and the discrete
|
||||
action-token loss.
|
||||
- `policy.train_action_expert_only`: trains only parameters whose names contain
|
||||
`action_expert`. It requires `policy.action_mode=continuous`.
|
||||
- `policy.enable_lora_vlm`: enables LoRA on VLM linear layers. Use
|
||||
`policy.enable_lora_action_expert=true` only if LoRA should also cover action
|
||||
expert linear layers. When `policy.enable_lora_action_expert=false`, the
|
||||
action expert base weights remain fully trainable while the VLM is trained
|
||||
through LoRA adapters. When `policy.enable_lora_action_expert=true`, the
|
||||
action expert is also adapter-tuned instead of fully fine-tuned.
|
||||
- `policy.enable_knowledge_insulation`: when `true`, detaches action-expert
|
||||
context K/V states before the action loss. The default is `false`.
|
||||
- `policy.chunk_size`: action horizon used by the policy. For LIBERO we use
|
||||
`10`. This LeRobot port overrides the loaded checkpoint's
|
||||
`max_action_horizon` with this value.
|
||||
- `policy.n_action_steps`: number of actions consumed from each predicted
|
||||
chunk before querying the policy again. For LIBERO, set it to `chunk_size`.
|
||||
- `policy.setup_type`: text inserted into the prompt to describe the robot and
|
||||
scene, e.g. `single franka robotic arm in libero`. More examples are listed
|
||||
in the `metadata_by_tag` entries of
|
||||
[`norm_stats.json`](https://huggingface.co/allenai/MolmoAct2/blob/main/norm_stats.json).
|
||||
- `policy.control_mode`: text inserted into the prompt to describe the action
|
||||
space, e.g. `delta end-effector pose` or `absolute joint pose`.
|
||||
- `policy.image_keys`: ordered LeRobot image observation keys passed to the
|
||||
processor.
|
||||
- `policy.model_dtype`: checkpoint/forward dtype, one of `float32`,
|
||||
`bfloat16`, or `float16`. Use `bfloat16` for normal training.
|
||||
- `policy.num_flow_timesteps`: number of flow-matching timesteps sampled per
|
||||
example during training. We use `8` for fine-tuning.
|
||||
- `policy.num_inference_steps`: optional override for continuous action
|
||||
generation steps at inference time.
|
||||
- `policy.gradient_checkpointing`: enables checkpointing in the VLM/action path
|
||||
to reduce activation memory.
|
||||
- `policy.freeze_embedding`: freezes input embeddings. The default is `true`.
|
||||
- `policy.normalize_gripper`: controls whether gripper dimensions are included
|
||||
in state/action quantile normalization. The default is `false`.
|
||||
- `policy.normalize_language`: normalizes task strings before prompt
|
||||
construction. The default is `true`.
|
||||
- `policy.mask_action_dim_padding`: masks padded dimensions in the flow loss.
|
||||
Released checkpoints use `policy.expected_max_action_dim=32`.
|
||||
- `policy.max_sequence_length`: optional manual sequence cap. Leave unset to
|
||||
infer it from images, state dimension, action dimension, action horizon, and
|
||||
discrete-action mode.
|
||||
|
||||
### Learning Rates
|
||||
|
||||
MolmoAct2 uses parameter-group learning rates to match the original MolmoAct2
|
||||
fine-tuning experiments.
|
||||
|
||||
- Full fine-tuning uses `policy.optimizer_lr=1e-5` for the VLM,
|
||||
`policy.optimizer_vit_lr=5e-6` for the vision tower,
|
||||
`policy.optimizer_connector_lr=5e-6` for image connector layers, and
|
||||
`policy.optimizer_action_expert_lr=5e-5` for the action expert.
|
||||
- LoRA VLM fine-tuning sets the VLM, vision, and connector LoRA parameter
|
||||
groups to `5e-5` when `policy.enable_lora_vlm=true`. By default,
|
||||
`policy.enable_lora_action_expert=false`, so the action expert is still fully
|
||||
fine-tuned with `policy.optimizer_action_expert_lr`. If
|
||||
`policy.enable_lora_action_expert=true`, the action expert is trained through
|
||||
LoRA adapters instead.
|
||||
- Action-expert-only fine-tuning trains only the action expert and uses
|
||||
`policy.optimizer_action_expert_lr=5e-5`.
|
||||
|
||||
You can override the full fine-tuning and action-expert learning rates with
|
||||
`policy.optimizer_lr`, `policy.optimizer_vit_lr`,
|
||||
`policy.optimizer_connector_lr`, and `policy.optimizer_action_expert_lr`.
|
||||
Scheduler settings can be changed with `policy.scheduler_warmup_steps`,
|
||||
`policy.scheduler_decay_steps`, and `policy.scheduler_decay_lr`.
|
||||
|
||||
### Dataset Quantile Statistics
|
||||
|
||||
MolmoAct2 defaults to quantile normalization for state and action features. If
|
||||
your dataset has not been converted with quantile statistics, you can add them
|
||||
with:
|
||||
|
||||
```bash
|
||||
python src/lerobot/scripts/augment_dataset_quantile_stats.py \
|
||||
--repo-id=your_dataset
|
||||
```
|
||||
|
||||
Alternatively, train MolmoAct2 with mean/std normalization:
|
||||
|
||||
```bash
|
||||
--policy.normalization_mapping='{"ACTION": "MEAN_STD", "STATE": "MEAN_STD", "VISUAL": "IDENTITY"}'
|
||||
```
|
||||
|
||||
## Evaluation
|
||||
|
||||
Evaluation also supports both LeRobot-saved checkpoints and original MolmoAct2
|
||||
HF checkpoints. For LIBERO replication, keep the EGL rendering environment
|
||||
fixed and use `policy.per_episode_seed=true`.
|
||||
|
||||
**Important:** We found that `num_steps_wait=10` does not reliably let the
|
||||
LIBERO scene stabilize and can degrade measured success. All LIBERO evaluation
|
||||
results reported here use `num_steps_wait=50`.
|
||||
|
||||
### Evaluation With LeRobot MolmoAct2 Weight
|
||||
|
||||
Use `policy.path` for a checkpoint saved by LeRobot. The saved processor and
|
||||
normalization statistics are restored together with the model.
|
||||
|
||||
```bash
|
||||
export MUJOCO_GL=egl
|
||||
export PYOPENGL_PLATFORM=egl
|
||||
export OMP_NUM_THREADS=1
|
||||
export MKL_NUM_THREADS=1
|
||||
|
||||
lerobot-eval \
|
||||
--policy.path=allenai/MolmoAct2-LIBERO-LeRobot \
|
||||
--policy.inference_action_mode=continuous \
|
||||
--policy.model_dtype=bfloat16 \
|
||||
--policy.use_amp=true \
|
||||
--policy.enable_inference_cuda_graph=true \
|
||||
--policy.device=cuda \
|
||||
--policy.per_episode_seed=true \
|
||||
--policy.eval_seed=1000 \
|
||||
--env.type=libero \
|
||||
--env.task=libero_10,libero_goal,libero_object,libero_spatial \
|
||||
--env.camera_name_mapping='{"agentview_image":"image","robot0_eye_in_hand_image":"wrist_image"}' \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=50 \
|
||||
--seed=1000
|
||||
```
|
||||
|
||||
### Evaluation With Original MolmoAct2 Weight
|
||||
|
||||
You can evaluate a released Hugging Face checkpoint directly without first
|
||||
converting it to a LeRobot checkpoint. In this case, set
|
||||
`policy.checkpoint_path` to the HF model repo and provide `policy.norm_tag`.
|
||||
For LIBERO, `policy.norm_tag=libero` loads the LIBERO action/state
|
||||
normalization statistics, action horizon, prompt metadata, and image-key order
|
||||
from the checkpoint's `norm_stats.json`.
|
||||
|
||||
To fully replicate the MolmoAct2 paper results with released Hugging Face
|
||||
checkpoints, we recommend using the v0.5.1-pinned
|
||||
[`allenai/lerobot` `molmoact2-hf-inference`](https://github.com/allenai/lerobot/tree/molmoact2-hf-inference)
|
||||
branch. That branch matches the original evaluation settings used for the
|
||||
reported numbers.
|
||||
|
||||
```bash
|
||||
export MUJOCO_GL=egl
|
||||
export PYOPENGL_PLATFORM=egl
|
||||
export OMP_NUM_THREADS=1
|
||||
export MKL_NUM_THREADS=1
|
||||
|
||||
lerobot-eval \
|
||||
--policy.type=molmoact2 \
|
||||
--policy.checkpoint_path=allenai/MolmoAct2-LIBERO \
|
||||
--policy.norm_tag=libero \
|
||||
--policy.inference_action_mode=continuous \
|
||||
--policy.model_dtype=float32 \
|
||||
--policy.use_amp=false \
|
||||
--policy.enable_inference_cuda_graph=true \
|
||||
--policy.device=cuda \
|
||||
--policy.per_episode_seed=true \
|
||||
--policy.eval_seed=1000 \
|
||||
--env.type=libero \
|
||||
--env.task=libero_goal \
|
||||
--env.camera_name_mapping='{"agentview_image":"image","robot0_eye_in_hand_image":"wrist_image"}' \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=50 \
|
||||
--seed=1000
|
||||
```
|
||||
|
||||
Use `--env.task=libero_10,libero_goal,libero_object,libero_spatial` to run the
|
||||
full LIBERO suite. The same command works for other released MolmoAct2
|
||||
checkpoints as long as the requested `policy.norm_tag` exists in that
|
||||
checkpoint's `norm_stats.json`.
|
||||
|
||||
### Common Evaluation Options
|
||||
|
||||
- `policy.inference_action_mode`: required for rollout. Use `continuous` for
|
||||
flow-matching inference or `discrete` for action-token inference. It must be
|
||||
compatible with the training-time `policy.action_mode` saved in the
|
||||
checkpoint.
|
||||
- `policy.path`: LeRobot checkpoint path or Hub repo. Use this for checkpoints
|
||||
saved by LeRobot.
|
||||
- `policy.checkpoint_path`: original MolmoAct2 HF checkpoint path or Hub repo.
|
||||
Use this with `policy.type=molmoact2` and `policy.norm_tag`.
|
||||
- `policy.norm_tag`: selects normalization statistics, prompt metadata,
|
||||
image-key order, and action horizon from the original checkpoint's
|
||||
`norm_stats.json`. It is required for direct original-HF checkpoint
|
||||
evaluation.
|
||||
- `policy.model_dtype`: model load/forward dtype. Use `bfloat16` for normal
|
||||
GPU evaluation. Use `float32` only when you explicitly want fp32 inference.
|
||||
- `policy.use_amp`: runs the policy forward under autocast during eval. For
|
||||
`model_dtype=bfloat16`, keep this enabled.
|
||||
- `policy.enable_inference_cuda_graph`: enables the MolmoAct2 inference CUDA
|
||||
graph path for faster repeated continuous-action rollout.
|
||||
- `policy.per_episode_seed` and `policy.eval_seed`: make stochastic continuous
|
||||
action generation deterministic per episode for replication.
|
||||
- `env.task`: comma-separated LIBERO suites or a single suite. Use
|
||||
`libero_10,libero_goal,libero_object,libero_spatial` for the full benchmark.
|
||||
- `env.camera_name_mapping`: maps LIBERO camera names to the image keys expected
|
||||
by the policy processor.
|
||||
|
||||
## Performance Results
|
||||
|
||||
### LIBERO Benchmark Results
|
||||
|
||||
MolmoAct2 has demonstrated strong performance on the LIBERO benchmark suite. To
|
||||
compare and test its LeRobot implementation, we fine-tuned
|
||||
[`allenai/MolmoAct2-LIBERO`](https://huggingface.co/allenai/MolmoAct2-LIBERO)
|
||||
for an additional 10k steps on the LIBERO dataset with per-GPU batch size 32 on
|
||||
8 H100 GPUs, then compared the results to the original MolmoAct2 reference
|
||||
results.
|
||||
|
||||
The LeRobot fine-tuned checkpoint reported here is available at
|
||||
[`allenai/MolmoAct2-LIBERO-LeRobot`](https://huggingface.co/allenai/MolmoAct2-LIBERO-LeRobot)
|
||||
and was trained on
|
||||
[`allenai/MolmoAct2-LIBERO-Dataset`](https://huggingface.co/datasets/allenai/MolmoAct2-LIBERO-Dataset).
|
||||
|
||||
| Benchmark | LeRobot Implementation | MolmoAct2 Original |
|
||||
| -------------- | ---------------------: | -----------------: |
|
||||
| LIBERO Spatial | 98.4% | 97.8% |
|
||||
| LIBERO Object | 100.0% | 100.0% |
|
||||
| LIBERO Goal | 98.0% | 97.8% |
|
||||
| LIBERO 10 | 96.6% | 93.2% |
|
||||
| Average | 98.25% | 97.20% |
|
||||
|
||||
These results demonstrate MolmoAct2's strong performance across diverse robotic
|
||||
manipulation tasks. To reproduce them, follow the instructions in the LIBERO
|
||||
evaluation section.
|
||||
|
||||
## Differences From the Original Implementation
|
||||
|
||||
This LeRobot port is intended to match MolmoAct2 behavior while using LeRobot's
|
||||
dataset, training, evaluation, checkpoint, and logging infrastructure. The main
|
||||
differences from the original training repository are:
|
||||
|
||||
- The original paper training stack loads the model in fp32 and trains under
|
||||
mixed precision. This LeRobot port usually loads the checkpoint directly in
|
||||
`policy.model_dtype=bfloat16` for lower memory use.
|
||||
- The original repository uses its own FSDP/model-parallel training path. The
|
||||
LeRobot port uses the standard LeRobot/Accelerate training path and has not
|
||||
been tested for multi-node training.
|
||||
- The original repository supports sequence packing. The LeRobot port trains on
|
||||
one LeRobot sample per item and pads to an inferred fixed sequence budget.
|
||||
- The LeRobot port follows LeRobot's optimizer, scheduler, checkpoint saving,
|
||||
dataset transforms, image augmentation, and Weights & Biases logging
|
||||
conventions.
|
||||
- The original training path supports mixed action horizons by padding to
|
||||
`max_action_horizon` and masking padded horizon slots in the action expert
|
||||
self-attention. This is useful when training across datasets with different
|
||||
control frequencies. The LeRobot port currently targets single-dataset
|
||||
fine-tuning, so `policy.chunk_size` overrides the checkpoint
|
||||
`max_action_horizon` and horizon masking is not implemented yet. Support for
|
||||
this mixed-horizon path is planned.
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@misc{fang2026molmoact2actionreasoningmodels,
|
||||
title={MolmoAct2: Action Reasoning Models for Real-world Deployment},
|
||||
author={Haoquan Fang and Jiafei Duan and Donovan Clay and Sam Wang and Shuo Liu and Weikai Huang and Xiang Fan and Wei-Chuan Tsai and Shirui Chen and Yi Ru Wang and Shanli Xing and Jaemin Cho and Jae Sung Park and Ainaz Eftekhar and Peter Sushko and Karen Farley and Angad Wadhwa and Cole Harrison and Winson Han and Ying-Chun Lee and Eli VanderBilt and Rose Hendrix and Suveen Ellawela and Lucas Ngoo and Joyce Chai and Zhongzheng Ren and Ali Farhadi and Dieter Fox and Ranjay Krishna},
|
||||
year={2026},
|
||||
eprint={2605.02881},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.RO},
|
||||
url={https://arxiv.org/abs/2605.02881},
|
||||
}
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
This model is licensed under Apache 2.0. It is intended for research and
|
||||
educational use in accordance with
|
||||
[Ai2's Responsible Use Guidelines](https://allenai.org/responsible-use),
|
||||
consistent with [allenai/molmoact2](https://github.com/allenai/molmoact2).
|
||||
@@ -91,7 +91,7 @@ lerobot-train \
|
||||
If your dataset is not converted with `quantiles`, you can convert it with the following command:
|
||||
|
||||
```bash
|
||||
python src/lerobot/scripts/augment_dataset_quantile_stats.py \
|
||||
python src/lerobot/datasets/v30/augment_dataset_quantile_stats.py \
|
||||
--repo-id=your_dataset \
|
||||
```
|
||||
|
||||
|
||||
@@ -1,39 +0,0 @@
|
||||
# MolmoAct2
|
||||
|
||||
This repository contains the LeRobot policy implementation of
|
||||
[MolmoAct2](https://allenai.org/blog/molmoact2), ported into LeRobot for
|
||||
training, evaluation, checkpointing, and dataset compatibility.
|
||||
|
||||
This implementation currently supports training and evaluation for the regular
|
||||
MolmoAct2 model. MolmoAct2-Think, which supports adaptive depth reasoning, is
|
||||
not included in this LeRobot policy yet and is coming soon.
|
||||
|
||||
For the original MolmoAct2 training code used for the experiments reported in
|
||||
the paper, see [allenai/molmoact2](https://github.com/allenai/molmoact2).
|
||||
|
||||
## LIBERO Evaluation
|
||||
|
||||
Important: we found that `num_steps_wait=10` does not reliably let the LIBERO
|
||||
scene stabilize and can degrade measured success. All LIBERO evaluation results
|
||||
reported for this LeRobot implementation use `num_steps_wait=50`.
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@misc{fang2026molmoact2actionreasoningmodels,
|
||||
title={MolmoAct2: Action Reasoning Models for Real-world Deployment},
|
||||
author={Haoquan Fang and Jiafei Duan and Donovan Clay and Sam Wang and Shuo Liu and Weikai Huang and Xiang Fan and Wei-Chuan Tsai and Shirui Chen and Yi Ru Wang and Shanli Xing and Jaemin Cho and Jae Sung Park and Ainaz Eftekhar and Peter Sushko and Karen Farley and Angad Wadhwa and Cole Harrison and Winson Han and Ying-Chun Lee and Eli VanderBilt and Rose Hendrix and Suveen Ellawela and Lucas Ngoo and Joyce Chai and Zhongzheng Ren and Ali Farhadi and Dieter Fox and Ranjay Krishna},
|
||||
year={2026},
|
||||
eprint={2605.02881},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.RO},
|
||||
url={https://arxiv.org/abs/2605.02881},
|
||||
}
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
This model is licensed under Apache 2.0. It is intended for research and
|
||||
educational use in accordance with
|
||||
[Ai2's Responsible Use Guidelines](https://allenai.org/responsible-use),
|
||||
consistent with [allenai/molmoact2](https://github.com/allenai/molmoact2).
|
||||
@@ -1,39 +0,0 @@
|
||||
# VLA-JEPA
|
||||
|
||||
This repository contains the LeRobot port of **VLA-JEPA**, a Vision-Language-Action model that combines a Qwen3-VL language backbone with a self-supervised video world model (V-JEPA2) and a flow-matching DiT action head.
|
||||
|
||||
Converted from [ginwind/VLA-JEPA](https://huggingface.co/ginwind/VLA-JEPA).
|
||||
|
||||
---
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
| Component | Module | Role |
|
||||
| ----------------------- | --------------------------------- | ------------------------------------------------------- |
|
||||
| **Qwen3-VL backbone** | `Qwen3VLInterface` | Fuses images + language instruction into context tokens |
|
||||
| **DiT-B action head** | `VLAJEPAActionHead` | Flow-matching diffusion over the action chunk |
|
||||
| **V-JEPA2 world model** | `ActionConditionedVideoPredictor` | Self-supervised video prediction loss (training only) |
|
||||
|
||||
At inference time only the Qwen backbone and action head are used; the world model is not needed.
|
||||
|
||||
---
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@misc{sun2026vlajepaenhancingvisionlanguageactionmodel,
|
||||
title = {VLA-JEPA: Enhancing Vision-Language-Action Model with Latent World Model},
|
||||
author = {Jingwen Sun and Wenyao Zhang and Zekun Qi and Shaojie Ren and Zezhi Liu and Hanxin Zhu and Guangzhong Sun and Xin Jin and Zhibo Chen},
|
||||
year = {2026},
|
||||
eprint = {2602.10098},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.RO},
|
||||
url = {https://arxiv.org/abs/2602.10098},
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## License
|
||||
|
||||
Weights are distributed under the license terms of the original [ginwind/VLA-JEPA](https://huggingface.co/ginwind/VLA-JEPA) repository (**Apache 2.0 License**). The LeRobot integration code follows the **Apache 2.0 License**.
|
||||
@@ -300,7 +300,7 @@ This replaces the old episode-per-file structure with efficient, optimally-sized
|
||||
If you have existing datasets in v2.1 format, use the migration tool:
|
||||
|
||||
```bash
|
||||
python src/lerobot/scripts/convert_dataset_v21_to_v30.py \
|
||||
python src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py \
|
||||
--repo-id your_id/existing_dataset
|
||||
```
|
||||
|
||||
|
||||
@@ -1,250 +0,0 @@
|
||||
# Remote Inference (lerobot-policy-server)
|
||||
|
||||
Remote inference decouples GPU policy inference from robot control. A `lerobot-policy-server` process runs the policy on a GPU machine; the robot runs `lerobot-rollout --inference.type=remote` as a **weightless edge client** — no policy weights, no GPU, no policy processors on the robot. One GPU server can serve several robots at once, and the remote backend works with every rollout strategy (`base`, `sentry`, `highlight`, `dagger`, `episodic`).
|
||||
|
||||
Use remote inference when:
|
||||
|
||||
- The policy is too large or too slow for the machine attached to the robot (e.g. Pi0/Pi0.5 on a Raspberry Pi or laptop edge).
|
||||
- You want one GPU to serve a fleet of robots running the same policy.
|
||||
- You want to update or restart the inference side without touching the robots.
|
||||
|
||||
<Tip>
|
||||
|
||||
Remote inference requires the `async` extra on **both** sides: `pip install 'lerobot[async]'` (installs `eclipse-zenoh` and `msgpack`). The server additionally needs the extras of the policy it serves (e.g. `lerobot[pi]`, `lerobot[smolvla]`).
|
||||
|
||||
</Tip>
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
robot (edge, weightless) GPU machine
|
||||
┌───────────────────────────┐ ┌────────────────────────────┐
|
||||
│ lerobot-rollout │ │ lerobot-policy-server │
|
||||
│ --inference.type=remote │ zenoh │ one process = one │
|
||||
│ │ router │ (model, revision, GPU) │
|
||||
│ control loop @ fps │ ┌────────┐ │ │
|
||||
│ └─ pops local action ◄──┼───┤ zenohd ├─────┼─► inference worker thread │
|
||||
│ buffer (chunks) │ └────────┘ │ (round-robin over │
|
||||
│ │ observations ► │ client sessions) │
|
||||
│ network worker thread ───┼──► ◄ action │ │
|
||||
│ (publishes obs, merges │ chunks │ stateless per request │
|
||||
│ chunks into buffer) │ │ │
|
||||
└───────────────────────────┘ └────────────────────────────┘
|
||||
```
|
||||
|
||||
The client keeps a local **action buffer** filled with chunks of future actions, so the control loop never blocks on the network: short network blips are absorbed by the buffer and the robot keeps moving. The client self-clocks — it requests a new chunk whenever the buffer holds less than `--inference.buffer_time_s` seconds of playback.
|
||||
|
||||
The server is **stateless per request**: clients ship their RTC prefixes and a delay hint with every observation, so a server crash or restart loses zero control state and reconnects are trivial. In production both robots and servers _dial out_ to a `zenohd` router (NAT-friendly: nothing on the robot network needs an open inbound port).
|
||||
|
||||
## Quickstart on a LAN (peer mode, no router)
|
||||
|
||||
For a quick test on one network you can skip the router: the server listens directly and the robot connects to it.
|
||||
|
||||
On the GPU machine:
|
||||
|
||||
```bash
|
||||
lerobot-policy-server \
|
||||
--model.repo_or_path=${HF_USER}/my_pi0_policy \
|
||||
--default_task="pick up the cube" \
|
||||
--zenoh.mode=peer \
|
||||
--zenoh.listen_endpoints='["tcp/0.0.0.0:7447"]'
|
||||
```
|
||||
|
||||
Wait for `Policy server up: ...` (the model is downloaded, loaded, and warmed up first).
|
||||
|
||||
On the robot machine (replace `192.168.1.42` with the GPU machine's IP):
|
||||
|
||||
```bash
|
||||
lerobot-rollout \
|
||||
--strategy.type=base \
|
||||
--policy.path=${HF_USER}/my_pi0_policy \
|
||||
--inference.type=remote \
|
||||
--inference.zenoh_mode=peer \
|
||||
--inference.connect_endpoint=tcp/192.168.1.42:7447 \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/ttyACM0 \
|
||||
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||
--task="pick up the cube" \
|
||||
--duration=60
|
||||
```
|
||||
|
||||
`--policy.path` on the client resolves to a config-only download (no weights): it is used for pre-flight validation and action ordering, and doubles as the default service address. The client's `--policy.path` and `--task` must match the server's `--model.repo_or_path` and `--default_task` — that pair is the namespace the service is published under (see [Troubleshooting](#troubleshooting)).
|
||||
|
||||
## Production deployment (router)
|
||||
|
||||
In production, run a [zenoh router](https://zenoh.io/docs/getting-started/installation/) (`zenohd`) somewhere both sides can reach, and have robots and servers dial out to it:
|
||||
|
||||
```bash
|
||||
zenohd # listens on tcp/0.0.0.0:7447 by default
|
||||
```
|
||||
|
||||
Configure the server with a YAML manifest:
|
||||
|
||||
```yaml
|
||||
# server.yaml
|
||||
model:
|
||||
repo_or_path: lerobot/pi0_towels
|
||||
revision: main
|
||||
dtype: bfloat16 # optional cast after load
|
||||
device: cuda
|
||||
default_task: "fold the towel"
|
||||
serving_mode: auto # shared for verified chunk-stateless policies, exclusive otherwise
|
||||
max_sessions: 5
|
||||
warmup_inferences: 2
|
||||
trained_fps: 30.0
|
||||
rtc:
|
||||
enabled: true
|
||||
execution_horizon: 10
|
||||
max_guidance_weight: 10.0
|
||||
health_port: 9100 # /healthz + /metrics; 0 disables
|
||||
zenoh:
|
||||
mode: client
|
||||
connect_endpoints: ["tcp/router.gpu-cluster.internal:7447"]
|
||||
```
|
||||
|
||||
```bash
|
||||
lerobot-policy-server --manifest server.yaml
|
||||
```
|
||||
|
||||
Everything in the manifest can also be set directly on the CLI (`--model.repo_or_path=...`, `--max_sessions=...`, etc.). One process serves exactly one `(model, revision, dtype, device)` — to serve two models, or one model on two GPUs, run two processes. Dynamic model loading is deliberately unsupported: pre-warmed processes keep capacity planning honest.
|
||||
|
||||
On the robot, only the endpoint changes (the default `--inference.zenoh_mode=client` is already router mode):
|
||||
|
||||
```bash
|
||||
lerobot-rollout \
|
||||
--strategy.type=base \
|
||||
--policy.path=lerobot/pi0_towels \
|
||||
--inference.type=remote \
|
||||
--inference.connect_endpoint=tcp/router.gpu-cluster.internal:7447 \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/ttyACM0 \
|
||||
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||
--task="fold the towel" \
|
||||
--duration=600
|
||||
```
|
||||
|
||||
### TLS / mTLS
|
||||
|
||||
For traffic that leaves a trusted network, terminate TLS at the router and give both sides client certificates (all three PEM paths are required together):
|
||||
|
||||
```yaml
|
||||
# server.yaml (zenoh section)
|
||||
zenoh:
|
||||
mode: client
|
||||
connect_endpoints: ["tls/router.gpu-cluster.internal:7447"]
|
||||
tls_root_ca_certificate: /etc/lerobot/ca.pem
|
||||
tls_connect_certificate: /etc/lerobot/server.pem
|
||||
tls_connect_private_key: /etc/lerobot/server.key
|
||||
```
|
||||
|
||||
On the robot the equivalent flags are `--inference.tls_ca`, `--inference.tls_cert`, and `--inference.tls_key`, with `--inference.connect_endpoint=tls/...`.
|
||||
|
||||
<Tip>
|
||||
|
||||
Multicast scouting is always disabled: discovery is configuration, not protocol magic. If nothing connects, check the endpoints — there is no fallback discovery mechanism.
|
||||
|
||||
</Tip>
|
||||
|
||||
## RTC over the network
|
||||
|
||||
The remote engine reuses the [Real-Time Chunking](./rtc) machinery: the client keeps the chunk leftover and latency tracking locally and ships an action prefix plus a delay hint with every observation; the server runs prefix-conditioned chunk generation. This gives the same smooth chunk-to-chunk transitions as local RTC, with network latency folded into the delay computation.
|
||||
|
||||
RTC is enabled by default on both sides (`rtc.enabled: true`). Tune it from the client:
|
||||
|
||||
```bash
|
||||
lerobot-rollout \
|
||||
... \
|
||||
--inference.type=remote \
|
||||
--inference.rtc.execution_horizon=10 \
|
||||
--inference.rtc.max_guidance_weight=10.0
|
||||
```
|
||||
|
||||
If the server or its policy does not support RTC (only `pi0`, `pi05`, and `smolvla` are RTC-capable, and the server manifest must have `rtc.enabled: true`), the session is **downgraded to plain chunk-append** and the client logs:
|
||||
|
||||
```
|
||||
RTC downgraded to chunk-append (server does not support RTC)
|
||||
```
|
||||
|
||||
The robot still runs — chunks are simply appended to the buffer without prefix blending, which can produce visible seams between chunks on slow policies.
|
||||
|
||||
## Fail-safe behavior
|
||||
|
||||
The client runs a fail-safe state machine (`CONNECTING → STREAMING → DEGRADED → STALLED → RECONNECTING → DEAD`). A bad initial deployment fails fast: `lerobot-rollout` aborts before the robot moves if the handshake or validation fails. Once streaming, faults degrade in stages:
|
||||
|
||||
| Condition | Behavior |
|
||||
| -------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| Short network blip / late chunk | The robot rides its action buffer; state goes `DEGRADED` after `--inference.degraded_after_s` (default 1.0 s) without a fresh chunk |
|
||||
| Buffered actions older than `max_action_age_s` | Stale actions are dropped (never executed); default `--inference.max_action_age_s=3.0` |
|
||||
| Buffer runs dry (`STALLED`) | Fallback per `--inference.fallback`: `hold` (default — robot holds its last commanded position), `repeat_last`, or `zero` |
|
||||
| Server liveliness lost / repeated request timeouts | `RECONNECTING`: re-handshake with exponential backoff (`reconnect_initial_backoff_s=0.5` doubling up to `reconnect_max_backoff_s=10.0`) |
|
||||
| Reconnected server runs a different model/revision | Hard refusal (`DEAD`) — the client never executes wrong-model chunks |
|
||||
| Offline longer than `max_offline_s` (default 60 s) | `DEAD`: the engine signals the rollout's shutdown event for a clean stop |
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
`--inference.fallback=zero` is required for velocity-controlled robots: for them "send nothing" means "keep the last velocity", so an explicit zero command is the only safe stop. For position-controlled arms the default `hold` is safe.
|
||||
|
||||
</Tip>
|
||||
|
||||
Server restarts are equally graceful: on SIGTERM the server drops its liveliness token first (clients ride their buffers through the drain), finishes the in-flight inference, and exits. Clients reconnect when the replacement comes up.
|
||||
|
||||
## Serving multiple robots
|
||||
|
||||
`max_sessions` caps concurrent clients per server process. A single inference worker thread serializes GPU access and round-robins over sessions with a pending observation; per-client newest-wins mailboxes mean overload degrades into longer cycle times (larger but correct client-side delays), never into queue buildup.
|
||||
|
||||
A rough capacity estimate, keeping ~20% headroom:
|
||||
|
||||
```
|
||||
N_robots ≈ 0.8 / (rate × inference_time)
|
||||
```
|
||||
|
||||
where `rate` is each robot's chunk-request rate in Hz (how often the client's buffer dips below `buffer_time_s`) and `inference_time` is the server's seconds per chunk. For example, at 100 ms per chunk and ~2 chunk requests per second per robot: `N ≈ 0.8 / (2 × 0.1) = 4` robots.
|
||||
|
||||
The actual serving mode is classified per policy family, never inferred:
|
||||
|
||||
- **shared** — verified chunk-stateless policies (`act`, `pi0`, `pi05`, and `smolvla` with `n_obs_steps=1`) serve up to `max_sessions` clients from one policy instance.
|
||||
- **exclusive** — stateful families (diffusion-family policies, `smolvla` with observation history, and any unverified policy) are forced to `max_sessions=1`. Run one server process per robot for these.
|
||||
|
||||
`serving_mode: auto` (the default) resolves this automatically; you may force `exclusive`, but `shared` can never override a stateful classification.
|
||||
|
||||
## Observability
|
||||
|
||||
With `health_port` set (default 9100), the server exposes:
|
||||
|
||||
- `GET /healthz` — `200 ok` while the inference worker is alive, `503` otherwise. Wire this to your orchestrator's liveness probe.
|
||||
- `GET /metrics` — Prometheus text format: `lerobot_policy_server_requests_total`, `errors_total`, `superseded_total`, `dropped_unknown_client_total`, `sessions_opened_total`, `sessions_closed_total`, `active_sessions`, `server_load`.
|
||||
|
||||
Every inference request also emits one structured audit line on the `lerobot.policy_server.audit` logger:
|
||||
|
||||
```json
|
||||
{
|
||||
"session_id": "9f2c...",
|
||||
"client_uuid": "robot-07",
|
||||
"seq_id": 412,
|
||||
"episode_id": 3,
|
||||
"queue_wait_ms": 1.8,
|
||||
"inference_ms": 93.2,
|
||||
"superseded": 0,
|
||||
"outcome": "ok"
|
||||
}
|
||||
```
|
||||
|
||||
`(session_id, seq_id)` correlates a server-side audit line with the client's request. Set a stable `--inference.client_uuid` per robot (instead of the default fresh UUID per run) for fleet-wide log correlation, and use `--inference.tags` to forward free-form labels in the handshake.
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
**`No policy server answered status query at '@lerobot/...'`**
|
||||
|
||||
The client found no server under the key it dialed. Either the endpoint is wrong (check `--inference.connect_endpoint`, the router, and firewalls), or the **service namespace** does not match. The namespace is the `(model_id, revision, task)` triple: on the client it comes from `--inference.service_model_id` (default: `--policy.path`), `--inference.service_revision` (default: `main`), and `--inference.service_task` (default: the rollout `--task`); on the server from `model.repo_or_path`, `model.revision`, and `service_name` (default: a slug of `default_task`). A robot task string that differs from the server's `default_task` is the most common cause — fix the task, or pin the namespace explicitly with `--inference.service_task` on the client / `service_name` in the manifest.
|
||||
|
||||
**`Action name/order mismatch between server policy and this robot`**
|
||||
|
||||
The hard sync-safety contract: chunk columns map to motors **by order**, so the robot's ordered action keys must exactly equal the policy's `action_feature_names`. This fires when the robot type, motor naming, or rename map differs from the training setup. Use the same robot type (and rename map) the policy was trained with.
|
||||
|
||||
**`RTC requested but this server/policy does not support it — downgrading to chunk-append`**
|
||||
|
||||
Informational, not fatal. Enable RTC in the server manifest (`rtc.enabled: true`) and make sure the policy family is RTC-capable (`pi0`, `pi05`, `smolvla`). Otherwise, expect chunk-append behavior (see [RTC over the network](#rtc-over-the-network)).
|
||||
|
||||
**`server full: N/N sessions active`**
|
||||
|
||||
The session-open was rejected at capacity. Raise `max_sessions` (shared mode only), or point the robot at another server replica — the rejection includes the current load so orchestration can retry elsewhere.
|
||||
@@ -1,185 +0,0 @@
|
||||
# ROBOMETER
|
||||
|
||||
ROBOMETER is a **general-purpose video-language robotic reward model**. It predicts dense, frame-level task progress and frame-level success from a trajectory video and a task description.
|
||||
|
||||
**Paper**: [ROBOMETER: Scaling General-Purpose Robotic Reward Models via Trajectory Comparisons](https://arxiv.org/abs/2603.02115)
|
||||
**Project**: [robometer.github.io](https://robometer.github.io/)
|
||||
**Original code**: [github.com/robometer/robometer](https://github.com/robometer/robometer)
|
||||
**Checkpoint**: [lerobot/Robometer-4B](https://huggingface.co/lerobot/Robometer-4B)
|
||||
|
||||
## Overview
|
||||
|
||||
ROBOMETER builds on `Qwen/Qwen3-VL-4B-Instruct` and adds three lightweight prediction heads:
|
||||
|
||||
- **Progress head**: predicts per-frame task progress in `[0, 1]`.
|
||||
- **Success head**: predicts per-frame task success probability.
|
||||
- **Preference head**: predicts which of two trajectories better completes the task during training.
|
||||
|
||||
The paper trains ROBOMETER with a composite objective:
|
||||
|
||||
```text
|
||||
L = L_pref + L_prog + L_succ
|
||||
```
|
||||
|
||||
The LeRobot integration is currently **inference-only**. It preserves the preference head so that the published `Robometer-4B` checkpoint loads without remapping, but `compute_reward()` queries the progress or success head only.
|
||||
|
||||
## What the LeRobot Integration Covers
|
||||
|
||||
- Standard `reward_model.type=robometer` configuration through LeRobot.
|
||||
- Qwen3-VL image and text preprocessing through `RobometerEncoderProcessorStep`.
|
||||
- LeRobot reward-model save/load APIs through `PreTrainedRewardModel`.
|
||||
- Dense, frame-level progress and success predictions internally.
|
||||
- A scalar reward through `compute_reward()` for downstream LeRobot reward-model usage.
|
||||
|
||||
This page focuses on using the published ROBOMETER checkpoint as a zero-shot reward model. Training ROBOMETER from scratch is outside the current LeRobot integration.
|
||||
|
||||
## Installation Requirements
|
||||
|
||||
1. Install LeRobot by following the [Installation Guide](./installation).
|
||||
2. Install the ROBOMETER dependencies:
|
||||
|
||||
```bash
|
||||
pip install -e ".[robometer]"
|
||||
```
|
||||
|
||||
If you use `uv` directly from a source checkout:
|
||||
|
||||
```bash
|
||||
uv sync --extra robometer
|
||||
```
|
||||
|
||||
ROBOMETER uses a Qwen3-VL-4B backbone, so GPU inference is strongly recommended.
|
||||
|
||||
## Model Inputs and Outputs
|
||||
|
||||
ROBOMETER expects:
|
||||
|
||||
- A trajectory video or sequence of frames.
|
||||
- A natural-language task description.
|
||||
|
||||
In LeRobot datasets, the preprocessor reads:
|
||||
|
||||
| Config field | Default | Meaning |
|
||||
| ------------------------- | ------------------------ | ----------------------------------------------------- |
|
||||
| `reward_model.image_key` | `observation.images.top` | Camera/video observation used by ROBOMETER |
|
||||
| `reward_model.task_key` | `task` | Key in complementary data that stores the task string |
|
||||
| `reward_model.max_frames` | `8` | Maximum number of frames passed to ROBOMETER |
|
||||
|
||||
The model predicts per-frame progress and success internally. The LeRobot reward API returns a scalar per sample:
|
||||
|
||||
- `reward_output="progress"` (default): return the last-frame progress, clamped to `[0, 1]`.
|
||||
- `reward_output="success"`: return `1.0` if the last-frame success probability is above `success_threshold`, otherwise `0.0`.
|
||||
|
||||
## Usage
|
||||
|
||||
### Load the Reward Model Directly
|
||||
|
||||
```python
|
||||
from lerobot.rewards.robometer import RobometerConfig, RobometerRewardModel
|
||||
|
||||
cfg = RobometerConfig(
|
||||
pretrained_path="lerobot/Robometer-4B",
|
||||
device="cuda",
|
||||
reward_output="progress",
|
||||
)
|
||||
reward_model = RobometerRewardModel.from_pretrained(cfg.pretrained_path, config=cfg)
|
||||
```
|
||||
|
||||
### Encode Frames and Compute a Reward
|
||||
|
||||
For a direct Python call, provide frames as `uint8` arrays with shape `(T, H, W, C)` and a task string:
|
||||
|
||||
```python
|
||||
from lerobot.rewards.robometer.modeling_robometer import ROBOMETER_FEATURE_PREFIX
|
||||
from lerobot.rewards.robometer.processor_robometer import RobometerEncoderProcessorStep
|
||||
|
||||
# frames: np.ndarray, shape (T, H, W, C), dtype uint8
|
||||
# task: str
|
||||
encoder = RobometerEncoderProcessorStep(
|
||||
base_model_id=cfg.base_model_id,
|
||||
use_multi_image=cfg.use_multi_image,
|
||||
use_per_frame_progress_token=cfg.use_per_frame_progress_token,
|
||||
max_frames=cfg.max_frames,
|
||||
)
|
||||
|
||||
encoded = encoder.encode_samples([(frames, task)])
|
||||
batch = {f"{ROBOMETER_FEATURE_PREFIX}{key}": value for key, value in encoded.items()}
|
||||
|
||||
reward = reward_model.compute_reward(batch)
|
||||
```
|
||||
|
||||
`reward` is a tensor of shape `(batch_size,)`.
|
||||
|
||||
### Use the Reward Factory
|
||||
|
||||
You can also instantiate ROBOMETER through the reward factory:
|
||||
|
||||
```python
|
||||
from lerobot.rewards import make_reward_model, make_reward_model_config, make_reward_pre_post_processors
|
||||
|
||||
cfg = make_reward_model_config(
|
||||
"robometer",
|
||||
pretrained_path="lerobot/Robometer-4B",
|
||||
device="cuda",
|
||||
image_key="observation.images.top",
|
||||
)
|
||||
reward_model = make_reward_model(cfg)
|
||||
preprocessor, postprocessor = make_reward_pre_post_processors(cfg)
|
||||
```
|
||||
|
||||
The preprocessor writes Qwen-VL tensors under the `observation.robometer.*` namespace, and `compute_reward()` reads those encoded tensors.
|
||||
|
||||
## Configuration Notes
|
||||
|
||||
### Backbone and Vocabulary
|
||||
|
||||
The published checkpoint uses a Qwen3-VL-4B backbone. ROBOMETER adds five special tokens to the tokenizer in a fixed order:
|
||||
|
||||
```text
|
||||
<|split_token|>
|
||||
<|reward_token|>
|
||||
<|pref_token|>
|
||||
<|sim_token|>
|
||||
<|prog_token|>
|
||||
```
|
||||
|
||||
`<|prog_token|>` is inserted after each frame and is the hidden-state position used for per-frame progress and success prediction. `<|split_token|>` and `<|pref_token|>` are used by the paper's pairwise trajectory preference objective. `<|reward_token|>` and `<|sim_token|>` are preserved for checkpoint compatibility.
|
||||
|
||||
The LeRobot config stores a serialized `vlm_config` with the post-resize vocabulary so the model can reload from `config.json` without downloading the base Qwen weights first. For `Qwen/Qwen3-VL-4B-Instruct`, the tokenizer length is `151669`, and the five ROBOMETER tokens produce the checkpoint vocabulary size `151674`.
|
||||
|
||||
### Progress Prediction
|
||||
|
||||
In the published checkpoint, progress is discrete. The progress head outputs logits over `progress_discrete_bins=10` uniformly spaced bin centers in `[0, 1]`. LeRobot converts these logits into a continuous value by applying a softmax and taking the expectation over bin centers, matching the upstream ROBOMETER implementation.
|
||||
|
||||
### Success Prediction
|
||||
|
||||
The success head outputs raw logits per frame. LeRobot converts them to probabilities with `sigmoid`. When `reward_output="success"`, `compute_reward()` thresholds the last-frame success probability using `success_threshold`.
|
||||
|
||||
## Limitations
|
||||
|
||||
- The current LeRobot integration is inference-only; it does not implement ROBOMETER training or preference-pair training.
|
||||
- `compute_reward()` returns a scalar per sample for the LeRobot reward-model API, even though ROBOMETER predicts per-frame progress and success internally.
|
||||
- ROBOMETER is video-language based; it does not use privileged robot state such as contact forces or object poses.
|
||||
|
||||
## References
|
||||
|
||||
- [ROBOMETER project](https://robometer.github.io/)
|
||||
- [ROBOMETER paper](https://arxiv.org/abs/2603.02115)
|
||||
- [Original ROBOMETER code](https://github.com/robometer/robometer)
|
||||
- [Published ROBOMETER-4B checkpoint](https://huggingface.co/lerobot/Robometer-4B)
|
||||
- [Qwen3-VL-4B-Instruct](https://huggingface.co/Qwen/Qwen3-VL-4B-Instruct)
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@inproceedings{liang2026robometer,
|
||||
title = {Robometer: Scaling General-Purpose Robotic Reward Models via Trajectory Comparisons},
|
||||
author={Anthony Liang and Yigit Korkmaz and Jiahui Zhang and Minyoung Hwang and Abrar Anwar and Sidhant Kaushik and Aditya Shah and Alex S. Huang and Luke Zettlemoyer and Dieter Fox and Yu Xiang and Anqi Li and Andreea Bobu and Abhishek Gupta and Stephen Tu and Erdem Biyik and Jesse Zhang},
|
||||
year={2026},
|
||||
booktitle={Robotics: Science and Systems 2026},
|
||||
}
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
This LeRobot integration follows the **Apache 2.0 License** used by LeRobot. Check the upstream ROBOMETER code and model pages for the licenses of the original implementation and released checkpoints.
|
||||
+9
-9
@@ -151,18 +151,18 @@ lerobot-rollout \
|
||||
--device=cuda
|
||||
```
|
||||
|
||||
## How It Relates to Remote Inference
|
||||
## How It Differs from the Async Inference in LeRobot
|
||||
|
||||
Both RTC and [remote inference](./remote_inference) improve real-time robot control, but they solve different problems.
|
||||
Both RTC and [async inference](./async) improve real-time robot control, but they solve different problems.
|
||||
|
||||
| Aspect | Remote Inference | RTC |
|
||||
| ------------- | ------------------------------------------------------------------------ | --------------------------------------------------- |
|
||||
| **Problem** | The policy is too large (or too slow) for the edge machine | Discontinuities between action chunks |
|
||||
| **Solution** | Run inference on a GPU server; the robot executes buffered action chunks | Guide new chunks to continue smoothly from previous |
|
||||
| **Benefit** | Weightless edge clients, one GPU serves many robots | Smooth transitions, natural motion |
|
||||
| **Best Used** | Large models with high inference latency, robot fleets | Flow-matching based policies |
|
||||
| Aspect | Async Inference | RTC |
|
||||
| ------------- | -------------------------------------------------------------------------- | --------------------------------------------------- |
|
||||
| **Problem** | Idle frames while waiting for inference | Discontinuities between action chunks |
|
||||
| **Solution** | Decouple prediction from execution | Guide new chunks to continue smoothly from previous |
|
||||
| **Benefit** | No waiting, continuous action | Smooth transitions, natural motion |
|
||||
| **Best Used** | Async inference is best used with large models with high inference latency | Flow-matching based policies |
|
||||
|
||||
**Use both together** (`--inference.type=remote` with `--inference.rtc.execution_horizon=...`) for maximum smoothness and reactivity: the remote engine reuses RTC's chunk-merging machinery client-side while the server runs prefix-conditioned chunk generation.
|
||||
**Use both together** for maximum smoothness and reactivity!
|
||||
|
||||
## Advanced: Debug Tracking
|
||||
|
||||
|
||||
@@ -1,177 +0,0 @@
|
||||
# TOPReward
|
||||
|
||||
TOPReward is a **zero-shot reward model** that extracts token log-probabilities from an off-the-shelf vision-language model (VLM) as a robotic reward signal. Given a video trajectory and a task instruction, it returns the VLM's log-likelihood that the instruction is true — no fine-tuning required.
|
||||
|
||||
**Paper**: [TOPReward: Token Probabilities as Hidden Zero-Shot Rewards for Robotics](https://arxiv.org/abs/2602.19313)
|
||||
**Project**: [topreward.github.io](https://topreward.github.io/webpage/)
|
||||
**Original code**: [github.com/TOPReward/TOPReward](https://github.com/TOPReward/TOPReward)
|
||||
**Default backbone**: [Qwen/Qwen3-VL-8B-Instruct](https://huggingface.co/Qwen/Qwen3-VL-8B-Instruct)
|
||||
|
||||
## Overview
|
||||
|
||||
TOPReward asks a generic VLM how likely a task instruction is, **conditioned on the video** of a robot trying to complete that task. Concretely, given:
|
||||
|
||||
- A trajectory video (a sequence of frames).
|
||||
- A task instruction (e.g. _"open the drawer"_).
|
||||
|
||||
it builds a chat prompt of the form
|
||||
|
||||
```text
|
||||
<video>
|
||||
"The above video shows a robot manipulation trajectory that completes the
|
||||
following task: <instruction> Decide whether the above statement is True
|
||||
or not. The answer is: True"
|
||||
```
|
||||
|
||||
forwards it through the VLM, label-masks everything except the very last token, and reads back the log-probability of that token — by default the literal `"True"` that closes the suffix template. The resulting `log P("True" | video + prompt + instruction)` is the reward.
|
||||
|
||||
Because the method only depends on a frozen VLM, TOPReward is **zero-shot**: there are no fine-tuned weights to host. The "model" in LeRobot is a small wrapper around `transformers`' `Qwen3VLForConditionalGeneration` plus the label-masking logic. The processor owns the tokeniser and builds the full chat prompt (EO-1/Robometer pattern).
|
||||
|
||||
## What the LeRobot integration covers
|
||||
|
||||
- Standard `reward_model.type=topreward` configuration through LeRobot.
|
||||
- VLM loading via the `transformers` `Qwen3VLForConditionalGeneration` API.
|
||||
- Prompt assembly + tokenisation in the processor (matching upstream `QwenClient.compute_instruction_reward`).
|
||||
- `compute_reward()` returns one scalar log-prob per sample.
|
||||
- LeRobot reward-model save/load — `save_pretrained` writes only `config.json` (the VLM is identified by `vlm_name`).
|
||||
- An offline labeling script that writes a `topreward_progress.parquet` (SARM-compatible schema) for RA-BC and overlay.
|
||||
|
||||
The current LeRobot port supports the **Qwen3-VL client only**. Other upstream clients (Gemini, OpenAI, Gemma, Molmo) can be added as follow-up extras.
|
||||
|
||||
## Installation Requirements
|
||||
|
||||
1. Install LeRobot following the [Installation Guide](./installation).
|
||||
2. Install the TOPReward optional extra:
|
||||
|
||||
```bash
|
||||
pip install -e ".[topreward]"
|
||||
```
|
||||
|
||||
or, with `uv` from a source checkout:
|
||||
|
||||
```bash
|
||||
uv sync --extra topreward
|
||||
```
|
||||
|
||||
This pulls in `transformers`. The first time you run TOPReward, Hugging Face will also download the VLM weights from the Hub (~16 GB for Qwen3-VL-8B-Instruct). A GPU is strongly recommended.
|
||||
|
||||
## Model Inputs and Outputs
|
||||
|
||||
TOPReward expects:
|
||||
|
||||
- A trajectory video or sequence of frames.
|
||||
- A natural-language task description.
|
||||
|
||||
In LeRobot datasets the preprocessor reads:
|
||||
|
||||
| Config field | Default | Meaning |
|
||||
| ------------------------- | --------------------------- | --------------------------------------------- |
|
||||
| `reward_model.image_key` | `observation.images.top` | Camera observation used by TOPReward |
|
||||
| `reward_model.task_key` | `task` | Key in complementary data for the task string |
|
||||
| `reward_model.max_frames` | `16` | Cap on frames per sample |
|
||||
| `reward_model.fps` | `2.0` | Metadata passed to the Qwen video processor |
|
||||
| `reward_model.vlm_name` | `Qwen/Qwen3-VL-8B-Instruct` | Hugging Face Hub id of the underlying VLM |
|
||||
|
||||
The model returns:
|
||||
|
||||
- `compute_reward(batch)`: one log-probability per sample. Higher = better task-video alignment. When `success_threshold` is finite, returns the binary thresholded value instead.
|
||||
|
||||
## Usage
|
||||
|
||||
### Load the reward model directly
|
||||
|
||||
```python
|
||||
from lerobot.rewards.topreward import TOPRewardConfig, TOPRewardModel
|
||||
|
||||
cfg = TOPRewardConfig(
|
||||
vlm_name="Qwen/Qwen3-VL-8B-Instruct",
|
||||
device="cuda",
|
||||
)
|
||||
reward_model = TOPRewardModel(cfg)
|
||||
```
|
||||
|
||||
### Use the reward factory
|
||||
|
||||
```python
|
||||
from lerobot.rewards import make_reward_model, make_reward_model_config, make_reward_pre_post_processors
|
||||
|
||||
cfg = make_reward_model_config(
|
||||
"topreward",
|
||||
vlm_name="Qwen/Qwen3-VL-8B-Instruct",
|
||||
device="cuda",
|
||||
image_key="observation.images.top",
|
||||
)
|
||||
reward_model = make_reward_model(cfg)
|
||||
preprocessor, postprocessor = make_reward_pre_post_processors(cfg)
|
||||
```
|
||||
|
||||
The preprocessor tokenises the full prompt (video + prefix + instruction suffix), writes Qwen-VL tensors + `prompt_length` under `observation.topreward.*`. The model reads those tensors, label-masks based on `prompt_length`, and extracts the log-prob reward.
|
||||
|
||||
### Offline dataset labeling
|
||||
|
||||
Write a `topreward_progress.parquet` for RA-BC training and overlay videos:
|
||||
|
||||
```bash
|
||||
# Sparse-dense (15 anchors per episode, matches upstream)
|
||||
uv run python -m lerobot.rewards.topreward.compute_rabc_weights \
|
||||
--dataset-repo-id lerobot/libero_10_image \
|
||||
--num-samples 15 \
|
||||
--device cuda
|
||||
```
|
||||
|
||||
Then render the progress overlay for any episode:
|
||||
|
||||
```bash
|
||||
uv run examples/dataset/create_progress_videos.py \
|
||||
--repo-id lerobot/libero_10_image \
|
||||
--episode 0 \
|
||||
--progress-file topreward_progress.parquet \
|
||||
--gif
|
||||
```
|
||||
|
||||
## Configuration Notes
|
||||
|
||||
### Prompt knobs
|
||||
|
||||
The default prompt mirrors the upstream paper:
|
||||
|
||||
```text
|
||||
prompt_prefix = "The above video shows a robot manipulation trajectory that completes the following task: "
|
||||
prompt_suffix_template = "{instruction} Decide whether the above statement is True or not. The answer is: True"
|
||||
```
|
||||
|
||||
Both are exposed on `TOPRewardConfig` for ablation. The suffix template **must** contain `{instruction}`.
|
||||
|
||||
### Chat template
|
||||
|
||||
`add_chat_template=True` wraps the full prompt (including instruction) with the tokenizer's chat template before tokenisation. Default is `False`, matching the upstream paper's main experiments.
|
||||
|
||||
## Limitations
|
||||
|
||||
- The current LeRobot port is **inference-only and zero-shot**; `forward()` is not overridden and `is_trainable` returns `False`.
|
||||
- Only the **Qwen3-VL family** is supported; other upstream clients are out of scope.
|
||||
- TOPReward inherits the underlying VLM's biases.
|
||||
|
||||
## References
|
||||
|
||||
- [TOPReward project page](https://topreward.github.io/webpage/)
|
||||
- [TOPReward paper](https://arxiv.org/abs/2602.19313)
|
||||
- [Original TOPReward code](https://github.com/TOPReward/TOPReward)
|
||||
- [Qwen3-VL-8B-Instruct](https://huggingface.co/Qwen/Qwen3-VL-8B-Instruct)
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@article{chen2026topreward,
|
||||
title={TOPReward: Token Probabilities as Hidden Zero-Shot Rewards for Robotics},
|
||||
author={Chen, Shirui and Harrison, Cole and Lee, Ying-Chun and Yang, Angela Jin and
|
||||
Ren, Zhongzheng and Ratliff, Lillian J and Duan, Jiafei and Fox, Dieter and
|
||||
Krishna, Ranjay},
|
||||
journal={arXiv preprint arXiv:2602.19313},
|
||||
year={2026}
|
||||
}
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
The original TOPReward codebase is MIT-licensed. The LeRobot port follows the LeRobot Apache 2.0 license; the wrapped Qwen3-VL weights are subject to the original Qwen license.
|
||||
@@ -1,235 +0,0 @@
|
||||
# VLA-JEPA
|
||||
|
||||
This is the LeRobot port of **VLA-JEPA**, a Vision-Language-Action model that combines a Qwen3-VL language backbone with a self-supervised video world model (V-JEPA2) and a flow-matching DiT action head.
|
||||
|
||||
---
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
VLA-JEPA has three main components:
|
||||
|
||||
| Component | Module | Role |
|
||||
| ----------------------- | --------------------------------- | ------------------------------------------------------- |
|
||||
| **Qwen3-VL backbone** | `Qwen3VLInterface` | Fuses images + language instruction into context tokens |
|
||||
| **DiT-B action head** | `VLAJEPAActionHead` | Flow-matching diffusion over the action chunk |
|
||||
| **V-JEPA2 world model** | `ActionConditionedVideoPredictor` | Self-supervised video prediction loss (training only) |
|
||||
|
||||
### Data flow
|
||||
|
||||
**Training:**
|
||||
|
||||
1. A video clip of `num_video_frames` frames is encoded by V-JEPA2 into per-frame patch tokens.
|
||||
2. The Qwen3-VL backbone processes multi-view images + the task instruction and produces a sequence of context tokens that includes special action tokens (for world model conditioning) and embodied tokens.
|
||||
3. The action head receives those context tokens as cross-attention keys/values and predicts a denoised action chunk via flow matching.
|
||||
4. The world model predictor uses the action tokens extracted from Qwen to predict future V-JEPA2 frame embeddings; a regression loss on those predictions is added to the action loss.
|
||||
|
||||
**Inference:**
|
||||
Only Qwen + the action head are used. The world model is not needed at inference time.
|
||||
|
||||
### Action head details
|
||||
|
||||
Available presets via `action_model_type`:
|
||||
|
||||
| Preset | Hidden dim | Heads | Head dim |
|
||||
| ------- | ---------- | ----- | -------- |
|
||||
| `DiT-B` | 768 | 12 | 64 |
|
||||
| `DiT-L` | 1536 | 32 | 48 |
|
||||
|
||||
### World model details
|
||||
|
||||
The video predictor is a ViT-style transformer (`ActionConditionedVideoPredictor`) that takes:
|
||||
|
||||
- **Frame tokens**: V-JEPA2 patch embeddings projected to `predictor_embed_dim`
|
||||
- **Action tokens**: Qwen action token embeddings projected to `predictor_embed_dim`
|
||||
|
||||
It uses block-causal attention so each temporal step can attend to all previous steps. The predictor's input `embed_dim` equals `num_views × video_encoder_hidden_size` (e.g. 2 views × 1024 = 2048 for the pretrained checkpoints).
|
||||
|
||||
---
|
||||
|
||||
## Pretrained Checkpoints
|
||||
|
||||
Three checkpoints are available directly inside the LeRobot org here: [`lerobot/VLA-JEPA`](https://huggingface.co/collections/lerobot/vla-jepa), converted from [ginwind/VLA-JEPA](https://huggingface.co/ginwind/VLA-JEPA):
|
||||
|
||||
| Checkpoint | Dataset | Cameras | World model | Action dim |
|
||||
| ----------------------------- | ----------------- | ----------------------- | ----------- | ---------- |
|
||||
| `lerobot/VLA-JEPA-LIBERO` | LIBERO-10 | 2 (agentview + wrist) | Enabled | 7 |
|
||||
| `lerobot/VLA-JEPA-Pretrain` | DROID 1.0.1 | 2 (exterior left views) | Enabled | 7 |
|
||||
| `lerobot/VLA-JEPA-SimplerEnv` | OXE Bridge / RT-1 | 1 (view duplicated ×2) | Enabled | 7 |
|
||||
|
||||
All checkpoints use `Qwen/Qwen3-VL-2B-Instruct` as the language backbone.
|
||||
|
||||
---
|
||||
|
||||
## Configuration
|
||||
|
||||
Key parameters in `VLAJEPAConfig`:
|
||||
|
||||
| Parameter | Default | Description |
|
||||
| ------------------------- | ------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `chunk_size` | 7 | Number of actions predicted per inference call |
|
||||
| `n_action_steps` | 7 | Steps executed from the predicted chunk before re-planning |
|
||||
| `num_video_frames` | 8 | Video clip length fed to the world model |
|
||||
| `enable_world_model` | `True` | Whether to load and train the V-JEPA2 predictor |
|
||||
| `world_model_loss_weight` | 0.1 | Weight of the JEPA prediction loss relative to the action loss |
|
||||
| `num_inference_timesteps` | 4 | Euler integration steps for action denoising |
|
||||
| `freeze_qwen` | `False` | Freeze the Qwen3-VL backbone and only train the action head |
|
||||
| `reinit_modules` | `None` | Key prefixes allowed to be randomly re-initialised on load (for cross-embodiment transfer, see [Fine-tuning on a different embodiment](#fine-tuning-on-a-different-embodiment)) |
|
||||
| `gripper_dim` | 6 | Index of the gripper dimension in the action vector (e.g. 6 for a 7-DoF arm with gripper as the last joint) |
|
||||
| `gripper_threshold` | 0.5 | Threshold used by `pre_snap_gripper_action` and `binarize_gripper_action` to binarize the gripper dimension |
|
||||
| `pre_snap_gripper_action` | `True` | Snap the gripper dim to {0, 1} before unnormalization. Set to `False` for robots without a binary gripper |
|
||||
| `binarize_gripper_action` | `True` | Binarize the gripper dim to {-1, 1} after unnormalization. Set to `False` for robots without a binary gripper |
|
||||
|
||||
---
|
||||
|
||||
## Training
|
||||
|
||||
Number of training steps may vary based on dataset size and compute budget. The original paper pretrained for 50k on ssv2 + droid jointly, then additional 30k steps for LIBERO, but fewer steps may still yield good performance when fine-tuning from the provided pretrained checkpoints.
|
||||
|
||||
### Full training from scratch
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
policy.type=vla_jepa \
|
||||
policy.repo_id=your_org/your_repo \
|
||||
dataset.repo_id=your_org/your_dataset
|
||||
```
|
||||
|
||||
### Fine-tuning from a pretrained checkpoint
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--policy.path=lerobot/VLA-JEPA-Pretrain \
|
||||
--policy.repo_id=your_org/your_repo \
|
||||
--dataset.repo_id=your_org/your_dataset
|
||||
```
|
||||
|
||||
If you want to freeze the Qwen backbone and only train the action head, set `policy.freeze_qwen=True`:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--policy.path=lerobot/VLA-JEPA-Pretrain \
|
||||
--policy.repo_id=your_org/your_repo \
|
||||
--policy.freeze_qwen=true \
|
||||
--dataset.repo_id=your_org/your_dataset
|
||||
```
|
||||
|
||||
### Fine-tuning on a different embodiment
|
||||
|
||||
When the target robot has a different action or state dimensionality than the pretrained checkpoint, the input/output projection layers of the action head will have mismatched shapes and cannot be loaded directly. `reinit_modules` lets you list the key prefixes that are allowed to mismatch — those layers are randomly re-initialised while every other weight is reused from the checkpoint. Any shape mismatch outside the listed prefixes raises an error.
|
||||
|
||||
The layers that depend on `action_dim` and `state_dim` are:
|
||||
|
||||
| Layer | Key prefix |
|
||||
| ----------------------------------------- | ----------------------------------- |
|
||||
| Action encoder (action_dim → inner_dim) | `model.action_model.action_encoder` |
|
||||
| Action decoder (hidden_size → action_dim) | `model.action_model.action_decoder` |
|
||||
| State encoder (state_dim → inner_dim) | `model.action_model.state_encoder` |
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--policy.path=lerobot/VLA-JEPA-Pretrain \
|
||||
--policy.repo_id=your_org/your_repo \
|
||||
--policy.freeze_qwen=true \
|
||||
--policy.reinit_modules='["model.action_model.action_encoder", "model.action_model.action_decoder", "model.action_model.state_encoder"]' \
|
||||
--dataset.repo_id=your_org/your_dataset
|
||||
```
|
||||
|
||||
If your robot has no proprioceptive state, omit `model.action_model.state_encoder` from the list.
|
||||
|
||||
### Reproducing the LIBERO results
|
||||
|
||||
**Training on LIBERO:**
|
||||
starts the training from the Pretrain checkpoint, trains for 30k steps on the LIBERO dataset.
|
||||
Original paper mentions training across 8 GPUs with a batch size of 32, meaning global batch size of 256.
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--policy.path=lerobot/VLA-JEPA-Pretrain \
|
||||
--policy.repo_id=your_org/your_repo \
|
||||
--dataset.repo_id=HuggingFaceVLA/libero \
|
||||
--steps=30000
|
||||
```
|
||||
|
||||
**Evaluating the pretrained LIBERO-10 checkpoint:**
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--policy.path=lerobot/VLA-JEPA-LIBERO \
|
||||
--env.type=libero \
|
||||
--env.task=libero_spatial,libero_object,libero_goal,libero_10 \
|
||||
--eval.n_episodes=10 \
|
||||
--eval.batch_size=5
|
||||
```
|
||||
|
||||
To evaluate a subset of tasks only:
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--policy.path=lerobot/VLA-JEPA-LIBERO \
|
||||
--env.type=libero \
|
||||
--env.task=libero_10 \
|
||||
--env.task_ids='[0,1,2]' \
|
||||
--eval.n_episodes=10 \
|
||||
--eval.batch_size=5
|
||||
```
|
||||
|
||||
**Expected results:**
|
||||
|
||||
| Suite | Episodes | Successes | Success Rate |
|
||||
| -------------- | -------- | --------- | ------------ |
|
||||
| libero_spatial | 100 | 93 | **95.0%** |
|
||||
| libero_object | 100 | 100 | **100.0%** |
|
||||
| libero_goal | 100 | 98 | **98.0%** |
|
||||
| libero_10 | 100 | 96 | **93.0%** |
|
||||
| **Overall** | **400** | **387** | **96.5%** |
|
||||
|
||||
---
|
||||
|
||||
## Fine-tuning on datasets with a different number of cameras
|
||||
|
||||
The pretrained world model predictor was trained with `embed_dim = jepa_tubelet_size × 1024` (default `jepa_tubelet_size=2`).
|
||||
|
||||
**Default behaviour — view padding / trimming (no action required)**
|
||||
|
||||
When fine-tuning from `VLA-JEPA-Pretrain` the model automatically adjusts the number of views fed to the world model to match `jepa_tubelet_size`:
|
||||
|
||||
- **Single-view datasets (e.g. BridgeV2):** the single-view latent is duplicated to produce a two-view world-model input, preserving the JEPA self-supervised signal without any weight mismatch.
|
||||
- **>2-view datasets (e.g. DROID with 3 views):** all views are passed to the Qwen backbone (for richer context), but only the first `jepa_tubelet_size` views (one wrist + one third-person, following the configured view order) are used for the world model.
|
||||
|
||||
**Option 1 — Disable the world model**
|
||||
|
||||
Set `enable_world_model=False` to skip the JEPA loss entirely. Only the Qwen backbone and action head are loaded and trained. This is sufficient for good action performance.
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--policy.path=lerobot/VLA-JEPA-Pretrain \
|
||||
--policy.enable_world_model=false \
|
||||
--policy.repo_id=your_org/your_repo \
|
||||
--dataset.repo_id=your_org/single_camera_dataset
|
||||
```
|
||||
|
||||
**Option 2 — Reinitialize the predictor input projection**
|
||||
|
||||
If you want to change `jepa_tubelet_size` to a value other than 2, load the checkpoint with `strict=False` and reinitialize `model.video_predictor.predictor_embed` for the new `embed_dim`. All other predictor block weights (attention, MLP, norm, output projection) are camera-count-agnostic and can be reused from the pretrained checkpoint.
|
||||
|
||||
---
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@misc{sun2026vlajepaenhancingvisionlanguageactionmodel,
|
||||
title = {VLA-JEPA: Enhancing Vision-Language-Action Model with Latent World Model},
|
||||
author = {Jingwen Sun and Wenyao Zhang and Zekun Qi and Shaojie Ren and Zezhi Liu and Hanxin Zhu and Guangzhong Sun and Xin Jin and Zhibo Chen},
|
||||
year = {2026},
|
||||
eprint = {2602.10098},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.RO},
|
||||
url = {https://arxiv.org/abs/2602.10098},
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## License
|
||||
|
||||
Weights are distributed under the license terms of the original [ginwind/VLA-JEPA](https://huggingface.co/ginwind/VLA-JEPA) repository (**Apache 2.0 License**). The LeRobot integration code follows the **Apache 2.0 License**.
|
||||
@@ -1,115 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Example manifest for `lerobot-policy-server --manifest server.yaml`.
|
||||
#
|
||||
# One process = one (model, revision, dtype, device) on one GPU. Dynamic
|
||||
# model loading is deliberately unsupported: pre-warmed processes keep
|
||||
# capacity planning honest. Every field below can also be overridden on
|
||||
# the command line via draccus, e.g. --model.repo_or_path=... or
|
||||
# --zenoh.connect_endpoints='["tcp/other-router:7447"]'.
|
||||
#
|
||||
# Field names mirror the dataclasses in src/lerobot/policy_server/manifest.py.
|
||||
|
||||
# --- Which policy this process serves, and where it runs ------------------
|
||||
model:
|
||||
# Hub repo id (org/name) or a local checkpoint directory. Required.
|
||||
repo_or_path: lerobot/pi0_towels
|
||||
# Hub revision: branch, tag, or commit sha.
|
||||
revision: main
|
||||
# Optional torch dtype cast applied after load (e.g. "bfloat16",
|
||||
# "float16"). null keeps the checkpoint's native dtype.
|
||||
dtype: bfloat16
|
||||
# Inference device, e.g. "cuda", "cuda:1", "cpu".
|
||||
device: cuda
|
||||
|
||||
# --- Task namespace --------------------------------------------------------
|
||||
# The task this service is published under. VLA clients may override the
|
||||
# task per session unless `pin_task` is true, in which case session opens
|
||||
# with a different task string are rejected.
|
||||
default_task: "fold the towel"
|
||||
pin_task: false
|
||||
# Optional override for the <task_slug> key segment of the Zenoh prefix
|
||||
# (defaults to a slug of `default_task`).
|
||||
service_name: ""
|
||||
|
||||
# --- Serving mode & capacity ------------------------------------------------
|
||||
# "auto" resolves from the policy classification: shared for verified
|
||||
# chunk-stateless policies (act/pi0/pi05, smolvla with n_obs_steps=1),
|
||||
# exclusive otherwise. Chunk-stateful policies — e.g. diffusion, whose
|
||||
# predict_action_chunk reads select_action-fed queues — are always forced
|
||||
# to "exclusive" (max_sessions=1); "shared" cannot override that.
|
||||
serving_mode: auto
|
||||
|
||||
# Capacity rule-of-thumb: with t = server seconds per inference, r = each
|
||||
# client's request rate (self-clocked to ~1-4 Hz, not the control rate),
|
||||
# H = RTC execution horizon, and dt = control period:
|
||||
# max_sessions ~= min( 0.8 / (r*t), (H*dt/2 - network RTT) / t )
|
||||
# e.g. ACT @ 20 ms, 1 Hz refresh -> ~40 clients/GPU; Pi0 @ 150 ms -> ~5.
|
||||
# Session opens beyond this are rejected with the current load in the
|
||||
# reply, so clients retry another replica.
|
||||
max_sessions: 5
|
||||
|
||||
# Dummy inferences run at startup so the first real request does not pay
|
||||
# for CUDA graph/kernel warmup.
|
||||
warmup_inferences: 2
|
||||
|
||||
# --- FPS contract -----------------------------------------------------------
|
||||
# Control rate the policy was trained at. Clients reporting a different
|
||||
# fps get a warning — or a hard reject when `strict_fps` is true.
|
||||
trained_fps: 30.0
|
||||
strict_fps: false
|
||||
|
||||
# --- Real Time Chunking (RTC) -----------------------------------------------
|
||||
# Global to this process: init_rtc_processor mutates the policy instance,
|
||||
# so RTC is a per-process decision, not per-session. Only rtc-capable
|
||||
# families (pi0/pi05/smolvla) honor it; others are downgraded to plain
|
||||
# chunk-append at session open.
|
||||
rtc:
|
||||
enabled: true
|
||||
# Number of actions executed from each chunk before the next chunk is
|
||||
# blended in (the H in the capacity formula above).
|
||||
execution_horizon: 10
|
||||
|
||||
# --- Housekeeping ------------------------------------------------------------
|
||||
# Sessions with no liveliness token and no traffic for this long are
|
||||
# garbage-collected (belt-and-braces behind liveliness GC).
|
||||
session_idle_timeout_s: 300.0
|
||||
|
||||
# --- Transport ----------------------------------------------------------------
|
||||
# Robots and servers both *dial out* to a zenohd router in production
|
||||
# (mode: client). mode: peer + listen_endpoints supports router-less LAN
|
||||
# and loopback test deployments. Multicast scouting is always disabled:
|
||||
# fleet discovery is configuration, not protocol magic.
|
||||
zenoh:
|
||||
mode: client
|
||||
connect_endpoints:
|
||||
- tcp/router.gpu-cluster.internal:7447
|
||||
listen_endpoints: []
|
||||
# mTLS material (PEM paths). All three are required for tls/ endpoints;
|
||||
# leave them null for plain tcp/ inside a trusted network.
|
||||
# tls_root_ca_certificate: /etc/lerobot/tls/ca.pem
|
||||
# tls_connect_certificate: /etc/lerobot/tls/server.pem
|
||||
# tls_connect_private_key: /etc/lerobot/tls/server.key
|
||||
# Escape hatch: raw JSON5 merged into the zenoh config last.
|
||||
# extra_config_json5: '{transport: {link: {tx: {queue: {size: {data: 4}}}}}}'
|
||||
|
||||
# --- Observability -------------------------------------------------------------
|
||||
# HTTP health + Prometheus metrics port; 0 disables the endpoint.
|
||||
health_port: 9100
|
||||
|
||||
# Optional bounded request/response capture for offline replay.
|
||||
debug:
|
||||
capture_dir: null
|
||||
capture_max: 256
|
||||
@@ -0,0 +1,17 @@
|
||||
from lerobot.async_inference.configs import PolicyServerConfig
|
||||
from lerobot.async_inference.policy_server import serve
|
||||
|
||||
|
||||
def main():
|
||||
host = ... # something like "127.0.0.1" if you're exposing to localhost
|
||||
port = ... # something like 8080
|
||||
|
||||
config = PolicyServerConfig(
|
||||
host=host,
|
||||
port=port,
|
||||
)
|
||||
serve(config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,62 @@
|
||||
import threading
|
||||
|
||||
from lerobot.async_inference.configs import RobotClientConfig
|
||||
from lerobot.async_inference.helpers import visualize_action_queue_size
|
||||
from lerobot.async_inference.robot_client import RobotClient
|
||||
from lerobot.cameras.opencv import OpenCVCameraConfig
|
||||
from lerobot.robots.so_follower import SO100FollowerConfig
|
||||
|
||||
|
||||
def main():
|
||||
# these cameras must match the ones expected by the policy - find your cameras with lerobot-find-cameras
|
||||
# check the config.json on the Hub for the policy you are using to see the expected camera specs
|
||||
camera_cfg = {
|
||||
"up": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
|
||||
"side": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30),
|
||||
}
|
||||
|
||||
# # find ports using lerobot-find-port
|
||||
follower_port = ... # something like "/dev/tty.usbmodem58760431631"
|
||||
|
||||
# # the robot ids are used the load the right calibration files
|
||||
follower_id = ... # something like "follower_so100"
|
||||
|
||||
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id, cameras=camera_cfg)
|
||||
|
||||
server_address = ... # something like "127.0.0.1:8080" if using localhost
|
||||
|
||||
# 3. Create client configuration
|
||||
client_cfg = RobotClientConfig(
|
||||
robot=robot_cfg,
|
||||
server_address=server_address,
|
||||
policy_device="mps",
|
||||
client_device="cpu",
|
||||
policy_type="act",
|
||||
pretrained_name_or_path="<user>/robot_learning_tutorial_act",
|
||||
chunk_size_threshold=0.5, # g
|
||||
actions_per_chunk=50, # make sure this is less than the max actions of the policy
|
||||
)
|
||||
|
||||
# 4. Create and start client
|
||||
client = RobotClient(client_cfg)
|
||||
|
||||
# 5. Provide a textual description of the task
|
||||
task = ...
|
||||
|
||||
if client.start():
|
||||
# Start action receiver thread
|
||||
action_receiver_thread = threading.Thread(target=client.receive_actions, daemon=True)
|
||||
action_receiver_thread.start()
|
||||
|
||||
try:
|
||||
# Run the control loop
|
||||
client.control_loop(task)
|
||||
except KeyboardInterrupt:
|
||||
client.stop()
|
||||
action_receiver_thread.join()
|
||||
# (Optionally) plot the action queue size
|
||||
visualize_action_queue_size(client.action_queue_size)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
+16
-31
@@ -115,8 +115,8 @@ dataset = [
|
||||
]
|
||||
training = [
|
||||
"lerobot[dataset]",
|
||||
"wandb>=0.24.0,<0.28.0",
|
||||
"lerobot[accelerate-dep]",
|
||||
"accelerate>=1.10.0,<2.0.0",
|
||||
"wandb>=0.24.0,<0.25.0",
|
||||
]
|
||||
hardware = [
|
||||
"lerobot[pynput-dep]",
|
||||
@@ -142,8 +142,7 @@ pygame-dep = ["pygame>=2.5.1,<2.7.0"]
|
||||
# (noble ships urdfdom 3.x). Cap below 0.9.16 until system urdfdom 4.x is broadly available.
|
||||
placo-dep = ["placo>=0.9.6,<0.9.16"]
|
||||
transformers-dep = ["transformers>=5.4.0,<5.6.0"]
|
||||
grpcio-dep = ["grpcio>=1.73.1,<2.0.0", "protobuf>=6.31.1,<8.0.0"]
|
||||
accelerate-dep = ["accelerate>=1.14.0,<2.0.0"]
|
||||
grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"]
|
||||
can-dep = ["python-can>=4.2.0,<5.0.0"]
|
||||
peft-dep = ["peft>=0.18.0,<1.0.0"]
|
||||
scipy-dep = ["scipy>=1.14.0,<2.0.0"]
|
||||
@@ -178,12 +177,7 @@ unitree_g1 = [
|
||||
"lerobot[matplotlib-dep]",
|
||||
"lerobot[pygame-dep]",
|
||||
]
|
||||
# reachy2-sdk caps grpcio<=1.73.1 and protobuf<=6.32.0; quarantined here so downstream users aren't held back. reachy2-sdk is unlikely to release new versions.
|
||||
reachy2 = [
|
||||
"reachy2_sdk>=1.0.15,<1.1.0",
|
||||
"grpcio<=1.73.1",
|
||||
"protobuf<=6.32.0",
|
||||
]
|
||||
reachy2 = ["reachy2_sdk>=1.0.15,<1.1.0"]
|
||||
# Seeed Studio reBot B601-DM follower (motorbridge / CAN) + StarArm102 / reBot Arm 102
|
||||
# leader (motorbridge-smart-servo / FashionStar UART servos).
|
||||
rebot = ["lerobot[motorbridge-dep]", "lerobot[motorbridge-smart-servo-dep]"]
|
||||
@@ -204,8 +198,7 @@ wallx = [
|
||||
"lerobot[qwen-vl-utils-dep]",
|
||||
]
|
||||
pi = ["lerobot[transformers-dep]", "lerobot[scipy-dep]"]
|
||||
molmoact2 = ["lerobot[transformers-dep]", "lerobot[peft-dep]", "lerobot[scipy-dep]"]
|
||||
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "lerobot[accelerate-dep]"]
|
||||
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0"]
|
||||
multi_task_dit = ["lerobot[transformers-dep]", "lerobot[diffusers-dep]"]
|
||||
groot = [
|
||||
"lerobot[transformers-dep]",
|
||||
@@ -218,30 +211,25 @@ groot = [
|
||||
"flash-attn>=2.5.9,<3.0.0 ; sys_platform != 'darwin'"
|
||||
]
|
||||
sarm = ["lerobot[transformers-dep]", "pydantic>=2.0.0,<3.0.0", "faker>=33.0.0,<35.0.0", "lerobot[matplotlib-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||
robometer = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]", "lerobot[peft-dep]"]
|
||||
topreward = ["lerobot[transformers-dep]"]
|
||||
xvla = ["lerobot[transformers-dep]"]
|
||||
eo1 = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||
hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.14,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
||||
vla_jepa = ["lerobot[transformers-dep]", "lerobot[diffusers-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||
hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
||||
|
||||
# Features
|
||||
# Remote inference over Zenoh: lerobot-policy-server + lerobot-rollout --inference.type=remote.
|
||||
# Keep zenohd routers on the same minor version as the Python binding.
|
||||
async = ["eclipse-zenoh>=1.9,<2.0", "msgpack>=1.0.0,<2.0.0"]
|
||||
async = ["lerobot[grpcio-dep]", "lerobot[matplotlib-dep]"]
|
||||
peft = ["lerobot[transformers-dep]", "lerobot[peft-dep]"]
|
||||
|
||||
# Development
|
||||
dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools>=1.73.1,<2.0.0", "mypy>=1.19.1", "ruff>=0.14.1", "lerobot[notebook]"]
|
||||
dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1", "mypy>=1.19.1", "ruff>=0.14.1", "lerobot[notebook]"]
|
||||
notebook = ["jupyter>=1.0.0,<2.0.0", "ipykernel>=6.0.0,<7.0.0"]
|
||||
test = ["pytest>=8.1.0,<9.0.0", "pytest-timeout>=2.4.0,<3.0.0", "pytest-cov>=5.0.0,<8.0.0", "mock-serial>=0.0.1,<0.1.0 ; sys_platform != 'win32'"]
|
||||
video_benchmark = ["scikit-image>=0.23.2,<0.26.0", "pandas>=2.2.2,<2.4.0"]
|
||||
|
||||
# Simulation
|
||||
# NOTE: Explicitly listing scipy helps flatten the dependecy tree.
|
||||
aloha = ["lerobot[dataset]", "gym-aloha>=0.1.4,<0.2.0", "lerobot[scipy-dep]"]
|
||||
aloha = ["lerobot[dataset]", "gym-aloha>=0.1.2,<0.2.0", "lerobot[scipy-dep]"]
|
||||
pusht = ["lerobot[dataset]", "gym-pusht>=0.1.5,<0.2.0", "pymunk>=6.6.0,<7.0.0"] # TODO: Fix pymunk version in gym-pusht instead
|
||||
libero = ["lerobot[dataset]", "lerobot[transformers-dep]", "hf-libero>=0.1.4,<0.2.0; sys_platform == 'linux'", "lerobot[scipy-dep]"]
|
||||
libero = ["lerobot[dataset]", "lerobot[transformers-dep]", "hf-libero>=0.1.3,<0.2.0; sys_platform == 'linux'", "lerobot[scipy-dep]"]
|
||||
metaworld = ["lerobot[dataset]", "metaworld==3.0.0", "lerobot[scipy-dep]"]
|
||||
# NOTE: vlabench is NOT exposed as a `lerobot` extra. Its only distribution
|
||||
# is the OpenMOSS/VLABench GitHub repo (package name `VLABench`, no PyPI
|
||||
@@ -286,12 +274,10 @@ all = [
|
||||
"lerobot[multi_task_dit]",
|
||||
"lerobot[wallx]",
|
||||
"lerobot[pi]",
|
||||
"lerobot[molmoact2]",
|
||||
"lerobot[smolvla]",
|
||||
# "lerobot[groot]", TODO(Steven): Gr00t requires specific installation instructions for flash-attn
|
||||
"lerobot[xvla]",
|
||||
"lerobot[hilserl]",
|
||||
"lerobot[vla_jepa]",
|
||||
"lerobot[async]",
|
||||
"lerobot[dev]",
|
||||
"lerobot[test]",
|
||||
@@ -302,8 +288,6 @@ all = [
|
||||
"lerobot[libero]; sys_platform == 'linux'",
|
||||
"lerobot[metaworld]",
|
||||
"lerobot[sarm]",
|
||||
"lerobot[robometer]",
|
||||
"lerobot[topreward]",
|
||||
"lerobot[peft]",
|
||||
# "lerobot[unitree_g1]", TODO: Unitree requires specific installation instructions for unitree_sdk2
|
||||
]
|
||||
@@ -326,7 +310,6 @@ lerobot-imgtransform-viz="lerobot.scripts.lerobot_imgtransform_viz:main"
|
||||
lerobot-edit-dataset="lerobot.scripts.lerobot_edit_dataset:main"
|
||||
lerobot-setup-can="lerobot.scripts.lerobot_setup_can:main"
|
||||
lerobot-rollout="lerobot.scripts.lerobot_rollout:main"
|
||||
lerobot-policy-server="lerobot.scripts.lerobot_policy_server:main"
|
||||
|
||||
# ---------------- Tool Configurations ----------------
|
||||
|
||||
@@ -420,11 +403,8 @@ default.extend-ignore-identifiers-re = [
|
||||
"ein",
|
||||
"thw",
|
||||
"inpt",
|
||||
"arange",
|
||||
"is_compileable",
|
||||
"ROBOTIS",
|
||||
"OT_VALUE",
|
||||
"VanderBilt"
|
||||
"OT_VALUE"
|
||||
]
|
||||
|
||||
# TODO: Uncomment when ready to use
|
||||
@@ -519,6 +499,11 @@ ignore_errors = false
|
||||
# module = "lerobot.rl.*"
|
||||
# ignore_errors = false
|
||||
|
||||
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.async_inference.*"
|
||||
# ignore_errors = false
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "lerobot.transport.*"
|
||||
ignore_errors = false
|
||||
|
||||
+16
-9
@@ -1,4 +1,4 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -12,12 +12,19 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .configuration_vla_jepa import VLAJEPAConfig
|
||||
from .modeling_vla_jepa import VLAJEPAPolicy
|
||||
from .processor_vla_jepa import make_vla_jepa_pre_post_processors
|
||||
"""
|
||||
Async inference server/client.
|
||||
|
||||
__all__ = [
|
||||
"VLAJEPAConfig",
|
||||
"VLAJEPAPolicy",
|
||||
"make_vla_jepa_pre_post_processors",
|
||||
]
|
||||
Requires: ``pip install 'lerobot[async]'``
|
||||
|
||||
Available modules (import directly)::
|
||||
|
||||
from lerobot.async_inference.policy_server import ...
|
||||
from lerobot.async_inference.robot_client import ...
|
||||
"""
|
||||
|
||||
from lerobot.utils.import_utils import require_package
|
||||
|
||||
require_package("grpcio", extra="async", import_name="grpc")
|
||||
|
||||
__all__: list[str] = []
|
||||
@@ -0,0 +1,203 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.robots.config import RobotConfig
|
||||
|
||||
from .constants import (
|
||||
DEFAULT_FPS,
|
||||
DEFAULT_INFERENCE_LATENCY,
|
||||
DEFAULT_OBS_QUEUE_TIMEOUT,
|
||||
)
|
||||
|
||||
# Aggregate function registry for CLI usage
|
||||
AGGREGATE_FUNCTIONS = {
|
||||
"weighted_average": lambda old, new: 0.3 * old + 0.7 * new,
|
||||
"latest_only": lambda old, new: new,
|
||||
"average": lambda old, new: 0.5 * old + 0.5 * new,
|
||||
"conservative": lambda old, new: 0.7 * old + 0.3 * new,
|
||||
}
|
||||
|
||||
|
||||
def get_aggregate_function(name: str) -> Callable[[torch.Tensor, torch.Tensor], torch.Tensor]:
|
||||
"""Get aggregate function by name from registry."""
|
||||
if name not in AGGREGATE_FUNCTIONS:
|
||||
available = list(AGGREGATE_FUNCTIONS.keys())
|
||||
raise ValueError(f"Unknown aggregate function '{name}'. Available: {available}")
|
||||
return AGGREGATE_FUNCTIONS[name]
|
||||
|
||||
|
||||
@dataclass
|
||||
class PolicyServerConfig:
|
||||
"""Configuration for PolicyServer.
|
||||
|
||||
This class defines all configurable parameters for the PolicyServer,
|
||||
including networking settings and action chunking specifications.
|
||||
"""
|
||||
|
||||
# Networking configuration
|
||||
host: str = field(default="localhost", metadata={"help": "Host address to bind the server to"})
|
||||
port: int = field(default=8080, metadata={"help": "Port number to bind the server to"})
|
||||
|
||||
# Timing configuration
|
||||
fps: int = field(default=DEFAULT_FPS, metadata={"help": "Frames per second"})
|
||||
inference_latency: float = field(
|
||||
default=DEFAULT_INFERENCE_LATENCY, metadata={"help": "Target inference latency in seconds"}
|
||||
)
|
||||
|
||||
obs_queue_timeout: float = field(
|
||||
default=DEFAULT_OBS_QUEUE_TIMEOUT, metadata={"help": "Timeout for observation queue in seconds"}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
"""Validate configuration after initialization."""
|
||||
if self.port < 1 or self.port > 65535:
|
||||
raise ValueError(f"Port must be between 1 and 65535, got {self.port}")
|
||||
|
||||
if self.environment_dt <= 0:
|
||||
raise ValueError(f"environment_dt must be positive, got {self.environment_dt}")
|
||||
|
||||
if self.inference_latency < 0:
|
||||
raise ValueError(f"inference_latency must be non-negative, got {self.inference_latency}")
|
||||
|
||||
if self.obs_queue_timeout < 0:
|
||||
raise ValueError(f"obs_queue_timeout must be non-negative, got {self.obs_queue_timeout}")
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, config_dict: dict) -> "PolicyServerConfig":
|
||||
"""Create a PolicyServerConfig from a dictionary."""
|
||||
return cls(**config_dict)
|
||||
|
||||
@property
|
||||
def environment_dt(self) -> float:
|
||||
"""Environment time step, in seconds"""
|
||||
return 1 / self.fps
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert the configuration to a dictionary."""
|
||||
return {
|
||||
"host": self.host,
|
||||
"port": self.port,
|
||||
"fps": self.fps,
|
||||
"environment_dt": self.environment_dt,
|
||||
"inference_latency": self.inference_latency,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class RobotClientConfig:
|
||||
"""Configuration for RobotClient.
|
||||
|
||||
This class defines all configurable parameters for the RobotClient,
|
||||
including network connection, policy settings, and control behavior.
|
||||
"""
|
||||
|
||||
# Policy configuration
|
||||
policy_type: str = field(metadata={"help": "Type of policy to use"})
|
||||
pretrained_name_or_path: str = field(metadata={"help": "Pretrained model name or path"})
|
||||
|
||||
# Robot configuration (for CLI usage - robot instance will be created from this)
|
||||
robot: RobotConfig = field(metadata={"help": "Robot configuration"})
|
||||
|
||||
# Policies typically output K actions at max, but we can use less to avoid wasting bandwidth (as actions
|
||||
# would be aggregated on the client side anyway, depending on the value of `chunk_size_threshold`)
|
||||
actions_per_chunk: int = field(metadata={"help": "Number of actions per chunk"})
|
||||
|
||||
# Task instruction for the robot to execute (e.g., 'fold my tshirt')
|
||||
task: str = field(default="", metadata={"help": "Task instruction for the robot to execute"})
|
||||
|
||||
# Network configuration
|
||||
server_address: str = field(default="localhost:8080", metadata={"help": "Server address to connect to"})
|
||||
|
||||
# Device configuration
|
||||
policy_device: str = field(default="cpu", metadata={"help": "Device for policy inference"})
|
||||
client_device: str = field(
|
||||
default="cpu",
|
||||
metadata={
|
||||
"help": "Device to move actions to after receiving from server (e.g., for downstream planners)"
|
||||
},
|
||||
)
|
||||
|
||||
# Control behavior configuration
|
||||
chunk_size_threshold: float = field(default=0.5, metadata={"help": "Threshold for chunk size control"})
|
||||
fps: int = field(default=DEFAULT_FPS, metadata={"help": "Frames per second"})
|
||||
|
||||
# Aggregate function configuration (CLI-compatible)
|
||||
aggregate_fn_name: str = field(
|
||||
default="weighted_average",
|
||||
metadata={"help": f"Name of aggregate function to use. Options: {list(AGGREGATE_FUNCTIONS.keys())}"},
|
||||
)
|
||||
|
||||
# Debug configuration
|
||||
debug_visualize_queue_size: bool = field(
|
||||
default=False, metadata={"help": "Visualize the action queue size"}
|
||||
)
|
||||
|
||||
@property
|
||||
def environment_dt(self) -> float:
|
||||
"""Environment time step, in seconds"""
|
||||
return 1 / self.fps
|
||||
|
||||
def __post_init__(self):
|
||||
"""Validate configuration after initialization."""
|
||||
if not self.server_address:
|
||||
raise ValueError("server_address cannot be empty")
|
||||
|
||||
if not self.policy_type:
|
||||
raise ValueError("policy_type cannot be empty")
|
||||
|
||||
if not self.pretrained_name_or_path:
|
||||
raise ValueError("pretrained_name_or_path cannot be empty")
|
||||
|
||||
if not self.policy_device:
|
||||
raise ValueError("policy_device cannot be empty")
|
||||
|
||||
if not self.client_device:
|
||||
raise ValueError("client_device cannot be empty")
|
||||
|
||||
if self.chunk_size_threshold < 0 or self.chunk_size_threshold > 1:
|
||||
raise ValueError(f"chunk_size_threshold must be between 0 and 1, got {self.chunk_size_threshold}")
|
||||
|
||||
if self.fps <= 0:
|
||||
raise ValueError(f"fps must be positive, got {self.fps}")
|
||||
|
||||
if self.actions_per_chunk <= 0:
|
||||
raise ValueError(f"actions_per_chunk must be positive, got {self.actions_per_chunk}")
|
||||
|
||||
self.aggregate_fn = get_aggregate_function(self.aggregate_fn_name)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, config_dict: dict) -> "RobotClientConfig":
|
||||
"""Create a RobotClientConfig from a dictionary."""
|
||||
return cls(**config_dict)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert the configuration to a dictionary."""
|
||||
return {
|
||||
"server_address": self.server_address,
|
||||
"policy_type": self.policy_type,
|
||||
"pretrained_name_or_path": self.pretrained_name_or_path,
|
||||
"policy_device": self.policy_device,
|
||||
"client_device": self.client_device,
|
||||
"chunk_size_threshold": self.chunk_size_threshold,
|
||||
"fps": self.fps,
|
||||
"actions_per_chunk": self.actions_per_chunk,
|
||||
"task": self.task,
|
||||
"debug_visualize_queue_size": self.debug_visualize_queue_size,
|
||||
"aggregate_fn_name": self.aggregate_fn_name,
|
||||
}
|
||||
@@ -0,0 +1,29 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Client side: The environment evolves with a time resolution equal to 1/fps"""
|
||||
|
||||
DEFAULT_FPS = 30
|
||||
|
||||
"""Server side: Running inference on (at most) 1/fps"""
|
||||
DEFAULT_INFERENCE_LATENCY = 1 / DEFAULT_FPS
|
||||
|
||||
"""Server side: Timeout for observation queue in seconds"""
|
||||
DEFAULT_OBS_QUEUE_TIMEOUT = 2
|
||||
|
||||
# All action chunking policies
|
||||
SUPPORTED_POLICIES = ["act", "smolvla", "diffusion", "tdmpc", "vqbet", "pi0", "pi05", "groot"]
|
||||
|
||||
# TODO: Add all other robots
|
||||
SUPPORTED_ROBOTS = ["so100_follower", "so101_follower", "bi_so_follower", "omx_follower"]
|
||||
@@ -0,0 +1,297 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import logging.handlers
|
||||
import os
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs import PolicyFeature
|
||||
|
||||
# NOTE: Configs need to be loaded for the client to be able to instantiate the policy config
|
||||
from lerobot.policies import ( # noqa: F401
|
||||
ACTConfig,
|
||||
DiffusionConfig,
|
||||
PI0Config,
|
||||
PI05Config,
|
||||
SmolVLAConfig,
|
||||
VQBeTConfig,
|
||||
)
|
||||
from lerobot.robots.robot import Robot
|
||||
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE, OBS_STR
|
||||
from lerobot.utils.feature_utils import build_dataset_frame, hw_to_dataset_features
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
Action = torch.Tensor
|
||||
|
||||
# observation as received from the robot (can be numpy arrays, floats, etc.)
|
||||
RawObservation = dict[str, Any]
|
||||
|
||||
# observation as those recorded in LeRobot dataset (keys are different)
|
||||
LeRobotObservation = dict[str, torch.Tensor]
|
||||
|
||||
# observation, ready for policy inference (image keys resized)
|
||||
Observation = dict[str, torch.Tensor]
|
||||
|
||||
|
||||
def visualize_action_queue_size(action_queue_size: list[int]) -> None:
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
_, ax = plt.subplots()
|
||||
ax.set_title("Action Queue Size Over Time")
|
||||
ax.set_xlabel("Environment steps")
|
||||
ax.set_ylabel("Action Queue Size")
|
||||
ax.set_ylim(0, max(action_queue_size) * 1.1)
|
||||
ax.grid(True, alpha=0.3)
|
||||
ax.plot(range(len(action_queue_size)), action_queue_size)
|
||||
plt.show()
|
||||
|
||||
|
||||
def map_robot_keys_to_lerobot_features(robot: Robot) -> dict[str, dict]:
|
||||
return hw_to_dataset_features(robot.observation_features, OBS_STR, use_video=False)
|
||||
|
||||
|
||||
def is_image_key(k: str) -> bool:
|
||||
return k.startswith(OBS_IMAGES)
|
||||
|
||||
|
||||
def resize_robot_observation_image(image: torch.tensor, resize_dims: tuple[int, int, int]) -> torch.tensor:
|
||||
assert image.ndim == 3, f"Image must be (C, H, W)! Received {image.shape}"
|
||||
# (H, W, C) -> (C, H, W) for resizing from robot obsevation resolution to policy image resolution
|
||||
image = image.permute(2, 0, 1)
|
||||
dims = (resize_dims[1], resize_dims[2])
|
||||
# Add batch dimension for interpolate: (C, H, W) -> (1, C, H, W)
|
||||
image_batched = image.unsqueeze(0)
|
||||
# Interpolate and remove batch dimension: (1, C, H, W) -> (C, H, W)
|
||||
resized = torch.nn.functional.interpolate(image_batched, size=dims, mode="bilinear", align_corners=False)
|
||||
|
||||
return resized.squeeze(0)
|
||||
|
||||
|
||||
# TODO(Steven): Consider implementing a pipeline step for this
|
||||
def raw_observation_to_observation(
|
||||
raw_observation: RawObservation,
|
||||
lerobot_features: dict[str, dict],
|
||||
policy_image_features: dict[str, PolicyFeature],
|
||||
) -> Observation:
|
||||
observation = {}
|
||||
|
||||
observation = prepare_raw_observation(raw_observation, lerobot_features, policy_image_features)
|
||||
for k, v in observation.items():
|
||||
if isinstance(v, torch.Tensor): # VLAs present natural-language instructions in observations
|
||||
if "image" in k:
|
||||
# Policy expects images in shape (B, C, H, W)
|
||||
observation[k] = prepare_image(v).unsqueeze(0)
|
||||
else:
|
||||
observation[k] = v
|
||||
|
||||
return observation
|
||||
|
||||
|
||||
def prepare_image(image: torch.Tensor) -> torch.Tensor:
|
||||
"""Minimal preprocessing to turn int8 images to float32 in [0, 1], and create a memory-contiguous tensor"""
|
||||
image = image.type(torch.float32) / 255
|
||||
image = image.contiguous()
|
||||
|
||||
return image
|
||||
|
||||
|
||||
def extract_state_from_raw_observation(
|
||||
lerobot_obs: RawObservation,
|
||||
) -> torch.Tensor:
|
||||
"""Extract the state from a raw observation."""
|
||||
state = torch.tensor(lerobot_obs[OBS_STATE])
|
||||
|
||||
if state.ndim == 1:
|
||||
state = state.unsqueeze(0)
|
||||
|
||||
return state
|
||||
|
||||
|
||||
def extract_images_from_raw_observation(
|
||||
lerobot_obs: RawObservation,
|
||||
camera_key: str,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""Extract the images from a raw observation."""
|
||||
return torch.tensor(lerobot_obs[camera_key])
|
||||
|
||||
|
||||
def make_lerobot_observation(
|
||||
robot_obs: RawObservation,
|
||||
lerobot_features: dict[str, dict],
|
||||
) -> LeRobotObservation:
|
||||
"""Make a lerobot observation from a raw observation."""
|
||||
return build_dataset_frame(lerobot_features, robot_obs, prefix=OBS_STR)
|
||||
|
||||
|
||||
def prepare_raw_observation(
|
||||
robot_obs: RawObservation,
|
||||
lerobot_features: dict[str, dict],
|
||||
policy_image_features: dict[str, PolicyFeature],
|
||||
) -> Observation:
|
||||
"""Matches keys from the raw robot_obs dict to the keys expected by a given policy (passed as
|
||||
policy_image_features)."""
|
||||
# 1. {motor.pos1:value1, motor.pos2:value2, ..., laptop:np.ndarray} ->
|
||||
# -> {observation.state:[value1,value2,...], observation.images.laptop:np.ndarray}
|
||||
lerobot_obs = make_lerobot_observation(robot_obs, lerobot_features)
|
||||
|
||||
# 2. Greps all observation.images.<> keys
|
||||
image_keys = list(filter(is_image_key, lerobot_obs))
|
||||
# state's shape is expected as (B, state_dim)
|
||||
state_dict = {OBS_STATE: extract_state_from_raw_observation(lerobot_obs)}
|
||||
image_dict = {
|
||||
image_k: extract_images_from_raw_observation(lerobot_obs, image_k) for image_k in image_keys
|
||||
}
|
||||
|
||||
# Turns the image features to (C, H, W) with H, W matching the policy image features.
|
||||
# This reduces the resolution of the images
|
||||
image_dict = {
|
||||
key: resize_robot_observation_image(torch.tensor(lerobot_obs[key]), policy_image_features[key].shape)
|
||||
for key in image_keys
|
||||
}
|
||||
|
||||
if "task" in robot_obs:
|
||||
state_dict["task"] = robot_obs["task"]
|
||||
|
||||
return {**state_dict, **image_dict}
|
||||
|
||||
|
||||
def get_logger(name: str, log_to_file: bool = True) -> logging.Logger:
|
||||
"""
|
||||
Get a logger using the standardized logging setup from utils.py.
|
||||
|
||||
Args:
|
||||
name: Logger name (e.g., 'policy_server', 'robot_client')
|
||||
log_to_file: Whether to also log to a file
|
||||
|
||||
Returns:
|
||||
Configured logger instance
|
||||
"""
|
||||
# Create logs directory if logging to file
|
||||
if log_to_file:
|
||||
os.makedirs("logs", exist_ok=True)
|
||||
log_file = Path(f"logs/{name}_{int(time.time())}.log")
|
||||
else:
|
||||
log_file = None
|
||||
|
||||
# Initialize the standardized logging
|
||||
init_logging(log_file=log_file, display_pid=False)
|
||||
|
||||
# Return a named logger
|
||||
return logging.getLogger(name)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TimedData:
|
||||
"""A data object with timestamp and timestep information.
|
||||
|
||||
Args:
|
||||
timestamp: Unix timestamp relative to data's creation.
|
||||
data: The actual data to wrap a timestamp around.
|
||||
timestep: The timestep of the data.
|
||||
"""
|
||||
|
||||
timestamp: float
|
||||
timestep: int
|
||||
|
||||
def get_timestamp(self):
|
||||
return self.timestamp
|
||||
|
||||
def get_timestep(self):
|
||||
return self.timestep
|
||||
|
||||
|
||||
@dataclass
|
||||
class TimedAction(TimedData):
|
||||
action: Action
|
||||
|
||||
def get_action(self):
|
||||
return self.action
|
||||
|
||||
|
||||
@dataclass
|
||||
class TimedObservation(TimedData):
|
||||
observation: RawObservation
|
||||
must_go: bool = False
|
||||
|
||||
def get_observation(self):
|
||||
return self.observation
|
||||
|
||||
|
||||
@dataclass
|
||||
class FPSTracker:
|
||||
"""Utility class to track FPS metrics over time."""
|
||||
|
||||
target_fps: float
|
||||
first_timestamp: float = None
|
||||
total_obs_count: int = 0
|
||||
|
||||
def calculate_fps_metrics(self, current_timestamp: float) -> dict[str, float]:
|
||||
"""Calculate average FPS vs target"""
|
||||
self.total_obs_count += 1
|
||||
|
||||
# Initialize first observation time
|
||||
if self.first_timestamp is None:
|
||||
self.first_timestamp = current_timestamp
|
||||
|
||||
# Calculate overall average FPS (since start)
|
||||
total_duration = current_timestamp - self.first_timestamp
|
||||
avg_fps = (self.total_obs_count - 1) / total_duration if total_duration > 1e-6 else 0.0
|
||||
|
||||
return {"avg_fps": avg_fps, "target_fps": self.target_fps}
|
||||
|
||||
def reset(self):
|
||||
"""Reset the FPS tracker state"""
|
||||
self.first_timestamp = None
|
||||
self.total_obs_count = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class RemotePolicyConfig:
|
||||
policy_type: str
|
||||
pretrained_name_or_path: str
|
||||
lerobot_features: dict[str, PolicyFeature]
|
||||
actions_per_chunk: int
|
||||
device: str = "cpu"
|
||||
rename_map: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
|
||||
def _compare_observation_states(obs1_state: torch.Tensor, obs2_state: torch.Tensor, atol: float) -> bool:
|
||||
"""Check if two observation states are similar, under a tolerance threshold"""
|
||||
return bool(torch.linalg.norm(obs1_state - obs2_state) < atol)
|
||||
|
||||
|
||||
def observations_similar(
|
||||
obs1: TimedObservation, obs2: TimedObservation, lerobot_features: dict[str, dict], atol: float = 1
|
||||
) -> bool:
|
||||
"""Check if two observations are similar, under a tolerance threshold. Measures distance between
|
||||
observations as the difference in joint-space between the two observations.
|
||||
|
||||
NOTE(fracapuano): This is a very simple check, and it is enough for the current use case.
|
||||
An immediate next step is to use (fast) perceptual difference metrics comparing some camera views,
|
||||
to surpass this joint-space similarity check.
|
||||
"""
|
||||
obs1_state = extract_state_from_raw_observation(
|
||||
make_lerobot_observation(obs1.get_observation(), lerobot_features)
|
||||
)
|
||||
obs2_state = extract_state_from_raw_observation(
|
||||
make_lerobot_observation(obs2.get_observation(), lerobot_features)
|
||||
)
|
||||
|
||||
return _compare_observation_states(obs1_state, obs2_state, atol=atol)
|
||||
@@ -0,0 +1,439 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Example:
|
||||
```shell
|
||||
python -m lerobot.async_inference.policy_server \
|
||||
--host=127.0.0.1 \
|
||||
--port=8080 \
|
||||
--fps=30 \
|
||||
--inference_latency=0.033 \
|
||||
--obs_queue_timeout=1
|
||||
```
|
||||
"""
|
||||
|
||||
import logging
|
||||
import pickle # nosec
|
||||
import threading
|
||||
import time
|
||||
from concurrent import futures
|
||||
from dataclasses import asdict
|
||||
from pprint import pformat
|
||||
from queue import Empty, Queue
|
||||
from typing import Any
|
||||
|
||||
import draccus
|
||||
import grpc
|
||||
import torch
|
||||
|
||||
from lerobot.policies import get_policy_class, make_pre_post_processors
|
||||
from lerobot.processor import PolicyProcessorPipeline
|
||||
from lerobot.transport import (
|
||||
services_pb2, # type: ignore
|
||||
services_pb2_grpc, # type: ignore
|
||||
)
|
||||
from lerobot.transport.utils import receive_bytes_in_chunks
|
||||
from lerobot.types import PolicyAction
|
||||
|
||||
from .configs import PolicyServerConfig
|
||||
from .constants import SUPPORTED_POLICIES
|
||||
from .helpers import (
|
||||
FPSTracker,
|
||||
Observation,
|
||||
RemotePolicyConfig,
|
||||
TimedAction,
|
||||
TimedObservation,
|
||||
get_logger,
|
||||
observations_similar,
|
||||
raw_observation_to_observation,
|
||||
)
|
||||
|
||||
|
||||
class PolicyServer(services_pb2_grpc.AsyncInferenceServicer):
|
||||
prefix = "policy_server"
|
||||
logger = get_logger(prefix)
|
||||
|
||||
def __init__(self, config: PolicyServerConfig):
|
||||
self.config = config
|
||||
self.shutdown_event = threading.Event()
|
||||
|
||||
# FPS measurement
|
||||
self.fps_tracker = FPSTracker(target_fps=config.fps)
|
||||
|
||||
self.observation_queue = Queue(maxsize=1)
|
||||
|
||||
self._predicted_timesteps_lock = threading.Lock()
|
||||
self._predicted_timesteps = set()
|
||||
|
||||
self.last_processed_obs = None
|
||||
|
||||
# Attributes will be set by SendPolicyInstructions
|
||||
self.device = None
|
||||
self.policy_type = None
|
||||
self.lerobot_features = None
|
||||
self.actions_per_chunk = None
|
||||
self.policy = None
|
||||
self.preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]] | None = None
|
||||
self.postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction] | None = None
|
||||
|
||||
@property
|
||||
def running(self):
|
||||
return not self.shutdown_event.is_set()
|
||||
|
||||
@property
|
||||
def policy_image_features(self):
|
||||
return self.policy.config.image_features
|
||||
|
||||
def _reset_server(self) -> None:
|
||||
"""Flushes server state when new client connects."""
|
||||
# only running inference on the latest observation received by the server
|
||||
self.shutdown_event.set()
|
||||
self.observation_queue = Queue(maxsize=1)
|
||||
|
||||
with self._predicted_timesteps_lock:
|
||||
self._predicted_timesteps = set()
|
||||
|
||||
def Ready(self, request, context): # noqa: N802
|
||||
client_id = context.peer()
|
||||
self.logger.info(f"Client {client_id} connected and ready")
|
||||
self._reset_server()
|
||||
self.shutdown_event.clear()
|
||||
|
||||
return services_pb2.Empty()
|
||||
|
||||
def SendPolicyInstructions(self, request, context): # noqa: N802
|
||||
"""Receive policy instructions from the robot client"""
|
||||
|
||||
if not self.running:
|
||||
self.logger.warning("Server is not running. Ignoring policy instructions.")
|
||||
return services_pb2.Empty()
|
||||
|
||||
client_id = context.peer()
|
||||
|
||||
policy_specs = pickle.loads(request.data) # nosec
|
||||
|
||||
if not isinstance(policy_specs, RemotePolicyConfig):
|
||||
raise TypeError(f"Policy specs must be a RemotePolicyConfig. Got {type(policy_specs)}")
|
||||
|
||||
if policy_specs.policy_type not in SUPPORTED_POLICIES:
|
||||
raise ValueError(
|
||||
f"Policy type {policy_specs.policy_type} not supported. "
|
||||
f"Supported policies: {SUPPORTED_POLICIES}"
|
||||
)
|
||||
|
||||
self.logger.info(
|
||||
f"Receiving policy instructions from {client_id} | "
|
||||
f"Policy type: {policy_specs.policy_type} | "
|
||||
f"Pretrained name or path: {policy_specs.pretrained_name_or_path} | "
|
||||
f"Actions per chunk: {policy_specs.actions_per_chunk} | "
|
||||
f"Device: {policy_specs.device}"
|
||||
)
|
||||
|
||||
self.device = policy_specs.device
|
||||
self.policy_type = policy_specs.policy_type # act, pi0, etc.
|
||||
self.lerobot_features = policy_specs.lerobot_features
|
||||
self.actions_per_chunk = policy_specs.actions_per_chunk
|
||||
|
||||
policy_class = get_policy_class(self.policy_type)
|
||||
|
||||
start = time.perf_counter()
|
||||
self.policy = policy_class.from_pretrained(policy_specs.pretrained_name_or_path)
|
||||
self.policy.to(self.device)
|
||||
|
||||
# Load preprocessor and postprocessor, overriding device to match requested device
|
||||
device_override = {"device": self.device}
|
||||
self.preprocessor, self.postprocessor = make_pre_post_processors(
|
||||
self.policy.config,
|
||||
pretrained_path=policy_specs.pretrained_name_or_path,
|
||||
preprocessor_overrides={
|
||||
"device_processor": device_override,
|
||||
"rename_observations_processor": {"rename_map": policy_specs.rename_map},
|
||||
},
|
||||
postprocessor_overrides={"device_processor": device_override},
|
||||
)
|
||||
|
||||
end = time.perf_counter()
|
||||
|
||||
self.logger.info(f"Time taken to put policy on {self.device}: {end - start:.4f} seconds")
|
||||
|
||||
return services_pb2.Empty()
|
||||
|
||||
def SendObservations(self, request_iterator, context): # noqa: N802
|
||||
"""Receive observations from the robot client"""
|
||||
client_id = context.peer()
|
||||
self.logger.debug(f"Receiving observations from {client_id}")
|
||||
|
||||
receive_time = time.time() # comparing timestamps so need time.time()
|
||||
start_deserialize = time.perf_counter()
|
||||
received_bytes = receive_bytes_in_chunks(
|
||||
request_iterator, None, self.shutdown_event, self.logger
|
||||
) # blocking call while looping over request_iterator
|
||||
timed_observation = pickle.loads(received_bytes) # nosec
|
||||
deserialize_time = time.perf_counter() - start_deserialize
|
||||
|
||||
self.logger.debug(f"Received observation #{timed_observation.get_timestep()}")
|
||||
|
||||
obs_timestep = timed_observation.get_timestep()
|
||||
obs_timestamp = timed_observation.get_timestamp()
|
||||
|
||||
# Calculate FPS metrics
|
||||
fps_metrics = self.fps_tracker.calculate_fps_metrics(obs_timestamp)
|
||||
|
||||
self.logger.debug(
|
||||
f"Received observation #{obs_timestep} | "
|
||||
f"Avg FPS: {fps_metrics['avg_fps']:.2f} | " # fps at which observations are received from client
|
||||
f"Target: {fps_metrics['target_fps']:.2f} | "
|
||||
f"One-way latency: {(receive_time - obs_timestamp) * 1000:.2f}ms"
|
||||
)
|
||||
|
||||
self.logger.debug(
|
||||
f"Server timestamp: {receive_time:.6f} | "
|
||||
f"Client timestamp: {obs_timestamp:.6f} | "
|
||||
f"Deserialization time: {deserialize_time:.6f}s"
|
||||
)
|
||||
|
||||
if not self._enqueue_observation(
|
||||
timed_observation # wrapping a RawObservation
|
||||
):
|
||||
self.logger.debug(f"Observation #{obs_timestep} has been filtered out")
|
||||
|
||||
return services_pb2.Empty()
|
||||
|
||||
def GetActions(self, request, context): # noqa: N802
|
||||
"""Returns actions to the robot client. Actions are sent as a single
|
||||
chunk, containing multiple actions."""
|
||||
client_id = context.peer()
|
||||
self.logger.debug(f"Client {client_id} connected for action streaming")
|
||||
|
||||
# Generate action based on the most recent observation and its timestep
|
||||
try:
|
||||
getactions_starts = time.perf_counter()
|
||||
obs = self.observation_queue.get(timeout=self.config.obs_queue_timeout)
|
||||
self.logger.info(
|
||||
f"Running inference for observation #{obs.get_timestep()} (must_go: {obs.must_go})"
|
||||
)
|
||||
|
||||
with self._predicted_timesteps_lock:
|
||||
self._predicted_timesteps.add(obs.get_timestep())
|
||||
|
||||
start_time = time.perf_counter()
|
||||
action_chunk = self._predict_action_chunk(obs)
|
||||
inference_time = time.perf_counter() - start_time
|
||||
|
||||
start_time = time.perf_counter()
|
||||
actions_bytes = pickle.dumps(action_chunk) # nosec
|
||||
serialize_time = time.perf_counter() - start_time
|
||||
|
||||
# Create and return the action chunk
|
||||
actions = services_pb2.Actions(data=actions_bytes)
|
||||
|
||||
self.logger.info(
|
||||
f"Action chunk #{obs.get_timestep()} generated | "
|
||||
f"Total time: {(inference_time + serialize_time) * 1000:.2f}ms"
|
||||
)
|
||||
|
||||
self.logger.debug(
|
||||
f"Action chunk #{obs.get_timestep()} generated | "
|
||||
f"Inference time: {inference_time:.2f}s |"
|
||||
f"Serialize time: {serialize_time:.2f}s |"
|
||||
f"Total time: {inference_time + serialize_time:.2f}s"
|
||||
)
|
||||
|
||||
time.sleep(
|
||||
max(0, self.config.inference_latency - max(0, time.perf_counter() - getactions_starts))
|
||||
) # sleep controls inference latency
|
||||
|
||||
return actions
|
||||
|
||||
except Empty: # no observation added to queue in obs_queue_timeout
|
||||
return services_pb2.Empty()
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in StreamActions: {e}")
|
||||
|
||||
return services_pb2.Empty()
|
||||
|
||||
def _obs_sanity_checks(self, obs: TimedObservation, previous_obs: TimedObservation) -> bool:
|
||||
"""Check if the observation is valid to be processed by the policy"""
|
||||
with self._predicted_timesteps_lock:
|
||||
predicted_timesteps = self._predicted_timesteps
|
||||
|
||||
if obs.get_timestep() in predicted_timesteps:
|
||||
self.logger.debug(f"Skipping observation #{obs.get_timestep()} - Timestep predicted already!")
|
||||
return False
|
||||
|
||||
elif observations_similar(obs, previous_obs, lerobot_features=self.lerobot_features):
|
||||
self.logger.debug(
|
||||
f"Skipping observation #{obs.get_timestep()} - Observation too similar to last obs predicted!"
|
||||
)
|
||||
return False
|
||||
|
||||
else:
|
||||
return True
|
||||
|
||||
def _enqueue_observation(self, obs: TimedObservation) -> bool:
|
||||
"""Enqueue an observation if it must go through processing, otherwise skip it.
|
||||
Observations not in queue are never run through the policy network"""
|
||||
|
||||
if (
|
||||
obs.must_go
|
||||
or self.last_processed_obs is None
|
||||
or self._obs_sanity_checks(obs, self.last_processed_obs)
|
||||
):
|
||||
last_obs = self.last_processed_obs.get_timestep() if self.last_processed_obs else "None"
|
||||
self.logger.debug(
|
||||
f"Enqueuing observation. Must go: {obs.must_go} | Last processed obs: {last_obs}"
|
||||
)
|
||||
|
||||
# If queue is full, get the old observation to make room
|
||||
if self.observation_queue.full():
|
||||
# pops from queue
|
||||
_ = self.observation_queue.get_nowait()
|
||||
self.logger.debug("Observation queue was full, removed oldest observation")
|
||||
|
||||
# Now put the new observation (never blocks as queue is non-full here)
|
||||
self.observation_queue.put(obs)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _time_action_chunk(self, t_0: float, action_chunk: list[torch.Tensor], i_0: int) -> list[TimedAction]:
|
||||
"""Turn a chunk of actions into a list of TimedAction instances,
|
||||
with the first action corresponding to t_0 and the rest corresponding to
|
||||
t_0 + i*environment_dt for i in range(len(action_chunk))
|
||||
"""
|
||||
return [
|
||||
TimedAction(timestamp=t_0 + i * self.config.environment_dt, timestep=i_0 + i, action=action)
|
||||
for i, action in enumerate(action_chunk)
|
||||
]
|
||||
|
||||
def _get_action_chunk(self, observation: dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
"""Get an action chunk from the policy. The chunk contains only"""
|
||||
chunk = self.policy.predict_action_chunk(observation)
|
||||
if chunk.ndim != 3:
|
||||
chunk = chunk.unsqueeze(0) # adding batch dimension, now shape is (B, chunk_size, action_dim)
|
||||
|
||||
return chunk[:, : self.actions_per_chunk, :]
|
||||
|
||||
def _predict_action_chunk(self, observation_t: TimedObservation) -> list[TimedAction]:
|
||||
"""Predict an action chunk based on an observation.
|
||||
|
||||
Pipeline:
|
||||
1. Convert raw observation to LeRobot format
|
||||
2. Apply preprocessor (tokenization, normalization, batching, device placement)
|
||||
3. Run policy inference to get action chunk
|
||||
4. Apply postprocessor (unnormalization, device movement)
|
||||
5. Convert to TimedAction list
|
||||
"""
|
||||
"""1. Prepare observation"""
|
||||
start_prepare = time.perf_counter()
|
||||
observation: Observation = raw_observation_to_observation(
|
||||
observation_t.get_observation(),
|
||||
self.lerobot_features,
|
||||
self.policy_image_features,
|
||||
)
|
||||
prepare_time = time.perf_counter() - start_prepare
|
||||
|
||||
"""2. Apply preprocessor"""
|
||||
start_preprocess = time.perf_counter()
|
||||
observation = self.preprocessor(observation)
|
||||
self.last_processed_obs: TimedObservation = observation_t
|
||||
preprocessing_time = time.perf_counter() - start_preprocess
|
||||
|
||||
"""3. Get action chunk"""
|
||||
start_inference = time.perf_counter()
|
||||
action_tensor = self._get_action_chunk(observation)
|
||||
inference_time = time.perf_counter() - start_inference
|
||||
self.logger.info(
|
||||
f"Preprocessing and inference took {inference_time:.4f}s, action shape: {action_tensor.shape}"
|
||||
)
|
||||
|
||||
"""4. Apply postprocessor"""
|
||||
# Apply postprocessor (handles unnormalization and device movement)
|
||||
# Postprocessor expects (B, action_dim) per action, but we have (B, chunk_size, action_dim)
|
||||
# So we process each action in the chunk individually
|
||||
start_postprocess = time.perf_counter()
|
||||
_, chunk_size, _ = action_tensor.shape
|
||||
|
||||
# Process each action in the chunk
|
||||
processed_actions = []
|
||||
for i in range(chunk_size):
|
||||
# Extract action at timestep i: (B, action_dim)
|
||||
single_action = action_tensor[:, i, :]
|
||||
processed_action = self.postprocessor(single_action)
|
||||
processed_actions.append(processed_action)
|
||||
|
||||
# Stack back to (B, chunk_size, action_dim), then remove batch dim
|
||||
action_tensor = torch.stack(processed_actions, dim=1).squeeze(0)
|
||||
self.logger.debug(f"Postprocessed action shape: {action_tensor.shape}")
|
||||
|
||||
action_tensor = action_tensor.detach().cpu()
|
||||
|
||||
"""5. Convert to TimedAction list"""
|
||||
action_chunk = self._time_action_chunk(
|
||||
observation_t.get_timestamp(), list(action_tensor), observation_t.get_timestep()
|
||||
)
|
||||
postprocess_stops = time.perf_counter()
|
||||
postprocessing_time = postprocess_stops - start_postprocess
|
||||
|
||||
self.logger.info(
|
||||
f"Observation {observation_t.get_timestep()} | "
|
||||
f"Total time: {1000 * (postprocess_stops - start_prepare):.2f}ms"
|
||||
)
|
||||
|
||||
self.logger.debug(
|
||||
f"Observation {observation_t.get_timestep()} | "
|
||||
f"Prepare time: {1000 * prepare_time:.2f}ms | "
|
||||
f"Preprocessing time: {1000 * preprocessing_time:.2f}ms | "
|
||||
f"Inference time: {1000 * inference_time:.2f}ms | "
|
||||
f"Postprocessing time: {1000 * postprocessing_time:.2f}ms | "
|
||||
f"Total time: {1000 * (postprocess_stops - start_prepare):.2f}ms"
|
||||
)
|
||||
|
||||
return action_chunk
|
||||
|
||||
def stop(self):
|
||||
"""Stop the server"""
|
||||
self._reset_server()
|
||||
self.logger.info("Server stopping...")
|
||||
|
||||
|
||||
@draccus.wrap()
|
||||
def serve(cfg: PolicyServerConfig):
|
||||
"""Start the PolicyServer with the given configuration.
|
||||
|
||||
Args:
|
||||
config: PolicyServerConfig instance. If None, uses default configuration.
|
||||
"""
|
||||
logging.info(pformat(asdict(cfg)))
|
||||
|
||||
# Create the server instance first
|
||||
policy_server = PolicyServer(cfg)
|
||||
|
||||
# Setup and start gRPC server
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=4))
|
||||
services_pb2_grpc.add_AsyncInferenceServicer_to_server(policy_server, server)
|
||||
server.add_insecure_port(f"{cfg.host}:{cfg.port}")
|
||||
|
||||
policy_server.logger.info(f"PolicyServer started on {cfg.host}:{cfg.port}")
|
||||
server.start()
|
||||
|
||||
server.wait_for_termination()
|
||||
|
||||
policy_server.logger.info("Server terminated")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
serve()
|
||||
@@ -0,0 +1,517 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Example command:
|
||||
```shell
|
||||
python src/lerobot/async_inference/robot_client.py \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58760431541 \
|
||||
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 1920, height: 1080, fps: 30}}" \
|
||||
--robot.id=black \
|
||||
--task="dummy" \
|
||||
--server_address=127.0.0.1:8080 \
|
||||
--policy_type=act \
|
||||
--pretrained_name_or_path=user/model \
|
||||
--policy_device=mps \
|
||||
--client_device=cpu \
|
||||
--actions_per_chunk=50 \
|
||||
--chunk_size_threshold=0.5 \
|
||||
--aggregate_fn_name=weighted_average \
|
||||
--debug_visualize_queue_size=True
|
||||
```
|
||||
"""
|
||||
|
||||
import logging
|
||||
import pickle # nosec
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from dataclasses import asdict
|
||||
from pprint import pformat
|
||||
from queue import Queue
|
||||
from typing import Any
|
||||
|
||||
import draccus
|
||||
import grpc
|
||||
import torch
|
||||
|
||||
from lerobot.cameras.opencv import OpenCVCameraConfig # noqa: F401
|
||||
from lerobot.cameras.realsense import RealSenseCameraConfig # noqa: F401
|
||||
from lerobot.robots import ( # noqa: F401
|
||||
Robot,
|
||||
RobotConfig,
|
||||
bi_so_follower,
|
||||
koch_follower,
|
||||
make_robot_from_config,
|
||||
omx_follower,
|
||||
so_follower,
|
||||
)
|
||||
from lerobot.transport import (
|
||||
services_pb2, # type: ignore
|
||||
services_pb2_grpc, # type: ignore
|
||||
)
|
||||
from lerobot.transport.utils import grpc_channel_options, send_bytes_in_chunks
|
||||
from lerobot.utils.import_utils import register_third_party_plugins
|
||||
|
||||
from .configs import RobotClientConfig
|
||||
from .helpers import (
|
||||
Action,
|
||||
FPSTracker,
|
||||
Observation,
|
||||
RawObservation,
|
||||
RemotePolicyConfig,
|
||||
TimedAction,
|
||||
TimedObservation,
|
||||
get_logger,
|
||||
map_robot_keys_to_lerobot_features,
|
||||
visualize_action_queue_size,
|
||||
)
|
||||
|
||||
|
||||
class RobotClient:
|
||||
prefix = "robot_client"
|
||||
logger = get_logger(prefix)
|
||||
|
||||
def __init__(self, config: RobotClientConfig):
|
||||
"""Initialize RobotClient with unified configuration.
|
||||
|
||||
Args:
|
||||
config: RobotClientConfig containing all configuration parameters
|
||||
"""
|
||||
# Store configuration
|
||||
self.config = config
|
||||
self.robot = make_robot_from_config(config.robot)
|
||||
self.robot.connect()
|
||||
|
||||
lerobot_features = map_robot_keys_to_lerobot_features(self.robot)
|
||||
|
||||
# Use environment variable if server_address is not provided in config
|
||||
self.server_address = config.server_address
|
||||
|
||||
self.policy_config = RemotePolicyConfig(
|
||||
config.policy_type,
|
||||
config.pretrained_name_or_path,
|
||||
lerobot_features,
|
||||
config.actions_per_chunk,
|
||||
config.policy_device,
|
||||
)
|
||||
self.channel = grpc.insecure_channel(
|
||||
self.server_address, grpc_channel_options(initial_backoff=f"{config.environment_dt:.4f}s")
|
||||
)
|
||||
self.stub = services_pb2_grpc.AsyncInferenceStub(self.channel)
|
||||
self.logger.info(f"Initializing client to connect to server at {self.server_address}")
|
||||
|
||||
self.shutdown_event = threading.Event()
|
||||
|
||||
# Initialize client side variables
|
||||
self.latest_action_lock = threading.Lock()
|
||||
self.latest_action = -1
|
||||
self.action_chunk_size = -1
|
||||
|
||||
self._chunk_size_threshold = config.chunk_size_threshold
|
||||
|
||||
self.action_queue = Queue()
|
||||
self.action_queue_lock = threading.Lock() # Protect queue operations
|
||||
self.action_queue_size = []
|
||||
self.start_barrier = threading.Barrier(2) # 2 threads: action receiver, control loop
|
||||
|
||||
# FPS measurement
|
||||
self.fps_tracker = FPSTracker(target_fps=self.config.fps)
|
||||
|
||||
self.logger.info("Robot connected and ready")
|
||||
|
||||
# Use an event for thread-safe coordination
|
||||
self.must_go = threading.Event()
|
||||
self.must_go.set() # Initially set - observations qualify for direct processing
|
||||
|
||||
@property
|
||||
def running(self):
|
||||
return not self.shutdown_event.is_set()
|
||||
|
||||
def start(self):
|
||||
"""Start the robot client and connect to the policy server"""
|
||||
try:
|
||||
# client-server handshake
|
||||
start_time = time.perf_counter()
|
||||
self.stub.Ready(services_pb2.Empty())
|
||||
end_time = time.perf_counter()
|
||||
self.logger.debug(f"Connected to policy server in {end_time - start_time:.4f}s")
|
||||
|
||||
# send policy instructions
|
||||
policy_config_bytes = pickle.dumps(self.policy_config)
|
||||
policy_setup = services_pb2.PolicySetup(data=policy_config_bytes)
|
||||
|
||||
self.logger.info("Sending policy instructions to policy server")
|
||||
self.logger.debug(
|
||||
f"Policy type: {self.policy_config.policy_type} | "
|
||||
f"Pretrained name or path: {self.policy_config.pretrained_name_or_path} | "
|
||||
f"Device: {self.policy_config.device}"
|
||||
)
|
||||
|
||||
self.stub.SendPolicyInstructions(policy_setup)
|
||||
|
||||
self.shutdown_event.clear()
|
||||
|
||||
return True
|
||||
|
||||
except grpc.RpcError as e:
|
||||
self.logger.error(f"Failed to connect to policy server: {e}")
|
||||
return False
|
||||
|
||||
def stop(self):
|
||||
"""Stop the robot client"""
|
||||
self.shutdown_event.set()
|
||||
|
||||
self.robot.disconnect()
|
||||
self.logger.debug("Robot disconnected")
|
||||
|
||||
self.channel.close()
|
||||
self.logger.debug("Client stopped, channel closed")
|
||||
|
||||
def send_observation(
|
||||
self,
|
||||
obs: TimedObservation,
|
||||
) -> bool:
|
||||
"""Send observation to the policy server.
|
||||
Returns True if the observation was sent successfully, False otherwise."""
|
||||
if not self.running:
|
||||
raise RuntimeError("Client not running. Run RobotClient.start() before sending observations.")
|
||||
|
||||
if not isinstance(obs, TimedObservation):
|
||||
raise ValueError("Input observation needs to be a TimedObservation!")
|
||||
|
||||
start_time = time.perf_counter()
|
||||
observation_bytes = pickle.dumps(obs)
|
||||
serialize_time = time.perf_counter() - start_time
|
||||
self.logger.debug(f"Observation serialization time: {serialize_time:.6f}s")
|
||||
|
||||
try:
|
||||
observation_iterator = send_bytes_in_chunks(
|
||||
observation_bytes,
|
||||
services_pb2.Observation,
|
||||
log_prefix="[CLIENT] Observation",
|
||||
silent=True,
|
||||
)
|
||||
_ = self.stub.SendObservations(observation_iterator)
|
||||
obs_timestep = obs.get_timestep()
|
||||
self.logger.debug(f"Sent observation #{obs_timestep} | ")
|
||||
|
||||
return True
|
||||
|
||||
except grpc.RpcError as e:
|
||||
self.logger.error(f"Error sending observation #{obs.get_timestep()}: {e}")
|
||||
return False
|
||||
|
||||
def _inspect_action_queue(self):
|
||||
with self.action_queue_lock:
|
||||
queue_size = self.action_queue.qsize()
|
||||
timestamps = sorted([action.get_timestep() for action in self.action_queue.queue])
|
||||
self.logger.debug(f"Queue size: {queue_size}, Queue contents: {timestamps}")
|
||||
return queue_size, timestamps
|
||||
|
||||
def _aggregate_action_queues(
|
||||
self,
|
||||
incoming_actions: list[TimedAction],
|
||||
aggregate_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
|
||||
):
|
||||
"""Finds the same timestep actions in the queue and aggregates them using the aggregate_fn"""
|
||||
if aggregate_fn is None:
|
||||
# default aggregate function: take the latest action
|
||||
def aggregate_fn(x1, x2):
|
||||
return x2
|
||||
|
||||
future_action_queue = Queue()
|
||||
with self.action_queue_lock:
|
||||
internal_queue = self.action_queue.queue
|
||||
|
||||
current_action_queue = {action.get_timestep(): action.get_action() for action in internal_queue}
|
||||
|
||||
for new_action in incoming_actions:
|
||||
with self.latest_action_lock:
|
||||
latest_action = self.latest_action
|
||||
|
||||
# New action is older than the latest action in the queue, skip it
|
||||
if new_action.get_timestep() <= latest_action:
|
||||
continue
|
||||
|
||||
# If the new action's timestep is not in the current action queue, add it directly
|
||||
elif new_action.get_timestep() not in current_action_queue:
|
||||
future_action_queue.put(new_action)
|
||||
continue
|
||||
|
||||
# If the new action's timestep is in the current action queue, aggregate it
|
||||
# TODO: There is probably a way to do this with broadcasting of the two action tensors
|
||||
future_action_queue.put(
|
||||
TimedAction(
|
||||
timestamp=new_action.get_timestamp(),
|
||||
timestep=new_action.get_timestep(),
|
||||
action=aggregate_fn(
|
||||
current_action_queue[new_action.get_timestep()], new_action.get_action()
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
with self.action_queue_lock:
|
||||
self.action_queue = future_action_queue
|
||||
|
||||
def receive_actions(self, verbose: bool = False):
|
||||
"""Receive actions from the policy server"""
|
||||
# Wait at barrier for synchronized start
|
||||
self.start_barrier.wait()
|
||||
self.logger.info("Action receiving thread starting")
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
# Use StreamActions to get a stream of actions from the server
|
||||
actions_chunk = self.stub.GetActions(services_pb2.Empty())
|
||||
if len(actions_chunk.data) == 0:
|
||||
continue # received `Empty` from server, wait for next call
|
||||
|
||||
receive_time = time.time()
|
||||
|
||||
# Deserialize bytes back into list[TimedAction]
|
||||
deserialize_start = time.perf_counter()
|
||||
timed_actions = pickle.loads(actions_chunk.data) # nosec
|
||||
deserialize_time = time.perf_counter() - deserialize_start
|
||||
|
||||
# Log device type of received actions
|
||||
if len(timed_actions) > 0:
|
||||
received_device = timed_actions[0].get_action().device.type
|
||||
self.logger.debug(f"Received actions on device: {received_device}")
|
||||
|
||||
# Move actions to client_device (e.g., for downstream planners that need GPU)
|
||||
client_device = self.config.client_device
|
||||
if client_device != "cpu":
|
||||
for timed_action in timed_actions:
|
||||
if timed_action.get_action().device.type != client_device:
|
||||
timed_action.action = timed_action.get_action().to(client_device)
|
||||
self.logger.debug(f"Converted actions to device: {client_device}")
|
||||
else:
|
||||
self.logger.debug(f"Actions kept on device: {client_device}")
|
||||
|
||||
self.action_chunk_size = max(self.action_chunk_size, len(timed_actions))
|
||||
|
||||
# Calculate network latency if we have matching observations
|
||||
if len(timed_actions) > 0 and verbose:
|
||||
with self.latest_action_lock:
|
||||
latest_action = self.latest_action
|
||||
|
||||
self.logger.debug(f"Current latest action: {latest_action}")
|
||||
|
||||
# Get queue state before changes
|
||||
old_size, old_timesteps = self._inspect_action_queue()
|
||||
if not old_timesteps:
|
||||
old_timesteps = [latest_action] # queue was empty
|
||||
|
||||
# Log incoming actions
|
||||
incoming_timesteps = [a.get_timestep() for a in timed_actions]
|
||||
|
||||
first_action_timestep = timed_actions[0].get_timestep()
|
||||
server_to_client_latency = (receive_time - timed_actions[0].get_timestamp()) * 1000
|
||||
|
||||
self.logger.info(
|
||||
f"Received action chunk for step #{first_action_timestep} | "
|
||||
f"Latest action: #{latest_action} | "
|
||||
f"Incoming actions: {incoming_timesteps[0]}:{incoming_timesteps[-1]} | "
|
||||
f"Network latency (server->client): {server_to_client_latency:.2f}ms | "
|
||||
f"Deserialization time: {deserialize_time * 1000:.2f}ms"
|
||||
)
|
||||
|
||||
# Update action queue
|
||||
start_time = time.perf_counter()
|
||||
self._aggregate_action_queues(timed_actions, self.config.aggregate_fn)
|
||||
queue_update_time = time.perf_counter() - start_time
|
||||
|
||||
self.must_go.set() # after receiving actions, next empty queue triggers must-go processing!
|
||||
|
||||
if verbose:
|
||||
# Get queue state after changes
|
||||
new_size, new_timesteps = self._inspect_action_queue()
|
||||
|
||||
with self.latest_action_lock:
|
||||
latest_action = self.latest_action
|
||||
|
||||
self.logger.info(
|
||||
f"Latest action: {latest_action} | "
|
||||
f"Old action steps: {old_timesteps[0]}:{old_timesteps[-1]} | "
|
||||
f"Incoming action steps: {incoming_timesteps[0]}:{incoming_timesteps[-1]} | "
|
||||
f"Updated action steps: {new_timesteps[0]}:{new_timesteps[-1]}"
|
||||
)
|
||||
self.logger.debug(
|
||||
f"Queue update complete ({queue_update_time:.6f}s) | "
|
||||
f"Before: {old_size} items | "
|
||||
f"After: {new_size} items | "
|
||||
)
|
||||
|
||||
except grpc.RpcError as e:
|
||||
self.logger.error(f"Error receiving actions: {e}")
|
||||
|
||||
def actions_available(self):
|
||||
"""Check if there are actions available in the queue"""
|
||||
with self.action_queue_lock:
|
||||
return not self.action_queue.empty()
|
||||
|
||||
def _action_tensor_to_action_dict(self, action_tensor: torch.Tensor) -> dict[str, float]:
|
||||
action = {key: action_tensor[i].item() for i, key in enumerate(self.robot.action_features)}
|
||||
return action
|
||||
|
||||
def control_loop_action(self, verbose: bool = False) -> dict[str, Any]:
|
||||
"""Reading and performing actions in local queue"""
|
||||
|
||||
# Lock only for queue operations
|
||||
get_start = time.perf_counter()
|
||||
with self.action_queue_lock:
|
||||
self.action_queue_size.append(self.action_queue.qsize())
|
||||
# Get action from queue
|
||||
timed_action = self.action_queue.get_nowait()
|
||||
get_end = time.perf_counter() - get_start
|
||||
|
||||
_performed_action = self.robot.send_action(
|
||||
self._action_tensor_to_action_dict(timed_action.get_action())
|
||||
)
|
||||
with self.latest_action_lock:
|
||||
self.latest_action = timed_action.get_timestep()
|
||||
|
||||
if verbose:
|
||||
with self.action_queue_lock:
|
||||
current_queue_size = self.action_queue.qsize()
|
||||
|
||||
self.logger.debug(
|
||||
f"Ts={timed_action.get_timestamp()} | "
|
||||
f"Action #{timed_action.get_timestep()} performed | "
|
||||
f"Queue size: {current_queue_size}"
|
||||
)
|
||||
|
||||
self.logger.debug(
|
||||
f"Popping action from queue to perform took {get_end:.6f}s | Queue size: {current_queue_size}"
|
||||
)
|
||||
|
||||
return _performed_action
|
||||
|
||||
def _ready_to_send_observation(self):
|
||||
"""Flags when the client is ready to send an observation"""
|
||||
with self.action_queue_lock:
|
||||
return self.action_queue.qsize() / self.action_chunk_size <= self._chunk_size_threshold
|
||||
|
||||
def control_loop_observation(self, task: str, verbose: bool = False) -> RawObservation:
|
||||
try:
|
||||
# Get serialized observation bytes from the function
|
||||
start_time = time.perf_counter()
|
||||
|
||||
raw_observation: RawObservation = self.robot.get_observation()
|
||||
raw_observation["task"] = task
|
||||
|
||||
with self.latest_action_lock:
|
||||
latest_action = self.latest_action
|
||||
|
||||
observation = TimedObservation(
|
||||
timestamp=time.time(), # need time.time() to compare timestamps across client and server
|
||||
observation=raw_observation,
|
||||
timestep=max(latest_action, 0),
|
||||
)
|
||||
|
||||
obs_capture_time = time.perf_counter() - start_time
|
||||
|
||||
# If there are no actions left in the queue, the observation must go through processing!
|
||||
with self.action_queue_lock:
|
||||
observation.must_go = self.must_go.is_set() and self.action_queue.empty()
|
||||
current_queue_size = self.action_queue.qsize()
|
||||
|
||||
_ = self.send_observation(observation)
|
||||
|
||||
self.logger.debug(f"QUEUE SIZE: {current_queue_size} (Must go: {observation.must_go})")
|
||||
if observation.must_go:
|
||||
# must-go event will be set again after receiving actions
|
||||
self.must_go.clear()
|
||||
|
||||
if verbose:
|
||||
# Calculate comprehensive FPS metrics
|
||||
fps_metrics = self.fps_tracker.calculate_fps_metrics(observation.get_timestamp())
|
||||
|
||||
self.logger.info(
|
||||
f"Obs #{observation.get_timestep()} | "
|
||||
f"Avg FPS: {fps_metrics['avg_fps']:.2f} | "
|
||||
f"Target: {fps_metrics['target_fps']:.2f}"
|
||||
)
|
||||
|
||||
self.logger.debug(
|
||||
f"Ts={observation.get_timestamp():.6f} | Capturing observation took {obs_capture_time:.6f}s"
|
||||
)
|
||||
|
||||
return raw_observation
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in observation sender: {e}")
|
||||
|
||||
def control_loop(self, task: str, verbose: bool = False) -> tuple[Observation, Action]:
|
||||
"""Combined function for executing actions and streaming observations"""
|
||||
# Wait at barrier for synchronized start
|
||||
self.start_barrier.wait()
|
||||
self.logger.info("Control loop thread starting")
|
||||
|
||||
_performed_action = None
|
||||
_captured_observation = None
|
||||
|
||||
while self.running:
|
||||
control_loop_start = time.perf_counter()
|
||||
"""Control loop: (1) Performing actions, when available"""
|
||||
if self.actions_available():
|
||||
_performed_action = self.control_loop_action(verbose)
|
||||
|
||||
"""Control loop: (2) Streaming observations to the remote policy server"""
|
||||
if self._ready_to_send_observation():
|
||||
_captured_observation = self.control_loop_observation(task, verbose)
|
||||
|
||||
self.logger.debug(f"Control loop (ms): {(time.perf_counter() - control_loop_start) * 1000:.2f}")
|
||||
# Dynamically adjust sleep time to maintain the desired control frequency
|
||||
time.sleep(max(0, self.config.environment_dt - (time.perf_counter() - control_loop_start)))
|
||||
|
||||
return _captured_observation, _performed_action
|
||||
|
||||
|
||||
@draccus.wrap()
|
||||
def async_client(cfg: RobotClientConfig):
|
||||
logging.info(pformat(asdict(cfg)))
|
||||
|
||||
# TODO: Assert if checking robot support is still needed with the plugin system
|
||||
# if cfg.robot.type not in SUPPORTED_ROBOTS:
|
||||
# raise ValueError(f"Robot {cfg.robot.type} not yet supported!")
|
||||
|
||||
client = RobotClient(cfg)
|
||||
|
||||
if client.start():
|
||||
client.logger.info("Starting action receiver thread...")
|
||||
|
||||
# Create and start action receiver thread
|
||||
action_receiver_thread = threading.Thread(target=client.receive_actions, daemon=True)
|
||||
|
||||
# Start action receiver thread
|
||||
action_receiver_thread.start()
|
||||
|
||||
try:
|
||||
# The main thread runs the control loop
|
||||
client.control_loop(task=cfg.task)
|
||||
|
||||
finally:
|
||||
client.stop()
|
||||
action_receiver_thread.join()
|
||||
if cfg.debug_visualize_queue_size:
|
||||
visualize_action_queue_size(client.action_queue_size)
|
||||
client.logger.info("Client stopped")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
register_third_party_plugins()
|
||||
async_client() # run the client
|
||||
@@ -18,7 +18,6 @@ from __future__ import annotations
|
||||
# Utilities
|
||||
########################################################################################
|
||||
import logging
|
||||
import time
|
||||
import traceback
|
||||
from contextlib import nullcontext
|
||||
from copy import copy
|
||||
@@ -244,72 +243,3 @@ def sanity_check_dataset_robot_compatibility(
|
||||
raise ValueError(
|
||||
"Dataset metadata compatibility check failed with mismatches:\n" + "\n".join(mismatches)
|
||||
)
|
||||
|
||||
|
||||
########################################################################################
|
||||
# Teleoperator smooth handover helpers
|
||||
# NOTE(Maxime): These functions use minimal type hints to maintain compatibility with utils
|
||||
# being a root module.
|
||||
########################################################################################
|
||||
|
||||
|
||||
def teleop_supports_feedback(teleop) -> bool:
|
||||
"""Return True when the teleop can receive position feedback (is actuated).
|
||||
|
||||
Actuated teleops (e.g. SO-101, OpenArmMini) have non-empty ``feedback_features``
|
||||
and expose ``enable_torque`` / ``disable_torque`` motor-control methods.
|
||||
|
||||
TODO(Maxime): See if it is possible to unify this interface across teleops instead of duck-typing.
|
||||
"""
|
||||
return (
|
||||
bool(teleop.feedback_features)
|
||||
and hasattr(teleop, "disable_torque")
|
||||
and hasattr(teleop, "enable_torque")
|
||||
)
|
||||
|
||||
|
||||
def teleop_smooth_move_to(teleop, target_pos: dict, duration_s: float = 2.0, fps: int = 30) -> None:
|
||||
"""Smoothly move an actuated teleop to ``target_pos`` via linear interpolation.
|
||||
|
||||
Requires the teleoperator to support feedback (i.e. have non-empty
|
||||
``feedback_features`` and implement ``disable_torque`` / ``enable_torque``).
|
||||
|
||||
``target_pos`` is expected to be in the teleop's action/feedback key space.
|
||||
For homogeneous setups (e.g. SO-101 leader + SO-101 follower) this matches
|
||||
the robot action key space directly.
|
||||
|
||||
TODO(Maxime): This blocks up to ``duration_s`` seconds; during this time the
|
||||
follower robot does not receive new actions, which could be an issue on LeKiwi.
|
||||
"""
|
||||
teleop.enable_torque()
|
||||
current = teleop.get_action()
|
||||
steps = max(int(duration_s * fps), 1)
|
||||
|
||||
for step in range(steps + 1):
|
||||
t = step / steps
|
||||
interp = {
|
||||
k: current[k] * (1 - t) + target_pos[k] * t if k in target_pos else current[k] for k in current
|
||||
}
|
||||
teleop.send_feedback(interp)
|
||||
time.sleep(1 / fps)
|
||||
|
||||
|
||||
def follower_smooth_move_to(
|
||||
robot, current: dict, target: dict, duration_s: float = 1.0, fps: int = 30
|
||||
) -> None:
|
||||
"""Smoothly move the follower robot from ``current`` to ``target`` action.
|
||||
|
||||
Used when the teleop is non-actuated: instead of driving the leader arm to
|
||||
the follower, the follower is brought to the teleop's current pose so the
|
||||
robot meets the operator's hand rather than jumping to it on the first frame.
|
||||
|
||||
Both ``current`` and ``target`` must be in the robot action key space
|
||||
(i.e. the output of ``robot_action_processor``).
|
||||
"""
|
||||
steps = max(int(duration_s * fps), 1)
|
||||
|
||||
for step in range(steps + 1):
|
||||
t = step / steps
|
||||
interp = {k: current[k] * (1 - t) + target[k] * t if k in target else current[k] for k in current}
|
||||
robot.send_action(interp)
|
||||
time.sleep(1 / fps)
|
||||
|
||||
@@ -41,8 +41,8 @@ class DatasetRecordConfig:
|
||||
video: bool = True
|
||||
# Upload dataset to Hugging Face hub.
|
||||
push_to_hub: bool = True
|
||||
# If True, upload as private; if None, defer to the org default on the Hub (only affects orgs).
|
||||
private: bool | None = None
|
||||
# Upload on private repository on the Hugging Face hub.
|
||||
private: bool = False
|
||||
# Add tags to your dataset on the hub.
|
||||
tags: list[str] | None = None
|
||||
# Number of subprocesses handling the saving of frames as PNG. Set to 0 to use threads only;
|
||||
|
||||
@@ -255,7 +255,8 @@ def extract_path_fields_from_config(config_path: str, path_fields: list[str]) ->
|
||||
remaining = config_data[field]
|
||||
if remaining:
|
||||
_config_yaml_overrides[field] = _flatten_to_cli_args(remaining)
|
||||
del config_data[field]
|
||||
else:
|
||||
del config_data[field]
|
||||
modified = True
|
||||
|
||||
if not modified:
|
||||
@@ -310,13 +311,7 @@ def wrap(config_path: Path | None = None) -> Callable[[F], F]:
|
||||
cli_args = filter_arg("config_path", cli_args)
|
||||
cfg = argtype.from_pretrained(config_path_cli, cli_args=cli_args)
|
||||
else:
|
||||
if config_path_cli:
|
||||
cli_args = filter_arg("config_path", cli_args)
|
||||
cfg = draccus.parse(
|
||||
config_class=argtype,
|
||||
config_path=config_path_cli or config_path,
|
||||
args=cli_args,
|
||||
)
|
||||
cfg = draccus.parse(config_class=argtype, config_path=config_path, args=cli_args)
|
||||
response = fn(cfg, *args, **kwargs)
|
||||
return response
|
||||
|
||||
|
||||
@@ -177,12 +177,6 @@ class TrainPipelineConfig(HubMixin):
|
||||
)
|
||||
|
||||
active_cfg = self.trainable_config
|
||||
if self.rename_map and active_cfg.pretrained_path is None:
|
||||
raise ValueError(
|
||||
"`rename_map` requires a pretrained policy checkpoint. "
|
||||
"Fresh initialization derives feature names from the current dataset, so no rename is applied."
|
||||
)
|
||||
|
||||
if not self.job_name:
|
||||
if self.env is None:
|
||||
self.job_name = f"{active_cfg.type}"
|
||||
|
||||
@@ -524,7 +524,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
license: str | None = "apache-2.0",
|
||||
tag_version: bool = True,
|
||||
push_videos: bool = True,
|
||||
private: bool | None = None,
|
||||
private: bool = False,
|
||||
allow_patterns: list[str] | str | None = None,
|
||||
upload_large_folder: bool = False,
|
||||
**card_kwargs,
|
||||
@@ -543,8 +543,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
tag_version: If ``True``, create a Git tag for the current codebase
|
||||
version.
|
||||
push_videos: If ``False``, skip uploading the ``videos/`` directory.
|
||||
private: If ``True``, create a private repository. If ``None``
|
||||
(default), defer to the org default on the Hub (only affects orgs).
|
||||
private: If ``True``, create a private repository.
|
||||
allow_patterns: Glob pattern(s) restricting which files to upload.
|
||||
upload_large_folder: If ``True``, use ``upload_large_folder`` instead
|
||||
of ``upload_folder`` for very large datasets.
|
||||
|
||||
@@ -30,7 +30,6 @@ class EpisodeAwareSampler:
|
||||
drop_n_first_frames: int = 0,
|
||||
drop_n_last_frames: int = 0,
|
||||
shuffle: bool = False,
|
||||
generator: torch.Generator | None = None,
|
||||
):
|
||||
"""Sampler that optionally incorporates episode boundary information.
|
||||
|
||||
@@ -42,10 +41,6 @@ class EpisodeAwareSampler:
|
||||
drop_n_first_frames: Number of frames to drop from the start of each episode.
|
||||
drop_n_last_frames: Number of frames to drop from the end of each episode.
|
||||
shuffle: Whether to shuffle the indices.
|
||||
generator: Generator used for shuffling. Exposing this attribute (even when None) lets
|
||||
`accelerate` register it as the synchronized RNG in distributed training, so
|
||||
every rank draws the same permutation and batch shards stay disjoint. When
|
||||
None, shuffling falls back to the global torch RNG.
|
||||
"""
|
||||
if drop_n_first_frames < 0:
|
||||
raise ValueError(f"drop_n_first_frames must be >= 0, got {drop_n_first_frames}")
|
||||
@@ -78,11 +73,10 @@ class EpisodeAwareSampler:
|
||||
|
||||
self.indices = indices
|
||||
self.shuffle = shuffle
|
||||
self.generator = generator
|
||||
|
||||
def __iter__(self) -> Iterator[int]:
|
||||
if self.shuffle:
|
||||
for i in torch.randperm(len(self.indices), generator=self.generator):
|
||||
for i in torch.randperm(len(self.indices)):
|
||||
yield self.indices[i]
|
||||
else:
|
||||
for i in self.indices:
|
||||
|
||||
@@ -20,7 +20,6 @@ from .eo1.configuration_eo1 import EO1Config as EO1Config
|
||||
from .factory import get_policy_class, make_policy, make_policy_config, make_pre_post_processors
|
||||
from .gaussian_actor.configuration_gaussian_actor import GaussianActorConfig as GaussianActorConfig
|
||||
from .groot.configuration_groot import GrootConfig as GrootConfig
|
||||
from .molmoact2.configuration_molmoact2 import MolmoAct2Config as MolmoAct2Config
|
||||
from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig as MultiTaskDiTConfig
|
||||
from .pi0.configuration_pi0 import PI0Config as PI0Config
|
||||
from .pi0_fast.configuration_pi0_fast import PI0FastConfig as PI0FastConfig
|
||||
@@ -44,7 +43,6 @@ __all__ = [
|
||||
"EO1Config",
|
||||
"GaussianActorConfig",
|
||||
"GrootConfig",
|
||||
"MolmoAct2Config",
|
||||
"MultiTaskDiTConfig",
|
||||
"PI0Config",
|
||||
"PI0FastConfig",
|
||||
|
||||
@@ -49,7 +49,6 @@ from .diffusion.configuration_diffusion import DiffusionConfig
|
||||
from .eo1.configuration_eo1 import EO1Config
|
||||
from .gaussian_actor.configuration_gaussian_actor import GaussianActorConfig
|
||||
from .groot.configuration_groot import GrootConfig
|
||||
from .molmoact2.configuration_molmoact2 import MolmoAct2Config
|
||||
from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
|
||||
from .pi0.configuration_pi0 import PI0Config
|
||||
from .pi05.configuration_pi05 import PI05Config
|
||||
@@ -57,7 +56,6 @@ from .pretrained import PreTrainedPolicy
|
||||
from .smolvla.configuration_smolvla import SmolVLAConfig
|
||||
from .tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from .utils import validate_visual_features_consistency
|
||||
from .vla_jepa.configuration_vla_jepa import VLAJEPAConfig
|
||||
from .vqbet.configuration_vqbet import VQBeTConfig
|
||||
from .wall_x.configuration_wall_x import WallXConfig
|
||||
from .xvla.configuration_xvla import XVLAConfig
|
||||
@@ -90,8 +88,7 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
|
||||
Args:
|
||||
name: The name of the policy. Supported names are "tdmpc", "diffusion", "act",
|
||||
"multi_task_dit", "vqbet", "pi0", "pi05", "gaussian_actor", "smolvla", "wall_x",
|
||||
"molmoact2".
|
||||
"multi_task_dit", "vqbet", "pi0", "pi05", "gaussian_actor", "smolvla", "wall_x".
|
||||
Returns:
|
||||
The policy class corresponding to the given name.
|
||||
|
||||
@@ -154,14 +151,6 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
from .eo1.modeling_eo1 import EO1Policy
|
||||
|
||||
return EO1Policy
|
||||
elif name == "molmoact2":
|
||||
from .molmoact2.modeling_molmoact2 import MolmoAct2Policy
|
||||
|
||||
return MolmoAct2Policy
|
||||
elif name == "vla_jepa":
|
||||
from .vla_jepa.modeling_vla_jepa import VLAJEPAPolicy
|
||||
|
||||
return VLAJEPAPolicy
|
||||
else:
|
||||
try:
|
||||
return _get_policy_cls_from_policy_name(name=name)
|
||||
@@ -179,7 +168,7 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
Args:
|
||||
policy_type: The type of the policy. Supported types include "tdmpc",
|
||||
"multi_task_dit", "diffusion", "act", "vqbet", "pi0", "pi05", "gaussian_actor",
|
||||
"smolvla", "wall_x", "molmoact2".
|
||||
"smolvla", "wall_x".
|
||||
**kwargs: Keyword arguments to be passed to the configuration class constructor.
|
||||
|
||||
Returns:
|
||||
@@ -214,10 +203,6 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
return WallXConfig(**kwargs)
|
||||
elif policy_type == "eo1":
|
||||
return EO1Config(**kwargs)
|
||||
elif policy_type == "molmoact2":
|
||||
return MolmoAct2Config(**kwargs)
|
||||
elif policy_type == "vla_jepa":
|
||||
return VLAJEPAConfig(**kwargs)
|
||||
else:
|
||||
try:
|
||||
config_cls = PreTrainedConfig.get_choice_class(policy_type)
|
||||
@@ -246,7 +231,6 @@ class ProcessorConfigKwargs(TypedDict, total=False):
|
||||
preprocessor_overrides: dict[str, Any] | None
|
||||
postprocessor_overrides: dict[str, Any] | None
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None
|
||||
dataset_meta: Any | None
|
||||
|
||||
|
||||
def make_pre_post_processors(
|
||||
@@ -422,7 +406,6 @@ def make_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, EO1Config):
|
||||
from .eo1.processor_eo1 import make_eo1_pre_post_processors
|
||||
|
||||
@@ -431,23 +414,6 @@ def make_pre_post_processors(
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, MolmoAct2Config):
|
||||
from .molmoact2.processor_molmoact2 import make_molmoact2_pre_post_processors
|
||||
|
||||
processors = make_molmoact2_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
dataset_meta=kwargs.get("dataset_meta"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, VLAJEPAConfig):
|
||||
from .vla_jepa.processor_vla_jepa import make_vla_jepa_pre_post_processors
|
||||
|
||||
processors = make_vla_jepa_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
else:
|
||||
try:
|
||||
processors = _make_processors_from_policy_config(
|
||||
@@ -533,10 +499,6 @@ def make_policy(
|
||||
action_names = ds_meta.features.get(ACTION, {}).get("names")
|
||||
if action_names is not None:
|
||||
cfg.action_feature_names = list(action_names)
|
||||
if ds_meta is not None:
|
||||
set_dataset_feature_metadata = getattr(cfg, "set_dataset_feature_metadata", None)
|
||||
if callable(set_dataset_feature_metadata):
|
||||
set_dataset_feature_metadata(ds_meta.features)
|
||||
|
||||
kwargs["config"] = cfg
|
||||
|
||||
|
||||
@@ -60,7 +60,6 @@ class Eagle25VLPreTrainedModel(PreTrainedModel):
|
||||
"SiglipEncoderLayer",
|
||||
]
|
||||
_skip_keys_device_placement = "past_key_values"
|
||||
_supports_flash_attn = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_cache_class = True
|
||||
_supports_static_cache = True
|
||||
|
||||
@@ -124,6 +124,7 @@ class Eagle25VLProcessor(ProcessorMixin):
|
||||
"videos_kwargs",
|
||||
"text_kwargs",
|
||||
]
|
||||
image_processor_class = "AutoImageProcessor"
|
||||
tokenizer_class = "AutoTokenizer"
|
||||
|
||||
def __init__(
|
||||
|
||||
@@ -206,11 +206,7 @@ def _build_eagle_processor(tokenizer_assets_repo: str = DEFAULT_TOKENIZER_ASSETS
|
||||
"Vendor files are copied during model creation. Create the policy/model first, "
|
||||
"or call ensure_eagle_cache_ready() before building processors."
|
||||
)
|
||||
proc = AutoProcessor.from_pretrained(
|
||||
str(cache_dir),
|
||||
trust_remote_code=True,
|
||||
fix_mistral_regex=False,
|
||||
)
|
||||
proc = AutoProcessor.from_pretrained(str(cache_dir), trust_remote_code=True, use_fast=True)
|
||||
proc.tokenizer.padding_side = "left"
|
||||
return proc
|
||||
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
../../../../docs/source/policy_molmoact2_README.md
|
||||
@@ -1,21 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .configuration_molmoact2 import MolmoAct2Config
|
||||
from .modeling_molmoact2 import MolmoAct2Policy
|
||||
from .processor_molmoact2 import make_molmoact2_pre_post_processors
|
||||
|
||||
__all__ = ["MolmoAct2Config", "MolmoAct2Policy", "make_molmoact2_pre_post_processors"]
|
||||
@@ -1,519 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
from contextlib import suppress
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from lerobot.configs import FeatureType, NormalizationMode, PolicyFeature, PreTrainedConfig
|
||||
from lerobot.optim import (
|
||||
AdamWConfig,
|
||||
CosineDecayWithWarmupSchedulerConfig,
|
||||
LRSchedulerConfig,
|
||||
OptimizerConfig,
|
||||
)
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE
|
||||
|
||||
from ..rtc.configuration_rtc import RTCConfig
|
||||
|
||||
MOLMOACT2_DEFAULT_NUM_IMAGES = 2
|
||||
MOLMOACT2_IMAGE_TOKENS_PER_IMAGE = 196
|
||||
MOLMOACT2_FIXED_PROMPT_TOKEN_BUDGET = 80
|
||||
MOLMOACT2_TASK_TOKEN_BUDGET = 32
|
||||
MOLMOACT2_SEQUENCE_LENGTH_MARGIN = 32
|
||||
MOLMOACT2_SEQUENCE_LENGTH_MULTIPLE = 64
|
||||
MOLMOACT2_DISCRETE_ACTION_WRAPPER_TOKENS = 4
|
||||
MOLMOACT2_MIN_DISCRETE_ACTION_TOKENS_PER_STEP = 6
|
||||
MOLMOACT2_DISCRETE_ACTION_TOKENS_PER_DIM = 0.95
|
||||
|
||||
|
||||
def _hf_token() -> str | None:
|
||||
return os.environ.get("HF_TOKEN") or os.environ.get("HF_ACCESS_TOKEN")
|
||||
|
||||
|
||||
def _resolve_checkpoint_location(
|
||||
checkpoint_path: str,
|
||||
*,
|
||||
revision: str | None = None,
|
||||
force_download: bool = False,
|
||||
) -> str:
|
||||
checkpoint_path = str(checkpoint_path or "").strip()
|
||||
if not checkpoint_path:
|
||||
raise ValueError("MolmoAct2 policy requires `checkpoint_path`.")
|
||||
local_path = Path(checkpoint_path).expanduser()
|
||||
if local_path.exists():
|
||||
return str(local_path)
|
||||
return snapshot_download(
|
||||
repo_id=checkpoint_path,
|
||||
repo_type="model",
|
||||
revision=revision,
|
||||
force_download=force_download,
|
||||
ignore_patterns=["*.py", "*.pyc", "__pycache__/*"],
|
||||
token=_hf_token(),
|
||||
)
|
||||
|
||||
|
||||
def _load_hf_norm_metadata_for_tag(
|
||||
checkpoint_path: str,
|
||||
*,
|
||||
revision: str | None,
|
||||
force_download: bool,
|
||||
norm_tag: str | None,
|
||||
) -> dict[str, Any]:
|
||||
norm_tag = str(norm_tag or "").strip()
|
||||
if not norm_tag:
|
||||
return {}
|
||||
checkpoint_location = Path(
|
||||
_resolve_checkpoint_location(
|
||||
checkpoint_path,
|
||||
revision=revision,
|
||||
force_download=force_download,
|
||||
)
|
||||
)
|
||||
norm_stats_filename = "norm_stats.json"
|
||||
config_path = checkpoint_location / "config.json"
|
||||
if config_path.exists():
|
||||
with suppress(OSError, json.JSONDecodeError):
|
||||
norm_stats_filename = str(
|
||||
json.loads(config_path.read_text()).get("norm_stats_filename") or norm_stats_filename
|
||||
)
|
||||
stats_path = checkpoint_location / norm_stats_filename
|
||||
if not stats_path.exists():
|
||||
raise FileNotFoundError(
|
||||
f"MolmoAct2 HF checkpoint is missing {norm_stats_filename!r}; cannot resolve norm_tag={norm_tag!r}."
|
||||
)
|
||||
payload = json.loads(stats_path.read_text())
|
||||
metadata_by_tag = payload.get("metadata_by_tag")
|
||||
if not isinstance(metadata_by_tag, dict):
|
||||
raise ValueError(f"MolmoAct2 norm stats file {stats_path} has no metadata_by_tag mapping.")
|
||||
metadata = metadata_by_tag.get(norm_tag)
|
||||
if not isinstance(metadata, dict):
|
||||
available = sorted(str(tag) for tag in metadata_by_tag)
|
||||
raise ValueError(f"Unknown MolmoAct2 norm_tag={norm_tag!r}. Available tags: {available}.")
|
||||
return metadata
|
||||
|
||||
|
||||
@LRSchedulerConfig.register_subclass("molmoact2_cosine_decay_with_warmup")
|
||||
@dataclass
|
||||
class MolmoAct2CosineDecayWithWarmupSchedulerConfig(CosineDecayWithWarmupSchedulerConfig):
|
||||
"""MolmoAct2-local cosine scheduler with optional decay-step auto-match.
|
||||
|
||||
LeRobot's generic cosine scheduler keeps an explicit integer decay length.
|
||||
For MolmoAct2, leaving num_decay_steps unset means "decay across this run's
|
||||
training steps"; build() is the first point where num_training_steps is known.
|
||||
"""
|
||||
|
||||
num_decay_steps: int | None
|
||||
|
||||
def build(self, optimizer, num_training_steps: int):
|
||||
return CosineDecayWithWarmupSchedulerConfig(
|
||||
peak_lr=self.peak_lr,
|
||||
decay_lr=self.decay_lr,
|
||||
num_warmup_steps=self.num_warmup_steps,
|
||||
num_decay_steps=num_training_steps if self.num_decay_steps is None else self.num_decay_steps,
|
||||
).build(optimizer, num_training_steps=num_training_steps)
|
||||
|
||||
|
||||
def _round_up(value: int, multiple: int) -> int:
|
||||
return int(math.ceil(value / multiple) * multiple)
|
||||
|
||||
|
||||
def infer_molmoact2_max_sequence_length(
|
||||
*,
|
||||
num_images: int,
|
||||
state_dim: int,
|
||||
action_dim: int,
|
||||
action_horizon: int,
|
||||
include_discrete_action: bool,
|
||||
) -> int:
|
||||
"""Infer the padded text/image sequence cap from MolmoAct2's fixed token layout."""
|
||||
if num_images < 1:
|
||||
num_images = MOLMOACT2_DEFAULT_NUM_IMAGES
|
||||
if state_dim < 0:
|
||||
state_dim = 0
|
||||
if action_dim < 1:
|
||||
action_dim = 1
|
||||
if action_horizon < 1:
|
||||
action_horizon = 1
|
||||
|
||||
image_tokens = num_images * MOLMOACT2_IMAGE_TOKENS_PER_IMAGE
|
||||
prompt_tokens = (
|
||||
MOLMOACT2_FIXED_PROMPT_TOKEN_BUDGET
|
||||
+ MOLMOACT2_TASK_TOKEN_BUDGET
|
||||
+ state_dim
|
||||
+ MOLMOACT2_SEQUENCE_LENGTH_MARGIN
|
||||
)
|
||||
action_tokens = 0
|
||||
if include_discrete_action:
|
||||
action_tokens_per_step = max(
|
||||
MOLMOACT2_MIN_DISCRETE_ACTION_TOKENS_PER_STEP,
|
||||
math.ceil(action_dim * MOLMOACT2_DISCRETE_ACTION_TOKENS_PER_DIM),
|
||||
)
|
||||
action_tokens = MOLMOACT2_DISCRETE_ACTION_WRAPPER_TOKENS + action_horizon * action_tokens_per_step
|
||||
|
||||
return _round_up(
|
||||
image_tokens + prompt_tokens + action_tokens,
|
||||
MOLMOACT2_SEQUENCE_LENGTH_MULTIPLE,
|
||||
)
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("molmoact2")
|
||||
@dataclass
|
||||
class MolmoAct2Config(PreTrainedConfig):
|
||||
"""MolmoAct2 policy backed by the converted HF checkpoint implementation."""
|
||||
|
||||
checkpoint_path: str = "allenai/MolmoAct2"
|
||||
checkpoint_revision: str | None = None
|
||||
checkpoint_force_download: bool = False
|
||||
|
||||
n_obs_steps: int = 1
|
||||
chunk_size: int = 30
|
||||
n_action_steps: int = 30
|
||||
|
||||
action_mode: str = "both"
|
||||
inference_action_mode: str | None = None
|
||||
discrete_action_tokenizer: str = "allenai/MolmoAct2-FAST-Tokenizer"
|
||||
discrete_generation_max_steps: int | None = None
|
||||
norm_tag: str | None = None
|
||||
|
||||
setup_type: str = ""
|
||||
control_mode: str = ""
|
||||
image_keys: list[str] = field(default_factory=list)
|
||||
normalize_language: bool = True
|
||||
add_setup_tokens: bool = True
|
||||
add_control_tokens: bool = True
|
||||
normalize_gripper: bool = False
|
||||
num_state_tokens: int = 256
|
||||
# Leave unset for the default MolmoAct2 sequence budget inferred from the fixed
|
||||
# image/prompt/state/action token layout. Override only for unusual long prompts.
|
||||
max_sequence_length: int | None = None
|
||||
|
||||
# Fixed by released MolmoAct2 checkpoints. We validate this at model load.
|
||||
expected_max_action_dim: int = 32
|
||||
|
||||
# Flow-matching training knobs copied from the original MolmoAct2 training path.
|
||||
num_flow_timesteps: int = 8
|
||||
flow_matching_cutoff: float = 1.0
|
||||
flow_matching_time_offset: float = 0.001
|
||||
flow_matching_time_scale: float = 0.999
|
||||
flow_matching_beta_alpha: float = 1.0
|
||||
flow_matching_beta_beta: float = 1.5
|
||||
num_inference_steps: int | None = None
|
||||
mask_action_dim_padding: bool = True
|
||||
enable_inference_cuda_graph: bool = True
|
||||
# MolmoAct2-local eval option. When enabled, stochastic continuous action
|
||||
# generation uses a rollout-local generator derived from eval_seed.
|
||||
per_episode_seed: bool = False
|
||||
eval_seed: int | None = None
|
||||
rtc_config: RTCConfig | None = None
|
||||
|
||||
# Default is full finetuning with gradients from the action expert flowing into the VLM.
|
||||
enable_lora_vlm: bool = False
|
||||
lora_rank: int = 64
|
||||
lora_alpha: int = 16
|
||||
lora_dropout: float = 0.05
|
||||
lora_bias: str = "none"
|
||||
enable_lora_action_expert: bool = False
|
||||
enable_knowledge_insulation: bool = False
|
||||
freeze_embedding: bool = True
|
||||
train_action_expert_only: bool = False
|
||||
gradient_checkpointing: bool = False
|
||||
|
||||
model_dtype: str = "bfloat16"
|
||||
softmax_auxiliary_loss: bool = True
|
||||
softmax_auxiliary_loss_scale: float = 1e-4
|
||||
discrete_loss_token_weighting: str = "root_subsegments_root_tokens"
|
||||
|
||||
optimizer_lr: float = 1e-5
|
||||
optimizer_vit_lr: float = 5e-6
|
||||
optimizer_connector_lr: float = 5e-6
|
||||
optimizer_action_expert_lr: float = 5e-5
|
||||
optimizer_betas: tuple[float, float] = (0.9, 0.95)
|
||||
optimizer_eps: float = 1e-6
|
||||
optimizer_weight_decay: float = 0.0
|
||||
optimizer_grad_clip_norm: float = 1.0
|
||||
|
||||
scheduler_warmup_steps: int = 200
|
||||
scheduler_decay_steps: int | None = None
|
||||
scheduler_decay_lr: float = 1e-6
|
||||
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.QUANTILES,
|
||||
"ACTION": NormalizationMode.QUANTILES,
|
||||
}
|
||||
)
|
||||
|
||||
input_features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||
output_features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||
dataset_feature_names: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
super().__post_init__()
|
||||
if self.action_mode not in {"continuous", "discrete", "both"}:
|
||||
raise ValueError(
|
||||
f"Unsupported action_mode={self.action_mode!r}. "
|
||||
"Expected one of {'continuous', 'discrete', 'both'}."
|
||||
)
|
||||
if self.inference_action_mode not in {None, "continuous", "discrete"}:
|
||||
raise ValueError(
|
||||
f"Unsupported inference_action_mode={self.inference_action_mode!r}. "
|
||||
"Expected one of {None, 'continuous', 'discrete'}."
|
||||
)
|
||||
if self.inference_action_mode == "continuous" and self.action_mode == "discrete":
|
||||
raise ValueError("MolmoAct2 action_mode='discrete' cannot run continuous inference.")
|
||||
if self.inference_action_mode == "discrete" and self.action_mode == "continuous":
|
||||
raise ValueError("MolmoAct2 action_mode='continuous' cannot run discrete inference.")
|
||||
if self.train_action_expert_only and self.action_mode != "continuous":
|
||||
raise ValueError("MolmoAct2 train_action_expert_only requires action_mode='continuous'.")
|
||||
if self.train_action_expert_only and self.enable_lora_vlm:
|
||||
raise ValueError("MolmoAct2 train_action_expert_only is incompatible with enable_lora_vlm.")
|
||||
if self.enable_lora_action_expert and not self.enable_lora_vlm:
|
||||
raise ValueError("MolmoAct2 enable_lora_action_expert requires enable_lora_vlm.")
|
||||
if self.chunk_size < 1:
|
||||
raise ValueError(f"chunk_size must be >= 1, got {self.chunk_size}.")
|
||||
if self.n_action_steps < 1:
|
||||
raise ValueError(f"n_action_steps must be >= 1, got {self.n_action_steps}.")
|
||||
if self.n_action_steps > self.chunk_size:
|
||||
raise ValueError(
|
||||
f"n_action_steps ({self.n_action_steps}) cannot exceed chunk_size ({self.chunk_size})."
|
||||
)
|
||||
if self.expected_max_action_dim != 32:
|
||||
raise ValueError("MolmoAct2 released checkpoints use expected_max_action_dim=32.")
|
||||
if self.model_dtype not in {"float32", "bfloat16", "float16"}:
|
||||
raise ValueError(
|
||||
f"Unsupported model_dtype={self.model_dtype!r}. Expected 'float32', 'bfloat16', or 'float16'."
|
||||
)
|
||||
if self.lora_rank < 1:
|
||||
raise ValueError(f"lora_rank must be >= 1, got {self.lora_rank}.")
|
||||
if self.lora_alpha < 1:
|
||||
raise ValueError(f"lora_alpha must be >= 1, got {self.lora_alpha}.")
|
||||
if not 0 <= self.lora_dropout <= 1:
|
||||
raise ValueError(f"lora_dropout must be in [0, 1], got {self.lora_dropout}.")
|
||||
if self.lora_bias not in {"none", "all", "lora_only"}:
|
||||
raise ValueError(
|
||||
f"Unsupported lora_bias={self.lora_bias!r}. Expected one of 'none', 'all', or 'lora_only'."
|
||||
)
|
||||
if self.discrete_loss_token_weighting not in {
|
||||
"none",
|
||||
"token",
|
||||
"root_tokens",
|
||||
"root_subsegments",
|
||||
"root_subsegments_root_tokens",
|
||||
}:
|
||||
raise ValueError(
|
||||
f"Unsupported discrete_loss_token_weighting={self.discrete_loss_token_weighting!r}."
|
||||
)
|
||||
if self.discrete_generation_max_steps is not None and self.discrete_generation_max_steps < 1:
|
||||
raise ValueError(
|
||||
f"discrete_generation_max_steps must be >= 1 or None, got {self.discrete_generation_max_steps}."
|
||||
)
|
||||
if self.max_sequence_length is not None and self.max_sequence_length < 1:
|
||||
raise ValueError(f"max_sequence_length must be >= 1 or None, got {self.max_sequence_length}.")
|
||||
|
||||
def inferred_max_sequence_length(
|
||||
self,
|
||||
*,
|
||||
num_images: int | None = None,
|
||||
state_dim: int | None = None,
|
||||
action_dim: int | None = None,
|
||||
action_horizon: int | None = None,
|
||||
include_discrete_action: bool | None = None,
|
||||
) -> int:
|
||||
if self.max_sequence_length is not None:
|
||||
return int(self.max_sequence_length)
|
||||
|
||||
if num_images is None:
|
||||
num_images = len(self.image_keys) or len(self.image_features) or MOLMOACT2_DEFAULT_NUM_IMAGES
|
||||
if state_dim is None:
|
||||
state_feature = self.robot_state_feature
|
||||
state_dim = int(state_feature.shape[0]) if state_feature is not None else 0
|
||||
if action_dim is None:
|
||||
action_feature = self.action_feature
|
||||
action_dim = (
|
||||
int(action_feature.shape[0]) if action_feature is not None else self.expected_max_action_dim
|
||||
)
|
||||
if action_horizon is None:
|
||||
action_horizon = self.chunk_size
|
||||
if include_discrete_action is None:
|
||||
include_discrete_action = self.action_mode in {"discrete", "both"}
|
||||
|
||||
return infer_molmoact2_max_sequence_length(
|
||||
num_images=int(num_images),
|
||||
state_dim=int(state_dim),
|
||||
action_dim=int(action_dim),
|
||||
action_horizon=int(action_horizon),
|
||||
include_discrete_action=bool(include_discrete_action),
|
||||
)
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list[int]:
|
||||
return list(range(self.chunk_size))
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
return None
|
||||
|
||||
def get_optimizer_preset(self) -> OptimizerConfig:
|
||||
return AdamWConfig(
|
||||
lr=self.optimizer_lr,
|
||||
betas=self.optimizer_betas,
|
||||
eps=self.optimizer_eps,
|
||||
weight_decay=self.optimizer_weight_decay,
|
||||
grad_clip_norm=self.optimizer_grad_clip_norm,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self) -> LRSchedulerConfig | None:
|
||||
return MolmoAct2CosineDecayWithWarmupSchedulerConfig(
|
||||
peak_lr=self.optimizer_lr,
|
||||
decay_lr=self.scheduler_decay_lr,
|
||||
num_warmup_steps=self.scheduler_warmup_steps,
|
||||
num_decay_steps=self.scheduler_decay_steps,
|
||||
)
|
||||
|
||||
def set_dataset_feature_metadata(self, features: dict[str, Any]) -> None:
|
||||
self.dataset_feature_names = {}
|
||||
for key in (ACTION, OBS_STATE):
|
||||
feature = features.get(key) if isinstance(features, dict) else None
|
||||
if isinstance(feature, dict) and feature.get("names") is not None:
|
||||
self.dataset_feature_names[key] = feature["names"]
|
||||
|
||||
def validate_features(self) -> None:
|
||||
"""Validate and set up MolmoAct2 input and output features."""
|
||||
image_features = [key for key, feat in self.input_features.items() if feat.type == FeatureType.VISUAL]
|
||||
if not image_features:
|
||||
raise ValueError(
|
||||
"MolmoAct2 policy requires at least one visual input feature. "
|
||||
"No features of type FeatureType.VISUAL found in input_features."
|
||||
)
|
||||
|
||||
if OBS_STATE not in self.input_features:
|
||||
state_feature = PolicyFeature(
|
||||
type=FeatureType.STATE,
|
||||
shape=(0,),
|
||||
)
|
||||
self.input_features[OBS_STATE] = state_feature
|
||||
|
||||
if ACTION not in self.output_features:
|
||||
action_feature = PolicyFeature(
|
||||
type=FeatureType.ACTION,
|
||||
shape=(self.expected_max_action_dim,),
|
||||
)
|
||||
self.output_features[ACTION] = action_feature
|
||||
|
||||
def apply_norm_tag_metadata(self) -> None:
|
||||
if not str(self.norm_tag or "").strip():
|
||||
return
|
||||
metadata = _load_hf_norm_metadata_for_tag(
|
||||
self.checkpoint_path,
|
||||
revision=self.checkpoint_revision,
|
||||
force_download=bool(self.checkpoint_force_download),
|
||||
norm_tag=self.norm_tag,
|
||||
)
|
||||
if metadata.get("action_horizon") is not None:
|
||||
self.chunk_size = int(metadata["action_horizon"])
|
||||
if metadata.get("n_action_steps") is not None:
|
||||
self.n_action_steps = int(metadata["n_action_steps"])
|
||||
if not self.setup_type and metadata.get("setup_type") is not None:
|
||||
self.setup_type = str(metadata["setup_type"])
|
||||
if not self.control_mode and metadata.get("control_mode") is not None:
|
||||
self.control_mode = str(metadata["control_mode"])
|
||||
|
||||
def saved_policy_action_mode(self) -> str | None:
|
||||
pretrained_path = getattr(self, "pretrained_path", None)
|
||||
if pretrained_path is None:
|
||||
return None
|
||||
config_path = Path(pretrained_path) / "config.json"
|
||||
if not config_path.exists():
|
||||
return None
|
||||
try:
|
||||
mode = json.loads(config_path.read_text()).get("action_mode")
|
||||
except (OSError, json.JSONDecodeError):
|
||||
return None
|
||||
if mode in {"continuous", "discrete", "both"}:
|
||||
return str(mode)
|
||||
return None
|
||||
|
||||
def training_action_mode(self, saved_policy_action_mode: str | None = None) -> str:
|
||||
return saved_policy_action_mode or self.action_mode
|
||||
|
||||
def validate_inference_action_mode(self, saved_policy_action_mode: str | None = None) -> None:
|
||||
requested_mode = self.inference_action_mode
|
||||
if requested_mode is None:
|
||||
return
|
||||
training_mode = self.training_action_mode(saved_policy_action_mode)
|
||||
if requested_mode == "continuous" and training_mode == "discrete":
|
||||
raise ValueError(
|
||||
"MolmoAct2 checkpoint was trained with action_mode='discrete' and cannot run "
|
||||
"continuous inference."
|
||||
)
|
||||
if requested_mode == "discrete" and training_mode == "continuous":
|
||||
raise ValueError(
|
||||
"MolmoAct2 checkpoint was trained with action_mode='continuous' and cannot run "
|
||||
"discrete inference. Train with action_mode='both' or action_mode='discrete' first."
|
||||
)
|
||||
|
||||
def validate_checkpoint_action_mode(
|
||||
self,
|
||||
checkpoint_action_mode: str,
|
||||
*,
|
||||
has_action_expert: bool,
|
||||
) -> None:
|
||||
if self.action_mode == "both" and checkpoint_action_mode != "both":
|
||||
raise ValueError(
|
||||
f"action_mode='both' requires checkpoint action_mode='both', got {checkpoint_action_mode!r}."
|
||||
)
|
||||
if self.action_mode == "discrete" and checkpoint_action_mode not in {"discrete", "both"}:
|
||||
raise ValueError(
|
||||
f"action_mode='discrete' requires checkpoint action_mode in {{'discrete', 'both'}}, "
|
||||
f"got {checkpoint_action_mode!r}."
|
||||
)
|
||||
if self.action_mode in {"continuous", "both"} and not has_action_expert:
|
||||
raise ValueError("Continuous MolmoAct2 training requires an action expert checkpoint.")
|
||||
|
||||
def resolve_inference_action_mode(
|
||||
self,
|
||||
requested_mode: str | None,
|
||||
saved_policy_action_mode: str | None = None,
|
||||
) -> str:
|
||||
training_mode = self.training_action_mode(saved_policy_action_mode)
|
||||
if requested_mode is None:
|
||||
requested_mode = self.inference_action_mode
|
||||
if requested_mode is None:
|
||||
raise ValueError(
|
||||
"MolmoAct2 inference requires `inference_action_mode` to be set explicitly "
|
||||
"to either 'continuous' or 'discrete'."
|
||||
)
|
||||
if requested_mode not in {"continuous", "discrete"}:
|
||||
raise ValueError("MolmoAct2 inference_action_mode must be either 'continuous' or 'discrete'.")
|
||||
if requested_mode == "continuous" and training_mode == "discrete":
|
||||
raise ValueError("MolmoAct2 action_mode='discrete' checkpoint cannot run continuous inference.")
|
||||
if requested_mode == "discrete" and training_mode == "continuous":
|
||||
raise ValueError("MolmoAct2 action_mode='continuous' checkpoint cannot run discrete inference.")
|
||||
return requested_mode
|
||||
@@ -1,17 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# ruff: noqa
|
||||
@@ -1,237 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# ruff: noqa
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import ClassVar
|
||||
|
||||
import numpy as np
|
||||
from tokenizers import ByteLevelBPETokenizer
|
||||
from tokenizers.trainers import BpeTrainer
|
||||
from huggingface_hub import snapshot_download
|
||||
from transformers import PreTrainedTokenizerFast
|
||||
from transformers.processing_utils import ProcessorMixin
|
||||
|
||||
|
||||
def _hf_token() -> str | None:
|
||||
return os.environ.get("HF_TOKEN") or os.environ.get("HF_ACCESS_TOKEN")
|
||||
|
||||
|
||||
def _resolve_tokenizer_location(
|
||||
tokenizer_path: str,
|
||||
*,
|
||||
revision: str | None = None,
|
||||
force_download: bool = False,
|
||||
) -> str:
|
||||
local_path = Path(str(tokenizer_path)).expanduser()
|
||||
if local_path.exists():
|
||||
return str(local_path)
|
||||
return snapshot_download(
|
||||
repo_id=str(tokenizer_path),
|
||||
repo_type="model",
|
||||
revision=revision,
|
||||
force_download=force_download,
|
||||
ignore_patterns=["*.py", "*.pyc", "__pycache__/*"],
|
||||
token=_hf_token(),
|
||||
)
|
||||
|
||||
|
||||
class UniversalActionProcessor(ProcessorMixin):
|
||||
attributes: ClassVar[list[str]] = ["tokenizer"]
|
||||
tokenizer_class: str = "AutoTokenizer"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: PreTrainedTokenizerFast,
|
||||
scale: float = 10,
|
||||
vocab_size: int = 1024,
|
||||
min_token: int = 0,
|
||||
*,
|
||||
action_dim: int | None = None,
|
||||
time_horizon: int | None = None,
|
||||
):
|
||||
self.scale = scale
|
||||
self.vocab_size = vocab_size
|
||||
self.min_token = min_token
|
||||
|
||||
# Action horizon and dimension needed during decoding. These can be specified
|
||||
# in three ways (in order of priority):
|
||||
# 1. passed in as kwargs to decode()
|
||||
# 2. in the constructor
|
||||
# 3. cached from the last time decode() was called
|
||||
self.time_horizon = time_horizon
|
||||
self.action_dim = action_dim
|
||||
self.called_time_horizon = time_horizon
|
||||
self.called_action_dim = action_dim
|
||||
|
||||
super().__init__(tokenizer)
|
||||
self.bpe_tokenizer = self.tokenizer
|
||||
|
||||
def __call__(self, action_chunk: np.array) -> np.array:
|
||||
from scipy.fft import dct
|
||||
|
||||
assert action_chunk.ndim <= 3, "Only 3 dimensions supported: [batch, timesteps, action_dim]"
|
||||
if action_chunk.ndim == 2:
|
||||
action_chunk = action_chunk[None, ...]
|
||||
|
||||
# Cache the time horizon and action dimension for decoding
|
||||
self.called_time_horizon = action_chunk.shape[-2]
|
||||
self.called_action_dim = action_chunk.shape[-1]
|
||||
|
||||
dct_coeff = dct(action_chunk, axis=1, norm="ortho")
|
||||
dct_coeff = np.around(dct_coeff * self.scale)
|
||||
tokens = []
|
||||
for elem in dct_coeff:
|
||||
token_str = "".join(map(chr, np.maximum(elem.flatten() - self.min_token, 0).astype(int)))
|
||||
tokens.append(self.bpe_tokenizer(token_str)["input_ids"])
|
||||
return tokens
|
||||
|
||||
def decode(
|
||||
self,
|
||||
tokens: list[list[int]],
|
||||
*,
|
||||
time_horizon: int | None = None,
|
||||
action_dim: int | None = None,
|
||||
) -> np.array:
|
||||
from scipy.fft import idct
|
||||
|
||||
self.time_horizon = time_horizon or self.time_horizon or self.called_time_horizon
|
||||
self.action_dim = action_dim or self.action_dim or self.called_action_dim
|
||||
|
||||
# Cache the time horizon and action dimension for the next call
|
||||
self.called_time_horizon = self.time_horizon
|
||||
self.called_action_dim = self.action_dim
|
||||
|
||||
assert self.time_horizon is not None and self.action_dim is not None, (
|
||||
"Tokenizer not initialized, call encode() once or pass in time_horizon and action_dim."
|
||||
)
|
||||
|
||||
decoded_actions = []
|
||||
for token in tokens:
|
||||
try:
|
||||
decoded_tokens = self.bpe_tokenizer.decode(token)
|
||||
decoded_dct_coeff = np.array(list(map(ord, decoded_tokens))) + self.min_token
|
||||
decoded_dct_coeff = decoded_dct_coeff.reshape(-1, self.action_dim)
|
||||
assert decoded_dct_coeff.shape == (
|
||||
self.time_horizon,
|
||||
self.action_dim,
|
||||
), (
|
||||
f"Decoded DCT coefficients have shape {decoded_dct_coeff.shape}, expected ({self.time_horizon}, {self.action_dim})"
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error decoding tokens: {e}")
|
||||
print(f"Tokens: {token}")
|
||||
decoded_dct_coeff = np.zeros((self.time_horizon, self.action_dim))
|
||||
decoded_actions.append(idct(decoded_dct_coeff / self.scale, axis=0, norm="ortho"))
|
||||
return np.stack(decoded_actions)
|
||||
|
||||
@classmethod
|
||||
def fit(
|
||||
cls,
|
||||
action_data: list[np.array],
|
||||
scale: float = 10,
|
||||
vocab_size: int = 1024,
|
||||
*,
|
||||
time_horizon: int | None = None,
|
||||
action_dim: int | None = None,
|
||||
) -> "UniversalActionProcessor":
|
||||
from scipy.fft import dct
|
||||
|
||||
# Run DCT over all inputs
|
||||
dct_tokens = [dct(a, axis=0, norm="ortho").flatten() for a in action_data]
|
||||
|
||||
# Quantize and find min token
|
||||
max_token = int(np.around(np.concatenate(dct_tokens) * scale).max())
|
||||
min_token = int(np.around(np.concatenate(dct_tokens) * scale).min())
|
||||
min_vocab_size = max_token - min_token
|
||||
|
||||
assert min_vocab_size <= vocab_size, (
|
||||
f"Vocab size {vocab_size} is too small for the range of tokens {min_vocab_size}"
|
||||
)
|
||||
if min_vocab_size + 100 > vocab_size:
|
||||
logging.warning(
|
||||
f"Initial alphabet size {min_vocab_size} is almost as large as the vocab"
|
||||
f"size {vocab_size}, consider increasing vocab size"
|
||||
)
|
||||
|
||||
# Make token iterator for BPE training
|
||||
def _token_iter():
|
||||
for tokens in dct_tokens:
|
||||
rounded_tokens = np.around(tokens * scale) - min_token
|
||||
rounded_tokens = rounded_tokens.astype(int)
|
||||
string = "".join(map(chr, rounded_tokens))
|
||||
yield string
|
||||
|
||||
# Train BPE tokenizer
|
||||
bpe = ByteLevelBPETokenizer()
|
||||
|
||||
# Set up the entire range of possible tokens as the initial alphabet
|
||||
alphabet = [chr(i) for i in range(max_token - min_token + 1)]
|
||||
trainer = BpeTrainer(
|
||||
vocab_size=vocab_size,
|
||||
min_frequency=2,
|
||||
show_progress=True,
|
||||
special_tokens=[],
|
||||
initial_alphabet=alphabet,
|
||||
max_token_length=10000,
|
||||
)
|
||||
|
||||
# Train the inner tokenizer (don't use ByteLevelBPETokenizer.train_from_iterator()
|
||||
# because it doesn't support custom alphabets)
|
||||
bpe._tokenizer.train_from_iterator(_token_iter(), trainer=trainer)
|
||||
|
||||
return cls(
|
||||
PreTrainedTokenizerFast(tokenizer_object=bpe, clean_up_tokenization_spaces=False),
|
||||
scale=scale,
|
||||
vocab_size=vocab_size,
|
||||
min_token=min_token,
|
||||
time_horizon=time_horizon,
|
||||
action_dim=action_dim,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained_local(
|
||||
cls,
|
||||
pretrained_model_name_or_path: str,
|
||||
*,
|
||||
revision: str | None = None,
|
||||
force_download: bool = False,
|
||||
) -> "UniversalActionProcessor":
|
||||
location = Path(
|
||||
_resolve_tokenizer_location(
|
||||
pretrained_model_name_or_path,
|
||||
revision=revision,
|
||||
force_download=force_download,
|
||||
)
|
||||
)
|
||||
processor_config = {}
|
||||
processor_config_path = location / "processor_config.json"
|
||||
if processor_config_path.exists():
|
||||
import json
|
||||
|
||||
processor_config = json.loads(processor_config_path.read_text())
|
||||
tokenizer = PreTrainedTokenizerFast.from_pretrained(str(location))
|
||||
return cls(
|
||||
tokenizer,
|
||||
scale=processor_config.get("scale", 10),
|
||||
vocab_size=processor_config.get("vocab_size", 1024),
|
||||
min_token=processor_config.get("min_token", 0),
|
||||
action_dim=processor_config.get("action_dim"),
|
||||
time_horizon=processor_config.get("time_horizon"),
|
||||
)
|
||||
@@ -1,553 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# ruff: noqa
|
||||
|
||||
"""
|
||||
MolmoAct2 configuration
|
||||
"""
|
||||
|
||||
from typing import Optional, Any
|
||||
|
||||
from transformers import PretrainedConfig
|
||||
from transformers.modeling_rope_utils import rope_config_validation
|
||||
from transformers.utils import logging
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class MolmoAct2VitConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`MolmoAct2VisionTransformer`].
|
||||
It is used to instantiate a `MolmoAct2VisionTransformer` according to the specified arguments,
|
||||
defining the model architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Example:
|
||||
```python
|
||||
>>> from transformers import MolmoAct2VitConfig, MolmoAct2VisionTransformer
|
||||
|
||||
>>> # Initializing a MolmoAct2VitConfig
|
||||
>>> configuration = MolmoAct2VitConfig()
|
||||
|
||||
>>> # Initializing a MolmoAct2VisionTransformer (with random weights)
|
||||
>>> model = MolmoAct2VisionTransformer(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "molmoact2"
|
||||
base_config_key = "vit_config"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int = 1152,
|
||||
intermediate_size: int = 4304,
|
||||
num_hidden_layers: int = 27,
|
||||
num_attention_heads: int = 16,
|
||||
num_key_value_heads: int = 16,
|
||||
head_dim: int = 72,
|
||||
hidden_act: str = "gelu_pytorch_tanh",
|
||||
layer_norm_eps: float = 1e-6,
|
||||
image_default_input_size: tuple[int, int] = (378, 378),
|
||||
image_patch_size: int = 14,
|
||||
image_num_pos: int = 577,
|
||||
attention_dropout: float = 0.0,
|
||||
residual_dropout: float = 0.0,
|
||||
initializer_range: float = 0.02,
|
||||
float32_attention: bool = True,
|
||||
attn_implementation: str = "eager",
|
||||
**kwargs,
|
||||
):
|
||||
self.attn_implementation = attn_implementation
|
||||
super().__init__(attn_implementation=attn_implementation, **kwargs)
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.head_dim = head_dim
|
||||
self.hidden_act = hidden_act
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
self.image_default_input_size = image_default_input_size
|
||||
self.image_patch_size = image_patch_size
|
||||
self.image_num_pos = image_num_pos
|
||||
self.attention_dropout = attention_dropout
|
||||
self.residual_dropout = residual_dropout
|
||||
self.initializer_range = initializer_range
|
||||
self.float32_attention = float32_attention
|
||||
|
||||
@property
|
||||
def image_num_patch(self):
|
||||
h, w = self.image_default_input_size
|
||||
return h // self.image_patch_size, w // self.image_patch_size
|
||||
|
||||
|
||||
class MolmoAct2AdapterConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of MolmoAct2Adapter. With MolmoAct2VitConfig,
|
||||
It is used to instantiate an MolmoAct2VisionBackbone according to the specified arguments,
|
||||
defining the model architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import MolmoAct2VitConfig, MolmoAct2AdapterConfig, MolmoAct2VisionBackbone
|
||||
|
||||
>>> # Initializing a MolmoAct2VitConfig and a MolmoAct2AdapterConfig
|
||||
>>> vit_config = MolmoAct2VitConfig()
|
||||
>>> adapter_config = MolmoPoolingConfig()
|
||||
|
||||
>>> # Initializing a MolmoAct2VisionBackbone (with random weights)
|
||||
>>> model = MolmoAct2VisionBackbone(vit_config, adapter_config)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> vit_configuration = model.vit_config
|
||||
>>> adapter_configuration = model.adapter_config
|
||||
```"""
|
||||
|
||||
model_type = "molmoact2"
|
||||
base_config_key = "adapter_config"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vit_layers: tuple = (-3, -9),
|
||||
pooling_attention_mask: bool = False,
|
||||
hidden_size: int = 1152,
|
||||
num_attention_heads: int = 16,
|
||||
num_key_value_heads: int = 16,
|
||||
head_dim: int = 72,
|
||||
float32_attention: bool = True,
|
||||
attention_dropout: float = 0.0,
|
||||
residual_dropout: float = 0.0,
|
||||
hidden_act: str = "silu",
|
||||
intermediate_size: int = 18944,
|
||||
text_hidden_size: int = 3584,
|
||||
image_feature_dropout: float = 0.0,
|
||||
initializer_range: float = 0.02,
|
||||
attn_implementation: str = "eager",
|
||||
**kwargs,
|
||||
):
|
||||
self.attn_implementation = attn_implementation
|
||||
super().__init__(attn_implementation=attn_implementation, **kwargs)
|
||||
self.vit_layers = vit_layers
|
||||
self.pooling_attention_mask = pooling_attention_mask
|
||||
self.hidden_size = hidden_size
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.head_dim = head_dim
|
||||
self.float32_attention = float32_attention
|
||||
self.attention_dropout = attention_dropout
|
||||
self.residual_dropout = residual_dropout
|
||||
self.hidden_act = hidden_act
|
||||
self.intermediate_size = intermediate_size
|
||||
self.text_hidden_size = text_hidden_size
|
||||
self.image_feature_dropout = image_feature_dropout
|
||||
self.initializer_range = initializer_range
|
||||
|
||||
|
||||
class MolmoAct2TextConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`MolmoAct2TextModel`]. It is used to instantiate a
|
||||
`MolmoAct2TextModel` according to the specified arguments, defining the model architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Example:
|
||||
```python
|
||||
>>> from transformers import MolmoAct2TextConfig, MolmoAct2TextModel
|
||||
|
||||
>>> # Initializing a MolmoAct2TextConfig
|
||||
>>> configuration = MolmoAct2TextConfig()
|
||||
|
||||
>>> # Initializing a MolmoAct2TextModel (with random weights)
|
||||
>>> model = MolmoAct2TextModel(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "molmoact2_text"
|
||||
base_config_key = "text_config"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
base_model_tp_plan = {
|
||||
"blocks.*.self_attn.att_proj": "colwise",
|
||||
"blocks.*.self_attn.attn_out": "rowwise",
|
||||
"blocks.*.mlp.ff_proj": "colwise",
|
||||
"blocks.*.mlp.ff_out": "rowwise",
|
||||
}
|
||||
base_model_pp_plan = {
|
||||
"wte": (["input_ids"], ["inputs_embeds"]),
|
||||
"blocks": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||
"ln_f": (["hidden_states"], ["hidden_states"]),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int = 3584,
|
||||
num_attention_heads: int = 28,
|
||||
num_key_value_heads: int | None = 4,
|
||||
head_dim: int = 128,
|
||||
vocab_size: int = 152064,
|
||||
additional_vocab_size: int = 128,
|
||||
qkv_bias: bool = True,
|
||||
num_hidden_layers: int = 48,
|
||||
intermediate_size: int = 18944,
|
||||
hidden_act: str = "silu",
|
||||
embedding_dropout: float = 0.0,
|
||||
attention_dropout: float = 0.0,
|
||||
residual_dropout: float = 0.0,
|
||||
max_position_embeddings: int = 4096,
|
||||
rope_theta: float = 1000000.0,
|
||||
rope_scaling: dict[str, Any] = None,
|
||||
rope_scaling_layers: list[int] | None = None,
|
||||
use_qk_norm: bool = False,
|
||||
qk_norm_type: str = "olmo",
|
||||
layer_norm_eps: int = 1e-6,
|
||||
norm_after: bool = False,
|
||||
initializer_range: float = 0.02,
|
||||
use_cache=True,
|
||||
tie_word_embeddings=False,
|
||||
attn_implementation: str = "eager",
|
||||
**kwargs,
|
||||
):
|
||||
self.attn_implementation = attn_implementation
|
||||
super().__init__(
|
||||
tie_word_embeddings=tie_word_embeddings, attn_implementation=attn_implementation, **kwargs
|
||||
)
|
||||
self.hidden_size = hidden_size
|
||||
self.num_attention_heads = num_attention_heads
|
||||
if num_key_value_heads is None:
|
||||
num_key_value_heads = num_attention_heads
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.head_dim = head_dim
|
||||
self.vocab_size = vocab_size
|
||||
self.additional_vocab_size = additional_vocab_size
|
||||
self.qkv_bias = qkv_bias
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_act = hidden_act
|
||||
self.embedding_dropout = embedding_dropout
|
||||
self.attention_dropout = attention_dropout
|
||||
self.residual_dropout = residual_dropout
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.rope_theta = rope_theta
|
||||
self.rope_scaling = rope_scaling
|
||||
self.rope_scaling_layers = rope_scaling_layers
|
||||
self.use_qk_norm = use_qk_norm
|
||||
self.qk_norm_type = qk_norm_type
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
self.norm_after = norm_after
|
||||
self.initializer_range = initializer_range
|
||||
self.use_cache = use_cache
|
||||
|
||||
# Validate the correctness of rotary position embeddings parameters
|
||||
rope_config_validation(self)
|
||||
|
||||
|
||||
class MolmoAct2ActionExpertConfig(PretrainedConfig):
|
||||
r"""Configuration for the MolmoAct2 modern action expert."""
|
||||
|
||||
model_type = "molmoact2_action_expert"
|
||||
base_config_key = "action_expert_config"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_action_horizon: int = 32,
|
||||
max_action_dim: int = 32,
|
||||
hidden_size: int = 1024,
|
||||
num_layers: int = 32,
|
||||
num_heads: int = 16,
|
||||
mlp_ratio: float = 8.0 / 3.0,
|
||||
ffn_multiple_of: int = 256,
|
||||
timestep_embed_dim: int = 256,
|
||||
dropout: float = 0.0,
|
||||
attn_dropout: float = 0.0,
|
||||
context_layer_norm: bool = True,
|
||||
qk_norm: bool = True,
|
||||
qk_norm_eps: float = 1e-6,
|
||||
rope: bool = True,
|
||||
causal_attn: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.max_action_horizon = max_action_horizon
|
||||
self.max_action_dim = max_action_dim
|
||||
self.hidden_size = hidden_size
|
||||
self.num_layers = num_layers
|
||||
self.num_heads = num_heads
|
||||
self.mlp_ratio = mlp_ratio
|
||||
self.ffn_multiple_of = ffn_multiple_of
|
||||
self.timestep_embed_dim = timestep_embed_dim
|
||||
self.dropout = dropout
|
||||
self.attn_dropout = attn_dropout
|
||||
self.context_layer_norm = context_layer_norm
|
||||
self.qk_norm = qk_norm
|
||||
self.qk_norm_eps = qk_norm_eps
|
||||
self.rope = rope
|
||||
self.causal_attn = causal_attn
|
||||
|
||||
def to_dict(self):
|
||||
output = super().to_dict()
|
||||
# These are derived from the parent MolmoAct2Config for HF exports. Keeping
|
||||
# them out of the public nested config avoids duplicated sources of truth.
|
||||
output.pop("max_action_horizon", None)
|
||||
output.pop("max_action_dim", None)
|
||||
return output
|
||||
|
||||
|
||||
class MolmoAct2Config(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`MolmoAct2ForConditionalGeneration`].
|
||||
It is used to instantiate an MolmoAct2 model according to the specified arguments, defining the model architecture.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import MolmoAct2Config, MolmoAct2VitConfig, MolmoAct2AdapterConfig, MolmoAct2TextConfig
|
||||
|
||||
>>> # Initializing a MolmoAct2VitConfig
|
||||
>>> vit_config = MolmoAct2VitConfig()
|
||||
|
||||
>>> # Initializing a MolmoAct2AdapterConfig
|
||||
>>> adapter_config = MolmoAct2AdapterConfig()
|
||||
|
||||
>>> # Initializing a MolmoAct2TextConfig
|
||||
>>> text_config = MolmoAct2TextConfig()
|
||||
|
||||
>>> # Initializing a MolmoAct2Config
|
||||
>>> configuration = MolmoAct2Config(
|
||||
>>> vit_config=vit_config,
|
||||
>>> adapter_config=adapter_config,
|
||||
>>> text_config=text_config,
|
||||
>>> image_start_token_id=151936,
|
||||
>>> image_end_token_id=151937,
|
||||
>>> image_patch_id=151938,
|
||||
>>> image_col_id=151939,
|
||||
>>> low_res_image_start_token_id=151940,
|
||||
>>> image_low_res_id=151942,
|
||||
>>> frame_start_token_id=151943,
|
||||
>>> frame_end_token_id=151944,
|
||||
>>> )
|
||||
|
||||
>>> # Initializing a model
|
||||
>>> model = MolmoAct2ForConditionalGeneration(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "molmoact2"
|
||||
sub_configs = {
|
||||
"text_config": MolmoAct2TextConfig,
|
||||
"vit_config": MolmoAct2VitConfig,
|
||||
"adapter_config": MolmoAct2AdapterConfig,
|
||||
"action_expert_config": MolmoAct2ActionExpertConfig,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vit_config: MolmoAct2VitConfig = None,
|
||||
adapter_config: MolmoAct2AdapterConfig = None,
|
||||
text_config: MolmoAct2TextConfig = None,
|
||||
action_expert_config: MolmoAct2ActionExpertConfig = None,
|
||||
image_start_token_id: int = None,
|
||||
low_res_image_start_token_id: int = None,
|
||||
image_end_token_id: int = None,
|
||||
image_low_res_id: int = None,
|
||||
image_patch_id: int = None,
|
||||
image_col_id: int = None,
|
||||
frame_start_token_id: int = None,
|
||||
frame_end_token_id: int = None,
|
||||
use_frame_special_tokens: bool = True,
|
||||
initializer_range: float = 0.02,
|
||||
add_action_expert: bool = True,
|
||||
max_action_dim: int = 32,
|
||||
max_action_horizon: int = 30,
|
||||
n_obs_steps: int = 30,
|
||||
action_mode: str = "both",
|
||||
state_format: str = "discrete",
|
||||
flow_matching_num_steps: int = 10,
|
||||
flow_matching_cutoff: float = 1.0,
|
||||
flow_matching_time_offset: float = 0.001,
|
||||
flow_matching_time_scale: float = 0.999,
|
||||
flow_matching_beta_alpha: float = 1.0,
|
||||
flow_matching_beta_beta: float = 1.5,
|
||||
mask_action_dim_padding: bool = True,
|
||||
enable_depth_reasoning: bool = False,
|
||||
depth_mode: int = 2,
|
||||
num_depth_codes: int = 100,
|
||||
action_expert_depth_gate: bool = False,
|
||||
action_expert_depth_gate_per_layer: bool = False,
|
||||
action_expert_depth_gate_init_bias: float = -4.0,
|
||||
action_output_token_id: int = None,
|
||||
action_start_token_id: int = None,
|
||||
action_end_token_id: int = None,
|
||||
action_token_start_id: int = None,
|
||||
num_action_tokens: int = 0,
|
||||
depth_output_token_id: int = None,
|
||||
depth_start_token_id: int = None,
|
||||
depth_end_token_id: int = None,
|
||||
depth_token_start_id: int = None,
|
||||
num_depth_tokens: int = 0,
|
||||
state_start_token_id: int = None,
|
||||
state_end_token_id: int = None,
|
||||
state_token_start_id: int = None,
|
||||
num_state_tokens: int = 0,
|
||||
add_setup_tokens: bool = True,
|
||||
add_control_tokens: bool = True,
|
||||
norm_stats_filename: str = "norm_stats.json",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
if vit_config is None:
|
||||
self.vit_config = MolmoAct2VitConfig()
|
||||
elif isinstance(vit_config, dict):
|
||||
self.vit_config = MolmoAct2VitConfig(**vit_config)
|
||||
else:
|
||||
self.vit_config = vit_config
|
||||
if adapter_config is None:
|
||||
self.adapter_config = MolmoAct2AdapterConfig()
|
||||
elif isinstance(adapter_config, dict):
|
||||
self.adapter_config = MolmoAct2AdapterConfig(**adapter_config)
|
||||
else:
|
||||
self.adapter_config = adapter_config
|
||||
if text_config is None:
|
||||
self.text_config = MolmoAct2TextConfig()
|
||||
elif isinstance(text_config, dict):
|
||||
self.text_config = MolmoAct2TextConfig(**text_config)
|
||||
else:
|
||||
self.text_config = text_config
|
||||
self.add_action_expert = bool(add_action_expert)
|
||||
if not self.add_action_expert:
|
||||
self.action_expert_config = None
|
||||
elif action_expert_config is None:
|
||||
self.action_expert_config = MolmoAct2ActionExpertConfig(
|
||||
max_action_horizon=max_action_horizon,
|
||||
max_action_dim=max_action_dim,
|
||||
num_layers=self.text_config.num_hidden_layers,
|
||||
)
|
||||
elif isinstance(action_expert_config, dict):
|
||||
self.action_expert_config = MolmoAct2ActionExpertConfig(**action_expert_config)
|
||||
else:
|
||||
self.action_expert_config = action_expert_config
|
||||
if self.add_action_expert:
|
||||
self.action_expert_config.max_action_dim = int(max_action_dim)
|
||||
self.action_expert_config.max_action_horizon = int(max_action_horizon)
|
||||
self._validate_release_action_config(
|
||||
state_format=state_format,
|
||||
)
|
||||
self.image_start_token_id = image_start_token_id
|
||||
self.low_res_image_start_token_id = low_res_image_start_token_id
|
||||
self.image_end_token_id = image_end_token_id
|
||||
self.image_low_res_id = image_low_res_id
|
||||
self.image_high_res_id = image_patch_id
|
||||
self.image_patch_id = image_patch_id
|
||||
self.image_col_id = image_col_id
|
||||
self.frame_start_token_id = frame_start_token_id
|
||||
self.frame_end_token_id = frame_end_token_id
|
||||
self.use_frame_special_tokens = use_frame_special_tokens
|
||||
self.initializer_range = initializer_range
|
||||
self.max_action_dim = max_action_dim
|
||||
self.max_action_horizon = max_action_horizon
|
||||
self.n_obs_steps = n_obs_steps
|
||||
self.action_mode = action_mode
|
||||
self.state_format = state_format
|
||||
self.flow_matching_num_steps = flow_matching_num_steps
|
||||
self.flow_matching_cutoff = flow_matching_cutoff
|
||||
self.flow_matching_time_offset = flow_matching_time_offset
|
||||
self.flow_matching_time_scale = flow_matching_time_scale
|
||||
self.flow_matching_beta_alpha = flow_matching_beta_alpha
|
||||
self.flow_matching_beta_beta = flow_matching_beta_beta
|
||||
self.mask_action_dim_padding = mask_action_dim_padding
|
||||
self.enable_depth_reasoning = enable_depth_reasoning
|
||||
self.depth_mode = depth_mode
|
||||
self.num_depth_codes = num_depth_codes
|
||||
self.action_expert_depth_gate = action_expert_depth_gate
|
||||
self.action_expert_depth_gate_per_layer = action_expert_depth_gate_per_layer
|
||||
self.action_expert_depth_gate_init_bias = action_expert_depth_gate_init_bias
|
||||
self.action_output_token_id = action_output_token_id
|
||||
self.action_start_token_id = action_start_token_id
|
||||
self.action_end_token_id = action_end_token_id
|
||||
self.action_token_start_id = action_token_start_id
|
||||
self.num_action_tokens = num_action_tokens
|
||||
self.depth_output_token_id = depth_output_token_id
|
||||
self.depth_start_token_id = depth_start_token_id
|
||||
self.depth_end_token_id = depth_end_token_id
|
||||
self.depth_token_start_id = depth_token_start_id
|
||||
self.num_depth_tokens = num_depth_tokens
|
||||
self.state_start_token_id = state_start_token_id
|
||||
self.state_end_token_id = state_end_token_id
|
||||
self.state_token_start_id = state_token_start_id
|
||||
self.num_state_tokens = num_state_tokens
|
||||
self.add_setup_tokens = add_setup_tokens
|
||||
self.add_control_tokens = add_control_tokens
|
||||
self.norm_stats_filename = norm_stats_filename
|
||||
|
||||
@staticmethod
|
||||
def _validate_release_action_config(
|
||||
*,
|
||||
state_format: str,
|
||||
) -> None:
|
||||
if state_format != "discrete":
|
||||
raise ValueError("MolmoAct2 HF export supports only state_format='discrete'.")
|
||||
|
||||
@property
|
||||
def image_num_patch(self):
|
||||
assert self.vit_config is not None
|
||||
return self.vit_config.image_num_patch
|
||||
|
||||
@property
|
||||
def num_attention_heads(self):
|
||||
return self.text_config.num_attention_heads
|
||||
|
||||
@property
|
||||
def num_key_value_heads(self):
|
||||
return self.text_config.num_key_value_heads
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return self.text_config.head_dim
|
||||
|
||||
@property
|
||||
def num_hidden_layers(self):
|
||||
return self.text_config.num_hidden_layers
|
||||
|
||||
@property
|
||||
def hidden_size(self):
|
||||
return self.text_config.hidden_size
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
return self.text_config.vocab_size
|
||||
|
||||
@property
|
||||
def max_position_embeddings(self):
|
||||
return self.text_config.max_position_embeddings
|
||||
|
||||
|
||||
MolmoAct2VitConfig.register_for_auto_class()
|
||||
MolmoAct2AdapterConfig.register_for_auto_class()
|
||||
MolmoAct2TextConfig.register_for_auto_class()
|
||||
MolmoAct2ActionExpertConfig.register_for_auto_class()
|
||||
MolmoAct2Config.register_for_auto_class()
|
||||
@@ -1,564 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# ruff: noqa
|
||||
|
||||
"""Image processor class for MolmoAct2"""
|
||||
|
||||
from typing import Optional, Union
|
||||
import numpy as np
|
||||
import einops
|
||||
import torch
|
||||
import torchvision.transforms
|
||||
|
||||
from transformers.image_utils import (
|
||||
IMAGENET_STANDARD_MEAN,
|
||||
IMAGENET_STANDARD_STD,
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
make_flat_list_of_images,
|
||||
valid_images,
|
||||
to_numpy_array,
|
||||
)
|
||||
from transformers.image_transforms import convert_to_rgb
|
||||
from transformers.processing_utils import ImagesKwargs
|
||||
from transformers.image_processing_utils import BaseImageProcessor, get_size_dict
|
||||
from transformers.utils import logging
|
||||
from transformers.feature_extraction_utils import BatchFeature
|
||||
from transformers.utils import TensorType, logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def normalize_image(
|
||||
image: np.ndarray,
|
||||
image_mean: list[float],
|
||||
image_std: list[float],
|
||||
) -> np.ndarray:
|
||||
if np.allclose(image_mean, [0.5, 0.5, 0.5]) and np.allclose(image_std, [0.5, 0.5, 0.5]):
|
||||
return image * np.asarray(2.0, dtype=np.float32) - np.asarray(1.0, dtype=np.float32)
|
||||
image -= np.array(image_mean, dtype=np.float32)[None, None, :]
|
||||
image /= np.array(image_std, dtype=np.float32)[None, None, :]
|
||||
return image
|
||||
|
||||
|
||||
def resize_image(
|
||||
image: np.ndarray,
|
||||
desired_output_size: list[int],
|
||||
resample: PILImageResampling,
|
||||
) -> np.ndarray:
|
||||
image = torch.permute(torch.from_numpy(image), [2, 0, 1])
|
||||
dtype = image.dtype
|
||||
if torch.is_floating_point(image):
|
||||
in_min = 0.0
|
||||
in_max = 1.0
|
||||
resized = torchvision.transforms.Resize(
|
||||
desired_output_size,
|
||||
resample,
|
||||
antialias=False,
|
||||
)(image)
|
||||
resized = torch.clip(resized, 0.0, 1.0).to(dtype)
|
||||
else:
|
||||
assert image.dtype == torch.uint8, "SigLIP expects float images or uint8 images, but got {}".format(
|
||||
image.dtype
|
||||
)
|
||||
in_min = 0.0
|
||||
in_max = 255.0
|
||||
resized = torchvision.transforms.Resize(
|
||||
desired_output_size,
|
||||
resample,
|
||||
antialias=False,
|
||||
)(image)
|
||||
resized = torch.clip(resized, 0, 255).to(dtype)
|
||||
|
||||
resized = resized.to(torch.float32)
|
||||
resized = (resized - in_min) / (in_max - in_min)
|
||||
|
||||
resized = torch.permute(resized, [1, 2, 0]).numpy()
|
||||
|
||||
return resized
|
||||
|
||||
|
||||
def select_tiling(h, w, patch_size, max_num_crops):
|
||||
"""Divide in image of size [w, h] in up to max_num_patches of size patch_size"""
|
||||
original_size = np.stack([h, w]) # [1, 2]
|
||||
original_res = h * w
|
||||
tilings = []
|
||||
for i in range(1, max_num_crops + 1):
|
||||
for j in range(1, max_num_crops + 1):
|
||||
if i * j <= max_num_crops:
|
||||
tilings.append((i, j))
|
||||
# sort so argmin and argmax favour smaller tilings in the event of a tie
|
||||
tilings.sort(key=lambda x: (x[0] * x[1], x[0]))
|
||||
candidate_tilings = np.array(tilings, dtype=np.int32) # [n_resolutions, 2]
|
||||
candidate_resolutions = candidate_tilings * patch_size # [n_resolutions, 2]
|
||||
|
||||
# How much we would need to scale the image to fit exactly in each tiling
|
||||
original_size = np.stack([h, w], dtype=np.float32) # [1, 2]
|
||||
|
||||
# The original size can be zero in rare cases if the image is smaller than the margin
|
||||
# In those cases letting the scale become infinite means the tiling is based on the
|
||||
# other side, or falls back to the smallest tiling
|
||||
with np.errstate(divide="ignore"):
|
||||
required_scale_d = (candidate_resolutions.astype(np.float32) / original_size,)
|
||||
required_scale = np.min(required_scale_d, axis=-1, keepdims=True) # [n_resolutions, 1]
|
||||
if np.all(required_scale < 1):
|
||||
# We are forced to downscale, so try to minimize the amount of downscaling
|
||||
ix = np.argmax(required_scale)
|
||||
else:
|
||||
# Pick the resolution that required the least upscaling so that it most closely fits the image
|
||||
required_scale = np.where(required_scale < 1.0, 10e9, required_scale)
|
||||
ix = np.argmin(required_scale)
|
||||
return candidate_tilings[ix]
|
||||
|
||||
|
||||
def build_resized_image(
|
||||
image: np.ndarray,
|
||||
base_image_input_size: list[int],
|
||||
resample: PILImageResampling,
|
||||
image_mean: list[float],
|
||||
image_std: list[float],
|
||||
image_patch_size: int,
|
||||
) -> tuple[np.ndarray, np.ndarray]:
|
||||
resized = resize_image(
|
||||
image,
|
||||
base_image_input_size,
|
||||
resample,
|
||||
)
|
||||
resized = normalize_image(resized, image_mean, image_std)
|
||||
if len(resized.shape) == 3:
|
||||
resized = np.expand_dims(resized, 0)
|
||||
crop_patch_w = base_image_input_size[1] // image_patch_size
|
||||
crop_patch_h = base_image_input_size[0] // image_patch_size
|
||||
resize_idx = np.arange(crop_patch_w * crop_patch_h).reshape([crop_patch_h, crop_patch_w])
|
||||
return resized, resize_idx
|
||||
|
||||
|
||||
def build_overlapping_crops(
|
||||
image: np.ndarray,
|
||||
max_crops: int,
|
||||
overlap_margins: list[int],
|
||||
base_image_input_size: list[int],
|
||||
resample: PILImageResampling,
|
||||
image_mean: list[float],
|
||||
image_std: list[float],
|
||||
image_patch_size: int,
|
||||
) -> tuple[np.ndarray, np.ndarray]:
|
||||
"""Decompose an image into a set of overlapping crops
|
||||
|
||||
:return crop_arr: [n_crops, h, w, 3] The crops
|
||||
:return patch_idx: [overlap_patch_h, overlap_patch_w] For each patch in the resized image
|
||||
the crops were extracted from, what patch in `crop_arr` it corresponds to
|
||||
"""
|
||||
original_image_h, original_image_w = image.shape[:2]
|
||||
crop_size = base_image_input_size[0]
|
||||
assert base_image_input_size[0] == base_image_input_size[1]
|
||||
|
||||
left_margin, right_margin = overlap_margins
|
||||
total_margin_pixels = image_patch_size * (right_margin + left_margin) # pixels removed per dim
|
||||
crop_patches = base_image_input_size[0] // image_patch_size # patches per crop dim
|
||||
crop_window_patches = crop_patches - (right_margin + left_margin) # usable patches
|
||||
crop_window_size = crop_window_patches * image_patch_size
|
||||
crop_patch_w = base_image_input_size[1] // image_patch_size
|
||||
crop_patch_h = base_image_input_size[0] // image_patch_size
|
||||
original_image_h, original_image_w = image.shape[:2]
|
||||
crop_size = base_image_input_size[0]
|
||||
|
||||
# Decide how to tile the image, to account for the overlap margins we compute the tiling
|
||||
# as if we had an image without the margins and were using a crop size without the margins
|
||||
tiling = select_tiling(
|
||||
original_image_h - total_margin_pixels,
|
||||
original_image_w - total_margin_pixels,
|
||||
crop_window_size,
|
||||
max_crops,
|
||||
)
|
||||
|
||||
src = resize_image(
|
||||
image,
|
||||
[
|
||||
tiling[0] * crop_window_size + total_margin_pixels,
|
||||
tiling[1] * crop_window_size + total_margin_pixels,
|
||||
],
|
||||
resample,
|
||||
)
|
||||
src = normalize_image(src, image_mean, image_std)
|
||||
|
||||
# Now we have to split the image into crops, and track what patches came from
|
||||
# where in `patch_idx_arr`
|
||||
n_crops = tiling[0] * tiling[1]
|
||||
crop_arr = np.zeros([n_crops, crop_size, crop_size, 3], dtype=src.dtype)
|
||||
patch_idx_arr = np.zeros([n_crops, crop_patch_h, crop_patch_w], dtype=np.int32)
|
||||
on_crop = 0
|
||||
for i in range(tiling[0]):
|
||||
# Slide over `src` by `crop_window_size` steps, but extract crops of size `crops_size`
|
||||
# which results in overlapping crop windows
|
||||
y0 = i * crop_window_size
|
||||
for j in range(tiling[1]):
|
||||
x0 = j * crop_window_size
|
||||
crop_arr[on_crop] = src[y0 : y0 + crop_size, x0 : x0 + crop_size]
|
||||
patch_idx = np.arange(crop_patch_w * crop_patch_h).reshape(crop_patch_h, crop_patch_w)
|
||||
patch_idx += on_crop * crop_patch_h * crop_patch_w
|
||||
|
||||
# Mask out idx that are in the overlap region
|
||||
if i != 0:
|
||||
patch_idx[:left_margin, :] = -1
|
||||
if j != 0:
|
||||
patch_idx[:, :left_margin] = -1
|
||||
if i != tiling[0] - 1:
|
||||
patch_idx[-right_margin:, :] = -1
|
||||
if j != tiling[1] - 1:
|
||||
patch_idx[:, -right_margin:] = -1
|
||||
patch_idx_arr[on_crop] = patch_idx
|
||||
on_crop += 1
|
||||
|
||||
# `patch_idx_arr` is ordered crop-by-crop, here we transpose `patch_idx_arr`
|
||||
# so it is ordered left-to-right order
|
||||
patch_idx_arr = np.reshape(patch_idx_arr, [tiling[0], tiling[1], crop_patch_h, crop_patch_w])
|
||||
patch_idx_arr = np.transpose(patch_idx_arr, [0, 2, 1, 3])
|
||||
patch_idx_arr = np.reshape(patch_idx_arr, [-1])
|
||||
|
||||
# Now get the parts not in the overlap region, so it should map each patch in `src`
|
||||
# to the correct patch it should come from in `crop_arr`
|
||||
patch_idx_arr = patch_idx_arr[patch_idx_arr >= 0].reshape(
|
||||
src.shape[0] // image_patch_size,
|
||||
src.shape[1] // image_patch_size,
|
||||
)
|
||||
return crop_arr, patch_idx_arr
|
||||
|
||||
|
||||
def batch_pixels_to_patches(array: np.ndarray, patch_size: int) -> np.ndarray:
|
||||
"""Reshape images of [n_images, h, w, 3] -> [n_images, n_patches, pixels_per_patch]"""
|
||||
if len(array.shape) == 3:
|
||||
n_crops, h, w = array.shape
|
||||
h_patches = h // patch_size
|
||||
w_patches = w // patch_size
|
||||
array = np.reshape(array, [n_crops, h_patches, patch_size, w_patches, patch_size])
|
||||
array = np.transpose(array, [0, 1, 3, 2, 4])
|
||||
array = np.reshape(array, [n_crops, h_patches * w_patches, patch_size * patch_size])
|
||||
return array
|
||||
else:
|
||||
n_crops, h, w, c = array.shape
|
||||
h_patches = h // patch_size
|
||||
w_patches = w // patch_size
|
||||
array = np.reshape(array, [n_crops, h_patches, patch_size, w_patches, patch_size, c])
|
||||
array = np.transpose(array, [0, 1, 3, 2, 4, 5])
|
||||
array = np.reshape(array, [n_crops, h_patches * w_patches, patch_size * patch_size * c])
|
||||
return array
|
||||
|
||||
|
||||
def arange_for_pooling(
|
||||
idx_arr: np.ndarray,
|
||||
pool_h: int,
|
||||
pool_w: int,
|
||||
) -> np.ndarray:
|
||||
h_pad = pool_h * ((idx_arr.shape[0] + pool_h - 1) // pool_h) - idx_arr.shape[0]
|
||||
w_pad = pool_w * ((idx_arr.shape[1] + pool_w - 1) // pool_w) - idx_arr.shape[1]
|
||||
idx_arr = np.pad(
|
||||
idx_arr,
|
||||
[[h_pad // 2, (h_pad + 1) // 2], [w_pad // 2, (w_pad + 1) // 2]],
|
||||
mode="constant",
|
||||
constant_values=-1,
|
||||
)
|
||||
return einops.rearrange(idx_arr, "(h dh) (w dw) -> h w (dh dw)", dh=pool_h, dw=pool_w)
|
||||
|
||||
|
||||
def image_to_patches_and_grids(
|
||||
image: np.ndarray,
|
||||
max_crops: int,
|
||||
overlap_margins: list[int],
|
||||
base_image_input_size: list[int],
|
||||
resample: PILImageResampling,
|
||||
image_mean: list[float],
|
||||
image_std: list[float],
|
||||
image_patch_size: int,
|
||||
image_pooling_w: int,
|
||||
image_pooling_h: int,
|
||||
crop_mode: str = "overlap-and-resize-c2",
|
||||
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
"""
|
||||
:return image_grids, the shape of each (low-res, high-res) image after pooling
|
||||
:return crops, the image crops to processes with the ViT
|
||||
:return pooled_patch_idx, for each patch_id tokens in `image_tokens`, the indices of the
|
||||
patches in `crops` to pool for that token, masked with -1
|
||||
"""
|
||||
if isinstance(base_image_input_size, int):
|
||||
base_image_input_size = (base_image_input_size, base_image_input_size)
|
||||
|
||||
base_image_input_d = image_patch_size
|
||||
pooling_w = image_pooling_w
|
||||
pooling_h = image_pooling_h
|
||||
crop_patch_w = base_image_input_size[1] // base_image_input_d
|
||||
crop_patch_h = base_image_input_size[0] // base_image_input_d
|
||||
|
||||
if crop_mode == "resize":
|
||||
resized, resize_idx = build_resized_image(
|
||||
image,
|
||||
base_image_input_size,
|
||||
resample,
|
||||
image_mean,
|
||||
image_std,
|
||||
image_patch_size,
|
||||
)
|
||||
resize_idx = arange_for_pooling(resize_idx, pooling_h, pooling_w)
|
||||
resized_h, resized_w = resize_idx.shape[:2]
|
||||
resize_idx = resize_idx.reshape([-1, pooling_h * pooling_w])
|
||||
image_grid = [np.array([resized_h, resized_w, 0, 0])]
|
||||
return (
|
||||
np.stack(image_grid, 0),
|
||||
batch_pixels_to_patches(resized, image_patch_size),
|
||||
resize_idx,
|
||||
)
|
||||
|
||||
if crop_mode not in {"overlap-and-resize-c2", "overlap-and-resize"}:
|
||||
raise ValueError(f"Unsupported MolmoAct2 image crop_mode {crop_mode!r}.")
|
||||
|
||||
crop_arr, patch_idx_arr = build_overlapping_crops(
|
||||
image,
|
||||
max_crops,
|
||||
overlap_margins,
|
||||
base_image_input_size,
|
||||
resample,
|
||||
image_mean,
|
||||
image_std,
|
||||
image_patch_size,
|
||||
)
|
||||
pooling_idx = arange_for_pooling(patch_idx_arr, pooling_h, pooling_w)
|
||||
h, w = pooling_idx.shape[:2]
|
||||
pooling_idx = pooling_idx.reshape([-1, pooling_h * pooling_w])
|
||||
|
||||
# Finally do the same for the global image
|
||||
resized, resize_idx = build_resized_image(
|
||||
image,
|
||||
base_image_input_size,
|
||||
resample,
|
||||
image_mean,
|
||||
image_std,
|
||||
image_patch_size,
|
||||
)
|
||||
crop_arr = np.concatenate([resized, crop_arr], 0)
|
||||
|
||||
resize_idx = arange_for_pooling(resize_idx, pooling_h, pooling_w)
|
||||
resized_h, resized_w = resize_idx.shape[:2]
|
||||
resize_idx = resize_idx.reshape([-1, pooling_h * pooling_w])
|
||||
|
||||
# Global image goes first, so the order of patches in previous crops gets increased
|
||||
pooling_idx = np.where(pooling_idx >= 0, pooling_idx + crop_patch_h * crop_patch_w, -1)
|
||||
pooling_idx = np.concatenate([resize_idx, pooling_idx])
|
||||
image_grid = [np.array([resized_h, resized_w, h, w])]
|
||||
|
||||
return (np.stack(image_grid, 0), batch_pixels_to_patches(crop_arr, image_patch_size), pooling_idx)
|
||||
|
||||
|
||||
class MolmoAct2ImagesKwargs(ImagesKwargs, total=False):
|
||||
max_crops: int | None
|
||||
overlap_margins: list[int] | None
|
||||
crop_mode: str | None
|
||||
patch_size: int | None
|
||||
pooling_size: list[int] | None
|
||||
|
||||
|
||||
class MolmoAct2ImageProcessor(BaseImageProcessor):
|
||||
r"""
|
||||
Constructs a MolmoAct2 image processor that preprocesses images for the model.
|
||||
|
||||
Args:
|
||||
size (`dict[str, int]` *optional*, defaults to `{"height": 378, "width": 378}`):
|
||||
Size of the image after resizing.
|
||||
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
|
||||
Resampling filter to use when resizing the image.
|
||||
image_mean (`float` or `list[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
|
||||
Mean to use if normalizing the image. This is a float or list of floats for each channel in the image.
|
||||
image_std (`float` or `list[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
|
||||
Standard deviation to use if normalizing the image. This is a float or list of floats for each channel in the image.
|
||||
do_convert_rgb (`bool`, *optional*, defaults to `True`):
|
||||
Whether to convert the image to RGB.
|
||||
max_crops (`int`, *optional*, defaults to `8`):
|
||||
Maximum number of crops to use per image.
|
||||
overlap_margins (`list[int]`, *optional*, defaults to `[4, 4]`):
|
||||
Overlap margins to use.
|
||||
patch_size (`int`, *optional*, defaults to 14):
|
||||
The spatial patch size of the vision encoder.
|
||||
pooling_size (`list[int]`, *optional*, defaults to `[2, 2]`):
|
||||
The pooling size of the vision adapter.
|
||||
"""
|
||||
|
||||
model_input_names = ["pixel_values", "image_token_pooling", "image_grids", "image_num_crops"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size: dict[str, int] | None = None,
|
||||
resample: PILImageResampling = PILImageResampling.BILINEAR,
|
||||
image_mean: float | list[float] | None = None,
|
||||
image_std: float | list[float] | None = None,
|
||||
do_convert_rgb: bool = True,
|
||||
max_crops: int = 8,
|
||||
overlap_margins: list[int] = [4, 4],
|
||||
crop_mode: str = "overlap-and-resize-c2",
|
||||
patch_size: int = 14,
|
||||
pooling_size: list[int] = [2, 2],
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
size = size if size is not None else {"height": 378, "width": 378}
|
||||
size = get_size_dict(size, default_to_square=True)
|
||||
self.size = size
|
||||
|
||||
self.resample = resample
|
||||
self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
|
||||
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
|
||||
self.do_convert_rgb = do_convert_rgb
|
||||
|
||||
self.max_crops = max_crops
|
||||
self.overlap_margins = overlap_margins
|
||||
self.crop_mode = crop_mode
|
||||
self.patch_size = patch_size
|
||||
self.pooling_size = pooling_size
|
||||
|
||||
def preprocess(
|
||||
self,
|
||||
images: ImageInput,
|
||||
size: dict[str, int] | None = None,
|
||||
resample: PILImageResampling | None = None,
|
||||
image_mean: float | list[float] | None = None,
|
||||
image_std: float | list[float] | None = None,
|
||||
do_convert_rgb: bool | None = None,
|
||||
max_crops: int | None = None,
|
||||
overlap_margins: list[int] | None = None,
|
||||
crop_mode: str | None = None,
|
||||
patch_size: int | None = None,
|
||||
pooling_size: list[int] | None = None,
|
||||
return_tensors: str | TensorType | None = None,
|
||||
**kwargs,
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
Args:
|
||||
images (`ImageInput`):
|
||||
Image to preprocess.
|
||||
size (`dict[str, int]`, *optional*, defaults to `self.size`):
|
||||
Size of the image after resizing.
|
||||
resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
|
||||
Resampling filter to use when resizing the image. This can be one of the enum `PILImageResampling`. Only
|
||||
has an effect if `do_resize` is set to `True`.
|
||||
image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`):
|
||||
Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
|
||||
image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`):
|
||||
Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
|
||||
`True`.
|
||||
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
|
||||
Whether to convert the image to RGB.
|
||||
max_crops (`int`, *optional*, defaults to `self.max_crops`):
|
||||
Maximum number of crops to use per image.
|
||||
overlap_margins (`list[int]`, *optional*, defaults to `self.overlap_margins`):
|
||||
Overlap margins to use.
|
||||
patch_size (`int`, *optional*, defaults to `self.patch_size`):
|
||||
The spatial patch size of the vision encoder.
|
||||
pooling_size (`list[int]`, *optional*, defaults to `self.pooling_size`):
|
||||
The pooling size of the vision adapter.
|
||||
return_tensors (`str` or `TensorType`, *optional*):
|
||||
The type of tensors to return. Can be one of:
|
||||
- Unset: Return a list of `np.ndarray`.
|
||||
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
||||
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
||||
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
||||
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
||||
|
||||
Returns:
|
||||
A `BatchFeature` containing the following keys:
|
||||
- `pixel_values`: The preprocessed images.
|
||||
- `image_token_pooling`: The indices of the patches in `crops` to pool for each token in `image_tokens`.
|
||||
- `image_grids`: The image grids.
|
||||
- `image_num_crops`: The number of crops for each image.
|
||||
"""
|
||||
if size is not None:
|
||||
if "height" not in size or "width" not in size:
|
||||
raise ValueError("size must contain 'height' and 'width' keys.")
|
||||
else:
|
||||
size = {**self.size}
|
||||
|
||||
base_image_input_size = [size["height"], size["width"]]
|
||||
|
||||
resample = resample or self.resample
|
||||
image_mean = image_mean or self.image_mean
|
||||
image_std = image_std or self.image_std
|
||||
do_convert_rgb = do_convert_rgb or self.do_convert_rgb
|
||||
|
||||
max_crops = max_crops or self.max_crops
|
||||
overlap_margins = overlap_margins or self.overlap_margins
|
||||
crop_mode = crop_mode or self.crop_mode
|
||||
patch_size = patch_size or self.patch_size
|
||||
pooling_size = pooling_size or self.pooling_size
|
||||
|
||||
image_pooling_h, image_pooling_w = pooling_size
|
||||
|
||||
if images is not None:
|
||||
images = self.fetch_images(images)
|
||||
images = make_flat_list_of_images(images)
|
||||
|
||||
if images is not None and not valid_images(images):
|
||||
raise ValueError(
|
||||
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
||||
"torch.Tensor, tf.Tensor or jax.ndarray."
|
||||
)
|
||||
|
||||
if do_convert_rgb:
|
||||
images = [convert_to_rgb(image) for image in images]
|
||||
|
||||
# All transformations expect numpy arrays.
|
||||
images = [to_numpy_array(image) for image in images]
|
||||
|
||||
data = {}
|
||||
if images is not None:
|
||||
batch_grids = []
|
||||
batch_crops = []
|
||||
batch_pooled_patches_idx = []
|
||||
batch_num_crops = []
|
||||
|
||||
for image in images:
|
||||
image_grid, crops, pooled_idx = image_to_patches_and_grids(
|
||||
image,
|
||||
max_crops,
|
||||
overlap_margins,
|
||||
base_image_input_size,
|
||||
resample,
|
||||
image_mean,
|
||||
image_std,
|
||||
patch_size,
|
||||
image_pooling_w,
|
||||
image_pooling_h,
|
||||
crop_mode,
|
||||
)
|
||||
batch_grids.append(image_grid)
|
||||
batch_crops.append(crops)
|
||||
batch_pooled_patches_idx.append(pooled_idx)
|
||||
batch_num_crops.append(crops.shape[0])
|
||||
|
||||
pixel_values = np.concatenate(batch_crops, 0)
|
||||
image_token_pooling = np.concatenate(batch_pooled_patches_idx, 0)
|
||||
image_grids = np.concatenate(batch_grids, 0)
|
||||
image_num_crops = np.array(batch_num_crops)
|
||||
|
||||
data.update(
|
||||
pixel_values=pixel_values,
|
||||
image_token_pooling=image_token_pooling,
|
||||
image_grids=image_grids,
|
||||
image_num_crops=image_num_crops,
|
||||
)
|
||||
|
||||
return BatchFeature(data, tensor_type=return_tensors)
|
||||
|
||||
|
||||
MolmoAct2ImageProcessor.register_for_auto_class()
|
||||
@@ -1,748 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# ruff: noqa
|
||||
|
||||
"""Inference utilities for MolmoAct2"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, Tuple
|
||||
from collections.abc import Iterable, Sequence
|
||||
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
from transformers.cache_utils import Cache
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class _ActionFlowInputs:
|
||||
trajectory: torch.Tensor
|
||||
context: Any
|
||||
modulations: Sequence[Any]
|
||||
action_dim_is_pad: torch.Tensor | None
|
||||
|
||||
|
||||
@dataclass
|
||||
class _ActionFlowCudaGraph:
|
||||
key: tuple[Any, ...]
|
||||
graph: torch.cuda.CUDAGraph
|
||||
static_inputs: _ActionFlowInputs
|
||||
output: torch.Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class _DepthDecodeCudaGraphLayerStage:
|
||||
residual: torch.Tensor
|
||||
query: torch.Tensor
|
||||
key: torch.Tensor
|
||||
value: torch.Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class _DepthDecodeCudaGraphPostStage:
|
||||
graph: torch.cuda.CUDAGraph
|
||||
attn_context: torch.Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class _DepthDecodeCudaGraph:
|
||||
cache_key: tuple[Any, ...]
|
||||
pre_graph: torch.cuda.CUDAGraph
|
||||
token_ids: torch.Tensor
|
||||
cos: torch.Tensor
|
||||
sin: torch.Tensor
|
||||
positions: torch.Tensor
|
||||
stages: Sequence[_DepthDecodeCudaGraphLayerStage]
|
||||
post_graphs: Sequence[_DepthDecodeCudaGraphPostStage]
|
||||
output: torch.Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class _DepthDecodeCudaGraphSpec:
|
||||
eligible: bool
|
||||
cache_key_prefix: tuple[Any, ...]
|
||||
num_hidden_layers: int
|
||||
head_dim: int
|
||||
num_attention_heads: int
|
||||
|
||||
|
||||
def _cache_seq_len_int(past_key_values: Cache | None) -> int:
|
||||
if past_key_values is None:
|
||||
return 0
|
||||
seq_len = past_key_values.get_seq_length()
|
||||
if torch.is_tensor(seq_len):
|
||||
return int(seq_len.item())
|
||||
return int(seq_len)
|
||||
|
||||
|
||||
def _cache_max_len_int(past_key_values: Cache | None) -> int:
|
||||
if past_key_values is None:
|
||||
return -1
|
||||
max_len = past_key_values.get_max_cache_shape()
|
||||
if torch.is_tensor(max_len):
|
||||
return int(max_len.item())
|
||||
return int(max_len)
|
||||
|
||||
|
||||
def _iter_cache_key_values(
|
||||
past_key_values: Cache,
|
||||
) -> Iterable[tuple[torch.Tensor | None, torch.Tensor | None]]:
|
||||
layers = getattr(past_key_values, "layers", None)
|
||||
if layers is not None:
|
||||
for layer in layers:
|
||||
yield getattr(layer, "keys", None), getattr(layer, "values", None)
|
||||
return
|
||||
for layer in past_key_values:
|
||||
yield layer[0], layer[1]
|
||||
|
||||
|
||||
class _DepthDecodeStaticLayerCache:
|
||||
is_compileable = False
|
||||
is_sliding = False
|
||||
|
||||
def __init__(self, max_cache_len: int) -> None:
|
||||
self.max_cache_len = int(max_cache_len)
|
||||
self.cumulative_length = 0
|
||||
self.keys: torch.Tensor | None = None
|
||||
self.values: torch.Tensor | None = None
|
||||
|
||||
def _allocate(self, key_states: torch.Tensor, value_states: torch.Tensor) -> None:
|
||||
bsz, n_heads = key_states.shape[:2]
|
||||
self.keys = torch.empty(
|
||||
(bsz, n_heads, self.max_cache_len, key_states.shape[-1]),
|
||||
dtype=key_states.dtype,
|
||||
device=key_states.device,
|
||||
)
|
||||
self.values = torch.empty(
|
||||
(bsz, n_heads, self.max_cache_len, value_states.shape[-1]),
|
||||
dtype=value_states.dtype,
|
||||
device=value_states.device,
|
||||
)
|
||||
|
||||
def update(
|
||||
self,
|
||||
key_states: torch.Tensor,
|
||||
value_states: torch.Tensor,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if self.keys is None:
|
||||
self._allocate(key_states, value_states)
|
||||
start = self.cumulative_length
|
||||
end = start + key_states.shape[-2]
|
||||
if end > self.max_cache_len:
|
||||
raise RuntimeError(f"KV cache length {end} exceeds max_cache_len={self.max_cache_len}.")
|
||||
self.keys[:, :, start:end, :].copy_(key_states)
|
||||
self.values[:, :, start:end, :].copy_(value_states)
|
||||
self.cumulative_length = end
|
||||
return self.keys[:, :, :end, :], self.values[:, :, :end, :]
|
||||
|
||||
def get_seq_length(self) -> int:
|
||||
return self.cumulative_length
|
||||
|
||||
def get_max_cache_shape(self) -> int:
|
||||
return -1
|
||||
|
||||
def reset(self) -> None:
|
||||
self.cumulative_length = 0
|
||||
|
||||
|
||||
class _DepthDecodeStaticCache(Cache):
|
||||
def __init__(self, config: PretrainedConfig, max_cache_len: int) -> None:
|
||||
text_config = config.get_text_config(decoder=True)
|
||||
super().__init__(
|
||||
layers=[
|
||||
_DepthDecodeStaticLayerCache(max_cache_len=max_cache_len)
|
||||
for _ in range(text_config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
|
||||
def get_seq_length(self, layer_idx: int = 0) -> int:
|
||||
return self.layers[layer_idx].get_seq_length()
|
||||
|
||||
def get_max_cache_shape(self, layer_idx: int = 0) -> int:
|
||||
return self.layers[layer_idx].get_max_cache_shape()
|
||||
|
||||
def reset(self) -> None:
|
||||
for layer in self.layers:
|
||||
layer.reset()
|
||||
|
||||
|
||||
class ActionCudaGraphManager:
|
||||
def __init__(self, model: Any) -> None:
|
||||
self.model = model
|
||||
self.enabled = True
|
||||
self.action_flow_graph: _ActionFlowCudaGraph | None = None
|
||||
|
||||
def set_enabled(self, enabled: bool) -> None:
|
||||
self.enabled = bool(enabled)
|
||||
|
||||
def can_use_action_flow(self, inputs: _ActionFlowInputs) -> bool:
|
||||
action_model = self.model
|
||||
if not self.enabled:
|
||||
return False
|
||||
if action_model.training or action_model._require_action_expert().training:
|
||||
return False
|
||||
if inputs.trajectory.device.type != "cuda":
|
||||
return False
|
||||
|
||||
def all_on_cuda():
|
||||
yield inputs.trajectory
|
||||
for k, v in inputs.context.kv_contexts:
|
||||
yield k
|
||||
yield v
|
||||
for t in (
|
||||
inputs.context.cross_mask,
|
||||
inputs.context.self_mask,
|
||||
inputs.context.valid_action,
|
||||
inputs.action_dim_is_pad,
|
||||
):
|
||||
if t is not None:
|
||||
yield t
|
||||
if inputs.context.rope_cache is not None:
|
||||
yield from inputs.context.rope_cache
|
||||
for step in inputs.modulations:
|
||||
yield step.conditioning
|
||||
for block_modulation in step.block_modulations:
|
||||
yield from block_modulation
|
||||
yield from step.final_modulation
|
||||
|
||||
return all(t.device.type == "cuda" for t in all_on_cuda())
|
||||
|
||||
def run_action_flow(
|
||||
self,
|
||||
inputs: _ActionFlowInputs,
|
||||
steps: int,
|
||||
run_loop,
|
||||
) -> torch.Tensor:
|
||||
key = _cuda_graph_key(inputs, steps)
|
||||
cache = self.action_flow_graph
|
||||
if cache is None or cache.key != key:
|
||||
static_inputs = _clone_static_inputs(inputs)
|
||||
graph, output = _capture_cuda_graph(
|
||||
lambda: run_loop(static_inputs, steps),
|
||||
inputs.trajectory.device,
|
||||
after_warmup=lambda: static_inputs.trajectory.copy_(inputs.trajectory),
|
||||
)
|
||||
cache = _ActionFlowCudaGraph(
|
||||
key=key,
|
||||
graph=graph,
|
||||
static_inputs=static_inputs,
|
||||
output=output,
|
||||
)
|
||||
self.action_flow_graph = cache
|
||||
else:
|
||||
_copy_inputs_(cache.static_inputs, inputs)
|
||||
|
||||
cache.graph.replay()
|
||||
return cache.output.clone()
|
||||
|
||||
|
||||
class DepthDecodeCudaGraphManager:
|
||||
def __init__(self, model: Any) -> None:
|
||||
self.model = model
|
||||
self.backbone = model.model
|
||||
self.enabled = True
|
||||
self.graph: _DepthDecodeCudaGraph | None = None
|
||||
self.graph_spec: _DepthDecodeCudaGraphSpec | None = None
|
||||
|
||||
def set_enabled(self, enabled: bool) -> None:
|
||||
self.enabled = bool(enabled)
|
||||
|
||||
def make_static_cache(self, max_cache_len: int) -> _DepthDecodeStaticCache:
|
||||
return _DepthDecodeStaticCache(
|
||||
config=self.model.config.text_config,
|
||||
max_cache_len=max_cache_len,
|
||||
)
|
||||
|
||||
def _depth_decode_spec(self) -> _DepthDecodeCudaGraphSpec:
|
||||
static = self.graph_spec
|
||||
if static is None:
|
||||
cfg = self.backbone.transformer.config
|
||||
rotary_emb = getattr(self.backbone.transformer, "rotary_emb", None)
|
||||
static = _DepthDecodeCudaGraphSpec(
|
||||
eligible=(
|
||||
not cfg.norm_after
|
||||
and cfg.rope_scaling_layers is None
|
||||
and getattr(rotary_emb, "rope_type", None) == "default"
|
||||
and cfg._attn_implementation == "sdpa"
|
||||
),
|
||||
cache_key_prefix=(
|
||||
cfg.hidden_size,
|
||||
cfg.num_attention_heads,
|
||||
cfg.num_key_value_heads,
|
||||
cfg.head_dim,
|
||||
cfg.num_hidden_layers,
|
||||
cfg.use_qk_norm,
|
||||
cfg.qk_norm_type,
|
||||
cfg._attn_implementation,
|
||||
),
|
||||
num_hidden_layers=cfg.num_hidden_layers,
|
||||
head_dim=cfg.head_dim,
|
||||
num_attention_heads=cfg.num_attention_heads,
|
||||
)
|
||||
self.graph_spec = static
|
||||
return static
|
||||
|
||||
def can_use(
|
||||
self,
|
||||
next_input_ids: torch.Tensor,
|
||||
*,
|
||||
past_key_values: Cache,
|
||||
attention_bias: torch.Tensor,
|
||||
) -> bool:
|
||||
if not self.enabled or self.model.training or self.backbone.transformer.training:
|
||||
return False
|
||||
if next_input_ids.device.type != "cuda":
|
||||
return False
|
||||
if next_input_ids.ndim != 2 or next_input_ids.shape[0] != 1 or next_input_ids.shape[1] != 1:
|
||||
return False
|
||||
if not isinstance(past_key_values, _DepthDecodeStaticCache):
|
||||
return False
|
||||
if not torch.is_tensor(attention_bias) or attention_bias.device != next_input_ids.device:
|
||||
return False
|
||||
return self._depth_decode_spec().eligible
|
||||
|
||||
def _depth_decode_key(
|
||||
self,
|
||||
next_input_ids: torch.Tensor,
|
||||
attention_bias: torch.Tensor,
|
||||
) -> tuple[Any, ...]:
|
||||
device = next_input_ids.device
|
||||
return (
|
||||
self._depth_decode_spec().cache_key_prefix,
|
||||
device.type,
|
||||
device.index,
|
||||
self.model.lm_head.weight.dtype,
|
||||
attention_bias.shape[-1],
|
||||
)
|
||||
|
||||
def _select_depth_decode_rope(self, cos: torch.Tensor, sin: torch.Tensor, *, past_length: int) -> None:
|
||||
emb = self.backbone.transformer.rotary_emb
|
||||
cos.copy_(emb._pos_cos_cache[0, :, past_length : past_length + 1, :])
|
||||
sin.copy_(emb._pos_sin_cache[0, :, past_length : past_length + 1, :])
|
||||
|
||||
def _depth_decode_pre_layer(
|
||||
self,
|
||||
layer_idx: int,
|
||||
hidden_states: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
block = self.backbone.transformer.blocks[layer_idx]
|
||||
attention = block.self_attn
|
||||
residual = hidden_states
|
||||
hidden_states = block.attn_norm(hidden_states)
|
||||
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, attention.head_dim)
|
||||
qkv = attention.att_proj(hidden_states)
|
||||
query_states, key_states, value_states = qkv.split(attention.fused_dims, dim=-1)
|
||||
value_states = value_states.view(hidden_shape)
|
||||
|
||||
apply_qk_norm = attention.q_norm is not None and attention.k_norm is not None
|
||||
norm_after_view = apply_qk_norm and attention.qk_norm_type == "qwen3"
|
||||
|
||||
if apply_qk_norm and not norm_after_view:
|
||||
query_states = attention.q_norm(query_states)
|
||||
key_states = attention.k_norm(key_states)
|
||||
|
||||
query_states = query_states.view(hidden_shape)
|
||||
key_states = key_states.view(hidden_shape)
|
||||
|
||||
if norm_after_view:
|
||||
query_states = attention.q_norm(query_states)
|
||||
key_states = attention.k_norm(key_states)
|
||||
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
query_states, key_states = _apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
return residual, query_states, key_states, value_states
|
||||
|
||||
def _depth_decode_pre0(
|
||||
self,
|
||||
token_ids: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
inputs_embeds = self.model._embed_base_tokens(token_ids)
|
||||
return self._depth_decode_pre_layer(0, inputs_embeds, cos, sin)
|
||||
|
||||
def _depth_decode_post_layer(
|
||||
self,
|
||||
layer_idx: int,
|
||||
residual: torch.Tensor,
|
||||
attn_context: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
block = self.backbone.transformer.blocks[layer_idx]
|
||||
attention = block.self_attn
|
||||
input_shape = residual.shape[:-1]
|
||||
attn_output = attn_context.reshape(*input_shape, -1).contiguous()
|
||||
attn_output = attention.attn_out(attn_output)
|
||||
hidden_states = residual + block.dropout(attn_output)
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = block.ff_norm(hidden_states)
|
||||
hidden_states = block.mlp(hidden_states)
|
||||
hidden_states = residual + block.dropout(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
def _depth_decode_post_and_pre_next(
|
||||
self,
|
||||
layer_idx: int,
|
||||
residual: torch.Tensor,
|
||||
attn_context: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
hidden_states = self._depth_decode_post_layer(layer_idx, residual, attn_context)
|
||||
return self._depth_decode_pre_layer(layer_idx + 1, hidden_states, cos, sin)
|
||||
|
||||
def _depth_decode_last_post(
|
||||
self,
|
||||
layer_idx: int,
|
||||
residual: torch.Tensor,
|
||||
attn_context: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self._depth_decode_post_layer(layer_idx, residual, attn_context)
|
||||
return self.backbone.transformer.ln_f(hidden_states)
|
||||
|
||||
def _build_depth_decode_graph(
|
||||
self,
|
||||
next_input_ids: torch.Tensor,
|
||||
*,
|
||||
past_length: int,
|
||||
attention_bias: torch.Tensor,
|
||||
) -> _DepthDecodeCudaGraph:
|
||||
text_config = self.backbone.transformer.config
|
||||
device = next_input_ids.device
|
||||
dtype = self.model.lm_head.weight.dtype
|
||||
static = self._depth_decode_spec()
|
||||
num_layers = static.num_hidden_layers
|
||||
head_dim = static.head_dim
|
||||
max_cache_len = int(attention_bias.shape[-1])
|
||||
max_rope_len = max(int(text_config.max_position_embeddings or 0), max_cache_len)
|
||||
self.backbone.transformer.prepare_rope_cache(device=device, max_seq_len=max_rope_len)
|
||||
|
||||
token_ids = torch.empty((1, 1), device=device, dtype=torch.long)
|
||||
cos = torch.empty((1, 1, head_dim), device=device, dtype=dtype)
|
||||
sin = torch.empty_like(cos)
|
||||
positions = torch.arange(max_cache_len, device=device, dtype=torch.long)
|
||||
context_shape = (1, 1, static.num_attention_heads, head_dim)
|
||||
|
||||
token_ids.copy_(next_input_ids)
|
||||
self._select_depth_decode_rope(cos, sin, past_length=past_length)
|
||||
|
||||
pre_graph, pre_output = _capture_cuda_graph(
|
||||
lambda: self._depth_decode_pre0(token_ids, cos, sin),
|
||||
device,
|
||||
)
|
||||
stages = [_DepthDecodeCudaGraphLayerStage(*pre_output)]
|
||||
post_graphs = []
|
||||
for layer_idx in range(num_layers - 1):
|
||||
stage = stages[-1]
|
||||
attn_context = torch.empty(context_shape, device=device, dtype=dtype)
|
||||
graph, output = _capture_cuda_graph(
|
||||
lambda layer_idx=layer_idx, stage=stage, attn_context=attn_context: (
|
||||
self._depth_decode_post_and_pre_next(
|
||||
layer_idx,
|
||||
stage.residual,
|
||||
attn_context,
|
||||
cos,
|
||||
sin,
|
||||
)
|
||||
),
|
||||
device,
|
||||
)
|
||||
post_graphs.append(_DepthDecodeCudaGraphPostStage(graph=graph, attn_context=attn_context))
|
||||
stages.append(_DepthDecodeCudaGraphLayerStage(*output))
|
||||
|
||||
last_stage = stages[-1]
|
||||
last_attn_context = torch.empty(context_shape, device=device, dtype=dtype)
|
||||
last_graph, last_output = _capture_cuda_graph(
|
||||
lambda: self._depth_decode_last_post(
|
||||
num_layers - 1,
|
||||
last_stage.residual,
|
||||
last_attn_context,
|
||||
),
|
||||
device,
|
||||
)
|
||||
post_graphs.append(_DepthDecodeCudaGraphPostStage(graph=last_graph, attn_context=last_attn_context))
|
||||
return _DepthDecodeCudaGraph(
|
||||
cache_key=self._depth_decode_key(next_input_ids, attention_bias),
|
||||
pre_graph=pre_graph,
|
||||
token_ids=token_ids,
|
||||
cos=cos,
|
||||
sin=sin,
|
||||
positions=positions,
|
||||
stages=tuple(stages),
|
||||
post_graphs=tuple(post_graphs),
|
||||
output=last_output,
|
||||
)
|
||||
|
||||
def _get_depth_decode_graph(
|
||||
self,
|
||||
next_input_ids: torch.Tensor,
|
||||
*,
|
||||
past_length: int,
|
||||
attention_bias: torch.Tensor,
|
||||
) -> _DepthDecodeCudaGraph:
|
||||
key = self._depth_decode_key(next_input_ids, attention_bias)
|
||||
decode_graph = self.graph
|
||||
if decode_graph is None or decode_graph.cache_key != key:
|
||||
decode_graph = self._build_depth_decode_graph(
|
||||
next_input_ids,
|
||||
past_length=past_length,
|
||||
attention_bias=attention_bias,
|
||||
)
|
||||
self.graph = decode_graph
|
||||
else:
|
||||
decode_graph.token_ids.copy_(next_input_ids)
|
||||
self._select_depth_decode_rope(decode_graph.cos, decode_graph.sin, past_length=past_length)
|
||||
return decode_graph
|
||||
|
||||
def _run_depth_decode_attention_core(
|
||||
self,
|
||||
layer_idx: int,
|
||||
stage: _DepthDecodeCudaGraphLayerStage,
|
||||
*,
|
||||
past_key_values: Cache,
|
||||
attention_bias: torch.Tensor,
|
||||
cache_position: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
attention = self.backbone.transformer.blocks[layer_idx].self_attn
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
key_states, value_states = past_key_values.update(
|
||||
stage.key,
|
||||
stage.value,
|
||||
layer_idx,
|
||||
cache_kwargs,
|
||||
)
|
||||
key_states = _repeat_kv(key_states, attention.num_key_value_groups)
|
||||
value_states = _repeat_kv(value_states, attention.num_key_value_groups)
|
||||
attn_output = F.scaled_dot_product_attention(
|
||||
stage.query,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=attention_bias,
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
)
|
||||
return attn_output.transpose(1, 2)
|
||||
|
||||
def run(
|
||||
self,
|
||||
next_input_ids: torch.Tensor,
|
||||
*,
|
||||
past_key_values: Cache,
|
||||
attention_bias: torch.Tensor,
|
||||
past_length: int,
|
||||
) -> tuple[torch.Tensor, Cache]:
|
||||
end = past_length + 1
|
||||
decode_graph = self._get_depth_decode_graph(
|
||||
next_input_ids,
|
||||
past_length=past_length,
|
||||
attention_bias=attention_bias,
|
||||
)
|
||||
cache_position = decode_graph.positions[past_length:end]
|
||||
attention_bias_q = attention_bias[:, :, past_length:end, :end]
|
||||
|
||||
decode_graph.pre_graph.replay()
|
||||
|
||||
for layer_idx, post_graph in enumerate(decode_graph.post_graphs):
|
||||
attn_context = self._run_depth_decode_attention_core(
|
||||
layer_idx,
|
||||
decode_graph.stages[layer_idx],
|
||||
past_key_values=past_key_values,
|
||||
attention_bias=attention_bias_q,
|
||||
cache_position=cache_position,
|
||||
cos=decode_graph.cos,
|
||||
sin=decode_graph.sin,
|
||||
)
|
||||
post_graph.attn_context.copy_(attn_context)
|
||||
post_graph.graph.replay()
|
||||
|
||||
return decode_graph.output, past_key_values
|
||||
|
||||
|
||||
def _cuda_graph_tensor_signature(
|
||||
tensor: torch.Tensor | None,
|
||||
) -> tuple[Any, ...] | None:
|
||||
if tensor is None:
|
||||
return None
|
||||
return (
|
||||
tuple(tensor.shape),
|
||||
tuple(tensor.stride()),
|
||||
str(tensor.dtype),
|
||||
str(tensor.device),
|
||||
)
|
||||
|
||||
|
||||
def _cuda_graph_context_signature(context: Any) -> tuple[Any, ...]:
|
||||
sig = _cuda_graph_tensor_signature
|
||||
return (
|
||||
tuple((sig(k), sig(v)) for k, v in context.kv_contexts),
|
||||
sig(context.cross_mask),
|
||||
sig(context.self_mask),
|
||||
sig(context.valid_action),
|
||||
None if context.rope_cache is None else tuple(sig(t) for t in context.rope_cache),
|
||||
)
|
||||
|
||||
|
||||
def _cuda_graph_modulation_signature(modulations: Sequence[Any]) -> tuple[Any, ...]:
|
||||
sig = _cuda_graph_tensor_signature
|
||||
return tuple(
|
||||
(
|
||||
sig(step.conditioning),
|
||||
tuple(tuple(sig(t) for t in block_modulation) for block_modulation in step.block_modulations),
|
||||
tuple(sig(t) for t in step.final_modulation),
|
||||
)
|
||||
for step in modulations
|
||||
)
|
||||
|
||||
|
||||
def _cuda_graph_key(inputs: _ActionFlowInputs, steps: int) -> tuple[Any, ...]:
|
||||
sig = _cuda_graph_tensor_signature
|
||||
return (
|
||||
sig(inputs.trajectory),
|
||||
_cuda_graph_context_signature(inputs.context),
|
||||
_cuda_graph_modulation_signature(inputs.modulations),
|
||||
sig(inputs.action_dim_is_pad),
|
||||
int(steps),
|
||||
)
|
||||
|
||||
|
||||
def _clone_static_tensor(tensor: torch.Tensor | None) -> torch.Tensor | None:
|
||||
if tensor is None:
|
||||
return None
|
||||
static = torch.empty_strided(
|
||||
tuple(tensor.shape),
|
||||
tuple(tensor.stride()),
|
||||
device=tensor.device,
|
||||
dtype=tensor.dtype,
|
||||
)
|
||||
static.copy_(tensor)
|
||||
return static
|
||||
|
||||
|
||||
def _clone_static_context(context: Any) -> Any:
|
||||
rope_cache = None
|
||||
if context.rope_cache is not None:
|
||||
rope_cache = tuple(_clone_static_tensor(t) for t in context.rope_cache)
|
||||
return context.__class__(
|
||||
kv_contexts=tuple((_clone_static_tensor(k), _clone_static_tensor(v)) for k, v in context.kv_contexts),
|
||||
cross_mask=_clone_static_tensor(context.cross_mask),
|
||||
self_mask=_clone_static_tensor(context.self_mask),
|
||||
valid_action=_clone_static_tensor(context.valid_action),
|
||||
rope_cache=rope_cache,
|
||||
)
|
||||
|
||||
|
||||
def _clone_static_modulations(modulations: Sequence[Any]) -> Sequence[Any]:
|
||||
return tuple(
|
||||
step.__class__(
|
||||
conditioning=_clone_static_tensor(step.conditioning),
|
||||
block_modulations=tuple(
|
||||
tuple(_clone_static_tensor(t) for t in block_modulation)
|
||||
for block_modulation in step.block_modulations
|
||||
),
|
||||
final_modulation=tuple(_clone_static_tensor(t) for t in step.final_modulation),
|
||||
)
|
||||
for step in modulations
|
||||
)
|
||||
|
||||
|
||||
def _clone_static_inputs(inputs: _ActionFlowInputs) -> _ActionFlowInputs:
|
||||
return _ActionFlowInputs(
|
||||
trajectory=_clone_static_tensor(inputs.trajectory),
|
||||
context=_clone_static_context(inputs.context),
|
||||
modulations=_clone_static_modulations(inputs.modulations),
|
||||
action_dim_is_pad=_clone_static_tensor(inputs.action_dim_is_pad),
|
||||
)
|
||||
|
||||
|
||||
def _copy_context_(dst: Any, src: Any) -> None:
|
||||
for (dst_k, dst_v), (src_k, src_v) in zip(dst.kv_contexts, src.kv_contexts):
|
||||
dst_k.copy_(src_k)
|
||||
dst_v.copy_(src_v)
|
||||
if src.cross_mask is not None:
|
||||
dst.cross_mask.copy_(src.cross_mask)
|
||||
if src.self_mask is not None:
|
||||
dst.self_mask.copy_(src.self_mask)
|
||||
if src.valid_action is not None:
|
||||
dst.valid_action.copy_(src.valid_action)
|
||||
if src.rope_cache is not None:
|
||||
for dst_tensor, src_tensor in zip(dst.rope_cache, src.rope_cache):
|
||||
dst_tensor.copy_(src_tensor)
|
||||
|
||||
|
||||
def _copy_inputs_(dst: _ActionFlowInputs, src: _ActionFlowInputs) -> None:
|
||||
dst.trajectory.copy_(src.trajectory)
|
||||
_copy_context_(dst.context, src.context)
|
||||
if src.action_dim_is_pad is not None:
|
||||
dst.action_dim_is_pad.copy_(src.action_dim_is_pad)
|
||||
|
||||
|
||||
def _rotate_half(x: torch.Tensor) -> torch.Tensor:
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def _apply_rotary_pos_emb(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
unsqueeze_dim: int = 1,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
cos = cos.unsqueeze(unsqueeze_dim)
|
||||
sin = sin.unsqueeze(unsqueeze_dim)
|
||||
q_embed = (q * cos) + (_rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (_rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
def _repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
||||
if n_rep == 1:
|
||||
return hidden_states
|
||||
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
||||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||
|
||||
|
||||
def _capture_cuda_graph(
|
||||
fn,
|
||||
device: torch.device,
|
||||
*,
|
||||
after_warmup=None,
|
||||
) -> tuple[torch.cuda.CUDAGraph, Any]:
|
||||
warmup_stream = torch.cuda.Stream(device=device)
|
||||
warmup_stream.wait_stream(torch.cuda.current_stream(device))
|
||||
with torch.cuda.stream(warmup_stream):
|
||||
fn()
|
||||
torch.cuda.current_stream(device).wait_stream(warmup_stream)
|
||||
if after_warmup is not None:
|
||||
after_warmup()
|
||||
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph):
|
||||
output = fn()
|
||||
return graph, output
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,431 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# ruff: noqa
|
||||
|
||||
"""
|
||||
Processor class for MolmoAct2.
|
||||
"""
|
||||
|
||||
from typing import Optional, Union
|
||||
import dataclasses
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers.image_utils import ImageInput
|
||||
from transformers.video_utils import VideoInput
|
||||
from transformers.processing_utils import (
|
||||
Unpack,
|
||||
ProcessingKwargs,
|
||||
ProcessorMixin,
|
||||
)
|
||||
from transformers.feature_extraction_utils import BatchFeature
|
||||
from transformers.tokenization_utils_base import TextInput, PreTokenizedInput
|
||||
from transformers.utils import logging
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
from .image_processing_molmoact2 import MolmoAct2ImagesKwargs, MolmoAct2ImageProcessor
|
||||
from .video_processing_molmoact2 import MolmoAct2VideoProcessorKwargs, MolmoAct2VideoProcessor
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# Special tokens, these should be present in any tokenizer we use since the preprocessor uses them
|
||||
IMAGE_PATCH_TOKEN = f"<im_patch>" # Where to insert high-res tokens
|
||||
IMAGE_LOW_RES_TOKEN = f"<im_low>" # Where to insert low-res tokens
|
||||
IM_START_TOKEN = f"<im_start>"
|
||||
LOW_RES_IMAGE_START_TOKEN = f"<low_res_im_start>"
|
||||
FRAME_START_TOKEN = f"<frame_start>"
|
||||
IM_END_TOKEN = f"<im_end>"
|
||||
FRAME_END_TOKEN = f"<frame_end>"
|
||||
IM_COL_TOKEN = f"<im_col>"
|
||||
IMAGE_PROMPT = "<|image|>"
|
||||
VIDEO_PROMPT = "<|video|>"
|
||||
|
||||
IMAGE_TOKENS = [
|
||||
IMAGE_PATCH_TOKEN,
|
||||
IM_COL_TOKEN,
|
||||
IM_START_TOKEN,
|
||||
LOW_RES_IMAGE_START_TOKEN,
|
||||
FRAME_START_TOKEN,
|
||||
IM_END_TOKEN,
|
||||
FRAME_END_TOKEN,
|
||||
IMAGE_LOW_RES_TOKEN,
|
||||
]
|
||||
|
||||
|
||||
class MolmoAct2ProcessorKwargs(ProcessingKwargs, total=False):
|
||||
"""MolmoAct2 processor kwargs"""
|
||||
|
||||
images_kwargs: MolmoAct2ImagesKwargs
|
||||
videos_kwargs: MolmoAct2VideoProcessorKwargs
|
||||
_defaults = {
|
||||
"text_kwargs": {
|
||||
"padding": False,
|
||||
"return_mm_token_type_ids": True,
|
||||
},
|
||||
"videos_kwargs": {"return_metadata": True},
|
||||
}
|
||||
|
||||
|
||||
class MolmoAct2Processor(ProcessorMixin):
|
||||
attributes = ["image_processor", "video_processor", "tokenizer"]
|
||||
optional_attributes = [
|
||||
"chat_template",
|
||||
"time_mode",
|
||||
"image_use_col_tokens",
|
||||
"use_single_crop_col_tokens",
|
||||
"use_single_crop_start_token",
|
||||
"video_use_col_tokens",
|
||||
"use_frame_special_tokens",
|
||||
]
|
||||
image_processor_class = "AutoImageProcessor"
|
||||
video_processor_class = "AutoVideoProcessor"
|
||||
tokenizer_class = "AutoTokenizer"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_processor: MolmoAct2ImageProcessor = None,
|
||||
video_processor: MolmoAct2VideoProcessor = None,
|
||||
tokenizer: AutoTokenizer = None,
|
||||
chat_template: str | None = None,
|
||||
image_use_col_tokens: bool | None = True,
|
||||
use_single_crop_col_tokens: bool | None = None,
|
||||
use_single_crop_start_token: bool | None = True,
|
||||
video_use_col_tokens: bool | None = False,
|
||||
use_frame_special_tokens: bool | None = True,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
image_processor,
|
||||
video_processor,
|
||||
tokenizer,
|
||||
chat_template=chat_template,
|
||||
)
|
||||
self.image_use_col_tokens = image_use_col_tokens
|
||||
self.use_single_crop_col_tokens = use_single_crop_col_tokens
|
||||
self.use_single_crop_start_token = use_single_crop_start_token
|
||||
self.video_use_col_tokens = video_use_col_tokens
|
||||
self.use_frame_special_tokens = use_frame_special_tokens
|
||||
|
||||
self.image_placeholder_token = IMAGE_PROMPT
|
||||
self.video_placeholder_token = VIDEO_PROMPT
|
||||
self.image_token_ids = [tokenizer.convert_tokens_to_ids(token) for token in IMAGE_TOKENS]
|
||||
|
||||
def get_image_tokens(self, image_grid: np.ndarray):
|
||||
resized_h, resized_w, height, width = image_grid
|
||||
if int(height) == 0 or int(width) == 0:
|
||||
per_row = np.full(resized_w, IMAGE_PATCH_TOKEN)
|
||||
use_single_crop_col_tokens = (
|
||||
self.image_use_col_tokens
|
||||
if self.use_single_crop_col_tokens is None
|
||||
else self.use_single_crop_col_tokens
|
||||
)
|
||||
if use_single_crop_col_tokens:
|
||||
per_row = np.concatenate([per_row, [IM_COL_TOKEN]], 0)
|
||||
joint = [
|
||||
[IM_START_TOKEN],
|
||||
np.tile(per_row, [resized_h]),
|
||||
[IM_END_TOKEN],
|
||||
]
|
||||
return np.concatenate(joint)
|
||||
per_row = np.full(width, IMAGE_PATCH_TOKEN)
|
||||
if self.image_use_col_tokens:
|
||||
per_row = np.concatenate([per_row, [IM_COL_TOKEN]], 0)
|
||||
joint = [
|
||||
[IM_START_TOKEN],
|
||||
np.tile(per_row, [height]),
|
||||
[IM_END_TOKEN],
|
||||
]
|
||||
per_row = np.full(resized_w, IMAGE_PATCH_TOKEN)
|
||||
use_single_crop_col_tokens = (
|
||||
self.image_use_col_tokens
|
||||
if self.use_single_crop_col_tokens is None
|
||||
else self.use_single_crop_col_tokens
|
||||
)
|
||||
image_start_token = LOW_RES_IMAGE_START_TOKEN if self.use_single_crop_start_token else IM_START_TOKEN
|
||||
if use_single_crop_col_tokens:
|
||||
per_row = np.concatenate([per_row, [IM_COL_TOKEN]], 0)
|
||||
joint = [
|
||||
[image_start_token],
|
||||
np.tile(per_row, [resized_h]),
|
||||
[IM_END_TOKEN],
|
||||
] + joint
|
||||
|
||||
return np.concatenate(joint)
|
||||
|
||||
def get_video_string(
|
||||
self,
|
||||
video_grid: np.ndarray,
|
||||
timestamps: np.ndarray,
|
||||
):
|
||||
if self.use_frame_special_tokens:
|
||||
start_token_id = FRAME_START_TOKEN
|
||||
end_token_id = FRAME_END_TOKEN
|
||||
else:
|
||||
start_token_id = IM_START_TOKEN
|
||||
end_token_id = IM_END_TOKEN
|
||||
|
||||
num_frames, h, w = video_grid
|
||||
video_string: str = ""
|
||||
for frame_idx, frame_time in enumerate(timestamps):
|
||||
# `per-frame-compact` time mode
|
||||
prev_space = " " if frame_idx > 0 else ""
|
||||
frame_prefix = prev_space + f"{frame_time:.1f} " # explicit whitespace before/after image tokens
|
||||
|
||||
video_string += frame_prefix
|
||||
per_row = np.full(w, IMAGE_PATCH_TOKEN)
|
||||
if self.video_use_col_tokens:
|
||||
per_row = np.concatenate([per_row, [IM_COL_TOKEN]], 0)
|
||||
extra_tokens = np.tile(per_row, [h])
|
||||
video_tokens = [
|
||||
[start_token_id],
|
||||
extra_tokens,
|
||||
[end_token_id],
|
||||
]
|
||||
video_string += "".join(np.concatenate(video_tokens, 0))
|
||||
|
||||
return video_string
|
||||
|
||||
def insert_bos(
|
||||
self,
|
||||
input_ids: np.ndarray,
|
||||
attention_mask: np.ndarray,
|
||||
bos_token_id: int,
|
||||
pad_token_id: int,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
input_ids: [B, S] array with left padding
|
||||
attention_mask: [B, S] array (0 for pad, 1 for valid)
|
||||
bos_token_id: int
|
||||
pad_token_id: int
|
||||
Returns:
|
||||
input_ids_out: [B, S] or [B, S+1] array with bos inserted if needed
|
||||
attention_mask_out: same shape as input_ids_out
|
||||
"""
|
||||
|
||||
need_to_expand = len(input_ids.shape) == 1
|
||||
if need_to_expand:
|
||||
input_ids = input_ids[None, :]
|
||||
attention_mask = attention_mask[None, :]
|
||||
|
||||
B, S = input_ids.shape
|
||||
|
||||
# Handle zero-length sequence
|
||||
if S == 0:
|
||||
new_input_ids = np.full((B, 1), bos_token_id, dtype=input_ids.dtype)
|
||||
new_attention_mask = np.ones((B, 1), dtype=attention_mask.dtype)
|
||||
if need_to_expand:
|
||||
new_input_ids = new_input_ids[0]
|
||||
new_attention_mask = new_attention_mask[0]
|
||||
return new_input_ids, new_attention_mask
|
||||
|
||||
first_valid_index = (attention_mask == 1).argmax(axis=-1) # [B]
|
||||
bos_already_present = np.all(input_ids[np.arange(B), first_valid_index] == bos_token_id)
|
||||
|
||||
if bos_already_present:
|
||||
if need_to_expand:
|
||||
input_ids = input_ids[0]
|
||||
attention_mask = attention_mask[0]
|
||||
return input_ids, attention_mask
|
||||
else:
|
||||
new_input_ids = np.full((B, S + 1), pad_token_id, dtype=input_ids.dtype)
|
||||
new_attention_mask = np.zeros((B, S + 1), dtype=attention_mask.dtype)
|
||||
|
||||
src_idx = np.tile(np.arange(S), (B, 1)) # [B, S]
|
||||
valid_mask = src_idx >= first_valid_index[:, None] # [B, S]
|
||||
tgt_idx = src_idx + 1 # shit right
|
||||
batch_idx = np.tile(np.arange(B)[:, None], (1, S)) # [B, S]
|
||||
|
||||
# flatten valid_positions
|
||||
flat_vals = input_ids[valid_mask]
|
||||
flat_batch = batch_idx[valid_mask]
|
||||
flat_tgt = tgt_idx[valid_mask]
|
||||
|
||||
new_input_ids[flat_batch, flat_tgt] = flat_vals
|
||||
new_attention_mask[flat_batch, flat_tgt] = 1
|
||||
|
||||
insert_pos = first_valid_index
|
||||
new_input_ids[np.arange(B), insert_pos] = bos_token_id
|
||||
new_attention_mask[np.arange(B), insert_pos] = 1
|
||||
|
||||
if need_to_expand:
|
||||
new_input_ids = new_input_ids[0]
|
||||
new_attention_mask = new_attention_mask[0]
|
||||
|
||||
return new_input_ids, new_attention_mask
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
text: TextInput | PreTokenizedInput | list[TextInput] | list[PreTokenizedInput] = None,
|
||||
images: ImageInput = None,
|
||||
videos: VideoInput = None,
|
||||
**kwargs: Unpack[MolmoAct2ProcessorKwargs],
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
|
||||
Args:
|
||||
text (`str`, `list[str]`, `list[list[str]]`):
|
||||
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
||||
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
||||
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
||||
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`, `list[torch.Tensor]`):
|
||||
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
||||
tensor. Both channels-first and channels-last formats are supported.
|
||||
videos (`dict[str, Any]` or `list[dict[str, Any]]`):
|
||||
The video or batch of videos to be prepared. Each video can be a dictionary with the following keys:
|
||||
- `"frames"`: `np.ndarray` of shape (T, H, W, 3)
|
||||
- `"timestamps"`: `np.ndarray` of shape (T,)
|
||||
- `"sampled_fps"`: `float` (optional)
|
||||
- `"sampling_augmentation"`: `str` (optional)
|
||||
return_tensors (`str` or [`~utils.TensorType`], *optional*):
|
||||
If set, will return tensors of a particular framework. Acceptable values are:
|
||||
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
||||
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
||||
- `'np'`: Return NumPy `np.ndarray` objects.
|
||||
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
||||
|
||||
Returns:
|
||||
`BatchFeature`: A [`BatchFeature`] with the following fields:
|
||||
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
|
||||
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
|
||||
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not `None`).
|
||||
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
|
||||
- **image_token_pooling** -- Indices of the patches in `image_grids` to pool for each token in `image_tokens`.
|
||||
Returned when `images` is not `None`.
|
||||
- **image_grids** -- Grids of images. Returned when `images` is not `None`.
|
||||
- **image_num_crops** -- Number of crops for each image. Returned when `images` is not `None`.
|
||||
- **pixel_values_videos** -- Pixel values of videos to be fed to a model. Returned when `videos` is not `None`.
|
||||
- **video_token_pooling** -- Indices of the patches in `video_grids` to pool for each token in `video_tokens`.
|
||||
Returned when `videos` is not `None`.
|
||||
- **video_grids** -- Grids of videos. Returned when `videos` is not `None`.
|
||||
"""
|
||||
|
||||
output_kwargs = self._merge_kwargs(
|
||||
MolmoAct2ProcessorKwargs,
|
||||
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if images is not None:
|
||||
image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
|
||||
image_grids = image_inputs["image_grids"]
|
||||
else:
|
||||
image_inputs = {}
|
||||
image_grids = None
|
||||
|
||||
if videos is not None:
|
||||
videos_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"])
|
||||
video_grids = videos_inputs["video_grids"]
|
||||
# If user has not requested video metadata, pop it
|
||||
if "return_metadata" not in kwargs:
|
||||
video_metadata = videos_inputs.pop("video_metadata")
|
||||
else:
|
||||
video_metadata = videos_inputs["video_metadata"]
|
||||
else:
|
||||
videos_inputs = {}
|
||||
video_grids = None
|
||||
|
||||
if not isinstance(text, list):
|
||||
text = [text]
|
||||
|
||||
text = text.copy() # below lines change text in-place
|
||||
|
||||
if image_grids is not None:
|
||||
index = 0
|
||||
for i in range(len(text)):
|
||||
num_images = text[i].count(self.image_placeholder_token)
|
||||
image_grids_i = image_grids[index : index + num_images]
|
||||
for image_grid in image_grids_i:
|
||||
image_tokens = self.get_image_tokens(image_grid)
|
||||
image_string = "".join(image_tokens)
|
||||
text[i] = text[i].replace(self.image_placeholder_token, image_string, 1)
|
||||
index += num_images
|
||||
|
||||
if video_grids is not None:
|
||||
index = 0
|
||||
for i in range(len(text)):
|
||||
num_videos = text[i].count(self.video_placeholder_token)
|
||||
assert num_videos in {0, 1}, "At most one video is supported for now"
|
||||
video_grids_i = video_grids[index : index + num_videos]
|
||||
metadata_i = video_metadata[index : index + num_videos]
|
||||
for video_grid, metadata in zip(video_grids_i, metadata_i):
|
||||
video_string = self.get_video_string(
|
||||
video_grid,
|
||||
metadata.timestamps,
|
||||
)
|
||||
text[i] = text[i].replace(self.video_placeholder_token, video_string, 1)
|
||||
index += num_videos
|
||||
|
||||
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
|
||||
return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False)
|
||||
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
|
||||
|
||||
input_ids = text_inputs["input_ids"]
|
||||
attention_mask = text_inputs["attention_mask"]
|
||||
|
||||
input_ids = np.array(input_ids)
|
||||
attention_mask = np.array(attention_mask)
|
||||
|
||||
bos = self.tokenizer.bos_token_id or self.tokenizer.eos_token_id
|
||||
input_ids, attention_mask = self.insert_bos(
|
||||
input_ids, attention_mask, bos, self.tokenizer.pad_token_id
|
||||
)
|
||||
|
||||
if return_mm_token_type_ids:
|
||||
image_tokens = np.array(self.image_token_ids).astype(input_ids.dtype)
|
||||
token_type_ids = np.any(input_ids[:, :, None] == image_tokens[None, None, :], axis=-1)
|
||||
text_inputs["token_type_ids"] = token_type_ids.tolist()
|
||||
|
||||
text_inputs["input_ids"] = input_ids.tolist()
|
||||
text_inputs["attention_mask"] = attention_mask.tolist()
|
||||
|
||||
return BatchFeature(
|
||||
data={**text_inputs, **image_inputs, **videos_inputs},
|
||||
tensor_type=return_tensors,
|
||||
)
|
||||
|
||||
def post_process_image_text_to_text(
|
||||
self, generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False, **kwargs
|
||||
):
|
||||
"""
|
||||
Post-process the output of the model to decode the text.
|
||||
|
||||
Args:
|
||||
generated_outputs (`torch.Tensor` or `np.ndarray`):
|
||||
The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
|
||||
or `(sequence_length,)`.
|
||||
skip_special_tokens (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to remove special tokens in the output. Argument passed to the tokenizer's `batch_decode` method.
|
||||
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to clean up the tokenization spaces. Argument passed to the tokenizer's `batch_decode` method.
|
||||
**kwargs:
|
||||
Additional arguments to be passed to the tokenizer's `batch_decode method`.
|
||||
|
||||
Returns:
|
||||
`list[str]`: The decoded text.
|
||||
"""
|
||||
return self.tokenizer.batch_decode(
|
||||
generated_outputs,
|
||||
skip_special_tokens=skip_special_tokens,
|
||||
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
MolmoAct2Processor.register_for_auto_class()
|
||||
@@ -1,997 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# ruff: noqa
|
||||
|
||||
"""Video processor class for MolmoAct2"""
|
||||
|
||||
from functools import partial
|
||||
import os
|
||||
import warnings
|
||||
from contextlib import redirect_stdout
|
||||
from io import BytesIO
|
||||
from urllib.parse import urlparse
|
||||
from typing import Optional, Union
|
||||
from collections.abc import Callable
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
import einops
|
||||
import torch
|
||||
import torchvision.transforms
|
||||
|
||||
from transformers.image_utils import (
|
||||
IMAGENET_STANDARD_MEAN,
|
||||
IMAGENET_STANDARD_STD,
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
SizeDict,
|
||||
validate_kwargs,
|
||||
)
|
||||
from transformers.video_utils import (
|
||||
VideoInput,
|
||||
is_valid_video,
|
||||
make_batched_videos,
|
||||
make_batched_metadata,
|
||||
VideoMetadata,
|
||||
)
|
||||
from transformers.processing_utils import Unpack, VideosKwargs
|
||||
from transformers.video_processing_utils import BaseVideoProcessor
|
||||
from transformers.utils import logging
|
||||
from transformers.feature_extraction_utils import BatchFeature
|
||||
from transformers.utils import (
|
||||
is_av_available,
|
||||
is_decord_available,
|
||||
is_torchcodec_available,
|
||||
is_yt_dlp_available,
|
||||
TensorType,
|
||||
logging,
|
||||
to_numpy,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
MAX_VIDEO_FPS = 8
|
||||
|
||||
|
||||
def normalize_image(
|
||||
image: np.ndarray,
|
||||
image_mean: list[float],
|
||||
image_std: list[float],
|
||||
) -> np.ndarray:
|
||||
if np.allclose(image_mean, [0.5, 0.5, 0.5]) and np.allclose(image_std, [0.5, 0.5, 0.5]):
|
||||
return image * np.asarray(2.0, dtype=np.float32) - np.asarray(1.0, dtype=np.float32)
|
||||
image -= np.array(image_mean, dtype=np.float32)[None, None, :]
|
||||
image /= np.array(image_std, dtype=np.float32)[None, None, :]
|
||||
return image
|
||||
|
||||
|
||||
def resize_image(
|
||||
image: np.ndarray,
|
||||
desired_output_size: list[int],
|
||||
resample: PILImageResampling,
|
||||
) -> np.ndarray:
|
||||
if len(image.shape) == 3:
|
||||
is_video = False
|
||||
image = torch.permute(torch.from_numpy(image), [2, 0, 1])
|
||||
else:
|
||||
is_video = True
|
||||
image = torch.permute(torch.from_numpy(image), [0, 3, 1, 2])
|
||||
dtype = image.dtype
|
||||
if torch.is_floating_point(image):
|
||||
in_min = 0.0
|
||||
in_max = 1.0
|
||||
resized = torchvision.transforms.Resize(
|
||||
desired_output_size,
|
||||
resample,
|
||||
antialias=False,
|
||||
)(image)
|
||||
resized = torch.clip(resized, 0.0, 1.0).to(dtype)
|
||||
else:
|
||||
assert image.dtype == torch.uint8, "SigLIP expects float images or uint8 images, but got {}".format(
|
||||
image.dtype
|
||||
)
|
||||
in_min = 0.0
|
||||
in_max = 255.0
|
||||
resized = torchvision.transforms.Resize(
|
||||
desired_output_size,
|
||||
resample,
|
||||
antialias=False,
|
||||
)(image)
|
||||
resized = torch.clip(resized, 0, 255).to(dtype)
|
||||
|
||||
resized = resized.to(torch.float32)
|
||||
resized = (resized - in_min) / (in_max - in_min)
|
||||
|
||||
if is_video:
|
||||
resized = torch.permute(resized, [0, 2, 3, 1]).numpy()
|
||||
else:
|
||||
resized = torch.permute(resized, [1, 2, 0]).numpy()
|
||||
|
||||
return resized
|
||||
|
||||
|
||||
def build_resized_image(
|
||||
image: np.ndarray,
|
||||
base_image_input_size: list[int],
|
||||
resample: PILImageResampling,
|
||||
image_mean: list[float],
|
||||
image_std: list[float],
|
||||
image_patch_size: int,
|
||||
) -> tuple[np.ndarray, np.ndarray]:
|
||||
resized = resize_image(
|
||||
image,
|
||||
base_image_input_size,
|
||||
resample,
|
||||
)
|
||||
resized = normalize_image(resized, image_mean, image_std)
|
||||
if len(resized.shape) == 3:
|
||||
resized = np.expand_dims(resized, 0)
|
||||
crop_patch_w = base_image_input_size[1] // image_patch_size
|
||||
crop_patch_h = base_image_input_size[0] // image_patch_size
|
||||
resize_idx = np.arange(crop_patch_w * crop_patch_h).reshape([crop_patch_h, crop_patch_w])
|
||||
return resized, resize_idx
|
||||
|
||||
|
||||
def batch_pixels_to_patches(array: np.ndarray, patch_size: int) -> np.ndarray:
|
||||
"""Reshape images of [n_images, h, w, 3] -> [n_images, n_patches, pixels_per_patch]"""
|
||||
if len(array.shape) == 3:
|
||||
n_crops, h, w = array.shape
|
||||
h_patches = h // patch_size
|
||||
w_patches = w // patch_size
|
||||
array = np.reshape(array, [n_crops, h_patches, patch_size, w_patches, patch_size])
|
||||
array = np.transpose(array, [0, 1, 3, 2, 4])
|
||||
array = np.reshape(array, [n_crops, h_patches * w_patches, patch_size * patch_size])
|
||||
return array
|
||||
else:
|
||||
n_crops, h, w, c = array.shape
|
||||
h_patches = h // patch_size
|
||||
w_patches = w // patch_size
|
||||
array = np.reshape(array, [n_crops, h_patches, patch_size, w_patches, patch_size, c])
|
||||
array = np.transpose(array, [0, 1, 3, 2, 4, 5])
|
||||
array = np.reshape(array, [n_crops, h_patches * w_patches, patch_size * patch_size * c])
|
||||
return array
|
||||
|
||||
|
||||
def arange_for_pooling(
|
||||
idx_arr: np.ndarray,
|
||||
pool_h: int,
|
||||
pool_w: int,
|
||||
) -> np.ndarray:
|
||||
h_pad = pool_h * ((idx_arr.shape[0] + pool_h - 1) // pool_h) - idx_arr.shape[0]
|
||||
w_pad = pool_w * ((idx_arr.shape[1] + pool_w - 1) // pool_w) - idx_arr.shape[1]
|
||||
idx_arr = np.pad(
|
||||
idx_arr,
|
||||
[[h_pad // 2, (h_pad + 1) // 2], [w_pad // 2, (w_pad + 1) // 2]],
|
||||
mode="constant",
|
||||
constant_values=-1,
|
||||
)
|
||||
return einops.rearrange(idx_arr, "(h dh) (w dw) -> h w (dh dw)", dh=pool_h, dw=pool_w)
|
||||
|
||||
|
||||
def image_to_patches_and_grids(
|
||||
image: ImageInput,
|
||||
base_image_input_size: list[int],
|
||||
resample: PILImageResampling,
|
||||
image_mean: list[float],
|
||||
image_std: list[float],
|
||||
image_patch_size: int,
|
||||
image_pooling_w: int,
|
||||
image_pooling_h: int,
|
||||
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
"""
|
||||
:return image_grids, the shape of each image after pooling
|
||||
:return crops, the image crops to processes with the ViT
|
||||
:return pooled_patch_idx, for each patch_id tokens in `image_tokens`, the indices of the
|
||||
patches in `crops` to pool for that token, masked with -1
|
||||
"""
|
||||
if isinstance(base_image_input_size, int):
|
||||
base_image_input_size = (base_image_input_size, base_image_input_size)
|
||||
|
||||
pooling_w = image_pooling_w
|
||||
pooling_h = image_pooling_h
|
||||
|
||||
resized, resize_idx = build_resized_image(
|
||||
image,
|
||||
base_image_input_size,
|
||||
resample,
|
||||
image_mean,
|
||||
image_std,
|
||||
image_patch_size,
|
||||
)
|
||||
pooling_idx = arange_for_pooling(resize_idx, pooling_h, pooling_w)
|
||||
h, w = pooling_idx.shape[:2]
|
||||
pooling_idx = pooling_idx.reshape([-1, pooling_h * pooling_w])
|
||||
image_grid = [h, w]
|
||||
return (
|
||||
image_grid,
|
||||
batch_pixels_to_patches(resized, image_patch_size),
|
||||
pooling_idx,
|
||||
)
|
||||
|
||||
|
||||
def get_candidate_target_fps(
|
||||
video_fps: int | float,
|
||||
sampling_fps: int | float,
|
||||
max_fps: int | float = MAX_VIDEO_FPS,
|
||||
) -> list[float]:
|
||||
"""
|
||||
Return the subset of `video_fps` factors that remain multiples of `sampling_fps`.
|
||||
|
||||
Examples:
|
||||
>>> get_candidate_target_fps(video_fps=6, sampling_fps=2)
|
||||
[2, 6]
|
||||
>>> get_candidate_target_fps(video_fps=5, sampling_fps=1)
|
||||
[1, 5]
|
||||
>>> get_candidate_target_fps(video_fps=2, sampling_fps=2)
|
||||
[2]
|
||||
>>> get_candidate_target_fps(video_fps=5, sampling_fps=2)
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
ValueError: sampling_fps=2 must divide video_fps=5 to produce consistent frame steps.
|
||||
"""
|
||||
video_fps = int(video_fps)
|
||||
sampling_fps = int(sampling_fps)
|
||||
max_fps = int(max_fps)
|
||||
|
||||
if sampling_fps is None:
|
||||
raise ValueError("sampling_fps must be provided")
|
||||
if video_fps <= 0 or sampling_fps <= 0:
|
||||
raise ValueError(f"video_fps and sampling_fps must be positive (got {video_fps}, {sampling_fps})")
|
||||
if video_fps % sampling_fps != 0:
|
||||
raise ValueError(f"sampling_fps={sampling_fps} must divide video_fps={video_fps}.")
|
||||
|
||||
candidates = []
|
||||
for candidate in range(sampling_fps, video_fps + 1, sampling_fps):
|
||||
if candidate > max_fps:
|
||||
break
|
||||
if video_fps % candidate == 0:
|
||||
candidates.append(float(candidate))
|
||||
|
||||
return candidates
|
||||
|
||||
|
||||
def read_video_decord(
|
||||
video_path,
|
||||
sample_timestamps_fn: Callable,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Decode a video using the Decord backend.
|
||||
|
||||
Args:
|
||||
video_path (`str`):
|
||||
Path to the video file.
|
||||
sample_timestamps_fn (`Callable`):
|
||||
A callable function that will return timestamps at which the video should be sampled.
|
||||
|
||||
Returns:
|
||||
tuple[`np.array`, `VideoMetadata`]: A tuple containing:
|
||||
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
|
||||
- `VideoMetadata` object.
|
||||
"""
|
||||
# Lazy import from decord
|
||||
import importlib
|
||||
|
||||
decord = importlib.import_module("decord")
|
||||
|
||||
vr = decord.VideoReader(uri=video_path, ctx=decord.cpu(0)) # decord has problems with gpu
|
||||
video_fps = vr.get_avg_fps()
|
||||
total_num_frames = len(vr)
|
||||
time_stamps = vr.get_frame_timestamp(list(range(len(vr))))
|
||||
duration = time_stamps[-1][1] - time_stamps[0][0]
|
||||
|
||||
metadata = VideoMetadata(
|
||||
total_num_frames=int(total_num_frames),
|
||||
fps=float(video_fps),
|
||||
duration=float(duration),
|
||||
video_backend="decord",
|
||||
)
|
||||
|
||||
target_timestamps = sample_timestamps_fn(metadata=metadata, **kwargs)
|
||||
target_timestamps = np.array(target_timestamps)
|
||||
offset = time_stamps[0, 0]
|
||||
|
||||
ix = np.searchsorted(time_stamps[:, 1], target_timestamps + offset, side="right")
|
||||
ix = np.minimum(ix, len(time_stamps) - 1)
|
||||
|
||||
video = vr.get_batch(ix).asnumpy()
|
||||
metadata.update(
|
||||
{
|
||||
"frames_indices": target_timestamps * video_fps,
|
||||
"height": video.shape[1],
|
||||
"width": video.shape[2],
|
||||
}
|
||||
)
|
||||
return video, metadata
|
||||
|
||||
|
||||
def read_video_torchcodec(
|
||||
video_path,
|
||||
sample_timestamps_fn: Callable,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Decode a video using torchcodec decoder.
|
||||
|
||||
Args:
|
||||
video_path (`str`):
|
||||
Path to the video file.
|
||||
sample_timestamps_fn (`Callable`):
|
||||
A callable function that will return timestamps at which the video should be sampled.
|
||||
|
||||
Returns:
|
||||
tuple[`np.array`, `VideoMetadata`]: A tuple containing:
|
||||
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
|
||||
- `VideoMetadata` object.
|
||||
"""
|
||||
# Lazy import torchcodec
|
||||
import importlib
|
||||
|
||||
torchcodec = importlib.import_module("torchcodec")
|
||||
|
||||
decoder = torchcodec.decoders.VideoDecoder(
|
||||
video_path,
|
||||
# Interestingly `exact` mode takes less than approximate when we load the whole video
|
||||
seek_mode="exact",
|
||||
# Allow FFmpeg decide on the number of threads for efficiency
|
||||
num_ffmpeg_threads=0,
|
||||
)
|
||||
# If the first frame starts at > 0, we effectively clip the video starting at that time
|
||||
# since (most) video players would also skip to that time
|
||||
time_offset = decoder.metadata.begin_stream_seconds_from_content
|
||||
# Note this duration does assume we started playing at `time_offset`
|
||||
duration = decoder.metadata.duration_seconds
|
||||
|
||||
metadata = VideoMetadata(
|
||||
total_num_frames=decoder.metadata.num_frames,
|
||||
fps=decoder.metadata.average_fps,
|
||||
duration=duration,
|
||||
video_backend="torchcodec",
|
||||
height=decoder.metadata.height,
|
||||
width=decoder.metadata.width,
|
||||
)
|
||||
|
||||
target_timestamps = sample_timestamps_fn(metadata=metadata, **kwargs)
|
||||
|
||||
# Floating point/rounding issues might cause `target_timestamps` to be very slightly
|
||||
# out-of-bounds, to handle this we sanity check then clip them
|
||||
assert all(x >= 0 for x in target_timestamps)
|
||||
assert all(x < duration + 1e-6 for x in target_timestamps)
|
||||
# 1e-6 padding since torchcodec can throw out-of-bounds errors even if you ask for the
|
||||
# exact boundary value, we should still get the first/last frame anyway
|
||||
max_timestamp = decoder.metadata.end_stream_seconds_from_content - 1e-6
|
||||
min_timestamp = decoder.metadata.begin_stream_seconds_from_content + 1e-6
|
||||
# Note we avoid using numpy ops here to reduce floating precision issues
|
||||
timestamps = [x + time_offset for x in target_timestamps]
|
||||
timestamps = [max(min_timestamp, min(max_timestamp, x)) for x in timestamps]
|
||||
|
||||
video = (
|
||||
decoder.get_frames_played_at(timestamps).data.numpy().transpose(0, 2, 3, 1)
|
||||
) # Convert to THWC format
|
||||
target_timestamps = np.array(target_timestamps)
|
||||
metadata.frames_indices = target_timestamps * metadata.fps
|
||||
|
||||
return video, metadata
|
||||
|
||||
|
||||
def read_video_pyav(
|
||||
video_path,
|
||||
sample_timestamps_fn: Callable,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Decode a video using the PyAV backend.
|
||||
|
||||
Args:
|
||||
video_path (`str`):
|
||||
Path to the video file.
|
||||
sample_timestamps_fn (`Callable`):
|
||||
A callable function that will return timestamps at which the video should be sampled.
|
||||
|
||||
Returns:
|
||||
tuple[`np.array`, `VideoMetadata`]: A tuple containing:
|
||||
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
|
||||
- `VideoMetadata` object.
|
||||
"""
|
||||
# Lazy import torchcodec
|
||||
import importlib
|
||||
|
||||
av = importlib.import_module("av")
|
||||
|
||||
with av.open(video_path) as container:
|
||||
video_stream = container.streams.video[0]
|
||||
fps = video_stream.average_rate or video_stream.guessed_rate
|
||||
it = container.decode(video=0)
|
||||
frames = list(it)
|
||||
|
||||
stream = container.streams.video[0]
|
||||
start = frames[0].pts * stream.time_base
|
||||
container_end = stream.duration
|
||||
if container_end is not None:
|
||||
container_end *= stream.time_base
|
||||
if container_end is None or container_end < frames[-1].pts:
|
||||
# Some problem with stream duration, so use the frame PTS directly
|
||||
# and guess the duration of the last frame
|
||||
end = frames[-1].pts * stream.time_base + 1 / fps
|
||||
else:
|
||||
end = container_end
|
||||
duration = float(end - start)
|
||||
|
||||
metadata = VideoMetadata(
|
||||
total_num_frames=len(frames),
|
||||
fps=float(fps),
|
||||
duration=float(duration),
|
||||
video_backend="pyav",
|
||||
height=video_stream.height,
|
||||
width=video_stream.width,
|
||||
)
|
||||
|
||||
target_timestamps = sample_timestamps_fn(metadata=metadata, **kwargs)
|
||||
offset = float(start)
|
||||
|
||||
target_timestamps = np.array(target_timestamps)
|
||||
end_time_stamps = np.array([float(frame.pts * stream.time_base) for frame in frames[1:]] + [duration])
|
||||
indices = np.searchsorted(end_time_stamps, target_timestamps + offset, side="right")
|
||||
indices = np.minimum(indices, len(end_time_stamps) - 1)
|
||||
|
||||
video = np.stack(
|
||||
[frames[i].to_ndarray(format="rgb24", channel_last=True) for i in indices],
|
||||
axis=0,
|
||||
)
|
||||
|
||||
metadata.frames_indices = target_timestamps * fps
|
||||
|
||||
return video, metadata
|
||||
|
||||
|
||||
VIDEO_DECODERS = {
|
||||
"decord": read_video_decord,
|
||||
"torchcodec": read_video_torchcodec,
|
||||
"pyav": read_video_pyav,
|
||||
}
|
||||
|
||||
|
||||
def load_video(
|
||||
video: VideoInput,
|
||||
backend: str = "decord",
|
||||
sample_timestamps_fn: Callable | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Loads `video` to a numpy array.
|
||||
|
||||
Args:
|
||||
video (`VideoInput`):
|
||||
The video to convert to the numpy array format. Can be a link to video or local path.
|
||||
backend (`str`, *optional*, defaults to `"decord"`):
|
||||
The backend to use when loading the video. Can be any of ["decord", "pyav", ""torchcodec"]. Defaults to "decord".
|
||||
sample_timestamps_fn (`Callable`):
|
||||
A callable function that will return timestamps at which the video should be sampled.
|
||||
"""
|
||||
|
||||
# Early exit if provided an array or `PIL` frames
|
||||
if not isinstance(video, str):
|
||||
metadata = [None] * len(video)
|
||||
return video, metadata
|
||||
|
||||
if urlparse(video).netloc in ["www.youtube.com", "youtube.com"]:
|
||||
if not is_yt_dlp_available():
|
||||
raise ImportError("To load a video from YouTube url you have to install `yt_dlp` first.")
|
||||
# Lazy import from yt_dlp
|
||||
import importlib
|
||||
|
||||
yt_dlp = importlib.import_module("yt_dlp")
|
||||
|
||||
buffer = BytesIO()
|
||||
with redirect_stdout(buffer), yt_dlp.YoutubeDL() as f:
|
||||
f.download([video])
|
||||
bytes_obj = buffer.getvalue()
|
||||
file_obj = BytesIO(bytes_obj)
|
||||
elif video.startswith("http://") or video.startswith("https://"):
|
||||
file_obj = BytesIO(requests.get(video, timeout=10).content)
|
||||
elif os.path.isfile(video):
|
||||
file_obj = video
|
||||
else:
|
||||
raise TypeError(
|
||||
"Incorrect format used for video. Should be an url linking to an video or a local path."
|
||||
)
|
||||
|
||||
# can also load with decord, but not cv2/torchvision
|
||||
# both will fail in case of url links
|
||||
video_is_url = video.startswith("http://") or video.startswith("https://")
|
||||
if video_is_url and backend == "opencv":
|
||||
raise ValueError("If you are trying to load a video from URL, you cannot use 'opencv' as backend")
|
||||
|
||||
if (
|
||||
(not is_decord_available() and backend == "decord")
|
||||
or (not is_torchcodec_available() and backend == "torchcodec")
|
||||
or (not is_av_available() and backend == "pyav")
|
||||
):
|
||||
raise ImportError(
|
||||
f"You chose backend={backend} for loading the video but the required library is not found in your environment "
|
||||
f"Make sure to install {backend} before loading the video."
|
||||
)
|
||||
|
||||
video_decoder = VIDEO_DECODERS[backend]
|
||||
video, metadata = video_decoder(file_obj, sample_timestamps_fn, **kwargs)
|
||||
return video, metadata
|
||||
|
||||
|
||||
def get_target_fps(
|
||||
video_fps: float,
|
||||
max_frames: int,
|
||||
total_frames: int,
|
||||
frame_sample_mode: str,
|
||||
candidate_target_fps: tuple[float],
|
||||
) -> float:
|
||||
"""
|
||||
Get the target fps that best spans the video and has the most frames sampled
|
||||
"""
|
||||
num_frames_sampled = 0
|
||||
selected_target_fps = None
|
||||
for target_fps in candidate_target_fps:
|
||||
step_size = max(int(video_fps / target_fps), 1)
|
||||
num_frames_sampled_at_fps = int(total_frames / step_size)
|
||||
if num_frames_sampled == 0:
|
||||
if "uniform" in frame_sample_mode:
|
||||
if num_frames_sampled_at_fps > max_frames:
|
||||
break
|
||||
selected_target_fps = target_fps
|
||||
num_frames_sampled = num_frames_sampled_at_fps
|
||||
|
||||
else:
|
||||
# the candidate sampling fps increases so frame count can't decrease
|
||||
assert num_frames_sampled <= num_frames_sampled_at_fps
|
||||
if num_frames_sampled_at_fps > max_frames:
|
||||
# choose the sampling fps that spans the video
|
||||
continue
|
||||
|
||||
elif num_frames_sampled_at_fps > num_frames_sampled:
|
||||
# both are less than max_frames, choose the one with higher density of frames sampled
|
||||
selected_target_fps = target_fps
|
||||
num_frames_sampled = num_frames_sampled_at_fps
|
||||
return selected_target_fps
|
||||
|
||||
|
||||
def get_frame_times_and_chosen_fps(selected_target_fps, total_frames, max_frames, video_fps):
|
||||
if selected_target_fps is None:
|
||||
frame_indices = np.linspace(0, total_frames, max_frames, endpoint=False, dtype=int)
|
||||
else:
|
||||
step_size = max(int(video_fps / selected_target_fps), 1)
|
||||
frame_indices = np.arange(0, total_frames, step_size)
|
||||
if len(frame_indices) > max_frames:
|
||||
frame_indices = frame_indices[:max_frames]
|
||||
return selected_target_fps, frame_indices
|
||||
|
||||
|
||||
class MolmoAct2VideoProcessorKwargs(VideosKwargs, total=False):
|
||||
patch_size: int | None
|
||||
pooling_size: list[int] | None
|
||||
frame_sample_mode: str | None
|
||||
max_fps: int | None
|
||||
sampling_fps: int | None
|
||||
|
||||
|
||||
class MolmoAct2VideoProcessor(BaseVideoProcessor):
|
||||
resample = PILImageResampling.BILINEAR
|
||||
size = {"height": 378, "width": 378}
|
||||
image_mean = IMAGENET_STANDARD_MEAN
|
||||
image_std = IMAGENET_STANDARD_STD
|
||||
do_resize = True
|
||||
do_rescale = True
|
||||
do_normalize = True
|
||||
do_convert_rgb = True
|
||||
patch_size = 14
|
||||
pooling_size = [3, 3]
|
||||
do_sample_frames = True
|
||||
frame_sample_mode = "uniform_last_frame"
|
||||
max_fps = 2
|
||||
sampling_fps = 2
|
||||
valid_kwargs = MolmoAct2VideoProcessorKwargs
|
||||
model_input_names = ["pixel_values_videos", "video_token_pooling", "video_grids"]
|
||||
|
||||
def __init__(self, **kwargs: Unpack[MolmoAct2VideoProcessorKwargs]):
|
||||
super().__init__(**kwargs)
|
||||
if self.size is not None and (
|
||||
self.size.get("height", None) is None or self.size.get("width", None) is None
|
||||
):
|
||||
raise ValueError("size must contain 'height' and 'width' keys.")
|
||||
|
||||
def _further_process_kwargs(
|
||||
self,
|
||||
size: SizeDict | None = None,
|
||||
**kwargs,
|
||||
) -> dict:
|
||||
"""
|
||||
Update kwargs that need further processing before being validated
|
||||
Can be overridden by subclasses to customize the processing of kwargs.
|
||||
"""
|
||||
if size is not None and ("height" not in size or "width" not in size):
|
||||
raise ValueError("size must contain 'height' and 'width' keys.")
|
||||
|
||||
return super()._further_process_kwargs(size=size, **kwargs)
|
||||
|
||||
def sample_times(
|
||||
self,
|
||||
metadata: VideoMetadata,
|
||||
frame_sample_mode: str,
|
||||
num_frames: int,
|
||||
max_fps: int | None = None,
|
||||
sampling_fps: int | None = None,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Time-based sampling if an array video is passed
|
||||
Args:
|
||||
metadata (`VideoMetadata`):
|
||||
Metadata of the video containing information about total duration, fps and total number of frames.
|
||||
frame_sample_mode (`str`, *optional*):
|
||||
Mode to sample frames. Defaults to `self.frame_sample_mode`.
|
||||
num_frames (`int`, *optional*):
|
||||
Maximum number of frames to sample. Defaults to `self.num_frames`.
|
||||
man_fps (`int`, *optional*):
|
||||
Maximum frames per second to sample.
|
||||
sampling_fps (`int`, *optional*):
|
||||
Sampling frames per second. Defaults to `self.sampling_fps`.
|
||||
Used when `frame_sample_mode` is `"fps"`.
|
||||
"""
|
||||
frame_sample_mode = frame_sample_mode or self.frame_sample_mode
|
||||
num_frames = num_frames or self.num_frames
|
||||
sampling_fps = sampling_fps or self.sampling_fps
|
||||
|
||||
duration = metadata.duration or metadata.total_num_frames / metadata.fps
|
||||
if frame_sample_mode == "fps":
|
||||
candidate_target_fps = get_candidate_target_fps(metadata.fps, sampling_fps)
|
||||
# Try larger and larger FPSs until we hit one that can't span the video
|
||||
target_fps = candidate_target_fps[0]
|
||||
for candidate_fps in candidate_target_fps[1:]:
|
||||
if num_frames / candidate_fps < duration:
|
||||
break
|
||||
target_fps = candidate_fps
|
||||
times = np.arange(0, num_frames) / target_fps
|
||||
times = times[times < duration]
|
||||
return times
|
||||
elif frame_sample_mode == "uniform_last_frame":
|
||||
if max_fps is not None:
|
||||
max_duration = (num_frames - 1) / max_fps # -1 to include the last frame
|
||||
if max_duration < duration:
|
||||
times = np.linspace(0, duration, num=num_frames, endpoint=True, dtype=np.float64)
|
||||
else:
|
||||
times = np.arange(0.0, stop=duration, step=1 / max_fps)
|
||||
times = np.concatenate([times, [duration]], axis=0)
|
||||
assert len(times) <= num_frames
|
||||
else:
|
||||
times = np.linspace(0, duration, num=num_frames, endpoint=True, dtype=np.float64)
|
||||
return times
|
||||
else:
|
||||
raise NotImplementedError(frame_sample_mode)
|
||||
|
||||
def sample_frames(
|
||||
self,
|
||||
metadata: VideoMetadata,
|
||||
frame_sample_mode: str | None = None,
|
||||
num_frames: int | None = None,
|
||||
max_fps: int | None = None,
|
||||
sampling_fps: int | None = None,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Frame-based sampling if an array video is passed
|
||||
Args:
|
||||
metadata (`VideoMetadata`):
|
||||
Metadata of the video containing information about total duration, fps and total number of frames.
|
||||
frame_sample_mode (`str`, *optional*):
|
||||
Mode to sample frames. Defaults to `self.frame_sample_mode`.
|
||||
num_frames (`int`, *optional*):
|
||||
Maximum number of frames to sample. Defaults to `self.num_frames`.
|
||||
max_fps (`int`, *optional*):
|
||||
Maximum frames per second to sample.
|
||||
sampling_fps (`int`, *optional*):
|
||||
Sampling frames per second. Defaults to `self.sampling_fps`.
|
||||
Used when `frame_sample_mode` is `"fps"`.
|
||||
"""
|
||||
frame_sample_mode = frame_sample_mode or self.frame_sample_mode
|
||||
num_frames = num_frames or self.num_frames
|
||||
sampling_fps = sampling_fps or self.sampling_fps
|
||||
|
||||
total_num_frames = metadata.total_num_frames
|
||||
if frame_sample_mode == "uniform_last_frame" and max_fps is not None:
|
||||
duration = total_num_frames / metadata.fps
|
||||
if total_num_frames <= 2:
|
||||
return np.arange(total_num_frames).astype(int)
|
||||
if duration > (num_frames - 1) / max_fps: # -1 to include the last frame
|
||||
# uniform fallback
|
||||
indices = np.linspace(
|
||||
0,
|
||||
total_num_frames - 1,
|
||||
num=min(num_frames, total_num_frames),
|
||||
endpoint=True,
|
||||
).astype(int)
|
||||
return indices
|
||||
else:
|
||||
float_indices = np.arange(
|
||||
0.0,
|
||||
stop=total_num_frames - 1,
|
||||
step=float(metadata.fps / max_fps),
|
||||
)
|
||||
if np.round(float_indices[-1]) != total_num_frames - 1:
|
||||
float_indices = np.concatenate([float_indices, [total_num_frames - 1]], axis=0)
|
||||
indices = np.round(float_indices).astype(int)
|
||||
assert indices[-1] < total_num_frames
|
||||
assert len(float_indices) <= num_frames
|
||||
return indices
|
||||
elif frame_sample_mode == "uniform_last_frame":
|
||||
indices = np.linspace(
|
||||
0,
|
||||
total_num_frames - 1,
|
||||
num=min(num_frames, total_num_frames),
|
||||
endpoint=True,
|
||||
).astype(int)
|
||||
return indices
|
||||
elif frame_sample_mode == "fps":
|
||||
candidate_target_fps = get_candidate_target_fps(metadata.fps, sampling_fps)
|
||||
selected_target_fps = get_target_fps(
|
||||
metadata.fps,
|
||||
num_frames,
|
||||
total_num_frames,
|
||||
frame_sample_mode,
|
||||
candidate_target_fps,
|
||||
)
|
||||
_, indices = get_frame_times_and_chosen_fps(
|
||||
selected_target_fps,
|
||||
total_num_frames,
|
||||
num_frames,
|
||||
metadata.fps,
|
||||
)
|
||||
return indices
|
||||
else:
|
||||
raise NotImplementedError(frame_sample_mode)
|
||||
|
||||
def fetch_videos(self, video_url_or_urls: str | list[str] | list[list[str]], sample_timestamps_fn=None):
|
||||
"""
|
||||
Convert a single or a list of urls into the corresponding `np.array` objects.
|
||||
|
||||
If a single url is passed, the return value will be a single object. If a list is passed a list of objects is
|
||||
returned.
|
||||
"""
|
||||
if (not is_decord_available()) and (not is_torchcodec_available()) and (not is_av_available()):
|
||||
raise ImportError(
|
||||
"MolmoAct2VideoProcessor requires `decord`, `torchcodec`, or `av` to be installed."
|
||||
)
|
||||
|
||||
if is_decord_available():
|
||||
backend = "decord"
|
||||
elif is_torchcodec_available():
|
||||
warnings.warn(
|
||||
"`decord` is not installed and cannot be used to decode the video by default. "
|
||||
"Falling back to `torchcodec`."
|
||||
)
|
||||
backend = "torchcodec"
|
||||
else:
|
||||
warnings.warn(
|
||||
"`decord` is not installed and cannot be used to decode the video by default. "
|
||||
"Falling back to `PyAV`."
|
||||
)
|
||||
backend = "pyav"
|
||||
|
||||
if isinstance(video_url_or_urls, list):
|
||||
return list(
|
||||
zip(
|
||||
*[
|
||||
self.fetch_videos(x, sample_timestamps_fn=sample_timestamps_fn)
|
||||
for x in video_url_or_urls
|
||||
]
|
||||
)
|
||||
)
|
||||
else:
|
||||
return load_video(video_url_or_urls, backend=backend, sample_timestamps_fn=sample_timestamps_fn)
|
||||
|
||||
def _decode_and_sample_videos(
|
||||
self,
|
||||
videos: VideoInput,
|
||||
video_metadata: VideoMetadata | dict,
|
||||
do_sample_frames: bool | None = None,
|
||||
sample_indices_fn: Callable | None = None,
|
||||
sample_timestamps_fn: Callable | None = None,
|
||||
):
|
||||
"""
|
||||
Decode input videos and sample frames if needed.
|
||||
"""
|
||||
videos = make_batched_videos(videos)
|
||||
video_metadata = make_batched_metadata(videos, video_metadata=video_metadata)
|
||||
|
||||
# Framed-based sampling if an array video is passed
|
||||
# Otherwise, time-based sampling with decoding
|
||||
if is_valid_video(videos[0]) and do_sample_frames:
|
||||
assert video_metadata[0].fps is not None, "FPS must be provided for video input"
|
||||
sampled_videos = []
|
||||
sampled_metadata = []
|
||||
for video, metadata in zip(videos, video_metadata):
|
||||
indices = sample_indices_fn(metadata=metadata)
|
||||
metadata.frames_indices = indices
|
||||
sampled_videos.append(video[indices])
|
||||
sampled_metadata.append(metadata)
|
||||
videos = sampled_videos
|
||||
video_metadata = sampled_metadata
|
||||
elif not is_valid_video(videos[0]):
|
||||
if sample_indices_fn is None:
|
||||
logger.warning(
|
||||
"do_sample_frames is False, but video array is not provided: "
|
||||
"Will decode the video and sample frames using MolmoAct2's default sampling mode"
|
||||
)
|
||||
if isinstance(videos[0], list):
|
||||
raise ValueError("A list of images is not supported for video input!")
|
||||
else:
|
||||
videos, video_metadata = self.fetch_videos(videos, sample_timestamps_fn=sample_timestamps_fn)
|
||||
|
||||
return videos, video_metadata
|
||||
|
||||
def _prepare_input_videos(
|
||||
self,
|
||||
videos: VideoInput,
|
||||
**kwargs,
|
||||
) -> list[np.ndarray]:
|
||||
processed_videos = [to_numpy(video) for video in videos]
|
||||
return processed_videos
|
||||
|
||||
def preprocess(
|
||||
self,
|
||||
videos: VideoInput,
|
||||
**kwargs: Unpack[MolmoAct2VideoProcessorKwargs],
|
||||
) -> BatchFeature:
|
||||
validate_kwargs(
|
||||
captured_kwargs=kwargs.keys(),
|
||||
valid_processor_keys=list(self.valid_kwargs.__annotations__.keys()) + ["return_tensors"],
|
||||
)
|
||||
|
||||
# Set default kwargs from self. This ensures that if a kwarg is not provided
|
||||
# by the user, it gets its default value from the instance, or is set to None.
|
||||
for kwarg_name in self.valid_kwargs.__annotations__:
|
||||
kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None))
|
||||
|
||||
do_sample_frames = kwargs.pop("do_sample_frames")
|
||||
video_metadata = kwargs.pop("video_metadata")
|
||||
|
||||
sample_indices_fn = partial(self.sample_frames, **kwargs) if do_sample_frames else None
|
||||
sample_timestamps_fn = partial(self.sample_times, **kwargs)
|
||||
videos, video_metadata = self._decode_and_sample_videos(
|
||||
videos,
|
||||
video_metadata=video_metadata,
|
||||
do_sample_frames=do_sample_frames,
|
||||
sample_indices_fn=sample_indices_fn,
|
||||
sample_timestamps_fn=sample_timestamps_fn,
|
||||
)
|
||||
videos = self._prepare_input_videos(videos=videos)
|
||||
|
||||
kwargs = self._further_process_kwargs(**kwargs)
|
||||
|
||||
return_metadata = kwargs.pop("return_metadata")
|
||||
preprocessed_videos = self._preprocess(videos=videos, **kwargs)
|
||||
if return_metadata:
|
||||
preprocessed_videos["video_metadata"] = video_metadata
|
||||
return preprocessed_videos
|
||||
|
||||
def _preprocess(
|
||||
self,
|
||||
videos: list[np.ndarray],
|
||||
size: SizeDict | None = None,
|
||||
resample: PILImageResampling | None = None,
|
||||
image_mean: float | list[float] | None = None,
|
||||
image_std: float | list[float] | None = None,
|
||||
do_convert_rgb: bool | None = None,
|
||||
patch_size: int | None = None,
|
||||
pooling_size: list[int] | None = None,
|
||||
return_tensors: str | TensorType | None = None,
|
||||
**kwargs,
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
Preprocess a video for the model.
|
||||
Args:
|
||||
videos (`VideoInput`):
|
||||
Video to preprocess.
|
||||
size (`SizeDict`, *optional*, defaults to `self.size`):
|
||||
Size of the image after resizing.
|
||||
resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
|
||||
Resampling filter to use when resizing the image. This can be one of the enum `PILImageResampling`. Only
|
||||
has an effect if `do_resize` is set to `True`.
|
||||
image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`):
|
||||
Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
|
||||
image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`):
|
||||
Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
|
||||
`True`.
|
||||
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
|
||||
Whether to convert the image to RGB.
|
||||
patch_size (`int`, *optional*, defaults to `self.patch_size`):
|
||||
The spatial patch size of the vision encoder.
|
||||
pooling_size (`list[int]`, *optional*, defaults to `self.pooling_size`):
|
||||
The pooling size of the vision adapter.
|
||||
return_tensors (`str` or `TensorType`, *optional*):
|
||||
The type of tensors to return. Can be one of:
|
||||
- Unset: Return a list of `np.ndarray`.
|
||||
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
||||
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
||||
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
||||
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
||||
|
||||
Returns:
|
||||
A `BatchFeature` containing the following keys:
|
||||
- `pixel_values_videos`: The preprocessed videos.
|
||||
- `video_token_pooling`: The indices of the patches in `crops` to pool for each token in `video_tokens`.
|
||||
- `video_grids`: The video grids.
|
||||
"""
|
||||
if size.height is None or size.width is None:
|
||||
raise ValueError("size must contain 'height' and 'width' keys.")
|
||||
|
||||
base_image_input_size = [size.height, size.width]
|
||||
|
||||
resample = resample or self.resample
|
||||
image_mean = image_mean or self.image_mean
|
||||
image_std = image_std or self.image_std
|
||||
do_convert_rgb = do_convert_rgb or self.do_convert_rgb
|
||||
|
||||
patch_size = patch_size or self.patch_size
|
||||
pooling_size = pooling_size or self.pooling_size
|
||||
|
||||
image_pooling_h, image_pooling_w = pooling_size
|
||||
|
||||
batch_grids = []
|
||||
batch_crops = []
|
||||
batch_pooled_patches_idx = []
|
||||
|
||||
for video in videos:
|
||||
all_crops = []
|
||||
pooled_patches_idx = []
|
||||
|
||||
for frame in video:
|
||||
image_grid, crops, pooled_idx = image_to_patches_and_grids(
|
||||
frame,
|
||||
base_image_input_size,
|
||||
resample,
|
||||
image_mean,
|
||||
image_std,
|
||||
patch_size,
|
||||
image_pooling_w,
|
||||
image_pooling_h,
|
||||
)
|
||||
offset = sum(np.prod(x.shape[:2]) for x in all_crops)
|
||||
pooled_idx_with_offset = np.where(pooled_idx >= 0, pooled_idx + offset, pooled_idx)
|
||||
pooled_patches_idx.append(pooled_idx_with_offset)
|
||||
all_crops.append(crops)
|
||||
|
||||
video_grid = np.array([len(video), image_grid[0], image_grid[1]])
|
||||
all_crops = np.concatenate(all_crops, 0)
|
||||
pooled_patches_idx = np.concatenate(pooled_patches_idx, 0)
|
||||
|
||||
batch_grids.append(video_grid)
|
||||
batch_crops.append(all_crops)
|
||||
batch_pooled_patches_idx.append(pooled_patches_idx)
|
||||
|
||||
video_grids = np.stack(batch_grids, 0)
|
||||
pixel_values_videos = np.concatenate(batch_crops, 0)
|
||||
video_token_pooling = np.concatenate(batch_pooled_patches_idx, 0)
|
||||
|
||||
data = dict(
|
||||
pixel_values_videos=pixel_values_videos,
|
||||
video_token_pooling=video_token_pooling,
|
||||
video_grids=video_grids,
|
||||
)
|
||||
|
||||
return BatchFeature(data, tensor_type=return_tensors)
|
||||
|
||||
|
||||
MolmoAct2VideoProcessor.register_for_auto_class()
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1 +0,0 @@
|
||||
../../../../docs/source/policy_vla_jepa_README.md
|
||||
@@ -1,337 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import nn
|
||||
from torch.distributions import Beta
|
||||
|
||||
from lerobot.utils.import_utils import _diffusers_available, require_package
|
||||
|
||||
if TYPE_CHECKING or _diffusers_available:
|
||||
from diffusers import ConfigMixin, ModelMixin
|
||||
from diffusers.configuration_utils import register_to_config
|
||||
from diffusers.models.attention import Attention, FeedForward
|
||||
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
|
||||
else:
|
||||
|
||||
class ModelMixin: # type: ignore[no-redef]
|
||||
pass
|
||||
|
||||
class ConfigMixin: # type: ignore[no-redef]
|
||||
pass
|
||||
|
||||
register_to_config = lambda f: f # noqa: E731
|
||||
Attention = FeedForward = TimestepEmbedding = Timesteps = None
|
||||
|
||||
from .configuration_vla_jepa import VLAJEPAConfig
|
||||
|
||||
|
||||
class SinusoidalPositionalEncoding(nn.Module):
|
||||
def __init__(self, embedding_dim: int):
|
||||
super().__init__()
|
||||
self.embedding_dim = embedding_dim
|
||||
|
||||
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
|
||||
timesteps = timesteps.float()
|
||||
batch_size, seq_len = timesteps.shape
|
||||
half_dim = self.embedding_dim // 2
|
||||
exponent = -torch.arange(half_dim, dtype=torch.float, device=timesteps.device)
|
||||
exponent = exponent * (torch.log(torch.tensor(10000.0, device=timesteps.device)) / max(half_dim, 1))
|
||||
freqs = timesteps.unsqueeze(-1) * exponent.exp()
|
||||
return torch.cat([torch.sin(freqs), torch.cos(freqs)], dim=-1).view(batch_size, seq_len, -1)
|
||||
|
||||
|
||||
class ActionEncoder(nn.Module):
|
||||
def __init__(self, action_dim: int, hidden_size: int):
|
||||
super().__init__()
|
||||
self.layer1 = nn.Linear(action_dim, hidden_size)
|
||||
self.layer2 = nn.Linear(hidden_size * 2, hidden_size)
|
||||
self.layer3 = nn.Linear(hidden_size, hidden_size)
|
||||
self.pos_encoding = SinusoidalPositionalEncoding(hidden_size)
|
||||
|
||||
def forward(self, actions: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, seq_len, _ = actions.shape
|
||||
if timesteps.ndim != 1 or timesteps.shape[0] != batch_size:
|
||||
raise ValueError("timesteps must have shape [batch_size].")
|
||||
timesteps = timesteps.unsqueeze(1).expand(-1, seq_len)
|
||||
action_emb = self.layer1(actions)
|
||||
time_emb = self.pos_encoding(timesteps).to(dtype=action_emb.dtype)
|
||||
return self.layer3(F.silu(self.layer2(torch.cat([action_emb, time_emb], dim=-1))))
|
||||
|
||||
|
||||
class TimestepEncoder(nn.Module):
|
||||
def __init__(self, embedding_dim: int):
|
||||
super().__init__()
|
||||
require_package("diffusers", extra="vla_jepa")
|
||||
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1)
|
||||
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
||||
|
||||
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
|
||||
projected = self.time_proj(timesteps).to(dtype=next(self.parameters()).dtype)
|
||||
return self.timestep_embedder(projected)
|
||||
|
||||
|
||||
class AdaLayerNorm(nn.Module):
|
||||
def __init__(self, embedding_dim: int):
|
||||
super().__init__()
|
||||
self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
|
||||
self.norm = nn.LayerNorm(embedding_dim, eps=1e-5, elementwise_affine=False)
|
||||
self.silu = nn.SiLU()
|
||||
|
||||
def forward(self, x: torch.Tensor, temb: torch.Tensor) -> torch.Tensor:
|
||||
scale, shift = self.linear(self.silu(temb)).chunk(2, dim=-1)
|
||||
return self.norm(x) * (1 + scale[:, None]) + shift[:, None]
|
||||
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
dropout: float,
|
||||
cross_attention_dim: int,
|
||||
is_cross_attention: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.is_cross_attention = is_cross_attention
|
||||
self.norm1 = AdaLayerNorm(dim)
|
||||
self.attn1 = Attention(
|
||||
query_dim=dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
dropout=dropout,
|
||||
bias=True,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
out_bias=True,
|
||||
)
|
||||
self.norm2 = nn.LayerNorm(dim, eps=1e-5, elementwise_affine=False)
|
||||
self.ff = FeedForward(dim, dropout=dropout, activation_fn="gelu-approximate", final_dropout=True)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor | None,
|
||||
temb: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
attn_input = self.norm1(hidden_states, temb)
|
||||
attention_context = encoder_hidden_states if self.is_cross_attention else None
|
||||
hidden_states = hidden_states + self.attn1(attn_input, encoder_hidden_states=attention_context)
|
||||
hidden_states = hidden_states + self.ff(self.norm2(hidden_states))
|
||||
return hidden_states
|
||||
|
||||
|
||||
class DiT(ModelMixin, ConfigMixin):
|
||||
_supports_gradient_checkpointing = False
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
output_dim: int,
|
||||
num_layers: int,
|
||||
dropout: float,
|
||||
cross_attention_dim: int,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.inner_dim = num_attention_heads * attention_head_dim
|
||||
self.timestep_encoder = TimestepEncoder(self.inner_dim)
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
dim=self.inner_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
dropout=dropout,
|
||||
cross_attention_dim=cross_attention_dim if layer_idx % 2 == 0 else self.inner_dim,
|
||||
is_cross_attention=layer_idx % 2 == 0,
|
||||
)
|
||||
for layer_idx in range(num_layers)
|
||||
]
|
||||
)
|
||||
self.norm_out = nn.LayerNorm(self.inner_dim, eps=1e-6, elementwise_affine=False)
|
||||
self.proj_out_1 = nn.Linear(self.inner_dim, self.inner_dim * 2)
|
||||
self.proj_out_2 = nn.Linear(self.inner_dim, output_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
timestep: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
temb = self.timestep_encoder(timestep)
|
||||
x = hidden_states
|
||||
for block in self.transformer_blocks:
|
||||
x = block(x, encoder_hidden_states=encoder_hidden_states, temb=temb)
|
||||
shift, scale = self.proj_out_1(F.silu(temb)).chunk(2, dim=-1)
|
||||
x = self.norm_out(x) * (1 + scale[:, None]) + shift[:, None]
|
||||
return self.proj_out_2(x)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActionModelPreset:
|
||||
hidden_size: int
|
||||
attention_head_dim: int
|
||||
num_attention_heads: int
|
||||
|
||||
|
||||
DIT_PRESETS = {
|
||||
"DiT-B": ActionModelPreset(hidden_size=768, attention_head_dim=64, num_attention_heads=12),
|
||||
"DiT-L": ActionModelPreset(hidden_size=1536, attention_head_dim=48, num_attention_heads=32),
|
||||
"DiT-test": ActionModelPreset(hidden_size=16, attention_head_dim=8, num_attention_heads=2),
|
||||
}
|
||||
|
||||
|
||||
class VLAJEPAActionHead(nn.Module):
|
||||
def __init__(self, config: VLAJEPAConfig, cross_attention_dim: int) -> None:
|
||||
super().__init__()
|
||||
preset = DIT_PRESETS[config.action_model_type]
|
||||
self.config = config
|
||||
num_heads = config.action_num_heads or preset.num_attention_heads
|
||||
head_dim = config.action_attention_head_dim or preset.attention_head_dim
|
||||
inner_dim = num_heads * head_dim # e.g. DiT-B: 12 × 64 = 768
|
||||
|
||||
self.input_embedding_dim = inner_dim
|
||||
self.action_horizon = config.chunk_size
|
||||
self.num_inference_timesteps = config.num_inference_timesteps
|
||||
|
||||
hidden_size = config.action_hidden_size
|
||||
self.model = DiT(
|
||||
num_attention_heads=num_heads,
|
||||
attention_head_dim=head_dim,
|
||||
output_dim=hidden_size,
|
||||
num_layers=config.action_num_layers,
|
||||
dropout=config.action_dropout,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
)
|
||||
self.action_encoder = ActionEncoder(config.action_dim, inner_dim)
|
||||
self.action_decoder = nn.Sequential(
|
||||
OrderedDict(
|
||||
[
|
||||
("layer1", nn.Linear(hidden_size, hidden_size)),
|
||||
("relu", nn.ReLU()),
|
||||
("layer2", nn.Linear(hidden_size, config.action_dim)),
|
||||
]
|
||||
)
|
||||
)
|
||||
self.state_encoder = (
|
||||
nn.Sequential(
|
||||
OrderedDict(
|
||||
[
|
||||
("layer1", nn.Linear(config.state_dim, hidden_size)),
|
||||
("relu", nn.ReLU()),
|
||||
("layer2", nn.Linear(hidden_size, inner_dim)),
|
||||
]
|
||||
)
|
||||
)
|
||||
if config.state_dim > 0
|
||||
else None
|
||||
)
|
||||
self.future_tokens = nn.Embedding(config.num_embodied_action_tokens_per_instruction, inner_dim)
|
||||
self.position_embedding = nn.Embedding(
|
||||
max(1024, config.chunk_size + config.num_action_tokens_per_timestep + 4),
|
||||
inner_dim,
|
||||
)
|
||||
self.beta_dist = Beta(config.action_noise_beta_alpha, config.action_noise_beta_beta)
|
||||
|
||||
def sample_time(self, batch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
|
||||
sample = self.beta_dist.sample([batch_size]).to(device=device, dtype=dtype)
|
||||
return (self.config.action_noise_s - sample) / self.config.action_noise_s
|
||||
|
||||
def _build_inputs(
|
||||
self,
|
||||
conditioning_tokens: torch.Tensor,
|
||||
actions: torch.Tensor,
|
||||
state: torch.Tensor | None,
|
||||
timesteps: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
action_features = self.action_encoder(actions, timesteps)
|
||||
pos_ids = torch.arange(action_features.shape[1], device=actions.device)
|
||||
action_features = action_features + self.position_embedding(pos_ids)[None]
|
||||
|
||||
future_tokens = self.future_tokens.weight.unsqueeze(0).expand(actions.shape[0], -1, -1)
|
||||
seq = [future_tokens, action_features]
|
||||
if state is not None and self.state_encoder is not None:
|
||||
if state.ndim == 2:
|
||||
state = state.unsqueeze(1)
|
||||
seq.insert(0, self.state_encoder(state))
|
||||
return torch.cat(seq, dim=1)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
conditioning_tokens: torch.Tensor,
|
||||
actions: torch.Tensor,
|
||||
state: torch.Tensor | None = None,
|
||||
action_is_pad: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
noise = torch.randn_like(actions)
|
||||
t = self.sample_time(actions.shape[0], actions.device, actions.dtype)
|
||||
noisy_actions = (1 - t[:, None, None]) * noise + t[:, None, None] * actions
|
||||
velocity = actions - noise
|
||||
t_discretized = (t * self.config.action_num_timestep_buckets).long()
|
||||
|
||||
hidden_states = self._build_inputs(conditioning_tokens, noisy_actions, state, t_discretized)
|
||||
pred = self.model(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=conditioning_tokens,
|
||||
timestep=t_discretized,
|
||||
)
|
||||
pred_actions = self.action_decoder(pred[:, -actions.shape[1] :])
|
||||
|
||||
if action_is_pad is None:
|
||||
action_is_pad = torch.zeros(actions.shape[:2], dtype=torch.bool, device=actions.device)
|
||||
|
||||
loss = F.mse_loss(pred_actions, velocity, reduction="none") # [B, T, action_dim]
|
||||
valid_mask = ~action_is_pad.unsqueeze(-1) # [B, T, 1]
|
||||
num_valid = valid_mask.sum() * loss.shape[-1]
|
||||
return (loss * valid_mask).sum() / num_valid.clamp_min(1)
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action(
|
||||
self,
|
||||
conditioning_tokens: torch.Tensor,
|
||||
state: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
batch_size = conditioning_tokens.shape[0]
|
||||
actions = torch.randn(
|
||||
batch_size,
|
||||
self.action_horizon,
|
||||
self.config.action_dim,
|
||||
dtype=conditioning_tokens.dtype,
|
||||
device=conditioning_tokens.device,
|
||||
)
|
||||
dt = 1.0 / max(self.num_inference_timesteps, 1)
|
||||
for step in range(self.num_inference_timesteps):
|
||||
t_cont = step / float(max(self.num_inference_timesteps, 1))
|
||||
t_value = int(t_cont * self.config.action_num_timestep_buckets)
|
||||
timesteps = torch.full(
|
||||
(batch_size,), t_value, device=conditioning_tokens.device, dtype=torch.long
|
||||
)
|
||||
hidden_states = self._build_inputs(conditioning_tokens, actions, state, timesteps)
|
||||
pred = self.model(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=conditioning_tokens,
|
||||
timestep=timesteps,
|
||||
)
|
||||
pred_velocity = self.action_decoder(pred[:, -self.action_horizon :])
|
||||
actions = actions + dt * pred_velocity
|
||||
return actions
|
||||
@@ -1,154 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import NormalizationMode
|
||||
from lerobot.optim.optimizers import AdamWConfig
|
||||
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("vla_jepa")
|
||||
@dataclass
|
||||
class VLAJEPAConfig(PreTrainedConfig):
|
||||
n_obs_steps: int = 1
|
||||
chunk_size: int = 7
|
||||
n_action_steps: int = 7
|
||||
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.MEAN_STD,
|
||||
"ACTION": NormalizationMode.MIN_MAX,
|
||||
}
|
||||
)
|
||||
|
||||
qwen_model_name: str = "Qwen/Qwen3-VL-2B-Instruct"
|
||||
jepa_encoder_name: str = "facebook/vjepa2-vitl-fpc64-256"
|
||||
freeze_qwen: bool = False
|
||||
enable_world_model: bool = True
|
||||
# Enables cross-embodiment transfer: when fine-tuning a pretrained model on a robot with a
|
||||
# different action or state dimensionality, the input/output projection layers must be
|
||||
# re-initialised from scratch while the rest of the network keeps its pretrained weights.
|
||||
# List the key prefixes that are allowed to have shape mismatches; anything else raises an error.
|
||||
# e.g. ["model.action_model.action_encoder", "model.action_model.state_encoder"]
|
||||
reinit_modules: list[str] | None = None
|
||||
|
||||
tokenizer_padding_side: str = "left"
|
||||
prompt_template: str = "Your task is {instruction}. Infer the temporal dynamics from frames {actions} and produce the corresponding policy actions {e_actions}."
|
||||
special_action_token: str = "<|action_{}|>"
|
||||
embodied_action_token: str = "<|embodied_action|>"
|
||||
|
||||
action_dim: int = 7
|
||||
state_dim: int = 8
|
||||
|
||||
num_action_tokens_per_timestep: int = 8
|
||||
num_embodied_action_tokens_per_instruction: int = 32
|
||||
num_inference_timesteps: int = 4
|
||||
|
||||
action_hidden_size: int = 1024
|
||||
action_model_type: str = "DiT-B"
|
||||
action_num_layers: int = 16
|
||||
action_num_heads: int | None = None
|
||||
action_attention_head_dim: int | None = None
|
||||
action_dropout: float = 0.2
|
||||
action_num_timestep_buckets: int = 1000
|
||||
action_noise_beta_alpha: float = 1.5
|
||||
action_noise_beta_beta: float = 1.0
|
||||
action_noise_s: float = 0.999
|
||||
num_target_vision_tokens: int = 32
|
||||
action_max_seq_len: int = 1024
|
||||
|
||||
# total video frames loaded per sample
|
||||
num_video_frames: int = 8
|
||||
predictor_depth: int = 12
|
||||
predictor_num_heads: int = 8
|
||||
predictor_mlp_ratio: float = 4.0
|
||||
predictor_dropout: float = 0.0
|
||||
world_model_loss_weight: float = 0.1
|
||||
jepa_tubelet_size: int = 2 # must match the encoder (e.g. 2 for vjepa2-vitl-fpc64-256)
|
||||
repeated_diffusion_steps: int = 8 # independent noise draws per batch item (CogACT-style)
|
||||
|
||||
resize_images_to: tuple[int, int] | None = None
|
||||
binarize_gripper_action: bool = True
|
||||
pre_snap_gripper_action: bool = True
|
||||
clip_normalized_actions: bool = True
|
||||
gripper_dim: int = 6
|
||||
gripper_threshold: float = 0.5
|
||||
torch_dtype: str = "bfloat16"
|
||||
|
||||
optimizer_lr: float = 1e-4
|
||||
optimizer_betas: tuple[float, float] = (0.9, 0.95)
|
||||
optimizer_eps: float = 1e-8
|
||||
optimizer_weight_decay: float = 1e-10
|
||||
optimizer_grad_clip_norm: float = 10.0
|
||||
scheduler_warmup_steps: int = 1_000
|
||||
scheduler_decay_steps: int = 30_000
|
||||
scheduler_decay_lr: float = 2.5e-6
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
super().__post_init__()
|
||||
if self.freeze_qwen and self.enable_world_model:
|
||||
# freezing qwen backbone makes world model training irrelevant since no grad flows
|
||||
self.enable_world_model = False
|
||||
if self.n_action_steps > self.chunk_size:
|
||||
raise ValueError("`n_action_steps` must be <= `chunk_size`.")
|
||||
if self.num_video_frames < 2 * self.jepa_tubelet_size:
|
||||
raise ValueError(
|
||||
f"`video_horizon` ({self.num_video_frames}) must be >= 2 * `jepa_tubelet_size` "
|
||||
f"({self.jepa_tubelet_size}) to have at least one context and one GT temporal position."
|
||||
)
|
||||
|
||||
def validate_features(self) -> None:
|
||||
if not self.image_features:
|
||||
raise ValueError("VLAJEPA requires at least one visual input feature.")
|
||||
if self.action_feature is None:
|
||||
raise ValueError("VLAJEPA requires an action output feature.")
|
||||
self.action_dim = self.action_feature.shape[0]
|
||||
if self.robot_state_feature is not None:
|
||||
self.state_dim = self.robot_state_feature.shape[0]
|
||||
|
||||
def get_optimizer_preset(self) -> AdamWConfig:
|
||||
return AdamWConfig(
|
||||
lr=self.optimizer_lr,
|
||||
betas=self.optimizer_betas,
|
||||
eps=self.optimizer_eps,
|
||||
weight_decay=self.optimizer_weight_decay,
|
||||
grad_clip_norm=self.optimizer_grad_clip_norm,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self) -> CosineDecayWithWarmupSchedulerConfig:
|
||||
return CosineDecayWithWarmupSchedulerConfig(
|
||||
peak_lr=self.optimizer_lr,
|
||||
decay_lr=self.scheduler_decay_lr,
|
||||
num_warmup_steps=self.scheduler_warmup_steps,
|
||||
num_decay_steps=self.scheduler_decay_steps,
|
||||
)
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> list[int]:
|
||||
# load video_horizon frames starting from current timestep: [t, t+1, ..., t+video_horizon-1]
|
||||
# matches original repo's observation_indices=list(range(video_horizon))
|
||||
return list(range(self.num_video_frames))
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list[int]:
|
||||
return list(range(self.chunk_size))
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
return None
|
||||
@@ -1,629 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections import deque
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from PIL import Image
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
||||
from lerobot.policies.utils import populate_queues
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE
|
||||
from lerobot.utils.import_utils import _transformers_available, require_package
|
||||
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers import AutoModel, AutoVideoProcessor
|
||||
else:
|
||||
AutoModel = None
|
||||
AutoVideoProcessor = None
|
||||
|
||||
from .action_head import VLAJEPAActionHead
|
||||
from .configuration_vla_jepa import VLAJEPAConfig
|
||||
from .qwen_interface import Qwen3VLInterface
|
||||
from .world_model import ActionConditionedVideoPredictor
|
||||
|
||||
# ============================================================================
|
||||
# Native VLA-JEPA Model - follows original starVLA VLA_JEPA.py implementation
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class VLAJEPAModel(nn.Module):
|
||||
"""
|
||||
Native VLA-JEPA model following the original starVLA VLA_JEPA.py.
|
||||
|
||||
Components:
|
||||
- Qwen3-VL: vision-language backbone for fused embeddings
|
||||
- DiT-B: flow-matching action head for future action prediction
|
||||
- V-JEPA: world model for video frame prediction
|
||||
|
||||
Input: List[dict] native format (same as original starVLA)
|
||||
- "image": List[PIL.Image] (multi-view images)
|
||||
- "video": np.ndarray [V, T, H, W, 3]
|
||||
- "lang": str (task instruction)
|
||||
- "action": np.ndarray [T, action_dim] (optional, training only)
|
||||
- "state": np.ndarray [1, state_dim] (optional)
|
||||
"""
|
||||
|
||||
def __init__(self, config: VLAJEPAConfig) -> None:
|
||||
super().__init__()
|
||||
require_package("transformers", extra="vla_jepa")
|
||||
self.config = config
|
||||
|
||||
# Vision-language backbone
|
||||
self.qwen = Qwen3VLInterface(config)
|
||||
|
||||
# Tokenizer expansion for special action tokens
|
||||
self.action_tokens, self.action_token_ids, self.embodied_action_token_id = (
|
||||
self.qwen.expand_tokenizer()
|
||||
)
|
||||
|
||||
# Action head (flow-matching DiT)
|
||||
self.action_model = VLAJEPAActionHead(config, cross_attention_dim=self.qwen.model.config.hidden_size)
|
||||
|
||||
# JEPA world model components
|
||||
if config.enable_world_model:
|
||||
self.video_encoder = AutoModel.from_pretrained(
|
||||
config.jepa_encoder_name,
|
||||
torch_dtype=self.qwen._get_torch_dtype(config.torch_dtype),
|
||||
)
|
||||
self.video_processor = AutoVideoProcessor.from_pretrained(config.jepa_encoder_name)
|
||||
num_views = config.jepa_tubelet_size
|
||||
tubelet_size = self.video_encoder.config.tubelet_size
|
||||
image_size = getattr(self.video_encoder.config, "image_size", None)
|
||||
if image_size is None:
|
||||
first_image_shape = next(iter(config.image_features.values())).shape
|
||||
image_size = first_image_shape[-1]
|
||||
self.video_predictor = ActionConditionedVideoPredictor(
|
||||
num_frames=config.num_video_frames // tubelet_size,
|
||||
img_size=(image_size, image_size),
|
||||
patch_size=16,
|
||||
tubelet_size=1,
|
||||
embed_dim=self.video_encoder.config.hidden_size * num_views,
|
||||
action_embed_dim=self.qwen.model.config.hidden_size,
|
||||
predictor_embed_dim=self.video_encoder.config.hidden_size,
|
||||
depth=config.predictor_depth,
|
||||
num_heads=config.predictor_num_heads,
|
||||
mlp_ratio=config.predictor_mlp_ratio,
|
||||
num_action_tokens_per_step=config.num_action_tokens_per_timestep,
|
||||
)
|
||||
else:
|
||||
self.video_encoder = None
|
||||
self.video_processor = None
|
||||
self.video_predictor = None
|
||||
|
||||
if config.freeze_qwen:
|
||||
self.qwen.requires_grad_(False)
|
||||
|
||||
# Build prompt placeholders.
|
||||
# Use the encoder's actual tubelet_size when available (world model enabled),
|
||||
# otherwise fall back to config.
|
||||
_tubelet_size = (
|
||||
self.video_encoder.config.tubelet_size
|
||||
if config.enable_world_model
|
||||
else self.config.jepa_tubelet_size
|
||||
)
|
||||
num_action_prompt_steps = self.config.num_video_frames // _tubelet_size - 1
|
||||
self.replace_prompt = "".join(
|
||||
token * self.config.num_action_tokens_per_timestep
|
||||
for token in self.action_tokens[:num_action_prompt_steps]
|
||||
)
|
||||
self.embodied_replace_prompt = (
|
||||
self.config.embodied_action_token * self.config.num_embodied_action_tokens_per_instruction
|
||||
)
|
||||
|
||||
def _qwen_last_decoder_hidden(self, qwen_inputs: dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
"""Return the last decoder hidden state before the final RMSNorm.
|
||||
|
||||
The model was trained with the output of the last transformer block BEFORE
|
||||
the final RMSNorm. In transformers 5.x, `hidden_states[-1]` from
|
||||
`output_hidden_states=True` is post-norm (tied to `last_hidden_state` via
|
||||
`@capture_outputs`). A forward hook on `language_model.layers[-1]` recovers
|
||||
the correct pre-RMSNorm state, matching the training-time representation.
|
||||
"""
|
||||
captured: list[torch.Tensor] = []
|
||||
|
||||
def _hook(module, input, output):
|
||||
h = output[0] if isinstance(output, tuple) else output
|
||||
captured.append(h)
|
||||
|
||||
last_layer = self.qwen.model.model.language_model.layers[-1]
|
||||
handle = last_layer.register_forward_hook(_hook)
|
||||
try:
|
||||
self.qwen.model(
|
||||
**qwen_inputs,
|
||||
output_hidden_states=False,
|
||||
output_attentions=False,
|
||||
return_dict=True,
|
||||
)
|
||||
finally:
|
||||
handle.remove()
|
||||
|
||||
return captured[0] # [B, seq_len, H]
|
||||
|
||||
# ---- Native VLA-JEPA forward (follows original VLA_JEPA.py) ----
|
||||
|
||||
def forward(self, examples: list[dict]) -> dict[str, Tensor]:
|
||||
"""
|
||||
Native forward pass following original starVLA VLA_JEPA.forward.
|
||||
|
||||
Args:
|
||||
examples: List of per-sample dicts with keys:
|
||||
"image" : List[PIL.Image] — multi-view images
|
||||
"video" : np.ndarray [V, T, H, W, 3]
|
||||
"lang" : str — task instruction
|
||||
"action" : np.ndarray [T, action_dim] (optional)
|
||||
"state" : np.ndarray [1, state_dim] (optional)
|
||||
|
||||
Returns:
|
||||
dict with "action_loss" and "wm_loss" keys (scalar Tensors).
|
||||
"""
|
||||
# Unpack native format (same pattern as original VLA_JEPA.py)
|
||||
batch_images = [ex["image"] for ex in examples] # List[List[PIL.Image]]
|
||||
batch_videos = [ex["video"] for ex in examples] # List[np.ndarray]
|
||||
instructions = [ex["lang"] for ex in examples] # List[str]
|
||||
has_action = "action" in examples[0] and examples[0]["action"] is not None
|
||||
actions = [ex["action"] for ex in examples] if has_action else None
|
||||
has_state = "state" in examples[0] and examples[0]["state"] is not None
|
||||
state = [ex["state"] for ex in examples] if has_state else None
|
||||
action_is_pad = (
|
||||
[ex["action_is_pad"] for ex in examples]
|
||||
if has_action and "action_is_pad" in examples[0] and examples[0]["action_is_pad"] is not None
|
||||
else None
|
||||
)
|
||||
|
||||
# Stack videos: [B, V, T, H, W, 3] -> [B, V, T, 3, H, W]
|
||||
batch_videos = np.stack(batch_videos)
|
||||
batch_videos = batch_videos.transpose(0, 1, 2, 5, 3, 4) # [B, V, T, 3, H, W]
|
||||
|
||||
# Adjust number of views for the world model:
|
||||
# - fewer views than expected: duplicate the first view to fill up
|
||||
# - more views than expected: keep only the first num_views_world_model views
|
||||
num_views_world_model = self.config.jepa_tubelet_size
|
||||
if batch_videos.shape[1] < num_views_world_model:
|
||||
num_missing_views = num_views_world_model - batch_videos.shape[1]
|
||||
first_view = np.repeat(batch_videos[:, :1], num_missing_views, axis=1)
|
||||
batch_videos = np.concatenate([batch_videos, first_view], axis=1)
|
||||
elif batch_videos.shape[1] > num_views_world_model:
|
||||
batch_videos = batch_videos[:, :num_views_world_model]
|
||||
|
||||
# ---- Step 1: QwenVL encode (same as original) ----
|
||||
qwen_inputs = self.qwen.build_inputs(
|
||||
images=batch_images,
|
||||
instructions=instructions,
|
||||
action_prompt=self.replace_prompt,
|
||||
embodied_prompt=self.embodied_replace_prompt,
|
||||
)
|
||||
|
||||
# Locate embodied-action tokens (always needed for action head)
|
||||
embodied_mask = qwen_inputs["input_ids"] == self.embodied_action_token_id
|
||||
embodied_indices = embodied_mask.nonzero(as_tuple=True)
|
||||
|
||||
# Locate action tokens (only needed for world model predictor)
|
||||
if self.config.enable_world_model:
|
||||
action_mask = torch.isin(
|
||||
qwen_inputs["input_ids"],
|
||||
torch.tensor(self.action_token_ids, device=qwen_inputs["input_ids"].device),
|
||||
)
|
||||
action_indices = action_mask.nonzero(as_tuple=True)
|
||||
|
||||
device_type = next(self.parameters()).device.type
|
||||
|
||||
with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
|
||||
last_hidden = self._qwen_last_decoder_hidden(qwen_inputs) # [B, seq_len, H]
|
||||
b, _, h = last_hidden.shape
|
||||
|
||||
if self.config.enable_world_model:
|
||||
action_tokens = last_hidden[action_indices[0], action_indices[1], :].view(b, -1, h)
|
||||
|
||||
embodied_action_tokens = last_hidden[embodied_indices[0], embodied_indices[1], :].view(b, -1, h)
|
||||
|
||||
# ---- Step 2+3: JEPA Encoder + Predictor ----
|
||||
device_wm = last_hidden.device
|
||||
if not self.config.enable_world_model:
|
||||
wm_loss = torch.tensor(0.0, device=device_wm)
|
||||
else:
|
||||
b, v, t_frames, c, h_img, w_img = batch_videos.shape
|
||||
batch_videos_flat = batch_videos.reshape(b * v, t_frames, c, h_img, w_img)
|
||||
|
||||
video_pixels = self.video_processor(videos=list(batch_videos_flat), return_tensors="pt")[
|
||||
"pixel_values_videos"
|
||||
].to(self.video_encoder.device) # [B*V, T, C, H, W]
|
||||
|
||||
with torch.no_grad():
|
||||
video_embeddings = self.video_encoder.get_vision_features(pixel_values_videos=video_pixels)
|
||||
# Merge views: [B*V, ...] -> [B, ..., V*embed_dim]
|
||||
video_embeddings = torch.cat(torch.chunk(video_embeddings, chunks=v, dim=0), dim=2)
|
||||
|
||||
tubelet_size = self.video_encoder.config.tubelet_size
|
||||
device_wm = video_embeddings.device
|
||||
# num_video_frames raw frames → t_enc_total temporal positions after tubelet compression
|
||||
t_enc_total = self.config.num_video_frames // tubelet_size
|
||||
|
||||
if t_enc_total < 2:
|
||||
wm_loss = torch.tensor(0.0, device=device_wm)
|
||||
else:
|
||||
# Shift-by-one JEPA split (matches original VLA_JEPA.py lines 231-232):
|
||||
# input_states: positions 0..T-2, gt_states: positions 1..T-1
|
||||
t_enc_ctx = t_enc_total - 1
|
||||
tokens_per_frame = video_embeddings.shape[1] // t_enc_total
|
||||
|
||||
input_states = video_embeddings[:, : tokens_per_frame * t_enc_ctx, :]
|
||||
gt_states = video_embeddings[:, tokens_per_frame:, :]
|
||||
|
||||
expected_actions = t_enc_ctx * self.config.num_action_tokens_per_timestep
|
||||
if action_tokens.shape[1] < expected_actions:
|
||||
pad = action_tokens[:, -1:].repeat(1, expected_actions - action_tokens.shape[1], 1)
|
||||
action_tokens = torch.cat([action_tokens, pad], dim=1)
|
||||
|
||||
predicted_states = self.video_predictor(
|
||||
input_states.float(),
|
||||
action_tokens[:, :expected_actions].float(),
|
||||
)
|
||||
|
||||
wm_loss = F.l1_loss(predicted_states, gt_states.float(), reduction="mean")
|
||||
|
||||
if not has_action:
|
||||
return {"wm_loss": wm_loss}
|
||||
|
||||
# ---- Step 4: Action Head ----
|
||||
with torch.autocast(device_type=device_type, dtype=torch.float32):
|
||||
actions_tensor = torch.tensor(
|
||||
np.array(actions), device=last_hidden.device, dtype=torch.float32
|
||||
) # [B, T_full, action_dim]
|
||||
action_horizon = self.config.chunk_size
|
||||
actions_target = actions_tensor[:, -action_horizon:, :]
|
||||
|
||||
state_tensor = None
|
||||
if state is not None:
|
||||
state_tensor = torch.tensor(
|
||||
np.array(state), device=last_hidden.device, dtype=last_hidden.dtype
|
||||
) # [B, 1, state_dim]
|
||||
|
||||
repeated_diffusion_steps = self.config.repeated_diffusion_steps
|
||||
actions_target = actions_target.repeat(repeated_diffusion_steps, 1, 1)
|
||||
embodied_action_tokens = embodied_action_tokens.repeat(repeated_diffusion_steps, 1, 1)
|
||||
if state_tensor is not None:
|
||||
state_tensor = state_tensor.repeat(repeated_diffusion_steps, 1, 1)
|
||||
|
||||
action_is_pad_rep = None
|
||||
if action_is_pad is not None:
|
||||
pad_tensor = torch.stack(
|
||||
[
|
||||
p.to(actions_target.device)
|
||||
if isinstance(p, Tensor)
|
||||
else torch.tensor(p, device=actions_target.device)
|
||||
for p in action_is_pad
|
||||
]
|
||||
) # [B, T_full]
|
||||
pad_tensor = pad_tensor[:, -action_horizon:] # [B, action_horizon]
|
||||
action_is_pad_rep = pad_tensor.repeat(repeated_diffusion_steps, 1) # [B*R, action_horizon]
|
||||
|
||||
action_loss = self.action_model(
|
||||
embodied_action_tokens, actions_target, state_tensor, action_is_pad_rep
|
||||
)
|
||||
|
||||
return {"action_loss": action_loss, "wm_loss": wm_loss * self.config.world_model_loss_weight}
|
||||
|
||||
# ---- Native predict_action (follows original VLA_JEPA.predict_action) ----
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action(
|
||||
self,
|
||||
batch_images: list[list[Image.Image]],
|
||||
instructions: list[str],
|
||||
state: np.ndarray | None = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Native action prediction following original VLA_JEPA.predict_action.
|
||||
|
||||
Args:
|
||||
batch_images: List of samples; each is List[PIL.Image] (multi-view).
|
||||
instructions: Task instructions, one per sample.
|
||||
state: Optional [B, state_dim] numpy array.
|
||||
|
||||
Returns:
|
||||
np.ndarray [B, action_horizon, action_dim] — predicted actions.
|
||||
"""
|
||||
if self.config.resize_images_to is not None:
|
||||
height, width = self.config.resize_images_to
|
||||
resampling = getattr(Image, "Resampling", Image).BOX
|
||||
batch_images = [
|
||||
[image.resize((width, height), resample=resampling) for image in sample_images]
|
||||
for sample_images in batch_images
|
||||
]
|
||||
|
||||
qwen_inputs = self.qwen.build_inputs(
|
||||
images=batch_images,
|
||||
instructions=instructions,
|
||||
action_prompt=self.replace_prompt,
|
||||
embodied_prompt=self.embodied_replace_prompt,
|
||||
)
|
||||
|
||||
embodied_mask = qwen_inputs["input_ids"] == self.embodied_action_token_id
|
||||
embodied_indices = embodied_mask.nonzero(as_tuple=True)
|
||||
|
||||
device_type = next(self.parameters()).device.type
|
||||
|
||||
with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
|
||||
last_hidden = self._qwen_last_decoder_hidden(qwen_inputs) # [B, seq_len, H]
|
||||
b, _, h = last_hidden.shape
|
||||
embodied_action_tokens = last_hidden[embodied_indices[0], embodied_indices[1], :].view(b, -1, h)
|
||||
|
||||
state_tensor = None
|
||||
if state is not None:
|
||||
state_tensor = torch.from_numpy(np.array(state)).to(
|
||||
device=last_hidden.device, dtype=last_hidden.dtype
|
||||
)
|
||||
|
||||
pred_actions = self.action_model.predict_action(
|
||||
embodied_action_tokens.float(), state_tensor.float() if state_tensor is not None else None
|
||||
) # [B, action_horizon, action_dim]
|
||||
|
||||
return pred_actions.detach().cpu().numpy()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# LeRobot Adapter Layer - converts between LeRobot batch format and native VLA-JEPA format
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class VLAJEPAPolicy(PreTrainedPolicy):
|
||||
"""
|
||||
LeRobot adapter for VLA-JEPA.
|
||||
|
||||
Converts LeRobot's standard batch format (dict[str, Tensor]) to the native
|
||||
VLA-JEPA format (List[dict]), calls the native model, and converts outputs
|
||||
back to LeRobot format.
|
||||
"""
|
||||
|
||||
config_class = VLAJEPAConfig
|
||||
name = "vla_jepa"
|
||||
|
||||
def __init__(self, config: VLAJEPAConfig, **kwargs) -> None:
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
if dataset_meta := kwargs.get("dataset_meta"):
|
||||
# cfg.input_features keeps the pretrained model's feature keys (needed for rename_map
|
||||
# compatibility), so validate_features() may have read stale dims from a pretrained
|
||||
# config. Override state_dim/action_dim from the actual dataset being used.
|
||||
ds_features = dataset_meta.features
|
||||
if OBS_STATE in ds_features:
|
||||
config.state_dim = ds_features[OBS_STATE]["shape"][0]
|
||||
if ACTION in ds_features:
|
||||
config.action_dim = ds_features[ACTION]["shape"][0]
|
||||
|
||||
self.model = VLAJEPAModel(config)
|
||||
self.reset()
|
||||
|
||||
def reset(self) -> None:
|
||||
self._queues = {ACTION: deque(maxlen=self.config.n_action_steps)}
|
||||
|
||||
# ---- Format Conversion: LeRobot → Native ----
|
||||
|
||||
def _prepare_model_inputs(self, batch: dict[str, Tensor]) -> list[dict]:
|
||||
"""
|
||||
Convert LeRobot batch format to native VLA-JEPA examples format.
|
||||
|
||||
LeRobot format:
|
||||
batch = {
|
||||
"observation.images.<key>": Tensor [B, C, H, W] or [B, T, C, H, W],
|
||||
"observation.state": Tensor [B, state_dim] or [B, T, state_dim],
|
||||
"action": Tensor [B, chunk_size, action_dim], (training only)
|
||||
"task": str | List[str], (optional instruction)
|
||||
}
|
||||
|
||||
Native format (List[dict]):
|
||||
{
|
||||
"image": List[PIL.Image], # multi-view images per sample
|
||||
"video": np.ndarray [V, T, H, W, 3],
|
||||
"lang": str, # task instruction
|
||||
"action": np.ndarray [T, action_dim], # optional
|
||||
"state": np.ndarray [1, state_dim], # optional
|
||||
}
|
||||
"""
|
||||
# Determine batch size from the first image feature
|
||||
image_keys = list(self.config.image_features.keys())
|
||||
if not image_keys:
|
||||
raise ValueError("VLAJEPA requires at least one image feature.")
|
||||
first_key = image_keys[0]
|
||||
first_tensor = batch[first_key]
|
||||
batch_size = first_tensor.shape[0]
|
||||
|
||||
# ---- Collect images per sample ----
|
||||
# images_per_sample[b][v] = PIL.Image for view v
|
||||
images_per_sample: list[list[Image.Image]] = [[] for _ in range(batch_size)]
|
||||
for key in image_keys:
|
||||
tensor = batch[key] # [B, C, H, W] or [B, T, C, H, W]
|
||||
if tensor.ndim == 5:
|
||||
# observation_delta_indices = [0, 1, ..., num_video_frames-1]
|
||||
# index 0 is the current observation (delta=0)
|
||||
tensor = tensor[:, 0]
|
||||
for b in range(batch_size):
|
||||
images_per_sample[b].append(self.model.qwen.tensor_to_pil(tensor[b]))
|
||||
|
||||
# ---- Collect videos per sample ----
|
||||
# Build video arrays: for each sample, stack views as [V, T, H, W, 3]
|
||||
# Check whether any image feature has a time dimension
|
||||
video_source = None
|
||||
for k in image_keys:
|
||||
if k in batch:
|
||||
video_source = batch[k] # Use first available for shape inspection
|
||||
break
|
||||
|
||||
if video_source is None:
|
||||
raise ValueError("No image data found in batch for video construction.")
|
||||
|
||||
videos_per_sample = []
|
||||
for b in range(batch_size):
|
||||
sample_views = []
|
||||
for k in image_keys:
|
||||
t = batch[k][b] # [C, H, W] or [T, C, H, W]
|
||||
if t.ndim == 3:
|
||||
t = t.unsqueeze(0) # [1, C, H, W]
|
||||
# Convert to [T, H, W, 3] numpy
|
||||
t_np = t.permute(0, 2, 3, 1).detach().cpu().float().numpy()
|
||||
# Clamp to [0, 255]
|
||||
if t_np.max() <= 1.0:
|
||||
t_np = t_np * 255.0
|
||||
t_np = np.rint(t_np.clip(0, 255)).astype(np.uint8)
|
||||
sample_views.append(t_np)
|
||||
# Stack views: [V, T, H, W, 3]
|
||||
videos_per_sample.append(np.stack(sample_views, axis=0))
|
||||
|
||||
# ---- Collect instructions ----
|
||||
tasks = batch.get("task")
|
||||
if tasks is None:
|
||||
instructions = ["Execute the robot action."] * batch_size
|
||||
elif isinstance(tasks, str):
|
||||
instructions = [tasks] * batch_size
|
||||
else:
|
||||
instructions = list(tasks)
|
||||
|
||||
# ---- Collect actions (training only) ----
|
||||
actions_list = None
|
||||
action_is_pad_list = None
|
||||
actions_tensor = batch.get(ACTION)
|
||||
if actions_tensor is not None:
|
||||
if actions_tensor.ndim == 2:
|
||||
actions_tensor = actions_tensor.unsqueeze(1)
|
||||
actions_list = [actions_tensor[b].detach().cpu().float().numpy() for b in range(batch_size)]
|
||||
action_is_pad_tensor = batch.get("action_is_pad")
|
||||
if action_is_pad_tensor is not None:
|
||||
action_is_pad_list = [action_is_pad_tensor[b].detach().cpu() for b in range(batch_size)]
|
||||
|
||||
# ---- Collect state ----
|
||||
state_list = None
|
||||
state_tensor = batch.get(OBS_STATE)
|
||||
if state_tensor is not None:
|
||||
if state_tensor.ndim > 2:
|
||||
state_tensor = state_tensor[:, -1, :]
|
||||
if state_tensor.ndim == 2:
|
||||
state_tensor = state_tensor.unsqueeze(1) # [B, 1, state_dim]
|
||||
state_list = [state_tensor[b].detach().cpu().float().numpy() for b in range(batch_size)]
|
||||
|
||||
# ---- Assemble native examples ----
|
||||
examples = []
|
||||
for b in range(batch_size):
|
||||
example = {
|
||||
"image": images_per_sample[b],
|
||||
"video": videos_per_sample[b],
|
||||
"lang": instructions[b],
|
||||
}
|
||||
if actions_list is not None:
|
||||
example["action"] = actions_list[b]
|
||||
if action_is_pad_list is not None:
|
||||
example["action_is_pad"] = action_is_pad_list[b]
|
||||
if state_list is not None:
|
||||
example["state"] = state_list[b]
|
||||
examples.append(example)
|
||||
|
||||
return examples
|
||||
|
||||
# ---- LeRobot Policy Interface ----
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
||||
"""LeRobot train forward: convert → native forward → aggregate losses."""
|
||||
examples = self._prepare_model_inputs(batch)
|
||||
native_output = self.model.forward(examples)
|
||||
|
||||
ref = next(iter(native_output.values()))
|
||||
zero = torch.zeros((), device=ref.device, dtype=ref.dtype)
|
||||
total_loss = native_output.get("action_loss", zero) + native_output.get("wm_loss", zero)
|
||||
logs = {k: v.detach().item() for k, v in native_output.items()}
|
||||
logs["loss"] = total_loss.detach().item()
|
||||
return total_loss, logs
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
return self.model.parameters()
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
||||
"""LeRobot inference: convert → native predict → return as Tensor."""
|
||||
self.eval()
|
||||
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
|
||||
|
||||
examples = self._prepare_model_inputs(batch)
|
||||
batch_images = [ex["image"] for ex in examples]
|
||||
instructions = [ex["lang"] for ex in examples]
|
||||
|
||||
state_np = None
|
||||
if "state" in examples[0] and examples[0]["state"] is not None:
|
||||
state_np = np.stack([ex["state"] for ex in examples])
|
||||
|
||||
actions_np = self.model.predict_action(batch_images, instructions, state_np)
|
||||
return torch.from_numpy(actions_np).to(device=self.config.device, dtype=torch.float32)
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
||||
"""LeRobot select_action with action queue caching."""
|
||||
self.eval()
|
||||
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
|
||||
if len(self._queues[ACTION]) == 0:
|
||||
actions = self.predict_action_chunk(batch)
|
||||
self._queues[ACTION].extend(actions.transpose(0, 1)[: self.config.n_action_steps])
|
||||
return self._queues[ACTION].popleft()
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls: type[T],
|
||||
pretrained_name_or_path: str | Path,
|
||||
**kwargs,
|
||||
):
|
||||
return super().from_pretrained(pretrained_name_or_path, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, strict: bool) -> T:
|
||||
reinit_prefixes = model.config.reinit_modules
|
||||
if not reinit_prefixes:
|
||||
return super()._load_as_safetensor(model, model_file, map_location, strict)
|
||||
|
||||
from safetensors.torch import load_file
|
||||
|
||||
state_dict = load_file(model_file, device=map_location)
|
||||
current = model.state_dict()
|
||||
|
||||
reinitialized: list[str] = []
|
||||
filtered: dict = {}
|
||||
for key, value in state_dict.items():
|
||||
if key in current and value.shape != current[key].shape:
|
||||
if not any(key.startswith(p) for p in reinit_prefixes):
|
||||
raise ValueError(
|
||||
f"Shape mismatch for '{key}' (checkpoint {tuple(value.shape)} vs model "
|
||||
f"{tuple(current[key].shape)}) and its prefix is not in `reinit_modules`."
|
||||
)
|
||||
reinitialized.append(
|
||||
f"{key}: checkpoint {tuple(value.shape)} → model {tuple(current[key].shape)}"
|
||||
)
|
||||
else:
|
||||
filtered[key] = value
|
||||
|
||||
if reinitialized:
|
||||
logging.warning(
|
||||
f"reinit_modules: skipping {len(reinitialized)} tensor(s) with mismatched shapes "
|
||||
f"(randomly re-initialised):\n " + "\n ".join(reinitialized)
|
||||
)
|
||||
|
||||
from lerobot.policies.utils import log_model_loading_keys
|
||||
|
||||
missing_keys, unexpected_keys = model.load_state_dict(filtered, strict=False)
|
||||
log_model_loading_keys(missing_keys, unexpected_keys)
|
||||
return model
|
||||
@@ -1,155 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.policies.vla_jepa.configuration_vla_jepa import VLAJEPAConfig
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
EnvTransition,
|
||||
NormalizerProcessorStep,
|
||||
PolicyAction,
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorStep,
|
||||
ProcessorStepRegistry,
|
||||
RenameObservationsProcessorStep,
|
||||
TransitionKey,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
|
||||
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register(name="vla_jepa_clip_actions")
|
||||
class ClipActionsProcessorStep(ProcessorStep):
|
||||
"""Clips action tensor to [-1, 1] before unnormalization."""
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
action = transition.get(TransitionKey.ACTION)
|
||||
if action is not None:
|
||||
transition = dict(transition)
|
||||
transition[TransitionKey.ACTION] = action.clamp(-1.0, 1.0)
|
||||
return transition
|
||||
|
||||
def transform_features(self, features):
|
||||
return features
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register(name="vla_jepa_pre_snap_gripper")
|
||||
class PreSnapGripperProcessorStep(ProcessorStep):
|
||||
"""Snaps a gripper dimension to {0, 1} BEFORE unnormalization.
|
||||
|
||||
Mirrors the original starVLA LIBERO eval:
|
||||
normalized[:, gripper_dim] = np.where(normalized[:, gripper_dim] < threshold, 0, 1)
|
||||
This ensures the unnormalizer receives an exact binary value, which is
|
||||
required when the model was trained with gripper in identity (mask=False)
|
||||
space where 0=open and 1=close.
|
||||
"""
|
||||
|
||||
def __init__(self, gripper_dim: int = 6, threshold: float = 0.5):
|
||||
self.gripper_dim = gripper_dim
|
||||
self.threshold = threshold
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
action = transition.get(TransitionKey.ACTION)
|
||||
if action is not None and action.shape[-1] > self.gripper_dim:
|
||||
transition = dict(transition)
|
||||
a = action.clone()
|
||||
a[..., self.gripper_dim] = (a[..., self.gripper_dim] >= self.threshold).float()
|
||||
transition[TransitionKey.ACTION] = a
|
||||
return transition
|
||||
|
||||
def transform_features(self, features):
|
||||
return features
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register(name="vla_jepa_binarize_gripper")
|
||||
class BinarizeGripperProcessorStep(ProcessorStep):
|
||||
"""Binarizes a gripper dimension after unnormalization.
|
||||
|
||||
Maps continuous value to {-1, 1}: > threshold → -1, <= threshold → 1 (matches starVLA convention).
|
||||
Only applied when action has more dimensions than gripper_dim.
|
||||
"""
|
||||
|
||||
def __init__(self, gripper_dim: int = 6, threshold: float = 0.5):
|
||||
self.gripper_dim = gripper_dim
|
||||
self.threshold = threshold
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
action = transition.get(TransitionKey.ACTION)
|
||||
if action is not None and action.shape[-1] > self.gripper_dim:
|
||||
transition = dict(transition)
|
||||
a = action.clone()
|
||||
a[..., self.gripper_dim] = 1.0 - 2.0 * (a[..., self.gripper_dim] > self.threshold).float()
|
||||
transition[TransitionKey.ACTION] = a
|
||||
return transition
|
||||
|
||||
def transform_features(self, features):
|
||||
return features
|
||||
|
||||
|
||||
def make_vla_jepa_pre_post_processors(
|
||||
config: VLAJEPAConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
features = {**config.input_features, **config.output_features}
|
||||
input_steps = [
|
||||
RenameObservationsProcessorStep(rename_map={}),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
NormalizerProcessorStep(
|
||||
features=features,
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
]
|
||||
output_steps: list[ProcessorStep] = []
|
||||
if config.clip_normalized_actions:
|
||||
output_steps.append(ClipActionsProcessorStep())
|
||||
if config.pre_snap_gripper_action:
|
||||
output_steps.append(
|
||||
PreSnapGripperProcessorStep(gripper_dim=config.gripper_dim, threshold=config.gripper_threshold)
|
||||
)
|
||||
output_steps.append(
|
||||
UnnormalizerProcessorStep(
|
||||
features=features,
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
)
|
||||
)
|
||||
if config.binarize_gripper_action:
|
||||
output_steps.append(
|
||||
BinarizeGripperProcessorStep(gripper_dim=config.gripper_dim, threshold=config.gripper_threshold)
|
||||
)
|
||||
output_steps.append(DeviceProcessorStep(device="cpu"))
|
||||
return (
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
|
||||
steps=input_steps,
|
||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
),
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction](
|
||||
steps=output_steps,
|
||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
to_transition=policy_action_to_transition,
|
||||
to_output=transition_to_policy_action,
|
||||
),
|
||||
)
|
||||
@@ -1,117 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from lerobot.utils.import_utils import _transformers_available
|
||||
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers import AutoProcessor, Qwen3VLForConditionalGeneration
|
||||
else:
|
||||
AutoProcessor = None
|
||||
Qwen3VLForConditionalGeneration = None
|
||||
|
||||
from .configuration_vla_jepa import VLAJEPAConfig
|
||||
|
||||
|
||||
class Qwen3VLInterface(torch.nn.Module):
|
||||
def __init__(self, config: VLAJEPAConfig) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.model = Qwen3VLForConditionalGeneration.from_pretrained(
|
||||
config.qwen_model_name,
|
||||
torch_dtype=self._get_torch_dtype(config.torch_dtype),
|
||||
)
|
||||
self.processor = AutoProcessor.from_pretrained(config.qwen_model_name)
|
||||
self.processor.tokenizer.padding_side = config.tokenizer_padding_side
|
||||
self.model.config.hidden_size = self.model.config.text_config.hidden_size
|
||||
|
||||
@staticmethod
|
||||
def _get_torch_dtype(dtype_name: str) -> torch.dtype:
|
||||
if dtype_name == "float32":
|
||||
return torch.float32
|
||||
if dtype_name == "float16":
|
||||
return torch.float16
|
||||
return torch.bfloat16
|
||||
|
||||
def expand_tokenizer(self) -> tuple[list[str], list[int], int]:
|
||||
# starVLA/JEVLA checkpoints expand action tokens as action_horizon * 4,
|
||||
# independent of vj2 num_action_tokens_per_timestep. Keeping this count
|
||||
# is required for Qwen embedding/lm_head checkpoint shapes to match.
|
||||
max_action_tokens = self.config.chunk_size * 4
|
||||
tokenizer = self.processor.tokenizer
|
||||
action_tokens = []
|
||||
action_token_ids = []
|
||||
for idx in range(max_action_tokens):
|
||||
token = self.config.special_action_token.format(idx)
|
||||
action_tokens.append(token)
|
||||
if token not in tokenizer.get_vocab():
|
||||
tokenizer.add_tokens([token], special_tokens=True)
|
||||
action_token_ids.append(tokenizer.convert_tokens_to_ids(token))
|
||||
|
||||
embodied_action_token = self.config.embodied_action_token
|
||||
if embodied_action_token not in tokenizer.get_vocab():
|
||||
tokenizer.add_tokens([embodied_action_token], special_tokens=True)
|
||||
embodied_action_token_id = tokenizer.convert_tokens_to_ids(embodied_action_token)
|
||||
|
||||
if self.model.get_input_embeddings().weight.size(0) < len(tokenizer):
|
||||
self.model.resize_token_embeddings(len(tokenizer))
|
||||
return action_tokens, action_token_ids, embodied_action_token_id
|
||||
|
||||
def build_inputs(
|
||||
self,
|
||||
images: Sequence[Sequence[Image.Image]],
|
||||
instructions: Sequence[str],
|
||||
action_prompt: str,
|
||||
embodied_prompt: str,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
messages = []
|
||||
for sample_images, instruction in zip(images, instructions, strict=True):
|
||||
prompt = self.config.prompt_template.format(
|
||||
instruction=instruction,
|
||||
actions=action_prompt,
|
||||
e_actions=embodied_prompt,
|
||||
)
|
||||
content = [{"type": "image", "image": img} for img in sample_images]
|
||||
content.append({"type": "text", "text": prompt})
|
||||
messages.append([{"role": "user", "content": content}])
|
||||
|
||||
batch_inputs = self.processor.apply_chat_template(
|
||||
messages,
|
||||
tokenize=True,
|
||||
add_generation_prompt=True,
|
||||
return_dict=True,
|
||||
processor_kwargs={"padding": True, "return_tensors": "pt"},
|
||||
)
|
||||
return batch_inputs.to(self.model.device)
|
||||
|
||||
@staticmethod
|
||||
def tensor_to_pil(image_tensor: torch.Tensor) -> Image.Image:
|
||||
image = image_tensor.detach().cpu()
|
||||
if image.ndim == 3 and image.shape[0] in (1, 3):
|
||||
image = image.permute(1, 2, 0)
|
||||
image = image.float()
|
||||
if image.max() <= 1.0:
|
||||
image = image * 255.0
|
||||
image = image.clamp(0, 255).round().to(torch.uint8).numpy()
|
||||
if image.shape[-1] == 1:
|
||||
image = np.repeat(image, 3, axis=-1)
|
||||
return Image.fromarray(image)
|
||||
@@ -1,418 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import nn
|
||||
|
||||
|
||||
def build_action_block_causal_attention_mask(
|
||||
num_frames: int, grid_height: int, grid_width: int, add_tokens: int = 1
|
||||
) -> torch.Tensor:
|
||||
tokens_per_frame = add_tokens + grid_height * grid_width
|
||||
num_tokens = num_frames * tokens_per_frame
|
||||
mask = torch.zeros(num_tokens, num_tokens, dtype=torch.bool)
|
||||
mask_block = torch.ones(tokens_per_frame, tokens_per_frame, dtype=torch.bool)
|
||||
local_window_time = num_frames
|
||||
|
||||
for current_frame in range(num_frames):
|
||||
first_context_frame = max(0, current_frame - local_window_time + 1)
|
||||
for context_frame in range(first_context_frame, current_frame + 1):
|
||||
row = slice(current_frame * tokens_per_frame, (current_frame + 1) * tokens_per_frame)
|
||||
col = slice(context_frame * tokens_per_frame, (context_frame + 1) * tokens_per_frame)
|
||||
mask[row, col] = mask_block
|
||||
return mask
|
||||
|
||||
|
||||
def rotate_queries_or_keys(x: torch.Tensor, pos: torch.Tensor) -> torch.Tensor:
|
||||
_, _, _, dim = x.size()
|
||||
if dim % 2 != 0:
|
||||
raise ValueError("Embedding dimension must be even for rotary position encoding.")
|
||||
|
||||
omega = torch.arange(dim // 2, dtype=x.dtype, device=x.device)
|
||||
omega /= dim / 2.0
|
||||
omega = 1.0 / 10000**omega
|
||||
freqs = torch.einsum("..., f -> ... f", pos, omega)
|
||||
emb_sin = freqs.sin().squeeze(-1).repeat(1, 1, 1, 2)
|
||||
emb_cos = freqs.cos().squeeze(-1).repeat(1, 1, 1, 2)
|
||||
|
||||
y = x.unflatten(-1, (-1, 2))
|
||||
y1, y2 = y.unbind(dim=-1)
|
||||
y = torch.stack((-y2, y1), dim=-1).flatten(-2)
|
||||
return x * emb_cos + y * emb_sin
|
||||
|
||||
|
||||
class DropPath(nn.Module):
|
||||
def __init__(self, drop_prob: float = 0.0) -> None:
|
||||
super().__init__()
|
||||
self.drop_prob = drop_prob
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if self.drop_prob == 0.0 or not self.training:
|
||||
return x
|
||||
keep_prob = 1 - self.drop_prob
|
||||
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
|
||||
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
||||
random_tensor.floor_()
|
||||
return x.div(keep_prob) * random_tensor
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
hidden_features: int | None = None,
|
||||
out_features: int | None = None,
|
||||
act_layer: type[nn.Module] = nn.GELU,
|
||||
drop: float = 0.0,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.fc1 = nn.Linear(in_features, hidden_features)
|
||||
self.act = act_layer()
|
||||
self.fc2 = nn.Linear(hidden_features, out_features)
|
||||
self.drop = nn.Dropout(drop)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
x = self.drop(x)
|
||||
x = self.fc2(x)
|
||||
x = self.drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class ACRoPEAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int = 8,
|
||||
qkv_bias: bool = False,
|
||||
qk_scale: float | None = None,
|
||||
attn_drop: float = 0.0,
|
||||
proj_drop: float = 0.0,
|
||||
use_sdpa: bool = True,
|
||||
is_causal: bool = False,
|
||||
grid_size: int = 16,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.scale = qk_scale or self.head_dim**-0.5
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop_prob = proj_drop
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
self.use_sdpa = use_sdpa
|
||||
self.d_dim = int(2 * ((self.head_dim // 3) // 2))
|
||||
self.h_dim = int(2 * ((self.head_dim // 3) // 2))
|
||||
self.w_dim = int(2 * ((self.head_dim // 3) // 2))
|
||||
self.grid_size = grid_size
|
||||
self.is_causal = is_causal
|
||||
|
||||
@staticmethod
|
||||
def _get_frame_pos(ids: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
||||
return ids // int(height * width)
|
||||
|
||||
def _get_height_pos(self, ids: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
||||
frame_ids = self._get_frame_pos(ids, height, width)
|
||||
ids = ids - int(height * width) * frame_ids
|
||||
return ids // width
|
||||
|
||||
def separate_positions(
|
||||
self, ids: torch.Tensor, height: int, width: int
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
frame_ids = self._get_frame_pos(ids, height, width)
|
||||
height_ids = self._get_height_pos(ids, height, width)
|
||||
width_ids = ids - int(height * width) * frame_ids - width * height_ids
|
||||
return 1.0 * frame_ids, 1.0 * height_ids, 1.0 * width_ids
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
mask: torch.Tensor | None = None,
|
||||
attn_mask: torch.Tensor | None = None,
|
||||
num_frames: int | None = None,
|
||||
grid_height: int | None = None,
|
||||
grid_width: int | None = None,
|
||||
action_tokens: int = 0,
|
||||
) -> torch.Tensor:
|
||||
batch_size, num_tokens, channels = x.size()
|
||||
if num_frames is None or grid_height is None or grid_width is None:
|
||||
raise ValueError("num_frames, grid_height and grid_width are required.")
|
||||
|
||||
if mask is not None:
|
||||
mask = mask.unsqueeze(1).repeat(1, self.num_heads, 1)
|
||||
d_mask, h_mask, w_mask = self.separate_positions(mask, grid_height, grid_width)
|
||||
else:
|
||||
mask = torch.arange(int(num_frames * grid_height * grid_width), device=x.device)
|
||||
d_mask, h_mask, w_mask = self.separate_positions(mask, grid_height, grid_width)
|
||||
|
||||
h_mask *= self.grid_size / grid_height
|
||||
w_mask *= self.grid_size / grid_width
|
||||
|
||||
if action_tokens > 0:
|
||||
x = x.view(batch_size, -1, action_tokens + grid_height * grid_width, channels)
|
||||
action_q, action_k, action_v = [], [], []
|
||||
for idx in range(action_tokens):
|
||||
action_token = x[:, :, idx : idx + 1, :].flatten(1, 2)
|
||||
qkv = self.qkv(action_token).unflatten(-1, (3, self.num_heads, -1)).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2]
|
||||
qd = rotate_queries_or_keys(
|
||||
q[..., : self.d_dim], pos=torch.arange(num_frames, device=x.device)
|
||||
)
|
||||
kd = rotate_queries_or_keys(
|
||||
k[..., : self.d_dim], pos=torch.arange(num_frames, device=x.device)
|
||||
)
|
||||
qr = q[..., self.d_dim :]
|
||||
kr = k[..., self.d_dim :]
|
||||
action_q.append(
|
||||
torch.cat([qd, qr], dim=-1).view(batch_size, self.num_heads, num_frames, 1, -1)
|
||||
)
|
||||
action_k.append(
|
||||
torch.cat([kd, kr], dim=-1).view(batch_size, self.num_heads, num_frames, 1, -1)
|
||||
)
|
||||
action_v.append(v.view(batch_size, self.num_heads, num_frames, 1, -1))
|
||||
|
||||
action_q = torch.cat(action_q, dim=3).flatten(2, 3)
|
||||
action_k = torch.cat(action_k, dim=3).flatten(2, 3)
|
||||
action_v = torch.cat(action_v, dim=3).flatten(2, 3)
|
||||
x = x[:, :, action_tokens:, :].flatten(1, 2)
|
||||
|
||||
qkv = self.qkv(x).unflatten(-1, (3, self.num_heads, -1)).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2]
|
||||
|
||||
offset = 0
|
||||
qd = rotate_queries_or_keys(q[..., offset : offset + self.d_dim], pos=d_mask)
|
||||
kd = rotate_queries_or_keys(k[..., offset : offset + self.d_dim], pos=d_mask)
|
||||
offset += self.d_dim
|
||||
qh = rotate_queries_or_keys(q[..., offset : offset + self.h_dim], pos=h_mask)
|
||||
kh = rotate_queries_or_keys(k[..., offset : offset + self.h_dim], pos=h_mask)
|
||||
offset += self.h_dim
|
||||
qw = rotate_queries_or_keys(q[..., offset : offset + self.w_dim], pos=w_mask)
|
||||
kw = rotate_queries_or_keys(k[..., offset : offset + self.w_dim], pos=w_mask)
|
||||
offset += self.w_dim
|
||||
|
||||
if offset < self.head_dim:
|
||||
q = torch.cat([qd, qh, qw, q[..., offset:]], dim=-1)
|
||||
k = torch.cat([kd, kh, kw, k[..., offset:]], dim=-1)
|
||||
else:
|
||||
q = torch.cat([qd, qh, qw], dim=-1)
|
||||
k = torch.cat([kd, kh, kw], dim=-1)
|
||||
|
||||
if action_tokens > 0:
|
||||
|
||||
def merge(frame_tokens: torch.Tensor, action_token_values: torch.Tensor) -> torch.Tensor:
|
||||
frame_tokens = frame_tokens.view(
|
||||
batch_size, self.num_heads, num_frames, grid_height * grid_width, -1
|
||||
)
|
||||
action_token_values = action_token_values.view(
|
||||
batch_size, self.num_heads, num_frames, action_tokens, -1
|
||||
)
|
||||
return torch.cat([action_token_values, frame_tokens], dim=3).flatten(2, 3)
|
||||
|
||||
q = merge(q, action_q)
|
||||
k = merge(k, action_k)
|
||||
v = merge(v, action_v)
|
||||
|
||||
if attn_mask is not None or self.use_sdpa:
|
||||
x = F.scaled_dot_product_attention(
|
||||
q, k, v, dropout_p=self.proj_drop_prob, is_causal=self.is_causal, attn_mask=attn_mask
|
||||
)
|
||||
else:
|
||||
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
x = attn @ v
|
||||
|
||||
x = x.transpose(1, 2).reshape(batch_size, num_tokens, channels)
|
||||
x = self.proj(x)
|
||||
return self.proj_drop(x)
|
||||
|
||||
|
||||
class ACBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
qkv_bias: bool = True,
|
||||
qk_scale: float | None = None,
|
||||
drop: float = 0.0,
|
||||
attn_drop: float = 0.0,
|
||||
drop_path: float = 0.0,
|
||||
norm_layer: type[nn.Module] = nn.LayerNorm,
|
||||
use_sdpa: bool = True,
|
||||
is_causal: bool = False,
|
||||
grid_size: int = 16,
|
||||
use_rope: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.norm1 = norm_layer(dim)
|
||||
if not use_rope:
|
||||
raise ValueError("JEVLA1 world predictor uses AC RoPE attention.")
|
||||
self.attn = ACRoPEAttention(
|
||||
dim,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
attn_drop=attn_drop,
|
||||
use_sdpa=use_sdpa,
|
||||
is_causal=is_causal,
|
||||
grid_size=grid_size,
|
||||
proj_drop=drop,
|
||||
)
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
self.norm2 = norm_layer(dim)
|
||||
self.mlp = MLP(
|
||||
in_features=dim,
|
||||
hidden_features=int(dim * mlp_ratio),
|
||||
act_layer=nn.GELU,
|
||||
drop=drop,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
attn_mask: torch.Tensor | None = None,
|
||||
num_frames: int | None = None,
|
||||
grid_height: int | None = None,
|
||||
grid_width: int | None = None,
|
||||
action_tokens: int = 0,
|
||||
) -> torch.Tensor:
|
||||
y = self.norm1(x)
|
||||
y = self.attn(
|
||||
y,
|
||||
mask=None,
|
||||
attn_mask=attn_mask,
|
||||
num_frames=num_frames,
|
||||
grid_height=grid_height,
|
||||
grid_width=grid_width,
|
||||
action_tokens=action_tokens,
|
||||
)
|
||||
x = x + self.drop_path(y)
|
||||
y = self.norm2(x)
|
||||
return x + self.drop_path(self.mlp(y))
|
||||
|
||||
|
||||
class ActionConditionedVideoPredictor(nn.Module):
|
||||
"""JEVLA1-compatible action-conditioned V-JEPA predictor."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_frames: int,
|
||||
img_size: tuple[int, int],
|
||||
patch_size: int,
|
||||
tubelet_size: int,
|
||||
embed_dim: int,
|
||||
action_embed_dim: int,
|
||||
predictor_embed_dim: int,
|
||||
depth: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float,
|
||||
num_action_tokens_per_step: int,
|
||||
use_extrinsics: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.is_frame_causal = True
|
||||
self.use_extrinsics = use_extrinsics
|
||||
self.predictor_embed = nn.Linear(embed_dim, predictor_embed_dim, bias=True)
|
||||
self.action_encoder = nn.Linear(action_embed_dim, predictor_embed_dim, bias=True)
|
||||
self.state_encoder = nn.Linear(action_embed_dim, predictor_embed_dim, bias=True)
|
||||
self.extrinsics_encoder = nn.Linear(action_embed_dim - 1, predictor_embed_dim, bias=True)
|
||||
|
||||
self.img_height, self.img_width = img_size
|
||||
self.patch_size = patch_size
|
||||
self.num_frames = num_frames
|
||||
self.tubelet_size = tubelet_size
|
||||
self.grid_height = self.img_height // self.patch_size
|
||||
self.grid_width = self.img_width // self.patch_size
|
||||
|
||||
self.predictor_blocks = nn.ModuleList(
|
||||
[
|
||||
ACBlock(
|
||||
dim=predictor_embed_dim,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=True,
|
||||
drop=0.0,
|
||||
attn_drop=0.0,
|
||||
drop_path=0.0,
|
||||
norm_layer=lambda dim: nn.LayerNorm(dim, eps=1e-6),
|
||||
grid_size=self.grid_height,
|
||||
use_rope=True,
|
||||
)
|
||||
for _ in range(depth)
|
||||
]
|
||||
)
|
||||
self.predictor_norm = nn.LayerNorm(predictor_embed_dim, eps=1e-6)
|
||||
self.predictor_proj = nn.Linear(predictor_embed_dim, embed_dim, bias=True)
|
||||
self.num_action_tokens_per_step = num_action_tokens_per_step
|
||||
|
||||
@property
|
||||
def norm(self) -> nn.LayerNorm:
|
||||
return self.predictor_norm
|
||||
|
||||
@property
|
||||
def proj(self) -> nn.Linear:
|
||||
return self.predictor_proj
|
||||
|
||||
def forward(
|
||||
self,
|
||||
frame_tokens: torch.Tensor,
|
||||
action_tokens: torch.Tensor,
|
||||
extrinsics: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
# starVLA input convention: frame_tokens [B, T*H*W, D], actions [B, T*A, D].
|
||||
x = self.predictor_embed(frame_tokens)
|
||||
batch_size, num_context_tokens, hidden_dim = x.size()
|
||||
num_frames = num_context_tokens // (self.grid_height * self.grid_width)
|
||||
|
||||
actions = self.action_encoder(action_tokens)
|
||||
actions = actions.view(batch_size, num_frames, -1, hidden_dim)
|
||||
cond_tokens = actions.shape[2]
|
||||
|
||||
x = x.view(batch_size, num_frames, self.grid_height * self.grid_width, hidden_dim)
|
||||
if self.use_extrinsics:
|
||||
if extrinsics is None:
|
||||
raise ValueError("extrinsics are required when use_extrinsics=True.")
|
||||
cond_tokens += 1
|
||||
extrinsic_tokens = self.extrinsics_encoder(extrinsics).unsqueeze(2)
|
||||
x = torch.cat([actions, extrinsic_tokens, x], dim=2).flatten(1, 2)
|
||||
else:
|
||||
x = torch.cat([actions, x], dim=2).flatten(1, 2)
|
||||
|
||||
attn_mask = build_action_block_causal_attention_mask(
|
||||
num_frames, self.grid_height, self.grid_width, add_tokens=cond_tokens
|
||||
)
|
||||
attn_mask = attn_mask[: x.size(1), : x.size(1)].to(x.device, non_blocking=True)
|
||||
|
||||
for block in self.predictor_blocks:
|
||||
x = block(
|
||||
x,
|
||||
attn_mask=attn_mask,
|
||||
num_frames=num_frames,
|
||||
grid_height=self.grid_height,
|
||||
grid_width=self.grid_width,
|
||||
action_tokens=cond_tokens,
|
||||
)
|
||||
|
||||
x = x.view(batch_size, num_frames, cond_tokens + self.grid_height * self.grid_width, hidden_dim)
|
||||
x = x[:, :, cond_tokens:, :].flatten(1, 2)
|
||||
x = self.predictor_norm(x)
|
||||
return self.predictor_proj(x)
|
||||
@@ -1,286 +0,0 @@
|
||||
# Remote Inference Architecture
|
||||
|
||||
How `lerobot-policy-server` and `lerobot-rollout --inference.type=remote` decouple GPU-bound policy inference from high-frequency robot control over Zenoh.
|
||||
|
||||
This document explains the **internals** — the wire protocol, threading models, state machines, and safety invariants. For the user-facing guide (CLI quickstarts, deployment), see [`docs/source/remote_inference.mdx`](../../../docs/source/remote_inference.mdx).
|
||||
|
||||
## 1. The problem and the shape of the solution
|
||||
|
||||
Running a large policy (Pi0-class, ~150 ms inference) inside a 33 ms control loop doesn't work, and putting a GPU next to every robot doesn't scale. LeRobot already solved the _local_ version of this problem: `RTCInferenceEngine` runs inference in a background **thread** that fills a thread-safe `ActionQueue`, while the control loop pops one action per tick.
|
||||
|
||||
Remote inference is **that same architecture with the thread boundary replaced by a network boundary**:
|
||||
|
||||
```
|
||||
local RTC: control loop ──ActionQueue── inference thread (same process, same GPU)
|
||||
remote: control loop ──ActionQueue── network worker ══zenoh══ policy server (GPU, elsewhere)
|
||||
```
|
||||
|
||||
Three design commitments follow from this:
|
||||
|
||||
- **The client is a backend, not a CLI.** `RemoteInferenceEngine` plugs into the existing `InferenceEngine` seam (`rollout/inference/base.py`), so every rollout strategy (base, sentry, highlight, dagger, episodic) gets network inference — including dataset recording, pause/resume, and safe teardown — without changing a line.
|
||||
- **The client is weightless.** No policy weights, no policy processors on the edge. `--policy.path` resolves to a config-only `PreTrainedConfig` used for pre-flight validation and action ordering.
|
||||
- **The server is stateless per request.** All chunk state (RTC prefixes, latency tracking, delay computation) lives client-side in the existing `ActionQueue`/`LatencyTracker`. The client ships prefixes + a delay hint with every observation, so a server crash loses zero control state.
|
||||
|
||||
## 2. Component map
|
||||
|
||||
```mermaid
|
||||
flowchart LR
|
||||
subgraph EDGE["Edge (per robot, weightless)"]
|
||||
R[Robot HW] --> S["Rollout strategy<br/>(sentry / dagger / ...)"]
|
||||
S -->|"notify_observation()"| E[RemoteInferenceEngine]
|
||||
E -->|"get_action()"| S
|
||||
S -->|actions| R
|
||||
E --- AQ[("ActionQueue<br/>(chunk buffer)")]
|
||||
end
|
||||
|
||||
subgraph NET["Transport"]
|
||||
Z["zenohd router(s)<br/>(robots dial out, mTLS + ACL)"]
|
||||
end
|
||||
|
||||
subgraph GPU["GPU pod (one model · one device · one process)"]
|
||||
PS[PolicyServer]
|
||||
PS --- SR["SessionRegistry<br/>(per-client mailboxes + pipelines)"]
|
||||
PS --- W["Inference worker<br/>(1 thread, owns GPU)"]
|
||||
W --- P["PreTrainedPolicy<br/>(pre-warmed)"]
|
||||
end
|
||||
|
||||
E <-->|"obs ↑ / chunks ↓ (pub/sub)<br/>status · session · reset (queryables)"| Z
|
||||
Z <--> PS
|
||||
```
|
||||
|
||||
One server process = one pre-warmed `(model, revision, dtype, device)` serving up to `max_sessions` robots. Scaling out = more pods; clients rejected with the current load retry another replica.
|
||||
|
||||
## 3. Where the network cut goes
|
||||
|
||||
The local RTC pipeline is split at the cheapest, most hardware-coupled point. Everything policy-coupled (resize, normalize, tokenize) runs server-side with the **canonical training-time processors**, so serve-time preprocessing is byte-identical to train-time:
|
||||
|
||||
```
|
||||
robot obs (processed dict)
|
||||
→ build_dataset_frame(...) CLIENT cheap, hardware-coupled
|
||||
→ rename_map applied to keys CLIENT wire format = canonical policy keys
|
||||
══════════════════════ network (msgpack + JPEG) ══════════════════════
|
||||
→ prepare_observation_for_inference(...) SERVER tensors, batch dim, device
|
||||
→ per-session preprocessor(...) SERVER stateful within the request
|
||||
→ policy.predict_action_chunk(obs, delay, prefix) SERVER pure for allowlisted policies
|
||||
→ per-session postprocessor(...) SERVER reads state cached at preprocess
|
||||
══════════════════════ network ══════════════════════
|
||||
→ ActionQueue.merge(original, processed, delay, idx_before) CLIENT
|
||||
```
|
||||
|
||||
The reply carries **both** the model-space (`chunk_model`) and robot-space (`chunk_robot`) chunks because `ActionQueue.merge` needs both, and the next request's relative-action prefix re-anchoring needs the robot-space tail.
|
||||
|
||||
## 4. Wire protocol
|
||||
|
||||
### 4.1 Key-expression schema (Zenoh)
|
||||
|
||||
```
|
||||
@lerobot/<model_slug>/<revision>/<task_slug>/<client_uuid>/obs client → server pub/sub
|
||||
@lerobot/<model_slug>/<revision>/<task_slug>/<client_uuid>/action server → client pub/sub
|
||||
@lerobot/<model_slug>/<revision>/<task_slug>/status queryable (capabilities)
|
||||
@lerobot/<model_slug>/<revision>/<task_slug>/session queryable (open / close)
|
||||
@lerobot/<model_slug>/<revision>/<task_slug>/<client_uuid>/reset queryable (episode boundary)
|
||||
@lerobot/<model_slug>/<revision>/<task_slug>/<client_uuid>/alive liveliness token (client)
|
||||
@lerobot/<model_slug>/<revision>/<task_slug>/server/alive liveliness token (server)
|
||||
```
|
||||
|
||||
`@lerobot` is a **verbatim chunk**: wildcards never match it, so third-party `**` subscribers on a shared router cannot scrape the tree. User-supplied segments are sanitized (`sanitize_key_segment`), and the server subscribes with single-depth wildcards only (`.../*/obs`, never `**`).
|
||||
|
||||
Data plane = pub/sub (a late chunk is still usable; a timed-out query reply is not). Control plane = queryables with explicit timeouts (the rmw_zenoh pattern). QoS (`zenoh_utils.py`): actions are `RELIABLE + congestion DROP + express + INTERACTIVE_HIGH` — **never BLOCK**, so one dead robot uplink can never stall the server's publish path; a dropped chunk is recoverable because the client buffer keeps the robot moving.
|
||||
|
||||
### 4.2 Messages
|
||||
|
||||
Every data-plane message carries a **packed little-endian attachment header** (27 bytes, parsed without touching the body):
|
||||
|
||||
| field | type | meaning |
|
||||
| ---------------- | ---- | --------------------------------------------------------- |
|
||||
| `schema_version` | u16 | negotiated at session open; additive-only body evolution |
|
||||
| `msg_type` | u8 | OBS / CHUNK / EVENT |
|
||||
| `seq_id` | u64 | per-session monotonic; echoed in the chunk |
|
||||
| `episode_id` | u32 | bumped by `reset()` |
|
||||
| `client_mono_ns` | i64 | client monotonic clock — **opaque to the server, echoed** |
|
||||
| `session_epoch` | u32 | bumped per (re)connect; stale-epoch chunks dropped |
|
||||
|
||||
Bodies are msgpack (`codec.py`): tensors as raw little-endian bytes + dtype + shape, images JPEG (RGB convention enforced inside the codec; `jpeg_quality=0` = raw). No pickle anywhere — nothing on the wire can carry code.
|
||||
|
||||
**Clock iron rule:** wall-clock instants never cross machines. The client computes RTT from its own monotonic clock via the echoed `client_mono_ns`; the server reports only **durations** (`queue_wait_ms`, `inference_ms`).
|
||||
|
||||
### 4.3 Session lifecycle
|
||||
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
participant C as RemoteInferenceEngine
|
||||
participant S as PolicyServer
|
||||
|
||||
C->>S: GET status (timeout 2s)
|
||||
S-->>C: capabilities (model, action_names, cameras, chunk_size, supports_rtc, ...)
|
||||
C->>S: GET session {op: open, action_names, cameras, state_dim, fps, rtc, task}
|
||||
Note over S: validate (hard: action name ORDER,<br/>cameras, state_dim, schema, capacity)
|
||||
S-->>C: SessionAck {session_id, warnings, rtc_execution_horizon, ...}
|
||||
Note over C,S: both declare liveliness tokens
|
||||
|
||||
loop self-clocked by buffer_time_s (one-in-flight)
|
||||
C->>S: PUB obs {state, images, delay_steps, prefix_model, prefix_robot} + header
|
||||
Note over S: latest-only mailbox → worker →<br/>preprocess → predict_action_chunk → postprocess
|
||||
S-->>C: PUB chunk {chunk_model, chunk_robot, durations} + echoed header
|
||||
Note over C: validate (episode, epoch) → ActionQueue.merge(..., idx_before)
|
||||
end
|
||||
|
||||
C->>S: GET reset {episode_id} (episode boundary, acked)
|
||||
C->>S: GET session {op: close} (graceful stop)
|
||||
```
|
||||
|
||||
The **action-name order check is a hard reject**: it is the contract that maps chunk columns to motors. A mismatch means wrong-joint commands, so the session never opens.
|
||||
|
||||
## 5. The client: `RemoteInferenceEngine`
|
||||
|
||||
File: `src/lerobot/rollout/inference/remote.py`, registered as `--inference.type=remote` (`RemoteInferenceConfig` in `factory.py`).
|
||||
|
||||
### 5.1 Threading model
|
||||
|
||||
| thread | role |
|
||||
| ---------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| main (strategy loop) | `notify_observation()` → latest-only slot; `get_action()` → `ActionQueue.get()` + staleness check + fallback. **Never any I/O.** |
|
||||
| network worker (1) | gate on `buffer_time_s` → snapshot `(seq, episode, epoch)` then `idx_before` + RTC prefixes → publish obs → await chunk (timeout) → revalidate → merge. Owns the state machine and reconnects. |
|
||||
| zenoh callback threads | deposit-only: chunk → bounded queue; server liveliness → event. |
|
||||
|
||||
**One-in-flight is a correctness requirement, not a tuning choice.** `merge(..., idx_before)` validates against the consumption index snapshotted at send time; two in-flight requests would carry conflicting snapshots and corrupt both RTC-replace and append modes. The worker therefore publishes one observation, waits for its chunk (or timeout), then sends the next. A late chunk is accepted only if it answers the latest outstanding `seq_id` _and_ the current `(episode, epoch)`.
|
||||
|
||||
### 5.2 The request cycle
|
||||
|
||||
```
|
||||
queue playback ≤ buffer_time_s? (self-clocking: ~1–4 Hz, not the 30 Hz control rate)
|
||||
├─ snapshot (seq, episode, epoch)
|
||||
├─ snapshot idx_before, prefix_model = queue.get_left_over()[:H],
|
||||
│ prefix_robot = queue.get_processed_left_over()[:H]
|
||||
├─ revalidate (episode, epoch) unchanged ← a reset racing the snapshot skips the cycle
|
||||
├─ delay_steps = ceil(LatencyTracker.max() / dt)
|
||||
├─ publish obs + header
|
||||
├─ await chunk (request_timeout_s)
|
||||
├─ revalidate (episode, epoch) under _anchor_lock ← a stale chunk can never survive a reset
|
||||
└─ merge(chunk_model, chunk_robot, ceil(measured_latency/dt), idx_before); update anchor
|
||||
```
|
||||
|
||||
Because the `LatencyTracker` samples are full network-inclusive cycle times, RTT compensation falls out for free — the same `delay`-trimming machinery local RTC uses absorbs network latency as just more delay.
|
||||
|
||||
### 5.3 Fail-safe state machine
|
||||
|
||||
```mermaid
|
||||
stateDiagram-v2
|
||||
[*] --> CONNECTING
|
||||
CONNECTING --> STREAMING: first merge
|
||||
STREAMING --> DEGRADED: request timeouts,<br/>queue still has actions
|
||||
DEGRADED --> STREAMING: merge
|
||||
DEGRADED --> STALLED: queue empty or<br/>max_action_age_s hit
|
||||
STALLED --> RECONNECTING: timeout streak /<br/>server liveliness drop
|
||||
DEGRADED --> RECONNECTING: timeout streak /<br/>server liveliness drop
|
||||
RECONNECTING --> STREAMING: re-handshake OK<br/>(epoch++)
|
||||
RECONNECTING --> DEAD: offline > max_offline_s,<br/>capability/model mismatch
|
||||
DEAD --> [*]: failed=True → shutdown_event<br/>→ strategy teardown
|
||||
```
|
||||
|
||||
- **DEGRADED**: the chunk buffer _is_ the fault tolerance — 1–3 s of buffered actions makes network blips and clean server drains invisible to the robot.
|
||||
- **Staleness bound** (`max_action_age_s`): `get_action` refuses any action whose source observation is too old, bounding open-loop execution after a stall. Then the **fallback ladder** applies: `hold` (return `None`; the robot holds), `repeat_last`, or `zero` (the safe stop for velocity-controlled robots).
|
||||
- **Watchdog layering**: per-request timeout (catches a _hung-but-connected_ server) → server liveliness token (catches a dead server/router) → staleness bound (the robot-side invariant that holds regardless of why data stopped).
|
||||
- **DEAD** is reserved for hard failures: offline beyond `max_offline_s` with no successful merge (a server that handshakes but never delivers chunks still runs out of budget), or a contract violation on reconnect (model/revision changed, RTC capability flipped — never execute wrong-model chunks). It triggers the exact mechanism local RTC uses: `failed=True` + the global `shutdown_event`, so the existing teardown (return-to-initial-pose) runs unchanged.
|
||||
- **Pause/resume** (DAgger): `pause()` stops publishing; the queue stays intact. A pause during an outage freezes the offline budget so a human correction can never be aborted by `max_offline_s`.
|
||||
|
||||
### 5.4 Episode boundaries
|
||||
|
||||
`reset()` (control thread) atomically — under the same lock the merge path takes — clears the `ActionQueue`, nulls the staleness anchor, bumps `episode_id`, and invalidates the observation slot (the previous episode's final frame must not seed the new one). The worker sends an acked `reset` query, and the next observation header carries the new `episode_id` anyway — so a lost ack costs nothing (the server is stateless per request).
|
||||
|
||||
## 6. The server: `PolicyServer`
|
||||
|
||||
Files: `src/lerobot/policy_server/`. Entry point: `lerobot-policy-server --manifest server.yaml` (draccus dataclasses in `manifest.py`).
|
||||
|
||||
### 6.1 Concurrency model
|
||||
|
||||
zenoh-python is thread-based (no asyncio); callbacks must be deposit-only:
|
||||
|
||||
```
|
||||
zenoh subscriber (.../*/obs) inference worker (1 thread, owns GPU)
|
||||
deposit-only callback: loop:
|
||||
session.deposit(header, body) ──► scheduler picks next session with pending obs
|
||||
(per-client latest-only mailbox) decode → episode-boundary check
|
||||
preprocess → predict_action_chunk(delay, prefix)
|
||||
control queryables (status / postprocess → encode
|
||||
session / reset): validate, publisher.put(.../<uuid>/action)
|
||||
mutate registry, reply inline
|
||||
|
||||
liveliness subscriber (.../*/alive): mark sessions for GC on token DELETE
|
||||
```
|
||||
|
||||
- **Latest-only mailboxes**: the newest observation wins; superseded requests are counted and reported in the next reply (`superseded_seqs`), so drops are visible client-side. The client decides _when_ to request; the server never second-guesses observation content.
|
||||
- **Single inference worker** + round-robin over ready sessions: every ready session gets exactly one inference per cycle — starvation is structurally impossible. Overload degrades into longer cycle times → larger (but correct) client `delay_steps` → eventually the client staleness bound trips and the robot holds. Safe by construction.
|
||||
- The `Scheduler` seam (`scheduler.py`) exists so cross-session micro-batching can land later without redesign (blocked today on `predict_action_chunk` taking a _scalar_ `inference_delay`).
|
||||
- `_inference_lock` serializes the worker's predict path against episode resets arriving on queryable threads (in exclusive mode a `policy.reset()` mid-predict would corrupt the in-flight request).
|
||||
|
||||
### 6.2 Multi-tenancy: engineered, not assumed
|
||||
|
||||
Sharing one policy instance across sessions is only safe when `predict_action_chunk` touches no cross-request instance state. That property is **verified per family and encoded as a registry** (`validation.py`) — never inferred:
|
||||
|
||||
| class | policies | mode | why |
|
||||
| --------------- | ------------------------------------------------- | ----------- | ----------------------------------------------------------------------------------- |
|
||||
| chunk-stateless | `act`, `pi0`, `pi05`, `smolvla` (`n_obs_steps=1`) | `shared` | chunk call is pure (smolvla overwrites its 1-deep queue with the request's own obs) |
|
||||
| chunk-stateful | `diffusion` (and `smolvla` with `n_obs_steps>1`) | `exclusive` | chunk call reads `select_action`-fed `_queues` → server populates them per request |
|
||||
| no chunk API | `sac`, `tdmpc`, ... (no `predict_action_chunk`) | refused | nothing to serve |
|
||||
| unverified | any other chunk-API policy | `exclusive` | a manifest can force `exclusive`, but never `shared` for an unverified policy |
|
||||
|
||||
The real multi-tenancy hazard is **processor state**, not just policy purity: `RelativeActionsProcessorStep` caches `_last_state` at preprocess and the postprocessor reads it back. The server therefore builds a **fresh pre/post pipeline pair per session** — two robots at different joint positions can never cross-contaminate each other's action conversions. `policy.reset()` is **never** called in shared mode (it is global to the shared instance).
|
||||
|
||||
### 6.3 Statelessness and the RTC prefix
|
||||
|
||||
The server holds no cross-request control state. Each observation ships everything inference needs:
|
||||
|
||||
- `inference_delay_steps` — computed client-side from network-inclusive latency.
|
||||
- `prefix_model` — the unexecuted tail of the previous chunk in model space (feeds `prev_chunk_left_over`).
|
||||
- `prefix_robot` — the same tail in robot space. For relative-action policies the server **re-anchors** it against the state cached by _this request's_ preprocess (`reanchor_relative_rtc_prefix`, mirroring `rtc.py`), so the prefix is expressed relative to where the robot actually is now.
|
||||
|
||||
Consequences: reconnects are trivial, horizontal scaling is trivial, and a `kill -9` on the server loses nothing the client can't re-send.
|
||||
|
||||
### 6.4 Episode and reconnect hygiene
|
||||
|
||||
- Fresh sessions start at the `episode_id = -1` sentinel: the **first** observation of any session always triggers the boundary branch (pipelines reset; exclusive policies `reset()`), so a mid-episode reconnect can never inherit stale state.
|
||||
- Session replacement is identity-checked (`SessionRegistry.remove(expected=...)`): a GC sweep that snapshotted an old session can never tear down its just-handshaked replacement.
|
||||
- Liveliness GC double-checks with an explicit liveliness `get` before closing: the token key is per-client (not per-epoch), so a _late_ DELETE from a previous incarnation must not kill the live session.
|
||||
- Drain (`SIGTERM`): drop the liveliness token first (clients ride their buffers), finish the in-flight inference, undeclare the control surface, then close. Clients reconnect to another replica invisibly.
|
||||
|
||||
## 7. Latency budget (why the transport is never the bottleneck)
|
||||
|
||||
| stage | LAN | WAN (50 ms RTT) |
|
||||
| ------------------------------ | ------------- | --------------- |
|
||||
| JPEG encode + serialize (edge) | 2–9 ms | 2–9 ms |
|
||||
| uplink | ~2 ms | ~54 ms |
|
||||
| decode + canonical preprocess | 4–10 ms | 4–10 ms |
|
||||
| **inference** | **15–150 ms** | **15–150 ms** |
|
||||
| postprocess + downlink + merge | ~2 ms | ~27 ms |
|
||||
|
||||
Inference dominates (60–85% on LAN). At 30 fps a WAN deployment lands `delay_steps ≈ 4–8`, comfortably inside RTC execution horizons: WAN degrades smoothness parameters, never correctness. Requests are self-clocked by `buffer_time_s` to ~1–4 Hz per robot, so 300 robots cost ~0.3–10 Mbps each.
|
||||
|
||||
Capacity per GPU: `N_max ≈ 0.8 / (request_rate × inference_time)` → ~40 ACT-class or ~5 Pi0-class clients; `max_sessions` enforces it at session open (rejected clients receive the current load and retry another replica).
|
||||
|
||||
## 8. Observability & reproducibility
|
||||
|
||||
The contract is **fully logged + replayable**, not "deterministic" (no seed controls hardware or network jitter):
|
||||
|
||||
- **Client = source of truth**: recording strategies persist observations + executed actions as usual; the engine tracks `(session_id, seq_id, episode_id)` and per-cycle stats.
|
||||
- **Server**: one JSON audit line per request on the `lerobot.policy_server.audit` logger — `{session_id, client_uuid, seq_id, episode_id, queue_wait_ms, inference_ms, superseded, outcome}` — plus `/healthz` and Prometheus-style `/metrics`, and an optional bounded raw request/response capture (`debug.capture_dir`) for byte-exact offline replay.
|
||||
- Every hop shares `(session_id, seq_id)`, so joining a robot-side stutter to a server-side cause is mechanical.
|
||||
|
||||
## 9. File map
|
||||
|
||||
| path | contents |
|
||||
| ---------------------------------- | ---------------------------------------------------------------------------------------------------- |
|
||||
| `policy_server/schema.py` | wire messages, packed header, key-expression schema + sanitizer |
|
||||
| `policy_server/codec.py` | msgpack bodies, tensor codec (LE bytes), JPEG image codec (RGB convention) |
|
||||
| `policy_server/manifest.py` | draccus config: model, zenoh endpoints/TLS, serving mode, capacity, RTC, health |
|
||||
| `policy_server/validation.py` | serving-mode registry + session-open capability matrix |
|
||||
| `policy_server/session.py` | per-client `Session` (pipelines, latest-only mailbox, stats) + identity-safe registry |
|
||||
| `policy_server/scheduler.py` | `Scheduler` seam; `RoundRobinScheduler` |
|
||||
| `policy_server/zenoh_utils.py` | config builder, QoS profiles, lazy import with install hint |
|
||||
| `policy_server/server.py` | `PolicyServer`: zenoh surface, inference worker, GC, warmup, drain, health/metrics |
|
||||
| `rollout/inference/remote.py` | `RemoteInferenceEngine` (the edge client) |
|
||||
| `rollout/inference/factory.py` | `RemoteInferenceConfig`, `FallbackMode`, factory dispatch |
|
||||
| `scripts/lerobot_policy_server.py` | console entry point (`--manifest` → draccus `--config_path`) |
|
||||
| `tests/policy_server/` | codec/schema/validation/scheduler/session units, server logic, zenoh loopback + chaos, golden parity |
|
||||
|
||||
The golden parity test (`tests/policy_server/test_golden_parity.py`) is the standing contract: the remote request path (encode → decode → `run_inference_request` → encode → decode → merge) must produce **byte-identical** action queues to the local RTC compute path on identical inputs.
|
||||
@@ -1,53 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Multi-client GPU policy serving over Zenoh (``lerobot-policy-server``).
|
||||
|
||||
The wire schema (:mod:`.schema`) and codecs (:mod:`.codec`) are shared
|
||||
with the edge-side :class:`~lerobot.rollout.inference.remote.RemoteInferenceEngine`.
|
||||
Heavy/optional imports (msgpack, zenoh, torch server) are deferred so the
|
||||
schema stays importable without the ``async`` extra.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from .manifest import (
|
||||
DebugSpec,
|
||||
ModelSpec,
|
||||
PolicyServerManifest,
|
||||
ZenohSpec,
|
||||
)
|
||||
from .schema import SCHEMA_VERSION, MsgHeader, service_prefix
|
||||
|
||||
__all__ = [
|
||||
"SCHEMA_VERSION",
|
||||
"DebugSpec",
|
||||
"ModelSpec",
|
||||
"MsgHeader",
|
||||
"PolicyServer",
|
||||
"PolicyServerManifest",
|
||||
"ZenohSpec",
|
||||
"codec",
|
||||
"service_prefix",
|
||||
]
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
import importlib
|
||||
|
||||
if name == "PolicyServer":
|
||||
return importlib.import_module(".server", __name__).PolicyServer
|
||||
if name == "codec":
|
||||
return importlib.import_module(".codec", __name__)
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
@@ -1,262 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""MessagePack codecs for the remote-inference wire schema.
|
||||
|
||||
Encoding rules:
|
||||
- Tensors are raw little-endian bytes + dtype + shape (msgpack's ``bin``
|
||||
type), so decoding is a zero-parse ``np.frombuffer``.
|
||||
- Images are JPEG by default (``jpeg_quality=0`` sends raw bytes). The
|
||||
in-memory convention on both ends is **RGB** uint8 HWC; the OpenCV
|
||||
BGR↔RGB conversion happens inside this module only.
|
||||
- Decoders are tolerant: unknown keys are ignored, missing optional keys
|
||||
take dataclass defaults — schema evolution is additive-only.
|
||||
- No pickle anywhere: nothing in this codec can carry code.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
import msgpack
|
||||
except ImportError as e: # pragma: no cover
|
||||
raise ImportError(
|
||||
"Remote inference requires the 'async' extra: pip install 'lerobot[async]' (eclipse-zenoh + msgpack)"
|
||||
) from e
|
||||
|
||||
from .schema import (
|
||||
IMAGE_CODEC_JPEG,
|
||||
IMAGE_CODEC_RAW,
|
||||
ActionChunkMsg,
|
||||
ObservationMsg,
|
||||
ResetAckMsg,
|
||||
ResetMsg,
|
||||
SessionAckMsg,
|
||||
SessionCloseMsg,
|
||||
SessionOpenMsg,
|
||||
StatusMsg,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tensor codec
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _to_little_endian(arr: np.ndarray) -> np.ndarray:
|
||||
if arr.dtype.byteorder == ">":
|
||||
arr = arr.astype(arr.dtype.newbyteorder("<"))
|
||||
return np.ascontiguousarray(arr)
|
||||
|
||||
|
||||
def encode_tensor(arr: np.ndarray | None) -> dict[str, Any] | None:
|
||||
"""Encode an ndarray as raw little-endian bytes + dtype + shape."""
|
||||
if arr is None:
|
||||
return None
|
||||
arr = np.asarray(arr)
|
||||
# Record the shape before ascontiguousarray, which promotes 0-d to 1-d.
|
||||
shape = list(arr.shape)
|
||||
arr = _to_little_endian(arr)
|
||||
return {"dtype": arr.dtype.str, "shape": shape, "data": arr.tobytes()}
|
||||
|
||||
|
||||
def decode_tensor(obj: dict[str, Any] | None) -> np.ndarray | None:
|
||||
if obj is None:
|
||||
return None
|
||||
dtype = np.dtype(obj["dtype"])
|
||||
if dtype.hasobject:
|
||||
raise ValueError(f"Refusing object dtype {dtype} on the wire")
|
||||
arr = np.frombuffer(obj["data"], dtype=dtype).reshape(obj["shape"])
|
||||
# frombuffer returns a read-only view; copy so downstream torch.from_numpy works.
|
||||
return arr.copy()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Image codec (RGB uint8 HWC on both ends)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def encode_image(img: np.ndarray, jpeg_quality: int = 90) -> dict[str, Any]:
|
||||
"""Encode an RGB uint8 HWC image; ``jpeg_quality=0`` keeps it raw."""
|
||||
img = np.asarray(img)
|
||||
if img.dtype != np.uint8 or img.ndim != 3 or img.shape[2] != 3:
|
||||
raise ValueError(f"Expected uint8 HWC RGB image, got dtype={img.dtype} shape={img.shape}")
|
||||
if jpeg_quality <= 0:
|
||||
return {"codec": IMAGE_CODEC_RAW, "shape": list(img.shape), "data": _to_little_endian(img).tobytes()}
|
||||
ok, buf = cv2.imencode(
|
||||
".jpg", cv2.cvtColor(img, cv2.COLOR_RGB2BGR), [int(cv2.IMWRITE_JPEG_QUALITY), int(jpeg_quality)]
|
||||
)
|
||||
if not ok:
|
||||
raise ValueError("JPEG encoding failed")
|
||||
return {"codec": IMAGE_CODEC_JPEG, "data": buf.tobytes()}
|
||||
|
||||
|
||||
def decode_image(obj: dict[str, Any]) -> np.ndarray:
|
||||
"""Decode to an RGB uint8 HWC image."""
|
||||
codec = obj.get("codec", IMAGE_CODEC_JPEG)
|
||||
if codec == IMAGE_CODEC_RAW:
|
||||
return np.frombuffer(obj["data"], dtype=np.uint8).reshape(obj["shape"]).copy()
|
||||
if codec == IMAGE_CODEC_JPEG:
|
||||
bgr = cv2.imdecode(np.frombuffer(obj["data"], dtype=np.uint8), cv2.IMREAD_COLOR)
|
||||
if bgr is None:
|
||||
raise ValueError("JPEG decoding failed")
|
||||
return cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
|
||||
raise ValueError(f"Unknown image codec: {codec!r}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# msgpack helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _packb(obj: dict[str, Any]) -> bytes:
|
||||
return msgpack.packb(obj, use_bin_type=True)
|
||||
|
||||
|
||||
def _unpackb(data: bytes) -> dict[str, Any]:
|
||||
return msgpack.unpackb(data, raw=False)
|
||||
|
||||
|
||||
def decode_raw(data: bytes) -> dict[str, Any]:
|
||||
"""Decode a body to a plain dict (e.g. to peek a control-plane ``op``)."""
|
||||
return _unpackb(data)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Data-plane messages
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def encode_observation(msg: ObservationMsg) -> bytes:
|
||||
return _packb(
|
||||
{
|
||||
"state": encode_tensor(msg.state),
|
||||
"images": {name: encode_image(img, msg.jpeg_quality) for name, img in msg.images.items()},
|
||||
"task": msg.task,
|
||||
"inference_delay_steps": int(msg.inference_delay_steps),
|
||||
"prefix_model": encode_tensor(msg.prefix_model),
|
||||
"prefix_robot": encode_tensor(msg.prefix_robot),
|
||||
"episode_start": bool(msg.episode_start),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def decode_observation(data: bytes) -> ObservationMsg:
|
||||
obj = _unpackb(data)
|
||||
return ObservationMsg(
|
||||
state=decode_tensor(obj.get("state")),
|
||||
images={name: decode_image(img) for name, img in obj.get("images", {}).items()},
|
||||
task=obj.get("task", ""),
|
||||
inference_delay_steps=obj.get("inference_delay_steps", 0),
|
||||
prefix_model=decode_tensor(obj.get("prefix_model")),
|
||||
prefix_robot=decode_tensor(obj.get("prefix_robot")),
|
||||
episode_start=obj.get("episode_start", False),
|
||||
)
|
||||
|
||||
|
||||
def encode_action_chunk(msg: ActionChunkMsg) -> bytes:
|
||||
return _packb(
|
||||
{
|
||||
"seq_id_echo": int(msg.seq_id_echo),
|
||||
"client_mono_ns_echo": int(msg.client_mono_ns_echo),
|
||||
"episode_id_echo": int(msg.episode_id_echo),
|
||||
"chunk_model": encode_tensor(msg.chunk_model),
|
||||
"chunk_robot": encode_tensor(msg.chunk_robot),
|
||||
"queue_wait_ms": float(msg.queue_wait_ms),
|
||||
"inference_ms": float(msg.inference_ms),
|
||||
"superseded_seqs": int(msg.superseded_seqs),
|
||||
"server_load": float(msg.server_load),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def decode_action_chunk(data: bytes) -> ActionChunkMsg:
|
||||
obj = _unpackb(data)
|
||||
return ActionChunkMsg(
|
||||
seq_id_echo=obj.get("seq_id_echo", 0),
|
||||
client_mono_ns_echo=obj.get("client_mono_ns_echo", 0),
|
||||
episode_id_echo=obj.get("episode_id_echo", 0),
|
||||
chunk_model=decode_tensor(obj.get("chunk_model")),
|
||||
chunk_robot=decode_tensor(obj.get("chunk_robot")),
|
||||
queue_wait_ms=obj.get("queue_wait_ms", 0.0),
|
||||
inference_ms=obj.get("inference_ms", 0.0),
|
||||
superseded_seqs=obj.get("superseded_seqs", 0),
|
||||
server_load=obj.get("server_load", 0.0),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Control-plane messages (flat scalar/list/dict fields → generic codec)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _encode_flat(msg: Any) -> bytes:
|
||||
return _packb(dict(vars(msg).items()))
|
||||
|
||||
|
||||
def _decode_flat(cls: type, data: bytes) -> Any:
|
||||
obj = _unpackb(data)
|
||||
known = set(cls.__dataclass_fields__)
|
||||
return cls(**{k: v for k, v in obj.items() if k in known})
|
||||
|
||||
|
||||
def encode_session_open(msg: SessionOpenMsg) -> bytes:
|
||||
return _encode_flat(msg)
|
||||
|
||||
|
||||
def decode_session_open(data: bytes) -> SessionOpenMsg:
|
||||
return _decode_flat(SessionOpenMsg, data)
|
||||
|
||||
|
||||
def encode_session_ack(msg: SessionAckMsg) -> bytes:
|
||||
return _encode_flat(msg)
|
||||
|
||||
|
||||
def decode_session_ack(data: bytes) -> SessionAckMsg:
|
||||
return _decode_flat(SessionAckMsg, data)
|
||||
|
||||
|
||||
def encode_status(msg: StatusMsg) -> bytes:
|
||||
return _encode_flat(msg)
|
||||
|
||||
|
||||
def decode_status(data: bytes) -> StatusMsg:
|
||||
return _decode_flat(StatusMsg, data)
|
||||
|
||||
|
||||
def encode_reset(msg: ResetMsg) -> bytes:
|
||||
return _encode_flat(msg)
|
||||
|
||||
|
||||
def decode_reset(data: bytes) -> ResetMsg:
|
||||
return _decode_flat(ResetMsg, data)
|
||||
|
||||
|
||||
def encode_reset_ack(msg: ResetAckMsg) -> bytes:
|
||||
return _encode_flat(msg)
|
||||
|
||||
|
||||
def decode_reset_ack(data: bytes) -> ResetAckMsg:
|
||||
return _decode_flat(ResetAckMsg, data)
|
||||
|
||||
|
||||
def encode_session_close(msg: SessionCloseMsg) -> bytes:
|
||||
return _encode_flat(msg)
|
||||
|
||||
|
||||
def decode_session_close(data: bytes) -> SessionCloseMsg:
|
||||
return _decode_flat(SessionCloseMsg, data)
|
||||
@@ -1,139 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Policy-server manifest: one process = one (model, revision, dtype, device) on one GPU.
|
||||
|
||||
Loaded from YAML via ``lerobot-policy-server --manifest server.yaml`` (or
|
||||
individual ``--model.repo_or_path=...`` CLI overrides through draccus).
|
||||
Dynamic model loading is deliberately unsupported: pre-warmed processes
|
||||
keep capacity planning honest and keep code-carrying payloads off the wire.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SERVING_MODE_AUTO = "auto"
|
||||
SERVING_MODE_SHARED = "shared"
|
||||
SERVING_MODE_EXCLUSIVE = "exclusive"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelSpec:
|
||||
"""Which policy this process serves, and where it runs."""
|
||||
|
||||
repo_or_path: str = ""
|
||||
revision: str = "main"
|
||||
# Optional torch dtype cast applied after load (e.g. "bfloat16").
|
||||
dtype: str | None = None
|
||||
device: str = "cuda"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ZenohSpec:
|
||||
"""Transport endpoints and security.
|
||||
|
||||
Robots and servers both *dial out* to a ``zenohd`` router in
|
||||
production (``mode=client``). ``mode=peer`` + ``listen_endpoints``
|
||||
supports router-less LAN and loopback test deployments. Multicast
|
||||
scouting is always disabled: fleet discovery is configuration, not
|
||||
protocol magic.
|
||||
"""
|
||||
|
||||
mode: str = "client" # "client" (via router) | "peer" (direct)
|
||||
connect_endpoints: list[str] = field(default_factory=list)
|
||||
listen_endpoints: list[str] = field(default_factory=list)
|
||||
# mTLS material (PEM paths). All three are required for TLS endpoints.
|
||||
tls_root_ca_certificate: str | None = None
|
||||
tls_connect_certificate: str | None = None
|
||||
tls_connect_private_key: str | None = None
|
||||
# Escape hatch: raw JSON5 merged into the zenoh config last.
|
||||
extra_config_json5: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class DebugSpec:
|
||||
"""Optional bounded request/response capture for offline replay."""
|
||||
|
||||
capture_dir: str | None = None
|
||||
capture_max: int = 256
|
||||
|
||||
|
||||
@dataclass
|
||||
class PolicyServerManifest:
|
||||
"""Top-level config for ``lerobot-policy-server``."""
|
||||
|
||||
model: ModelSpec = field(default_factory=ModelSpec)
|
||||
zenoh: ZenohSpec = field(default_factory=ZenohSpec)
|
||||
|
||||
# The task namespace this service is published under. When
|
||||
# ``pin_task`` is true, session opens with a different task string
|
||||
# are rejected; otherwise VLA clients may set the task per session.
|
||||
default_task: str = ""
|
||||
pin_task: bool = False
|
||||
# Optional override for the <task_slug> key segment (defaults to a
|
||||
# slug of ``default_task``).
|
||||
service_name: str = ""
|
||||
|
||||
# "auto" resolves from the policy classification (shared for
|
||||
# chunk-stateless policies, exclusive otherwise). "exclusive" can be
|
||||
# forced; "shared" cannot override a chunk-stateful classification.
|
||||
serving_mode: str = SERVING_MODE_AUTO
|
||||
max_sessions: int = 5
|
||||
warmup_inferences: int = 2
|
||||
|
||||
# FPS contract: warn on mismatch unless strict.
|
||||
trained_fps: float = 30.0
|
||||
strict_fps: bool = False
|
||||
|
||||
# RTC behaviour for this server process (global to the shared policy:
|
||||
# ``init_rtc_processor`` mutates the policy instance, so it is a
|
||||
# per-process decision, not per-session).
|
||||
rtc: RTCConfig = field(default_factory=RTCConfig)
|
||||
|
||||
# Sessions with no liveliness token and no traffic for this long are
|
||||
# garbage-collected (belt-and-braces behind liveliness GC).
|
||||
session_idle_timeout_s: float = 300.0
|
||||
|
||||
# HTTP health + Prometheus metrics port; 0 disables the endpoint.
|
||||
health_port: int = 9100
|
||||
|
||||
debug: DebugSpec = field(default_factory=DebugSpec)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if not self.model.repo_or_path:
|
||||
raise ValueError("--model.repo_or_path is required (the policy this server serves)")
|
||||
if self.serving_mode not in (SERVING_MODE_AUTO, SERVING_MODE_SHARED, SERVING_MODE_EXCLUSIVE):
|
||||
raise ValueError(f"serving_mode must be one of auto|shared|exclusive, got {self.serving_mode!r}")
|
||||
if self.max_sessions < 1:
|
||||
raise ValueError(f"max_sessions must be >= 1, got {self.max_sessions}")
|
||||
if self.zenoh.mode not in ("client", "peer"):
|
||||
raise ValueError(f"zenoh.mode must be 'client' or 'peer', got {self.zenoh.mode!r}")
|
||||
if self.zenoh.mode == "client" and not self.zenoh.connect_endpoints:
|
||||
raise ValueError("zenoh.connect_endpoints is required in client mode (router address)")
|
||||
tls_fields = (
|
||||
self.zenoh.tls_root_ca_certificate,
|
||||
self.zenoh.tls_connect_certificate,
|
||||
self.zenoh.tls_connect_private_key,
|
||||
)
|
||||
if any(tls_fields) and not all(tls_fields):
|
||||
raise ValueError(
|
||||
"TLS requires all of tls_root_ca_certificate, tls_connect_certificate, "
|
||||
"tls_connect_private_key"
|
||||
)
|
||||
@@ -1,58 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Scheduling seam between the session registry and the inference worker.
|
||||
|
||||
The v1 scheduler is strict round-robin over sessions with a pending
|
||||
observation: every ready session gets exactly one inference per cycle,
|
||||
so starvation is structurally impossible. The seam exists so that
|
||||
cross-session micro-batching can land later without redesign (blocked
|
||||
today on ``predict_action_chunk`` taking a *scalar* ``inference_delay``).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
|
||||
from .session import Session
|
||||
|
||||
|
||||
class Scheduler(abc.ABC):
|
||||
"""Pick which ready session(s) the worker serves next."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def select(self, ready: list[Session]) -> list[Session]:
|
||||
"""Return the sessions to serve this cycle (subset of ``ready``)."""
|
||||
|
||||
|
||||
class RoundRobinScheduler(Scheduler):
|
||||
"""Serve one session per cycle, fairly, in client_uuid order."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._last_served: str | None = None
|
||||
|
||||
def select(self, ready: list[Session]) -> list[Session]:
|
||||
if not ready:
|
||||
return []
|
||||
ring = sorted(ready, key=lambda s: s.client_uuid)
|
||||
if self._last_served is not None:
|
||||
for i, session in enumerate(ring):
|
||||
if session.client_uuid > self._last_served:
|
||||
ring = ring[i:] + ring[:i]
|
||||
break
|
||||
else:
|
||||
pass # wrap: everyone is <= last served, keep sorted order
|
||||
chosen = ring[0]
|
||||
self._last_served = chosen.client_uuid
|
||||
return [chosen]
|
||||
@@ -1,340 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Wire schema for remote policy inference.
|
||||
|
||||
Message dataclasses, the packed attachment header, and the Zenoh
|
||||
key-expression layout shared by the policy server and the remote
|
||||
inference engine. This module is transport-free (no zenoh import) so
|
||||
codecs and validation can be unit-tested without the optional extra.
|
||||
|
||||
Schema discipline: bodies are MessagePack maps decoded tolerantly
|
||||
(unknown keys ignored, missing optional keys defaulted) so evolution is
|
||||
additive-only. Any change to the attachment layout requires a
|
||||
``SCHEMA_VERSION`` bump; versions are negotiated at session open.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import struct
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import numpy as np
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Versioning
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
SCHEMA_VERSION = 1
|
||||
# Oldest schema version this build can still serve.
|
||||
MIN_SUPPORTED_SCHEMA_VERSION = 1
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Attachment header (fixed layout, packed little-endian)
|
||||
#
|
||||
# Parsed without touching the msgpack body so routing, correlation and
|
||||
# supersession decisions never pay deserialization costs. The
|
||||
# ``client_mono_ns`` field is a client-monotonic timestamp that is
|
||||
# OPAQUE to the server: it is echoed back verbatim so the client can
|
||||
# compute round-trip times on its own clock. Wall-clock instants never
|
||||
# cross machines (the clock iron rule).
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_HEADER_STRUCT = struct.Struct("<HBQIqI") # schema_version, msg_type, seq_id, episode_id, mono_ns, epoch
|
||||
|
||||
MSG_TYPE_OBS = 1
|
||||
MSG_TYPE_CHUNK = 2
|
||||
MSG_TYPE_EVENT = 3
|
||||
|
||||
|
||||
@dataclass
|
||||
class MsgHeader:
|
||||
"""Packed per-message header carried in the Zenoh attachment."""
|
||||
|
||||
schema_version: int = SCHEMA_VERSION
|
||||
msg_type: int = MSG_TYPE_OBS
|
||||
seq_id: int = 0
|
||||
episode_id: int = 0
|
||||
client_mono_ns: int = 0
|
||||
session_epoch: int = 0
|
||||
|
||||
def pack(self) -> bytes:
|
||||
return _HEADER_STRUCT.pack(
|
||||
self.schema_version,
|
||||
self.msg_type,
|
||||
self.seq_id,
|
||||
self.episode_id,
|
||||
self.client_mono_ns,
|
||||
self.session_epoch,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def unpack(cls, data: bytes) -> MsgHeader:
|
||||
if len(data) != _HEADER_STRUCT.size:
|
||||
raise ValueError(f"Bad header length: {len(data)} (expected {_HEADER_STRUCT.size})")
|
||||
version, msg_type, seq_id, episode_id, mono_ns, epoch = _HEADER_STRUCT.unpack(data)
|
||||
return cls(
|
||||
schema_version=version,
|
||||
msg_type=msg_type,
|
||||
seq_id=seq_id,
|
||||
episode_id=episode_id,
|
||||
client_mono_ns=mono_ns,
|
||||
session_epoch=epoch,
|
||||
)
|
||||
|
||||
|
||||
HEADER_SIZE = _HEADER_STRUCT.size
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Message bodies
|
||||
#
|
||||
# ``np.ndarray`` fields travel as raw little-endian bytes + dtype + shape
|
||||
# (see codec.py). Images travel JPEG-compressed by default.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
IMAGE_CODEC_JPEG = "jpeg"
|
||||
IMAGE_CODEC_RAW = "raw"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ObservationMsg:
|
||||
"""Client → server: one inference request (data plane)."""
|
||||
|
||||
state: np.ndarray | None = None # float32 [state_dim]
|
||||
images: dict[str, np.ndarray] = field(default_factory=dict) # name -> uint8 HWC RGB
|
||||
task: str = ""
|
||||
inference_delay_steps: int = 0
|
||||
# RTC prefixes: the unexecuted tail of the previous chunk, in model
|
||||
# space (original) and robot space (postprocessed). Both are needed
|
||||
# because the server re-anchors relative-action prefixes against the
|
||||
# current state and the client's ActionQueue.merge needs both chunks.
|
||||
prefix_model: np.ndarray | None = None # float32 [T, action_dim]
|
||||
prefix_robot: np.ndarray | None = None # float32 [T, action_dim]
|
||||
episode_start: bool = False
|
||||
# JPEG quality the images were encoded with; 0 means raw.
|
||||
jpeg_quality: int = 90
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActionChunkMsg:
|
||||
"""Server → client: one action chunk (data plane)."""
|
||||
|
||||
seq_id_echo: int = 0
|
||||
client_mono_ns_echo: int = 0
|
||||
episode_id_echo: int = 0
|
||||
chunk_model: np.ndarray | None = None # float32 [H, action_dim] (pre-postprocessor)
|
||||
chunk_robot: np.ndarray | None = None # float32 [H, action_dim] (postprocessed)
|
||||
# Durations only — measured on the server's monotonic clock, never
|
||||
# compared against client time (the clock iron rule).
|
||||
queue_wait_ms: float = 0.0
|
||||
inference_ms: float = 0.0
|
||||
# Observations from this client that were superseded (overwritten in
|
||||
# the latest-only mailbox) since the previous reply — makes drops visible.
|
||||
superseded_seqs: int = 0
|
||||
server_load: float = 0.0 # active_sessions / max_sessions
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionOpenMsg:
|
||||
"""Client → server (control plane): open and validate a session."""
|
||||
|
||||
op: str = "open"
|
||||
client_uuid: str = ""
|
||||
robot_type: str = ""
|
||||
policy_type: str = ""
|
||||
fps: float = 0.0
|
||||
# Hard sync-safety contract: must equal the server's action feature
|
||||
# names *and order* — this maps chunk columns to motors.
|
||||
action_names: list[str] = field(default_factory=list)
|
||||
camera_names: list[str] = field(default_factory=list) # canonical keys (post-rename)
|
||||
state_dim: int = 0
|
||||
schema_version: int = SCHEMA_VERSION
|
||||
rtc_enabled: bool = False
|
||||
task: str = ""
|
||||
tags: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionAckMsg:
|
||||
"""Server → client (control plane): session accept/reject + capabilities."""
|
||||
|
||||
accepted: bool = False
|
||||
error: str = ""
|
||||
warnings: list[str] = field(default_factory=list)
|
||||
session_id: str = ""
|
||||
model_repo: str = ""
|
||||
model_revision: str = ""
|
||||
policy_type: str = ""
|
||||
action_names: list[str] = field(default_factory=list)
|
||||
expected_cameras: list[str] = field(default_factory=list)
|
||||
state_dim: int = 0
|
||||
chunk_size: int = 0
|
||||
trained_fps: float = 0.0
|
||||
supports_rtc: bool = False
|
||||
rtc_execution_horizon: int = 0
|
||||
serving_mode: str = ""
|
||||
warmed_up: bool = False
|
||||
schema_version: int = SCHEMA_VERSION
|
||||
server_load: float = 0.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class StatusMsg:
|
||||
"""Server → client (control plane): pre-flight capability snapshot."""
|
||||
|
||||
model_repo: str = ""
|
||||
model_revision: str = ""
|
||||
policy_type: str = ""
|
||||
action_names: list[str] = field(default_factory=list)
|
||||
expected_cameras: list[str] = field(default_factory=list)
|
||||
state_dim: int = 0
|
||||
chunk_size: int = 0
|
||||
trained_fps: float = 0.0
|
||||
supports_rtc: bool = False
|
||||
rtc_execution_horizon: int = 0
|
||||
serving_mode: str = ""
|
||||
warmed_up: bool = False
|
||||
min_schema_version: int = MIN_SUPPORTED_SCHEMA_VERSION
|
||||
max_schema_version: int = SCHEMA_VERSION
|
||||
active_sessions: int = 0
|
||||
max_sessions: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResetMsg:
|
||||
"""Client → server (control plane): episode boundary (acknowledged)."""
|
||||
|
||||
client_uuid: str = ""
|
||||
episode_id: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResetAckMsg:
|
||||
"""Server → client: reset acknowledgement."""
|
||||
|
||||
ok: bool = True
|
||||
error: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionCloseMsg:
|
||||
"""Client → server (control plane): graceful session close."""
|
||||
|
||||
op: str = "close"
|
||||
client_uuid: str = ""
|
||||
session_id: str = ""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Key-expression schema
|
||||
#
|
||||
# @lerobot/<service>/<client_uuid>/obs client → server (pub/sub)
|
||||
# @lerobot/<service>/<client_uuid>/action server → client (pub/sub)
|
||||
# @lerobot/<service>/status queryable (capabilities)
|
||||
# @lerobot/<service>/session queryable (open/close)
|
||||
# @lerobot/<service>/<client_uuid>/reset queryable (episode boundary)
|
||||
# @lerobot/<service>/<client_uuid>/alive liveliness token (client)
|
||||
# @lerobot/<service>/server/alive liveliness token (server)
|
||||
#
|
||||
# where <service> = <model_slug>/<revision_slug>/<task_slug>. The task
|
||||
# segment is a *namespace label* derived from the server's default task
|
||||
# (or an explicit service name) — the actual inference task string
|
||||
# travels in the session/observation messages.
|
||||
#
|
||||
# ``@lerobot`` is a verbatim chunk: it is only matched by an identical
|
||||
# chunk, so third-party ``**`` subscribers on a shared router can never
|
||||
# scrape this tree.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
KEY_ROOT = "@lerobot"
|
||||
|
||||
# Conservative allowlist for user-supplied key segments. Everything
|
||||
# else (including '/', '*', '$', '?', '#', whitespace) is folded to '-'.
|
||||
_SEGMENT_SANITIZE_RE = re.compile(r"[^A-Za-z0-9_.\-]+")
|
||||
|
||||
# Reserved final chunks of the key tree; a client UUID must never
|
||||
# collide with them.
|
||||
RESERVED_SEGMENTS = frozenset({"status", "session", "server", "obs", "action", "reset", "alive"})
|
||||
|
||||
|
||||
def sanitize_key_segment(segment: str) -> str:
|
||||
"""Fold an arbitrary string into a single safe Zenoh key chunk."""
|
||||
slug = _SEGMENT_SANITIZE_RE.sub("-", segment.strip()).strip("-.")
|
||||
if not slug:
|
||||
raise ValueError(f"Key segment {segment!r} sanitizes to an empty chunk")
|
||||
if slug in RESERVED_SEGMENTS:
|
||||
raise ValueError(f"Key segment {segment!r} collides with reserved chunk {slug!r}")
|
||||
return slug
|
||||
|
||||
|
||||
def service_prefix(model_id: str, revision: str, task: str) -> str:
|
||||
"""Build the shared key prefix for one served (model, revision, task) triple."""
|
||||
return "/".join(
|
||||
(
|
||||
KEY_ROOT,
|
||||
sanitize_key_segment(model_id),
|
||||
sanitize_key_segment(revision or "main"),
|
||||
sanitize_key_segment(task or "default"),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def obs_key(prefix: str, client_uuid: str) -> str:
|
||||
return f"{prefix}/{sanitize_key_segment(client_uuid)}/obs"
|
||||
|
||||
|
||||
def action_key(prefix: str, client_uuid: str) -> str:
|
||||
return f"{prefix}/{sanitize_key_segment(client_uuid)}/action"
|
||||
|
||||
|
||||
def reset_key(prefix: str, client_uuid: str) -> str:
|
||||
return f"{prefix}/{sanitize_key_segment(client_uuid)}/reset"
|
||||
|
||||
|
||||
def client_alive_key(prefix: str, client_uuid: str) -> str:
|
||||
return f"{prefix}/{sanitize_key_segment(client_uuid)}/alive"
|
||||
|
||||
|
||||
def status_key(prefix: str) -> str:
|
||||
return f"{prefix}/status"
|
||||
|
||||
|
||||
def session_key(prefix: str) -> str:
|
||||
return f"{prefix}/session"
|
||||
|
||||
|
||||
def server_alive_key(prefix: str) -> str:
|
||||
return f"{prefix}/server/alive"
|
||||
|
||||
|
||||
# Single-depth wildcards only — '**' would also match status/session/alive.
|
||||
def obs_wildcard(prefix: str) -> str:
|
||||
return f"{prefix}/*/obs"
|
||||
|
||||
|
||||
def reset_wildcard(prefix: str) -> str:
|
||||
return f"{prefix}/*/reset"
|
||||
|
||||
|
||||
def client_alive_wildcard(prefix: str) -> str:
|
||||
return f"{prefix}/*/alive"
|
||||
|
||||
|
||||
def client_uuid_from_key(key: str) -> str:
|
||||
"""Extract the client UUID chunk from an obs/reset/alive key."""
|
||||
parts = key.split("/")
|
||||
if len(parts) < 2:
|
||||
raise ValueError(f"Key {key!r} has no client chunk")
|
||||
return parts[-2]
|
||||
@@ -1,934 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""``lerobot-policy-server``: multi-client GPU inference over Zenoh.
|
||||
|
||||
One process serves one pre-warmed (model, revision, dtype, device) to up
|
||||
to ``max_sessions`` robot clients. The process is **stateless per
|
||||
request**: clients ship RTC prefixes and a delay hint with every
|
||||
observation, so a server crash loses zero control state and reconnects
|
||||
are trivial.
|
||||
|
||||
Concurrency model (pure threads — zenoh-python has no asyncio API):
|
||||
|
||||
zenoh subscriber (.../*/obs) inference worker (1 thread, owns GPU)
|
||||
deposit-only callback: loop:
|
||||
session.deposit(header, body) ──► pick next session with pending obs (RR)
|
||||
(per-client latest-only slot) decode → per-session preprocess
|
||||
predict_action_chunk(delay, prefix)
|
||||
control queryables (status/session/ per-session postprocess → encode
|
||||
reset): validate, mutate session publisher.put(.../<uuid>/action)
|
||||
registry, reply inline
|
||||
|
||||
The single worker thread serializes GPU access; newest-wins mailboxes
|
||||
mean overload degrades into longer cycle times (larger but correct
|
||||
client delays), never into queue buildup.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import http.server
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
import uuid as uuid_module
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.configs import FeatureType
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.policies.factory import get_policy_class, make_pre_post_processors
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.rtc.relative import reanchor_relative_rtc_prefix
|
||||
from lerobot.policies.utils import populate_queues, prepare_observation_for_inference
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
|
||||
|
||||
from . import codec
|
||||
from .manifest import PolicyServerManifest
|
||||
from .scheduler import RoundRobinScheduler, Scheduler
|
||||
from .schema import (
|
||||
SCHEMA_VERSION,
|
||||
ActionChunkMsg,
|
||||
MsgHeader,
|
||||
ObservationMsg,
|
||||
ResetAckMsg,
|
||||
SessionAckMsg,
|
||||
SessionCloseMsg,
|
||||
SessionOpenMsg,
|
||||
StatusMsg,
|
||||
action_key,
|
||||
client_alive_key,
|
||||
client_alive_wildcard,
|
||||
client_uuid_from_key,
|
||||
obs_wildcard,
|
||||
reset_wildcard,
|
||||
server_alive_key,
|
||||
service_prefix,
|
||||
session_key,
|
||||
status_key,
|
||||
)
|
||||
from .session import Session, SessionRegistry
|
||||
from .validation import (
|
||||
PolicyClassification,
|
||||
classify_policy,
|
||||
resolve_serving_mode,
|
||||
validate_session_open,
|
||||
)
|
||||
from .zenoh_utils import action_publisher_qos, build_zenoh_config, import_zenoh
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
audit_logger = logging.getLogger("lerobot.policy_server.audit")
|
||||
|
||||
# Grace period after a client liveliness token drops before its session
|
||||
# is garbage-collected (rides out router blips and reconnects).
|
||||
_LIVELINESS_GC_GRACE_S = 5.0
|
||||
# Worker idle wait between work-event checks (also paces the GC sweep).
|
||||
_WORKER_IDLE_WAIT_S = 0.05
|
||||
|
||||
|
||||
def _normalize_prev_actions_length(prev_actions: torch.Tensor, target_steps: int) -> torch.Tensor:
|
||||
"""Pad or truncate RTC prefix actions to a fixed length (mirrors rtc.py)."""
|
||||
if prev_actions.ndim != 2:
|
||||
raise ValueError(f"Expected 2D [T, A] tensor, got shape={tuple(prev_actions.shape)}")
|
||||
steps, action_dim = prev_actions.shape
|
||||
if steps == target_steps:
|
||||
return prev_actions
|
||||
if steps > target_steps:
|
||||
return prev_actions[:target_steps]
|
||||
padded = torch.zeros((target_steps, action_dim), dtype=prev_actions.dtype, device=prev_actions.device)
|
||||
padded[:steps] = prev_actions
|
||||
return padded
|
||||
|
||||
|
||||
class PolicyServer:
|
||||
"""Zenoh policy server: control-plane queryables + data-plane pub/sub + one GPU worker."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
manifest: PolicyServerManifest,
|
||||
*,
|
||||
policy: PreTrainedPolicy | None = None,
|
||||
policy_cfg: PreTrainedConfig | None = None,
|
||||
processor_factory: Callable[[], tuple[Any, Any]] | None = None,
|
||||
classification: PolicyClassification | None = None,
|
||||
scheduler: Scheduler | None = None,
|
||||
) -> None:
|
||||
"""``policy``/``policy_cfg``/``processor_factory``/``classification``
|
||||
are injection points for tests; production loads everything from
|
||||
the manifest via :meth:`load_policy`.
|
||||
"""
|
||||
self._manifest = manifest
|
||||
self._device = torch.device(manifest.model.device)
|
||||
self._policy = policy
|
||||
self._policy_cfg = policy_cfg
|
||||
self._processor_factory = processor_factory
|
||||
self._classification = classification
|
||||
self._scheduler = scheduler or RoundRobinScheduler()
|
||||
|
||||
self._serving_mode: str = ""
|
||||
self._max_sessions: int = manifest.max_sessions
|
||||
self._rtc_active = False
|
||||
self._warmed_up = False
|
||||
|
||||
self.registry = SessionRegistry()
|
||||
self._registry_lock = threading.Lock() # serializes open/close/GC decisions
|
||||
# Serializes inference against episode resets: in exclusive mode a
|
||||
# reset (policy.reset(), pipeline reset) arriving on a queryable
|
||||
# thread mid-predict would corrupt the in-flight request's state.
|
||||
self._inference_lock = threading.Lock()
|
||||
|
||||
self._zenoh = None
|
||||
self._declarations: list[Any] = []
|
||||
self._alive_token = None
|
||||
|
||||
self._work = threading.Event()
|
||||
self._shutdown = threading.Event()
|
||||
self._worker: threading.Thread | None = None
|
||||
self._health_server: http.server.ThreadingHTTPServer | None = None
|
||||
|
||||
self._unknown_clients_warned: set[str] = set()
|
||||
self._capture_count = 0
|
||||
|
||||
self.metrics: dict[str, float] = {
|
||||
"requests_total": 0,
|
||||
"errors_total": 0,
|
||||
"superseded_total": 0,
|
||||
"dropped_unknown_client_total": 0,
|
||||
"sessions_opened_total": 0,
|
||||
"sessions_closed_total": 0,
|
||||
}
|
||||
self._metrics_lock = threading.Lock()
|
||||
|
||||
task_slug_source = manifest.service_name or manifest.default_task or "default"
|
||||
self.prefix = service_prefix(manifest.model.repo_or_path, manifest.model.revision, task_slug_source)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Loading & warmup
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def load_policy(self) -> None:
|
||||
"""Load config + weights, apply RTC settings, classify, warm up."""
|
||||
manifest = self._manifest
|
||||
if self._policy is None:
|
||||
logger.info(
|
||||
"Loading policy from '%s' (revision=%s)...",
|
||||
manifest.model.repo_or_path,
|
||||
manifest.model.revision,
|
||||
)
|
||||
policy_cfg = PreTrainedConfig.from_pretrained(manifest.model.repo_or_path)
|
||||
policy_cfg.pretrained_path = manifest.model.repo_or_path
|
||||
policy_class = get_policy_class(policy_cfg.type)
|
||||
policy = policy_class.from_pretrained(manifest.model.repo_or_path, config=policy_cfg)
|
||||
self._policy = policy
|
||||
self._policy_cfg = policy_cfg
|
||||
elif self._policy_cfg is None:
|
||||
self._policy_cfg = self._policy.config
|
||||
|
||||
if self._classification is None:
|
||||
self._classification = classify_policy(self._policy)
|
||||
logger.info("Policy classification: %s", self._classification.reason)
|
||||
|
||||
self._serving_mode, self._max_sessions = resolve_serving_mode(self._classification, manifest)
|
||||
logger.info("Serving mode: %s (max_sessions=%d)", self._serving_mode, self._max_sessions)
|
||||
|
||||
# RTC is a per-process decision: init_rtc_processor mutates the
|
||||
# shared policy instance.
|
||||
self._rtc_active = (
|
||||
manifest.rtc.enabled
|
||||
and self._classification.supports_rtc
|
||||
and hasattr(self._policy.config, "rtc_config")
|
||||
)
|
||||
if self._rtc_active:
|
||||
self._policy.config.rtc_config = manifest.rtc
|
||||
if hasattr(self._policy, "init_rtc_processor"):
|
||||
self._policy.init_rtc_processor()
|
||||
logger.info("RTC active (execution_horizon=%d)", manifest.rtc.execution_horizon)
|
||||
|
||||
if manifest.model.dtype:
|
||||
self._policy = self._policy.to(getattr(torch, manifest.model.dtype))
|
||||
self._policy = self._policy.to(self._device)
|
||||
self._policy.eval()
|
||||
|
||||
if not self.action_names:
|
||||
logger.warning(
|
||||
"Policy config has no action_feature_names: the action-order contract "
|
||||
"cannot be enforced at session open. Clients are trusted to match training order."
|
||||
)
|
||||
|
||||
if manifest.warmup_inferences > 0:
|
||||
self._warmup(manifest.warmup_inferences)
|
||||
self._warmed_up = True
|
||||
|
||||
def make_session_processors(self) -> tuple[Any, Any]:
|
||||
"""Build a fresh per-session pre/post pipeline pair.
|
||||
|
||||
The rename step is forced to identity: clients apply their
|
||||
rename map before encoding, so the wire format is canonical
|
||||
policy-feature keys across heterogeneous robots.
|
||||
"""
|
||||
if self._processor_factory is not None:
|
||||
return self._processor_factory()
|
||||
return make_pre_post_processors(
|
||||
policy_cfg=self._policy_cfg,
|
||||
pretrained_path=self._policy_cfg.pretrained_path,
|
||||
preprocessor_overrides={
|
||||
"device_processor": {"device": str(self._device)},
|
||||
"rename_observations_processor": {"rename_map": {}},
|
||||
},
|
||||
)
|
||||
|
||||
def _warmup(self, n: int) -> None:
|
||||
"""Run dummy forwards through the full request path (covers compile/caches)."""
|
||||
logger.info("Warmup: %d inferences...", n)
|
||||
obs = self._synthetic_observation()
|
||||
preprocessor, postprocessor = self.make_session_processors()
|
||||
session = Session(
|
||||
session_id="warmup",
|
||||
client_uuid="warmup",
|
||||
task=self._manifest.default_task,
|
||||
robot_type="",
|
||||
rtc_enabled=self._rtc_active,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
)
|
||||
reply = self.run_inference_request(session, MsgHeader(), obs)
|
||||
if self._rtc_active and reply.chunk_model is not None and n > 1:
|
||||
# Exercise the prefix-conditioned path too so its compile/cache
|
||||
# cost isn't paid by the first real RTC request.
|
||||
action_dim = reply.chunk_model.shape[-1]
|
||||
horizon = self._manifest.rtc.execution_horizon
|
||||
obs.prefix_model = np.zeros((horizon, action_dim), dtype=np.float32)
|
||||
obs.prefix_robot = np.zeros((horizon, action_dim), dtype=np.float32)
|
||||
obs.inference_delay_steps = 1
|
||||
for _ in range(n - 1):
|
||||
self.run_inference_request(session, MsgHeader(), obs)
|
||||
session.close()
|
||||
# Stateful policies must not carry warmup observations into real sessions.
|
||||
if self._serving_mode == "exclusive":
|
||||
self._policy.reset()
|
||||
logger.info("Warmup complete")
|
||||
|
||||
def _synthetic_observation(self) -> ObservationMsg:
|
||||
cfg = self._policy_cfg
|
||||
state_dim = self.state_dim or 1
|
||||
images = {}
|
||||
for key, feature in cfg.input_features.items():
|
||||
if feature.type == FeatureType.VISUAL:
|
||||
channels, height, width = feature.shape
|
||||
images[key] = np.zeros((height, width, channels), dtype=np.uint8)
|
||||
return ObservationMsg(
|
||||
state=np.zeros(state_dim, dtype=np.float32),
|
||||
images=images,
|
||||
task=self._manifest.default_task,
|
||||
jpeg_quality=0,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Capabilities
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@property
|
||||
def action_names(self) -> list[str]:
|
||||
names = getattr(self._policy_cfg, "action_feature_names", None)
|
||||
return list(names) if names else []
|
||||
|
||||
@property
|
||||
def state_dim(self) -> int:
|
||||
cfg = self._policy_cfg
|
||||
for key, feature in getattr(cfg, "input_features", {}).items():
|
||||
if key == OBS_STATE or feature.type == FeatureType.STATE:
|
||||
return int(feature.shape[0])
|
||||
return 0
|
||||
|
||||
@property
|
||||
def chunk_size(self) -> int:
|
||||
cfg = self._policy_cfg
|
||||
for attr in ("chunk_size", "n_action_steps", "horizon"):
|
||||
value = getattr(cfg, attr, None)
|
||||
if value:
|
||||
return int(value)
|
||||
return 0
|
||||
|
||||
def status_snapshot(self) -> StatusMsg:
|
||||
cfg = self._policy_cfg
|
||||
expected_cameras = [
|
||||
key
|
||||
for key, feature in getattr(cfg, "input_features", {}).items()
|
||||
if feature.type == FeatureType.VISUAL
|
||||
]
|
||||
return StatusMsg(
|
||||
model_repo=self._manifest.model.repo_or_path,
|
||||
model_revision=self._manifest.model.revision,
|
||||
policy_type=getattr(cfg, "type", "") or getattr(self._policy, "name", ""),
|
||||
action_names=self.action_names,
|
||||
expected_cameras=expected_cameras,
|
||||
state_dim=self.state_dim,
|
||||
chunk_size=self.chunk_size,
|
||||
trained_fps=self._manifest.trained_fps,
|
||||
supports_rtc=self._rtc_active,
|
||||
rtc_execution_horizon=self._manifest.rtc.execution_horizon if self._rtc_active else 0,
|
||||
serving_mode=self._serving_mode,
|
||||
warmed_up=self._warmed_up,
|
||||
active_sessions=len(self.registry),
|
||||
max_sessions=self._max_sessions,
|
||||
)
|
||||
|
||||
@property
|
||||
def server_load(self) -> float:
|
||||
return len(self.registry) / max(1, self._max_sessions)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# The per-request inference path (pure: no zenoh — parity-testable)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def run_inference_request(
|
||||
self, session: Session, header: MsgHeader, obs: ObservationMsg
|
||||
) -> ActionChunkMsg:
|
||||
"""Mirror of the local RTC loop's compute step (rtc.py), minus the queue merge."""
|
||||
t0 = time.perf_counter()
|
||||
|
||||
obs_np: dict[str, np.ndarray] = {}
|
||||
if obs.state is not None:
|
||||
obs_np[OBS_STATE] = np.asarray(obs.state, dtype=np.float32)
|
||||
for name, img in obs.images.items():
|
||||
obs_np[name] = img
|
||||
|
||||
task = obs.task or session.task or self._manifest.default_task
|
||||
batch = prepare_observation_for_inference(obs_np, self._device, task, session.robot_type)
|
||||
batch["task"] = [task]
|
||||
|
||||
preprocessed = session.preprocessor(batch)
|
||||
|
||||
use_rtc = self._rtc_active and session.rtc_enabled
|
||||
if use_rtc:
|
||||
delay = max(0, int(obs.inference_delay_steps))
|
||||
prev_actions: torch.Tensor | None = None
|
||||
if obs.prefix_model is not None and obs.prefix_model.size:
|
||||
prev_actions = torch.from_numpy(np.ascontiguousarray(obs.prefix_model)).to(self._device)
|
||||
|
||||
if prev_actions is not None and session.relative_step is not None:
|
||||
# Re-anchor the absolute leftover tail against the state
|
||||
# cached by THIS request's preprocess (mirrors rtc.py).
|
||||
raw_state = session.relative_step.get_cached_state()
|
||||
prefix_robot = obs.prefix_robot
|
||||
if raw_state is not None and prefix_robot is not None and prefix_robot.size:
|
||||
prev_actions = reanchor_relative_rtc_prefix(
|
||||
prev_actions_absolute=torch.from_numpy(np.ascontiguousarray(prefix_robot)),
|
||||
current_state=raw_state,
|
||||
relative_step=session.relative_step,
|
||||
normalizer_step=session.normalizer_step,
|
||||
policy_device=self._device,
|
||||
)
|
||||
|
||||
if prev_actions is not None:
|
||||
prev_actions = _normalize_prev_actions_length(
|
||||
prev_actions, target_steps=self._manifest.rtc.execution_horizon
|
||||
)
|
||||
|
||||
actions = self._policy.predict_action_chunk(
|
||||
preprocessed, inference_delay=delay, prev_chunk_left_over=prev_actions
|
||||
)
|
||||
else:
|
||||
if self._classification is not None and self._classification.needs_queue_population:
|
||||
preprocessed = self._populate_select_queues(preprocessed)
|
||||
actions = self._policy.predict_action_chunk(preprocessed)
|
||||
|
||||
original = actions.squeeze(0).clone()
|
||||
processed = session.postprocessor(actions).squeeze(0)
|
||||
inference_ms = (time.perf_counter() - t0) * 1e3
|
||||
|
||||
session.stats.requests += 1
|
||||
session.stats.last_inference_ms = inference_ms
|
||||
superseded = session.take_superseded()
|
||||
|
||||
return ActionChunkMsg(
|
||||
seq_id_echo=header.seq_id,
|
||||
client_mono_ns_echo=header.client_mono_ns,
|
||||
episode_id_echo=header.episode_id,
|
||||
chunk_model=original.detach().to("cpu", torch.float32).numpy(),
|
||||
chunk_robot=processed.detach().to("cpu", torch.float32).numpy(),
|
||||
inference_ms=inference_ms,
|
||||
superseded_seqs=superseded,
|
||||
server_load=self.server_load,
|
||||
)
|
||||
|
||||
def _populate_select_queues(self, batch: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Exclusive-mode shim for select_action-fed policies (diffusion family).
|
||||
|
||||
Mirrors ``DiffusionPolicy.select_action``: stack camera features
|
||||
into OBS_IMAGES, then populate the policy's observation queues so
|
||||
``predict_action_chunk`` sees the same history it would locally.
|
||||
"""
|
||||
policy = self._policy
|
||||
batch = {k: v for k, v in batch.items() if k != ACTION}
|
||||
if getattr(policy.config, "image_features", None):
|
||||
batch[OBS_IMAGES] = torch.stack([batch[key] for key in policy.config.image_features], dim=-4)
|
||||
policy._queues = populate_queues(policy._queues, batch)
|
||||
return batch
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Zenoh wiring
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def start(self) -> None:
|
||||
"""Open zenoh, declare the service surface, start worker + health threads."""
|
||||
if self._policy is None or not self._warmed_up:
|
||||
self.load_policy()
|
||||
|
||||
zenoh = import_zenoh()
|
||||
spec = self._manifest.zenoh
|
||||
self._zenoh = zenoh.open(
|
||||
build_zenoh_config(
|
||||
mode=spec.mode,
|
||||
connect_endpoints=spec.connect_endpoints,
|
||||
listen_endpoints=spec.listen_endpoints,
|
||||
tls_root_ca_certificate=spec.tls_root_ca_certificate,
|
||||
tls_connect_certificate=spec.tls_connect_certificate,
|
||||
tls_connect_private_key=spec.tls_connect_private_key,
|
||||
extra_config_json5=spec.extra_config_json5,
|
||||
)
|
||||
)
|
||||
handlers = zenoh.handlers
|
||||
|
||||
# Data plane: wildcard subscriber, deposit-only callback.
|
||||
self._declarations.append(
|
||||
self._zenoh.declare_subscriber(obs_wildcard(self.prefix), handlers.Callback(self._on_obs))
|
||||
)
|
||||
# Control plane: queryables reply inline (low rate).
|
||||
self._declarations.append(
|
||||
self._zenoh.declare_queryable(status_key(self.prefix), handlers.Callback(self._on_status_query))
|
||||
)
|
||||
self._declarations.append(
|
||||
self._zenoh.declare_queryable(session_key(self.prefix), handlers.Callback(self._on_session_query))
|
||||
)
|
||||
self._declarations.append(
|
||||
self._zenoh.declare_queryable(
|
||||
reset_wildcard(self.prefix), handlers.Callback(self._on_reset_query)
|
||||
)
|
||||
)
|
||||
# Presence: watch client tokens; publish our own.
|
||||
self._declarations.append(
|
||||
self._zenoh.liveliness().declare_subscriber(
|
||||
client_alive_wildcard(self.prefix), handlers.Callback(self._on_liveliness), history=True
|
||||
)
|
||||
)
|
||||
self._alive_token = self._zenoh.liveliness().declare_token(server_alive_key(self.prefix))
|
||||
|
||||
self._shutdown.clear()
|
||||
self._worker = threading.Thread(target=self._worker_loop, daemon=True, name="PolicyServerWorker")
|
||||
self._worker.start()
|
||||
|
||||
if self._manifest.health_port:
|
||||
self._start_health_server(self._manifest.health_port)
|
||||
|
||||
logger.info(
|
||||
"Policy server up: prefix=%s mode=%s max_sessions=%d rtc=%s",
|
||||
self.prefix,
|
||||
self._serving_mode,
|
||||
self._max_sessions,
|
||||
self._rtc_active,
|
||||
)
|
||||
|
||||
def serve_forever(self) -> None:
|
||||
try:
|
||||
while not self._shutdown.is_set():
|
||||
self._shutdown.wait(timeout=0.5)
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Interrupted — draining")
|
||||
finally:
|
||||
self.stop()
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Drain: drop the liveliness token first (clients ride their buffers
|
||||
through the drain), finish the in-flight inference, then close."""
|
||||
if self._alive_token is not None:
|
||||
with contextlib.suppress(Exception):
|
||||
self._alive_token.undeclare()
|
||||
self._alive_token = None
|
||||
|
||||
self._shutdown.set()
|
||||
self._work.set()
|
||||
if self._worker is not None and self._worker.is_alive():
|
||||
self._worker.join(timeout=10.0)
|
||||
if self._worker.is_alive():
|
||||
logger.warning("Inference worker did not join within 10s")
|
||||
self._worker = None
|
||||
|
||||
# Undeclare the control/data surface BEFORE closing sessions so a
|
||||
# late session open cannot be accepted by a server that has
|
||||
# already drained its worker.
|
||||
for declaration in self._declarations:
|
||||
with contextlib.suppress(Exception):
|
||||
declaration.undeclare()
|
||||
self._declarations.clear()
|
||||
|
||||
for session in self.registry.snapshot():
|
||||
self._close_session(session, reason="server shutdown")
|
||||
|
||||
if self._zenoh is not None:
|
||||
with contextlib.suppress(Exception):
|
||||
self._zenoh.close()
|
||||
self._zenoh = None
|
||||
|
||||
if self._health_server is not None:
|
||||
self._health_server.shutdown()
|
||||
self._health_server = None
|
||||
logger.info("Policy server stopped")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Zenoh callbacks (deposit-only on the data plane)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _on_obs(self, sample: Any) -> None:
|
||||
try:
|
||||
attachment = sample.attachment
|
||||
if attachment is None:
|
||||
return
|
||||
header = MsgHeader.unpack(attachment.to_bytes())
|
||||
if header.schema_version != SCHEMA_VERSION:
|
||||
return
|
||||
client_uuid = client_uuid_from_key(str(sample.key_expr))
|
||||
session = self.registry.get(client_uuid)
|
||||
if session is None:
|
||||
self._bump("dropped_unknown_client_total")
|
||||
# Bounded: garbage publishers must not grow this set (or
|
||||
# the log) without limit.
|
||||
if (
|
||||
client_uuid not in self._unknown_clients_warned
|
||||
and len(self._unknown_clients_warned) < 256
|
||||
):
|
||||
self._unknown_clients_warned.add(client_uuid)
|
||||
logger.warning(
|
||||
"Observation from unknown client '%s' (no session) — dropping", client_uuid
|
||||
)
|
||||
return
|
||||
session.deposit(header, sample.payload.to_bytes())
|
||||
self._work.set()
|
||||
except Exception as e: # noqa: BLE001 — a malformed sample must never kill the subscriber
|
||||
logger.error("obs callback error: %s", e)
|
||||
|
||||
def _on_liveliness(self, sample: Any) -> None:
|
||||
try:
|
||||
import zenoh
|
||||
|
||||
client_uuid = client_uuid_from_key(str(sample.key_expr))
|
||||
session = self.registry.get(client_uuid)
|
||||
if session is None:
|
||||
return
|
||||
if sample.kind == zenoh.SampleKind.DELETE:
|
||||
session.alive = False
|
||||
session.token_dropped_mono = time.monotonic()
|
||||
logger.info(
|
||||
"Client '%s' liveliness dropped — GC in %.0fs", client_uuid, _LIVELINESS_GC_GRACE_S
|
||||
)
|
||||
else:
|
||||
session.alive = True
|
||||
session.token_dropped_mono = None
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.error("liveliness callback error: %s", e)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Control-plane queryables
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _on_status_query(self, query: Any) -> None:
|
||||
try:
|
||||
query.reply(status_key(self.prefix), codec.encode_status(self.status_snapshot()))
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.error("status query error: %s", e)
|
||||
|
||||
def _on_session_query(self, query: Any) -> None:
|
||||
try:
|
||||
payload = query.payload.to_bytes() if query.payload is not None else b""
|
||||
op = codec.decode_raw(payload).get("op", "open") if payload else "open"
|
||||
if op == "close":
|
||||
self._handle_session_close(codec.decode_session_close(payload))
|
||||
query.reply(session_key(self.prefix), codec.encode_reset_ack(ResetAckMsg(ok=True)))
|
||||
return
|
||||
ack = self._handle_session_open(codec.decode_session_open(payload))
|
||||
query.reply(session_key(self.prefix), codec.encode_session_ack(ack))
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.error("session query error: %s\n%s", e, traceback.format_exc())
|
||||
with contextlib.suppress(Exception):
|
||||
query.reply(
|
||||
session_key(self.prefix),
|
||||
codec.encode_session_ack(SessionAckMsg(accepted=False, error=f"server error: {e}")),
|
||||
)
|
||||
|
||||
def _on_reset_query(self, query: Any) -> None:
|
||||
try:
|
||||
payload = query.payload.to_bytes() if query.payload is not None else b""
|
||||
msg = codec.decode_reset(payload)
|
||||
session = self.registry.get(msg.client_uuid)
|
||||
if session is None:
|
||||
ack = ResetAckMsg(ok=False, error=f"unknown client '{msg.client_uuid}'")
|
||||
else:
|
||||
# Serialize with the worker: resetting pipelines/policy
|
||||
# mid-predict would corrupt the in-flight request.
|
||||
with self._inference_lock:
|
||||
session.reset_episode(msg.episode_id)
|
||||
if self._serving_mode == "exclusive":
|
||||
# Safe: max_sessions=1, the policy belongs to this client.
|
||||
self._policy.reset()
|
||||
ack = ResetAckMsg(ok=True)
|
||||
query.reply(str(query.key_expr), codec.encode_reset_ack(ack))
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.error("reset query error: %s", e)
|
||||
|
||||
def _handle_session_open(self, msg: SessionOpenMsg) -> SessionAckMsg:
|
||||
capabilities = self.status_snapshot()
|
||||
with self._registry_lock:
|
||||
# A re-handshake from a known client replaces its session and
|
||||
# does not count against capacity.
|
||||
existing = self.registry.get(msg.client_uuid)
|
||||
active = len(self.registry) - (1 if existing else 0)
|
||||
result = validate_session_open(msg, capabilities, self._manifest, active)
|
||||
if not result.ok:
|
||||
logger.warning("Session rejected for '%s': %s", msg.client_uuid, result.error)
|
||||
return SessionAckMsg(accepted=False, error=result.error, server_load=self.server_load)
|
||||
|
||||
preprocessor, postprocessor = self.make_session_processors()
|
||||
session = Session(
|
||||
session_id=uuid_module.uuid4().hex,
|
||||
client_uuid=msg.client_uuid,
|
||||
task=msg.task or self._manifest.default_task,
|
||||
robot_type=msg.robot_type,
|
||||
rtc_enabled=msg.rtc_enabled and not result.rtc_downgraded,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
action_publisher=self._declare_action_publisher(msg.client_uuid),
|
||||
)
|
||||
if session.relative_step is not None and session.relative_step.action_names is None:
|
||||
session.relative_step.action_names = self.action_names or list(msg.action_names)
|
||||
# Sentinel: the FIRST observation of a fresh session always
|
||||
# triggers the episode-boundary branch in _serve_one, so a
|
||||
# mid-episode reconnect can never inherit stale state.
|
||||
session.episode_id = -1
|
||||
displaced = self.registry.add(session)
|
||||
if displaced is not None:
|
||||
displaced.close()
|
||||
self._bump("sessions_closed_total")
|
||||
logger.info("Client '%s' re-handshake: previous session replaced", msg.client_uuid)
|
||||
if self._serving_mode == "exclusive":
|
||||
# A new exclusive session must start from fresh policy state.
|
||||
with self._inference_lock:
|
||||
self._policy.reset()
|
||||
self._bump("sessions_opened_total")
|
||||
self._unknown_clients_warned.discard(msg.client_uuid)
|
||||
logger.info(
|
||||
"Session opened: client=%s session=%s task=%r rtc=%s (%d/%d)",
|
||||
msg.client_uuid,
|
||||
session.session_id,
|
||||
session.task,
|
||||
session.rtc_enabled,
|
||||
len(self.registry),
|
||||
self._max_sessions,
|
||||
)
|
||||
return SessionAckMsg(
|
||||
accepted=True,
|
||||
warnings=result.warnings,
|
||||
session_id=session.session_id,
|
||||
model_repo=capabilities.model_repo,
|
||||
model_revision=capabilities.model_revision,
|
||||
policy_type=capabilities.policy_type,
|
||||
action_names=capabilities.action_names,
|
||||
expected_cameras=capabilities.expected_cameras,
|
||||
state_dim=capabilities.state_dim,
|
||||
chunk_size=capabilities.chunk_size,
|
||||
trained_fps=capabilities.trained_fps,
|
||||
supports_rtc=capabilities.supports_rtc and session.rtc_enabled,
|
||||
rtc_execution_horizon=capabilities.rtc_execution_horizon,
|
||||
serving_mode=capabilities.serving_mode,
|
||||
warmed_up=capabilities.warmed_up,
|
||||
server_load=self.server_load,
|
||||
)
|
||||
|
||||
def _declare_action_publisher(self, client_uuid: str) -> Any:
|
||||
if self._zenoh is None: # pure-logic tests run without transport
|
||||
return None
|
||||
zenoh = import_zenoh()
|
||||
return self._zenoh.declare_publisher(
|
||||
action_key(self.prefix, client_uuid), **action_publisher_qos(zenoh)
|
||||
)
|
||||
|
||||
def _handle_session_close(self, msg: SessionCloseMsg) -> None:
|
||||
session = self.registry.get(msg.client_uuid)
|
||||
if session is not None and (not msg.session_id or msg.session_id == session.session_id):
|
||||
self._close_session(session, reason="client close")
|
||||
|
||||
def _close_session(self, session: Session, reason: str) -> None:
|
||||
# Identity-checked removal: never tear down a same-uuid session
|
||||
# that replaced this one via a re-handshake.
|
||||
removed = self.registry.remove(session.client_uuid, expected=session)
|
||||
if removed is not None:
|
||||
removed.close()
|
||||
self._bump("sessions_closed_total")
|
||||
logger.info("Session closed: client=%s (%s)", session.client_uuid, reason)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Inference worker
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _worker_loop(self) -> None:
|
||||
last_gc = time.monotonic()
|
||||
while not self._shutdown.is_set():
|
||||
ready = [s for s in self.registry.snapshot() if s.has_pending()]
|
||||
if not ready:
|
||||
self._work.wait(timeout=_WORKER_IDLE_WAIT_S)
|
||||
self._work.clear()
|
||||
else:
|
||||
for session in self._scheduler.select(ready):
|
||||
self._serve_one(session)
|
||||
|
||||
now = time.monotonic()
|
||||
if now - last_gc > 1.0:
|
||||
last_gc = now
|
||||
self._gc_sessions(now)
|
||||
|
||||
def _serve_one(self, session: Session) -> None:
|
||||
item = session.take()
|
||||
if item is None:
|
||||
return
|
||||
queue_wait_ms = (time.monotonic() - item.recv_mono) * 1e3
|
||||
outcome = "ok"
|
||||
try:
|
||||
obs = codec.decode_observation(item.payload)
|
||||
self._capture("req", item.payload)
|
||||
|
||||
with self._inference_lock:
|
||||
# Belt-and-braces episode ordering: the first observation of
|
||||
# an episode also announces the boundary (one-in-flight makes
|
||||
# the reset query race-free, but a lost ack must not desync
|
||||
# us; fresh sessions start at the -1 sentinel so their first
|
||||
# request always lands here).
|
||||
if obs.episode_start or item.header.episode_id != session.episode_id:
|
||||
session.preprocessor.reset()
|
||||
session.postprocessor.reset()
|
||||
session.episode_id = item.header.episode_id
|
||||
if self._serving_mode == "exclusive":
|
||||
self._policy.reset()
|
||||
|
||||
reply = self.run_inference_request(session, item.header, obs)
|
||||
reply.queue_wait_ms = queue_wait_ms
|
||||
session.stats.last_queue_wait_ms = queue_wait_ms
|
||||
|
||||
body = codec.encode_action_chunk(reply)
|
||||
self._capture("rep", body)
|
||||
# Local ref: a re-handshake can null session.action_publisher
|
||||
# between the check and the put.
|
||||
publisher = session.action_publisher
|
||||
if publisher is not None:
|
||||
reply_header = MsgHeader(
|
||||
schema_version=SCHEMA_VERSION,
|
||||
msg_type=2, # MSG_TYPE_CHUNK
|
||||
seq_id=item.header.seq_id,
|
||||
episode_id=item.header.episode_id,
|
||||
client_mono_ns=item.header.client_mono_ns,
|
||||
session_epoch=item.header.session_epoch,
|
||||
)
|
||||
publisher.put(body, attachment=reply_header.pack())
|
||||
self._bump("requests_total")
|
||||
self._bump("superseded_total", reply.superseded_seqs)
|
||||
except Exception as e: # noqa: BLE001 — one bad request must not kill the worker
|
||||
outcome = f"error: {e}"
|
||||
session.stats.errors += 1
|
||||
self._bump("errors_total")
|
||||
logger.error(
|
||||
"Inference error for client '%s' seq=%d: %s\n%s",
|
||||
session.client_uuid,
|
||||
item.header.seq_id,
|
||||
e,
|
||||
traceback.format_exc(),
|
||||
)
|
||||
finally:
|
||||
audit_logger.info(
|
||||
json.dumps(
|
||||
{
|
||||
"session_id": session.session_id,
|
||||
"client_uuid": session.client_uuid,
|
||||
"seq_id": item.header.seq_id,
|
||||
"episode_id": item.header.episode_id,
|
||||
"queue_wait_ms": round(queue_wait_ms, 3),
|
||||
"inference_ms": round(session.stats.last_inference_ms, 3),
|
||||
"superseded": session.stats.superseded,
|
||||
"outcome": outcome,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
def _gc_sessions(self, now: float) -> None:
|
||||
for session in self.registry.snapshot():
|
||||
if (
|
||||
session.token_dropped_mono is not None
|
||||
and now - session.token_dropped_mono > _LIVELINESS_GC_GRACE_S
|
||||
):
|
||||
if self._client_token_alive(session.client_uuid):
|
||||
# The DELETE was a late echo of a previous incarnation
|
||||
# (the token key is per client, not per epoch) — the
|
||||
# client re-declared and is alive.
|
||||
session.token_dropped_mono = None
|
||||
session.alive = True
|
||||
continue
|
||||
self._close_session(session, reason="liveliness token dropped")
|
||||
elif now - session.last_seen_mono > self._manifest.session_idle_timeout_s:
|
||||
self._close_session(session, reason="idle timeout")
|
||||
|
||||
def _client_token_alive(self, client_uuid: str) -> bool:
|
||||
"""Confirm a client's liveliness token via an explicit get (GC double-check)."""
|
||||
if self._zenoh is None:
|
||||
return False
|
||||
try:
|
||||
zenoh = import_zenoh()
|
||||
replies = self._zenoh.liveliness().get(
|
||||
client_alive_key(self.prefix, client_uuid),
|
||||
handler=zenoh.handlers.FifoChannel(4),
|
||||
timeout=0.5,
|
||||
)
|
||||
deadline = time.monotonic() + 1.0
|
||||
while time.monotonic() < deadline:
|
||||
try:
|
||||
reply = replies.try_recv()
|
||||
except Exception: # channel closed: no token found # noqa: BLE001
|
||||
return False
|
||||
if reply is None:
|
||||
time.sleep(0.01)
|
||||
continue
|
||||
if reply.ok is not None:
|
||||
return True
|
||||
return False
|
||||
except Exception: # noqa: BLE001 — treat transport trouble as "not alive"
|
||||
return False
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Misc
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _bump(self, key: str, amount: float = 1) -> None:
|
||||
with self._metrics_lock:
|
||||
self.metrics[key] = self.metrics.get(key, 0) + amount
|
||||
|
||||
def _capture(self, kind: str, data: bytes) -> None:
|
||||
capture_dir = self._manifest.debug.capture_dir
|
||||
if not capture_dir:
|
||||
return
|
||||
try:
|
||||
directory = Path(capture_dir)
|
||||
directory.mkdir(parents=True, exist_ok=True)
|
||||
index = self._capture_count % max(1, self._manifest.debug.capture_max)
|
||||
(directory / f"{kind}_{index:05d}.bin").write_bytes(data)
|
||||
if kind == "rep":
|
||||
self._capture_count += 1
|
||||
except OSError as e:
|
||||
logger.warning("debug capture failed: %s", e)
|
||||
|
||||
def _start_health_server(self, port: int) -> None:
|
||||
server_ref = self
|
||||
|
||||
class Handler(http.server.BaseHTTPRequestHandler):
|
||||
def do_GET(self) -> None: # noqa: N802 — http.server API
|
||||
if self.path == "/healthz":
|
||||
worker = server_ref._worker # local ref: stop() may null it mid-read
|
||||
healthy = worker is not None and worker.is_alive()
|
||||
self.send_response(200 if healthy else 503)
|
||||
self.end_headers()
|
||||
self.wfile.write(b"ok" if healthy else b"worker dead")
|
||||
elif self.path == "/metrics":
|
||||
with server_ref._metrics_lock:
|
||||
counters = dict(server_ref.metrics)
|
||||
counters["active_sessions"] = len(server_ref.registry)
|
||||
counters["server_load"] = server_ref.server_load
|
||||
body = "".join(
|
||||
f"lerobot_policy_server_{name} {value}\n" for name, value in sorted(counters.items())
|
||||
).encode()
|
||||
self.send_response(200)
|
||||
self.send_header("Content-Type", "text/plain; version=0.0.4")
|
||||
self.end_headers()
|
||||
self.wfile.write(body)
|
||||
else:
|
||||
self.send_response(404)
|
||||
self.end_headers()
|
||||
|
||||
def log_message(self, *args: Any) -> None: # silence per-request logging
|
||||
pass
|
||||
|
||||
self._health_server = http.server.ThreadingHTTPServer(("0.0.0.0", port), Handler) # nosec B104
|
||||
threading.Thread(target=self._health_server.serve_forever, daemon=True, name="HealthHTTP").start()
|
||||
logger.info("Health/metrics on :%d (/healthz, /metrics)", port)
|
||||
@@ -1,203 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Per-client session state and the latest-only observation mailbox.
|
||||
|
||||
The server holds **no cross-request control state**: RTC prefixes and
|
||||
delay hints arrive with every observation. What a session does hold:
|
||||
|
||||
- Per-session processor pipeline instances. Mandatory:
|
||||
``RelativeActionsProcessorStep`` caches ``_last_state`` at preprocess
|
||||
and the postprocessor reads it back — a pipeline shared across clients
|
||||
would be a race.
|
||||
- A one-slot mailbox: the newest observation wins; superseded requests
|
||||
are counted so drops stay visible to the client.
|
||||
- Counters for the audit log and ``/metrics``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from threading import Lock
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from lerobot.processor import (
|
||||
NormalizerProcessorStep,
|
||||
PolicyProcessorPipeline,
|
||||
RelativeActionsProcessorStep,
|
||||
)
|
||||
|
||||
from .schema import MsgHeader
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class MailboxItem:
|
||||
header: MsgHeader
|
||||
payload: bytes
|
||||
recv_mono: float # server-local monotonic deposit time (for queue_wait_ms)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionStats:
|
||||
requests: int = 0
|
||||
errors: int = 0
|
||||
superseded: int = 0 # observations overwritten before inference (lifetime)
|
||||
superseded_since_reply: int = 0
|
||||
last_inference_ms: float = 0.0
|
||||
last_queue_wait_ms: float = 0.0
|
||||
|
||||
|
||||
class Session:
|
||||
"""One connected robot client."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_id: str,
|
||||
client_uuid: str,
|
||||
task: str,
|
||||
robot_type: str,
|
||||
rtc_enabled: bool,
|
||||
preprocessor: PolicyProcessorPipeline,
|
||||
postprocessor: PolicyProcessorPipeline,
|
||||
action_publisher: Any = None, # zenoh.Publisher (Any: zenoh optional at import)
|
||||
) -> None:
|
||||
self.session_id = session_id
|
||||
self.client_uuid = client_uuid
|
||||
self.task = task
|
||||
self.robot_type = robot_type
|
||||
self.rtc_enabled = rtc_enabled
|
||||
self.preprocessor = preprocessor
|
||||
self.postprocessor = postprocessor
|
||||
self.action_publisher = action_publisher
|
||||
|
||||
self.episode_id = 0
|
||||
self.stats = SessionStats()
|
||||
self.alive = True
|
||||
self.last_seen_mono = time.monotonic()
|
||||
# Set when the client's liveliness token drops; GC after grace.
|
||||
self.token_dropped_mono: float | None = None
|
||||
|
||||
# Processor introspection for relative-action prefix re-anchoring
|
||||
# (mirrors RTCInferenceEngine.__init__).
|
||||
self.relative_step = next(
|
||||
(s for s in preprocessor.steps if isinstance(s, RelativeActionsProcessorStep) and s.enabled),
|
||||
None,
|
||||
)
|
||||
self.normalizer_step = next(
|
||||
(s for s in preprocessor.steps if isinstance(s, NormalizerProcessorStep)),
|
||||
None,
|
||||
)
|
||||
|
||||
self._mailbox: MailboxItem | None = None
|
||||
self._mailbox_lock = Lock()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Mailbox (deposit-only callbacks write, the inference worker reads)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def deposit(self, header: MsgHeader, payload: bytes) -> None:
|
||||
"""Latest-only deposit; counts superseded observations."""
|
||||
item = MailboxItem(header=header, payload=payload, recv_mono=time.monotonic())
|
||||
with self._mailbox_lock:
|
||||
if self._mailbox is not None:
|
||||
self.stats.superseded += 1
|
||||
self.stats.superseded_since_reply += 1
|
||||
self._mailbox = item
|
||||
self.alive = True
|
||||
self.token_dropped_mono = None
|
||||
self.last_seen_mono = item.recv_mono
|
||||
|
||||
def take(self) -> MailboxItem | None:
|
||||
with self._mailbox_lock:
|
||||
item, self._mailbox = self._mailbox, None
|
||||
return item
|
||||
|
||||
def take_superseded(self) -> int:
|
||||
"""Atomically read-and-reset the per-reply supersession counter."""
|
||||
with self._mailbox_lock:
|
||||
count = self.stats.superseded_since_reply
|
||||
self.stats.superseded_since_reply = 0
|
||||
return count
|
||||
|
||||
def has_pending(self) -> bool:
|
||||
with self._mailbox_lock:
|
||||
return self._mailbox is not None
|
||||
|
||||
def clear_mailbox(self) -> None:
|
||||
with self._mailbox_lock:
|
||||
self._mailbox = None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Episode boundary
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def reset_episode(self, episode_id: int | None = None) -> None:
|
||||
"""Clear per-episode state. The shared policy is NOT touched here."""
|
||||
self.clear_mailbox()
|
||||
self.preprocessor.reset()
|
||||
self.postprocessor.reset()
|
||||
self.episode_id = episode_id if episode_id is not None else self.episode_id + 1
|
||||
|
||||
def close(self) -> None:
|
||||
self.clear_mailbox()
|
||||
publisher = self.action_publisher
|
||||
self.action_publisher = None
|
||||
if publisher is not None:
|
||||
# Already-closed transport is fine on teardown.
|
||||
with contextlib.suppress(Exception):
|
||||
publisher.undeclare()
|
||||
|
||||
|
||||
class SessionRegistry:
|
||||
"""Thread-safe map of client_uuid → Session."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._sessions: dict[str, Session] = {}
|
||||
self._lock = Lock()
|
||||
|
||||
def add(self, session: Session) -> Session | None:
|
||||
"""Register, returning a displaced same-client session (caller closes it)."""
|
||||
with self._lock:
|
||||
old = self._sessions.get(session.client_uuid)
|
||||
self._sessions[session.client_uuid] = session
|
||||
return old
|
||||
|
||||
def get(self, client_uuid: str) -> Session | None:
|
||||
with self._lock:
|
||||
return self._sessions.get(client_uuid)
|
||||
|
||||
def remove(self, client_uuid: str, expected: Session | None = None) -> Session | None:
|
||||
"""Remove by uuid; with ``expected``, only if it is still that exact session.
|
||||
|
||||
The identity check stops a GC sweep that snapshotted an old
|
||||
session from tearing down its just-handshaked replacement.
|
||||
"""
|
||||
with self._lock:
|
||||
current = self._sessions.get(client_uuid)
|
||||
if current is None or (expected is not None and current is not expected):
|
||||
return None
|
||||
return self._sessions.pop(client_uuid)
|
||||
|
||||
def snapshot(self) -> list[Session]:
|
||||
with self._lock:
|
||||
return list(self._sessions.values())
|
||||
|
||||
def __len__(self) -> int:
|
||||
with self._lock:
|
||||
return len(self._sessions)
|
||||
@@ -1,265 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Serving-mode classification and session capability validation.
|
||||
|
||||
Multi-tenancy is engineered, not assumed: sharing one policy instance
|
||||
across sessions is only safe when ``predict_action_chunk`` touches no
|
||||
instance state. That property has been verified per policy family and
|
||||
is encoded here as an explicit registry — never inferred.
|
||||
|
||||
- ``act``/``pi0``/``pi05``: chunk-stateless (verified in-tree).
|
||||
- ``smolvla``: populates its ``_queues`` *inside* ``predict_action_chunk``;
|
||||
with ``n_obs_steps == 1`` the queue is overwritten with the request's
|
||||
own observation before being read, so sharing is safe. With history
|
||||
(``n_obs_steps > 1``) requests would read other sessions' frames →
|
||||
exclusive.
|
||||
- ``diffusion``: ``predict_action_chunk`` reads ``_queues`` that only
|
||||
``select_action`` populates → exclusive, with the server populating
|
||||
the observation queues per request (mirroring ``select_action``).
|
||||
- Policies without a ``predict_action_chunk`` override are refused.
|
||||
- Unverified chunk-API policies default to exclusive; ``shared`` cannot
|
||||
be forced for them (the roadmap upstreams a
|
||||
``supports_stateless_chunking`` attribute to policy classes).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
|
||||
from .manifest import (
|
||||
SERVING_MODE_EXCLUSIVE,
|
||||
SERVING_MODE_SHARED,
|
||||
PolicyServerManifest,
|
||||
)
|
||||
from .schema import SessionOpenMsg, StatusMsg
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ServingClass(Enum):
|
||||
SHARED = "shared"
|
||||
EXCLUSIVE = "exclusive"
|
||||
REFUSED = "refused"
|
||||
|
||||
|
||||
# Verified chunk-stateless families (predict_action_chunk touches no
|
||||
# cross-request instance state).
|
||||
VERIFIED_CHUNK_STATELESS: frozenset[str] = frozenset({"act", "pi0", "pi05"})
|
||||
|
||||
# Families whose predict_action_chunk reads select_action-fed queues:
|
||||
# the server must populate the observation queues per request.
|
||||
QUEUE_POPULATED_IN_SELECT: frozenset[str] = frozenset({"diffusion"})
|
||||
|
||||
# Families whose predict_action_chunk accepts the RTC kwargs
|
||||
# (inference_delay / prev_chunk_left_over) — see each family's
|
||||
# ActionSelectKwargs TypedDict.
|
||||
RTC_CAPABLE: frozenset[str] = frozenset({"pi0", "pi05", "smolvla"})
|
||||
|
||||
|
||||
@dataclass
|
||||
class PolicyClassification:
|
||||
serving_class: ServingClass
|
||||
supports_rtc: bool
|
||||
needs_queue_population: bool
|
||||
reason: str
|
||||
|
||||
|
||||
def _has_chunk_api(policy: PreTrainedPolicy) -> bool:
|
||||
method = getattr(type(policy), "predict_action_chunk", None)
|
||||
return method is not None and method is not PreTrainedPolicy.predict_action_chunk
|
||||
|
||||
|
||||
def classify_policy(policy: PreTrainedPolicy) -> PolicyClassification:
|
||||
"""Classify a loaded policy into a serving class. Registry-driven, never inferred."""
|
||||
name = getattr(policy, "name", type(policy).__name__)
|
||||
|
||||
if not _has_chunk_api(policy):
|
||||
return PolicyClassification(
|
||||
ServingClass.REFUSED,
|
||||
supports_rtc=False,
|
||||
needs_queue_population=False,
|
||||
reason=f"policy '{name}' does not implement predict_action_chunk",
|
||||
)
|
||||
|
||||
supports_rtc = name in RTC_CAPABLE
|
||||
|
||||
if name in VERIFIED_CHUNK_STATELESS:
|
||||
return PolicyClassification(
|
||||
ServingClass.SHARED, supports_rtc, False, f"'{name}' is verified chunk-stateless"
|
||||
)
|
||||
|
||||
if name == "smolvla":
|
||||
n_obs_steps = getattr(policy.config, "n_obs_steps", 1)
|
||||
if n_obs_steps == 1:
|
||||
return PolicyClassification(
|
||||
ServingClass.SHARED,
|
||||
supports_rtc,
|
||||
False,
|
||||
"'smolvla' with n_obs_steps=1 overwrites its queues per request",
|
||||
)
|
||||
return PolicyClassification(
|
||||
ServingClass.EXCLUSIVE,
|
||||
supports_rtc,
|
||||
False,
|
||||
f"'smolvla' with n_obs_steps={n_obs_steps} keeps observation history across requests",
|
||||
)
|
||||
|
||||
if name in QUEUE_POPULATED_IN_SELECT:
|
||||
return PolicyClassification(
|
||||
ServingClass.EXCLUSIVE,
|
||||
supports_rtc,
|
||||
True,
|
||||
f"'{name}' predict_action_chunk reads select_action-fed queues",
|
||||
)
|
||||
|
||||
return PolicyClassification(
|
||||
ServingClass.EXCLUSIVE,
|
||||
supports_rtc,
|
||||
False,
|
||||
f"'{name}' has a chunk API but is not in the verified chunk-stateless registry",
|
||||
)
|
||||
|
||||
|
||||
def resolve_serving_mode(
|
||||
classification: PolicyClassification, manifest: PolicyServerManifest
|
||||
) -> tuple[str, int]:
|
||||
"""Resolve the final (serving_mode, max_sessions) from classification + manifest.
|
||||
|
||||
The manifest may force ``exclusive`` but can never force ``shared``
|
||||
for a policy that is not verified chunk-stateless.
|
||||
"""
|
||||
if classification.serving_class is ServingClass.REFUSED:
|
||||
raise ValueError(f"Refusing to serve this policy: {classification.reason}")
|
||||
|
||||
if manifest.serving_mode == SERVING_MODE_SHARED:
|
||||
if classification.serving_class is not ServingClass.SHARED:
|
||||
raise ValueError(
|
||||
f"serving_mode=shared is unsafe for this policy: {classification.reason}. "
|
||||
"Use serving_mode=exclusive (or auto)."
|
||||
)
|
||||
mode = SERVING_MODE_SHARED
|
||||
elif manifest.serving_mode == SERVING_MODE_EXCLUSIVE:
|
||||
mode = SERVING_MODE_EXCLUSIVE
|
||||
else: # auto
|
||||
mode = (
|
||||
SERVING_MODE_SHARED
|
||||
if classification.serving_class is ServingClass.SHARED
|
||||
else SERVING_MODE_EXCLUSIVE
|
||||
)
|
||||
|
||||
max_sessions = manifest.max_sessions
|
||||
if mode == SERVING_MODE_EXCLUSIVE and max_sessions != 1:
|
||||
logger.warning(
|
||||
"serving_mode=exclusive forces max_sessions=1 (manifest had %d)", manifest.max_sessions
|
||||
)
|
||||
max_sessions = 1
|
||||
return mode, max_sessions
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Session-open validation (fail fast, fail loud)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValidationResult:
|
||||
error: str = "" # non-empty → hard reject
|
||||
warnings: list[str] = field(default_factory=list)
|
||||
# RTC requested but unsupported → downgrade to plain chunk-append.
|
||||
rtc_downgraded: bool = False
|
||||
|
||||
@property
|
||||
def ok(self) -> bool:
|
||||
return not self.error
|
||||
|
||||
|
||||
def validate_session_open(
|
||||
msg: SessionOpenMsg,
|
||||
capabilities: StatusMsg,
|
||||
manifest: PolicyServerManifest,
|
||||
active_sessions: int,
|
||||
) -> ValidationResult:
|
||||
"""Apply the capability matrix from the design doc (§8.4)."""
|
||||
result = ValidationResult()
|
||||
|
||||
# Schema version: client must be within the server's supported range.
|
||||
if not (capabilities.min_schema_version <= msg.schema_version <= capabilities.max_schema_version):
|
||||
result.error = (
|
||||
f"schema_version {msg.schema_version} outside supported range "
|
||||
f"[{capabilities.min_schema_version}, {capabilities.max_schema_version}]"
|
||||
)
|
||||
return result
|
||||
|
||||
# Capacity: reject with current load so the client can retry another replica.
|
||||
if active_sessions >= capabilities.max_sessions:
|
||||
result.error = f"server full: {active_sessions}/{capabilities.max_sessions} sessions active"
|
||||
return result
|
||||
|
||||
# Action names AND order: the hard sync-safety contract mapping
|
||||
# chunk columns to motors.
|
||||
if capabilities.action_names and msg.action_names != capabilities.action_names:
|
||||
result.error = (
|
||||
"action feature names/order mismatch — refusing to map chunk columns to motors.\n"
|
||||
f" server: {capabilities.action_names}\n"
|
||||
f" client: {msg.action_names}"
|
||||
)
|
||||
return result
|
||||
|
||||
# State dim.
|
||||
if capabilities.state_dim and msg.state_dim and msg.state_dim != capabilities.state_dim:
|
||||
result.error = f"state dim mismatch: server={capabilities.state_dim}, client={msg.state_dim}"
|
||||
return result
|
||||
|
||||
# Camera names: the client set must cover the policy's visual features.
|
||||
missing = set(capabilities.expected_cameras) - set(msg.camera_names)
|
||||
if missing:
|
||||
result.error = (
|
||||
f"missing camera features {sorted(missing)} "
|
||||
f"(client provides {sorted(msg.camera_names)}; resolution may differ — names may not)"
|
||||
)
|
||||
return result
|
||||
|
||||
# Task pinning.
|
||||
if manifest.pin_task and msg.task and msg.task != manifest.default_task:
|
||||
result.error = f"task is pinned to {manifest.default_task!r} on this server, got {msg.task!r}"
|
||||
return result
|
||||
|
||||
# fps: warn unless strict.
|
||||
if capabilities.trained_fps and msg.fps and abs(msg.fps - capabilities.trained_fps) > 1e-6:
|
||||
fps_msg = f"client fps={msg.fps:g} != trained fps={capabilities.trained_fps:g}"
|
||||
if manifest.strict_fps:
|
||||
result.error = fps_msg + " (strict_fps=true)"
|
||||
return result
|
||||
result.warnings.append(fps_msg)
|
||||
|
||||
# Policy type sanity (informational mismatch is a warning, not fatal:
|
||||
# the action/state/camera contracts above are the binding ones).
|
||||
if msg.policy_type and capabilities.policy_type and msg.policy_type != capabilities.policy_type:
|
||||
result.warnings.append(
|
||||
f"client expected policy_type={msg.policy_type!r}, server runs {capabilities.policy_type!r}"
|
||||
)
|
||||
|
||||
# RTC: requested but unsupported → serve plain chunks, client appends.
|
||||
if msg.rtc_enabled and not capabilities.supports_rtc:
|
||||
result.rtc_downgraded = True
|
||||
result.warnings.append(
|
||||
"RTC requested but this server/policy does not support it — downgrading to chunk-append"
|
||||
)
|
||||
|
||||
return result
|
||||
@@ -1,101 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Zenoh session construction shared by the policy server and the remote engine.
|
||||
|
||||
Verified against eclipse-zenoh 1.9 (thread-based; no asyncio API).
|
||||
Multicast scouting is always disabled — fleet "discovery" is static
|
||||
endpoint configuration plus liveliness tokens, never protocol magic.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_ZENOH_IMPORT_HINT = (
|
||||
"Remote inference requires the 'async' extra: pip install 'lerobot[async]' (eclipse-zenoh + msgpack)"
|
||||
)
|
||||
|
||||
|
||||
def import_zenoh():
|
||||
"""Import zenoh lazily with an actionable error message."""
|
||||
try:
|
||||
import zenoh
|
||||
except ImportError as e:
|
||||
raise ImportError(_ZENOH_IMPORT_HINT) from e
|
||||
return zenoh
|
||||
|
||||
|
||||
def build_zenoh_config(
|
||||
*,
|
||||
mode: str = "client",
|
||||
connect_endpoints: list[str] | None = None,
|
||||
listen_endpoints: list[str] | None = None,
|
||||
tls_root_ca_certificate: str | None = None,
|
||||
tls_connect_certificate: str | None = None,
|
||||
tls_connect_private_key: str | None = None,
|
||||
extra_config_json5: str | None = None,
|
||||
):
|
||||
"""Build a zenoh.Config (values are JSON5 strings — note the inner quoting)."""
|
||||
zenoh = import_zenoh()
|
||||
cfg = zenoh.Config()
|
||||
cfg.insert_json5("mode", json.dumps(mode))
|
||||
cfg.insert_json5("scouting/multicast/enabled", "false")
|
||||
if connect_endpoints:
|
||||
cfg.insert_json5("connect/endpoints", json.dumps(list(connect_endpoints)))
|
||||
if listen_endpoints:
|
||||
cfg.insert_json5("listen/endpoints", json.dumps(list(listen_endpoints)))
|
||||
if tls_root_ca_certificate:
|
||||
cfg.insert_json5("transport/link/tls/root_ca_certificate", json.dumps(tls_root_ca_certificate))
|
||||
if tls_connect_certificate:
|
||||
cfg.insert_json5("transport/link/tls/connect_certificate", json.dumps(tls_connect_certificate))
|
||||
if tls_connect_private_key:
|
||||
cfg.insert_json5("transport/link/tls/connect_private_key", json.dumps(tls_connect_private_key))
|
||||
if extra_config_json5:
|
||||
merged = json.loads(extra_config_json5)
|
||||
for key, value in merged.items():
|
||||
cfg.insert_json5(key, json.dumps(value))
|
||||
return cfg
|
||||
|
||||
|
||||
def action_publisher_qos(zenoh) -> dict:
|
||||
"""QoS for the action topic: RELIABLE + congestion DROP (never BLOCK) + express.
|
||||
|
||||
DROP so one dead robot uplink can never stall the server's publish
|
||||
path; a dropped chunk is recoverable by design — the client's action
|
||||
buffer keeps the robot moving and the next chunk replaces it.
|
||||
"""
|
||||
return {
|
||||
"reliability": zenoh.Reliability.RELIABLE,
|
||||
"congestion_control": zenoh.CongestionControl.DROP,
|
||||
"express": True,
|
||||
"priority": zenoh.Priority.INTERACTIVE_HIGH,
|
||||
}
|
||||
|
||||
|
||||
def obs_publisher_qos(zenoh) -> dict:
|
||||
"""QoS for the observation topic: best-effort drop, default priority.
|
||||
|
||||
Intentional drop already happened at the client's one-slot holder;
|
||||
if the uplink stalls, dropping a frame protects the control loop.
|
||||
"""
|
||||
return {
|
||||
"reliability": zenoh.Reliability.BEST_EFFORT,
|
||||
"congestion_control": zenoh.CongestionControl.DROP,
|
||||
"express": False,
|
||||
"priority": zenoh.Priority.DATA,
|
||||
}
|
||||
@@ -32,6 +32,7 @@ from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable, Iterable, Sequence
|
||||
@@ -280,11 +281,6 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
|
||||
before_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False)
|
||||
after_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False)
|
||||
_serialized_state_filenames: tuple[str | None, ...] | None = field(
|
||||
default=None,
|
||||
init=False,
|
||||
repr=False,
|
||||
)
|
||||
|
||||
def __call__(self, data: TInput) -> TOutput:
|
||||
"""Processes input data through the full pipeline.
|
||||
@@ -342,108 +338,30 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
transition = processor_step(transition)
|
||||
yield transition
|
||||
|
||||
def _get_sanitized_name(self) -> str:
|
||||
"""Return a filename-safe version of the pipeline name.
|
||||
def _save_pretrained(self, save_directory: Path, **kwargs):
|
||||
"""Internal method to comply with `HubMixin`'s saving mechanism.
|
||||
|
||||
Returns:
|
||||
The lower-cased pipeline name with non-alphanumeric characters replaced by underscores.
|
||||
This method does the actual saving work and is called by HubMixin.save_pretrained.
|
||||
"""
|
||||
return re.sub(r"[^a-zA-Z0-9_]", "_", self.name.lower())
|
||||
config_filename = kwargs.pop("config_filename", None)
|
||||
|
||||
@staticmethod
|
||||
def _get_state_filename(
|
||||
*,
|
||||
step_index: int,
|
||||
registry_name: str | None,
|
||||
sanitized_name: str,
|
||||
) -> str:
|
||||
"""Return the safetensors filename for one stateful processor step.
|
||||
# Sanitize the pipeline name to create a valid filename prefix.
|
||||
sanitized_name = re.sub(r"[^a-zA-Z0-9_]", "_", self.name.lower())
|
||||
|
||||
Args:
|
||||
step_index: The index of the processor step in this pipeline.
|
||||
registry_name: The registered processor step name, if available.
|
||||
sanitized_name: The filename-safe pipeline name.
|
||||
if config_filename is None:
|
||||
config_filename = f"{sanitized_name}.json"
|
||||
|
||||
Returns:
|
||||
The state filename used by the existing disk serialization format.
|
||||
"""
|
||||
if registry_name:
|
||||
return f"{sanitized_name}_step_{step_index}_{registry_name}.safetensors"
|
||||
|
||||
return f"{sanitized_name}_step_{step_index}.safetensors"
|
||||
|
||||
@staticmethod
|
||||
def _get_state_key(state_filename: str) -> str:
|
||||
"""Return the in-memory state key for a serialized state filename.
|
||||
|
||||
Args:
|
||||
state_filename: The `.safetensors` filename from the serialized config.
|
||||
|
||||
Returns:
|
||||
The state key used by the in-memory pipeline state dictionary.
|
||||
"""
|
||||
return state_filename.removesuffix(".safetensors")
|
||||
|
||||
@staticmethod
|
||||
def _get_state_filenames_from_config(loaded_config: dict[str, Any]) -> tuple[str | None, ...]:
|
||||
"""Return serialized state filenames in step order.
|
||||
|
||||
Args:
|
||||
loaded_config: A validated processor pipeline config.
|
||||
|
||||
Returns:
|
||||
A tuple containing each step's serialized state filename, or None for stateless steps.
|
||||
"""
|
||||
return tuple(step_entry.get("state_file") for step_entry in loaded_config["steps"])
|
||||
|
||||
def _get_state_filenames_for_loading(self) -> tuple[str | None, ...]:
|
||||
"""Return expected state filenames in step order for `load_state_dict()`.
|
||||
|
||||
Returns:
|
||||
The preserved serialized state filenames when available, otherwise filenames derived from
|
||||
current non-empty step state.
|
||||
"""
|
||||
if self._serialized_state_filenames is not None and len(self._serialized_state_filenames) == len(
|
||||
self.steps
|
||||
):
|
||||
return self._serialized_state_filenames
|
||||
|
||||
sanitized_name = self._get_sanitized_name()
|
||||
state_filenames: list[str | None] = []
|
||||
|
||||
for step_index, processor_step in enumerate(self.steps):
|
||||
step_state_dict = processor_step.state_dict()
|
||||
if not step_state_dict:
|
||||
state_filenames.append(None)
|
||||
continue
|
||||
|
||||
registry_name = getattr(processor_step.__class__, "_registry_name", None)
|
||||
state_filenames.append(
|
||||
self._get_state_filename(
|
||||
step_index=step_index,
|
||||
registry_name=registry_name,
|
||||
sanitized_name=sanitized_name,
|
||||
)
|
||||
)
|
||||
|
||||
return tuple(state_filenames)
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""Return the JSON-serializable pipeline configuration.
|
||||
|
||||
Returns:
|
||||
A dictionary with the same content that `save_pretrained()` writes as JSON.
|
||||
"""
|
||||
sanitized_name = self._get_sanitized_name()
|
||||
pipeline_config: dict[str, Any] = {
|
||||
config: dict[str, Any] = {
|
||||
"name": self.name,
|
||||
"steps": [],
|
||||
}
|
||||
|
||||
# Iterate through each step to build its configuration entry.
|
||||
for step_index, processor_step in enumerate(self.steps):
|
||||
registry_name = getattr(processor_step.__class__, "_registry_name", None)
|
||||
step_entry: dict[str, Any] = {}
|
||||
|
||||
step_entry: dict[str, Any] = {}
|
||||
# Prefer registry name for portability, otherwise fall back to full class path.
|
||||
if registry_name:
|
||||
step_entry["registry_name"] = registry_name
|
||||
else:
|
||||
@@ -451,110 +369,31 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
f"{processor_step.__class__.__module__}.{processor_step.__class__.__name__}"
|
||||
)
|
||||
|
||||
step_entry["config"] = processor_step.get_config()
|
||||
# Save step configuration if `get_config` is implemented.
|
||||
if hasattr(processor_step, "get_config"):
|
||||
step_entry["config"] = processor_step.get_config()
|
||||
|
||||
step_state_dict = processor_step.state_dict()
|
||||
if step_state_dict:
|
||||
step_entry["state_file"] = self._get_state_filename(
|
||||
step_index=step_index,
|
||||
registry_name=registry_name,
|
||||
sanitized_name=sanitized_name,
|
||||
)
|
||||
# Save step state if `state_dict` is implemented and returns a non-empty dict.
|
||||
if hasattr(processor_step, "state_dict"):
|
||||
state = processor_step.state_dict()
|
||||
if state:
|
||||
# Clone tensors to avoid modifying the original state.
|
||||
cloned_state = {key: tensor.clone() for key, tensor in state.items()}
|
||||
|
||||
pipeline_config["steps"].append(step_entry)
|
||||
# Create a unique filename for the state file.
|
||||
if registry_name:
|
||||
state_filename = f"{sanitized_name}_step_{step_index}_{registry_name}.safetensors"
|
||||
else:
|
||||
state_filename = f"{sanitized_name}_step_{step_index}.safetensors"
|
||||
|
||||
return pipeline_config
|
||||
save_file(cloned_state, os.path.join(str(save_directory), state_filename))
|
||||
step_entry["state_file"] = state_filename
|
||||
|
||||
def state_dict(self) -> dict[str, dict[str, torch.Tensor]]:
|
||||
"""Return pipeline state tensors grouped by state key.
|
||||
config["steps"].append(step_entry)
|
||||
|
||||
Returns:
|
||||
A dictionary mapping suffixless state keys to cloned step state dictionaries.
|
||||
"""
|
||||
sanitized_name = self._get_sanitized_name()
|
||||
pipeline_state_dict: dict[str, dict[str, torch.Tensor]] = {}
|
||||
|
||||
for step_index, processor_step in enumerate(self.steps):
|
||||
step_state_dict = processor_step.state_dict()
|
||||
if not step_state_dict:
|
||||
continue
|
||||
|
||||
registry_name = getattr(processor_step.__class__, "_registry_name", None)
|
||||
state_filename = self._get_state_filename(
|
||||
step_index=step_index,
|
||||
registry_name=registry_name,
|
||||
sanitized_name=sanitized_name,
|
||||
)
|
||||
state_key = self._get_state_key(state_filename)
|
||||
pipeline_state_dict[state_key] = {
|
||||
tensor_name: tensor.clone() for tensor_name, tensor in step_state_dict.items()
|
||||
}
|
||||
|
||||
return pipeline_state_dict
|
||||
|
||||
def load_state_dict(
|
||||
self,
|
||||
state_dict: dict[str, dict[str, torch.Tensor]],
|
||||
) -> None:
|
||||
"""Load pipeline state tensors into the existing steps.
|
||||
|
||||
Args:
|
||||
state_dict: A dictionary mapping suffixless state keys to step state dictionaries.
|
||||
|
||||
Raises:
|
||||
KeyError: If loading finds missing expected state or unexpected extra state.
|
||||
"""
|
||||
expected_state_filenames = self._get_state_filenames_for_loading()
|
||||
used_state_keys: set[str] = set()
|
||||
|
||||
for step_index, (processor_step, state_filename) in enumerate(
|
||||
zip(self.steps, expected_state_filenames, strict=True)
|
||||
):
|
||||
if state_filename is None:
|
||||
continue
|
||||
|
||||
state_key = self._get_state_key(state_filename)
|
||||
if state_key not in state_dict:
|
||||
raise KeyError(
|
||||
f"Missing state key '{state_key}' for processor step {step_index}. "
|
||||
f"Available state keys: {sorted(state_dict.keys())}"
|
||||
)
|
||||
|
||||
processor_step.load_state_dict(state_dict[state_key])
|
||||
used_state_keys.add(state_key)
|
||||
|
||||
unexpected_state_keys = set(state_dict) - used_state_keys
|
||||
if unexpected_state_keys:
|
||||
expected_state_key_set = {
|
||||
self._get_state_key(state_filename)
|
||||
for state_filename in expected_state_filenames
|
||||
if state_filename is not None
|
||||
}
|
||||
raise KeyError(
|
||||
f"Unexpected processor state keys: {sorted(unexpected_state_keys)}. "
|
||||
f"Expected state keys: {sorted(expected_state_key_set)}"
|
||||
)
|
||||
|
||||
def _save_pretrained(self, save_directory: Path, **kwargs) -> None:
|
||||
"""Internal method to comply with `HubMixin`'s saving mechanism.
|
||||
|
||||
This method does the actual saving work and is called by HubMixin.save_pretrained.
|
||||
"""
|
||||
config_filename = kwargs.pop("config_filename", None)
|
||||
sanitized_name = self._get_sanitized_name()
|
||||
|
||||
if config_filename is None:
|
||||
config_filename = f"{sanitized_name}.json"
|
||||
|
||||
pipeline_config = self.get_config()
|
||||
pipeline_state_dict = self.state_dict()
|
||||
|
||||
for state_key, step_state_dict in pipeline_state_dict.items():
|
||||
state_filename = f"{state_key}.safetensors"
|
||||
save_file(step_state_dict, save_directory / state_filename)
|
||||
|
||||
with open(save_directory / config_filename, "w") as file_pointer:
|
||||
json.dump(pipeline_config, file_pointer, indent=2)
|
||||
# Write the main configuration JSON file.
|
||||
with open(os.path.join(str(save_directory), config_filename), "w") as file_pointer:
|
||||
json.dump(config, file_pointer, indent=2)
|
||||
|
||||
def save_pretrained(
|
||||
self,
|
||||
@@ -738,54 +577,12 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
cls._validate_overrides_used(validated_overrides, loaded_config)
|
||||
|
||||
# 5. Construct and return the final pipeline instance
|
||||
pipeline = cls(
|
||||
return cls(
|
||||
steps=steps,
|
||||
name=loaded_config.get("name", "DataProcessorPipeline"),
|
||||
to_transition=to_transition or cast(Callable[[TInput], EnvTransition], batch_to_transition),
|
||||
to_output=to_output or cast(Callable[[EnvTransition], TOutput], transition_to_batch),
|
||||
)
|
||||
pipeline._serialized_state_filenames = cls._get_state_filenames_from_config(loaded_config)
|
||||
return pipeline
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls,
|
||||
config: dict[str, Any],
|
||||
*,
|
||||
state_dict: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
overrides: dict[str, Any] | None = None,
|
||||
to_transition: Callable[[TInput], EnvTransition] | None = None,
|
||||
to_output: Callable[[EnvTransition], TOutput] | None = None,
|
||||
) -> DataProcessorPipeline[TInput, TOutput]:
|
||||
"""Build a pipeline from an in-memory config and optional state tensors.
|
||||
|
||||
Args:
|
||||
config: A config dictionary with the same structure as the saved processor JSON.
|
||||
state_dict: Optional in-memory pipeline state grouped by suffixless state key.
|
||||
overrides: Optional constructor overrides keyed by registry name or class name.
|
||||
to_transition: Optional converter from input data to `EnvTransition`.
|
||||
to_output: Optional converter from `EnvTransition` to output data.
|
||||
|
||||
Returns:
|
||||
A processor pipeline built from the config and optional state.
|
||||
"""
|
||||
cls._validate_loaded_config("<in-memory config>", config, "<in-memory config>")
|
||||
|
||||
steps, remaining_override_keys = cls._build_steps_from_config(config, overrides or {})
|
||||
cls._validate_overrides_used(remaining_override_keys, config)
|
||||
|
||||
pipeline = cls(
|
||||
steps=steps,
|
||||
name=config.get("name", "DataProcessorPipeline"),
|
||||
to_transition=to_transition or cast(Callable[[TInput], EnvTransition], batch_to_transition),
|
||||
to_output=to_output or cast(Callable[[EnvTransition], TOutput], transition_to_batch),
|
||||
)
|
||||
pipeline._serialized_state_filenames = cls._get_state_filenames_from_config(config)
|
||||
|
||||
if state_dict is not None:
|
||||
pipeline.load_state_dict(state_dict)
|
||||
|
||||
return pipeline
|
||||
|
||||
@classmethod
|
||||
def _load_config(
|
||||
@@ -869,7 +666,9 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
) from e
|
||||
|
||||
@classmethod
|
||||
def _validate_loaded_config(cls, model_id: str, loaded_config: Any, config_filename: str) -> None:
|
||||
def _validate_loaded_config(
|
||||
cls, model_id: str, loaded_config: dict[str, Any], config_filename: str
|
||||
) -> None:
|
||||
"""Validate that a config was loaded and is a valid processor config.
|
||||
|
||||
This method validates processor config format with intelligent migration detection:
|
||||
@@ -889,7 +688,7 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
|
||||
Args:
|
||||
model_id: The model identifier (used for migration detection)
|
||||
loaded_config: The loaded config value to validate (may be non-dict)
|
||||
loaded_config: The loaded config dictionary (guaranteed non-None)
|
||||
config_filename: The config filename that was loaded (for error messages)
|
||||
|
||||
Raises:
|
||||
@@ -903,14 +702,9 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
model_id,
|
||||
f"Config file '{config_filename}' is not a valid processor configuration",
|
||||
)
|
||||
loaded_config_description = (
|
||||
list(loaded_config.keys())
|
||||
if isinstance(loaded_config, dict)
|
||||
else type(loaded_config).__name__
|
||||
)
|
||||
raise ValueError(
|
||||
f"Config file '{config_filename}' is not a valid processor configuration. "
|
||||
f"Expected a config with 'steps' field, but got: {loaded_config_description}"
|
||||
f"Expected a config with 'steps' field, but got: {list(loaded_config.keys())}"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -972,41 +766,26 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
ImportError: If a step class cannot be imported or found in registry
|
||||
ValueError: If a step cannot be instantiated with its configuration
|
||||
"""
|
||||
steps, remaining_override_keys = cls._build_steps_from_config(loaded_config, overrides)
|
||||
|
||||
for step_instance, step_entry in zip(steps, loaded_config["steps"], strict=True):
|
||||
cls._load_step_state(step_instance, step_entry, model_id, base_path, hub_download_kwargs)
|
||||
|
||||
return steps, remaining_override_keys
|
||||
|
||||
@classmethod
|
||||
def _build_steps_from_config(
|
||||
cls,
|
||||
loaded_config: dict[str, Any],
|
||||
overrides: dict[str, Any],
|
||||
) -> tuple[list[ProcessorStep], set[str]]:
|
||||
"""Build processor steps from config without loading tensor state.
|
||||
|
||||
Args:
|
||||
loaded_config: The loaded processor configuration.
|
||||
overrides: User-provided constructor overrides keyed by step key.
|
||||
|
||||
Returns:
|
||||
A tuple containing instantiated steps and override keys that did not match a step.
|
||||
"""
|
||||
processor_steps: list[ProcessorStep] = []
|
||||
remaining_override_keys = set(overrides.keys())
|
||||
steps: list[ProcessorStep] = []
|
||||
override_keys = set(overrides.keys())
|
||||
|
||||
for step_entry in loaded_config["steps"]:
|
||||
# 1. Get step class and key
|
||||
step_class, step_key = cls._resolve_step_class(step_entry)
|
||||
processor_step = cls._instantiate_step(step_entry, step_class, step_key, overrides)
|
||||
|
||||
if step_key in remaining_override_keys:
|
||||
remaining_override_keys.discard(step_key)
|
||||
# 2. Instantiate step with overrides
|
||||
step_instance = cls._instantiate_step(step_entry, step_class, step_key, overrides)
|
||||
|
||||
processor_steps.append(processor_step)
|
||||
# 3. Load step state if available
|
||||
cls._load_step_state(step_instance, step_entry, model_id, base_path, hub_download_kwargs)
|
||||
|
||||
return processor_steps, remaining_override_keys
|
||||
# 4. Track used overrides
|
||||
if step_key in override_keys:
|
||||
override_keys.discard(step_key)
|
||||
|
||||
steps.append(step_instance)
|
||||
|
||||
return steps, override_keys
|
||||
|
||||
@classmethod
|
||||
def _resolve_step_class(cls, step_entry: dict[str, Any]) -> tuple[type[ProcessorStep], str]:
|
||||
@@ -1317,7 +1096,7 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def _is_processor_config(cls, config: Any) -> bool:
|
||||
def _is_processor_config(cls, config: dict) -> bool:
|
||||
"""Check if config follows DataProcessorPipeline format.
|
||||
|
||||
This method validates the processor configuration structure:
|
||||
@@ -1368,9 +1147,6 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
Returns:
|
||||
True if config follows valid DataProcessorPipeline format, False otherwise
|
||||
"""
|
||||
if not isinstance(config, dict):
|
||||
return False
|
||||
|
||||
# Must have a "steps" field with a list of step configurations
|
||||
if not isinstance(config.get("steps"), list):
|
||||
return False
|
||||
|
||||
@@ -81,7 +81,7 @@ def to_absolute_actions(actions: Tensor, state: Tensor, mask: Sequence[bool]) ->
|
||||
return actions
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("relative_actions_processor")
|
||||
@ProcessorStepRegistry.register("delta_actions_processor")
|
||||
@dataclass
|
||||
class RelativeActionsProcessorStep(ProcessorStep):
|
||||
"""Converts absolute actions to relative actions (action -= state) for masked dimensions.
|
||||
|
||||
@@ -20,16 +20,12 @@ from .factory import (
|
||||
make_reward_pre_post_processors as make_reward_pre_post_processors,
|
||||
)
|
||||
from .pretrained import PreTrainedRewardModel as PreTrainedRewardModel
|
||||
from .robometer.configuration_robometer import RobometerConfig as RobometerConfig
|
||||
from .sarm.configuration_sarm import SARMConfig as SARMConfig
|
||||
from .topreward.configuration_topreward import TOPRewardConfig as TOPRewardConfig
|
||||
|
||||
__all__ = [
|
||||
# Configuration classes
|
||||
"RewardClassifierConfig",
|
||||
"RobometerConfig",
|
||||
"SARMConfig",
|
||||
"TOPRewardConfig",
|
||||
# Base class
|
||||
"PreTrainedRewardModel",
|
||||
# Factory functions
|
||||
|
||||
@@ -25,9 +25,7 @@ from lerobot.processor import PolicyAction, PolicyProcessorPipeline
|
||||
|
||||
from .classifier.configuration_classifier import RewardClassifierConfig
|
||||
from .pretrained import PreTrainedRewardModel
|
||||
from .robometer.configuration_robometer import RobometerConfig
|
||||
from .sarm.configuration_sarm import SARMConfig
|
||||
from .topreward.configuration_topreward import TOPRewardConfig
|
||||
|
||||
|
||||
def get_reward_model_class(name: str) -> type[PreTrainedRewardModel]:
|
||||
@@ -39,7 +37,7 @@ def get_reward_model_class(name: str) -> type[PreTrainedRewardModel]:
|
||||
|
||||
Args:
|
||||
name: The name of the reward model. Supported names are "reward_classifier",
|
||||
"sarm", "robometer", "topreward".
|
||||
"sarm".
|
||||
|
||||
Returns:
|
||||
The reward model class corresponding to the given name.
|
||||
@@ -55,14 +53,6 @@ def get_reward_model_class(name: str) -> type[PreTrainedRewardModel]:
|
||||
from lerobot.rewards.sarm.modeling_sarm import SARMRewardModel
|
||||
|
||||
return SARMRewardModel
|
||||
elif name == "robometer":
|
||||
from lerobot.rewards.robometer.modeling_robometer import RobometerRewardModel
|
||||
|
||||
return RobometerRewardModel
|
||||
elif name == "topreward":
|
||||
from lerobot.rewards.topreward.modeling_topreward import TOPRewardModel
|
||||
|
||||
return TOPRewardModel
|
||||
else:
|
||||
try:
|
||||
return _get_reward_model_cls_from_name(name=name)
|
||||
@@ -79,7 +69,7 @@ def make_reward_model_config(reward_type: str, **kwargs) -> RewardModelConfig:
|
||||
|
||||
Args:
|
||||
reward_type: The type of the reward model. Supported types include
|
||||
"reward_classifier", "sarm", "robometer", "topreward".
|
||||
"reward_classifier", "sarm".
|
||||
**kwargs: Keyword arguments to be passed to the configuration class constructor.
|
||||
|
||||
Returns:
|
||||
@@ -92,10 +82,6 @@ def make_reward_model_config(reward_type: str, **kwargs) -> RewardModelConfig:
|
||||
return RewardClassifierConfig(**kwargs)
|
||||
elif reward_type == "sarm":
|
||||
return SARMConfig(**kwargs)
|
||||
elif reward_type == "robometer":
|
||||
return RobometerConfig(**kwargs)
|
||||
elif reward_type == "topreward":
|
||||
return TOPRewardConfig(**kwargs)
|
||||
else:
|
||||
try:
|
||||
config_cls = RewardModelConfig.get_choice_class(reward_type)
|
||||
@@ -175,21 +161,6 @@ def make_reward_pre_post_processors(
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
dataset_meta=kwargs.get("dataset_meta"),
|
||||
)
|
||||
elif isinstance(reward_cfg, RobometerConfig):
|
||||
from lerobot.rewards.robometer.processor_robometer import make_robometer_pre_post_processors
|
||||
|
||||
return make_robometer_pre_post_processors(
|
||||
config=reward_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif isinstance(reward_cfg, TOPRewardConfig):
|
||||
from lerobot.rewards.topreward.processor_topreward import make_topreward_pre_post_processors
|
||||
|
||||
return make_topreward_pre_post_processors(
|
||||
config=reward_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
else:
|
||||
try:
|
||||
|
||||
@@ -1,19 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .configuration_robometer import RobometerConfig
|
||||
from .modeling_robometer import RobometerRewardModel
|
||||
from .processor_robometer import make_robometer_pre_post_processors
|
||||
|
||||
__all__ = ["RobometerConfig", "RobometerRewardModel", "make_robometer_pre_post_processors"]
|
||||
@@ -1,320 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Compute per-frame Robometer progress and success curves for a LeRobot dataset.
|
||||
|
||||
For each episode, builds per-frame sub-samples using the frame-steps
|
||||
strategy from the Robometer eval server: for each original frame ``t``,
|
||||
linspace-subsample ``[0, t]`` into ``K`` frames (default 4, matching
|
||||
``NUM_SUBSAMPLED_FRAMES`` in the eval server), run one forward through
|
||||
the Robometer processor + model, and keep the last-frame progress value.
|
||||
All sub-samples are the same size ``K`` so they batch cleanly.
|
||||
|
||||
The parquet uses the same schema as SARM's
|
||||
:mod:`lerobot.rewards.sarm.compute_rabc_weights` so existing consumers —
|
||||
:class:`lerobot.rewards.sarm.rabc.RABCWeights` (which reads
|
||||
``progress_sparse``) and the progress-overlay script in
|
||||
``examples/dataset/create_progress_videos.py`` — work without modification.
|
||||
|
||||
Usage:
|
||||
# Dense per-frame progress for one episode
|
||||
python -m lerobot.rewards.robometer.compute_rabc_weights \\
|
||||
--dataset-repo-id lerobot/libero_10_image \\
|
||||
--reward-model-path lerobot/Robometer-4B \\
|
||||
--episodes 0
|
||||
|
||||
# All episodes with batching
|
||||
python -m lerobot.rewards.robometer.compute_rabc_weights \\
|
||||
--dataset-repo-id lerobot/libero_10_image \\
|
||||
--reward-model-path lerobot/Robometer-4B \\
|
||||
--batch-size 16
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
import pyarrow.parquet as pq
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.datasets import LeRobotDataset
|
||||
from lerobot.rewards.robometer.configuration_robometer import RobometerConfig
|
||||
from lerobot.rewards.robometer.modeling_robometer import RobometerRewardModel
|
||||
from lerobot.rewards.robometer.processor_robometer import RobometerEncoderProcessorStep
|
||||
from lerobot.types import TransitionKey
|
||||
|
||||
DEFAULT_OUTPUT_FILENAME = "robometer_progress.parquet"
|
||||
|
||||
# Upstream Robometer eval server uses K=4 for frame-steps sub-samples.
|
||||
DEFAULT_NUM_SUBSAMPLED_FRAMES = 4
|
||||
|
||||
|
||||
def get_reward_model_path_from_parquet(parquet_path: Path) -> str | None:
|
||||
"""Read ``reward_model_path`` from parquet metadata if available."""
|
||||
if not parquet_path.exists():
|
||||
return None
|
||||
try:
|
||||
metadata = pq.read_metadata(parquet_path).schema.to_arrow_schema().metadata
|
||||
if metadata and b"reward_model_path" in metadata:
|
||||
return metadata[b"reward_model_path"].decode()
|
||||
except Exception: # nosec B110
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
def _resolve_task(sample: dict[str, Any], default: str) -> str:
|
||||
"""Best-effort task extraction from a dataset sample."""
|
||||
task = sample.get("task")
|
||||
if isinstance(task, str) and task:
|
||||
return task
|
||||
return default
|
||||
|
||||
|
||||
def _build_subsample_indices(num_frames: int, num_subsampled_frames: int) -> list[np.ndarray]:
|
||||
"""Frame-steps linspace expansion.
|
||||
|
||||
For each ``t in [0, num_frames - 1]`` returns ``num_subsampled_frames``
|
||||
indices from ``np.linspace(0, t, num_subsampled_frames)`` — the first
|
||||
and last frames are always included. Each entry is a fixed-size array
|
||||
so the model can batch them.
|
||||
"""
|
||||
return [np.linspace(0, t, num_subsampled_frames).round().astype(np.int64) for t in range(num_frames)]
|
||||
|
||||
|
||||
def compute_robometer_progress(
|
||||
dataset_repo_id: str,
|
||||
reward_model_path: str,
|
||||
output_path: str | None = None,
|
||||
device: str = "cuda",
|
||||
batch_size: int = 32,
|
||||
num_subsampled_frames: int = DEFAULT_NUM_SUBSAMPLED_FRAMES,
|
||||
episodes: list[int] | None = None,
|
||||
image_key: str | None = None,
|
||||
) -> Path:
|
||||
"""Run Robometer over a dataset and write per-frame progress + success."""
|
||||
logging.info(f"Loading Robometer: {reward_model_path}")
|
||||
config = RobometerConfig(pretrained_path=reward_model_path, device=device)
|
||||
if image_key is not None:
|
||||
config.image_key = image_key
|
||||
model = RobometerRewardModel.from_pretrained(reward_model_path, config=config)
|
||||
model.to(device).eval()
|
||||
|
||||
encoder = RobometerEncoderProcessorStep(
|
||||
base_model_id=config.base_model_id,
|
||||
image_key=config.image_key,
|
||||
task_key=config.task_key,
|
||||
default_task=config.default_task,
|
||||
max_frames=num_subsampled_frames,
|
||||
use_multi_image=config.use_multi_image,
|
||||
use_per_frame_progress_token=config.use_per_frame_progress_token,
|
||||
)
|
||||
|
||||
image_key = config.image_key
|
||||
|
||||
logging.info(f"Loading dataset: {dataset_repo_id}")
|
||||
dataset = LeRobotDataset(dataset_repo_id, download_videos=True)
|
||||
logging.info(f"Dataset: {dataset.num_episodes} episodes, {dataset.num_frames} frames")
|
||||
|
||||
episode_indices = list(range(dataset.num_episodes)) if episodes is None else episodes
|
||||
logging.info(f"Processing {len(episode_indices)} episode(s)")
|
||||
|
||||
all_index: list[int] = []
|
||||
all_episode: list[int] = []
|
||||
all_frame: list[int] = []
|
||||
all_progress: list[float] = []
|
||||
|
||||
for episode_idx in tqdm(episode_indices, desc="Episodes"):
|
||||
ep = dataset.meta.episodes[episode_idx]
|
||||
ep_start = int(ep["dataset_from_index"])
|
||||
ep_end = int(ep["dataset_to_index"])
|
||||
num_frames = ep_end - ep_start
|
||||
if num_frames <= 0:
|
||||
continue
|
||||
|
||||
first_sample = dataset[ep_start]
|
||||
task = _resolve_task(first_sample, default=config.default_task or "perform the task")
|
||||
|
||||
ep_frames = torch.stack([dataset[ep_start + i][image_key] for i in range(num_frames)])
|
||||
|
||||
sub_indices = _build_subsample_indices(num_frames, num_subsampled_frames)
|
||||
|
||||
progress_per_frame = np.zeros(num_frames, dtype=np.float32)
|
||||
|
||||
for start in tqdm(range(0, num_frames, batch_size), desc=f" Ep {episode_idx}", leave=False):
|
||||
end = min(start + batch_size, num_frames)
|
||||
frames_batch = torch.stack([ep_frames[sub_indices[i]] for i in range(start, end)])
|
||||
|
||||
transition = {
|
||||
TransitionKey.OBSERVATION: {image_key: frames_batch},
|
||||
TransitionKey.COMPLEMENTARY_DATA: {"task": task},
|
||||
}
|
||||
encoded = encoder(transition)
|
||||
obs = encoded[TransitionKey.OBSERVATION]
|
||||
batch = {
|
||||
key: value.to(device) if isinstance(value, torch.Tensor) else value
|
||||
for key, value in obs.items()
|
||||
}
|
||||
|
||||
with torch.no_grad():
|
||||
rewards = model.compute_reward(batch)
|
||||
progress_per_frame[start:end] = rewards.cpu().numpy()
|
||||
|
||||
for local in range(num_frames):
|
||||
all_index.append(ep_start + local)
|
||||
all_episode.append(episode_idx)
|
||||
all_frame.append(local)
|
||||
all_progress.append(float(progress_per_frame[local]))
|
||||
|
||||
if device.startswith("cuda"):
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
table = pa.table(
|
||||
{
|
||||
"index": np.asarray(all_index, dtype=np.int64),
|
||||
"episode_index": np.asarray(all_episode, dtype=np.int64),
|
||||
"frame_index": np.asarray(all_frame, dtype=np.int64),
|
||||
"progress_sparse": np.asarray(all_progress, dtype=np.float32),
|
||||
}
|
||||
).replace_schema_metadata({b"reward_model_path": reward_model_path.encode()})
|
||||
|
||||
out = Path(dataset.root) / DEFAULT_OUTPUT_FILENAME if output_path is None else Path(output_path)
|
||||
out.parent.mkdir(parents=True, exist_ok=True)
|
||||
pq.write_table(table, out)
|
||||
logging.info(f"Saved {len(table)} frame values to {out}")
|
||||
|
||||
progress_arr = np.asarray(all_progress, dtype=np.float32)
|
||||
if progress_arr.size:
|
||||
logging.info(
|
||||
f"Progress: mean={float(progress_arr.mean()):.4f}, "
|
||||
f"std={float(progress_arr.std()):.4f}, "
|
||||
f"min={float(progress_arr.min()):.4f}, "
|
||||
f"max={float(progress_arr.max()):.4f}"
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Compute per-frame Robometer progress curves for RA-BC weighting.",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
# Dense per-frame progress for one episode
|
||||
python -m lerobot.rewards.robometer.compute_rabc_weights \\
|
||||
--dataset-repo-id lerobot/libero_10_image \\
|
||||
--reward-model-path lerobot/Robometer-4B \\
|
||||
--episodes 0
|
||||
|
||||
# All episodes, smaller batches for memory-constrained GPUs
|
||||
python -m lerobot.rewards.robometer.compute_rabc_weights \\
|
||||
--dataset-repo-id lerobot/libero_10_image \\
|
||||
--reward-model-path lerobot/Robometer-4B \\
|
||||
--batch-size 16
|
||||
""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-repo-id", type=str, required=True, help="HuggingFace dataset repo id or local path."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--reward-model-path", type=str, default=None, help="Robometer checkpoint repo id or local path."
|
||||
)
|
||||
parser.add_argument("--output-path", type=str, default=None, help="Output parquet path.")
|
||||
parser.add_argument("--device", type=str, default="cuda", help="Device to use (default: cuda).")
|
||||
parser.add_argument(
|
||||
"--batch-size", type=int, default=32, help="Sub-samples per Qwen forward (default: 32)."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-subsampled-frames",
|
||||
type=int,
|
||||
default=DEFAULT_NUM_SUBSAMPLED_FRAMES,
|
||||
help=f"Frames per sub-sample (default: {DEFAULT_NUM_SUBSAMPLED_FRAMES}, matches eval server).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--episodes", type=int, nargs="+", default=None, help="Process only these episode indices."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--image-key", type=str, default=None, help="Image observation key (default: from config)."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push-to-hub", action="store_true", help="Upload to the dataset repo on HuggingFace Hub."
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
|
||||
|
||||
reward_model_path = args.reward_model_path
|
||||
if reward_model_path is None:
|
||||
temp_dataset = LeRobotDataset(args.dataset_repo_id, download_videos=False)
|
||||
parquet_path = Path(temp_dataset.root) / DEFAULT_OUTPUT_FILENAME
|
||||
reward_model_path = get_reward_model_path_from_parquet(parquet_path)
|
||||
if reward_model_path:
|
||||
logging.info(f"Using reward model from parquet metadata: {reward_model_path}")
|
||||
else:
|
||||
raise ValueError(
|
||||
"--reward-model-path is required (no existing parquet with model metadata found)."
|
||||
)
|
||||
|
||||
output_path = compute_robometer_progress(
|
||||
dataset_repo_id=args.dataset_repo_id,
|
||||
reward_model_path=reward_model_path,
|
||||
output_path=args.output_path,
|
||||
device=args.device,
|
||||
batch_size=args.batch_size,
|
||||
num_subsampled_frames=args.num_subsampled_frames,
|
||||
episodes=args.episodes,
|
||||
image_key=args.image_key,
|
||||
)
|
||||
|
||||
print(f"\nRobometer progress saved to: {output_path}")
|
||||
|
||||
if args.push_to_hub:
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
api = HfApi()
|
||||
hub_path = DEFAULT_OUTPUT_FILENAME
|
||||
|
||||
print(f"\nUploading to Hub: {args.dataset_repo_id}/{hub_path}")
|
||||
api.upload_file(
|
||||
path_or_fileobj=str(output_path),
|
||||
path_in_repo=hub_path,
|
||||
repo_id=args.dataset_repo_id,
|
||||
repo_type="dataset",
|
||||
)
|
||||
print(
|
||||
"Successfully uploaded to: "
|
||||
f"https://huggingface.co/datasets/{args.dataset_repo_id}/blob/main/{hub_path}"
|
||||
)
|
||||
|
||||
print("\nTo use in training, add to your config:")
|
||||
print(" use_rabc: true")
|
||||
print(f" rabc_progress_path: hf://datasets/{args.dataset_repo_id}/{hub_path}")
|
||||
print(" rabc_head_mode: sparse")
|
||||
else:
|
||||
print("\nTo use in training, add to your config:")
|
||||
print(" use_rabc: true")
|
||||
print(f" rabc_progress_path: {output_path}")
|
||||
print(" rabc_head_mode: sparse")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,158 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from lerobot.configs import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.configs.rewards import RewardModelConfig
|
||||
from lerobot.utils.constants import OBS_IMAGES
|
||||
from lerobot.utils.import_utils import _transformers_available, require_package
|
||||
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers import AutoConfig, AutoTokenizer
|
||||
else:
|
||||
AutoConfig = None # type: ignore[assignment]
|
||||
AutoTokenizer = None # type: ignore[assignment]
|
||||
|
||||
|
||||
# Special tokens Robometer adds to the Qwen-VL tokenizer at construction time.
|
||||
# The order is part of the data contract: upstream resized ``embed_tokens``
|
||||
# after adding these tokens in this exact order, so changing the set or order
|
||||
# would silently misalign the saved embedding rows with their token ids.
|
||||
# ``<|reward_token|>`` and ``<|sim_token|>`` are leftover from earlier upstream
|
||||
# heads (never read at inference) but still occupy rows the checkpoint expects.
|
||||
ROBOMETER_SPECIAL_TOKENS = (
|
||||
"<|split_token|>",
|
||||
"<|reward_token|>",
|
||||
"<|pref_token|>",
|
||||
"<|sim_token|>",
|
||||
"<|prog_token|>",
|
||||
)
|
||||
|
||||
|
||||
@RewardModelConfig.register_subclass("robometer")
|
||||
@dataclass
|
||||
class RobometerConfig(RewardModelConfig):
|
||||
"""Configuration for the Robometer reward model."""
|
||||
|
||||
pretrained_path: str | None = "lerobot/Robometer-4B"
|
||||
image_key: str = OBS_IMAGES + ".top"
|
||||
task_key: str = "task"
|
||||
default_task: str | None = None
|
||||
|
||||
max_frames: int | None = 8
|
||||
reward_output: str = "progress" # "progress" or "success"
|
||||
success_threshold: float = 0.5
|
||||
|
||||
license: str | None = "apache-2.0"
|
||||
tags: list[str] | None = field(
|
||||
default_factory=lambda: ["reward-model", "vision-language", "qwen3-vl", "zero-shot"]
|
||||
)
|
||||
|
||||
base_model_id: str = "Qwen/Qwen3-VL-4B-Instruct"
|
||||
torch_dtype: str = "bfloat16"
|
||||
use_multi_image: bool = True
|
||||
use_per_frame_progress_token: bool = True
|
||||
average_temporal_patches: bool = True
|
||||
frame_pooling: str = "mean" # "mean" | "boundary" | "attention"
|
||||
frame_pooling_attn_temperature: float = 1.0
|
||||
progress_loss_type: str = "discrete" # "l1" | "l2" | "discrete"
|
||||
progress_discrete_bins: int = 10
|
||||
|
||||
# Serialised Qwen backbone config (post-resize). Always populated by
|
||||
# ``__post_init__`` from ``base_model_id`` + ``len(tokenizer) + 5``, so it
|
||||
# is non-empty after construction. Saved into ``config.json`` automatically
|
||||
# by the base ``_save_pretrained``.
|
||||
vlm_config: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
input_features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||
output_features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"REWARD": NormalizationMode.IDENTITY,
|
||||
}
|
||||
)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
super().__post_init__()
|
||||
if self.reward_output not in {"progress", "success"}:
|
||||
raise ValueError(f"reward_output must be 'progress' or 'success', got {self.reward_output!r}")
|
||||
if self.max_frames is not None and self.max_frames < 1:
|
||||
raise ValueError(f"max_frames must be >= 1, got {self.max_frames}")
|
||||
if self.frame_pooling not in {"mean", "boundary", "attention"}:
|
||||
raise ValueError(f"frame_pooling must be mean/boundary/attention; got {self.frame_pooling!r}")
|
||||
if self.frame_pooling_attn_temperature <= 0:
|
||||
raise ValueError("frame_pooling_attn_temperature must be > 0")
|
||||
if self.progress_loss_type not in {"l1", "l2", "discrete"}:
|
||||
raise ValueError(f"progress_loss_type must be l1/l2/discrete; got {self.progress_loss_type!r}")
|
||||
if self.use_per_frame_progress_token and not self.use_multi_image:
|
||||
raise ValueError("use_per_frame_progress_token=True requires use_multi_image=True")
|
||||
|
||||
if self.image_key not in self.input_features:
|
||||
self.input_features[self.image_key] = PolicyFeature(shape=(3, 224, 224), type=FeatureType.VISUAL)
|
||||
self.output_features.setdefault("progress", PolicyFeature(shape=(1,), type=FeatureType.REWARD))
|
||||
self.output_features.setdefault("success", PolicyFeature(shape=(1,), type=FeatureType.REWARD))
|
||||
|
||||
# Deterministically populate ``vlm_config`` so it is non-empty after
|
||||
# construction. For ``Qwen/Qwen3-VL-4B-Instruct`` this gives
|
||||
# ``len(tokenizer) + 5 = 151,669 + 5 = 151,674`` — the exact post-resize
|
||||
# vocab the published ``Robometer-4B`` checkpoint was saved with.
|
||||
if not self.vlm_config:
|
||||
require_package("transformers", extra="robometer")
|
||||
vlm = AutoConfig.from_pretrained(self.base_model_id).to_dict()
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.base_model_id)
|
||||
text_config = vlm.get("text_config")
|
||||
if not isinstance(text_config, dict):
|
||||
raise ValueError(
|
||||
f"Backbone config for {self.base_model_id!r} has no nested `text_config`; "
|
||||
"Robometer expects a Qwen-VL-style config."
|
||||
)
|
||||
text_config["vocab_size"] = len(tokenizer) + len(ROBOMETER_SPECIAL_TOKENS)
|
||||
self.vlm_config = vlm
|
||||
|
||||
@property
|
||||
def use_discrete_progress(self) -> bool:
|
||||
"""Whether the progress head outputs distribution logits over bins."""
|
||||
return self.progress_loss_type.lower() == "discrete"
|
||||
|
||||
@property
|
||||
def vlm_backbone_config(self):
|
||||
"""Reconstruct the Qwen backbone config from :attr:`vlm_config`."""
|
||||
require_package("transformers", extra="robometer")
|
||||
config_dict = deepcopy(self.vlm_config)
|
||||
model_type = config_dict.pop("model_type", None)
|
||||
if model_type is None:
|
||||
raise ValueError("vlm_config must include `model_type` to reconstruct the backbone config")
|
||||
return AutoConfig.for_model(model_type, **config_dict)
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> list[int] | None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
return None
|
||||
|
||||
def validate_features(self) -> None:
|
||||
if self.image_key not in self.input_features:
|
||||
raise ValueError(f"Robometer requires image input feature {self.image_key!r}")
|
||||
@@ -1,481 +0,0 @@
|
||||
# Copyright 2026 Anthony Liang, Yigit Korkmaz, Stephen Tu, Erdem Bıyık, Jesse Zhang
|
||||
# and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""ROBOMETER: Scaling General-Purpose Robotic Reward Models via Trajectory Comparisons.
|
||||
|
||||
Paper: https://arxiv.org/abs/2603.02115
|
||||
Project: https://robometer.github.io
|
||||
Original code: https://github.com/aliang8/robometer
|
||||
Model: https://huggingface.co/robometer/Robometer-4B
|
||||
|
||||
Robometer is a general-purpose, video-language-input reward model built on
|
||||
``Qwen/Qwen3-VL-4B-Instruct``. It is trained with a dual reward-prediction
|
||||
objective:
|
||||
|
||||
- A frame-level progress loss anchoring reward magnitude on expert data.
|
||||
- A trajectory-comparison preference loss imposing global ordering constraints
|
||||
across trajectories sharing the same instruction.
|
||||
|
||||
To support downstream RL it also predicts a frame-level binary success. The
|
||||
training prompt inserts three learnable tokens:
|
||||
|
||||
- ``<|prog_token|>`` after each frame to read per-frame progress and success.
|
||||
- ``<|pref_token|>`` at the end to read pairwise preference (training-only).
|
||||
- ``<|split_token|>`` between two trajectories in preference samples
|
||||
(training-only).
|
||||
|
||||
Progress is modeled as a categorical distribution over ``progress_discrete_bins``
|
||||
uniformly-spaced centers in ``[0, 1]`` (C51-style), and the continuous estimate
|
||||
is recovered as the softmax-weighted mean of those centers — see
|
||||
:func:`convert_bins_to_continuous`.
|
||||
|
||||
This LeRobot port is **inference-only**: the preference head is preserved in
|
||||
the state dict for byte-equivalence with the published ``Robometer-4B``
|
||||
checkpoint but is not queried by :meth:`RobometerRewardModel.compute_reward`,
|
||||
which returns the last-frame progress (clamped to ``[0, 1]``) or sigmoid'd
|
||||
success probability depending on :attr:`RobometerConfig.reward_output`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.rewards.pretrained import PreTrainedRewardModel
|
||||
from lerobot.rewards.robometer.configuration_robometer import RobometerConfig
|
||||
from lerobot.utils.constants import OBS_PREFIX
|
||||
from lerobot.utils.import_utils import _transformers_available, require_package
|
||||
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers import AutoModelForImageTextToText
|
||||
else:
|
||||
AutoModelForImageTextToText = None # type: ignore[assignment]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Namespace for Robometer's pre-encoded Qwen-VL observation tensors.
|
||||
ROBOMETER_FEATURE_PREFIX = f"{OBS_PREFIX}robometer."
|
||||
ROBOMETER_QWEN_INPUT_KEYS = (
|
||||
"input_ids",
|
||||
"attention_mask",
|
||||
"pixel_values",
|
||||
"pixel_values_videos",
|
||||
"image_grid_thw",
|
||||
"video_grid_thw",
|
||||
"second_per_grid_ts",
|
||||
"mm_token_type_ids",
|
||||
)
|
||||
ROBOMETER_METADATA_KEYS = (
|
||||
"prog_token_id",
|
||||
"vision_start_token_id",
|
||||
"vision_end_token_id",
|
||||
"video_merge_size",
|
||||
)
|
||||
ROBOMETER_INPUT_KEYS = ROBOMETER_QWEN_INPUT_KEYS + ROBOMETER_METADATA_KEYS
|
||||
|
||||
|
||||
def convert_bins_to_continuous(bin_logits: Tensor) -> Tensor:
|
||||
"""Collapse per-bin logits into a single value in ``[0, 1]``.
|
||||
|
||||
The discrete progress head outputs ``num_bins`` logits per frame. Bins are
|
||||
evenly spaced centers in ``[0, 1]``; the continuous prediction is the
|
||||
softmax-weighted mean of those centers.
|
||||
"""
|
||||
bin_probs = torch.softmax(bin_logits, dim=-1)
|
||||
num_bins = bin_logits.shape[-1]
|
||||
bin_centers = torch.linspace(0.0, 1.0, num_bins, device=bin_logits.device, dtype=bin_logits.dtype)
|
||||
return (bin_probs * bin_centers).sum(dim=-1)
|
||||
|
||||
|
||||
def _squeeze_last_safe(x: Tensor) -> Tensor:
|
||||
"""Drop a trailing singleton dim only when present."""
|
||||
return x.squeeze(-1) if x.ndim > 1 and x.shape[-1] == 1 else x
|
||||
|
||||
|
||||
def _torch_dtype(name: str) -> torch.dtype:
|
||||
dtype = getattr(torch, name, None)
|
||||
if isinstance(dtype, torch.dtype):
|
||||
return dtype
|
||||
raise ValueError(f"Unknown torch dtype: {name!r}")
|
||||
|
||||
|
||||
class RobometerPredictionHead(nn.Sequential):
|
||||
"""Small MLP head used for Robometer's progress / success / preference outputs."""
|
||||
|
||||
def __init__(self, hidden_dim: int, output_size: int, *, dropout: float, with_sigmoid: bool) -> None:
|
||||
layers: list[nn.Module] = [
|
||||
nn.Linear(hidden_dim, hidden_dim // 2),
|
||||
nn.LayerNorm(hidden_dim // 2),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_dim // 2, output_size),
|
||||
]
|
||||
if with_sigmoid:
|
||||
layers.append(nn.Sigmoid())
|
||||
super().__init__(*layers)
|
||||
|
||||
|
||||
def decode_progress_outputs(
|
||||
progress_logits: Tensor | None,
|
||||
success_logits: Tensor | None,
|
||||
*,
|
||||
is_discrete_mode: bool,
|
||||
) -> dict[str, list[list[float]]]:
|
||||
"""Decode RBM head outputs into per-frame floats.
|
||||
|
||||
Args:
|
||||
progress_logits: ``(B, T)`` (continuous) or ``(B, T, num_bins)`` (discrete).
|
||||
success_logits: ``(B, T)`` raw logits, ``sigmoid``-ed to probabilities.
|
||||
is_discrete_mode: if True the progress logits get a softmax over bins
|
||||
and are projected onto bin centers via :func:`convert_bins_to_continuous`.
|
||||
|
||||
Returns:
|
||||
Dict with ``progress_pred`` and ``success_probs``, each a list of
|
||||
length ``B`` of per-frame float lists.
|
||||
"""
|
||||
progress_pred: list[list[float]] = []
|
||||
success_probs: list[list[float]] = []
|
||||
|
||||
if progress_logits is not None:
|
||||
for sample_logits in progress_logits:
|
||||
if is_discrete_mode:
|
||||
continuous = convert_bins_to_continuous(sample_logits.detach().float().cpu())
|
||||
progress_pred.append(continuous.flatten().tolist())
|
||||
else:
|
||||
progress_pred.append(sample_logits.detach().float().cpu().flatten().tolist())
|
||||
|
||||
if success_logits is not None:
|
||||
for sample_logits in success_logits:
|
||||
success_probs.append(torch.sigmoid(sample_logits.detach().float().cpu()).flatten().tolist())
|
||||
|
||||
return {"progress_pred": progress_pred, "success_probs": success_probs}
|
||||
|
||||
|
||||
class RobometerRewardModel(PreTrainedRewardModel):
|
||||
"""Robometer (RBM) reward model — inference-only LeRobot port.
|
||||
|
||||
Wraps a Qwen-VL backbone (default: ``Qwen/Qwen3-VL-4B-Instruct``) with three
|
||||
prediction heads from the paper (progress, success, preference). At
|
||||
inference time only the progress and success heads are queried; the
|
||||
preference head is kept on the module so the published ``Robometer-4B``
|
||||
safetensors load unchanged.
|
||||
"""
|
||||
|
||||
name = "robometer"
|
||||
config_class = RobometerConfig
|
||||
|
||||
def __init__(self, config: RobometerConfig, *, dropout: float = 0.1) -> None:
|
||||
require_package("transformers", extra="robometer")
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
# Two backbone-build paths (EO-1 style, branched on ``pretrained_path``):
|
||||
#
|
||||
# - Fresh training (``pretrained_path is None``): download the base
|
||||
# Qwen weights and resize the embed table to match
|
||||
# ``vlm_config.text_config.vocab_size`` — populated deterministically
|
||||
# in ``RobometerConfig.__post_init__`` as
|
||||
# ``len(tokenizer) + len(ROBOMETER_SPECIAL_TOKENS)``
|
||||
#
|
||||
# - Loading a saved checkpoint (``pretrained_path`` is set): rebuild
|
||||
# the empty architecture from ``vlm_config`` via
|
||||
# ``AutoModelForImageTextToText.from_config`` so the subsequent
|
||||
# ``model.safetensors`` load is a direct fill of the right shape —
|
||||
# no redundant Qwen weight download.
|
||||
torch_dtype = _torch_dtype(config.torch_dtype)
|
||||
if config.pretrained_path is None:
|
||||
self.model = AutoModelForImageTextToText.from_pretrained(
|
||||
config.base_model_id,
|
||||
dtype=torch_dtype,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
target_vocab = config.vlm_config["text_config"]["vocab_size"]
|
||||
self.model.resize_token_embeddings(target_vocab)
|
||||
else:
|
||||
self.model = AutoModelForImageTextToText.from_config(
|
||||
config.vlm_backbone_config,
|
||||
dtype=torch_dtype,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
# All Qwen-VL backbones Robometer supports expose `text_config.hidden_size`.
|
||||
# Falls back to the top-level `hidden_size` so future non-multimodal
|
||||
# variants would still resolve.
|
||||
backbone_config = self.model.config
|
||||
text_config = getattr(backbone_config, "text_config", None)
|
||||
hidden_size = getattr(text_config, "hidden_size", None) if text_config is not None else None
|
||||
if hidden_size is None:
|
||||
hidden_size = getattr(backbone_config, "hidden_size", None)
|
||||
if hidden_size is None:
|
||||
raise AttributeError(
|
||||
f"Could not infer hidden_size from backbone config of {config.base_model_id}"
|
||||
)
|
||||
hidden_dim = int(hidden_size)
|
||||
|
||||
# Robometer's three prediction heads + frame-pool attention.
|
||||
progress_output = config.progress_discrete_bins if config.use_discrete_progress else 1
|
||||
self.progress_head = RobometerPredictionHead(
|
||||
hidden_dim,
|
||||
progress_output,
|
||||
dropout=dropout,
|
||||
with_sigmoid=not config.use_discrete_progress,
|
||||
)
|
||||
self.preference_head = RobometerPredictionHead(hidden_dim, 1, dropout=dropout, with_sigmoid=False)
|
||||
self.success_head = RobometerPredictionHead(hidden_dim, 1, dropout=dropout, with_sigmoid=False)
|
||||
self.frame_pool_attn = nn.Linear(hidden_dim, 1, bias=False)
|
||||
|
||||
# Match the dtype of the loaded base model so weight loading is a no-op cast.
|
||||
model_dtype = next(self.model.parameters()).dtype
|
||||
self.progress_head.to(dtype=model_dtype)
|
||||
self.preference_head.to(dtype=model_dtype)
|
||||
self.success_head.to(dtype=model_dtype)
|
||||
self.frame_pool_attn.to(dtype=model_dtype)
|
||||
|
||||
def compute_reward(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
inputs = {
|
||||
key: batch[f"{ROBOMETER_FEATURE_PREFIX}{key}"]
|
||||
for key in ROBOMETER_INPUT_KEYS
|
||||
if f"{ROBOMETER_FEATURE_PREFIX}{key}" in batch
|
||||
}
|
||||
if "input_ids" not in inputs:
|
||||
raise KeyError(
|
||||
f"Robometer batch missing pre-encoded inputs (expected "
|
||||
f"`{ROBOMETER_FEATURE_PREFIX}input_ids`). Make sure the "
|
||||
"RobometerEncoderProcessorStep ran before `compute_reward`."
|
||||
)
|
||||
|
||||
device = next(self.model.parameters()).device
|
||||
inputs = {key: value.to(device) if hasattr(value, "to") else value for key, value in inputs.items()}
|
||||
|
||||
self.eval()
|
||||
with torch.no_grad():
|
||||
progress_logits, success_logits = self._compute_rbm_logits(inputs)
|
||||
|
||||
decoded = decode_progress_outputs(
|
||||
progress_logits,
|
||||
success_logits,
|
||||
is_discrete_mode=self.config.use_discrete_progress,
|
||||
)
|
||||
values = (
|
||||
decoded["success_probs"] if self.config.reward_output == "success" else decoded["progress_pred"]
|
||||
)
|
||||
|
||||
rewards = torch.stack([torch.as_tensor(seq, dtype=torch.float32)[-1] for seq in values])
|
||||
if self.config.reward_output == "success":
|
||||
rewards = (rewards > self.config.success_threshold).float()
|
||||
else:
|
||||
# Match upstream Robometer's ``extract_rewards_from_output``: per-frame
|
||||
# progress predictions are clamped to ``[0, 1]`` before being returned.
|
||||
rewards = rewards.clamp(0.0, 1.0)
|
||||
return rewards.to(self.config.device or "cpu")
|
||||
|
||||
def _compute_rbm_logits(
|
||||
self,
|
||||
inputs: dict[str, Any],
|
||||
) -> tuple[Tensor, Tensor]:
|
||||
"""Run the Qwen3-VL backbone and apply Robometer's heads.
|
||||
|
||||
``inputs`` is the encoded batch produced by
|
||||
:class:`RobometerEncoderProcessorStep`. It carries Qwen tensors as well
|
||||
as Robometer-specific metadata (``prog_token_id``,
|
||||
``vision_start_token_id``, ``vision_end_token_id``, ``video_merge_size``)
|
||||
— the metadata is popped here so the rest can be forwarded straight to
|
||||
the Qwen model.
|
||||
|
||||
Returns ``(progress_logits, success_logits)``. Shapes:
|
||||
|
||||
- ``progress_logits``: ``(B, T)`` (continuous) or ``(B, T, num_bins)`` (discrete).
|
||||
- ``success_logits``: ``(B, T)`` raw logits (sigmoid happens at decode time).
|
||||
"""
|
||||
prog_token_id = inputs.pop("prog_token_id", None)
|
||||
vision_start_token_id = inputs.pop("vision_start_token_id", None)
|
||||
vision_end_token_id = inputs.pop("vision_end_token_id", None)
|
||||
video_merge_size = inputs.pop("video_merge_size", 14)
|
||||
|
||||
# Qwen3-VL doesn't reliably populate `last_hidden_state`; ask for the
|
||||
# full hidden-state tuple and take the last layer. This matches the
|
||||
# `is_qwen3` path in upstream Robometer's `RBM.forward_qwen` (main).
|
||||
outputs = self.model(**inputs, output_hidden_states=True, return_dict=True)
|
||||
hidden_state = (
|
||||
outputs.hidden_states[-1]
|
||||
if getattr(outputs, "hidden_states", None)
|
||||
else outputs.last_hidden_state
|
||||
)
|
||||
|
||||
input_ids = inputs["input_ids"]
|
||||
if self.config.use_per_frame_progress_token:
|
||||
if prog_token_id is None:
|
||||
raise KeyError("`prog_token_id` missing in batch (run RobometerEncoderProcessorStep first)")
|
||||
return self._process_token_extraction(hidden_state, input_ids, prog_token_id=prog_token_id)
|
||||
if self.config.use_multi_image:
|
||||
if vision_start_token_id is None or vision_end_token_id is None:
|
||||
raise KeyError(
|
||||
"`vision_start_token_id` / `vision_end_token_id` missing in batch "
|
||||
"(run RobometerEncoderProcessorStep first)"
|
||||
)
|
||||
return self._process_multi_image_frames(
|
||||
hidden_state,
|
||||
input_ids,
|
||||
start_id=vision_start_token_id,
|
||||
end_id=vision_end_token_id,
|
||||
)
|
||||
video_grid_thw = inputs.get("video_grid_thw")
|
||||
if video_grid_thw is None:
|
||||
raise ValueError("video_grid_thw is required for video-mode Robometer inference")
|
||||
if vision_start_token_id is None:
|
||||
raise KeyError("`vision_start_token_id` missing in batch")
|
||||
return self._process_video_frames(
|
||||
hidden_state,
|
||||
input_ids,
|
||||
video_grid_thw,
|
||||
start_id=vision_start_token_id,
|
||||
merge_size=video_merge_size,
|
||||
)
|
||||
|
||||
def _apply_heads_to_hidden_states(self, frame_embeddings: Tensor) -> tuple[Tensor, Tensor]:
|
||||
"""Apply progress + success heads to a tensor of frame embeddings."""
|
||||
progress_out = self.progress_head(frame_embeddings)
|
||||
progress = progress_out if self.config.use_discrete_progress else _squeeze_last_safe(progress_out)
|
||||
success = _squeeze_last_safe(self.success_head(frame_embeddings))
|
||||
return progress, success
|
||||
|
||||
def _process_token_extraction(
|
||||
self,
|
||||
hidden_state: Tensor,
|
||||
input_ids: Tensor,
|
||||
*,
|
||||
prog_token_id: int,
|
||||
) -> tuple[Tensor, Tensor]:
|
||||
"""Per-frame progress/success from ``<|prog_token|>`` positions."""
|
||||
token_mask = input_ids == prog_token_id
|
||||
batch_indices, positions = token_mask.nonzero(as_tuple=True)
|
||||
if positions.numel() == 0:
|
||||
raise ValueError("`<|prog_token|>` not found in any sequence")
|
||||
|
||||
per_sample_hidden = [
|
||||
hidden_state[i, positions[batch_indices == i]] for i in range(input_ids.shape[0])
|
||||
]
|
||||
progress_list, success_list = [], []
|
||||
for embeddings in per_sample_hidden:
|
||||
if embeddings.shape[0] == 0:
|
||||
raise ValueError("`<|prog_token|>` missing in a sequence")
|
||||
progress, success = self._apply_heads_to_hidden_states(embeddings)
|
||||
progress_list.append(progress)
|
||||
success_list.append(success)
|
||||
|
||||
return torch.stack(progress_list), torch.stack(success_list)
|
||||
|
||||
def _process_multi_image_frames(
|
||||
self,
|
||||
hidden_state: Tensor,
|
||||
input_ids: Tensor,
|
||||
*,
|
||||
start_id: int,
|
||||
end_id: int,
|
||||
) -> tuple[Tensor, Tensor]:
|
||||
"""Per-frame progress/success in multi-image mode (Qwen-VL)."""
|
||||
progress_list, success_list = [], []
|
||||
for batch_idx in range(input_ids.shape[0]):
|
||||
seq_ids = input_ids[batch_idx]
|
||||
seq_hidden = hidden_state[batch_idx]
|
||||
frame_embeddings = self._extract_hidden_states_from_token_pairs(
|
||||
seq_hidden, seq_ids, start_id, end_id
|
||||
)
|
||||
progress, success = self._apply_heads_to_hidden_states(frame_embeddings)
|
||||
progress_list.append(progress)
|
||||
success_list.append(success)
|
||||
|
||||
return torch.stack(progress_list), torch.stack(success_list)
|
||||
|
||||
def _extract_hidden_states_from_token_pairs(
|
||||
self,
|
||||
hidden_state: Tensor,
|
||||
input_ids: Tensor,
|
||||
start_id: int,
|
||||
end_id: int,
|
||||
) -> Tensor:
|
||||
start_positions = (input_ids == start_id).nonzero(as_tuple=True)[0]
|
||||
end_positions = (input_ids == end_id).nonzero(as_tuple=True)[0]
|
||||
if start_positions.numel() == 0:
|
||||
raise ValueError("`<|vision_start|>` not found in sequence")
|
||||
if start_positions.numel() != end_positions.numel():
|
||||
raise ValueError(
|
||||
f"Mismatched vision token counts: {start_positions.numel()} start vs "
|
||||
f"{end_positions.numel()} end"
|
||||
)
|
||||
|
||||
frames: list[Tensor] = []
|
||||
for start, end in zip(start_positions.tolist(), end_positions.tolist(), strict=True):
|
||||
if start >= end:
|
||||
raise ValueError(f"Invalid vision token pair: start={start} end={end}")
|
||||
patch_tokens = hidden_state[start + 1 : end]
|
||||
if patch_tokens.shape[0] == 0:
|
||||
frames.append((hidden_state[start] + hidden_state[end]) / 2.0)
|
||||
continue
|
||||
|
||||
pooling = self.config.frame_pooling
|
||||
if pooling == "mean":
|
||||
frames.append(patch_tokens.mean(dim=0))
|
||||
elif pooling == "boundary":
|
||||
frames.append(patch_tokens[-1])
|
||||
else: # attention
|
||||
scores = (
|
||||
self.frame_pool_attn(patch_tokens).squeeze(-1)
|
||||
/ self.config.frame_pooling_attn_temperature
|
||||
)
|
||||
weights = torch.softmax(scores, dim=0).unsqueeze(-1)
|
||||
frames.append((weights * patch_tokens).sum(dim=0))
|
||||
|
||||
return torch.stack(frames)
|
||||
|
||||
def _process_video_frames(
|
||||
self,
|
||||
hidden_state: Tensor,
|
||||
input_ids: Tensor,
|
||||
video_grid_thw: Tensor,
|
||||
*,
|
||||
start_id: int,
|
||||
merge_size: int,
|
||||
) -> tuple[Tensor, Tensor]:
|
||||
"""Per-frame progress/success in video mode (Qwen-VL)."""
|
||||
progress_list, success_list = [], []
|
||||
for batch_idx in range(input_ids.shape[0]):
|
||||
seq_ids = input_ids[batch_idx]
|
||||
seq_hidden = hidden_state[batch_idx]
|
||||
start_positions = (seq_ids == start_id).nonzero(as_tuple=True)[0]
|
||||
if start_positions.numel() == 0:
|
||||
raise ValueError("`<|vision_start|>` not found in sequence")
|
||||
t_dim, h_dim, w_dim = (int(x) for x in video_grid_thw[batch_idx].tolist())
|
||||
tokens_per_frame = (h_dim * w_dim) // (merge_size**2)
|
||||
|
||||
cursor = start_positions[0].item()
|
||||
frame_embeddings: list[Tensor] = []
|
||||
for _ in range(t_dim):
|
||||
if self.config.average_temporal_patches:
|
||||
patch = seq_hidden[cursor : cursor + tokens_per_frame]
|
||||
frame_embeddings.append(patch.mean(dim=0))
|
||||
else:
|
||||
frame_embeddings.append(seq_hidden[cursor + tokens_per_frame])
|
||||
cursor += tokens_per_frame
|
||||
|
||||
stacked = torch.stack(frame_embeddings)
|
||||
progress, success = self._apply_heads_to_hidden_states(stacked)
|
||||
progress_list.append(progress)
|
||||
success_list.append(success)
|
||||
|
||||
return torch.stack(progress_list), torch.stack(success_list)
|
||||
@@ -1,338 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Robometer pre/post processing pipelines."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.configs import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
PolicyAction,
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorStep,
|
||||
ProcessorStepRegistry,
|
||||
policy_action_to_transition,
|
||||
)
|
||||
from lerobot.rewards.robometer.configuration_robometer import (
|
||||
ROBOMETER_SPECIAL_TOKENS,
|
||||
RobometerConfig,
|
||||
)
|
||||
from lerobot.rewards.robometer.modeling_robometer import ROBOMETER_FEATURE_PREFIX
|
||||
from lerobot.types import EnvTransition, TransitionKey
|
||||
from lerobot.utils.constants import (
|
||||
OBS_IMAGES,
|
||||
POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
)
|
||||
from lerobot.utils.import_utils import _transformers_available, require_package
|
||||
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers import AutoProcessor
|
||||
else:
|
||||
AutoProcessor = None
|
||||
|
||||
PROGRESS_PROMPT = (
|
||||
"The task for the robot is '{task}'. Given the trajectory video, predict "
|
||||
"the task progress at each frame, how far along the robot is towards "
|
||||
"completing the task, a float between 0 and 1, where 0 is the starting "
|
||||
"state and 1 is when the task is completed. If the robot is not "
|
||||
"performing the same task, predict 0 progress."
|
||||
)
|
||||
|
||||
|
||||
def _frames_to_pil(frames: np.ndarray) -> list[Image.Image]:
|
||||
"""Convert ``(T, H, W, C)`` uint8 frames to a list of PIL images."""
|
||||
if frames.ndim != 4:
|
||||
raise ValueError(f"Expected (T,H,W,C) frames; got shape {frames.shape}")
|
||||
if frames.dtype != np.uint8:
|
||||
frames = np.clip(frames, 0, 255).astype(np.uint8)
|
||||
return [Image.fromarray(frames[i]) for i in range(frames.shape[0])]
|
||||
|
||||
|
||||
def _video_to_numpy(video: Tensor, *, max_frames: int | None) -> np.ndarray:
|
||||
"""Convert one trajectory tensor to a ``(T, H, W, C) uint8`` numpy array."""
|
||||
if max_frames is not None:
|
||||
video = video[-max_frames:]
|
||||
if video.shape[1] in (1, 3):
|
||||
video = video.permute(0, 2, 3, 1)
|
||||
elif video.shape[-1] not in (1, 3):
|
||||
raise ValueError(f"Expected channel dim of size 1 or 3, got shape {tuple(video.shape)}")
|
||||
|
||||
array = video.detach().cpu().numpy()
|
||||
if np.issubdtype(array.dtype, np.floating) and array.size > 0 and array.max() <= 1.0:
|
||||
array = array * 255.0
|
||||
return np.clip(array, 0, 255).astype(np.uint8)
|
||||
|
||||
|
||||
def _expand_tasks(task: Any, *, batch_size: int, default: str | None) -> list[str]:
|
||||
if task is None:
|
||||
task = default
|
||||
if task is None:
|
||||
raise KeyError("Robometer expected a task description in complementary data")
|
||||
if isinstance(task, str):
|
||||
return [task] * batch_size
|
||||
if isinstance(task, tuple):
|
||||
task = list(task)
|
||||
if not (isinstance(task, list) and all(isinstance(item, str) for item in task)):
|
||||
raise TypeError(f"Robometer task must be a string or list of strings, got {type(task)}")
|
||||
if len(task) == 1 and batch_size > 1:
|
||||
return task * batch_size
|
||||
if len(task) != batch_size:
|
||||
raise ValueError(f"Expected {batch_size} tasks, got {len(task)}")
|
||||
return task
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="robometer_encoder")
|
||||
class RobometerEncoderProcessorStep(ProcessorStep):
|
||||
"""Encode raw frames + task into Qwen-VL tensors for the Robometer model.
|
||||
|
||||
Loads a :class:`~transformers.AutoProcessor` matching ``base_model_id`` and
|
||||
registers Robometer's special tokens on the tokenizer. The matching
|
||||
embedding resize happens model-side in
|
||||
:meth:`RobometerRewardModel.__init__`.
|
||||
|
||||
At call time the step reads:
|
||||
|
||||
- ``observation[image_key]``: ``(B, T, C, H, W)`` or ``(B, C, H, W)`` frames.
|
||||
- ``complementary_data[task_key]``: a string or list of strings.
|
||||
|
||||
and writes ``observation[f"{ROBOMETER_FEATURE_PREFIX}<name>"]`` for:
|
||||
|
||||
- the Qwen-VL processor outputs: ``input_ids``, ``attention_mask``,
|
||||
``pixel_values``, ``image_grid_thw``, ``video_grid_thw``, ...
|
||||
- Robometer-specific token ids consumed by the model heads:
|
||||
``prog_token_id``, ``vision_start_token_id``, ``vision_end_token_id``,
|
||||
``video_merge_size``.
|
||||
"""
|
||||
|
||||
base_model_id: str = "Qwen/Qwen3-VL-4B-Instruct"
|
||||
image_key: str = OBS_IMAGES + ".top"
|
||||
task_key: str = "task"
|
||||
default_task: str | None = None
|
||||
max_frames: int | None = 8
|
||||
use_multi_image: bool = True
|
||||
use_per_frame_progress_token: bool = True
|
||||
max_length: int = 1024
|
||||
|
||||
_processor: Any = field(default=None, init=False, repr=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
require_package("transformers", extra="robometer")
|
||||
require_package("qwen-vl-utils", extra="robometer", import_name="qwen_vl_utils")
|
||||
|
||||
self._processor = AutoProcessor.from_pretrained(
|
||||
self.base_model_id,
|
||||
trust_remote_code=True,
|
||||
do_sample_frames=False,
|
||||
padding_side="right",
|
||||
)
|
||||
|
||||
# Register Robometer's special tokens on the tokenizer. The matching
|
||||
# embedding resize happens model-side in `RobometerRewardModel.__init__`.
|
||||
tokenizer = self._processor.tokenizer
|
||||
# Qwen tokenizers may not define a pad token, but batched prompts/videos
|
||||
# require padding, so reuse EOS as the padding token.
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
for token in ROBOMETER_SPECIAL_TOKENS:
|
||||
if token not in tokenizer.get_vocab():
|
||||
tokenizer.add_special_tokens({"additional_special_tokens": [token]})
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
observation = transition.get(TransitionKey.OBSERVATION)
|
||||
complementary = transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
|
||||
if not isinstance(observation, dict):
|
||||
raise ValueError("RobometerEncoderProcessorStep requires an observation dict")
|
||||
|
||||
if self.image_key not in observation:
|
||||
raise KeyError(f"Robometer expected image key {self.image_key!r} in observation")
|
||||
|
||||
frames = observation[self.image_key]
|
||||
tensor = frames.detach().cpu() if isinstance(frames, Tensor) else torch.as_tensor(frames)
|
||||
if tensor.ndim == 4:
|
||||
tensor = tensor.unsqueeze(1)
|
||||
elif tensor.ndim != 5:
|
||||
raise ValueError(
|
||||
f"Expected Robometer frames with shape (B,C,H,W) or (B,T,C,H,W); got {tuple(tensor.shape)}"
|
||||
)
|
||||
|
||||
batch_size = tensor.shape[0]
|
||||
tasks = _expand_tasks(
|
||||
complementary.get(self.task_key, self.default_task),
|
||||
batch_size=batch_size,
|
||||
default=self.default_task,
|
||||
)
|
||||
|
||||
samples = [
|
||||
(_video_to_numpy(tensor[i], max_frames=self.max_frames), tasks[i]) for i in range(batch_size)
|
||||
]
|
||||
encoded = self.encode_samples(samples)
|
||||
|
||||
new_observation = dict(observation)
|
||||
for key, value in encoded.items():
|
||||
new_observation[f"{ROBOMETER_FEATURE_PREFIX}{key}"] = value
|
||||
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.OBSERVATION] = new_observation
|
||||
return new_transition
|
||||
|
||||
def encode_samples(self, samples: list[tuple[np.ndarray, str]]) -> dict[str, Tensor]:
|
||||
"""Run the Qwen-VL processor on a list of ``(frames, task)`` samples."""
|
||||
from qwen_vl_utils import process_vision_info
|
||||
|
||||
conversations = [self._build_conversation(frames, task) for frames, task in samples]
|
||||
|
||||
texts = [
|
||||
self._processor.apply_chat_template(
|
||||
msg,
|
||||
tokenize=False,
|
||||
add_generation_prompt=False,
|
||||
add_vision_id=True,
|
||||
enable_thinking=False,
|
||||
fps=1,
|
||||
)
|
||||
for msg in conversations
|
||||
]
|
||||
|
||||
process_kwargs: dict[str, Any] = {
|
||||
"return_video_kwargs": True,
|
||||
"return_video_metadata": True,
|
||||
}
|
||||
image_processor = getattr(self._processor, "image_processor", None)
|
||||
if image_processor is not None and hasattr(image_processor, "patch_size"):
|
||||
process_kwargs["image_patch_size"] = image_processor.patch_size
|
||||
|
||||
image_inputs, video_inputs, video_kwargs = process_vision_info(conversations, **process_kwargs)
|
||||
|
||||
videos: list[Any] | None = None
|
||||
video_metadatas: list[Any] | None = None
|
||||
if video_inputs:
|
||||
if isinstance(video_inputs[0], tuple) and len(video_inputs[0]) == 2:
|
||||
videos_seq, metadatas_seq = zip(*video_inputs, strict=False)
|
||||
videos = list(videos_seq)
|
||||
video_metadatas = list(metadatas_seq)
|
||||
else:
|
||||
videos = list(video_inputs)
|
||||
|
||||
processor_kwargs: dict[str, Any] = {
|
||||
"text": texts,
|
||||
"images": image_inputs,
|
||||
"padding": True,
|
||||
"truncation": False,
|
||||
"max_length": self.max_length,
|
||||
"return_tensors": "pt",
|
||||
"do_resize": False,
|
||||
}
|
||||
if videos is not None:
|
||||
processor_kwargs["videos"] = videos
|
||||
if video_metadatas is not None:
|
||||
processor_kwargs["video_metadata"] = video_metadatas
|
||||
if video_kwargs:
|
||||
processor_kwargs.update(video_kwargs)
|
||||
|
||||
encoded = self._processor(**processor_kwargs)
|
||||
|
||||
# Write Robometer-specific token ids and the video patch merge size into
|
||||
# the encoded batch so `RobometerRewardModel` doesn't need its own
|
||||
# tokenizer at inference (EO1-style separation: the processor owns the
|
||||
# tokenizer, the model owns the backbone and heads).
|
||||
tokenizer = self._processor.tokenizer
|
||||
encoded["prog_token_id"] = tokenizer.convert_tokens_to_ids("<|prog_token|>")
|
||||
encoded["vision_start_token_id"] = tokenizer.convert_tokens_to_ids("<|vision_start|>")
|
||||
encoded["vision_end_token_id"] = tokenizer.convert_tokens_to_ids("<|vision_end|>")
|
||||
video_processor = getattr(self._processor, "video_processor", None)
|
||||
encoded["video_merge_size"] = int(getattr(video_processor, "merge_size", 14))
|
||||
return encoded
|
||||
|
||||
def _build_conversation(self, frames: np.ndarray, task: str) -> list[dict[str, Any]]:
|
||||
pil_frames = _frames_to_pil(frames)
|
||||
prompt = PROGRESS_PROMPT.format(task=task)
|
||||
content: list[dict[str, Any]] = [{"type": "text", "text": prompt}]
|
||||
|
||||
if self.use_multi_image:
|
||||
for image in pil_frames:
|
||||
content.append({"type": "image", "image": image})
|
||||
if self.use_per_frame_progress_token:
|
||||
content.append({"type": "text", "text": "<|prog_token|>"})
|
||||
else:
|
||||
content.append({"type": "video", "video": pil_frames, "sample_fps": 1.0})
|
||||
|
||||
return [{"role": "user", "content": content}]
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
return features
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {
|
||||
"base_model_id": self.base_model_id,
|
||||
"image_key": self.image_key,
|
||||
"task_key": self.task_key,
|
||||
"default_task": self.default_task,
|
||||
"max_frames": self.max_frames,
|
||||
"use_multi_image": self.use_multi_image,
|
||||
"use_per_frame_progress_token": self.use_per_frame_progress_token,
|
||||
"max_length": self.max_length,
|
||||
}
|
||||
|
||||
|
||||
def make_robometer_pre_post_processors(
|
||||
config: RobometerConfig,
|
||||
dataset_stats: dict[str, dict[str, Any]] | None = None,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
"""Pipeline that pre-encodes frames + task into Qwen-VL tensors.
|
||||
|
||||
The preprocessor adds a batch dimension if needed, runs Robometer's
|
||||
encoder, and moves everything to the configured device. The
|
||||
postprocessor is the identity since Robometer outputs a single reward
|
||||
tensor.
|
||||
"""
|
||||
del dataset_stats # Robometer has its own normalisation inside the Qwen-VL processor.
|
||||
|
||||
preprocessor = PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
|
||||
steps=[
|
||||
AddBatchDimensionProcessorStep(),
|
||||
RobometerEncoderProcessorStep(
|
||||
base_model_id=config.base_model_id,
|
||||
image_key=config.image_key,
|
||||
task_key=config.task_key,
|
||||
default_task=config.default_task,
|
||||
max_frames=config.max_frames,
|
||||
use_multi_image=config.use_multi_image,
|
||||
use_per_frame_progress_token=config.use_per_frame_progress_token,
|
||||
),
|
||||
DeviceProcessorStep(device=config.device or "cpu"),
|
||||
],
|
||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
)
|
||||
postprocessor = PolicyProcessorPipeline(
|
||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
to_transition=policy_action_to_transition,
|
||||
)
|
||||
return preprocessor, postprocessor
|
||||
@@ -1,19 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .configuration_topreward import TOPRewardConfig
|
||||
from .modeling_topreward import TOPRewardModel
|
||||
from .processor_topreward import make_topreward_pre_post_processors
|
||||
|
||||
__all__ = ["TOPRewardConfig", "TOPRewardModel", "make_topreward_pre_post_processors"]
|
||||
@@ -1,353 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Compute per-frame TOPReward progress curves for a LeRobot dataset.
|
||||
|
||||
For each episode, scores trajectory prefixes of increasing length using
|
||||
the TOPReward reward model, min-max normalises the raw log-prob rewards per episode,
|
||||
and writes a parquet file with one row per frame.
|
||||
|
||||
The parquet uses the same schema as SARM's :mod:`lerobot.rewards.sarm.compute_rabc_weights`.
|
||||
|
||||
Usage:
|
||||
# Sparse-dense mode (15 anchors per episode, matches upstream)
|
||||
python -m lerobot.rewards.topreward.compute_rabc_weights \\
|
||||
--dataset-repo-id lerobot/libero_10_image \\
|
||||
--num-samples 15
|
||||
|
||||
# Use a different VLM backbone
|
||||
python -m lerobot.rewards.topreward.compute_rabc_weights \\
|
||||
--dataset-repo-id lerobot/libero_10_image \\
|
||||
--vlm-name Qwen/Qwen3-VL-4B-Instruct
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
import pyarrow.parquet as pq
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.datasets import LeRobotDataset
|
||||
from lerobot.rewards.topreward.configuration_topreward import TOPRewardConfig
|
||||
from lerobot.rewards.topreward.modeling_topreward import TOPRewardModel
|
||||
from lerobot.rewards.topreward.processor_topreward import TOPRewardEncoderProcessorStep
|
||||
from lerobot.types import TransitionKey
|
||||
|
||||
DEFAULT_OUTPUT_FILENAME = "topreward_progress.parquet"
|
||||
|
||||
|
||||
def get_reward_model_path_from_parquet(parquet_path: Path) -> str | None:
|
||||
"""Read ``reward_model_path`` from parquet metadata if available."""
|
||||
if not parquet_path.exists():
|
||||
return None
|
||||
try:
|
||||
metadata = pq.read_metadata(parquet_path).schema.to_arrow_schema().metadata
|
||||
if metadata and b"reward_model_path" in metadata:
|
||||
return metadata[b"reward_model_path"].decode()
|
||||
except Exception: # nosec B110
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
def _resolve_task(sample: dict[str, Any], default: str) -> str:
|
||||
"""Best-effort task extraction from a dataset sample."""
|
||||
task = sample.get("task")
|
||||
if isinstance(task, str) and task:
|
||||
return task
|
||||
return default
|
||||
|
||||
|
||||
def normalize_rewards(rewards: list[float] | np.ndarray) -> np.ndarray:
|
||||
"""Min-max normalise raw log-prob rewards into ``[0, 1]``."""
|
||||
rewards_arr = np.asarray(rewards, dtype=np.float64)
|
||||
if rewards_arr.size == 0:
|
||||
return rewards_arr.astype(np.float32)
|
||||
if rewards_arr.size == 1:
|
||||
return np.array([1.0], dtype=np.float32)
|
||||
r_min, r_max = rewards_arr.min(), rewards_arr.max()
|
||||
if r_max == r_min:
|
||||
return np.ones_like(rewards_arr, dtype=np.float32)
|
||||
return ((rewards_arr - r_min) / (r_max - r_min)).astype(np.float32)
|
||||
|
||||
|
||||
def compute_instruction_rewards_for_prefixes(
|
||||
model: TOPRewardModel,
|
||||
encoder: TOPRewardEncoderProcessorStep,
|
||||
dataset: LeRobotDataset,
|
||||
ep_start: int,
|
||||
num_frames: int,
|
||||
task: str,
|
||||
image_key: str,
|
||||
num_samples: int | None,
|
||||
device: str,
|
||||
) -> np.ndarray:
|
||||
"""Score an episode via prefix sweep and return a per-frame normalised curve."""
|
||||
if num_samples is None or num_samples >= num_frames:
|
||||
prefix_lengths = np.arange(1, num_frames + 1, dtype=np.int64)
|
||||
else:
|
||||
prefix_lengths = np.unique(np.linspace(1, num_frames, num_samples).round().astype(np.int64))
|
||||
|
||||
episode_frames = torch.stack([dataset[ep_start + i][image_key] for i in range(num_frames)])
|
||||
rewards: list[float] = []
|
||||
for length in prefix_lengths:
|
||||
frames = episode_frames[: int(length)].unsqueeze(0) # (1, T, C, H, W)
|
||||
|
||||
transition = {
|
||||
TransitionKey.OBSERVATION: {image_key: frames},
|
||||
TransitionKey.COMPLEMENTARY_DATA: {"task": task},
|
||||
}
|
||||
encoded = encoder(transition)
|
||||
obs = encoded[TransitionKey.OBSERVATION]
|
||||
batch = {
|
||||
key: value.to(device) if isinstance(value, torch.Tensor) else value for key, value in obs.items()
|
||||
}
|
||||
|
||||
with torch.no_grad():
|
||||
reward = model.compute_reward(batch)
|
||||
rewards.append(float(reward.item()))
|
||||
|
||||
normalized_rewards = normalize_rewards(rewards)
|
||||
|
||||
if prefix_lengths.shape[0] == num_frames:
|
||||
return normalized_rewards
|
||||
|
||||
return np.interp(
|
||||
np.arange(1, num_frames + 1, dtype=np.float64),
|
||||
prefix_lengths.astype(np.float64),
|
||||
normalized_rewards.astype(np.float64),
|
||||
).astype(np.float32)
|
||||
|
||||
|
||||
def compute_topreward_progress(
|
||||
dataset_repo_id: str,
|
||||
reward_model_path: str | None = None,
|
||||
vlm_name: str | None = None,
|
||||
output_path: str | None = None,
|
||||
device: str = "cuda",
|
||||
num_samples: int | None = None,
|
||||
fps: float | None = None,
|
||||
episodes: list[int] | None = None,
|
||||
) -> Path:
|
||||
"""Run TOPReward over a dataset and write per-frame progress."""
|
||||
if reward_model_path is not None:
|
||||
logging.info(f"Loading TOPReward config from: {reward_model_path}")
|
||||
model = TOPRewardModel.from_pretrained(reward_model_path)
|
||||
config = model.config
|
||||
config.device = device
|
||||
if vlm_name is not None and vlm_name != config.vlm_name:
|
||||
logging.info(f"Overriding vlm_name from config: {config.vlm_name} -> {vlm_name}")
|
||||
config.vlm_name = vlm_name
|
||||
model = TOPRewardModel(config)
|
||||
else:
|
||||
config_kwargs: dict[str, Any] = {"device": device}
|
||||
if vlm_name is not None:
|
||||
config_kwargs["vlm_name"] = vlm_name
|
||||
if fps is not None:
|
||||
config_kwargs["fps"] = fps
|
||||
config = TOPRewardConfig(**config_kwargs)
|
||||
logging.info(f"Constructing TOPReward with VLM: {config.vlm_name}")
|
||||
model = TOPRewardModel(config)
|
||||
|
||||
model.to(device).eval()
|
||||
|
||||
encoder = TOPRewardEncoderProcessorStep(
|
||||
vlm_name=config.vlm_name,
|
||||
image_key=config.image_key,
|
||||
task_key=config.task_key,
|
||||
default_task=config.default_task,
|
||||
max_frames=None, # no tail-crop: we control prefix length explicitly
|
||||
fps=config.fps,
|
||||
prompt_prefix=config.prompt_prefix,
|
||||
prompt_suffix_template=config.prompt_suffix_template,
|
||||
add_chat_template=config.add_chat_template,
|
||||
max_length=config.max_input_length,
|
||||
)
|
||||
|
||||
image_key = config.image_key
|
||||
|
||||
logging.info(f"Loading dataset: {dataset_repo_id}")
|
||||
dataset = LeRobotDataset(dataset_repo_id, download_videos=True)
|
||||
logging.info(f"Dataset: {dataset.num_episodes} episodes, {dataset.num_frames} frames")
|
||||
|
||||
episode_indices = list(range(dataset.num_episodes)) if episodes is None else episodes
|
||||
logging.info(f"Processing {len(episode_indices)} episode(s)")
|
||||
|
||||
all_index: list[int] = []
|
||||
all_episode: list[int] = []
|
||||
all_frame: list[int] = []
|
||||
all_progress: list[float] = []
|
||||
|
||||
for episode_idx in tqdm(episode_indices, desc="Episodes"):
|
||||
ep = dataset.meta.episodes[episode_idx]
|
||||
ep_start = int(ep["dataset_from_index"])
|
||||
ep_end = int(ep["dataset_to_index"])
|
||||
num_frames = ep_end - ep_start
|
||||
if num_frames <= 0:
|
||||
continue
|
||||
|
||||
first_sample = dataset[ep_start]
|
||||
task = _resolve_task(first_sample, default=config.default_task or "perform the task")
|
||||
|
||||
per_frame = compute_instruction_rewards_for_prefixes(
|
||||
model=model,
|
||||
encoder=encoder,
|
||||
dataset=dataset,
|
||||
ep_start=ep_start,
|
||||
num_frames=num_frames,
|
||||
task=task,
|
||||
image_key=image_key,
|
||||
num_samples=num_samples,
|
||||
device=device,
|
||||
)
|
||||
|
||||
for local in range(num_frames):
|
||||
all_index.append(ep_start + local)
|
||||
all_episode.append(episode_idx)
|
||||
all_frame.append(local)
|
||||
all_progress.append(float(per_frame[local]))
|
||||
|
||||
if device.startswith("cuda"):
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
table = pa.table(
|
||||
{
|
||||
"index": np.asarray(all_index, dtype=np.int64),
|
||||
"episode_index": np.asarray(all_episode, dtype=np.int64),
|
||||
"frame_index": np.asarray(all_frame, dtype=np.int64),
|
||||
"progress_sparse": np.asarray(all_progress, dtype=np.float32),
|
||||
}
|
||||
)
|
||||
|
||||
schema_metadata: dict[bytes, bytes] = {b"vlm_name": config.vlm_name.encode()}
|
||||
if reward_model_path is not None:
|
||||
schema_metadata[b"reward_model_path"] = reward_model_path.encode()
|
||||
table = table.replace_schema_metadata(schema_metadata)
|
||||
|
||||
out = Path(dataset.root) / DEFAULT_OUTPUT_FILENAME if output_path is None else Path(output_path)
|
||||
out.parent.mkdir(parents=True, exist_ok=True)
|
||||
pq.write_table(table, out)
|
||||
logging.info(f"Saved {len(table)} frame values to {out}")
|
||||
|
||||
progress_arr = np.asarray(all_progress, dtype=np.float32)
|
||||
if progress_arr.size:
|
||||
logging.info(
|
||||
f"Progress: mean={float(progress_arr.mean()):.4f}, "
|
||||
f"std={float(progress_arr.std()):.4f}, "
|
||||
f"min={float(progress_arr.min()):.4f}, "
|
||||
f"max={float(progress_arr.max()):.4f}"
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Compute per-frame TOPReward progress curves for RA-BC weighting.",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
# Sparse-dense mode (matches upstream TOPReward num_samples=15)
|
||||
python -m lerobot.rewards.topreward.compute_rabc_weights \\
|
||||
--dataset-repo-id lerobot/libero_10_image \\
|
||||
--num-samples 15
|
||||
|
||||
# Use a smaller VLM
|
||||
python -m lerobot.rewards.topreward.compute_rabc_weights \\
|
||||
--dataset-repo-id lerobot/libero_10_image \\
|
||||
--vlm-name Qwen/Qwen3-VL-4B-Instruct
|
||||
""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-repo-id", type=str, required=True, help="HuggingFace dataset repo id or local path."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--reward-model-path", type=str, default=None, help="Optional TOPReward LeRobot config."
|
||||
)
|
||||
parser.add_argument("--vlm-name", type=str, default=None, help="Override the VLM backbone (HF Hub id).")
|
||||
parser.add_argument("--output-path", type=str, default=None, help="Output parquet path.")
|
||||
parser.add_argument("--device", type=str, default="cuda", help="Device to use (default: cuda).")
|
||||
parser.add_argument(
|
||||
"--num-samples",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Anchor prefix samples per episode. None = dense. 15 matches upstream.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--episodes",
|
||||
type=int,
|
||||
nargs="+",
|
||||
default=None,
|
||||
help="Process only these episode indices (e.g. --episodes 0 or --episodes 0 5 10).",
|
||||
)
|
||||
parser.add_argument("--fps", type=float, default=None, help="Override TOPRewardConfig.fps.")
|
||||
parser.add_argument(
|
||||
"--push-to-hub", action="store_true", help="Upload to the dataset repo on HuggingFace Hub."
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
|
||||
|
||||
output_path = compute_topreward_progress(
|
||||
dataset_repo_id=args.dataset_repo_id,
|
||||
reward_model_path=args.reward_model_path,
|
||||
vlm_name=args.vlm_name,
|
||||
output_path=args.output_path,
|
||||
device=args.device,
|
||||
num_samples=args.num_samples,
|
||||
fps=args.fps,
|
||||
episodes=args.episodes,
|
||||
)
|
||||
|
||||
print(f"\nTOPReward progress saved to: {output_path}")
|
||||
|
||||
if args.push_to_hub:
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
api = HfApi()
|
||||
hub_path = DEFAULT_OUTPUT_FILENAME
|
||||
|
||||
print(f"\nUploading to Hub: {args.dataset_repo_id}/{hub_path}")
|
||||
api.upload_file(
|
||||
path_or_fileobj=str(output_path),
|
||||
path_in_repo=hub_path,
|
||||
repo_id=args.dataset_repo_id,
|
||||
repo_type="dataset",
|
||||
)
|
||||
print(
|
||||
"Successfully uploaded to: "
|
||||
f"https://huggingface.co/datasets/{args.dataset_repo_id}/blob/main/{hub_path}"
|
||||
)
|
||||
|
||||
print("\nTo use in training, add to your config:")
|
||||
print(" use_rabc: true")
|
||||
print(f" rabc_progress_path: hf://datasets/{args.dataset_repo_id}/{hub_path}")
|
||||
print(" rabc_head_mode: sparse")
|
||||
else:
|
||||
print("\nTo use in training, add to your config:")
|
||||
print(" use_rabc: true")
|
||||
print(f" rabc_progress_path: {output_path}")
|
||||
print(" rabc_head_mode: sparse")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,146 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.configs import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.configs.rewards import RewardModelConfig
|
||||
from lerobot.utils.constants import OBS_IMAGES
|
||||
|
||||
# Default prompt scaffolding from the upstream TOPReward paper / reference
|
||||
# implementation (``QwenClient.compute_instruction_reward``). The prompt
|
||||
# scores the terminal ``True`` token in ``f"{instruction} ... True"``
|
||||
# given the video.
|
||||
DEFAULT_PROMPT_PREFIX = (
|
||||
"The above video shows a robot manipulation trajectory that completes the following task: "
|
||||
)
|
||||
DEFAULT_PROMPT_SUFFIX_TEMPLATE = (
|
||||
"{instruction} Decide whether the above statement is True or not. The answer is: True"
|
||||
)
|
||||
|
||||
|
||||
@RewardModelConfig.register_subclass("topreward")
|
||||
@dataclass
|
||||
class TOPRewardConfig(RewardModelConfig):
|
||||
"""Configuration for the TOPReward zero-shot reward model.
|
||||
|
||||
TOPReward is **zero-shot**: it has no learnable parameters of its own.
|
||||
The "model" is a generic vision-language model (default
|
||||
``Qwen/Qwen3-VL-8B-Instruct``) used with a fixed prompt to extract
|
||||
token log-probabilities as a reward signal. There is therefore no
|
||||
fine-tuned checkpoint to host: ``pretrained_path`` is unused at
|
||||
runtime — the model identity is :attr:`vlm_name` (an HF Hub id).
|
||||
|
||||
Args:
|
||||
vlm_name: Hugging Face Hub id of the underlying VLM. Must be a
|
||||
Qwen3-VL family model (the only client implemented in this
|
||||
LeRobot port).
|
||||
torch_dtype: Torch dtype name passed to the VLM loader
|
||||
(``"auto"``, ``"bfloat16"``, ``"float16"``, ...).
|
||||
attn_implementation: ``transformers`` attention implementation
|
||||
(e.g. ``"flash_attention_2"``, ``"sdpa"``). Defaults to
|
||||
``None`` so the upstream picks the best available.
|
||||
image_key: Observation key that holds the trajectory frames.
|
||||
task_key: Complementary-data key that holds the task instruction.
|
||||
default_task: Fallback instruction when ``task_key`` is absent.
|
||||
max_frames: Cap on the number of frames fed to the VLM per
|
||||
sample. ``None`` = use all frames.
|
||||
fps: Frames-per-second metadata for the Qwen video processor.
|
||||
prompt_prefix: Text shown to the VLM right after the video and
|
||||
before the suffix template.
|
||||
prompt_suffix_template: Suffix appended after ``prompt_prefix``.
|
||||
Must contain ``{instruction}``; the VLM scores the
|
||||
log-likelihood of the tokens that follow the prefix.
|
||||
add_chat_template: If ``True``, wrap the full prompt with the
|
||||
tokenizer's chat template before tokenisation (matches
|
||||
upstream ``add_chat_template=True``).
|
||||
success_threshold: Optional log-prob threshold. If finite,
|
||||
:meth:`TOPRewardModel.compute_reward` returns
|
||||
``(reward > success_threshold).float()`` instead of the raw
|
||||
log-prob.
|
||||
max_input_length: Hard limit on the total tokenized input length;
|
||||
samples that exceed it raise a ``ValueError``.
|
||||
"""
|
||||
|
||||
# Path to a local LeRobot dir or HF repo that holds a ``config.json``
|
||||
# snapshot of this TOPRewardConfig. The VLM weights themselves are
|
||||
# always identified by ``vlm_name``.
|
||||
pretrained_path: str | None = None
|
||||
|
||||
vlm_name: str = "Qwen/Qwen3-VL-8B-Instruct"
|
||||
torch_dtype: str = "auto"
|
||||
attn_implementation: str | None = None
|
||||
|
||||
image_key: str = OBS_IMAGES + ".top"
|
||||
task_key: str = "task"
|
||||
default_task: str | None = None
|
||||
max_frames: int | None = 16
|
||||
fps: float = 2.0
|
||||
|
||||
prompt_prefix: str = DEFAULT_PROMPT_PREFIX
|
||||
prompt_suffix_template: str = DEFAULT_PROMPT_SUFFIX_TEMPLATE
|
||||
add_chat_template: bool = False
|
||||
|
||||
success_threshold: float = float("-inf")
|
||||
max_input_length: int = 32768
|
||||
|
||||
license: str | None = "mit" # matches upstream TOPReward
|
||||
tags: list[str] | None = field(
|
||||
default_factory=lambda: ["reward-model", "vision-language", "qwen3-vl", "zero-shot"]
|
||||
)
|
||||
|
||||
input_features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||
output_features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"REWARD": NormalizationMode.IDENTITY,
|
||||
}
|
||||
)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
super().__post_init__()
|
||||
if self.max_frames is not None and self.max_frames < 1:
|
||||
raise ValueError(f"max_frames must be >= 1, got {self.max_frames}")
|
||||
if self.fps <= 0:
|
||||
raise ValueError(f"fps must be > 0, got {self.fps}")
|
||||
if "{instruction}" not in self.prompt_suffix_template:
|
||||
raise ValueError(
|
||||
"prompt_suffix_template must contain `{instruction}` so the model "
|
||||
"scores the log-likelihood of the task suffix."
|
||||
)
|
||||
if self.max_input_length <= 0:
|
||||
raise ValueError(f"max_input_length must be > 0, got {self.max_input_length}")
|
||||
|
||||
if self.image_key not in self.input_features:
|
||||
self.input_features[self.image_key] = PolicyFeature(shape=(3, 224, 224), type=FeatureType.VISUAL)
|
||||
self.output_features.setdefault("reward", PolicyFeature(shape=(1,), type=FeatureType.REWARD))
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> list[int] | None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
return None
|
||||
|
||||
def validate_features(self) -> None:
|
||||
if self.image_key not in self.input_features:
|
||||
raise ValueError(f"TOPReward requires image input feature {self.image_key!r}")
|
||||
@@ -1,238 +0,0 @@
|
||||
# Copyright 2026 Shirui Chen, Cole Harrison, Ying-Chun Lee, Angela Jin Yang,
|
||||
# Zhongzheng Ren, Lillian J. Ratliff, Jiafei Duan, Dieter Fox, Ranjay Krishna
|
||||
# and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""TOPReward: Token Probabilities as Hidden Zero-Shot Rewards for Robotics.
|
||||
|
||||
Paper: https://arxiv.org/abs/2602.19313
|
||||
Project: https://topreward.github.io/webpage/
|
||||
Original code: https://github.com/TOPReward/TOPReward
|
||||
Backbone: https://huggingface.co/Qwen/Qwen3-VL-8B-Instruct (default)
|
||||
|
||||
TOPReward is a **zero-shot** reward model: it has no fine-tuned weights of
|
||||
its own. Given a video trajectory and a task instruction, it asks an
|
||||
off-the-shelf VLM how likely the instruction is, conditioned on the video,
|
||||
and returns that log-likelihood as the reward signal.
|
||||
|
||||
Inference recipe:
|
||||
|
||||
1. The processor builds a chat-style prompt, tokenises it, and emits
|
||||
``input_ids``, ``attention_mask``, vision tensors, and ``labels``.
|
||||
The processor label-masks everything except the terminal answer token with
|
||||
``-100``.
|
||||
2. Forward the full token sequence through the VLM.
|
||||
3. Read the terminal answer token log-probability from the logits as the
|
||||
scalar reward.
|
||||
|
||||
With the default ``prompt_suffix_template``, the only unmasked token is the
|
||||
literal ``"True"`` at the end — the reward is
|
||||
``log P("True" | video + prompt + instruction)``.
|
||||
|
||||
This LeRobot port is **inference-only and not trainable** — :meth:`forward`
|
||||
is intentionally inherited from :class:`PreTrainedRewardModel` and raises
|
||||
``NotImplementedError``, making :attr:`PreTrainedRewardModel.is_trainable`
|
||||
return ``False``.
|
||||
|
||||
Because the VLM weights live on the Hugging Face Hub under their canonical
|
||||
id (``Qwen/Qwen3-VL-8B-Instruct`` etc.) and TOPReward never modifies them,
|
||||
:meth:`_save_pretrained` and :meth:`from_pretrained` are overridden so a
|
||||
TOPReward LeRobot "checkpoint" is a single ``config.json`` (the VLM is
|
||||
re-fetched from the Hub at load time).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import builtins
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import TYPE_CHECKING, Any, TypeVar
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from huggingface_hub import HfApi, hf_hub_download
|
||||
from huggingface_hub.constants import CONFIG_NAME
|
||||
from huggingface_hub.errors import HfHubHTTPError
|
||||
from torch import Tensor
|
||||
from torch.nn.functional import cross_entropy
|
||||
|
||||
from lerobot.configs.rewards import RewardModelConfig
|
||||
from lerobot.rewards.pretrained import PreTrainedRewardModel
|
||||
from lerobot.rewards.topreward.configuration_topreward import TOPRewardConfig
|
||||
from lerobot.rewards.topreward.processor_topreward import TOPREWARD_FEATURE_PREFIX, TOPREWARD_INPUT_KEYS
|
||||
from lerobot.utils.import_utils import _transformers_available, require_package
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers import Qwen3VLForConditionalGeneration
|
||||
else:
|
||||
Qwen3VLForConditionalGeneration = None # type: ignore[assignment]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
T = TypeVar("T", bound="TOPRewardModel")
|
||||
|
||||
|
||||
def _torch_dtype(name: str) -> torch.dtype | str:
|
||||
"""Resolve a torch dtype name; ``"auto"`` is passed through verbatim."""
|
||||
if name == "auto":
|
||||
return "auto"
|
||||
dtype = getattr(torch, name, None)
|
||||
if isinstance(dtype, torch.dtype):
|
||||
return dtype
|
||||
raise ValueError(f"Unknown torch dtype: {name!r}")
|
||||
|
||||
|
||||
class TOPRewardModel(PreTrainedRewardModel):
|
||||
"""TOPReward zero-shot reward model."""
|
||||
|
||||
name = "topreward"
|
||||
config_class = TOPRewardConfig
|
||||
|
||||
def __init__(self, config: TOPRewardConfig) -> None:
|
||||
require_package("transformers", extra="topreward")
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
torch_dtype = _torch_dtype(config.torch_dtype)
|
||||
model_kwargs: dict[str, Any] = {"dtype": torch_dtype, "trust_remote_code": True}
|
||||
if config.attn_implementation is not None:
|
||||
model_kwargs["attn_implementation"] = config.attn_implementation
|
||||
|
||||
self.model = Qwen3VLForConditionalGeneration.from_pretrained(config.vlm_name, **model_kwargs)
|
||||
|
||||
def compute_reward(self, batch: dict[str, Any]) -> Tensor:
|
||||
"""Return one log-prob reward per sample in the batch."""
|
||||
inputs: dict[str, Any] = {}
|
||||
for key in TOPREWARD_INPUT_KEYS:
|
||||
batch_key = f"{TOPREWARD_FEATURE_PREFIX}{key}"
|
||||
if batch_key not in batch:
|
||||
raise KeyError(
|
||||
f"TOPReward batch missing `{batch_key}`. Make sure the "
|
||||
"TOPRewardEncoderProcessorStep ran before `compute_reward`."
|
||||
)
|
||||
inputs[key] = batch[batch_key]
|
||||
|
||||
device = next(self.model.parameters()).device
|
||||
inputs = {key: value.to(device) if hasattr(value, "to") else value for key, value in inputs.items()}
|
||||
labels = inputs.pop("labels")
|
||||
inputs["logits_to_keep"] = 2
|
||||
|
||||
self.eval()
|
||||
with torch.no_grad():
|
||||
outputs = self.model(**inputs)
|
||||
logits = outputs.logits
|
||||
rewards = -cross_entropy(logits[:, -2, :].float(), labels[:, -1], reduction="none")
|
||||
if np.isfinite(self.config.success_threshold):
|
||||
rewards = (rewards > self.config.success_threshold).float()
|
||||
return rewards.to(self.config.device or "cpu")
|
||||
|
||||
def _save_pretrained(self, save_directory: Path) -> None:
|
||||
"""Save ``config.json`` only."""
|
||||
self.config._save_pretrained(save_directory)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls: builtins.type[T],
|
||||
pretrained_name_or_path: str | Path,
|
||||
*,
|
||||
config: RewardModelConfig | None = None,
|
||||
force_download: bool = False,
|
||||
resume_download: bool | None = None,
|
||||
proxies: dict | None = None,
|
||||
token: str | bool | None = None,
|
||||
cache_dir: str | Path | None = None,
|
||||
local_files_only: bool = False,
|
||||
revision: str | None = None,
|
||||
strict: bool = False, # noqa: ARG003 — accepted for API parity; unused (no safetensors to load)
|
||||
**kwargs: Any,
|
||||
) -> T:
|
||||
"""Load a TOPReward configuration and instantiate the wrapped VLM."""
|
||||
if config is None:
|
||||
config = RewardModelConfig.from_pretrained(
|
||||
pretrained_name_or_path=pretrained_name_or_path,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=local_files_only,
|
||||
revision=revision,
|
||||
**kwargs,
|
||||
)
|
||||
if not isinstance(config, TOPRewardConfig):
|
||||
raise TypeError(
|
||||
f"Expected a TOPRewardConfig, got {type(config).__name__}. Make sure "
|
||||
f"`pretrained_name_or_path={pretrained_name_or_path!r}` points at a "
|
||||
"TOPReward checkpoint."
|
||||
)
|
||||
|
||||
model_id = str(pretrained_name_or_path)
|
||||
if not os.path.isdir(model_id):
|
||||
try:
|
||||
hf_hub_download(
|
||||
repo_id=model_id,
|
||||
filename=CONFIG_NAME,
|
||||
revision=revision,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
token=token,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
except HfHubHTTPError as e:
|
||||
raise FileNotFoundError(
|
||||
f"{CONFIG_NAME} not found on the HuggingFace Hub in {model_id}"
|
||||
) from e
|
||||
|
||||
instance = cls(config, **kwargs)
|
||||
instance.to(config.device)
|
||||
instance.eval()
|
||||
return instance
|
||||
|
||||
def push_model_to_hub(self, cfg: TrainPipelineConfig):
|
||||
"""Push the TOPReward ``config.json`` + model card to the Hub."""
|
||||
api = HfApi()
|
||||
repo_id = api.create_repo(
|
||||
repo_id=self.config.repo_id, private=self.config.private, exist_ok=True
|
||||
).repo_id
|
||||
|
||||
with TemporaryDirectory(ignore_cleanup_errors=True) as tmp:
|
||||
saved_path = Path(tmp) / repo_id
|
||||
saved_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.config._save_pretrained(saved_path)
|
||||
|
||||
card = self.generate_model_card(
|
||||
cfg.dataset.repo_id, self.config.type, self.config.license, self.config.tags
|
||||
)
|
||||
card.save(str(saved_path / "README.md"))
|
||||
|
||||
cfg.save_pretrained(saved_path)
|
||||
|
||||
commit_info = api.upload_folder(
|
||||
repo_id=repo_id,
|
||||
repo_type="model",
|
||||
folder_path=saved_path,
|
||||
commit_message="Upload TOPReward config and readme",
|
||||
allow_patterns=["*.json", "*.yaml", "*.md"],
|
||||
ignore_patterns=["*.tmp", "*.log", "*.safetensors"],
|
||||
)
|
||||
|
||||
logger.info(f"Model pushed to {commit_info.repo_url.url}")
|
||||
@@ -1,305 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""TOPReward pre/post processing pipeline."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.configs import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
PolicyAction,
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorStep,
|
||||
ProcessorStepRegistry,
|
||||
policy_action_to_transition,
|
||||
)
|
||||
from lerobot.rewards.topreward.configuration_topreward import (
|
||||
DEFAULT_PROMPT_PREFIX,
|
||||
DEFAULT_PROMPT_SUFFIX_TEMPLATE,
|
||||
TOPRewardConfig,
|
||||
)
|
||||
from lerobot.types import EnvTransition, TransitionKey
|
||||
from lerobot.utils.constants import (
|
||||
OBS_IMAGES,
|
||||
OBS_PREFIX,
|
||||
POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
)
|
||||
from lerobot.utils.import_utils import _transformers_available, require_package
|
||||
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers import AutoProcessor
|
||||
else:
|
||||
AutoProcessor = None
|
||||
|
||||
TOPREWARD_FEATURE_PREFIX = f"{OBS_PREFIX}topreward."
|
||||
|
||||
_TRUE_ANSWER = "True"
|
||||
|
||||
TOPREWARD_VLM_INPUT_KEYS = (
|
||||
"input_ids",
|
||||
"attention_mask",
|
||||
"pixel_values_videos",
|
||||
"video_grid_thw",
|
||||
"mm_token_type_ids",
|
||||
)
|
||||
TOPREWARD_INPUT_KEYS = TOPREWARD_VLM_INPUT_KEYS + ("labels",)
|
||||
|
||||
|
||||
def _prepare_video_batch(video: Tensor, *, max_frames: int | None) -> Tensor:
|
||||
"""Return videos as ``(B, T, C, H, W)`` uint8 tensors for Qwen3-VL."""
|
||||
if video.ndim == 4:
|
||||
video = video.unsqueeze(1)
|
||||
elif video.ndim != 5:
|
||||
raise ValueError(
|
||||
f"Expected TOPReward frames with shape (B,C,H,W) or (B,T,C,H,W); got {tuple(video.shape)}"
|
||||
)
|
||||
|
||||
if max_frames is not None:
|
||||
video = video[:, -max_frames:]
|
||||
if video.shape[-1] in (1, 3):
|
||||
video = video.permute(0, 1, 4, 2, 3)
|
||||
elif video.shape[2] not in (1, 3):
|
||||
raise ValueError(f"Expected channel dim of size 1 or 3, got shape {tuple(video.shape)}")
|
||||
|
||||
if video.is_floating_point():
|
||||
video = video * 255.0
|
||||
|
||||
return video.clamp(0, 255).to(torch.uint8).contiguous()
|
||||
|
||||
|
||||
def _expand_tasks(task: Any, *, batch_size: int, default: str | None) -> list[str]:
|
||||
if task is None:
|
||||
task = default
|
||||
if task is None:
|
||||
raise KeyError("TOPReward expected a task description in complementary data")
|
||||
if isinstance(task, str):
|
||||
return [task] * batch_size
|
||||
if isinstance(task, tuple):
|
||||
task = list(task)
|
||||
if not (isinstance(task, list) and all(isinstance(item, str) for item in task)):
|
||||
raise TypeError(f"TOPReward task must be a string or list of strings, got {type(task)}")
|
||||
if len(task) == 1 and batch_size > 1:
|
||||
return task * batch_size
|
||||
if len(task) != batch_size:
|
||||
raise ValueError(f"Expected {batch_size} tasks, got {len(task)}")
|
||||
return task
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="topreward_encoder")
|
||||
class TOPRewardEncoderProcessorStep(ProcessorStep):
|
||||
"""Encode raw frames + task into Qwen-VL tensors for the TOPReward model.
|
||||
|
||||
Loads a :class:`~transformers.AutoProcessor` matching ``vlm_name`` and
|
||||
builds the full chat prompt including the instruction suffix. The
|
||||
resulting ``input_ids``, ``attention_mask``, vision tensors, and
|
||||
``labels`` are written under the ``observation.topreward.*`` namespace
|
||||
so the model can score without re-tokenising.
|
||||
|
||||
At call time the step reads:
|
||||
|
||||
- ``observation[image_key]``: ``(B, T, C, H, W)`` or ``(B, C, H, W)`` frames.
|
||||
- ``complementary_data[task_key]``: a string or list of strings.
|
||||
|
||||
and writes ``observation[f"{TOPREWARD_FEATURE_PREFIX}<name>"]`` for the
|
||||
Qwen-VL tensors plus ``labels``.
|
||||
"""
|
||||
|
||||
vlm_name: str = "Qwen/Qwen3-VL-8B-Instruct"
|
||||
image_key: str = OBS_IMAGES + ".top"
|
||||
task_key: str = "task"
|
||||
default_task: str | None = None
|
||||
max_frames: int | None = 16
|
||||
fps: float = 2.0
|
||||
prompt_prefix: str = DEFAULT_PROMPT_PREFIX
|
||||
prompt_suffix_template: str = DEFAULT_PROMPT_SUFFIX_TEMPLATE
|
||||
add_chat_template: bool = False
|
||||
max_length: int = 32768
|
||||
|
||||
_processor: Any = field(default=None, init=False, repr=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
require_package("transformers", extra="topreward")
|
||||
self._processor = AutoProcessor.from_pretrained(self.vlm_name, trust_remote_code=True)
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
observation = transition.get(TransitionKey.OBSERVATION)
|
||||
complementary = transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
|
||||
if self.image_key not in observation:
|
||||
raise KeyError(f"TOPReward expected image key {self.image_key!r} in observation")
|
||||
|
||||
frames = observation[self.image_key]
|
||||
videos = frames.detach().cpu() if isinstance(frames, Tensor) else torch.as_tensor(frames)
|
||||
videos = _prepare_video_batch(videos, max_frames=self.max_frames)
|
||||
|
||||
batch_size = videos.shape[0]
|
||||
tasks = _expand_tasks(
|
||||
complementary.get(self.task_key, self.default_task),
|
||||
batch_size=batch_size,
|
||||
default=self.default_task,
|
||||
)
|
||||
|
||||
encoded = self._encode_batch(videos, tasks, batch_size)
|
||||
|
||||
new_observation = dict(observation)
|
||||
for key, value in encoded.items():
|
||||
new_observation[f"{TOPREWARD_FEATURE_PREFIX}{key}"] = value
|
||||
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.OBSERVATION] = new_observation
|
||||
return new_transition
|
||||
|
||||
def _encode_batch(self, videos: Tensor, tasks: list[str], batch_size) -> dict[str, Any]:
|
||||
"""Tokenise a batch of (frames, task) pairs into Qwen-VL tensors.
|
||||
|
||||
The loop only builds per-sample chat strings. Tokenisation, padding,
|
||||
video preprocessing, and label construction are batched.
|
||||
"""
|
||||
|
||||
texts: list[str] = []
|
||||
video_metadata = [
|
||||
{
|
||||
"total_num_frames": int(videos.shape[1]),
|
||||
"fps": float(self.fps),
|
||||
"frames_indices": list(range(int(videos.shape[1]))),
|
||||
}
|
||||
for _ in range(batch_size)
|
||||
]
|
||||
eos_token = self._processor.tokenizer.eos_token
|
||||
|
||||
for i in range(batch_size):
|
||||
instruction_suffix = self.prompt_suffix_template.format(instruction=tasks[i])
|
||||
if self.add_chat_template:
|
||||
suffix_for_template = instruction_suffix.removesuffix(_TRUE_ANSWER).rstrip()
|
||||
templated_messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "video", "video": videos[i], "fps": self.fps},
|
||||
{"type": "text", "text": f"{self.prompt_prefix}{suffix_for_template}"},
|
||||
],
|
||||
}
|
||||
]
|
||||
prompt_chat = self._processor.apply_chat_template(
|
||||
templated_messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
full_text = f"{prompt_chat}{_TRUE_ANSWER}"
|
||||
else:
|
||||
user_messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "video", "video": videos[i], "fps": self.fps},
|
||||
{"type": "text", "text": self.prompt_prefix},
|
||||
],
|
||||
}
|
||||
]
|
||||
prompt_chat = self._processor.apply_chat_template(
|
||||
user_messages, tokenize=False, add_generation_prompt=False
|
||||
)
|
||||
if eos_token is not None:
|
||||
prompt_chat = prompt_chat.split(eos_token)[0]
|
||||
full_text = f"{prompt_chat}{instruction_suffix}"
|
||||
|
||||
texts.append(full_text)
|
||||
|
||||
result = self._processor(
|
||||
text=texts,
|
||||
videos=videos,
|
||||
video_metadata=video_metadata,
|
||||
do_sample_frames=False,
|
||||
padding=True,
|
||||
padding_side="left",
|
||||
return_tensors="pt",
|
||||
)
|
||||
input_ids = result["input_ids"]
|
||||
|
||||
if input_ids.shape[-1] > self.max_length:
|
||||
raise ValueError(
|
||||
f"TOPReward input length {input_ids.shape[-1]} exceeds max_length "
|
||||
f"{self.max_length}; lower `max_frames` or raise `max_length`."
|
||||
)
|
||||
|
||||
labels = torch.full_like(input_ids, -100)
|
||||
labels[:, -1] = input_ids[:, -1]
|
||||
result["labels"] = labels
|
||||
return result
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
return features
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {
|
||||
"vlm_name": self.vlm_name,
|
||||
"image_key": self.image_key,
|
||||
"task_key": self.task_key,
|
||||
"default_task": self.default_task,
|
||||
"max_frames": self.max_frames,
|
||||
"fps": self.fps,
|
||||
"prompt_prefix": self.prompt_prefix,
|
||||
"prompt_suffix_template": self.prompt_suffix_template,
|
||||
"add_chat_template": self.add_chat_template,
|
||||
"max_length": self.max_length,
|
||||
}
|
||||
|
||||
|
||||
def make_topreward_pre_post_processors(
|
||||
config: TOPRewardConfig,
|
||||
dataset_stats: dict[str, dict[str, Any]] | None = None,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
"""Pipeline that pre-encodes frames + task into Qwen-VL tensors.
|
||||
|
||||
The preprocessor adds a batch dimension if needed, runs TOPReward's
|
||||
encoder (which tokenises the full prompt and emits ``labels``), and
|
||||
moves everything to the configured device. The postprocessor is
|
||||
the identity since TOPReward outputs a single reward tensor.
|
||||
"""
|
||||
preprocessor = PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
|
||||
steps=[
|
||||
AddBatchDimensionProcessorStep(),
|
||||
TOPRewardEncoderProcessorStep(
|
||||
vlm_name=config.vlm_name,
|
||||
image_key=config.image_key,
|
||||
task_key=config.task_key,
|
||||
default_task=config.default_task,
|
||||
max_frames=config.max_frames,
|
||||
fps=config.fps,
|
||||
prompt_prefix=config.prompt_prefix,
|
||||
prompt_suffix_template=config.prompt_suffix_template,
|
||||
add_chat_template=config.add_chat_template,
|
||||
max_length=config.max_input_length,
|
||||
),
|
||||
DeviceProcessorStep(device=config.device or "cpu"),
|
||||
],
|
||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
)
|
||||
postprocessor = PolicyProcessorPipeline(
|
||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
to_transition=policy_action_to_transition,
|
||||
)
|
||||
return preprocessor, postprocessor
|
||||
@@ -23,7 +23,6 @@ from .configs import (
|
||||
DAggerKeyboardConfig,
|
||||
DAggerPedalConfig,
|
||||
DAggerStrategyConfig,
|
||||
EpisodicStrategyConfig,
|
||||
HighlightStrategyConfig,
|
||||
RolloutConfig,
|
||||
RolloutStrategyConfig,
|
||||
@@ -39,10 +38,8 @@ from .context import (
|
||||
build_rollout_context,
|
||||
)
|
||||
from .inference import (
|
||||
FallbackMode,
|
||||
InferenceEngine,
|
||||
InferenceEngineConfig,
|
||||
RemoteInferenceConfig,
|
||||
RTCInferenceConfig,
|
||||
RTCInferenceEngine,
|
||||
SyncInferenceConfig,
|
||||
@@ -52,7 +49,6 @@ from .inference import (
|
||||
from .strategies import (
|
||||
BaseStrategy,
|
||||
DAggerStrategy,
|
||||
EpisodicStrategy,
|
||||
HighlightStrategy,
|
||||
RolloutStrategy,
|
||||
SentryStrategy,
|
||||
@@ -70,16 +66,12 @@ __all__ = [
|
||||
"HardwareContext",
|
||||
"HighlightStrategy",
|
||||
"HighlightStrategyConfig",
|
||||
"EpisodicStrategy",
|
||||
"EpisodicStrategyConfig",
|
||||
"FallbackMode",
|
||||
"InferenceEngine",
|
||||
"InferenceEngineConfig",
|
||||
"PolicyContext",
|
||||
"ProcessorContext",
|
||||
"RTCInferenceConfig",
|
||||
"RTCInferenceEngine",
|
||||
"RemoteInferenceConfig",
|
||||
"RolloutConfig",
|
||||
"RolloutContext",
|
||||
"RolloutStrategy",
|
||||
|
||||
@@ -121,35 +121,6 @@ class DAggerPedalConfig:
|
||||
upload: str = "KEY_C"
|
||||
|
||||
|
||||
@RolloutStrategyConfig.register_subclass("episodic")
|
||||
@dataclass
|
||||
class EpisodicStrategyConfig(RolloutStrategyConfig):
|
||||
"""Episode-oriented recording that mirrors the behavior of ``lerobot-record``.
|
||||
|
||||
Records ``dataset.num_episodes`` episodes of maximum ``dataset.episode_time_s`` each.
|
||||
After each episode, runs ``dataset.reset_time_s`` seconds of reset time.
|
||||
|
||||
Keyboard controls:
|
||||
Right arrow — end current episode or reset phase early
|
||||
Left arrow — discard current episode and re-record
|
||||
Escape — stop recording session
|
||||
|
||||
In between episodes:
|
||||
- if there is no teleop leader, the robot is held at its initial joint positions captured at startup.
|
||||
- else, the robot is moved smoothly to the position of the teleop leader.
|
||||
"""
|
||||
|
||||
# This only applies if there are no teleop leaders specified.
|
||||
# When True (default), moves the robot back to the joint positions captured at startup.
|
||||
# Otherwise, leave the robot in its current position.
|
||||
reset_to_initial_position: bool = True
|
||||
|
||||
# Whether to turn on or off the leader -> follower smooth handover behavior.
|
||||
# When False, fallback to follower -> leader handover.
|
||||
# Note that leader -> follower handover is only supported when the leader has `send_feedback` capability.
|
||||
smooth_leader_to_follower_handover: bool = True
|
||||
|
||||
|
||||
@RolloutStrategyConfig.register_subclass("dagger")
|
||||
@dataclass
|
||||
class DAggerStrategyConfig(RolloutStrategyConfig):
|
||||
@@ -258,13 +229,7 @@ class RolloutConfig:
|
||||
|
||||
# TODO(Steven): DAgger shouldn't require a dataset (user may want to just rollout+intervene without recording), but for now we require it to simplify the implementation.
|
||||
needs_dataset = isinstance(
|
||||
self.strategy,
|
||||
(
|
||||
SentryStrategyConfig,
|
||||
HighlightStrategyConfig,
|
||||
DAggerStrategyConfig,
|
||||
EpisodicStrategyConfig,
|
||||
),
|
||||
self.strategy, (SentryStrategyConfig, HighlightStrategyConfig, DAggerStrategyConfig)
|
||||
)
|
||||
if needs_dataset and (self.dataset is None or not self.dataset.repo_id):
|
||||
raise ValueError(f"{self.strategy.type} strategy requires --dataset.repo_id to be set")
|
||||
|
||||
@@ -51,7 +51,6 @@ from lerobot.utils.feature_utils import combine_feature_dicts, hw_to_dataset_fea
|
||||
from .configs import BaseStrategyConfig, DAggerStrategyConfig, RolloutConfig
|
||||
from .inference import (
|
||||
InferenceEngine,
|
||||
RemoteInferenceConfig,
|
||||
RTCInferenceConfig,
|
||||
SyncInferenceConfig,
|
||||
create_inference_engine,
|
||||
@@ -114,17 +113,11 @@ class HardwareContext:
|
||||
|
||||
@dataclass
|
||||
class PolicyContext:
|
||||
"""Loaded policy and its inference engine.
|
||||
"""Loaded policy and its inference engine."""
|
||||
|
||||
``policy``/``preprocessor``/``postprocessor`` are ``None`` for the
|
||||
weightless remote backend (``--inference.type=remote``): inference
|
||||
runs on a ``lerobot-policy-server`` and strategies only ever consume
|
||||
``inference``.
|
||||
"""
|
||||
|
||||
policy: PreTrainedPolicy | None
|
||||
preprocessor: PolicyProcessorPipeline | None
|
||||
postprocessor: PolicyProcessorPipeline | None
|
||||
policy: PreTrainedPolicy
|
||||
preprocessor: PolicyProcessorPipeline
|
||||
postprocessor: PolicyProcessorPipeline
|
||||
inference: InferenceEngine
|
||||
|
||||
|
||||
@@ -179,66 +172,54 @@ def build_rollout_context(
|
||||
fails fast without touching the robot.
|
||||
"""
|
||||
is_rtc = isinstance(cfg.inference, RTCInferenceConfig)
|
||||
is_remote = isinstance(cfg.inference, RemoteInferenceConfig)
|
||||
|
||||
# --- 1. Policy (heavy I/O, but no hardware yet) -------------------
|
||||
# Remote inference keeps the edge weightless: the config-only
|
||||
# PreTrainedConfig (already loaded by RolloutConfig.__post_init__,
|
||||
# no weight download) is all the client needs for pre-flight
|
||||
# validation and action ordering.
|
||||
logger.info("Loading policy from '%s'...", cfg.policy.pretrained_path)
|
||||
policy_config = cfg.policy
|
||||
policy = None
|
||||
if is_remote:
|
||||
logger.info(
|
||||
"Remote inference: weightless client for '%s' (no weights downloaded)",
|
||||
cfg.policy.pretrained_path,
|
||||
policy_class = get_policy_class(policy_config.type)
|
||||
|
||||
if hasattr(policy_config, "compile_model"):
|
||||
policy_config.compile_model = cfg.use_torch_compile
|
||||
|
||||
if policy_config.type == "vqbet" and cfg.device == "mps":
|
||||
raise NotImplementedError(
|
||||
"Current implementation of VQBeT does not support `mps` backend. "
|
||||
"Please use `cpu` or `cuda` backend."
|
||||
)
|
||||
|
||||
if policy_config.use_peft:
|
||||
from peft import PeftConfig, PeftModel
|
||||
|
||||
peft_path = policy_config.pretrained_path
|
||||
peft_config = PeftConfig.from_pretrained(peft_path)
|
||||
policy = policy_class.from_pretrained(
|
||||
pretrained_name_or_path=peft_config.base_model_name_or_path, config=policy_config
|
||||
)
|
||||
policy = PeftModel.from_pretrained(policy, peft_path, config=peft_config)
|
||||
else:
|
||||
logger.info("Loading policy from '%s'...", cfg.policy.pretrained_path)
|
||||
policy_class = get_policy_class(policy_config.type)
|
||||
policy = policy_class.from_pretrained(policy_config.pretrained_path, config=policy_config)
|
||||
|
||||
if hasattr(policy_config, "compile_model"):
|
||||
policy_config.compile_model = cfg.use_torch_compile
|
||||
if is_rtc:
|
||||
policy.config.rtc_config = cfg.inference.rtc
|
||||
if hasattr(policy, "init_rtc_processor"):
|
||||
policy.init_rtc_processor()
|
||||
|
||||
if policy_config.type == "vqbet" and cfg.device == "mps":
|
||||
raise NotImplementedError(
|
||||
"Current implementation of VQBeT does not support `mps` backend. "
|
||||
"Please use `cpu` or `cuda` backend."
|
||||
)
|
||||
policy = policy.to(cfg.device)
|
||||
policy.eval()
|
||||
logger.info("Policy loaded: type=%s, device=%s", policy_config.type, cfg.device)
|
||||
|
||||
if policy_config.use_peft:
|
||||
from peft import PeftConfig, PeftModel
|
||||
|
||||
peft_path = policy_config.pretrained_path
|
||||
peft_config = PeftConfig.from_pretrained(peft_path)
|
||||
policy = policy_class.from_pretrained(
|
||||
pretrained_name_or_path=peft_config.base_model_name_or_path, config=policy_config
|
||||
)
|
||||
policy = PeftModel.from_pretrained(policy, peft_path, config=peft_config)
|
||||
else:
|
||||
policy = policy_class.from_pretrained(policy_config.pretrained_path, config=policy_config)
|
||||
|
||||
if is_rtc:
|
||||
policy.config.rtc_config = cfg.inference.rtc
|
||||
if hasattr(policy, "init_rtc_processor"):
|
||||
policy.init_rtc_processor()
|
||||
|
||||
policy = policy.to(cfg.device)
|
||||
policy.eval()
|
||||
logger.info("Policy loaded: type=%s, device=%s", policy_config.type, cfg.device)
|
||||
|
||||
if cfg.use_torch_compile and policy.type not in ("pi0", "pi05"):
|
||||
try:
|
||||
if hasattr(torch, "compile"):
|
||||
compile_kwargs = {
|
||||
"backend": cfg.torch_compile_backend,
|
||||
"mode": cfg.torch_compile_mode,
|
||||
"options": {"triton.cudagraphs": False},
|
||||
}
|
||||
policy.predict_action_chunk = torch.compile(policy.predict_action_chunk, **compile_kwargs)
|
||||
logger.info("torch.compile applied to predict_action_chunk")
|
||||
except Exception as e:
|
||||
logger.warning("Failed to apply torch.compile: %s", e)
|
||||
if cfg.use_torch_compile and policy.type not in ("pi0", "pi05"):
|
||||
try:
|
||||
if hasattr(torch, "compile"):
|
||||
compile_kwargs = {
|
||||
"backend": cfg.torch_compile_backend,
|
||||
"mode": cfg.torch_compile_mode,
|
||||
"options": {"triton.cudagraphs": False},
|
||||
}
|
||||
policy.predict_action_chunk = torch.compile(policy.predict_action_chunk, **compile_kwargs)
|
||||
logger.info("torch.compile applied to predict_action_chunk")
|
||||
except Exception as e:
|
||||
logger.warning("Failed to apply torch.compile: %s", e)
|
||||
|
||||
# --- 2. Robot-side processors (user-supplied or defaults) --------
|
||||
if (
|
||||
@@ -397,36 +378,31 @@ def build_rollout_context(
|
||||
logger.info("Dataset ready: %s (%d existing episodes)", dataset.repo_id, dataset.num_episodes)
|
||||
|
||||
# --- 6. Policy pre/post processors (needs dataset stats if any) ---
|
||||
# Remote inference runs the policy processors server-side (per
|
||||
# session); the edge ships canonical dataset-format observations.
|
||||
preprocessor = None
|
||||
postprocessor = None
|
||||
if not is_remote:
|
||||
dataset_stats = None
|
||||
if dataset is not None:
|
||||
dataset_stats = rename_stats(
|
||||
dataset.meta.stats,
|
||||
cfg.rename_map,
|
||||
)
|
||||
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=policy_config,
|
||||
pretrained_path=cfg.policy.pretrained_path,
|
||||
dataset_stats=dataset_stats,
|
||||
preprocessor_overrides={
|
||||
"device_processor": {"device": cfg.device},
|
||||
"rename_observations_processor": {"rename_map": cfg.rename_map},
|
||||
},
|
||||
dataset_stats = None
|
||||
if dataset is not None:
|
||||
dataset_stats = rename_stats(
|
||||
dataset.meta.stats,
|
||||
cfg.rename_map,
|
||||
)
|
||||
|
||||
if isinstance(cfg.inference, SyncInferenceConfig) and any(
|
||||
isinstance(step, RelativeActionsProcessorStep) and step.enabled
|
||||
for step in getattr(preprocessor, "steps", ())
|
||||
):
|
||||
raise NotImplementedError(
|
||||
"SyncInferenceEngine does not support policies with relative actions for now."
|
||||
"Use --inference.type=rtc or remove relative action processor steps from the policy pipeline."
|
||||
)
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=policy_config,
|
||||
pretrained_path=cfg.policy.pretrained_path,
|
||||
dataset_stats=dataset_stats,
|
||||
preprocessor_overrides={
|
||||
"device_processor": {"device": cfg.device},
|
||||
"rename_observations_processor": {"rename_map": cfg.rename_map},
|
||||
},
|
||||
)
|
||||
|
||||
if isinstance(cfg.inference, SyncInferenceConfig) and any(
|
||||
isinstance(step, RelativeActionsProcessorStep) and step.enabled
|
||||
for step in getattr(preprocessor, "steps", ())
|
||||
):
|
||||
raise NotImplementedError(
|
||||
"SyncInferenceEngine does not support policies with relative actions for now."
|
||||
"Use --inference.type=rtc or remove relative action processor steps from the policy pipeline."
|
||||
)
|
||||
|
||||
# --- 7. Inference strategy (needs policy + pre/post + hardware) --
|
||||
logger.info(
|
||||
@@ -449,8 +425,6 @@ def build_rollout_context(
|
||||
use_torch_compile=cfg.use_torch_compile,
|
||||
compile_warmup_inferences=cfg.compile_warmup_inferences,
|
||||
shutdown_event=shutdown_event,
|
||||
policy_config=policy_config,
|
||||
rename_map=cfg.rename_map,
|
||||
)
|
||||
|
||||
# --- 8. Assemble ---------------------------------------------------
|
||||
|
||||
@@ -14,18 +14,13 @@
|
||||
|
||||
"""Inference engine package — backend-agnostic action production.
|
||||
|
||||
Concrete backends (``sync``, ``rtc``, ``remote``, ...) expose the same
|
||||
small interface so rollout strategies never branch on which backend is
|
||||
in use.
|
||||
Concrete backends (``sync``, ``rtc``, ...) expose the same small interface so
|
||||
rollout strategies never branch on which backend is in use.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from .base import InferenceEngine
|
||||
from .factory import (
|
||||
FallbackMode,
|
||||
InferenceEngineConfig,
|
||||
RemoteInferenceConfig,
|
||||
RTCInferenceConfig,
|
||||
SyncInferenceConfig,
|
||||
create_inference_engine,
|
||||
@@ -34,23 +29,11 @@ from .rtc import RTCInferenceEngine
|
||||
from .sync import SyncInferenceEngine
|
||||
|
||||
__all__ = [
|
||||
"FallbackMode",
|
||||
"InferenceEngine",
|
||||
"InferenceEngineConfig",
|
||||
"RTCInferenceConfig",
|
||||
"RTCInferenceEngine",
|
||||
"RemoteInferenceConfig",
|
||||
"RemoteInferenceEngine",
|
||||
"SyncInferenceConfig",
|
||||
"SyncInferenceEngine",
|
||||
"create_inference_engine",
|
||||
]
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
# Lazy: RemoteInferenceEngine pulls in msgpack/zenoh ('async' extra).
|
||||
if name == "RemoteInferenceEngine":
|
||||
from .remote import RemoteInferenceEngine
|
||||
|
||||
return RemoteInferenceEngine
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
|
||||
@@ -14,9 +14,9 @@
|
||||
|
||||
"""Inference engine configs and factory.
|
||||
|
||||
Selection is explicit via ``--inference.type=sync|rtc|remote``. Adding a
|
||||
new backend requires registering its config subclass and dispatching it
|
||||
in :func:`create_inference_engine`.
|
||||
Selection is explicit via ``--inference.type=sync|rtc``. Adding a new
|
||||
backend requires registering its config subclass and dispatching it in
|
||||
:func:`create_inference_engine`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -24,12 +24,10 @@ from __future__ import annotations
|
||||
import abc
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from enum import StrEnum
|
||||
from threading import Event
|
||||
|
||||
import draccus
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.processor import PolicyProcessorPipeline
|
||||
@@ -76,73 +74,6 @@ class RTCInferenceConfig(InferenceEngineConfig):
|
||||
queue_threshold: int = 30
|
||||
|
||||
|
||||
class FallbackMode(StrEnum):
|
||||
"""What ``get_action`` returns when the remote queue runs dry (STALLED)."""
|
||||
|
||||
HOLD = "hold" # return None: the robot holds its last commanded position
|
||||
REPEAT_LAST = "repeat_last" # re-send the last executed action
|
||||
ZERO = "zero" # explicit zero command (required for velocity-controlled robots)
|
||||
|
||||
|
||||
@InferenceEngineConfig.register_subclass("remote")
|
||||
@dataclass
|
||||
class RemoteInferenceConfig(InferenceEngineConfig):
|
||||
"""Network inference against a ``lerobot-policy-server`` over Zenoh.
|
||||
|
||||
The edge stays weightless: ``--policy.path`` resolves to a
|
||||
config-only ``PreTrainedConfig`` (no weight download) used for
|
||||
pre-flight validation and action ordering. Requires the ``async``
|
||||
extra (``pip install 'lerobot[async]'``).
|
||||
"""
|
||||
|
||||
# Transport: robots dial out to a zenoh router (NAT-friendly).
|
||||
connect_endpoint: str = "tcp/localhost:7447"
|
||||
# "client" via a zenohd router (production) | "peer" direct (LAN/tests).
|
||||
zenoh_mode: str = "client"
|
||||
tls_ca: str | None = None
|
||||
tls_cert: str | None = None
|
||||
tls_key: str | None = None
|
||||
|
||||
# Service addressing: which (model, revision, task) key tree to dial.
|
||||
# service_model_id defaults to --policy.path; service_task to the
|
||||
# rollout task. These must match the server manifest's namespace.
|
||||
service_model_id: str = ""
|
||||
service_revision: str = "main"
|
||||
service_task: str = ""
|
||||
|
||||
# Identity: "" → a fresh uuid4 per run. Set a stable ID per robot for
|
||||
# fleet-wide log correlation and per-robot router ACLs.
|
||||
client_uuid: str = ""
|
||||
|
||||
# Observation encoding: JPEG quality (0 = raw, LAN/debug only).
|
||||
jpeg_quality: int = 90
|
||||
|
||||
# Self-clocking: request the next chunk when the local queue holds
|
||||
# less than this many seconds of playback.
|
||||
buffer_time_s: float = 0.5
|
||||
|
||||
# Safety: never execute an action whose source observation is older
|
||||
# than this (bounds open-loop execution after a network stall).
|
||||
max_action_age_s: float = 3.0
|
||||
# Fallback when the queue runs dry (see FallbackMode).
|
||||
fallback: FallbackMode = FallbackMode.HOLD
|
||||
|
||||
# Watchdogs & reconnection.
|
||||
degraded_after_s: float = 1.0
|
||||
request_timeout_s: float = 5.0
|
||||
handshake_timeout_s: float = 2.0
|
||||
reconnect_initial_backoff_s: float = 0.5
|
||||
reconnect_max_backoff_s: float = 10.0
|
||||
max_offline_s: float = 60.0
|
||||
|
||||
# RTC settings (enabled → replace-merge with prefix conditioning when
|
||||
# the server supports it; otherwise downgraded to chunk-append).
|
||||
rtc: RTCConfig = field(default_factory=RTCConfig)
|
||||
|
||||
# Free-form labels forwarded in the session handshake (telemetry only).
|
||||
tags: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Factory
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -151,9 +82,9 @@ class RemoteInferenceConfig(InferenceEngineConfig):
|
||||
def create_inference_engine(
|
||||
config: InferenceEngineConfig,
|
||||
*,
|
||||
policy: PreTrainedPolicy | None,
|
||||
preprocessor: PolicyProcessorPipeline | None,
|
||||
postprocessor: PolicyProcessorPipeline | None,
|
||||
policy: PreTrainedPolicy,
|
||||
preprocessor: PolicyProcessorPipeline,
|
||||
postprocessor: PolicyProcessorPipeline,
|
||||
robot_wrapper: ThreadSafeRobot,
|
||||
hw_features: dict,
|
||||
dataset_features: dict,
|
||||
@@ -164,19 +95,10 @@ def create_inference_engine(
|
||||
use_torch_compile: bool = False,
|
||||
compile_warmup_inferences: int = 2,
|
||||
shutdown_event: Event | None = None,
|
||||
policy_config: PreTrainedConfig | None = None,
|
||||
rename_map: dict[str, str] | None = None,
|
||||
) -> InferenceEngine:
|
||||
"""Instantiate the appropriate inference engine from a config object.
|
||||
|
||||
``policy``/``preprocessor``/``postprocessor`` are required for the
|
||||
local backends (``sync``, ``rtc``) and must be ``None``-free there;
|
||||
the ``remote`` backend is weightless and needs only ``policy_config``.
|
||||
"""
|
||||
"""Instantiate the appropriate inference engine from a config object."""
|
||||
logger.info("Creating inference engine: %s", config.type)
|
||||
if isinstance(config, SyncInferenceConfig):
|
||||
if policy is None or preprocessor is None or postprocessor is None:
|
||||
raise ValueError("sync inference requires a loaded policy and processors")
|
||||
return SyncInferenceEngine(
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
@@ -188,8 +110,6 @@ def create_inference_engine(
|
||||
robot_type=robot_wrapper.robot_type,
|
||||
)
|
||||
if isinstance(config, RTCInferenceConfig):
|
||||
if policy is None or preprocessor is None or postprocessor is None:
|
||||
raise ValueError("rtc inference requires a loaded policy and processors")
|
||||
return RTCInferenceEngine(
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
@@ -205,25 +125,4 @@ def create_inference_engine(
|
||||
rtc_queue_threshold=config.queue_threshold,
|
||||
shutdown_event=shutdown_event,
|
||||
)
|
||||
if isinstance(config, RemoteInferenceConfig):
|
||||
if policy_config is None:
|
||||
raise ValueError("remote inference requires policy_config (from config-only --policy.path)")
|
||||
if use_torch_compile:
|
||||
logger.warning("--use_torch_compile is ignored with remote inference (server-side concern)")
|
||||
if device not in (None, "cpu"):
|
||||
logger.warning("--device=%s is ignored with remote inference (server-side concern)", device)
|
||||
# Lazy import: eclipse-zenoh/msgpack live behind the 'async' extra.
|
||||
from .remote import RemoteInferenceEngine
|
||||
|
||||
return RemoteInferenceEngine(
|
||||
config=config,
|
||||
policy_config=policy_config,
|
||||
hw_features=hw_features,
|
||||
ordered_action_keys=ordered_action_keys,
|
||||
task=task,
|
||||
fps=fps,
|
||||
robot_type=robot_wrapper.robot_type,
|
||||
rename_map=rename_map,
|
||||
shutdown_event=shutdown_event,
|
||||
)
|
||||
raise ValueError(f"Unknown inference engine type: {type(config).__name__}")
|
||||
|
||||
@@ -1,851 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Remote inference engine: network-decoupled policy inference over Zenoh.
|
||||
|
||||
The same architecture as :class:`RTCInferenceEngine` with the thread
|
||||
boundary replaced by a network boundary. The edge stays **weightless**
|
||||
(no policy weights, no policy processors); a ``lerobot-policy-server``
|
||||
runs the heavy half. All chunk state — leftover prefixes, latency
|
||||
tracking, delay computation — lives client-side in the existing
|
||||
``ActionQueue``/``LatencyTracker`` machinery, so the server is stateless
|
||||
per request and a server crash loses zero control state.
|
||||
|
||||
Threading model:
|
||||
- **Main thread** (strategy loop): ``notify_observation`` writes a
|
||||
latest-only slot; ``get_action`` pops the local queue and applies the
|
||||
staleness bound + fallback ladder. Never any I/O.
|
||||
- **Network worker** (one daemon thread): self-clocked by
|
||||
``buffer_time_s``, publishes one observation, awaits its chunk (or
|
||||
timeout), merges, repeats. One-in-flight is a *correctness*
|
||||
requirement: ``idx_before``/prefix snapshots must serialize with
|
||||
merges.
|
||||
- **Zenoh threads**: deposit-only callbacks (chunk → bounded queue,
|
||||
liveliness → event).
|
||||
|
||||
Clock iron rule: wall-clock instants never cross machines. The header's
|
||||
``client_mono_ns`` is opaque to the server and echoed back; the server
|
||||
reports only durations.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
import math
|
||||
import queue as queue_module
|
||||
import time
|
||||
import traceback
|
||||
import uuid as uuid_module
|
||||
from threading import Event, Lock, Thread
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.policies.rtc import ActionQueue, LatencyTracker
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.policy_server import codec
|
||||
from lerobot.policy_server.schema import (
|
||||
MSG_TYPE_OBS,
|
||||
SCHEMA_VERSION,
|
||||
MsgHeader,
|
||||
ObservationMsg,
|
||||
ResetMsg,
|
||||
SessionAckMsg,
|
||||
SessionCloseMsg,
|
||||
SessionOpenMsg,
|
||||
action_key,
|
||||
client_alive_key,
|
||||
obs_key,
|
||||
reset_key,
|
||||
sanitize_key_segment,
|
||||
server_alive_key,
|
||||
service_prefix,
|
||||
session_key,
|
||||
status_key,
|
||||
)
|
||||
from lerobot.policy_server.zenoh_utils import build_zenoh_config, import_zenoh, obs_publisher_qos
|
||||
from lerobot.utils.constants import OBS_STATE, OBS_STR
|
||||
from lerobot.utils.feature_utils import build_dataset_frame
|
||||
|
||||
from .base import InferenceEngine
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
|
||||
from .factory import RemoteInferenceConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_IDLE_SLEEP_S = 0.01
|
||||
_MAX_CONSECUTIVE_WORKER_ERRORS = 10
|
||||
|
||||
|
||||
class ClientState:
|
||||
"""Fail-safe state machine states (see the design doc §9.2)."""
|
||||
|
||||
CONNECTING = "CONNECTING"
|
||||
STREAMING = "STREAMING"
|
||||
DEGRADED = "DEGRADED"
|
||||
STALLED = "STALLED"
|
||||
RECONNECTING = "RECONNECTING"
|
||||
DEAD = "DEAD"
|
||||
|
||||
|
||||
class RemoteInferenceEngine(InferenceEngine):
|
||||
"""``--inference.type=remote``: weightless edge client of a policy server."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: RemoteInferenceConfig,
|
||||
policy_config: PreTrainedConfig,
|
||||
hw_features: dict,
|
||||
ordered_action_keys: list[str],
|
||||
task: str,
|
||||
fps: float,
|
||||
robot_type: str,
|
||||
rename_map: dict[str, str] | None = None,
|
||||
shutdown_event: Event | None = None,
|
||||
) -> None:
|
||||
self._config = config
|
||||
self._policy_config = policy_config
|
||||
self._hw_features = hw_features
|
||||
self._ordered_action_keys = list(ordered_action_keys)
|
||||
self._task = task
|
||||
self._fps = float(fps)
|
||||
self._dt = 1.0 / self._fps
|
||||
self._robot_type = robot_type
|
||||
self._rename_map = dict(rename_map or {})
|
||||
self._global_shutdown_event = shutdown_event
|
||||
|
||||
self._client_uuid = sanitize_key_segment(config.client_uuid or uuid_module.uuid4().hex)
|
||||
model_id = config.service_model_id or getattr(policy_config, "pretrained_path", "") or "model"
|
||||
self._prefix = service_prefix(model_id, config.service_revision, config.service_task or task)
|
||||
|
||||
# Latest-only observation slot (identical to rtc.py's _obs_holder).
|
||||
self._obs_holder: dict[str, Any] = {"obs": None}
|
||||
self._obs_lock = Lock()
|
||||
|
||||
self._action_queue: ActionQueue | None = None
|
||||
self._latency_tracker = LatencyTracker()
|
||||
self._effective_rtc: RTCConfig = config.rtc
|
||||
|
||||
# Replies deposited by the zenoh callback, consumed by the worker.
|
||||
self._reply_queue: queue_module.Queue[tuple[MsgHeader, bytes]] = queue_module.Queue(maxsize=4)
|
||||
|
||||
self._zenoh = None
|
||||
self._obs_publisher = None
|
||||
self._declarations: list[Any] = []
|
||||
self._alive_token = None
|
||||
self._server_alive = Event()
|
||||
|
||||
self._worker: Thread | None = None
|
||||
self._stop_event = Event()
|
||||
self._active = Event()
|
||||
self._dead = Event()
|
||||
self._session_ack: SessionAckMsg | None = None
|
||||
|
||||
self.state = ClientState.CONNECTING
|
||||
self._state_lock = Lock()
|
||||
self._seq_id = 0
|
||||
self._epoch = 0
|
||||
self._episode_id = 0
|
||||
self._pending_reset = False
|
||||
|
||||
# Staleness bookkeeping: client-monotonic send time of the
|
||||
# observation that produced the current queue contents.
|
||||
# _anchor_lock serializes {merge + anchor update} (worker),
|
||||
# {staleness clear} (control thread), and {reset clear} so a
|
||||
# stale chunk can never merge into a freshly-reset queue and the
|
||||
# safety path can never clear a just-merged one.
|
||||
self._anchor_lock = Lock()
|
||||
self._chunk_anchor_mono: float | None = None
|
||||
self._last_chunk_mono: float | None = None
|
||||
self._offline_since_mono: float | None = None
|
||||
self._last_action: torch.Tensor | None = None
|
||||
|
||||
self.stats: dict[str, float] = {
|
||||
"requests": 0,
|
||||
"timeouts": 0,
|
||||
"merges": 0,
|
||||
"stale_drops": 0,
|
||||
"fallback_ticks": 0,
|
||||
"reconnects": 0,
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Lifecycle
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@property
|
||||
def ready(self) -> bool:
|
||||
"""Session opened, capabilities validated, server warmed up."""
|
||||
ack = self._session_ack
|
||||
return ack is not None and ack.warmed_up and not self._dead.is_set()
|
||||
|
||||
@property
|
||||
def failed(self) -> bool:
|
||||
return self._dead.is_set()
|
||||
|
||||
@property
|
||||
def action_queue(self) -> ActionQueue | None:
|
||||
return self._action_queue
|
||||
|
||||
def start(self) -> None:
|
||||
"""Open transport, handshake, start the network worker.
|
||||
|
||||
Raises on initial connection/validation failure so a bad
|
||||
deployment aborts before the robot moves (reconnect logic only
|
||||
guards established sessions).
|
||||
"""
|
||||
zenoh = import_zenoh()
|
||||
cfg = self._config
|
||||
self._zenoh = zenoh.open(
|
||||
build_zenoh_config(
|
||||
mode=cfg.zenoh_mode,
|
||||
connect_endpoints=[cfg.connect_endpoint] if cfg.connect_endpoint else None,
|
||||
tls_root_ca_certificate=cfg.tls_ca,
|
||||
tls_connect_certificate=cfg.tls_cert,
|
||||
tls_connect_private_key=cfg.tls_key,
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
ack = self._handshake(initial=True)
|
||||
except Exception:
|
||||
# Fail fast without leaking the transport session.
|
||||
with contextlib.suppress(Exception):
|
||||
self._zenoh.close()
|
||||
self._zenoh = None
|
||||
raise
|
||||
self._configure_from_ack(ack)
|
||||
|
||||
handlers = zenoh.handlers
|
||||
self._declarations.append(
|
||||
self._zenoh.declare_subscriber(
|
||||
action_key(self._prefix, self._client_uuid), handlers.Callback(self._on_chunk)
|
||||
)
|
||||
)
|
||||
self._obs_publisher = self._zenoh.declare_publisher(
|
||||
obs_key(self._prefix, self._client_uuid), **obs_publisher_qos(zenoh)
|
||||
)
|
||||
self._declarations.append(self._obs_publisher)
|
||||
self._server_alive.set()
|
||||
self._declarations.append(
|
||||
self._zenoh.liveliness().declare_subscriber(
|
||||
server_alive_key(self._prefix), handlers.Callback(self._on_server_liveliness), history=True
|
||||
)
|
||||
)
|
||||
self._alive_token = self._zenoh.liveliness().declare_token(
|
||||
client_alive_key(self._prefix, self._client_uuid)
|
||||
)
|
||||
|
||||
self._stop_event.clear()
|
||||
self._dead.clear()
|
||||
self._active.set()
|
||||
self._worker = Thread(target=self._worker_loop, daemon=True, name="RemoteInference")
|
||||
self._worker.start()
|
||||
logger.info(
|
||||
"Remote inference started: prefix=%s client=%s rtc=%s",
|
||||
self._prefix,
|
||||
self._client_uuid,
|
||||
self._effective_rtc.enabled,
|
||||
)
|
||||
|
||||
def stop(self) -> None:
|
||||
logger.info("Stopping remote inference engine...")
|
||||
self._stop_event.set()
|
||||
self._active.clear()
|
||||
if self._worker is not None and self._worker.is_alive():
|
||||
# Worst case the worker is mid-handshake inside _enter_reconnect.
|
||||
join_timeout = max(3.0, self._config.handshake_timeout_s + self._config.request_timeout_s + 2.0)
|
||||
self._worker.join(timeout=join_timeout)
|
||||
if self._worker.is_alive():
|
||||
logger.warning("Remote inference worker did not join")
|
||||
self._worker = None
|
||||
|
||||
if self._zenoh is not None:
|
||||
# Best-effort graceful close; the server also GCs on liveliness drop.
|
||||
with contextlib.suppress(Exception):
|
||||
self._control_query(
|
||||
session_key(self._prefix),
|
||||
codec.encode_session_close(
|
||||
SessionCloseMsg(
|
||||
client_uuid=self._client_uuid,
|
||||
session_id=self._session_ack.session_id if self._session_ack else "",
|
||||
)
|
||||
),
|
||||
timeout=1.0,
|
||||
)
|
||||
if self._alive_token is not None:
|
||||
with contextlib.suppress(Exception):
|
||||
self._alive_token.undeclare()
|
||||
self._alive_token = None
|
||||
for declaration in self._declarations:
|
||||
with contextlib.suppress(Exception):
|
||||
declaration.undeclare()
|
||||
self._declarations.clear()
|
||||
self._obs_publisher = None
|
||||
with contextlib.suppress(Exception):
|
||||
self._zenoh.close()
|
||||
self._zenoh = None
|
||||
logger.info("Remote inference engine stopped")
|
||||
|
||||
def pause(self) -> None:
|
||||
"""Stop publishing observations; the local queue stays intact."""
|
||||
logger.info("Pausing remote inference (publishing stops, queue intact)")
|
||||
self._active.clear()
|
||||
|
||||
def resume(self) -> None:
|
||||
logger.info("Resuming remote inference")
|
||||
self._active.set()
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Episode boundary: clear local chunk state, notify the server.
|
||||
|
||||
The acked reset query runs on the worker thread (never I/O on the
|
||||
control thread); thanks to per-request server statelessness a
|
||||
lost ack only costs a warning — the next observation announces
|
||||
the new episode in its header anyway.
|
||||
"""
|
||||
logger.info("Resetting remote inference state (queue + episode)")
|
||||
with self._anchor_lock:
|
||||
if self._action_queue is not None:
|
||||
self._action_queue.clear()
|
||||
self._chunk_anchor_mono = None
|
||||
with self._state_lock:
|
||||
self._episode_id += 1
|
||||
self._pending_reset = True
|
||||
with self._obs_lock:
|
||||
# The previous episode's final frame must not seed the new
|
||||
# episode's first request.
|
||||
self._obs_holder["obs"] = None
|
||||
self._last_action = None
|
||||
# LatencyTracker intentionally survives reset: latency is
|
||||
# episode-invariant (parity with local RTC).
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Action production (main thread — never any I/O)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def notify_observation(self, obs: dict) -> None:
|
||||
with self._obs_lock:
|
||||
self._obs_holder["obs"] = obs
|
||||
|
||||
def get_action(self, obs_frame: dict | None) -> torch.Tensor | None:
|
||||
queue = self._action_queue
|
||||
if queue is None:
|
||||
return None
|
||||
|
||||
# Staleness bound (sync safety): never execute an action whose
|
||||
# source observation is older than max_action_age_s. The lock
|
||||
# makes the check-and-clear atomic with the worker's merge.
|
||||
with self._anchor_lock:
|
||||
anchor = self._chunk_anchor_mono
|
||||
if (
|
||||
anchor is not None
|
||||
and queue.qsize() > 0
|
||||
and time.monotonic() - anchor > self._config.max_action_age_s
|
||||
):
|
||||
logger.warning(
|
||||
"Dropping %d stale actions (older than %.1fs) — applying fallback",
|
||||
queue.qsize(),
|
||||
self._config.max_action_age_s,
|
||||
)
|
||||
self.stats["stale_drops"] += 1
|
||||
queue.clear()
|
||||
self._chunk_anchor_mono = None
|
||||
|
||||
action = queue.get()
|
||||
if action is not None:
|
||||
self._last_action = action
|
||||
return action
|
||||
|
||||
self._set_state(ClientState.STALLED if self.state == ClientState.DEGRADED else self.state)
|
||||
return self._fallback_action()
|
||||
|
||||
def _fallback_action(self) -> torch.Tensor | None:
|
||||
from .factory import FallbackMode
|
||||
|
||||
mode = self._config.fallback
|
||||
if mode == FallbackMode.REPEAT_LAST and self._last_action is not None:
|
||||
self.stats["fallback_ticks"] += 1
|
||||
return self._last_action.clone()
|
||||
if mode == FallbackMode.ZERO:
|
||||
# For velocity-controlled robots "send nothing" means "keep
|
||||
# last velocity" — an explicit zero command is the safe stop.
|
||||
self.stats["fallback_ticks"] += 1
|
||||
return torch.zeros(len(self._ordered_action_keys))
|
||||
return None # HOLD: send_next_action tolerates None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Handshake & control plane
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _handshake(self, initial: bool) -> SessionAckMsg:
|
||||
"""status (pre-flight) + session open; raises on rejection."""
|
||||
cfg = self._config
|
||||
status_data = self._control_query(status_key(self._prefix), b"", timeout=cfg.handshake_timeout_s)
|
||||
if status_data is None:
|
||||
raise ConnectionError(
|
||||
f"No policy server answered status query at {status_key(self._prefix)!r} "
|
||||
f"via {cfg.connect_endpoint!r} (timeout {cfg.handshake_timeout_s}s)"
|
||||
)
|
||||
status = codec.decode_status(status_data)
|
||||
logger.info(
|
||||
"Server status: model=%s@%s policy=%s sessions=%d/%d warmed_up=%s",
|
||||
status.model_repo,
|
||||
status.model_revision,
|
||||
status.policy_type,
|
||||
status.active_sessions,
|
||||
status.max_sessions,
|
||||
status.warmed_up,
|
||||
)
|
||||
|
||||
open_msg = SessionOpenMsg(
|
||||
client_uuid=self._client_uuid,
|
||||
robot_type=self._robot_type,
|
||||
policy_type=getattr(self._policy_config, "type", ""),
|
||||
fps=self._fps,
|
||||
action_names=self._ordered_action_keys,
|
||||
camera_names=self._wire_camera_names(),
|
||||
state_dim=self._state_dim(),
|
||||
schema_version=SCHEMA_VERSION,
|
||||
rtc_enabled=cfg.rtc.enabled,
|
||||
task=self._task,
|
||||
tags=cfg.tags,
|
||||
)
|
||||
ack_data = self._control_query(
|
||||
session_key(self._prefix), codec.encode_session_open(open_msg), timeout=cfg.request_timeout_s
|
||||
)
|
||||
if ack_data is None:
|
||||
raise ConnectionError("Session open query timed out")
|
||||
ack = codec.decode_session_ack(ack_data)
|
||||
if not ack.accepted:
|
||||
raise ConnectionError(f"Policy server rejected the session: {ack.error}")
|
||||
for warning in ack.warnings:
|
||||
logger.warning("Server warning: %s", warning)
|
||||
|
||||
# Hard sync-safety contract: chunk columns map to motors by order.
|
||||
if ack.action_names and ack.action_names != self._ordered_action_keys:
|
||||
raise ValueError(
|
||||
"Action name/order mismatch between server policy and this robot.\n"
|
||||
f" server: {ack.action_names}\n client: {self._ordered_action_keys}"
|
||||
)
|
||||
if not initial and self._session_ack is not None:
|
||||
previous = self._session_ack
|
||||
if (ack.model_repo, ack.model_revision) != (previous.model_repo, previous.model_revision):
|
||||
raise ValueError(
|
||||
f"Server model changed across reconnect "
|
||||
f"({previous.model_repo}@{previous.model_revision} → "
|
||||
f"{ack.model_repo}@{ack.model_revision}) — refusing to execute wrong-model chunks"
|
||||
)
|
||||
return ack
|
||||
|
||||
def _configure_from_ack(self, ack: SessionAckMsg) -> None:
|
||||
rtc_requested = self._config.rtc.enabled
|
||||
rtc_effective = rtc_requested and ack.supports_rtc
|
||||
if rtc_requested and not rtc_effective:
|
||||
logger.warning("RTC downgraded to chunk-append (server does not support RTC)")
|
||||
if self._action_queue is not None and self._action_queue.cfg.enabled != rtc_effective:
|
||||
# The queue's merge semantics (replace vs append) were fixed at
|
||||
# session start; a server whose RTC capability changed across a
|
||||
# reconnect would corrupt them.
|
||||
raise ValueError(
|
||||
"Server RTC capability changed across reconnect "
|
||||
f"(queue merge mode {'replace' if self._action_queue.cfg.enabled else 'append'} "
|
||||
f"vs server RTC={rtc_effective}) — refusing to continue"
|
||||
)
|
||||
self._effective_rtc = RTCConfig(
|
||||
enabled=rtc_effective,
|
||||
prefix_attention_schedule=self._config.rtc.prefix_attention_schedule,
|
||||
max_guidance_weight=self._config.rtc.max_guidance_weight,
|
||||
execution_horizon=ack.rtc_execution_horizon or self._config.rtc.execution_horizon,
|
||||
debug=self._config.rtc.debug,
|
||||
debug_maxlen=self._config.rtc.debug_maxlen,
|
||||
)
|
||||
if self._action_queue is None:
|
||||
self._action_queue = ActionQueue(self._effective_rtc)
|
||||
self._session_ack = ack
|
||||
|
||||
def _control_query(self, key: str, payload: bytes, timeout: float) -> bytes | None:
|
||||
"""One request/reply on the control plane; None on timeout/no-server."""
|
||||
zenoh = import_zenoh()
|
||||
try:
|
||||
replies = self._zenoh.get(
|
||||
key,
|
||||
handler=zenoh.handlers.FifoChannel(4),
|
||||
payload=payload,
|
||||
timeout=timeout,
|
||||
)
|
||||
deadline = time.monotonic() + timeout + 0.5
|
||||
while time.monotonic() < deadline:
|
||||
try:
|
||||
reply = replies.try_recv()
|
||||
except Exception: # zenoh.ZError: channel closed (no queryable / finished)
|
||||
return None
|
||||
if reply is None:
|
||||
time.sleep(0.005)
|
||||
continue
|
||||
if reply.ok is not None:
|
||||
return reply.ok.payload.to_bytes()
|
||||
return None # Reply.err (e.g. b"Timeout")
|
||||
return None
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.warning("Control query %s failed: %s", key, e)
|
||||
return None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Zenoh callbacks (deposit-only)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _on_chunk(self, sample: Any) -> None:
|
||||
try:
|
||||
attachment = sample.attachment
|
||||
if attachment is None:
|
||||
return
|
||||
header = MsgHeader.unpack(attachment.to_bytes())
|
||||
item = (header, sample.payload.to_bytes())
|
||||
try:
|
||||
self._reply_queue.put_nowait(item)
|
||||
except queue_module.Full:
|
||||
# Drop oldest, keep newest.
|
||||
with contextlib.suppress(queue_module.Empty):
|
||||
self._reply_queue.get_nowait()
|
||||
with contextlib.suppress(queue_module.Full):
|
||||
self._reply_queue.put_nowait(item)
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.error("chunk callback error: %s", e)
|
||||
|
||||
def _on_server_liveliness(self, sample: Any) -> None:
|
||||
try:
|
||||
import zenoh
|
||||
|
||||
if sample.kind == zenoh.SampleKind.DELETE:
|
||||
logger.warning("Server liveliness token dropped")
|
||||
self._server_alive.clear()
|
||||
else:
|
||||
self._server_alive.set()
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.error("liveliness callback error: %s", e)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Network worker
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _worker_loop(self) -> None:
|
||||
consecutive_errors = 0
|
||||
try:
|
||||
while not self._stop_event.is_set():
|
||||
if not self._active.is_set():
|
||||
time.sleep(_IDLE_SLEEP_S)
|
||||
continue
|
||||
try:
|
||||
self._maybe_send_reset()
|
||||
|
||||
if not self._server_alive.is_set():
|
||||
self._enter_reconnect("server liveliness dropped")
|
||||
continue
|
||||
|
||||
queue = self._action_queue
|
||||
if queue is not None and queue.qsize() * self._dt > self._config.buffer_time_s:
|
||||
time.sleep(_IDLE_SLEEP_S)
|
||||
continue
|
||||
|
||||
with self._obs_lock:
|
||||
obs = self._obs_holder.get("obs")
|
||||
if obs is None:
|
||||
time.sleep(_IDLE_SLEEP_S)
|
||||
continue
|
||||
|
||||
self._request_cycle(obs)
|
||||
consecutive_errors = 0
|
||||
except ConnectionError as e:
|
||||
# Raised by reconnect on hard contract violations.
|
||||
raise e
|
||||
except Exception as e: # noqa: BLE001 — transient worker errors retry
|
||||
consecutive_errors += 1
|
||||
logger.error(
|
||||
"Remote inference worker error (%d/%d): %s",
|
||||
consecutive_errors,
|
||||
_MAX_CONSECUTIVE_WORKER_ERRORS,
|
||||
e,
|
||||
)
|
||||
logger.debug(traceback.format_exc())
|
||||
if consecutive_errors >= _MAX_CONSECUTIVE_WORKER_ERRORS:
|
||||
raise
|
||||
time.sleep(0.5)
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.error("Fatal error in remote inference worker: %s", e)
|
||||
logger.error(traceback.format_exc())
|
||||
self._go_dead(str(e))
|
||||
|
||||
def _request_cycle(self, obs: dict) -> None:
|
||||
"""Publish one observation and merge its chunk (one-in-flight)."""
|
||||
cfg = self._config
|
||||
queue = self._action_queue
|
||||
|
||||
obs_frame = build_dataset_frame(self._hw_features, obs, prefix=OBS_STR)
|
||||
if self._rename_map:
|
||||
obs_frame = {self._rename_map.get(k, k): v for k, v in obs_frame.items()}
|
||||
|
||||
state = obs_frame.pop(OBS_STATE, None)
|
||||
images = {k: v for k, v in obs_frame.items() if isinstance(v, np.ndarray) and v.ndim == 3}
|
||||
|
||||
with self._state_lock:
|
||||
self._seq_id += 1
|
||||
seq_id = self._seq_id
|
||||
episode_id = self._episode_id
|
||||
epoch = self._epoch
|
||||
|
||||
# Snapshot RTC state (must precede the publish; merge validates
|
||||
# against idx_before).
|
||||
idx_before = queue.get_action_index()
|
||||
prefix_model: np.ndarray | None = None
|
||||
prefix_robot: np.ndarray | None = None
|
||||
delay_steps = 0
|
||||
if self._effective_rtc.enabled:
|
||||
horizon = self._effective_rtc.execution_horizon
|
||||
left_over = queue.get_left_over()
|
||||
if left_over is not None and left_over.numel():
|
||||
prefix_model = left_over[:horizon].to(torch.float32).numpy()
|
||||
processed_left_over = queue.get_processed_left_over()
|
||||
if processed_left_over is not None and processed_left_over.numel():
|
||||
prefix_robot = processed_left_over[:horizon].to(torch.float32).numpy()
|
||||
max_latency = self._latency_tracker.max() if len(self._latency_tracker) else 0.0
|
||||
delay_steps = math.ceil(max_latency / self._dt) if max_latency else 0
|
||||
|
||||
# A reset/reconnect between the counter snapshot and the prefix
|
||||
# snapshot would pair a new episode id with old-episode prefixes
|
||||
# — skip the cycle instead.
|
||||
with self._state_lock:
|
||||
if (self._episode_id, self._epoch) != (episode_id, epoch):
|
||||
return
|
||||
|
||||
header = MsgHeader(
|
||||
schema_version=SCHEMA_VERSION,
|
||||
msg_type=MSG_TYPE_OBS,
|
||||
seq_id=seq_id,
|
||||
episode_id=episode_id,
|
||||
client_mono_ns=time.monotonic_ns(),
|
||||
session_epoch=epoch,
|
||||
)
|
||||
msg = ObservationMsg(
|
||||
state=state,
|
||||
images=images,
|
||||
task=self._task,
|
||||
inference_delay_steps=delay_steps,
|
||||
prefix_model=prefix_model,
|
||||
prefix_robot=prefix_robot,
|
||||
episode_start=(queue.qsize() == 0 and idx_before == 0 and self._chunk_anchor_mono is None),
|
||||
jpeg_quality=cfg.jpeg_quality,
|
||||
)
|
||||
|
||||
t_send = time.perf_counter()
|
||||
self._obs_publisher.put(codec.encode_observation(msg), attachment=header.pack())
|
||||
self.stats["requests"] += 1
|
||||
|
||||
reply = self._await_chunk(seq_id, episode_id, epoch, timeout=cfg.request_timeout_s)
|
||||
if reply is None:
|
||||
self.stats["timeouts"] += 1
|
||||
self._on_request_timeout()
|
||||
return
|
||||
|
||||
chunk = codec.decode_action_chunk(reply)
|
||||
if chunk.chunk_model is None or chunk.chunk_robot is None:
|
||||
# A persistently malformed server must still trip the
|
||||
# degradation ladder, not stall in nominal state.
|
||||
logger.warning("Chunk for seq=%d had empty tensors — dropping", seq_id)
|
||||
self.stats["timeouts"] += 1
|
||||
self._on_request_timeout()
|
||||
return
|
||||
|
||||
latency = time.perf_counter() - t_send
|
||||
real_delay = math.ceil(latency / self._dt)
|
||||
with self._anchor_lock:
|
||||
# reset() takes the same lock before clearing: either the
|
||||
# reset fully precedes this merge (episode changed → drop the
|
||||
# stale chunk) or the merge completes first (and the reset
|
||||
# then clears it) — a stale chunk can never survive a reset.
|
||||
with self._state_lock:
|
||||
if (self._episode_id, self._epoch) != (episode_id, epoch):
|
||||
logger.debug("Dropping chunk seq=%d: episode/epoch changed mid-flight", seq_id)
|
||||
return
|
||||
queue.merge(
|
||||
torch.from_numpy(np.ascontiguousarray(chunk.chunk_model)),
|
||||
torch.from_numpy(np.ascontiguousarray(chunk.chunk_robot)),
|
||||
real_delay,
|
||||
idx_before,
|
||||
)
|
||||
self._chunk_anchor_mono = time.monotonic() - latency # ≈ when the source obs was sent
|
||||
self._latency_tracker.add(latency)
|
||||
self._last_chunk_mono = time.monotonic()
|
||||
self._offline_since_mono = None
|
||||
self.stats["merges"] += 1
|
||||
self._set_state(ClientState.STREAMING)
|
||||
logger.debug(
|
||||
"merge: seq=%d latency=%.0fms delay=%d queue=%d server(inf=%.0fms wait=%.0fms load=%.2f)",
|
||||
seq_id,
|
||||
latency * 1e3,
|
||||
real_delay,
|
||||
queue.qsize(),
|
||||
chunk.inference_ms,
|
||||
chunk.queue_wait_ms,
|
||||
chunk.server_load,
|
||||
)
|
||||
|
||||
def _await_chunk(self, seq_id: int, episode_id: int, epoch: int, timeout: float) -> bytes | None:
|
||||
"""Wait for the chunk answering the latest outstanding request.
|
||||
|
||||
Stale replies (older seq/episode/epoch) are dropped — under
|
||||
one-in-flight a late chunk can only ever answer an older request.
|
||||
"""
|
||||
deadline = time.monotonic() + timeout
|
||||
while not self._stop_event.is_set():
|
||||
remaining = deadline - time.monotonic()
|
||||
if remaining <= 0:
|
||||
return None
|
||||
try:
|
||||
header, payload = self._reply_queue.get(timeout=min(remaining, 0.1))
|
||||
except queue_module.Empty:
|
||||
continue
|
||||
if header.session_epoch != epoch or header.episode_id != episode_id:
|
||||
continue # stale epoch/episode (reset or reconnect happened)
|
||||
if header.seq_id != seq_id:
|
||||
continue # late reply to a superseded request
|
||||
return payload
|
||||
return None
|
||||
|
||||
def _maybe_send_reset(self) -> None:
|
||||
with self._state_lock:
|
||||
pending, episode_id = self._pending_reset, self._episode_id
|
||||
self._pending_reset = False
|
||||
if pending and self._zenoh is not None:
|
||||
ack_data = self._control_query(
|
||||
reset_key(self._prefix, self._client_uuid),
|
||||
codec.encode_reset(ResetMsg(client_uuid=self._client_uuid, episode_id=episode_id)),
|
||||
timeout=1.0,
|
||||
)
|
||||
if ack_data is None:
|
||||
# Harmless: the server is stateless per request and the next
|
||||
# observation header announces the new episode anyway.
|
||||
logger.warning("Reset ack not received (continuing — header carries the episode bump)")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Degradation / reconnect / death
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _on_request_timeout(self) -> None:
|
||||
if self._stop_event.is_set():
|
||||
# _await_chunk aborted by a normal stop(), not by the network.
|
||||
return
|
||||
now = time.monotonic()
|
||||
if self._offline_since_mono is None:
|
||||
self._offline_since_mono = now
|
||||
offline_for = now - self._offline_since_mono
|
||||
|
||||
queue = self._action_queue
|
||||
if queue is not None and queue.qsize() > 0:
|
||||
self._set_state(ClientState.DEGRADED)
|
||||
else:
|
||||
self._set_state(ClientState.STALLED)
|
||||
|
||||
last = self._last_chunk_mono or 0.0
|
||||
if (now - last if last else offline_for) >= self._config.degraded_after_s:
|
||||
logger.warning(
|
||||
"No chunk for %.1fs (queue=%d) — %s",
|
||||
offline_for,
|
||||
queue.qsize() if queue else 0,
|
||||
self.state,
|
||||
)
|
||||
if offline_for > self._config.max_offline_s:
|
||||
self._go_dead(f"offline for {offline_for:.0f}s (> max_offline_s)")
|
||||
return
|
||||
if not self._server_alive.is_set() or offline_for >= 2 * self._config.request_timeout_s:
|
||||
self._enter_reconnect(f"request timeouts for {offline_for:.0f}s")
|
||||
|
||||
def _enter_reconnect(self, reason: str) -> None:
|
||||
"""Backoff + re-handshake loop. Hard contract violations → DEAD."""
|
||||
self._set_state(ClientState.RECONNECTING)
|
||||
logger.warning("Reconnecting: %s", reason)
|
||||
if self._offline_since_mono is None:
|
||||
self._offline_since_mono = time.monotonic()
|
||||
backoff = self._config.reconnect_initial_backoff_s
|
||||
while not self._stop_event.is_set():
|
||||
if not self._active.is_set():
|
||||
# Paused (e.g. DAgger human correction): keep trying to
|
||||
# reconnect, but a pause must never burn the offline budget
|
||||
# into a mid-correction shutdown.
|
||||
self._offline_since_mono = time.monotonic()
|
||||
offline_for = time.monotonic() - self._offline_since_mono
|
||||
if offline_for > self._config.max_offline_s:
|
||||
self._go_dead(f"offline for {offline_for:.0f}s (> max_offline_s)")
|
||||
return
|
||||
self._stop_event.wait(timeout=backoff)
|
||||
if self._stop_event.is_set():
|
||||
return
|
||||
backoff = min(backoff * 2, self._config.reconnect_max_backoff_s)
|
||||
try:
|
||||
with self._state_lock:
|
||||
self._epoch += 1
|
||||
ack = self._handshake(initial=False)
|
||||
self._configure_from_ack(ack)
|
||||
except ValueError as e:
|
||||
# Capability/schema/model mismatch: never execute wrong-model chunks.
|
||||
self._go_dead(str(e))
|
||||
return
|
||||
except Exception as e: # noqa: BLE001 — server still down, keep trying
|
||||
logger.info("Re-handshake failed (%s) — retrying in %.1fs", e, backoff)
|
||||
continue
|
||||
# A successful handshake is proof of life even if the liveliness
|
||||
# PUT was missed or hasn't been delivered yet.
|
||||
self._server_alive.set()
|
||||
# The offline budget is only reset by the next successful merge:
|
||||
# a server that handshakes but never delivers chunks must still
|
||||
# run out of budget and go DEAD.
|
||||
self.stats["reconnects"] += 1
|
||||
self._set_state(ClientState.STREAMING)
|
||||
logger.info("Reconnected (epoch=%d, session=%s)", self._epoch, ack.session_id)
|
||||
return
|
||||
|
||||
def _go_dead(self, reason: str) -> None:
|
||||
if self._dead.is_set():
|
||||
return
|
||||
logger.error("Remote inference DEAD: %s", reason)
|
||||
self._set_state(ClientState.DEAD)
|
||||
self._dead.set()
|
||||
if self._global_shutdown_event is not None:
|
||||
self._global_shutdown_event.set()
|
||||
|
||||
def _set_state(self, new_state: str) -> None:
|
||||
if new_state != self.state:
|
||||
logger.info("Client state: %s → %s", self.state, new_state)
|
||||
self.state = new_state
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Feature helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _wire_camera_names(self) -> list[str]:
|
||||
names = [
|
||||
key for key, feature in self._hw_features.items() if feature.get("dtype") in ("image", "video")
|
||||
]
|
||||
return [self._rename_map.get(name, name) for name in names]
|
||||
|
||||
def _state_dim(self) -> int:
|
||||
state_feature = self._hw_features.get(OBS_STATE)
|
||||
if state_feature and state_feature.get("names"):
|
||||
return len(state_feature["names"])
|
||||
return 0
|
||||
@@ -17,7 +17,6 @@
|
||||
from .base import BaseStrategy
|
||||
from .core import RolloutStrategy, estimate_max_episode_seconds, safe_push_to_hub, send_next_action
|
||||
from .dagger import DAggerEvents, DAggerPhase, DAggerStrategy
|
||||
from .episodic import EpisodicStrategy
|
||||
from .factory import create_strategy
|
||||
from .highlight import HighlightStrategy
|
||||
from .sentry import SentryStrategy
|
||||
@@ -28,7 +27,6 @@ __all__ = [
|
||||
"DAggerPhase",
|
||||
"DAggerStrategy",
|
||||
"HighlightStrategy",
|
||||
"EpisodicStrategy",
|
||||
"RolloutStrategy",
|
||||
"SentryStrategy",
|
||||
"create_strategy",
|
||||
|
||||
@@ -56,14 +56,10 @@ from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lerobot.common.control_utils import (
|
||||
follower_smooth_move_to,
|
||||
is_headless,
|
||||
teleop_smooth_move_to,
|
||||
teleop_supports_feedback,
|
||||
)
|
||||
from lerobot.common.control_utils import is_headless
|
||||
from lerobot.datasets import VideoEncodingManager
|
||||
from lerobot.datasets.utils import DEFAULT_VIDEO_FILE_SIZE_IN_MB
|
||||
from lerobot.teleoperators import Teleoperator
|
||||
from lerobot.utils.constants import ACTION, OBS_STR
|
||||
from lerobot.utils.feature_utils import build_dataset_frame
|
||||
from lerobot.utils.import_utils import _pynput_available
|
||||
@@ -73,6 +69,7 @@ from lerobot.utils.utils import log_say
|
||||
|
||||
from ..configs import DAggerKeyboardConfig, DAggerPedalConfig, DAggerStrategyConfig
|
||||
from ..context import RolloutContext
|
||||
from ..robot_wrapper import ThreadSafeRobot
|
||||
from .core import RolloutStrategy, estimate_max_episode_seconds, safe_push_to_hub, send_next_action
|
||||
|
||||
PYNPUT_AVAILABLE = _pynput_available
|
||||
@@ -174,6 +171,64 @@ class DAggerEvents:
|
||||
self.upload_requested.clear()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Teleoperator helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _teleop_supports_feedback(teleop: Teleoperator) -> bool:
|
||||
"""Return True when the teleop can receive position feedback (is actuated).
|
||||
TODO(Maxime): See if it is possible to unify this interface across teleops instead of duck-typing.
|
||||
"""
|
||||
return (
|
||||
bool(teleop.feedback_features)
|
||||
and hasattr(teleop, "disable_torque")
|
||||
and hasattr(teleop, "enable_torque")
|
||||
)
|
||||
|
||||
|
||||
def _teleop_smooth_move_to(
|
||||
teleop: Teleoperator, target_pos: dict, duration_s: float = 2.0, fps: int = 30
|
||||
) -> None:
|
||||
"""Smoothly move an actuated teleop to ``target_pos`` via linear interpolation.
|
||||
|
||||
Requires the teleoperator to support feedback
|
||||
(i.e. have non-empty ``feedback_features`` and implement ``disable_torque`` / ``enable_torque``).
|
||||
|
||||
TODO(Maxime): This blocks up to ``duration_s`` seconds, during this time
|
||||
the follower robot doesn't receive new actions, this could be an issue on LeKiwi.
|
||||
"""
|
||||
teleop.enable_torque()
|
||||
current = teleop.get_action()
|
||||
steps = max(int(duration_s * fps), 1)
|
||||
|
||||
for step in range(steps + 1):
|
||||
t = step / steps
|
||||
interp = {
|
||||
k: current[k] * (1 - t) + target_pos[k] * t if k in target_pos else current[k] for k in current
|
||||
}
|
||||
teleop.send_feedback(interp)
|
||||
time.sleep(1 / fps)
|
||||
|
||||
|
||||
def _follower_smooth_move_to(
|
||||
robot: ThreadSafeRobot, current: dict, target: dict, duration_s: float = 1.0, fps: int = 30
|
||||
) -> None:
|
||||
"""Smoothly move the follower robot from ``current`` to ``target`` action.
|
||||
|
||||
Used when the teleop is non-actuated: instead of driving the leader arm
|
||||
to the follower, we bring the follower to the teleop's current pose.
|
||||
Both ``current`` and ``target`` must be in robot-action key space.
|
||||
"""
|
||||
steps = max(int(duration_s * fps), 1)
|
||||
|
||||
for step in range(steps + 1):
|
||||
t = step / steps
|
||||
interp = {k: current[k] * (1 - t) + target[k] * t if k in target else current[k] for k in current}
|
||||
robot.send_action(interp)
|
||||
time.sleep(1 / fps)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Input device handlers
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -701,31 +756,31 @@ class DAggerStrategy(RolloutStrategy):
|
||||
logger.info("Pausing engine - robot holds position")
|
||||
engine.pause()
|
||||
|
||||
if teleop_supports_feedback(teleop) and prev_action is not None:
|
||||
if _teleop_supports_feedback(teleop) and prev_action is not None:
|
||||
# TODO(Maxime): prev_action is in robot action key space (output of robot_action_processor).
|
||||
# send_feedback expects teleop feedback key space. For homogeneous setups (e.g. SO-101
|
||||
# leader + SO-101 follower) the keys are identical so this works. If the processor pipeline
|
||||
# does non-trivial key renaming (e.g. a rename_map on action keys), the interpolation in
|
||||
# teleop_smooth_move_to silently no-ops and the arm doesn't move.
|
||||
# _teleop_smooth_move_to silently no-ops and the arm doesn't move.
|
||||
logger.info("Smooth handover: moving leader arm to follower position")
|
||||
teleop_smooth_move_to(teleop, prev_action)
|
||||
_teleop_smooth_move_to(teleop, prev_action)
|
||||
|
||||
elif old_phase == DAggerPhase.PAUSED and new_phase == DAggerPhase.CORRECTING:
|
||||
logger.info("Entering correction mode - human teleop control")
|
||||
if not teleop_supports_feedback(teleop) and prev_action is not None:
|
||||
if not _teleop_supports_feedback(teleop) and prev_action is not None:
|
||||
logger.info("Smooth handover: sliding follower to teleop position")
|
||||
obs = robot.get_observation()
|
||||
teleop_action = teleop.get_action()
|
||||
processed = ctx.processors.teleop_action_processor((teleop_action, obs))
|
||||
target = ctx.processors.robot_action_processor((processed, obs))
|
||||
follower_smooth_move_to(robot, prev_action, target)
|
||||
_follower_smooth_move_to(robot, prev_action, target)
|
||||
|
||||
# unlock the teleop for human control
|
||||
if teleop_supports_feedback(teleop):
|
||||
if _teleop_supports_feedback(teleop):
|
||||
teleop.disable_torque()
|
||||
|
||||
elif old_phase == DAggerPhase.CORRECTING and new_phase == DAggerPhase.PAUSED:
|
||||
if teleop_supports_feedback(teleop):
|
||||
if _teleop_supports_feedback(teleop):
|
||||
teleop.enable_torque()
|
||||
|
||||
elif new_phase == DAggerPhase.AUTONOMOUS:
|
||||
@@ -735,7 +790,7 @@ class DAggerStrategy(RolloutStrategy):
|
||||
engine.resume()
|
||||
|
||||
# release teleop before resuming the policy
|
||||
if teleop_supports_feedback(teleop):
|
||||
if _teleop_supports_feedback(teleop):
|
||||
teleop.disable_torque()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@@ -1,335 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Episodic rollout strategy: mirrors the behavior of ``lerobot-record``.
|
||||
|
||||
- Policy drives the robot during each recording episode.
|
||||
- An optional teleoperator can drive the robot during reset phases so the
|
||||
operator can bring the environment back to its starting configuration.
|
||||
If no teleop is connected the robot stays in its current position.
|
||||
- Keyboard controls:
|
||||
|
||||
Right arrow — end the current episode or reset phase early
|
||||
Left arrow — discard the current episode and re-record it
|
||||
Escape — stop the recording session
|
||||
|
||||
Dataset naming follows the rollout convention: repo names must start with ``rollout_``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
import time
|
||||
|
||||
from lerobot.common.control_utils import (
|
||||
follower_smooth_move_to,
|
||||
init_keyboard_listener,
|
||||
is_headless,
|
||||
teleop_smooth_move_to,
|
||||
teleop_supports_feedback,
|
||||
)
|
||||
from lerobot.datasets import VideoEncodingManager
|
||||
from lerobot.utils.constants import ACTION, OBS_STR
|
||||
from lerobot.utils.feature_utils import build_dataset_frame
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import log_rerun_data
|
||||
|
||||
from ..configs import EpisodicStrategyConfig
|
||||
from ..context import RolloutContext
|
||||
from .core import RolloutStrategy, safe_push_to_hub, send_next_action
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EpisodicStrategy(RolloutStrategy):
|
||||
"""Policy-driven multi-episode recording, mirrors the behavior of ``lerobot-record``.
|
||||
|
||||
Each recording episode runs the policy for maximum ``dataset.episode_time_s``
|
||||
seconds, recording every frame. A reset phase of ``dataset.reset_time_s``
|
||||
follows every episode (except the last) so the operator can manually
|
||||
reset the environment. During the reset phase, an optional teleoperator
|
||||
drives the robot; if none is present the robot returns to its initial joint positions captured at startup.
|
||||
|
||||
The policy state (hidden state, RTC queue, interpolator) is reset at
|
||||
the start of each recording episode.
|
||||
|
||||
Keyboard events:
|
||||
right arrow → end current episode or reset phase early
|
||||
left arrow → discard & re-record current episode
|
||||
ESC → stop the session
|
||||
"""
|
||||
|
||||
config: EpisodicStrategyConfig
|
||||
|
||||
def __init__(self, config: EpisodicStrategyConfig) -> None:
|
||||
super().__init__(config)
|
||||
self._listener = None
|
||||
self._events: dict | None = None
|
||||
|
||||
def setup(self, ctx: RolloutContext) -> None:
|
||||
"""Start the inference engine and attach the keyboard listener."""
|
||||
self._init_engine(ctx)
|
||||
self._listener, self._events = init_keyboard_listener()
|
||||
logger.info("Episodic strategy ready")
|
||||
|
||||
def run(self, ctx: RolloutContext) -> None:
|
||||
"""Main multi-episode recording loop."""
|
||||
cfg = ctx.runtime.cfg
|
||||
dataset_cfg = cfg.dataset
|
||||
robot = ctx.hardware.robot_wrapper
|
||||
teleop = ctx.hardware.teleop
|
||||
dataset = ctx.data.dataset
|
||||
events = self._events
|
||||
features = ctx.data.dataset_features
|
||||
|
||||
fps = cfg.fps
|
||||
episode_time_s = dataset_cfg.episode_time_s
|
||||
reset_time_s = dataset_cfg.reset_time_s
|
||||
num_episodes = dataset_cfg.num_episodes
|
||||
single_task = dataset_cfg.single_task or cfg.task
|
||||
play_sounds = cfg.play_sounds
|
||||
|
||||
display_compressed = (
|
||||
True
|
||||
if (cfg.display_data and cfg.display_ip is not None and cfg.display_port is not None)
|
||||
else cfg.display_compressed_images
|
||||
)
|
||||
|
||||
with VideoEncodingManager(dataset):
|
||||
try:
|
||||
recorded_episodes = 0
|
||||
while recorded_episodes < num_episodes and not events["stop_recording"]:
|
||||
if ctx.runtime.shutdown_event.is_set():
|
||||
break
|
||||
|
||||
# Reset policy state at episode start (discard leftover hidden state / queue)
|
||||
self._engine.reset()
|
||||
self._interpolator.reset()
|
||||
self._engine.resume()
|
||||
|
||||
log_say(f"Recording episode {dataset.num_episodes}", play_sounds)
|
||||
self._policy_loop(
|
||||
ctx=ctx,
|
||||
robot=robot,
|
||||
events=events,
|
||||
features=features,
|
||||
fps=fps,
|
||||
control_time_s=episode_time_s,
|
||||
dataset=dataset,
|
||||
single_task=single_task,
|
||||
)
|
||||
|
||||
# Reset phase, skip after the last episode (but run when re-recording)
|
||||
if not events["stop_recording"] and (
|
||||
recorded_episodes < num_episodes - 1 or events["rerecord_episode"]
|
||||
):
|
||||
log_say("Reset the environment", play_sounds)
|
||||
|
||||
if teleop:
|
||||
# Smooth handover so the transition to teleop control is jerk-free.
|
||||
# For actuated teleops: drive the leader arm to the follower's current
|
||||
# position so the operator takes over without fighting the arm.
|
||||
# For non-actuated teleops: slide the follower to the teleop's current
|
||||
# pose instead, since the leader cannot be driven.
|
||||
obs = robot.get_observation()
|
||||
current_pos = {k: v for k, v in obs.items() if k.endswith(".pos")}
|
||||
if (
|
||||
teleop_supports_feedback(teleop)
|
||||
and self.config.smooth_leader_to_follower_handover
|
||||
):
|
||||
logger.info("Smooth handover: moving leader arm to follower position")
|
||||
teleop_smooth_move_to(teleop, current_pos, duration_s=2)
|
||||
teleop.disable_torque()
|
||||
else:
|
||||
logger.info("Smooth handover: sliding follower to teleop position")
|
||||
teleop_action = teleop.get_action()
|
||||
processed = ctx.processors.teleop_action_processor((teleop_action, obs))
|
||||
target = ctx.processors.robot_action_processor((processed, obs))
|
||||
follower_smooth_move_to(robot, current_pos, target, duration_s=1)
|
||||
|
||||
elif self.config.reset_to_initial_position:
|
||||
# No teleop: return the robot to its startup position.
|
||||
self._return_to_initial_position(hw=ctx.hardware, duration_s=1)
|
||||
|
||||
self._reset_loop(
|
||||
ctx=ctx,
|
||||
robot=robot,
|
||||
teleop=teleop,
|
||||
events=events,
|
||||
fps=fps,
|
||||
control_time_s=reset_time_s,
|
||||
display_data=cfg.display_data,
|
||||
display_compressed=display_compressed,
|
||||
)
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-record episode", play_sounds)
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
|
||||
# returns to its initial joint positions captured at startup
|
||||
if not teleop and self.config.reset_to_initial_position:
|
||||
self._return_to_initial_position(hw=ctx.hardware, duration_s=1)
|
||||
|
||||
continue
|
||||
|
||||
dataset.save_episode()
|
||||
recorded_episodes += 1
|
||||
finally:
|
||||
# Save any frames buffered in the current episode so an unexpected
|
||||
# exception or KeyboardInterrupt does not silently drop recorded data.
|
||||
# suppress: save_episode raises if the buffer is empty (nothing to lose).
|
||||
logger.info("Episodic control loop ended — saving any in-progress episode")
|
||||
with contextlib.suppress(Exception):
|
||||
dataset.save_episode()
|
||||
|
||||
def _policy_loop(
|
||||
self,
|
||||
ctx: RolloutContext,
|
||||
robot,
|
||||
events: dict,
|
||||
features: dict,
|
||||
fps: float,
|
||||
control_time_s: float,
|
||||
dataset,
|
||||
single_task: str,
|
||||
) -> None:
|
||||
"""Policy-driven recording loop for a single episode."""
|
||||
interpolator = self._interpolator
|
||||
control_interval = interpolator.get_control_interval(fps)
|
||||
|
||||
timestamp = 0.0
|
||||
start_t = time.perf_counter()
|
||||
|
||||
while timestamp < control_time_s:
|
||||
loop_start = time.perf_counter()
|
||||
|
||||
if events["exit_early"]:
|
||||
events["exit_early"] = False
|
||||
break
|
||||
|
||||
if ctx.runtime.shutdown_event.is_set():
|
||||
break
|
||||
|
||||
obs = robot.get_observation()
|
||||
obs_processed = self._process_observation_and_notify(ctx.processors, obs)
|
||||
|
||||
if self._handle_warmup(ctx.runtime.cfg.use_torch_compile, loop_start, control_interval):
|
||||
continue
|
||||
|
||||
action_dict = send_next_action(obs_processed, obs, ctx, interpolator)
|
||||
|
||||
if action_dict is not None:
|
||||
obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR)
|
||||
action_frame = build_dataset_frame(features, action_dict, prefix=ACTION)
|
||||
dataset.add_frame({**obs_frame, **action_frame, "task": single_task})
|
||||
self._log_telemetry(obs_processed, action_dict, ctx.runtime)
|
||||
|
||||
dt = time.perf_counter() - loop_start
|
||||
sleep_t = control_interval - dt
|
||||
if sleep_t < 0:
|
||||
logger.warning(
|
||||
f"Record loop is running slower ({1 / dt:.1f} Hz) than the target FPS ({fps} Hz). "
|
||||
"Dataset frames might be dropped and robot control might be unstable. "
|
||||
"Common causes are: 1) Camera FPS not keeping up 2) Policy inference taking too long "
|
||||
"3) CPU starvation"
|
||||
)
|
||||
precise_sleep(max(sleep_t, 0.0))
|
||||
timestamp = time.perf_counter() - start_t
|
||||
|
||||
def _reset_loop(
|
||||
self,
|
||||
ctx: RolloutContext,
|
||||
robot,
|
||||
teleop,
|
||||
events: dict,
|
||||
fps: float,
|
||||
control_time_s: float,
|
||||
display_data: bool,
|
||||
display_compressed: bool,
|
||||
) -> None:
|
||||
"""Reset-phase loop: teleop drives the robot if available, no recording."""
|
||||
processors = ctx.processors
|
||||
control_interval = 1.0 / fps
|
||||
|
||||
timestamp = 0.0
|
||||
start_t = time.perf_counter()
|
||||
|
||||
while timestamp < control_time_s:
|
||||
loop_start = time.perf_counter()
|
||||
|
||||
if events["exit_early"]:
|
||||
events["exit_early"] = False
|
||||
break
|
||||
|
||||
if ctx.runtime.shutdown_event.is_set():
|
||||
break
|
||||
|
||||
obs = robot.get_observation()
|
||||
|
||||
if teleop is not None:
|
||||
act = teleop.get_action()
|
||||
act_teleop = processors.teleop_action_processor((act, obs))
|
||||
robot_action = processors.robot_action_processor((act_teleop, obs))
|
||||
robot.send_action(robot_action)
|
||||
|
||||
if display_data:
|
||||
obs_processed = processors.robot_observation_processor(obs)
|
||||
log_rerun_data(
|
||||
observation=obs_processed,
|
||||
action=act_teleop,
|
||||
compress_images=display_compressed,
|
||||
)
|
||||
|
||||
dt = time.perf_counter() - loop_start
|
||||
sleep_t = control_interval - dt
|
||||
precise_sleep(max(sleep_t, 0.0))
|
||||
timestamp = time.perf_counter() - start_t
|
||||
|
||||
def teardown(self, ctx: RolloutContext) -> None:
|
||||
"""Finalise dataset, stop listener, push to hub, and disconnect hardware."""
|
||||
cfg = ctx.runtime.cfg
|
||||
play_sounds = cfg.play_sounds
|
||||
|
||||
log_say("Stop recording", play_sounds, blocking=True)
|
||||
|
||||
if not is_headless() and self._listener is not None:
|
||||
self._listener.stop()
|
||||
|
||||
if ctx.data.dataset is not None:
|
||||
logger.info("Finalizing dataset...")
|
||||
ctx.data.dataset.finalize()
|
||||
|
||||
if (
|
||||
cfg.dataset is not None
|
||||
and cfg.dataset.push_to_hub
|
||||
and ctx.data.dataset is not None
|
||||
and safe_push_to_hub(
|
||||
ctx.data.dataset,
|
||||
tags=cfg.dataset.tags,
|
||||
private=cfg.dataset.private,
|
||||
)
|
||||
):
|
||||
logger.info("Dataset uploaded to hub")
|
||||
log_say("Dataset uploaded to hub", play_sounds)
|
||||
|
||||
self._teardown_hardware(
|
||||
ctx.hardware,
|
||||
return_to_initial_position=cfg.return_to_initial_position,
|
||||
)
|
||||
log_say("Exiting", play_sounds)
|
||||
logger.info("Episodic strategy teardown complete")
|
||||
@@ -21,7 +21,6 @@ from typing import TYPE_CHECKING
|
||||
from .base import BaseStrategy
|
||||
from .core import RolloutStrategy
|
||||
from .dagger import DAggerStrategy
|
||||
from .episodic import EpisodicStrategy
|
||||
from .highlight import HighlightStrategy
|
||||
from .sentry import SentryStrategy
|
||||
|
||||
@@ -43,8 +42,4 @@ def create_strategy(config: RolloutStrategyConfig) -> RolloutStrategy:
|
||||
return HighlightStrategy(config)
|
||||
if config.type == "dagger":
|
||||
return DAggerStrategy(config)
|
||||
if config.type == "episodic":
|
||||
return EpisodicStrategy(config)
|
||||
raise ValueError(
|
||||
f"Unknown strategy type '{config.type}'. Available: base, sentry, highlight, dagger, episodic"
|
||||
)
|
||||
raise ValueError(f"Unknown strategy type '{config.type}'. Available: base, sentry, highlight, dagger")
|
||||
|
||||
@@ -1,93 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Serve a pretrained policy to remote ``lerobot-rollout`` clients over Zenoh.
|
||||
|
||||
One process = one pre-warmed (model, revision, dtype, device) on one GPU.
|
||||
Robots connect with ``lerobot-rollout --inference.type=remote``.
|
||||
|
||||
Usage examples
|
||||
--------------
|
||||
|
||||
Serve a model from a YAML manifest::
|
||||
|
||||
lerobot-policy-server --manifest server.yaml
|
||||
|
||||
Minimal manifest::
|
||||
|
||||
model:
|
||||
repo_or_path: lerobot/pi0_towels
|
||||
device: cuda
|
||||
default_task: "fold the towel"
|
||||
max_sessions: 5
|
||||
zenoh:
|
||||
mode: client
|
||||
connect_endpoints: ["tcp/router.gpu-cluster.internal:7447"]
|
||||
|
||||
Everything in the manifest can also be set directly on the CLI::
|
||||
|
||||
lerobot-policy-server \\
|
||||
--model.repo_or_path=lerobot/pi0_towels \\
|
||||
--default_task="fold the towel" \\
|
||||
--zenoh.mode=peer --zenoh.listen_endpoints='["tcp/0.0.0.0:7447"]'
|
||||
|
||||
SIGTERM/SIGINT drains gracefully: the server drops its liveliness token
|
||||
(clients ride their action buffers through the drain), finishes the
|
||||
in-flight inference, and exits.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import signal
|
||||
import sys
|
||||
|
||||
from lerobot.configs import parser
|
||||
from lerobot.policy_server.manifest import PolicyServerManifest
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def policy_server(manifest: PolicyServerManifest):
|
||||
init_logging()
|
||||
from lerobot.policy_server.server import PolicyServer
|
||||
|
||||
server = PolicyServer(manifest)
|
||||
server.load_policy()
|
||||
|
||||
def _drain(signum, frame): # noqa: ARG001
|
||||
logger.info("Signal %s received — draining", signum)
|
||||
server.stop()
|
||||
sys.exit(0)
|
||||
|
||||
signal.signal(signal.SIGTERM, _drain)
|
||||
server.start()
|
||||
server.serve_forever()
|
||||
|
||||
|
||||
def main():
|
||||
# `--manifest path.yaml` is sugar for draccus's `--config_path`.
|
||||
sys.argv = [
|
||||
arg.replace("--manifest=", "--config_path=") if arg.startswith("--manifest=") else arg
|
||||
for arg in sys.argv
|
||||
]
|
||||
if "--manifest" in sys.argv:
|
||||
sys.argv[sys.argv.index("--manifest")] = "--config_path"
|
||||
policy_server()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -25,13 +25,11 @@ Strategies
|
||||
--strategy.type=sentry Continuous recording with auto-upload
|
||||
--strategy.type=highlight Ring buffer + keystroke save
|
||||
--strategy.type=dagger Human-in-the-loop (DAgger / RaC)
|
||||
--strategy.type=episodic Episode-oriented recording with reset phases
|
||||
|
||||
Inference backends
|
||||
------------------
|
||||
--inference.type=sync One policy call per control tick (default)
|
||||
--inference.type=rtc Real-Time Chunking for slow VLA models
|
||||
--inference.type=remote Network inference via lerobot-policy-server (weightless edge)
|
||||
|
||||
Usage examples
|
||||
--------------
|
||||
@@ -113,18 +111,6 @@ Usage examples
|
||||
--display_data=true \\
|
||||
--use_torch_compile=true
|
||||
|
||||
# Episodic mode — episode-oriented recording with reset phases
|
||||
lerobot-rollout \\
|
||||
--strategy.type=episodic \\
|
||||
--policy.path=user/my_policy \\
|
||||
--robot.type=so100_follower \\
|
||||
--robot.port=/dev/ttyACM0 \\
|
||||
--teleop.type=so100_leader \\
|
||||
--teleop.port=/dev/ttyACM1 \\
|
||||
--dataset.repo_id=user/rollout_episodic_data \\
|
||||
--dataset.num_episodes=20 \\
|
||||
--dataset.single_task="Grab the cube"
|
||||
|
||||
# Resume a previous sentry recording session
|
||||
lerobot-rollout \\
|
||||
--strategy.type=sentry \\
|
||||
@@ -146,19 +132,6 @@ Usage examples
|
||||
--dataset.camera_encoder.vcodec=h264 \\
|
||||
--dataset.camera_encoder.preset=fast \\
|
||||
--dataset.camera_encoder.extra_options={"tune": "film", "profile:v": "high", "bf": 2}
|
||||
|
||||
# Sentry mode — remote inference against a lerobot-policy-server (weightless edge)
|
||||
lerobot-rollout \\
|
||||
--strategy.type=sentry \\
|
||||
--strategy.upload_every_n_episodes=5 \\
|
||||
--policy.path=lerobot/pi0_base \\
|
||||
--inference.type=remote \\
|
||||
--inference.connect_endpoint=tcp/router.gpu-cluster.internal:7447 \\
|
||||
--inference.rtc.execution_horizon=10 \\
|
||||
--robot.type=so100_follower \\
|
||||
--robot.port=/dev/ttyACM0 \\
|
||||
--dataset.repo_id=user/rollout_sentry_data \\
|
||||
--dataset.single_task="patrol" --duration=3600
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
@@ -99,9 +99,6 @@ def update_policy(
|
||||
start_time = time.perf_counter()
|
||||
policy.train()
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
# Compute sample weights if a weighter is provided
|
||||
sample_weights = None
|
||||
weight_stats = None
|
||||
@@ -161,8 +158,6 @@ def update_policy(
|
||||
train_metrics.grad_norm = grad_norm.item()
|
||||
train_metrics.lr = optimizer.param_groups[0]["lr"]
|
||||
train_metrics.update_s = time.perf_counter() - start_time
|
||||
if torch.cuda.is_available():
|
||||
train_metrics.gpu_mem_gb = torch.cuda.max_memory_allocated() / (1024**3)
|
||||
return train_metrics, output_dict
|
||||
|
||||
|
||||
@@ -237,18 +232,15 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
# Dataset loading synchronization: each node's local main process downloads first to avoid
|
||||
# race conditions (the global main process only exists on node 0, so gating on it would let
|
||||
# all ranks of the other nodes download and build the Arrow cache concurrently).
|
||||
if accelerator.is_local_main_process:
|
||||
if is_main_process:
|
||||
logging.info("Creating dataset")
|
||||
# Dataset loading synchronization: main process downloads first to avoid race conditions
|
||||
if is_main_process:
|
||||
logging.info("Creating dataset")
|
||||
dataset = make_dataset(cfg)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# Now all other processes can safely load the dataset from the local cache
|
||||
if not accelerator.is_local_main_process:
|
||||
# Now all other processes can safely load the dataset
|
||||
if not is_main_process:
|
||||
dataset = make_dataset(cfg)
|
||||
|
||||
# Create environment used for evaluating checkpoints during training on simulation data.
|
||||
@@ -300,8 +292,19 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
|
||||
active_cfg = cfg.trainable_config
|
||||
processor_pretrained_path = active_cfg.pretrained_path
|
||||
if (
|
||||
getattr(active_cfg, "use_relative_actions", False)
|
||||
and processor_pretrained_path is not None
|
||||
and not cfg.resume
|
||||
):
|
||||
logging.warning(
|
||||
"use_relative_actions=true with pretrained processors can skip relative transforms if "
|
||||
"the checkpoint processors do not define them. Building processors from current policy config."
|
||||
)
|
||||
processor_pretrained_path = None
|
||||
|
||||
processor_kwargs = {}
|
||||
postprocessor_kwargs = {}
|
||||
if (processor_pretrained_path and not cfg.resume) or not processor_pretrained_path:
|
||||
processor_kwargs["dataset_stats"] = dataset.meta.stats
|
||||
|
||||
@@ -309,31 +312,24 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
processor_kwargs["dataset_meta"] = dataset.meta
|
||||
|
||||
if not cfg.is_reward_model_training and processor_pretrained_path is not None:
|
||||
preprocessor_overrides = {
|
||||
processor_kwargs["preprocessor_overrides"] = {
|
||||
"device_processor": {"device": device.type},
|
||||
"normalizer_processor": {
|
||||
"stats": dataset.meta.stats,
|
||||
"features": {**policy.config.input_features, **policy.config.output_features},
|
||||
"norm_map": policy.config.normalization_mapping,
|
||||
},
|
||||
"rename_observations_processor": {"rename_map": cfg.rename_map},
|
||||
}
|
||||
postprocessor_overrides = {
|
||||
processor_kwargs["preprocessor_overrides"]["rename_observations_processor"] = {
|
||||
"rename_map": cfg.rename_map
|
||||
}
|
||||
postprocessor_kwargs["postprocessor_overrides"] = {
|
||||
"unnormalizer_processor": {
|
||||
"stats": dataset.meta.stats,
|
||||
"features": policy.config.output_features,
|
||||
"norm_map": policy.config.normalization_mapping,
|
||||
},
|
||||
}
|
||||
if getattr(active_cfg, "use_relative_actions", False):
|
||||
preprocessor_overrides["relative_actions_processor"] = {
|
||||
"enabled": True,
|
||||
"exclude_joints": getattr(active_cfg, "relative_exclude_joints", []),
|
||||
"action_names": getattr(active_cfg, "action_feature_names", None),
|
||||
}
|
||||
postprocessor_overrides["absolute_actions_processor"] = {"enabled": True}
|
||||
processor_kwargs["preprocessor_overrides"] = preprocessor_overrides
|
||||
processor_kwargs["postprocessor_overrides"] = postprocessor_overrides
|
||||
|
||||
if cfg.is_reward_model_training:
|
||||
preprocessor, postprocessor = make_reward_pre_post_processors(
|
||||
@@ -345,6 +341,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
policy_cfg=cfg.policy,
|
||||
pretrained_path=processor_pretrained_path,
|
||||
**processor_kwargs,
|
||||
**postprocessor_kwargs,
|
||||
)
|
||||
|
||||
if is_main_process:
|
||||
@@ -394,19 +391,12 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
# create dataloader for offline training
|
||||
if hasattr(active_cfg, "drop_n_last_frames"):
|
||||
shuffle = False
|
||||
# A dedicated generator (rather than the global torch RNG) lets accelerator.prepare
|
||||
# synchronize the shuffle permutation across ranks, keeping batch shards disjoint even
|
||||
# when ranks consume the global RNG asymmetrically (e.g. eval on the main process only).
|
||||
sampler_generator = torch.Generator()
|
||||
if cfg.seed is not None:
|
||||
sampler_generator.manual_seed(cfg.seed)
|
||||
sampler = EpisodeAwareSampler(
|
||||
dataset.meta.episodes["dataset_from_index"],
|
||||
dataset.meta.episodes["dataset_to_index"],
|
||||
episode_indices_to_use=dataset.episodes,
|
||||
drop_n_last_frames=active_cfg.drop_n_last_frames,
|
||||
shuffle=True,
|
||||
generator=sampler_generator,
|
||||
)
|
||||
else:
|
||||
shuffle = True
|
||||
@@ -439,22 +429,12 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
policy.train()
|
||||
|
||||
train_metrics = {
|
||||
# Per-rank loss reflects only one shard of the global batch; mean recovers the loss DDP
|
||||
# is actually optimizing. grad_norm and lr are already identical on every rank (post
|
||||
# gradient sync / deterministic scheduler) so reducing them would be a no-op collective.
|
||||
"loss": AverageMeter("loss", ":.3f", reduction="mean"),
|
||||
"loss": AverageMeter("loss", ":.3f"),
|
||||
"grad_norm": AverageMeter("grdn", ":.3f"),
|
||||
"lr": AverageMeter("lr", ":0.1e"),
|
||||
# Report the slowest rank for bottleneck-style timings so multi-GPU runs surface the
|
||||
# true straggler instead of rank 0's view.
|
||||
"update_s": AverageMeter("updt_s", ":.3f", reduction="max"),
|
||||
"dataloading_s": AverageMeter("data_s", ":.3f", reduction="max"),
|
||||
# Derived from the post-reduce max step time; set once per log window on the main rank.
|
||||
"samples_per_s": AverageMeter("smp/s", ":.0f"),
|
||||
"update_s": AverageMeter("updt_s", ":.3f"),
|
||||
"dataloading_s": AverageMeter("data_s", ":.3f"),
|
||||
}
|
||||
if torch.cuda.is_available():
|
||||
# max() because headroom is gated by the worst-case rank.
|
||||
train_metrics["gpu_mem_gb"] = AverageMeter("mem_gb", ":.2f", reduction="max")
|
||||
|
||||
# Keep global batch size for logging; MetricsTracker handles world size internally.
|
||||
effective_batch_size = cfg.batch_size * accelerator.num_processes
|
||||
@@ -506,29 +486,21 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
if is_main_process:
|
||||
progbar.update(1)
|
||||
train_tracker.step()
|
||||
is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0
|
||||
is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0 and is_main_process
|
||||
is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps
|
||||
is_eval_step = cfg.eval_freq > 0 and step % cfg.eval_freq == 0
|
||||
|
||||
if is_log_step:
|
||||
# Collective reduce must run on every rank, before the main-process gate below.
|
||||
train_tracker.reduce_across_ranks()
|
||||
if is_main_process:
|
||||
# Cluster-wide throughput, derived from the already-reduced (max) step time so it
|
||||
# reflects the slowest rank — which is what actually gates the next iteration.
|
||||
step_time = train_tracker.update_s.avg + train_tracker.dataloading_s.avg
|
||||
if step_time > 0:
|
||||
train_tracker.samples_per_s = effective_batch_size / step_time
|
||||
logging.info(train_tracker)
|
||||
if wandb_logger:
|
||||
wandb_log_dict = train_tracker.to_dict()
|
||||
if output_dict:
|
||||
wandb_log_dict.update(output_dict)
|
||||
# Log sample weighting statistics if enabled
|
||||
if sample_weighter is not None:
|
||||
weighter_stats = sample_weighter.get_stats()
|
||||
wandb_log_dict.update({f"sample_weighting/{k}": v for k, v in weighter_stats.items()})
|
||||
wandb_logger.log_dict(wandb_log_dict, step)
|
||||
logging.info(train_tracker)
|
||||
if wandb_logger:
|
||||
wandb_log_dict = train_tracker.to_dict()
|
||||
if output_dict:
|
||||
wandb_log_dict.update(output_dict)
|
||||
# Log sample weighting statistics if enabled
|
||||
if sample_weighter is not None:
|
||||
weighter_stats = sample_weighter.get_stats()
|
||||
wandb_log_dict.update({f"sample_weighting/{k}": v for k, v in weighter_stats.items()})
|
||||
wandb_logger.log_dict(wandb_log_dict, step)
|
||||
train_tracker.reset_averages()
|
||||
|
||||
if cfg.save_checkpoint and is_saving_step:
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user