mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-26 12:47:18 +00:00
Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 9423deda02 | |||
| 25556ceefe | |||
| 4cfa762da8 | |||
| fa984990c0 |
+13
-3
@@ -139,6 +139,8 @@ every finetuning flag.
|
||||
| `policy.n_action_steps` | `50` | Number of actions consumed from a sampled chunk |
|
||||
| `policy.max_state_dim` | `24` | State padding dimension |
|
||||
| `policy.max_action_dim` | `24` | Action padding dimension |
|
||||
| `policy.postprocess_action_dim` | `null` | Optional action dimension returned after EVO1 postprocessing |
|
||||
| `policy.binarize_gripper` | `false` | Binarizes the postprocessed gripper channel for LIBERO-style eval |
|
||||
| `policy.task_field` | `task` | Batch field used as the language prompt |
|
||||
|
||||
## Results
|
||||
@@ -159,14 +161,22 @@ pixel embeddings, VLM fused tokens, normalized actions, and denormalized actions
|
||||
(`max_abs_diff=0.0`).
|
||||
|
||||
The published checkpoint expects the raw LIBERO camera feature names
|
||||
`observation.images.agentview_image` and `observation.images.robot0_eye_in_hand_image`. To run the converted
|
||||
checkpoint with LeRobot LIBERO evaluation for the same one-episode-per-task setting, keep those camera names
|
||||
instead of the default `image`/`image2` mapping:
|
||||
`observation.images.agentview_image` and `observation.images.robot0_eye_in_hand_image`. The official EVO1 LIBERO
|
||||
rollout protocol also replans every 14 actions and binarizes the gripper command before stepping the simulator.
|
||||
The EVO1 policy postprocessor can crop the padded 24D action back to the 7D LIBERO action space and apply that
|
||||
gripper binarization. To run the converted checkpoint with LeRobot LIBERO evaluation for the same
|
||||
one-episode-per-task setting, keep the raw camera names instead of the default `image`/`image2` mapping, enable
|
||||
FlashAttention, and set the LIBERO action postprocessing flags:
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--policy.path=javadcc/evo1-libero-lerobot \
|
||||
--policy.vlm_model_name=OpenGVLab/InternVL3-1B \
|
||||
--policy.device=cuda \
|
||||
--policy.use_flash_attn=true \
|
||||
--policy.n_action_steps=14 \
|
||||
--policy.postprocess_action_dim=7 \
|
||||
--policy.binarize_gripper=true \
|
||||
--env.type=libero \
|
||||
--env.task=libero_object \
|
||||
--env.camera_name_mapping="{agentview_image: agentview_image, robot0_eye_in_hand_image: robot0_eye_in_hand_image}" \
|
||||
|
||||
+3
-3
@@ -140,6 +140,7 @@ pyserial-dep = ["pyserial>=3.5,<4.0"]
|
||||
deepdiff-dep = ["deepdiff>=7.0.1,<9.0.0"]
|
||||
pynput-dep = ["pynput>=1.7.8,<1.9.0"]
|
||||
pyzmq-dep = ["pyzmq>=26.2.1,<28.0.0"]
|
||||
timm-dep = ["timm>=1.0.0,<1.1.0"]
|
||||
|
||||
# Motors
|
||||
feetech = ["feetech-servo-sdk>=1.0.0,<2.0.0", "lerobot[pyserial-dep]", "lerobot[deepdiff-dep]"]
|
||||
@@ -187,7 +188,7 @@ groot = [
|
||||
"lerobot[peft-dep]",
|
||||
"lerobot[diffusers-dep]",
|
||||
"dm-tree>=0.1.8,<1.0.0",
|
||||
"timm>=1.0.0,<1.1.0",
|
||||
"lerobot[timm-dep]",
|
||||
"decord>=0.6.0,<1.0.0; (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
|
||||
"ninja>=1.11.1,<2.0.0",
|
||||
"flash-attn>=2.5.9,<3.0.0 ; sys_platform != 'darwin'"
|
||||
@@ -195,7 +196,7 @@ groot = [
|
||||
sarm = ["lerobot[transformers-dep]", "pydantic>=2.0.0,<3.0.0", "faker>=33.0.0,<35.0.0", "lerobot[matplotlib-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||
xvla = ["lerobot[transformers-dep]"]
|
||||
eo1 = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||
evo1 = ["lerobot[transformers-dep]", "timm>=1.0.0,<1.1.0"]
|
||||
evo1 = ["lerobot[transformers-dep]", "lerobot[timm-dep]"]
|
||||
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
||||
|
||||
# Features
|
||||
@@ -350,7 +351,6 @@ ignore = [
|
||||
# E402: conditional-import guards (TYPE_CHECKING / is_package_available) must precede the imports they protect
|
||||
"src/lerobot/scripts/convert_dataset_v21_to_v30.py" = ["E402"]
|
||||
"src/lerobot/policies/wall_x/**" = ["N801", "N812", "SIM102", "SIM108", "SIM210", "SIM211", "B006", "B007", "SIM118"] # Supprese these as they are coming from original Qwen2_5_vl code TODO(pepijn): refactor original
|
||||
"src/lerobot/policies/evo1/**" = ["N801", "N812"]
|
||||
|
||||
[tool.ruff.lint.isort]
|
||||
combine-as-imports = true
|
||||
|
||||
@@ -26,7 +26,6 @@ from gymnasium.envs.registration import registry as gym_registry
|
||||
from lerobot.configs import FeatureType, PolicyFeature
|
||||
from lerobot.processor import (
|
||||
IsaaclabArenaProcessorStep,
|
||||
LiberoActionProcessorStep,
|
||||
LiberoProcessorStep,
|
||||
PolicyProcessorPipeline,
|
||||
)
|
||||
@@ -128,7 +127,7 @@ class EnvConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||
vec = env_cls([_make_one for _ in range(n_envs)], **extra_kwargs)
|
||||
return {self.type: {0: vec}}
|
||||
|
||||
def get_env_processors(self, policy_cfg: Any | None = None):
|
||||
def get_env_processors(self):
|
||||
"""Return (preprocessor, postprocessor) for this env. Default: identity."""
|
||||
return PolicyProcessorPipeline(steps=[]), PolicyProcessorPipeline(steps=[])
|
||||
|
||||
@@ -441,13 +440,10 @@ class LiberoEnv(EnvConfig):
|
||||
is_libero_plus=self.is_libero_plus,
|
||||
)
|
||||
|
||||
def get_env_processors(self, policy_cfg: Any | None = None):
|
||||
max_state_dim = getattr(policy_cfg, "max_state_dim", None) if getattr(policy_cfg, "type", None) == "evo1" else None
|
||||
action_feature = self.features.get(ACTION)
|
||||
action_dim = int(action_feature.shape[0]) if action_feature is not None else 7
|
||||
def get_env_processors(self):
|
||||
return (
|
||||
PolicyProcessorPipeline(steps=[LiberoProcessorStep(max_state_dim=max_state_dim)]),
|
||||
PolicyProcessorPipeline(steps=[LiberoActionProcessorStep(action_dim=action_dim)]),
|
||||
PolicyProcessorPipeline(steps=[LiberoProcessorStep()]),
|
||||
PolicyProcessorPipeline(steps=[]),
|
||||
)
|
||||
|
||||
|
||||
@@ -713,7 +709,7 @@ class IsaaclabArenaEnv(HubEnvConfig):
|
||||
def gym_kwargs(self) -> dict:
|
||||
return {}
|
||||
|
||||
def get_env_processors(self, policy_cfg: Any | None = None):
|
||||
def get_env_processors(self):
|
||||
state_keys = tuple(k.strip() for k in (self.state_keys or "").split(",") if k.strip())
|
||||
camera_keys = tuple(k.strip() for k in (self.camera_keys or "").split(",") if k.strip())
|
||||
if not state_keys and not camera_keys:
|
||||
|
||||
@@ -15,7 +15,6 @@
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
from typing import Any
|
||||
|
||||
import gymnasium as gym
|
||||
@@ -53,14 +52,7 @@ def make_env_pre_post_processors(
|
||||
|
||||
return make_xvla_libero_pre_post_processors()
|
||||
|
||||
get_processors = env_cfg.get_env_processors
|
||||
signature = inspect.signature(get_processors)
|
||||
supports_policy_cfg = "policy_cfg" in signature.parameters or any(
|
||||
param.kind is inspect.Parameter.VAR_KEYWORD for param in signature.parameters.values()
|
||||
)
|
||||
if supports_policy_cfg:
|
||||
return get_processors(policy_cfg=policy_cfg)
|
||||
return get_processors()
|
||||
return env_cfg.get_env_processors()
|
||||
|
||||
|
||||
def make_env(
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import math
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
@@ -26,6 +27,8 @@ from lerobot.optim.optimizers import AdamWConfig
|
||||
from lerobot.optim.schedulers import LRSchedulerConfig
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@LRSchedulerConfig.register_subclass("evo1_exact")
|
||||
@dataclass
|
||||
@@ -59,6 +62,12 @@ class Evo1Config(PreTrainedConfig):
|
||||
max_views: int = 3
|
||||
image_resolution: tuple[int, int] = (448, 448)
|
||||
empty_cameras: int = 0
|
||||
postprocess_action_dim: int | None = None
|
||||
binarize_gripper: bool = False
|
||||
gripper_index: int = 6
|
||||
gripper_threshold: float = 0.5
|
||||
gripper_below_threshold_value: float = 1.0
|
||||
gripper_above_threshold_value: float = -1.0
|
||||
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
@@ -68,7 +77,7 @@ class Evo1Config(PreTrainedConfig):
|
||||
}
|
||||
)
|
||||
|
||||
vlm_model_name: str = "OpenGVLab/InternVL3-1B"
|
||||
vlm_model_name: str = "OpenGVLab/InternVL3-1B-hf"
|
||||
vlm_num_layers: int | None = 14
|
||||
vlm_dtype: str = "bfloat16"
|
||||
use_flash_attn: bool = True
|
||||
@@ -114,16 +123,32 @@ class Evo1Config(PreTrainedConfig):
|
||||
)
|
||||
|
||||
if self.apply_training_stage_defaults:
|
||||
if self.training_stage == "stage1":
|
||||
self.finetune_vlm = False
|
||||
self.finetune_language_model = False
|
||||
self.finetune_vision_model = False
|
||||
self.finetune_action_head = True
|
||||
elif self.training_stage == "stage2":
|
||||
self.finetune_vlm = True
|
||||
self.finetune_language_model = True
|
||||
self.finetune_vision_model = True
|
||||
self.finetune_action_head = True
|
||||
stage_defaults = {
|
||||
"stage1": {
|
||||
"finetune_vlm": False,
|
||||
"finetune_language_model": False,
|
||||
"finetune_vision_model": False,
|
||||
"finetune_action_head": True,
|
||||
},
|
||||
"stage2": {
|
||||
"finetune_vlm": True,
|
||||
"finetune_language_model": True,
|
||||
"finetune_vision_model": True,
|
||||
"finetune_action_head": True,
|
||||
},
|
||||
}[self.training_stage]
|
||||
for flag_name, default_value in stage_defaults.items():
|
||||
current_value = getattr(self, flag_name)
|
||||
if current_value is not None and current_value != default_value:
|
||||
logger.warning(
|
||||
"EVO1 %s=%s is overridden by training_stage=%s default %s. "
|
||||
"Set apply_training_stage_defaults=false to keep explicit finetuning flags.",
|
||||
flag_name,
|
||||
current_value,
|
||||
self.training_stage,
|
||||
default_value,
|
||||
)
|
||||
setattr(self, flag_name, default_value)
|
||||
elif self.training_stage == "stage1":
|
||||
if self.finetune_vlm is None:
|
||||
self.finetune_vlm = False
|
||||
@@ -171,6 +196,11 @@ class Evo1Config(PreTrainedConfig):
|
||||
raise ValueError(
|
||||
f"n_action_steps ({self.n_action_steps}) must be <= chunk_size ({self.chunk_size})"
|
||||
)
|
||||
if len(self.image_resolution) != 2 or self.image_resolution[0] != self.image_resolution[1]:
|
||||
raise ValueError(
|
||||
"EVO1 currently expects a square image_resolution because InternVL3 preprocessing "
|
||||
f"uses a scalar image_size, got {self.image_resolution}."
|
||||
)
|
||||
|
||||
def validate_features(self) -> None:
|
||||
if self.input_features is None:
|
||||
|
||||
@@ -21,8 +21,8 @@ import torch
|
||||
import torch.nn as nn
|
||||
from PIL import Image
|
||||
|
||||
from lerobot.policies.evo1.flow_matching import FlowmatchingActionHead
|
||||
from lerobot.policies.evo1.internvl3_embedder import InternVL3Embedder
|
||||
from .flow_matching import FlowmatchingActionHead
|
||||
from .internvl3_embedder import InternVL3Embedder
|
||||
|
||||
|
||||
def _cfgget(config: Any, key: str, default=None):
|
||||
@@ -163,37 +163,6 @@ class EVO1(nn.Module):
|
||||
embodiment_id=embodiment_ids,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def run_inference(
|
||||
self,
|
||||
images: list[Image.Image | torch.Tensor],
|
||||
image_mask: torch.Tensor,
|
||||
prompt: str,
|
||||
state_input: list | torch.Tensor,
|
||||
return_cls_only: bool | None = None,
|
||||
action_mask: torch.Tensor | None = None,
|
||||
embodiment_ids: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
if image_mask.dim() == 1:
|
||||
image_mask = image_mask.unsqueeze(0)
|
||||
|
||||
fused_tokens = self.get_vl_embeddings(
|
||||
images=[images],
|
||||
image_mask=image_mask,
|
||||
prompt=[prompt],
|
||||
return_cls_only=return_cls_only,
|
||||
)
|
||||
state_tensor = self.prepare_state(state_input)
|
||||
action = self.predict_action(
|
||||
fused_tokens,
|
||||
state_tensor,
|
||||
action_mask=action_mask,
|
||||
embodiment_ids=embodiment_ids,
|
||||
)
|
||||
if isinstance(action, torch.Tensor) and action.dtype == torch.bfloat16:
|
||||
action = action.to(torch.float32)
|
||||
return action
|
||||
|
||||
def forward(
|
||||
self,
|
||||
fused_tokens: torch.Tensor,
|
||||
|
||||
@@ -129,7 +129,10 @@ class MultiEmbodimentActionEncoder(nn.Module):
|
||||
|
||||
def forward(self, action_seq: torch.Tensor, category_id: torch.LongTensor):
|
||||
batch_size, horizon, action_dim = action_seq.shape
|
||||
assert self.horizon == horizon, "Action sequence length must match horizon"
|
||||
if self.horizon != horizon:
|
||||
raise ValueError(
|
||||
f"Action sequence length must match horizon: got {horizon}, expected {self.horizon}."
|
||||
)
|
||||
|
||||
x = action_seq.reshape(batch_size * horizon, action_dim)
|
||||
if category_id.dim() == 0:
|
||||
|
||||
@@ -16,15 +16,12 @@ from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import logging
|
||||
import types
|
||||
from collections.abc import Sequence
|
||||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint
|
||||
import torchvision.transforms.functional as TF
|
||||
import torchvision.transforms.functional as tvf
|
||||
from PIL import Image
|
||||
from torchvision.transforms.functional import to_pil_image
|
||||
|
||||
@@ -45,65 +42,6 @@ IMG_END_TOKEN = "</img>" # nosec B105
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _patch_vision_encoder_checkpointing(encoder: nn.Module, use_reentrant: bool) -> None:
|
||||
if getattr(encoder, "_evo1_checkpoint_patch_applied", False):
|
||||
encoder.gradient_checkpointing_use_reentrant = use_reentrant
|
||||
return
|
||||
|
||||
original_forward = encoder.forward
|
||||
|
||||
def forward_with_checkpoint_kwargs(self, *args, **kwargs):
|
||||
original_checkpoint = torch.utils.checkpoint.checkpoint
|
||||
|
||||
def checkpoint(function, *checkpoint_args, **checkpoint_kwargs):
|
||||
checkpoint_kwargs.setdefault("use_reentrant", self.gradient_checkpointing_use_reentrant)
|
||||
return original_checkpoint(function, *checkpoint_args, **checkpoint_kwargs)
|
||||
|
||||
torch.utils.checkpoint.checkpoint = checkpoint
|
||||
try:
|
||||
return original_forward(*args, **kwargs)
|
||||
finally:
|
||||
torch.utils.checkpoint.checkpoint = original_checkpoint
|
||||
|
||||
encoder.gradient_checkpointing_use_reentrant = use_reentrant
|
||||
encoder.forward = types.MethodType(forward_with_checkpoint_kwargs, encoder)
|
||||
encoder._evo1_checkpoint_patch_applied = True
|
||||
|
||||
|
||||
def flash_attn_is_available() -> bool:
|
||||
try:
|
||||
import flash_attn # noqa: F401
|
||||
except ModuleNotFoundError:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _internvl_transformers5_load_compatibility():
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
|
||||
original_linspace = torch.linspace
|
||||
original_mark_tied = PreTrainedModel.mark_tied_weights_as_initialized
|
||||
|
||||
def linspace(*args, **kwargs):
|
||||
if kwargs.get("device") is None:
|
||||
kwargs["device"] = torch.device("cpu")
|
||||
return original_linspace(*args, **kwargs)
|
||||
|
||||
def mark_tied_weights_as_initialized(self, loading_info):
|
||||
if not hasattr(self, "all_tied_weights_keys"):
|
||||
self.all_tied_weights_keys = {}
|
||||
return original_mark_tied(self, loading_info)
|
||||
|
||||
torch.linspace = linspace
|
||||
PreTrainedModel.mark_tied_weights_as_initialized = mark_tied_weights_as_initialized
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
torch.linspace = original_linspace
|
||||
PreTrainedModel.mark_tied_weights_as_initialized = original_mark_tied
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=10000)
|
||||
def get_target_aspect_ratio(orig_width: int, orig_height: int, image_size: int, min_num: int, max_num: int):
|
||||
aspect_ratio = orig_width / orig_height
|
||||
@@ -152,9 +90,11 @@ def dynamic_preprocess(image, min_num=1, max_num=1, image_size=448, use_thumbnai
|
||||
|
||||
|
||||
class InternVL3Embedder(nn.Module):
|
||||
"""Vision-language embedder using the native HF InternVL3 model (no trust_remote_code)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name="OpenGVLab/InternVL3-1B",
|
||||
model_name="OpenGVLab/InternVL3-1B-hf",
|
||||
image_size=448,
|
||||
device="cuda",
|
||||
num_language_layers: int | None = 14,
|
||||
@@ -173,43 +113,31 @@ class InternVL3Embedder(nn.Module):
|
||||
|
||||
require_package("transformers", extra="evo1")
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, use_fast=False)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
if isinstance(model_dtype, str):
|
||||
try:
|
||||
model_dtype = getattr(torch, model_dtype)
|
||||
except AttributeError as exc:
|
||||
raise ValueError(f"Unsupported EVO1 vlm_dtype '{model_dtype}'") from exc
|
||||
|
||||
resolved_use_flash_attn = bool(use_flash_attn and flash_attn_is_available())
|
||||
if use_flash_attn and not resolved_use_flash_attn:
|
||||
logger.warning("flash_attn is not installed. Falling back to standard attention.")
|
||||
attn_implementation = "flash_attention_2" if (use_flash_attn and _flash_attn_available()) else "eager"
|
||||
if use_flash_attn and attn_implementation == "eager":
|
||||
logger.warning("flash_attn is not installed. Falling back to eager attention.")
|
||||
|
||||
# InternVL3 remote code predates Transformers 5 post-init conventions:
|
||||
# it computes stochastic-depth scalars via torch.linspace(...).item()
|
||||
# while Transformers initializes under torch.device("meta"), and it
|
||||
# does not populate all_tied_weights_keys before loading finalization.
|
||||
with _internvl_transformers5_load_compatibility():
|
||||
self.model = AutoModel.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=model_dtype,
|
||||
trust_remote_code=True,
|
||||
use_flash_attn=resolved_use_flash_attn,
|
||||
low_cpu_mem_usage=True,
|
||||
_fast_init=False,
|
||||
).to(self._requested_device)
|
||||
self.model = AutoModel.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=model_dtype,
|
||||
attn_implementation=attn_implementation,
|
||||
low_cpu_mem_usage=True,
|
||||
).to(self._requested_device)
|
||||
|
||||
if hasattr(self.model.language_model, "model"):
|
||||
layers = self.model.language_model.model.layers
|
||||
else:
|
||||
layers = self.model.language_model.layers
|
||||
self.num_image_token = self.model.config.image_seq_length
|
||||
|
||||
# Truncate language model to the requested number of layers
|
||||
layers = self.model.language_model.layers
|
||||
if self.num_language_layers is not None:
|
||||
layers = layers[: self.num_language_layers]
|
||||
|
||||
if hasattr(self.model.language_model, "model"):
|
||||
self.model.language_model.model.layers = torch.nn.ModuleList(layers)
|
||||
else:
|
||||
self.model.language_model.layers = torch.nn.ModuleList(layers)
|
||||
self.model.language_model.lm_head = torch.nn.Identity()
|
||||
self.model.language_model.layers = torch.nn.ModuleList(layers)
|
||||
|
||||
self._configure_memory_features()
|
||||
self.img_context_token_id = self.tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
|
||||
@@ -218,20 +146,12 @@ class InternVL3Embedder(nn.Module):
|
||||
checkpoint_kwargs = {"use_reentrant": self.gradient_checkpointing_use_reentrant}
|
||||
|
||||
if not self.enable_gradient_checkpointing:
|
||||
if hasattr(self.model, "vision_model") and hasattr(self.model.vision_model, "encoder"):
|
||||
self.model.vision_model.encoder.gradient_checkpointing = False
|
||||
language_model = getattr(self.model, "language_model", None)
|
||||
if language_model is not None:
|
||||
if hasattr(language_model, "gradient_checkpointing_disable"):
|
||||
language_model.gradient_checkpointing_disable()
|
||||
elif hasattr(language_model, "gradient_checkpointing"):
|
||||
language_model.gradient_checkpointing = False
|
||||
if hasattr(language_model, "model"):
|
||||
inner = language_model.model
|
||||
if hasattr(inner, "gradient_checkpointing_disable"):
|
||||
inner.gradient_checkpointing_disable()
|
||||
elif hasattr(inner, "gradient_checkpointing"):
|
||||
inner.gradient_checkpointing = False
|
||||
language_model = self.model.language_model
|
||||
if hasattr(language_model, "gradient_checkpointing_disable"):
|
||||
language_model.gradient_checkpointing_disable()
|
||||
vision_tower = getattr(self.model, "vision_tower", None)
|
||||
if vision_tower is not None and hasattr(vision_tower, "encoder"):
|
||||
vision_tower.encoder.gradient_checkpointing = False
|
||||
return
|
||||
|
||||
def _enable_ckpt(module: nn.Module | None) -> bool:
|
||||
@@ -250,21 +170,14 @@ class InternVL3Embedder(nn.Module):
|
||||
|
||||
enabled_any = _enable_ckpt(self.model)
|
||||
|
||||
if hasattr(self.model, "vision_model") and hasattr(self.model.vision_model, "encoder"):
|
||||
encoder = self.model.vision_model.encoder
|
||||
encoder.gradient_checkpointing = True
|
||||
_patch_vision_encoder_checkpointing(
|
||||
encoder, use_reentrant=self.gradient_checkpointing_use_reentrant
|
||||
)
|
||||
enabled_any = True
|
||||
vision_tower = getattr(self.model, "vision_tower", None)
|
||||
if vision_tower is not None:
|
||||
enabled_any = _enable_ckpt(vision_tower) or enabled_any
|
||||
|
||||
language_model = getattr(self.model, "language_model", None)
|
||||
if language_model is not None:
|
||||
enabled_any = _enable_ckpt(language_model) or enabled_any
|
||||
if hasattr(language_model, "model"):
|
||||
enabled_any = _enable_ckpt(language_model.model) or enabled_any
|
||||
if hasattr(language_model, "config"):
|
||||
language_model.config.use_cache = False
|
||||
language_model = self.model.language_model
|
||||
enabled_any = _enable_ckpt(language_model) or enabled_any
|
||||
if hasattr(language_model, "config"):
|
||||
language_model.config.use_cache = False
|
||||
|
||||
if hasattr(self.model, "config"):
|
||||
self.model.config.use_cache = False
|
||||
@@ -284,7 +197,7 @@ class InternVL3Embedder(nn.Module):
|
||||
else:
|
||||
pil_image = image.convert("RGB")
|
||||
tiles = dynamic_preprocess(pil_image, image_size=self.image_size)
|
||||
tile_tensors = torch.stack([TF.to_tensor(tile) for tile in tiles]).to(
|
||||
tile_tensors = torch.stack([tvf.to_tensor(tile) for tile in tiles]).to(
|
||||
device=self.device, dtype=torch.bfloat16
|
||||
)
|
||||
mean = torch.tensor(IMAGENET_MEAN, device=self.device, dtype=torch.bfloat16).view(1, 3, 1, 1)
|
||||
@@ -323,76 +236,12 @@ class InternVL3Embedder(nn.Module):
|
||||
for num_tiles_list, text_prompt in zip(batch_num_tiles_list, text_prompts, strict=True):
|
||||
prompt_segments = []
|
||||
for i, tile_count in enumerate(num_tiles_list):
|
||||
token_count = self.model.num_image_token * tile_count
|
||||
token_count = self.num_image_token * tile_count
|
||||
image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * token_count + IMG_END_TOKEN
|
||||
prompt_segments.append(f"Image-{i + 1}: {image_tokens}\n")
|
||||
prompts.append("".join(prompt_segments) + text_prompt.strip())
|
||||
return prompts
|
||||
|
||||
def _prepare_and_fuse_embeddings(
|
||||
self,
|
||||
prompts: Sequence[str],
|
||||
vit_embeds: torch.Tensor,
|
||||
image_masks: torch.Tensor,
|
||||
batch_num_tiles_list: list[list[int]],
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
untruncated_ids = self.tokenizer(list(prompts), padding=False, truncation=False)["input_ids"]
|
||||
true_sequence_length = max((len(ids) for ids in untruncated_ids), default=0)
|
||||
if true_sequence_length > self.max_text_length:
|
||||
logger.warning(
|
||||
"InternVL3 prompt truncated in batch: max_length=%s actual_max_length=%s",
|
||||
self.max_text_length,
|
||||
true_sequence_length,
|
||||
)
|
||||
|
||||
model_inputs = self.tokenizer(
|
||||
list(prompts),
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=self.max_text_length,
|
||||
).to(self.device)
|
||||
input_ids = model_inputs["input_ids"]
|
||||
attention_mask = model_inputs["attention_mask"]
|
||||
|
||||
img_token_mask = input_ids == self.img_context_token_id
|
||||
input_embeds = self.model.language_model.get_input_embeddings()(input_ids).clone()
|
||||
|
||||
batch_size, _, channels = input_embeds.shape
|
||||
vit_embeds = vit_embeds.reshape(-1, channels).to(dtype=input_embeds.dtype, device=input_embeds.device)
|
||||
tokens_per_tile = self.model.num_image_token
|
||||
actual_vis_tokens_list = img_token_mask.sum(dim=1).tolist()
|
||||
|
||||
vit_idx = 0
|
||||
for batch_index in range(batch_size):
|
||||
expected_vis_tokens = sum(batch_num_tiles_list[batch_index]) * tokens_per_tile
|
||||
mask_b = img_token_mask[batch_index]
|
||||
actual_vis_tokens = actual_vis_tokens_list[batch_index]
|
||||
|
||||
item_vit_embeds = vit_embeds[vit_idx : vit_idx + expected_vis_tokens]
|
||||
vit_idx += expected_vis_tokens
|
||||
if actual_vis_tokens > 0:
|
||||
if item_vit_embeds.shape[0] < actual_vis_tokens:
|
||||
raise ValueError(
|
||||
f"InternVL3 produced fewer image tokens than expected for sample {batch_index}: "
|
||||
f"got {item_vit_embeds.shape[0]}, need {actual_vis_tokens}"
|
||||
)
|
||||
input_embeds[batch_index, mask_b] = item_vit_embeds[:actual_vis_tokens]
|
||||
|
||||
current_token_idx = 0
|
||||
img_token_locations = torch.where(mask_b)[0]
|
||||
for image_index, num_tiles in enumerate(batch_num_tiles_list[batch_index]):
|
||||
num_tokens_for_image = num_tiles * tokens_per_tile
|
||||
if not bool(image_masks[batch_index, image_index].item()):
|
||||
start_offset = current_token_idx
|
||||
end_offset = min(current_token_idx + num_tokens_for_image, len(img_token_locations))
|
||||
if start_offset < end_offset:
|
||||
idxs = img_token_locations[start_offset:end_offset]
|
||||
attention_mask[batch_index, idxs] = 0
|
||||
current_token_idx += num_tokens_for_image
|
||||
|
||||
return input_embeds, attention_mask
|
||||
|
||||
def get_fused_image_text_embedding_from_tensor_images(
|
||||
self,
|
||||
image_tensors_batch: Sequence[Sequence[Image.Image | torch.Tensor]],
|
||||
@@ -404,27 +253,46 @@ class InternVL3Embedder(nn.Module):
|
||||
if pixel_values.shape[0] == 0:
|
||||
logger.warning("InternVL3 received an empty image batch after preprocessing.")
|
||||
hidden_size = getattr(self.model.config, "hidden_size", None)
|
||||
if hidden_size is None and hasattr(self.model.language_model, "config"):
|
||||
hidden_size = getattr(self.model.language_model.config, "hidden_size", None)
|
||||
if hidden_size is None:
|
||||
hidden_size = getattr(self.model.config.text_config, "hidden_size", None)
|
||||
if hidden_size is None:
|
||||
raise RuntimeError("Unable to infer hidden size for empty InternVL3 batch.")
|
||||
empty = torch.empty(0, hidden_size, device=self.device, dtype=torch.float32)
|
||||
return empty
|
||||
|
||||
prompts = self._build_multimodal_prompts(batch_num_tiles_list, text_prompts)
|
||||
vit_embeds = self.model.extract_feature(pixel_values)
|
||||
inputs_embeds, attention_mask = self._prepare_and_fuse_embeddings(
|
||||
prompts,
|
||||
vit_embeds,
|
||||
image_masks.to(device=self.device),
|
||||
batch_num_tiles_list,
|
||||
)
|
||||
|
||||
outputs = self.model.language_model(
|
||||
inputs_embeds=inputs_embeds,
|
||||
model_inputs = self.tokenizer(
|
||||
list(prompts),
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=self.max_text_length,
|
||||
).to(self.device)
|
||||
input_ids = model_inputs["input_ids"]
|
||||
attention_mask = model_inputs["attention_mask"]
|
||||
|
||||
# Zero out attention for absent images
|
||||
img_token_mask = input_ids == self.img_context_token_id
|
||||
tokens_per_tile = self.num_image_token
|
||||
for batch_index in range(input_ids.shape[0]):
|
||||
current_token_idx = 0
|
||||
img_token_locations = torch.where(img_token_mask[batch_index])[0]
|
||||
for image_index, num_tiles in enumerate(batch_num_tiles_list[batch_index]):
|
||||
num_tokens_for_image = num_tiles * tokens_per_tile
|
||||
if not bool(image_masks[batch_index, image_index].item()):
|
||||
start_offset = current_token_idx
|
||||
end_offset = min(current_token_idx + num_tokens_for_image, len(img_token_locations))
|
||||
if start_offset < end_offset:
|
||||
idxs = img_token_locations[start_offset:end_offset]
|
||||
attention_mask[batch_index, idxs] = 0
|
||||
current_token_idx += num_tokens_for_image
|
||||
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
pixel_values=pixel_values,
|
||||
attention_mask=attention_mask,
|
||||
output_hidden_states=True,
|
||||
use_cache=False,
|
||||
return_dict=True,
|
||||
)
|
||||
fused_hidden = outputs.hidden_states[-1].to(torch.float32)
|
||||
@@ -433,3 +301,11 @@ class InternVL3Embedder(nn.Module):
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return next(self.model.parameters()).device
|
||||
|
||||
|
||||
def _flash_attn_available() -> bool:
|
||||
try:
|
||||
import flash_attn # noqa: F401
|
||||
except ModuleNotFoundError:
|
||||
return False
|
||||
return True
|
||||
|
||||
@@ -45,6 +45,7 @@ class EVO1Policy(PreTrainedPolicy):
|
||||
self.config = config
|
||||
self.model = EVO1(self._build_model_config(config))
|
||||
self.model.set_finetune_flags()
|
||||
self._keep_frozen_embedder_eval()
|
||||
self.reset()
|
||||
|
||||
@classmethod
|
||||
@@ -64,7 +65,7 @@ class EVO1Policy(PreTrainedPolicy):
|
||||
**kwargs,
|
||||
) -> T:
|
||||
if strict is None:
|
||||
strict = not (config is not None and getattr(config, "training_stage", None) == "stage2")
|
||||
strict = True
|
||||
return super().from_pretrained(
|
||||
pretrained_name_or_path=pretrained_name_or_path,
|
||||
config=config,
|
||||
@@ -85,6 +86,7 @@ class EVO1Policy(PreTrainedPolicy):
|
||||
"device": config.device,
|
||||
"return_cls_only": config.return_cls_only,
|
||||
"vlm_name": config.vlm_model_name,
|
||||
"image_size": int(config.image_resolution[0]),
|
||||
"vlm_num_layers": config.vlm_num_layers,
|
||||
"vlm_dtype": config.vlm_dtype,
|
||||
"use_flash_attn": config.use_flash_attn,
|
||||
@@ -100,7 +102,8 @@ class EVO1Policy(PreTrainedPolicy):
|
||||
"dropout": config.dropout,
|
||||
"num_inference_timesteps": config.num_inference_timesteps,
|
||||
"num_categories": config.num_categories,
|
||||
"enable_gradient_checkpointing": config.enable_gradient_checkpointing,
|
||||
"enable_gradient_checkpointing": config.enable_gradient_checkpointing
|
||||
and bool(config.finetune_vlm or config.finetune_language_model or config.finetune_vision_model),
|
||||
"gradient_checkpointing_use_reentrant": config.gradient_checkpointing_use_reentrant,
|
||||
"finetune_vlm": config.finetune_vlm,
|
||||
"finetune_language_model": config.finetune_language_model,
|
||||
@@ -303,6 +306,18 @@ class EVO1Policy(PreTrainedPolicy):
|
||||
or self.config.finetune_vision_model
|
||||
)
|
||||
|
||||
def _keep_frozen_embedder_eval(self) -> None:
|
||||
if self._tracks_vlm_gradients:
|
||||
return
|
||||
embedder = getattr(self.model, "embedder", None)
|
||||
if embedder is not None:
|
||||
embedder.eval()
|
||||
|
||||
def train(self, mode: bool = True):
|
||||
super().train(mode)
|
||||
self._keep_frozen_embedder_eval()
|
||||
return self
|
||||
|
||||
def _collect_image_batches(self, batch: dict[str, Tensor]) -> tuple[list[list[Tensor]], Tensor]:
|
||||
camera_keys = self._camera_keys or sorted(key for key in batch if key.startswith(f"{OBS_IMAGES}."))
|
||||
if not camera_keys:
|
||||
@@ -348,23 +363,13 @@ class EVO1Policy(PreTrainedPolicy):
|
||||
) -> Tensor:
|
||||
track_vlm_gradients = self._tracks_vlm_gradients
|
||||
grad_context = nullcontext() if track_vlm_gradients else torch.no_grad()
|
||||
embedder = getattr(self.model, "embedder", None)
|
||||
embedder_was_training = embedder.training if embedder is not None else None
|
||||
|
||||
if not track_vlm_gradients and embedder is not None:
|
||||
embedder.eval()
|
||||
|
||||
try:
|
||||
with grad_context:
|
||||
fused_tokens = self.model.get_vl_embeddings(
|
||||
images=image_batches,
|
||||
image_mask=image_masks,
|
||||
prompt=prompts,
|
||||
return_cls_only=self.config.return_cls_only,
|
||||
)
|
||||
finally:
|
||||
if not track_vlm_gradients and embedder is not None and embedder_was_training is not None:
|
||||
embedder.train(embedder_was_training)
|
||||
with grad_context:
|
||||
fused_tokens = self.model.get_vl_embeddings(
|
||||
images=image_batches,
|
||||
image_mask=image_masks,
|
||||
prompt=prompts,
|
||||
return_cls_only=self.config.return_cls_only,
|
||||
)
|
||||
|
||||
if not track_vlm_gradients:
|
||||
fused_tokens = fused_tokens.detach()
|
||||
@@ -439,7 +444,7 @@ class EVO1Policy(PreTrainedPolicy):
|
||||
embodiment_ids=embodiment_ids,
|
||||
)
|
||||
actions = actions.view(states.shape[0], self.config.chunk_size, self.config.max_action_dim)
|
||||
return actions[:, :, : self._env_action_dim]
|
||||
return actions
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor], **kwargs) -> Tensor:
|
||||
|
||||
@@ -14,17 +14,24 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs import FeatureType, PipelineFeatureType, PolicyFeature
|
||||
from lerobot.policies.evo1.configuration_evo1 import Evo1Config
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
ObservationProcessorStep,
|
||||
PolicyAction,
|
||||
PolicyActionProcessorStep,
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorStep,
|
||||
ProcessorStepRegistry,
|
||||
RenameObservationsProcessorStep,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
@@ -34,11 +41,13 @@ from lerobot.processor.converters import (
|
||||
policy_action_to_transition,
|
||||
transition_to_policy_action,
|
||||
)
|
||||
from lerobot.types import EnvTransition, TransitionKey
|
||||
from lerobot.utils.constants import (
|
||||
ACTION,
|
||||
DONE,
|
||||
INFO,
|
||||
OBS_PREFIX,
|
||||
OBS_STATE,
|
||||
POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
REWARD,
|
||||
@@ -65,6 +74,305 @@ def evo1_batch_to_transition(batch: dict[str, Any]):
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="evo1_pad_state_processor")
|
||||
class Evo1PadStateProcessorStep(ObservationProcessorStep):
|
||||
"""Pad policy observations to EVO1's fixed state width before normalization."""
|
||||
|
||||
max_state_dim: int = 24
|
||||
|
||||
def observation(self, observation: dict[str, Any]) -> dict[str, Any]:
|
||||
if OBS_STATE not in observation:
|
||||
return observation
|
||||
|
||||
state = observation[OBS_STATE]
|
||||
state_dim = state.shape[-1]
|
||||
if state_dim > self.max_state_dim:
|
||||
raise ValueError(
|
||||
f"EVO1 state has {state_dim} dims, which exceeds max_state_dim={self.max_state_dim}."
|
||||
)
|
||||
if state_dim < self.max_state_dim:
|
||||
observation = observation.copy()
|
||||
observation[OBS_STATE] = torch.nn.functional.pad(state, (0, self.max_state_dim - state_dim))
|
||||
return observation
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
new_features = {ft: feats.copy() for ft, feats in features.items()}
|
||||
state_feats = new_features.setdefault(FeatureType.STATE, {})
|
||||
if OBS_STATE in state_feats:
|
||||
state_feats[OBS_STATE] = PolicyFeature(type=FeatureType.STATE, shape=(self.max_state_dim,))
|
||||
return new_features
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {"max_state_dim": self.max_state_dim}
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="evo1_pad_action_processor")
|
||||
class Evo1PadActionProcessorStep(ProcessorStep):
|
||||
"""Pad training actions and preserve the active action dimensions with action_mask."""
|
||||
|
||||
max_action_dim: int = 24
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
action = transition.get(TransitionKey.ACTION)
|
||||
if action is None:
|
||||
return transition
|
||||
if not isinstance(action, PolicyAction):
|
||||
raise ValueError(f"EVO1 action should be a PolicyAction tensor, but got {type(action)}.")
|
||||
|
||||
action_dim = action.shape[-1]
|
||||
if action_dim > self.max_action_dim:
|
||||
raise ValueError(
|
||||
f"EVO1 action has {action_dim} dims, which exceeds max_action_dim={self.max_action_dim}."
|
||||
)
|
||||
|
||||
new_transition = transition.copy()
|
||||
new_action = action
|
||||
if action_dim < self.max_action_dim:
|
||||
new_action = torch.nn.functional.pad(action, (0, self.max_action_dim - action_dim))
|
||||
|
||||
complementary_data = dict(new_transition.get(TransitionKey.COMPLEMENTARY_DATA) or {})
|
||||
action_mask = complementary_data.get("action_mask")
|
||||
if action_mask is None:
|
||||
action_mask = torch.ones(action.shape, dtype=torch.bool, device=action.device)
|
||||
else:
|
||||
action_mask = torch.as_tensor(action_mask, dtype=torch.bool, device=action.device)
|
||||
if action_mask.shape != action.shape:
|
||||
raise ValueError(
|
||||
f"action_mask shape {tuple(action_mask.shape)} does not match action shape {tuple(action.shape)}."
|
||||
)
|
||||
if action_dim < self.max_action_dim:
|
||||
action_mask = torch.nn.functional.pad(action_mask, (0, self.max_action_dim - action_dim))
|
||||
|
||||
complementary_data["action_mask"] = action_mask
|
||||
new_transition[TransitionKey.ACTION] = new_action
|
||||
new_transition[TransitionKey.COMPLEMENTARY_DATA] = complementary_data
|
||||
return new_transition
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
new_features = {ft: feats.copy() for ft, feats in features.items()}
|
||||
action_feats = new_features.setdefault(FeatureType.ACTION, {})
|
||||
action_feats[ACTION] = PolicyFeature(type=FeatureType.ACTION, shape=(self.max_action_dim,))
|
||||
return new_features
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {"max_action_dim": self.max_action_dim}
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="evo1_action_processor")
|
||||
class Evo1ActionProcessorStep(PolicyActionProcessorStep):
|
||||
"""Crop padded EVO1 actions and optionally binarize the LIBERO gripper channel."""
|
||||
|
||||
action_dim: int
|
||||
binarize_gripper: bool = False
|
||||
gripper_index: int = 6
|
||||
gripper_threshold: float = 0.5
|
||||
gripper_below_threshold_value: float = 1.0
|
||||
gripper_above_threshold_value: float = -1.0
|
||||
|
||||
def action(self, action: PolicyAction) -> PolicyAction:
|
||||
if action.shape[-1] < self.action_dim:
|
||||
raise ValueError(
|
||||
f"EVO1 action has {action.shape[-1]} dims, which is smaller than action_dim={self.action_dim}."
|
||||
)
|
||||
|
||||
action = action[..., : self.action_dim]
|
||||
if not self.binarize_gripper:
|
||||
return action
|
||||
|
||||
if not 0 <= self.gripper_index < self.action_dim:
|
||||
raise ValueError(
|
||||
f"gripper_index={self.gripper_index} must be within action_dim={self.action_dim}."
|
||||
)
|
||||
|
||||
action = action.clone()
|
||||
below = torch.as_tensor(
|
||||
self.gripper_below_threshold_value,
|
||||
dtype=action.dtype,
|
||||
device=action.device,
|
||||
)
|
||||
above = torch.as_tensor(
|
||||
self.gripper_above_threshold_value,
|
||||
dtype=action.dtype,
|
||||
device=action.device,
|
||||
)
|
||||
action[..., self.gripper_index] = torch.where(
|
||||
action[..., self.gripper_index] > self.gripper_threshold,
|
||||
above,
|
||||
below,
|
||||
)
|
||||
return action
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
new_features = {ft: feats.copy() for ft, feats in features.items()}
|
||||
action_feats = new_features.setdefault(FeatureType.ACTION, {})
|
||||
action_feats[ACTION] = PolicyFeature(type=FeatureType.ACTION, shape=(self.action_dim,))
|
||||
return new_features
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {
|
||||
"action_dim": self.action_dim,
|
||||
"binarize_gripper": self.binarize_gripper,
|
||||
"gripper_index": self.gripper_index,
|
||||
"gripper_threshold": self.gripper_threshold,
|
||||
"gripper_below_threshold_value": self.gripper_below_threshold_value,
|
||||
"gripper_above_threshold_value": self.gripper_above_threshold_value,
|
||||
}
|
||||
|
||||
|
||||
def _evo1_action_dim(config: Evo1Config) -> int:
|
||||
if config.postprocess_action_dim is not None:
|
||||
return config.postprocess_action_dim
|
||||
action_feature = config.action_feature
|
||||
if action_feature is None:
|
||||
return config.max_action_dim
|
||||
return int(action_feature.shape[0])
|
||||
|
||||
|
||||
def _evo1_normalization_features(config: Evo1Config) -> dict[str, PolicyFeature]:
|
||||
features = {**config.input_features, **config.output_features}
|
||||
features[OBS_STATE] = PolicyFeature(type=FeatureType.STATE, shape=(config.max_state_dim,))
|
||||
features[ACTION] = PolicyFeature(type=FeatureType.ACTION, shape=(config.max_action_dim,))
|
||||
return features
|
||||
|
||||
|
||||
def _evo1_action_features(config: Evo1Config) -> dict[str, PolicyFeature]:
|
||||
return {ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(config.max_action_dim,))}
|
||||
|
||||
|
||||
_STAT_PAD_VALUES = {
|
||||
"mean": 0.0,
|
||||
"std": 1.0,
|
||||
"min": -1.0,
|
||||
"max": 1.0,
|
||||
"q01": -1.0,
|
||||
"q99": 1.0,
|
||||
"q10": -1.0,
|
||||
"q90": 1.0,
|
||||
}
|
||||
|
||||
|
||||
def _pad_stat_value(value: Any, target_dim: int, stat_name: str) -> torch.Tensor:
|
||||
tensor = torch.as_tensor(value)
|
||||
if not tensor.is_floating_point():
|
||||
tensor = tensor.to(dtype=torch.float32)
|
||||
if tensor.ndim == 0 or tensor.shape[-1] >= target_dim:
|
||||
return tensor
|
||||
|
||||
pad_shape = (*tensor.shape[:-1], target_dim - tensor.shape[-1])
|
||||
pad_value = _STAT_PAD_VALUES.get(stat_name, 0.0)
|
||||
padding = torch.full(pad_shape, pad_value, dtype=tensor.dtype, device=tensor.device)
|
||||
return torch.cat([tensor, padding], dim=-1)
|
||||
|
||||
|
||||
def _pad_feature_stats(
|
||||
stats: dict[str, dict[str, Any]],
|
||||
feature_key: str,
|
||||
target_dim: int,
|
||||
) -> None:
|
||||
if feature_key not in stats:
|
||||
return
|
||||
stats[feature_key] = {
|
||||
stat_name: _pad_stat_value(stat_value, target_dim, stat_name)
|
||||
for stat_name, stat_value in stats[feature_key].items()
|
||||
}
|
||||
|
||||
|
||||
def _pad_evo1_stats(
|
||||
config: Evo1Config,
|
||||
stats: dict[str, dict[str, Any]] | None,
|
||||
) -> dict[str, dict[str, Any]] | None:
|
||||
if stats is None:
|
||||
return None
|
||||
|
||||
padded_stats = deepcopy(stats)
|
||||
# Added dimensions represent zero-padding inside EVO1. These neutral stats keep
|
||||
# padded observations at normalized zero and only provide shape compatibility.
|
||||
_pad_feature_stats(padded_stats, OBS_STATE, config.max_state_dim)
|
||||
_pad_feature_stats(padded_stats, ACTION, config.max_action_dim)
|
||||
return padded_stats
|
||||
|
||||
|
||||
def _refresh_evo1_normalization_steps(
|
||||
config: Evo1Config,
|
||||
preprocessor: PolicyProcessorPipeline,
|
||||
postprocessor: PolicyProcessorPipeline,
|
||||
) -> None:
|
||||
normalization_features = _evo1_normalization_features(config)
|
||||
action_features = _evo1_action_features(config)
|
||||
|
||||
for step in preprocessor.steps:
|
||||
if isinstance(step, NormalizerProcessorStep):
|
||||
step.features = normalization_features
|
||||
step.stats = _pad_evo1_stats(config, step.stats)
|
||||
step.to(device=step.device, dtype=step.dtype)
|
||||
|
||||
for step in postprocessor.steps:
|
||||
if isinstance(step, UnnormalizerProcessorStep):
|
||||
step.features = action_features
|
||||
step.stats = _pad_evo1_stats(config, step.stats)
|
||||
step.to(device=step.device, dtype=step.dtype)
|
||||
|
||||
|
||||
def ensure_evo1_processor_steps(
|
||||
config: Evo1Config,
|
||||
preprocessor: PolicyProcessorPipeline,
|
||||
postprocessor: PolicyProcessorPipeline,
|
||||
) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]:
|
||||
"""Add EVO1 processor steps when loading older checkpoints that do not serialize them."""
|
||||
|
||||
has_state_padding = any(isinstance(step, Evo1PadStateProcessorStep) for step in preprocessor.steps)
|
||||
if not has_state_padding:
|
||||
steps = list(preprocessor.steps)
|
||||
insert_idx = next(
|
||||
(idx for idx, step in enumerate(steps) if isinstance(step, NormalizerProcessorStep)),
|
||||
len(steps),
|
||||
)
|
||||
steps.insert(insert_idx, Evo1PadStateProcessorStep(max_state_dim=config.max_state_dim))
|
||||
preprocessor.steps = steps
|
||||
|
||||
has_action_padding = any(isinstance(step, Evo1PadActionProcessorStep) for step in preprocessor.steps)
|
||||
if not has_action_padding:
|
||||
steps = list(preprocessor.steps)
|
||||
insert_idx = next(
|
||||
(idx for idx, step in enumerate(steps) if isinstance(step, NormalizerProcessorStep)),
|
||||
len(steps),
|
||||
)
|
||||
steps.insert(insert_idx, Evo1PadActionProcessorStep(max_action_dim=config.max_action_dim))
|
||||
preprocessor.steps = steps
|
||||
|
||||
has_action_processor = any(isinstance(step, Evo1ActionProcessorStep) for step in postprocessor.steps)
|
||||
if not has_action_processor:
|
||||
steps = list(postprocessor.steps)
|
||||
insert_idx = next(
|
||||
(idx + 1 for idx, step in enumerate(steps) if isinstance(step, UnnormalizerProcessorStep)),
|
||||
0,
|
||||
)
|
||||
steps.insert(
|
||||
insert_idx,
|
||||
Evo1ActionProcessorStep(
|
||||
action_dim=_evo1_action_dim(config),
|
||||
binarize_gripper=config.binarize_gripper,
|
||||
gripper_index=config.gripper_index,
|
||||
gripper_threshold=config.gripper_threshold,
|
||||
gripper_below_threshold_value=config.gripper_below_threshold_value,
|
||||
gripper_above_threshold_value=config.gripper_above_threshold_value,
|
||||
),
|
||||
)
|
||||
postprocessor.steps = steps
|
||||
|
||||
_refresh_evo1_normalization_steps(config, preprocessor, postprocessor)
|
||||
return preprocessor, postprocessor
|
||||
|
||||
|
||||
def make_evo1_pre_post_processors(
|
||||
config: Evo1Config,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
@@ -72,21 +380,35 @@ def make_evo1_pre_post_processors(
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
normalization_features = _evo1_normalization_features(config)
|
||||
action_features = _evo1_action_features(config)
|
||||
normalization_stats = _pad_evo1_stats(config, dataset_stats)
|
||||
|
||||
input_steps = [
|
||||
RenameObservationsProcessorStep(rename_map={}),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
Evo1PadStateProcessorStep(max_state_dim=config.max_state_dim),
|
||||
Evo1PadActionProcessorStep(max_action_dim=config.max_action_dim),
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
features=normalization_features,
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
stats=normalization_stats,
|
||||
),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
]
|
||||
output_steps = [
|
||||
UnnormalizerProcessorStep(
|
||||
features=config.output_features,
|
||||
features=action_features,
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
stats=normalization_stats,
|
||||
),
|
||||
Evo1ActionProcessorStep(
|
||||
action_dim=_evo1_action_dim(config),
|
||||
binarize_gripper=config.binarize_gripper,
|
||||
gripper_index=config.gripper_index,
|
||||
gripper_threshold=config.gripper_threshold,
|
||||
gripper_below_threshold_value=config.gripper_below_threshold_value,
|
||||
gripper_above_threshold_value=config.gripper_above_threshold_value,
|
||||
),
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
]
|
||||
|
||||
@@ -311,6 +311,14 @@ def make_pre_post_processors(
|
||||
to_output=transition_to_policy_action,
|
||||
)
|
||||
_reconnect_relative_absolute_steps(preprocessor, postprocessor)
|
||||
if isinstance(policy_cfg, Evo1Config):
|
||||
from .evo1.processor_evo1 import ensure_evo1_processor_steps
|
||||
|
||||
preprocessor, postprocessor = ensure_evo1_processor_steps(
|
||||
policy_cfg,
|
||||
preprocessor,
|
||||
postprocessor,
|
||||
)
|
||||
return preprocessor, postprocessor
|
||||
|
||||
# Create a new processor based on policy type
|
||||
|
||||
@@ -40,7 +40,7 @@ from .converters import (
|
||||
)
|
||||
from .delta_action_processor import MapDeltaActionToRobotActionStep, MapTensorToDeltaActionDictStep
|
||||
from .device_processor import DeviceProcessorStep
|
||||
from .env_processor import IsaaclabArenaProcessorStep, LiberoActionProcessorStep, LiberoProcessorStep
|
||||
from .env_processor import IsaaclabArenaProcessorStep, LiberoProcessorStep
|
||||
from .factory import (
|
||||
make_default_processors,
|
||||
make_default_robot_action_processor,
|
||||
@@ -149,7 +149,6 @@ __all__ = [
|
||||
"RewardProcessorStep",
|
||||
"DataProcessorPipeline",
|
||||
"IsaaclabArenaProcessorStep",
|
||||
"LiberoActionProcessorStep",
|
||||
"LiberoProcessorStep",
|
||||
"TimeLimitProcessorStep",
|
||||
"AddBatchDimensionProcessorStep",
|
||||
|
||||
@@ -18,9 +18,9 @@ from dataclasses import dataclass
|
||||
import torch
|
||||
|
||||
from lerobot.configs import FeatureType, PipelineFeatureType, PolicyFeature
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_PREFIX, OBS_STATE, OBS_STR
|
||||
from lerobot.utils.constants import OBS_IMAGES, OBS_PREFIX, OBS_STATE, OBS_STR
|
||||
|
||||
from .pipeline import ActionProcessorStep, ObservationProcessorStep, ProcessorStepRegistry
|
||||
from .pipeline import ObservationProcessorStep, ProcessorStepRegistry
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -46,8 +46,6 @@ class LiberoProcessorStep(ObservationProcessorStep):
|
||||
- This accounts for the HuggingFaceVLA/libero camera orientation convention.
|
||||
"""
|
||||
|
||||
max_state_dim: int | None = None
|
||||
|
||||
def _process_observation(self, observation):
|
||||
"""
|
||||
Processes both image and robot_state observations from LIBERO.
|
||||
@@ -80,16 +78,6 @@ class LiberoProcessorStep(ObservationProcessorStep):
|
||||
state = state.float()
|
||||
if state.dim() == 1:
|
||||
state = state.unsqueeze(0)
|
||||
if self.max_state_dim is not None:
|
||||
if state.shape[-1] > self.max_state_dim:
|
||||
raise ValueError(
|
||||
f"LIBERO state has {state.shape[-1]} dims, which is larger than "
|
||||
f"configured max_state_dim={self.max_state_dim}."
|
||||
)
|
||||
if state.shape[-1] < self.max_state_dim:
|
||||
pad_width = self.max_state_dim - state.shape[-1]
|
||||
state = torch.nn.functional.pad(state, (0, pad_width))
|
||||
|
||||
processed_obs[OBS_STATE] = state
|
||||
return processed_obs
|
||||
|
||||
@@ -112,7 +100,7 @@ class LiberoProcessorStep(ObservationProcessorStep):
|
||||
# add our new flattened state
|
||||
state_feats[OBS_STATE] = PolicyFeature(
|
||||
type=FeatureType.STATE,
|
||||
shape=(self.max_state_dim or 8,), # [eef_pos(3), axis_angle(3), gripper(2)] plus padding
|
||||
shape=(8,), # [eef_pos(3), axis_angle(3), gripper(2)]
|
||||
)
|
||||
|
||||
new_features[FeatureType.STATE] = state_feats
|
||||
@@ -122,9 +110,6 @@ class LiberoProcessorStep(ObservationProcessorStep):
|
||||
def observation(self, observation):
|
||||
return self._process_observation(observation)
|
||||
|
||||
def get_config(self) -> dict:
|
||||
return {"max_state_dim": self.max_state_dim}
|
||||
|
||||
def _quat2axisangle(self, quat: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Convert batched quaternions to axis-angle format.
|
||||
@@ -167,32 +152,6 @@ class LiberoProcessorStep(ObservationProcessorStep):
|
||||
return result
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="libero_action_processor")
|
||||
class LiberoActionProcessorStep(ActionProcessorStep):
|
||||
"""Slices padded policy actions back to the executable LIBERO action space."""
|
||||
|
||||
action_dim: int = 7
|
||||
|
||||
def action(self, action):
|
||||
if action.shape[-1] < self.action_dim:
|
||||
raise ValueError(
|
||||
f"LIBERO action has {action.shape[-1]} dims, which is smaller than action_dim={self.action_dim}."
|
||||
)
|
||||
return action[..., : self.action_dim]
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
new_features = {ft: feats.copy() for ft, feats in features.items()}
|
||||
action_feats = new_features.setdefault(FeatureType.ACTION, {})
|
||||
action_feats[ACTION] = PolicyFeature(type=FeatureType.ACTION, shape=(self.action_dim,))
|
||||
return new_features
|
||||
|
||||
def get_config(self) -> dict:
|
||||
return {"action_dim": self.action_dim}
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="isaaclab_arena_processor")
|
||||
class IsaaclabArenaProcessorStep(ObservationProcessorStep):
|
||||
|
||||
@@ -191,7 +191,7 @@ def rollout(
|
||||
action = action_transition[ACTION]
|
||||
|
||||
# Convert to CPU / numpy.
|
||||
action_numpy: np.ndarray = action.to("cpu").numpy()
|
||||
action_numpy = _action_to_env_numpy(action)
|
||||
assert action_numpy.ndim == 2, "Action dimensions should be (batch, action_dim)"
|
||||
|
||||
# Apply the next action.
|
||||
@@ -261,6 +261,11 @@ def rollout(
|
||||
return ret
|
||||
|
||||
|
||||
def _action_to_env_numpy(action: Tensor) -> np.ndarray:
|
||||
"""Convert policy actions to a NumPy array accepted by Gym environments."""
|
||||
return action.detach().to(device="cpu", dtype=torch.float32).numpy()
|
||||
|
||||
|
||||
def eval_policy(
|
||||
env: gym.vector.VectorEnv,
|
||||
policy: PreTrainedPolicy,
|
||||
|
||||
+10
-37
@@ -13,7 +13,7 @@ from gymnasium.envs.registration import register, registry as gym_registry
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.envs.configs import EnvConfig, LiberoEnv
|
||||
from lerobot.envs.factory import make_env, make_env_config, make_env_pre_post_processors
|
||||
from lerobot.processor import LiberoActionProcessorStep, LiberoProcessorStep
|
||||
from lerobot.processor import LiberoProcessorStep
|
||||
from lerobot.utils.constants import OBS_PREFIX, OBS_STATE
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -86,32 +86,18 @@ def test_processors_delegation_supports_legacy_override_signature():
|
||||
assert isinstance(post, DataProcessorPipeline)
|
||||
|
||||
|
||||
def test_libero_evo1_processors_use_padded_state_and_env_action_dim():
|
||||
"""EVO1 uses padded LIBERO state features while env actions stay executable."""
|
||||
|
||||
class _Evo1Config:
|
||||
type = "evo1"
|
||||
max_state_dim = 24
|
||||
|
||||
def test_libero_processors_are_policy_agnostic():
|
||||
cfg = LiberoEnv()
|
||||
pre, post = make_env_pre_post_processors(cfg, policy_cfg=_Evo1Config())
|
||||
pre, post = make_env_pre_post_processors(cfg, policy_cfg=object())
|
||||
|
||||
assert isinstance(pre.steps[0], LiberoProcessorStep)
|
||||
assert pre.steps[0].max_state_dim == 24
|
||||
assert isinstance(post.steps[0], LiberoActionProcessorStep)
|
||||
assert post.steps[0].action_dim == cfg.features["action"].shape[0] == 7
|
||||
|
||||
class _OtherConfig:
|
||||
type = "other"
|
||||
|
||||
pre_other, _ = make_env_pre_post_processors(cfg, policy_cfg=_OtherConfig())
|
||||
assert pre_other.steps[0].max_state_dim is None
|
||||
assert len(post.steps) == 0
|
||||
|
||||
|
||||
def test_libero_processor_pads_state_to_max_dim():
|
||||
step = LiberoProcessorStep(max_state_dim=24)
|
||||
def test_libero_processor_flattens_state_to_raw_8_dim():
|
||||
step = LiberoProcessorStep()
|
||||
observation = {
|
||||
OBS_PREFIX
|
||||
+ "robot_state": {
|
||||
OBS_PREFIX + "robot_state": {
|
||||
"eef": {
|
||||
"pos": torch.tensor([[1.0, 2.0, 3.0]]),
|
||||
"quat": torch.tensor([[0.0, 0.0, 0.0, 1.0]]),
|
||||
@@ -121,21 +107,8 @@ def test_libero_processor_pads_state_to_max_dim():
|
||||
}
|
||||
|
||||
state = step.observation(observation)[OBS_STATE]
|
||||
assert state.shape == (1, 24)
|
||||
assert torch.allclose(state[:, :8], torch.tensor([[1.0, 2.0, 3.0, 0.0, 0.0, 0.0, 4.0, 5.0]]))
|
||||
assert torch.count_nonzero(state[:, 8:]).item() == 0
|
||||
|
||||
|
||||
def test_libero_action_processor_slices_padded_action():
|
||||
step = LiberoActionProcessorStep(action_dim=7)
|
||||
action = torch.arange(2 * 3 * 24, dtype=torch.float32).reshape(2, 3, 24)
|
||||
|
||||
sliced = step.action(action)
|
||||
assert sliced.shape == (2, 3, 7)
|
||||
assert torch.equal(sliced, action[..., :7])
|
||||
|
||||
with pytest.raises(ValueError, match="smaller than action_dim=7"):
|
||||
step.action(torch.zeros(2, 6))
|
||||
assert state.shape == (1, 8)
|
||||
assert torch.allclose(state, torch.tensor([[1.0, 2.0, 3.0, 0.0, 0.0, 0.0, 4.0, 5.0]]))
|
||||
|
||||
|
||||
def test_base_create_envs():
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
@@ -23,7 +24,15 @@ import lerobot.policies.evo1.modeling_evo1 as modeling_evo1
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.policies.evo1.configuration_evo1 import Evo1Config
|
||||
from lerobot.policies.evo1.flow_matching import FlowmatchingActionHead
|
||||
from lerobot.policies.evo1.processor_evo1 import (
|
||||
Evo1ActionProcessorStep,
|
||||
Evo1PadActionProcessorStep,
|
||||
Evo1PadStateProcessorStep,
|
||||
ensure_evo1_processor_steps,
|
||||
make_evo1_pre_post_processors,
|
||||
)
|
||||
from lerobot.policies.factory import get_policy_class, make_policy_config
|
||||
from lerobot.processor import NormalizerProcessorStep, PolicyProcessorPipeline, UnnormalizerProcessorStep
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
|
||||
|
||||
STATE_DIM = 4
|
||||
@@ -108,6 +117,19 @@ def make_batch(include_action=True):
|
||||
return batch
|
||||
|
||||
|
||||
def make_stats(state_dim=STATE_DIM, action_dim=ACTION_DIM):
|
||||
return {
|
||||
OBS_STATE: {
|
||||
"min": torch.full((state_dim,), -2.0),
|
||||
"max": torch.full((state_dim,), 2.0),
|
||||
},
|
||||
ACTION: {
|
||||
"min": torch.full((action_dim,), -1.0),
|
||||
"max": torch.full((action_dim,), 1.0),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def test_evo1_factory_registration():
|
||||
cfg = make_policy_config(
|
||||
"evo1",
|
||||
@@ -191,22 +213,151 @@ def test_evo1_stage_defaults_and_consistency():
|
||||
raise AssertionError("Expected inconsistent finetune config to raise ValueError")
|
||||
|
||||
|
||||
def test_evo1_rejects_non_square_image_resolution():
|
||||
with pytest.raises(ValueError, match="square image_resolution"):
|
||||
make_config(image_resolution=(448, 320))
|
||||
|
||||
|
||||
def test_evo1_build_model_config_uses_image_resolution_and_trainable_checkpointing():
|
||||
stage1 = make_config(training_stage="stage1", image_resolution=(224, 224))
|
||||
stage1_model_config = modeling_evo1.EVO1Policy._build_model_config(stage1)
|
||||
|
||||
assert stage1_model_config["image_size"] == 224
|
||||
assert stage1_model_config["enable_gradient_checkpointing"] is False
|
||||
|
||||
stage2 = make_config(training_stage="stage2", image_resolution=(224, 224))
|
||||
stage2_model_config = modeling_evo1.EVO1Policy._build_model_config(stage2)
|
||||
|
||||
assert stage2_model_config["enable_gradient_checkpointing"] is True
|
||||
|
||||
|
||||
def test_evo1_policy_processors_pad_state_crop_action_and_binarize_gripper():
|
||||
libero_action_dim = 7
|
||||
config = make_config(
|
||||
max_state_dim=MAX_STATE_DIM,
|
||||
max_action_dim=8,
|
||||
postprocess_action_dim=libero_action_dim,
|
||||
binarize_gripper=True,
|
||||
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(libero_action_dim,))},
|
||||
)
|
||||
stats = make_stats(action_dim=libero_action_dim)
|
||||
|
||||
preprocessor, postprocessor = make_evo1_pre_post_processors(config, dataset_stats=stats)
|
||||
|
||||
assert isinstance(preprocessor.steps[2], Evo1PadStateProcessorStep)
|
||||
assert isinstance(preprocessor.steps[3], Evo1PadActionProcessorStep)
|
||||
assert isinstance(preprocessor.steps[4], NormalizerProcessorStep)
|
||||
assert isinstance(postprocessor.steps[0], UnnormalizerProcessorStep)
|
||||
assert isinstance(postprocessor.steps[1], Evo1ActionProcessorStep)
|
||||
|
||||
normalizer = preprocessor.steps[4]
|
||||
assert normalizer.features[OBS_STATE].shape == (MAX_STATE_DIM,)
|
||||
assert normalizer.features[ACTION].shape == (8,)
|
||||
assert normalizer._tensor_stats[OBS_STATE]["min"].shape == (MAX_STATE_DIM,)
|
||||
assert normalizer._tensor_stats[ACTION]["min"].shape == (8,)
|
||||
|
||||
processed_batch = preprocessor(
|
||||
{
|
||||
"task": "pick the block",
|
||||
OBS_STATE: torch.zeros(STATE_DIM),
|
||||
ACTION: torch.zeros(libero_action_dim),
|
||||
f"{OBS_IMAGES}.front": torch.rand(3, 16, 16),
|
||||
}
|
||||
)
|
||||
processed_state = processed_batch[OBS_STATE]
|
||||
assert processed_state.shape == (1, MAX_STATE_DIM)
|
||||
assert torch.allclose(processed_state, torch.zeros_like(processed_state))
|
||||
assert processed_batch[ACTION].shape == (1, 8)
|
||||
assert torch.allclose(processed_batch[ACTION], torch.zeros_like(processed_batch[ACTION]))
|
||||
assert processed_batch["action_mask"].shape == (1, 8)
|
||||
assert processed_batch["action_mask"][:, :libero_action_dim].all()
|
||||
assert not processed_batch["action_mask"][:, libero_action_dim:].any()
|
||||
|
||||
action = torch.tensor(
|
||||
[
|
||||
[0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.5, 0.7],
|
||||
[0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7],
|
||||
],
|
||||
dtype=torch.float32,
|
||||
)
|
||||
processed = postprocessor(action)
|
||||
|
||||
assert processed.shape == (2, 7)
|
||||
assert torch.allclose(processed[:, :6], action[:, :6])
|
||||
assert torch.equal(processed[:, 6], torch.tensor([1.0, -1.0]))
|
||||
|
||||
|
||||
def test_evo1_legacy_processors_are_completed_before_normalization():
|
||||
config = make_config(
|
||||
max_state_dim=MAX_STATE_DIM,
|
||||
max_action_dim=8,
|
||||
postprocess_action_dim=7,
|
||||
binarize_gripper=True,
|
||||
)
|
||||
stats = make_stats(action_dim=7)
|
||||
legacy_pre = PolicyProcessorPipeline(
|
||||
steps=[
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=stats,
|
||||
)
|
||||
]
|
||||
)
|
||||
legacy_post = PolicyProcessorPipeline(
|
||||
steps=[
|
||||
UnnormalizerProcessorStep(
|
||||
features=config.output_features,
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=stats,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
preprocessor, postprocessor = ensure_evo1_processor_steps(config, legacy_pre, legacy_post)
|
||||
|
||||
assert isinstance(preprocessor.steps[0], Evo1PadStateProcessorStep)
|
||||
assert isinstance(preprocessor.steps[1], Evo1PadActionProcessorStep)
|
||||
assert isinstance(preprocessor.steps[2], NormalizerProcessorStep)
|
||||
assert isinstance(postprocessor.steps[0], UnnormalizerProcessorStep)
|
||||
assert isinstance(postprocessor.steps[1], Evo1ActionProcessorStep)
|
||||
assert postprocessor.steps[1].action_dim == 7
|
||||
assert postprocessor.steps[1].binarize_gripper is True
|
||||
assert preprocessor.steps[2].features[OBS_STATE].shape == (MAX_STATE_DIM,)
|
||||
assert preprocessor.steps[2]._tensor_stats[OBS_STATE]["min"].shape == (MAX_STATE_DIM,)
|
||||
assert preprocessor.steps[2]._tensor_stats[ACTION]["min"].shape == (8,)
|
||||
assert postprocessor.steps[0].features[ACTION].shape == (8,)
|
||||
assert postprocessor.steps[0]._tensor_stats[ACTION]["min"].shape == (8,)
|
||||
|
||||
preprocessor, postprocessor = ensure_evo1_processor_steps(config, preprocessor, postprocessor)
|
||||
assert sum(isinstance(step, Evo1PadStateProcessorStep) for step in preprocessor.steps) == 1
|
||||
assert sum(isinstance(step, Evo1PadActionProcessorStep) for step in preprocessor.steps) == 1
|
||||
assert sum(isinstance(step, Evo1ActionProcessorStep) for step in postprocessor.steps) == 1
|
||||
|
||||
|
||||
def test_evo1_policy_forward_and_inference_use_batched_embedding(monkeypatch):
|
||||
monkeypatch.setattr(modeling_evo1, "EVO1", DummyEVO1)
|
||||
policy = modeling_evo1.EVO1Policy(make_config())
|
||||
preprocessor, _postprocessor = make_evo1_pre_post_processors(policy.config, dataset_stats=make_stats())
|
||||
training_batch = preprocessor(make_batch(include_action=True))
|
||||
|
||||
loss, metrics = policy.forward(make_batch(include_action=True))
|
||||
assert training_batch[ACTION].shape == (2, CHUNK_SIZE, MAX_ACTION_DIM)
|
||||
assert training_batch["action_mask"].shape == (2, CHUNK_SIZE, MAX_ACTION_DIM)
|
||||
assert training_batch["action_mask"][:, :, :ACTION_DIM].all()
|
||||
assert not training_batch["action_mask"][:, :, ACTION_DIM:].any()
|
||||
|
||||
loss, metrics = policy.forward(training_batch)
|
||||
assert loss.ndim == 0
|
||||
assert torch.isfinite(loss)
|
||||
assert metrics["active_action_dims"] == ACTION_DIM * CHUNK_SIZE
|
||||
assert policy.model.get_vl_embeddings_calls == 1
|
||||
|
||||
action_chunk = policy.predict_action_chunk(make_batch(include_action=False))
|
||||
assert action_chunk.shape == (2, CHUNK_SIZE, ACTION_DIM)
|
||||
assert action_chunk.shape == (2, CHUNK_SIZE, MAX_ACTION_DIM)
|
||||
|
||||
policy.reset()
|
||||
selected = policy.select_action(make_batch(include_action=False))
|
||||
assert selected.shape == (2, ACTION_DIM)
|
||||
assert selected.shape == (2, MAX_ACTION_DIM)
|
||||
|
||||
|
||||
def test_stage1_frozen_vlm_embeddings_do_not_track_gradients(monkeypatch):
|
||||
@@ -220,7 +371,7 @@ def test_stage1_frozen_vlm_embeddings_do_not_track_gradients(monkeypatch):
|
||||
assert policy.model.grad_enabled_calls == [False]
|
||||
assert policy.model.embedder_training_calls == [False]
|
||||
assert not fused_tokens.requires_grad
|
||||
assert policy.model.embedder.training is True
|
||||
assert policy.model.embedder.training is False
|
||||
|
||||
|
||||
def test_stage2_vlm_embeddings_track_gradients(monkeypatch):
|
||||
|
||||
@@ -0,0 +1,14 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.scripts.lerobot_eval import _action_to_env_numpy
|
||||
|
||||
|
||||
def test_action_to_env_numpy_casts_bfloat16_to_float32():
|
||||
action = torch.tensor([[0.5, -1.0]], dtype=torch.bfloat16)
|
||||
|
||||
action_numpy = _action_to_env_numpy(action)
|
||||
|
||||
assert action_numpy.shape == (1, 2)
|
||||
assert action_numpy.dtype == np.float32
|
||||
np.testing.assert_allclose(action_numpy, np.array([[0.5, -1.0]], dtype=np.float32))
|
||||
Reference in New Issue
Block a user