mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-11 14:49:43 +00:00
feat(policies): add EO-1 model (#3403)
* feat(policies): add EO-1 model * chore(eo1): adjust policy_eo1_README.md to to avoid duplicate with eo1.mdx * chore(eo1): remove policy_eo1_README.md, link eo1.mdx in policy folder --------- Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com>
This commit is contained in:
@@ -0,0 +1,186 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Smoke tests for EO1's public LeRobot policy interface."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
pytest.importorskip("transformers")
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.policies.eo1.modeling_eo1 import EO1Policy
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE
|
||||
|
||||
HIDDEN_SIZE = 8
|
||||
STATE_DIM = 4
|
||||
ACTION_DIM = 3
|
||||
CHUNK_SIZE = 3
|
||||
N_ACTION_STEPS = 2
|
||||
MAX_ACTION_DIM = 6
|
||||
STATE_TOKEN_ID = 5
|
||||
ACTION_TOKEN_ID = 6
|
||||
|
||||
|
||||
class DummyVLMBackbone(nn.Module):
|
||||
def __init__(self, hidden_size: int, vocab_size: int = 64):
|
||||
super().__init__()
|
||||
self.embedding = nn.Embedding(vocab_size, hidden_size)
|
||||
self.config = SimpleNamespace(text_config=SimpleNamespace(hidden_size=hidden_size))
|
||||
|
||||
@property
|
||||
def model(self):
|
||||
return self
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embedding
|
||||
|
||||
def get_rope_index(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
image_grid_thw: torch.Tensor | None = None,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
mm_token_type_ids: torch.Tensor | None = None,
|
||||
):
|
||||
batch_size, seq_len = input_ids.shape
|
||||
if attention_mask is None:
|
||||
text_positions = torch.arange(seq_len, device=input_ids.device).expand(batch_size, -1)
|
||||
else:
|
||||
text_positions = attention_mask.long().cumsum(-1) - 1
|
||||
text_positions = text_positions.masked_fill(attention_mask == 0, 0)
|
||||
position_ids = text_positions.view(1, batch_size, seq_len).expand(3, batch_size, seq_len)
|
||||
rope_deltas = torch.zeros(batch_size, 1, dtype=torch.long, device=input_ids.device)
|
||||
return position_ids, rope_deltas
|
||||
|
||||
def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
|
||||
return gradient_checkpointing_kwargs
|
||||
|
||||
def gradient_checkpointing_disable(self):
|
||||
return None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
*,
|
||||
input_ids: torch.Tensor | None = None,
|
||||
inputs_embeds: torch.Tensor | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embedding(input_ids)
|
||||
return SimpleNamespace(
|
||||
last_hidden_state=inputs_embeds,
|
||||
past_key_values=SimpleNamespace(crop=lambda prefix_len: None),
|
||||
)
|
||||
|
||||
|
||||
def make_eo1_config():
|
||||
from lerobot.policies.eo1.configuration_eo1 import EO1Config
|
||||
|
||||
return EO1Config(
|
||||
device="cpu",
|
||||
dtype="float32",
|
||||
vlm_base="dummy-qwen",
|
||||
vlm_config={},
|
||||
chunk_size=CHUNK_SIZE,
|
||||
n_action_steps=N_ACTION_STEPS,
|
||||
max_state_dim=STATE_DIM,
|
||||
max_action_dim=MAX_ACTION_DIM,
|
||||
num_denoise_steps=2,
|
||||
input_features={
|
||||
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(STATE_DIM,)),
|
||||
"observation.images.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 16, 16)),
|
||||
},
|
||||
output_features={
|
||||
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(ACTION_DIM,)),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def make_policy_batch(include_action: bool) -> dict[str, torch.Tensor | int]:
|
||||
batch_size = 1
|
||||
seq_len = CHUNK_SIZE + 4
|
||||
input_ids = torch.tensor(
|
||||
[[11, STATE_TOKEN_ID, 12, ACTION_TOKEN_ID, ACTION_TOKEN_ID, ACTION_TOKEN_ID, 13]],
|
||||
dtype=torch.long,
|
||||
)
|
||||
assert input_ids.shape == (batch_size, seq_len)
|
||||
|
||||
batch: dict[str, torch.Tensor | int] = {
|
||||
OBS_STATE: torch.randn(batch_size, STATE_DIM, dtype=torch.float32),
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": torch.ones(batch_size, seq_len, dtype=torch.long),
|
||||
"pixel_values": torch.zeros(batch_size, 3, 4, 4, dtype=torch.float32),
|
||||
"image_grid_thw": torch.tensor([[1, 2, 2]], dtype=torch.long),
|
||||
"mm_token_type_ids": torch.zeros(batch_size, seq_len, dtype=torch.int32),
|
||||
"state_token_id": STATE_TOKEN_ID,
|
||||
"action_token_id": ACTION_TOKEN_ID,
|
||||
}
|
||||
if include_action:
|
||||
batch[ACTION] = torch.randn(batch_size, CHUNK_SIZE, ACTION_DIM, dtype=torch.float32)
|
||||
return batch
|
||||
|
||||
|
||||
def test_lerobot_eo1_forward_pass(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"lerobot.policies.eo1.modeling_eo1.Qwen2_5_VLForConditionalGeneration.from_pretrained",
|
||||
lambda *args, **kwargs: DummyVLMBackbone(HIDDEN_SIZE),
|
||||
)
|
||||
policy = EO1Policy(make_eo1_config())
|
||||
|
||||
loss, metrics = policy.forward(make_policy_batch(include_action=True))
|
||||
|
||||
assert loss.ndim == 0
|
||||
assert torch.isfinite(loss)
|
||||
assert metrics["loss"] == pytest.approx(loss.item())
|
||||
|
||||
|
||||
def test_lerobot_eo1_inference(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"lerobot.policies.eo1.modeling_eo1.Qwen2_5_VLForConditionalGeneration.from_pretrained",
|
||||
lambda *args, **kwargs: DummyVLMBackbone(HIDDEN_SIZE),
|
||||
)
|
||||
policy = EO1Policy(make_eo1_config())
|
||||
|
||||
sample_calls = {"count": 0}
|
||||
fixed_chunk = torch.tensor(
|
||||
[
|
||||
[
|
||||
[0.1, 0.2, 0.3, 9.0, 9.0, 9.0],
|
||||
[1.1, 1.2, 1.3, 9.0, 9.0, 9.0],
|
||||
[2.1, 2.2, 2.3, 9.0, 9.0, 9.0],
|
||||
]
|
||||
],
|
||||
dtype=torch.float32,
|
||||
)
|
||||
|
||||
def fake_sample_actions(**kwargs):
|
||||
sample_calls["count"] += 1
|
||||
return fixed_chunk
|
||||
|
||||
monkeypatch.setattr(policy.model, "sample_actions", fake_sample_actions)
|
||||
|
||||
batch = make_policy_batch(include_action=False)
|
||||
action_0 = policy.select_action(batch)
|
||||
action_1 = policy.select_action(batch)
|
||||
|
||||
torch.testing.assert_close(action_0, fixed_chunk[:, 0, :ACTION_DIM])
|
||||
torch.testing.assert_close(action_1, fixed_chunk[:, 1, :ACTION_DIM])
|
||||
assert sample_calls["count"] == 1
|
||||
Reference in New Issue
Block a user