mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-19 17:27:03 +00:00
fix(evo1): finalize policy guide alignment
This commit is contained in:
+55
-1
@@ -26,7 +26,13 @@ The broader EVO1 project may include additional training scripts and dataset too
|
||||
pip install -e ".[evo1]"
|
||||
```
|
||||
|
||||
3. Install a `flash-attn` wheel only if it is compatible with your Python, PyTorch, CUDA, and GPU stack. EVO1 falls back to standard attention when `flash_attn` is not available.
|
||||
For LIBERO evaluation, install the LIBERO extra as well:
|
||||
|
||||
```bash
|
||||
pip install -e ".[evo1,libero]"
|
||||
```
|
||||
|
||||
3. Install a `flash-attn` wheel only if it is compatible with your Python, PyTorch, CUDA, and GPU stack. EVO1 falls back to standard attention when `flash_attn` is not available, but reproducing the official LIBERO checkpoint conversion result below requires the same FlashAttention path used by the original EVO1 checkpoint.
|
||||
|
||||
EVO1 uses InternVL3 through the Hugging Face `transformers` remote-code path, so the first run may download the configured VLM checkpoint unless `policy.vlm_model_name` points to a local model directory.
|
||||
|
||||
@@ -61,6 +67,12 @@ Once a LeRobot-format EVO1 checkpoint is available, load it with:
|
||||
policy.path=your-org/your-evo1-checkpoint
|
||||
```
|
||||
|
||||
The converted LIBERO checkpoint used for this PR is available at:
|
||||
|
||||
```python
|
||||
policy.path=javadcc/evo1-libero-lerobot
|
||||
```
|
||||
|
||||
## Training
|
||||
|
||||
### Stage 1
|
||||
@@ -105,12 +117,19 @@ lerobot-train \
|
||||
--output_dir=./outputs/evo1_stage2
|
||||
```
|
||||
|
||||
By default, `policy.training_stage` reapplies the finetuning defaults for that stage. This is important when
|
||||
starting Stage 2 from a Stage 1 checkpoint, because the Stage 1 checkpoint config stores the VLM finetuning
|
||||
flags as disabled. These stage defaults take precedence over saved or manually supplied `policy.finetune_*`
|
||||
flags unless `policy.apply_training_stage_defaults=false`, so set that flag only when manually controlling
|
||||
every finetuning flag.
|
||||
|
||||
### Key Training Parameters
|
||||
|
||||
| Parameter | Default | Description |
|
||||
| --------------------------------------------- | ------------------------ | ----------------------------------------------------------------- |
|
||||
| `policy.vlm_model_name` | `OpenGVLab/InternVL3-1B` | InternVL3 checkpoint or local model directory |
|
||||
| `policy.training_stage` | `stage1` | `stage1` trains the action head; `stage2` finetunes VLM branches |
|
||||
| `policy.apply_training_stage_defaults` | `true` | Reapplies stage finetuning defaults after loading a checkpoint |
|
||||
| `policy.vlm_num_layers` | `14` | Number of InternVL3 language layers kept for the policy |
|
||||
| `policy.vlm_dtype` | `bfloat16` | Requested VLM dtype |
|
||||
| `policy.use_flash_attn` | `true` | Requests FlashAttention when installed; otherwise falls back |
|
||||
@@ -122,6 +141,41 @@ lerobot-train \
|
||||
| `policy.max_action_dim` | `24` | Action padding dimension |
|
||||
| `policy.task_field` | `task` | Batch field used as the language prompt |
|
||||
|
||||
## Results
|
||||
|
||||
### LIBERO Object Checkpoint Conversion
|
||||
|
||||
The checkpoint [javadcc/evo1-libero-lerobot](https://huggingface.co/javadcc/evo1-libero-lerobot)
|
||||
is the LeRobot-format conversion of the official EVO1 LIBERO checkpoint. The conversion was checked against
|
||||
the official EVO1 checkpoint with the same LIBERO Object initial states and action postprocessing.
|
||||
|
||||
| Checkpoint | Suite | Episodes | Success Rate |
|
||||
| -------------------------- | --------------- | --------------- | ------------ |
|
||||
| Official EVO1 checkpoint | `libero_object` | 10, one per task | 100% |
|
||||
| LeRobot converted checkpoint | `libero_object` | 10, one per task | 100% |
|
||||
|
||||
For a fixed `libero_object` rollout, the official checkpoint and LeRobot checkpoint produced identical
|
||||
pixel embeddings, VLM fused tokens, normalized actions, and denormalized actions for the checked action step
|
||||
(`max_abs_diff=0.0`).
|
||||
|
||||
The published checkpoint expects the raw LIBERO camera feature names
|
||||
`observation.images.agentview_image` and `observation.images.robot0_eye_in_hand_image`. To run the converted
|
||||
checkpoint with LeRobot LIBERO evaluation for the same one-episode-per-task setting, keep those camera names
|
||||
instead of the default `image`/`image2` mapping:
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--policy.path=javadcc/evo1-libero-lerobot \
|
||||
--policy.device=cuda \
|
||||
--env.type=libero \
|
||||
--env.task=libero_object \
|
||||
--env.camera_name_mapping="{agentview_image: agentview_image, robot0_eye_in_hand_image: robot0_eye_in_hand_image}" \
|
||||
--env.observation_height=448 \
|
||||
--env.observation_width=448 \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=1
|
||||
```
|
||||
|
||||
## References
|
||||
|
||||
- [EVO1 repository](https://github.com/MINT-SJTU/Evo-1)
|
||||
|
||||
@@ -0,0 +1,18 @@
|
||||
# EVO1
|
||||
|
||||
EVO1 is a Vision-Language-Action policy for robot control. The LeRobot
|
||||
integration uses an InternVL3 vision-language backbone with a flow-matching
|
||||
action head, and supports staged training through the standard LeRobot policy
|
||||
APIs.
|
||||
|
||||
The upstream EVO1 project is available at
|
||||
[MINT-SJTU/Evo-1](https://github.com/MINT-SJTU/Evo-1).
|
||||
|
||||
```bibtex
|
||||
@misc{evo1,
|
||||
title = {EVO1},
|
||||
author = {{MINT-SJTU}},
|
||||
year = {2026},
|
||||
howpublished = {\url{https://github.com/MINT-SJTU/Evo-1}},
|
||||
}
|
||||
```
|
||||
+1
-1
@@ -195,7 +195,7 @@ groot = [
|
||||
sarm = ["lerobot[transformers-dep]", "pydantic>=2.0.0,<3.0.0", "faker>=33.0.0,<35.0.0", "lerobot[matplotlib-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||
xvla = ["lerobot[transformers-dep]"]
|
||||
eo1 = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||
evo1 = ["lerobot[transformers-dep]"]
|
||||
evo1 = ["lerobot[transformers-dep]", "timm>=1.0.0,<1.1.0"]
|
||||
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
||||
|
||||
# Features
|
||||
|
||||
@@ -1 +1 @@
|
||||
../../../../docs/source/evo1.mdx
|
||||
../../../../docs/source/policy_evo1_README.md
|
||||
@@ -89,6 +89,9 @@ class Evo1Config(PreTrainedConfig):
|
||||
finetune_language_model: bool | None = None
|
||||
finetune_vision_model: bool | None = None
|
||||
finetune_action_head: bool | None = None
|
||||
# Reapply stage defaults after loading checkpoint configs so stage2 cannot
|
||||
# accidentally inherit the frozen VLM flags stored by a stage1 checkpoint.
|
||||
apply_training_stage_defaults: bool = True
|
||||
|
||||
task_field: str = "task"
|
||||
embodiment_id_field: str | None = None
|
||||
@@ -110,7 +113,18 @@ class Evo1Config(PreTrainedConfig):
|
||||
f"Unsupported EVO1 training_stage '{self.training_stage}', expected 'stage1' or 'stage2'"
|
||||
)
|
||||
|
||||
if self.training_stage == "stage1":
|
||||
if self.apply_training_stage_defaults:
|
||||
if self.training_stage == "stage1":
|
||||
self.finetune_vlm = False
|
||||
self.finetune_language_model = False
|
||||
self.finetune_vision_model = False
|
||||
self.finetune_action_head = True
|
||||
elif self.training_stage == "stage2":
|
||||
self.finetune_vlm = True
|
||||
self.finetune_language_model = True
|
||||
self.finetune_vision_model = True
|
||||
self.finetune_action_head = True
|
||||
elif self.training_stage == "stage1":
|
||||
if self.finetune_vlm is None:
|
||||
self.finetune_vlm = False
|
||||
if self.finetune_language_model is None:
|
||||
|
||||
@@ -16,11 +16,14 @@ from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import logging
|
||||
import types
|
||||
from collections.abc import Sequence
|
||||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint
|
||||
import torchvision.transforms.functional as TF
|
||||
from PIL import Image
|
||||
from torchvision.transforms.functional import to_pil_image
|
||||
@@ -42,6 +45,31 @@ IMG_END_TOKEN = "</img>" # nosec B105
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _patch_vision_encoder_checkpointing(encoder: nn.Module, use_reentrant: bool) -> None:
|
||||
if getattr(encoder, "_evo1_checkpoint_patch_applied", False):
|
||||
encoder.gradient_checkpointing_use_reentrant = use_reentrant
|
||||
return
|
||||
|
||||
original_forward = encoder.forward
|
||||
|
||||
def forward_with_checkpoint_kwargs(self, *args, **kwargs):
|
||||
original_checkpoint = torch.utils.checkpoint.checkpoint
|
||||
|
||||
def checkpoint(function, *checkpoint_args, **checkpoint_kwargs):
|
||||
checkpoint_kwargs.setdefault("use_reentrant", self.gradient_checkpointing_use_reentrant)
|
||||
return original_checkpoint(function, *checkpoint_args, **checkpoint_kwargs)
|
||||
|
||||
torch.utils.checkpoint.checkpoint = checkpoint
|
||||
try:
|
||||
return original_forward(*args, **kwargs)
|
||||
finally:
|
||||
torch.utils.checkpoint.checkpoint = original_checkpoint
|
||||
|
||||
encoder.gradient_checkpointing_use_reentrant = use_reentrant
|
||||
encoder.forward = types.MethodType(forward_with_checkpoint_kwargs, encoder)
|
||||
encoder._evo1_checkpoint_patch_applied = True
|
||||
|
||||
|
||||
def flash_attn_is_available() -> bool:
|
||||
try:
|
||||
import flash_attn # noqa: F401
|
||||
@@ -50,6 +78,32 @@ def flash_attn_is_available() -> bool:
|
||||
return True
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _internvl_transformers5_load_compatibility():
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
|
||||
original_linspace = torch.linspace
|
||||
original_mark_tied = PreTrainedModel.mark_tied_weights_as_initialized
|
||||
|
||||
def linspace(*args, **kwargs):
|
||||
if kwargs.get("device") is None:
|
||||
kwargs["device"] = torch.device("cpu")
|
||||
return original_linspace(*args, **kwargs)
|
||||
|
||||
def mark_tied_weights_as_initialized(self, loading_info):
|
||||
if not hasattr(self, "all_tied_weights_keys"):
|
||||
self.all_tied_weights_keys = {}
|
||||
return original_mark_tied(self, loading_info)
|
||||
|
||||
torch.linspace = linspace
|
||||
PreTrainedModel.mark_tied_weights_as_initialized = mark_tied_weights_as_initialized
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
torch.linspace = original_linspace
|
||||
PreTrainedModel.mark_tied_weights_as_initialized = original_mark_tied
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=10000)
|
||||
def get_target_aspect_ratio(orig_width: int, orig_height: int, image_size: int, min_num: int, max_num: int):
|
||||
aspect_ratio = orig_width / orig_height
|
||||
@@ -130,14 +184,19 @@ class InternVL3Embedder(nn.Module):
|
||||
if use_flash_attn and not resolved_use_flash_attn:
|
||||
logger.warning("flash_attn is not installed. Falling back to standard attention.")
|
||||
|
||||
self.model = AutoModel.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=model_dtype,
|
||||
trust_remote_code=True,
|
||||
use_flash_attn=resolved_use_flash_attn,
|
||||
low_cpu_mem_usage=True,
|
||||
_fast_init=False,
|
||||
).to(self._requested_device)
|
||||
# InternVL3 remote code predates Transformers 5 post-init conventions:
|
||||
# it computes stochastic-depth scalars via torch.linspace(...).item()
|
||||
# while Transformers initializes under torch.device("meta"), and it
|
||||
# does not populate all_tied_weights_keys before loading finalization.
|
||||
with _internvl_transformers5_load_compatibility():
|
||||
self.model = AutoModel.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=model_dtype,
|
||||
trust_remote_code=True,
|
||||
use_flash_attn=resolved_use_flash_attn,
|
||||
low_cpu_mem_usage=True,
|
||||
_fast_init=False,
|
||||
).to(self._requested_device)
|
||||
|
||||
if hasattr(self.model.language_model, "model"):
|
||||
layers = self.model.language_model.model.layers
|
||||
@@ -192,7 +251,11 @@ class InternVL3Embedder(nn.Module):
|
||||
enabled_any = _enable_ckpt(self.model)
|
||||
|
||||
if hasattr(self.model, "vision_model") and hasattr(self.model.vision_model, "encoder"):
|
||||
self.model.vision_model.encoder.gradient_checkpointing = True
|
||||
encoder = self.model.vision_model.encoder
|
||||
encoder.gradient_checkpointing = True
|
||||
_patch_vision_encoder_checkpointing(
|
||||
encoder, use_reentrant=self.gradient_checkpointing_use_reentrant
|
||||
)
|
||||
enabled_any = True
|
||||
|
||||
language_model = getattr(self.model, "language_model", None)
|
||||
|
||||
@@ -295,6 +295,14 @@ class EVO1Policy(PreTrainedPolicy):
|
||||
embodiment_ids = embodiment_ids[:, -1]
|
||||
return embodiment_ids.to(device=self.config.device, dtype=torch.long)
|
||||
|
||||
@property
|
||||
def _tracks_vlm_gradients(self) -> bool:
|
||||
return bool(
|
||||
self.config.finetune_vlm
|
||||
or self.config.finetune_language_model
|
||||
or self.config.finetune_vision_model
|
||||
)
|
||||
|
||||
def _collect_image_batches(self, batch: dict[str, Tensor]) -> tuple[list[list[Tensor]], Tensor]:
|
||||
camera_keys = self._camera_keys or sorted(key for key in batch if key.startswith(f"{OBS_IMAGES}."))
|
||||
if not camera_keys:
|
||||
@@ -338,12 +346,28 @@ class EVO1Policy(PreTrainedPolicy):
|
||||
image_batches: list[list[Tensor]],
|
||||
image_masks: Tensor,
|
||||
) -> Tensor:
|
||||
fused_tokens = self.model.get_vl_embeddings(
|
||||
images=image_batches,
|
||||
image_mask=image_masks,
|
||||
prompt=prompts,
|
||||
return_cls_only=self.config.return_cls_only,
|
||||
)
|
||||
track_vlm_gradients = self._tracks_vlm_gradients
|
||||
grad_context = nullcontext() if track_vlm_gradients else torch.no_grad()
|
||||
embedder = getattr(self.model, "embedder", None)
|
||||
embedder_was_training = embedder.training if embedder is not None else None
|
||||
|
||||
if not track_vlm_gradients and embedder is not None:
|
||||
embedder.eval()
|
||||
|
||||
try:
|
||||
with grad_context:
|
||||
fused_tokens = self.model.get_vl_embeddings(
|
||||
images=image_batches,
|
||||
image_mask=image_masks,
|
||||
prompt=prompts,
|
||||
return_cls_only=self.config.return_cls_only,
|
||||
)
|
||||
finally:
|
||||
if not track_vlm_gradients and embedder is not None and embedder_was_training is not None:
|
||||
embedder.train(embedder_was_training)
|
||||
|
||||
if not track_vlm_gradients:
|
||||
fused_tokens = fused_tokens.detach()
|
||||
return fused_tokens.to(device=self.config.device, dtype=self._compute_dtype)
|
||||
|
||||
def _compute_masked_loss(
|
||||
|
||||
@@ -38,15 +38,20 @@ class DummyEVO1(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embedder = nn.Dropout(p=0.0)
|
||||
self.action_head = nn.Linear(1, 1)
|
||||
self.get_vl_embeddings_calls = 0
|
||||
self.grad_enabled_calls = []
|
||||
self.embedder_training_calls = []
|
||||
|
||||
def set_finetune_flags(self):
|
||||
return None
|
||||
|
||||
def get_vl_embeddings(self, images, image_mask, prompt=None, return_cls_only=False):
|
||||
self.get_vl_embeddings_calls += 1
|
||||
return torch.ones(len(images), 4, EMBED_DIM)
|
||||
self.grad_enabled_calls.append(torch.is_grad_enabled())
|
||||
self.embedder_training_calls.append(self.embedder.training)
|
||||
return torch.ones(len(images), 4, EMBED_DIM, requires_grad=torch.is_grad_enabled())
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -136,8 +141,27 @@ def test_evo1_stage_defaults_and_consistency():
|
||||
)
|
||||
assert stage2.finetune_action_head is True
|
||||
|
||||
stage2_from_stage1_checkpoint_flags = make_config(
|
||||
training_stage="stage2",
|
||||
finetune_vlm=False,
|
||||
finetune_language_model=False,
|
||||
finetune_vision_model=False,
|
||||
finetune_action_head=False,
|
||||
)
|
||||
assert (
|
||||
stage2_from_stage1_checkpoint_flags.finetune_vlm,
|
||||
stage2_from_stage1_checkpoint_flags.finetune_language_model,
|
||||
stage2_from_stage1_checkpoint_flags.finetune_vision_model,
|
||||
) == (
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
)
|
||||
assert stage2_from_stage1_checkpoint_flags.finetune_action_head is True
|
||||
|
||||
explicit_off = make_config(
|
||||
training_stage="stage2",
|
||||
apply_training_stage_defaults=False,
|
||||
finetune_vlm=False,
|
||||
finetune_language_model=False,
|
||||
finetune_vision_model=False,
|
||||
@@ -155,7 +179,12 @@ def test_evo1_stage_defaults_and_consistency():
|
||||
assert explicit_off.finetune_action_head is False
|
||||
|
||||
try:
|
||||
make_config(training_stage="stage2", finetune_vlm=True, finetune_language_model=False)
|
||||
make_config(
|
||||
training_stage="stage2",
|
||||
apply_training_stage_defaults=False,
|
||||
finetune_vlm=True,
|
||||
finetune_language_model=False,
|
||||
)
|
||||
except ValueError as exc:
|
||||
assert "Inconsistent EVO1 finetune config" in str(exc)
|
||||
else:
|
||||
@@ -180,6 +209,33 @@ def test_evo1_policy_forward_and_inference_use_batched_embedding(monkeypatch):
|
||||
assert selected.shape == (2, ACTION_DIM)
|
||||
|
||||
|
||||
def test_stage1_frozen_vlm_embeddings_do_not_track_gradients(monkeypatch):
|
||||
monkeypatch.setattr(modeling_evo1, "EVO1", DummyEVO1)
|
||||
policy = modeling_evo1.EVO1Policy(make_config(training_stage="stage1"))
|
||||
policy.train()
|
||||
|
||||
image_batches, image_masks = policy._collect_image_batches(make_batch(include_action=False))
|
||||
fused_tokens = policy._compute_fused_tokens(["pick", "place"], image_batches, image_masks)
|
||||
|
||||
assert policy.model.grad_enabled_calls == [False]
|
||||
assert policy.model.embedder_training_calls == [False]
|
||||
assert not fused_tokens.requires_grad
|
||||
assert policy.model.embedder.training is True
|
||||
|
||||
|
||||
def test_stage2_vlm_embeddings_track_gradients(monkeypatch):
|
||||
monkeypatch.setattr(modeling_evo1, "EVO1", DummyEVO1)
|
||||
policy = modeling_evo1.EVO1Policy(make_config(training_stage="stage2"))
|
||||
policy.train()
|
||||
|
||||
image_batches, image_masks = policy._collect_image_batches(make_batch(include_action=False))
|
||||
fused_tokens = policy._compute_fused_tokens(["pick", "place"], image_batches, image_masks)
|
||||
|
||||
assert policy.model.grad_enabled_calls == [True]
|
||||
assert policy.model.embedder_training_calls == [True]
|
||||
assert fused_tokens.requires_grad
|
||||
|
||||
|
||||
def test_collect_image_batches_handles_unbatched_chw(monkeypatch):
|
||||
# Regression for an issue where batch_size was read from shape[0] before normalizing
|
||||
# per-camera tensor dims, so an unbatched (C, H, W) input was treated as batch_size=C.
|
||||
|
||||
@@ -2723,6 +2723,7 @@ all = [
|
||||
{ name = "scikit-image" },
|
||||
{ name = "scipy" },
|
||||
{ name = "teleop" },
|
||||
{ name = "timm" },
|
||||
{ name = "torchcodec", marker = "(platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l' and sys_platform == 'linux') or (platform_machine != 'x86_64' and sys_platform == 'darwin') or (sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')" },
|
||||
{ name = "torchdiffeq" },
|
||||
{ name = "transformers" },
|
||||
@@ -2815,6 +2816,7 @@ evaluation = [
|
||||
{ name = "av" },
|
||||
]
|
||||
evo1 = [
|
||||
{ name = "timm" },
|
||||
{ name = "transformers" },
|
||||
]
|
||||
feetech = [
|
||||
@@ -3198,6 +3200,7 @@ requires-dist = [
|
||||
{ name = "setuptools", specifier = ">=71.0.0,<81.0.0" },
|
||||
{ name = "teleop", marker = "extra == 'phone'", specifier = ">=0.1.0,<0.2.0" },
|
||||
{ name = "termcolor", specifier = ">=2.4.0,<4.0.0" },
|
||||
{ name = "timm", marker = "extra == 'evo1'", specifier = ">=1.0.0,<1.1.0" },
|
||||
{ name = "timm", marker = "extra == 'groot'", specifier = ">=1.0.0,<1.1.0" },
|
||||
{ name = "torch", specifier = ">=2.7,<2.13.0" },
|
||||
{ name = "torchcodec", marker = "(platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l' and sys_platform == 'linux' and extra == 'dataset') or (platform_machine != 'x86_64' and sys_platform == 'darwin' and extra == 'dataset') or (sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32' and extra == 'dataset')", specifier = ">=0.3.0,<0.13.0" },
|
||||
|
||||
Reference in New Issue
Block a user