refactor(policies): use config for evo1 + local imports

This commit is contained in:
Steven Palma
2026-07-02 11:51:27 +02:00
parent d61941fe68
commit 2afe2864e9
9 changed files with 99 additions and 145 deletions
+1 -1
View File
@@ -235,7 +235,7 @@ fastwam = [
"lerobot[transformers-dep]",
"lerobot[diffusers-dep]",
]
evo1 = ["lerobot[transformers-dep]", "lerobot[timm-dep]"]
evo1 = ["lerobot[transformers-dep]"]
hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.14,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
vla_jepa = ["lerobot[transformers-dep]", "lerobot[diffusers-dep]", "lerobot[qwen-vl-utils-dep]"]
@@ -80,6 +80,9 @@ class Evo1Config(PreTrainedConfig):
vlm_model_name: str = "OpenGVLab/InternVL3-1B-hf"
vlm_num_layers: int | None = 14
vlm_dtype: str = "bfloat16"
# Max token length for tokenizing the (image placeholders + instruction) prompt. Prompts longer
# than this are right-truncated, so raise it for tasks with long language instructions or many views.
max_text_length: int = 1024
use_flash_attn: bool = True
action_head: str = "flowmatching"
embed_dim: int = 896
+42 -42
View File
@@ -14,62 +14,64 @@
from __future__ import annotations
from typing import Any
import torch
import torch.nn as nn
from .configuration_evo1 import Evo1Config
from .flow_matching import FlowmatchingActionHead
from .internvl3_embedder import InternVL3Embedder
def _cfgget(config: Any, key: str, default=None):
if isinstance(config, dict):
return config.get(key, default)
return getattr(config, key, default)
class EVO1(nn.Module):
def __init__(self, config: dict):
def __init__(self, config: Evo1Config):
super().__init__()
self.config = config
self._device = _cfgget(config, "device", "cuda")
self.return_cls_only = _cfgget(config, "return_cls_only", False)
vlm_name = _cfgget(config, "vlm_name", "OpenGVLab/InternVL3-1B")
image_size = _cfgget(config, "image_size", 448)
if image_size is None:
image_resolution = _cfgget(config, "image_resolution", (448, 448))
image_size = int(image_resolution[0])
self._device = config.device
self.return_cls_only = config.return_cls_only
# Gradient checkpointing only pays off when the VLM is actually being trained; keep it off
# whenever every VLM branch is frozen so the frozen forward stays cheap.
tracks_vlm_gradients = bool(
config.finetune_vlm or config.finetune_language_model or config.finetune_vision_model
)
enable_gradient_checkpointing = config.enable_gradient_checkpointing and tracks_vlm_gradients
self.embedder = InternVL3Embedder(
model_name=vlm_name,
image_size=image_size,
model_name=config.vlm_model_name,
image_size=int(config.image_resolution[0]),
device=self._device,
num_language_layers=_cfgget(config, "vlm_num_layers", 14),
model_dtype=_cfgget(config, "vlm_dtype", "bfloat16"),
use_flash_attn=_cfgget(config, "use_flash_attn", True),
enable_gradient_checkpointing=_cfgget(config, "enable_gradient_checkpointing", True),
gradient_checkpointing_use_reentrant=_cfgget(
config, "gradient_checkpointing_use_reentrant", False
),
num_language_layers=config.vlm_num_layers,
model_dtype=config.vlm_dtype,
use_flash_attn=config.use_flash_attn,
max_text_length=config.max_text_length,
enable_gradient_checkpointing=enable_gradient_checkpointing,
gradient_checkpointing_use_reentrant=config.gradient_checkpointing_use_reentrant,
)
action_head_type = _cfgget(config, "action_head", "flowmatching").lower()
action_head_type = config.action_head.lower()
if action_head_type != "flowmatching":
raise NotImplementedError(f"Unknown action_head: {action_head_type}")
horizon = _cfgget(config, "action_horizon", _cfgget(config, "horizon", 16))
per_action_dim = _cfgget(config, "per_action_dim", 7)
horizon = config.chunk_size
per_action_dim = config.max_action_dim
action_dim = horizon * per_action_dim
if isinstance(config, dict):
config["horizon"] = horizon
config["per_action_dim"] = per_action_dim
config["action_dim"] = action_dim
self.horizon = horizon
self.per_action_dim = per_action_dim
self.action_head = FlowmatchingActionHead(config=config).to(self._device)
self.action_head = FlowmatchingActionHead(
embed_dim=config.embed_dim,
hidden_dim=config.hidden_dim,
action_dim=action_dim,
horizon=horizon,
per_action_dim=per_action_dim,
num_heads=config.num_heads,
num_layers=config.num_layers,
dropout=config.dropout,
num_inference_timesteps=config.num_inference_timesteps,
num_categories=config.num_categories,
state_dim=config.max_state_dim,
state_hidden_dim=config.state_hidden_dim,
).to(self._device)
def get_vl_embeddings(
self,
@@ -166,15 +168,13 @@ class EVO1(nn.Module):
param.requires_grad = trainable
def set_finetune_flags(self):
finetune_vlm = _cfgget(self.config, "finetune_vlm", False)
finetune_language_model = _cfgget(self.config, "finetune_language_model", False)
finetune_vision_model = _cfgget(self.config, "finetune_vision_model", False)
finetune_vlm = bool(self.config.finetune_vlm)
finetune_language_model = bool(self.config.finetune_language_model)
finetune_vision_model = bool(self.config.finetune_vision_model)
has_explicit_branch_flags = any(
flag is not None for flag in (finetune_language_model, finetune_vision_model)
flag is not None
for flag in (self.config.finetune_language_model, self.config.finetune_vision_model)
)
finetune_language_model = bool(finetune_language_model)
finetune_vision_model = bool(finetune_vision_model)
finetune_vlm = bool(finetune_vlm)
if has_explicit_branch_flags:
self._set_module_trainable(self.embedder, False)
@@ -187,5 +187,5 @@ class EVO1(nn.Module):
elif not finetune_vlm:
self._set_module_trainable(self.embedder, False)
if not _cfgget(self.config, "finetune_action_head", False):
if not self.config.finetune_action_head:
self._set_module_trainable(self.action_head, False)
+9 -41
View File
@@ -16,7 +16,6 @@ from __future__ import annotations
import logging
import math
from types import SimpleNamespace
import torch
import torch.nn as nn
@@ -24,12 +23,6 @@ import torch.nn as nn
logger = logging.getLogger(__name__)
def _cfgget(config, key: str, default=None):
if isinstance(config, dict):
return config.get(key, default)
return getattr(config, key, default)
class SinusoidalPositionalEncoding(nn.Module):
def __init__(self, dim: int, max_len: int = 1000):
super().__init__()
@@ -171,7 +164,6 @@ class BasicTransformerBlock(nn.Module):
class FlowmatchingActionHead(nn.Module):
def __init__(
self,
config=None,
embed_dim: int = 896,
hidden_dim: int = 1024,
action_dim: int = 16 * 7,
@@ -182,40 +174,17 @@ class FlowmatchingActionHead(nn.Module):
dropout: float = 0.0,
num_inference_timesteps: int = 20,
num_categories: int = 1,
state_dim: int | None = None,
state_hidden_dim: int | None = None,
):
super().__init__()
if config is not None:
embed_dim = _cfgget(config, "embed_dim", embed_dim)
hidden_dim = _cfgget(config, "hidden_dim", hidden_dim)
action_dim = _cfgget(config, "action_dim", action_dim)
horizon = _cfgget(config, "horizon", horizon)
per_action_dim = _cfgget(config, "per_action_dim", per_action_dim)
num_heads = _cfgget(config, "num_heads", num_heads)
num_layers = _cfgget(config, "num_layers", num_layers)
dropout = _cfgget(config, "dropout", dropout)
num_inference_timesteps = _cfgget(config, "num_inference_timesteps", num_inference_timesteps)
num_categories = _cfgget(config, "num_categories", num_categories)
self.config = config
else:
self.config = SimpleNamespace(
embed_dim=embed_dim,
hidden_dim=hidden_dim,
action_dim=action_dim,
horizon=horizon,
per_action_dim=per_action_dim,
num_heads=num_heads,
num_layers=num_layers,
dropout=dropout,
num_inference_timesteps=num_inference_timesteps,
num_categories=num_categories,
)
logger.info("FlowmatchingActionHead num_inference_timesteps=%s", num_inference_timesteps)
self.embed_dim = embed_dim
self.horizon = horizon
self.per_action_dim = _cfgget(self.config, "per_action_dim", per_action_dim)
self.action_dim = _cfgget(self.config, "action_dim", action_dim)
self.per_action_dim = per_action_dim
self.action_dim = action_dim
self.num_inference_timesteps = num_inference_timesteps
self.time_pos_enc = SinusoidalPositionalEncoding(embed_dim, max_len=1000)
self.transformer_blocks = nn.ModuleList(
@@ -239,9 +208,8 @@ class FlowmatchingActionHead(nn.Module):
)
self.state_encoder = None
state_dim = _cfgget(self.config, "state_dim")
if state_dim is not None:
state_hidden = _cfgget(self.config, "state_hidden_dim", embed_dim)
state_hidden = state_hidden_dim if state_hidden_dim is not None else embed_dim
self.state_encoder = CategorySpecificMLP(
input_dim=state_dim,
hidden_dim=state_hidden,
@@ -390,8 +358,8 @@ class FlowmatchingActionHead(nn.Module):
state_emb = self.state_encoder(state, embodiment_id).unsqueeze(1)
context_tokens = torch.cat([context_tokens, state_emb], dim=1)
action_dim_total = _cfgget(self.config, "action_dim", self.action_dim)
per_action_dim = _cfgget(self.config, "per_action_dim", action_dim_total // max(self.horizon, 1))
action_dim_total = self.action_dim
per_action_dim = self.per_action_dim
action = torch.rand(batch_size, action_dim_total, device=device, dtype=context_tokens.dtype) * 2 - 1
action_seq = (
@@ -411,7 +379,7 @@ class FlowmatchingActionHead(nn.Module):
target_dtype = self.dtype
context_tokens = context_tokens.to(dtype=target_dtype)
num_steps = int(_cfgget(self.config, "num_inference_timesteps", 32))
num_steps = int(self.num_inference_timesteps)
if num_steps <= 0:
raise ValueError(f"num_inference_timesteps must be positive, got {num_steps}")
dt = 1.0 / num_steps
@@ -111,6 +111,7 @@ class InternVL3Embedder(nn.Module):
num_language_layers: int | None = 14,
model_dtype: str | torch.dtype = "bfloat16",
use_flash_attn: bool = True,
max_text_length: int = 1024,
enable_gradient_checkpointing: bool = True,
gradient_checkpointing_use_reentrant: bool = False,
):
@@ -118,7 +119,7 @@ class InternVL3Embedder(nn.Module):
self._requested_device = device
self.image_size = image_size
self.num_language_layers = num_language_layers
self.max_text_length = 1024
self.max_text_length = max_text_length
self.enable_gradient_checkpointing = bool(enable_gradient_checkpointing)
self.gradient_checkpointing_use_reentrant = bool(gradient_checkpointing_use_reentrant)
+10 -34
View File
@@ -23,11 +23,12 @@ import torch
from torch import Tensor
from lerobot.configs.policies import PreTrainedConfig
from lerobot.policies.evo1.configuration_evo1 import Evo1Config
from lerobot.policies.evo1.evo1_model import EVO1
from lerobot.policies.pretrained import PreTrainedPolicy, T
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
from .configuration_evo1 import Evo1Config
from .evo1_model import EVO1
class EVO1Policy(PreTrainedPolicy):
config_class = Evo1Config
@@ -43,7 +44,7 @@ class EVO1Policy(PreTrainedPolicy):
)
self.config = config
self.model = EVO1(self._build_model_config(config))
self.model = EVO1(config)
self.model.set_finetune_flags()
self._keep_frozen_embedder_eval()
self.reset()
@@ -80,37 +81,6 @@ class EVO1Policy(PreTrainedPolicy):
**kwargs,
)
@staticmethod
def _build_model_config(config: Evo1Config) -> dict:
return {
"device": config.device,
"return_cls_only": config.return_cls_only,
"vlm_name": config.vlm_model_name,
"image_size": int(config.image_resolution[0]),
"vlm_num_layers": config.vlm_num_layers,
"vlm_dtype": config.vlm_dtype,
"use_flash_attn": config.use_flash_attn,
"action_head": config.action_head,
"action_horizon": config.chunk_size,
"per_action_dim": config.max_action_dim,
"state_dim": config.max_state_dim,
"embed_dim": config.embed_dim,
"hidden_dim": config.hidden_dim,
"state_hidden_dim": config.state_hidden_dim,
"num_heads": config.num_heads,
"num_layers": config.num_layers,
"dropout": config.dropout,
"num_inference_timesteps": config.num_inference_timesteps,
"num_categories": config.num_categories,
"enable_gradient_checkpointing": config.enable_gradient_checkpointing
and bool(config.finetune_vlm or config.finetune_language_model or config.finetune_vision_model),
"gradient_checkpointing_use_reentrant": config.gradient_checkpointing_use_reentrant,
"finetune_vlm": config.finetune_vlm,
"finetune_language_model": config.finetune_language_model,
"finetune_vision_model": config.finetune_vision_model,
"finetune_action_head": config.finetune_action_head,
}
@property
def _camera_keys(self) -> list[str]:
return list(self.config.image_features)
@@ -406,6 +376,9 @@ class EVO1Policy(PreTrainedPolicy):
embodiment_ids=embodiment_ids,
)
flat_action_mask = action_mask.view(action_mask.shape[0], -1).to(dtype=actions_gt.dtype)
# Flow-matching velocity target. Padded (masked-out) action dims are already zero on both sides
# here (`actions_gt` is zero-padded in `_prepare_actions`, and `noise` is masked inside the head),
# and the whole difference is multiplied by `flat_action_mask`, so padded dims contribute nothing.
target_velocity = (actions_gt - noise).view(actions_gt.shape[0], -1) * flat_action_mask
loss = self._compute_masked_loss(pred_velocity, target_velocity, action_mask, reduction)
loss_mean = loss.mean().item() if loss.ndim > 0 else loss.item()
@@ -447,4 +420,7 @@ class EVO1Policy(PreTrainedPolicy):
if len(self._action_queue) == 0:
action_chunk = self.predict_action_chunk(batch)[:, : self.config.n_action_steps]
self._action_queue.extend(action_chunk.transpose(0, 1))
# Returns one step of shape (B, max_action_dim): actions are emitted at the padded max_action_dim
# width and cropped to the real action dim downstream by the postprocessor (Evo1ActionProcessorStep).
# Callers that bypass the postprocessor receive the padded width.
return self._action_queue.popleft()
+2 -1
View File
@@ -21,7 +21,6 @@ from typing import Any
import torch
from lerobot.configs import FeatureType, PipelineFeatureType, PolicyFeature
from lerobot.policies.evo1.configuration_evo1 import Evo1Config
from lerobot.processor import (
AddBatchDimensionProcessorStep,
DeviceProcessorStep,
@@ -54,6 +53,8 @@ from lerobot.utils.constants import (
TRUNCATED,
)
from .configuration_evo1 import Evo1Config
def evo1_batch_to_transition(batch: dict[str, Any]):
transition = batch_to_transition(batch)
+30 -22
View File
@@ -20,6 +20,7 @@ import pytest
import torch
from torch import nn
import lerobot.policies.evo1.evo1_model as evo1_model
import lerobot.policies.evo1.modeling_evo1 as modeling_evo1
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.policies.evo1.configuration_evo1 import Evo1Config
@@ -225,17 +226,26 @@ def test_evo1_rejects_non_square_image_resolution():
make_config(image_resolution=(448, 320))
def test_evo1_build_model_config_uses_image_resolution_and_trainable_checkpointing():
stage1 = make_config(training_stage="stage1", image_resolution=(224, 224))
stage1_model_config = modeling_evo1.EVO1Policy._build_model_config(stage1)
def test_evo1_model_uses_image_resolution_and_trainable_checkpointing(monkeypatch):
captured: dict = {}
assert stage1_model_config["image_size"] == 224
assert stage1_model_config["enable_gradient_checkpointing"] is False
class SpyEmbedder(nn.Module):
def __init__(self, **kwargs):
super().__init__()
captured.clear()
captured.update(kwargs)
monkeypatch.setattr(evo1_model, "InternVL3Embedder", SpyEmbedder)
stage1 = make_config(training_stage="stage1", image_resolution=(224, 224))
evo1_model.EVO1(stage1)
assert captured["image_size"] == 224
# VLM is frozen in stage1, so gradient checkpointing is gated off.
assert captured["enable_gradient_checkpointing"] is False
stage2 = make_config(training_stage="stage2", image_resolution=(224, 224))
stage2_model_config = modeling_evo1.EVO1Policy._build_model_config(stage2)
assert stage2_model_config["enable_gradient_checkpointing"] is True
evo1_model.EVO1(stage2)
assert captured["enable_gradient_checkpointing"] is True
def test_evo1_policy_processors_pad_state_crop_action_and_binarize_gripper():
@@ -429,21 +439,19 @@ def test_evo1_action_mask_accepts_chunk_size_one(monkeypatch):
assert not action_mask[:, :, ACTION_DIM:].any()
def test_flowmatching_dict_config_enables_state_encoder_for_horizon_one():
def test_flowmatching_state_encoder_for_horizon_one():
head = FlowmatchingActionHead(
config={
"embed_dim": EMBED_DIM,
"hidden_dim": 16,
"action_dim": ACTION_DIM,
"horizon": 1,
"per_action_dim": ACTION_DIM,
"num_heads": 2,
"num_layers": 1,
"num_inference_timesteps": 2,
"state_dim": STATE_DIM,
"state_hidden_dim": 16,
"num_categories": 1,
}
embed_dim=EMBED_DIM,
hidden_dim=16,
action_dim=ACTION_DIM,
horizon=1,
per_action_dim=ACTION_DIM,
num_heads=2,
num_layers=1,
num_inference_timesteps=2,
state_dim=STATE_DIM,
state_hidden_dim=16,
num_categories=1,
)
assert head.state_encoder is not None
Generated
-3
View File
@@ -2875,7 +2875,6 @@ all = [
{ name = "scikit-image" },
{ name = "scipy" },
{ name = "teleop" },
{ name = "timm" },
{ name = "torchcodec", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'AMD64' and sys_platform == 'linux') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'win32'" },
{ name = "torchdiffeq" },
{ name = "transformers" },
@@ -2980,7 +2979,6 @@ evaluation = [
{ name = "av" },
]
evo1 = [
{ name = "timm" },
{ name = "transformers" },
]
fastwam = [
@@ -3371,7 +3369,6 @@ requires-dist = [
{ name = "lerobot", extras = ["scipy-dep"], marker = "extra == 'wallx'" },
{ name = "lerobot", extras = ["smolvla"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["test"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["timm-dep"], marker = "extra == 'evo1'" },
{ name = "lerobot", extras = ["timm-dep"], marker = "extra == 'groot'" },
{ name = "lerobot", extras = ["topreward"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["training"], marker = "extra == 'all'" },