mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-18 16:57:12 +00:00
Compare commits
22 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 132ea975f0 | |||
| 961e0d9bcd | |||
| 6496728025 | |||
| 3b37bd0ca6 | |||
| 8e692e365c | |||
| f617b2c2bf | |||
| c6a51b9b60 | |||
| ab49c71c22 | |||
| 459efef8a0 | |||
| 5568ce7af1 | |||
| be0320a420 | |||
| 5222f3a4a7 | |||
| f9d12db9cf | |||
| 71aacda05e | |||
| e3deff00ad | |||
| 09808183ca | |||
| 4dfa8cea65 | |||
| 2e9cd87bbd | |||
| d1b1c5c8cf | |||
| 741c2d0a39 | |||
| 19fe315971 | |||
| 906b585826 |
@@ -22,6 +22,10 @@ outputs
|
||||
rl
|
||||
media
|
||||
|
||||
# Local virtualenvs (the image provides its own)
|
||||
.venv
|
||||
venv
|
||||
|
||||
|
||||
# Logging
|
||||
logs
|
||||
|
||||
@@ -9,6 +9,8 @@
|
||||
- sections:
|
||||
- local: il_robots
|
||||
title: Imitation Learning for Robots
|
||||
- local: lelab
|
||||
title: LeLab - Lerobot GUI
|
||||
- local: bring_your_own_policies
|
||||
title: Adding a Policy
|
||||
- local: integrate_hardware
|
||||
@@ -61,8 +63,12 @@
|
||||
title: π₀.₅ (Pi05)
|
||||
- local: molmoact2
|
||||
title: MolmoAct2
|
||||
- local: vla_jepa
|
||||
title: VLA-JEPA
|
||||
- local: eo1
|
||||
title: EO-1
|
||||
- local: lingbot_va
|
||||
title: LingBot-VA
|
||||
- local: groot
|
||||
title: NVIDIA GR00T N1.5
|
||||
- local: xvla
|
||||
|
||||
@@ -647,5 +647,6 @@ The `--strategy.type` flag selects the execution mode:
|
||||
- `sentry`: Continuous recording with auto-upload (useful for large-scale evaluation)
|
||||
- `highlight`: Ring buffer recording with keystroke save (useful for capturing interesting events)
|
||||
- `dagger`: Human-in-the-loop data collection (see [HIL Data Collection](./hil_data_collection))
|
||||
- `episodic`: Episode-oriented policy recording with reset phases between episodes
|
||||
|
||||
All strategies support `--inference.type=rtc` for smooth execution with slow VLA models (Pi0, Pi0.5, SmolVLA).
|
||||
|
||||
@@ -157,6 +157,44 @@ Foot pedal input is also supported via `--strategy.input_device=pedal`. Configur
|
||||
| `--strategy.input_device` | Input device: `keyboard` or `pedal` (default: keyboard) |
|
||||
| `--teleop.type` | **Required.** Teleoperator type |
|
||||
|
||||
### Episodic (`--strategy.type=episodic`)
|
||||
|
||||
Episode-oriented recording that mirrors the behavior of `lerobot-record`. The policy drives the robot for each episode; an optional teleoperator can drive the robot during the reset phase between episodes.
|
||||
|
||||
```bash
|
||||
lerobot-rollout \
|
||||
--strategy.type=episodic \
|
||||
--policy.path=${HF_USER}/my_policy \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/ttyACM0 \
|
||||
--teleop.type=so100_leader \
|
||||
--teleop.port=/dev/ttyACM1 \
|
||||
--dataset.repo_id=${HF_USER}/my_eval_data \
|
||||
--dataset.num_episodes=20 \
|
||||
--dataset.episode_time_s=30 \
|
||||
--dataset.reset_time_s=10 \
|
||||
--dataset.single_task="Pick up the red cube"
|
||||
```
|
||||
|
||||
Teleop is optional — if omitted the robot holds its position during the reset phase.
|
||||
|
||||
**Keyboard controls:**
|
||||
|
||||
| Key | Action |
|
||||
| ----------- | -------------------------------- |
|
||||
| `→` (right) | End the current episode early |
|
||||
| `←` (left) | Discard episode and re-record it |
|
||||
| `ESC` | Stop the recording session |
|
||||
|
||||
| Flag | Description |
|
||||
| ----------------------------------------------- | -------------------------------------------------------------------------- |
|
||||
| `--dataset.num_episodes` | Number of episodes to record |
|
||||
| `--dataset.episode_time_s` | Duration of each recording episode in seconds |
|
||||
| `--dataset.reset_time_s` | Duration of the reset phase between episodes in seconds |
|
||||
| `--teleop.type` | Optional. Teleoperator to drive the robot during resets |
|
||||
| `--strategy.reset_to_initial_position` | Whether to reset the robot to its initial position between episodes |
|
||||
| `--strategy.smooth_leader_to_follower_handover` | Whether to turn on or off the leader -> follower smooth handover behavior. |
|
||||
|
||||
---
|
||||
|
||||
## Inference Backends
|
||||
|
||||
@@ -0,0 +1,29 @@
|
||||
# LeLab - LeRobot Guide
|
||||
|
||||
LeLab is a graphical user interface built on top of the LeRobot library, designed to make robotics accessible without needing to memorize CLI commands. From a single app you can configure your robot, teleoperate it, collect datasets, train policies locally or on cloud GPUs via HF Jobs, and deploy trained models back onto your robot. It's the easiest way to go from an unboxed SO-101 to a working policy, and a great companion for anyone learning the LeRobot workflow. Source code and issues live on GitHub: [huggingface/leLab](https://github.com/huggingface/leLab).
|
||||
|
||||
> [!TIP]
|
||||
> For now LeLab is compatible only with SO-ARM101
|
||||
|
||||
<Youtube id="VqyKUuW9V1g" />
|
||||
|
||||
### Installation
|
||||
|
||||
Requires [`uv`](https://docs.astral.sh/uv/getting-started/installation/). Install and launch in one command:
|
||||
|
||||
```
|
||||
uv tool install git+https://github.com/huggingface/leLab.git && lelab
|
||||
```
|
||||
|
||||
After install, run `lelab` from your terminal anytime to start the app.
|
||||
|
||||
### Features
|
||||
|
||||
- **Add robots** — Select arm type (leader/follower), calibrate each joint from the middle position, and attach cameras.
|
||||
- **Teleoperation** — Control the follower arm with the leader and see a live 3D visualization of the arms.
|
||||
- **Dataset recording** — Define a task description, number of episodes, and episode/reset durations. Press spacebar to advance between episodes. 30+ episodes recommended.
|
||||
- **Local training** — Train a policy directly on your own machine with a selected dataset, policy type, batch size, and step count.
|
||||
- **Cloud training with HF Jobs** — Train on powerful GPUs via [HF Jobs](https://huggingface.co/docs/huggingface_hub/en/guides/jobs) with transparent pricing. Run `hf auth login` first. See the [Compute HW Guide](hardware_guide) for hardware/batch size tips.
|
||||
- **Training visualization** — Watch progress live in the app, with checkpoints saved automatically.
|
||||
- **Run trained policies** — Pick any model from your jobs list and run inference on your robot with one click.
|
||||
- **Use community datasets** — Provide any Hugging Face dataset ID to train on datasets you didn't record yourself.
|
||||
@@ -275,7 +275,7 @@ A converter aggregates per‑episode files into larger shards and writes episode
|
||||
pip install "https://github.com/huggingface/lerobot/archive/33cad37054c2b594ceba57463e8f11ee374fa93c.zip"
|
||||
|
||||
# Convert an existing v2.1 dataset hosted on the Hub:
|
||||
python -m lerobot.datasets.v30.convert_dataset_v21_to_v30 --repo-id=<HF_USER/DATASET_ID>
|
||||
python -m lerobot.scripts.convert_dataset_v21_to_v30 --repo-id=<HF_USER/DATASET_ID>
|
||||
```
|
||||
|
||||
**What it does**
|
||||
|
||||
@@ -0,0 +1,187 @@
|
||||
# LingBot-VA
|
||||
|
||||
LingBot-VA is an **autoregressive video-action world-model policy** built on the **Wan2.2**
|
||||
video-diffusion stack. It interleaves, in one autoregressive sequence, the prediction of
|
||||
future **video latents** and **robot actions** ("VA" = Video-Action). The LeRobot
|
||||
integration wires LingBot-VA into the standard training, evaluation and processor
|
||||
interfaces.
|
||||
|
||||
## Model Overview
|
||||
|
||||
LingBot-VA is a **dual-stream "mixture-of-transformers"**: a video/latent stream
|
||||
(`patch_embedding_mlp → blocks → proj_out`) and an action stream
|
||||
(`action_embedder → blocks → action_proj_out`) share the same 30 transformer blocks and
|
||||
text conditioning.
|
||||
|
||||
| Component | Class | Role |
|
||||
| ------------------------ | ----------------------- | ----------------------------------------------------------- |
|
||||
| DiT backbone (trainable) | `WanTransformer3DModel` | ~5B-param dual-stream transformer. |
|
||||
| VAE (frozen) | `AutoencoderKLWan` | Wan2.2 VAE, `z_dim=48`. Lazy-pulled from the source repo. |
|
||||
| Text encoder (frozen) | `UMT5EncoderModel` | UMT5-XXL, `d_model=4096`. Lazy-pulled from the source repo. |
|
||||
|
||||
At inference the policy runs an autoregressive loop per chunk: it denoises the video-latent
|
||||
stream (CFG, ~20 steps) and the action stream (~50 steps) with two independent
|
||||
flow-matching schedulers, maintaining a KV cache across chunks. Real observed keyframes are
|
||||
fed back into the KV cache as the chunk is executed (closed-loop world modeling).
|
||||
|
||||
### What the LeRobot Integration Covers
|
||||
|
||||
- Standard `policy.type=lingbot_va` configuration through LeRobot.
|
||||
- Ready-to-use LeRobot-format checkpoints on the Hub (converted from the released upstream ones).
|
||||
- Autoregressive dual-stream inference behind the standard `select_action` interface
|
||||
(single-environment eval, `--eval.batch_size=1`).
|
||||
- Opt-in saving of the policy's **predicted (imagined) videos** during eval / training.
|
||||
- Evaluation with `lerobot-eval` on LIBERO and RoboTwin.
|
||||
- Training / fine-tuning via the dual-stream flow-matching loss (`policy.forward`), see below.
|
||||
|
||||
## Installation
|
||||
|
||||
1. Install LeRobot by following the [Installation Guide](./installation).
|
||||
2. Install the LingBot-VA extra:
|
||||
|
||||
```bash
|
||||
pip install -e ".[lingbot_va]"
|
||||
```
|
||||
|
||||
## Checkpoints
|
||||
|
||||
The released upstream checkpoints have been converted to LeRobot format and pushed to the Hub:
|
||||
|
||||
| Variant | LeRobot checkpoint |
|
||||
| ---------------------- | -------------------------------- |
|
||||
| LIBERO-Long post-train | `lerobot/lingbot_va_libero_long` |
|
||||
| RoboTwin post-train | `lerobot/lingbot_va_robotwin` |
|
||||
| Pretrained base | `lerobot/lingbot_va_base` |
|
||||
|
||||
Only the trainable ~5B transformer is stored in the LeRobot
|
||||
`model.safetensors`. The frozen VAE + UMT5 + tokenizer (~20 GB) are pulled from
|
||||
`config.wan_pretrained_path` at load time (defaults to the source `robbyant/*` repo). The
|
||||
UMT5-XXL text encoder runs on CPU by default (`config.text_encoder_device`) so the 5B
|
||||
transformer + VAE fit on a single 24–32 GB GPU.
|
||||
|
||||
## Evaluation (LIBERO)
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--policy.path=lerobot/lingbot_va_libero_long \
|
||||
--policy.device=cuda \
|
||||
--env.type=libero --env.task=libero_10 \
|
||||
--env.observation_height=128 --env.observation_width=128 \
|
||||
--eval.n_episodes=50 --eval.batch_size=1 \
|
||||
--output_dir=outputs/eval/lingbot_va_libero
|
||||
```
|
||||
|
||||
LingBot-VA's streaming inference (KV cache + observed-keyframe feedback) is implemented for
|
||||
single-environment eval; use `--eval.batch_size=1`.
|
||||
|
||||
## Evaluation (RoboTwin)
|
||||
|
||||
RoboTwin 2.0 needs the SAPIEN + CuRobo simulator stack. You can use the benchmark Docker image
|
||||
(`docker/Dockerfile.benchmark.robotwin`, which also needs `warp-lang==1.3.1` and CuRobo built
|
||||
with the GPU's compute capability in `TORCH_CUDA_ARCH_LIST`). RoboTwin uses **end-effector-pose
|
||||
control**, so run with `--env.action_mode=ee`: the policy predicts per-arm `xyz+quaternion+gripper`
|
||||
deltas (`robotwin_tshape` latent layout) that are composed onto the episode's initial eef pose and
|
||||
executed via CuRobo IK.
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--policy.path=lerobot/lingbot_va_robotwin \
|
||||
--policy.device=cuda \
|
||||
--env.type=robotwin --env.task=beat_block_hammer --env.action_mode=ee \
|
||||
--eval.n_episodes=10 --eval.batch_size=1 \
|
||||
--output_dir=outputs/eval/lingbot_va_robotwin
|
||||
```
|
||||
|
||||
### Saving predicted (imagined) videos
|
||||
|
||||
Set `--policy.save_predicted_video=true` to additionally VAE-decode the predicted video
|
||||
latents and write `pred_episode_*.mp4` next to the env-rendered `eval_episode_*.mp4` videos.
|
||||
The same flag works for the periodic eval during `lerobot-train`.
|
||||
|
||||
## Training / fine-tuning
|
||||
|
||||
`LingBotVAPolicy.forward(batch)` implements the dual-stream **flow-matching** loss
|
||||
(`latent_loss + action_loss`, timestep-weighted, action-masked) from the paper: it VAE-encodes
|
||||
the camera clips into video latents, UMT5-encodes the task, noises both streams, runs the
|
||||
transformer's block-causal training pass and returns `(loss, metrics)`. Optimizer preset is AdamW
|
||||
with a linear-warmup-then-constant schedule (matching upstream).
|
||||
|
||||
Requirements:
|
||||
|
||||
- The block-causal masks use PyTorch **flex-attention**, so build the policy with
|
||||
`--policy.attn_mode=flex` for training (the default `torch` SDPA is inference-only).
|
||||
- The full 5B DiT does not fit a single 24–32 GB GPU under AdamW; fine-tune with **LoRA**
|
||||
(`--policy.use_peft=true`) and/or optimizer offload. `get_optim_params` returns only the
|
||||
trainable (e.g. adapter) parameters; the VAE + UMT5 text encoder stay frozen.
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--policy.path=lerobot/lingbot_va_libero_long --policy.attn_mode=flex \
|
||||
--policy.use_peft=true \
|
||||
--dataset.repo_id=<your LeRobot-format dataset> \
|
||||
--batch_size=1 --steps=... --output_dir=outputs/train/lingbot_va
|
||||
```
|
||||
|
||||
The dataset must provide camera clips (a temporal window per camera, VAE-encoded to
|
||||
`frame_chunk_size` latent frames) and `frame_chunk_size * action_per_frame` action steps per item.
|
||||
|
||||
## Data format (action channels & camera order)
|
||||
|
||||
LingBot-VA is an **end-effector (Cartesian) pose** policy, it predicts EEF poses + gripper, not
|
||||
joint positions. Actions live in a fixed multi-embodiment **30-dim** layout; map your robot's
|
||||
action dimensions into these channels and pad the rest with `0` (`used_action_channel_ids` selects
|
||||
the channels a given checkpoint actually uses):
|
||||
|
||||
| channels | meaning |
|
||||
| -------- | ----------------------------------------------------- |
|
||||
| 0–6 | Left-arm end-effector pose |
|
||||
| 7–13 | Right-arm end-effector pose |
|
||||
| 14–20 | Left-arm joints (unused by the released checkpoints) |
|
||||
| 21–27 | Right-arm joints (unused by the released checkpoints) |
|
||||
| 28 | Left gripper |
|
||||
| 29 | Right gripper |
|
||||
|
||||
- **LIBERO** uses channels `0–6`: a 6-DoF EEF delta (xyz + rotation) + gripper (single arm).
|
||||
- **RoboTwin** uses channels `[0–6, 28, 7–13, 29]`: left EEF (xyz + quaternion) + left gripper +
|
||||
right EEF + right gripper (16 dims). The env converts these poses to joint trajectories via
|
||||
CuRobo IK — joints are never predicted.
|
||||
|
||||
Joint-space datasets (or a different EEF convention) must be remapped into this schema before
|
||||
fine-tuning these checkpoints.
|
||||
|
||||
**Camera order is fixed and order-sensitive**, per-camera latents are concatenated spatially in
|
||||
`obs_cam_keys` order, so the physical camera→slot mapping must match training:
|
||||
|
||||
| benchmark | `obs_cam_keys` (in order) | `camera_layout` |
|
||||
| --------- | ----------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------- |
|
||||
| LIBERO | `observation.images.image` (agentview / 3rd-person), `observation.images.image2` (eye-in-hand wrist) | `width_concat` (latents concatenated on width) |
|
||||
| RoboTwin | `observation.images.head_camera`, `observation.images.left_camera`, `observation.images.right_camera` | `robotwin_tshape` (full-res head below, two half-res wrists on top) |
|
||||
|
||||
The first camera is the exterior/head view and the rest are wrist views.
|
||||
|
||||
## Inference Hyperparameters (LIBERO)
|
||||
|
||||
| Key | Value |
|
||||
| -------------------------------------- | --------------------------------------------------------------------------------- |
|
||||
| height × width | 128 × 128 |
|
||||
| cameras | `observation.images.image` (agentview), `observation.images.image2` (eye-in-hand) |
|
||||
| action channels used | 0–6 (7-DoF arm + gripper) |
|
||||
| action_per_frame / frame_chunk_size | 4 / 4 |
|
||||
| attn_window | 30 |
|
||||
| video / action denoising steps | 20 / 50 |
|
||||
| guidance_scale / action_guidance_scale | 5 / 1 |
|
||||
| snr_shift / action_snr_shift | 5.0 / 0.05 |
|
||||
|
||||
These are the defaults of `LingBotVAConfig`; override any of them via `--policy.<name>=...`.
|
||||
|
||||
## Notes
|
||||
|
||||
- **Attention backend:** inference uses the `torch` SDPA backend (always available). The
|
||||
`flashattn` and `flex` backends are optional; `flex` is only needed for training.
|
||||
- **Model size:** the DiT is ~5B params and the frozen VAE+UMT5 add ~20 GB; inference needs
|
||||
roughly 18–24 GB of VRAM.
|
||||
|
||||
## License
|
||||
|
||||
LingBot-VA is released under Apache-2.0. See the
|
||||
[upstream repository](https://github.com/Robbyant/lingbot-va).
|
||||
@@ -238,7 +238,7 @@ your dataset has not been converted with quantile statistics, you can add them
|
||||
with:
|
||||
|
||||
```bash
|
||||
python src/lerobot/datasets/v30/augment_dataset_quantile_stats.py \
|
||||
python src/lerobot/scripts/augment_dataset_quantile_stats.py \
|
||||
--repo-id=your_dataset
|
||||
```
|
||||
|
||||
|
||||
@@ -91,7 +91,7 @@ lerobot-train \
|
||||
If your dataset is not converted with `quantiles`, you can convert it with the following command:
|
||||
|
||||
```bash
|
||||
python src/lerobot/datasets/v30/augment_dataset_quantile_stats.py \
|
||||
python src/lerobot/scripts/augment_dataset_quantile_stats.py \
|
||||
--repo-id=your_dataset \
|
||||
```
|
||||
|
||||
|
||||
@@ -0,0 +1,39 @@
|
||||
# VLA-JEPA
|
||||
|
||||
This repository contains the LeRobot port of **VLA-JEPA**, a Vision-Language-Action model that combines a Qwen3-VL language backbone with a self-supervised video world model (V-JEPA2) and a flow-matching DiT action head.
|
||||
|
||||
Converted from [ginwind/VLA-JEPA](https://huggingface.co/ginwind/VLA-JEPA).
|
||||
|
||||
---
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
| Component | Module | Role |
|
||||
| ----------------------- | --------------------------------- | ------------------------------------------------------- |
|
||||
| **Qwen3-VL backbone** | `Qwen3VLInterface` | Fuses images + language instruction into context tokens |
|
||||
| **DiT-B action head** | `VLAJEPAActionHead` | Flow-matching diffusion over the action chunk |
|
||||
| **V-JEPA2 world model** | `ActionConditionedVideoPredictor` | Self-supervised video prediction loss (training only) |
|
||||
|
||||
At inference time only the Qwen backbone and action head are used; the world model is not needed.
|
||||
|
||||
---
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@misc{sun2026vlajepaenhancingvisionlanguageactionmodel,
|
||||
title = {VLA-JEPA: Enhancing Vision-Language-Action Model with Latent World Model},
|
||||
author = {Jingwen Sun and Wenyao Zhang and Zekun Qi and Shaojie Ren and Zezhi Liu and Hanxin Zhu and Guangzhong Sun and Xin Jin and Zhibo Chen},
|
||||
year = {2026},
|
||||
eprint = {2602.10098},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.RO},
|
||||
url = {https://arxiv.org/abs/2602.10098},
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## License
|
||||
|
||||
Weights are distributed under the license terms of the original [ginwind/VLA-JEPA](https://huggingface.co/ginwind/VLA-JEPA) repository (**Apache 2.0 License**). The LeRobot integration code follows the **Apache 2.0 License**.
|
||||
@@ -300,7 +300,7 @@ This replaces the old episode-per-file structure with efficient, optimally-sized
|
||||
If you have existing datasets in v2.1 format, use the migration tool:
|
||||
|
||||
```bash
|
||||
python src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py \
|
||||
python src/lerobot/scripts/convert_dataset_v21_to_v30.py \
|
||||
--repo-id your_id/existing_dataset
|
||||
```
|
||||
|
||||
|
||||
@@ -0,0 +1,235 @@
|
||||
# VLA-JEPA
|
||||
|
||||
This is the LeRobot port of **VLA-JEPA**, a Vision-Language-Action model that combines a Qwen3-VL language backbone with a self-supervised video world model (V-JEPA2) and a flow-matching DiT action head.
|
||||
|
||||
---
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
VLA-JEPA has three main components:
|
||||
|
||||
| Component | Module | Role |
|
||||
| ----------------------- | --------------------------------- | ------------------------------------------------------- |
|
||||
| **Qwen3-VL backbone** | `Qwen3VLInterface` | Fuses images + language instruction into context tokens |
|
||||
| **DiT-B action head** | `VLAJEPAActionHead` | Flow-matching diffusion over the action chunk |
|
||||
| **V-JEPA2 world model** | `ActionConditionedVideoPredictor` | Self-supervised video prediction loss (training only) |
|
||||
|
||||
### Data flow
|
||||
|
||||
**Training:**
|
||||
|
||||
1. A video clip of `num_video_frames` frames is encoded by V-JEPA2 into per-frame patch tokens.
|
||||
2. The Qwen3-VL backbone processes multi-view images + the task instruction and produces a sequence of context tokens that includes special action tokens (for world model conditioning) and embodied tokens.
|
||||
3. The action head receives those context tokens as cross-attention keys/values and predicts a denoised action chunk via flow matching.
|
||||
4. The world model predictor uses the action tokens extracted from Qwen to predict future V-JEPA2 frame embeddings; a regression loss on those predictions is added to the action loss.
|
||||
|
||||
**Inference:**
|
||||
Only Qwen + the action head are used. The world model is not needed at inference time.
|
||||
|
||||
### Action head details
|
||||
|
||||
Available presets via `action_model_type`:
|
||||
|
||||
| Preset | Hidden dim | Heads | Head dim |
|
||||
| ------- | ---------- | ----- | -------- |
|
||||
| `DiT-B` | 768 | 12 | 64 |
|
||||
| `DiT-L` | 1536 | 32 | 48 |
|
||||
|
||||
### World model details
|
||||
|
||||
The video predictor is a ViT-style transformer (`ActionConditionedVideoPredictor`) that takes:
|
||||
|
||||
- **Frame tokens**: V-JEPA2 patch embeddings projected to `predictor_embed_dim`
|
||||
- **Action tokens**: Qwen action token embeddings projected to `predictor_embed_dim`
|
||||
|
||||
It uses block-causal attention so each temporal step can attend to all previous steps. The predictor's input `embed_dim` equals `num_views × video_encoder_hidden_size` (e.g. 2 views × 1024 = 2048 for the pretrained checkpoints).
|
||||
|
||||
---
|
||||
|
||||
## Pretrained Checkpoints
|
||||
|
||||
Three checkpoints are available directly inside the LeRobot org here: [`lerobot/VLA-JEPA`](https://huggingface.co/collections/lerobot/vla-jepa), converted from [ginwind/VLA-JEPA](https://huggingface.co/ginwind/VLA-JEPA):
|
||||
|
||||
| Checkpoint | Dataset | Cameras | World model | Action dim |
|
||||
| ----------------------------- | ----------------- | ----------------------- | ----------- | ---------- |
|
||||
| `lerobot/VLA-JEPA-LIBERO` | LIBERO-10 | 2 (agentview + wrist) | Enabled | 7 |
|
||||
| `lerobot/VLA-JEPA-Pretrain` | DROID 1.0.1 | 2 (exterior left views) | Enabled | 7 |
|
||||
| `lerobot/VLA-JEPA-SimplerEnv` | OXE Bridge / RT-1 | 1 (view duplicated ×2) | Enabled | 7 |
|
||||
|
||||
All checkpoints use `Qwen/Qwen3-VL-2B-Instruct` as the language backbone.
|
||||
|
||||
---
|
||||
|
||||
## Configuration
|
||||
|
||||
Key parameters in `VLAJEPAConfig`:
|
||||
|
||||
| Parameter | Default | Description |
|
||||
| ------------------------- | ------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `chunk_size` | 7 | Number of actions predicted per inference call |
|
||||
| `n_action_steps` | 7 | Steps executed from the predicted chunk before re-planning |
|
||||
| `num_video_frames` | 8 | Video clip length fed to the world model |
|
||||
| `enable_world_model` | `True` | Whether to load and train the V-JEPA2 predictor |
|
||||
| `world_model_loss_weight` | 0.1 | Weight of the JEPA prediction loss relative to the action loss |
|
||||
| `num_inference_timesteps` | 4 | Euler integration steps for action denoising |
|
||||
| `freeze_qwen` | `False` | Freeze the Qwen3-VL backbone and only train the action head |
|
||||
| `reinit_modules` | `None` | Key prefixes allowed to be randomly re-initialised on load (for cross-embodiment transfer, see [Fine-tuning on a different embodiment](#fine-tuning-on-a-different-embodiment)) |
|
||||
| `gripper_dim` | 6 | Index of the gripper dimension in the action vector (e.g. 6 for a 7-DoF arm with gripper as the last joint) |
|
||||
| `gripper_threshold` | 0.5 | Threshold used by `pre_snap_gripper_action` and `binarize_gripper_action` to binarize the gripper dimension |
|
||||
| `pre_snap_gripper_action` | `True` | Snap the gripper dim to {0, 1} before unnormalization. Set to `False` for robots without a binary gripper |
|
||||
| `binarize_gripper_action` | `True` | Binarize the gripper dim to {-1, 1} after unnormalization. Set to `False` for robots without a binary gripper |
|
||||
|
||||
---
|
||||
|
||||
## Training
|
||||
|
||||
Number of training steps may vary based on dataset size and compute budget. The original paper pretrained for 50k on ssv2 + droid jointly, then additional 30k steps for LIBERO, but fewer steps may still yield good performance when fine-tuning from the provided pretrained checkpoints.
|
||||
|
||||
### Full training from scratch
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
policy.type=vla_jepa \
|
||||
policy.repo_id=your_org/your_repo \
|
||||
dataset.repo_id=your_org/your_dataset
|
||||
```
|
||||
|
||||
### Fine-tuning from a pretrained checkpoint
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--policy.path=lerobot/VLA-JEPA-Pretrain \
|
||||
--policy.repo_id=your_org/your_repo \
|
||||
--dataset.repo_id=your_org/your_dataset
|
||||
```
|
||||
|
||||
If you want to freeze the Qwen backbone and only train the action head, set `policy.freeze_qwen=True`:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--policy.path=lerobot/VLA-JEPA-Pretrain \
|
||||
--policy.repo_id=your_org/your_repo \
|
||||
--policy.freeze_qwen=true \
|
||||
--dataset.repo_id=your_org/your_dataset
|
||||
```
|
||||
|
||||
### Fine-tuning on a different embodiment
|
||||
|
||||
When the target robot has a different action or state dimensionality than the pretrained checkpoint, the input/output projection layers of the action head will have mismatched shapes and cannot be loaded directly. `reinit_modules` lets you list the key prefixes that are allowed to mismatch — those layers are randomly re-initialised while every other weight is reused from the checkpoint. Any shape mismatch outside the listed prefixes raises an error.
|
||||
|
||||
The layers that depend on `action_dim` and `state_dim` are:
|
||||
|
||||
| Layer | Key prefix |
|
||||
| ----------------------------------------- | ----------------------------------- |
|
||||
| Action encoder (action_dim → inner_dim) | `model.action_model.action_encoder` |
|
||||
| Action decoder (hidden_size → action_dim) | `model.action_model.action_decoder` |
|
||||
| State encoder (state_dim → inner_dim) | `model.action_model.state_encoder` |
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--policy.path=lerobot/VLA-JEPA-Pretrain \
|
||||
--policy.repo_id=your_org/your_repo \
|
||||
--policy.freeze_qwen=true \
|
||||
--policy.reinit_modules='["model.action_model.action_encoder", "model.action_model.action_decoder", "model.action_model.state_encoder"]' \
|
||||
--dataset.repo_id=your_org/your_dataset
|
||||
```
|
||||
|
||||
If your robot has no proprioceptive state, omit `model.action_model.state_encoder` from the list.
|
||||
|
||||
### Reproducing the LIBERO results
|
||||
|
||||
**Training on LIBERO:**
|
||||
starts the training from the Pretrain checkpoint, trains for 30k steps on the LIBERO dataset.
|
||||
Original paper mentions training across 8 GPUs with a batch size of 32, meaning global batch size of 256.
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--policy.path=lerobot/VLA-JEPA-Pretrain \
|
||||
--policy.repo_id=your_org/your_repo \
|
||||
--dataset.repo_id=HuggingFaceVLA/libero \
|
||||
--steps=30000
|
||||
```
|
||||
|
||||
**Evaluating the pretrained LIBERO-10 checkpoint:**
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--policy.path=lerobot/VLA-JEPA-LIBERO \
|
||||
--env.type=libero \
|
||||
--env.task=libero_spatial,libero_object,libero_goal,libero_10 \
|
||||
--eval.n_episodes=10 \
|
||||
--eval.batch_size=5
|
||||
```
|
||||
|
||||
To evaluate a subset of tasks only:
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--policy.path=lerobot/VLA-JEPA-LIBERO \
|
||||
--env.type=libero \
|
||||
--env.task=libero_10 \
|
||||
--env.task_ids='[0,1,2]' \
|
||||
--eval.n_episodes=10 \
|
||||
--eval.batch_size=5
|
||||
```
|
||||
|
||||
**Expected results:**
|
||||
|
||||
| Suite | Episodes | Successes | Success Rate |
|
||||
| -------------- | -------- | --------- | ------------ |
|
||||
| libero_spatial | 100 | 93 | **95.0%** |
|
||||
| libero_object | 100 | 100 | **100.0%** |
|
||||
| libero_goal | 100 | 98 | **98.0%** |
|
||||
| libero_10 | 100 | 96 | **93.0%** |
|
||||
| **Overall** | **400** | **387** | **96.5%** |
|
||||
|
||||
---
|
||||
|
||||
## Fine-tuning on datasets with a different number of cameras
|
||||
|
||||
The pretrained world model predictor was trained with `embed_dim = jepa_tubelet_size × 1024` (default `jepa_tubelet_size=2`).
|
||||
|
||||
**Default behaviour — view padding / trimming (no action required)**
|
||||
|
||||
When fine-tuning from `VLA-JEPA-Pretrain` the model automatically adjusts the number of views fed to the world model to match `jepa_tubelet_size`:
|
||||
|
||||
- **Single-view datasets (e.g. BridgeV2):** the single-view latent is duplicated to produce a two-view world-model input, preserving the JEPA self-supervised signal without any weight mismatch.
|
||||
- **>2-view datasets (e.g. DROID with 3 views):** all views are passed to the Qwen backbone (for richer context), but only the first `jepa_tubelet_size` views (one wrist + one third-person, following the configured view order) are used for the world model.
|
||||
|
||||
**Option 1 — Disable the world model**
|
||||
|
||||
Set `enable_world_model=False` to skip the JEPA loss entirely. Only the Qwen backbone and action head are loaded and trained. This is sufficient for good action performance.
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--policy.path=lerobot/VLA-JEPA-Pretrain \
|
||||
--policy.enable_world_model=false \
|
||||
--policy.repo_id=your_org/your_repo \
|
||||
--dataset.repo_id=your_org/single_camera_dataset
|
||||
```
|
||||
|
||||
**Option 2 — Reinitialize the predictor input projection**
|
||||
|
||||
If you want to change `jepa_tubelet_size` to a value other than 2, load the checkpoint with `strict=False` and reinitialize `model.video_predictor.predictor_embed` for the new `embed_dim`. All other predictor block weights (attention, MLP, norm, output projection) are camera-count-agnostic and can be reused from the pretrained checkpoint.
|
||||
|
||||
---
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@misc{sun2026vlajepaenhancingvisionlanguageactionmodel,
|
||||
title = {VLA-JEPA: Enhancing Vision-Language-Action Model with Latent World Model},
|
||||
author = {Jingwen Sun and Wenyao Zhang and Zekun Qi and Shaojie Ren and Zezhi Liu and Hanxin Zhu and Guangzhong Sun and Xin Jin and Zhibo Chen},
|
||||
year = {2026},
|
||||
eprint = {2602.10098},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.RO},
|
||||
url = {https://arxiv.org/abs/2602.10098},
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## License
|
||||
|
||||
Weights are distributed under the license terms of the original [ginwind/VLA-JEPA](https://huggingface.co/ginwind/VLA-JEPA) repository (**Apache 2.0 License**). The LeRobot integration code follows the **Apache 2.0 License**.
|
||||
+9
-1
@@ -146,7 +146,8 @@ grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"]
|
||||
can-dep = ["python-can>=4.2.0,<5.0.0"]
|
||||
peft-dep = ["peft>=0.18.0,<1.0.0"]
|
||||
scipy-dep = ["scipy>=1.14.0,<2.0.0"]
|
||||
diffusers-dep = ["diffusers>=0.27.2,<0.36.0"]
|
||||
diffusers-dep = ["diffusers>=0.27.2,<0.37.0"]
|
||||
imageio-dep = ["imageio[ffmpeg]>=2.34.0,<3.0.0"]
|
||||
qwen-vl-utils-dep = ["qwen-vl-utils>=0.0.11,<0.1.0"]
|
||||
matplotlib-dep = ["matplotlib>=3.10.3,<4.0.0", "contourpy>=1.3.0,<2.0.0"] # NOTE: Explicitly listing contourpy helps the resolver converge faster.
|
||||
pyserial-dep = ["pyserial>=3.5,<4.0"]
|
||||
@@ -217,6 +218,8 @@ topreward = ["lerobot[transformers-dep]"]
|
||||
xvla = ["lerobot[transformers-dep]"]
|
||||
eo1 = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||
hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
||||
vla_jepa = ["lerobot[transformers-dep]", "lerobot[diffusers-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||
lingbot_va = ["lerobot[transformers-dep]", "diffusers>=0.36.0,<0.37.0", "lerobot[imageio-dep]", "accelerate>=1.10.0,<2.0.0", "ftfy>=6.0.0,<7.0.0"]
|
||||
|
||||
# Features
|
||||
async = ["lerobot[grpcio-dep]", "lerobot[matplotlib-dep]"]
|
||||
@@ -282,6 +285,8 @@ all = [
|
||||
# "lerobot[groot]", TODO(Steven): Gr00t requires specific installation instructions for flash-attn
|
||||
"lerobot[xvla]",
|
||||
"lerobot[hilserl]",
|
||||
"lerobot[vla_jepa]",
|
||||
"lerobot[lingbot_va]",
|
||||
"lerobot[async]",
|
||||
"lerobot[dev]",
|
||||
"lerobot[test]",
|
||||
@@ -373,6 +378,9 @@ ignore = [
|
||||
# E402: conditional-import guards (TYPE_CHECKING / is_package_available) must precede the imports they protect
|
||||
"src/lerobot/scripts/convert_dataset_v21_to_v30.py" = ["E402"]
|
||||
"src/lerobot/policies/wall_x/**" = ["N801", "N812", "SIM102", "SIM108", "SIM210", "SIM211", "B006", "B007", "SIM118"] # Supprese these as they are coming from original Qwen2_5_vl code TODO(pepijn): refactor original
|
||||
# Vendored Wan2.2 / LingBot-VA model code uses tensor-dimension names (B, F, H, W) and `F` for
|
||||
# torch.nn.functional.
|
||||
"src/lerobot/policies/lingbot_va/**" = ["N803", "N806", "N812", "SIM102"]
|
||||
|
||||
[tool.ruff.lint.isort]
|
||||
combine-as-imports = true
|
||||
|
||||
@@ -18,6 +18,7 @@ from __future__ import annotations
|
||||
# Utilities
|
||||
########################################################################################
|
||||
import logging
|
||||
import time
|
||||
import traceback
|
||||
from contextlib import nullcontext
|
||||
from copy import copy
|
||||
@@ -243,3 +244,72 @@ def sanity_check_dataset_robot_compatibility(
|
||||
raise ValueError(
|
||||
"Dataset metadata compatibility check failed with mismatches:\n" + "\n".join(mismatches)
|
||||
)
|
||||
|
||||
|
||||
########################################################################################
|
||||
# Teleoperator smooth handover helpers
|
||||
# NOTE(Maxime): These functions use minimal type hints to maintain compatibility with utils
|
||||
# being a root module.
|
||||
########################################################################################
|
||||
|
||||
|
||||
def teleop_supports_feedback(teleop) -> bool:
|
||||
"""Return True when the teleop can receive position feedback (is actuated).
|
||||
|
||||
Actuated teleops (e.g. SO-101, OpenArmMini) have non-empty ``feedback_features``
|
||||
and expose ``enable_torque`` / ``disable_torque`` motor-control methods.
|
||||
|
||||
TODO(Maxime): See if it is possible to unify this interface across teleops instead of duck-typing.
|
||||
"""
|
||||
return (
|
||||
bool(teleop.feedback_features)
|
||||
and hasattr(teleop, "disable_torque")
|
||||
and hasattr(teleop, "enable_torque")
|
||||
)
|
||||
|
||||
|
||||
def teleop_smooth_move_to(teleop, target_pos: dict, duration_s: float = 2.0, fps: int = 30) -> None:
|
||||
"""Smoothly move an actuated teleop to ``target_pos`` via linear interpolation.
|
||||
|
||||
Requires the teleoperator to support feedback (i.e. have non-empty
|
||||
``feedback_features`` and implement ``disable_torque`` / ``enable_torque``).
|
||||
|
||||
``target_pos`` is expected to be in the teleop's action/feedback key space.
|
||||
For homogeneous setups (e.g. SO-101 leader + SO-101 follower) this matches
|
||||
the robot action key space directly.
|
||||
|
||||
TODO(Maxime): This blocks up to ``duration_s`` seconds; during this time the
|
||||
follower robot does not receive new actions, which could be an issue on LeKiwi.
|
||||
"""
|
||||
teleop.enable_torque()
|
||||
current = teleop.get_action()
|
||||
steps = max(int(duration_s * fps), 1)
|
||||
|
||||
for step in range(steps + 1):
|
||||
t = step / steps
|
||||
interp = {
|
||||
k: current[k] * (1 - t) + target_pos[k] * t if k in target_pos else current[k] for k in current
|
||||
}
|
||||
teleop.send_feedback(interp)
|
||||
time.sleep(1 / fps)
|
||||
|
||||
|
||||
def follower_smooth_move_to(
|
||||
robot, current: dict, target: dict, duration_s: float = 1.0, fps: int = 30
|
||||
) -> None:
|
||||
"""Smoothly move the follower robot from ``current`` to ``target`` action.
|
||||
|
||||
Used when the teleop is non-actuated: instead of driving the leader arm to
|
||||
the follower, the follower is brought to the teleop's current pose so the
|
||||
robot meets the operator's hand rather than jumping to it on the first frame.
|
||||
|
||||
Both ``current`` and ``target`` must be in the robot action key space
|
||||
(i.e. the output of ``robot_action_processor``).
|
||||
"""
|
||||
steps = max(int(duration_s * fps), 1)
|
||||
|
||||
for step in range(steps + 1):
|
||||
t = step / steps
|
||||
interp = {k: current[k] * (1 - t) + target[k] * t if k in target else current[k] for k in current}
|
||||
robot.send_action(interp)
|
||||
time.sleep(1 / fps)
|
||||
|
||||
@@ -41,8 +41,8 @@ class DatasetRecordConfig:
|
||||
video: bool = True
|
||||
# Upload dataset to Hugging Face hub.
|
||||
push_to_hub: bool = True
|
||||
# Upload on private repository on the Hugging Face hub.
|
||||
private: bool = False
|
||||
# If True, upload as private; if None, defer to the org default on the Hub (only affects orgs).
|
||||
private: bool | None = None
|
||||
# Add tags to your dataset on the hub.
|
||||
tags: list[str] | None = None
|
||||
# Number of subprocesses handling the saving of frames as PNG. Set to 0 to use threads only;
|
||||
|
||||
@@ -177,6 +177,12 @@ class TrainPipelineConfig(HubMixin):
|
||||
)
|
||||
|
||||
active_cfg = self.trainable_config
|
||||
if self.rename_map and active_cfg.pretrained_path is None:
|
||||
raise ValueError(
|
||||
"`rename_map` requires a pretrained policy checkpoint. "
|
||||
"Fresh initialization derives feature names from the current dataset, so no rename is applied."
|
||||
)
|
||||
|
||||
if not self.job_name:
|
||||
if self.env is None:
|
||||
self.job_name = f"{active_cfg.type}"
|
||||
|
||||
@@ -524,7 +524,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
license: str | None = "apache-2.0",
|
||||
tag_version: bool = True,
|
||||
push_videos: bool = True,
|
||||
private: bool = False,
|
||||
private: bool | None = None,
|
||||
allow_patterns: list[str] | str | None = None,
|
||||
upload_large_folder: bool = False,
|
||||
**card_kwargs,
|
||||
@@ -543,7 +543,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
tag_version: If ``True``, create a Git tag for the current codebase
|
||||
version.
|
||||
push_videos: If ``False``, skip uploading the ``videos/`` directory.
|
||||
private: If ``True``, create a private repository.
|
||||
private: If ``True``, create a private repository. If ``None``
|
||||
(default), defer to the org default on the Hub (only affects orgs).
|
||||
allow_patterns: Glob pattern(s) restricting which files to upload.
|
||||
upload_large_folder: If ``True``, use ``upload_large_folder`` instead
|
||||
of ``upload_folder`` for very large datasets.
|
||||
|
||||
@@ -757,7 +757,7 @@ class RoboTwinEnvConfig(EnvConfig):
|
||||
|
||||
task: str = "beat_block_hammer" # single task or comma-separated list
|
||||
fps: int = 25
|
||||
episode_length: int = 300
|
||||
episode_length: int = 1200
|
||||
obs_type: str = "pixels_agent_pos"
|
||||
render_mode: str = "rgb_array"
|
||||
# Available cameras from RoboTwin's aloha-agilex embodiment: head_camera
|
||||
@@ -768,6 +768,9 @@ class RoboTwinEnvConfig(EnvConfig):
|
||||
# must equal what SAPIEN actually renders.
|
||||
observation_height: int = 240
|
||||
observation_width: int = 320
|
||||
# "joint": 14-d joint-space control. "ee": 16-d end-effector-pose deltas executed via CuRobo IK
|
||||
# (for world-model policies like LingBot-VA that predict per-arm xyz+quaternion+gripper poses).
|
||||
action_mode: str = "joint"
|
||||
features: dict[str, PolicyFeature] = field(
|
||||
default_factory=lambda: {
|
||||
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(14,)),
|
||||
@@ -784,6 +787,8 @@ class RoboTwinEnvConfig(EnvConfig):
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.action_mode == "ee":
|
||||
self.features[ACTION] = PolicyFeature(type=FeatureType.ACTION, shape=(16,))
|
||||
cam_list = [c.strip() for c in self.camera_names.split(",") if c.strip()]
|
||||
for cam in cam_list:
|
||||
self.features[f"pixels/{cam}"] = PolicyFeature(
|
||||
@@ -826,6 +831,7 @@ class RoboTwinEnvConfig(EnvConfig):
|
||||
observation_height=self.observation_height,
|
||||
observation_width=self.observation_width,
|
||||
episode_length=self.episode_length,
|
||||
action_mode=self.action_mode,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable, Sequence
|
||||
from functools import partial
|
||||
@@ -41,10 +42,117 @@ ROBOTWIN_CAMERA_NAMES: tuple[str, ...] = (
|
||||
"right_camera",
|
||||
)
|
||||
|
||||
ACTION_DIM = 14 # 7 DOF × 2 arms
|
||||
ACTION_DIM = 14 # 7 DOF × 2 arms (joint-space control mode)
|
||||
# End-effector-pose control mode: per arm [x, y, z, qx, qy, qz, qw, gripper] = 8, dual-arm = 16.
|
||||
# Used by world-model policies (e.g. LingBot-VA) that predict eef-pose deltas executed via CuRobo IK.
|
||||
EEF_ACTION_DIM = 16
|
||||
ACTION_LOW = -1.0
|
||||
ACTION_HIGH = 1.0
|
||||
DEFAULT_EPISODE_LENGTH = 300
|
||||
DEFAULT_EPISODE_LENGTH = 1200
|
||||
OFFICIAL_INSTRUCTION_ENV = "LEROBOT_ROBOTWIN_OFFICIAL_INSTRUCTION"
|
||||
OFFICIAL_INSTRUCTION_TYPE_ENV = "LEROBOT_ROBOTWIN_INSTRUCTION_TYPE"
|
||||
OFFICIAL_INSTRUCTION_MAX_ENV = "LEROBOT_ROBOTWIN_INSTRUCTION_MAX"
|
||||
|
||||
|
||||
def _compose_eef_pose(new_pose: np.ndarray, init_pose: np.ndarray) -> np.ndarray:
|
||||
"""Compose a single-arm predicted delta pose onto the initial pose.
|
||||
|
||||
``new_pose`` / ``init_pose`` are 8-vectors ``[x, y, z, qx, qy, qz, qw, gripper]``. Translation
|
||||
is added, rotation is composed (``init_R * new_R``), and the gripper is taken from the
|
||||
prediction. Mirrors ``add_eef_pose`` in the upstream LingBot-VA RoboTwin client.
|
||||
"""
|
||||
from scipy.spatial.transform import Rotation
|
||||
|
||||
new_r = Rotation.from_quat(new_pose[3:7])
|
||||
init_r = Rotation.from_quat(init_pose[3:7])
|
||||
out_rot = (init_r * new_r).as_quat()
|
||||
out_trans = new_pose[:3] + init_pose[:3]
|
||||
return np.concatenate([out_trans, out_rot, new_pose[7:8]])
|
||||
|
||||
|
||||
def _add_init_eef_pose(delta_pose: np.ndarray, init_pose: np.ndarray) -> np.ndarray:
|
||||
"""Compose a dual-arm (16-d) predicted delta pose onto the initial eef pose, normalizing quats."""
|
||||
left = _compose_eef_pose(delta_pose[:8], init_pose[:8])
|
||||
right = _compose_eef_pose(delta_pose[8:], init_pose[8:])
|
||||
out = np.concatenate([left, right])
|
||||
# Normalize the two quaternions (indices 3:7 and 11:15) as the upstream client does.
|
||||
out[3:7] = out[3:7] / (np.linalg.norm(out[3:7]) + 1e-8)
|
||||
out[11:15] = out[11:15] / (np.linalg.norm(out[11:15]) + 1e-8)
|
||||
return out
|
||||
|
||||
|
||||
def _env_flag(name: str, default: bool = False) -> bool:
|
||||
raw = os.environ.get(name)
|
||||
if raw is None:
|
||||
return default
|
||||
return raw.strip().lower() in {"1", "true", "yes", "on"}
|
||||
|
||||
|
||||
def _arm_for_block(block: Any) -> str:
|
||||
return "left" if float(block.get_pose().p[0]) < 0 else "right"
|
||||
|
||||
|
||||
def _robotwin_blocks_episode_info(task_name: str, env: Any) -> dict[str, str] | None:
|
||||
"""Infer the episode-info dict used by RoboTwin's official instruction generator for block ranking."""
|
||||
if task_name == "blocks_ranking_rgb":
|
||||
return {
|
||||
"{A}": "red block",
|
||||
"{B}": "green block",
|
||||
"{C}": "blue block",
|
||||
"{a}": _arm_for_block(env.block1),
|
||||
"{b}": _arm_for_block(env.block2),
|
||||
"{c}": _arm_for_block(env.block3),
|
||||
}
|
||||
if task_name == "blocks_ranking_size":
|
||||
return {
|
||||
"{A}": "large block",
|
||||
"{B}": "medium block",
|
||||
"{C}": "small block",
|
||||
"{a}": _arm_for_block(env.block1),
|
||||
"{b}": _arm_for_block(env.block2),
|
||||
"{c}": _arm_for_block(env.block3),
|
||||
}
|
||||
return None
|
||||
|
||||
|
||||
def _generate_robotwin_official_instruction(task_name: str, env: Any) -> str:
|
||||
"""Generate language with RoboTwin's official task templates, matching its eval client."""
|
||||
fallback = task_name.replace("_", " ")
|
||||
episode_info = _robotwin_blocks_episode_info(task_name, env)
|
||||
if episode_info is None:
|
||||
logger.warning("Official RoboTwin instruction is not implemented for task=%s; using %r.", task_name, fallback)
|
||||
return fallback
|
||||
|
||||
try:
|
||||
from description.utils.generate_episode_instructions import generate_episode_descriptions
|
||||
except Exception:
|
||||
logger.warning("Failed to import RoboTwin official instruction generator; using %r.", fallback, exc_info=True)
|
||||
return fallback
|
||||
|
||||
instruction_type = os.environ.get(OFFICIAL_INSTRUCTION_TYPE_ENV, "seen")
|
||||
try:
|
||||
max_descriptions = int(os.environ.get(OFFICIAL_INSTRUCTION_MAX_ENV, "1000000"))
|
||||
except ValueError:
|
||||
max_descriptions = 1000000
|
||||
|
||||
results = generate_episode_descriptions(task_name, [episode_info], max_descriptions=max_descriptions)
|
||||
if not results:
|
||||
logger.warning("RoboTwin generated no official instructions for task=%s; using %r.", task_name, fallback)
|
||||
return fallback
|
||||
|
||||
options = results[0].get(instruction_type) or results[0].get("seen") or results[0].get("unseen")
|
||||
if not options:
|
||||
logger.warning(
|
||||
"RoboTwin generated no %s official instructions for task=%s; using %r.",
|
||||
instruction_type,
|
||||
task_name,
|
||||
fallback,
|
||||
)
|
||||
return fallback
|
||||
|
||||
return str(np.random.choice(options))
|
||||
|
||||
|
||||
# D435 dims from task_config/_camera_config.yml (what demo_clean.yml selects).
|
||||
DEFAULT_CAMERA_H = 240
|
||||
DEFAULT_CAMERA_W = 320
|
||||
@@ -234,6 +342,7 @@ class RoboTwinEnv(gym.Env):
|
||||
observation_width: int | None = None,
|
||||
episode_length: int = DEFAULT_EPISODE_LENGTH,
|
||||
render_mode: str = "rgb_array",
|
||||
action_mode: str = "joint",
|
||||
):
|
||||
super().__init__()
|
||||
self.task_name = task_name
|
||||
@@ -241,6 +350,13 @@ class RoboTwinEnv(gym.Env):
|
||||
self.task_description = task_name.replace("_", " ")
|
||||
self.episode_index = episode_index
|
||||
self._reset_stride = n_envs
|
||||
# "joint": 14-d joint-space actions via take_action(action). "ee": 16-d end-effector-pose
|
||||
# deltas (added onto the episode's initial eef pose) executed via take_action(.., "ee") + IK.
|
||||
if action_mode not in ("joint", "ee"):
|
||||
raise ValueError(f"action_mode must be 'joint' or 'ee'; got {action_mode!r}")
|
||||
self.action_mode = action_mode
|
||||
self._action_dim = EEF_ACTION_DIM if action_mode == "ee" else ACTION_DIM
|
||||
self._init_eef_pose: np.ndarray | None = None
|
||||
self.camera_names = list(camera_names)
|
||||
# Default to D435 dims (the camera type baked into task_config/demo_clean.yml).
|
||||
# The YAML-driven lookup is deferred to reset() so construction doesn't
|
||||
@@ -271,7 +387,7 @@ class RoboTwinEnv(gym.Env):
|
||||
}
|
||||
)
|
||||
self.action_space = spaces.Box(
|
||||
low=ACTION_LOW, high=ACTION_HIGH, shape=(ACTION_DIM,), dtype=np.float32
|
||||
low=ACTION_LOW, high=ACTION_HIGH, shape=(self._action_dim,), dtype=np.float32
|
||||
)
|
||||
|
||||
def _ensure_env(self) -> None:
|
||||
@@ -317,6 +433,18 @@ class RoboTwinEnv(gym.Env):
|
||||
|
||||
return {"pixels": images, "agent_pos": joint_state}
|
||||
|
||||
def _read_eef_pose(self) -> np.ndarray:
|
||||
"""Read the current 16-d dual-arm eef pose [left(xyz+quat)+grip, right(xyz+quat)+grip]."""
|
||||
assert self._env is not None, "_read_eef_pose called before _ensure_env()"
|
||||
ep = self._env.get_obs()["endpose"]
|
||||
pose = (
|
||||
list(ep["left_endpose"])
|
||||
+ [ep["left_gripper"]]
|
||||
+ list(ep["right_endpose"])
|
||||
+ [ep["right_gripper"]]
|
||||
)
|
||||
return np.asarray(pose, dtype=np.float64)
|
||||
|
||||
def reset(self, seed: int | None = None, **kwargs) -> tuple[RobotObservation, dict]:
|
||||
self._ensure_env()
|
||||
super().reset(seed=seed)
|
||||
@@ -330,16 +458,32 @@ class RoboTwinEnv(gym.Env):
|
||||
self.episode_index += self._reset_stride
|
||||
self._step_count = 0
|
||||
|
||||
use_official_instruction = self.task_name in {"blocks_ranking_rgb", "blocks_ranking_size"}
|
||||
if _env_flag(OFFICIAL_INSTRUCTION_ENV, default=use_official_instruction):
|
||||
self.task_description = _generate_robotwin_official_instruction(self.task_name, self._env)
|
||||
if hasattr(self._env, "set_instruction"):
|
||||
self._env.set_instruction(instruction=self.task_description)
|
||||
logger.info("RoboTwin official instruction | task=%s | %s", self.task_name, self.task_description)
|
||||
else:
|
||||
self.task_description = self.task_name.replace("_", " ")
|
||||
|
||||
# In eef mode the policy predicts pose deltas relative to the initial eef pose.
|
||||
if self.action_mode == "ee":
|
||||
self._init_eef_pose = self._read_eef_pose()
|
||||
|
||||
obs = self._get_obs()
|
||||
return obs, {"is_success": False, "task": self.task_name}
|
||||
|
||||
def step(self, action: np.ndarray) -> tuple[RobotObservation, float, bool, bool, dict[str, Any]]:
|
||||
assert self._env is not None, "step() called before reset()"
|
||||
if action.ndim != 1 or action.shape[0] != ACTION_DIM:
|
||||
raise ValueError(f"Expected 1-D action of shape ({ACTION_DIM},), got {action.shape}")
|
||||
if action.ndim != 1 or action.shape[0] != self._action_dim:
|
||||
raise ValueError(f"Expected 1-D action of shape ({self._action_dim},), got {action.shape}")
|
||||
|
||||
with torch.enable_grad():
|
||||
if hasattr(self._env, "take_action"):
|
||||
if self.action_mode == "ee":
|
||||
ee_action = _add_init_eef_pose(np.asarray(action, dtype=np.float64), self._init_eef_pose)
|
||||
self._env.take_action(ee_action, action_type="ee")
|
||||
elif hasattr(self._env, "take_action"):
|
||||
self._env.take_action(action)
|
||||
else:
|
||||
self._env.step(action)
|
||||
@@ -398,6 +542,7 @@ def _make_env_fns(
|
||||
observation_height: int,
|
||||
observation_width: int,
|
||||
episode_length: int,
|
||||
action_mode: str = "joint",
|
||||
) -> list[Callable[[], RoboTwinEnv]]:
|
||||
"""Return n_envs factory callables for a single task."""
|
||||
|
||||
@@ -410,6 +555,7 @@ def _make_env_fns(
|
||||
observation_height=observation_height,
|
||||
observation_width=observation_width,
|
||||
episode_length=episode_length,
|
||||
action_mode=action_mode,
|
||||
)
|
||||
|
||||
return [partial(_make_one, i) for i in range(n_envs)]
|
||||
@@ -423,6 +569,7 @@ def create_robotwin_envs(
|
||||
observation_height: int = DEFAULT_CAMERA_H,
|
||||
observation_width: int = DEFAULT_CAMERA_W,
|
||||
episode_length: int = DEFAULT_EPISODE_LENGTH,
|
||||
action_mode: str = "joint",
|
||||
) -> dict[str, dict[int, Any]]:
|
||||
"""Create vectorized RoboTwin 2.0 environments.
|
||||
|
||||
@@ -473,6 +620,7 @@ def create_robotwin_envs(
|
||||
observation_height=observation_height,
|
||||
observation_width=observation_width,
|
||||
episode_length=episode_length,
|
||||
action_mode=action_mode,
|
||||
)
|
||||
if is_async:
|
||||
lazy = _LazyAsyncVectorEnv(fns, cached_obs_space, cached_act_space, cached_metadata)
|
||||
|
||||
@@ -83,6 +83,28 @@ class VQBeTSchedulerConfig(LRSchedulerConfig):
|
||||
return LambdaLR(optimizer, lr_lambda, -1)
|
||||
|
||||
|
||||
@LRSchedulerConfig.register_subclass("constant_with_warmup")
|
||||
@dataclass
|
||||
class ConstantWithWarmupSchedulerConfig(LRSchedulerConfig):
|
||||
"""Linear warmup followed by a constant learning rate.
|
||||
|
||||
Mirrors the ``warmup_constant_lambda`` used by LingBot-VA (upstream ``wan_va/train.py``):
|
||||
the LR ramps linearly from 0 to the peak over ``num_warmup_steps`` steps, then stays flat.
|
||||
"""
|
||||
|
||||
num_warmup_steps: int = 1000
|
||||
|
||||
def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR:
|
||||
warmup_steps = self.num_warmup_steps or 0
|
||||
|
||||
def lr_lambda(current_step):
|
||||
if current_step < warmup_steps:
|
||||
return float(current_step) / float(max(1, warmup_steps))
|
||||
return 1.0
|
||||
|
||||
return LambdaLR(optimizer, lr_lambda, -1)
|
||||
|
||||
|
||||
@LRSchedulerConfig.register_subclass("cosine_decay_with_warmup")
|
||||
@dataclass
|
||||
class CosineDecayWithWarmupSchedulerConfig(LRSchedulerConfig):
|
||||
|
||||
@@ -20,6 +20,7 @@ from .eo1.configuration_eo1 import EO1Config as EO1Config
|
||||
from .factory import get_policy_class, make_policy, make_policy_config, make_pre_post_processors
|
||||
from .gaussian_actor.configuration_gaussian_actor import GaussianActorConfig as GaussianActorConfig
|
||||
from .groot.configuration_groot import GrootConfig as GrootConfig
|
||||
from .lingbot_va.configuration_lingbot_va import LingBotVAConfig as LingBotVAConfig
|
||||
from .molmoact2.configuration_molmoact2 import MolmoAct2Config as MolmoAct2Config
|
||||
from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig as MultiTaskDiTConfig
|
||||
from .pi0.configuration_pi0 import PI0Config as PI0Config
|
||||
@@ -44,6 +45,7 @@ __all__ = [
|
||||
"EO1Config",
|
||||
"GaussianActorConfig",
|
||||
"GrootConfig",
|
||||
"LingBotVAConfig",
|
||||
"MolmoAct2Config",
|
||||
"MultiTaskDiTConfig",
|
||||
"PI0Config",
|
||||
|
||||
@@ -49,6 +49,7 @@ from .diffusion.configuration_diffusion import DiffusionConfig
|
||||
from .eo1.configuration_eo1 import EO1Config
|
||||
from .gaussian_actor.configuration_gaussian_actor import GaussianActorConfig
|
||||
from .groot.configuration_groot import GrootConfig
|
||||
from .lingbot_va.configuration_lingbot_va import LingBotVAConfig
|
||||
from .molmoact2.configuration_molmoact2 import MolmoAct2Config
|
||||
from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
|
||||
from .pi0.configuration_pi0 import PI0Config
|
||||
@@ -57,6 +58,7 @@ from .pretrained import PreTrainedPolicy
|
||||
from .smolvla.configuration_smolvla import SmolVLAConfig
|
||||
from .tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from .utils import validate_visual_features_consistency
|
||||
from .vla_jepa.configuration_vla_jepa import VLAJEPAConfig
|
||||
from .vqbet.configuration_vqbet import VQBeTConfig
|
||||
from .wall_x.configuration_wall_x import WallXConfig
|
||||
from .xvla.configuration_xvla import XVLAConfig
|
||||
@@ -157,6 +159,14 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
from .molmoact2.modeling_molmoact2 import MolmoAct2Policy
|
||||
|
||||
return MolmoAct2Policy
|
||||
elif name == "vla_jepa":
|
||||
from .vla_jepa.modeling_vla_jepa import VLAJEPAPolicy
|
||||
|
||||
return VLAJEPAPolicy
|
||||
elif name == "lingbot_va":
|
||||
from .lingbot_va.modeling_lingbot_va import LingBotVAPolicy
|
||||
|
||||
return LingBotVAPolicy
|
||||
else:
|
||||
try:
|
||||
return _get_policy_cls_from_policy_name(name=name)
|
||||
@@ -211,6 +221,10 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
return EO1Config(**kwargs)
|
||||
elif policy_type == "molmoact2":
|
||||
return MolmoAct2Config(**kwargs)
|
||||
elif policy_type == "vla_jepa":
|
||||
return VLAJEPAConfig(**kwargs)
|
||||
elif policy_type == "lingbot_va":
|
||||
return LingBotVAConfig(**kwargs)
|
||||
else:
|
||||
try:
|
||||
config_cls = PreTrainedConfig.get_choice_class(policy_type)
|
||||
@@ -415,6 +429,7 @@ def make_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, EO1Config):
|
||||
from .eo1.processor_eo1 import make_eo1_pre_post_processors
|
||||
|
||||
@@ -432,6 +447,22 @@ def make_pre_post_processors(
|
||||
dataset_meta=kwargs.get("dataset_meta"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, VLAJEPAConfig):
|
||||
from .vla_jepa.processor_vla_jepa import make_vla_jepa_pre_post_processors
|
||||
|
||||
processors = make_vla_jepa_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, LingBotVAConfig):
|
||||
from .lingbot_va.processor_lingbot_va import make_lingbot_va_pre_post_processors
|
||||
|
||||
processors = make_lingbot_va_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
else:
|
||||
try:
|
||||
processors = _make_processors_from_policy_config(
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
../../../../docs/source/lingbot_va.mdx
|
||||
@@ -0,0 +1,33 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# NOTE: ``LingBotVAPolicy`` (and the Wan transformer it owns) imports ``diffusers`` as a
|
||||
# hard dependency at class-definition time (it subclasses diffusers' ModelMixin/ConfigMixin).
|
||||
# To keep base ``import lerobot`` working without the optional ``lingbot_va`` extra, the
|
||||
# policy is exposed lazily via module ``__getattr__`` — the heavy import only happens when
|
||||
# ``LingBotVAPolicy`` is actually accessed (mirroring the lazy import in policies/factory.py).
|
||||
from .configuration_lingbot_va import LingBotVAConfig
|
||||
from .processor_lingbot_va import make_lingbot_va_pre_post_processors
|
||||
|
||||
__all__ = ["LingBotVAConfig", "LingBotVAPolicy", "make_lingbot_va_pre_post_processors"]
|
||||
|
||||
|
||||
def __getattr__(name):
|
||||
if name == "LingBotVAPolicy":
|
||||
from .modeling_lingbot_va import LingBotVAPolicy
|
||||
|
||||
return LingBotVAPolicy
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
@@ -0,0 +1,168 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Configuration for the LingBot-VA policy.
|
||||
|
||||
LingBot-VA is an autoregressive video-action world-model policy built on the Wan2.2
|
||||
video-diffusion stack. It interleaves prediction of future video latents and robot
|
||||
actions in a single dual-stream transformer. See ``docs/source/lingbot_va.mdx`` and the
|
||||
upstream repository (https://github.com/Robbyant/lingbot-va).
|
||||
|
||||
Defaults below match the upstream LIBERO configuration (``wan_va/configs/va_libero_cfg.py``)
|
||||
and the ``transformer/config.json`` of the released checkpoints.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.optim.optimizers import AdamWConfig
|
||||
from lerobot.optim.schedulers import LRSchedulerConfig
|
||||
from lerobot.utils.constants import ACTION
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("lingbot_va")
|
||||
@dataclass
|
||||
class LingBotVAConfig(PreTrainedConfig):
|
||||
"""Configuration for the native LingBot-VA policy integration in LeRobot."""
|
||||
|
||||
# Wan transformer architecture
|
||||
patch_size: tuple[int, int, int] = (1, 2, 2)
|
||||
num_attention_heads: int = 24
|
||||
attention_head_dim: int = 128
|
||||
in_channels: int = 48
|
||||
out_channels: int = 48
|
||||
action_dim: int = 30
|
||||
text_dim: int = 4096
|
||||
freq_dim: int = 256
|
||||
ffn_dim: int = 14336
|
||||
num_layers: int = 30
|
||||
cross_attn_norm: bool = True
|
||||
eps: float = 1e-6
|
||||
rope_max_seq_len: int = 1024
|
||||
# "flex" = training only (needs recent torch); inference uses "torch" SDPA or "flashattn".
|
||||
attn_mode: str = "torch"
|
||||
|
||||
# Frozen sub-models (VAE + UMT5 text encoder + tokenizer)
|
||||
# ~20 GB of frozen weights, NOT bundled in the checkpoint; lazily pulled from this HF repo /
|
||||
# local dir (must hold diffusers-style ``vae/``, ``text_encoder/``, ``tokenizer/`` sub-folders).
|
||||
wan_pretrained_path: str = "robbyant/lingbot-va-base"
|
||||
dtype: str = "bfloat16" # transformer / VAE / text-encoder dtype: "bfloat16", "float16", "float32"
|
||||
# Frozen UMT5-XXL encoder device; "cpu" frees ~11 GB VRAM (it runs once per episode).
|
||||
text_encoder_device: str = "cpu"
|
||||
|
||||
# Observation cameras (order matters: latents are concatenated on width; LIBERO defaults)
|
||||
obs_cam_keys: list[str] = field(
|
||||
default_factory=lambda: ["observation.images.image", "observation.images.image2"]
|
||||
)
|
||||
# Undo the LIBERO env processor's extra horizontal flip to match the model's training orientation.
|
||||
image_hflip: bool = False
|
||||
# Camera latent layout: "width_concat" (cameras concatenated on width; LIBERO) or
|
||||
# "robotwin_tshape" (full-res head + half-res wrists in a "T"; RoboTwin).
|
||||
camera_layout: str = "width_concat"
|
||||
|
||||
# Inference hyperparameters (LIBERO defaults)
|
||||
n_obs_steps: int = 1
|
||||
height: int = 128
|
||||
width: int = 128
|
||||
action_per_frame: int = 4
|
||||
frame_chunk_size: int = 4
|
||||
attn_window: int = 30
|
||||
num_inference_steps: int = 20
|
||||
video_exec_step: int = -1
|
||||
action_num_inference_steps: int = 50
|
||||
guidance_scale: float = 5.0
|
||||
action_guidance_scale: float = 1.0
|
||||
snr_shift: float = 5.0
|
||||
action_snr_shift: float = 0.05
|
||||
max_sequence_length: int = 512 # UMT5 prompt length
|
||||
|
||||
# Subset of the 30-d action space used by the benchmark (LIBERO = 7-DoF). The action
|
||||
# (un)normalization quantiles live in the checkpoint's ``policy_postprocessor.json``, not here.
|
||||
used_action_channel_ids: list[int] = field(default_factory=lambda: list(range(7)))
|
||||
|
||||
# Opt-in: VAE-decode predicted video latents to ``self.last_predicted_frames`` for saving MP4s.
|
||||
save_predicted_video: bool = False
|
||||
|
||||
# Normalization: IDENTITY here; images are scaled + VAE-encoded and actions are
|
||||
# quantile-(un)normalized inside the policy / dedicated processor steps.
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.IDENTITY,
|
||||
"ACTION": NormalizationMode.IDENTITY,
|
||||
}
|
||||
)
|
||||
|
||||
# Optimizer / scheduler (training; AdamW + warmup-constant per upstream train.py)
|
||||
optimizer_lr: float = 1e-5
|
||||
optimizer_betas: tuple[float, float] = (0.9, 0.95)
|
||||
optimizer_eps: float = 1e-8
|
||||
optimizer_weight_decay: float = 1e-4
|
||||
optimizer_grad_clip_norm: float = 1.0
|
||||
scheduler_warmup_steps: int = 1000
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
if self.attn_mode not in ("torch", "flashattn", "flex"):
|
||||
raise ValueError(f"attn_mode must be one of 'torch', 'flashattn', 'flex'; got {self.attn_mode!r}")
|
||||
|
||||
@property
|
||||
def chunk_size(self) -> int:
|
||||
"""Number of single-step actions produced per autoregressive chunk."""
|
||||
return self.frame_chunk_size * self.action_per_frame
|
||||
|
||||
@property
|
||||
def n_action_steps(self) -> int:
|
||||
"""Number of actions executed before refilling (the whole chunk)."""
|
||||
return self.chunk_size
|
||||
|
||||
def validate_features(self) -> None:
|
||||
image_features = [key for key, feat in self.input_features.items() if feat.type == FeatureType.VISUAL]
|
||||
if not image_features:
|
||||
raise ValueError(
|
||||
"LingBot-VA requires at least one visual input feature. "
|
||||
"No features of type FeatureType.VISUAL found in input_features."
|
||||
)
|
||||
if ACTION not in self.output_features:
|
||||
self.output_features[ACTION] = PolicyFeature(
|
||||
type=FeatureType.ACTION, shape=(len(self.used_action_channel_ids),)
|
||||
)
|
||||
|
||||
def get_optimizer_preset(self) -> AdamWConfig:
|
||||
return AdamWConfig(
|
||||
lr=self.optimizer_lr,
|
||||
betas=self.optimizer_betas,
|
||||
eps=self.optimizer_eps,
|
||||
weight_decay=self.optimizer_weight_decay,
|
||||
grad_clip_norm=self.optimizer_grad_clip_norm,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self) -> LRSchedulerConfig | None:
|
||||
# Upstream uses a linear warmup followed by a constant LR (warmup_constant_lambda).
|
||||
from lerobot.optim.schedulers import ConstantWithWarmupSchedulerConfig
|
||||
|
||||
return ConstantWithWarmupSchedulerConfig(num_warmup_steps=self.scheduler_warmup_steps)
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list[int]:
|
||||
return list(range(self.chunk_size))
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
return None
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,87 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Pre/post-processor pipelines for the LingBot-VA policy.
|
||||
|
||||
The preprocessor passes inputs through (IDENTITY) and the postprocessor maps the policy's
|
||||
``[-1, 1]`` actions back to physical units with the built-in ``UnnormalizerProcessorStep``
|
||||
(QUANTILES) using per-channel q01/q99 restored from the checkpoint.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyAction,
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorStep,
|
||||
RenameObservationsProcessorStep,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
|
||||
from lerobot.utils.constants import (
|
||||
POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
)
|
||||
|
||||
from .configuration_lingbot_va import LingBotVAConfig
|
||||
|
||||
|
||||
def make_lingbot_va_pre_post_processors(
|
||||
config: LingBotVAConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
"""Build the pre/post processor pipelines for LingBot-VA."""
|
||||
|
||||
input_steps: list[ProcessorStep] = [
|
||||
RenameObservationsProcessorStep(rename_map={}),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
]
|
||||
|
||||
# Unnormalize actions from [-1, 1] to physical units (QUANTILES) using q01/q99 restored from the checkpoint.
|
||||
output_steps: list[ProcessorStep] = [
|
||||
UnnormalizerProcessorStep(
|
||||
features=config.output_features,
|
||||
norm_map={FeatureType.ACTION: NormalizationMode.QUANTILES},
|
||||
stats=dataset_stats,
|
||||
),
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
]
|
||||
|
||||
return (
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
|
||||
steps=input_steps,
|
||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
),
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction](
|
||||
steps=output_steps,
|
||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
to_transition=policy_action_to_transition,
|
||||
to_output=transition_to_policy_action,
|
||||
),
|
||||
)
|
||||
+1
@@ -0,0 +1 @@
|
||||
../../../../docs/source/policy_vla_jepa_README.md
|
||||
@@ -0,0 +1,23 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .configuration_vla_jepa import VLAJEPAConfig
|
||||
from .modeling_vla_jepa import VLAJEPAPolicy
|
||||
from .processor_vla_jepa import make_vla_jepa_pre_post_processors
|
||||
|
||||
__all__ = [
|
||||
"VLAJEPAConfig",
|
||||
"VLAJEPAPolicy",
|
||||
"make_vla_jepa_pre_post_processors",
|
||||
]
|
||||
@@ -0,0 +1,337 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import nn
|
||||
from torch.distributions import Beta
|
||||
|
||||
from lerobot.utils.import_utils import _diffusers_available, require_package
|
||||
|
||||
if TYPE_CHECKING or _diffusers_available:
|
||||
from diffusers import ConfigMixin, ModelMixin
|
||||
from diffusers.configuration_utils import register_to_config
|
||||
from diffusers.models.attention import Attention, FeedForward
|
||||
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
|
||||
else:
|
||||
|
||||
class ModelMixin: # type: ignore[no-redef]
|
||||
pass
|
||||
|
||||
class ConfigMixin: # type: ignore[no-redef]
|
||||
pass
|
||||
|
||||
register_to_config = lambda f: f # noqa: E731
|
||||
Attention = FeedForward = TimestepEmbedding = Timesteps = None
|
||||
|
||||
from .configuration_vla_jepa import VLAJEPAConfig
|
||||
|
||||
|
||||
class SinusoidalPositionalEncoding(nn.Module):
|
||||
def __init__(self, embedding_dim: int):
|
||||
super().__init__()
|
||||
self.embedding_dim = embedding_dim
|
||||
|
||||
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
|
||||
timesteps = timesteps.float()
|
||||
batch_size, seq_len = timesteps.shape
|
||||
half_dim = self.embedding_dim // 2
|
||||
exponent = -torch.arange(half_dim, dtype=torch.float, device=timesteps.device)
|
||||
exponent = exponent * (torch.log(torch.tensor(10000.0, device=timesteps.device)) / max(half_dim, 1))
|
||||
freqs = timesteps.unsqueeze(-1) * exponent.exp()
|
||||
return torch.cat([torch.sin(freqs), torch.cos(freqs)], dim=-1).view(batch_size, seq_len, -1)
|
||||
|
||||
|
||||
class ActionEncoder(nn.Module):
|
||||
def __init__(self, action_dim: int, hidden_size: int):
|
||||
super().__init__()
|
||||
self.layer1 = nn.Linear(action_dim, hidden_size)
|
||||
self.layer2 = nn.Linear(hidden_size * 2, hidden_size)
|
||||
self.layer3 = nn.Linear(hidden_size, hidden_size)
|
||||
self.pos_encoding = SinusoidalPositionalEncoding(hidden_size)
|
||||
|
||||
def forward(self, actions: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, seq_len, _ = actions.shape
|
||||
if timesteps.ndim != 1 or timesteps.shape[0] != batch_size:
|
||||
raise ValueError("timesteps must have shape [batch_size].")
|
||||
timesteps = timesteps.unsqueeze(1).expand(-1, seq_len)
|
||||
action_emb = self.layer1(actions)
|
||||
time_emb = self.pos_encoding(timesteps).to(dtype=action_emb.dtype)
|
||||
return self.layer3(F.silu(self.layer2(torch.cat([action_emb, time_emb], dim=-1))))
|
||||
|
||||
|
||||
class TimestepEncoder(nn.Module):
|
||||
def __init__(self, embedding_dim: int):
|
||||
super().__init__()
|
||||
require_package("diffusers", extra="vla_jepa")
|
||||
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1)
|
||||
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
||||
|
||||
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
|
||||
projected = self.time_proj(timesteps).to(dtype=next(self.parameters()).dtype)
|
||||
return self.timestep_embedder(projected)
|
||||
|
||||
|
||||
class AdaLayerNorm(nn.Module):
|
||||
def __init__(self, embedding_dim: int):
|
||||
super().__init__()
|
||||
self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
|
||||
self.norm = nn.LayerNorm(embedding_dim, eps=1e-5, elementwise_affine=False)
|
||||
self.silu = nn.SiLU()
|
||||
|
||||
def forward(self, x: torch.Tensor, temb: torch.Tensor) -> torch.Tensor:
|
||||
scale, shift = self.linear(self.silu(temb)).chunk(2, dim=-1)
|
||||
return self.norm(x) * (1 + scale[:, None]) + shift[:, None]
|
||||
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
dropout: float,
|
||||
cross_attention_dim: int,
|
||||
is_cross_attention: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.is_cross_attention = is_cross_attention
|
||||
self.norm1 = AdaLayerNorm(dim)
|
||||
self.attn1 = Attention(
|
||||
query_dim=dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
dropout=dropout,
|
||||
bias=True,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
out_bias=True,
|
||||
)
|
||||
self.norm2 = nn.LayerNorm(dim, eps=1e-5, elementwise_affine=False)
|
||||
self.ff = FeedForward(dim, dropout=dropout, activation_fn="gelu-approximate", final_dropout=True)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor | None,
|
||||
temb: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
attn_input = self.norm1(hidden_states, temb)
|
||||
attention_context = encoder_hidden_states if self.is_cross_attention else None
|
||||
hidden_states = hidden_states + self.attn1(attn_input, encoder_hidden_states=attention_context)
|
||||
hidden_states = hidden_states + self.ff(self.norm2(hidden_states))
|
||||
return hidden_states
|
||||
|
||||
|
||||
class DiT(ModelMixin, ConfigMixin):
|
||||
_supports_gradient_checkpointing = False
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
output_dim: int,
|
||||
num_layers: int,
|
||||
dropout: float,
|
||||
cross_attention_dim: int,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.inner_dim = num_attention_heads * attention_head_dim
|
||||
self.timestep_encoder = TimestepEncoder(self.inner_dim)
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
dim=self.inner_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
dropout=dropout,
|
||||
cross_attention_dim=cross_attention_dim if layer_idx % 2 == 0 else self.inner_dim,
|
||||
is_cross_attention=layer_idx % 2 == 0,
|
||||
)
|
||||
for layer_idx in range(num_layers)
|
||||
]
|
||||
)
|
||||
self.norm_out = nn.LayerNorm(self.inner_dim, eps=1e-6, elementwise_affine=False)
|
||||
self.proj_out_1 = nn.Linear(self.inner_dim, self.inner_dim * 2)
|
||||
self.proj_out_2 = nn.Linear(self.inner_dim, output_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
timestep: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
temb = self.timestep_encoder(timestep)
|
||||
x = hidden_states
|
||||
for block in self.transformer_blocks:
|
||||
x = block(x, encoder_hidden_states=encoder_hidden_states, temb=temb)
|
||||
shift, scale = self.proj_out_1(F.silu(temb)).chunk(2, dim=-1)
|
||||
x = self.norm_out(x) * (1 + scale[:, None]) + shift[:, None]
|
||||
return self.proj_out_2(x)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActionModelPreset:
|
||||
hidden_size: int
|
||||
attention_head_dim: int
|
||||
num_attention_heads: int
|
||||
|
||||
|
||||
DIT_PRESETS = {
|
||||
"DiT-B": ActionModelPreset(hidden_size=768, attention_head_dim=64, num_attention_heads=12),
|
||||
"DiT-L": ActionModelPreset(hidden_size=1536, attention_head_dim=48, num_attention_heads=32),
|
||||
"DiT-test": ActionModelPreset(hidden_size=16, attention_head_dim=8, num_attention_heads=2),
|
||||
}
|
||||
|
||||
|
||||
class VLAJEPAActionHead(nn.Module):
|
||||
def __init__(self, config: VLAJEPAConfig, cross_attention_dim: int) -> None:
|
||||
super().__init__()
|
||||
preset = DIT_PRESETS[config.action_model_type]
|
||||
self.config = config
|
||||
num_heads = config.action_num_heads or preset.num_attention_heads
|
||||
head_dim = config.action_attention_head_dim or preset.attention_head_dim
|
||||
inner_dim = num_heads * head_dim # e.g. DiT-B: 12 × 64 = 768
|
||||
|
||||
self.input_embedding_dim = inner_dim
|
||||
self.action_horizon = config.chunk_size
|
||||
self.num_inference_timesteps = config.num_inference_timesteps
|
||||
|
||||
hidden_size = config.action_hidden_size
|
||||
self.model = DiT(
|
||||
num_attention_heads=num_heads,
|
||||
attention_head_dim=head_dim,
|
||||
output_dim=hidden_size,
|
||||
num_layers=config.action_num_layers,
|
||||
dropout=config.action_dropout,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
)
|
||||
self.action_encoder = ActionEncoder(config.action_dim, inner_dim)
|
||||
self.action_decoder = nn.Sequential(
|
||||
OrderedDict(
|
||||
[
|
||||
("layer1", nn.Linear(hidden_size, hidden_size)),
|
||||
("relu", nn.ReLU()),
|
||||
("layer2", nn.Linear(hidden_size, config.action_dim)),
|
||||
]
|
||||
)
|
||||
)
|
||||
self.state_encoder = (
|
||||
nn.Sequential(
|
||||
OrderedDict(
|
||||
[
|
||||
("layer1", nn.Linear(config.state_dim, hidden_size)),
|
||||
("relu", nn.ReLU()),
|
||||
("layer2", nn.Linear(hidden_size, inner_dim)),
|
||||
]
|
||||
)
|
||||
)
|
||||
if config.state_dim > 0
|
||||
else None
|
||||
)
|
||||
self.future_tokens = nn.Embedding(config.num_embodied_action_tokens_per_instruction, inner_dim)
|
||||
self.position_embedding = nn.Embedding(
|
||||
max(1024, config.chunk_size + config.num_action_tokens_per_timestep + 4),
|
||||
inner_dim,
|
||||
)
|
||||
self.beta_dist = Beta(config.action_noise_beta_alpha, config.action_noise_beta_beta)
|
||||
|
||||
def sample_time(self, batch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
|
||||
sample = self.beta_dist.sample([batch_size]).to(device=device, dtype=dtype)
|
||||
return (self.config.action_noise_s - sample) / self.config.action_noise_s
|
||||
|
||||
def _build_inputs(
|
||||
self,
|
||||
conditioning_tokens: torch.Tensor,
|
||||
actions: torch.Tensor,
|
||||
state: torch.Tensor | None,
|
||||
timesteps: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
action_features = self.action_encoder(actions, timesteps)
|
||||
pos_ids = torch.arange(action_features.shape[1], device=actions.device)
|
||||
action_features = action_features + self.position_embedding(pos_ids)[None]
|
||||
|
||||
future_tokens = self.future_tokens.weight.unsqueeze(0).expand(actions.shape[0], -1, -1)
|
||||
seq = [future_tokens, action_features]
|
||||
if state is not None and self.state_encoder is not None:
|
||||
if state.ndim == 2:
|
||||
state = state.unsqueeze(1)
|
||||
seq.insert(0, self.state_encoder(state))
|
||||
return torch.cat(seq, dim=1)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
conditioning_tokens: torch.Tensor,
|
||||
actions: torch.Tensor,
|
||||
state: torch.Tensor | None = None,
|
||||
action_is_pad: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
noise = torch.randn_like(actions)
|
||||
t = self.sample_time(actions.shape[0], actions.device, actions.dtype)
|
||||
noisy_actions = (1 - t[:, None, None]) * noise + t[:, None, None] * actions
|
||||
velocity = actions - noise
|
||||
t_discretized = (t * self.config.action_num_timestep_buckets).long()
|
||||
|
||||
hidden_states = self._build_inputs(conditioning_tokens, noisy_actions, state, t_discretized)
|
||||
pred = self.model(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=conditioning_tokens,
|
||||
timestep=t_discretized,
|
||||
)
|
||||
pred_actions = self.action_decoder(pred[:, -actions.shape[1] :])
|
||||
|
||||
if action_is_pad is None:
|
||||
action_is_pad = torch.zeros(actions.shape[:2], dtype=torch.bool, device=actions.device)
|
||||
|
||||
loss = F.mse_loss(pred_actions, velocity, reduction="none") # [B, T, action_dim]
|
||||
valid_mask = ~action_is_pad.unsqueeze(-1) # [B, T, 1]
|
||||
num_valid = valid_mask.sum() * loss.shape[-1]
|
||||
return (loss * valid_mask).sum() / num_valid.clamp_min(1)
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action(
|
||||
self,
|
||||
conditioning_tokens: torch.Tensor,
|
||||
state: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
batch_size = conditioning_tokens.shape[0]
|
||||
actions = torch.randn(
|
||||
batch_size,
|
||||
self.action_horizon,
|
||||
self.config.action_dim,
|
||||
dtype=conditioning_tokens.dtype,
|
||||
device=conditioning_tokens.device,
|
||||
)
|
||||
dt = 1.0 / max(self.num_inference_timesteps, 1)
|
||||
for step in range(self.num_inference_timesteps):
|
||||
t_cont = step / float(max(self.num_inference_timesteps, 1))
|
||||
t_value = int(t_cont * self.config.action_num_timestep_buckets)
|
||||
timesteps = torch.full(
|
||||
(batch_size,), t_value, device=conditioning_tokens.device, dtype=torch.long
|
||||
)
|
||||
hidden_states = self._build_inputs(conditioning_tokens, actions, state, timesteps)
|
||||
pred = self.model(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=conditioning_tokens,
|
||||
timestep=timesteps,
|
||||
)
|
||||
pred_velocity = self.action_decoder(pred[:, -self.action_horizon :])
|
||||
actions = actions + dt * pred_velocity
|
||||
return actions
|
||||
@@ -0,0 +1,154 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import NormalizationMode
|
||||
from lerobot.optim.optimizers import AdamWConfig
|
||||
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("vla_jepa")
|
||||
@dataclass
|
||||
class VLAJEPAConfig(PreTrainedConfig):
|
||||
n_obs_steps: int = 1
|
||||
chunk_size: int = 7
|
||||
n_action_steps: int = 7
|
||||
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.MEAN_STD,
|
||||
"ACTION": NormalizationMode.MIN_MAX,
|
||||
}
|
||||
)
|
||||
|
||||
qwen_model_name: str = "Qwen/Qwen3-VL-2B-Instruct"
|
||||
jepa_encoder_name: str = "facebook/vjepa2-vitl-fpc64-256"
|
||||
freeze_qwen: bool = False
|
||||
enable_world_model: bool = True
|
||||
# Enables cross-embodiment transfer: when fine-tuning a pretrained model on a robot with a
|
||||
# different action or state dimensionality, the input/output projection layers must be
|
||||
# re-initialised from scratch while the rest of the network keeps its pretrained weights.
|
||||
# List the key prefixes that are allowed to have shape mismatches; anything else raises an error.
|
||||
# e.g. ["model.action_model.action_encoder", "model.action_model.state_encoder"]
|
||||
reinit_modules: list[str] | None = None
|
||||
|
||||
tokenizer_padding_side: str = "left"
|
||||
prompt_template: str = "Your task is {instruction}. Infer the temporal dynamics from frames {actions} and produce the corresponding policy actions {e_actions}."
|
||||
special_action_token: str = "<|action_{}|>"
|
||||
embodied_action_token: str = "<|embodied_action|>"
|
||||
|
||||
action_dim: int = 7
|
||||
state_dim: int = 8
|
||||
|
||||
num_action_tokens_per_timestep: int = 8
|
||||
num_embodied_action_tokens_per_instruction: int = 32
|
||||
num_inference_timesteps: int = 4
|
||||
|
||||
action_hidden_size: int = 1024
|
||||
action_model_type: str = "DiT-B"
|
||||
action_num_layers: int = 16
|
||||
action_num_heads: int | None = None
|
||||
action_attention_head_dim: int | None = None
|
||||
action_dropout: float = 0.2
|
||||
action_num_timestep_buckets: int = 1000
|
||||
action_noise_beta_alpha: float = 1.5
|
||||
action_noise_beta_beta: float = 1.0
|
||||
action_noise_s: float = 0.999
|
||||
num_target_vision_tokens: int = 32
|
||||
action_max_seq_len: int = 1024
|
||||
|
||||
# total video frames loaded per sample
|
||||
num_video_frames: int = 8
|
||||
predictor_depth: int = 12
|
||||
predictor_num_heads: int = 8
|
||||
predictor_mlp_ratio: float = 4.0
|
||||
predictor_dropout: float = 0.0
|
||||
world_model_loss_weight: float = 0.1
|
||||
jepa_tubelet_size: int = 2 # must match the encoder (e.g. 2 for vjepa2-vitl-fpc64-256)
|
||||
repeated_diffusion_steps: int = 8 # independent noise draws per batch item (CogACT-style)
|
||||
|
||||
resize_images_to: tuple[int, int] | None = None
|
||||
binarize_gripper_action: bool = True
|
||||
pre_snap_gripper_action: bool = True
|
||||
clip_normalized_actions: bool = True
|
||||
gripper_dim: int = 6
|
||||
gripper_threshold: float = 0.5
|
||||
torch_dtype: str = "bfloat16"
|
||||
|
||||
optimizer_lr: float = 1e-4
|
||||
optimizer_betas: tuple[float, float] = (0.9, 0.95)
|
||||
optimizer_eps: float = 1e-8
|
||||
optimizer_weight_decay: float = 1e-10
|
||||
optimizer_grad_clip_norm: float = 10.0
|
||||
scheduler_warmup_steps: int = 1_000
|
||||
scheduler_decay_steps: int = 30_000
|
||||
scheduler_decay_lr: float = 2.5e-6
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
super().__post_init__()
|
||||
if self.freeze_qwen and self.enable_world_model:
|
||||
# freezing qwen backbone makes world model training irrelevant since no grad flows
|
||||
self.enable_world_model = False
|
||||
if self.n_action_steps > self.chunk_size:
|
||||
raise ValueError("`n_action_steps` must be <= `chunk_size`.")
|
||||
if self.num_video_frames < 2 * self.jepa_tubelet_size:
|
||||
raise ValueError(
|
||||
f"`video_horizon` ({self.num_video_frames}) must be >= 2 * `jepa_tubelet_size` "
|
||||
f"({self.jepa_tubelet_size}) to have at least one context and one GT temporal position."
|
||||
)
|
||||
|
||||
def validate_features(self) -> None:
|
||||
if not self.image_features:
|
||||
raise ValueError("VLAJEPA requires at least one visual input feature.")
|
||||
if self.action_feature is None:
|
||||
raise ValueError("VLAJEPA requires an action output feature.")
|
||||
self.action_dim = self.action_feature.shape[0]
|
||||
if self.robot_state_feature is not None:
|
||||
self.state_dim = self.robot_state_feature.shape[0]
|
||||
|
||||
def get_optimizer_preset(self) -> AdamWConfig:
|
||||
return AdamWConfig(
|
||||
lr=self.optimizer_lr,
|
||||
betas=self.optimizer_betas,
|
||||
eps=self.optimizer_eps,
|
||||
weight_decay=self.optimizer_weight_decay,
|
||||
grad_clip_norm=self.optimizer_grad_clip_norm,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self) -> CosineDecayWithWarmupSchedulerConfig:
|
||||
return CosineDecayWithWarmupSchedulerConfig(
|
||||
peak_lr=self.optimizer_lr,
|
||||
decay_lr=self.scheduler_decay_lr,
|
||||
num_warmup_steps=self.scheduler_warmup_steps,
|
||||
num_decay_steps=self.scheduler_decay_steps,
|
||||
)
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> list[int]:
|
||||
# load video_horizon frames starting from current timestep: [t, t+1, ..., t+video_horizon-1]
|
||||
# matches original repo's observation_indices=list(range(video_horizon))
|
||||
return list(range(self.num_video_frames))
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list[int]:
|
||||
return list(range(self.chunk_size))
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
return None
|
||||
@@ -0,0 +1,629 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections import deque
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from PIL import Image
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
||||
from lerobot.policies.utils import populate_queues
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE
|
||||
from lerobot.utils.import_utils import _transformers_available, require_package
|
||||
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers import AutoModel, AutoVideoProcessor
|
||||
else:
|
||||
AutoModel = None
|
||||
AutoVideoProcessor = None
|
||||
|
||||
from .action_head import VLAJEPAActionHead
|
||||
from .configuration_vla_jepa import VLAJEPAConfig
|
||||
from .qwen_interface import Qwen3VLInterface
|
||||
from .world_model import ActionConditionedVideoPredictor
|
||||
|
||||
# ============================================================================
|
||||
# Native VLA-JEPA Model - follows original starVLA VLA_JEPA.py implementation
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class VLAJEPAModel(nn.Module):
|
||||
"""
|
||||
Native VLA-JEPA model following the original starVLA VLA_JEPA.py.
|
||||
|
||||
Components:
|
||||
- Qwen3-VL: vision-language backbone for fused embeddings
|
||||
- DiT-B: flow-matching action head for future action prediction
|
||||
- V-JEPA: world model for video frame prediction
|
||||
|
||||
Input: List[dict] native format (same as original starVLA)
|
||||
- "image": List[PIL.Image] (multi-view images)
|
||||
- "video": np.ndarray [V, T, H, W, 3]
|
||||
- "lang": str (task instruction)
|
||||
- "action": np.ndarray [T, action_dim] (optional, training only)
|
||||
- "state": np.ndarray [1, state_dim] (optional)
|
||||
"""
|
||||
|
||||
def __init__(self, config: VLAJEPAConfig) -> None:
|
||||
super().__init__()
|
||||
require_package("transformers", extra="vla_jepa")
|
||||
self.config = config
|
||||
|
||||
# Vision-language backbone
|
||||
self.qwen = Qwen3VLInterface(config)
|
||||
|
||||
# Tokenizer expansion for special action tokens
|
||||
self.action_tokens, self.action_token_ids, self.embodied_action_token_id = (
|
||||
self.qwen.expand_tokenizer()
|
||||
)
|
||||
|
||||
# Action head (flow-matching DiT)
|
||||
self.action_model = VLAJEPAActionHead(config, cross_attention_dim=self.qwen.model.config.hidden_size)
|
||||
|
||||
# JEPA world model components
|
||||
if config.enable_world_model:
|
||||
self.video_encoder = AutoModel.from_pretrained(
|
||||
config.jepa_encoder_name,
|
||||
torch_dtype=self.qwen._get_torch_dtype(config.torch_dtype),
|
||||
)
|
||||
self.video_processor = AutoVideoProcessor.from_pretrained(config.jepa_encoder_name)
|
||||
num_views = config.jepa_tubelet_size
|
||||
tubelet_size = self.video_encoder.config.tubelet_size
|
||||
image_size = getattr(self.video_encoder.config, "image_size", None)
|
||||
if image_size is None:
|
||||
first_image_shape = next(iter(config.image_features.values())).shape
|
||||
image_size = first_image_shape[-1]
|
||||
self.video_predictor = ActionConditionedVideoPredictor(
|
||||
num_frames=config.num_video_frames // tubelet_size,
|
||||
img_size=(image_size, image_size),
|
||||
patch_size=16,
|
||||
tubelet_size=1,
|
||||
embed_dim=self.video_encoder.config.hidden_size * num_views,
|
||||
action_embed_dim=self.qwen.model.config.hidden_size,
|
||||
predictor_embed_dim=self.video_encoder.config.hidden_size,
|
||||
depth=config.predictor_depth,
|
||||
num_heads=config.predictor_num_heads,
|
||||
mlp_ratio=config.predictor_mlp_ratio,
|
||||
num_action_tokens_per_step=config.num_action_tokens_per_timestep,
|
||||
)
|
||||
else:
|
||||
self.video_encoder = None
|
||||
self.video_processor = None
|
||||
self.video_predictor = None
|
||||
|
||||
if config.freeze_qwen:
|
||||
self.qwen.requires_grad_(False)
|
||||
|
||||
# Build prompt placeholders.
|
||||
# Use the encoder's actual tubelet_size when available (world model enabled),
|
||||
# otherwise fall back to config.
|
||||
_tubelet_size = (
|
||||
self.video_encoder.config.tubelet_size
|
||||
if config.enable_world_model
|
||||
else self.config.jepa_tubelet_size
|
||||
)
|
||||
num_action_prompt_steps = self.config.num_video_frames // _tubelet_size - 1
|
||||
self.replace_prompt = "".join(
|
||||
token * self.config.num_action_tokens_per_timestep
|
||||
for token in self.action_tokens[:num_action_prompt_steps]
|
||||
)
|
||||
self.embodied_replace_prompt = (
|
||||
self.config.embodied_action_token * self.config.num_embodied_action_tokens_per_instruction
|
||||
)
|
||||
|
||||
def _qwen_last_decoder_hidden(self, qwen_inputs: dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
"""Return the last decoder hidden state before the final RMSNorm.
|
||||
|
||||
The model was trained with the output of the last transformer block BEFORE
|
||||
the final RMSNorm. In transformers 5.x, `hidden_states[-1]` from
|
||||
`output_hidden_states=True` is post-norm (tied to `last_hidden_state` via
|
||||
`@capture_outputs`). A forward hook on `language_model.layers[-1]` recovers
|
||||
the correct pre-RMSNorm state, matching the training-time representation.
|
||||
"""
|
||||
captured: list[torch.Tensor] = []
|
||||
|
||||
def _hook(module, input, output):
|
||||
h = output[0] if isinstance(output, tuple) else output
|
||||
captured.append(h)
|
||||
|
||||
last_layer = self.qwen.model.model.language_model.layers[-1]
|
||||
handle = last_layer.register_forward_hook(_hook)
|
||||
try:
|
||||
self.qwen.model(
|
||||
**qwen_inputs,
|
||||
output_hidden_states=False,
|
||||
output_attentions=False,
|
||||
return_dict=True,
|
||||
)
|
||||
finally:
|
||||
handle.remove()
|
||||
|
||||
return captured[0] # [B, seq_len, H]
|
||||
|
||||
# ---- Native VLA-JEPA forward (follows original VLA_JEPA.py) ----
|
||||
|
||||
def forward(self, examples: list[dict]) -> dict[str, Tensor]:
|
||||
"""
|
||||
Native forward pass following original starVLA VLA_JEPA.forward.
|
||||
|
||||
Args:
|
||||
examples: List of per-sample dicts with keys:
|
||||
"image" : List[PIL.Image] — multi-view images
|
||||
"video" : np.ndarray [V, T, H, W, 3]
|
||||
"lang" : str — task instruction
|
||||
"action" : np.ndarray [T, action_dim] (optional)
|
||||
"state" : np.ndarray [1, state_dim] (optional)
|
||||
|
||||
Returns:
|
||||
dict with "action_loss" and "wm_loss" keys (scalar Tensors).
|
||||
"""
|
||||
# Unpack native format (same pattern as original VLA_JEPA.py)
|
||||
batch_images = [ex["image"] for ex in examples] # List[List[PIL.Image]]
|
||||
batch_videos = [ex["video"] for ex in examples] # List[np.ndarray]
|
||||
instructions = [ex["lang"] for ex in examples] # List[str]
|
||||
has_action = "action" in examples[0] and examples[0]["action"] is not None
|
||||
actions = [ex["action"] for ex in examples] if has_action else None
|
||||
has_state = "state" in examples[0] and examples[0]["state"] is not None
|
||||
state = [ex["state"] for ex in examples] if has_state else None
|
||||
action_is_pad = (
|
||||
[ex["action_is_pad"] for ex in examples]
|
||||
if has_action and "action_is_pad" in examples[0] and examples[0]["action_is_pad"] is not None
|
||||
else None
|
||||
)
|
||||
|
||||
# Stack videos: [B, V, T, H, W, 3] -> [B, V, T, 3, H, W]
|
||||
batch_videos = np.stack(batch_videos)
|
||||
batch_videos = batch_videos.transpose(0, 1, 2, 5, 3, 4) # [B, V, T, 3, H, W]
|
||||
|
||||
# Adjust number of views for the world model:
|
||||
# - fewer views than expected: duplicate the first view to fill up
|
||||
# - more views than expected: keep only the first num_views_world_model views
|
||||
num_views_world_model = self.config.jepa_tubelet_size
|
||||
if batch_videos.shape[1] < num_views_world_model:
|
||||
num_missing_views = num_views_world_model - batch_videos.shape[1]
|
||||
first_view = np.repeat(batch_videos[:, :1], num_missing_views, axis=1)
|
||||
batch_videos = np.concatenate([batch_videos, first_view], axis=1)
|
||||
elif batch_videos.shape[1] > num_views_world_model:
|
||||
batch_videos = batch_videos[:, :num_views_world_model]
|
||||
|
||||
# ---- Step 1: QwenVL encode (same as original) ----
|
||||
qwen_inputs = self.qwen.build_inputs(
|
||||
images=batch_images,
|
||||
instructions=instructions,
|
||||
action_prompt=self.replace_prompt,
|
||||
embodied_prompt=self.embodied_replace_prompt,
|
||||
)
|
||||
|
||||
# Locate embodied-action tokens (always needed for action head)
|
||||
embodied_mask = qwen_inputs["input_ids"] == self.embodied_action_token_id
|
||||
embodied_indices = embodied_mask.nonzero(as_tuple=True)
|
||||
|
||||
# Locate action tokens (only needed for world model predictor)
|
||||
if self.config.enable_world_model:
|
||||
action_mask = torch.isin(
|
||||
qwen_inputs["input_ids"],
|
||||
torch.tensor(self.action_token_ids, device=qwen_inputs["input_ids"].device),
|
||||
)
|
||||
action_indices = action_mask.nonzero(as_tuple=True)
|
||||
|
||||
device_type = next(self.parameters()).device.type
|
||||
|
||||
with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
|
||||
last_hidden = self._qwen_last_decoder_hidden(qwen_inputs) # [B, seq_len, H]
|
||||
b, _, h = last_hidden.shape
|
||||
|
||||
if self.config.enable_world_model:
|
||||
action_tokens = last_hidden[action_indices[0], action_indices[1], :].view(b, -1, h)
|
||||
|
||||
embodied_action_tokens = last_hidden[embodied_indices[0], embodied_indices[1], :].view(b, -1, h)
|
||||
|
||||
# ---- Step 2+3: JEPA Encoder + Predictor ----
|
||||
device_wm = last_hidden.device
|
||||
if not self.config.enable_world_model:
|
||||
wm_loss = torch.tensor(0.0, device=device_wm)
|
||||
else:
|
||||
b, v, t_frames, c, h_img, w_img = batch_videos.shape
|
||||
batch_videos_flat = batch_videos.reshape(b * v, t_frames, c, h_img, w_img)
|
||||
|
||||
video_pixels = self.video_processor(videos=list(batch_videos_flat), return_tensors="pt")[
|
||||
"pixel_values_videos"
|
||||
].to(self.video_encoder.device) # [B*V, T, C, H, W]
|
||||
|
||||
with torch.no_grad():
|
||||
video_embeddings = self.video_encoder.get_vision_features(pixel_values_videos=video_pixels)
|
||||
# Merge views: [B*V, ...] -> [B, ..., V*embed_dim]
|
||||
video_embeddings = torch.cat(torch.chunk(video_embeddings, chunks=v, dim=0), dim=2)
|
||||
|
||||
tubelet_size = self.video_encoder.config.tubelet_size
|
||||
device_wm = video_embeddings.device
|
||||
# num_video_frames raw frames → t_enc_total temporal positions after tubelet compression
|
||||
t_enc_total = self.config.num_video_frames // tubelet_size
|
||||
|
||||
if t_enc_total < 2:
|
||||
wm_loss = torch.tensor(0.0, device=device_wm)
|
||||
else:
|
||||
# Shift-by-one JEPA split (matches original VLA_JEPA.py lines 231-232):
|
||||
# input_states: positions 0..T-2, gt_states: positions 1..T-1
|
||||
t_enc_ctx = t_enc_total - 1
|
||||
tokens_per_frame = video_embeddings.shape[1] // t_enc_total
|
||||
|
||||
input_states = video_embeddings[:, : tokens_per_frame * t_enc_ctx, :]
|
||||
gt_states = video_embeddings[:, tokens_per_frame:, :]
|
||||
|
||||
expected_actions = t_enc_ctx * self.config.num_action_tokens_per_timestep
|
||||
if action_tokens.shape[1] < expected_actions:
|
||||
pad = action_tokens[:, -1:].repeat(1, expected_actions - action_tokens.shape[1], 1)
|
||||
action_tokens = torch.cat([action_tokens, pad], dim=1)
|
||||
|
||||
predicted_states = self.video_predictor(
|
||||
input_states.float(),
|
||||
action_tokens[:, :expected_actions].float(),
|
||||
)
|
||||
|
||||
wm_loss = F.l1_loss(predicted_states, gt_states.float(), reduction="mean")
|
||||
|
||||
if not has_action:
|
||||
return {"wm_loss": wm_loss}
|
||||
|
||||
# ---- Step 4: Action Head ----
|
||||
with torch.autocast(device_type=device_type, dtype=torch.float32):
|
||||
actions_tensor = torch.tensor(
|
||||
np.array(actions), device=last_hidden.device, dtype=torch.float32
|
||||
) # [B, T_full, action_dim]
|
||||
action_horizon = self.config.chunk_size
|
||||
actions_target = actions_tensor[:, -action_horizon:, :]
|
||||
|
||||
state_tensor = None
|
||||
if state is not None:
|
||||
state_tensor = torch.tensor(
|
||||
np.array(state), device=last_hidden.device, dtype=last_hidden.dtype
|
||||
) # [B, 1, state_dim]
|
||||
|
||||
repeated_diffusion_steps = self.config.repeated_diffusion_steps
|
||||
actions_target = actions_target.repeat(repeated_diffusion_steps, 1, 1)
|
||||
embodied_action_tokens = embodied_action_tokens.repeat(repeated_diffusion_steps, 1, 1)
|
||||
if state_tensor is not None:
|
||||
state_tensor = state_tensor.repeat(repeated_diffusion_steps, 1, 1)
|
||||
|
||||
action_is_pad_rep = None
|
||||
if action_is_pad is not None:
|
||||
pad_tensor = torch.stack(
|
||||
[
|
||||
p.to(actions_target.device)
|
||||
if isinstance(p, Tensor)
|
||||
else torch.tensor(p, device=actions_target.device)
|
||||
for p in action_is_pad
|
||||
]
|
||||
) # [B, T_full]
|
||||
pad_tensor = pad_tensor[:, -action_horizon:] # [B, action_horizon]
|
||||
action_is_pad_rep = pad_tensor.repeat(repeated_diffusion_steps, 1) # [B*R, action_horizon]
|
||||
|
||||
action_loss = self.action_model(
|
||||
embodied_action_tokens, actions_target, state_tensor, action_is_pad_rep
|
||||
)
|
||||
|
||||
return {"action_loss": action_loss, "wm_loss": wm_loss * self.config.world_model_loss_weight}
|
||||
|
||||
# ---- Native predict_action (follows original VLA_JEPA.predict_action) ----
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action(
|
||||
self,
|
||||
batch_images: list[list[Image.Image]],
|
||||
instructions: list[str],
|
||||
state: np.ndarray | None = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Native action prediction following original VLA_JEPA.predict_action.
|
||||
|
||||
Args:
|
||||
batch_images: List of samples; each is List[PIL.Image] (multi-view).
|
||||
instructions: Task instructions, one per sample.
|
||||
state: Optional [B, state_dim] numpy array.
|
||||
|
||||
Returns:
|
||||
np.ndarray [B, action_horizon, action_dim] — predicted actions.
|
||||
"""
|
||||
if self.config.resize_images_to is not None:
|
||||
height, width = self.config.resize_images_to
|
||||
resampling = getattr(Image, "Resampling", Image).BOX
|
||||
batch_images = [
|
||||
[image.resize((width, height), resample=resampling) for image in sample_images]
|
||||
for sample_images in batch_images
|
||||
]
|
||||
|
||||
qwen_inputs = self.qwen.build_inputs(
|
||||
images=batch_images,
|
||||
instructions=instructions,
|
||||
action_prompt=self.replace_prompt,
|
||||
embodied_prompt=self.embodied_replace_prompt,
|
||||
)
|
||||
|
||||
embodied_mask = qwen_inputs["input_ids"] == self.embodied_action_token_id
|
||||
embodied_indices = embodied_mask.nonzero(as_tuple=True)
|
||||
|
||||
device_type = next(self.parameters()).device.type
|
||||
|
||||
with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
|
||||
last_hidden = self._qwen_last_decoder_hidden(qwen_inputs) # [B, seq_len, H]
|
||||
b, _, h = last_hidden.shape
|
||||
embodied_action_tokens = last_hidden[embodied_indices[0], embodied_indices[1], :].view(b, -1, h)
|
||||
|
||||
state_tensor = None
|
||||
if state is not None:
|
||||
state_tensor = torch.from_numpy(np.array(state)).to(
|
||||
device=last_hidden.device, dtype=last_hidden.dtype
|
||||
)
|
||||
|
||||
pred_actions = self.action_model.predict_action(
|
||||
embodied_action_tokens.float(), state_tensor.float() if state_tensor is not None else None
|
||||
) # [B, action_horizon, action_dim]
|
||||
|
||||
return pred_actions.detach().cpu().numpy()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# LeRobot Adapter Layer - converts between LeRobot batch format and native VLA-JEPA format
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class VLAJEPAPolicy(PreTrainedPolicy):
|
||||
"""
|
||||
LeRobot adapter for VLA-JEPA.
|
||||
|
||||
Converts LeRobot's standard batch format (dict[str, Tensor]) to the native
|
||||
VLA-JEPA format (List[dict]), calls the native model, and converts outputs
|
||||
back to LeRobot format.
|
||||
"""
|
||||
|
||||
config_class = VLAJEPAConfig
|
||||
name = "vla_jepa"
|
||||
|
||||
def __init__(self, config: VLAJEPAConfig, **kwargs) -> None:
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
if dataset_meta := kwargs.get("dataset_meta"):
|
||||
# cfg.input_features keeps the pretrained model's feature keys (needed for rename_map
|
||||
# compatibility), so validate_features() may have read stale dims from a pretrained
|
||||
# config. Override state_dim/action_dim from the actual dataset being used.
|
||||
ds_features = dataset_meta.features
|
||||
if OBS_STATE in ds_features:
|
||||
config.state_dim = ds_features[OBS_STATE]["shape"][0]
|
||||
if ACTION in ds_features:
|
||||
config.action_dim = ds_features[ACTION]["shape"][0]
|
||||
|
||||
self.model = VLAJEPAModel(config)
|
||||
self.reset()
|
||||
|
||||
def reset(self) -> None:
|
||||
self._queues = {ACTION: deque(maxlen=self.config.n_action_steps)}
|
||||
|
||||
# ---- Format Conversion: LeRobot → Native ----
|
||||
|
||||
def _prepare_model_inputs(self, batch: dict[str, Tensor]) -> list[dict]:
|
||||
"""
|
||||
Convert LeRobot batch format to native VLA-JEPA examples format.
|
||||
|
||||
LeRobot format:
|
||||
batch = {
|
||||
"observation.images.<key>": Tensor [B, C, H, W] or [B, T, C, H, W],
|
||||
"observation.state": Tensor [B, state_dim] or [B, T, state_dim],
|
||||
"action": Tensor [B, chunk_size, action_dim], (training only)
|
||||
"task": str | List[str], (optional instruction)
|
||||
}
|
||||
|
||||
Native format (List[dict]):
|
||||
{
|
||||
"image": List[PIL.Image], # multi-view images per sample
|
||||
"video": np.ndarray [V, T, H, W, 3],
|
||||
"lang": str, # task instruction
|
||||
"action": np.ndarray [T, action_dim], # optional
|
||||
"state": np.ndarray [1, state_dim], # optional
|
||||
}
|
||||
"""
|
||||
# Determine batch size from the first image feature
|
||||
image_keys = list(self.config.image_features.keys())
|
||||
if not image_keys:
|
||||
raise ValueError("VLAJEPA requires at least one image feature.")
|
||||
first_key = image_keys[0]
|
||||
first_tensor = batch[first_key]
|
||||
batch_size = first_tensor.shape[0]
|
||||
|
||||
# ---- Collect images per sample ----
|
||||
# images_per_sample[b][v] = PIL.Image for view v
|
||||
images_per_sample: list[list[Image.Image]] = [[] for _ in range(batch_size)]
|
||||
for key in image_keys:
|
||||
tensor = batch[key] # [B, C, H, W] or [B, T, C, H, W]
|
||||
if tensor.ndim == 5:
|
||||
# observation_delta_indices = [0, 1, ..., num_video_frames-1]
|
||||
# index 0 is the current observation (delta=0)
|
||||
tensor = tensor[:, 0]
|
||||
for b in range(batch_size):
|
||||
images_per_sample[b].append(self.model.qwen.tensor_to_pil(tensor[b]))
|
||||
|
||||
# ---- Collect videos per sample ----
|
||||
# Build video arrays: for each sample, stack views as [V, T, H, W, 3]
|
||||
# Check whether any image feature has a time dimension
|
||||
video_source = None
|
||||
for k in image_keys:
|
||||
if k in batch:
|
||||
video_source = batch[k] # Use first available for shape inspection
|
||||
break
|
||||
|
||||
if video_source is None:
|
||||
raise ValueError("No image data found in batch for video construction.")
|
||||
|
||||
videos_per_sample = []
|
||||
for b in range(batch_size):
|
||||
sample_views = []
|
||||
for k in image_keys:
|
||||
t = batch[k][b] # [C, H, W] or [T, C, H, W]
|
||||
if t.ndim == 3:
|
||||
t = t.unsqueeze(0) # [1, C, H, W]
|
||||
# Convert to [T, H, W, 3] numpy
|
||||
t_np = t.permute(0, 2, 3, 1).detach().cpu().float().numpy()
|
||||
# Clamp to [0, 255]
|
||||
if t_np.max() <= 1.0:
|
||||
t_np = t_np * 255.0
|
||||
t_np = np.rint(t_np.clip(0, 255)).astype(np.uint8)
|
||||
sample_views.append(t_np)
|
||||
# Stack views: [V, T, H, W, 3]
|
||||
videos_per_sample.append(np.stack(sample_views, axis=0))
|
||||
|
||||
# ---- Collect instructions ----
|
||||
tasks = batch.get("task")
|
||||
if tasks is None:
|
||||
instructions = ["Execute the robot action."] * batch_size
|
||||
elif isinstance(tasks, str):
|
||||
instructions = [tasks] * batch_size
|
||||
else:
|
||||
instructions = list(tasks)
|
||||
|
||||
# ---- Collect actions (training only) ----
|
||||
actions_list = None
|
||||
action_is_pad_list = None
|
||||
actions_tensor = batch.get(ACTION)
|
||||
if actions_tensor is not None:
|
||||
if actions_tensor.ndim == 2:
|
||||
actions_tensor = actions_tensor.unsqueeze(1)
|
||||
actions_list = [actions_tensor[b].detach().cpu().float().numpy() for b in range(batch_size)]
|
||||
action_is_pad_tensor = batch.get("action_is_pad")
|
||||
if action_is_pad_tensor is not None:
|
||||
action_is_pad_list = [action_is_pad_tensor[b].detach().cpu() for b in range(batch_size)]
|
||||
|
||||
# ---- Collect state ----
|
||||
state_list = None
|
||||
state_tensor = batch.get(OBS_STATE)
|
||||
if state_tensor is not None:
|
||||
if state_tensor.ndim > 2:
|
||||
state_tensor = state_tensor[:, -1, :]
|
||||
if state_tensor.ndim == 2:
|
||||
state_tensor = state_tensor.unsqueeze(1) # [B, 1, state_dim]
|
||||
state_list = [state_tensor[b].detach().cpu().float().numpy() for b in range(batch_size)]
|
||||
|
||||
# ---- Assemble native examples ----
|
||||
examples = []
|
||||
for b in range(batch_size):
|
||||
example = {
|
||||
"image": images_per_sample[b],
|
||||
"video": videos_per_sample[b],
|
||||
"lang": instructions[b],
|
||||
}
|
||||
if actions_list is not None:
|
||||
example["action"] = actions_list[b]
|
||||
if action_is_pad_list is not None:
|
||||
example["action_is_pad"] = action_is_pad_list[b]
|
||||
if state_list is not None:
|
||||
example["state"] = state_list[b]
|
||||
examples.append(example)
|
||||
|
||||
return examples
|
||||
|
||||
# ---- LeRobot Policy Interface ----
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
||||
"""LeRobot train forward: convert → native forward → aggregate losses."""
|
||||
examples = self._prepare_model_inputs(batch)
|
||||
native_output = self.model.forward(examples)
|
||||
|
||||
ref = next(iter(native_output.values()))
|
||||
zero = torch.zeros((), device=ref.device, dtype=ref.dtype)
|
||||
total_loss = native_output.get("action_loss", zero) + native_output.get("wm_loss", zero)
|
||||
logs = {k: v.detach().item() for k, v in native_output.items()}
|
||||
logs["loss"] = total_loss.detach().item()
|
||||
return total_loss, logs
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
return self.model.parameters()
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
||||
"""LeRobot inference: convert → native predict → return as Tensor."""
|
||||
self.eval()
|
||||
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
|
||||
|
||||
examples = self._prepare_model_inputs(batch)
|
||||
batch_images = [ex["image"] for ex in examples]
|
||||
instructions = [ex["lang"] for ex in examples]
|
||||
|
||||
state_np = None
|
||||
if "state" in examples[0] and examples[0]["state"] is not None:
|
||||
state_np = np.stack([ex["state"] for ex in examples])
|
||||
|
||||
actions_np = self.model.predict_action(batch_images, instructions, state_np)
|
||||
return torch.from_numpy(actions_np).to(device=self.config.device, dtype=torch.float32)
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
||||
"""LeRobot select_action with action queue caching."""
|
||||
self.eval()
|
||||
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
|
||||
if len(self._queues[ACTION]) == 0:
|
||||
actions = self.predict_action_chunk(batch)
|
||||
self._queues[ACTION].extend(actions.transpose(0, 1)[: self.config.n_action_steps])
|
||||
return self._queues[ACTION].popleft()
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls: type[T],
|
||||
pretrained_name_or_path: str | Path,
|
||||
**kwargs,
|
||||
):
|
||||
return super().from_pretrained(pretrained_name_or_path, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, strict: bool) -> T:
|
||||
reinit_prefixes = model.config.reinit_modules
|
||||
if not reinit_prefixes:
|
||||
return super()._load_as_safetensor(model, model_file, map_location, strict)
|
||||
|
||||
from safetensors.torch import load_file
|
||||
|
||||
state_dict = load_file(model_file, device=map_location)
|
||||
current = model.state_dict()
|
||||
|
||||
reinitialized: list[str] = []
|
||||
filtered: dict = {}
|
||||
for key, value in state_dict.items():
|
||||
if key in current and value.shape != current[key].shape:
|
||||
if not any(key.startswith(p) for p in reinit_prefixes):
|
||||
raise ValueError(
|
||||
f"Shape mismatch for '{key}' (checkpoint {tuple(value.shape)} vs model "
|
||||
f"{tuple(current[key].shape)}) and its prefix is not in `reinit_modules`."
|
||||
)
|
||||
reinitialized.append(
|
||||
f"{key}: checkpoint {tuple(value.shape)} → model {tuple(current[key].shape)}"
|
||||
)
|
||||
else:
|
||||
filtered[key] = value
|
||||
|
||||
if reinitialized:
|
||||
logging.warning(
|
||||
f"reinit_modules: skipping {len(reinitialized)} tensor(s) with mismatched shapes "
|
||||
f"(randomly re-initialised):\n " + "\n ".join(reinitialized)
|
||||
)
|
||||
|
||||
from lerobot.policies.utils import log_model_loading_keys
|
||||
|
||||
missing_keys, unexpected_keys = model.load_state_dict(filtered, strict=False)
|
||||
log_model_loading_keys(missing_keys, unexpected_keys)
|
||||
return model
|
||||
@@ -0,0 +1,155 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.policies.vla_jepa.configuration_vla_jepa import VLAJEPAConfig
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
EnvTransition,
|
||||
NormalizerProcessorStep,
|
||||
PolicyAction,
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorStep,
|
||||
ProcessorStepRegistry,
|
||||
RenameObservationsProcessorStep,
|
||||
TransitionKey,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
|
||||
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register(name="vla_jepa_clip_actions")
|
||||
class ClipActionsProcessorStep(ProcessorStep):
|
||||
"""Clips action tensor to [-1, 1] before unnormalization."""
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
action = transition.get(TransitionKey.ACTION)
|
||||
if action is not None:
|
||||
transition = dict(transition)
|
||||
transition[TransitionKey.ACTION] = action.clamp(-1.0, 1.0)
|
||||
return transition
|
||||
|
||||
def transform_features(self, features):
|
||||
return features
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register(name="vla_jepa_pre_snap_gripper")
|
||||
class PreSnapGripperProcessorStep(ProcessorStep):
|
||||
"""Snaps a gripper dimension to {0, 1} BEFORE unnormalization.
|
||||
|
||||
Mirrors the original starVLA LIBERO eval:
|
||||
normalized[:, gripper_dim] = np.where(normalized[:, gripper_dim] < threshold, 0, 1)
|
||||
This ensures the unnormalizer receives an exact binary value, which is
|
||||
required when the model was trained with gripper in identity (mask=False)
|
||||
space where 0=open and 1=close.
|
||||
"""
|
||||
|
||||
def __init__(self, gripper_dim: int = 6, threshold: float = 0.5):
|
||||
self.gripper_dim = gripper_dim
|
||||
self.threshold = threshold
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
action = transition.get(TransitionKey.ACTION)
|
||||
if action is not None and action.shape[-1] > self.gripper_dim:
|
||||
transition = dict(transition)
|
||||
a = action.clone()
|
||||
a[..., self.gripper_dim] = (a[..., self.gripper_dim] >= self.threshold).float()
|
||||
transition[TransitionKey.ACTION] = a
|
||||
return transition
|
||||
|
||||
def transform_features(self, features):
|
||||
return features
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register(name="vla_jepa_binarize_gripper")
|
||||
class BinarizeGripperProcessorStep(ProcessorStep):
|
||||
"""Binarizes a gripper dimension after unnormalization.
|
||||
|
||||
Maps continuous value to {-1, 1}: > threshold → -1, <= threshold → 1 (matches starVLA convention).
|
||||
Only applied when action has more dimensions than gripper_dim.
|
||||
"""
|
||||
|
||||
def __init__(self, gripper_dim: int = 6, threshold: float = 0.5):
|
||||
self.gripper_dim = gripper_dim
|
||||
self.threshold = threshold
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
action = transition.get(TransitionKey.ACTION)
|
||||
if action is not None and action.shape[-1] > self.gripper_dim:
|
||||
transition = dict(transition)
|
||||
a = action.clone()
|
||||
a[..., self.gripper_dim] = 1.0 - 2.0 * (a[..., self.gripper_dim] > self.threshold).float()
|
||||
transition[TransitionKey.ACTION] = a
|
||||
return transition
|
||||
|
||||
def transform_features(self, features):
|
||||
return features
|
||||
|
||||
|
||||
def make_vla_jepa_pre_post_processors(
|
||||
config: VLAJEPAConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
features = {**config.input_features, **config.output_features}
|
||||
input_steps = [
|
||||
RenameObservationsProcessorStep(rename_map={}),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
NormalizerProcessorStep(
|
||||
features=features,
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
]
|
||||
output_steps: list[ProcessorStep] = []
|
||||
if config.clip_normalized_actions:
|
||||
output_steps.append(ClipActionsProcessorStep())
|
||||
if config.pre_snap_gripper_action:
|
||||
output_steps.append(
|
||||
PreSnapGripperProcessorStep(gripper_dim=config.gripper_dim, threshold=config.gripper_threshold)
|
||||
)
|
||||
output_steps.append(
|
||||
UnnormalizerProcessorStep(
|
||||
features=features,
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
)
|
||||
)
|
||||
if config.binarize_gripper_action:
|
||||
output_steps.append(
|
||||
BinarizeGripperProcessorStep(gripper_dim=config.gripper_dim, threshold=config.gripper_threshold)
|
||||
)
|
||||
output_steps.append(DeviceProcessorStep(device="cpu"))
|
||||
return (
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
|
||||
steps=input_steps,
|
||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
),
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction](
|
||||
steps=output_steps,
|
||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
to_transition=policy_action_to_transition,
|
||||
to_output=transition_to_policy_action,
|
||||
),
|
||||
)
|
||||
@@ -0,0 +1,117 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from lerobot.utils.import_utils import _transformers_available
|
||||
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers import AutoProcessor, Qwen3VLForConditionalGeneration
|
||||
else:
|
||||
AutoProcessor = None
|
||||
Qwen3VLForConditionalGeneration = None
|
||||
|
||||
from .configuration_vla_jepa import VLAJEPAConfig
|
||||
|
||||
|
||||
class Qwen3VLInterface(torch.nn.Module):
|
||||
def __init__(self, config: VLAJEPAConfig) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.model = Qwen3VLForConditionalGeneration.from_pretrained(
|
||||
config.qwen_model_name,
|
||||
torch_dtype=self._get_torch_dtype(config.torch_dtype),
|
||||
)
|
||||
self.processor = AutoProcessor.from_pretrained(config.qwen_model_name)
|
||||
self.processor.tokenizer.padding_side = config.tokenizer_padding_side
|
||||
self.model.config.hidden_size = self.model.config.text_config.hidden_size
|
||||
|
||||
@staticmethod
|
||||
def _get_torch_dtype(dtype_name: str) -> torch.dtype:
|
||||
if dtype_name == "float32":
|
||||
return torch.float32
|
||||
if dtype_name == "float16":
|
||||
return torch.float16
|
||||
return torch.bfloat16
|
||||
|
||||
def expand_tokenizer(self) -> tuple[list[str], list[int], int]:
|
||||
# starVLA/JEVLA checkpoints expand action tokens as action_horizon * 4,
|
||||
# independent of vj2 num_action_tokens_per_timestep. Keeping this count
|
||||
# is required for Qwen embedding/lm_head checkpoint shapes to match.
|
||||
max_action_tokens = self.config.chunk_size * 4
|
||||
tokenizer = self.processor.tokenizer
|
||||
action_tokens = []
|
||||
action_token_ids = []
|
||||
for idx in range(max_action_tokens):
|
||||
token = self.config.special_action_token.format(idx)
|
||||
action_tokens.append(token)
|
||||
if token not in tokenizer.get_vocab():
|
||||
tokenizer.add_tokens([token], special_tokens=True)
|
||||
action_token_ids.append(tokenizer.convert_tokens_to_ids(token))
|
||||
|
||||
embodied_action_token = self.config.embodied_action_token
|
||||
if embodied_action_token not in tokenizer.get_vocab():
|
||||
tokenizer.add_tokens([embodied_action_token], special_tokens=True)
|
||||
embodied_action_token_id = tokenizer.convert_tokens_to_ids(embodied_action_token)
|
||||
|
||||
if self.model.get_input_embeddings().weight.size(0) < len(tokenizer):
|
||||
self.model.resize_token_embeddings(len(tokenizer))
|
||||
return action_tokens, action_token_ids, embodied_action_token_id
|
||||
|
||||
def build_inputs(
|
||||
self,
|
||||
images: Sequence[Sequence[Image.Image]],
|
||||
instructions: Sequence[str],
|
||||
action_prompt: str,
|
||||
embodied_prompt: str,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
messages = []
|
||||
for sample_images, instruction in zip(images, instructions, strict=True):
|
||||
prompt = self.config.prompt_template.format(
|
||||
instruction=instruction,
|
||||
actions=action_prompt,
|
||||
e_actions=embodied_prompt,
|
||||
)
|
||||
content = [{"type": "image", "image": img} for img in sample_images]
|
||||
content.append({"type": "text", "text": prompt})
|
||||
messages.append([{"role": "user", "content": content}])
|
||||
|
||||
batch_inputs = self.processor.apply_chat_template(
|
||||
messages,
|
||||
tokenize=True,
|
||||
add_generation_prompt=True,
|
||||
return_dict=True,
|
||||
processor_kwargs={"padding": True, "return_tensors": "pt"},
|
||||
)
|
||||
return batch_inputs.to(self.model.device)
|
||||
|
||||
@staticmethod
|
||||
def tensor_to_pil(image_tensor: torch.Tensor) -> Image.Image:
|
||||
image = image_tensor.detach().cpu()
|
||||
if image.ndim == 3 and image.shape[0] in (1, 3):
|
||||
image = image.permute(1, 2, 0)
|
||||
image = image.float()
|
||||
if image.max() <= 1.0:
|
||||
image = image * 255.0
|
||||
image = image.clamp(0, 255).round().to(torch.uint8).numpy()
|
||||
if image.shape[-1] == 1:
|
||||
image = np.repeat(image, 3, axis=-1)
|
||||
return Image.fromarray(image)
|
||||
@@ -0,0 +1,418 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import nn
|
||||
|
||||
|
||||
def build_action_block_causal_attention_mask(
|
||||
num_frames: int, grid_height: int, grid_width: int, add_tokens: int = 1
|
||||
) -> torch.Tensor:
|
||||
tokens_per_frame = add_tokens + grid_height * grid_width
|
||||
num_tokens = num_frames * tokens_per_frame
|
||||
mask = torch.zeros(num_tokens, num_tokens, dtype=torch.bool)
|
||||
mask_block = torch.ones(tokens_per_frame, tokens_per_frame, dtype=torch.bool)
|
||||
local_window_time = num_frames
|
||||
|
||||
for current_frame in range(num_frames):
|
||||
first_context_frame = max(0, current_frame - local_window_time + 1)
|
||||
for context_frame in range(first_context_frame, current_frame + 1):
|
||||
row = slice(current_frame * tokens_per_frame, (current_frame + 1) * tokens_per_frame)
|
||||
col = slice(context_frame * tokens_per_frame, (context_frame + 1) * tokens_per_frame)
|
||||
mask[row, col] = mask_block
|
||||
return mask
|
||||
|
||||
|
||||
def rotate_queries_or_keys(x: torch.Tensor, pos: torch.Tensor) -> torch.Tensor:
|
||||
_, _, _, dim = x.size()
|
||||
if dim % 2 != 0:
|
||||
raise ValueError("Embedding dimension must be even for rotary position encoding.")
|
||||
|
||||
omega = torch.arange(dim // 2, dtype=x.dtype, device=x.device)
|
||||
omega /= dim / 2.0
|
||||
omega = 1.0 / 10000**omega
|
||||
freqs = torch.einsum("..., f -> ... f", pos, omega)
|
||||
emb_sin = freqs.sin().squeeze(-1).repeat(1, 1, 1, 2)
|
||||
emb_cos = freqs.cos().squeeze(-1).repeat(1, 1, 1, 2)
|
||||
|
||||
y = x.unflatten(-1, (-1, 2))
|
||||
y1, y2 = y.unbind(dim=-1)
|
||||
y = torch.stack((-y2, y1), dim=-1).flatten(-2)
|
||||
return x * emb_cos + y * emb_sin
|
||||
|
||||
|
||||
class DropPath(nn.Module):
|
||||
def __init__(self, drop_prob: float = 0.0) -> None:
|
||||
super().__init__()
|
||||
self.drop_prob = drop_prob
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if self.drop_prob == 0.0 or not self.training:
|
||||
return x
|
||||
keep_prob = 1 - self.drop_prob
|
||||
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
|
||||
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
||||
random_tensor.floor_()
|
||||
return x.div(keep_prob) * random_tensor
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
hidden_features: int | None = None,
|
||||
out_features: int | None = None,
|
||||
act_layer: type[nn.Module] = nn.GELU,
|
||||
drop: float = 0.0,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.fc1 = nn.Linear(in_features, hidden_features)
|
||||
self.act = act_layer()
|
||||
self.fc2 = nn.Linear(hidden_features, out_features)
|
||||
self.drop = nn.Dropout(drop)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
x = self.drop(x)
|
||||
x = self.fc2(x)
|
||||
x = self.drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class ACRoPEAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int = 8,
|
||||
qkv_bias: bool = False,
|
||||
qk_scale: float | None = None,
|
||||
attn_drop: float = 0.0,
|
||||
proj_drop: float = 0.0,
|
||||
use_sdpa: bool = True,
|
||||
is_causal: bool = False,
|
||||
grid_size: int = 16,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.scale = qk_scale or self.head_dim**-0.5
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop_prob = proj_drop
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
self.use_sdpa = use_sdpa
|
||||
self.d_dim = int(2 * ((self.head_dim // 3) // 2))
|
||||
self.h_dim = int(2 * ((self.head_dim // 3) // 2))
|
||||
self.w_dim = int(2 * ((self.head_dim // 3) // 2))
|
||||
self.grid_size = grid_size
|
||||
self.is_causal = is_causal
|
||||
|
||||
@staticmethod
|
||||
def _get_frame_pos(ids: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
||||
return ids // int(height * width)
|
||||
|
||||
def _get_height_pos(self, ids: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
||||
frame_ids = self._get_frame_pos(ids, height, width)
|
||||
ids = ids - int(height * width) * frame_ids
|
||||
return ids // width
|
||||
|
||||
def separate_positions(
|
||||
self, ids: torch.Tensor, height: int, width: int
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
frame_ids = self._get_frame_pos(ids, height, width)
|
||||
height_ids = self._get_height_pos(ids, height, width)
|
||||
width_ids = ids - int(height * width) * frame_ids - width * height_ids
|
||||
return 1.0 * frame_ids, 1.0 * height_ids, 1.0 * width_ids
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
mask: torch.Tensor | None = None,
|
||||
attn_mask: torch.Tensor | None = None,
|
||||
num_frames: int | None = None,
|
||||
grid_height: int | None = None,
|
||||
grid_width: int | None = None,
|
||||
action_tokens: int = 0,
|
||||
) -> torch.Tensor:
|
||||
batch_size, num_tokens, channels = x.size()
|
||||
if num_frames is None or grid_height is None or grid_width is None:
|
||||
raise ValueError("num_frames, grid_height and grid_width are required.")
|
||||
|
||||
if mask is not None:
|
||||
mask = mask.unsqueeze(1).repeat(1, self.num_heads, 1)
|
||||
d_mask, h_mask, w_mask = self.separate_positions(mask, grid_height, grid_width)
|
||||
else:
|
||||
mask = torch.arange(int(num_frames * grid_height * grid_width), device=x.device)
|
||||
d_mask, h_mask, w_mask = self.separate_positions(mask, grid_height, grid_width)
|
||||
|
||||
h_mask *= self.grid_size / grid_height
|
||||
w_mask *= self.grid_size / grid_width
|
||||
|
||||
if action_tokens > 0:
|
||||
x = x.view(batch_size, -1, action_tokens + grid_height * grid_width, channels)
|
||||
action_q, action_k, action_v = [], [], []
|
||||
for idx in range(action_tokens):
|
||||
action_token = x[:, :, idx : idx + 1, :].flatten(1, 2)
|
||||
qkv = self.qkv(action_token).unflatten(-1, (3, self.num_heads, -1)).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2]
|
||||
qd = rotate_queries_or_keys(
|
||||
q[..., : self.d_dim], pos=torch.arange(num_frames, device=x.device)
|
||||
)
|
||||
kd = rotate_queries_or_keys(
|
||||
k[..., : self.d_dim], pos=torch.arange(num_frames, device=x.device)
|
||||
)
|
||||
qr = q[..., self.d_dim :]
|
||||
kr = k[..., self.d_dim :]
|
||||
action_q.append(
|
||||
torch.cat([qd, qr], dim=-1).view(batch_size, self.num_heads, num_frames, 1, -1)
|
||||
)
|
||||
action_k.append(
|
||||
torch.cat([kd, kr], dim=-1).view(batch_size, self.num_heads, num_frames, 1, -1)
|
||||
)
|
||||
action_v.append(v.view(batch_size, self.num_heads, num_frames, 1, -1))
|
||||
|
||||
action_q = torch.cat(action_q, dim=3).flatten(2, 3)
|
||||
action_k = torch.cat(action_k, dim=3).flatten(2, 3)
|
||||
action_v = torch.cat(action_v, dim=3).flatten(2, 3)
|
||||
x = x[:, :, action_tokens:, :].flatten(1, 2)
|
||||
|
||||
qkv = self.qkv(x).unflatten(-1, (3, self.num_heads, -1)).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2]
|
||||
|
||||
offset = 0
|
||||
qd = rotate_queries_or_keys(q[..., offset : offset + self.d_dim], pos=d_mask)
|
||||
kd = rotate_queries_or_keys(k[..., offset : offset + self.d_dim], pos=d_mask)
|
||||
offset += self.d_dim
|
||||
qh = rotate_queries_or_keys(q[..., offset : offset + self.h_dim], pos=h_mask)
|
||||
kh = rotate_queries_or_keys(k[..., offset : offset + self.h_dim], pos=h_mask)
|
||||
offset += self.h_dim
|
||||
qw = rotate_queries_or_keys(q[..., offset : offset + self.w_dim], pos=w_mask)
|
||||
kw = rotate_queries_or_keys(k[..., offset : offset + self.w_dim], pos=w_mask)
|
||||
offset += self.w_dim
|
||||
|
||||
if offset < self.head_dim:
|
||||
q = torch.cat([qd, qh, qw, q[..., offset:]], dim=-1)
|
||||
k = torch.cat([kd, kh, kw, k[..., offset:]], dim=-1)
|
||||
else:
|
||||
q = torch.cat([qd, qh, qw], dim=-1)
|
||||
k = torch.cat([kd, kh, kw], dim=-1)
|
||||
|
||||
if action_tokens > 0:
|
||||
|
||||
def merge(frame_tokens: torch.Tensor, action_token_values: torch.Tensor) -> torch.Tensor:
|
||||
frame_tokens = frame_tokens.view(
|
||||
batch_size, self.num_heads, num_frames, grid_height * grid_width, -1
|
||||
)
|
||||
action_token_values = action_token_values.view(
|
||||
batch_size, self.num_heads, num_frames, action_tokens, -1
|
||||
)
|
||||
return torch.cat([action_token_values, frame_tokens], dim=3).flatten(2, 3)
|
||||
|
||||
q = merge(q, action_q)
|
||||
k = merge(k, action_k)
|
||||
v = merge(v, action_v)
|
||||
|
||||
if attn_mask is not None or self.use_sdpa:
|
||||
x = F.scaled_dot_product_attention(
|
||||
q, k, v, dropout_p=self.proj_drop_prob, is_causal=self.is_causal, attn_mask=attn_mask
|
||||
)
|
||||
else:
|
||||
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
x = attn @ v
|
||||
|
||||
x = x.transpose(1, 2).reshape(batch_size, num_tokens, channels)
|
||||
x = self.proj(x)
|
||||
return self.proj_drop(x)
|
||||
|
||||
|
||||
class ACBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
qkv_bias: bool = True,
|
||||
qk_scale: float | None = None,
|
||||
drop: float = 0.0,
|
||||
attn_drop: float = 0.0,
|
||||
drop_path: float = 0.0,
|
||||
norm_layer: type[nn.Module] = nn.LayerNorm,
|
||||
use_sdpa: bool = True,
|
||||
is_causal: bool = False,
|
||||
grid_size: int = 16,
|
||||
use_rope: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.norm1 = norm_layer(dim)
|
||||
if not use_rope:
|
||||
raise ValueError("JEVLA1 world predictor uses AC RoPE attention.")
|
||||
self.attn = ACRoPEAttention(
|
||||
dim,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
attn_drop=attn_drop,
|
||||
use_sdpa=use_sdpa,
|
||||
is_causal=is_causal,
|
||||
grid_size=grid_size,
|
||||
proj_drop=drop,
|
||||
)
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
self.norm2 = norm_layer(dim)
|
||||
self.mlp = MLP(
|
||||
in_features=dim,
|
||||
hidden_features=int(dim * mlp_ratio),
|
||||
act_layer=nn.GELU,
|
||||
drop=drop,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
attn_mask: torch.Tensor | None = None,
|
||||
num_frames: int | None = None,
|
||||
grid_height: int | None = None,
|
||||
grid_width: int | None = None,
|
||||
action_tokens: int = 0,
|
||||
) -> torch.Tensor:
|
||||
y = self.norm1(x)
|
||||
y = self.attn(
|
||||
y,
|
||||
mask=None,
|
||||
attn_mask=attn_mask,
|
||||
num_frames=num_frames,
|
||||
grid_height=grid_height,
|
||||
grid_width=grid_width,
|
||||
action_tokens=action_tokens,
|
||||
)
|
||||
x = x + self.drop_path(y)
|
||||
y = self.norm2(x)
|
||||
return x + self.drop_path(self.mlp(y))
|
||||
|
||||
|
||||
class ActionConditionedVideoPredictor(nn.Module):
|
||||
"""JEVLA1-compatible action-conditioned V-JEPA predictor."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_frames: int,
|
||||
img_size: tuple[int, int],
|
||||
patch_size: int,
|
||||
tubelet_size: int,
|
||||
embed_dim: int,
|
||||
action_embed_dim: int,
|
||||
predictor_embed_dim: int,
|
||||
depth: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float,
|
||||
num_action_tokens_per_step: int,
|
||||
use_extrinsics: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.is_frame_causal = True
|
||||
self.use_extrinsics = use_extrinsics
|
||||
self.predictor_embed = nn.Linear(embed_dim, predictor_embed_dim, bias=True)
|
||||
self.action_encoder = nn.Linear(action_embed_dim, predictor_embed_dim, bias=True)
|
||||
self.state_encoder = nn.Linear(action_embed_dim, predictor_embed_dim, bias=True)
|
||||
self.extrinsics_encoder = nn.Linear(action_embed_dim - 1, predictor_embed_dim, bias=True)
|
||||
|
||||
self.img_height, self.img_width = img_size
|
||||
self.patch_size = patch_size
|
||||
self.num_frames = num_frames
|
||||
self.tubelet_size = tubelet_size
|
||||
self.grid_height = self.img_height // self.patch_size
|
||||
self.grid_width = self.img_width // self.patch_size
|
||||
|
||||
self.predictor_blocks = nn.ModuleList(
|
||||
[
|
||||
ACBlock(
|
||||
dim=predictor_embed_dim,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=True,
|
||||
drop=0.0,
|
||||
attn_drop=0.0,
|
||||
drop_path=0.0,
|
||||
norm_layer=lambda dim: nn.LayerNorm(dim, eps=1e-6),
|
||||
grid_size=self.grid_height,
|
||||
use_rope=True,
|
||||
)
|
||||
for _ in range(depth)
|
||||
]
|
||||
)
|
||||
self.predictor_norm = nn.LayerNorm(predictor_embed_dim, eps=1e-6)
|
||||
self.predictor_proj = nn.Linear(predictor_embed_dim, embed_dim, bias=True)
|
||||
self.num_action_tokens_per_step = num_action_tokens_per_step
|
||||
|
||||
@property
|
||||
def norm(self) -> nn.LayerNorm:
|
||||
return self.predictor_norm
|
||||
|
||||
@property
|
||||
def proj(self) -> nn.Linear:
|
||||
return self.predictor_proj
|
||||
|
||||
def forward(
|
||||
self,
|
||||
frame_tokens: torch.Tensor,
|
||||
action_tokens: torch.Tensor,
|
||||
extrinsics: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
# starVLA input convention: frame_tokens [B, T*H*W, D], actions [B, T*A, D].
|
||||
x = self.predictor_embed(frame_tokens)
|
||||
batch_size, num_context_tokens, hidden_dim = x.size()
|
||||
num_frames = num_context_tokens // (self.grid_height * self.grid_width)
|
||||
|
||||
actions = self.action_encoder(action_tokens)
|
||||
actions = actions.view(batch_size, num_frames, -1, hidden_dim)
|
||||
cond_tokens = actions.shape[2]
|
||||
|
||||
x = x.view(batch_size, num_frames, self.grid_height * self.grid_width, hidden_dim)
|
||||
if self.use_extrinsics:
|
||||
if extrinsics is None:
|
||||
raise ValueError("extrinsics are required when use_extrinsics=True.")
|
||||
cond_tokens += 1
|
||||
extrinsic_tokens = self.extrinsics_encoder(extrinsics).unsqueeze(2)
|
||||
x = torch.cat([actions, extrinsic_tokens, x], dim=2).flatten(1, 2)
|
||||
else:
|
||||
x = torch.cat([actions, x], dim=2).flatten(1, 2)
|
||||
|
||||
attn_mask = build_action_block_causal_attention_mask(
|
||||
num_frames, self.grid_height, self.grid_width, add_tokens=cond_tokens
|
||||
)
|
||||
attn_mask = attn_mask[: x.size(1), : x.size(1)].to(x.device, non_blocking=True)
|
||||
|
||||
for block in self.predictor_blocks:
|
||||
x = block(
|
||||
x,
|
||||
attn_mask=attn_mask,
|
||||
num_frames=num_frames,
|
||||
grid_height=self.grid_height,
|
||||
grid_width=self.grid_width,
|
||||
action_tokens=cond_tokens,
|
||||
)
|
||||
|
||||
x = x.view(batch_size, num_frames, cond_tokens + self.grid_height * self.grid_width, hidden_dim)
|
||||
x = x[:, :, cond_tokens:, :].flatten(1, 2)
|
||||
x = self.predictor_norm(x)
|
||||
return self.predictor_proj(x)
|
||||
@@ -81,7 +81,7 @@ def to_absolute_actions(actions: Tensor, state: Tensor, mask: Sequence[bool]) ->
|
||||
return actions
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("delta_actions_processor")
|
||||
@ProcessorStepRegistry.register("relative_actions_processor")
|
||||
@dataclass
|
||||
class RelativeActionsProcessorStep(ProcessorStep):
|
||||
"""Converts absolute actions to relative actions (action -= state) for masked dimensions.
|
||||
|
||||
@@ -23,6 +23,7 @@ from .configs import (
|
||||
DAggerKeyboardConfig,
|
||||
DAggerPedalConfig,
|
||||
DAggerStrategyConfig,
|
||||
EpisodicStrategyConfig,
|
||||
HighlightStrategyConfig,
|
||||
RolloutConfig,
|
||||
RolloutStrategyConfig,
|
||||
@@ -49,6 +50,7 @@ from .inference import (
|
||||
from .strategies import (
|
||||
BaseStrategy,
|
||||
DAggerStrategy,
|
||||
EpisodicStrategy,
|
||||
HighlightStrategy,
|
||||
RolloutStrategy,
|
||||
SentryStrategy,
|
||||
@@ -66,6 +68,8 @@ __all__ = [
|
||||
"HardwareContext",
|
||||
"HighlightStrategy",
|
||||
"HighlightStrategyConfig",
|
||||
"EpisodicStrategy",
|
||||
"EpisodicStrategyConfig",
|
||||
"InferenceEngine",
|
||||
"InferenceEngineConfig",
|
||||
"PolicyContext",
|
||||
|
||||
@@ -121,6 +121,35 @@ class DAggerPedalConfig:
|
||||
upload: str = "KEY_C"
|
||||
|
||||
|
||||
@RolloutStrategyConfig.register_subclass("episodic")
|
||||
@dataclass
|
||||
class EpisodicStrategyConfig(RolloutStrategyConfig):
|
||||
"""Episode-oriented recording that mirrors the behavior of ``lerobot-record``.
|
||||
|
||||
Records ``dataset.num_episodes`` episodes of maximum ``dataset.episode_time_s`` each.
|
||||
After each episode, runs ``dataset.reset_time_s`` seconds of reset time.
|
||||
|
||||
Keyboard controls:
|
||||
Right arrow — end current episode or reset phase early
|
||||
Left arrow — discard current episode and re-record
|
||||
Escape — stop recording session
|
||||
|
||||
In between episodes:
|
||||
- if there is no teleop leader, the robot is held at its initial joint positions captured at startup.
|
||||
- else, the robot is moved smoothly to the position of the teleop leader.
|
||||
"""
|
||||
|
||||
# This only applies if there are no teleop leaders specified.
|
||||
# When True (default), moves the robot back to the joint positions captured at startup.
|
||||
# Otherwise, leave the robot in its current position.
|
||||
reset_to_initial_position: bool = True
|
||||
|
||||
# Whether to turn on or off the leader -> follower smooth handover behavior.
|
||||
# When False, fallback to follower -> leader handover.
|
||||
# Note that leader -> follower handover is only supported when the leader has `send_feedback` capability.
|
||||
smooth_leader_to_follower_handover: bool = True
|
||||
|
||||
|
||||
@RolloutStrategyConfig.register_subclass("dagger")
|
||||
@dataclass
|
||||
class DAggerStrategyConfig(RolloutStrategyConfig):
|
||||
@@ -229,7 +258,13 @@ class RolloutConfig:
|
||||
|
||||
# TODO(Steven): DAgger shouldn't require a dataset (user may want to just rollout+intervene without recording), but for now we require it to simplify the implementation.
|
||||
needs_dataset = isinstance(
|
||||
self.strategy, (SentryStrategyConfig, HighlightStrategyConfig, DAggerStrategyConfig)
|
||||
self.strategy,
|
||||
(
|
||||
SentryStrategyConfig,
|
||||
HighlightStrategyConfig,
|
||||
DAggerStrategyConfig,
|
||||
EpisodicStrategyConfig,
|
||||
),
|
||||
)
|
||||
if needs_dataset and (self.dataset is None or not self.dataset.repo_id):
|
||||
raise ValueError(f"{self.strategy.type} strategy requires --dataset.repo_id to be set")
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
from .base import BaseStrategy
|
||||
from .core import RolloutStrategy, estimate_max_episode_seconds, safe_push_to_hub, send_next_action
|
||||
from .dagger import DAggerEvents, DAggerPhase, DAggerStrategy
|
||||
from .episodic import EpisodicStrategy
|
||||
from .factory import create_strategy
|
||||
from .highlight import HighlightStrategy
|
||||
from .sentry import SentryStrategy
|
||||
@@ -27,6 +28,7 @@ __all__ = [
|
||||
"DAggerPhase",
|
||||
"DAggerStrategy",
|
||||
"HighlightStrategy",
|
||||
"EpisodicStrategy",
|
||||
"RolloutStrategy",
|
||||
"SentryStrategy",
|
||||
"create_strategy",
|
||||
|
||||
@@ -56,10 +56,14 @@ from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lerobot.common.control_utils import is_headless
|
||||
from lerobot.common.control_utils import (
|
||||
follower_smooth_move_to,
|
||||
is_headless,
|
||||
teleop_smooth_move_to,
|
||||
teleop_supports_feedback,
|
||||
)
|
||||
from lerobot.datasets import VideoEncodingManager
|
||||
from lerobot.datasets.utils import DEFAULT_VIDEO_FILE_SIZE_IN_MB
|
||||
from lerobot.teleoperators import Teleoperator
|
||||
from lerobot.utils.constants import ACTION, OBS_STR
|
||||
from lerobot.utils.feature_utils import build_dataset_frame
|
||||
from lerobot.utils.import_utils import _pynput_available
|
||||
@@ -69,7 +73,6 @@ from lerobot.utils.utils import log_say
|
||||
|
||||
from ..configs import DAggerKeyboardConfig, DAggerPedalConfig, DAggerStrategyConfig
|
||||
from ..context import RolloutContext
|
||||
from ..robot_wrapper import ThreadSafeRobot
|
||||
from .core import RolloutStrategy, estimate_max_episode_seconds, safe_push_to_hub, send_next_action
|
||||
|
||||
PYNPUT_AVAILABLE = _pynput_available
|
||||
@@ -171,64 +174,6 @@ class DAggerEvents:
|
||||
self.upload_requested.clear()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Teleoperator helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _teleop_supports_feedback(teleop: Teleoperator) -> bool:
|
||||
"""Return True when the teleop can receive position feedback (is actuated).
|
||||
TODO(Maxime): See if it is possible to unify this interface across teleops instead of duck-typing.
|
||||
"""
|
||||
return (
|
||||
bool(teleop.feedback_features)
|
||||
and hasattr(teleop, "disable_torque")
|
||||
and hasattr(teleop, "enable_torque")
|
||||
)
|
||||
|
||||
|
||||
def _teleop_smooth_move_to(
|
||||
teleop: Teleoperator, target_pos: dict, duration_s: float = 2.0, fps: int = 30
|
||||
) -> None:
|
||||
"""Smoothly move an actuated teleop to ``target_pos`` via linear interpolation.
|
||||
|
||||
Requires the teleoperator to support feedback
|
||||
(i.e. have non-empty ``feedback_features`` and implement ``disable_torque`` / ``enable_torque``).
|
||||
|
||||
TODO(Maxime): This blocks up to ``duration_s`` seconds, during this time
|
||||
the follower robot doesn't receive new actions, this could be an issue on LeKiwi.
|
||||
"""
|
||||
teleop.enable_torque()
|
||||
current = teleop.get_action()
|
||||
steps = max(int(duration_s * fps), 1)
|
||||
|
||||
for step in range(steps + 1):
|
||||
t = step / steps
|
||||
interp = {
|
||||
k: current[k] * (1 - t) + target_pos[k] * t if k in target_pos else current[k] for k in current
|
||||
}
|
||||
teleop.send_feedback(interp)
|
||||
time.sleep(1 / fps)
|
||||
|
||||
|
||||
def _follower_smooth_move_to(
|
||||
robot: ThreadSafeRobot, current: dict, target: dict, duration_s: float = 1.0, fps: int = 30
|
||||
) -> None:
|
||||
"""Smoothly move the follower robot from ``current`` to ``target`` action.
|
||||
|
||||
Used when the teleop is non-actuated: instead of driving the leader arm
|
||||
to the follower, we bring the follower to the teleop's current pose.
|
||||
Both ``current`` and ``target`` must be in robot-action key space.
|
||||
"""
|
||||
steps = max(int(duration_s * fps), 1)
|
||||
|
||||
for step in range(steps + 1):
|
||||
t = step / steps
|
||||
interp = {k: current[k] * (1 - t) + target[k] * t if k in target else current[k] for k in current}
|
||||
robot.send_action(interp)
|
||||
time.sleep(1 / fps)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Input device handlers
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -756,31 +701,31 @@ class DAggerStrategy(RolloutStrategy):
|
||||
logger.info("Pausing engine - robot holds position")
|
||||
engine.pause()
|
||||
|
||||
if _teleop_supports_feedback(teleop) and prev_action is not None:
|
||||
if teleop_supports_feedback(teleop) and prev_action is not None:
|
||||
# TODO(Maxime): prev_action is in robot action key space (output of robot_action_processor).
|
||||
# send_feedback expects teleop feedback key space. For homogeneous setups (e.g. SO-101
|
||||
# leader + SO-101 follower) the keys are identical so this works. If the processor pipeline
|
||||
# does non-trivial key renaming (e.g. a rename_map on action keys), the interpolation in
|
||||
# _teleop_smooth_move_to silently no-ops and the arm doesn't move.
|
||||
# teleop_smooth_move_to silently no-ops and the arm doesn't move.
|
||||
logger.info("Smooth handover: moving leader arm to follower position")
|
||||
_teleop_smooth_move_to(teleop, prev_action)
|
||||
teleop_smooth_move_to(teleop, prev_action)
|
||||
|
||||
elif old_phase == DAggerPhase.PAUSED and new_phase == DAggerPhase.CORRECTING:
|
||||
logger.info("Entering correction mode - human teleop control")
|
||||
if not _teleop_supports_feedback(teleop) and prev_action is not None:
|
||||
if not teleop_supports_feedback(teleop) and prev_action is not None:
|
||||
logger.info("Smooth handover: sliding follower to teleop position")
|
||||
obs = robot.get_observation()
|
||||
teleop_action = teleop.get_action()
|
||||
processed = ctx.processors.teleop_action_processor((teleop_action, obs))
|
||||
target = ctx.processors.robot_action_processor((processed, obs))
|
||||
_follower_smooth_move_to(robot, prev_action, target)
|
||||
follower_smooth_move_to(robot, prev_action, target)
|
||||
|
||||
# unlock the teleop for human control
|
||||
if _teleop_supports_feedback(teleop):
|
||||
if teleop_supports_feedback(teleop):
|
||||
teleop.disable_torque()
|
||||
|
||||
elif old_phase == DAggerPhase.CORRECTING and new_phase == DAggerPhase.PAUSED:
|
||||
if _teleop_supports_feedback(teleop):
|
||||
if teleop_supports_feedback(teleop):
|
||||
teleop.enable_torque()
|
||||
|
||||
elif new_phase == DAggerPhase.AUTONOMOUS:
|
||||
@@ -790,7 +735,7 @@ class DAggerStrategy(RolloutStrategy):
|
||||
engine.resume()
|
||||
|
||||
# release teleop before resuming the policy
|
||||
if _teleop_supports_feedback(teleop):
|
||||
if teleop_supports_feedback(teleop):
|
||||
teleop.disable_torque()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@@ -0,0 +1,335 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Episodic rollout strategy: mirrors the behavior of ``lerobot-record``.
|
||||
|
||||
- Policy drives the robot during each recording episode.
|
||||
- An optional teleoperator can drive the robot during reset phases so the
|
||||
operator can bring the environment back to its starting configuration.
|
||||
If no teleop is connected the robot stays in its current position.
|
||||
- Keyboard controls:
|
||||
|
||||
Right arrow — end the current episode or reset phase early
|
||||
Left arrow — discard the current episode and re-record it
|
||||
Escape — stop the recording session
|
||||
|
||||
Dataset naming follows the rollout convention: repo names must start with ``rollout_``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
import time
|
||||
|
||||
from lerobot.common.control_utils import (
|
||||
follower_smooth_move_to,
|
||||
init_keyboard_listener,
|
||||
is_headless,
|
||||
teleop_smooth_move_to,
|
||||
teleop_supports_feedback,
|
||||
)
|
||||
from lerobot.datasets import VideoEncodingManager
|
||||
from lerobot.utils.constants import ACTION, OBS_STR
|
||||
from lerobot.utils.feature_utils import build_dataset_frame
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import log_rerun_data
|
||||
|
||||
from ..configs import EpisodicStrategyConfig
|
||||
from ..context import RolloutContext
|
||||
from .core import RolloutStrategy, safe_push_to_hub, send_next_action
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EpisodicStrategy(RolloutStrategy):
|
||||
"""Policy-driven multi-episode recording, mirrors the behavior of ``lerobot-record``.
|
||||
|
||||
Each recording episode runs the policy for maximum ``dataset.episode_time_s``
|
||||
seconds, recording every frame. A reset phase of ``dataset.reset_time_s``
|
||||
follows every episode (except the last) so the operator can manually
|
||||
reset the environment. During the reset phase, an optional teleoperator
|
||||
drives the robot; if none is present the robot returns to its initial joint positions captured at startup.
|
||||
|
||||
The policy state (hidden state, RTC queue, interpolator) is reset at
|
||||
the start of each recording episode.
|
||||
|
||||
Keyboard events:
|
||||
right arrow → end current episode or reset phase early
|
||||
left arrow → discard & re-record current episode
|
||||
ESC → stop the session
|
||||
"""
|
||||
|
||||
config: EpisodicStrategyConfig
|
||||
|
||||
def __init__(self, config: EpisodicStrategyConfig) -> None:
|
||||
super().__init__(config)
|
||||
self._listener = None
|
||||
self._events: dict | None = None
|
||||
|
||||
def setup(self, ctx: RolloutContext) -> None:
|
||||
"""Start the inference engine and attach the keyboard listener."""
|
||||
self._init_engine(ctx)
|
||||
self._listener, self._events = init_keyboard_listener()
|
||||
logger.info("Episodic strategy ready")
|
||||
|
||||
def run(self, ctx: RolloutContext) -> None:
|
||||
"""Main multi-episode recording loop."""
|
||||
cfg = ctx.runtime.cfg
|
||||
dataset_cfg = cfg.dataset
|
||||
robot = ctx.hardware.robot_wrapper
|
||||
teleop = ctx.hardware.teleop
|
||||
dataset = ctx.data.dataset
|
||||
events = self._events
|
||||
features = ctx.data.dataset_features
|
||||
|
||||
fps = cfg.fps
|
||||
episode_time_s = dataset_cfg.episode_time_s
|
||||
reset_time_s = dataset_cfg.reset_time_s
|
||||
num_episodes = dataset_cfg.num_episodes
|
||||
single_task = dataset_cfg.single_task or cfg.task
|
||||
play_sounds = cfg.play_sounds
|
||||
|
||||
display_compressed = (
|
||||
True
|
||||
if (cfg.display_data and cfg.display_ip is not None and cfg.display_port is not None)
|
||||
else cfg.display_compressed_images
|
||||
)
|
||||
|
||||
with VideoEncodingManager(dataset):
|
||||
try:
|
||||
recorded_episodes = 0
|
||||
while recorded_episodes < num_episodes and not events["stop_recording"]:
|
||||
if ctx.runtime.shutdown_event.is_set():
|
||||
break
|
||||
|
||||
# Reset policy state at episode start (discard leftover hidden state / queue)
|
||||
self._engine.reset()
|
||||
self._interpolator.reset()
|
||||
self._engine.resume()
|
||||
|
||||
log_say(f"Recording episode {dataset.num_episodes}", play_sounds)
|
||||
self._policy_loop(
|
||||
ctx=ctx,
|
||||
robot=robot,
|
||||
events=events,
|
||||
features=features,
|
||||
fps=fps,
|
||||
control_time_s=episode_time_s,
|
||||
dataset=dataset,
|
||||
single_task=single_task,
|
||||
)
|
||||
|
||||
# Reset phase, skip after the last episode (but run when re-recording)
|
||||
if not events["stop_recording"] and (
|
||||
recorded_episodes < num_episodes - 1 or events["rerecord_episode"]
|
||||
):
|
||||
log_say("Reset the environment", play_sounds)
|
||||
|
||||
if teleop:
|
||||
# Smooth handover so the transition to teleop control is jerk-free.
|
||||
# For actuated teleops: drive the leader arm to the follower's current
|
||||
# position so the operator takes over without fighting the arm.
|
||||
# For non-actuated teleops: slide the follower to the teleop's current
|
||||
# pose instead, since the leader cannot be driven.
|
||||
obs = robot.get_observation()
|
||||
current_pos = {k: v for k, v in obs.items() if k.endswith(".pos")}
|
||||
if (
|
||||
teleop_supports_feedback(teleop)
|
||||
and self.config.smooth_leader_to_follower_handover
|
||||
):
|
||||
logger.info("Smooth handover: moving leader arm to follower position")
|
||||
teleop_smooth_move_to(teleop, current_pos, duration_s=2)
|
||||
teleop.disable_torque()
|
||||
else:
|
||||
logger.info("Smooth handover: sliding follower to teleop position")
|
||||
teleop_action = teleop.get_action()
|
||||
processed = ctx.processors.teleop_action_processor((teleop_action, obs))
|
||||
target = ctx.processors.robot_action_processor((processed, obs))
|
||||
follower_smooth_move_to(robot, current_pos, target, duration_s=1)
|
||||
|
||||
elif self.config.reset_to_initial_position:
|
||||
# No teleop: return the robot to its startup position.
|
||||
self._return_to_initial_position(hw=ctx.hardware, duration_s=1)
|
||||
|
||||
self._reset_loop(
|
||||
ctx=ctx,
|
||||
robot=robot,
|
||||
teleop=teleop,
|
||||
events=events,
|
||||
fps=fps,
|
||||
control_time_s=reset_time_s,
|
||||
display_data=cfg.display_data,
|
||||
display_compressed=display_compressed,
|
||||
)
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-record episode", play_sounds)
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
|
||||
# returns to its initial joint positions captured at startup
|
||||
if not teleop and self.config.reset_to_initial_position:
|
||||
self._return_to_initial_position(hw=ctx.hardware, duration_s=1)
|
||||
|
||||
continue
|
||||
|
||||
dataset.save_episode()
|
||||
recorded_episodes += 1
|
||||
finally:
|
||||
# Save any frames buffered in the current episode so an unexpected
|
||||
# exception or KeyboardInterrupt does not silently drop recorded data.
|
||||
# suppress: save_episode raises if the buffer is empty (nothing to lose).
|
||||
logger.info("Episodic control loop ended — saving any in-progress episode")
|
||||
with contextlib.suppress(Exception):
|
||||
dataset.save_episode()
|
||||
|
||||
def _policy_loop(
|
||||
self,
|
||||
ctx: RolloutContext,
|
||||
robot,
|
||||
events: dict,
|
||||
features: dict,
|
||||
fps: float,
|
||||
control_time_s: float,
|
||||
dataset,
|
||||
single_task: str,
|
||||
) -> None:
|
||||
"""Policy-driven recording loop for a single episode."""
|
||||
interpolator = self._interpolator
|
||||
control_interval = interpolator.get_control_interval(fps)
|
||||
|
||||
timestamp = 0.0
|
||||
start_t = time.perf_counter()
|
||||
|
||||
while timestamp < control_time_s:
|
||||
loop_start = time.perf_counter()
|
||||
|
||||
if events["exit_early"]:
|
||||
events["exit_early"] = False
|
||||
break
|
||||
|
||||
if ctx.runtime.shutdown_event.is_set():
|
||||
break
|
||||
|
||||
obs = robot.get_observation()
|
||||
obs_processed = self._process_observation_and_notify(ctx.processors, obs)
|
||||
|
||||
if self._handle_warmup(ctx.runtime.cfg.use_torch_compile, loop_start, control_interval):
|
||||
continue
|
||||
|
||||
action_dict = send_next_action(obs_processed, obs, ctx, interpolator)
|
||||
|
||||
if action_dict is not None:
|
||||
obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR)
|
||||
action_frame = build_dataset_frame(features, action_dict, prefix=ACTION)
|
||||
dataset.add_frame({**obs_frame, **action_frame, "task": single_task})
|
||||
self._log_telemetry(obs_processed, action_dict, ctx.runtime)
|
||||
|
||||
dt = time.perf_counter() - loop_start
|
||||
sleep_t = control_interval - dt
|
||||
if sleep_t < 0:
|
||||
logger.warning(
|
||||
f"Record loop is running slower ({1 / dt:.1f} Hz) than the target FPS ({fps} Hz). "
|
||||
"Dataset frames might be dropped and robot control might be unstable. "
|
||||
"Common causes are: 1) Camera FPS not keeping up 2) Policy inference taking too long "
|
||||
"3) CPU starvation"
|
||||
)
|
||||
precise_sleep(max(sleep_t, 0.0))
|
||||
timestamp = time.perf_counter() - start_t
|
||||
|
||||
def _reset_loop(
|
||||
self,
|
||||
ctx: RolloutContext,
|
||||
robot,
|
||||
teleop,
|
||||
events: dict,
|
||||
fps: float,
|
||||
control_time_s: float,
|
||||
display_data: bool,
|
||||
display_compressed: bool,
|
||||
) -> None:
|
||||
"""Reset-phase loop: teleop drives the robot if available, no recording."""
|
||||
processors = ctx.processors
|
||||
control_interval = 1.0 / fps
|
||||
|
||||
timestamp = 0.0
|
||||
start_t = time.perf_counter()
|
||||
|
||||
while timestamp < control_time_s:
|
||||
loop_start = time.perf_counter()
|
||||
|
||||
if events["exit_early"]:
|
||||
events["exit_early"] = False
|
||||
break
|
||||
|
||||
if ctx.runtime.shutdown_event.is_set():
|
||||
break
|
||||
|
||||
obs = robot.get_observation()
|
||||
|
||||
if teleop is not None:
|
||||
act = teleop.get_action()
|
||||
act_teleop = processors.teleop_action_processor((act, obs))
|
||||
robot_action = processors.robot_action_processor((act_teleop, obs))
|
||||
robot.send_action(robot_action)
|
||||
|
||||
if display_data:
|
||||
obs_processed = processors.robot_observation_processor(obs)
|
||||
log_rerun_data(
|
||||
observation=obs_processed,
|
||||
action=act_teleop,
|
||||
compress_images=display_compressed,
|
||||
)
|
||||
|
||||
dt = time.perf_counter() - loop_start
|
||||
sleep_t = control_interval - dt
|
||||
precise_sleep(max(sleep_t, 0.0))
|
||||
timestamp = time.perf_counter() - start_t
|
||||
|
||||
def teardown(self, ctx: RolloutContext) -> None:
|
||||
"""Finalise dataset, stop listener, push to hub, and disconnect hardware."""
|
||||
cfg = ctx.runtime.cfg
|
||||
play_sounds = cfg.play_sounds
|
||||
|
||||
log_say("Stop recording", play_sounds, blocking=True)
|
||||
|
||||
if not is_headless() and self._listener is not None:
|
||||
self._listener.stop()
|
||||
|
||||
if ctx.data.dataset is not None:
|
||||
logger.info("Finalizing dataset...")
|
||||
ctx.data.dataset.finalize()
|
||||
|
||||
if (
|
||||
cfg.dataset is not None
|
||||
and cfg.dataset.push_to_hub
|
||||
and ctx.data.dataset is not None
|
||||
and safe_push_to_hub(
|
||||
ctx.data.dataset,
|
||||
tags=cfg.dataset.tags,
|
||||
private=cfg.dataset.private,
|
||||
)
|
||||
):
|
||||
logger.info("Dataset uploaded to hub")
|
||||
log_say("Dataset uploaded to hub", play_sounds)
|
||||
|
||||
self._teardown_hardware(
|
||||
ctx.hardware,
|
||||
return_to_initial_position=cfg.return_to_initial_position,
|
||||
)
|
||||
log_say("Exiting", play_sounds)
|
||||
logger.info("Episodic strategy teardown complete")
|
||||
@@ -21,6 +21,7 @@ from typing import TYPE_CHECKING
|
||||
from .base import BaseStrategy
|
||||
from .core import RolloutStrategy
|
||||
from .dagger import DAggerStrategy
|
||||
from .episodic import EpisodicStrategy
|
||||
from .highlight import HighlightStrategy
|
||||
from .sentry import SentryStrategy
|
||||
|
||||
@@ -42,4 +43,8 @@ def create_strategy(config: RolloutStrategyConfig) -> RolloutStrategy:
|
||||
return HighlightStrategy(config)
|
||||
if config.type == "dagger":
|
||||
return DAggerStrategy(config)
|
||||
raise ValueError(f"Unknown strategy type '{config.type}'. Available: base, sentry, highlight, dagger")
|
||||
if config.type == "episodic":
|
||||
return EpisodicStrategy(config)
|
||||
raise ValueError(
|
||||
f"Unknown strategy type '{config.type}'. Available: base, sentry, highlight, dagger, episodic"
|
||||
)
|
||||
|
||||
@@ -105,6 +105,7 @@ def rollout(
|
||||
seeds: list[int] | None = None,
|
||||
return_observations: bool = False,
|
||||
render_callback: Callable[[gym.vector.VectorEnv], None] | None = None,
|
||||
predicted_latents_callback: Callable[[PreTrainedPolicy], None] | None = None,
|
||||
) -> dict:
|
||||
"""Run a batched policy rollout once through a batch of environments.
|
||||
|
||||
@@ -134,6 +135,9 @@ def rollout(
|
||||
are returned optionally because they typically take more memory to cache. Defaults to False.
|
||||
render_callback: Optional rendering callback to be used after the environments are reset, and after
|
||||
every step.
|
||||
predicted_latents_callback: Optional callback invoked after every ``select_action`` with the policy
|
||||
itself. World-model policies (e.g. LingBot-VA) stash predicted video latents on
|
||||
``policy.last_predicted_latents``; this lets the caller concatenate chunks and decode once.
|
||||
Returns:
|
||||
The dictionary described above.
|
||||
"""
|
||||
@@ -184,6 +188,8 @@ def rollout(
|
||||
observation = preprocessor(observation)
|
||||
with torch.inference_mode():
|
||||
action = policy.select_action(observation)
|
||||
if predicted_latents_callback is not None:
|
||||
predicted_latents_callback(policy)
|
||||
action = postprocessor(action)
|
||||
|
||||
action_transition = {ACTION: action}
|
||||
@@ -203,12 +209,22 @@ def rollout(
|
||||
# available if none of the envs finished.
|
||||
if "final_info" in info:
|
||||
final_info = info["final_info"]
|
||||
if not isinstance(final_info, dict):
|
||||
raise RuntimeError(
|
||||
"Unsupported `final_info` format: expected dict (Gymnasium >= 1.0). "
|
||||
"You're likely using an older version of gymnasium (< 1.0). Please upgrade."
|
||||
if isinstance(final_info, dict):
|
||||
is_success = final_info.get("is_success", [False] * env.num_envs)
|
||||
successes = (
|
||||
is_success.tolist()
|
||||
if hasattr(is_success, "tolist")
|
||||
else [bool(is_success)] * env.num_envs
|
||||
)
|
||||
successes = final_info["is_success"].tolist()
|
||||
else:
|
||||
# Gymnasium < 1.0 returns final_info as a per-env sequence/object array,
|
||||
# with entries set to a dict only for envs that just finished.
|
||||
successes = []
|
||||
for item in final_info:
|
||||
if isinstance(item, dict) and "is_success" in item:
|
||||
successes.append(bool(item["is_success"]))
|
||||
else:
|
||||
successes.append(False)
|
||||
elif "is_success" in info:
|
||||
is_success = info["is_success"]
|
||||
successes = (
|
||||
@@ -273,6 +289,7 @@ def eval_policy(
|
||||
videos_dir: Path | None = None,
|
||||
return_episode_data: bool = False,
|
||||
start_seed: int | None = None,
|
||||
save_predicted_video: bool = False,
|
||||
) -> dict:
|
||||
"""
|
||||
Args:
|
||||
@@ -291,6 +308,11 @@ def eval_policy(
|
||||
if max_episodes_rendered > 0 and not videos_dir:
|
||||
raise ValueError("If max_episodes_rendered > 0, videos_dir must be provided.")
|
||||
|
||||
# World-model policies (e.g. LingBot-VA) opt into predicted-video saving via their config.
|
||||
save_predicted_video = save_predicted_video or bool(
|
||||
getattr(getattr(policy, "config", None), "save_predicted_video", False)
|
||||
)
|
||||
|
||||
if not isinstance(policy, PreTrainedPolicy):
|
||||
exc = ValueError(
|
||||
f"Policy of type 'PreTrainedPolicy' is expected, but type '{type(policy)}' was provided."
|
||||
@@ -334,6 +356,22 @@ def eval_policy(
|
||||
if max_episodes_rendered > 0:
|
||||
video_paths: list[str] = []
|
||||
|
||||
if save_predicted_video:
|
||||
if not videos_dir:
|
||||
raise ValueError("If save_predicted_video is True, videos_dir must be provided.")
|
||||
predicted_video_paths: list[str] = []
|
||||
n_predicted_rendered = 0
|
||||
|
||||
# Collect predicted-video latents across a rollout (world-model policies only). The latents are
|
||||
# concatenated and decoded once after the rollout, matching upstream LingBot-VA's visualization path.
|
||||
def collect_predicted_latents(policy: PreTrainedPolicy):
|
||||
latents = getattr(policy, "last_predicted_latents", None)
|
||||
if latents is not None:
|
||||
pred_latents.append(
|
||||
latents.detach().to("cpu") if hasattr(latents, "detach") else torch.as_tensor(latents).cpu()
|
||||
)
|
||||
policy.last_predicted_latents = None
|
||||
|
||||
if return_episode_data:
|
||||
episode_data: dict | None = None
|
||||
|
||||
@@ -345,6 +383,9 @@ def eval_policy(
|
||||
if max_episodes_rendered > 0:
|
||||
ep_frames: list[np.ndarray] = []
|
||||
|
||||
if save_predicted_video:
|
||||
pred_latents: list[torch.Tensor] = []
|
||||
|
||||
if start_seed is None:
|
||||
seeds = None
|
||||
else:
|
||||
@@ -361,6 +402,7 @@ def eval_policy(
|
||||
seeds=list(seeds) if seeds else None,
|
||||
return_observations=return_episode_data,
|
||||
render_callback=render_frame if max_episodes_rendered > 0 else None,
|
||||
predicted_latents_callback=collect_predicted_latents if save_predicted_video else None,
|
||||
)
|
||||
|
||||
# Figure out where in each rollout sequence the first done condition was encountered (results after
|
||||
@@ -426,6 +468,35 @@ def eval_policy(
|
||||
threads.append(thread)
|
||||
n_episodes_rendered += 1
|
||||
|
||||
# Maybe save the policy's predicted (imagined) video for this batch's rollout.
|
||||
if save_predicted_video and len(pred_latents) > 0:
|
||||
predicted_latent = torch.cat(pred_latents, dim=2)
|
||||
decoder = getattr(policy, "decode_predicted_latents", None) or getattr(
|
||||
policy, "_decode_predicted_video", None
|
||||
)
|
||||
if decoder is None:
|
||||
raise AttributeError(
|
||||
"Policy config requested predicted-video saving, but the policy does not expose "
|
||||
"`decode_predicted_latents` or `_decode_predicted_video`."
|
||||
)
|
||||
predicted_video = decoder(predicted_latent)
|
||||
if hasattr(predicted_video, "detach"):
|
||||
predicted_video = predicted_video.detach().to("cpu").numpy()
|
||||
videos_dir.mkdir(parents=True, exist_ok=True)
|
||||
predicted_video_path = videos_dir / f"pred_episode_{n_predicted_rendered}.mp4"
|
||||
predicted_video_paths.append(str(predicted_video_path))
|
||||
thread = threading.Thread(
|
||||
target=write_video,
|
||||
args=(
|
||||
str(predicted_video_path),
|
||||
predicted_video,
|
||||
env.unwrapped.metadata["render_fps"],
|
||||
),
|
||||
)
|
||||
thread.start()
|
||||
threads.append(thread)
|
||||
n_predicted_rendered += 1
|
||||
|
||||
progbar.set_postfix(
|
||||
{"running_success_rate": f"{np.mean(all_successes[:n_episodes]).item() * 100:.1f}%"}
|
||||
)
|
||||
@@ -469,6 +540,9 @@ def eval_policy(
|
||||
if max_episodes_rendered > 0:
|
||||
info["video_paths"] = video_paths
|
||||
|
||||
if save_predicted_video:
|
||||
info["predicted_video_paths"] = predicted_video_paths
|
||||
|
||||
return info
|
||||
|
||||
|
||||
@@ -600,9 +674,10 @@ class TaskMetrics(TypedDict):
|
||||
max_rewards: list[float]
|
||||
successes: list[bool]
|
||||
video_paths: list[str]
|
||||
predicted_video_paths: list[str]
|
||||
|
||||
|
||||
ACC_KEYS = ("sum_rewards", "max_rewards", "successes", "video_paths")
|
||||
ACC_KEYS = ("sum_rewards", "max_rewards", "successes", "video_paths", "predicted_video_paths")
|
||||
|
||||
|
||||
def eval_one(
|
||||
@@ -643,6 +718,7 @@ def eval_one(
|
||||
max_rewards=[ep["max_reward"] for ep in per_episode],
|
||||
successes=[ep["success"] for ep in per_episode],
|
||||
video_paths=task_result.get("video_paths", []),
|
||||
predicted_video_paths=task_result.get("predicted_video_paths", []),
|
||||
)
|
||||
|
||||
|
||||
@@ -689,6 +765,7 @@ def run_one(
|
||||
# ensure we always provide video_paths key to simplify accumulation
|
||||
if max_episodes_rendered > 0:
|
||||
metrics.setdefault("video_paths", [])
|
||||
metrics.setdefault("predicted_video_paths", [])
|
||||
return task_group, task_id, metrics
|
||||
|
||||
|
||||
@@ -742,11 +819,11 @@ def eval_policy_all(
|
||||
_append("sum_rewards", metrics.get("sum_rewards"))
|
||||
_append("max_rewards", metrics.get("max_rewards"))
|
||||
_append("successes", metrics.get("successes"))
|
||||
# video_paths is list-like
|
||||
paths = metrics.get("video_paths", [])
|
||||
if paths:
|
||||
group_acc[group]["video_paths"].extend(paths)
|
||||
overall["video_paths"].extend(paths)
|
||||
for key in ("video_paths", "predicted_video_paths"):
|
||||
paths = metrics.get(key, [])
|
||||
if paths:
|
||||
group_acc[group][key].extend(paths)
|
||||
overall[key].extend(paths)
|
||||
|
||||
# Choose runner (sequential vs threaded)
|
||||
task_runner = partial(
|
||||
@@ -814,6 +891,7 @@ def eval_policy_all(
|
||||
"pc_success": _agg_from_list(acc["successes"]) * 100 if acc["successes"] else float("nan"),
|
||||
"n_episodes": len(acc["sum_rewards"]),
|
||||
"video_paths": list(acc["video_paths"]),
|
||||
"predicted_video_paths": list(acc["predicted_video_paths"]),
|
||||
}
|
||||
|
||||
# overall aggregates
|
||||
@@ -825,6 +903,7 @@ def eval_policy_all(
|
||||
"eval_s": time.time() - start_t,
|
||||
"eval_ep_s": (time.time() - start_t) / max(1, len(overall["sum_rewards"])),
|
||||
"video_paths": list(overall["video_paths"]),
|
||||
"predicted_video_paths": list(overall["predicted_video_paths"]),
|
||||
}
|
||||
|
||||
return {
|
||||
|
||||
@@ -25,6 +25,7 @@ Strategies
|
||||
--strategy.type=sentry Continuous recording with auto-upload
|
||||
--strategy.type=highlight Ring buffer + keystroke save
|
||||
--strategy.type=dagger Human-in-the-loop (DAgger / RaC)
|
||||
--strategy.type=episodic Episode-oriented recording with reset phases
|
||||
|
||||
Inference backends
|
||||
------------------
|
||||
@@ -111,6 +112,18 @@ Usage examples
|
||||
--display_data=true \\
|
||||
--use_torch_compile=true
|
||||
|
||||
# Episodic mode — episode-oriented recording with reset phases
|
||||
lerobot-rollout \\
|
||||
--strategy.type=episodic \\
|
||||
--policy.path=user/my_policy \\
|
||||
--robot.type=so100_follower \\
|
||||
--robot.port=/dev/ttyACM0 \\
|
||||
--teleop.type=so100_leader \\
|
||||
--teleop.port=/dev/ttyACM1 \\
|
||||
--dataset.repo_id=user/rollout_episodic_data \\
|
||||
--dataset.num_episodes=20 \\
|
||||
--dataset.single_task="Grab the cube"
|
||||
|
||||
# Resume a previous sentry recording session
|
||||
lerobot-rollout \\
|
||||
--strategy.type=sentry \\
|
||||
|
||||
@@ -292,19 +292,8 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
|
||||
active_cfg = cfg.trainable_config
|
||||
processor_pretrained_path = active_cfg.pretrained_path
|
||||
if (
|
||||
getattr(active_cfg, "use_relative_actions", False)
|
||||
and processor_pretrained_path is not None
|
||||
and not cfg.resume
|
||||
):
|
||||
logging.warning(
|
||||
"use_relative_actions=true with pretrained processors can skip relative transforms if "
|
||||
"the checkpoint processors do not define them. Building processors from current policy config."
|
||||
)
|
||||
processor_pretrained_path = None
|
||||
|
||||
processor_kwargs = {}
|
||||
postprocessor_kwargs = {}
|
||||
if (processor_pretrained_path and not cfg.resume) or not processor_pretrained_path:
|
||||
processor_kwargs["dataset_stats"] = dataset.meta.stats
|
||||
|
||||
@@ -312,24 +301,31 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
processor_kwargs["dataset_meta"] = dataset.meta
|
||||
|
||||
if not cfg.is_reward_model_training and processor_pretrained_path is not None:
|
||||
processor_kwargs["preprocessor_overrides"] = {
|
||||
preprocessor_overrides = {
|
||||
"device_processor": {"device": device.type},
|
||||
"normalizer_processor": {
|
||||
"stats": dataset.meta.stats,
|
||||
"features": {**policy.config.input_features, **policy.config.output_features},
|
||||
"norm_map": policy.config.normalization_mapping,
|
||||
},
|
||||
"rename_observations_processor": {"rename_map": cfg.rename_map},
|
||||
}
|
||||
processor_kwargs["preprocessor_overrides"]["rename_observations_processor"] = {
|
||||
"rename_map": cfg.rename_map
|
||||
}
|
||||
postprocessor_kwargs["postprocessor_overrides"] = {
|
||||
postprocessor_overrides = {
|
||||
"unnormalizer_processor": {
|
||||
"stats": dataset.meta.stats,
|
||||
"features": policy.config.output_features,
|
||||
"norm_map": policy.config.normalization_mapping,
|
||||
},
|
||||
}
|
||||
if getattr(active_cfg, "use_relative_actions", False):
|
||||
preprocessor_overrides["relative_actions_processor"] = {
|
||||
"enabled": True,
|
||||
"exclude_joints": getattr(active_cfg, "relative_exclude_joints", []),
|
||||
"action_names": getattr(active_cfg, "action_feature_names", None),
|
||||
}
|
||||
postprocessor_overrides["absolute_actions_processor"] = {"enabled": True}
|
||||
processor_kwargs["preprocessor_overrides"] = preprocessor_overrides
|
||||
processor_kwargs["postprocessor_overrides"] = postprocessor_overrides
|
||||
|
||||
if cfg.is_reward_model_training:
|
||||
preprocessor, postprocessor = make_reward_pre_post_processors(
|
||||
@@ -341,7 +337,6 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
policy_cfg=cfg.policy,
|
||||
pretrained_path=processor_pretrained_path,
|
||||
**processor_kwargs,
|
||||
**postprocessor_kwargs,
|
||||
)
|
||||
|
||||
if is_main_process:
|
||||
|
||||
@@ -0,0 +1,13 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
@@ -0,0 +1,78 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.policies.lingbot_va.configuration_lingbot_va import LingBotVAConfig
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES
|
||||
|
||||
|
||||
def make_config(**overrides) -> LingBotVAConfig:
|
||||
kwargs = {"device": "cpu"}
|
||||
kwargs.update(overrides)
|
||||
return LingBotVAConfig(**kwargs)
|
||||
|
||||
|
||||
def test_registered_in_choice_registry() -> None:
|
||||
assert "lingbot_va" in PreTrainedConfig.get_known_choices()
|
||||
assert PreTrainedConfig.get_choice_class("lingbot_va") is LingBotVAConfig
|
||||
|
||||
|
||||
def test_type_property() -> None:
|
||||
assert make_config().type == "lingbot_va"
|
||||
|
||||
|
||||
def test_chunk_size_and_action_steps() -> None:
|
||||
cfg = make_config(frame_chunk_size=4, action_per_frame=4)
|
||||
assert cfg.chunk_size == 16
|
||||
assert cfg.n_action_steps == 16
|
||||
assert cfg.action_delta_indices == list(range(16))
|
||||
assert cfg.observation_delta_indices is None
|
||||
assert cfg.reward_delta_indices is None
|
||||
|
||||
|
||||
def test_optimizer_and_scheduler_presets() -> None:
|
||||
cfg = make_config()
|
||||
opt = cfg.get_optimizer_preset()
|
||||
assert opt.lr == cfg.optimizer_lr
|
||||
sched = cfg.get_scheduler_preset()
|
||||
assert sched.num_warmup_steps == cfg.scheduler_warmup_steps
|
||||
|
||||
|
||||
def test_validate_features_sets_action_feature() -> None:
|
||||
cfg = make_config()
|
||||
cfg.input_features = {f"{OBS_IMAGES}.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128))}
|
||||
cfg.output_features = {}
|
||||
cfg.validate_features()
|
||||
assert ACTION in cfg.output_features
|
||||
assert cfg.output_features[ACTION].shape == (len(cfg.used_action_channel_ids),)
|
||||
|
||||
|
||||
def test_validate_features_no_visual_raises() -> None:
|
||||
cfg = make_config()
|
||||
cfg.input_features = {}
|
||||
cfg.output_features = {}
|
||||
with pytest.raises(ValueError, match="at least one visual input feature"):
|
||||
cfg.validate_features()
|
||||
|
||||
|
||||
def test_invalid_attn_mode_raises() -> None:
|
||||
with pytest.raises(ValueError, match="attn_mode"):
|
||||
make_config(attn_mode="banana")
|
||||
@@ -0,0 +1,38 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from lerobot.policies.factory import make_policy_config
|
||||
from lerobot.policies.lingbot_va.configuration_lingbot_va import LingBotVAConfig
|
||||
|
||||
|
||||
def test_make_policy_config_returns_lingbot_va() -> None:
|
||||
cfg = make_policy_config("lingbot_va", device="cpu")
|
||||
assert isinstance(cfg, LingBotVAConfig)
|
||||
|
||||
|
||||
def test_get_policy_class_resolves_lazily() -> None:
|
||||
# Importing the policy class pulls in diffusers (Wan2.2 stack); skip if unavailable.
|
||||
pytest.importorskip("diffusers")
|
||||
pytest.importorskip("transformers")
|
||||
from lerobot.policies.factory import get_policy_class
|
||||
|
||||
cls = get_policy_class("lingbot_va")
|
||||
assert cls.name == "lingbot_va"
|
||||
assert cls.config_class is LingBotVAConfig
|
||||
@@ -0,0 +1,131 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Unit tests for the vendored LingBot-VA helper code (scheduler + grid utilities)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
pytest.importorskip("diffusers") # the model code lives in modeling_lingbot_va, which imports diffusers
|
||||
|
||||
from lerobot.policies.lingbot_va.modeling_lingbot_va import ( # noqa: E402
|
||||
FlowMatchScheduler,
|
||||
data_seq_to_patch,
|
||||
get_mesh_id,
|
||||
)
|
||||
|
||||
|
||||
def test_flow_match_scheduler_timesteps_monotone_decreasing() -> None:
|
||||
sch = FlowMatchScheduler(shift=5.0, sigma_min=0.0, extra_one_step=True)
|
||||
sch.set_timesteps(20)
|
||||
assert sch.timesteps.shape == (20,)
|
||||
diffs = sch.timesteps[1:] - sch.timesteps[:-1]
|
||||
assert torch.all(diffs <= 0) # decreasing
|
||||
|
||||
|
||||
def test_flow_match_scheduler_step_preserves_shape() -> None:
|
||||
sch = FlowMatchScheduler(shift=5.0, sigma_min=0.0, extra_one_step=True)
|
||||
sch.set_timesteps(20)
|
||||
sample = torch.zeros(1, 48, 4, 8, 16)
|
||||
out = sch.step(torch.ones_like(sample), sch.timesteps[0], sample)
|
||||
assert out.shape == sample.shape
|
||||
|
||||
|
||||
def test_flow_match_scheduler_add_noise() -> None:
|
||||
sch = FlowMatchScheduler(shift=5.0, sigma_min=0.0, extra_one_step=True)
|
||||
sch.set_timesteps(20)
|
||||
sample = torch.randn(1, 48, 4, 8, 16)
|
||||
noise = torch.randn_like(sample)
|
||||
noisy = sch.add_noise(sample, noise, sch.timesteps[:4], t_dim=2)
|
||||
assert noisy.shape == sample.shape
|
||||
|
||||
|
||||
def test_get_mesh_id_latent_shape() -> None:
|
||||
grid = get_mesh_id(4, 8, 16, 0, 1, 0)
|
||||
assert grid.shape == (4, 4 * 8 * 16) # (f, h, w, stream) x tokens
|
||||
|
||||
|
||||
def test_get_mesh_id_action_shape() -> None:
|
||||
grid = get_mesh_id(4, 4, 1, 1, 1, 0, action=True)
|
||||
assert grid.shape == (4, 4 * 4 * 1)
|
||||
# Action rows for h/w are sentinel -1.
|
||||
assert torch.all(grid[1] < 0)
|
||||
assert torch.all(grid[2] < 0)
|
||||
|
||||
|
||||
def test_data_seq_to_patch_roundtrip_shape() -> None:
|
||||
b, f, h, w, c = 1, 4, 8, 16, 48
|
||||
seq = torch.arange(b * f * h * w * c, dtype=torch.float32).reshape(b, f * h * w, c)
|
||||
out = data_seq_to_patch((1, 2, 2), seq, f, h, w, batch_size=b)
|
||||
assert out.shape == (b, c, f, h, w)
|
||||
|
||||
|
||||
def test_training_step_reduces_loss_tiny_flex() -> None:
|
||||
"""End-to-end single training step (flow-matching loss -> backward -> AdamW) on a tiny config.
|
||||
|
||||
Exercises the flex-attention training path; requires a CUDA GPU with flex-attention support.
|
||||
"""
|
||||
if not torch.cuda.is_available():
|
||||
import pytest
|
||||
|
||||
pytest.skip("training step test requires a CUDA GPU (flex-attention)")
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.policies.lingbot_va.configuration_lingbot_va import LingBotVAConfig
|
||||
from lerobot.policies.lingbot_va.modeling_lingbot_va import LingBotVAPolicy
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES
|
||||
|
||||
cfg = LingBotVAConfig(
|
||||
attn_mode="flex",
|
||||
dtype="bfloat16",
|
||||
in_channels=16,
|
||||
out_channels=16,
|
||||
action_dim=8,
|
||||
text_dim=32,
|
||||
freq_dim=64,
|
||||
ffn_dim=64,
|
||||
num_attention_heads=2,
|
||||
attention_head_dim=24,
|
||||
num_layers=2,
|
||||
frame_chunk_size=2,
|
||||
action_per_frame=4,
|
||||
used_action_channel_ids=[0, 1, 2, 3],
|
||||
obs_cam_keys=[f"{OBS_IMAGES}.image"],
|
||||
device="cuda",
|
||||
)
|
||||
cfg.input_features = {f"{OBS_IMAGES}.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 64, 64))}
|
||||
cfg.output_features = {ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(4,))}
|
||||
cfg.validate_features()
|
||||
|
||||
policy = LingBotVAPolicy(cfg).to("cuda")
|
||||
policy.train()
|
||||
opt = torch.optim.AdamW(policy.get_optim_params(), lr=1e-4)
|
||||
|
||||
b, fc, apf = 1, cfg.frame_chunk_size, cfg.action_per_frame
|
||||
latents = torch.randn(b, cfg.in_channels, fc, 4, 4, device="cuda", dtype=torch.bfloat16)
|
||||
actions = torch.randn(b, cfg.action_dim, fc, apf, 1, device="cuda", dtype=torch.bfloat16)
|
||||
amask = torch.zeros(cfg.action_dim, device="cuda")
|
||||
amask[cfg.used_action_channel_ids] = 1.0
|
||||
actions_mask = amask.view(1, -1, 1, 1, 1).expand_as(actions)
|
||||
text_emb = torch.randn(b, cfg.max_sequence_length, cfg.text_dim, device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
loss, metrics = policy.training_loss_from_streams(latents, actions, actions_mask, text_emb)
|
||||
assert torch.isfinite(loss) and {"latent_loss", "action_loss"} <= set(metrics)
|
||||
loss.backward()
|
||||
assert any(p.grad is not None and torch.isfinite(p.grad).all() for p in policy.get_optim_params())
|
||||
opt.step()
|
||||
@@ -0,0 +1,88 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.policies.lingbot_va.configuration_lingbot_va import LingBotVAConfig
|
||||
from lerobot.policies.lingbot_va.processor_lingbot_va import make_lingbot_va_pre_post_processors
|
||||
from lerobot.processor import PolicyProcessorPipeline, UnnormalizerProcessorStep
|
||||
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
|
||||
from lerobot.utils.constants import (
|
||||
ACTION,
|
||||
OBS_IMAGES,
|
||||
POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
)
|
||||
|
||||
|
||||
def _make_config() -> LingBotVAConfig:
|
||||
cfg = LingBotVAConfig(device="cpu")
|
||||
cfg.input_features = {f"{OBS_IMAGES}.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128))}
|
||||
cfg.output_features = {}
|
||||
cfg.validate_features()
|
||||
return cfg
|
||||
|
||||
|
||||
def test_make_pre_post_processors_names_and_steps() -> None:
|
||||
cfg = _make_config()
|
||||
pre, post = make_lingbot_va_pre_post_processors(cfg, dataset_stats=None)
|
||||
assert pre.name == POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
assert post.name == POLICY_POSTPROCESSOR_DEFAULT_NAME
|
||||
# Actions are unnormalized by the standard built-in quantile unnormalizer.
|
||||
assert any(isinstance(s, UnnormalizerProcessorStep) for s in post.steps)
|
||||
|
||||
|
||||
def test_freshly_built_postprocessor_is_identity() -> None:
|
||||
# Without action stats the quantile unnormalizer is a no-op (identity passthrough): the real
|
||||
# per-benchmark q01/q99 are restored from the saved checkpoint on load, not hardcoded here.
|
||||
cfg = _make_config()
|
||||
_, post = make_lingbot_va_pre_post_processors(cfg, dataset_stats=None)
|
||||
normed = torch.tensor([[0.3, -0.5, 1.0, -1.0, 0.0, 0.7, -0.2]])
|
||||
assert torch.allclose(post(normed), normed, atol=1e-6)
|
||||
|
||||
|
||||
def test_postprocessor_quantile_unnormalization() -> None:
|
||||
# QUANTILES unnormalize maps [-1, 1] -> [q01, q99]: -1 -> q01, +1 -> q99.
|
||||
cfg = _make_config()
|
||||
q01 = [-1.0, -0.5, 0.0, -1.0, -1.0, -1.0, -1.0]
|
||||
q99 = [1.0, 0.5, 2.0, 1.0, 1.0, 1.0, 1.0]
|
||||
stats = {ACTION: {"q01": q01, "q99": q99}}
|
||||
_, post = make_lingbot_va_pre_post_processors(cfg, dataset_stats=stats)
|
||||
out_lo = post(torch.full((1, 7), -1.0))
|
||||
out_hi = post(torch.full((1, 7), 1.0))
|
||||
assert torch.allclose(out_lo, torch.tensor(q01).unsqueeze(0), atol=1e-4)
|
||||
assert torch.allclose(out_hi, torch.tensor(q99).unsqueeze(0), atol=1e-4)
|
||||
|
||||
|
||||
def test_postprocessor_stats_survive_save_load(tmp_path) -> None:
|
||||
# Regression guard for the Hub mechanism: the q01/q99 stats live in the saved post-processor
|
||||
# state and must round-trip through save_pretrained / from_pretrained.
|
||||
cfg = _make_config()
|
||||
q01 = [-0.6, -0.8, -0.9, -0.1, -0.15, -0.25, -1.0]
|
||||
q99 = [0.9, 0.85, 0.9, 0.17, 0.18, 0.34, 1.0]
|
||||
_, post = make_lingbot_va_pre_post_processors(cfg, dataset_stats={ACTION: {"q01": q01, "q99": q99}})
|
||||
post.save_pretrained(tmp_path)
|
||||
loaded = PolicyProcessorPipeline.from_pretrained(
|
||||
tmp_path,
|
||||
config_filename=f"{POLICY_POSTPROCESSOR_DEFAULT_NAME}.json",
|
||||
to_transition=policy_action_to_transition,
|
||||
to_output=transition_to_policy_action,
|
||||
)
|
||||
out = loaded(torch.full((1, 7), -1.0))
|
||||
assert torch.allclose(out, torch.tensor(q01).unsqueeze(0), atol=1e-4)
|
||||
@@ -0,0 +1,273 @@
|
||||
#!/usr/bin/env python
|
||||
"""Shared fixtures and helpers for VLA-JEPA tests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.policies.vla_jepa.configuration_vla_jepa import VLAJEPAConfig
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shared constants
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
BATCH_SIZE = 2
|
||||
ACTION_DIM = 3
|
||||
STATE_DIM = 4
|
||||
IMAGE_SIZE = 8
|
||||
ACTION_HORIZON = 4
|
||||
N_ACTION_STEPS = 2
|
||||
NUM_VIDEO_FRAMES = 3
|
||||
QWEN_HIDDEN_SIZE = 16 # hidden size produced by _FakeQwenBackbone
|
||||
|
||||
EXPECTED_ACTION_CHUNK_SHAPE = (BATCH_SIZE, ACTION_HORIZON, ACTION_DIM)
|
||||
EXPECTED_SELECT_ACTION_SHAPE = (BATCH_SIZE, ACTION_DIM)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def set_seed_all(seed: int) -> None:
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
|
||||
def make_config(
|
||||
action_dim: int = ACTION_DIM,
|
||||
state_dim: int = STATE_DIM,
|
||||
action_horizon: int = ACTION_HORIZON,
|
||||
num_video_frames: int = NUM_VIDEO_FRAMES,
|
||||
) -> VLAJEPAConfig:
|
||||
config = VLAJEPAConfig(
|
||||
input_features={
|
||||
f"{OBS_IMAGES}.laptop": PolicyFeature(type=FeatureType.VISUAL, shape=(3, IMAGE_SIZE, IMAGE_SIZE)),
|
||||
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,)),
|
||||
},
|
||||
output_features={
|
||||
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,)),
|
||||
},
|
||||
device="cpu",
|
||||
chunk_size=action_horizon,
|
||||
n_action_steps=min(N_ACTION_STEPS, action_horizon),
|
||||
action_dim=action_dim,
|
||||
state_dim=state_dim,
|
||||
num_video_frames=num_video_frames,
|
||||
num_action_tokens_per_timestep=2,
|
||||
num_embodied_action_tokens_per_instruction=3,
|
||||
num_inference_timesteps=2,
|
||||
action_hidden_size=QWEN_HIDDEN_SIZE,
|
||||
action_model_type="DiT-test",
|
||||
action_num_layers=1,
|
||||
predictor_depth=1,
|
||||
predictor_num_heads=2,
|
||||
predictor_mlp_ratio=2.0,
|
||||
jepa_tubelet_size=1,
|
||||
)
|
||||
config.validate_features()
|
||||
return config
|
||||
|
||||
|
||||
def make_train_batch(
|
||||
batch_size: int = BATCH_SIZE,
|
||||
action_dim: int = ACTION_DIM,
|
||||
state_dim: int = STATE_DIM,
|
||||
action_horizon: int = ACTION_HORIZON,
|
||||
num_video_frames: int = NUM_VIDEO_FRAMES,
|
||||
) -> dict[str, Tensor | list[str]]:
|
||||
return {
|
||||
f"{OBS_IMAGES}.laptop": torch.rand(batch_size, num_video_frames, 3, IMAGE_SIZE, IMAGE_SIZE),
|
||||
OBS_STATE: torch.randn(batch_size, 1, state_dim),
|
||||
ACTION: torch.randn(batch_size, action_horizon, action_dim),
|
||||
"task": ["pick up the cube"] * batch_size,
|
||||
}
|
||||
|
||||
|
||||
def make_inference_batch(
|
||||
batch_size: int = BATCH_SIZE,
|
||||
state_dim: int = STATE_DIM,
|
||||
) -> dict[str, Tensor | list[str]]:
|
||||
return {
|
||||
f"{OBS_IMAGES}.laptop": torch.rand(batch_size, 3, IMAGE_SIZE, IMAGE_SIZE),
|
||||
OBS_STATE: torch.randn(batch_size, state_dim),
|
||||
"task": ["pick up the cube"] * batch_size,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fake external models (replace Qwen3-VL and V-JEPA at test time)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _FakeLanguageLayer(nn.Module):
|
||||
"""Leaf module whose forward hook is captured by _qwen_last_decoder_hidden."""
|
||||
|
||||
def __init__(self, hidden_size: int) -> None:
|
||||
super().__init__()
|
||||
self._hidden_size = hidden_size
|
||||
|
||||
def forward(self, hidden: Tensor, **_: object) -> tuple[Tensor, ...]:
|
||||
return (hidden,)
|
||||
|
||||
|
||||
class _FakeLanguageModel(nn.Module):
|
||||
def __init__(self, hidden_size: int) -> None:
|
||||
super().__init__()
|
||||
self._hidden_size = hidden_size
|
||||
self.layers = nn.ModuleList([_FakeLanguageLayer(hidden_size)])
|
||||
|
||||
def forward(self, input_ids: Tensor, **_: object) -> SimpleNamespace:
|
||||
batch_size, seq_len = input_ids.shape
|
||||
hidden = torch.zeros(batch_size, seq_len, self._hidden_size, device=input_ids.device)
|
||||
self.layers[-1](hidden)
|
||||
return SimpleNamespace()
|
||||
|
||||
|
||||
class _FakeQwenInnerModel(nn.Module):
|
||||
"""Mimics the `.model.model` level that _qwen_last_decoder_hidden walks into."""
|
||||
|
||||
def __init__(self, hidden_size: int) -> None:
|
||||
super().__init__()
|
||||
self.language_model = _FakeLanguageModel(hidden_size)
|
||||
|
||||
def forward(self, input_ids: Tensor, **kwargs: object) -> SimpleNamespace:
|
||||
return self.language_model(input_ids)
|
||||
|
||||
|
||||
class _FakeQwenBackbone(nn.Module):
|
||||
def __init__(self, hidden_size: int) -> None:
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(1))
|
||||
self.config = SimpleNamespace(
|
||||
hidden_size=hidden_size,
|
||||
text_config=SimpleNamespace(hidden_size=hidden_size),
|
||||
)
|
||||
self.model = _FakeQwenInnerModel(hidden_size)
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return self.weight.device
|
||||
|
||||
def forward(self, input_ids: Tensor, **_: object) -> SimpleNamespace:
|
||||
batch_size, seq_len = input_ids.shape
|
||||
hidden_size = self.config.hidden_size
|
||||
values = torch.arange(
|
||||
batch_size * seq_len * hidden_size,
|
||||
device=input_ids.device,
|
||||
dtype=torch.float32,
|
||||
).view(batch_size, seq_len, hidden_size)
|
||||
hidden = values / values.numel() + self.weight
|
||||
self.model(input_ids) # call through so the forward hook on layers[-1] fires
|
||||
return SimpleNamespace(hidden_states=[hidden])
|
||||
|
||||
|
||||
class _FakeQwenInterface(nn.Module):
|
||||
def __init__(self, config: VLAJEPAConfig) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.model = _FakeQwenBackbone(hidden_size=QWEN_HIDDEN_SIZE)
|
||||
|
||||
@staticmethod
|
||||
def _get_torch_dtype(dtype_name: str) -> torch.dtype:
|
||||
return torch.float32 if dtype_name == "float32" else torch.bfloat16
|
||||
|
||||
def expand_tokenizer(self) -> tuple[list[str], list[int], int]:
|
||||
max_action_tokens = self.config.chunk_size * self.config.num_action_tokens_per_timestep
|
||||
action_tokens = [self.config.special_action_token.format(idx) for idx in range(max_action_tokens)]
|
||||
action_token_ids = list(range(1000, 1000 + max_action_tokens))
|
||||
return action_tokens, action_token_ids, 2000
|
||||
|
||||
def build_inputs(
|
||||
self,
|
||||
images: list[list[Image.Image]],
|
||||
instructions: list[str],
|
||||
action_prompt: str,
|
||||
embodied_prompt: str,
|
||||
) -> dict[str, Tensor]:
|
||||
batch_size = len(images)
|
||||
del images, instructions, action_prompt, embodied_prompt
|
||||
action_count = (self.config.num_video_frames - 1) * self.config.num_action_tokens_per_timestep
|
||||
token_ids = (
|
||||
[10]
|
||||
+ list(range(1000, 1000 + action_count))
|
||||
+ [2000] * self.config.num_embodied_action_tokens_per_instruction
|
||||
+ [11]
|
||||
)
|
||||
return {
|
||||
"input_ids": torch.tensor(
|
||||
[token_ids] * batch_size,
|
||||
device=self.model.device,
|
||||
dtype=torch.long,
|
||||
)
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def tensor_to_pil(image_tensor: Tensor) -> Image.Image:
|
||||
image = image_tensor.detach().cpu()
|
||||
if image.ndim == 3 and image.shape[0] in (1, 3):
|
||||
image = image.permute(1, 2, 0)
|
||||
image = (image.float().clamp(0, 1) * 255).to(torch.uint8).numpy()
|
||||
return Image.fromarray(image)
|
||||
|
||||
|
||||
class _FakeVideoEncoder(nn.Module):
|
||||
def __init__(self, hidden_size: int = 8, tubelet_size: int = 1) -> None:
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(1))
|
||||
# image_size must be >= patch_size (16) so the predictor grid is non-zero.
|
||||
# Setting image_size=16 gives a 1x1 grid (1 patch per frame).
|
||||
self.config = SimpleNamespace(hidden_size=hidden_size, tubelet_size=tubelet_size, image_size=16)
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return self.weight.device
|
||||
|
||||
def get_vision_features(self, pixel_values_videos: Tensor) -> Tensor:
|
||||
batch_size, num_frames = pixel_values_videos.shape[:2]
|
||||
hidden_size = self.config.hidden_size
|
||||
frame_values = pixel_values_videos.float().mean(dim=(2, 3, 4), keepdim=False)
|
||||
return frame_values[:, :, None].expand(batch_size, num_frames, hidden_size)
|
||||
|
||||
|
||||
class _FakeVideoProcessor:
|
||||
def __call__(self, videos, return_tensors: str) -> dict[str, Tensor]:
|
||||
assert return_tensors == "pt"
|
||||
if isinstance(videos, list):
|
||||
pixel_values = torch.stack([torch.as_tensor(v) for v in videos])
|
||||
else:
|
||||
pixel_values = torch.as_tensor(videos).unsqueeze(0)
|
||||
return {"pixel_values_videos": pixel_values}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patch_vla_jepa_external_models(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
from lerobot.policies.vla_jepa import modeling_vla_jepa
|
||||
|
||||
monkeypatch.setattr(modeling_vla_jepa, "Qwen3VLInterface", _FakeQwenInterface)
|
||||
monkeypatch.setattr(
|
||||
modeling_vla_jepa.AutoModel,
|
||||
"from_pretrained",
|
||||
lambda *args, **kwargs: _FakeVideoEncoder(),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
modeling_vla_jepa.AutoVideoProcessor,
|
||||
"from_pretrained",
|
||||
lambda *args, **kwargs: _FakeVideoProcessor(),
|
||||
)
|
||||
@@ -0,0 +1,157 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
pytest.importorskip("diffusers")
|
||||
|
||||
from conftest import (
|
||||
ACTION_DIM,
|
||||
ACTION_HORIZON,
|
||||
BATCH_SIZE,
|
||||
QWEN_HIDDEN_SIZE,
|
||||
STATE_DIM,
|
||||
make_config,
|
||||
set_seed_all,
|
||||
) # noqa: E402
|
||||
|
||||
from lerobot.policies.vla_jepa.action_head import ( # noqa: E402
|
||||
VLAJEPAActionHead,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# VLAJEPAActionHead
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"action_dim,state_dim,action_horizon",
|
||||
[
|
||||
(3, 4, 4), # default test dims
|
||||
(7, 0, 16), # no proprioceptive state, production-like action space
|
||||
(6, 8, 8), # medium dims
|
||||
],
|
||||
)
|
||||
def test_action_head_sample_time_range(action_dim: int, state_dim: int, action_horizon: int) -> None:
|
||||
config = make_config(action_dim=action_dim, state_dim=state_dim, action_horizon=action_horizon)
|
||||
head = VLAJEPAActionHead(config, cross_attention_dim=QWEN_HIDDEN_SIZE)
|
||||
t = head.sample_time(batch_size=200, device=torch.device("cpu"), dtype=torch.float32)
|
||||
assert t.shape == (200,)
|
||||
assert torch.isfinite(t).all()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"action_dim,state_dim,action_horizon",
|
||||
[
|
||||
(3, 4, 4),
|
||||
(7, 0, 16),
|
||||
(6, 8, 8),
|
||||
],
|
||||
)
|
||||
def test_action_head_build_inputs_shape(action_dim: int, state_dim: int, action_horizon: int) -> None:
|
||||
config = make_config(action_dim=action_dim, state_dim=state_dim, action_horizon=action_horizon)
|
||||
head = VLAJEPAActionHead(config, cross_attention_dim=QWEN_HIDDEN_SIZE)
|
||||
conditioning = torch.randn(2, 4, QWEN_HIDDEN_SIZE)
|
||||
actions = torch.randn(2, action_horizon, action_dim)
|
||||
timesteps = torch.randint(0, 100, (2,))
|
||||
|
||||
state = torch.randn(2, state_dim) if state_dim > 0 else None
|
||||
out_with = head._build_inputs(conditioning, actions, state, timesteps)
|
||||
out_none = head._build_inputs(conditioning, actions, None, timesteps)
|
||||
|
||||
assert out_with.ndim == 3 and out_none.ndim == 3
|
||||
if state_dim > 0:
|
||||
assert out_with.shape[1] > out_none.shape[1]
|
||||
assert torch.isfinite(out_with).all() and torch.isfinite(out_none).all()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"action_dim,state_dim,action_horizon",
|
||||
[
|
||||
(3, 4, 4),
|
||||
(7, 0, 16),
|
||||
(6, 8, 8),
|
||||
],
|
||||
)
|
||||
def test_action_head_forward_loss_valid(action_dim: int, state_dim: int, action_horizon: int) -> None:
|
||||
set_seed_all(42)
|
||||
config = make_config(action_dim=action_dim, state_dim=state_dim, action_horizon=action_horizon)
|
||||
head = VLAJEPAActionHead(config, cross_attention_dim=QWEN_HIDDEN_SIZE)
|
||||
conditioning = torch.randn(2, 4, QWEN_HIDDEN_SIZE)
|
||||
actions = torch.randn(2, action_horizon, action_dim)
|
||||
state = torch.randn(2, state_dim) if state_dim > 0 else None
|
||||
loss = head.forward(conditioning, actions, state)
|
||||
assert loss.shape == ()
|
||||
assert torch.isfinite(loss) and loss > 0
|
||||
|
||||
|
||||
def test_action_head_forward_gradient_flows() -> None:
|
||||
set_seed_all(42)
|
||||
config = make_config()
|
||||
head = VLAJEPAActionHead(config, cross_attention_dim=QWEN_HIDDEN_SIZE)
|
||||
conditioning = torch.randn(BATCH_SIZE, 4, QWEN_HIDDEN_SIZE)
|
||||
actions = torch.randn(BATCH_SIZE, ACTION_HORIZON, ACTION_DIM)
|
||||
state = torch.randn(BATCH_SIZE, STATE_DIM)
|
||||
loss = head.forward(conditioning, actions, state)
|
||||
loss.backward()
|
||||
assert any(p.grad is not None for p in head.parameters() if p.requires_grad)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@pytest.mark.parametrize(
|
||||
"action_dim,state_dim,action_horizon",
|
||||
[
|
||||
(3, 4, 4),
|
||||
(7, 0, 16),
|
||||
(6, 8, 8),
|
||||
],
|
||||
)
|
||||
def test_action_head_predict_action_shape(action_dim: int, state_dim: int, action_horizon: int) -> None:
|
||||
set_seed_all(42)
|
||||
config = make_config(action_dim=action_dim, state_dim=state_dim, action_horizon=action_horizon)
|
||||
head = VLAJEPAActionHead(config, cross_attention_dim=QWEN_HIDDEN_SIZE)
|
||||
conditioning = torch.randn(2, 4, QWEN_HIDDEN_SIZE)
|
||||
state = torch.randn(2, state_dim) if state_dim > 0 else None
|
||||
pred = head.predict_action(conditioning, state)
|
||||
assert tuple(pred.shape) == (2, action_horizon, action_dim)
|
||||
assert torch.isfinite(pred).all()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# action_is_pad masking
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_action_head_loss_fully_padded_is_zero() -> None:
|
||||
"""Loss is 0 when every timestep is padded (exercises the clamp_min guard)."""
|
||||
set_seed_all(42)
|
||||
config = make_config()
|
||||
head = VLAJEPAActionHead(config, cross_attention_dim=QWEN_HIDDEN_SIZE)
|
||||
conditioning = torch.randn(BATCH_SIZE, 4, QWEN_HIDDEN_SIZE)
|
||||
actions = torch.randn(BATCH_SIZE, ACTION_HORIZON, ACTION_DIM)
|
||||
state = torch.randn(BATCH_SIZE, STATE_DIM)
|
||||
|
||||
action_is_pad = torch.ones(BATCH_SIZE, ACTION_HORIZON, dtype=torch.bool)
|
||||
loss = head.forward(conditioning, actions, state, action_is_pad)
|
||||
assert loss.item() == 0.0
|
||||
|
||||
|
||||
def test_action_head_loss_none_matches_no_padding() -> None:
|
||||
"""action_is_pad=None is equivalent to an all-False (no padding) mask."""
|
||||
set_seed_all(42)
|
||||
config = make_config()
|
||||
head = VLAJEPAActionHead(config, cross_attention_dim=QWEN_HIDDEN_SIZE)
|
||||
conditioning = torch.randn(BATCH_SIZE, 4, QWEN_HIDDEN_SIZE)
|
||||
actions = torch.randn(BATCH_SIZE, ACTION_HORIZON, ACTION_DIM)
|
||||
state = torch.randn(BATCH_SIZE, STATE_DIM)
|
||||
|
||||
set_seed_all(0)
|
||||
loss_none = head.forward(conditioning, actions, state, action_is_pad=None)
|
||||
|
||||
set_seed_all(0)
|
||||
no_pad = torch.zeros(BATCH_SIZE, ACTION_HORIZON, dtype=torch.bool)
|
||||
loss_zeros = head.forward(conditioning, actions, state, action_is_pad=no_pad)
|
||||
|
||||
assert torch.isclose(loss_none, loss_zeros)
|
||||
@@ -0,0 +1,57 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from conftest import ACTION_DIM, ACTION_HORIZON, IMAGE_SIZE, NUM_VIDEO_FRAMES, STATE_DIM, make_config
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.policies.vla_jepa.configuration_vla_jepa import VLAJEPAConfig
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
|
||||
|
||||
|
||||
def test_delta_indices() -> None:
|
||||
config = make_config()
|
||||
assert config.observation_delta_indices == list(range(NUM_VIDEO_FRAMES))
|
||||
assert config.action_delta_indices == list(range(ACTION_HORIZON))
|
||||
|
||||
|
||||
def test_n_action_steps_exceeds_chunk_size_raises() -> None:
|
||||
with pytest.raises(ValueError, match="n_action_steps"):
|
||||
VLAJEPAConfig(chunk_size=4, n_action_steps=8)
|
||||
|
||||
|
||||
def test_too_few_video_frames_raises() -> None:
|
||||
with pytest.raises(ValueError, match="video_horizon"):
|
||||
VLAJEPAConfig(
|
||||
chunk_size=16,
|
||||
n_action_steps=16,
|
||||
num_video_frames=2,
|
||||
jepa_tubelet_size=2, # needs >= 4 frames (2 for current, 2 for future) to have a window of size > 0
|
||||
)
|
||||
|
||||
|
||||
def test_validate_features_no_image_raises() -> None:
|
||||
config = VLAJEPAConfig(
|
||||
input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(STATE_DIM,))},
|
||||
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(ACTION_DIM,))},
|
||||
)
|
||||
with pytest.raises(ValueError, match="at least one visual input feature"):
|
||||
config.validate_features()
|
||||
|
||||
|
||||
def test_validate_features_no_action_raises() -> None:
|
||||
config = VLAJEPAConfig(
|
||||
input_features={
|
||||
f"{OBS_IMAGES}.cam": PolicyFeature(type=FeatureType.VISUAL, shape=(3, IMAGE_SIZE, IMAGE_SIZE)),
|
||||
},
|
||||
output_features={},
|
||||
)
|
||||
with pytest.raises(ValueError, match="action output feature"):
|
||||
config.validate_features()
|
||||
|
||||
|
||||
def test_validate_features_sets_action_dim_from_feature() -> None:
|
||||
config = make_config(action_dim=6, state_dim=10)
|
||||
assert config.action_dim == 6
|
||||
assert config.state_dim == 10
|
||||
@@ -0,0 +1,598 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from copy import deepcopy
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
pytest.importorskip("transformers")
|
||||
pytest.importorskip("diffusers")
|
||||
|
||||
pytestmark = pytest.mark.filterwarnings(
|
||||
"ignore:In CPU autocast, but the target dtype is not supported:UserWarning"
|
||||
)
|
||||
|
||||
from conftest import ( # noqa: E402
|
||||
ACTION_DIM,
|
||||
ACTION_HORIZON,
|
||||
BATCH_SIZE,
|
||||
EXPECTED_ACTION_CHUNK_SHAPE,
|
||||
EXPECTED_SELECT_ACTION_SHAPE,
|
||||
IMAGE_SIZE,
|
||||
N_ACTION_STEPS,
|
||||
QWEN_HIDDEN_SIZE,
|
||||
STATE_DIM,
|
||||
make_config,
|
||||
make_inference_batch,
|
||||
make_train_batch,
|
||||
set_seed_all,
|
||||
)
|
||||
|
||||
from lerobot.policies.vla_jepa.configuration_vla_jepa import VLAJEPAConfig # noqa: E402
|
||||
from lerobot.policies.vla_jepa.modeling_vla_jepa import VLAJEPAPolicy # noqa: E402
|
||||
from lerobot.utils.constants import ACTION # noqa: E402
|
||||
|
||||
PRETRAINED_REPO_ID = "ginwind/VLA-JEPA"
|
||||
PRETRAINED_SUBFOLDER = "LIBERO"
|
||||
|
||||
# extended hub tests load the full converted safetensors checkpoints (~5 GB) and are
|
||||
# skipped by default. Set VLA_JEPA_EXTENDED=1 to opt in.
|
||||
_VLA_JEPA_EXTENDED = os.environ.get("VLA_JEPA_EXTENDED", "0") != "0"
|
||||
extended_test = pytest.mark.skipif(not _VLA_JEPA_EXTENDED, reason="Set VLA_JEPA_EXTENDED=1 to run hub tests")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Core training / inference tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_training_forward_pass(patch_vla_jepa_external_models: None) -> None:
|
||||
set_seed_all(42)
|
||||
policy = VLAJEPAPolicy(make_config())
|
||||
policy.train()
|
||||
|
||||
batch = make_train_batch()
|
||||
batch_before = deepcopy(batch)
|
||||
|
||||
loss, logs = policy.forward(batch)
|
||||
|
||||
assert loss.shape == ()
|
||||
assert torch.isfinite(loss)
|
||||
assert set(logs) == {"action_loss", "wm_loss", "loss"}
|
||||
assert logs["action_loss"] > 0
|
||||
assert logs["wm_loss"] >= 0
|
||||
|
||||
loss.backward()
|
||||
assert any(p.grad is not None for p in policy.model.action_model.parameters() if p.requires_grad)
|
||||
# Batch must not be mutated.
|
||||
assert set(batch) == set(batch_before)
|
||||
for key, value in batch.items():
|
||||
if isinstance(value, Tensor):
|
||||
assert torch.equal(value, batch_before[key])
|
||||
else:
|
||||
assert value == batch_before[key]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [1, 2, 4])
|
||||
def test_training_forward_various_batch_sizes(patch_vla_jepa_external_models: None, batch_size: int) -> None:
|
||||
set_seed_all(42)
|
||||
policy = VLAJEPAPolicy(make_config())
|
||||
policy.train()
|
||||
loss, logs = policy.forward(make_train_batch(batch_size=batch_size))
|
||||
assert torch.isfinite(loss) and loss > 0
|
||||
assert set(logs) == {"action_loss", "wm_loss", "loss"}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"action_dim,state_dim,action_horizon",
|
||||
[
|
||||
(3, 4, 4),
|
||||
(7, 0, 16),
|
||||
(6, 8, 8),
|
||||
],
|
||||
)
|
||||
def test_training_forward_various_dims(
|
||||
patch_vla_jepa_external_models: None,
|
||||
action_dim: int,
|
||||
state_dim: int,
|
||||
action_horizon: int,
|
||||
) -> None:
|
||||
set_seed_all(42)
|
||||
config = make_config(action_dim=action_dim, state_dim=state_dim, action_horizon=action_horizon)
|
||||
policy = VLAJEPAPolicy(config)
|
||||
policy.train()
|
||||
batch = make_train_batch(action_dim=action_dim, state_dim=state_dim, action_horizon=action_horizon)
|
||||
loss, _ = policy.forward(batch)
|
||||
assert torch.isfinite(loss) and loss > 0
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_action_generation_shape(patch_vla_jepa_external_models: None) -> None:
|
||||
set_seed_all(42)
|
||||
policy = VLAJEPAPolicy(make_config())
|
||||
policy.eval()
|
||||
batch = make_inference_batch()
|
||||
|
||||
chunk = policy.predict_action_chunk(batch)
|
||||
assert tuple(chunk.shape) == EXPECTED_ACTION_CHUNK_SHAPE
|
||||
assert chunk.device.type == "cpu"
|
||||
assert torch.isfinite(chunk).all()
|
||||
|
||||
a1 = policy.select_action(batch)
|
||||
a2 = policy.select_action(batch)
|
||||
assert tuple(a1.shape) == EXPECTED_SELECT_ACTION_SHAPE
|
||||
assert tuple(a2.shape) == EXPECTED_SELECT_ACTION_SHAPE
|
||||
assert torch.isfinite(a1).all() and torch.isfinite(a2).all()
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@pytest.mark.parametrize("action_dim,state_dim", [(3, 4), (7, 0), (6, 8)])
|
||||
def test_action_generation_various_dims(
|
||||
patch_vla_jepa_external_models: None, action_dim: int, state_dim: int
|
||||
) -> None:
|
||||
set_seed_all(42)
|
||||
config = make_config(action_dim=action_dim, state_dim=state_dim)
|
||||
policy = VLAJEPAPolicy(config)
|
||||
policy.eval()
|
||||
batch = make_inference_batch(state_dim=state_dim)
|
||||
chunk = policy.predict_action_chunk(batch)
|
||||
assert chunk.shape[-1] == action_dim
|
||||
assert torch.isfinite(chunk).all()
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_inference_reproducibility(patch_vla_jepa_external_models: None) -> None:
|
||||
set_seed_all(42)
|
||||
policy = VLAJEPAPolicy(make_config())
|
||||
policy.eval()
|
||||
batch = make_inference_batch()
|
||||
|
||||
set_seed_all(123)
|
||||
actions_1 = policy.predict_action_chunk(batch)
|
||||
set_seed_all(123)
|
||||
actions_2 = policy.predict_action_chunk(batch)
|
||||
|
||||
assert tuple(actions_1.shape) == EXPECTED_ACTION_CHUNK_SHAPE
|
||||
assert torch.allclose(actions_1, actions_2, atol=1e-6)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_predict_action_chunk_always_finite(patch_vla_jepa_external_models: None) -> None:
|
||||
policy = VLAJEPAPolicy(make_config())
|
||||
policy.eval()
|
||||
for seed in [0, 42, 123]:
|
||||
set_seed_all(seed)
|
||||
chunk = policy.predict_action_chunk(make_inference_batch())
|
||||
assert torch.isfinite(chunk).all(), f"non-finite actions with seed={seed}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Action queue behaviour
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_select_action_queue_drains_before_refill(patch_vla_jepa_external_models: None) -> None:
|
||||
set_seed_all(42)
|
||||
policy = VLAJEPAPolicy(make_config())
|
||||
policy.eval()
|
||||
batch = make_inference_batch()
|
||||
|
||||
# First call fills the queue (n_action_steps items) and pops one.
|
||||
a1 = policy.select_action(batch)
|
||||
assert len(policy._queues[ACTION]) == N_ACTION_STEPS - 1
|
||||
|
||||
# Second call pops from the existing queue without calling predict_action_chunk.
|
||||
a2 = policy.select_action(batch)
|
||||
assert tuple(a1.shape) == EXPECTED_SELECT_ACTION_SHAPE
|
||||
assert tuple(a2.shape) == EXPECTED_SELECT_ACTION_SHAPE
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_reset_clears_action_queue(patch_vla_jepa_external_models: None) -> None:
|
||||
set_seed_all(42)
|
||||
policy = VLAJEPAPolicy(make_config())
|
||||
policy.eval()
|
||||
policy.select_action(make_inference_batch())
|
||||
assert len(policy._queues[ACTION]) > 0
|
||||
|
||||
policy.reset()
|
||||
assert len(policy._queues[ACTION]) == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Format conversion
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_prepare_model_inputs_training_format(patch_vla_jepa_external_models: None) -> None:
|
||||
from PIL import Image
|
||||
|
||||
policy = VLAJEPAPolicy(make_config())
|
||||
examples = policy._prepare_model_inputs(make_train_batch())
|
||||
|
||||
assert len(examples) == BATCH_SIZE
|
||||
for ex in examples:
|
||||
assert set(ex) >= {"image", "video", "lang", "action", "state"}
|
||||
assert len(ex["image"]) == 1 and isinstance(ex["image"][0], Image.Image)
|
||||
assert ex["video"].ndim == 5 and ex["video"].dtype == np.uint8 # [V,T,H,W,C]
|
||||
assert ex["action"].shape == (ACTION_HORIZON, ACTION_DIM)
|
||||
assert ex["state"].shape == (1, STATE_DIM)
|
||||
|
||||
|
||||
def test_prepare_model_inputs_inference_omits_action(patch_vla_jepa_external_models: None) -> None:
|
||||
policy = VLAJEPAPolicy(make_config())
|
||||
for ex in policy._prepare_model_inputs(make_inference_batch()):
|
||||
assert "action" not in ex
|
||||
assert "image" in ex and "video" in ex and "lang" in ex
|
||||
|
||||
|
||||
def test_prepare_model_inputs_missing_task_uses_default(patch_vla_jepa_external_models: None) -> None:
|
||||
policy = VLAJEPAPolicy(make_config())
|
||||
batch = make_inference_batch()
|
||||
del batch["task"]
|
||||
examples = policy._prepare_model_inputs(batch)
|
||||
assert all(isinstance(ex["lang"], str) and len(ex["lang"]) > 0 for ex in examples)
|
||||
|
||||
|
||||
def test_prepare_model_inputs_string_task_broadcast(patch_vla_jepa_external_models: None) -> None:
|
||||
policy = VLAJEPAPolicy(make_config())
|
||||
batch = make_inference_batch()
|
||||
batch["task"] = "open the drawer"
|
||||
assert all(ex["lang"] == "open the drawer" for ex in policy._prepare_model_inputs(batch))
|
||||
|
||||
|
||||
def test_prepare_model_inputs_no_state_omitted(patch_vla_jepa_external_models: None) -> None:
|
||||
from lerobot.utils.constants import OBS_STATE
|
||||
|
||||
policy = VLAJEPAPolicy(make_config())
|
||||
batch = make_inference_batch()
|
||||
del batch[OBS_STATE]
|
||||
assert all("state" not in ex for ex in policy._prepare_model_inputs(batch))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pretrained checkpoint
|
||||
# Hub tests (opt-in: VLA_JEPA_EXTENDED=1)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_hub_train_batch(policy: VLAJEPAPolicy, batch_size: int = 1) -> dict:
|
||||
"""Build a training batch whose keys/shapes match a hub-loaded policy config."""
|
||||
cfg = policy.config
|
||||
batch: dict = {"task": ["pick up the cube"] * batch_size}
|
||||
for key, feat in cfg.image_features.items():
|
||||
h, w = feat.shape[-2], feat.shape[-1]
|
||||
batch[key] = torch.rand(batch_size, cfg.num_video_frames, 3, h, w)
|
||||
if cfg.robot_state_feature is not None:
|
||||
batch["observation.state"] = torch.randn(batch_size, 1, cfg.robot_state_feature.shape[0])
|
||||
batch[ACTION] = torch.randn(batch_size, cfg.chunk_size, cfg.action_dim)
|
||||
return batch
|
||||
|
||||
|
||||
def _make_hub_inference_batch(policy: VLAJEPAPolicy, batch_size: int = 1) -> dict:
|
||||
"""Build an inference batch whose keys/shapes match a hub-loaded policy config."""
|
||||
cfg = policy.config
|
||||
batch: dict = {"task": ["pick up the cube"] * batch_size}
|
||||
for key, feat in cfg.image_features.items():
|
||||
h, w = feat.shape[-2], feat.shape[-1]
|
||||
batch[key] = torch.rand(batch_size, 3, h, w)
|
||||
if cfg.robot_state_feature is not None:
|
||||
batch["observation.state"] = torch.randn(batch_size, cfg.robot_state_feature.shape[0])
|
||||
return batch
|
||||
|
||||
|
||||
_CP_ROOT = "lerobot"
|
||||
|
||||
# Each tuple: (repo_id, enable_world_model)
|
||||
_HUB_VARIANTS = [
|
||||
(f"{_CP_ROOT}/VLA-JEPA-LIBERO", True),
|
||||
(f"{_CP_ROOT}/VLA-JEPA-Pretrain", True),
|
||||
(f"{_CP_ROOT}/VLA-JEPA-SimplerEnv", False),
|
||||
]
|
||||
|
||||
|
||||
@extended_test
|
||||
@pytest.mark.parametrize("repo_id,enable_world_model", _HUB_VARIANTS)
|
||||
def test_hub_checkpoint_loads(repo_id: str, enable_world_model: bool) -> None:
|
||||
"""Policy loads from the converted safetensors checkpoint on the Hub."""
|
||||
policy = VLAJEPAPolicy.from_pretrained(repo_id)
|
||||
assert policy.config.enable_world_model == enable_world_model
|
||||
assert sum(p.numel() for p in policy.parameters()) > 0
|
||||
|
||||
|
||||
@extended_test
|
||||
@pytest.mark.parametrize("repo_id,enable_world_model", _HUB_VARIANTS)
|
||||
def test_hub_checkpoint_forward_pass(repo_id: str, enable_world_model: bool) -> None:
|
||||
"""Policy loaded from hub produces finite losses with a correctly-shaped batch."""
|
||||
policy = VLAJEPAPolicy.from_pretrained(repo_id)
|
||||
policy.train()
|
||||
|
||||
batch = _make_hub_train_batch(policy)
|
||||
loss, logs = policy.forward(batch)
|
||||
assert torch.isfinite(loss)
|
||||
assert "action_loss" in logs
|
||||
if enable_world_model:
|
||||
assert "wm_loss" in logs
|
||||
|
||||
|
||||
@extended_test
|
||||
def test_hub_freeze_qwen_disables_world_model() -> None:
|
||||
"""freeze_qwen=True (via cli_overrides) freezes qwen and disables the world model."""
|
||||
policy = VLAJEPAPolicy.from_pretrained(f"{_CP_ROOT}/VLA-JEPA-LIBERO", cli_overrides=["freeze_qwen=true"])
|
||||
assert not policy.config.enable_world_model
|
||||
assert policy.model.video_predictor is None
|
||||
qwen_params = list(policy.model.qwen.parameters())
|
||||
assert all(not p.requires_grad for p in qwen_params)
|
||||
assert any(p.requires_grad for p in policy.model.action_model.parameters())
|
||||
|
||||
|
||||
@extended_test
|
||||
def test_hub_disable_world_model_loads_simpler_env() -> None:
|
||||
"""SimplerEnv checkpoint (world model disabled) loads cleanly and runs inference."""
|
||||
policy = VLAJEPAPolicy.from_pretrained(f"{_CP_ROOT}/VLA-JEPA-SimplerEnv")
|
||||
assert not policy.config.enable_world_model
|
||||
assert policy.model.video_predictor is None
|
||||
assert policy.model.video_encoder is None
|
||||
|
||||
|
||||
@extended_test
|
||||
def test_hub_libero_inference_shape() -> None:
|
||||
"""select_action returns the expected shape using the LIBERO hub checkpoint."""
|
||||
policy = VLAJEPAPolicy.from_pretrained(f"{_CP_ROOT}/VLA-JEPA-LIBERO")
|
||||
policy.eval()
|
||||
batch = _make_hub_inference_batch(policy)
|
||||
action = policy.select_action(batch)
|
||||
assert action.shape[-1] == policy.config.action_dim
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Postprocessor unnormalization tests
|
||||
#
|
||||
# These tests verify that the postprocessor pipeline (clip → unnorm → binarize)
|
||||
# correctly applies MIN_MAX unnormalization after predict_action_chunk.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_dataset_stats(action_dim: int = ACTION_DIM) -> dict:
|
||||
"""Returns sample dataset_stats with a simple [i, i+10] range per action dim."""
|
||||
from lerobot.utils.constants import ACTION
|
||||
|
||||
return {
|
||||
ACTION: {
|
||||
"min": torch.tensor([float(i) for i in range(action_dim)], dtype=torch.float32),
|
||||
"max": torch.tensor([float(i) + 10.0 for i in range(action_dim)], dtype=torch.float32),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_postprocessor_unnormalizes_actions(patch_vla_jepa_external_models: None) -> None:
|
||||
"""UnnormalizerProcessorStep with MIN_MAX produces the correct inverse of MIN_MAX normalization."""
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.processor import UnnormalizerProcessorStep
|
||||
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
|
||||
from lerobot.utils.constants import ACTION
|
||||
|
||||
dataset_stats = _make_dataset_stats()
|
||||
|
||||
rng = np.random.default_rng(7)
|
||||
actions_np = rng.uniform(-1.0, 1.0, (2, ACTION_HORIZON, ACTION_DIM)).astype(np.float32)
|
||||
|
||||
a_min = dataset_stats[ACTION]["min"].numpy()
|
||||
a_max = dataset_stats[ACTION]["max"].numpy()
|
||||
expected = (actions_np + 1.0) / 2.0 * (a_max - a_min) + a_min
|
||||
|
||||
features = {ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(ACTION_DIM,))}
|
||||
unnorm_step = UnnormalizerProcessorStep(
|
||||
features=features,
|
||||
norm_map={FeatureType.ACTION: NormalizationMode.MIN_MAX},
|
||||
stats=dataset_stats,
|
||||
)
|
||||
|
||||
actions_tensor = torch.from_numpy(actions_np)
|
||||
transition = policy_action_to_transition(actions_tensor)
|
||||
result = transition_to_policy_action(unnorm_step(transition)).numpy()
|
||||
|
||||
np.testing.assert_allclose(result, expected, rtol=1e-5, atol=1e-6)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_postprocessor_clip_clamps_before_unnorm(patch_vla_jepa_external_models: None) -> None:
|
||||
"""ClipActionsProcessorStep clamps to [-1, 1] before unnormalization."""
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.policies.vla_jepa.processor_vla_jepa import ClipActionsProcessorStep
|
||||
from lerobot.processor import UnnormalizerProcessorStep
|
||||
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
|
||||
from lerobot.utils.constants import ACTION
|
||||
|
||||
dataset_stats = _make_dataset_stats()
|
||||
a_min = dataset_stats[ACTION]["min"].numpy()
|
||||
a_max = dataset_stats[ACTION]["max"].numpy()
|
||||
|
||||
# Deliberately out-of-range inputs
|
||||
actions_np = np.array([[[2.0] * ACTION_DIM, [-3.0] * ACTION_DIM]], dtype=np.float32)
|
||||
clipped = np.clip(actions_np, -1.0, 1.0)
|
||||
expected = (clipped + 1.0) / 2.0 * (a_max - a_min) + a_min
|
||||
|
||||
features = {ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(ACTION_DIM,))}
|
||||
clip_step = ClipActionsProcessorStep()
|
||||
unnorm_step = UnnormalizerProcessorStep(
|
||||
features=features,
|
||||
norm_map={FeatureType.ACTION: NormalizationMode.MIN_MAX},
|
||||
stats=dataset_stats,
|
||||
)
|
||||
|
||||
transition = policy_action_to_transition(torch.from_numpy(actions_np))
|
||||
transition = clip_step(transition)
|
||||
result = transition_to_policy_action(unnorm_step(transition)).numpy()
|
||||
|
||||
np.testing.assert_allclose(result, expected, rtol=1e-5, atol=1e-6)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_postprocessor_applied_after_predict_action_chunk(
|
||||
patch_vla_jepa_external_models: None, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""predict_action_chunk returns raw actions; the postprocessor applies unnormalization.
|
||||
|
||||
Verifies the split: predict_action_chunk returns normalized actions, and calling the
|
||||
postprocessor on them produces the correctly unnormalized result.
|
||||
"""
|
||||
from lerobot.policies.vla_jepa.processor_vla_jepa import make_vla_jepa_pre_post_processors
|
||||
|
||||
raw_actions = np.zeros((BATCH_SIZE, ACTION_HORIZON, ACTION_DIM), dtype=np.float32)
|
||||
|
||||
cfg = make_config()
|
||||
cfg.clip_normalized_actions = False
|
||||
cfg.binarize_gripper_action = False
|
||||
policy = VLAJEPAPolicy(cfg)
|
||||
policy.eval()
|
||||
monkeypatch.setattr(policy.model, "predict_action", lambda *a, **kw: raw_actions.copy())
|
||||
|
||||
dataset_stats = _make_dataset_stats()
|
||||
_, postprocessor = make_vla_jepa_pre_post_processors(cfg, dataset_stats)
|
||||
|
||||
batch = make_inference_batch()
|
||||
chunk = policy.predict_action_chunk(batch)
|
||||
|
||||
# predict_action_chunk returns raw (normalized) actions
|
||||
assert torch.allclose(chunk, torch.zeros_like(chunk), atol=1e-6), (
|
||||
"predict_action_chunk should return raw actions without unnormalization applied."
|
||||
)
|
||||
|
||||
# Postprocessor applies unnormalization: 0 → (0+1)/2 * (max-min) + min = 5 + i
|
||||
unnormed = postprocessor(chunk)
|
||||
from lerobot.utils.constants import ACTION
|
||||
|
||||
a_min = dataset_stats[ACTION]["min"].numpy()
|
||||
a_max = dataset_stats[ACTION]["max"].numpy()
|
||||
expected_first = 0.5 * (0.0 + 1.0) * (a_max[0] - a_min[0]) + a_min[0]
|
||||
assert unnormed[0, 0, 0].item() == pytest.approx(expected_first, abs=1e-5)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# World-model view adjustment (padding / trimming) tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
_MULTIVIEW_NUM_FRAMES = 4 # must be >= 2 * jepa_tubelet_size (=2) for world-model tests
|
||||
|
||||
|
||||
def _make_multiview_config(num_views: int, jepa_tubelet_size: int = 2) -> VLAJEPAConfig:
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE
|
||||
|
||||
config = VLAJEPAConfig(
|
||||
input_features={
|
||||
**{
|
||||
f"{OBS_IMAGES}.cam{i}": PolicyFeature(
|
||||
type=FeatureType.VISUAL, shape=(3, IMAGE_SIZE, IMAGE_SIZE)
|
||||
)
|
||||
for i in range(num_views)
|
||||
},
|
||||
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(STATE_DIM,)),
|
||||
},
|
||||
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(ACTION_DIM,))},
|
||||
device="cpu",
|
||||
chunk_size=ACTION_HORIZON,
|
||||
n_action_steps=N_ACTION_STEPS,
|
||||
action_dim=ACTION_DIM,
|
||||
state_dim=STATE_DIM,
|
||||
num_video_frames=_MULTIVIEW_NUM_FRAMES,
|
||||
num_action_tokens_per_timestep=2,
|
||||
num_embodied_action_tokens_per_instruction=3,
|
||||
num_inference_timesteps=2,
|
||||
action_hidden_size=QWEN_HIDDEN_SIZE,
|
||||
action_model_type="DiT-test",
|
||||
action_num_layers=1,
|
||||
predictor_depth=1,
|
||||
predictor_num_heads=2,
|
||||
predictor_mlp_ratio=2.0,
|
||||
jepa_tubelet_size=jepa_tubelet_size,
|
||||
)
|
||||
config.validate_features()
|
||||
return config
|
||||
|
||||
|
||||
def _make_multiview_train_batch(num_views: int, batch_size: int = BATCH_SIZE) -> dict:
|
||||
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE
|
||||
|
||||
batch = {
|
||||
f"{OBS_IMAGES}.cam{i}": torch.rand(batch_size, _MULTIVIEW_NUM_FRAMES, 3, IMAGE_SIZE, IMAGE_SIZE)
|
||||
for i in range(num_views)
|
||||
}
|
||||
batch[OBS_STATE] = torch.randn(batch_size, 1, STATE_DIM)
|
||||
batch[ACTION] = torch.randn(batch_size, ACTION_HORIZON, ACTION_DIM)
|
||||
batch["task"] = ["pick up the cube"] * batch_size
|
||||
return batch
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"num_views",
|
||||
[
|
||||
1, # fewer views than jepa_tubelet_size → first view duplicated
|
||||
2, # exact match → unchanged
|
||||
3, # more views than jepa_tubelet_size → trimmed to first two
|
||||
],
|
||||
)
|
||||
def test_training_forward_world_model_view_adjustment(
|
||||
patch_vla_jepa_external_models: None,
|
||||
num_views: int,
|
||||
) -> None:
|
||||
"""World-model view padding/trimming must not break the training forward pass."""
|
||||
set_seed_all(42)
|
||||
policy = VLAJEPAPolicy(_make_multiview_config(num_views=num_views, jepa_tubelet_size=2))
|
||||
policy.train()
|
||||
loss, logs = policy.forward(_make_multiview_train_batch(num_views=num_views))
|
||||
assert torch.isfinite(loss)
|
||||
assert logs["wm_loss"] >= 0
|
||||
|
||||
|
||||
def test_single_view_is_duplicated_for_world_model(patch_vla_jepa_external_models: None) -> None:
|
||||
"""With one dataset view and jepa_tubelet_size=2, the view must be duplicated before encoding."""
|
||||
set_seed_all(42)
|
||||
policy = VLAJEPAPolicy(_make_multiview_config(num_views=1, jepa_tubelet_size=2))
|
||||
policy.train()
|
||||
|
||||
captured_videos: list = []
|
||||
original_processor = policy.model.video_processor
|
||||
|
||||
class _CapturingProcessor:
|
||||
def __call__(self, videos: list, return_tensors: str) -> dict:
|
||||
captured_videos.extend(videos)
|
||||
return original_processor(videos=videos, return_tensors=return_tensors)
|
||||
|
||||
policy.model.video_processor = _CapturingProcessor()
|
||||
policy.forward(_make_multiview_train_batch(num_views=1))
|
||||
|
||||
# reshape is batch-major: (b0v0, b0v1, b1v0, b1v1, …)
|
||||
assert len(captured_videos) == BATCH_SIZE * 2
|
||||
for i in range(BATCH_SIZE):
|
||||
np.testing.assert_array_equal(captured_videos[2 * i], captured_videos[2 * i + 1])
|
||||
|
||||
|
||||
def test_excess_views_trimmed_for_world_model(patch_vla_jepa_external_models: None) -> None:
|
||||
"""With three dataset views and jepa_tubelet_size=2, only the first two views reach the encoder."""
|
||||
set_seed_all(42)
|
||||
policy = VLAJEPAPolicy(_make_multiview_config(num_views=3, jepa_tubelet_size=2))
|
||||
policy.train()
|
||||
|
||||
captured_videos: list = []
|
||||
original_processor = policy.model.video_processor
|
||||
|
||||
class _CapturingProcessor:
|
||||
def __call__(self, videos: list, return_tensors: str) -> dict:
|
||||
captured_videos.extend(videos)
|
||||
return original_processor(videos=videos, return_tensors=return_tensors)
|
||||
|
||||
policy.model.video_processor = _CapturingProcessor()
|
||||
policy.forward(_make_multiview_train_batch(num_views=3))
|
||||
|
||||
# Only B*2 items must reach the encoder, not B*3.
|
||||
assert len(captured_videos) == BATCH_SIZE * 2
|
||||
@@ -0,0 +1,60 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.policies.vla_jepa.world_model import (
|
||||
ActionConditionedVideoPredictor,
|
||||
)
|
||||
|
||||
_ACTION_EMBED_DIM = 8
|
||||
|
||||
|
||||
def _make_predictor(
|
||||
embed_dim: int = 8,
|
||||
action_embed_dim: int = _ACTION_EMBED_DIM,
|
||||
predictor_embed_dim: int = 24,
|
||||
num_action_tokens: int = 2,
|
||||
tokens_per_frame: int = 1,
|
||||
) -> ActionConditionedVideoPredictor:
|
||||
return ActionConditionedVideoPredictor(
|
||||
num_frames=1,
|
||||
img_size=(1, tokens_per_frame),
|
||||
patch_size=1,
|
||||
tubelet_size=1,
|
||||
embed_dim=embed_dim,
|
||||
action_embed_dim=action_embed_dim,
|
||||
predictor_embed_dim=predictor_embed_dim,
|
||||
depth=1,
|
||||
num_heads=2,
|
||||
mlp_ratio=2.0,
|
||||
num_action_tokens_per_step=num_action_tokens,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"batch,num_steps,tokens_per_frame,embed_dim",
|
||||
[
|
||||
(1, 2, 1, 8),
|
||||
(2, 3, 4, 8),
|
||||
(4, 5, 2, 16),
|
||||
],
|
||||
)
|
||||
def test_predictor_output_shape(batch: int, num_steps: int, tokens_per_frame: int, embed_dim: int) -> None:
|
||||
predictor = _make_predictor(
|
||||
embed_dim=embed_dim, action_embed_dim=_ACTION_EMBED_DIM, tokens_per_frame=tokens_per_frame
|
||||
)
|
||||
frame_tokens = torch.randn(batch, num_steps * tokens_per_frame, embed_dim)
|
||||
action_tokens = torch.randn(batch, num_steps * 2, _ACTION_EMBED_DIM)
|
||||
out = predictor(frame_tokens, action_tokens)
|
||||
assert tuple(out.shape) == (batch, num_steps * tokens_per_frame, embed_dim)
|
||||
assert torch.isfinite(out).all()
|
||||
|
||||
|
||||
def test_predictor_step_mismatch_raises() -> None:
|
||||
predictor = _make_predictor(tokens_per_frame=4)
|
||||
frame_tokens = torch.randn(2, 3 * 4, 8) # 3 steps, 4 tokens each
|
||||
with pytest.raises(RuntimeError):
|
||||
predictor(frame_tokens, torch.randn(2, 2 * 2, 8)) # 2 steps → mismatch
|
||||
@@ -59,6 +59,7 @@ def test_strategy_config_types():
|
||||
from lerobot.rollout import (
|
||||
BaseStrategyConfig,
|
||||
DAggerStrategyConfig,
|
||||
EpisodicStrategyConfig,
|
||||
HighlightStrategyConfig,
|
||||
SentryStrategyConfig,
|
||||
)
|
||||
@@ -67,6 +68,7 @@ def test_strategy_config_types():
|
||||
assert SentryStrategyConfig().type == "sentry"
|
||||
assert HighlightStrategyConfig().type == "highlight"
|
||||
assert DAggerStrategyConfig().type == "dagger"
|
||||
assert EpisodicStrategyConfig().type == "episodic"
|
||||
|
||||
|
||||
def test_dagger_config_invalid_input_device():
|
||||
@@ -203,6 +205,8 @@ def test_create_strategy_dispatches():
|
||||
BaseStrategyConfig,
|
||||
DAggerStrategy,
|
||||
DAggerStrategyConfig,
|
||||
EpisodicStrategy,
|
||||
EpisodicStrategyConfig,
|
||||
SentryStrategy,
|
||||
SentryStrategyConfig,
|
||||
create_strategy,
|
||||
@@ -211,6 +215,7 @@ def test_create_strategy_dispatches():
|
||||
assert isinstance(create_strategy(BaseStrategyConfig()), BaseStrategy)
|
||||
assert isinstance(create_strategy(SentryStrategyConfig()), SentryStrategy)
|
||||
assert isinstance(create_strategy(DAggerStrategyConfig()), DAggerStrategy)
|
||||
assert isinstance(create_strategy(EpisodicStrategyConfig()), EpisodicStrategy)
|
||||
|
||||
|
||||
def test_create_strategy_unknown_raises():
|
||||
|
||||
@@ -1172,10 +1172,11 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "diffusers"
|
||||
version = "0.35.2"
|
||||
version = "0.36.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "filelock" },
|
||||
{ name = "httpx" },
|
||||
{ name = "huggingface-hub" },
|
||||
{ name = "importlib-metadata" },
|
||||
{ name = "numpy" },
|
||||
@@ -1184,9 +1185,9 @@ dependencies = [
|
||||
{ name = "requests" },
|
||||
{ name = "safetensors" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/03/68/288ca23c7c05c73e87ffe5efffc282400ac9b017f7a9bb03883f4310ea15/diffusers-0.35.2.tar.gz", hash = "sha256:30ecd552303edfcfe1724573c3918a8462ee3ab4d529bdbd4c0045f763affded", size = 3366711, upload-time = "2025-10-15T04:05:17.213Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/88/45/ccb2e2180ddf475a0f931dac6a50346310e4c464ce3cccb8a65d1fc1e16d/diffusers-0.36.0.tar.gz", hash = "sha256:a9cde8721b415bde6a678f2d02abb85396487e1b0e0d2b4abb462d14a9825ab0", size = 3795088, upload-time = "2025-12-08T10:14:34.255Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/2a/2e/38d9824f8c6bb048c5ba21c6d4da54c29c162a46b58b3ef907a360a76d3e/diffusers-0.35.2-py3-none-any.whl", hash = "sha256:d50d5e74fdd6dcf55e5c1d304bc52cc7c2659abd1752740d736d7b54078b4db5", size = 4121649, upload-time = "2025-10-15T04:05:14.391Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/35/50/281f92cb1f83854dbd79b6e958b3bc5018607e2542971d41604ba7a14b2f/diffusers-0.36.0-py3-none-any.whl", hash = "sha256:525d42abc74bfc3b2db594999961295c054b48ef40a11724dacf50e6abd1af98", size = 4597884, upload-time = "2025-12-08T10:14:31.979Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1636,6 +1637,18 @@ http = [
|
||||
{ name = "aiohttp" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ftfy"
|
||||
version = "6.3.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "wcwidth" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/a5/d3/8650919bc3c7c6e90ee3fa7fd618bf373cbbe55dff043bd67353dbb20cd8/ftfy-6.3.1.tar.gz", hash = "sha256:9b3c3d90f84fb267fe64d375a07b7f8912d817cf86009ae134aa03e1819506ec", size = 308927, upload-time = "2024-10-26T00:50:35.149Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/ab/6e/81d47999aebc1b155f81eca4477a616a70f238a2549848c38983f3c22a82/ftfy-6.3.1-py3-none-any.whl", hash = "sha256:7c70eb532015cd2f9adb53f101fb6c7945988d023a085d127d1573dc49dd0083", size = 44821, upload-time = "2024-10-26T00:50:33.425Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "future"
|
||||
version = "1.0.0"
|
||||
@@ -2696,6 +2709,7 @@ all = [
|
||||
{ name = "faker" },
|
||||
{ name = "fastapi" },
|
||||
{ name = "feetech-servo-sdk" },
|
||||
{ name = "ftfy" },
|
||||
{ name = "grpcio" },
|
||||
{ name = "grpcio-tools" },
|
||||
{ name = "gym-aloha" },
|
||||
@@ -2704,6 +2718,7 @@ all = [
|
||||
{ name = "hebi-py" },
|
||||
{ name = "hf-libero", marker = "sys_platform == 'linux'" },
|
||||
{ name = "hidapi" },
|
||||
{ name = "imageio", extra = ["ffmpeg"] },
|
||||
{ name = "ipykernel" },
|
||||
{ name = "jsonlines" },
|
||||
{ name = "jupyter" },
|
||||
@@ -2877,6 +2892,9 @@ hopejr = [
|
||||
{ name = "pygame" },
|
||||
{ name = "pyserial" },
|
||||
]
|
||||
imageio-dep = [
|
||||
{ name = "imageio", extra = ["ffmpeg"] },
|
||||
]
|
||||
intelrealsense = [
|
||||
{ name = "pyrealsense2", marker = "sys_platform != 'darwin'" },
|
||||
{ name = "pyrealsense2-macosx", marker = "sys_platform == 'darwin'" },
|
||||
@@ -2901,6 +2919,13 @@ libero = [
|
||||
{ name = "torchcodec", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'AMD64' and sys_platform == 'linux') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'win32'" },
|
||||
{ name = "transformers" },
|
||||
]
|
||||
lingbot-va = [
|
||||
{ name = "accelerate" },
|
||||
{ name = "diffusers" },
|
||||
{ name = "ftfy" },
|
||||
{ name = "imageio", extra = ["ffmpeg"] },
|
||||
{ name = "transformers" },
|
||||
]
|
||||
matplotlib-dep = [
|
||||
{ name = "contourpy" },
|
||||
{ name = "matplotlib" },
|
||||
@@ -3052,6 +3077,11 @@ video-benchmark = [
|
||||
viz = [
|
||||
{ name = "rerun-sdk" },
|
||||
]
|
||||
vla-jepa = [
|
||||
{ name = "diffusers" },
|
||||
{ name = "qwen-vl-utils" },
|
||||
{ name = "transformers" },
|
||||
]
|
||||
wallx = [
|
||||
{ name = "peft" },
|
||||
{ name = "qwen-vl-utils" },
|
||||
@@ -3065,6 +3095,7 @@ xvla = [
|
||||
|
||||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "accelerate", marker = "extra == 'lingbot-va'", specifier = ">=1.10.0,<2.0.0" },
|
||||
{ name = "accelerate", marker = "extra == 'smolvla'", specifier = ">=1.7.0,<2.0.0" },
|
||||
{ name = "accelerate", marker = "extra == 'training'", specifier = ">=1.10.0,<2.0.0" },
|
||||
{ name = "av", marker = "extra == 'av-dep'", specifier = ">=15.0.0,<16.0.0" },
|
||||
@@ -3074,7 +3105,8 @@ requires-dist = [
|
||||
{ name = "debugpy", marker = "extra == 'dev'", specifier = ">=1.8.1,<1.9.0" },
|
||||
{ name = "decord", marker = "(platform_machine == 'AMD64' and extra == 'groot') or (platform_machine == 'x86_64' and extra == 'groot')", specifier = ">=0.6.0,<1.0.0" },
|
||||
{ name = "deepdiff", marker = "extra == 'deepdiff-dep'", specifier = ">=7.0.1,<9.0.0" },
|
||||
{ name = "diffusers", marker = "extra == 'diffusers-dep'", specifier = ">=0.27.2,<0.36.0" },
|
||||
{ name = "diffusers", marker = "extra == 'diffusers-dep'", specifier = ">=0.27.2,<0.37.0" },
|
||||
{ name = "diffusers", marker = "extra == 'lingbot-va'", specifier = ">=0.36.0,<0.37.0" },
|
||||
{ name = "dm-tree", marker = "extra == 'groot'", specifier = ">=0.1.8,<1.0.0" },
|
||||
{ name = "draccus", specifier = "==0.10.0" },
|
||||
{ name = "dynamixel-sdk", marker = "extra == 'dynamixel'", specifier = ">=3.7.31,<3.9.0" },
|
||||
@@ -3083,6 +3115,7 @@ requires-dist = [
|
||||
{ name = "fastapi", marker = "extra == 'phone'", specifier = "<1.0" },
|
||||
{ name = "feetech-servo-sdk", marker = "extra == 'feetech'", specifier = ">=1.0.0,<2.0.0" },
|
||||
{ name = "flash-attn", marker = "sys_platform != 'darwin' and extra == 'groot'", specifier = ">=2.5.9,<3.0.0" },
|
||||
{ name = "ftfy", marker = "extra == 'lingbot-va'", specifier = ">=6.0.0,<7.0.0" },
|
||||
{ name = "grpcio", marker = "extra == 'grpcio-dep'", specifier = "==1.73.1" },
|
||||
{ name = "grpcio-tools", marker = "extra == 'dev'", specifier = "==1.73.1" },
|
||||
{ name = "gym-aloha", marker = "extra == 'aloha'", specifier = ">=0.1.2,<0.2.0" },
|
||||
@@ -3093,6 +3126,7 @@ requires-dist = [
|
||||
{ name = "hf-libero", marker = "sys_platform == 'linux' and extra == 'libero'", specifier = ">=0.1.3,<0.2.0" },
|
||||
{ name = "hidapi", marker = "extra == 'gamepad'", specifier = ">=0.14.0,<0.15.0" },
|
||||
{ name = "huggingface-hub", specifier = ">=1.0.0,<2.0.0" },
|
||||
{ name = "imageio", extras = ["ffmpeg"], marker = "extra == 'imageio-dep'", specifier = ">=2.34.0,<3.0.0" },
|
||||
{ name = "ipykernel", marker = "extra == 'notebook'", specifier = ">=6.0.0,<7.0.0" },
|
||||
{ name = "jsonlines", marker = "extra == 'dataset'", specifier = ">=4.0.0,<5.0.0" },
|
||||
{ name = "jupyter", marker = "extra == 'notebook'", specifier = ">=1.0.0,<2.0.0" },
|
||||
@@ -3120,8 +3154,10 @@ requires-dist = [
|
||||
{ name = "lerobot", extras = ["diffusers-dep"], marker = "extra == 'diffusion'" },
|
||||
{ name = "lerobot", extras = ["diffusers-dep"], marker = "extra == 'groot'" },
|
||||
{ name = "lerobot", extras = ["diffusers-dep"], marker = "extra == 'multi-task-dit'" },
|
||||
{ name = "lerobot", extras = ["diffusers-dep"], marker = "extra == 'vla-jepa'" },
|
||||
{ name = "lerobot", extras = ["diffusion"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["dynamixel"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["eo1"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["feetech"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["feetech"], marker = "extra == 'hopejr'" },
|
||||
{ name = "lerobot", extras = ["feetech"], marker = "extra == 'lekiwi'" },
|
||||
@@ -3133,10 +3169,12 @@ requires-dist = [
|
||||
{ name = "lerobot", extras = ["hardware"], marker = "extra == 'core-scripts'" },
|
||||
{ name = "lerobot", extras = ["hilserl"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["hopejr"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["imageio-dep"], marker = "extra == 'lingbot-va'" },
|
||||
{ name = "lerobot", extras = ["intelrealsense"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["kinematics"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["lekiwi"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["libero"], marker = "sys_platform == 'linux' and extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["lingbot-va"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["matplotlib-dep"], marker = "extra == 'async'" },
|
||||
{ name = "lerobot", extras = ["matplotlib-dep"], marker = "extra == 'sarm'" },
|
||||
{ name = "lerobot", extras = ["matplotlib-dep"], marker = "extra == 'unitree-g1'" },
|
||||
@@ -3171,6 +3209,7 @@ requires-dist = [
|
||||
{ name = "lerobot", extras = ["qwen-vl-utils-dep"], marker = "extra == 'eo1'" },
|
||||
{ name = "lerobot", extras = ["qwen-vl-utils-dep"], marker = "extra == 'robometer'" },
|
||||
{ name = "lerobot", extras = ["qwen-vl-utils-dep"], marker = "extra == 'sarm'" },
|
||||
{ name = "lerobot", extras = ["qwen-vl-utils-dep"], marker = "extra == 'vla-jepa'" },
|
||||
{ name = "lerobot", extras = ["qwen-vl-utils-dep"], marker = "extra == 'wallx'" },
|
||||
{ name = "lerobot", extras = ["reachy2"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["rebot"], marker = "extra == 'all'" },
|
||||
@@ -3192,6 +3231,7 @@ requires-dist = [
|
||||
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'groot'" },
|
||||
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'hilserl'" },
|
||||
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'libero'" },
|
||||
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'lingbot-va'" },
|
||||
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'molmoact2'" },
|
||||
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'multi-task-dit'" },
|
||||
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'peft'" },
|
||||
@@ -3200,12 +3240,14 @@ requires-dist = [
|
||||
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'sarm'" },
|
||||
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'smolvla'" },
|
||||
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'topreward'" },
|
||||
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'vla-jepa'" },
|
||||
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'wallx'" },
|
||||
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'xvla'" },
|
||||
{ name = "lerobot", extras = ["video-benchmark"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["viz"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["viz"], marker = "extra == 'core-scripts'" },
|
||||
{ name = "lerobot", extras = ["viz"], marker = "extra == 'dataset-viz'" },
|
||||
{ name = "lerobot", extras = ["vla-jepa"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["wallx"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["xvla"], marker = "extra == 'all'" },
|
||||
{ name = "matplotlib", marker = "extra == 'matplotlib-dep'", specifier = ">=3.10.3,<4.0.0" },
|
||||
@@ -3267,7 +3309,7 @@ requires-dist = [
|
||||
{ name = "transformers", marker = "extra == 'transformers-dep'", specifier = ">=5.4.0,<5.6.0" },
|
||||
{ name = "wandb", marker = "extra == 'training'", specifier = ">=0.24.0,<0.25.0" },
|
||||
]
|
||||
provides-extras = ["dataset", "training", "hardware", "viz", "core-scripts", "evaluation", "dataset-viz", "av-dep", "pygame-dep", "placo-dep", "transformers-dep", "grpcio-dep", "can-dep", "peft-dep", "scipy-dep", "diffusers-dep", "qwen-vl-utils-dep", "matplotlib-dep", "pyserial-dep", "deepdiff-dep", "pynput-dep", "pyzmq-dep", "motorbridge-dep", "motorbridge-smart-servo-dep", "feetech", "dynamixel", "damiao", "robstride", "openarms", "gamepad", "hopejr", "lekiwi", "unitree-g1", "reachy2", "rebot", "kinematics", "intelrealsense", "phone", "diffusion", "wallx", "pi", "molmoact2", "smolvla", "multi-task-dit", "groot", "sarm", "robometer", "topreward", "xvla", "eo1", "hilserl", "async", "peft", "dev", "notebook", "test", "video-benchmark", "aloha", "pusht", "libero", "metaworld", "all"]
|
||||
provides-extras = ["dataset", "training", "hardware", "viz", "core-scripts", "evaluation", "dataset-viz", "av-dep", "pygame-dep", "placo-dep", "transformers-dep", "grpcio-dep", "can-dep", "peft-dep", "scipy-dep", "diffusers-dep", "imageio-dep", "qwen-vl-utils-dep", "matplotlib-dep", "pyserial-dep", "deepdiff-dep", "pynput-dep", "pyzmq-dep", "motorbridge-dep", "motorbridge-smart-servo-dep", "feetech", "dynamixel", "damiao", "robstride", "openarms", "gamepad", "hopejr", "lekiwi", "unitree-g1", "reachy2", "rebot", "kinematics", "intelrealsense", "phone", "diffusion", "wallx", "pi", "molmoact2", "smolvla", "multi-task-dit", "groot", "sarm", "robometer", "topreward", "xvla", "eo1", "hilserl", "vla-jepa", "lingbot-va", "async", "peft", "dev", "notebook", "test", "video-benchmark", "aloha", "pusht", "libero", "metaworld", "all"]
|
||||
|
||||
[[package]]
|
||||
name = "librt"
|
||||
|
||||
Reference in New Issue
Block a user