diff --git a/docs/source/evo1.mdx b/docs/source/evo1.mdx index d768103f6..eda8d65fa 100644 --- a/docs/source/evo1.mdx +++ b/docs/source/evo1.mdx @@ -139,6 +139,8 @@ every finetuning flag. | `policy.n_action_steps` | `50` | Number of actions consumed from a sampled chunk | | `policy.max_state_dim` | `24` | State padding dimension | | `policy.max_action_dim` | `24` | Action padding dimension | +| `policy.postprocess_action_dim` | `null` | Optional action dimension returned after EVO1 postprocessing | +| `policy.binarize_gripper` | `false` | Binarizes the postprocessed gripper channel for LIBERO-style eval | | `policy.task_field` | `task` | Batch field used as the language prompt | ## Results @@ -161,16 +163,20 @@ pixel embeddings, VLM fused tokens, normalized actions, and denormalized actions The published checkpoint expects the raw LIBERO camera feature names `observation.images.agentview_image` and `observation.images.robot0_eye_in_hand_image`. The official EVO1 LIBERO rollout protocol also replans every 14 actions and binarizes the gripper command before stepping the simulator. -The LIBERO environment postprocessor applies the gripper binarization automatically for EVO1 policies. To run the -converted checkpoint with LeRobot LIBERO evaluation for the same one-episode-per-task setting, keep the raw camera -names instead of the default `image`/`image2` mapping and override `policy.n_action_steps` to 14: +The EVO1 policy postprocessor can crop the padded 24D action back to the 7D LIBERO action space and apply that +gripper binarization. To run the converted checkpoint with LeRobot LIBERO evaluation for the same +one-episode-per-task setting, keep the raw camera names instead of the default `image`/`image2` mapping, enable +FlashAttention, and set the LIBERO action postprocessing flags: ```bash lerobot-eval \ --policy.path=javadcc/evo1-libero-lerobot \ --policy.vlm_model_name=OpenGVLab/InternVL3-1B \ --policy.device=cuda \ + --policy.use_flash_attn=true \ --policy.n_action_steps=14 \ + --policy.postprocess_action_dim=7 \ + --policy.binarize_gripper=true \ --env.type=libero \ --env.task=libero_object \ --env.camera_name_mapping="{agentview_image: agentview_image, robot0_eye_in_hand_image: robot0_eye_in_hand_image}" \ diff --git a/pyproject.toml b/pyproject.toml index dceece62a..b6533cf22 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -140,6 +140,7 @@ pyserial-dep = ["pyserial>=3.5,<4.0"] deepdiff-dep = ["deepdiff>=7.0.1,<9.0.0"] pynput-dep = ["pynput>=1.7.8,<1.9.0"] pyzmq-dep = ["pyzmq>=26.2.1,<28.0.0"] +timm-dep = ["timm>=1.0.0,<1.1.0"] # Motors feetech = ["feetech-servo-sdk>=1.0.0,<2.0.0", "lerobot[pyserial-dep]", "lerobot[deepdiff-dep]"] @@ -187,7 +188,7 @@ groot = [ "lerobot[peft-dep]", "lerobot[diffusers-dep]", "dm-tree>=0.1.8,<1.0.0", - "timm>=1.0.0,<1.1.0", + "lerobot[timm-dep]", "decord>=0.6.0,<1.0.0; (platform_machine == 'AMD64' or platform_machine == 'x86_64')", "ninja>=1.11.1,<2.0.0", "flash-attn>=2.5.9,<3.0.0 ; sys_platform != 'darwin'" @@ -195,7 +196,7 @@ groot = [ sarm = ["lerobot[transformers-dep]", "pydantic>=2.0.0,<3.0.0", "faker>=33.0.0,<35.0.0", "lerobot[matplotlib-dep]", "lerobot[qwen-vl-utils-dep]"] xvla = ["lerobot[transformers-dep]"] eo1 = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]"] -evo1 = ["lerobot[transformers-dep]", "timm>=1.0.0,<1.1.0"] +evo1 = ["lerobot[transformers-dep]", "lerobot[timm-dep]"] hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"] # Features @@ -350,7 +351,6 @@ ignore = [ # E402: conditional-import guards (TYPE_CHECKING / is_package_available) must precede the imports they protect "src/lerobot/scripts/convert_dataset_v21_to_v30.py" = ["E402"] "src/lerobot/policies/wall_x/**" = ["N801", "N812", "SIM102", "SIM108", "SIM210", "SIM211", "B006", "B007", "SIM118"] # Supprese these as they are coming from original Qwen2_5_vl code TODO(pepijn): refactor original -"src/lerobot/policies/evo1/**" = ["N801", "N812"] [tool.ruff.lint.isort] combine-as-imports = true diff --git a/src/lerobot/envs/configs.py b/src/lerobot/envs/configs.py index 0807641f4..e53f668a3 100644 --- a/src/lerobot/envs/configs.py +++ b/src/lerobot/envs/configs.py @@ -26,7 +26,6 @@ from gymnasium.envs.registration import registry as gym_registry from lerobot.configs import FeatureType, PolicyFeature from lerobot.processor import ( IsaaclabArenaProcessorStep, - LiberoActionProcessorStep, LiberoProcessorStep, PolicyProcessorPipeline, ) @@ -128,7 +127,7 @@ class EnvConfig(draccus.ChoiceRegistry, abc.ABC): vec = env_cls([_make_one for _ in range(n_envs)], **extra_kwargs) return {self.type: {0: vec}} - def get_env_processors(self, policy_cfg: Any | None = None): + def get_env_processors(self): """Return (preprocessor, postprocessor) for this env. Default: identity.""" return PolicyProcessorPipeline(steps=[]), PolicyProcessorPipeline(steps=[]) @@ -357,7 +356,6 @@ class LiberoEnv(EnvConfig): } ) control_mode: str = "relative" # or "absolute" - binarize_gripper: bool | None = None def __post_init__(self): if self.obs_type == "pixels": @@ -442,22 +440,10 @@ class LiberoEnv(EnvConfig): is_libero_plus=self.is_libero_plus, ) - def get_env_processors(self, policy_cfg: Any | None = None): - is_evo1 = getattr(policy_cfg, "type", None) == "evo1" - max_state_dim = getattr(policy_cfg, "max_state_dim", None) if is_evo1 else None - action_feature = self.features.get(ACTION) - action_dim = int(action_feature.shape[0]) if action_feature is not None else 7 - binarize_gripper = is_evo1 if self.binarize_gripper is None else self.binarize_gripper + def get_env_processors(self): return ( - PolicyProcessorPipeline(steps=[LiberoProcessorStep(max_state_dim=max_state_dim)]), - PolicyProcessorPipeline( - steps=[ - LiberoActionProcessorStep( - action_dim=action_dim, - binarize_gripper=binarize_gripper, - ) - ] - ), + PolicyProcessorPipeline(steps=[LiberoProcessorStep()]), + PolicyProcessorPipeline(steps=[]), ) @@ -723,7 +709,7 @@ class IsaaclabArenaEnv(HubEnvConfig): def gym_kwargs(self) -> dict: return {} - def get_env_processors(self, policy_cfg: Any | None = None): + def get_env_processors(self): state_keys = tuple(k.strip() for k in (self.state_keys or "").split(",") if k.strip()) camera_keys = tuple(k.strip() for k in (self.camera_keys or "").split(",") if k.strip()) if not state_keys and not camera_keys: diff --git a/src/lerobot/envs/factory.py b/src/lerobot/envs/factory.py index 22f43f4cd..317cf2e6f 100644 --- a/src/lerobot/envs/factory.py +++ b/src/lerobot/envs/factory.py @@ -15,7 +15,6 @@ # limitations under the License. from __future__ import annotations -import inspect from typing import Any import gymnasium as gym @@ -53,14 +52,7 @@ def make_env_pre_post_processors( return make_xvla_libero_pre_post_processors() - get_processors = env_cfg.get_env_processors - signature = inspect.signature(get_processors) - supports_policy_cfg = "policy_cfg" in signature.parameters or any( - param.kind is inspect.Parameter.VAR_KEYWORD for param in signature.parameters.values() - ) - if supports_policy_cfg: - return get_processors(policy_cfg=policy_cfg) - return get_processors() + return env_cfg.get_env_processors() def make_env( diff --git a/src/lerobot/policies/evo1/README.md b/src/lerobot/policies/evo1/README.md deleted file mode 120000 index 6c4284fb9..000000000 --- a/src/lerobot/policies/evo1/README.md +++ /dev/null @@ -1 +0,0 @@ -../../../../docs/source/policy_evo1_README.md \ No newline at end of file diff --git a/src/lerobot/policies/evo1/README.md b/src/lerobot/policies/evo1/README.md new file mode 100644 index 000000000..3c6d31c83 --- /dev/null +++ b/src/lerobot/policies/evo1/README.md @@ -0,0 +1,18 @@ +# EVO1 + +EVO1 is a Vision-Language-Action policy for robot control. The LeRobot +integration uses an InternVL3 vision-language backbone with a flow-matching +action head, and supports staged training through the standard LeRobot policy +APIs. + +The upstream EVO1 project is available at +[MINT-SJTU/Evo-1](https://github.com/MINT-SJTU/Evo-1). + +```bibtex +@misc{evo1, + title = {EVO1}, + author = {{MINT-SJTU}}, + year = {2026}, + howpublished = {\url{https://github.com/MINT-SJTU/Evo-1}}, +} +``` diff --git a/src/lerobot/policies/evo1/configuration_evo1.py b/src/lerobot/policies/evo1/configuration_evo1.py index 6804535d0..b7dd72a95 100644 --- a/src/lerobot/policies/evo1/configuration_evo1.py +++ b/src/lerobot/policies/evo1/configuration_evo1.py @@ -14,6 +14,7 @@ from __future__ import annotations +import logging import math from dataclasses import dataclass, field @@ -26,6 +27,8 @@ from lerobot.optim.optimizers import AdamWConfig from lerobot.optim.schedulers import LRSchedulerConfig from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE +logger = logging.getLogger(__name__) + @LRSchedulerConfig.register_subclass("evo1_exact") @dataclass @@ -59,6 +62,12 @@ class Evo1Config(PreTrainedConfig): max_views: int = 3 image_resolution: tuple[int, int] = (448, 448) empty_cameras: int = 0 + postprocess_action_dim: int | None = None + binarize_gripper: bool = False + gripper_index: int = 6 + gripper_threshold: float = 0.5 + gripper_below_threshold_value: float = 1.0 + gripper_above_threshold_value: float = -1.0 normalization_mapping: dict[str, NormalizationMode] = field( default_factory=lambda: { @@ -114,16 +123,32 @@ class Evo1Config(PreTrainedConfig): ) if self.apply_training_stage_defaults: - if self.training_stage == "stage1": - self.finetune_vlm = False - self.finetune_language_model = False - self.finetune_vision_model = False - self.finetune_action_head = True - elif self.training_stage == "stage2": - self.finetune_vlm = True - self.finetune_language_model = True - self.finetune_vision_model = True - self.finetune_action_head = True + stage_defaults = { + "stage1": { + "finetune_vlm": False, + "finetune_language_model": False, + "finetune_vision_model": False, + "finetune_action_head": True, + }, + "stage2": { + "finetune_vlm": True, + "finetune_language_model": True, + "finetune_vision_model": True, + "finetune_action_head": True, + }, + }[self.training_stage] + for flag_name, default_value in stage_defaults.items(): + current_value = getattr(self, flag_name) + if current_value is not None and current_value != default_value: + logger.warning( + "EVO1 %s=%s is overridden by training_stage=%s default %s. " + "Set apply_training_stage_defaults=false to keep explicit finetuning flags.", + flag_name, + current_value, + self.training_stage, + default_value, + ) + setattr(self, flag_name, default_value) elif self.training_stage == "stage1": if self.finetune_vlm is None: self.finetune_vlm = False @@ -171,6 +196,11 @@ class Evo1Config(PreTrainedConfig): raise ValueError( f"n_action_steps ({self.n_action_steps}) must be <= chunk_size ({self.chunk_size})" ) + if len(self.image_resolution) != 2 or self.image_resolution[0] != self.image_resolution[1]: + raise ValueError( + "EVO1 currently expects a square image_resolution because InternVL3 preprocessing " + f"uses a scalar image_size, got {self.image_resolution}." + ) def validate_features(self) -> None: if self.input_features is None: diff --git a/src/lerobot/policies/evo1/evo1_model.py b/src/lerobot/policies/evo1/evo1_model.py index 18e4dab33..a9637eda0 100644 --- a/src/lerobot/policies/evo1/evo1_model.py +++ b/src/lerobot/policies/evo1/evo1_model.py @@ -21,8 +21,8 @@ import torch import torch.nn as nn from PIL import Image -from lerobot.policies.evo1.flow_matching import FlowmatchingActionHead -from lerobot.policies.evo1.internvl3_embedder import InternVL3Embedder +from .flow_matching import FlowmatchingActionHead +from .internvl3_embedder import InternVL3Embedder def _cfgget(config: Any, key: str, default=None): @@ -163,37 +163,6 @@ class EVO1(nn.Module): embodiment_id=embodiment_ids, ) - @torch.no_grad() - def run_inference( - self, - images: list[Image.Image | torch.Tensor], - image_mask: torch.Tensor, - prompt: str, - state_input: list | torch.Tensor, - return_cls_only: bool | None = None, - action_mask: torch.Tensor | None = None, - embodiment_ids: torch.Tensor | None = None, - ) -> torch.Tensor: - if image_mask.dim() == 1: - image_mask = image_mask.unsqueeze(0) - - fused_tokens = self.get_vl_embeddings( - images=[images], - image_mask=image_mask, - prompt=[prompt], - return_cls_only=return_cls_only, - ) - state_tensor = self.prepare_state(state_input) - action = self.predict_action( - fused_tokens, - state_tensor, - action_mask=action_mask, - embodiment_ids=embodiment_ids, - ) - if isinstance(action, torch.Tensor) and action.dtype == torch.bfloat16: - action = action.to(torch.float32) - return action - def forward( self, fused_tokens: torch.Tensor, diff --git a/src/lerobot/policies/evo1/flow_matching.py b/src/lerobot/policies/evo1/flow_matching.py index b36af406c..986ea7a72 100644 --- a/src/lerobot/policies/evo1/flow_matching.py +++ b/src/lerobot/policies/evo1/flow_matching.py @@ -129,7 +129,10 @@ class MultiEmbodimentActionEncoder(nn.Module): def forward(self, action_seq: torch.Tensor, category_id: torch.LongTensor): batch_size, horizon, action_dim = action_seq.shape - assert self.horizon == horizon, "Action sequence length must match horizon" + if self.horizon != horizon: + raise ValueError( + f"Action sequence length must match horizon: got {horizon}, expected {self.horizon}." + ) x = action_seq.reshape(batch_size * horizon, action_dim) if category_id.dim() == 0: diff --git a/src/lerobot/policies/evo1/internvl3_embedder.py b/src/lerobot/policies/evo1/internvl3_embedder.py index 20745f8b6..fa7b6eb7d 100644 --- a/src/lerobot/policies/evo1/internvl3_embedder.py +++ b/src/lerobot/policies/evo1/internvl3_embedder.py @@ -24,7 +24,7 @@ from typing import TYPE_CHECKING import torch import torch.nn as nn import torch.utils.checkpoint -import torchvision.transforms.functional as TF +import torchvision.transforms.functional as tvf from PIL import Image from torchvision.transforms.functional import to_pil_image @@ -46,6 +46,26 @@ logger = logging.getLogger(__name__) def _patch_vision_encoder_checkpointing(encoder: nn.Module, use_reentrant: bool) -> None: + for attr_name in ("_gradient_checkpointing_func", "gradient_checkpointing_func"): + original_func = getattr(encoder, attr_name, None) + if not callable(original_func): + continue + patch_attr = f"_evo1_{attr_name}_patch_applied" + if getattr(encoder, patch_attr, False): + encoder.gradient_checkpointing_use_reentrant = use_reentrant + return + + def checkpoint_with_kwargs( + function, *checkpoint_args, _original_func=original_func, **checkpoint_kwargs + ): + checkpoint_kwargs.setdefault("use_reentrant", encoder.gradient_checkpointing_use_reentrant) + return _original_func(function, *checkpoint_args, **checkpoint_kwargs) + + encoder.gradient_checkpointing_use_reentrant = use_reentrant + setattr(encoder, attr_name, checkpoint_with_kwargs) + setattr(encoder, patch_attr, True) + return + if getattr(encoder, "_evo1_checkpoint_patch_applied", False): encoder.gradient_checkpointing_use_reentrant = use_reentrant return @@ -59,6 +79,9 @@ def _patch_vision_encoder_checkpointing(encoder: nn.Module, use_reentrant: bool) checkpoint_kwargs.setdefault("use_reentrant", self.gradient_checkpointing_use_reentrant) return original_checkpoint(function, *checkpoint_args, **checkpoint_kwargs) + # Some InternVL3 remote-code versions call torch.utils.checkpoint.checkpoint + # directly and do not expose a per-encoder checkpoint function to patch. + # Keep this compatibility fallback scoped to encoder.forward and restore it. torch.utils.checkpoint.checkpoint = checkpoint try: return original_forward(*args, **kwargs) @@ -280,11 +303,13 @@ class InternVL3Embedder(nn.Module): def _preprocess_single_image(self, image: Image.Image | torch.Tensor) -> torch.Tensor: if isinstance(image, torch.Tensor): + # Match upstream EVO1/InternVL preprocessing, which converts tensors + # through PIL before tiling and ImageNet normalization. pil_image = to_pil_image(image.detach().cpu()) else: pil_image = image.convert("RGB") tiles = dynamic_preprocess(pil_image, image_size=self.image_size) - tile_tensors = torch.stack([TF.to_tensor(tile) for tile in tiles]).to( + tile_tensors = torch.stack([tvf.to_tensor(tile) for tile in tiles]).to( device=self.device, dtype=torch.bfloat16 ) mean = torch.tensor(IMAGENET_MEAN, device=self.device, dtype=torch.bfloat16).view(1, 3, 1, 1) diff --git a/src/lerobot/policies/evo1/modeling_evo1.py b/src/lerobot/policies/evo1/modeling_evo1.py index 7867d0c8e..3606d63ca 100644 --- a/src/lerobot/policies/evo1/modeling_evo1.py +++ b/src/lerobot/policies/evo1/modeling_evo1.py @@ -45,6 +45,7 @@ class EVO1Policy(PreTrainedPolicy): self.config = config self.model = EVO1(self._build_model_config(config)) self.model.set_finetune_flags() + self._keep_frozen_embedder_eval() self.reset() @classmethod @@ -64,7 +65,7 @@ class EVO1Policy(PreTrainedPolicy): **kwargs, ) -> T: if strict is None: - strict = not (config is not None and getattr(config, "training_stage", None) == "stage2") + strict = True return super().from_pretrained( pretrained_name_or_path=pretrained_name_or_path, config=config, @@ -85,6 +86,7 @@ class EVO1Policy(PreTrainedPolicy): "device": config.device, "return_cls_only": config.return_cls_only, "vlm_name": config.vlm_model_name, + "image_size": int(config.image_resolution[0]), "vlm_num_layers": config.vlm_num_layers, "vlm_dtype": config.vlm_dtype, "use_flash_attn": config.use_flash_attn, @@ -100,7 +102,8 @@ class EVO1Policy(PreTrainedPolicy): "dropout": config.dropout, "num_inference_timesteps": config.num_inference_timesteps, "num_categories": config.num_categories, - "enable_gradient_checkpointing": config.enable_gradient_checkpointing, + "enable_gradient_checkpointing": config.enable_gradient_checkpointing + and bool(config.finetune_vlm or config.finetune_language_model or config.finetune_vision_model), "gradient_checkpointing_use_reentrant": config.gradient_checkpointing_use_reentrant, "finetune_vlm": config.finetune_vlm, "finetune_language_model": config.finetune_language_model, @@ -303,6 +306,18 @@ class EVO1Policy(PreTrainedPolicy): or self.config.finetune_vision_model ) + def _keep_frozen_embedder_eval(self) -> None: + if self._tracks_vlm_gradients: + return + embedder = getattr(self.model, "embedder", None) + if embedder is not None: + embedder.eval() + + def train(self, mode: bool = True): + super().train(mode) + self._keep_frozen_embedder_eval() + return self + def _collect_image_batches(self, batch: dict[str, Tensor]) -> tuple[list[list[Tensor]], Tensor]: camera_keys = self._camera_keys or sorted(key for key in batch if key.startswith(f"{OBS_IMAGES}.")) if not camera_keys: @@ -348,23 +363,13 @@ class EVO1Policy(PreTrainedPolicy): ) -> Tensor: track_vlm_gradients = self._tracks_vlm_gradients grad_context = nullcontext() if track_vlm_gradients else torch.no_grad() - embedder = getattr(self.model, "embedder", None) - embedder_was_training = embedder.training if embedder is not None else None - - if not track_vlm_gradients and embedder is not None: - embedder.eval() - - try: - with grad_context: - fused_tokens = self.model.get_vl_embeddings( - images=image_batches, - image_mask=image_masks, - prompt=prompts, - return_cls_only=self.config.return_cls_only, - ) - finally: - if not track_vlm_gradients and embedder is not None and embedder_was_training is not None: - embedder.train(embedder_was_training) + with grad_context: + fused_tokens = self.model.get_vl_embeddings( + images=image_batches, + image_mask=image_masks, + prompt=prompts, + return_cls_only=self.config.return_cls_only, + ) if not track_vlm_gradients: fused_tokens = fused_tokens.detach() @@ -439,7 +444,7 @@ class EVO1Policy(PreTrainedPolicy): embodiment_ids=embodiment_ids, ) actions = actions.view(states.shape[0], self.config.chunk_size, self.config.max_action_dim) - return actions[:, :, : self._env_action_dim] + return actions @torch.no_grad() def select_action(self, batch: dict[str, Tensor], **kwargs) -> Tensor: diff --git a/src/lerobot/policies/evo1/processor_evo1.py b/src/lerobot/policies/evo1/processor_evo1.py index f1a162df1..9acc7bb74 100644 --- a/src/lerobot/policies/evo1/processor_evo1.py +++ b/src/lerobot/policies/evo1/processor_evo1.py @@ -14,17 +14,24 @@ from __future__ import annotations +from copy import deepcopy +from dataclasses import dataclass from typing import Any import torch +from lerobot.configs import FeatureType, PipelineFeatureType, PolicyFeature from lerobot.policies.evo1.configuration_evo1 import Evo1Config from lerobot.processor import ( AddBatchDimensionProcessorStep, DeviceProcessorStep, NormalizerProcessorStep, + ObservationProcessorStep, PolicyAction, + PolicyActionProcessorStep, PolicyProcessorPipeline, + ProcessorStep, + ProcessorStepRegistry, RenameObservationsProcessorStep, UnnormalizerProcessorStep, ) @@ -34,11 +41,13 @@ from lerobot.processor.converters import ( policy_action_to_transition, transition_to_policy_action, ) +from lerobot.types import EnvTransition, TransitionKey from lerobot.utils.constants import ( ACTION, DONE, INFO, OBS_PREFIX, + OBS_STATE, POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME, REWARD, @@ -65,6 +74,305 @@ def evo1_batch_to_transition(batch: dict[str, Any]): ) +@dataclass +@ProcessorStepRegistry.register(name="evo1_pad_state_processor") +class Evo1PadStateProcessorStep(ObservationProcessorStep): + """Pad policy observations to EVO1's fixed state width before normalization.""" + + max_state_dim: int = 24 + + def observation(self, observation: dict[str, Any]) -> dict[str, Any]: + if OBS_STATE not in observation: + return observation + + state = observation[OBS_STATE] + state_dim = state.shape[-1] + if state_dim > self.max_state_dim: + raise ValueError( + f"EVO1 state has {state_dim} dims, which exceeds max_state_dim={self.max_state_dim}." + ) + if state_dim < self.max_state_dim: + observation = observation.copy() + observation[OBS_STATE] = torch.nn.functional.pad(state, (0, self.max_state_dim - state_dim)) + return observation + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + new_features = {ft: feats.copy() for ft, feats in features.items()} + state_feats = new_features.setdefault(FeatureType.STATE, {}) + if OBS_STATE in state_feats: + state_feats[OBS_STATE] = PolicyFeature(type=FeatureType.STATE, shape=(self.max_state_dim,)) + return new_features + + def get_config(self) -> dict[str, Any]: + return {"max_state_dim": self.max_state_dim} + + +@dataclass +@ProcessorStepRegistry.register(name="evo1_pad_action_processor") +class Evo1PadActionProcessorStep(ProcessorStep): + """Pad training actions and preserve the active action dimensions with action_mask.""" + + max_action_dim: int = 24 + + def __call__(self, transition: EnvTransition) -> EnvTransition: + action = transition.get(TransitionKey.ACTION) + if action is None: + return transition + if not isinstance(action, PolicyAction): + raise ValueError(f"EVO1 action should be a PolicyAction tensor, but got {type(action)}.") + + action_dim = action.shape[-1] + if action_dim > self.max_action_dim: + raise ValueError( + f"EVO1 action has {action_dim} dims, which exceeds max_action_dim={self.max_action_dim}." + ) + + new_transition = transition.copy() + new_action = action + if action_dim < self.max_action_dim: + new_action = torch.nn.functional.pad(action, (0, self.max_action_dim - action_dim)) + + complementary_data = dict(new_transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}) + action_mask = complementary_data.get("action_mask") + if action_mask is None: + action_mask = torch.ones(action.shape, dtype=torch.bool, device=action.device) + else: + action_mask = torch.as_tensor(action_mask, dtype=torch.bool, device=action.device) + if action_mask.shape != action.shape: + raise ValueError( + f"action_mask shape {tuple(action_mask.shape)} does not match action shape {tuple(action.shape)}." + ) + if action_dim < self.max_action_dim: + action_mask = torch.nn.functional.pad(action_mask, (0, self.max_action_dim - action_dim)) + + complementary_data["action_mask"] = action_mask + new_transition[TransitionKey.ACTION] = new_action + new_transition[TransitionKey.COMPLEMENTARY_DATA] = complementary_data + return new_transition + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + new_features = {ft: feats.copy() for ft, feats in features.items()} + action_feats = new_features.setdefault(FeatureType.ACTION, {}) + action_feats[ACTION] = PolicyFeature(type=FeatureType.ACTION, shape=(self.max_action_dim,)) + return new_features + + def get_config(self) -> dict[str, Any]: + return {"max_action_dim": self.max_action_dim} + + +@dataclass +@ProcessorStepRegistry.register(name="evo1_action_processor") +class Evo1ActionProcessorStep(PolicyActionProcessorStep): + """Crop padded EVO1 actions and optionally binarize the LIBERO gripper channel.""" + + action_dim: int + binarize_gripper: bool = False + gripper_index: int = 6 + gripper_threshold: float = 0.5 + gripper_below_threshold_value: float = 1.0 + gripper_above_threshold_value: float = -1.0 + + def action(self, action: PolicyAction) -> PolicyAction: + if action.shape[-1] < self.action_dim: + raise ValueError( + f"EVO1 action has {action.shape[-1]} dims, which is smaller than action_dim={self.action_dim}." + ) + + action = action[..., : self.action_dim] + if not self.binarize_gripper: + return action + + if not 0 <= self.gripper_index < self.action_dim: + raise ValueError( + f"gripper_index={self.gripper_index} must be within action_dim={self.action_dim}." + ) + + action = action.clone() + below = torch.as_tensor( + self.gripper_below_threshold_value, + dtype=action.dtype, + device=action.device, + ) + above = torch.as_tensor( + self.gripper_above_threshold_value, + dtype=action.dtype, + device=action.device, + ) + action[..., self.gripper_index] = torch.where( + action[..., self.gripper_index] > self.gripper_threshold, + above, + below, + ) + return action + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + new_features = {ft: feats.copy() for ft, feats in features.items()} + action_feats = new_features.setdefault(FeatureType.ACTION, {}) + action_feats[ACTION] = PolicyFeature(type=FeatureType.ACTION, shape=(self.action_dim,)) + return new_features + + def get_config(self) -> dict[str, Any]: + return { + "action_dim": self.action_dim, + "binarize_gripper": self.binarize_gripper, + "gripper_index": self.gripper_index, + "gripper_threshold": self.gripper_threshold, + "gripper_below_threshold_value": self.gripper_below_threshold_value, + "gripper_above_threshold_value": self.gripper_above_threshold_value, + } + + +def _evo1_action_dim(config: Evo1Config) -> int: + if config.postprocess_action_dim is not None: + return config.postprocess_action_dim + action_feature = config.action_feature + if action_feature is None: + return config.max_action_dim + return int(action_feature.shape[0]) + + +def _evo1_normalization_features(config: Evo1Config) -> dict[str, PolicyFeature]: + features = {**config.input_features, **config.output_features} + features[OBS_STATE] = PolicyFeature(type=FeatureType.STATE, shape=(config.max_state_dim,)) + features[ACTION] = PolicyFeature(type=FeatureType.ACTION, shape=(config.max_action_dim,)) + return features + + +def _evo1_action_features(config: Evo1Config) -> dict[str, PolicyFeature]: + return {ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(config.max_action_dim,))} + + +_STAT_PAD_VALUES = { + "mean": 0.0, + "std": 1.0, + "min": -1.0, + "max": 1.0, + "q01": -1.0, + "q99": 1.0, + "q10": -1.0, + "q90": 1.0, +} + + +def _pad_stat_value(value: Any, target_dim: int, stat_name: str) -> torch.Tensor: + tensor = torch.as_tensor(value) + if not tensor.is_floating_point(): + tensor = tensor.to(dtype=torch.float32) + if tensor.ndim == 0 or tensor.shape[-1] >= target_dim: + return tensor + + pad_shape = (*tensor.shape[:-1], target_dim - tensor.shape[-1]) + pad_value = _STAT_PAD_VALUES.get(stat_name, 0.0) + padding = torch.full(pad_shape, pad_value, dtype=tensor.dtype, device=tensor.device) + return torch.cat([tensor, padding], dim=-1) + + +def _pad_feature_stats( + stats: dict[str, dict[str, Any]], + feature_key: str, + target_dim: int, +) -> None: + if feature_key not in stats: + return + stats[feature_key] = { + stat_name: _pad_stat_value(stat_value, target_dim, stat_name) + for stat_name, stat_value in stats[feature_key].items() + } + + +def _pad_evo1_stats( + config: Evo1Config, + stats: dict[str, dict[str, Any]] | None, +) -> dict[str, dict[str, Any]] | None: + if stats is None: + return None + + padded_stats = deepcopy(stats) + # Added dimensions represent zero-padding inside EVO1. These neutral stats keep + # padded observations at normalized zero and only provide shape compatibility. + _pad_feature_stats(padded_stats, OBS_STATE, config.max_state_dim) + _pad_feature_stats(padded_stats, ACTION, config.max_action_dim) + return padded_stats + + +def _refresh_evo1_normalization_steps( + config: Evo1Config, + preprocessor: PolicyProcessorPipeline, + postprocessor: PolicyProcessorPipeline, +) -> None: + normalization_features = _evo1_normalization_features(config) + action_features = _evo1_action_features(config) + + for step in preprocessor.steps: + if isinstance(step, NormalizerProcessorStep): + step.features = normalization_features + step.stats = _pad_evo1_stats(config, step.stats) + step.to(device=step.device, dtype=step.dtype) + + for step in postprocessor.steps: + if isinstance(step, UnnormalizerProcessorStep): + step.features = action_features + step.stats = _pad_evo1_stats(config, step.stats) + step.to(device=step.device, dtype=step.dtype) + + +def ensure_evo1_processor_steps( + config: Evo1Config, + preprocessor: PolicyProcessorPipeline, + postprocessor: PolicyProcessorPipeline, +) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]: + """Add EVO1 processor steps when loading older checkpoints that do not serialize them.""" + + has_state_padding = any(isinstance(step, Evo1PadStateProcessorStep) for step in preprocessor.steps) + if not has_state_padding: + steps = list(preprocessor.steps) + insert_idx = next( + (idx for idx, step in enumerate(steps) if isinstance(step, NormalizerProcessorStep)), + len(steps), + ) + steps.insert(insert_idx, Evo1PadStateProcessorStep(max_state_dim=config.max_state_dim)) + preprocessor.steps = steps + + has_action_padding = any(isinstance(step, Evo1PadActionProcessorStep) for step in preprocessor.steps) + if not has_action_padding: + steps = list(preprocessor.steps) + insert_idx = next( + (idx for idx, step in enumerate(steps) if isinstance(step, NormalizerProcessorStep)), + len(steps), + ) + steps.insert(insert_idx, Evo1PadActionProcessorStep(max_action_dim=config.max_action_dim)) + preprocessor.steps = steps + + has_action_processor = any(isinstance(step, Evo1ActionProcessorStep) for step in postprocessor.steps) + if not has_action_processor: + steps = list(postprocessor.steps) + insert_idx = next( + (idx + 1 for idx, step in enumerate(steps) if isinstance(step, UnnormalizerProcessorStep)), + 0, + ) + steps.insert( + insert_idx, + Evo1ActionProcessorStep( + action_dim=_evo1_action_dim(config), + binarize_gripper=config.binarize_gripper, + gripper_index=config.gripper_index, + gripper_threshold=config.gripper_threshold, + gripper_below_threshold_value=config.gripper_below_threshold_value, + gripper_above_threshold_value=config.gripper_above_threshold_value, + ), + ) + postprocessor.steps = steps + + _refresh_evo1_normalization_steps(config, preprocessor, postprocessor) + return preprocessor, postprocessor + + def make_evo1_pre_post_processors( config: Evo1Config, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None, @@ -72,21 +380,35 @@ def make_evo1_pre_post_processors( PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], PolicyProcessorPipeline[PolicyAction, PolicyAction], ]: + normalization_features = _evo1_normalization_features(config) + action_features = _evo1_action_features(config) + normalization_stats = _pad_evo1_stats(config, dataset_stats) + input_steps = [ RenameObservationsProcessorStep(rename_map={}), AddBatchDimensionProcessorStep(), + Evo1PadStateProcessorStep(max_state_dim=config.max_state_dim), + Evo1PadActionProcessorStep(max_action_dim=config.max_action_dim), NormalizerProcessorStep( - features={**config.input_features, **config.output_features}, + features=normalization_features, norm_map=config.normalization_mapping, - stats=dataset_stats, + stats=normalization_stats, ), DeviceProcessorStep(device=config.device), ] output_steps = [ UnnormalizerProcessorStep( - features=config.output_features, + features=action_features, norm_map=config.normalization_mapping, - stats=dataset_stats, + stats=normalization_stats, + ), + Evo1ActionProcessorStep( + action_dim=_evo1_action_dim(config), + binarize_gripper=config.binarize_gripper, + gripper_index=config.gripper_index, + gripper_threshold=config.gripper_threshold, + gripper_below_threshold_value=config.gripper_below_threshold_value, + gripper_above_threshold_value=config.gripper_above_threshold_value, ), DeviceProcessorStep(device="cpu"), ] diff --git a/src/lerobot/policies/factory.py b/src/lerobot/policies/factory.py index a511de67a..fa322c237 100644 --- a/src/lerobot/policies/factory.py +++ b/src/lerobot/policies/factory.py @@ -311,6 +311,14 @@ def make_pre_post_processors( to_output=transition_to_policy_action, ) _reconnect_relative_absolute_steps(preprocessor, postprocessor) + if isinstance(policy_cfg, Evo1Config): + from .evo1.processor_evo1 import ensure_evo1_processor_steps + + preprocessor, postprocessor = ensure_evo1_processor_steps( + policy_cfg, + preprocessor, + postprocessor, + ) return preprocessor, postprocessor # Create a new processor based on policy type diff --git a/src/lerobot/processor/__init__.py b/src/lerobot/processor/__init__.py index 3329729fb..3688a4b8c 100644 --- a/src/lerobot/processor/__init__.py +++ b/src/lerobot/processor/__init__.py @@ -40,7 +40,7 @@ from .converters import ( ) from .delta_action_processor import MapDeltaActionToRobotActionStep, MapTensorToDeltaActionDictStep from .device_processor import DeviceProcessorStep -from .env_processor import IsaaclabArenaProcessorStep, LiberoActionProcessorStep, LiberoProcessorStep +from .env_processor import IsaaclabArenaProcessorStep, LiberoProcessorStep from .factory import ( make_default_processors, make_default_robot_action_processor, @@ -149,7 +149,6 @@ __all__ = [ "RewardProcessorStep", "DataProcessorPipeline", "IsaaclabArenaProcessorStep", - "LiberoActionProcessorStep", "LiberoProcessorStep", "TimeLimitProcessorStep", "AddBatchDimensionProcessorStep", diff --git a/src/lerobot/processor/env_processor.py b/src/lerobot/processor/env_processor.py index 4369a81af..72399e9ab 100644 --- a/src/lerobot/processor/env_processor.py +++ b/src/lerobot/processor/env_processor.py @@ -18,9 +18,9 @@ from dataclasses import dataclass import torch from lerobot.configs import FeatureType, PipelineFeatureType, PolicyFeature -from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_PREFIX, OBS_STATE, OBS_STR +from lerobot.utils.constants import OBS_IMAGES, OBS_PREFIX, OBS_STATE, OBS_STR -from .pipeline import ActionProcessorStep, ObservationProcessorStep, ProcessorStepRegistry +from .pipeline import ObservationProcessorStep, ProcessorStepRegistry @dataclass @@ -46,8 +46,6 @@ class LiberoProcessorStep(ObservationProcessorStep): - This accounts for the HuggingFaceVLA/libero camera orientation convention. """ - max_state_dim: int | None = None - def _process_observation(self, observation): """ Processes both image and robot_state observations from LIBERO. @@ -80,16 +78,6 @@ class LiberoProcessorStep(ObservationProcessorStep): state = state.float() if state.dim() == 1: state = state.unsqueeze(0) - if self.max_state_dim is not None: - if state.shape[-1] > self.max_state_dim: - raise ValueError( - f"LIBERO state has {state.shape[-1]} dims, which is larger than " - f"configured max_state_dim={self.max_state_dim}." - ) - if state.shape[-1] < self.max_state_dim: - pad_width = self.max_state_dim - state.shape[-1] - state = torch.nn.functional.pad(state, (0, pad_width)) - processed_obs[OBS_STATE] = state return processed_obs @@ -112,7 +100,7 @@ class LiberoProcessorStep(ObservationProcessorStep): # add our new flattened state state_feats[OBS_STATE] = PolicyFeature( type=FeatureType.STATE, - shape=(self.max_state_dim or 8,), # [eef_pos(3), axis_angle(3), gripper(2)] plus padding + shape=(8,), # [eef_pos(3), axis_angle(3), gripper(2)] ) new_features[FeatureType.STATE] = state_feats @@ -122,9 +110,6 @@ class LiberoProcessorStep(ObservationProcessorStep): def observation(self, observation): return self._process_observation(observation) - def get_config(self) -> dict: - return {"max_state_dim": self.max_state_dim} - def _quat2axisangle(self, quat: torch.Tensor) -> torch.Tensor: """ Convert batched quaternions to axis-angle format. @@ -167,68 +152,6 @@ class LiberoProcessorStep(ObservationProcessorStep): return result -@dataclass -@ProcessorStepRegistry.register(name="libero_action_processor") -class LiberoActionProcessorStep(ActionProcessorStep): - """Slices padded policy actions back to the executable LIBERO action space.""" - - action_dim: int = 7 - binarize_gripper: bool = False - gripper_index: int = 6 - gripper_threshold: float = 0.5 - gripper_below_threshold_value: float = 1.0 - gripper_above_threshold_value: float = -1.0 - - def action(self, action): - if action.shape[-1] < self.action_dim: - raise ValueError( - f"LIBERO action has {action.shape[-1]} dims, which is smaller than action_dim={self.action_dim}." - ) - action = action[..., : self.action_dim] - if not self.binarize_gripper: - return action - - if not 0 <= self.gripper_index < self.action_dim: - raise ValueError( - f"gripper_index={self.gripper_index} must be within sliced action_dim={self.action_dim}." - ) - action = action.clone() - below = torch.as_tensor( - self.gripper_below_threshold_value, - dtype=action.dtype, - device=action.device, - ) - above = torch.as_tensor( - self.gripper_above_threshold_value, - dtype=action.dtype, - device=action.device, - ) - action[..., self.gripper_index] = torch.where( - action[..., self.gripper_index] > self.gripper_threshold, - above, - below, - ) - return action - - def transform_features( - self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] - ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: - new_features = {ft: feats.copy() for ft, feats in features.items()} - action_feats = new_features.setdefault(FeatureType.ACTION, {}) - action_feats[ACTION] = PolicyFeature(type=FeatureType.ACTION, shape=(self.action_dim,)) - return new_features - - def get_config(self) -> dict: - return { - "action_dim": self.action_dim, - "binarize_gripper": self.binarize_gripper, - "gripper_index": self.gripper_index, - "gripper_threshold": self.gripper_threshold, - "gripper_below_threshold_value": self.gripper_below_threshold_value, - "gripper_above_threshold_value": self.gripper_above_threshold_value, - } - - @dataclass @ProcessorStepRegistry.register(name="isaaclab_arena_processor") class IsaaclabArenaProcessorStep(ObservationProcessorStep): diff --git a/tests/envs/test_dispatch.py b/tests/envs/test_dispatch.py index eeed9d1aa..b038832af 100644 --- a/tests/envs/test_dispatch.py +++ b/tests/envs/test_dispatch.py @@ -13,7 +13,7 @@ from gymnasium.envs.registration import register, registry as gym_registry from lerobot.configs.types import PolicyFeature from lerobot.envs.configs import EnvConfig, LiberoEnv from lerobot.envs.factory import make_env, make_env_config, make_env_pre_post_processors -from lerobot.processor import LiberoActionProcessorStep, LiberoProcessorStep +from lerobot.processor import LiberoProcessorStep from lerobot.utils.constants import OBS_PREFIX, OBS_STATE logger = logging.getLogger(__name__) @@ -86,38 +86,18 @@ def test_processors_delegation_supports_legacy_override_signature(): assert isinstance(post, DataProcessorPipeline) -def test_libero_evo1_processors_use_padded_state_and_env_action_dim(): - """EVO1 uses padded LIBERO state features while env actions stay executable.""" - - class _Evo1Config: - type = "evo1" - max_state_dim = 24 - +def test_libero_processors_are_policy_agnostic(): cfg = LiberoEnv() - pre, post = make_env_pre_post_processors(cfg, policy_cfg=_Evo1Config()) + pre, post = make_env_pre_post_processors(cfg, policy_cfg=object()) + assert isinstance(pre.steps[0], LiberoProcessorStep) - assert pre.steps[0].max_state_dim == 24 - assert isinstance(post.steps[0], LiberoActionProcessorStep) - assert post.steps[0].action_dim == cfg.features["action"].shape[0] == 7 - assert post.steps[0].binarize_gripper is True - - class _OtherConfig: - type = "other" - - pre_other, post_other = make_env_pre_post_processors(cfg, policy_cfg=_OtherConfig()) - assert pre_other.steps[0].max_state_dim is None - assert post_other.steps[0].binarize_gripper is False - - cfg.binarize_gripper = False - _, post_disabled = make_env_pre_post_processors(cfg, policy_cfg=_Evo1Config()) - assert post_disabled.steps[0].binarize_gripper is False + assert len(post.steps) == 0 -def test_libero_processor_pads_state_to_max_dim(): - step = LiberoProcessorStep(max_state_dim=24) +def test_libero_processor_flattens_state_to_raw_8_dim(): + step = LiberoProcessorStep() observation = { - OBS_PREFIX - + "robot_state": { + OBS_PREFIX + "robot_state": { "eef": { "pos": torch.tensor([[1.0, 2.0, 3.0]]), "quat": torch.tensor([[0.0, 0.0, 0.0, 1.0]]), @@ -127,39 +107,8 @@ def test_libero_processor_pads_state_to_max_dim(): } state = step.observation(observation)[OBS_STATE] - assert state.shape == (1, 24) - assert torch.allclose(state[:, :8], torch.tensor([[1.0, 2.0, 3.0, 0.0, 0.0, 0.0, 4.0, 5.0]])) - assert torch.count_nonzero(state[:, 8:]).item() == 0 - - -def test_libero_action_processor_slices_padded_action(): - step = LiberoActionProcessorStep(action_dim=7) - action = torch.arange(2 * 3 * 24, dtype=torch.float32).reshape(2, 3, 24) - - sliced = step.action(action) - assert sliced.shape == (2, 3, 7) - assert torch.equal(sliced, action[..., :7]) - - with pytest.raises(ValueError, match="smaller than action_dim=7"): - step.action(torch.zeros(2, 6)) - - -def test_libero_action_processor_can_binarize_gripper(): - step = LiberoActionProcessorStep(action_dim=7, binarize_gripper=True) - action = torch.tensor( - [ - [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 0.5, 7.0], - [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 0.6, 7.0], - ], - dtype=torch.float32, - ) - - processed = step.action(action) - - assert processed.shape == (2, 7) - assert torch.equal(processed[:, :6], action[:, :6]) - assert torch.equal(processed[:, 6], torch.tensor([1.0, -1.0])) - assert torch.equal(action[:, 6], torch.tensor([0.5, 0.6])) + assert state.shape == (1, 8) + assert torch.allclose(state, torch.tensor([[1.0, 2.0, 3.0, 0.0, 0.0, 0.0, 4.0, 5.0]])) def test_base_create_envs(): diff --git a/tests/policies/evo1/test_evo1.py b/tests/policies/evo1/test_evo1.py index 7ccd6274e..9ab531f63 100644 --- a/tests/policies/evo1/test_evo1.py +++ b/tests/policies/evo1/test_evo1.py @@ -16,6 +16,7 @@ from __future__ import annotations +import pytest import torch from torch import nn @@ -23,7 +24,15 @@ import lerobot.policies.evo1.modeling_evo1 as modeling_evo1 from lerobot.configs.types import FeatureType, PolicyFeature from lerobot.policies.evo1.configuration_evo1 import Evo1Config from lerobot.policies.evo1.flow_matching import FlowmatchingActionHead +from lerobot.policies.evo1.processor_evo1 import ( + Evo1ActionProcessorStep, + Evo1PadActionProcessorStep, + Evo1PadStateProcessorStep, + ensure_evo1_processor_steps, + make_evo1_pre_post_processors, +) from lerobot.policies.factory import get_policy_class, make_policy_config +from lerobot.processor import NormalizerProcessorStep, PolicyProcessorPipeline, UnnormalizerProcessorStep from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE STATE_DIM = 4 @@ -108,6 +117,19 @@ def make_batch(include_action=True): return batch +def make_stats(state_dim=STATE_DIM, action_dim=ACTION_DIM): + return { + OBS_STATE: { + "min": torch.full((state_dim,), -2.0), + "max": torch.full((state_dim,), 2.0), + }, + ACTION: { + "min": torch.full((action_dim,), -1.0), + "max": torch.full((action_dim,), 1.0), + }, + } + + def test_evo1_factory_registration(): cfg = make_policy_config( "evo1", @@ -191,22 +213,151 @@ def test_evo1_stage_defaults_and_consistency(): raise AssertionError("Expected inconsistent finetune config to raise ValueError") +def test_evo1_rejects_non_square_image_resolution(): + with pytest.raises(ValueError, match="square image_resolution"): + make_config(image_resolution=(448, 320)) + + +def test_evo1_build_model_config_uses_image_resolution_and_trainable_checkpointing(): + stage1 = make_config(training_stage="stage1", image_resolution=(224, 224)) + stage1_model_config = modeling_evo1.EVO1Policy._build_model_config(stage1) + + assert stage1_model_config["image_size"] == 224 + assert stage1_model_config["enable_gradient_checkpointing"] is False + + stage2 = make_config(training_stage="stage2", image_resolution=(224, 224)) + stage2_model_config = modeling_evo1.EVO1Policy._build_model_config(stage2) + + assert stage2_model_config["enable_gradient_checkpointing"] is True + + +def test_evo1_policy_processors_pad_state_crop_action_and_binarize_gripper(): + libero_action_dim = 7 + config = make_config( + max_state_dim=MAX_STATE_DIM, + max_action_dim=8, + postprocess_action_dim=libero_action_dim, + binarize_gripper=True, + output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(libero_action_dim,))}, + ) + stats = make_stats(action_dim=libero_action_dim) + + preprocessor, postprocessor = make_evo1_pre_post_processors(config, dataset_stats=stats) + + assert isinstance(preprocessor.steps[2], Evo1PadStateProcessorStep) + assert isinstance(preprocessor.steps[3], Evo1PadActionProcessorStep) + assert isinstance(preprocessor.steps[4], NormalizerProcessorStep) + assert isinstance(postprocessor.steps[0], UnnormalizerProcessorStep) + assert isinstance(postprocessor.steps[1], Evo1ActionProcessorStep) + + normalizer = preprocessor.steps[4] + assert normalizer.features[OBS_STATE].shape == (MAX_STATE_DIM,) + assert normalizer.features[ACTION].shape == (8,) + assert normalizer._tensor_stats[OBS_STATE]["min"].shape == (MAX_STATE_DIM,) + assert normalizer._tensor_stats[ACTION]["min"].shape == (8,) + + processed_batch = preprocessor( + { + "task": "pick the block", + OBS_STATE: torch.zeros(STATE_DIM), + ACTION: torch.zeros(libero_action_dim), + f"{OBS_IMAGES}.front": torch.rand(3, 16, 16), + } + ) + processed_state = processed_batch[OBS_STATE] + assert processed_state.shape == (1, MAX_STATE_DIM) + assert torch.allclose(processed_state, torch.zeros_like(processed_state)) + assert processed_batch[ACTION].shape == (1, 8) + assert torch.allclose(processed_batch[ACTION], torch.zeros_like(processed_batch[ACTION])) + assert processed_batch["action_mask"].shape == (1, 8) + assert processed_batch["action_mask"][:, :libero_action_dim].all() + assert not processed_batch["action_mask"][:, libero_action_dim:].any() + + action = torch.tensor( + [ + [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.5, 0.7], + [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7], + ], + dtype=torch.float32, + ) + processed = postprocessor(action) + + assert processed.shape == (2, 7) + assert torch.allclose(processed[:, :6], action[:, :6]) + assert torch.equal(processed[:, 6], torch.tensor([1.0, -1.0])) + + +def test_evo1_legacy_processors_are_completed_before_normalization(): + config = make_config( + max_state_dim=MAX_STATE_DIM, + max_action_dim=8, + postprocess_action_dim=7, + binarize_gripper=True, + ) + stats = make_stats(action_dim=7) + legacy_pre = PolicyProcessorPipeline( + steps=[ + NormalizerProcessorStep( + features={**config.input_features, **config.output_features}, + norm_map=config.normalization_mapping, + stats=stats, + ) + ] + ) + legacy_post = PolicyProcessorPipeline( + steps=[ + UnnormalizerProcessorStep( + features=config.output_features, + norm_map=config.normalization_mapping, + stats=stats, + ) + ] + ) + + preprocessor, postprocessor = ensure_evo1_processor_steps(config, legacy_pre, legacy_post) + + assert isinstance(preprocessor.steps[0], Evo1PadStateProcessorStep) + assert isinstance(preprocessor.steps[1], Evo1PadActionProcessorStep) + assert isinstance(preprocessor.steps[2], NormalizerProcessorStep) + assert isinstance(postprocessor.steps[0], UnnormalizerProcessorStep) + assert isinstance(postprocessor.steps[1], Evo1ActionProcessorStep) + assert postprocessor.steps[1].action_dim == 7 + assert postprocessor.steps[1].binarize_gripper is True + assert preprocessor.steps[2].features[OBS_STATE].shape == (MAX_STATE_DIM,) + assert preprocessor.steps[2]._tensor_stats[OBS_STATE]["min"].shape == (MAX_STATE_DIM,) + assert preprocessor.steps[2]._tensor_stats[ACTION]["min"].shape == (8,) + assert postprocessor.steps[0].features[ACTION].shape == (8,) + assert postprocessor.steps[0]._tensor_stats[ACTION]["min"].shape == (8,) + + preprocessor, postprocessor = ensure_evo1_processor_steps(config, preprocessor, postprocessor) + assert sum(isinstance(step, Evo1PadStateProcessorStep) for step in preprocessor.steps) == 1 + assert sum(isinstance(step, Evo1PadActionProcessorStep) for step in preprocessor.steps) == 1 + assert sum(isinstance(step, Evo1ActionProcessorStep) for step in postprocessor.steps) == 1 + + def test_evo1_policy_forward_and_inference_use_batched_embedding(monkeypatch): monkeypatch.setattr(modeling_evo1, "EVO1", DummyEVO1) policy = modeling_evo1.EVO1Policy(make_config()) + preprocessor, _postprocessor = make_evo1_pre_post_processors(policy.config, dataset_stats=make_stats()) + training_batch = preprocessor(make_batch(include_action=True)) - loss, metrics = policy.forward(make_batch(include_action=True)) + assert training_batch[ACTION].shape == (2, CHUNK_SIZE, MAX_ACTION_DIM) + assert training_batch["action_mask"].shape == (2, CHUNK_SIZE, MAX_ACTION_DIM) + assert training_batch["action_mask"][:, :, :ACTION_DIM].all() + assert not training_batch["action_mask"][:, :, ACTION_DIM:].any() + + loss, metrics = policy.forward(training_batch) assert loss.ndim == 0 assert torch.isfinite(loss) assert metrics["active_action_dims"] == ACTION_DIM * CHUNK_SIZE assert policy.model.get_vl_embeddings_calls == 1 action_chunk = policy.predict_action_chunk(make_batch(include_action=False)) - assert action_chunk.shape == (2, CHUNK_SIZE, ACTION_DIM) + assert action_chunk.shape == (2, CHUNK_SIZE, MAX_ACTION_DIM) policy.reset() selected = policy.select_action(make_batch(include_action=False)) - assert selected.shape == (2, ACTION_DIM) + assert selected.shape == (2, MAX_ACTION_DIM) def test_stage1_frozen_vlm_embeddings_do_not_track_gradients(monkeypatch): @@ -220,7 +371,7 @@ def test_stage1_frozen_vlm_embeddings_do_not_track_gradients(monkeypatch): assert policy.model.grad_enabled_calls == [False] assert policy.model.embedder_training_calls == [False] assert not fused_tokens.requires_grad - assert policy.model.embedder.training is True + assert policy.model.embedder.training is False def test_stage2_vlm_embeddings_track_gradients(monkeypatch):