Compare commits

...

3 Commits

Author SHA1 Message Date
Maximellerbach 6021554770 chore(rollout): nice collored cli 2026-05-07 11:12:02 +02:00
Haoming Song e99c55af4b feat(policies): add EO-1 model (#3403)
* feat(policies): add EO-1 model

* chore(eo1): adjust policy_eo1_README.md to to avoid duplicate with eo1.mdx

* chore(eo1): remove policy_eo1_README.md, link eo1.mdx in policy folder

---------

Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com>
2026-05-06 18:01:16 +02:00
Steven Palma 408e0ca763 fix(robots): openarm features with openarmmini (#3524) 2026-05-06 17:03:09 +02:00
24 changed files with 1843 additions and 26 deletions
+2
View File
@@ -47,6 +47,8 @@
title: π₀-FAST (Pi0Fast)
- local: pi05
title: π₀.₅ (Pi05)
- local: eo1
title: EO-1
- local: groot
title: NVIDIA GR00T N1.5
- local: xvla
+168
View File
@@ -0,0 +1,168 @@
# EO-1
EO-1 is a **Vision-Language-Action policy for robot control**. The LeRobot implementation integrates EO-1 with the standard LeRobot training, evaluation, processor interface.
## Model Overview
EO-1 uses a Qwen2.5-VL backbone for vision-language understanding and adds a continuous flow-matching action head for robot control. The policy formats each robot-control sample as a multimodal conversation: camera images are passed to Qwen2.5-VL, the robot state is represented with EO-1 state tokens, and the future action chunk is represented with EO-1 action tokens.
<img
src="https://huggingface.co/datasets/HaomingSong/lerobot-documentation-images/resolve/main/lerobot/eo_pipeline.png"
alt="An overview of EO-1"
width="85%"
/>
During training, EO-1 learns to denoise continuous action chunks at the action-token positions. During inference, it samples an action chunk, returns continuous actions, and executes `n_action_steps` from the chunk before sampling again.
### What the LeRobot Integration Covers
- Standard `policy.type=eo1` configuration through LeRobot
- Qwen2.5-VL image and text preprocessing through policy processors
- Continuous flow-matching action prediction
- Checkpoint save/load through LeRobot policy APIs
- Training with `lerobot-train` and evaluation with `lerobot-eval`
The broader EO-1 project also includes interleaved vision-text-action pretraining and multimodal reasoning workflows. This page focuses on the LeRobot robot-control policy path.
## Installation Requirements
1. Install LeRobot by following the [Installation Guide](./installation).
2. Install EO-1 dependencies by running:
```bash
pip install -e ".[eo1]"
```
3. If you want to train or evaluate on LIBERO, install the LIBERO dependencies too:
```bash
pip install -e ".[eo1,libero]"
```
EO-1 can use the standard PyTorch scaled-dot-product attention backend through `policy.attn_implementation=sdpa`. If your environment has a compatible `flash_attn` installation, you can request `policy.attn_implementation=flash_attention_2`.
## Data Requirements
EO-1 expects a LeRobot dataset with:
- At least one visual observation, for example `observation.images.image`
- `observation.state`
- `action`
- A language task instruction through the dataset `task` field
If your dataset uses different observation names, use `rename_map` to align them with the names expected by your training or evaluation setup.
## Usage
To use EO-1 in a LeRobot configuration, specify the policy type as:
```python
policy.type=eo1
```
By default, a new EO-1 policy initializes its backbone from:
```python
policy.vlm_base=Qwen/Qwen2.5-VL-3B-Instruct
```
Once a LeRobot-format EO-1 checkpoint is available, load it with:
```python
policy.path=your-org/your-eo1-checkpoint
```
## Training
### Training Command Example
```bash
lerobot-train \
--dataset.repo_id=your_org/your_dataset \
--policy.type=eo1 \
--policy.vlm_base=Qwen/Qwen2.5-VL-3B-Instruct \
--policy.dtype=bfloat16 \
--policy.attn_implementation=sdpa \
--policy.gradient_checkpointing=false \
--output_dir=./outputs/eo1_training \
--job_name=eo1_training \
--steps=300000 \
--batch_size=16 \
--policy.device=cuda
```
### Key Training Parameters
| Parameter | Default | Description |
| -------------------------------------- | ----------------------------- | ----------------------------------------------------------------------- |
| `policy.vlm_base` | `Qwen/Qwen2.5-VL-3B-Instruct` | Qwen2.5-VL checkpoint used to initialize a new policy |
| `policy.dtype` | `auto` | Backbone dtype request: `auto`, `bfloat16`, or `float32` |
| `policy.attn_implementation` | `None` | Optional Qwen attention backend, such as `sdpa` |
| `policy.gradient_checkpointing` | `false` | Reduces memory usage during training |
| `policy.chunk_size` | `8` | Number of future actions predicted per chunk |
| `policy.n_action_steps` | `8` | Number of actions consumed from a sampled chunk |
| `policy.num_denoise_steps` | `10` | Number of flow-matching denoising steps used during sampling |
| `policy.max_state_dim` | `32` | State padding dimension |
| `policy.max_action_dim` | `32` | Action padding dimension |
| `policy.force_fp32_autocast` | `true` | Keeps the flow head in fp32 even when the backbone uses mixed precision |
| `policy.supervise_padding_action_dims` | `true` | Controls whether padded action dimensions are supervised |
| `policy.supervise_padding_actions` | `true` | Controls whether padded future action rows are supervised |
## Evaluation
EO-1 can be evaluated through `lerobot-eval` once you have a LeRobot-format checkpoint:
```bash
lerobot-eval \
--policy.path=your-org/your-eo1-checkpoint \
--env.type=libero \
--env.task=libero_object \
--eval.batch_size=1 \
--eval.n_episodes=20
```
For datasets or environments whose camera names differ from the checkpoint configuration, pass a `rename_map`:
```bash
lerobot-eval \
--policy.path=your-org/your-eo1-checkpoint \
--env.type=libero \
--env.task=libero_object \
--rename_map='{"observation.images.image2":"observation.images.wrist_image"}'
```
## Configuration Notes
### Image Processing
EO-1 uses the Qwen2.5-VL processor. The `policy.image_min_pixels` and `policy.image_max_pixels` settings control the image resizing bounds before the visual tokens are passed into the backbone.
### State and Action Dimensions
The policy pads state and action vectors to `policy.max_state_dim` and `policy.max_action_dim` before the EO-1 flow head. Predictions are cropped back to the original action dimension before being returned by the policy.
### Attention Backend
Use `policy.attn_implementation=sdpa` for a portable setup. Use `flash_attention_2` only when `flash_attn` is installed and compatible with your environment.
## References
- [EO-1 project](https://github.com/EO-Robotics/EO1)
- [EO-1 paper](https://arxiv.org/abs/2508.21112)
- [Qwen2.5-VL-3B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-3B-Instruct)
## Citation
```bibtex
@article{eo1,
title={EO-1: Interleaved Vision-Text-Action Pretraining for General Robot Control},
author={Delin Qu and Haoming Song and Qizhi Chen and Zhaoqing Chen and Xianqiang Gao and Xinyi Ye and Qi Lv and Modi Shi and Guanghui Ren and Cheng Ruan and Maoqing Yao and Haoran Yang and Jiacheng Bao and Bin Zhao and Dong Wang},
journal={arXiv preprint},
year={2025},
url={https://arxiv.org/abs/2508.21112}
}
```
## License
This LeRobot integration follows the **Apache 2.0 License** used by LeRobot. Check the upstream EO-1 model and dataset pages for the licenses of released EO-1 checkpoints and data.
+1
View File
@@ -194,6 +194,7 @@ groot = [
]
sarm = ["lerobot[transformers-dep]", "pydantic>=2.0.0,<3.0.0", "faker>=33.0.0,<35.0.0", "lerobot[matplotlib-dep]", "lerobot[qwen-vl-utils-dep]"]
xvla = ["lerobot[transformers-dep]"]
eo1 = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]"]
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
# Features
+2
View File
@@ -16,6 +16,7 @@ from lerobot.utils.action_interpolator import ActionInterpolator as ActionInterp
from .act.configuration_act import ACTConfig as ACTConfig
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
from .eo1.configuration_eo1 import EO1Config as EO1Config
from .factory import get_policy_class, make_policy, make_policy_config, make_pre_post_processors
from .groot.configuration_groot import GrootConfig as GrootConfig
from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig as MultiTaskDiTConfig
@@ -41,6 +42,7 @@ __all__ = [
"DiffusionConfig",
"GrootConfig",
"MultiTaskDiTConfig",
"EO1Config",
"PI0Config",
"PI0FastConfig",
"PI05Config",
+1
View File
@@ -0,0 +1 @@
../../../../docs/source/eo1.mdx
+7
View File
@@ -0,0 +1,7 @@
#!/usr/bin/env python
from .configuration_eo1 import EO1Config
from .modeling_eo1 import EO1Policy
from .processor_eo1 import make_eo1_pre_post_processors
__all__ = ["EO1Config", "EO1Policy", "make_eo1_pre_post_processors"]
@@ -0,0 +1,193 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from copy import deepcopy
from dataclasses import dataclass, field
from typing import TYPE_CHECKING
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
from lerobot.utils.constants import ACTION, OBS_STATE
from lerobot.utils.import_utils import _transformers_available, require_package
if TYPE_CHECKING or _transformers_available:
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
Qwen2_5_VLConfig,
Qwen2_5_VLTextConfig,
Qwen2_5_VLVisionConfig,
)
else:
Qwen2_5_VLConfig = None
Qwen2_5_VLTextConfig = None
Qwen2_5_VLVisionConfig = None
@PreTrainedConfig.register_subclass("eo1")
@dataclass
class EO1Config(PreTrainedConfig):
"""Configuration for native EO1 policy integration in LeRobot."""
vlm_base: str = "Qwen/Qwen2.5-VL-3B-Instruct"
vlm_config: dict | None = None
# Vision processor settings.
image_min_pixels: int | None = 64 * 28 * 28
image_max_pixels: int | None = 128 * 28 * 28
use_fast_processor: bool = False
# Execution and action horizon.
n_obs_steps: int = 1
chunk_size: int = 8
n_action_steps: int = 8
# State/action padding to match EO1 flow head dimensionality.
max_state_dim: int = 32
max_action_dim: int = 32
# Flow matching sampling.
num_denoise_steps: int = 10
num_action_layers: int = 2
action_act: str = "linear"
time_sampling_beta_alpha: float = 1.5
time_sampling_beta_beta: float = 1.0
time_sampling_scale: float = 0.999
time_sampling_offset: float = 0.001
min_period: float = 4e-3
max_period: float = 4.0
supervise_padding_action_dims: bool = True
supervise_padding_actions: bool = True
# Policy-level dtype request for the Qwen backbone.
# - "auto": follow the backbone config/checkpoint default dtype. For Qwen2.5-VL this resolves to bf16.
# The EO1 flow-matching head still keeps its own parameters in fp32.
# - "bfloat16": force the backbone to initialize/load in bf16 regardless of the saved config default.
# - "float32": force the backbone to initialize/load in fp32 for maximum numerical conservatism.
dtype: str = "auto" # Options: "auto", "bfloat16", "float32"
force_fp32_autocast: bool = True
# Optional attention backend request passed through to the Qwen backbone.
# Common values: None, "eager", "sdpa", "flash_attention_2".
attn_implementation: str | None = None
# Training settings.
gradient_checkpointing: bool = False # Enable gradient checkpointing for memory optimization
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.MEAN_STD,
"ACTION": NormalizationMode.MEAN_STD,
}
)
# Optimizer settings aligned with EO1/experiments/2_libero/train.sh and EO1 TrainPipelineConfig defaults.
optimizer_lr: float = 1e-4
optimizer_betas: tuple[float, float] = (0.9, 0.999)
optimizer_eps: float = 1e-8
optimizer_weight_decay: float = 0.1
optimizer_grad_clip_norm: float = 1.0
# Scheduler settings aligned with EO1 train.sh: cosine schedule with warmup_ratio=0.03.
# Note: These will auto-scale if --steps < scheduler_decay_steps
# For example, --steps=3000 will scale warmup to 100 and decay to 3000
scheduler_warmup_steps: int = 900 # 0.03 * 30_000 long-run steps
scheduler_decay_steps: int = 30_000
scheduler_decay_lr: float = 0.0
def __post_init__(self):
super().__post_init__()
if self.n_action_steps > self.chunk_size:
raise ValueError(
f"n_action_steps ({self.n_action_steps}) cannot be greater than chunk_size ({self.chunk_size})"
)
# Populate the serialized backbone config only when the caller did not provide one.
if self.vlm_config is None:
require_package("transformers", extra="eo1")
self.vlm_config = Qwen2_5_VLConfig.from_pretrained(self.vlm_base).to_dict()
@property
def vlm_backbone_config(self) -> Qwen2_5_VLConfig:
require_package("transformers", extra="eo1")
config_dict = deepcopy(self.vlm_config)
if self.attn_implementation is not None:
config_dict["attn_implementation"] = self.attn_implementation
return Qwen2_5_VLConfig(**config_dict)
@property
def text_config(self) -> Qwen2_5_VLTextConfig:
return self.vlm_backbone_config.text_config
@property
def vision_config(self) -> Qwen2_5_VLVisionConfig:
return self.vlm_backbone_config.vision_config
def validate_features(self) -> None:
"""Validate and set up EO1 input and output features."""
image_features = [key for key, feat in self.input_features.items() if feat.type == FeatureType.VISUAL]
if not image_features:
raise ValueError(
"EO1 policy requires at least one visual input feature. "
"No features of type FeatureType.VISUAL found in input_features."
)
if OBS_STATE not in self.input_features:
state_feature = PolicyFeature(
type=FeatureType.STATE,
shape=(self.max_state_dim,),
)
self.input_features[OBS_STATE] = state_feature
if ACTION not in self.output_features:
action_feature = PolicyFeature(
type=FeatureType.ACTION,
shape=(self.max_action_dim,),
)
self.output_features[ACTION] = action_feature
def get_optimizer_preset(self) -> AdamWConfig:
return AdamWConfig(
lr=self.optimizer_lr,
betas=self.optimizer_betas,
eps=self.optimizer_eps,
weight_decay=self.optimizer_weight_decay,
grad_clip_norm=self.optimizer_grad_clip_norm,
)
def get_scheduler_preset(self):
return CosineDecayWithWarmupSchedulerConfig(
peak_lr=self.optimizer_lr,
decay_lr=self.scheduler_decay_lr,
num_warmup_steps=self.scheduler_warmup_steps,
num_decay_steps=self.scheduler_decay_steps,
)
@property
def observation_delta_indices(self) -> None:
return None
@property
def action_delta_indices(self) -> list[int]:
return list(range(self.chunk_size))
@property
def reward_delta_indices(self) -> None:
return None
+620
View File
@@ -0,0 +1,620 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import contextlib
import logging
import math
from collections import deque
from typing import TYPE_CHECKING, Any
import torch
import torch.nn as nn
import torch.nn.functional as F # noqa: N812
import torch.utils.checkpoint
from torch import Tensor
from lerobot.policies.eo1.configuration_eo1 import EO1Config
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.utils.constants import ACTION, OBS_STATE
from lerobot.utils.import_utils import _transformers_available, require_package
if TYPE_CHECKING or _transformers_available:
from transformers.activations import ACT2FN
from transformers.models.qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
from transformers.utils import torch_compilable_check
else:
ACT2FN = None
Qwen2_5_VLForConditionalGeneration = None
torch_compilable_check = None
logger = logging.getLogger(__name__)
def pad_vector(vector, new_dim):
"""Pad the last dimension of a vector to new_dim with zeros.
Can be (batch_size x sequence_length x features_dimension)
or (batch_size x features_dimension)
"""
if vector.shape[-1] >= new_dim:
return vector
return F.pad(vector, (0, new_dim - vector.shape[-1]))
class EO1Policy(PreTrainedPolicy):
"""EO1 policy wrapper for LeRobot robot-only training/evaluation."""
config_class = EO1Config
name = "eo1"
def __init__(self, config: EO1Config, **kwargs):
require_package("transformers", extra="eo1")
super().__init__(config)
config.validate_features()
self.config = config
if config.pretrained_path is None:
# Initialize from pretrained VLM
vlm_backbone = Qwen2_5_VLForConditionalGeneration.from_pretrained(
config.vlm_base,
dtype=config.dtype,
attn_implementation=config.attn_implementation,
)
else:
vlm_backbone = Qwen2_5_VLForConditionalGeneration._from_config(
config.vlm_backbone_config,
dtype=config.vlm_backbone_config.dtype if config.dtype == "auto" else config.dtype,
)
self.model = EO1VisionFlowMatchingModel(config, vlm_backbone)
if config.gradient_checkpointing:
self.model.gradient_checkpointing_enable()
self.model.to(config.device)
self.reset()
def reset(self):
self._action_queue = deque(maxlen=self.config.n_action_steps)
@staticmethod
def _get_model_inputs(batch: dict[str, Tensor], excluded_keys: set[str]) -> dict[str, Tensor]:
return {key: value for key, value in batch.items() if key not in excluded_keys}
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
state = self.prepare_state(batch[OBS_STATE])
actions = self.prepare_action(batch[ACTION])
model_inputs = self._get_model_inputs(batch, {OBS_STATE, ACTION})
loss = self.model(states=state, action=actions, **model_inputs)
loss_dict = {"loss": loss.item()}
return loss, loss_dict
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs) -> Tensor:
self.eval()
states = self.prepare_state(batch[OBS_STATE])
model_inputs = self._get_model_inputs(batch, {OBS_STATE})
actions = self.model.sample_actions(states=states, **model_inputs).to(torch.float32)
original_action_dim = self.config.output_features[ACTION].shape[0]
return actions[:, :, :original_action_dim]
def prepare_state(self, state: Tensor) -> Tensor:
return pad_vector(state, self.config.max_state_dim)
def prepare_action(self, action: Tensor) -> Tensor:
return pad_vector(action, self.config.max_action_dim)
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
self.eval()
if len(self._action_queue) == 0:
actions = self.predict_action_chunk(batch)[:, : self.config.n_action_steps]
self._action_queue.extend(actions.transpose(0, 1))
return self._action_queue.popleft()
def get_optim_params(self) -> dict:
return self.parameters()
def get_safe_dtype(target_dtype, device_type):
"""Get a safe dtype for the given device type."""
if device_type == "mps" and target_dtype == torch.float64:
return torch.float32
if device_type == "cpu":
# CPU doesn't support bfloat16, use float32 instead
if target_dtype == torch.bfloat16:
return torch.float32
if target_dtype == torch.float64:
return torch.float64
return target_dtype
def create_sinusoidal_pos_embedding( # see openpi `create_sinusoidal_pos_embedding` (exact copy)
time: torch.Tensor, dimension: int, min_period: float, max_period: float, device="cpu"
) -> Tensor:
"""Computes sine-cosine positional embedding vectors for scalar positions."""
if dimension % 2 != 0:
raise ValueError(f"dimension ({dimension}) must be divisible by 2")
if time.ndim != 1:
raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.")
dtype = get_safe_dtype(torch.float64, device.type)
fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device)
period = min_period * (max_period / min_period) ** fraction
# Compute the outer product
scaling_factor = 1.0 / period * 2 * math.pi
sin_input = scaling_factor[None, :] * time[:, None]
return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
def sample_beta(alpha, beta, bsize, device): # see openpi `sample_beta` (exact copy)
# Beta sampling uses _sample_dirichlet which isn't implemented for MPS, so sample on CPU
alpha_t = torch.tensor(alpha, dtype=torch.float32)
beta_t = torch.tensor(beta, dtype=torch.float32)
dist = torch.distributions.Beta(alpha_t, beta_t)
return dist.sample((bsize,)).to(device)
class EO1VisionActionProjector(torch.nn.Sequential):
"""This block implements the multi-layer perceptron (MLP) module."""
def __init__(
self,
in_channels: int,
out_channels: int,
num_layers: int = 2,
activation_layer: str = "linear",
bias: bool = True,
device: Any = None,
dtype: torch.dtype = torch.float32,
):
layers = []
in_dim = in_channels
hidden_channels = [in_dim] * (num_layers - 1) + [out_channels]
for hidden_dim in hidden_channels[:-1]:
layers.append(torch.nn.Linear(in_dim, hidden_dim, bias=bias, dtype=dtype, device=device))
layers.append(ACT2FN[activation_layer])
in_dim = hidden_dim
layers.append(torch.nn.Linear(in_dim, hidden_channels[-1], bias=bias, dtype=dtype, device=device))
super().__init__(*layers)
@property
def dtype(self):
return self[0].weight.dtype
class EO1VisionFlowMatchingModel(nn.Module):
def __init__(
self,
config: EO1Config,
vlm_backbone: Qwen2_5_VLForConditionalGeneration | None = None,
):
require_package("transformers", extra="eo1")
super().__init__()
self.config = config
# Preserve the backbone dtype selected at construction time so Qwen's fp32 rotary buffers stay intact.
self.vlm_backbone = vlm_backbone
self.hidden_size = self.vlm_backbone.config.text_config.hidden_size
max_state_dim = config.max_state_dim
max_action_dim = config.max_action_dim
self.state_proj = nn.Linear(max_state_dim, self.hidden_size, dtype=torch.float32)
self.action_in_proj = nn.Linear(max_action_dim, self.hidden_size, dtype=torch.float32)
self.action_out_proj = EO1VisionActionProjector(
self.hidden_size,
max_action_dim,
config.num_action_layers,
config.action_act,
dtype=torch.float32,
)
self.action_time_mlp_in = nn.Linear(self.hidden_size * 2, self.hidden_size, dtype=torch.float32)
self.action_time_mlp_out = nn.Linear(self.hidden_size, self.hidden_size, dtype=torch.float32)
self.gradient_checkpointing_enabled = False
def get_input_embeddings(self):
return self.vlm_backbone.get_input_embeddings()
def flow_head_autocast_context(self):
if self.config.force_fp32_autocast:
return torch.autocast(
device_type=self.state_proj.weight.device.type,
enabled=False,
)
return contextlib.nullcontext()
def gradient_checkpointing_enable(self):
"""Enable gradient checkpointing for the Qwen2.5-VL backbone."""
self.gradient_checkpointing_enabled = True
self.vlm_backbone.gradient_checkpointing_enable(
gradient_checkpointing_kwargs={"use_reentrant": False}
)
logger.info("Enabled gradient checkpointing for EO1VisionFlowMatchingModel")
def gradient_checkpointing_disable(self):
"""Disable gradient checkpointing for the Qwen2.5-VL backbone."""
self.gradient_checkpointing_enabled = False
self.vlm_backbone.gradient_checkpointing_disable()
logger.info("Disabled gradient checkpointing for EO1VisionFlowMatchingModel")
def _apply_checkpoint(self, func, *args, **kwargs):
"""Apply manual gradient checkpointing to EO1 flow-head computations when training."""
if self.gradient_checkpointing_enabled and self.training and torch.is_grad_enabled():
return torch.utils.checkpoint.checkpoint(
func, *args, use_reentrant=False, preserve_rng_state=False, **kwargs
)
return func(*args, **kwargs)
def sample_noise(self, shape, device):
noise = torch.normal(
mean=0.0,
std=1.0,
size=shape,
dtype=torch.float32,
device=device,
)
return noise
def sample_time(self, bsize, device):
time_beta = sample_beta(
self.config.time_sampling_beta_alpha, self.config.time_sampling_beta_beta, bsize, device
)
time = time_beta * self.config.time_sampling_scale + self.config.time_sampling_offset
return time.to(dtype=torch.float32, device=device)
def get_placeholder_mask(
self,
input_ids: torch.LongTensor | None,
inputs_embeds: torch.FloatTensor | None,
state_features: torch.FloatTensor | None = None,
action_features: torch.FloatTensor | None = None,
*,
state_token_id: int,
action_token_id: int,
) -> tuple[torch.BoolTensor, torch.BoolTensor]:
"""Return EO1 state/action placeholder masks, following Qwen's multimodal mask style."""
if input_ids is None:
special_state_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(state_token_id, dtype=torch.long, device=inputs_embeds.device)
)
special_state_mask = special_state_mask.all(-1)
special_action_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(action_token_id, dtype=torch.long, device=inputs_embeds.device)
)
special_action_mask = special_action_mask.all(-1)
else:
special_state_mask = input_ids == state_token_id
special_action_mask = input_ids == action_token_id
n_state_tokens = special_state_mask.sum()
special_state_mask = (
special_state_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
)
if state_features is not None:
torch_compilable_check(
inputs_embeds[special_state_mask].numel() == state_features.numel(),
f"State features and state tokens do not match, tokens: {n_state_tokens}, features: {state_features.shape[0]}",
)
n_action_tokens = special_action_mask.sum()
special_action_mask = (
special_action_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
)
if action_features is not None:
torch_compilable_check(
inputs_embeds[special_action_mask].numel() == action_features.numel(),
f"Action features and action tokens do not match, tokens: {n_action_tokens}, features: {action_features.shape[0]}",
)
return special_state_mask, special_action_mask
def embed_prefix(
self,
input_ids: torch.LongTensor,
states: torch.Tensor,
*,
state_token_id: int,
action_token_id: int,
) -> torch.FloatTensor:
"""Embed the EO1 prefix tokens before native Qwen injects multimodal features."""
# Get the input embeddings for the input IDs
def input_embed_func(input_ids: torch.LongTensor) -> torch.FloatTensor:
return self.get_input_embeddings()(input_ids)
inputs_embeds = self._apply_checkpoint(input_embed_func, input_ids)
# Project the states to the hidden size
def state_proj_func(states: torch.Tensor) -> torch.FloatTensor:
with self.flow_head_autocast_context():
states = states.to(dtype=self.state_proj.weight.dtype)
return self.state_proj(states)
state_embs = self._apply_checkpoint(state_proj_func, states)
state_mask, _ = self.get_placeholder_mask(
input_ids,
inputs_embeds,
state_features=state_embs,
state_token_id=state_token_id,
action_token_id=action_token_id,
)
state_embs = state_embs.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(state_mask, state_embs)
return inputs_embeds
def embed_suffix(
self,
timestep: torch.Tensor,
noisy_actions: torch.Tensor,
) -> torch.FloatTensor:
"""Embed the suffix"""
def action_proj_func(noisy_actions: torch.Tensor) -> torch.FloatTensor:
with self.flow_head_autocast_context():
noisy_actions = noisy_actions.to(dtype=self.action_in_proj.weight.dtype)
return self.action_in_proj(noisy_actions)
action_embs = self._apply_checkpoint(action_proj_func, noisy_actions)
time_embs = create_sinusoidal_pos_embedding(
timestep,
self.hidden_size,
min_period=self.config.min_period,
max_period=self.config.max_period,
device=action_embs.device,
)
time_embs = time_embs.to(dtype=action_embs.dtype)
time_embs = time_embs[:, None, :].expand_as(action_embs)
action_time_embs = torch.cat([action_embs, time_embs], dim=2)
def mlp_func(action_time_embs: torch.Tensor) -> torch.FloatTensor:
with self.flow_head_autocast_context():
action_time_embs = action_time_embs.to(dtype=self.action_time_mlp_in.weight.dtype)
action_time_embs = self.action_time_mlp_in(action_time_embs)
action_time_embs = F.silu(action_time_embs)
return self.action_time_mlp_out(action_time_embs)
action_time_embs = self._apply_checkpoint(mlp_func, action_time_embs)
return action_time_embs
def forward(
self,
input_ids: torch.LongTensor | None = None,
attention_mask: torch.LongTensor | None = None,
pixel_values: torch.FloatTensor | None = None,
image_grid_thw: torch.LongTensor | None = None,
mm_token_type_ids: torch.IntTensor | None = None,
states: torch.FloatTensor | None = None,
action: torch.FloatTensor | None = None,
action_is_pad: torch.BoolTensor | None = None,
*,
state_token_id: int,
action_token_id: int,
**kwargs,
) -> Tensor:
"""Run the EO1 training forward pass and compute the flow-matching loss."""
# 1. Build the EO1 prefix with state placeholders resolved.
inputs_embeds = self.embed_prefix(
input_ids,
states=states,
state_token_id=state_token_id,
action_token_id=action_token_id,
)
# 2. Sample the diffusion target and replace the action placeholders.
time = self.sample_time(action.shape[0], inputs_embeds.device)
noise = self.sample_noise(action.shape, inputs_embeds.device)
time_expanded = time[:, None, None]
x_t = time_expanded * noise + (1 - time_expanded) * action
u_t = noise - action
action_time_embs = self.embed_suffix(time, x_t)
_, action_mask = self.get_placeholder_mask(
input_ids,
inputs_embeds,
action_features=action_time_embs,
state_token_id=state_token_id,
action_token_id=action_token_id,
)
action_time_embs = action_time_embs.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(action_mask, action_time_embs)
# 3. Optionally drop padded action tokens from backbone attention.
if attention_mask is not None:
attention_mask = attention_mask.to(inputs_embeds.device)
if not self.config.supervise_padding_actions:
action_is_pad = action_is_pad.to(device=inputs_embeds.device, dtype=torch.bool)
action_token_mask = action_mask[..., 0]
action_padding_mask = torch.zeros_like(action_token_mask)
action_padding_mask = action_padding_mask.masked_scatter(
action_token_mask,
action_is_pad.reshape(-1),
)
attention_mask = attention_mask.masked_fill(action_padding_mask, 0)
# 4. Run the Qwen backbone on the fused EO1 sequence.
def vlm_forward_func(
input_ids: torch.LongTensor,
attention_mask: torch.Tensor | None,
inputs_embeds: torch.FloatTensor,
pixel_values: torch.Tensor | None,
image_grid_thw: torch.LongTensor | None,
mm_token_type_ids: torch.IntTensor | None,
) -> torch.FloatTensor:
outputs = self.vlm_backbone.model(
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
pixel_values=pixel_values,
image_grid_thw=image_grid_thw,
mm_token_type_ids=mm_token_type_ids,
use_cache=False,
output_hidden_states=False,
return_dict=True,
)
return outputs.last_hidden_state
hidden_states = self._apply_checkpoint(
vlm_forward_func,
input_ids,
attention_mask,
inputs_embeds,
pixel_values,
image_grid_thw,
mm_token_type_ids,
)
action_hidden_states = hidden_states[action_mask[..., 0]]
# 5. Project the action-token hidden states back to the flow target space.
def action_out_proj_func(action_hidden_states: torch.FloatTensor) -> torch.FloatTensor:
with self.flow_head_autocast_context():
action_hidden_states = action_hidden_states.to(dtype=self.action_out_proj.dtype)
return self.action_out_proj(action_hidden_states)
v_t = self._apply_checkpoint(action_out_proj_func, action_hidden_states)
v_t = v_t.reshape(u_t.shape).to(dtype=u_t.dtype)
losses = F.mse_loss(u_t, v_t, reduction="none")
# 6. Apply the configured supervision mask and reduce the loss.
if not self.config.supervise_padding_action_dims:
original_action_dim = self.config.output_features[ACTION].shape[0]
losses = losses[..., :original_action_dim]
if not self.config.supervise_padding_actions:
losses = losses[~action_is_pad]
return losses.mean()
@torch.no_grad()
def sample_actions(
self,
input_ids: torch.LongTensor | None = None,
attention_mask: torch.Tensor | None = None,
pixel_values: torch.Tensor | None = None,
image_grid_thw: torch.LongTensor | None = None,
mm_token_type_ids: torch.IntTensor | None = None,
states: torch.Tensor | None = None,
*,
state_token_id: int,
action_token_id: int,
**kwargs,
) -> Tensor:
"""Sample actions from the model."""
if states is None:
raise ValueError("states are required for EO1 action sampling.")
if mm_token_type_ids is None:
raise ValueError("mm_token_type_ids are required for EO1 action sampling.")
# 1. Resolve the left-padded rollout prompt and locate the action span.
chunk_size = self.config.chunk_size
inputs_embeds = self.embed_prefix(
input_ids,
states=states,
state_token_id=state_token_id,
action_token_id=action_token_id,
).clone()
_, action_placeholder_mask = self.get_placeholder_mask(
input_ids,
inputs_embeds,
state_token_id=state_token_id,
action_token_id=action_token_id,
)
action_mask = action_placeholder_mask[..., 0]
token_counts = action_mask.sum(dim=1)
if not torch.all(token_counts == chunk_size):
raise ValueError(
f"Each sample must contain exactly {chunk_size} action tokens, got {token_counts.tolist()}."
)
if action_mask.ne(action_mask[:1]).any():
raise ValueError(
"Batch inference expects all samples to share the same action token mask after left padding."
)
act_start = int(action_mask[0].to(torch.int64).argmax().item())
act_end = act_start + self.config.chunk_size
if not torch.all(action_mask[:, act_start:act_end]):
raise ValueError("Action tokens must form a contiguous chunk of length chunk_size.")
act_slice = slice(act_start, act_end)
# 2. Encode the fixed prefix once and cache its KV state.
batch_size = input_ids.shape[0]
device = inputs_embeds.device
attention_mask = attention_mask.to(device)
mm_token_type_ids = mm_token_type_ids.to(device)
position_ids, _ = self.vlm_backbone.model.get_rope_index(
input_ids,
image_grid_thw=image_grid_thw,
attention_mask=attention_mask,
mm_token_type_ids=mm_token_type_ids,
)
position_ids = position_ids.to(device)
outputs = self.vlm_backbone.model(
input_ids=input_ids[:, :act_start],
attention_mask=attention_mask[:, :act_start],
position_ids=position_ids[..., :act_start],
inputs_embeds=inputs_embeds[:, :act_start],
pixel_values=pixel_values,
image_grid_thw=image_grid_thw,
mm_token_type_ids=mm_token_type_ids[:, :act_start],
use_cache=True,
return_dict=True,
)
x_t = self.sample_noise(
(batch_size, chunk_size, self.config.max_action_dim),
device,
).to(dtype=self.action_in_proj.weight.dtype)
dt = -1.0 / self.config.num_denoise_steps
past_key_values = outputs.past_key_values
# 3. Denoise only the action chunk while keeping the prefix cache invariant.
for step in range(self.config.num_denoise_steps):
time = torch.full(
(batch_size,),
1.0 + step * dt,
device=device,
dtype=torch.float32,
)
action_time_embs = self.embed_suffix(time, x_t)
inputs_embeds[:, act_slice] = action_time_embs.to(inputs_embeds.dtype)
# Keep the prefix KV cache invariant across denoising steps.
past_key_values.crop(act_start)
outputs = self.vlm_backbone.model(
attention_mask=attention_mask[:, :act_end],
past_key_values=past_key_values,
inputs_embeds=inputs_embeds[:, act_slice],
position_ids=position_ids[..., act_slice],
use_cache=True,
return_dict=True,
)
with self.flow_head_autocast_context():
hidden_states = outputs.last_hidden_state[:, :chunk_size]
hidden_states = hidden_states.to(dtype=self.action_out_proj.dtype)
v_t = self.action_out_proj(hidden_states)
x_t += dt * v_t.reshape(x_t.shape)
return x_t
+282
View File
@@ -0,0 +1,282 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any
import torch
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
from lerobot.policies.eo1.configuration_eo1 import EO1Config
from lerobot.processor import (
AddBatchDimensionProcessorStep,
ComplementaryDataProcessorStep,
DeviceProcessorStep,
NormalizerProcessorStep,
PolicyAction,
PolicyProcessorPipeline,
ProcessorStep,
ProcessorStepRegistry,
RenameObservationsProcessorStep,
UnnormalizerProcessorStep,
)
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
from lerobot.types import TransitionKey
from lerobot.utils.constants import (
OBS_STATE,
POLICY_POSTPROCESSOR_DEFAULT_NAME,
POLICY_PREPROCESSOR_DEFAULT_NAME,
)
from lerobot.utils.import_utils import _transformers_available, require_package
if TYPE_CHECKING or _transformers_available:
from transformers.models.qwen2_5_vl import Qwen2_5_VLProcessor
else:
Qwen2_5_VLProcessor = None
SYSTEM_MESSAGE = "You are a helpful physical assistant."
# EO-1 special tokens
ACTION_START_TOKEN = "<|action_start|>" # nosec B105
DEFAULT_ACTION_TOKEN = "<|action_pad|>" # nosec B105
ACTION_END_TOKEN = "<|action_end|>" # nosec B105
STATE_START_TOKEN = "<|state_start|>" # nosec B105
DEFAULT_STATE_TOKEN = "<|state_pad|>" # nosec B105
STATE_END_TOKEN = "<|state_end|>" # nosec B105
TASK_VLA_TOKEN = "<|vla|>" # nosec B105
EO1_SPECIAL_TOKENS = [
ACTION_START_TOKEN,
DEFAULT_ACTION_TOKEN,
ACTION_END_TOKEN,
STATE_START_TOKEN,
DEFAULT_STATE_TOKEN,
STATE_END_TOKEN,
TASK_VLA_TOKEN,
]
@dataclass
@ProcessorStepRegistry.register(name="eo1_conversation_template_processor")
class EO1ConversationTemplateStep(ComplementaryDataProcessorStep):
input_features: dict[str, PolicyFeature] | dict[str, dict[str, Any]]
chunk_size: int
_image_keys: list[str] = field(default_factory=list, init=False, repr=False)
def __post_init__(self):
# Robust JSON deserialization handling (guard empty maps).
if self.input_features:
first_val = next(iter(self.input_features.values()))
if isinstance(first_val, dict):
reconstructed = {}
for key, ft_dict in self.input_features.items():
reconstructed[key] = PolicyFeature(
type=FeatureType(ft_dict["type"]), shape=tuple(ft_dict["shape"])
)
self.input_features = reconstructed
self._image_keys = [
key for key, value in self.input_features.items() if value.type == FeatureType.VISUAL
]
def complementary_data(self, complementary_data):
tasks = complementary_data.get("task")
if tasks is None:
raise ValueError("Task is required for EO1ConversationTemplateStep.")
observation = self.transition.get(TransitionKey.OBSERVATION)
if observation is None:
raise ValueError("Observation is required for EO1ConversationTemplateStep.")
if OBS_STATE in observation and observation[OBS_STATE].shape[0] != len(tasks):
raise ValueError("Batch size mismatch between observation.state and task list.")
# LeRobot visual observations reach in processor as float32 tensors in [0, 1].
# Convert to uint8 in [0, 255] to meet the input requirement of Qwen2.5-VL-3B-Instruct.
images = {
key: observation[key].clamp(0, 1).mul(255.0).round().to(torch.uint8) for key in self._image_keys
}
messages = []
for i in range(len(tasks)):
content = [
*[{"type": "image", "image": images[key][i]} for key in self._image_keys],
{
"type": "text",
"text": (
f"{STATE_START_TOKEN}{DEFAULT_STATE_TOKEN}{STATE_END_TOKEN}{tasks[i]}{TASK_VLA_TOKEN}"
),
},
]
messages.append(
[
{"role": "system", "content": [{"type": "text", "text": SYSTEM_MESSAGE}]},
{"role": "user", "content": content},
{
"role": "assistant",
"content": [
{
"type": "text",
"text": f"{ACTION_START_TOKEN}{DEFAULT_ACTION_TOKEN * self.chunk_size}{ACTION_END_TOKEN}",
}
],
},
]
)
complementary_data["messages"] = messages
return complementary_data
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
"""
This step only materializes EO1-specific message objects in complementary_data.
PipelineFeatureType tracks only ACTION and OBSERVATION, so there is no static
feature contract change to record here.
"""
return features
def get_config(self) -> dict[str, Any]:
return {
"input_features": {
key: {"type": ft.type.value, "shape": ft.shape} for key, ft in self.input_features.items()
},
"chunk_size": self.chunk_size,
}
@dataclass
@ProcessorStepRegistry.register(name="eo1_qwen_processor")
class EO1QwenProcessorStep(ComplementaryDataProcessorStep):
processor_name: str = "Qwen/Qwen2.5-VL-3B-Instruct"
image_min_pixels: int | None = 64 * 28 * 28
image_max_pixels: int | None = 128 * 28 * 28
use_fast_processor: bool = False
_processor: Qwen2_5_VLProcessor | None = field(default=None, init=False, repr=False)
_state_token_id: int | None = field(default=None, init=False, repr=False)
_action_token_id: int | None = field(default=None, init=False, repr=False)
def __post_init__(self):
require_package("transformers", extra="eo1")
self._processor = Qwen2_5_VLProcessor.from_pretrained(
self.processor_name,
use_fast=self.use_fast_processor,
)
self._processor.tokenizer.add_tokens(EO1_SPECIAL_TOKENS, special_tokens=True)
self._state_token_id = self._processor.tokenizer.convert_tokens_to_ids(DEFAULT_STATE_TOKEN)
self._action_token_id = self._processor.tokenizer.convert_tokens_to_ids(DEFAULT_ACTION_TOKEN)
def complementary_data(self, complementary_data):
messages = complementary_data.pop("messages", None)
if messages is None:
raise ValueError("Messages are required for EO1QwenProcessorStep.")
# Rollout batches use left padding so action spans stay aligned across samples.
# Supervised batches use right padding to match standard training collation.
padding_side = "right" if self.transition.get(TransitionKey.ACTION) is not None else "left"
inputs = self._processor.apply_chat_template(
messages,
tokenize=True,
padding=True,
padding_side=padding_side,
min_pixels=self.image_min_pixels,
max_pixels=self.image_max_pixels,
add_generation_prompt=False,
return_dict=True,
return_tensors="pt",
)
complementary_data["input_ids"] = inputs["input_ids"]
complementary_data["pixel_values"] = inputs["pixel_values"]
complementary_data["image_grid_thw"] = inputs["image_grid_thw"]
complementary_data["attention_mask"] = inputs["attention_mask"]
complementary_data["mm_token_type_ids"] = inputs["mm_token_type_ids"]
complementary_data["state_token_id"] = self._state_token_id
complementary_data["action_token_id"] = self._action_token_id
return complementary_data
def get_config(self) -> dict[str, Any]:
return {
"processor_name": self.processor_name,
"image_min_pixels": self.image_min_pixels,
"image_max_pixels": self.image_max_pixels,
"use_fast_processor": self.use_fast_processor,
}
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
"""
This step only converts the messages to the model input format.
"""
return features
def make_eo1_pre_post_processors(
config: EO1Config,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
"""Build pre/post processor pipelines for EO1."""
input_steps: list[ProcessorStep] = [
RenameObservationsProcessorStep(rename_map={}),
AddBatchDimensionProcessorStep(),
NormalizerProcessorStep(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
EO1ConversationTemplateStep(input_features=config.input_features, chunk_size=config.chunk_size),
EO1QwenProcessorStep(
processor_name=config.vlm_base,
image_min_pixels=config.image_min_pixels,
image_max_pixels=config.image_max_pixels,
use_fast_processor=config.use_fast_processor,
),
DeviceProcessorStep(device=config.device),
]
output_steps: list[ProcessorStep] = [
UnnormalizerProcessorStep(
features=config.output_features,
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
DeviceProcessorStep(device="cpu"),
]
return (
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
steps=input_steps,
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
),
PolicyProcessorPipeline[PolicyAction, PolicyAction](
steps=output_steps,
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
to_transition=policy_action_to_transition,
to_output=transition_to_policy_action,
),
)
+14
View File
@@ -46,6 +46,7 @@ from lerobot.utils.feature_utils import dataset_to_policy_features
from .act.configuration_act import ACTConfig
from .diffusion.configuration_diffusion import DiffusionConfig
from .eo1.configuration_eo1 import EO1Config
from .groot.configuration_groot import GrootConfig
from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
from .pi0.configuration_pi0 import PI0Config
@@ -146,6 +147,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
from .wall_x.modeling_wall_x import WallXPolicy
return WallXPolicy
elif name == "eo1":
from .eo1.modeling_eo1 import EO1Policy
return EO1Policy
else:
try:
return _get_policy_cls_from_policy_name(name=name)
@@ -196,6 +201,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
return XVLAConfig(**kwargs)
elif policy_type == "wall_x":
return WallXConfig(**kwargs)
elif policy_type == "eo1":
return EO1Config(**kwargs)
else:
try:
config_cls = PreTrainedConfig.get_choice_class(policy_type)
@@ -399,6 +406,13 @@ def make_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, EO1Config):
from .eo1.processor_eo1 import make_eo1_pre_post_processors
processors = make_eo1_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
else:
try:
@@ -54,6 +54,7 @@ class BiOpenArmFollower(Robot):
calibration_dir=config.calibration_dir,
port=config.left_arm_config.port,
disable_torque_on_disconnect=config.left_arm_config.disable_torque_on_disconnect,
use_velocity_and_torque=config.left_arm_config.use_velocity_and_torque,
max_relative_target=config.left_arm_config.max_relative_target,
cameras=left_cameras,
side=config.left_arm_config.side,
@@ -72,6 +73,7 @@ class BiOpenArmFollower(Robot):
calibration_dir=config.calibration_dir,
port=config.right_arm_config.port,
disable_torque_on_disconnect=config.right_arm_config.disable_torque_on_disconnect,
use_velocity_and_torque=config.right_arm_config.use_velocity_and_torque,
max_relative_target=config.right_arm_config.max_relative_target,
cameras=right_cameras,
side=config.right_arm_config.side,
@@ -66,6 +66,10 @@ class OpenArmFollowerConfigBase:
# Whether to disable torque when disconnecting
disable_torque_on_disconnect: bool = True
# When True, expose `.vel` and `.torque` per motor in observation features.
# Default False for compatibility with the position-only openarm_mini teleoperator.
use_velocity_and_torque: bool = False
# Safety limit for relative target positions
# Set to a positive scalar for all motors, or a dict mapping motor names to limits
max_relative_target: float | dict[str, float] | None = None
@@ -93,8 +93,9 @@ class OpenArmFollower(Robot):
features: dict[str, type] = {}
for motor in self.bus.motors:
features[f"{motor}.pos"] = float
features[f"{motor}.vel"] = float # Add this
features[f"{motor}.torque"] = float # Add this
if self.config.use_velocity_and_torque:
features[f"{motor}.vel"] = float
features[f"{motor}.torque"] = float
return features
@property
@@ -235,8 +236,9 @@ class OpenArmFollower(Robot):
for motor in self.bus.motors:
state = states.get(motor, {})
obs_dict[f"{motor}.pos"] = state.get("position", 0.0)
obs_dict[f"{motor}.vel"] = state.get("velocity", 0.0)
obs_dict[f"{motor}.torque"] = state.get("torque", 0.0)
if self.config.use_velocity_and_torque:
obs_dict[f"{motor}.vel"] = state.get("velocity", 0.0)
obs_dict[f"{motor}.torque"] = state.get("torque", 0.0)
# Capture images from cameras
for cam_key, cam in self.cameras.items():
+4 -3
View File
@@ -23,6 +23,7 @@ from lerobot.utils.robot_utils import precise_sleep
from ..context import RolloutContext
from .core import RolloutStrategy, send_next_action
from .display import BaseDisplay
logger = logging.getLogger(__name__)
@@ -38,6 +39,8 @@ class BaseStrategy(RolloutStrategy):
"""Initialise the inference engine."""
self._init_engine(ctx)
logger.info("Base strategy ready")
self._display = BaseDisplay(duration=ctx.runtime.cfg.duration)
self._display.show_banner()
def run(self, ctx: RolloutContext) -> None:
"""Run the autonomous control loop until shutdown or duration expires."""
@@ -72,9 +75,7 @@ class BaseStrategy(RolloutStrategy):
if (sleep_t := control_interval - dt) > 0:
precise_sleep(sleep_t)
else:
logger.warning(
f"Record loop is running slower ({1 / dt:.1f} Hz) than the target FPS ({cfg.fps} Hz). Dataset frames might be dropped and robot control might be unstable. Common causes are: 1) Camera FPS not keeping up 2) Policy inference taking too long 3) CPU starvation"
)
self._warn_slow_loop(dt, control_interval, cfg.fps)
def teardown(self, ctx: RolloutContext) -> None:
"""Disconnect hardware and stop inference."""
+12
View File
@@ -33,6 +33,7 @@ from ..inference import InferenceEngine
if TYPE_CHECKING:
from ..configs import RolloutStrategyConfig
from ..context import HardwareContext, ProcessorContext, RolloutContext, RuntimeContext
from .display import RolloutStatusDisplay
logger = logging.getLogger(__name__)
@@ -51,6 +52,17 @@ class RolloutStrategy(abc.ABC):
self._interpolator: ActionInterpolator | None = None
self._warmup_flushed: bool = False
self._cached_obs_processed: dict | None = None
self._display: RolloutStatusDisplay | None = None
def _warn_slow_loop(self, dt: float, control_interval: float, fps: float) -> None:
"""Warn when the control loop runs slower than the target FPS."""
if dt > control_interval:
logger.warning(
"Control loop running slower (%.1f Hz) than target (%.0f Hz). "
"Possible causes: camera FPS not keeping up, slow policy inference, CPU starvation.",
1 / dt,
fps,
)
def _init_engine(self, ctx: RolloutContext) -> None:
"""Attach the inference engine and action interpolator, then start the backend.
+30 -7
View File
@@ -71,6 +71,7 @@ from ..configs import DAggerKeyboardConfig, DAggerPedalConfig, DAggerStrategyCon
from ..context import RolloutContext
from ..robot_wrapper import ThreadSafeRobot
from .core import RolloutStrategy, estimate_max_episode_seconds, safe_push_to_hub, send_next_action
from .display import DAggerDisplay
PYNPUT_AVAILABLE = _pynput_available
keyboard = None
@@ -286,7 +287,7 @@ def _init_dagger_keyboard(events: DAggerEvents, cfg: DAggerKeyboardConfig):
listener = keyboard.Listener(on_press=on_press)
listener.start()
logger.info(
logger.debug(
"DAgger keyboard listener started (pause_resume='%s', correction='%s', upload='%s', ESC=stop)",
cfg.pause_resume,
cfg.correction,
@@ -370,6 +371,28 @@ class DAggerStrategy(RolloutStrategy):
self._episode_duration_s,
)
if self.config.input_device == "keyboard":
kb = self.config.keyboard
pause_key, correction_key, upload_key = (
kb.pause_resume.upper(),
kb.correction.upper(),
kb.upload.upper(),
)
else:
pb = self.config.pedal
pause_key, correction_key, upload_key = pb.pause_resume, pb.correction, pb.upload
self._display = DAggerDisplay(
record_autonomous=self.config.record_autonomous,
num_episodes=self.config.num_episodes,
episode_duration_s=self._episode_duration_s,
input_device=self.config.input_device,
pause_key=pause_key,
correction_key=correction_key,
upload_key=upload_key,
)
self._display.show_banner()
def run(self, ctx: RolloutContext) -> None:
"""Run DAgger episodes with human-in-the-loop intervention."""
if self.config.record_autonomous:
@@ -442,6 +465,7 @@ class DAggerStrategy(RolloutStrategy):
interpolator.reset()
events.reset()
engine.resume()
self._display.show_state(DAggerPhase.AUTONOMOUS)
last_action: dict[str, Any] | None = None
record_tick = 0
@@ -472,6 +496,7 @@ class DAggerStrategy(RolloutStrategy):
ctx,
last_action,
)
self._display.show_state(new_phase)
if new_phase == DAggerPhase.AUTONOMOUS:
last_action = None
@@ -556,9 +581,7 @@ class DAggerStrategy(RolloutStrategy):
if (sleep_t := control_interval - dt) > 0:
precise_sleep(sleep_t)
else:
logger.warning(
f"Record loop is running slower ({1 / dt:.1f} Hz) than the target FPS ({cfg.fps} Hz). Dataset frames might be dropped and robot control might be unstable. Common causes are: 1) Camera FPS not keeping up 2) Policy inference taking too long 3) CPU starvation"
)
self._warn_slow_loop(dt, control_interval, cfg.fps)
finally:
logger.info("DAgger continuous control loop ended — pausing engine")
@@ -599,6 +622,7 @@ class DAggerStrategy(RolloutStrategy):
interpolator.reset()
events.reset()
engine.resume()
self._display.show_state(DAggerPhase.AUTONOMOUS)
last_action: dict[str, Any] | None = None
start_time = time.perf_counter()
@@ -633,6 +657,7 @@ class DAggerStrategy(RolloutStrategy):
ctx,
last_action,
)
self._display.show_state(new_phase)
if new_phase == DAggerPhase.AUTONOMOUS:
last_action = None
@@ -705,9 +730,7 @@ class DAggerStrategy(RolloutStrategy):
if (sleep_t := control_interval - dt) > 0:
precise_sleep(sleep_t)
else:
logger.warning(
f"Record loop is running slower ({1 / dt:.1f} Hz) than the target FPS ({cfg.fps} Hz). Dataset frames might be dropped and robot control might be unstable. Common causes are: 1) Camera FPS not keeping up 2) Policy inference taking too long 3) CPU starvation"
)
self._warn_slow_loop(dt, control_interval, cfg.fps)
finally:
logger.info("DAgger corrections-only loop ended — pausing engine")
+263
View File
@@ -0,0 +1,263 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Console status display for rollout strategies.
One subclass per strategy static states/controls are declared as class
constants; runtime-dependent values are passed to ``__init__``.
In each strategy's ``setup()``:
self._display = DAggerDisplay(
record_autonomous=self.config.record_autonomous,
num_episodes=self.config.num_episodes,
episode_duration_s=self._episode_duration_s,
input_device=self.config.input_device,
pause_key="SPACE",
correction_key="TAB",
upload_key="ENTER",
)
self._display.show_banner()
On each state transition:
self._display.show_state("correcting")
"""
from __future__ import annotations
import enum
import sys
from dataclasses import dataclass
def _supports_color() -> bool:
return hasattr(sys.stdout, "isatty") and sys.stdout.isatty()
class _C:
"""ANSI escape codes."""
RESET = "\033[0m"
BOLD = "\033[1m"
DIM = "\033[2m"
GREEN = "\033[1;92m"
YELLOW = "\033[1;93m"
RED = "\033[1;91m"
CYAN = "\033[1;96m"
WHITE = "\033[1;97m"
GRAY = "\033[2;37m"
@dataclass
class StateConfig:
"""One named rollout state.
``key`` must match the string passed to ``RolloutStatusDisplay.show_state()``.
"""
key: str
emoji: str
label: str
description: str
color: str = _C.WHITE
@dataclass
class ControlConfig:
"""One keyboard/pedal binding shown in the startup banner."""
key: str
description: str
# ---------------------------------------------------------------------------
# Base display class
# ---------------------------------------------------------------------------
class RolloutStatusDisplay:
"""Unified console status display. Subclass once per strategy."""
def __init__(
self,
strategy: str,
states: list[StateConfig],
controls: list[ControlConfig],
info: list[str] | None = None,
) -> None:
self.strategy = strategy
self._states = {s.key: s for s in states}
self._controls = controls
self._info = info or []
self._use_color = _supports_color()
def _c(self, code: str, text: str) -> str:
if not self._use_color:
return text
return f"{code}{text}{_C.RESET}"
def show_banner(self) -> None:
"""Print startup banner: strategy name, states, controls, config info."""
width = 62
sep = self._c(_C.BOLD, "" * width)
print(f"\n{sep}")
print(self._c(_C.BOLD, f" lerobot-rollout │ {self.strategy}"))
if self._states:
print()
for state in self._states.values():
label = self._c(state.color, f"{state.label:<14}")
desc = self._c(_C.GRAY, state.description)
print(f" {state.emoji} {label} {desc}")
if self._controls:
print()
key_width = max(len(c.key) for c in self._controls)
for ctrl in self._controls:
key_str = self._c(_C.CYAN, f"[{ctrl.key:<{key_width}}]")
print(f" {key_str} {ctrl.description}")
if self._info:
print()
for item in self._info:
print(f" {item}")
print(f"{sep}\n")
def show_state(self, state_key: str | enum.Enum) -> None:
"""Print the current state and available controls - call this on every transition."""
key = state_key.value if isinstance(state_key, enum.Enum) else state_key
state = self._states.get(key)
if state is None:
return
label = self._c(state.color, f"{state.label:<14}")
desc = self._c(_C.GRAY, state.description)
print(f"\n {state.emoji} {label} {desc}\n")
if self._controls:
key_width = max(len(c.key) for c in self._controls)
for ctrl in self._controls:
key_str = self._c(_C.CYAN, f"[{ctrl.key:<{key_width}}]")
print(f" {key_str} {ctrl.description}")
print()
# ---------------------------------------------------------------------------
# One display subclass per strategy
# ---------------------------------------------------------------------------
class BaseDisplay(RolloutStatusDisplay):
"""Status display for the base (eval-only, no recording) strategy."""
_STATES = [StateConfig("running", "🟢", "RUNNING", "autonomous rollout — no recording", _C.GREEN)]
_CONTROLS = [ControlConfig("Ctrl+C", "stop session")]
def __init__(self, duration: float = 0) -> None:
info = ["No recording — evaluation only."]
if duration > 0:
info.append(f"Duration: {duration:.0f}s")
super().__init__("base", self._STATES, self._CONTROLS, info)
class SentryDisplay(RolloutStatusDisplay):
"""Status display for the sentry (continuous autonomous recording) strategy."""
_STATES = [StateConfig("recording", "🟢", "RECORDING", "continuous autonomous recording", _C.GREEN)]
_CONTROLS = [ControlConfig("Ctrl+C", "stop session")]
def __init__(self, episode_duration_s: float, upload_every_n_episodes: int) -> None:
info = [
f"Episode rotation: ~{episode_duration_s:.0f}s | "
f"Upload every {upload_every_n_episodes} episodes",
]
super().__init__("sentry", self._STATES, self._CONTROLS, info)
class HighlightDisplay(RolloutStatusDisplay):
"""Status display for the highlight (ring-buffer on-demand save) strategy."""
def __init__(self, ring_buffer_seconds: float, save_key: str, push_key: str) -> None:
states = [
StateConfig(
"buffering",
"",
"BUFFERING",
f"ring buffer active — last {ring_buffer_seconds:.0f}s captured",
_C.WHITE,
),
StateConfig("recording", "🔴", "RECORDING", "live recording — press [s] to save episode", _C.RED),
]
controls = [
ControlConfig(save_key, "BUFFERING ↔ RECORDING start recording / save episode"),
ControlConfig(push_key, "push dataset to Hub (background)"),
ControlConfig("ESC", "stop session"),
]
super().__init__("highlight", states, controls)
class DAggerDisplay(RolloutStatusDisplay):
"""Status display for the dagger (human-in-the-loop) strategy."""
_PAUSED_STATE = StateConfig("paused", "🟡", "PAUSED", "holding last position — awaiting input", _C.YELLOW)
_CORRECTING_STATE = StateConfig(
"correcting", "🔴", "CORRECTING", "human teleop active — recording correction", _C.RED
)
def __init__(
self,
record_autonomous: bool,
num_episodes: int,
episode_duration_s: float,
input_device: str,
pause_key: str,
correction_key: str,
upload_key: str,
) -> None:
mode = "continuous recording" if record_autonomous else "corrections only"
auto_desc = "policy running — recording" if record_autonomous else "policy running — no recording"
states = [
StateConfig("autonomous", "🟢", "AUTONOMOUS", auto_desc, _C.GREEN),
self._PAUSED_STATE,
self._CORRECTING_STATE,
]
controls = [
ControlConfig(pause_key, "AUTONOMOUS ↔ PAUSED pause / resume policy"),
ControlConfig(correction_key, "PAUSED ↔ CORRECTING start / stop correction"),
ControlConfig(upload_key, "push dataset to Hub"),
ControlConfig("ESC", "stop session"),
]
info = [f"Target: {num_episodes} episodes | Input: {input_device}"]
if record_autonomous:
info.append(f"Episode rotation: ~{episode_duration_s:.0f}s")
super().__init__(f"dagger [{mode}]", states, controls, info)
if __name__ == "__main__":
dagger_display = DAggerDisplay(
record_autonomous=False,
num_episodes=20,
episode_duration_s=30,
input_device="keyboard",
pause_key="SPACE",
correction_key="TAB",
upload_key="ENTER",
)
dagger_display.show_banner()
dagger_display.show_state("paused")
dagger_display.show_state("correcting")
dagger_display.show_state("paused")
dagger_display.show_state("autonomous")
+20 -4
View File
@@ -17,6 +17,7 @@
from __future__ import annotations
import contextlib
import enum
import logging
import os
import sys
@@ -36,6 +37,7 @@ from ..configs import HighlightStrategyConfig
from ..context import RolloutContext
from ..ring_buffer import RolloutRingBuffer
from .core import RolloutStrategy, safe_push_to_hub, send_next_action
from .display import HighlightDisplay
PYNPUT_AVAILABLE = _pynput_available
keyboard = None
@@ -53,6 +55,13 @@ if PYNPUT_AVAILABLE:
logger = logging.getLogger(__name__)
class HighlightPhase(enum.Enum):
"""Observable phases of a Highlight session."""
BUFFERING = "buffering" # Ring buffer accumulating frames, not recording
RECORDING = "recording" # Live recording active
class HighlightStrategy(RolloutStrategy):
"""Autonomous rollout with on-demand recording via ring buffer.
@@ -105,6 +114,13 @@ class HighlightStrategy(RolloutStrategy):
self.config.save_key,
self.config.push_key,
)
self._display = HighlightDisplay(
ring_buffer_seconds=self.config.ring_buffer_seconds,
save_key=self.config.save_key,
push_key=self.config.push_key,
)
self._display.show_banner()
self._display.show_state(HighlightPhase.BUFFERING)
def run(self, ctx: RolloutContext) -> None:
"""Run the autonomous loop, buffering frames and recording on demand."""
@@ -162,6 +178,7 @@ class HighlightStrategy(RolloutStrategy):
for buffered_frame in ring.drain():
dataset.add_frame(buffered_frame)
self._recording_live.set()
self._display.show_state(HighlightPhase.RECORDING)
else:
dataset.add_frame(frame)
with self._episode_lock:
@@ -172,6 +189,7 @@ class HighlightStrategy(RolloutStrategy):
play_sounds,
)
self._recording_live.clear()
self._display.show_state(HighlightPhase.BUFFERING)
continue # frame already consumed — skip ring.append
if self._push_requested.is_set():
@@ -188,9 +206,7 @@ class HighlightStrategy(RolloutStrategy):
if (sleep_t := control_interval - dt) > 0:
precise_sleep(sleep_t)
else:
logger.warning(
f"Record loop is running slower ({1 / dt:.1f} Hz) than the target FPS ({cfg.fps} Hz). Dataset frames might be dropped and robot control might be unstable. Common causes are: 1) Camera FPS not keeping up 2) Policy inference taking too long 3) CPU starvation"
)
self._warn_slow_loop(dt, control_interval, cfg.fps)
finally:
logger.info("Highlight control loop ended")
@@ -255,7 +271,7 @@ class HighlightStrategy(RolloutStrategy):
self._listener = keyboard.Listener(on_press=on_press)
self._listener.start()
logger.info("Keyboard listener started (save='%s', push='%s', ESC=stop)", save_key, push_key)
logger.debug("Keyboard listener started (save='%s', push='%s', ESC=stop)", save_key, push_key)
except ImportError:
logger.warning("pynput not available — keyboard listener disabled")
+7 -3
View File
@@ -32,6 +32,7 @@ from lerobot.utils.utils import log_say
from ..configs import SentryStrategyConfig
from ..context import RolloutContext
from .core import RolloutStrategy, estimate_max_episode_seconds, safe_push_to_hub, send_next_action
from .display import SentryDisplay
logger = logging.getLogger(__name__)
@@ -79,6 +80,11 @@ class SentryStrategy(RolloutStrategy):
self._episode_duration_s,
self.config.upload_every_n_episodes,
)
self._display = SentryDisplay(
episode_duration_s=self._episode_duration_s,
upload_every_n_episodes=self.config.upload_every_n_episodes,
)
self._display.show_banner()
def run(self, ctx: RolloutContext) -> None:
"""Run the continuous recording loop with automatic episode rotation."""
@@ -160,9 +166,7 @@ class SentryStrategy(RolloutStrategy):
if (sleep_t := control_interval - dt) > 0:
precise_sleep(sleep_t)
else:
logger.warning(
f"Record loop is running slower ({1 / dt:.1f} Hz) than the target FPS ({cfg.fps} Hz). Dataset frames might be dropped and robot control might be unstable. Common causes are: 1) Camera FPS not keeping up 2) Policy inference taking too long 3) CPU starvation"
)
self._warn_slow_loop(dt, control_interval, cfg.fps)
finally:
logger.info("Sentry control loop ended — saving final episode")
@@ -49,6 +49,7 @@ class BiOpenArmLeader(Teleoperator):
can_data_bitrate=config.left_arm_config.can_data_bitrate,
motor_config=config.left_arm_config.motor_config,
manual_control=config.left_arm_config.manual_control,
use_velocity_and_torque=config.left_arm_config.use_velocity_and_torque,
position_kd=config.left_arm_config.position_kd,
position_kp=config.left_arm_config.position_kp,
)
@@ -63,6 +64,7 @@ class BiOpenArmLeader(Teleoperator):
can_data_bitrate=config.right_arm_config.can_data_bitrate,
motor_config=config.right_arm_config.motor_config,
manual_control=config.right_arm_config.manual_control,
use_velocity_and_torque=config.right_arm_config.use_velocity_and_torque,
position_kd=config.right_arm_config.position_kd,
position_kp=config.right_arm_config.position_kp,
)
@@ -60,6 +60,10 @@ class OpenArmLeaderConfigBase:
# When enabled, motors have torque disabled for manual movement
manual_control: bool = True
# When True, expose `.vel` and `.torque` per motor in action features.
# Default False for compatibility with the position-only openarm_mini teleoperator.
use_velocity_and_torque: bool = False
# TODO(Steven, Pepijn): Not used ... ?
# MIT control parameters (used when manual_control=False for torque control)
# List of 8 values: [joint_1, joint_2, joint_3, joint_4, joint_5, joint_6, joint_7, gripper]
@@ -70,8 +70,9 @@ class OpenArmLeader(Teleoperator):
features: dict[str, type] = {}
for motor in self.bus.motors:
features[f"{motor}.pos"] = float
features[f"{motor}.vel"] = float
features[f"{motor}.torque"] = float
if self.config.use_velocity_and_torque:
features[f"{motor}.vel"] = float
features[f"{motor}.torque"] = float
return features
@property
@@ -201,8 +202,9 @@ class OpenArmLeader(Teleoperator):
for motor in self.bus.motors:
state = states.get(motor, {})
action_dict[f"{motor}.pos"] = state.get("position")
action_dict[f"{motor}.vel"] = state.get("velocity")
action_dict[f"{motor}.torque"] = state.get("torque")
if self.config.use_velocity_and_torque:
action_dict[f"{motor}.vel"] = state.get("velocity")
action_dict[f"{motor}.torque"] = state.get("torque")
dt_ms = (time.perf_counter() - start) * 1e3
logger.debug(f"{self} read state: {dt_ms:.1f}ms")
+186
View File
@@ -0,0 +1,186 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Smoke tests for EO1's public LeRobot policy interface."""
from __future__ import annotations
from types import SimpleNamespace
import pytest
import torch
from torch import nn
pytest.importorskip("transformers")
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.policies.eo1.modeling_eo1 import EO1Policy
from lerobot.utils.constants import ACTION, OBS_STATE
HIDDEN_SIZE = 8
STATE_DIM = 4
ACTION_DIM = 3
CHUNK_SIZE = 3
N_ACTION_STEPS = 2
MAX_ACTION_DIM = 6
STATE_TOKEN_ID = 5
ACTION_TOKEN_ID = 6
class DummyVLMBackbone(nn.Module):
def __init__(self, hidden_size: int, vocab_size: int = 64):
super().__init__()
self.embedding = nn.Embedding(vocab_size, hidden_size)
self.config = SimpleNamespace(text_config=SimpleNamespace(hidden_size=hidden_size))
@property
def model(self):
return self
def get_input_embeddings(self):
return self.embedding
def get_rope_index(
self,
input_ids: torch.Tensor,
image_grid_thw: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
mm_token_type_ids: torch.Tensor | None = None,
):
batch_size, seq_len = input_ids.shape
if attention_mask is None:
text_positions = torch.arange(seq_len, device=input_ids.device).expand(batch_size, -1)
else:
text_positions = attention_mask.long().cumsum(-1) - 1
text_positions = text_positions.masked_fill(attention_mask == 0, 0)
position_ids = text_positions.view(1, batch_size, seq_len).expand(3, batch_size, seq_len)
rope_deltas = torch.zeros(batch_size, 1, dtype=torch.long, device=input_ids.device)
return position_ids, rope_deltas
def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
return gradient_checkpointing_kwargs
def gradient_checkpointing_disable(self):
return None
def forward(
self,
*,
input_ids: torch.Tensor | None = None,
inputs_embeds: torch.Tensor | None = None,
**kwargs,
):
if inputs_embeds is None:
inputs_embeds = self.embedding(input_ids)
return SimpleNamespace(
last_hidden_state=inputs_embeds,
past_key_values=SimpleNamespace(crop=lambda prefix_len: None),
)
def make_eo1_config():
from lerobot.policies.eo1.configuration_eo1 import EO1Config
return EO1Config(
device="cpu",
dtype="float32",
vlm_base="dummy-qwen",
vlm_config={},
chunk_size=CHUNK_SIZE,
n_action_steps=N_ACTION_STEPS,
max_state_dim=STATE_DIM,
max_action_dim=MAX_ACTION_DIM,
num_denoise_steps=2,
input_features={
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(STATE_DIM,)),
"observation.images.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 16, 16)),
},
output_features={
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(ACTION_DIM,)),
},
)
def make_policy_batch(include_action: bool) -> dict[str, torch.Tensor | int]:
batch_size = 1
seq_len = CHUNK_SIZE + 4
input_ids = torch.tensor(
[[11, STATE_TOKEN_ID, 12, ACTION_TOKEN_ID, ACTION_TOKEN_ID, ACTION_TOKEN_ID, 13]],
dtype=torch.long,
)
assert input_ids.shape == (batch_size, seq_len)
batch: dict[str, torch.Tensor | int] = {
OBS_STATE: torch.randn(batch_size, STATE_DIM, dtype=torch.float32),
"input_ids": input_ids,
"attention_mask": torch.ones(batch_size, seq_len, dtype=torch.long),
"pixel_values": torch.zeros(batch_size, 3, 4, 4, dtype=torch.float32),
"image_grid_thw": torch.tensor([[1, 2, 2]], dtype=torch.long),
"mm_token_type_ids": torch.zeros(batch_size, seq_len, dtype=torch.int32),
"state_token_id": STATE_TOKEN_ID,
"action_token_id": ACTION_TOKEN_ID,
}
if include_action:
batch[ACTION] = torch.randn(batch_size, CHUNK_SIZE, ACTION_DIM, dtype=torch.float32)
return batch
def test_lerobot_eo1_forward_pass(monkeypatch):
monkeypatch.setattr(
"lerobot.policies.eo1.modeling_eo1.Qwen2_5_VLForConditionalGeneration.from_pretrained",
lambda *args, **kwargs: DummyVLMBackbone(HIDDEN_SIZE),
)
policy = EO1Policy(make_eo1_config())
loss, metrics = policy.forward(make_policy_batch(include_action=True))
assert loss.ndim == 0
assert torch.isfinite(loss)
assert metrics["loss"] == pytest.approx(loss.item())
def test_lerobot_eo1_inference(monkeypatch):
monkeypatch.setattr(
"lerobot.policies.eo1.modeling_eo1.Qwen2_5_VLForConditionalGeneration.from_pretrained",
lambda *args, **kwargs: DummyVLMBackbone(HIDDEN_SIZE),
)
policy = EO1Policy(make_eo1_config())
sample_calls = {"count": 0}
fixed_chunk = torch.tensor(
[
[
[0.1, 0.2, 0.3, 9.0, 9.0, 9.0],
[1.1, 1.2, 1.3, 9.0, 9.0, 9.0],
[2.1, 2.2, 2.3, 9.0, 9.0, 9.0],
]
],
dtype=torch.float32,
)
def fake_sample_actions(**kwargs):
sample_calls["count"] += 1
return fixed_chunk
monkeypatch.setattr(policy.model, "sample_actions", fake_sample_actions)
batch = make_policy_batch(include_action=False)
action_0 = policy.select_action(batch)
action_1 = policy.select_action(batch)
torch.testing.assert_close(action_0, fixed_chunk[:, 0, :ACTION_DIM])
torch.testing.assert_close(action_1, fixed_chunk[:, 1, :ACTION_DIM])
assert sample_calls["count"] == 1
Generated
+7 -1
View File
@@ -2723,6 +2723,10 @@ dynamixel = [
{ name = "dynamixel-sdk" },
{ name = "pyserial" },
]
eo1 = [
{ name = "qwen-vl-utils" },
{ name = "transformers" },
]
evaluation = [
{ name = "av" },
]
@@ -3029,6 +3033,7 @@ requires-dist = [
{ name = "lerobot", extras = ["pyserial-dep"], marker = "extra == 'unitree-g1'" },
{ name = "lerobot", extras = ["pyzmq-dep"], marker = "extra == 'lekiwi'" },
{ name = "lerobot", extras = ["pyzmq-dep"], marker = "extra == 'unitree-g1'" },
{ name = "lerobot", extras = ["qwen-vl-utils-dep"], marker = "extra == 'eo1'" },
{ name = "lerobot", extras = ["qwen-vl-utils-dep"], marker = "extra == 'sarm'" },
{ name = "lerobot", extras = ["qwen-vl-utils-dep"], marker = "extra == 'wallx'" },
{ name = "lerobot", extras = ["reachy2"], marker = "extra == 'all'" },
@@ -3043,6 +3048,7 @@ requires-dist = [
{ name = "lerobot", extras = ["smolvla"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["test"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["training"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'eo1'" },
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'groot'" },
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'hilserl'" },
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'libero'" },
@@ -3112,7 +3118,7 @@ requires-dist = [
{ name = "transformers", marker = "extra == 'transformers-dep'", specifier = ">=5.4.0,<5.6.0" },
{ name = "wandb", marker = "extra == 'training'", specifier = ">=0.24.0,<0.25.0" },
]
provides-extras = ["dataset", "training", "hardware", "viz", "core-scripts", "evaluation", "dataset-viz", "av-dep", "pygame-dep", "placo-dep", "transformers-dep", "grpcio-dep", "can-dep", "peft-dep", "scipy-dep", "diffusers-dep", "qwen-vl-utils-dep", "matplotlib-dep", "pyserial-dep", "deepdiff-dep", "pynput-dep", "pyzmq-dep", "feetech", "dynamixel", "damiao", "robstride", "openarms", "gamepad", "hopejr", "lekiwi", "unitree-g1", "reachy2", "kinematics", "intelrealsense", "phone", "diffusion", "wallx", "pi", "smolvla", "multi-task-dit", "groot", "sarm", "xvla", "hilserl", "async", "peft", "dev", "notebook", "test", "video-benchmark", "aloha", "pusht", "libero", "metaworld", "all"]
provides-extras = ["dataset", "training", "hardware", "viz", "core-scripts", "evaluation", "dataset-viz", "av-dep", "pygame-dep", "placo-dep", "transformers-dep", "grpcio-dep", "can-dep", "peft-dep", "scipy-dep", "diffusers-dep", "qwen-vl-utils-dep", "matplotlib-dep", "pyserial-dep", "deepdiff-dep", "pynput-dep", "pyzmq-dep", "feetech", "dynamixel", "damiao", "robstride", "openarms", "gamepad", "hopejr", "lekiwi", "unitree-g1", "reachy2", "kinematics", "intelrealsense", "phone", "diffusion", "wallx", "pi", "smolvla", "multi-task-dit", "groot", "sarm", "xvla", "eo1", "hilserl", "async", "peft", "dev", "notebook", "test", "video-benchmark", "aloha", "pusht", "libero", "metaworld", "all"]
[[package]]
name = "librt"