From a5e64099854c416c791fc737df48d1e6b96959a4 Mon Sep 17 00:00:00 2001 From: javadcc_mac Date: Mon, 11 May 2026 21:51:41 +0800 Subject: [PATCH] fix(evo1): finalize policy guide alignment --- docs/source/evo1.mdx | 56 ++++++++++++- docs/source/policy_evo1_README.md | 18 +++++ pyproject.toml | 2 +- src/lerobot/policies/evo1/README.md | 2 +- .../policies/evo1/configuration_evo1.py | 16 +++- .../policies/evo1/internvl3_embedder.py | 81 ++++++++++++++++--- src/lerobot/policies/evo1/modeling_evo1.py | 36 +++++++-- tests/policies/evo1/test_evo1.py | 60 +++++++++++++- uv.lock | 3 + 9 files changed, 253 insertions(+), 21 deletions(-) create mode 100644 docs/source/policy_evo1_README.md diff --git a/docs/source/evo1.mdx b/docs/source/evo1.mdx index a86f7a56b..210e4e488 100644 --- a/docs/source/evo1.mdx +++ b/docs/source/evo1.mdx @@ -26,7 +26,13 @@ The broader EVO1 project may include additional training scripts and dataset too pip install -e ".[evo1]" ``` -3. Install a `flash-attn` wheel only if it is compatible with your Python, PyTorch, CUDA, and GPU stack. EVO1 falls back to standard attention when `flash_attn` is not available. + For LIBERO evaluation, install the LIBERO extra as well: + + ```bash + pip install -e ".[evo1,libero]" + ``` + +3. Install a `flash-attn` wheel only if it is compatible with your Python, PyTorch, CUDA, and GPU stack. EVO1 falls back to standard attention when `flash_attn` is not available, but reproducing the official LIBERO checkpoint conversion result below requires the same FlashAttention path used by the original EVO1 checkpoint. EVO1 uses InternVL3 through the Hugging Face `transformers` remote-code path, so the first run may download the configured VLM checkpoint unless `policy.vlm_model_name` points to a local model directory. @@ -61,6 +67,12 @@ Once a LeRobot-format EVO1 checkpoint is available, load it with: policy.path=your-org/your-evo1-checkpoint ``` +The converted LIBERO checkpoint used for this PR is available at: + +```python +policy.path=javadcc/evo1-libero-lerobot +``` + ## Training ### Stage 1 @@ -105,12 +117,19 @@ lerobot-train \ --output_dir=./outputs/evo1_stage2 ``` +By default, `policy.training_stage` reapplies the finetuning defaults for that stage. This is important when +starting Stage 2 from a Stage 1 checkpoint, because the Stage 1 checkpoint config stores the VLM finetuning +flags as disabled. These stage defaults take precedence over saved or manually supplied `policy.finetune_*` +flags unless `policy.apply_training_stage_defaults=false`, so set that flag only when manually controlling +every finetuning flag. + ### Key Training Parameters | Parameter | Default | Description | | --------------------------------------------- | ------------------------ | ----------------------------------------------------------------- | | `policy.vlm_model_name` | `OpenGVLab/InternVL3-1B` | InternVL3 checkpoint or local model directory | | `policy.training_stage` | `stage1` | `stage1` trains the action head; `stage2` finetunes VLM branches | +| `policy.apply_training_stage_defaults` | `true` | Reapplies stage finetuning defaults after loading a checkpoint | | `policy.vlm_num_layers` | `14` | Number of InternVL3 language layers kept for the policy | | `policy.vlm_dtype` | `bfloat16` | Requested VLM dtype | | `policy.use_flash_attn` | `true` | Requests FlashAttention when installed; otherwise falls back | @@ -122,6 +141,41 @@ lerobot-train \ | `policy.max_action_dim` | `24` | Action padding dimension | | `policy.task_field` | `task` | Batch field used as the language prompt | +## Results + +### LIBERO Object Checkpoint Conversion + +The checkpoint [javadcc/evo1-libero-lerobot](https://huggingface.co/javadcc/evo1-libero-lerobot) +is the LeRobot-format conversion of the official EVO1 LIBERO checkpoint. The conversion was checked against +the official EVO1 checkpoint with the same LIBERO Object initial states and action postprocessing. + +| Checkpoint | Suite | Episodes | Success Rate | +| -------------------------- | --------------- | --------------- | ------------ | +| Official EVO1 checkpoint | `libero_object` | 10, one per task | 100% | +| LeRobot converted checkpoint | `libero_object` | 10, one per task | 100% | + +For a fixed `libero_object` rollout, the official checkpoint and LeRobot checkpoint produced identical +pixel embeddings, VLM fused tokens, normalized actions, and denormalized actions for the checked action step +(`max_abs_diff=0.0`). + +The published checkpoint expects the raw LIBERO camera feature names +`observation.images.agentview_image` and `observation.images.robot0_eye_in_hand_image`. To run the converted +checkpoint with LeRobot LIBERO evaluation for the same one-episode-per-task setting, keep those camera names +instead of the default `image`/`image2` mapping: + +```bash +lerobot-eval \ + --policy.path=javadcc/evo1-libero-lerobot \ + --policy.device=cuda \ + --env.type=libero \ + --env.task=libero_object \ + --env.camera_name_mapping="{agentview_image: agentview_image, robot0_eye_in_hand_image: robot0_eye_in_hand_image}" \ + --env.observation_height=448 \ + --env.observation_width=448 \ + --eval.batch_size=1 \ + --eval.n_episodes=1 +``` + ## References - [EVO1 repository](https://github.com/MINT-SJTU/Evo-1) diff --git a/docs/source/policy_evo1_README.md b/docs/source/policy_evo1_README.md new file mode 100644 index 000000000..3c6d31c83 --- /dev/null +++ b/docs/source/policy_evo1_README.md @@ -0,0 +1,18 @@ +# EVO1 + +EVO1 is a Vision-Language-Action policy for robot control. The LeRobot +integration uses an InternVL3 vision-language backbone with a flow-matching +action head, and supports staged training through the standard LeRobot policy +APIs. + +The upstream EVO1 project is available at +[MINT-SJTU/Evo-1](https://github.com/MINT-SJTU/Evo-1). + +```bibtex +@misc{evo1, + title = {EVO1}, + author = {{MINT-SJTU}}, + year = {2026}, + howpublished = {\url{https://github.com/MINT-SJTU/Evo-1}}, +} +``` diff --git a/pyproject.toml b/pyproject.toml index 241a5bcf0..49f61523d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -195,7 +195,7 @@ groot = [ sarm = ["lerobot[transformers-dep]", "pydantic>=2.0.0,<3.0.0", "faker>=33.0.0,<35.0.0", "lerobot[matplotlib-dep]", "lerobot[qwen-vl-utils-dep]"] xvla = ["lerobot[transformers-dep]"] eo1 = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]"] -evo1 = ["lerobot[transformers-dep]"] +evo1 = ["lerobot[transformers-dep]", "timm>=1.0.0,<1.1.0"] hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"] # Features diff --git a/src/lerobot/policies/evo1/README.md b/src/lerobot/policies/evo1/README.md index bcd30fe73..6c4284fb9 120000 --- a/src/lerobot/policies/evo1/README.md +++ b/src/lerobot/policies/evo1/README.md @@ -1 +1 @@ -../../../../docs/source/evo1.mdx \ No newline at end of file +../../../../docs/source/policy_evo1_README.md \ No newline at end of file diff --git a/src/lerobot/policies/evo1/configuration_evo1.py b/src/lerobot/policies/evo1/configuration_evo1.py index 4cfec4d28..6804535d0 100644 --- a/src/lerobot/policies/evo1/configuration_evo1.py +++ b/src/lerobot/policies/evo1/configuration_evo1.py @@ -89,6 +89,9 @@ class Evo1Config(PreTrainedConfig): finetune_language_model: bool | None = None finetune_vision_model: bool | None = None finetune_action_head: bool | None = None + # Reapply stage defaults after loading checkpoint configs so stage2 cannot + # accidentally inherit the frozen VLM flags stored by a stage1 checkpoint. + apply_training_stage_defaults: bool = True task_field: str = "task" embodiment_id_field: str | None = None @@ -110,7 +113,18 @@ class Evo1Config(PreTrainedConfig): f"Unsupported EVO1 training_stage '{self.training_stage}', expected 'stage1' or 'stage2'" ) - if self.training_stage == "stage1": + if self.apply_training_stage_defaults: + if self.training_stage == "stage1": + self.finetune_vlm = False + self.finetune_language_model = False + self.finetune_vision_model = False + self.finetune_action_head = True + elif self.training_stage == "stage2": + self.finetune_vlm = True + self.finetune_language_model = True + self.finetune_vision_model = True + self.finetune_action_head = True + elif self.training_stage == "stage1": if self.finetune_vlm is None: self.finetune_vlm = False if self.finetune_language_model is None: diff --git a/src/lerobot/policies/evo1/internvl3_embedder.py b/src/lerobot/policies/evo1/internvl3_embedder.py index 8962b8f0d..20745f8b6 100644 --- a/src/lerobot/policies/evo1/internvl3_embedder.py +++ b/src/lerobot/policies/evo1/internvl3_embedder.py @@ -16,11 +16,14 @@ from __future__ import annotations import functools import logging +import types from collections.abc import Sequence +from contextlib import contextmanager from typing import TYPE_CHECKING import torch import torch.nn as nn +import torch.utils.checkpoint import torchvision.transforms.functional as TF from PIL import Image from torchvision.transforms.functional import to_pil_image @@ -42,6 +45,31 @@ IMG_END_TOKEN = "" # nosec B105 logger = logging.getLogger(__name__) +def _patch_vision_encoder_checkpointing(encoder: nn.Module, use_reentrant: bool) -> None: + if getattr(encoder, "_evo1_checkpoint_patch_applied", False): + encoder.gradient_checkpointing_use_reentrant = use_reentrant + return + + original_forward = encoder.forward + + def forward_with_checkpoint_kwargs(self, *args, **kwargs): + original_checkpoint = torch.utils.checkpoint.checkpoint + + def checkpoint(function, *checkpoint_args, **checkpoint_kwargs): + checkpoint_kwargs.setdefault("use_reentrant", self.gradient_checkpointing_use_reentrant) + return original_checkpoint(function, *checkpoint_args, **checkpoint_kwargs) + + torch.utils.checkpoint.checkpoint = checkpoint + try: + return original_forward(*args, **kwargs) + finally: + torch.utils.checkpoint.checkpoint = original_checkpoint + + encoder.gradient_checkpointing_use_reentrant = use_reentrant + encoder.forward = types.MethodType(forward_with_checkpoint_kwargs, encoder) + encoder._evo1_checkpoint_patch_applied = True + + def flash_attn_is_available() -> bool: try: import flash_attn # noqa: F401 @@ -50,6 +78,32 @@ def flash_attn_is_available() -> bool: return True +@contextmanager +def _internvl_transformers5_load_compatibility(): + from transformers.modeling_utils import PreTrainedModel + + original_linspace = torch.linspace + original_mark_tied = PreTrainedModel.mark_tied_weights_as_initialized + + def linspace(*args, **kwargs): + if kwargs.get("device") is None: + kwargs["device"] = torch.device("cpu") + return original_linspace(*args, **kwargs) + + def mark_tied_weights_as_initialized(self, loading_info): + if not hasattr(self, "all_tied_weights_keys"): + self.all_tied_weights_keys = {} + return original_mark_tied(self, loading_info) + + torch.linspace = linspace + PreTrainedModel.mark_tied_weights_as_initialized = mark_tied_weights_as_initialized + try: + yield + finally: + torch.linspace = original_linspace + PreTrainedModel.mark_tied_weights_as_initialized = original_mark_tied + + @functools.lru_cache(maxsize=10000) def get_target_aspect_ratio(orig_width: int, orig_height: int, image_size: int, min_num: int, max_num: int): aspect_ratio = orig_width / orig_height @@ -130,14 +184,19 @@ class InternVL3Embedder(nn.Module): if use_flash_attn and not resolved_use_flash_attn: logger.warning("flash_attn is not installed. Falling back to standard attention.") - self.model = AutoModel.from_pretrained( - model_name, - torch_dtype=model_dtype, - trust_remote_code=True, - use_flash_attn=resolved_use_flash_attn, - low_cpu_mem_usage=True, - _fast_init=False, - ).to(self._requested_device) + # InternVL3 remote code predates Transformers 5 post-init conventions: + # it computes stochastic-depth scalars via torch.linspace(...).item() + # while Transformers initializes under torch.device("meta"), and it + # does not populate all_tied_weights_keys before loading finalization. + with _internvl_transformers5_load_compatibility(): + self.model = AutoModel.from_pretrained( + model_name, + torch_dtype=model_dtype, + trust_remote_code=True, + use_flash_attn=resolved_use_flash_attn, + low_cpu_mem_usage=True, + _fast_init=False, + ).to(self._requested_device) if hasattr(self.model.language_model, "model"): layers = self.model.language_model.model.layers @@ -192,7 +251,11 @@ class InternVL3Embedder(nn.Module): enabled_any = _enable_ckpt(self.model) if hasattr(self.model, "vision_model") and hasattr(self.model.vision_model, "encoder"): - self.model.vision_model.encoder.gradient_checkpointing = True + encoder = self.model.vision_model.encoder + encoder.gradient_checkpointing = True + _patch_vision_encoder_checkpointing( + encoder, use_reentrant=self.gradient_checkpointing_use_reentrant + ) enabled_any = True language_model = getattr(self.model, "language_model", None) diff --git a/src/lerobot/policies/evo1/modeling_evo1.py b/src/lerobot/policies/evo1/modeling_evo1.py index 91459d722..7867d0c8e 100644 --- a/src/lerobot/policies/evo1/modeling_evo1.py +++ b/src/lerobot/policies/evo1/modeling_evo1.py @@ -295,6 +295,14 @@ class EVO1Policy(PreTrainedPolicy): embodiment_ids = embodiment_ids[:, -1] return embodiment_ids.to(device=self.config.device, dtype=torch.long) + @property + def _tracks_vlm_gradients(self) -> bool: + return bool( + self.config.finetune_vlm + or self.config.finetune_language_model + or self.config.finetune_vision_model + ) + def _collect_image_batches(self, batch: dict[str, Tensor]) -> tuple[list[list[Tensor]], Tensor]: camera_keys = self._camera_keys or sorted(key for key in batch if key.startswith(f"{OBS_IMAGES}.")) if not camera_keys: @@ -338,12 +346,28 @@ class EVO1Policy(PreTrainedPolicy): image_batches: list[list[Tensor]], image_masks: Tensor, ) -> Tensor: - fused_tokens = self.model.get_vl_embeddings( - images=image_batches, - image_mask=image_masks, - prompt=prompts, - return_cls_only=self.config.return_cls_only, - ) + track_vlm_gradients = self._tracks_vlm_gradients + grad_context = nullcontext() if track_vlm_gradients else torch.no_grad() + embedder = getattr(self.model, "embedder", None) + embedder_was_training = embedder.training if embedder is not None else None + + if not track_vlm_gradients and embedder is not None: + embedder.eval() + + try: + with grad_context: + fused_tokens = self.model.get_vl_embeddings( + images=image_batches, + image_mask=image_masks, + prompt=prompts, + return_cls_only=self.config.return_cls_only, + ) + finally: + if not track_vlm_gradients and embedder is not None and embedder_was_training is not None: + embedder.train(embedder_was_training) + + if not track_vlm_gradients: + fused_tokens = fused_tokens.detach() return fused_tokens.to(device=self.config.device, dtype=self._compute_dtype) def _compute_masked_loss( diff --git a/tests/policies/evo1/test_evo1.py b/tests/policies/evo1/test_evo1.py index 706c1903f..7ccd6274e 100644 --- a/tests/policies/evo1/test_evo1.py +++ b/tests/policies/evo1/test_evo1.py @@ -38,15 +38,20 @@ class DummyEVO1(nn.Module): def __init__(self, config): super().__init__() self.config = config + self.embedder = nn.Dropout(p=0.0) self.action_head = nn.Linear(1, 1) self.get_vl_embeddings_calls = 0 + self.grad_enabled_calls = [] + self.embedder_training_calls = [] def set_finetune_flags(self): return None def get_vl_embeddings(self, images, image_mask, prompt=None, return_cls_only=False): self.get_vl_embeddings_calls += 1 - return torch.ones(len(images), 4, EMBED_DIM) + self.grad_enabled_calls.append(torch.is_grad_enabled()) + self.embedder_training_calls.append(self.embedder.training) + return torch.ones(len(images), 4, EMBED_DIM, requires_grad=torch.is_grad_enabled()) def forward( self, @@ -136,8 +141,27 @@ def test_evo1_stage_defaults_and_consistency(): ) assert stage2.finetune_action_head is True + stage2_from_stage1_checkpoint_flags = make_config( + training_stage="stage2", + finetune_vlm=False, + finetune_language_model=False, + finetune_vision_model=False, + finetune_action_head=False, + ) + assert ( + stage2_from_stage1_checkpoint_flags.finetune_vlm, + stage2_from_stage1_checkpoint_flags.finetune_language_model, + stage2_from_stage1_checkpoint_flags.finetune_vision_model, + ) == ( + True, + True, + True, + ) + assert stage2_from_stage1_checkpoint_flags.finetune_action_head is True + explicit_off = make_config( training_stage="stage2", + apply_training_stage_defaults=False, finetune_vlm=False, finetune_language_model=False, finetune_vision_model=False, @@ -155,7 +179,12 @@ def test_evo1_stage_defaults_and_consistency(): assert explicit_off.finetune_action_head is False try: - make_config(training_stage="stage2", finetune_vlm=True, finetune_language_model=False) + make_config( + training_stage="stage2", + apply_training_stage_defaults=False, + finetune_vlm=True, + finetune_language_model=False, + ) except ValueError as exc: assert "Inconsistent EVO1 finetune config" in str(exc) else: @@ -180,6 +209,33 @@ def test_evo1_policy_forward_and_inference_use_batched_embedding(monkeypatch): assert selected.shape == (2, ACTION_DIM) +def test_stage1_frozen_vlm_embeddings_do_not_track_gradients(monkeypatch): + monkeypatch.setattr(modeling_evo1, "EVO1", DummyEVO1) + policy = modeling_evo1.EVO1Policy(make_config(training_stage="stage1")) + policy.train() + + image_batches, image_masks = policy._collect_image_batches(make_batch(include_action=False)) + fused_tokens = policy._compute_fused_tokens(["pick", "place"], image_batches, image_masks) + + assert policy.model.grad_enabled_calls == [False] + assert policy.model.embedder_training_calls == [False] + assert not fused_tokens.requires_grad + assert policy.model.embedder.training is True + + +def test_stage2_vlm_embeddings_track_gradients(monkeypatch): + monkeypatch.setattr(modeling_evo1, "EVO1", DummyEVO1) + policy = modeling_evo1.EVO1Policy(make_config(training_stage="stage2")) + policy.train() + + image_batches, image_masks = policy._collect_image_batches(make_batch(include_action=False)) + fused_tokens = policy._compute_fused_tokens(["pick", "place"], image_batches, image_masks) + + assert policy.model.grad_enabled_calls == [True] + assert policy.model.embedder_training_calls == [True] + assert fused_tokens.requires_grad + + def test_collect_image_batches_handles_unbatched_chw(monkeypatch): # Regression for an issue where batch_size was read from shape[0] before normalizing # per-camera tensor dims, so an unbatched (C, H, W) input was treated as batch_size=C. diff --git a/uv.lock b/uv.lock index 36560c289..08a285e92 100644 --- a/uv.lock +++ b/uv.lock @@ -2723,6 +2723,7 @@ all = [ { name = "scikit-image" }, { name = "scipy" }, { name = "teleop" }, + { name = "timm" }, { name = "torchcodec", marker = "(platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l' and sys_platform == 'linux') or (platform_machine != 'x86_64' and sys_platform == 'darwin') or (sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')" }, { name = "torchdiffeq" }, { name = "transformers" }, @@ -2815,6 +2816,7 @@ evaluation = [ { name = "av" }, ] evo1 = [ + { name = "timm" }, { name = "transformers" }, ] feetech = [ @@ -3198,6 +3200,7 @@ requires-dist = [ { name = "setuptools", specifier = ">=71.0.0,<81.0.0" }, { name = "teleop", marker = "extra == 'phone'", specifier = ">=0.1.0,<0.2.0" }, { name = "termcolor", specifier = ">=2.4.0,<4.0.0" }, + { name = "timm", marker = "extra == 'evo1'", specifier = ">=1.0.0,<1.1.0" }, { name = "timm", marker = "extra == 'groot'", specifier = ">=1.0.0,<1.1.0" }, { name = "torch", specifier = ">=2.7,<2.13.0" }, { name = "torchcodec", marker = "(platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l' and sys_platform == 'linux' and extra == 'dataset') or (platform_machine != 'x86_64' and sys_platform == 'darwin' and extra == 'dataset') or (sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32' and extra == 'dataset')", specifier = ">=0.3.0,<0.13.0" },