Files
lerobot/tests/policies/eo1/test_eo1.py
T
Haoming Song e99c55af4b feat(policies): add EO-1 model (#3403)
* feat(policies): add EO-1 model

* chore(eo1): adjust policy_eo1_README.md to to avoid duplicate with eo1.mdx

* chore(eo1): remove policy_eo1_README.md, link eo1.mdx in policy folder

---------

Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com>
2026-05-06 18:01:16 +02:00

187 lines
6.1 KiB
Python

#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Smoke tests for EO1's public LeRobot policy interface."""
from __future__ import annotations
from types import SimpleNamespace
import pytest
import torch
from torch import nn
pytest.importorskip("transformers")
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.policies.eo1.modeling_eo1 import EO1Policy
from lerobot.utils.constants import ACTION, OBS_STATE
HIDDEN_SIZE = 8
STATE_DIM = 4
ACTION_DIM = 3
CHUNK_SIZE = 3
N_ACTION_STEPS = 2
MAX_ACTION_DIM = 6
STATE_TOKEN_ID = 5
ACTION_TOKEN_ID = 6
class DummyVLMBackbone(nn.Module):
def __init__(self, hidden_size: int, vocab_size: int = 64):
super().__init__()
self.embedding = nn.Embedding(vocab_size, hidden_size)
self.config = SimpleNamespace(text_config=SimpleNamespace(hidden_size=hidden_size))
@property
def model(self):
return self
def get_input_embeddings(self):
return self.embedding
def get_rope_index(
self,
input_ids: torch.Tensor,
image_grid_thw: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
mm_token_type_ids: torch.Tensor | None = None,
):
batch_size, seq_len = input_ids.shape
if attention_mask is None:
text_positions = torch.arange(seq_len, device=input_ids.device).expand(batch_size, -1)
else:
text_positions = attention_mask.long().cumsum(-1) - 1
text_positions = text_positions.masked_fill(attention_mask == 0, 0)
position_ids = text_positions.view(1, batch_size, seq_len).expand(3, batch_size, seq_len)
rope_deltas = torch.zeros(batch_size, 1, dtype=torch.long, device=input_ids.device)
return position_ids, rope_deltas
def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
return gradient_checkpointing_kwargs
def gradient_checkpointing_disable(self):
return None
def forward(
self,
*,
input_ids: torch.Tensor | None = None,
inputs_embeds: torch.Tensor | None = None,
**kwargs,
):
if inputs_embeds is None:
inputs_embeds = self.embedding(input_ids)
return SimpleNamespace(
last_hidden_state=inputs_embeds,
past_key_values=SimpleNamespace(crop=lambda prefix_len: None),
)
def make_eo1_config():
from lerobot.policies.eo1.configuration_eo1 import EO1Config
return EO1Config(
device="cpu",
dtype="float32",
vlm_base="dummy-qwen",
vlm_config={},
chunk_size=CHUNK_SIZE,
n_action_steps=N_ACTION_STEPS,
max_state_dim=STATE_DIM,
max_action_dim=MAX_ACTION_DIM,
num_denoise_steps=2,
input_features={
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(STATE_DIM,)),
"observation.images.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 16, 16)),
},
output_features={
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(ACTION_DIM,)),
},
)
def make_policy_batch(include_action: bool) -> dict[str, torch.Tensor | int]:
batch_size = 1
seq_len = CHUNK_SIZE + 4
input_ids = torch.tensor(
[[11, STATE_TOKEN_ID, 12, ACTION_TOKEN_ID, ACTION_TOKEN_ID, ACTION_TOKEN_ID, 13]],
dtype=torch.long,
)
assert input_ids.shape == (batch_size, seq_len)
batch: dict[str, torch.Tensor | int] = {
OBS_STATE: torch.randn(batch_size, STATE_DIM, dtype=torch.float32),
"input_ids": input_ids,
"attention_mask": torch.ones(batch_size, seq_len, dtype=torch.long),
"pixel_values": torch.zeros(batch_size, 3, 4, 4, dtype=torch.float32),
"image_grid_thw": torch.tensor([[1, 2, 2]], dtype=torch.long),
"mm_token_type_ids": torch.zeros(batch_size, seq_len, dtype=torch.int32),
"state_token_id": STATE_TOKEN_ID,
"action_token_id": ACTION_TOKEN_ID,
}
if include_action:
batch[ACTION] = torch.randn(batch_size, CHUNK_SIZE, ACTION_DIM, dtype=torch.float32)
return batch
def test_lerobot_eo1_forward_pass(monkeypatch):
monkeypatch.setattr(
"lerobot.policies.eo1.modeling_eo1.Qwen2_5_VLForConditionalGeneration.from_pretrained",
lambda *args, **kwargs: DummyVLMBackbone(HIDDEN_SIZE),
)
policy = EO1Policy(make_eo1_config())
loss, metrics = policy.forward(make_policy_batch(include_action=True))
assert loss.ndim == 0
assert torch.isfinite(loss)
assert metrics["loss"] == pytest.approx(loss.item())
def test_lerobot_eo1_inference(monkeypatch):
monkeypatch.setattr(
"lerobot.policies.eo1.modeling_eo1.Qwen2_5_VLForConditionalGeneration.from_pretrained",
lambda *args, **kwargs: DummyVLMBackbone(HIDDEN_SIZE),
)
policy = EO1Policy(make_eo1_config())
sample_calls = {"count": 0}
fixed_chunk = torch.tensor(
[
[
[0.1, 0.2, 0.3, 9.0, 9.0, 9.0],
[1.1, 1.2, 1.3, 9.0, 9.0, 9.0],
[2.1, 2.2, 2.3, 9.0, 9.0, 9.0],
]
],
dtype=torch.float32,
)
def fake_sample_actions(**kwargs):
sample_calls["count"] += 1
return fixed_chunk
monkeypatch.setattr(policy.model, "sample_actions", fake_sample_actions)
batch = make_policy_batch(include_action=False)
action_0 = policy.select_action(batch)
action_1 = policy.select_action(batch)
torch.testing.assert_close(action_0, fixed_chunk[:, 0, :ACTION_DIM])
torch.testing.assert_close(action_1, fixed_chunk[:, 1, :ACTION_DIM])
assert sample_calls["count"] == 1