Compare commits

...

4 Commits

Author SHA1 Message Date
Martino Russi 9423deda02 refactor(evo1): use native HF InternVL3-1B-hf, drop trust_remote_code
- Switch from OpenGVLab/InternVL3-1B (requires trust_remote_code=True)
  to OpenGVLab/InternVL3-1B-hf (native transformers implementation).
- Replace manual _extract_feature + _prepare_and_fuse_embeddings with
  a single model.forward() call — verified bit-for-bit identical output.
- Remove ~170 lines of manual ViT/pixel-shuffle/projection logic.
- Symlink README.md to docs/source/ following repo convention.

Weights are byte-identical between both model variants; only the module
naming differs. All 12 existing unit tests pass. Local training (10 steps)
on maximellerbach/omx_pickandplace confirmed working.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-06-23 17:17:19 +02:00
javadcc_mac 25556ceefe fix(evo1): move LIBERO padding into policy processors 2026-06-21 15:58:38 +08:00
javadcc_mac 4cfa762da8 Fix eval action conversion for bf16 policies 2026-06-13 10:51:33 +08:00
javadcc_mac fa984990c0 Fix EVO1 LIBERO eval action postprocessing 2026-06-13 10:18:34 +08:00
17 changed files with 690 additions and 378 deletions
+13 -3
View File
@@ -139,6 +139,8 @@ every finetuning flag.
| `policy.n_action_steps` | `50` | Number of actions consumed from a sampled chunk |
| `policy.max_state_dim` | `24` | State padding dimension |
| `policy.max_action_dim` | `24` | Action padding dimension |
| `policy.postprocess_action_dim` | `null` | Optional action dimension returned after EVO1 postprocessing |
| `policy.binarize_gripper` | `false` | Binarizes the postprocessed gripper channel for LIBERO-style eval |
| `policy.task_field` | `task` | Batch field used as the language prompt |
## Results
@@ -159,14 +161,22 @@ pixel embeddings, VLM fused tokens, normalized actions, and denormalized actions
(`max_abs_diff=0.0`).
The published checkpoint expects the raw LIBERO camera feature names
`observation.images.agentview_image` and `observation.images.robot0_eye_in_hand_image`. To run the converted
checkpoint with LeRobot LIBERO evaluation for the same one-episode-per-task setting, keep those camera names
instead of the default `image`/`image2` mapping:
`observation.images.agentview_image` and `observation.images.robot0_eye_in_hand_image`. The official EVO1 LIBERO
rollout protocol also replans every 14 actions and binarizes the gripper command before stepping the simulator.
The EVO1 policy postprocessor can crop the padded 24D action back to the 7D LIBERO action space and apply that
gripper binarization. To run the converted checkpoint with LeRobot LIBERO evaluation for the same
one-episode-per-task setting, keep the raw camera names instead of the default `image`/`image2` mapping, enable
FlashAttention, and set the LIBERO action postprocessing flags:
```bash
lerobot-eval \
--policy.path=javadcc/evo1-libero-lerobot \
--policy.vlm_model_name=OpenGVLab/InternVL3-1B \
--policy.device=cuda \
--policy.use_flash_attn=true \
--policy.n_action_steps=14 \
--policy.postprocess_action_dim=7 \
--policy.binarize_gripper=true \
--env.type=libero \
--env.task=libero_object \
--env.camera_name_mapping="{agentview_image: agentview_image, robot0_eye_in_hand_image: robot0_eye_in_hand_image}" \
+3 -3
View File
@@ -140,6 +140,7 @@ pyserial-dep = ["pyserial>=3.5,<4.0"]
deepdiff-dep = ["deepdiff>=7.0.1,<9.0.0"]
pynput-dep = ["pynput>=1.7.8,<1.9.0"]
pyzmq-dep = ["pyzmq>=26.2.1,<28.0.0"]
timm-dep = ["timm>=1.0.0,<1.1.0"]
# Motors
feetech = ["feetech-servo-sdk>=1.0.0,<2.0.0", "lerobot[pyserial-dep]", "lerobot[deepdiff-dep]"]
@@ -187,7 +188,7 @@ groot = [
"lerobot[peft-dep]",
"lerobot[diffusers-dep]",
"dm-tree>=0.1.8,<1.0.0",
"timm>=1.0.0,<1.1.0",
"lerobot[timm-dep]",
"decord>=0.6.0,<1.0.0; (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
"ninja>=1.11.1,<2.0.0",
"flash-attn>=2.5.9,<3.0.0 ; sys_platform != 'darwin'"
@@ -195,7 +196,7 @@ groot = [
sarm = ["lerobot[transformers-dep]", "pydantic>=2.0.0,<3.0.0", "faker>=33.0.0,<35.0.0", "lerobot[matplotlib-dep]", "lerobot[qwen-vl-utils-dep]"]
xvla = ["lerobot[transformers-dep]"]
eo1 = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]"]
evo1 = ["lerobot[transformers-dep]", "timm>=1.0.0,<1.1.0"]
evo1 = ["lerobot[transformers-dep]", "lerobot[timm-dep]"]
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
# Features
@@ -350,7 +351,6 @@ ignore = [
# E402: conditional-import guards (TYPE_CHECKING / is_package_available) must precede the imports they protect
"src/lerobot/scripts/convert_dataset_v21_to_v30.py" = ["E402"]
"src/lerobot/policies/wall_x/**" = ["N801", "N812", "SIM102", "SIM108", "SIM210", "SIM211", "B006", "B007", "SIM118"] # Supprese these as they are coming from original Qwen2_5_vl code TODO(pepijn): refactor original
"src/lerobot/policies/evo1/**" = ["N801", "N812"]
[tool.ruff.lint.isort]
combine-as-imports = true
+5 -9
View File
@@ -26,7 +26,6 @@ from gymnasium.envs.registration import registry as gym_registry
from lerobot.configs import FeatureType, PolicyFeature
from lerobot.processor import (
IsaaclabArenaProcessorStep,
LiberoActionProcessorStep,
LiberoProcessorStep,
PolicyProcessorPipeline,
)
@@ -128,7 +127,7 @@ class EnvConfig(draccus.ChoiceRegistry, abc.ABC):
vec = env_cls([_make_one for _ in range(n_envs)], **extra_kwargs)
return {self.type: {0: vec}}
def get_env_processors(self, policy_cfg: Any | None = None):
def get_env_processors(self):
"""Return (preprocessor, postprocessor) for this env. Default: identity."""
return PolicyProcessorPipeline(steps=[]), PolicyProcessorPipeline(steps=[])
@@ -441,13 +440,10 @@ class LiberoEnv(EnvConfig):
is_libero_plus=self.is_libero_plus,
)
def get_env_processors(self, policy_cfg: Any | None = None):
max_state_dim = getattr(policy_cfg, "max_state_dim", None) if getattr(policy_cfg, "type", None) == "evo1" else None
action_feature = self.features.get(ACTION)
action_dim = int(action_feature.shape[0]) if action_feature is not None else 7
def get_env_processors(self):
return (
PolicyProcessorPipeline(steps=[LiberoProcessorStep(max_state_dim=max_state_dim)]),
PolicyProcessorPipeline(steps=[LiberoActionProcessorStep(action_dim=action_dim)]),
PolicyProcessorPipeline(steps=[LiberoProcessorStep()]),
PolicyProcessorPipeline(steps=[]),
)
@@ -713,7 +709,7 @@ class IsaaclabArenaEnv(HubEnvConfig):
def gym_kwargs(self) -> dict:
return {}
def get_env_processors(self, policy_cfg: Any | None = None):
def get_env_processors(self):
state_keys = tuple(k.strip() for k in (self.state_keys or "").split(",") if k.strip())
camera_keys = tuple(k.strip() for k in (self.camera_keys or "").split(",") if k.strip())
if not state_keys and not camera_keys:
+1 -9
View File
@@ -15,7 +15,6 @@
# limitations under the License.
from __future__ import annotations
import inspect
from typing import Any
import gymnasium as gym
@@ -53,14 +52,7 @@ def make_env_pre_post_processors(
return make_xvla_libero_pre_post_processors()
get_processors = env_cfg.get_env_processors
signature = inspect.signature(get_processors)
supports_policy_cfg = "policy_cfg" in signature.parameters or any(
param.kind is inspect.Parameter.VAR_KEYWORD for param in signature.parameters.values()
)
if supports_policy_cfg:
return get_processors(policy_cfg=policy_cfg)
return get_processors()
return env_cfg.get_env_processors()
def make_env(
+41 -11
View File
@@ -14,6 +14,7 @@
from __future__ import annotations
import logging
import math
from dataclasses import dataclass, field
@@ -26,6 +27,8 @@ from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import LRSchedulerConfig
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
logger = logging.getLogger(__name__)
@LRSchedulerConfig.register_subclass("evo1_exact")
@dataclass
@@ -59,6 +62,12 @@ class Evo1Config(PreTrainedConfig):
max_views: int = 3
image_resolution: tuple[int, int] = (448, 448)
empty_cameras: int = 0
postprocess_action_dim: int | None = None
binarize_gripper: bool = False
gripper_index: int = 6
gripper_threshold: float = 0.5
gripper_below_threshold_value: float = 1.0
gripper_above_threshold_value: float = -1.0
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
@@ -68,7 +77,7 @@ class Evo1Config(PreTrainedConfig):
}
)
vlm_model_name: str = "OpenGVLab/InternVL3-1B"
vlm_model_name: str = "OpenGVLab/InternVL3-1B-hf"
vlm_num_layers: int | None = 14
vlm_dtype: str = "bfloat16"
use_flash_attn: bool = True
@@ -114,16 +123,32 @@ class Evo1Config(PreTrainedConfig):
)
if self.apply_training_stage_defaults:
if self.training_stage == "stage1":
self.finetune_vlm = False
self.finetune_language_model = False
self.finetune_vision_model = False
self.finetune_action_head = True
elif self.training_stage == "stage2":
self.finetune_vlm = True
self.finetune_language_model = True
self.finetune_vision_model = True
self.finetune_action_head = True
stage_defaults = {
"stage1": {
"finetune_vlm": False,
"finetune_language_model": False,
"finetune_vision_model": False,
"finetune_action_head": True,
},
"stage2": {
"finetune_vlm": True,
"finetune_language_model": True,
"finetune_vision_model": True,
"finetune_action_head": True,
},
}[self.training_stage]
for flag_name, default_value in stage_defaults.items():
current_value = getattr(self, flag_name)
if current_value is not None and current_value != default_value:
logger.warning(
"EVO1 %s=%s is overridden by training_stage=%s default %s. "
"Set apply_training_stage_defaults=false to keep explicit finetuning flags.",
flag_name,
current_value,
self.training_stage,
default_value,
)
setattr(self, flag_name, default_value)
elif self.training_stage == "stage1":
if self.finetune_vlm is None:
self.finetune_vlm = False
@@ -171,6 +196,11 @@ class Evo1Config(PreTrainedConfig):
raise ValueError(
f"n_action_steps ({self.n_action_steps}) must be <= chunk_size ({self.chunk_size})"
)
if len(self.image_resolution) != 2 or self.image_resolution[0] != self.image_resolution[1]:
raise ValueError(
"EVO1 currently expects a square image_resolution because InternVL3 preprocessing "
f"uses a scalar image_size, got {self.image_resolution}."
)
def validate_features(self) -> None:
if self.input_features is None:
+2 -33
View File
@@ -21,8 +21,8 @@ import torch
import torch.nn as nn
from PIL import Image
from lerobot.policies.evo1.flow_matching import FlowmatchingActionHead
from lerobot.policies.evo1.internvl3_embedder import InternVL3Embedder
from .flow_matching import FlowmatchingActionHead
from .internvl3_embedder import InternVL3Embedder
def _cfgget(config: Any, key: str, default=None):
@@ -163,37 +163,6 @@ class EVO1(nn.Module):
embodiment_id=embodiment_ids,
)
@torch.no_grad()
def run_inference(
self,
images: list[Image.Image | torch.Tensor],
image_mask: torch.Tensor,
prompt: str,
state_input: list | torch.Tensor,
return_cls_only: bool | None = None,
action_mask: torch.Tensor | None = None,
embodiment_ids: torch.Tensor | None = None,
) -> torch.Tensor:
if image_mask.dim() == 1:
image_mask = image_mask.unsqueeze(0)
fused_tokens = self.get_vl_embeddings(
images=[images],
image_mask=image_mask,
prompt=[prompt],
return_cls_only=return_cls_only,
)
state_tensor = self.prepare_state(state_input)
action = self.predict_action(
fused_tokens,
state_tensor,
action_mask=action_mask,
embodiment_ids=embodiment_ids,
)
if isinstance(action, torch.Tensor) and action.dtype == torch.bfloat16:
action = action.to(torch.float32)
return action
def forward(
self,
fused_tokens: torch.Tensor,
+4 -1
View File
@@ -129,7 +129,10 @@ class MultiEmbodimentActionEncoder(nn.Module):
def forward(self, action_seq: torch.Tensor, category_id: torch.LongTensor):
batch_size, horizon, action_dim = action_seq.shape
assert self.horizon == horizon, "Action sequence length must match horizon"
if self.horizon != horizon:
raise ValueError(
f"Action sequence length must match horizon: got {horizon}, expected {self.horizon}."
)
x = action_seq.reshape(batch_size * horizon, action_dim)
if category_id.dim() == 0:
+73 -197
View File
@@ -16,15 +16,12 @@ from __future__ import annotations
import functools
import logging
import types
from collections.abc import Sequence
from contextlib import contextmanager
from typing import TYPE_CHECKING
import torch
import torch.nn as nn
import torch.utils.checkpoint
import torchvision.transforms.functional as TF
import torchvision.transforms.functional as tvf
from PIL import Image
from torchvision.transforms.functional import to_pil_image
@@ -45,65 +42,6 @@ IMG_END_TOKEN = "</img>" # nosec B105
logger = logging.getLogger(__name__)
def _patch_vision_encoder_checkpointing(encoder: nn.Module, use_reentrant: bool) -> None:
if getattr(encoder, "_evo1_checkpoint_patch_applied", False):
encoder.gradient_checkpointing_use_reentrant = use_reentrant
return
original_forward = encoder.forward
def forward_with_checkpoint_kwargs(self, *args, **kwargs):
original_checkpoint = torch.utils.checkpoint.checkpoint
def checkpoint(function, *checkpoint_args, **checkpoint_kwargs):
checkpoint_kwargs.setdefault("use_reentrant", self.gradient_checkpointing_use_reentrant)
return original_checkpoint(function, *checkpoint_args, **checkpoint_kwargs)
torch.utils.checkpoint.checkpoint = checkpoint
try:
return original_forward(*args, **kwargs)
finally:
torch.utils.checkpoint.checkpoint = original_checkpoint
encoder.gradient_checkpointing_use_reentrant = use_reentrant
encoder.forward = types.MethodType(forward_with_checkpoint_kwargs, encoder)
encoder._evo1_checkpoint_patch_applied = True
def flash_attn_is_available() -> bool:
try:
import flash_attn # noqa: F401
except ModuleNotFoundError:
return False
return True
@contextmanager
def _internvl_transformers5_load_compatibility():
from transformers.modeling_utils import PreTrainedModel
original_linspace = torch.linspace
original_mark_tied = PreTrainedModel.mark_tied_weights_as_initialized
def linspace(*args, **kwargs):
if kwargs.get("device") is None:
kwargs["device"] = torch.device("cpu")
return original_linspace(*args, **kwargs)
def mark_tied_weights_as_initialized(self, loading_info):
if not hasattr(self, "all_tied_weights_keys"):
self.all_tied_weights_keys = {}
return original_mark_tied(self, loading_info)
torch.linspace = linspace
PreTrainedModel.mark_tied_weights_as_initialized = mark_tied_weights_as_initialized
try:
yield
finally:
torch.linspace = original_linspace
PreTrainedModel.mark_tied_weights_as_initialized = original_mark_tied
@functools.lru_cache(maxsize=10000)
def get_target_aspect_ratio(orig_width: int, orig_height: int, image_size: int, min_num: int, max_num: int):
aspect_ratio = orig_width / orig_height
@@ -152,9 +90,11 @@ def dynamic_preprocess(image, min_num=1, max_num=1, image_size=448, use_thumbnai
class InternVL3Embedder(nn.Module):
"""Vision-language embedder using the native HF InternVL3 model (no trust_remote_code)."""
def __init__(
self,
model_name="OpenGVLab/InternVL3-1B",
model_name="OpenGVLab/InternVL3-1B-hf",
image_size=448,
device="cuda",
num_language_layers: int | None = 14,
@@ -173,43 +113,31 @@ class InternVL3Embedder(nn.Module):
require_package("transformers", extra="evo1")
self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, use_fast=False)
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
if isinstance(model_dtype, str):
try:
model_dtype = getattr(torch, model_dtype)
except AttributeError as exc:
raise ValueError(f"Unsupported EVO1 vlm_dtype '{model_dtype}'") from exc
resolved_use_flash_attn = bool(use_flash_attn and flash_attn_is_available())
if use_flash_attn and not resolved_use_flash_attn:
logger.warning("flash_attn is not installed. Falling back to standard attention.")
attn_implementation = "flash_attention_2" if (use_flash_attn and _flash_attn_available()) else "eager"
if use_flash_attn and attn_implementation == "eager":
logger.warning("flash_attn is not installed. Falling back to eager attention.")
# InternVL3 remote code predates Transformers 5 post-init conventions:
# it computes stochastic-depth scalars via torch.linspace(...).item()
# while Transformers initializes under torch.device("meta"), and it
# does not populate all_tied_weights_keys before loading finalization.
with _internvl_transformers5_load_compatibility():
self.model = AutoModel.from_pretrained(
model_name,
torch_dtype=model_dtype,
trust_remote_code=True,
use_flash_attn=resolved_use_flash_attn,
low_cpu_mem_usage=True,
_fast_init=False,
).to(self._requested_device)
self.model = AutoModel.from_pretrained(
model_name,
torch_dtype=model_dtype,
attn_implementation=attn_implementation,
low_cpu_mem_usage=True,
).to(self._requested_device)
if hasattr(self.model.language_model, "model"):
layers = self.model.language_model.model.layers
else:
layers = self.model.language_model.layers
self.num_image_token = self.model.config.image_seq_length
# Truncate language model to the requested number of layers
layers = self.model.language_model.layers
if self.num_language_layers is not None:
layers = layers[: self.num_language_layers]
if hasattr(self.model.language_model, "model"):
self.model.language_model.model.layers = torch.nn.ModuleList(layers)
else:
self.model.language_model.layers = torch.nn.ModuleList(layers)
self.model.language_model.lm_head = torch.nn.Identity()
self.model.language_model.layers = torch.nn.ModuleList(layers)
self._configure_memory_features()
self.img_context_token_id = self.tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
@@ -218,20 +146,12 @@ class InternVL3Embedder(nn.Module):
checkpoint_kwargs = {"use_reentrant": self.gradient_checkpointing_use_reentrant}
if not self.enable_gradient_checkpointing:
if hasattr(self.model, "vision_model") and hasattr(self.model.vision_model, "encoder"):
self.model.vision_model.encoder.gradient_checkpointing = False
language_model = getattr(self.model, "language_model", None)
if language_model is not None:
if hasattr(language_model, "gradient_checkpointing_disable"):
language_model.gradient_checkpointing_disable()
elif hasattr(language_model, "gradient_checkpointing"):
language_model.gradient_checkpointing = False
if hasattr(language_model, "model"):
inner = language_model.model
if hasattr(inner, "gradient_checkpointing_disable"):
inner.gradient_checkpointing_disable()
elif hasattr(inner, "gradient_checkpointing"):
inner.gradient_checkpointing = False
language_model = self.model.language_model
if hasattr(language_model, "gradient_checkpointing_disable"):
language_model.gradient_checkpointing_disable()
vision_tower = getattr(self.model, "vision_tower", None)
if vision_tower is not None and hasattr(vision_tower, "encoder"):
vision_tower.encoder.gradient_checkpointing = False
return
def _enable_ckpt(module: nn.Module | None) -> bool:
@@ -250,21 +170,14 @@ class InternVL3Embedder(nn.Module):
enabled_any = _enable_ckpt(self.model)
if hasattr(self.model, "vision_model") and hasattr(self.model.vision_model, "encoder"):
encoder = self.model.vision_model.encoder
encoder.gradient_checkpointing = True
_patch_vision_encoder_checkpointing(
encoder, use_reentrant=self.gradient_checkpointing_use_reentrant
)
enabled_any = True
vision_tower = getattr(self.model, "vision_tower", None)
if vision_tower is not None:
enabled_any = _enable_ckpt(vision_tower) or enabled_any
language_model = getattr(self.model, "language_model", None)
if language_model is not None:
enabled_any = _enable_ckpt(language_model) or enabled_any
if hasattr(language_model, "model"):
enabled_any = _enable_ckpt(language_model.model) or enabled_any
if hasattr(language_model, "config"):
language_model.config.use_cache = False
language_model = self.model.language_model
enabled_any = _enable_ckpt(language_model) or enabled_any
if hasattr(language_model, "config"):
language_model.config.use_cache = False
if hasattr(self.model, "config"):
self.model.config.use_cache = False
@@ -284,7 +197,7 @@ class InternVL3Embedder(nn.Module):
else:
pil_image = image.convert("RGB")
tiles = dynamic_preprocess(pil_image, image_size=self.image_size)
tile_tensors = torch.stack([TF.to_tensor(tile) for tile in tiles]).to(
tile_tensors = torch.stack([tvf.to_tensor(tile) for tile in tiles]).to(
device=self.device, dtype=torch.bfloat16
)
mean = torch.tensor(IMAGENET_MEAN, device=self.device, dtype=torch.bfloat16).view(1, 3, 1, 1)
@@ -323,76 +236,12 @@ class InternVL3Embedder(nn.Module):
for num_tiles_list, text_prompt in zip(batch_num_tiles_list, text_prompts, strict=True):
prompt_segments = []
for i, tile_count in enumerate(num_tiles_list):
token_count = self.model.num_image_token * tile_count
token_count = self.num_image_token * tile_count
image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * token_count + IMG_END_TOKEN
prompt_segments.append(f"Image-{i + 1}: {image_tokens}\n")
prompts.append("".join(prompt_segments) + text_prompt.strip())
return prompts
def _prepare_and_fuse_embeddings(
self,
prompts: Sequence[str],
vit_embeds: torch.Tensor,
image_masks: torch.Tensor,
batch_num_tiles_list: list[list[int]],
) -> tuple[torch.Tensor, torch.Tensor]:
untruncated_ids = self.tokenizer(list(prompts), padding=False, truncation=False)["input_ids"]
true_sequence_length = max((len(ids) for ids in untruncated_ids), default=0)
if true_sequence_length > self.max_text_length:
logger.warning(
"InternVL3 prompt truncated in batch: max_length=%s actual_max_length=%s",
self.max_text_length,
true_sequence_length,
)
model_inputs = self.tokenizer(
list(prompts),
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=self.max_text_length,
).to(self.device)
input_ids = model_inputs["input_ids"]
attention_mask = model_inputs["attention_mask"]
img_token_mask = input_ids == self.img_context_token_id
input_embeds = self.model.language_model.get_input_embeddings()(input_ids).clone()
batch_size, _, channels = input_embeds.shape
vit_embeds = vit_embeds.reshape(-1, channels).to(dtype=input_embeds.dtype, device=input_embeds.device)
tokens_per_tile = self.model.num_image_token
actual_vis_tokens_list = img_token_mask.sum(dim=1).tolist()
vit_idx = 0
for batch_index in range(batch_size):
expected_vis_tokens = sum(batch_num_tiles_list[batch_index]) * tokens_per_tile
mask_b = img_token_mask[batch_index]
actual_vis_tokens = actual_vis_tokens_list[batch_index]
item_vit_embeds = vit_embeds[vit_idx : vit_idx + expected_vis_tokens]
vit_idx += expected_vis_tokens
if actual_vis_tokens > 0:
if item_vit_embeds.shape[0] < actual_vis_tokens:
raise ValueError(
f"InternVL3 produced fewer image tokens than expected for sample {batch_index}: "
f"got {item_vit_embeds.shape[0]}, need {actual_vis_tokens}"
)
input_embeds[batch_index, mask_b] = item_vit_embeds[:actual_vis_tokens]
current_token_idx = 0
img_token_locations = torch.where(mask_b)[0]
for image_index, num_tiles in enumerate(batch_num_tiles_list[batch_index]):
num_tokens_for_image = num_tiles * tokens_per_tile
if not bool(image_masks[batch_index, image_index].item()):
start_offset = current_token_idx
end_offset = min(current_token_idx + num_tokens_for_image, len(img_token_locations))
if start_offset < end_offset:
idxs = img_token_locations[start_offset:end_offset]
attention_mask[batch_index, idxs] = 0
current_token_idx += num_tokens_for_image
return input_embeds, attention_mask
def get_fused_image_text_embedding_from_tensor_images(
self,
image_tensors_batch: Sequence[Sequence[Image.Image | torch.Tensor]],
@@ -404,27 +253,46 @@ class InternVL3Embedder(nn.Module):
if pixel_values.shape[0] == 0:
logger.warning("InternVL3 received an empty image batch after preprocessing.")
hidden_size = getattr(self.model.config, "hidden_size", None)
if hidden_size is None and hasattr(self.model.language_model, "config"):
hidden_size = getattr(self.model.language_model.config, "hidden_size", None)
if hidden_size is None:
hidden_size = getattr(self.model.config.text_config, "hidden_size", None)
if hidden_size is None:
raise RuntimeError("Unable to infer hidden size for empty InternVL3 batch.")
empty = torch.empty(0, hidden_size, device=self.device, dtype=torch.float32)
return empty
prompts = self._build_multimodal_prompts(batch_num_tiles_list, text_prompts)
vit_embeds = self.model.extract_feature(pixel_values)
inputs_embeds, attention_mask = self._prepare_and_fuse_embeddings(
prompts,
vit_embeds,
image_masks.to(device=self.device),
batch_num_tiles_list,
)
outputs = self.model.language_model(
inputs_embeds=inputs_embeds,
model_inputs = self.tokenizer(
list(prompts),
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=self.max_text_length,
).to(self.device)
input_ids = model_inputs["input_ids"]
attention_mask = model_inputs["attention_mask"]
# Zero out attention for absent images
img_token_mask = input_ids == self.img_context_token_id
tokens_per_tile = self.num_image_token
for batch_index in range(input_ids.shape[0]):
current_token_idx = 0
img_token_locations = torch.where(img_token_mask[batch_index])[0]
for image_index, num_tiles in enumerate(batch_num_tiles_list[batch_index]):
num_tokens_for_image = num_tiles * tokens_per_tile
if not bool(image_masks[batch_index, image_index].item()):
start_offset = current_token_idx
end_offset = min(current_token_idx + num_tokens_for_image, len(img_token_locations))
if start_offset < end_offset:
idxs = img_token_locations[start_offset:end_offset]
attention_mask[batch_index, idxs] = 0
current_token_idx += num_tokens_for_image
outputs = self.model(
input_ids=input_ids,
pixel_values=pixel_values,
attention_mask=attention_mask,
output_hidden_states=True,
use_cache=False,
return_dict=True,
)
fused_hidden = outputs.hidden_states[-1].to(torch.float32)
@@ -433,3 +301,11 @@ class InternVL3Embedder(nn.Module):
@property
def device(self) -> torch.device:
return next(self.model.parameters()).device
def _flash_attn_available() -> bool:
try:
import flash_attn # noqa: F401
except ModuleNotFoundError:
return False
return True
+25 -20
View File
@@ -45,6 +45,7 @@ class EVO1Policy(PreTrainedPolicy):
self.config = config
self.model = EVO1(self._build_model_config(config))
self.model.set_finetune_flags()
self._keep_frozen_embedder_eval()
self.reset()
@classmethod
@@ -64,7 +65,7 @@ class EVO1Policy(PreTrainedPolicy):
**kwargs,
) -> T:
if strict is None:
strict = not (config is not None and getattr(config, "training_stage", None) == "stage2")
strict = True
return super().from_pretrained(
pretrained_name_or_path=pretrained_name_or_path,
config=config,
@@ -85,6 +86,7 @@ class EVO1Policy(PreTrainedPolicy):
"device": config.device,
"return_cls_only": config.return_cls_only,
"vlm_name": config.vlm_model_name,
"image_size": int(config.image_resolution[0]),
"vlm_num_layers": config.vlm_num_layers,
"vlm_dtype": config.vlm_dtype,
"use_flash_attn": config.use_flash_attn,
@@ -100,7 +102,8 @@ class EVO1Policy(PreTrainedPolicy):
"dropout": config.dropout,
"num_inference_timesteps": config.num_inference_timesteps,
"num_categories": config.num_categories,
"enable_gradient_checkpointing": config.enable_gradient_checkpointing,
"enable_gradient_checkpointing": config.enable_gradient_checkpointing
and bool(config.finetune_vlm or config.finetune_language_model or config.finetune_vision_model),
"gradient_checkpointing_use_reentrant": config.gradient_checkpointing_use_reentrant,
"finetune_vlm": config.finetune_vlm,
"finetune_language_model": config.finetune_language_model,
@@ -303,6 +306,18 @@ class EVO1Policy(PreTrainedPolicy):
or self.config.finetune_vision_model
)
def _keep_frozen_embedder_eval(self) -> None:
if self._tracks_vlm_gradients:
return
embedder = getattr(self.model, "embedder", None)
if embedder is not None:
embedder.eval()
def train(self, mode: bool = True):
super().train(mode)
self._keep_frozen_embedder_eval()
return self
def _collect_image_batches(self, batch: dict[str, Tensor]) -> tuple[list[list[Tensor]], Tensor]:
camera_keys = self._camera_keys or sorted(key for key in batch if key.startswith(f"{OBS_IMAGES}."))
if not camera_keys:
@@ -348,23 +363,13 @@ class EVO1Policy(PreTrainedPolicy):
) -> Tensor:
track_vlm_gradients = self._tracks_vlm_gradients
grad_context = nullcontext() if track_vlm_gradients else torch.no_grad()
embedder = getattr(self.model, "embedder", None)
embedder_was_training = embedder.training if embedder is not None else None
if not track_vlm_gradients and embedder is not None:
embedder.eval()
try:
with grad_context:
fused_tokens = self.model.get_vl_embeddings(
images=image_batches,
image_mask=image_masks,
prompt=prompts,
return_cls_only=self.config.return_cls_only,
)
finally:
if not track_vlm_gradients and embedder is not None and embedder_was_training is not None:
embedder.train(embedder_was_training)
with grad_context:
fused_tokens = self.model.get_vl_embeddings(
images=image_batches,
image_mask=image_masks,
prompt=prompts,
return_cls_only=self.config.return_cls_only,
)
if not track_vlm_gradients:
fused_tokens = fused_tokens.detach()
@@ -439,7 +444,7 @@ class EVO1Policy(PreTrainedPolicy):
embodiment_ids=embodiment_ids,
)
actions = actions.view(states.shape[0], self.config.chunk_size, self.config.max_action_dim)
return actions[:, :, : self._env_action_dim]
return actions
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor], **kwargs) -> Tensor:
+326 -4
View File
@@ -14,17 +14,24 @@
from __future__ import annotations
from copy import deepcopy
from dataclasses import dataclass
from typing import Any
import torch
from lerobot.configs import FeatureType, PipelineFeatureType, PolicyFeature
from lerobot.policies.evo1.configuration_evo1 import Evo1Config
from lerobot.processor import (
AddBatchDimensionProcessorStep,
DeviceProcessorStep,
NormalizerProcessorStep,
ObservationProcessorStep,
PolicyAction,
PolicyActionProcessorStep,
PolicyProcessorPipeline,
ProcessorStep,
ProcessorStepRegistry,
RenameObservationsProcessorStep,
UnnormalizerProcessorStep,
)
@@ -34,11 +41,13 @@ from lerobot.processor.converters import (
policy_action_to_transition,
transition_to_policy_action,
)
from lerobot.types import EnvTransition, TransitionKey
from lerobot.utils.constants import (
ACTION,
DONE,
INFO,
OBS_PREFIX,
OBS_STATE,
POLICY_POSTPROCESSOR_DEFAULT_NAME,
POLICY_PREPROCESSOR_DEFAULT_NAME,
REWARD,
@@ -65,6 +74,305 @@ def evo1_batch_to_transition(batch: dict[str, Any]):
)
@dataclass
@ProcessorStepRegistry.register(name="evo1_pad_state_processor")
class Evo1PadStateProcessorStep(ObservationProcessorStep):
"""Pad policy observations to EVO1's fixed state width before normalization."""
max_state_dim: int = 24
def observation(self, observation: dict[str, Any]) -> dict[str, Any]:
if OBS_STATE not in observation:
return observation
state = observation[OBS_STATE]
state_dim = state.shape[-1]
if state_dim > self.max_state_dim:
raise ValueError(
f"EVO1 state has {state_dim} dims, which exceeds max_state_dim={self.max_state_dim}."
)
if state_dim < self.max_state_dim:
observation = observation.copy()
observation[OBS_STATE] = torch.nn.functional.pad(state, (0, self.max_state_dim - state_dim))
return observation
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
new_features = {ft: feats.copy() for ft, feats in features.items()}
state_feats = new_features.setdefault(FeatureType.STATE, {})
if OBS_STATE in state_feats:
state_feats[OBS_STATE] = PolicyFeature(type=FeatureType.STATE, shape=(self.max_state_dim,))
return new_features
def get_config(self) -> dict[str, Any]:
return {"max_state_dim": self.max_state_dim}
@dataclass
@ProcessorStepRegistry.register(name="evo1_pad_action_processor")
class Evo1PadActionProcessorStep(ProcessorStep):
"""Pad training actions and preserve the active action dimensions with action_mask."""
max_action_dim: int = 24
def __call__(self, transition: EnvTransition) -> EnvTransition:
action = transition.get(TransitionKey.ACTION)
if action is None:
return transition
if not isinstance(action, PolicyAction):
raise ValueError(f"EVO1 action should be a PolicyAction tensor, but got {type(action)}.")
action_dim = action.shape[-1]
if action_dim > self.max_action_dim:
raise ValueError(
f"EVO1 action has {action_dim} dims, which exceeds max_action_dim={self.max_action_dim}."
)
new_transition = transition.copy()
new_action = action
if action_dim < self.max_action_dim:
new_action = torch.nn.functional.pad(action, (0, self.max_action_dim - action_dim))
complementary_data = dict(new_transition.get(TransitionKey.COMPLEMENTARY_DATA) or {})
action_mask = complementary_data.get("action_mask")
if action_mask is None:
action_mask = torch.ones(action.shape, dtype=torch.bool, device=action.device)
else:
action_mask = torch.as_tensor(action_mask, dtype=torch.bool, device=action.device)
if action_mask.shape != action.shape:
raise ValueError(
f"action_mask shape {tuple(action_mask.shape)} does not match action shape {tuple(action.shape)}."
)
if action_dim < self.max_action_dim:
action_mask = torch.nn.functional.pad(action_mask, (0, self.max_action_dim - action_dim))
complementary_data["action_mask"] = action_mask
new_transition[TransitionKey.ACTION] = new_action
new_transition[TransitionKey.COMPLEMENTARY_DATA] = complementary_data
return new_transition
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
new_features = {ft: feats.copy() for ft, feats in features.items()}
action_feats = new_features.setdefault(FeatureType.ACTION, {})
action_feats[ACTION] = PolicyFeature(type=FeatureType.ACTION, shape=(self.max_action_dim,))
return new_features
def get_config(self) -> dict[str, Any]:
return {"max_action_dim": self.max_action_dim}
@dataclass
@ProcessorStepRegistry.register(name="evo1_action_processor")
class Evo1ActionProcessorStep(PolicyActionProcessorStep):
"""Crop padded EVO1 actions and optionally binarize the LIBERO gripper channel."""
action_dim: int
binarize_gripper: bool = False
gripper_index: int = 6
gripper_threshold: float = 0.5
gripper_below_threshold_value: float = 1.0
gripper_above_threshold_value: float = -1.0
def action(self, action: PolicyAction) -> PolicyAction:
if action.shape[-1] < self.action_dim:
raise ValueError(
f"EVO1 action has {action.shape[-1]} dims, which is smaller than action_dim={self.action_dim}."
)
action = action[..., : self.action_dim]
if not self.binarize_gripper:
return action
if not 0 <= self.gripper_index < self.action_dim:
raise ValueError(
f"gripper_index={self.gripper_index} must be within action_dim={self.action_dim}."
)
action = action.clone()
below = torch.as_tensor(
self.gripper_below_threshold_value,
dtype=action.dtype,
device=action.device,
)
above = torch.as_tensor(
self.gripper_above_threshold_value,
dtype=action.dtype,
device=action.device,
)
action[..., self.gripper_index] = torch.where(
action[..., self.gripper_index] > self.gripper_threshold,
above,
below,
)
return action
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
new_features = {ft: feats.copy() for ft, feats in features.items()}
action_feats = new_features.setdefault(FeatureType.ACTION, {})
action_feats[ACTION] = PolicyFeature(type=FeatureType.ACTION, shape=(self.action_dim,))
return new_features
def get_config(self) -> dict[str, Any]:
return {
"action_dim": self.action_dim,
"binarize_gripper": self.binarize_gripper,
"gripper_index": self.gripper_index,
"gripper_threshold": self.gripper_threshold,
"gripper_below_threshold_value": self.gripper_below_threshold_value,
"gripper_above_threshold_value": self.gripper_above_threshold_value,
}
def _evo1_action_dim(config: Evo1Config) -> int:
if config.postprocess_action_dim is not None:
return config.postprocess_action_dim
action_feature = config.action_feature
if action_feature is None:
return config.max_action_dim
return int(action_feature.shape[0])
def _evo1_normalization_features(config: Evo1Config) -> dict[str, PolicyFeature]:
features = {**config.input_features, **config.output_features}
features[OBS_STATE] = PolicyFeature(type=FeatureType.STATE, shape=(config.max_state_dim,))
features[ACTION] = PolicyFeature(type=FeatureType.ACTION, shape=(config.max_action_dim,))
return features
def _evo1_action_features(config: Evo1Config) -> dict[str, PolicyFeature]:
return {ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(config.max_action_dim,))}
_STAT_PAD_VALUES = {
"mean": 0.0,
"std": 1.0,
"min": -1.0,
"max": 1.0,
"q01": -1.0,
"q99": 1.0,
"q10": -1.0,
"q90": 1.0,
}
def _pad_stat_value(value: Any, target_dim: int, stat_name: str) -> torch.Tensor:
tensor = torch.as_tensor(value)
if not tensor.is_floating_point():
tensor = tensor.to(dtype=torch.float32)
if tensor.ndim == 0 or tensor.shape[-1] >= target_dim:
return tensor
pad_shape = (*tensor.shape[:-1], target_dim - tensor.shape[-1])
pad_value = _STAT_PAD_VALUES.get(stat_name, 0.0)
padding = torch.full(pad_shape, pad_value, dtype=tensor.dtype, device=tensor.device)
return torch.cat([tensor, padding], dim=-1)
def _pad_feature_stats(
stats: dict[str, dict[str, Any]],
feature_key: str,
target_dim: int,
) -> None:
if feature_key not in stats:
return
stats[feature_key] = {
stat_name: _pad_stat_value(stat_value, target_dim, stat_name)
for stat_name, stat_value in stats[feature_key].items()
}
def _pad_evo1_stats(
config: Evo1Config,
stats: dict[str, dict[str, Any]] | None,
) -> dict[str, dict[str, Any]] | None:
if stats is None:
return None
padded_stats = deepcopy(stats)
# Added dimensions represent zero-padding inside EVO1. These neutral stats keep
# padded observations at normalized zero and only provide shape compatibility.
_pad_feature_stats(padded_stats, OBS_STATE, config.max_state_dim)
_pad_feature_stats(padded_stats, ACTION, config.max_action_dim)
return padded_stats
def _refresh_evo1_normalization_steps(
config: Evo1Config,
preprocessor: PolicyProcessorPipeline,
postprocessor: PolicyProcessorPipeline,
) -> None:
normalization_features = _evo1_normalization_features(config)
action_features = _evo1_action_features(config)
for step in preprocessor.steps:
if isinstance(step, NormalizerProcessorStep):
step.features = normalization_features
step.stats = _pad_evo1_stats(config, step.stats)
step.to(device=step.device, dtype=step.dtype)
for step in postprocessor.steps:
if isinstance(step, UnnormalizerProcessorStep):
step.features = action_features
step.stats = _pad_evo1_stats(config, step.stats)
step.to(device=step.device, dtype=step.dtype)
def ensure_evo1_processor_steps(
config: Evo1Config,
preprocessor: PolicyProcessorPipeline,
postprocessor: PolicyProcessorPipeline,
) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]:
"""Add EVO1 processor steps when loading older checkpoints that do not serialize them."""
has_state_padding = any(isinstance(step, Evo1PadStateProcessorStep) for step in preprocessor.steps)
if not has_state_padding:
steps = list(preprocessor.steps)
insert_idx = next(
(idx for idx, step in enumerate(steps) if isinstance(step, NormalizerProcessorStep)),
len(steps),
)
steps.insert(insert_idx, Evo1PadStateProcessorStep(max_state_dim=config.max_state_dim))
preprocessor.steps = steps
has_action_padding = any(isinstance(step, Evo1PadActionProcessorStep) for step in preprocessor.steps)
if not has_action_padding:
steps = list(preprocessor.steps)
insert_idx = next(
(idx for idx, step in enumerate(steps) if isinstance(step, NormalizerProcessorStep)),
len(steps),
)
steps.insert(insert_idx, Evo1PadActionProcessorStep(max_action_dim=config.max_action_dim))
preprocessor.steps = steps
has_action_processor = any(isinstance(step, Evo1ActionProcessorStep) for step in postprocessor.steps)
if not has_action_processor:
steps = list(postprocessor.steps)
insert_idx = next(
(idx + 1 for idx, step in enumerate(steps) if isinstance(step, UnnormalizerProcessorStep)),
0,
)
steps.insert(
insert_idx,
Evo1ActionProcessorStep(
action_dim=_evo1_action_dim(config),
binarize_gripper=config.binarize_gripper,
gripper_index=config.gripper_index,
gripper_threshold=config.gripper_threshold,
gripper_below_threshold_value=config.gripper_below_threshold_value,
gripper_above_threshold_value=config.gripper_above_threshold_value,
),
)
postprocessor.steps = steps
_refresh_evo1_normalization_steps(config, preprocessor, postprocessor)
return preprocessor, postprocessor
def make_evo1_pre_post_processors(
config: Evo1Config,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
@@ -72,21 +380,35 @@ def make_evo1_pre_post_processors(
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
normalization_features = _evo1_normalization_features(config)
action_features = _evo1_action_features(config)
normalization_stats = _pad_evo1_stats(config, dataset_stats)
input_steps = [
RenameObservationsProcessorStep(rename_map={}),
AddBatchDimensionProcessorStep(),
Evo1PadStateProcessorStep(max_state_dim=config.max_state_dim),
Evo1PadActionProcessorStep(max_action_dim=config.max_action_dim),
NormalizerProcessorStep(
features={**config.input_features, **config.output_features},
features=normalization_features,
norm_map=config.normalization_mapping,
stats=dataset_stats,
stats=normalization_stats,
),
DeviceProcessorStep(device=config.device),
]
output_steps = [
UnnormalizerProcessorStep(
features=config.output_features,
features=action_features,
norm_map=config.normalization_mapping,
stats=dataset_stats,
stats=normalization_stats,
),
Evo1ActionProcessorStep(
action_dim=_evo1_action_dim(config),
binarize_gripper=config.binarize_gripper,
gripper_index=config.gripper_index,
gripper_threshold=config.gripper_threshold,
gripper_below_threshold_value=config.gripper_below_threshold_value,
gripper_above_threshold_value=config.gripper_above_threshold_value,
),
DeviceProcessorStep(device="cpu"),
]
+8
View File
@@ -311,6 +311,14 @@ def make_pre_post_processors(
to_output=transition_to_policy_action,
)
_reconnect_relative_absolute_steps(preprocessor, postprocessor)
if isinstance(policy_cfg, Evo1Config):
from .evo1.processor_evo1 import ensure_evo1_processor_steps
preprocessor, postprocessor = ensure_evo1_processor_steps(
policy_cfg,
preprocessor,
postprocessor,
)
return preprocessor, postprocessor
# Create a new processor based on policy type
+1 -2
View File
@@ -40,7 +40,7 @@ from .converters import (
)
from .delta_action_processor import MapDeltaActionToRobotActionStep, MapTensorToDeltaActionDictStep
from .device_processor import DeviceProcessorStep
from .env_processor import IsaaclabArenaProcessorStep, LiberoActionProcessorStep, LiberoProcessorStep
from .env_processor import IsaaclabArenaProcessorStep, LiberoProcessorStep
from .factory import (
make_default_processors,
make_default_robot_action_processor,
@@ -149,7 +149,6 @@ __all__ = [
"RewardProcessorStep",
"DataProcessorPipeline",
"IsaaclabArenaProcessorStep",
"LiberoActionProcessorStep",
"LiberoProcessorStep",
"TimeLimitProcessorStep",
"AddBatchDimensionProcessorStep",
+3 -44
View File
@@ -18,9 +18,9 @@ from dataclasses import dataclass
import torch
from lerobot.configs import FeatureType, PipelineFeatureType, PolicyFeature
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_PREFIX, OBS_STATE, OBS_STR
from lerobot.utils.constants import OBS_IMAGES, OBS_PREFIX, OBS_STATE, OBS_STR
from .pipeline import ActionProcessorStep, ObservationProcessorStep, ProcessorStepRegistry
from .pipeline import ObservationProcessorStep, ProcessorStepRegistry
@dataclass
@@ -46,8 +46,6 @@ class LiberoProcessorStep(ObservationProcessorStep):
- This accounts for the HuggingFaceVLA/libero camera orientation convention.
"""
max_state_dim: int | None = None
def _process_observation(self, observation):
"""
Processes both image and robot_state observations from LIBERO.
@@ -80,16 +78,6 @@ class LiberoProcessorStep(ObservationProcessorStep):
state = state.float()
if state.dim() == 1:
state = state.unsqueeze(0)
if self.max_state_dim is not None:
if state.shape[-1] > self.max_state_dim:
raise ValueError(
f"LIBERO state has {state.shape[-1]} dims, which is larger than "
f"configured max_state_dim={self.max_state_dim}."
)
if state.shape[-1] < self.max_state_dim:
pad_width = self.max_state_dim - state.shape[-1]
state = torch.nn.functional.pad(state, (0, pad_width))
processed_obs[OBS_STATE] = state
return processed_obs
@@ -112,7 +100,7 @@ class LiberoProcessorStep(ObservationProcessorStep):
# add our new flattened state
state_feats[OBS_STATE] = PolicyFeature(
type=FeatureType.STATE,
shape=(self.max_state_dim or 8,), # [eef_pos(3), axis_angle(3), gripper(2)] plus padding
shape=(8,), # [eef_pos(3), axis_angle(3), gripper(2)]
)
new_features[FeatureType.STATE] = state_feats
@@ -122,9 +110,6 @@ class LiberoProcessorStep(ObservationProcessorStep):
def observation(self, observation):
return self._process_observation(observation)
def get_config(self) -> dict:
return {"max_state_dim": self.max_state_dim}
def _quat2axisangle(self, quat: torch.Tensor) -> torch.Tensor:
"""
Convert batched quaternions to axis-angle format.
@@ -167,32 +152,6 @@ class LiberoProcessorStep(ObservationProcessorStep):
return result
@dataclass
@ProcessorStepRegistry.register(name="libero_action_processor")
class LiberoActionProcessorStep(ActionProcessorStep):
"""Slices padded policy actions back to the executable LIBERO action space."""
action_dim: int = 7
def action(self, action):
if action.shape[-1] < self.action_dim:
raise ValueError(
f"LIBERO action has {action.shape[-1]} dims, which is smaller than action_dim={self.action_dim}."
)
return action[..., : self.action_dim]
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
new_features = {ft: feats.copy() for ft, feats in features.items()}
action_feats = new_features.setdefault(FeatureType.ACTION, {})
action_feats[ACTION] = PolicyFeature(type=FeatureType.ACTION, shape=(self.action_dim,))
return new_features
def get_config(self) -> dict:
return {"action_dim": self.action_dim}
@dataclass
@ProcessorStepRegistry.register(name="isaaclab_arena_processor")
class IsaaclabArenaProcessorStep(ObservationProcessorStep):
+6 -1
View File
@@ -191,7 +191,7 @@ def rollout(
action = action_transition[ACTION]
# Convert to CPU / numpy.
action_numpy: np.ndarray = action.to("cpu").numpy()
action_numpy = _action_to_env_numpy(action)
assert action_numpy.ndim == 2, "Action dimensions should be (batch, action_dim)"
# Apply the next action.
@@ -261,6 +261,11 @@ def rollout(
return ret
def _action_to_env_numpy(action: Tensor) -> np.ndarray:
"""Convert policy actions to a NumPy array accepted by Gym environments."""
return action.detach().to(device="cpu", dtype=torch.float32).numpy()
def eval_policy(
env: gym.vector.VectorEnv,
policy: PreTrainedPolicy,
+10 -37
View File
@@ -13,7 +13,7 @@ from gymnasium.envs.registration import register, registry as gym_registry
from lerobot.configs.types import PolicyFeature
from lerobot.envs.configs import EnvConfig, LiberoEnv
from lerobot.envs.factory import make_env, make_env_config, make_env_pre_post_processors
from lerobot.processor import LiberoActionProcessorStep, LiberoProcessorStep
from lerobot.processor import LiberoProcessorStep
from lerobot.utils.constants import OBS_PREFIX, OBS_STATE
logger = logging.getLogger(__name__)
@@ -86,32 +86,18 @@ def test_processors_delegation_supports_legacy_override_signature():
assert isinstance(post, DataProcessorPipeline)
def test_libero_evo1_processors_use_padded_state_and_env_action_dim():
"""EVO1 uses padded LIBERO state features while env actions stay executable."""
class _Evo1Config:
type = "evo1"
max_state_dim = 24
def test_libero_processors_are_policy_agnostic():
cfg = LiberoEnv()
pre, post = make_env_pre_post_processors(cfg, policy_cfg=_Evo1Config())
pre, post = make_env_pre_post_processors(cfg, policy_cfg=object())
assert isinstance(pre.steps[0], LiberoProcessorStep)
assert pre.steps[0].max_state_dim == 24
assert isinstance(post.steps[0], LiberoActionProcessorStep)
assert post.steps[0].action_dim == cfg.features["action"].shape[0] == 7
class _OtherConfig:
type = "other"
pre_other, _ = make_env_pre_post_processors(cfg, policy_cfg=_OtherConfig())
assert pre_other.steps[0].max_state_dim is None
assert len(post.steps) == 0
def test_libero_processor_pads_state_to_max_dim():
step = LiberoProcessorStep(max_state_dim=24)
def test_libero_processor_flattens_state_to_raw_8_dim():
step = LiberoProcessorStep()
observation = {
OBS_PREFIX
+ "robot_state": {
OBS_PREFIX + "robot_state": {
"eef": {
"pos": torch.tensor([[1.0, 2.0, 3.0]]),
"quat": torch.tensor([[0.0, 0.0, 0.0, 1.0]]),
@@ -121,21 +107,8 @@ def test_libero_processor_pads_state_to_max_dim():
}
state = step.observation(observation)[OBS_STATE]
assert state.shape == (1, 24)
assert torch.allclose(state[:, :8], torch.tensor([[1.0, 2.0, 3.0, 0.0, 0.0, 0.0, 4.0, 5.0]]))
assert torch.count_nonzero(state[:, 8:]).item() == 0
def test_libero_action_processor_slices_padded_action():
step = LiberoActionProcessorStep(action_dim=7)
action = torch.arange(2 * 3 * 24, dtype=torch.float32).reshape(2, 3, 24)
sliced = step.action(action)
assert sliced.shape == (2, 3, 7)
assert torch.equal(sliced, action[..., :7])
with pytest.raises(ValueError, match="smaller than action_dim=7"):
step.action(torch.zeros(2, 6))
assert state.shape == (1, 8)
assert torch.allclose(state, torch.tensor([[1.0, 2.0, 3.0, 0.0, 0.0, 0.0, 4.0, 5.0]]))
def test_base_create_envs():
+155 -4
View File
@@ -16,6 +16,7 @@
from __future__ import annotations
import pytest
import torch
from torch import nn
@@ -23,7 +24,15 @@ import lerobot.policies.evo1.modeling_evo1 as modeling_evo1
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.policies.evo1.configuration_evo1 import Evo1Config
from lerobot.policies.evo1.flow_matching import FlowmatchingActionHead
from lerobot.policies.evo1.processor_evo1 import (
Evo1ActionProcessorStep,
Evo1PadActionProcessorStep,
Evo1PadStateProcessorStep,
ensure_evo1_processor_steps,
make_evo1_pre_post_processors,
)
from lerobot.policies.factory import get_policy_class, make_policy_config
from lerobot.processor import NormalizerProcessorStep, PolicyProcessorPipeline, UnnormalizerProcessorStep
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
STATE_DIM = 4
@@ -108,6 +117,19 @@ def make_batch(include_action=True):
return batch
def make_stats(state_dim=STATE_DIM, action_dim=ACTION_DIM):
return {
OBS_STATE: {
"min": torch.full((state_dim,), -2.0),
"max": torch.full((state_dim,), 2.0),
},
ACTION: {
"min": torch.full((action_dim,), -1.0),
"max": torch.full((action_dim,), 1.0),
},
}
def test_evo1_factory_registration():
cfg = make_policy_config(
"evo1",
@@ -191,22 +213,151 @@ def test_evo1_stage_defaults_and_consistency():
raise AssertionError("Expected inconsistent finetune config to raise ValueError")
def test_evo1_rejects_non_square_image_resolution():
with pytest.raises(ValueError, match="square image_resolution"):
make_config(image_resolution=(448, 320))
def test_evo1_build_model_config_uses_image_resolution_and_trainable_checkpointing():
stage1 = make_config(training_stage="stage1", image_resolution=(224, 224))
stage1_model_config = modeling_evo1.EVO1Policy._build_model_config(stage1)
assert stage1_model_config["image_size"] == 224
assert stage1_model_config["enable_gradient_checkpointing"] is False
stage2 = make_config(training_stage="stage2", image_resolution=(224, 224))
stage2_model_config = modeling_evo1.EVO1Policy._build_model_config(stage2)
assert stage2_model_config["enable_gradient_checkpointing"] is True
def test_evo1_policy_processors_pad_state_crop_action_and_binarize_gripper():
libero_action_dim = 7
config = make_config(
max_state_dim=MAX_STATE_DIM,
max_action_dim=8,
postprocess_action_dim=libero_action_dim,
binarize_gripper=True,
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(libero_action_dim,))},
)
stats = make_stats(action_dim=libero_action_dim)
preprocessor, postprocessor = make_evo1_pre_post_processors(config, dataset_stats=stats)
assert isinstance(preprocessor.steps[2], Evo1PadStateProcessorStep)
assert isinstance(preprocessor.steps[3], Evo1PadActionProcessorStep)
assert isinstance(preprocessor.steps[4], NormalizerProcessorStep)
assert isinstance(postprocessor.steps[0], UnnormalizerProcessorStep)
assert isinstance(postprocessor.steps[1], Evo1ActionProcessorStep)
normalizer = preprocessor.steps[4]
assert normalizer.features[OBS_STATE].shape == (MAX_STATE_DIM,)
assert normalizer.features[ACTION].shape == (8,)
assert normalizer._tensor_stats[OBS_STATE]["min"].shape == (MAX_STATE_DIM,)
assert normalizer._tensor_stats[ACTION]["min"].shape == (8,)
processed_batch = preprocessor(
{
"task": "pick the block",
OBS_STATE: torch.zeros(STATE_DIM),
ACTION: torch.zeros(libero_action_dim),
f"{OBS_IMAGES}.front": torch.rand(3, 16, 16),
}
)
processed_state = processed_batch[OBS_STATE]
assert processed_state.shape == (1, MAX_STATE_DIM)
assert torch.allclose(processed_state, torch.zeros_like(processed_state))
assert processed_batch[ACTION].shape == (1, 8)
assert torch.allclose(processed_batch[ACTION], torch.zeros_like(processed_batch[ACTION]))
assert processed_batch["action_mask"].shape == (1, 8)
assert processed_batch["action_mask"][:, :libero_action_dim].all()
assert not processed_batch["action_mask"][:, libero_action_dim:].any()
action = torch.tensor(
[
[0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.5, 0.7],
[0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7],
],
dtype=torch.float32,
)
processed = postprocessor(action)
assert processed.shape == (2, 7)
assert torch.allclose(processed[:, :6], action[:, :6])
assert torch.equal(processed[:, 6], torch.tensor([1.0, -1.0]))
def test_evo1_legacy_processors_are_completed_before_normalization():
config = make_config(
max_state_dim=MAX_STATE_DIM,
max_action_dim=8,
postprocess_action_dim=7,
binarize_gripper=True,
)
stats = make_stats(action_dim=7)
legacy_pre = PolicyProcessorPipeline(
steps=[
NormalizerProcessorStep(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=stats,
)
]
)
legacy_post = PolicyProcessorPipeline(
steps=[
UnnormalizerProcessorStep(
features=config.output_features,
norm_map=config.normalization_mapping,
stats=stats,
)
]
)
preprocessor, postprocessor = ensure_evo1_processor_steps(config, legacy_pre, legacy_post)
assert isinstance(preprocessor.steps[0], Evo1PadStateProcessorStep)
assert isinstance(preprocessor.steps[1], Evo1PadActionProcessorStep)
assert isinstance(preprocessor.steps[2], NormalizerProcessorStep)
assert isinstance(postprocessor.steps[0], UnnormalizerProcessorStep)
assert isinstance(postprocessor.steps[1], Evo1ActionProcessorStep)
assert postprocessor.steps[1].action_dim == 7
assert postprocessor.steps[1].binarize_gripper is True
assert preprocessor.steps[2].features[OBS_STATE].shape == (MAX_STATE_DIM,)
assert preprocessor.steps[2]._tensor_stats[OBS_STATE]["min"].shape == (MAX_STATE_DIM,)
assert preprocessor.steps[2]._tensor_stats[ACTION]["min"].shape == (8,)
assert postprocessor.steps[0].features[ACTION].shape == (8,)
assert postprocessor.steps[0]._tensor_stats[ACTION]["min"].shape == (8,)
preprocessor, postprocessor = ensure_evo1_processor_steps(config, preprocessor, postprocessor)
assert sum(isinstance(step, Evo1PadStateProcessorStep) for step in preprocessor.steps) == 1
assert sum(isinstance(step, Evo1PadActionProcessorStep) for step in preprocessor.steps) == 1
assert sum(isinstance(step, Evo1ActionProcessorStep) for step in postprocessor.steps) == 1
def test_evo1_policy_forward_and_inference_use_batched_embedding(monkeypatch):
monkeypatch.setattr(modeling_evo1, "EVO1", DummyEVO1)
policy = modeling_evo1.EVO1Policy(make_config())
preprocessor, _postprocessor = make_evo1_pre_post_processors(policy.config, dataset_stats=make_stats())
training_batch = preprocessor(make_batch(include_action=True))
loss, metrics = policy.forward(make_batch(include_action=True))
assert training_batch[ACTION].shape == (2, CHUNK_SIZE, MAX_ACTION_DIM)
assert training_batch["action_mask"].shape == (2, CHUNK_SIZE, MAX_ACTION_DIM)
assert training_batch["action_mask"][:, :, :ACTION_DIM].all()
assert not training_batch["action_mask"][:, :, ACTION_DIM:].any()
loss, metrics = policy.forward(training_batch)
assert loss.ndim == 0
assert torch.isfinite(loss)
assert metrics["active_action_dims"] == ACTION_DIM * CHUNK_SIZE
assert policy.model.get_vl_embeddings_calls == 1
action_chunk = policy.predict_action_chunk(make_batch(include_action=False))
assert action_chunk.shape == (2, CHUNK_SIZE, ACTION_DIM)
assert action_chunk.shape == (2, CHUNK_SIZE, MAX_ACTION_DIM)
policy.reset()
selected = policy.select_action(make_batch(include_action=False))
assert selected.shape == (2, ACTION_DIM)
assert selected.shape == (2, MAX_ACTION_DIM)
def test_stage1_frozen_vlm_embeddings_do_not_track_gradients(monkeypatch):
@@ -220,7 +371,7 @@ def test_stage1_frozen_vlm_embeddings_do_not_track_gradients(monkeypatch):
assert policy.model.grad_enabled_calls == [False]
assert policy.model.embedder_training_calls == [False]
assert not fused_tokens.requires_grad
assert policy.model.embedder.training is True
assert policy.model.embedder.training is False
def test_stage2_vlm_embeddings_track_gradients(monkeypatch):
+14
View File
@@ -0,0 +1,14 @@
import numpy as np
import torch
from lerobot.scripts.lerobot_eval import _action_to_env_numpy
def test_action_to_env_numpy_casts_bfloat16_to_float32():
action = torch.tensor([[0.5, -1.0]], dtype=torch.bfloat16)
action_numpy = _action_to_env_numpy(action)
assert action_numpy.shape == (1, 2)
assert action_numpy.dtype == np.float32
np.testing.assert_allclose(action_numpy, np.array([[0.5, -1.0]], dtype=np.float32))