mirror of
https://github.com/huggingface/lerobot.git
synced 2026-07-03 08:07:03 +00:00
Compare commits
24 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 3fa1415057 | |||
| 8dcbfd9d25 | |||
| edc01c3b94 | |||
| f5ac58adb9 | |||
| 2afe2864e9 | |||
| d61941fe68 | |||
| e181f2e383 | |||
| 9f5ddeb761 | |||
| 13adcea522 | |||
| 5b541c042d | |||
| 9423deda02 | |||
| 25556ceefe | |||
| 4cfa762da8 | |||
| fa984990c0 | |||
| f9b8f297b4 | |||
| 95527f6051 | |||
| 407ee867b9 | |||
| a5e6409985 | |||
| 1c9fbba9a9 | |||
| 6a1b5ceb9d | |||
| daa4c4dd30 | |||
| ff992a7a1d | |||
| 48269dddb3 | |||
| 8df8d3d866 |
@@ -71,6 +71,8 @@
|
||||
title: EO-1
|
||||
- local: fastwam
|
||||
title: FastWAM
|
||||
- local: evo1
|
||||
title: EVO1
|
||||
- local: groot
|
||||
title: NVIDIA GR00T N1.5
|
||||
- local: xvla
|
||||
|
||||
@@ -0,0 +1,191 @@
|
||||
# EVO1
|
||||
|
||||
EVO1 is a Vision-Language-Action policy for robot control built around an InternVL3 backbone and a continuous flow-matching action head. This LeRobot integration exposes EVO1 as a standard policy type so it can be trained and evaluated with the usual LeRobot dataset, checkpoint, and processor APIs.
|
||||
|
||||
## Model Overview
|
||||
|
||||
The policy embeds one or more camera images and the language task prompt with InternVL3, pads robot state/action vectors to fixed maximum dimensions, and predicts future action chunks with a flow-matching action head. During inference, the policy samples an action chunk and returns `n_action_steps` actions from that chunk before sampling again.
|
||||
|
||||
### What the LeRobot Integration Covers
|
||||
|
||||
- Standard `policy.type=evo1` configuration through LeRobot
|
||||
- InternVL3 image/text embedding with optional FlashAttention fallback
|
||||
- Stage-based finetuning controls for action-head-only and VLM finetuning runs
|
||||
- Continuous flow-matching action prediction
|
||||
- Checkpoint save/load through LeRobot policy APIs
|
||||
- Training with `lerobot-train` and evaluation with standard policy inference APIs
|
||||
|
||||
The broader EVO1 project may include additional training scripts and dataset tooling. This page focuses on the LeRobot robot-control policy path.
|
||||
|
||||
## Installation Requirements
|
||||
|
||||
1. Install LeRobot by following the [Installation Guide](./installation).
|
||||
2. Install EVO1 dependencies:
|
||||
|
||||
```bash
|
||||
pip install -e ".[evo1]"
|
||||
```
|
||||
|
||||
For LIBERO evaluation, install the LIBERO extra as well:
|
||||
|
||||
```bash
|
||||
pip install -e ".[evo1,libero]"
|
||||
```
|
||||
|
||||
3. Install a `flash-attn` wheel only if it is compatible with your Python, PyTorch, CUDA, and GPU stack. EVO1 falls back to standard attention when `flash_attn` is not available.
|
||||
|
||||
EVO1 uses the native Hugging Face `transformers` InternVL implementation, so `policy.vlm_model_name` must point to a natively converted checkpoint such as `OpenGVLab/InternVL3-1B-hf` (note the `-hf` suffix). The first run may download the configured VLM checkpoint unless `policy.vlm_model_name` points to a local model directory.
|
||||
|
||||
## Data Requirements
|
||||
|
||||
EVO1 expects a LeRobot dataset with:
|
||||
|
||||
- One to `policy.max_views` visual observations, for example `observation.images.image`
|
||||
- `observation.state`
|
||||
- `action`
|
||||
- A language task instruction in the dataset `task` field, or another field configured with `policy.task_field`
|
||||
|
||||
State and action vectors are padded to `policy.max_state_dim` and `policy.max_action_dim`. Predictions are cropped back to the dataset action dimension before being returned.
|
||||
|
||||
## Usage
|
||||
|
||||
To use EVO1 in a LeRobot configuration, specify:
|
||||
|
||||
```python
|
||||
policy.type=evo1
|
||||
```
|
||||
|
||||
By default, a new EVO1 policy initializes its VLM from:
|
||||
|
||||
```python
|
||||
policy.vlm_model_name=OpenGVLab/InternVL3-1B-hf
|
||||
```
|
||||
|
||||
Once a LeRobot-format EVO1 checkpoint is available, load it with:
|
||||
|
||||
```python
|
||||
policy.path=your-org/your-evo1-checkpoint
|
||||
```
|
||||
|
||||
## Training
|
||||
|
||||
### Stage 1
|
||||
|
||||
Stage 1 freezes the VLM and trains the action head:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--dataset.repo_id=your_org/your_dataset \
|
||||
--policy.type=evo1 \
|
||||
--policy.training_stage=stage1 \
|
||||
--policy.vlm_model_name=OpenGVLab/InternVL3-1B-hf \
|
||||
--policy.device=cuda \
|
||||
--policy.chunk_size=50 \
|
||||
--policy.n_action_steps=50 \
|
||||
--policy.max_state_dim=24 \
|
||||
--policy.max_action_dim=24 \
|
||||
--policy.optimizer_lr=1e-5 \
|
||||
--batch_size=4 \
|
||||
--steps=5000 \
|
||||
--output_dir=./outputs/evo1_stage1
|
||||
```
|
||||
|
||||
### Stage 2
|
||||
|
||||
Stage 2 finetunes the VLM branches and action head. A common workflow starts from a Stage 1 checkpoint:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--dataset.repo_id=your_org/your_dataset \
|
||||
--policy.path=./outputs/evo1_stage1/checkpoints/005000/pretrained_model \
|
||||
--policy.training_stage=stage2 \
|
||||
--policy.vlm_model_name=OpenGVLab/InternVL3-1B-hf \
|
||||
--policy.device=cuda \
|
||||
--policy.chunk_size=50 \
|
||||
--policy.n_action_steps=50 \
|
||||
--policy.max_state_dim=24 \
|
||||
--policy.max_action_dim=24 \
|
||||
--policy.optimizer_lr=1e-5 \
|
||||
--batch_size=4 \
|
||||
--steps=80000 \
|
||||
--output_dir=./outputs/evo1_stage2
|
||||
```
|
||||
|
||||
By default, `policy.training_stage` reapplies the finetuning defaults for that stage. This is important when
|
||||
starting Stage 2 from a Stage 1 checkpoint, because the Stage 1 checkpoint config stores the VLM finetuning
|
||||
flags as disabled. These stage defaults take precedence over saved or manually supplied `policy.finetune_*`
|
||||
flags unless `policy.apply_training_stage_defaults=false`, so set that flag only when manually controlling
|
||||
every finetuning flag.
|
||||
|
||||
### Key Training Parameters
|
||||
|
||||
| Parameter | Default | Description |
|
||||
| --------------------------------------------- | --------------------------- | ----------------------------------------------------------------- |
|
||||
| `policy.vlm_model_name` | `OpenGVLab/InternVL3-1B-hf` | Natively converted InternVL3 checkpoint or local model directory |
|
||||
| `policy.training_stage` | `stage1` | `stage1` trains the action head; `stage2` finetunes VLM branches |
|
||||
| `policy.apply_training_stage_defaults` | `true` | Reapplies stage finetuning defaults after loading a checkpoint |
|
||||
| `policy.vlm_num_layers` | `14` | Number of InternVL3 language layers kept for the policy |
|
||||
| `policy.vlm_dtype` | `bfloat16` | Requested VLM dtype |
|
||||
| `policy.use_flash_attn` | `true` | Requests FlashAttention when installed; otherwise falls back |
|
||||
| `policy.enable_gradient_checkpointing` | `true` | Enables checkpointing on supported InternVL3 modules |
|
||||
| `policy.gradient_checkpointing_use_reentrant` | `false` | Reentrant setting passed to gradient checkpointing when supported |
|
||||
| `policy.chunk_size` | `50` | Number of future actions predicted per chunk |
|
||||
| `policy.n_action_steps` | `50` | Number of actions consumed from a sampled chunk |
|
||||
| `policy.max_state_dim` | `24` | State padding dimension |
|
||||
| `policy.max_action_dim` | `24` | Action padding dimension |
|
||||
| `policy.postprocess_action_dim` | `null` | Optional action dimension returned after EVO1 postprocessing |
|
||||
| `policy.binarize_gripper` | `false` | Binarizes the postprocessed gripper channel for LIBERO-style eval |
|
||||
| `policy.task_field` | `task` | Batch field used as the language prompt |
|
||||
|
||||
## Inference
|
||||
|
||||
Try it out with a trained EVO1 checkpoint:
|
||||
|
||||
```bash
|
||||
lerobot-rollout \
|
||||
--policy.path=your-org/your-evo1-checkpoint \
|
||||
--inference.type=rtc \ # optional
|
||||
...
|
||||
```
|
||||
|
||||
## Results
|
||||
|
||||
### LIBERO Evaluation
|
||||
|
||||
> [!NOTE]
|
||||
> Benchmark results for a `lerobot`-hosted LIBERO checkpoint trained with this implementation
|
||||
> will be added once training completes.
|
||||
|
||||
The official EVO1 LIBERO rollout protocol uses the raw LIBERO camera feature names
|
||||
(`observation.images.agentview_image` and `observation.images.robot0_eye_in_hand_image`), replans every
|
||||
14 actions, and binarizes the gripper command before stepping the simulator. The EVO1 policy postprocessor
|
||||
can crop the padded 24D action back to the 7D LIBERO action space and apply that gripper binarization. To
|
||||
evaluate a LIBERO checkpoint under the same one-episode-per-task setting, keep the raw camera names instead
|
||||
of the default `image`/`image2` mapping and set the LIBERO action postprocessing flags:
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--policy.path=your-org/your-evo1-libero-checkpoint \
|
||||
--policy.vlm_model_name=OpenGVLab/InternVL3-1B-hf \
|
||||
--policy.device=cuda \
|
||||
--policy.use_flash_attn=true \
|
||||
--policy.n_action_steps=14 \
|
||||
--policy.postprocess_action_dim=7 \
|
||||
--policy.binarize_gripper=true \
|
||||
--env.type=libero \
|
||||
--env.task=libero_object \
|
||||
--env.camera_name_mapping="{agentview_image: agentview_image, robot0_eye_in_hand_image: robot0_eye_in_hand_image}" \
|
||||
--env.observation_height=448 \
|
||||
--env.observation_width=448 \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=1
|
||||
```
|
||||
|
||||
## References
|
||||
|
||||
- [EVO1 repository](https://github.com/MINT-SJTU/Evo-1)
|
||||
- [InternVL3-1B-hf](https://huggingface.co/OpenGVLab/InternVL3-1B-hf)
|
||||
|
||||
## License
|
||||
|
||||
This LeRobot integration follows the Apache 2.0 License used by LeRobot. Check the upstream EVO1 and InternVL3 model pages for the licenses of released checkpoints and data.
|
||||
@@ -0,0 +1,18 @@
|
||||
# EVO1
|
||||
|
||||
EVO1 is a Vision-Language-Action policy for robot control. The LeRobot
|
||||
integration uses an InternVL3 vision-language backbone with a flow-matching
|
||||
action head, and supports staged training through the standard LeRobot policy
|
||||
APIs.
|
||||
|
||||
The upstream EVO1 project is available at
|
||||
[MINT-SJTU/Evo-1](https://github.com/MINT-SJTU/Evo-1).
|
||||
|
||||
```bibtex
|
||||
@misc{evo1,
|
||||
title = {EVO1},
|
||||
author = {{MINT-SJTU}},
|
||||
year = {2025},
|
||||
howpublished = {\url{https://github.com/MINT-SJTU/Evo-1}},
|
||||
}
|
||||
```
|
||||
+4
-1
@@ -164,6 +164,7 @@ pynput-dep = ["pynput>=1.7.8,<1.9.0"]
|
||||
pyzmq-dep = ["pyzmq>=26.2.1,<28.0.0"]
|
||||
motorbridge-dep = ["motorbridge>=0.3.2,<0.4.0"]
|
||||
motorbridge-smart-servo-dep = ["motorbridge-smart-servo>=0.0.4,<0.1.0"]
|
||||
timm-dep = ["timm>=1.0.0,<1.1.0"]
|
||||
|
||||
# Motors
|
||||
feetech = ["feetech-servo-sdk>=1.0.0,<2.0.0", "lerobot[pyserial-dep]", "lerobot[deepdiff-dep]"]
|
||||
@@ -220,7 +221,7 @@ groot = [
|
||||
"lerobot[peft-dep]",
|
||||
"lerobot[diffusers-dep]",
|
||||
"dm-tree>=0.1.8,<1.0.0",
|
||||
"timm>=1.0.0,<1.1.0",
|
||||
"lerobot[timm-dep]",
|
||||
"decord>=0.6.0,<1.0.0; (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
|
||||
"ninja>=1.11.1,<2.0.0",
|
||||
"flash-attn>=2.5.9,<3.0.0 ; sys_platform != 'darwin'"
|
||||
@@ -234,6 +235,7 @@ fastwam = [
|
||||
"lerobot[transformers-dep]",
|
||||
"lerobot[diffusers-dep]",
|
||||
]
|
||||
evo1 = ["lerobot[transformers-dep]"]
|
||||
hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.14,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
||||
vla_jepa = ["lerobot[transformers-dep]", "lerobot[diffusers-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||
|
||||
@@ -316,6 +318,7 @@ all = [
|
||||
"lerobot[fastwam]",
|
||||
# "lerobot[groot]", TODO(Steven): Gr00t requires specific installation instructions for flash-attn
|
||||
"lerobot[xvla]",
|
||||
"lerobot[evo1]",
|
||||
"lerobot[hilserl]",
|
||||
"lerobot[vla_jepa]",
|
||||
"lerobot[async]",
|
||||
|
||||
@@ -83,6 +83,28 @@ class VQBeTSchedulerConfig(LRSchedulerConfig):
|
||||
return LambdaLR(optimizer, lr_lambda, -1)
|
||||
|
||||
|
||||
@LRSchedulerConfig.register_subclass("cosine_annealing_with_warmup")
|
||||
@dataclass
|
||||
class CosineAnnealingWithWarmupSchedulerConfig(LRSchedulerConfig):
|
||||
"""Linear warmup followed by cosine annealing from the peak LR to zero.
|
||||
|
||||
Used by EVO1; the annealing phase always spans the remaining training steps.
|
||||
"""
|
||||
|
||||
num_warmup_steps: int
|
||||
|
||||
def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR:
|
||||
def lr_lambda(current_step: int) -> float:
|
||||
if current_step < self.num_warmup_steps:
|
||||
return current_step / max(1, self.num_warmup_steps)
|
||||
progress = (current_step - self.num_warmup_steps) / max(
|
||||
1, num_training_steps - self.num_warmup_steps
|
||||
)
|
||||
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))
|
||||
|
||||
return LambdaLR(optimizer, lr_lambda, -1)
|
||||
|
||||
|
||||
@LRSchedulerConfig.register_subclass("cosine_decay_with_warmup")
|
||||
@dataclass
|
||||
class CosineDecayWithWarmupSchedulerConfig(LRSchedulerConfig):
|
||||
|
||||
@@ -17,6 +17,7 @@ from lerobot.utils.action_interpolator import ActionInterpolator as ActionInterp
|
||||
from .act.configuration_act import ACTConfig as ACTConfig
|
||||
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
|
||||
from .eo1.configuration_eo1 import EO1Config as EO1Config
|
||||
from .evo1.configuration_evo1 import Evo1Config as Evo1Config
|
||||
from .factory import get_policy_class, make_policy, make_policy_config, make_pre_post_processors
|
||||
from .fastwam.configuration_fastwam import FastWAMConfig as FastWAMConfig
|
||||
from .gaussian_actor.configuration_gaussian_actor import GaussianActorConfig as GaussianActorConfig
|
||||
@@ -45,6 +46,7 @@ __all__ = [
|
||||
"EO1Config",
|
||||
"FastWAMConfig",
|
||||
"GaussianActorConfig",
|
||||
"Evo1Config",
|
||||
"GrootConfig",
|
||||
"MolmoAct2Config",
|
||||
"MultiTaskDiTConfig",
|
||||
|
||||
+1
@@ -0,0 +1 @@
|
||||
../../../../docs/source/policy_evo1_README.md
|
||||
@@ -0,0 +1,19 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .configuration_evo1 import Evo1Config
|
||||
from .modeling_evo1 import Evo1Policy
|
||||
from .processor_evo1 import make_evo1_pre_post_processors
|
||||
|
||||
__all__ = ["Evo1Config", "Evo1Policy", "make_evo1_pre_post_processors"]
|
||||
@@ -0,0 +1,252 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.optim.optimizers import AdamWConfig
|
||||
from lerobot.optim.schedulers import CosineAnnealingWithWarmupSchedulerConfig
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
|
||||
|
||||
from ..rtc.configuration_rtc import RTCConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("evo1")
|
||||
@dataclass
|
||||
class Evo1Config(PreTrainedConfig):
|
||||
training_stage: str = "stage1"
|
||||
# When True and the policy runs on CUDA, EVO1 wraps its own forward passes (training and
|
||||
# inference) in a bfloat16 autocast block, so its numerics do not depend on the dtype of any
|
||||
# outer autocast context opened by lerobot-train/lerobot-eval.
|
||||
use_amp: bool = True
|
||||
|
||||
n_obs_steps: int = 1
|
||||
chunk_size: int = 50
|
||||
n_action_steps: int = 50
|
||||
|
||||
max_state_dim: int = 24
|
||||
max_action_dim: int = 24
|
||||
max_views: int = 3
|
||||
image_resolution: tuple[int, int] = (448, 448)
|
||||
empty_cameras: int = 0
|
||||
postprocess_action_dim: int | None = None
|
||||
binarize_gripper: bool = False
|
||||
gripper_index: int = 6
|
||||
gripper_threshold: float = 0.5
|
||||
gripper_below_threshold_value: float = 1.0
|
||||
gripper_above_threshold_value: float = -1.0
|
||||
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.MIN_MAX,
|
||||
"ACTION": NormalizationMode.MIN_MAX,
|
||||
}
|
||||
)
|
||||
|
||||
vlm_model_name: str = "OpenGVLab/InternVL3-1B-hf"
|
||||
vlm_num_layers: int | None = 14
|
||||
vlm_dtype: str = "bfloat16"
|
||||
# Max token length for tokenizing the (image placeholders + instruction) prompt. Prompts longer
|
||||
# than this are right-truncated, so raise it for tasks with long language instructions or many views.
|
||||
max_text_length: int = 1024
|
||||
use_flash_attn: bool = True
|
||||
action_head: str = "flowmatching"
|
||||
embed_dim: int = 896
|
||||
hidden_dim: int = 1024
|
||||
state_hidden_dim: int = 1024
|
||||
num_heads: int = 8
|
||||
num_layers: int = 8
|
||||
dropout: float = 0.0
|
||||
num_inference_timesteps: int = 32
|
||||
num_categories: int = 1
|
||||
# When True, the action head is conditioned on a single pooled VL token (the last non-padding
|
||||
# token of the causal decoder) instead of the full fused token sequence.
|
||||
return_cls_only: bool = False
|
||||
enable_gradient_checkpointing: bool = True
|
||||
gradient_checkpointing_use_reentrant: bool = False
|
||||
|
||||
finetune_vlm: bool | None = None
|
||||
finetune_language_model: bool | None = None
|
||||
finetune_vision_model: bool | None = None
|
||||
finetune_action_head: bool | None = None
|
||||
# Reapply stage defaults after loading checkpoint configs so stage2 cannot
|
||||
# accidentally inherit the frozen VLM flags stored by a stage1 checkpoint.
|
||||
apply_training_stage_defaults: bool = True
|
||||
|
||||
task_field: str = "task"
|
||||
embodiment_id_field: str | None = None
|
||||
default_embodiment_id: int = 0
|
||||
|
||||
# Real-Time Chunking guidance for asynchronous inference (lerobot-rollout --inference.type=rtc
|
||||
# sets this and calls init_rtc_processor()); None disables RTC.
|
||||
rtc_config: RTCConfig | None = None
|
||||
|
||||
optimizer_lr: float = 1e-5
|
||||
optimizer_betas: tuple[float, float] = (0.9, 0.999)
|
||||
optimizer_eps: float = 1e-8
|
||||
optimizer_weight_decay: float = 1e-5
|
||||
optimizer_grad_clip_norm: float = 1.0
|
||||
|
||||
scheduler_warmup_steps: int = 300
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
if self.training_stage not in {"stage1", "stage2"}:
|
||||
raise ValueError(
|
||||
f"Unsupported EVO1 training_stage '{self.training_stage}', expected 'stage1' or 'stage2'"
|
||||
)
|
||||
|
||||
if self.apply_training_stage_defaults:
|
||||
stage_defaults = {
|
||||
"stage1": {
|
||||
"finetune_vlm": False,
|
||||
"finetune_language_model": False,
|
||||
"finetune_vision_model": False,
|
||||
"finetune_action_head": True,
|
||||
},
|
||||
"stage2": {
|
||||
"finetune_vlm": True,
|
||||
"finetune_language_model": True,
|
||||
"finetune_vision_model": True,
|
||||
"finetune_action_head": True,
|
||||
},
|
||||
}[self.training_stage]
|
||||
for flag_name, default_value in stage_defaults.items():
|
||||
current_value = getattr(self, flag_name)
|
||||
if current_value is not None and current_value != default_value:
|
||||
logger.warning(
|
||||
"EVO1 %s=%s is overridden by training_stage=%s default %s. "
|
||||
"Set apply_training_stage_defaults=false to keep explicit finetuning flags.",
|
||||
flag_name,
|
||||
current_value,
|
||||
self.training_stage,
|
||||
default_value,
|
||||
)
|
||||
setattr(self, flag_name, default_value)
|
||||
elif self.training_stage == "stage1":
|
||||
if self.finetune_vlm is None:
|
||||
self.finetune_vlm = False
|
||||
if self.finetune_language_model is None:
|
||||
self.finetune_language_model = False
|
||||
if self.finetune_vision_model is None:
|
||||
self.finetune_vision_model = False
|
||||
if self.finetune_action_head is None:
|
||||
self.finetune_action_head = True
|
||||
elif self.training_stage == "stage2":
|
||||
has_explicit_branch_flags = any(
|
||||
flag is not None for flag in (self.finetune_language_model, self.finetune_vision_model)
|
||||
)
|
||||
if not has_explicit_branch_flags:
|
||||
# An explicit finetune_vlm decides both branches; otherwise stage2 defaults to a
|
||||
# full-VLM finetune.
|
||||
vlm_finetune = self.finetune_vlm if self.finetune_vlm is not None else True
|
||||
self.finetune_vlm = vlm_finetune
|
||||
self.finetune_language_model = vlm_finetune
|
||||
self.finetune_vision_model = vlm_finetune
|
||||
elif self.finetune_vlm is None:
|
||||
self.finetune_vlm = bool(self.finetune_language_model or self.finetune_vision_model)
|
||||
if self.finetune_action_head is None:
|
||||
self.finetune_action_head = True
|
||||
|
||||
if self.finetune_vlm is None:
|
||||
self.finetune_vlm = False
|
||||
if self.finetune_language_model is None:
|
||||
self.finetune_language_model = False
|
||||
if self.finetune_vision_model is None:
|
||||
self.finetune_vision_model = False
|
||||
if self.finetune_action_head is None:
|
||||
self.finetune_action_head = False
|
||||
|
||||
branch_vlm = self.finetune_language_model or self.finetune_vision_model
|
||||
if self.finetune_vlm != branch_vlm:
|
||||
raise ValueError(
|
||||
"Inconsistent EVO1 finetune config: "
|
||||
f"finetune_vlm={self.finetune_vlm} but "
|
||||
f"(finetune_language_model or finetune_vision_model)={branch_vlm}. "
|
||||
"When branch-level flags are used, finetune_vlm must match their effective union."
|
||||
)
|
||||
|
||||
if self.n_action_steps > self.chunk_size:
|
||||
raise ValueError(
|
||||
f"n_action_steps ({self.n_action_steps}) must be <= chunk_size ({self.chunk_size})"
|
||||
)
|
||||
if len(self.image_resolution) != 2 or self.image_resolution[0] != self.image_resolution[1]:
|
||||
raise ValueError(
|
||||
"EVO1 currently expects a square image_resolution because InternVL3 preprocessing "
|
||||
f"uses a scalar image_size, got {self.image_resolution}."
|
||||
)
|
||||
if not 0 <= self.default_embodiment_id < self.num_categories:
|
||||
raise ValueError(
|
||||
f"default_embodiment_id ({self.default_embodiment_id}) must be in "
|
||||
f"[0, num_categories={self.num_categories})"
|
||||
)
|
||||
|
||||
def validate_features(self) -> None:
|
||||
if self.input_features is None:
|
||||
self.input_features = {}
|
||||
if self.output_features is None:
|
||||
self.output_features = {}
|
||||
|
||||
for i in range(self.empty_cameras):
|
||||
key = OBS_IMAGES + f".empty_camera_{i}"
|
||||
if key not in self.input_features:
|
||||
self.input_features[key] = PolicyFeature(
|
||||
type=FeatureType.VISUAL,
|
||||
shape=(3, *self.image_resolution),
|
||||
)
|
||||
|
||||
if OBS_STATE not in self.input_features:
|
||||
self.input_features[OBS_STATE] = PolicyFeature(
|
||||
type=FeatureType.STATE,
|
||||
shape=(self.max_state_dim,),
|
||||
)
|
||||
|
||||
if ACTION not in self.output_features:
|
||||
self.output_features[ACTION] = PolicyFeature(
|
||||
type=FeatureType.ACTION,
|
||||
shape=(self.max_action_dim,),
|
||||
)
|
||||
|
||||
def get_optimizer_preset(self) -> AdamWConfig:
|
||||
return AdamWConfig(
|
||||
lr=self.optimizer_lr,
|
||||
betas=self.optimizer_betas,
|
||||
eps=self.optimizer_eps,
|
||||
weight_decay=self.optimizer_weight_decay,
|
||||
grad_clip_norm=self.optimizer_grad_clip_norm,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self):
|
||||
return CosineAnnealingWithWarmupSchedulerConfig(
|
||||
num_warmup_steps=self.scheduler_warmup_steps,
|
||||
)
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> list[int]:
|
||||
return [0]
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list[int]:
|
||||
return list(range(self.chunk_size))
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
return None
|
||||
@@ -0,0 +1,210 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .configuration_evo1 import Evo1Config
|
||||
from .flow_matching import FlowmatchingActionHead
|
||||
from .internvl3_embedder import InternVL3Embedder
|
||||
|
||||
|
||||
class Evo1Model(nn.Module):
|
||||
def __init__(self, config: Evo1Config, vlm_hub_kwargs: dict | None = None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self._device = config.device
|
||||
self.return_cls_only = config.return_cls_only
|
||||
# Set by Evo1Policy.init_rtc_processor() when config.rtc_config is provided.
|
||||
self.rtc_processor = None
|
||||
|
||||
# Gradient checkpointing only pays off when the VLM is actually being trained; keep it off
|
||||
# whenever every VLM branch is frozen so the frozen forward stays cheap.
|
||||
tracks_vlm_gradients = bool(
|
||||
config.finetune_vlm or config.finetune_language_model or config.finetune_vision_model
|
||||
)
|
||||
enable_gradient_checkpointing = config.enable_gradient_checkpointing and tracks_vlm_gradients
|
||||
|
||||
self.embedder = InternVL3Embedder(
|
||||
model_name=config.vlm_model_name,
|
||||
image_size=int(config.image_resolution[0]),
|
||||
device=self._device,
|
||||
num_language_layers=config.vlm_num_layers,
|
||||
model_dtype=config.vlm_dtype,
|
||||
use_flash_attn=config.use_flash_attn,
|
||||
max_text_length=config.max_text_length,
|
||||
enable_gradient_checkpointing=enable_gradient_checkpointing,
|
||||
gradient_checkpointing_use_reentrant=config.gradient_checkpointing_use_reentrant,
|
||||
hub_kwargs=vlm_hub_kwargs,
|
||||
)
|
||||
|
||||
action_head_type = config.action_head.lower()
|
||||
if action_head_type != "flowmatching":
|
||||
raise NotImplementedError(f"Unknown action_head: {action_head_type}")
|
||||
|
||||
horizon = config.chunk_size
|
||||
per_action_dim = config.max_action_dim
|
||||
action_dim = horizon * per_action_dim
|
||||
|
||||
self.horizon = horizon
|
||||
self.per_action_dim = per_action_dim
|
||||
self.action_head = FlowmatchingActionHead(
|
||||
embed_dim=config.embed_dim,
|
||||
hidden_dim=config.hidden_dim,
|
||||
action_dim=action_dim,
|
||||
horizon=horizon,
|
||||
per_action_dim=per_action_dim,
|
||||
num_heads=config.num_heads,
|
||||
num_layers=config.num_layers,
|
||||
dropout=config.dropout,
|
||||
num_inference_timesteps=config.num_inference_timesteps,
|
||||
num_categories=config.num_categories,
|
||||
state_dim=config.max_state_dim,
|
||||
state_hidden_dim=config.state_hidden_dim,
|
||||
).to(self._device)
|
||||
|
||||
def get_vl_embeddings(
|
||||
self,
|
||||
images: list[torch.Tensor],
|
||||
image_mask: torch.Tensor,
|
||||
prompt: str | list[str] | None = None,
|
||||
return_cls_only: bool | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
"""Fused VL embeddings from per-camera image batches.
|
||||
|
||||
Args:
|
||||
images: list of per-camera tensors, each shaped ``(B, C, H, W)`` with values in ``[0, 1]``.
|
||||
image_mask: bool tensor ``(B, max_views)`` marking present views.
|
||||
|
||||
Returns:
|
||||
``(embeddings, valid_mask)``: the fused tokens and the bool mask of attendable context
|
||||
positions (None when a single pooled token is returned).
|
||||
"""
|
||||
if return_cls_only is None:
|
||||
return_cls_only = self.return_cls_only
|
||||
if not images:
|
||||
raise ValueError("EVO1 expects at least one image per sample.")
|
||||
|
||||
batch_size = images[0].shape[0]
|
||||
if prompt is None:
|
||||
prompts = [""] * batch_size
|
||||
elif isinstance(prompt, str):
|
||||
prompts = [prompt] * batch_size
|
||||
else:
|
||||
prompts = [str(p) for p in prompt]
|
||||
if len(prompts) != batch_size:
|
||||
raise ValueError(
|
||||
f"Prompt batch size {len(prompts)} does not match image batch size {batch_size}"
|
||||
)
|
||||
|
||||
if image_mask.dim() == 1:
|
||||
image_mask = image_mask.unsqueeze(0)
|
||||
if image_mask.shape[0] != batch_size:
|
||||
raise ValueError(
|
||||
f"image_mask batch size {image_mask.shape[0]} does not match image batch size {batch_size}"
|
||||
)
|
||||
|
||||
return self.embedder.get_fused_image_text_embedding_batched(
|
||||
camera_images=images,
|
||||
image_masks=image_mask,
|
||||
text_prompts=prompts,
|
||||
return_cls_only=return_cls_only,
|
||||
)
|
||||
|
||||
def predict_action(
|
||||
self,
|
||||
fused_tokens: torch.Tensor,
|
||||
state: torch.Tensor,
|
||||
actions_gt: torch.Tensor | None = None,
|
||||
action_mask: torch.Tensor | None = None,
|
||||
embodiment_ids: torch.Tensor | None = None,
|
||||
context_mask: torch.Tensor | None = None,
|
||||
inference_delay: int | None = None,
|
||||
prev_chunk_left_over: torch.Tensor | None = None,
|
||||
execution_horizon: int | None = None,
|
||||
):
|
||||
if actions_gt is None:
|
||||
return self.action_head.get_action(
|
||||
fused_tokens,
|
||||
state=state,
|
||||
action_mask=action_mask,
|
||||
embodiment_id=embodiment_ids,
|
||||
context_mask=context_mask,
|
||||
inference_delay=inference_delay,
|
||||
prev_chunk_left_over=prev_chunk_left_over,
|
||||
execution_horizon=execution_horizon,
|
||||
rtc_processor=self.rtc_processor,
|
||||
)
|
||||
return self.action_head(
|
||||
fused_tokens,
|
||||
state=state,
|
||||
actions_gt=actions_gt,
|
||||
action_mask=action_mask,
|
||||
embodiment_id=embodiment_ids,
|
||||
context_mask=context_mask,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
fused_tokens: torch.Tensor,
|
||||
state: torch.Tensor | None = None,
|
||||
actions_gt: torch.Tensor | None = None,
|
||||
action_mask: torch.Tensor | None = None,
|
||||
embodiment_ids: torch.Tensor | None = None,
|
||||
context_mask: torch.Tensor | None = None,
|
||||
inference_delay: int | None = None,
|
||||
prev_chunk_left_over: torch.Tensor | None = None,
|
||||
execution_horizon: int | None = None,
|
||||
):
|
||||
return self.predict_action(
|
||||
fused_tokens,
|
||||
state,
|
||||
actions_gt,
|
||||
action_mask,
|
||||
embodiment_ids,
|
||||
context_mask,
|
||||
inference_delay,
|
||||
prev_chunk_left_over,
|
||||
execution_horizon,
|
||||
)
|
||||
|
||||
def _set_module_trainable(self, module: nn.Module, trainable: bool):
|
||||
for param in module.parameters():
|
||||
param.requires_grad = trainable
|
||||
|
||||
def _vlm_submodule(self, name: str) -> nn.Module:
|
||||
module = getattr(self.embedder.model, name, None)
|
||||
if not isinstance(module, nn.Module):
|
||||
raise AttributeError(
|
||||
f"InternVL model {type(self.embedder.model).__name__} has no '{name}' submodule; "
|
||||
"the native HF InternVL layout (language_model / vision_tower / "
|
||||
"multi_modal_projector) is required to apply the EVO1 finetune flags."
|
||||
)
|
||||
return module
|
||||
|
||||
def set_finetune_flags(self):
|
||||
# __post_init__ resolves every finetune flag to a concrete boolean, so branch-level flags
|
||||
# are authoritative here. Freeze everything first, then re-enable the requested branches.
|
||||
self._set_module_trainable(self.embedder, False)
|
||||
self._set_module_trainable(
|
||||
self._vlm_submodule("language_model"), bool(self.config.finetune_language_model)
|
||||
)
|
||||
finetune_vision = bool(self.config.finetune_vision_model)
|
||||
self._set_module_trainable(self._vlm_submodule("vision_tower"), finetune_vision)
|
||||
self._set_module_trainable(self._vlm_submodule("multi_modal_projector"), finetune_vision)
|
||||
|
||||
if not self.config.finetune_action_head:
|
||||
self._set_module_trainable(self.action_head, False)
|
||||
@@ -0,0 +1,483 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SinusoidalPositionalEncoding(nn.Module):
|
||||
def __init__(self, dim: int, max_len: int = 1000):
|
||||
super().__init__()
|
||||
pe = torch.zeros(max_len, dim)
|
||||
position = torch.arange(0, max_len).unsqueeze(1)
|
||||
div_term = torch.exp(torch.arange(0, dim, 2) * -(math.log(10000.0) / dim))
|
||||
pe[:, 0::2] = torch.sin(position * div_term)
|
||||
pe[:, 1::2] = torch.cos(position * div_term)
|
||||
pe = pe.unsqueeze(0)
|
||||
self.register_buffer("pe", pe)
|
||||
|
||||
def forward(self, seq_len: int):
|
||||
if seq_len > self.pe.size(1):
|
||||
self._extend_pe(seq_len)
|
||||
return self.pe[:, :seq_len, :]
|
||||
|
||||
def _extend_pe(self, new_max_len):
|
||||
old_max_len, dim = self.pe.size(1), self.pe.size(2)
|
||||
if new_max_len <= old_max_len:
|
||||
return
|
||||
extra_positions = torch.arange(old_max_len, new_max_len, dtype=torch.float).unsqueeze(1)
|
||||
div_term = torch.exp(torch.arange(0, dim, 2, dtype=torch.float) * -(math.log(10000.0) / dim))
|
||||
extra_pe = torch.zeros(new_max_len - old_max_len, dim)
|
||||
extra_pe[:, 0::2] = torch.sin(extra_positions * div_term)
|
||||
extra_pe[:, 1::2] = torch.cos(extra_positions * div_term)
|
||||
extra_pe = extra_pe.unsqueeze(0)
|
||||
new_pe = torch.cat([self.pe, extra_pe.to(self.pe.device)], dim=1)
|
||||
self.pe = new_pe
|
||||
|
||||
|
||||
class CategorySpecificLinear(nn.Module):
|
||||
def __init__(self, in_dim: int, out_dim: int, num_categories: int = 1):
|
||||
super().__init__()
|
||||
self.num_categories = num_categories
|
||||
if num_categories <= 1:
|
||||
self.linear = nn.Linear(in_dim, out_dim)
|
||||
else:
|
||||
self.weight = nn.Parameter(torch.empty(num_categories, in_dim, out_dim))
|
||||
self.bias = nn.Parameter(torch.zeros(num_categories, out_dim))
|
||||
# Initialize each per-category (in_dim, out_dim) matrix separately: xavier on the full
|
||||
# 3D tensor would compute fan_in = in_dim * out_dim and badly under-scale the weights.
|
||||
for category in range(num_categories):
|
||||
nn.init.xavier_uniform_(self.weight[category])
|
||||
|
||||
def forward(self, x: torch.Tensor, category_id: torch.LongTensor):
|
||||
if self.num_categories <= 1:
|
||||
if x.dtype != self.linear.weight.dtype:
|
||||
x = x.to(dtype=self.linear.weight.dtype)
|
||||
return self.linear(x)
|
||||
|
||||
if x.dtype != self.weight.dtype:
|
||||
x = x.to(dtype=self.weight.dtype)
|
||||
|
||||
orig_shape = x.shape
|
||||
x_flat = x.reshape(-1, orig_shape[-1])
|
||||
if category_id.dim() == 0:
|
||||
cid = category_id.item()
|
||||
out = x_flat @ self.weight[cid] + self.bias[cid]
|
||||
else:
|
||||
category_id = category_id.reshape(-1)
|
||||
if category_id.numel() != x_flat.size(0):
|
||||
raise ValueError(
|
||||
f"category_id length {category_id.numel()} does not match flattened batch {x_flat.size(0)}"
|
||||
)
|
||||
weight_selected = self.weight[category_id]
|
||||
bias_selected = self.bias[category_id]
|
||||
out = torch.bmm(x_flat.unsqueeze(1), weight_selected).squeeze(1) + bias_selected
|
||||
out_shape = orig_shape[:-1] + (out.shape[-1],)
|
||||
return out.view(out_shape)
|
||||
|
||||
|
||||
class CategorySpecificMLP(nn.Module):
|
||||
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, num_categories: int = 1):
|
||||
super().__init__()
|
||||
self.fc1 = CategorySpecificLinear(input_dim, hidden_dim, num_categories)
|
||||
self.fc2 = CategorySpecificLinear(hidden_dim, output_dim, num_categories)
|
||||
self.activation = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, x: torch.Tensor, category_id: torch.LongTensor):
|
||||
out = self.activation(self.fc1(x, category_id))
|
||||
out = self.fc2(out, category_id)
|
||||
return out
|
||||
|
||||
|
||||
class MultiEmbodimentActionEncoder(nn.Module):
|
||||
def __init__(
|
||||
self, action_dim: int, embed_dim: int, hidden_dim: int, horizon: int, num_categories: int = 1
|
||||
):
|
||||
super().__init__()
|
||||
self.horizon = horizon
|
||||
self.embed_dim = embed_dim
|
||||
self.num_categories = num_categories
|
||||
|
||||
self.W1 = CategorySpecificLinear(action_dim, hidden_dim, num_categories)
|
||||
self.W2 = CategorySpecificLinear(hidden_dim, hidden_dim, num_categories)
|
||||
self.W3 = CategorySpecificLinear(hidden_dim, embed_dim, num_categories)
|
||||
|
||||
self.pos_encoding = SinusoidalPositionalEncoding(hidden_dim, max_len=horizon)
|
||||
self.activation = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, action_seq: torch.Tensor, category_id: torch.LongTensor):
|
||||
batch_size, horizon, action_dim = action_seq.shape
|
||||
if self.horizon != horizon:
|
||||
raise ValueError(
|
||||
f"Action sequence length must match horizon: got {horizon}, expected {self.horizon}."
|
||||
)
|
||||
|
||||
x = action_seq.reshape(batch_size * horizon, action_dim)
|
||||
if category_id.dim() == 0:
|
||||
cat_ids = category_id.expand(horizon * batch_size)
|
||||
else:
|
||||
cat_ids = category_id.unsqueeze(1).expand(batch_size, horizon).reshape(batch_size * horizon)
|
||||
|
||||
out = self.activation(self.W1(x, cat_ids))
|
||||
pos_enc = self.pos_encoding(horizon).to(device=out.device, dtype=out.dtype)
|
||||
out = out.view(batch_size, horizon, -1) + pos_enc
|
||||
out = out.view(batch_size * horizon, -1)
|
||||
out = self.activation(self.W2(out, cat_ids))
|
||||
out = self.W3(out, cat_ids)
|
||||
return out.view(batch_size, horizon, self.embed_dim)
|
||||
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
def __init__(self, embed_dim: int, num_heads: int, hidden_dim: int, dropout: float = 0.0):
|
||||
super().__init__()
|
||||
self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
|
||||
self.norm1 = nn.LayerNorm(embed_dim)
|
||||
self.norm2 = nn.LayerNorm(embed_dim)
|
||||
self.ff = nn.Sequential(nn.Linear(embed_dim, hidden_dim), nn.GELU(), nn.Linear(hidden_dim, embed_dim))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
action_tokens: torch.Tensor,
|
||||
context_tokens: torch.Tensor,
|
||||
time_emb: torch.Tensor,
|
||||
context_key_padding_mask: torch.Tensor | None = None,
|
||||
):
|
||||
x = self.norm1(action_tokens)
|
||||
attn_out, _ = self.attn(x, context_tokens, context_tokens, key_padding_mask=context_key_padding_mask)
|
||||
x = action_tokens + attn_out
|
||||
x2 = self.norm2(x)
|
||||
if time_emb is not None:
|
||||
x2 = x2 + time_emb.unsqueeze(1)
|
||||
ff_out = self.ff(x2)
|
||||
return x + ff_out
|
||||
|
||||
|
||||
class FlowmatchingActionHead(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int = 896,
|
||||
hidden_dim: int = 1024,
|
||||
action_dim: int = 16 * 7,
|
||||
horizon: int = 16,
|
||||
per_action_dim: int = 7,
|
||||
num_heads: int = 8,
|
||||
num_layers: int = 8,
|
||||
dropout: float = 0.0,
|
||||
num_inference_timesteps: int = 20,
|
||||
num_categories: int = 1,
|
||||
state_dim: int | None = None,
|
||||
state_hidden_dim: int | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
logger.info("FlowmatchingActionHead num_inference_timesteps=%s", num_inference_timesteps)
|
||||
self.embed_dim = embed_dim
|
||||
self.horizon = horizon
|
||||
self.per_action_dim = per_action_dim
|
||||
self.action_dim = action_dim
|
||||
self.num_inference_timesteps = num_inference_timesteps
|
||||
self.num_categories = num_categories
|
||||
|
||||
self.time_pos_enc = SinusoidalPositionalEncoding(embed_dim, max_len=1000)
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
embed_dim=embed_dim,
|
||||
num_heads=num_heads,
|
||||
hidden_dim=embed_dim * 4,
|
||||
dropout=dropout,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
self.norm_out = nn.LayerNorm(embed_dim)
|
||||
self.seq_pool_proj = nn.Linear(self.horizon * self.embed_dim, self.embed_dim)
|
||||
self.mlp_head = CategorySpecificMLP(
|
||||
input_dim=embed_dim,
|
||||
hidden_dim=hidden_dim,
|
||||
output_dim=action_dim,
|
||||
num_categories=num_categories,
|
||||
)
|
||||
|
||||
self.state_encoder = None
|
||||
if state_dim is not None:
|
||||
state_hidden = state_hidden_dim if state_hidden_dim is not None else embed_dim
|
||||
self.state_encoder = CategorySpecificMLP(
|
||||
input_dim=state_dim,
|
||||
hidden_dim=state_hidden,
|
||||
output_dim=embed_dim,
|
||||
num_categories=num_categories,
|
||||
)
|
||||
|
||||
if horizon > 1:
|
||||
self.action_encoder = MultiEmbodimentActionEncoder(
|
||||
action_dim=self.per_action_dim,
|
||||
embed_dim=embed_dim,
|
||||
hidden_dim=embed_dim,
|
||||
horizon=horizon,
|
||||
num_categories=num_categories,
|
||||
)
|
||||
self.single_action_proj = None
|
||||
else:
|
||||
self.action_encoder = None
|
||||
self.single_action_proj = nn.Linear(self.per_action_dim, self.embed_dim)
|
||||
|
||||
def _project_actions(self, action_seq: torch.Tensor, embodiment_id: torch.LongTensor) -> torch.Tensor:
|
||||
if self.horizon > 1 and self.action_encoder is not None:
|
||||
return self.action_encoder(action_seq, embodiment_id)
|
||||
if self.single_action_proj is None:
|
||||
raise RuntimeError("single_action_proj is not initialized for horizon <= 1.")
|
||||
return self.single_action_proj(action_seq)
|
||||
|
||||
def _expand_action_mask(
|
||||
self,
|
||||
action_mask: torch.Tensor,
|
||||
batch_size: int,
|
||||
per_action_dim: int,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
if action_mask is None:
|
||||
raise ValueError("action_mask must be provided for flow matching inference.")
|
||||
|
||||
if action_mask.dim() == 2:
|
||||
expected_last_dim = self.horizon * per_action_dim
|
||||
if action_mask.shape == (batch_size, expected_last_dim):
|
||||
expanded_mask = action_mask.reshape(batch_size, self.horizon, per_action_dim)
|
||||
elif action_mask.shape == (batch_size, per_action_dim):
|
||||
expanded_mask = action_mask.unsqueeze(1).expand(batch_size, self.horizon, per_action_dim)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Expected action_mask shape {(batch_size, expected_last_dim)} or "
|
||||
f"{(batch_size, per_action_dim)}, got {tuple(action_mask.shape)}"
|
||||
)
|
||||
elif action_mask.dim() == 3:
|
||||
expected_shape = (batch_size, self.horizon, per_action_dim)
|
||||
if tuple(action_mask.shape) != expected_shape:
|
||||
raise ValueError(
|
||||
f"Expected action_mask shape {expected_shape}, got {tuple(action_mask.shape)}"
|
||||
)
|
||||
expanded_mask = action_mask
|
||||
else:
|
||||
raise ValueError(f"Unsupported action_mask rank: {action_mask.dim()}")
|
||||
|
||||
return expanded_mask.to(device=device, dtype=dtype)
|
||||
|
||||
def _prepare_context(
|
||||
self,
|
||||
fused_tokens: torch.Tensor,
|
||||
state: torch.Tensor | None,
|
||||
embodiment_id: torch.LongTensor | None,
|
||||
context_mask: torch.Tensor | None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None, torch.LongTensor]:
|
||||
"""Normalize the VL context and embodiment ids shared by training and inference.
|
||||
|
||||
Returns the context tokens ``(B, S, E)``, a key_padding_mask for
|
||||
``nn.MultiheadAttention`` (True = ignore) or None, and the resolved embodiment ids.
|
||||
"""
|
||||
batch_size = fused_tokens.size(0)
|
||||
device = fused_tokens.device
|
||||
if embodiment_id is None:
|
||||
embodiment_id = torch.zeros(batch_size, dtype=torch.long, device=device)
|
||||
elif self.num_categories > 1 and (
|
||||
int(embodiment_id.min()) < 0 or int(embodiment_id.max()) >= self.num_categories
|
||||
):
|
||||
raise ValueError(
|
||||
f"embodiment ids must be in [0, num_categories={self.num_categories}), "
|
||||
f"got range [{int(embodiment_id.min())}, {int(embodiment_id.max())}]"
|
||||
)
|
||||
|
||||
context_tokens = fused_tokens
|
||||
if context_tokens.dim() == 2:
|
||||
# A single pooled VL token (return_cls_only): give it a sequence dim of 1.
|
||||
context_tokens = context_tokens.unsqueeze(1)
|
||||
context_mask = None
|
||||
if state is not None and self.state_encoder is not None:
|
||||
state_emb = self.state_encoder(state, embodiment_id).unsqueeze(1)
|
||||
context_tokens = torch.cat([context_tokens, state_emb], dim=1)
|
||||
if context_mask is not None:
|
||||
state_valid = torch.ones(batch_size, 1, dtype=torch.bool, device=context_mask.device)
|
||||
context_mask = torch.cat([context_mask.to(torch.bool), state_valid], dim=1)
|
||||
|
||||
key_padding_mask = None if context_mask is None else ~context_mask.to(torch.bool)
|
||||
return context_tokens, key_padding_mask, embodiment_id
|
||||
|
||||
def forward(
|
||||
self,
|
||||
fused_tokens: torch.Tensor,
|
||||
state: torch.Tensor = None,
|
||||
actions_gt: torch.Tensor = None,
|
||||
embodiment_id: torch.LongTensor = None,
|
||||
action_mask: torch.Tensor = None,
|
||||
context_mask: torch.Tensor = None,
|
||||
):
|
||||
if actions_gt is None:
|
||||
return self.get_action(
|
||||
fused_tokens,
|
||||
state=state,
|
||||
embodiment_id=embodiment_id,
|
||||
action_mask=action_mask,
|
||||
context_mask=context_mask,
|
||||
)
|
||||
|
||||
batch_size = fused_tokens.size(0)
|
||||
device = fused_tokens.device
|
||||
context_tokens, key_padding_mask, embodiment_id = self._prepare_context(
|
||||
fused_tokens, state, embodiment_id, context_mask
|
||||
)
|
||||
|
||||
t = (
|
||||
torch.distributions.Beta(2, 2)
|
||||
.sample((batch_size,))
|
||||
.clamp(0.02, 0.98)
|
||||
.to(device)
|
||||
.to(dtype=self.dtype)
|
||||
)
|
||||
time_index = (t * 999).long().clamp_(0, 999)
|
||||
time_emb = self.time_pos_enc(1000)[:, time_index, :].squeeze(0).to(dtype=context_tokens.dtype)
|
||||
|
||||
actions_gt_seq = actions_gt
|
||||
noise = torch.rand_like(actions_gt) * 2 - 1
|
||||
if action_mask is not None:
|
||||
action_mask = action_mask.to(dtype=noise.dtype, device=noise.device)
|
||||
if action_mask.shape != noise.shape:
|
||||
raise ValueError(f"action_mask shape {action_mask.shape} != noise shape {noise.shape}")
|
||||
actions_gt_seq = actions_gt_seq * action_mask
|
||||
noise = noise * action_mask
|
||||
|
||||
if self.horizon > 1:
|
||||
noise_seq = noise.view(batch_size, self.horizon, self.per_action_dim)
|
||||
else:
|
||||
noise_seq = noise if noise.dim() == 3 else noise.unsqueeze(1)
|
||||
t_broadcast = t.view(batch_size, 1, 1)
|
||||
action_intermediate_seq = (1 - t_broadcast) * noise_seq + t_broadcast * actions_gt_seq
|
||||
|
||||
action_tokens = self._project_actions(action_intermediate_seq, embodiment_id)
|
||||
target_dtype = self.dtype
|
||||
action_tokens = action_tokens.to(dtype=target_dtype)
|
||||
context_tokens = context_tokens.to(dtype=target_dtype)
|
||||
time_emb = time_emb.to(dtype=target_dtype)
|
||||
|
||||
x = action_tokens
|
||||
for block in self.transformer_blocks:
|
||||
x = block(x, context_tokens, time_emb, key_padding_mask)
|
||||
x = self.norm_out(x)
|
||||
|
||||
if self.horizon > 1:
|
||||
x_flat = x.reshape(batch_size, -1)
|
||||
x_pooled = self.seq_pool_proj(x_flat)
|
||||
else:
|
||||
x_pooled = x.squeeze(1)
|
||||
|
||||
pred_velocity = self.mlp_head(x_pooled, embodiment_id)
|
||||
return pred_velocity, noise
|
||||
|
||||
def get_action(
|
||||
self,
|
||||
fused_tokens: torch.Tensor,
|
||||
state: torch.Tensor = None,
|
||||
embodiment_id: torch.LongTensor = None,
|
||||
action_mask: torch.Tensor = None,
|
||||
context_mask: torch.Tensor = None,
|
||||
inference_delay: int | None = None,
|
||||
prev_chunk_left_over: torch.Tensor | None = None,
|
||||
execution_horizon: int | None = None,
|
||||
rtc_processor=None,
|
||||
):
|
||||
batch_size = fused_tokens.size(0)
|
||||
device = fused_tokens.device
|
||||
context_tokens, key_padding_mask, embodiment_id = self._prepare_context(
|
||||
fused_tokens, state, embodiment_id, context_mask
|
||||
)
|
||||
|
||||
action_dim_total = self.action_dim
|
||||
per_action_dim = self.per_action_dim
|
||||
|
||||
action = torch.rand(batch_size, action_dim_total, device=device, dtype=context_tokens.dtype) * 2 - 1
|
||||
action_seq = action.view(batch_size, self.horizon, per_action_dim)
|
||||
action_mask = self._expand_action_mask(
|
||||
action_mask,
|
||||
batch_size=batch_size,
|
||||
per_action_dim=per_action_dim,
|
||||
device=action_seq.device,
|
||||
dtype=action_seq.dtype,
|
||||
)
|
||||
action_seq = action_seq * action_mask
|
||||
|
||||
target_dtype = self.dtype
|
||||
context_tokens = context_tokens.to(dtype=target_dtype)
|
||||
|
||||
num_steps = int(self.num_inference_timesteps)
|
||||
if num_steps <= 0:
|
||||
raise ValueError(f"num_inference_timesteps must be positive, got {num_steps}")
|
||||
dt = 1.0 / num_steps
|
||||
|
||||
use_rtc = rtc_processor is not None and (
|
||||
inference_delay is not None or prev_chunk_left_over is not None
|
||||
)
|
||||
|
||||
def predict_velocity(seq: torch.Tensor, step_time_emb: torch.Tensor) -> torch.Tensor:
|
||||
"""Predict the masked flow velocity (x1 - x0 convention) for one integration step."""
|
||||
seq = seq * action_mask
|
||||
action_tokens = self._project_actions(seq, embodiment_id).to(dtype=target_dtype)
|
||||
x = action_tokens
|
||||
for block in self.transformer_blocks:
|
||||
x = block(x, context_tokens, step_time_emb, key_padding_mask)
|
||||
x = self.norm_out(x)
|
||||
x_pooled = self.seq_pool_proj(x.reshape(batch_size, -1)) if self.horizon > 1 else x.squeeze(1)
|
||||
pred = self.mlp_head(x_pooled, embodiment_id)
|
||||
return pred.view(batch_size, self.horizon, per_action_dim) * action_mask
|
||||
|
||||
for i in range(num_steps):
|
||||
t = i / num_steps
|
||||
time_index = min(int(t * 999), 999)
|
||||
time_emb = self.time_pos_enc(1000)[:, time_index, :].to(device).squeeze(0).to(dtype=target_dtype)
|
||||
time_emb = time_emb.unsqueeze(0).repeat(batch_size, 1)
|
||||
|
||||
if use_rtc:
|
||||
# RTCProcessor assumes the pi0 flow convention: its `time` runs 1 -> 0 and the
|
||||
# clean-action estimate is x1 = x_t - time * v. EVO1 integrates t: 0 -> 1 with
|
||||
# velocity v = x1 - x0 (so x1 = x_t + (1 - t) * v); passing time = 1 - t and
|
||||
# flipping the velocity sign in both directions maps one convention onto the other.
|
||||
guided = rtc_processor.denoise_step(
|
||||
x_t=action_seq,
|
||||
prev_chunk_left_over=prev_chunk_left_over,
|
||||
inference_delay=inference_delay,
|
||||
time=1.0 - t,
|
||||
original_denoise_step_partial=lambda seq, emb=time_emb: -predict_velocity(seq, emb),
|
||||
execution_horizon=execution_horizon,
|
||||
)
|
||||
velocity = -guided
|
||||
else:
|
||||
velocity = predict_velocity(action_seq, time_emb)
|
||||
|
||||
action_seq = action_seq + dt * velocity
|
||||
|
||||
action_seq = action_seq * action_mask
|
||||
return action_seq.reshape(batch_size, -1)
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(self.parameters()).device
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return next(self.parameters()).dtype
|
||||
@@ -0,0 +1,369 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision.transforms.functional as tvf
|
||||
from torchvision.transforms.functional import InterpolationMode
|
||||
|
||||
from lerobot.utils.import_utils import _transformers_available, require_package
|
||||
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
else:
|
||||
AutoModel = None
|
||||
AutoTokenizer = None
|
||||
|
||||
IMAGENET_MEAN = (0.485, 0.456, 0.406)
|
||||
IMAGENET_STD = (0.229, 0.224, 0.225)
|
||||
IMG_CONTEXT_TOKEN = "<IMG_CONTEXT>" # nosec B105
|
||||
IMG_START_TOKEN = "<img>" # nosec B105
|
||||
IMG_END_TOKEN = "</img>" # nosec B105
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _batched_resize_01(images: torch.Tensor, image_size: int) -> torch.Tensor:
|
||||
"""Resize a batch of ``[0, 1]`` images to ``(image_size, image_size)`` on-device.
|
||||
|
||||
Numerically mirrors InternVL3's reference PIL preprocessing
|
||||
(``to_pil_image`` -> ``Image.resize`` -> ``to_tensor``): the float input is quantized to uint8
|
||||
exactly as ``to_pil_image`` does, then resized with bicubic interpolation and antialiasing,
|
||||
which matches PIL's default resampler. Matching the reference pixel-for-pixel keeps the policy
|
||||
interchangeable with checkpoints produced by the upstream EVO1 preprocessing.
|
||||
|
||||
Args:
|
||||
images: float tensor of shape ``(N, C, H, W)`` with values in ``[0, 1]``.
|
||||
|
||||
Returns:
|
||||
float32 tensor of shape ``(N, C, image_size, image_size)`` with values in ``[0, 1]``.
|
||||
"""
|
||||
# to_pil_image() quantizes float [0, 1] to uint8 (x * 255, truncated); replicate that so the
|
||||
# bicubic resample sees the same integer pixels PIL would.
|
||||
pixels_u8 = (images * 255.0).clamp(0, 255).to(torch.uint8)
|
||||
resized = tvf.resize(
|
||||
pixels_u8, [image_size, image_size], interpolation=InterpolationMode.BICUBIC, antialias=True
|
||||
)
|
||||
return resized.to(torch.float32) / 255.0
|
||||
|
||||
|
||||
def _batched_pixel_values(
|
||||
camera_images: Sequence[torch.Tensor],
|
||||
max_views: int,
|
||||
image_size: int,
|
||||
mean: torch.Tensor,
|
||||
std: torch.Tensor,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device | str,
|
||||
) -> torch.Tensor:
|
||||
"""Build InternVL3 ``pixel_values`` from per-camera ``[0, 1]`` image batches without leaving the device.
|
||||
|
||||
Each image is resized, converted to ``dtype``, and ImageNet-normalized (a single tile per
|
||||
image), batched across the whole minibatch. Absent views (fewer cameras than ``max_views``)
|
||||
are filled with zero images; their placeholder tokens are masked out of attention downstream
|
||||
via ``_mask_absent_image_tokens``.
|
||||
|
||||
Returns:
|
||||
``pixel_values`` of shape ``(B * max_views, C, image_size, image_size)``, ordered row-major
|
||||
over ``(sample, view)`` to line up with the per-view image placeholders in the prompt.
|
||||
"""
|
||||
resized: list[torch.Tensor] = []
|
||||
for image in camera_images:
|
||||
resized.append(_batched_resize_01(image.to(device=device), image_size).to(dtype))
|
||||
|
||||
batch_size = resized[0].shape[0]
|
||||
channels = resized[0].shape[1]
|
||||
while len(resized) < max_views:
|
||||
resized.append(torch.zeros(batch_size, channels, image_size, image_size, dtype=dtype, device=device))
|
||||
|
||||
stacked = torch.stack(resized[:max_views], dim=1) # (B, V, C, H, W)
|
||||
mean = mean.to(device=device, dtype=dtype).view(1, 1, -1, 1, 1)
|
||||
std = std.to(device=device, dtype=dtype).view(1, 1, -1, 1, 1)
|
||||
normalized = (stacked - mean) / std
|
||||
return normalized.reshape(batch_size * max_views, channels, image_size, image_size)
|
||||
|
||||
|
||||
class InternVL3Embedder(nn.Module):
|
||||
"""Vision-language embedder using the native HF InternVL3 model (no trust_remote_code)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name="OpenGVLab/InternVL3-1B-hf",
|
||||
image_size=448,
|
||||
device="cuda",
|
||||
num_language_layers: int | None = 14,
|
||||
model_dtype: str | torch.dtype = "bfloat16",
|
||||
use_flash_attn: bool = True,
|
||||
max_text_length: int = 1024,
|
||||
enable_gradient_checkpointing: bool = True,
|
||||
gradient_checkpointing_use_reentrant: bool = False,
|
||||
hub_kwargs: dict | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self._requested_device = device
|
||||
self.image_size = image_size
|
||||
self.num_language_layers = num_language_layers
|
||||
self.max_text_length = max_text_length
|
||||
self.enable_gradient_checkpointing = bool(enable_gradient_checkpointing)
|
||||
self.gradient_checkpointing_use_reentrant = bool(gradient_checkpointing_use_reentrant)
|
||||
hub_kwargs = hub_kwargs or {}
|
||||
|
||||
require_package("transformers", extra="evo1")
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name, **hub_kwargs)
|
||||
if isinstance(model_dtype, str):
|
||||
try:
|
||||
model_dtype = getattr(torch, model_dtype)
|
||||
except AttributeError as exc:
|
||||
raise ValueError(f"Unsupported EVO1 vlm_dtype '{model_dtype}'") from exc
|
||||
self.model_dtype = model_dtype
|
||||
|
||||
attn_implementation = "flash_attention_2" if (use_flash_attn and _flash_attn_available()) else "eager"
|
||||
if use_flash_attn and attn_implementation == "eager":
|
||||
logger.warning("flash_attn is not installed. Falling back to eager attention.")
|
||||
|
||||
self.model = AutoModel.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=model_dtype,
|
||||
attn_implementation=attn_implementation,
|
||||
low_cpu_mem_usage=True,
|
||||
**hub_kwargs,
|
||||
).to(self._requested_device)
|
||||
|
||||
checkpoint_image_size = getattr(self.model.config.vision_config, "image_size", None)
|
||||
if isinstance(checkpoint_image_size, (list, tuple)):
|
||||
checkpoint_image_size = checkpoint_image_size[0]
|
||||
if checkpoint_image_size is not None and int(checkpoint_image_size) != int(image_size):
|
||||
raise ValueError(
|
||||
f"EVO1 image_resolution ({image_size}) must match the InternVL checkpoint's native "
|
||||
f"image size ({checkpoint_image_size}): the checkpoint's image_seq_length assumes "
|
||||
"its native resolution, so other sizes would desync the image placeholder tokens "
|
||||
"from the vision features."
|
||||
)
|
||||
|
||||
self.num_image_token = self.model.config.image_seq_length
|
||||
|
||||
# Truncate language model to the requested number of layers
|
||||
layers = self.model.language_model.layers
|
||||
if self.num_language_layers is not None:
|
||||
layers = layers[: self.num_language_layers]
|
||||
self.model.language_model.layers = torch.nn.ModuleList(layers)
|
||||
|
||||
self._configure_memory_features()
|
||||
self.img_context_token_id = self.tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
|
||||
|
||||
def _configure_memory_features(self) -> None:
|
||||
checkpoint_kwargs = {"use_reentrant": self.gradient_checkpointing_use_reentrant}
|
||||
|
||||
if not self.enable_gradient_checkpointing:
|
||||
language_model = self.model.language_model
|
||||
if hasattr(language_model, "gradient_checkpointing_disable"):
|
||||
language_model.gradient_checkpointing_disable()
|
||||
vision_tower = getattr(self.model, "vision_tower", None)
|
||||
if vision_tower is not None and hasattr(vision_tower, "encoder"):
|
||||
vision_tower.encoder.gradient_checkpointing = False
|
||||
return
|
||||
|
||||
def _enable_ckpt(module: nn.Module | None) -> bool:
|
||||
if module is None:
|
||||
return False
|
||||
if hasattr(module, "gradient_checkpointing_enable"):
|
||||
try:
|
||||
module.gradient_checkpointing_enable(gradient_checkpointing_kwargs=checkpoint_kwargs)
|
||||
except TypeError:
|
||||
module.gradient_checkpointing_enable()
|
||||
return True
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = True
|
||||
return True
|
||||
return False
|
||||
|
||||
enabled_any = _enable_ckpt(self.model)
|
||||
|
||||
vision_tower = getattr(self.model, "vision_tower", None)
|
||||
if vision_tower is not None:
|
||||
enabled_any = _enable_ckpt(vision_tower) or enabled_any
|
||||
|
||||
language_model = self.model.language_model
|
||||
enabled_any = _enable_ckpt(language_model) or enabled_any
|
||||
if hasattr(language_model, "config"):
|
||||
language_model.config.use_cache = False
|
||||
|
||||
if hasattr(self.model, "config"):
|
||||
self.model.config.use_cache = False
|
||||
if hasattr(self.model, "enable_input_require_grads"):
|
||||
self.model.enable_input_require_grads()
|
||||
|
||||
if enabled_any:
|
||||
logger.info("Gradient checkpointing enabled for InternVL3 embedder.")
|
||||
else:
|
||||
logger.warning(
|
||||
"Requested gradient checkpointing, but model does not expose checkpointing controls."
|
||||
)
|
||||
|
||||
def _build_multimodal_prompts(
|
||||
self,
|
||||
batch_num_tiles_list: list[list[int]],
|
||||
text_prompts: Sequence[str],
|
||||
) -> list[str]:
|
||||
prompts = []
|
||||
for num_tiles_list, text_prompt in zip(batch_num_tiles_list, text_prompts, strict=True):
|
||||
prompt_segments = []
|
||||
for i, tile_count in enumerate(num_tiles_list):
|
||||
token_count = self.num_image_token * tile_count
|
||||
image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * token_count + IMG_END_TOKEN
|
||||
prompt_segments.append(f"Image-{i + 1}: {image_tokens}\n")
|
||||
prompts.append("".join(prompt_segments) + text_prompt.strip())
|
||||
return prompts
|
||||
|
||||
def get_fused_image_text_embedding_batched(
|
||||
self,
|
||||
camera_images: Sequence[torch.Tensor],
|
||||
image_masks: torch.Tensor,
|
||||
text_prompts: Sequence[str],
|
||||
return_cls_only: bool = True,
|
||||
):
|
||||
"""Fused VL embedding from per-camera ``[0, 1]`` image batches (no PIL, no host round-trip).
|
||||
|
||||
Args:
|
||||
camera_images: list of per-camera tensors, each shaped ``(B, C, H, W)`` in ``[0, 1]``.
|
||||
image_masks: bool tensor ``(B, max_views)`` marking present views.
|
||||
|
||||
Returns:
|
||||
A ``(embeddings, valid_mask)`` tuple. With ``return_cls_only=False``, ``embeddings`` is
|
||||
``(B, L, H)`` and ``valid_mask`` is a ``(B, L)`` bool tensor marking tokens downstream
|
||||
attention may attend to (padding and absent-view tokens are False). With
|
||||
``return_cls_only=True``, ``embeddings`` is the pooled ``(B, H)`` last-valid-token state
|
||||
and ``valid_mask`` is None.
|
||||
"""
|
||||
max_views = int(image_masks.shape[1])
|
||||
batch_size = int(image_masks.shape[0])
|
||||
mean = torch.tensor(IMAGENET_MEAN, device=self.device, dtype=self.model_dtype)
|
||||
std = torch.tensor(IMAGENET_STD, device=self.device, dtype=self.model_dtype)
|
||||
pixel_values = _batched_pixel_values(
|
||||
camera_images, max_views, self.image_size, mean, std, self.model_dtype, self.device
|
||||
)
|
||||
# InternVL3 preprocessing uses a single tile per image (max_num=1).
|
||||
batch_num_tiles_list = [[1] * max_views for _ in range(batch_size)]
|
||||
return self._forward_vlm(
|
||||
pixel_values, batch_num_tiles_list, image_masks, text_prompts, return_cls_only
|
||||
)
|
||||
|
||||
def _mask_absent_image_tokens(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
image_masks: torch.Tensor,
|
||||
batch_num_tiles_list: list[list[int]],
|
||||
) -> torch.Tensor:
|
||||
"""Zero attention over the image-context tokens of absent (zero-padded) views.
|
||||
|
||||
Fully vectorized: runs without any host<->device synchronization.
|
||||
"""
|
||||
# A single tile per image (max_num=1), so every image occupies the same number of
|
||||
# context tokens.
|
||||
tiles_per_image = (
|
||||
batch_num_tiles_list[0][0] if batch_num_tiles_list and batch_num_tiles_list[0] else 1
|
||||
)
|
||||
tokens_per_image = self.num_image_token * tiles_per_image
|
||||
|
||||
image_masks = image_masks.to(device=input_ids.device).bool()
|
||||
img_token_mask = input_ids == self.img_context_token_id # (B, L)
|
||||
# keep[b, k] tells whether the k-th image-context token (ordered view0, view1, ...) survives.
|
||||
per_token_keep = image_masks.repeat_interleave(tokens_per_image, dim=1) # (B, V * tokens_per_image)
|
||||
# Rank each context token by its running position among the row's context tokens.
|
||||
ctx_index = img_token_mask.to(torch.long).cumsum(dim=1) - 1
|
||||
ctx_index = ctx_index.clamp(min=0, max=per_token_keep.shape[1] - 1)
|
||||
keep_here = torch.gather(per_token_keep, 1, ctx_index) # (B, L)
|
||||
drop = img_token_mask & ~keep_here
|
||||
return attention_mask.masked_fill(drop, 0)
|
||||
|
||||
def _forward_vlm(
|
||||
self,
|
||||
pixel_values: torch.Tensor,
|
||||
batch_num_tiles_list: list[list[int]],
|
||||
image_masks: torch.Tensor,
|
||||
text_prompts: Sequence[str],
|
||||
return_cls_only: bool,
|
||||
):
|
||||
if pixel_values.shape[0] == 0:
|
||||
logger.warning("InternVL3 received an empty image batch after preprocessing.")
|
||||
hidden_size = getattr(self.model.config, "hidden_size", None)
|
||||
if hidden_size is None:
|
||||
hidden_size = getattr(self.model.config.text_config, "hidden_size", None)
|
||||
if hidden_size is None:
|
||||
raise RuntimeError("Unable to infer hidden size for empty InternVL3 batch.")
|
||||
return torch.empty(0, hidden_size, device=self.device, dtype=torch.float32), None
|
||||
|
||||
prompts = self._build_multimodal_prompts(batch_num_tiles_list, text_prompts)
|
||||
|
||||
model_inputs = self.tokenizer(
|
||||
list(prompts),
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=self.max_text_length,
|
||||
).to(self.device)
|
||||
input_ids = model_inputs["input_ids"]
|
||||
if input_ids.shape[1] >= self.max_text_length:
|
||||
# Truncation cuts from the right, so text is dropped before image placeholders — but a
|
||||
# large max_views * image_seq_length budget can still eat into them. Fail loudly instead
|
||||
# of letting the VLM crash on a placeholder/vision-feature count mismatch.
|
||||
expected_image_tokens = self.num_image_token * sum(batch_num_tiles_list[0])
|
||||
image_token_counts = (input_ids == self.img_context_token_id).sum(dim=1)
|
||||
if not bool((image_token_counts == expected_image_tokens).all()):
|
||||
raise ValueError(
|
||||
f"Prompt truncation at max_text_length={self.max_text_length} cut into the "
|
||||
f"image placeholder tokens ({expected_image_tokens} expected per sample). "
|
||||
"Increase max_text_length or reduce max_views."
|
||||
)
|
||||
attention_mask = self._mask_absent_image_tokens(
|
||||
input_ids, model_inputs["attention_mask"], image_masks, batch_num_tiles_list
|
||||
)
|
||||
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
pixel_values=pixel_values,
|
||||
attention_mask=attention_mask,
|
||||
output_hidden_states=True,
|
||||
return_dict=True,
|
||||
)
|
||||
fused_hidden = outputs.hidden_states[-1].to(torch.float32)
|
||||
valid_mask = attention_mask.to(torch.bool)
|
||||
if return_cls_only:
|
||||
# Right-padded causal decoder: the last valid token is the only one that has attended
|
||||
# to the full image + text prompt.
|
||||
positions = torch.arange(valid_mask.shape[1], device=valid_mask.device)
|
||||
last_valid = (valid_mask.long() * positions).argmax(dim=1)
|
||||
batch_index = torch.arange(fused_hidden.shape[0], device=fused_hidden.device)
|
||||
return fused_hidden[batch_index, last_valid], None
|
||||
return fused_hidden, valid_mask
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return next(self.model.parameters()).device
|
||||
|
||||
|
||||
def _flash_attn_available() -> bool:
|
||||
try:
|
||||
import flash_attn # noqa: F401
|
||||
except ModuleNotFoundError:
|
||||
return False
|
||||
return True
|
||||
@@ -0,0 +1,532 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import builtins
|
||||
from collections import deque
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
from typing import TypedDict, Unpack
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
|
||||
|
||||
from ..rtc.modeling_rtc import RTCProcessor
|
||||
from .configuration_evo1 import Evo1Config
|
||||
from .evo1_model import Evo1Model
|
||||
|
||||
|
||||
class ActionSelectKwargs(TypedDict, total=False):
|
||||
inference_delay: int | None
|
||||
prev_chunk_left_over: Tensor | None
|
||||
execution_horizon: int | None
|
||||
|
||||
|
||||
class Evo1Policy(PreTrainedPolicy):
|
||||
config_class = Evo1Config
|
||||
name = "evo1"
|
||||
|
||||
def __init__(self, config: Evo1Config, *, vlm_hub_kwargs: dict | None = None, **kwargs):
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
|
||||
if len(config.image_features) > config.max_views:
|
||||
raise ValueError(
|
||||
f"EVO1 supports at most {config.max_views} camera streams, got {len(config.image_features)}"
|
||||
)
|
||||
|
||||
self.config = config
|
||||
self.model = Evo1Model(config, vlm_hub_kwargs=vlm_hub_kwargs)
|
||||
self.model.set_finetune_flags()
|
||||
self._keep_frozen_embedder_eval()
|
||||
self.init_rtc_processor()
|
||||
self.reset()
|
||||
|
||||
def init_rtc_processor(self):
|
||||
"""Create the RTC processor when config.rtc_config is set.
|
||||
|
||||
The RTC rollout backend assigns config.rtc_config after loading the policy and re-invokes
|
||||
this method.
|
||||
"""
|
||||
self.rtc_processor = None
|
||||
if self.config.rtc_config is not None:
|
||||
self.rtc_processor = RTCProcessor(self.config.rtc_config)
|
||||
model = getattr(self, "model", None)
|
||||
if model is not None:
|
||||
model.rtc_processor = self.rtc_processor
|
||||
|
||||
def _rtc_enabled(self) -> bool:
|
||||
return self.config.rtc_config is not None and self.config.rtc_config.enabled
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls: builtins.type[T],
|
||||
pretrained_name_or_path: str | Path,
|
||||
*,
|
||||
config: PreTrainedConfig | None = None,
|
||||
force_download: bool = False,
|
||||
resume_download: bool | None = None,
|
||||
proxies: dict | None = None,
|
||||
token: str | bool | None = None,
|
||||
cache_dir: str | Path | None = None,
|
||||
local_files_only: bool = False,
|
||||
revision: str | None = None,
|
||||
strict: bool | None = None,
|
||||
**kwargs,
|
||||
) -> T:
|
||||
if strict is None:
|
||||
strict = True
|
||||
vlm_hub_kwargs = kwargs.pop("vlm_hub_kwargs", None)
|
||||
if config is None:
|
||||
config = PreTrainedConfig.from_pretrained(
|
||||
pretrained_name_or_path=pretrained_name_or_path,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=local_files_only,
|
||||
revision=revision,
|
||||
**kwargs,
|
||||
)
|
||||
if vlm_hub_kwargs is None:
|
||||
# Forward the hub download options to the base-VLM download as well; `revision` is not
|
||||
# forwarded because it identifies the policy repo, not the VLM repo.
|
||||
vlm_hub_kwargs = {
|
||||
key: value
|
||||
for key, value in (
|
||||
("token", token),
|
||||
("cache_dir", cache_dir),
|
||||
("local_files_only", local_files_only),
|
||||
("proxies", proxies),
|
||||
)
|
||||
if value not in (None, False)
|
||||
}
|
||||
kwargs["vlm_hub_kwargs"] = vlm_hub_kwargs
|
||||
return super().from_pretrained(
|
||||
pretrained_name_or_path=pretrained_name_or_path,
|
||||
config=config,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=local_files_only,
|
||||
revision=revision,
|
||||
strict=strict,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@property
|
||||
def _camera_keys(self) -> list[str]:
|
||||
return list(self.config.image_features)
|
||||
|
||||
@property
|
||||
def _env_action_dim(self) -> int:
|
||||
action_feature = self.config.action_feature
|
||||
if action_feature is None:
|
||||
return self.config.max_action_dim
|
||||
return int(action_feature.shape[0])
|
||||
|
||||
@property
|
||||
def _compute_dtype(self) -> torch.dtype:
|
||||
return next(self.model.action_head.parameters()).dtype
|
||||
|
||||
@property
|
||||
def _device(self) -> torch.device:
|
||||
# The device the policy actually lives on. Derived from the parameters rather than
|
||||
# config.device so the policy keeps working after accelerate (or a plain .to()) moves it.
|
||||
return next(self.model.action_head.parameters()).device
|
||||
|
||||
@property
|
||||
def _amp_enabled(self) -> bool:
|
||||
return bool(self.config.use_amp) and self._device.type == "cuda"
|
||||
|
||||
def _maybe_autocast(self):
|
||||
# EVO1 manages its own mixed precision: an explicit bf16 autocast that also overrides any
|
||||
# outer autocast context (e.g. lerobot-eval's fp16 default), keeping train and eval
|
||||
# numerics identical.
|
||||
if self._amp_enabled:
|
||||
return torch.autocast(device_type="cuda", dtype=torch.bfloat16)
|
||||
return nullcontext()
|
||||
|
||||
def get_optim_params(self) -> list[dict]:
|
||||
decay, no_decay = [], []
|
||||
for name, param in self.named_parameters():
|
||||
if not param.requires_grad:
|
||||
continue
|
||||
is_bias = name.endswith("bias") or ".bias" in name
|
||||
is_norm = param.dim() == 1 or "norm" in name.lower()
|
||||
if is_bias or is_norm:
|
||||
no_decay.append(param)
|
||||
else:
|
||||
decay.append(param)
|
||||
return [
|
||||
{"params": decay, "weight_decay": self.config.optimizer_weight_decay},
|
||||
{"params": no_decay, "weight_decay": 0.0},
|
||||
]
|
||||
|
||||
def reset(self):
|
||||
self._action_queue = deque([], maxlen=self.config.n_action_steps)
|
||||
|
||||
def _normalize_task_batch(self, batch: dict[str, Tensor | list[str] | str]) -> list[str]:
|
||||
prompts = batch.get(self.config.task_field)
|
||||
if prompts is None and self.config.task_field != "task":
|
||||
prompts = batch.get("task")
|
||||
if prompts is None:
|
||||
raise ValueError(f"EVO1 expects a '{self.config.task_field}' text field in the batch.")
|
||||
if isinstance(prompts, str):
|
||||
return [prompts]
|
||||
if isinstance(prompts, (list, tuple)):
|
||||
return [str(prompt) for prompt in prompts]
|
||||
raise TypeError(f"Unsupported prompt batch type: {type(prompts)}")
|
||||
|
||||
def _prepare_state(self, batch: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
|
||||
if OBS_STATE not in batch:
|
||||
raise ValueError(f"EVO1 requires '{OBS_STATE}' in the batch.")
|
||||
state = batch[OBS_STATE]
|
||||
if state.dim() == 1:
|
||||
state = state.unsqueeze(0)
|
||||
elif state.dim() == 3:
|
||||
state = state[:, -1]
|
||||
elif state.dim() != 2:
|
||||
raise ValueError(f"Unsupported state tensor shape for EVO1: {tuple(state.shape)}")
|
||||
batch_size, state_dim = state.shape
|
||||
if state_dim > self.config.max_state_dim:
|
||||
raise ValueError(
|
||||
f"State dim {state_dim} exceeds configured max_state_dim {self.config.max_state_dim}"
|
||||
)
|
||||
explicit_mask = batch.get("state_mask")
|
||||
if explicit_mask is not None:
|
||||
if explicit_mask.dim() == 1:
|
||||
explicit_mask = explicit_mask.unsqueeze(0)
|
||||
elif explicit_mask.dim() == 3:
|
||||
explicit_mask = explicit_mask[:, -1]
|
||||
elif explicit_mask.dim() != 2:
|
||||
raise ValueError(
|
||||
f"Unsupported state_mask tensor shape for EVO1: {tuple(explicit_mask.shape)}"
|
||||
)
|
||||
if explicit_mask.shape != (batch_size, state_dim):
|
||||
raise ValueError(
|
||||
f"state_mask shape {tuple(explicit_mask.shape)} does not match state shape {(batch_size, state_dim)}"
|
||||
)
|
||||
device = self._device
|
||||
padded = torch.zeros(
|
||||
batch_size,
|
||||
self.config.max_state_dim,
|
||||
dtype=state.dtype,
|
||||
device=device,
|
||||
)
|
||||
padded[:, :state_dim] = state.to(device=device)
|
||||
mask = torch.zeros(
|
||||
batch_size,
|
||||
self.config.max_state_dim,
|
||||
dtype=torch.bool,
|
||||
device=device,
|
||||
)
|
||||
if explicit_mask is None:
|
||||
mask[:, :state_dim] = True
|
||||
else:
|
||||
mask[:, :state_dim] = explicit_mask.to(device=device, dtype=torch.bool)
|
||||
# Zero out masked state dims so an explicit state_mask actually affects the model input
|
||||
# (the state encoder has no mask argument of its own).
|
||||
padded = padded * mask.to(dtype=padded.dtype)
|
||||
return padded.to(dtype=self._compute_dtype), mask
|
||||
|
||||
def _prepare_actions(self, batch: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
|
||||
if ACTION not in batch:
|
||||
raise ValueError(f"EVO1 requires '{ACTION}' in the batch for training.")
|
||||
action = batch[ACTION]
|
||||
if action.dim() == 2:
|
||||
action = action.unsqueeze(1)
|
||||
batch_size, horizon, action_dim = action.shape
|
||||
if horizon != self.config.chunk_size:
|
||||
raise ValueError(
|
||||
f"EVO1 expects chunk_size={self.config.chunk_size}, got action horizon {horizon}"
|
||||
)
|
||||
if action_dim > self.config.max_action_dim:
|
||||
raise ValueError(
|
||||
f"Action dim {action_dim} exceeds configured max_action_dim {self.config.max_action_dim}"
|
||||
)
|
||||
explicit_mask = batch.get("action_mask")
|
||||
if explicit_mask is not None:
|
||||
if explicit_mask.dim() == 2:
|
||||
if horizon == 1:
|
||||
explicit_mask = explicit_mask.unsqueeze(1)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"2D action_mask is only supported when chunk_size=1, got action horizon {horizon}"
|
||||
)
|
||||
elif explicit_mask.dim() != 3:
|
||||
raise ValueError(
|
||||
f"Unsupported action_mask tensor shape for EVO1: {tuple(explicit_mask.shape)}"
|
||||
)
|
||||
if explicit_mask.shape != (batch_size, horizon, action_dim):
|
||||
raise ValueError(
|
||||
"action_mask shape "
|
||||
f"{tuple(explicit_mask.shape)} does not match action shape {(batch_size, horizon, action_dim)}"
|
||||
)
|
||||
device = self._device
|
||||
padded = torch.zeros(
|
||||
batch_size,
|
||||
horizon,
|
||||
self.config.max_action_dim,
|
||||
dtype=action.dtype,
|
||||
device=device,
|
||||
)
|
||||
padded[:, :, :action_dim] = action.to(device=device)
|
||||
mask = torch.zeros(
|
||||
batch_size,
|
||||
horizon,
|
||||
self.config.max_action_dim,
|
||||
dtype=torch.bool,
|
||||
device=device,
|
||||
)
|
||||
if explicit_mask is None:
|
||||
mask[:, :, :action_dim] = True
|
||||
else:
|
||||
mask[:, :, :action_dim] = explicit_mask.to(device=device, dtype=torch.bool)
|
||||
|
||||
# Timesteps beyond the episode end hold fabricated (repeated) actions; exclude them from
|
||||
# the loss like the other chunked policies do.
|
||||
action_is_pad = batch.get("action_is_pad")
|
||||
if action_is_pad is not None:
|
||||
if action_is_pad.shape != (batch_size, horizon):
|
||||
raise ValueError(
|
||||
f"action_is_pad shape {tuple(action_is_pad.shape)} does not match "
|
||||
f"(batch_size, chunk_size)={(batch_size, horizon)}"
|
||||
)
|
||||
in_episode = ~action_is_pad.to(device=device, dtype=torch.bool)
|
||||
mask = mask & in_episode.unsqueeze(-1)
|
||||
return padded.to(dtype=self._compute_dtype), mask
|
||||
|
||||
def _prepare_inference_action_mask(self, batch_size: int) -> Tensor:
|
||||
mask = torch.zeros(
|
||||
batch_size,
|
||||
self.config.max_action_dim,
|
||||
dtype=torch.bool,
|
||||
device=self._device,
|
||||
)
|
||||
mask[:, : self._env_action_dim] = True
|
||||
return mask
|
||||
|
||||
def _get_embodiment_ids(self, batch: dict[str, Tensor], batch_size: int) -> Tensor:
|
||||
embodiment_ids = batch.get("embodiment_id")
|
||||
if embodiment_ids is None and self.config.embodiment_id_field:
|
||||
embodiment_ids = batch.get(self.config.embodiment_id_field)
|
||||
if embodiment_ids is None:
|
||||
return torch.full(
|
||||
(batch_size,),
|
||||
self.config.default_embodiment_id,
|
||||
dtype=torch.long,
|
||||
device=self._device,
|
||||
)
|
||||
if embodiment_ids.dim() == 0:
|
||||
embodiment_ids = embodiment_ids.unsqueeze(0)
|
||||
elif embodiment_ids.dim() > 1:
|
||||
embodiment_ids = embodiment_ids[:, -1]
|
||||
return embodiment_ids.to(device=self._device, dtype=torch.long)
|
||||
|
||||
@property
|
||||
def _tracks_vlm_gradients(self) -> bool:
|
||||
return bool(
|
||||
self.config.finetune_vlm
|
||||
or self.config.finetune_language_model
|
||||
or self.config.finetune_vision_model
|
||||
)
|
||||
|
||||
def _keep_frozen_embedder_eval(self) -> None:
|
||||
if self._tracks_vlm_gradients:
|
||||
return
|
||||
embedder = getattr(self.model, "embedder", None)
|
||||
if embedder is not None:
|
||||
embedder.eval()
|
||||
|
||||
def train(self, mode: bool = True):
|
||||
super().train(mode)
|
||||
self._keep_frozen_embedder_eval()
|
||||
return self
|
||||
|
||||
def _collect_image_batches(self, batch: dict[str, Tensor]) -> tuple[list[Tensor], Tensor]:
|
||||
camera_keys = self._camera_keys or sorted(key for key in batch if key.startswith(f"{OBS_IMAGES}."))
|
||||
if not camera_keys:
|
||||
raise ValueError("EVO1 requires at least one visual observation feature.")
|
||||
camera_keys = list(camera_keys)[: self.config.max_views]
|
||||
|
||||
# Configured cameras may be absent from the batch up to the empty_cameras budget (e.g. the
|
||||
# placeholder features added by validate_features); they become masked-out views that the
|
||||
# embedder zero-pads. Any other absent camera is an error.
|
||||
present_keys = [key for key in camera_keys if key in batch]
|
||||
missing_keys = [key for key in camera_keys if key not in batch]
|
||||
if len(missing_keys) > self.config.empty_cameras:
|
||||
raise ValueError(
|
||||
f"Missing camera features {missing_keys} in batch; at most "
|
||||
f"empty_cameras={self.config.empty_cameras} may be absent."
|
||||
)
|
||||
if not present_keys:
|
||||
raise ValueError("EVO1 requires at least one visual observation in the batch.")
|
||||
|
||||
# Keep each present camera as a batched (B, C, H, W) tensor on its current (GPU) device.
|
||||
# Resizing/normalization and zero-padding of absent views happen batched inside the
|
||||
# embedder, so images never leave the device here.
|
||||
camera_images: list[Tensor] = []
|
||||
for camera_key in present_keys:
|
||||
image = batch[camera_key]
|
||||
if image.dim() == 3:
|
||||
# Promote an unbatched (C, H, W) frame so batch_size is read from a real batch dim.
|
||||
image = image.unsqueeze(0)
|
||||
elif image.dim() == 5:
|
||||
image = image[:, -1]
|
||||
elif image.dim() != 4:
|
||||
raise ValueError(
|
||||
f"Unsupported image tensor shape for EVO1: key={camera_key} shape={tuple(image.shape)}"
|
||||
)
|
||||
camera_images.append(image)
|
||||
|
||||
batch_size = camera_images[0].shape[0]
|
||||
n_present = len(camera_images)
|
||||
image_masks = torch.zeros(
|
||||
batch_size, self.config.max_views, dtype=torch.bool, device=camera_images[0].device
|
||||
)
|
||||
image_masks[:, :n_present] = True
|
||||
|
||||
return camera_images, image_masks
|
||||
|
||||
def _compute_fused_tokens(
|
||||
self,
|
||||
prompts: list[str],
|
||||
image_batches: list[Tensor],
|
||||
image_masks: Tensor,
|
||||
) -> tuple[Tensor, Tensor | None]:
|
||||
track_vlm_gradients = self._tracks_vlm_gradients
|
||||
grad_context = nullcontext() if track_vlm_gradients else torch.no_grad()
|
||||
with grad_context:
|
||||
fused_tokens, context_mask = self.model.get_vl_embeddings(
|
||||
images=image_batches,
|
||||
image_mask=image_masks,
|
||||
prompt=prompts,
|
||||
return_cls_only=self.config.return_cls_only,
|
||||
)
|
||||
|
||||
if not track_vlm_gradients:
|
||||
fused_tokens = fused_tokens.detach()
|
||||
fused_tokens = fused_tokens.to(device=self._device, dtype=self._compute_dtype)
|
||||
if context_mask is not None:
|
||||
context_mask = context_mask.to(device=self._device)
|
||||
return fused_tokens, context_mask
|
||||
|
||||
def _compute_masked_loss(
|
||||
self,
|
||||
pred_velocity: Tensor,
|
||||
target_velocity: Tensor,
|
||||
action_mask: Tensor,
|
||||
reduction: str,
|
||||
) -> Tensor:
|
||||
flat_mask = action_mask.view(action_mask.shape[0], -1).to(dtype=pred_velocity.dtype)
|
||||
sq_error = ((pred_velocity - target_velocity) * flat_mask).pow(2)
|
||||
active = flat_mask.sum(dim=1).clamp_min(1.0)
|
||||
per_sample_loss = sq_error.sum(dim=1) / active
|
||||
if reduction == "none":
|
||||
return per_sample_loss
|
||||
if reduction != "mean":
|
||||
raise ValueError(f"Unsupported reduction '{reduction}'")
|
||||
return sq_error.sum() / active.sum()
|
||||
|
||||
def forward(self, batch: dict[str, Tensor], reduction: str = "mean") -> tuple[Tensor, dict]:
|
||||
prompts = self._normalize_task_batch(batch)
|
||||
image_batches, image_masks = self._collect_image_batches(batch)
|
||||
states, _state_mask = self._prepare_state(batch)
|
||||
actions_gt, action_mask = self._prepare_actions(batch)
|
||||
embodiment_ids = self._get_embodiment_ids(batch, states.shape[0])
|
||||
|
||||
with self._maybe_autocast():
|
||||
fused_tokens, context_mask = self._compute_fused_tokens(prompts, image_batches, image_masks)
|
||||
pred_velocity, noise = self.model(
|
||||
fused_tokens,
|
||||
state=states,
|
||||
actions_gt=actions_gt,
|
||||
action_mask=action_mask.to(device=self._device, dtype=self._compute_dtype),
|
||||
embodiment_ids=embodiment_ids,
|
||||
context_mask=context_mask,
|
||||
)
|
||||
|
||||
# Compute the flow-matching regression loss in fp32, outside the autocast block.
|
||||
pred_velocity = pred_velocity.float()
|
||||
noise = noise.float()
|
||||
flat_action_mask = action_mask.view(action_mask.shape[0], -1).to(dtype=torch.float32)
|
||||
# Flow-matching velocity target. Padded (masked-out) action dims are already zero on both sides
|
||||
# here (`actions_gt` is zero-padded in `_prepare_actions`, and `noise` is masked inside the head),
|
||||
# and the whole difference is multiplied by `flat_action_mask`, so padded dims contribute nothing.
|
||||
target_velocity = (actions_gt.float() - noise).view(actions_gt.shape[0], -1) * flat_action_mask
|
||||
loss = self._compute_masked_loss(pred_velocity, target_velocity, action_mask, reduction)
|
||||
loss_mean = loss.mean().item() if loss.ndim > 0 else loss.item()
|
||||
return loss, {
|
||||
"loss": loss_mean,
|
||||
"active_action_dims": float(action_mask.sum(dim=(1, 2)).float().mean().item()),
|
||||
}
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs: Unpack[ActionSelectKwargs]) -> Tensor:
|
||||
inference_delay = kwargs.get("inference_delay")
|
||||
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
|
||||
execution_horizon = kwargs.get("execution_horizon")
|
||||
if (inference_delay is not None or prev_chunk_left_over is not None) and not self._rtc_enabled():
|
||||
raise RuntimeError(
|
||||
"Received RTC arguments but RTC is not configured for this EVO1 policy: set "
|
||||
"config.rtc_config and call init_rtc_processor() (lerobot-rollout does this for "
|
||||
"--inference.type=rtc)."
|
||||
)
|
||||
self.eval()
|
||||
|
||||
prompts = self._normalize_task_batch(batch)
|
||||
image_batches, image_masks = self._collect_image_batches(batch)
|
||||
states, _state_mask = self._prepare_state(batch)
|
||||
embodiment_ids = self._get_embodiment_ids(batch, states.shape[0])
|
||||
action_mask = self._prepare_inference_action_mask(states.shape[0])
|
||||
if prev_chunk_left_over is not None:
|
||||
prev_chunk_left_over = prev_chunk_left_over.to(device=self._device)
|
||||
|
||||
with self._maybe_autocast():
|
||||
fused_tokens, context_mask = self._compute_fused_tokens(prompts, image_batches, image_masks)
|
||||
actions = self.model(
|
||||
fused_tokens,
|
||||
state=states,
|
||||
action_mask=action_mask,
|
||||
embodiment_ids=embodiment_ids,
|
||||
context_mask=context_mask,
|
||||
inference_delay=inference_delay,
|
||||
prev_chunk_left_over=prev_chunk_left_over,
|
||||
execution_horizon=execution_horizon,
|
||||
)
|
||||
actions = actions.view(states.shape[0], self.config.chunk_size, self.config.max_action_dim)
|
||||
return actions.to(dtype=torch.float32)
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor], **kwargs) -> Tensor:
|
||||
assert not self._rtc_enabled(), (
|
||||
"RTC is not supported for select_action, use it with predict_action_chunk"
|
||||
)
|
||||
self.eval()
|
||||
if len(self._action_queue) == 0:
|
||||
action_chunk = self.predict_action_chunk(batch)[:, : self.config.n_action_steps]
|
||||
self._action_queue.extend(action_chunk.transpose(0, 1))
|
||||
# Returns one step of shape (B, max_action_dim): actions are emitted at the padded max_action_dim
|
||||
# width and cropped to the real action dim downstream by the postprocessor (Evo1ActionProcessorStep).
|
||||
# Callers that bypass the postprocessor receive the padded width.
|
||||
return self._action_queue.popleft()
|
||||
@@ -0,0 +1,400 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs import FeatureType, PipelineFeatureType, PolicyFeature
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
ObservationProcessorStep,
|
||||
PolicyAction,
|
||||
PolicyActionProcessorStep,
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorStep,
|
||||
ProcessorStepRegistry,
|
||||
RenameObservationsProcessorStep,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
from lerobot.processor.converters import (
|
||||
batch_to_transition,
|
||||
create_transition,
|
||||
policy_action_to_transition,
|
||||
transition_to_policy_action,
|
||||
)
|
||||
from lerobot.types import EnvTransition, TransitionKey
|
||||
from lerobot.utils.constants import (
|
||||
ACTION,
|
||||
DONE,
|
||||
INFO,
|
||||
OBS_PREFIX,
|
||||
OBS_STATE,
|
||||
POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
REWARD,
|
||||
TRUNCATED,
|
||||
)
|
||||
|
||||
from .configuration_evo1 import Evo1Config
|
||||
|
||||
|
||||
def evo1_batch_to_transition(batch: dict[str, Any]):
|
||||
transition = batch_to_transition(batch)
|
||||
complementary_data = dict(transition.get("complementary_data") or {})
|
||||
reserved = {ACTION, REWARD, DONE, TRUNCATED, INFO}
|
||||
for key, value in batch.items():
|
||||
if key in reserved or key.startswith(OBS_PREFIX):
|
||||
continue
|
||||
complementary_data.setdefault(key, value)
|
||||
return create_transition(
|
||||
observation=transition.get("observation"),
|
||||
action=transition.get("action"),
|
||||
reward=transition.get("reward", 0.0),
|
||||
done=transition.get("done", False),
|
||||
truncated=transition.get("truncated", False),
|
||||
info=transition.get("info", {}),
|
||||
complementary_data=complementary_data,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="evo1_pad_state_processor")
|
||||
class Evo1PadStateProcessorStep(ObservationProcessorStep):
|
||||
"""Pad policy observations to EVO1's fixed state width before normalization."""
|
||||
|
||||
max_state_dim: int = 24
|
||||
|
||||
def observation(self, observation: dict[str, Any]) -> dict[str, Any]:
|
||||
if OBS_STATE not in observation:
|
||||
return observation
|
||||
|
||||
state = observation[OBS_STATE]
|
||||
state_dim = state.shape[-1]
|
||||
if state_dim > self.max_state_dim:
|
||||
raise ValueError(
|
||||
f"EVO1 state has {state_dim} dims, which exceeds max_state_dim={self.max_state_dim}."
|
||||
)
|
||||
if state_dim < self.max_state_dim:
|
||||
observation = observation.copy()
|
||||
observation[OBS_STATE] = torch.nn.functional.pad(state, (0, self.max_state_dim - state_dim))
|
||||
return observation
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
new_features = {ft: feats.copy() for ft, feats in features.items()}
|
||||
obs_feats = new_features.setdefault(PipelineFeatureType.OBSERVATION, {})
|
||||
if OBS_STATE in obs_feats:
|
||||
obs_feats[OBS_STATE] = PolicyFeature(type=FeatureType.STATE, shape=(self.max_state_dim,))
|
||||
return new_features
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {"max_state_dim": self.max_state_dim}
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="evo1_pad_action_processor")
|
||||
class Evo1PadActionProcessorStep(ProcessorStep):
|
||||
"""Pad training actions and preserve the active action dimensions with action_mask."""
|
||||
|
||||
max_action_dim: int = 24
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
action = transition.get(TransitionKey.ACTION)
|
||||
if action is None:
|
||||
return transition
|
||||
if not isinstance(action, PolicyAction):
|
||||
raise ValueError(f"EVO1 action should be a PolicyAction tensor, but got {type(action)}.")
|
||||
|
||||
action_dim = action.shape[-1]
|
||||
if action_dim > self.max_action_dim:
|
||||
raise ValueError(
|
||||
f"EVO1 action has {action_dim} dims, which exceeds max_action_dim={self.max_action_dim}."
|
||||
)
|
||||
|
||||
new_transition = transition.copy()
|
||||
new_action = action
|
||||
if action_dim < self.max_action_dim:
|
||||
new_action = torch.nn.functional.pad(action, (0, self.max_action_dim - action_dim))
|
||||
|
||||
complementary_data = dict(new_transition.get(TransitionKey.COMPLEMENTARY_DATA) or {})
|
||||
action_mask = complementary_data.get("action_mask")
|
||||
if action_mask is None:
|
||||
action_mask = torch.ones(action.shape, dtype=torch.bool, device=action.device)
|
||||
else:
|
||||
action_mask = torch.as_tensor(action_mask, dtype=torch.bool, device=action.device)
|
||||
if action_mask.shape != action.shape:
|
||||
raise ValueError(
|
||||
f"action_mask shape {tuple(action_mask.shape)} does not match action shape {tuple(action.shape)}."
|
||||
)
|
||||
if action_dim < self.max_action_dim:
|
||||
action_mask = torch.nn.functional.pad(action_mask, (0, self.max_action_dim - action_dim))
|
||||
|
||||
complementary_data["action_mask"] = action_mask
|
||||
new_transition[TransitionKey.ACTION] = new_action
|
||||
new_transition[TransitionKey.COMPLEMENTARY_DATA] = complementary_data
|
||||
return new_transition
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
new_features = {ft: feats.copy() for ft, feats in features.items()}
|
||||
action_feats = new_features.setdefault(PipelineFeatureType.ACTION, {})
|
||||
action_feats[ACTION] = PolicyFeature(type=FeatureType.ACTION, shape=(self.max_action_dim,))
|
||||
return new_features
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {"max_action_dim": self.max_action_dim}
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="evo1_action_processor")
|
||||
class Evo1ActionProcessorStep(PolicyActionProcessorStep):
|
||||
"""Crop padded EVO1 actions and optionally binarize the LIBERO gripper channel."""
|
||||
|
||||
action_dim: int
|
||||
binarize_gripper: bool = False
|
||||
gripper_index: int = 6
|
||||
gripper_threshold: float = 0.5
|
||||
gripper_below_threshold_value: float = 1.0
|
||||
gripper_above_threshold_value: float = -1.0
|
||||
|
||||
def action(self, action: PolicyAction) -> PolicyAction:
|
||||
if action.shape[-1] < self.action_dim:
|
||||
raise ValueError(
|
||||
f"EVO1 action has {action.shape[-1]} dims, which is smaller than action_dim={self.action_dim}."
|
||||
)
|
||||
|
||||
action = action[..., : self.action_dim]
|
||||
if not self.binarize_gripper:
|
||||
return action
|
||||
|
||||
if not 0 <= self.gripper_index < self.action_dim:
|
||||
raise ValueError(
|
||||
f"gripper_index={self.gripper_index} must be within action_dim={self.action_dim}."
|
||||
)
|
||||
|
||||
action = action.clone()
|
||||
below = torch.as_tensor(
|
||||
self.gripper_below_threshold_value,
|
||||
dtype=action.dtype,
|
||||
device=action.device,
|
||||
)
|
||||
above = torch.as_tensor(
|
||||
self.gripper_above_threshold_value,
|
||||
dtype=action.dtype,
|
||||
device=action.device,
|
||||
)
|
||||
action[..., self.gripper_index] = torch.where(
|
||||
action[..., self.gripper_index] > self.gripper_threshold,
|
||||
above,
|
||||
below,
|
||||
)
|
||||
return action
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
new_features = {ft: feats.copy() for ft, feats in features.items()}
|
||||
action_feats = new_features.setdefault(PipelineFeatureType.ACTION, {})
|
||||
action_feats[ACTION] = PolicyFeature(type=FeatureType.ACTION, shape=(self.action_dim,))
|
||||
return new_features
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {
|
||||
"action_dim": self.action_dim,
|
||||
"binarize_gripper": self.binarize_gripper,
|
||||
"gripper_index": self.gripper_index,
|
||||
"gripper_threshold": self.gripper_threshold,
|
||||
"gripper_below_threshold_value": self.gripper_below_threshold_value,
|
||||
"gripper_above_threshold_value": self.gripper_above_threshold_value,
|
||||
}
|
||||
|
||||
|
||||
def _evo1_action_dim(config: Evo1Config) -> int:
|
||||
if config.postprocess_action_dim is not None:
|
||||
return config.postprocess_action_dim
|
||||
action_feature = config.action_feature
|
||||
if action_feature is None:
|
||||
return config.max_action_dim
|
||||
return int(action_feature.shape[0])
|
||||
|
||||
|
||||
def _evo1_normalization_features(config: Evo1Config) -> dict[str, PolicyFeature]:
|
||||
features = {**config.input_features, **config.output_features}
|
||||
features[OBS_STATE] = PolicyFeature(type=FeatureType.STATE, shape=(config.max_state_dim,))
|
||||
features[ACTION] = PolicyFeature(type=FeatureType.ACTION, shape=(config.max_action_dim,))
|
||||
return features
|
||||
|
||||
|
||||
def _evo1_action_features(config: Evo1Config) -> dict[str, PolicyFeature]:
|
||||
return {ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(config.max_action_dim,))}
|
||||
|
||||
|
||||
_STAT_PAD_VALUES = {
|
||||
"mean": 0.0,
|
||||
"std": 1.0,
|
||||
"min": -1.0,
|
||||
"max": 1.0,
|
||||
"q01": -1.0,
|
||||
"q99": 1.0,
|
||||
"q10": -1.0,
|
||||
"q90": 1.0,
|
||||
}
|
||||
|
||||
|
||||
def _pad_stat_value(value: Any, target_dim: int, stat_name: str) -> torch.Tensor:
|
||||
tensor = torch.as_tensor(value)
|
||||
if not tensor.is_floating_point():
|
||||
tensor = tensor.to(dtype=torch.float32)
|
||||
if tensor.ndim == 0 or tensor.shape[-1] >= target_dim:
|
||||
return tensor
|
||||
|
||||
pad_shape = (*tensor.shape[:-1], target_dim - tensor.shape[-1])
|
||||
pad_value = _STAT_PAD_VALUES.get(stat_name, 0.0)
|
||||
padding = torch.full(pad_shape, pad_value, dtype=tensor.dtype, device=tensor.device)
|
||||
return torch.cat([tensor, padding], dim=-1)
|
||||
|
||||
|
||||
def _pad_feature_stats(
|
||||
stats: dict[str, dict[str, Any]],
|
||||
feature_key: str,
|
||||
target_dim: int,
|
||||
) -> None:
|
||||
if feature_key not in stats:
|
||||
return
|
||||
stats[feature_key] = {
|
||||
stat_name: _pad_stat_value(stat_value, target_dim, stat_name)
|
||||
for stat_name, stat_value in stats[feature_key].items()
|
||||
}
|
||||
|
||||
|
||||
def _pad_evo1_stats(
|
||||
config: Evo1Config,
|
||||
stats: dict[str, dict[str, Any]] | None,
|
||||
) -> dict[str, dict[str, Any]] | None:
|
||||
if stats is None:
|
||||
return None
|
||||
|
||||
padded_stats = deepcopy(stats)
|
||||
# Added dimensions represent zero-padding inside EVO1. These neutral stats keep
|
||||
# padded observations at normalized zero and only provide shape compatibility.
|
||||
_pad_feature_stats(padded_stats, OBS_STATE, config.max_state_dim)
|
||||
_pad_feature_stats(padded_stats, ACTION, config.max_action_dim)
|
||||
return padded_stats
|
||||
|
||||
|
||||
def reconcile_evo1_processors(
|
||||
config: Evo1Config,
|
||||
preprocessor: PolicyProcessorPipeline,
|
||||
postprocessor: PolicyProcessorPipeline,
|
||||
) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]:
|
||||
"""Reconcile checkpoint-loaded pipelines with the current EVO1 config.
|
||||
|
||||
Two things cannot be restored from a serialized pipeline alone: the EVO1 batch converter
|
||||
(converters are plain functions and are never serialized), and eval-time CLI overrides of the
|
||||
action postprocessing flags (`postprocess_action_dim`, `binarize_gripper`, `gripper_*`). This
|
||||
restores the converter and rebuilds the action step from the current config so those overrides
|
||||
take effect.
|
||||
"""
|
||||
# Pipelines reloaded from a checkpoint come back with the default batch converter, which drops
|
||||
# non-observation extras (embodiment_id, state_mask, custom task fields) needed by EVO1.
|
||||
preprocessor.to_transition = evo1_batch_to_transition
|
||||
|
||||
action_step = Evo1ActionProcessorStep(
|
||||
action_dim=_evo1_action_dim(config),
|
||||
binarize_gripper=config.binarize_gripper,
|
||||
gripper_index=config.gripper_index,
|
||||
gripper_threshold=config.gripper_threshold,
|
||||
gripper_below_threshold_value=config.gripper_below_threshold_value,
|
||||
gripper_above_threshold_value=config.gripper_above_threshold_value,
|
||||
)
|
||||
steps = list(postprocessor.steps)
|
||||
action_step_idx = next(
|
||||
(idx for idx, step in enumerate(steps) if isinstance(step, Evo1ActionProcessorStep)), None
|
||||
)
|
||||
if action_step_idx is None:
|
||||
insert_idx = next(
|
||||
(idx + 1 for idx, step in enumerate(steps) if isinstance(step, UnnormalizerProcessorStep)),
|
||||
0,
|
||||
)
|
||||
steps.insert(insert_idx, action_step)
|
||||
else:
|
||||
steps[action_step_idx] = action_step
|
||||
postprocessor.steps = steps
|
||||
|
||||
return preprocessor, postprocessor
|
||||
|
||||
|
||||
def make_evo1_pre_post_processors(
|
||||
config: Evo1Config,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
normalization_features = _evo1_normalization_features(config)
|
||||
action_features = _evo1_action_features(config)
|
||||
normalization_stats = _pad_evo1_stats(config, dataset_stats)
|
||||
|
||||
input_steps = [
|
||||
RenameObservationsProcessorStep(rename_map={}),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
Evo1PadStateProcessorStep(max_state_dim=config.max_state_dim),
|
||||
Evo1PadActionProcessorStep(max_action_dim=config.max_action_dim),
|
||||
NormalizerProcessorStep(
|
||||
features=normalization_features,
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=normalization_stats,
|
||||
),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
]
|
||||
output_steps = [
|
||||
UnnormalizerProcessorStep(
|
||||
features=action_features,
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=normalization_stats,
|
||||
),
|
||||
Evo1ActionProcessorStep(
|
||||
action_dim=_evo1_action_dim(config),
|
||||
binarize_gripper=config.binarize_gripper,
|
||||
gripper_index=config.gripper_index,
|
||||
gripper_threshold=config.gripper_threshold,
|
||||
gripper_below_threshold_value=config.gripper_below_threshold_value,
|
||||
gripper_above_threshold_value=config.gripper_above_threshold_value,
|
||||
),
|
||||
# float32 so downstream numpy conversion works even when the policy computes in bf16.
|
||||
DeviceProcessorStep(device="cpu", float_dtype="float32"),
|
||||
]
|
||||
|
||||
return (
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
|
||||
steps=input_steps,
|
||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
to_transition=evo1_batch_to_transition,
|
||||
),
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction](
|
||||
steps=output_steps,
|
||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
to_transition=policy_action_to_transition,
|
||||
to_output=transition_to_policy_action,
|
||||
),
|
||||
)
|
||||
@@ -47,6 +47,7 @@ from lerobot.utils.feature_utils import dataset_to_policy_features
|
||||
from .act.configuration_act import ACTConfig
|
||||
from .diffusion.configuration_diffusion import DiffusionConfig
|
||||
from .eo1.configuration_eo1 import EO1Config
|
||||
from .evo1.configuration_evo1 import Evo1Config
|
||||
from .fastwam.configuration_fastwam import FastWAMConfig
|
||||
from .gaussian_actor.configuration_gaussian_actor import GaussianActorConfig
|
||||
from .groot.configuration_groot import GrootConfig
|
||||
@@ -92,7 +93,7 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
Args:
|
||||
name: The name of the policy. Supported names are "tdmpc", "diffusion", "act",
|
||||
"multi_task_dit", "vqbet", "pi0", "pi05", "gaussian_actor", "smolvla", "wall_x",
|
||||
"molmoact2".
|
||||
"molmoact2", "eo1", "evo1".
|
||||
Returns:
|
||||
The policy class corresponding to the given name.
|
||||
|
||||
@@ -167,6 +168,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
from .fastwam.modeling_fastwam import FastWAMPolicy
|
||||
|
||||
return FastWAMPolicy
|
||||
elif name == "evo1":
|
||||
from .evo1.modeling_evo1 import Evo1Policy
|
||||
|
||||
return Evo1Policy
|
||||
else:
|
||||
try:
|
||||
return _get_policy_cls_from_policy_name(name=name)
|
||||
@@ -184,7 +189,7 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
Args:
|
||||
policy_type: The type of the policy. Supported types include "tdmpc",
|
||||
"multi_task_dit", "diffusion", "act", "vqbet", "pi0", "pi05", "gaussian_actor",
|
||||
"smolvla", "wall_x", "molmoact2".
|
||||
"smolvla", "wall_x", "molmoact2", "eo1", "evo1".
|
||||
**kwargs: Keyword arguments to be passed to the configuration class constructor.
|
||||
|
||||
Returns:
|
||||
@@ -225,6 +230,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
return VLAJEPAConfig(**kwargs)
|
||||
elif policy_type == "fastwam":
|
||||
return FastWAMConfig(**kwargs)
|
||||
elif policy_type == "evo1":
|
||||
return Evo1Config(**kwargs)
|
||||
else:
|
||||
try:
|
||||
config_cls = PreTrainedConfig.get_choice_class(policy_type)
|
||||
@@ -330,6 +337,14 @@ def make_pre_post_processors(
|
||||
revision=pretrained_revision,
|
||||
)
|
||||
_reconnect_relative_absolute_steps(preprocessor, postprocessor)
|
||||
if isinstance(policy_cfg, Evo1Config):
|
||||
from .evo1.processor_evo1 import reconcile_evo1_processors
|
||||
|
||||
preprocessor, postprocessor = reconcile_evo1_processors(
|
||||
policy_cfg,
|
||||
preprocessor,
|
||||
postprocessor,
|
||||
)
|
||||
return preprocessor, postprocessor
|
||||
|
||||
# Create a new processor based on policy type
|
||||
@@ -440,6 +455,13 @@ def make_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
elif isinstance(policy_cfg, Evo1Config):
|
||||
from .evo1.processor_evo1 import make_evo1_pre_post_processors
|
||||
|
||||
processors = make_evo1_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, MolmoAct2Config):
|
||||
from .molmoact2.processor_molmoact2 import make_molmoact2_pre_post_processors
|
||||
|
||||
@@ -7,11 +7,14 @@ from dataclasses import dataclass, field
|
||||
|
||||
import gymnasium as gym
|
||||
import pytest
|
||||
import torch
|
||||
from gymnasium.envs.registration import register, registry as gym_registry
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.envs.configs import EnvConfig
|
||||
from lerobot.envs.configs import EnvConfig, LiberoEnv
|
||||
from lerobot.envs.factory import make_env, make_env_config, make_env_pre_post_processors
|
||||
from lerobot.processor import LiberoProcessorStep
|
||||
from lerobot.utils.constants import OBS_PREFIX, OBS_STATE
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -61,6 +64,31 @@ def test_processors_delegation():
|
||||
assert len(pre.steps) == 0
|
||||
|
||||
|
||||
def test_libero_processors_are_policy_agnostic():
|
||||
cfg = LiberoEnv()
|
||||
pre, post = make_env_pre_post_processors(cfg, policy_cfg=object())
|
||||
|
||||
assert isinstance(pre.steps[0], LiberoProcessorStep)
|
||||
assert len(post.steps) == 0
|
||||
|
||||
|
||||
def test_libero_processor_flattens_state_to_raw_8_dim():
|
||||
step = LiberoProcessorStep()
|
||||
observation = {
|
||||
OBS_PREFIX + "robot_state": {
|
||||
"eef": {
|
||||
"pos": torch.tensor([[1.0, 2.0, 3.0]]),
|
||||
"quat": torch.tensor([[0.0, 0.0, 0.0, 1.0]]),
|
||||
},
|
||||
"gripper": {"qpos": torch.tensor([[4.0, 5.0]])},
|
||||
}
|
||||
}
|
||||
|
||||
state = step.observation(observation)[OBS_STATE]
|
||||
assert state.shape == (1, 8)
|
||||
assert torch.allclose(state, torch.tensor([[1.0, 2.0, 3.0, 0.0, 0.0, 0.0, 4.0, 5.0]]))
|
||||
|
||||
|
||||
def test_base_create_envs():
|
||||
"""Base class create_envs() should build a single-task VectorEnv via gym.make()."""
|
||||
gym_id = "_dispatch_test/CartPole-v99"
|
||||
|
||||
@@ -0,0 +1,840 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
import lerobot.policies.evo1.evo1_model as evo1_model
|
||||
import lerobot.policies.evo1.modeling_evo1 as modeling_evo1
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.policies.evo1.configuration_evo1 import Evo1Config
|
||||
from lerobot.policies.evo1.flow_matching import FlowmatchingActionHead
|
||||
from lerobot.policies.evo1.internvl3_embedder import (
|
||||
IMAGENET_MEAN,
|
||||
IMAGENET_STD,
|
||||
_batched_pixel_values,
|
||||
)
|
||||
from lerobot.policies.evo1.processor_evo1 import (
|
||||
Evo1ActionProcessorStep,
|
||||
Evo1PadActionProcessorStep,
|
||||
Evo1PadStateProcessorStep,
|
||||
evo1_batch_to_transition,
|
||||
make_evo1_pre_post_processors,
|
||||
reconcile_evo1_processors,
|
||||
)
|
||||
from lerobot.policies.factory import get_policy_class, make_policy_config
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
|
||||
from lerobot.processor import (
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyProcessorPipeline,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
from lerobot.processor.converters import (
|
||||
batch_to_transition,
|
||||
policy_action_to_transition,
|
||||
transition_to_batch,
|
||||
transition_to_policy_action,
|
||||
)
|
||||
from lerobot.utils.constants import (
|
||||
ACTION,
|
||||
OBS_IMAGES,
|
||||
OBS_STATE,
|
||||
POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
)
|
||||
|
||||
STATE_DIM = 4
|
||||
ACTION_DIM = 3
|
||||
MAX_STATE_DIM = 6
|
||||
MAX_ACTION_DIM = 5
|
||||
CHUNK_SIZE = 2
|
||||
EMBED_DIM = 8
|
||||
|
||||
|
||||
class DummyEvo1Model(nn.Module):
|
||||
def __init__(self, config, vlm_hub_kwargs=None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embedder = nn.Dropout(p=0.0)
|
||||
self.action_head = nn.Linear(1, 1)
|
||||
self.get_vl_embeddings_calls = 0
|
||||
self.grad_enabled_calls = []
|
||||
self.embedder_training_calls = []
|
||||
|
||||
def set_finetune_flags(self):
|
||||
return None
|
||||
|
||||
def get_vl_embeddings(self, images, image_mask, prompt=None, return_cls_only=False):
|
||||
self.get_vl_embeddings_calls += 1
|
||||
self.grad_enabled_calls.append(torch.is_grad_enabled())
|
||||
self.embedder_training_calls.append(self.embedder.training)
|
||||
# images is a list of per-camera (B, C, H, W) tensors, so the batch dim is images[0].shape[0].
|
||||
batch_size = images[0].shape[0]
|
||||
tokens = torch.ones(batch_size, 4, EMBED_DIM, requires_grad=torch.is_grad_enabled())
|
||||
valid_mask = torch.ones(batch_size, 4, dtype=torch.bool)
|
||||
return tokens, valid_mask
|
||||
|
||||
def forward(
|
||||
self,
|
||||
fused_tokens,
|
||||
state=None,
|
||||
actions_gt=None,
|
||||
action_mask=None,
|
||||
embodiment_ids=None,
|
||||
context_mask=None,
|
||||
**kwargs,
|
||||
):
|
||||
batch_size = fused_tokens.shape[0]
|
||||
if actions_gt is None:
|
||||
return torch.ones(batch_size, CHUNK_SIZE * MAX_ACTION_DIM)
|
||||
pred_velocity = torch.zeros(batch_size, CHUNK_SIZE * MAX_ACTION_DIM)
|
||||
noise = torch.zeros_like(actions_gt)
|
||||
return pred_velocity, noise
|
||||
|
||||
|
||||
class ChunkCountingDummyModel(DummyEvo1Model):
|
||||
"""Emits per-step distinguishable actions so queue ordering and re-prediction are observable."""
|
||||
|
||||
def __init__(self, config, vlm_hub_kwargs=None):
|
||||
super().__init__(config, vlm_hub_kwargs)
|
||||
self.chunks_predicted = 0
|
||||
|
||||
def forward(
|
||||
self,
|
||||
fused_tokens,
|
||||
state=None,
|
||||
actions_gt=None,
|
||||
action_mask=None,
|
||||
embodiment_ids=None,
|
||||
context_mask=None,
|
||||
**kwargs,
|
||||
):
|
||||
if actions_gt is not None:
|
||||
return super().forward(fused_tokens, state, actions_gt, action_mask, embodiment_ids, context_mask)
|
||||
self.chunks_predicted += 1
|
||||
batch_size = fused_tokens.shape[0]
|
||||
step_values = torch.arange(CHUNK_SIZE, dtype=torch.float32) + 10.0 * self.chunks_predicted
|
||||
chunk = step_values.repeat_interleave(MAX_ACTION_DIM).unsqueeze(0).repeat(batch_size, 1)
|
||||
return chunk
|
||||
|
||||
|
||||
def make_config(training_stage="stage1", **kwargs):
|
||||
config_kwargs = {
|
||||
"device": "cpu",
|
||||
"vlm_model_name": "dummy-internvl3",
|
||||
"training_stage": training_stage,
|
||||
"chunk_size": CHUNK_SIZE,
|
||||
"n_action_steps": 1,
|
||||
"max_state_dim": MAX_STATE_DIM,
|
||||
"max_action_dim": MAX_ACTION_DIM,
|
||||
"max_views": 2,
|
||||
"embed_dim": EMBED_DIM,
|
||||
"hidden_dim": 16,
|
||||
"state_hidden_dim": 16,
|
||||
"num_heads": 2,
|
||||
"num_layers": 1,
|
||||
"num_inference_timesteps": 2,
|
||||
"input_features": {
|
||||
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(STATE_DIM,)),
|
||||
f"{OBS_IMAGES}.front": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 16, 16)),
|
||||
},
|
||||
"output_features": {
|
||||
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(ACTION_DIM,)),
|
||||
},
|
||||
}
|
||||
config_kwargs.update(kwargs)
|
||||
return Evo1Config(**config_kwargs)
|
||||
|
||||
|
||||
def make_batch(include_action=True):
|
||||
batch = {
|
||||
"task": ["pick the block", "place the block"],
|
||||
OBS_STATE: torch.randn(2, STATE_DIM),
|
||||
f"{OBS_IMAGES}.front": torch.rand(2, 3, 16, 16),
|
||||
}
|
||||
if include_action:
|
||||
batch[ACTION] = torch.randn(2, CHUNK_SIZE, ACTION_DIM)
|
||||
return batch
|
||||
|
||||
|
||||
def make_stats(state_dim=STATE_DIM, action_dim=ACTION_DIM):
|
||||
return {
|
||||
OBS_STATE: {
|
||||
"min": torch.full((state_dim,), -2.0),
|
||||
"max": torch.full((state_dim,), 2.0),
|
||||
},
|
||||
ACTION: {
|
||||
"min": torch.full((action_dim,), -1.0),
|
||||
"max": torch.full((action_dim,), 1.0),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def make_flowmatching_head(**overrides):
|
||||
kwargs = {
|
||||
"embed_dim": EMBED_DIM,
|
||||
"hidden_dim": 16,
|
||||
"action_dim": CHUNK_SIZE * ACTION_DIM,
|
||||
"horizon": CHUNK_SIZE,
|
||||
"per_action_dim": ACTION_DIM,
|
||||
"num_heads": 2,
|
||||
"num_layers": 1,
|
||||
"num_inference_timesteps": 2,
|
||||
"state_dim": STATE_DIM,
|
||||
"state_hidden_dim": 16,
|
||||
"num_categories": 1,
|
||||
}
|
||||
kwargs.update(overrides)
|
||||
return FlowmatchingActionHead(**kwargs)
|
||||
|
||||
|
||||
def test_evo1_factory_registration():
|
||||
cfg = make_policy_config(
|
||||
"evo1",
|
||||
device="cpu",
|
||||
vlm_model_name="dummy-internvl3",
|
||||
input_features={
|
||||
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(STATE_DIM,)),
|
||||
f"{OBS_IMAGES}.front": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 16, 16)),
|
||||
},
|
||||
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(ACTION_DIM,))},
|
||||
)
|
||||
|
||||
assert isinstance(cfg, Evo1Config)
|
||||
assert get_policy_class("evo1") is modeling_evo1.Evo1Policy
|
||||
|
||||
|
||||
def test_evo1_stage_defaults_and_consistency():
|
||||
stage1 = make_config(training_stage="stage1")
|
||||
assert (stage1.finetune_vlm, stage1.finetune_language_model, stage1.finetune_vision_model) == (
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
)
|
||||
assert stage1.finetune_action_head is True
|
||||
|
||||
stage2 = make_config(training_stage="stage2")
|
||||
assert (stage2.finetune_vlm, stage2.finetune_language_model, stage2.finetune_vision_model) == (
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
)
|
||||
assert stage2.finetune_action_head is True
|
||||
|
||||
stage2_from_stage1_checkpoint_flags = make_config(
|
||||
training_stage="stage2",
|
||||
finetune_vlm=False,
|
||||
finetune_language_model=False,
|
||||
finetune_vision_model=False,
|
||||
finetune_action_head=False,
|
||||
)
|
||||
assert (
|
||||
stage2_from_stage1_checkpoint_flags.finetune_vlm,
|
||||
stage2_from_stage1_checkpoint_flags.finetune_language_model,
|
||||
stage2_from_stage1_checkpoint_flags.finetune_vision_model,
|
||||
) == (
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
)
|
||||
assert stage2_from_stage1_checkpoint_flags.finetune_action_head is True
|
||||
|
||||
explicit_off = make_config(
|
||||
training_stage="stage2",
|
||||
apply_training_stage_defaults=False,
|
||||
finetune_vlm=False,
|
||||
finetune_language_model=False,
|
||||
finetune_vision_model=False,
|
||||
finetune_action_head=False,
|
||||
)
|
||||
assert (
|
||||
explicit_off.finetune_vlm,
|
||||
explicit_off.finetune_language_model,
|
||||
explicit_off.finetune_vision_model,
|
||||
) == (
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
)
|
||||
assert explicit_off.finetune_action_head is False
|
||||
|
||||
# An explicit finetune_vlm=False without branch-level flags freezes both branches instead of
|
||||
# raising an inconsistency error.
|
||||
frozen_vlm = make_config(
|
||||
training_stage="stage2",
|
||||
apply_training_stage_defaults=False,
|
||||
finetune_vlm=False,
|
||||
)
|
||||
assert (
|
||||
frozen_vlm.finetune_vlm,
|
||||
frozen_vlm.finetune_language_model,
|
||||
frozen_vlm.finetune_vision_model,
|
||||
) == (False, False, False)
|
||||
|
||||
try:
|
||||
make_config(
|
||||
training_stage="stage2",
|
||||
apply_training_stage_defaults=False,
|
||||
finetune_vlm=True,
|
||||
finetune_language_model=False,
|
||||
)
|
||||
except ValueError as exc:
|
||||
assert "Inconsistent EVO1 finetune config" in str(exc)
|
||||
else:
|
||||
raise AssertionError("Expected inconsistent finetune config to raise ValueError")
|
||||
|
||||
|
||||
def test_evo1_rejects_non_square_image_resolution():
|
||||
with pytest.raises(ValueError, match="square image_resolution"):
|
||||
make_config(image_resolution=(448, 320))
|
||||
|
||||
|
||||
def test_evo1_rejects_out_of_range_default_embodiment_id():
|
||||
with pytest.raises(ValueError, match="default_embodiment_id"):
|
||||
make_config(default_embodiment_id=3, num_categories=2)
|
||||
|
||||
|
||||
def test_evo1_model_uses_image_resolution_and_trainable_checkpointing(monkeypatch):
|
||||
captured: dict = {}
|
||||
|
||||
class SpyEmbedder(nn.Module):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__()
|
||||
captured.clear()
|
||||
captured.update(kwargs)
|
||||
|
||||
monkeypatch.setattr(evo1_model, "InternVL3Embedder", SpyEmbedder)
|
||||
|
||||
stage1 = make_config(training_stage="stage1", image_resolution=(224, 224))
|
||||
evo1_model.Evo1Model(stage1)
|
||||
assert captured["image_size"] == 224
|
||||
# VLM is frozen in stage1, so gradient checkpointing is gated off.
|
||||
assert captured["enable_gradient_checkpointing"] is False
|
||||
|
||||
stage2 = make_config(training_stage="stage2", image_resolution=(224, 224))
|
||||
evo1_model.Evo1Model(stage2)
|
||||
assert captured["enable_gradient_checkpointing"] is True
|
||||
|
||||
|
||||
class FakeInternVLModel(nn.Module):
|
||||
"""Minimal stand-in with the native HF InternVL submodule layout."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.language_model = nn.Linear(2, 2)
|
||||
self.vision_tower = nn.Linear(2, 2)
|
||||
self.multi_modal_projector = nn.Linear(2, 2)
|
||||
|
||||
|
||||
class FakeEmbedder(nn.Module):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__()
|
||||
self.model = FakeInternVLModel()
|
||||
|
||||
|
||||
def test_set_finetune_flags_targets_native_hf_internvl_submodules(monkeypatch):
|
||||
monkeypatch.setattr(evo1_model, "InternVL3Embedder", FakeEmbedder)
|
||||
|
||||
stage2_model = evo1_model.Evo1Model(make_config(training_stage="stage2"))
|
||||
stage2_model.set_finetune_flags()
|
||||
vlm = stage2_model.embedder.model
|
||||
assert all(p.requires_grad for p in vlm.language_model.parameters())
|
||||
assert all(p.requires_grad for p in vlm.vision_tower.parameters())
|
||||
assert all(p.requires_grad for p in vlm.multi_modal_projector.parameters())
|
||||
assert all(p.requires_grad for p in stage2_model.action_head.parameters())
|
||||
|
||||
stage1_model = evo1_model.Evo1Model(make_config(training_stage="stage1"))
|
||||
stage1_model.set_finetune_flags()
|
||||
vlm = stage1_model.embedder.model
|
||||
assert not any(p.requires_grad for p in vlm.parameters())
|
||||
assert all(p.requires_grad for p in stage1_model.action_head.parameters())
|
||||
|
||||
|
||||
def test_set_finetune_flags_fails_loudly_on_unknown_vlm_layout(monkeypatch):
|
||||
class LegacyLayoutModel(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.language_model = nn.Linear(2, 2)
|
||||
self.vision_model = nn.Linear(2, 2) # trust_remote_code-era attribute name
|
||||
self.mlp1 = nn.Linear(2, 2)
|
||||
|
||||
class FakeEmbedder(nn.Module):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__()
|
||||
self.model = LegacyLayoutModel()
|
||||
|
||||
monkeypatch.setattr(evo1_model, "InternVL3Embedder", FakeEmbedder)
|
||||
model = evo1_model.Evo1Model(make_config(training_stage="stage2"))
|
||||
with pytest.raises(AttributeError, match="vision_tower"):
|
||||
model.set_finetune_flags()
|
||||
|
||||
|
||||
def test_evo1_policy_processors_pad_state_crop_action_and_binarize_gripper():
|
||||
libero_action_dim = 7
|
||||
config = make_config(
|
||||
max_state_dim=MAX_STATE_DIM,
|
||||
max_action_dim=8,
|
||||
postprocess_action_dim=libero_action_dim,
|
||||
binarize_gripper=True,
|
||||
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(libero_action_dim,))},
|
||||
)
|
||||
stats = make_stats(action_dim=libero_action_dim)
|
||||
|
||||
preprocessor, postprocessor = make_evo1_pre_post_processors(config, dataset_stats=stats)
|
||||
|
||||
assert isinstance(preprocessor.steps[2], Evo1PadStateProcessorStep)
|
||||
assert isinstance(preprocessor.steps[3], Evo1PadActionProcessorStep)
|
||||
assert isinstance(preprocessor.steps[4], NormalizerProcessorStep)
|
||||
assert isinstance(postprocessor.steps[0], UnnormalizerProcessorStep)
|
||||
assert isinstance(postprocessor.steps[1], Evo1ActionProcessorStep)
|
||||
|
||||
normalizer = preprocessor.steps[4]
|
||||
assert normalizer.features[OBS_STATE].shape == (MAX_STATE_DIM,)
|
||||
assert normalizer.features[ACTION].shape == (8,)
|
||||
assert normalizer._tensor_stats[OBS_STATE]["min"].shape == (MAX_STATE_DIM,)
|
||||
assert normalizer._tensor_stats[ACTION]["min"].shape == (8,)
|
||||
|
||||
processed_batch = preprocessor(
|
||||
{
|
||||
"task": "pick the block",
|
||||
OBS_STATE: torch.zeros(STATE_DIM),
|
||||
ACTION: torch.zeros(libero_action_dim),
|
||||
f"{OBS_IMAGES}.front": torch.rand(3, 16, 16),
|
||||
}
|
||||
)
|
||||
processed_state = processed_batch[OBS_STATE]
|
||||
assert processed_state.shape == (1, MAX_STATE_DIM)
|
||||
assert torch.allclose(processed_state, torch.zeros_like(processed_state))
|
||||
assert processed_batch[ACTION].shape == (1, 8)
|
||||
assert torch.allclose(processed_batch[ACTION], torch.zeros_like(processed_batch[ACTION]))
|
||||
assert processed_batch["action_mask"].shape == (1, 8)
|
||||
assert processed_batch["action_mask"][:, :libero_action_dim].all()
|
||||
assert not processed_batch["action_mask"][:, libero_action_dim:].any()
|
||||
|
||||
action = torch.tensor(
|
||||
[
|
||||
[0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.5, 0.7],
|
||||
[0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7],
|
||||
],
|
||||
dtype=torch.float32,
|
||||
)
|
||||
processed = postprocessor(action)
|
||||
|
||||
assert processed.shape == (2, 7)
|
||||
assert processed.dtype == torch.float32
|
||||
assert torch.allclose(processed[:, :6], action[:, :6])
|
||||
assert torch.equal(processed[:, 6], torch.tensor([1.0, -1.0]))
|
||||
|
||||
|
||||
def test_evo1_postprocessor_returns_float32_for_bf16_actions():
|
||||
config = make_config()
|
||||
_preprocessor, postprocessor = make_evo1_pre_post_processors(config, dataset_stats=make_stats())
|
||||
|
||||
processed = postprocessor(torch.zeros(2, MAX_ACTION_DIM, dtype=torch.bfloat16))
|
||||
assert processed.dtype == torch.float32
|
||||
|
||||
|
||||
def test_evo1_processor_save_load_round_trip_applies_config_overrides(tmp_path):
|
||||
train_config = make_config()
|
||||
preprocessor, postprocessor = make_evo1_pre_post_processors(train_config, dataset_stats=make_stats())
|
||||
preprocessor.save_pretrained(tmp_path)
|
||||
postprocessor.save_pretrained(tmp_path)
|
||||
|
||||
loaded_pre = PolicyProcessorPipeline.from_pretrained(
|
||||
tmp_path,
|
||||
config_filename=f"{POLICY_PREPROCESSOR_DEFAULT_NAME}.json",
|
||||
to_transition=batch_to_transition,
|
||||
to_output=transition_to_batch,
|
||||
)
|
||||
loaded_post = PolicyProcessorPipeline.from_pretrained(
|
||||
tmp_path,
|
||||
config_filename=f"{POLICY_POSTPROCESSOR_DEFAULT_NAME}.json",
|
||||
to_transition=policy_action_to_transition,
|
||||
to_output=transition_to_policy_action,
|
||||
)
|
||||
|
||||
# Simulate eval-time CLI overrides applied on top of the loaded pipelines.
|
||||
eval_config = make_config(binarize_gripper=True, postprocess_action_dim=ACTION_DIM)
|
||||
loaded_pre, loaded_post = reconcile_evo1_processors(eval_config, loaded_pre, loaded_post)
|
||||
|
||||
assert loaded_pre.to_transition is evo1_batch_to_transition
|
||||
assert sum(isinstance(step, Evo1ActionProcessorStep) for step in loaded_post.steps) == 1
|
||||
action_step = next(step for step in loaded_post.steps if isinstance(step, Evo1ActionProcessorStep))
|
||||
assert action_step.binarize_gripper is True
|
||||
assert action_step.action_dim == ACTION_DIM
|
||||
# The float32 output dtype is part of the serialized pipeline itself.
|
||||
device_step = next(step for step in loaded_post.steps if isinstance(step, DeviceProcessorStep))
|
||||
assert device_step.float_dtype == "float32"
|
||||
|
||||
# Non-observation extras (embodiment_id, ...) must survive the reloaded preprocessor.
|
||||
processed = loaded_pre(
|
||||
{
|
||||
"task": "pick the block",
|
||||
OBS_STATE: torch.zeros(STATE_DIM),
|
||||
f"{OBS_IMAGES}.front": torch.rand(3, 16, 16),
|
||||
"embodiment_id": torch.tensor([0]),
|
||||
}
|
||||
)
|
||||
assert "embodiment_id" in processed
|
||||
|
||||
|
||||
def test_evo1_policy_forward_and_inference_use_batched_embedding(monkeypatch):
|
||||
monkeypatch.setattr(modeling_evo1, "Evo1Model", DummyEvo1Model)
|
||||
policy = modeling_evo1.Evo1Policy(make_config())
|
||||
preprocessor, _postprocessor = make_evo1_pre_post_processors(policy.config, dataset_stats=make_stats())
|
||||
training_batch = preprocessor(make_batch(include_action=True))
|
||||
|
||||
assert training_batch[ACTION].shape == (2, CHUNK_SIZE, MAX_ACTION_DIM)
|
||||
assert training_batch["action_mask"].shape == (2, CHUNK_SIZE, MAX_ACTION_DIM)
|
||||
assert training_batch["action_mask"][:, :, :ACTION_DIM].all()
|
||||
assert not training_batch["action_mask"][:, :, ACTION_DIM:].any()
|
||||
|
||||
loss, metrics = policy.forward(training_batch)
|
||||
assert loss.ndim == 0
|
||||
assert torch.isfinite(loss)
|
||||
assert metrics["active_action_dims"] == ACTION_DIM * CHUNK_SIZE
|
||||
assert policy.model.get_vl_embeddings_calls == 1
|
||||
|
||||
action_chunk = policy.predict_action_chunk(make_batch(include_action=False))
|
||||
assert action_chunk.shape == (2, CHUNK_SIZE, MAX_ACTION_DIM)
|
||||
assert action_chunk.dtype == torch.float32
|
||||
|
||||
policy.reset()
|
||||
selected = policy.select_action(make_batch(include_action=False))
|
||||
assert selected.shape == (2, MAX_ACTION_DIM)
|
||||
|
||||
|
||||
def test_evo1_forward_masks_padded_action_timesteps(monkeypatch):
|
||||
monkeypatch.setattr(modeling_evo1, "Evo1Model", DummyEvo1Model)
|
||||
policy = modeling_evo1.Evo1Policy(make_config())
|
||||
|
||||
batch = make_batch(include_action=True)
|
||||
batch[ACTION] = torch.ones(2, CHUNK_SIZE, ACTION_DIM)
|
||||
# Give the padded (past-episode-end) timestep a huge value: if it leaked into the loss, the
|
||||
# loss would blow up far beyond 1.0.
|
||||
batch[ACTION][:, -1, :] = 100.0
|
||||
batch["action_is_pad"] = torch.zeros(2, CHUNK_SIZE, dtype=torch.bool)
|
||||
batch["action_is_pad"][:, -1] = True
|
||||
|
||||
loss, metrics = policy.forward(batch)
|
||||
|
||||
# DummyEvo1Model predicts zero velocity and zero noise, so each active element contributes
|
||||
# (0 - action)^2 = 1.0 for the in-episode ones-valued actions.
|
||||
assert metrics["active_action_dims"] == ACTION_DIM * (CHUNK_SIZE - 1)
|
||||
assert torch.isclose(loss, torch.tensor(1.0))
|
||||
|
||||
|
||||
def test_evo1_select_action_queue_orders_steps_and_repredicts(monkeypatch):
|
||||
monkeypatch.setattr(modeling_evo1, "Evo1Model", ChunkCountingDummyModel)
|
||||
policy = modeling_evo1.Evo1Policy(make_config(n_action_steps=CHUNK_SIZE))
|
||||
|
||||
batch = make_batch(include_action=False)
|
||||
first = policy.select_action(batch)
|
||||
second = policy.select_action(batch)
|
||||
third = policy.select_action(batch)
|
||||
|
||||
# First chunk provides steps 10, 11 in order; the third call triggers a fresh prediction (20).
|
||||
assert torch.all(first == 10.0)
|
||||
assert torch.all(second == 11.0)
|
||||
assert torch.all(third == 20.0)
|
||||
assert policy.model.chunks_predicted == 2
|
||||
|
||||
|
||||
def test_evo1_predict_action_chunk_rejects_rtc_kwargs_without_rtc_config(monkeypatch):
|
||||
monkeypatch.setattr(modeling_evo1, "Evo1Model", DummyEvo1Model)
|
||||
policy = modeling_evo1.Evo1Policy(make_config())
|
||||
with pytest.raises(RuntimeError, match="RTC"):
|
||||
policy.predict_action_chunk(make_batch(include_action=False), inference_delay=2)
|
||||
|
||||
|
||||
def test_evo1_rtc_processor_wiring(monkeypatch):
|
||||
monkeypatch.setattr(evo1_model, "InternVL3Embedder", FakeEmbedder)
|
||||
policy = modeling_evo1.Evo1Policy(make_config())
|
||||
assert policy.rtc_processor is None
|
||||
assert policy.model.rtc_processor is None
|
||||
|
||||
# The RTC rollout backend assigns rtc_config after loading and re-inits the processor.
|
||||
policy.config.rtc_config = RTCConfig(execution_horizon=CHUNK_SIZE)
|
||||
policy.init_rtc_processor()
|
||||
assert isinstance(policy.rtc_processor, RTCProcessor)
|
||||
assert policy.model.rtc_processor is policy.rtc_processor
|
||||
|
||||
# RTC drives predict_action_chunk directly; the select_action queue path is unsupported.
|
||||
with pytest.raises(AssertionError, match="select_action"):
|
||||
policy.select_action(make_batch(include_action=False))
|
||||
|
||||
|
||||
def test_flowmatching_rtc_guidance_pulls_prefix_toward_previous_chunk():
|
||||
head = make_flowmatching_head(num_inference_timesteps=16)
|
||||
processor = RTCProcessor(RTCConfig(execution_horizon=CHUNK_SIZE))
|
||||
fused = torch.randn(2, 4, EMBED_DIM)
|
||||
state = torch.randn(2, STATE_DIM)
|
||||
action_mask = torch.ones(2, ACTION_DIM, dtype=torch.bool)
|
||||
prev_chunk = torch.tensor([0.7, -0.4, 0.2]).expand(2, CHUNK_SIZE, ACTION_DIM).contiguous()
|
||||
|
||||
torch.manual_seed(0)
|
||||
unguided = head.get_action(fused, state=state, action_mask=action_mask)
|
||||
unguided = unguided.view(2, CHUNK_SIZE, ACTION_DIM)
|
||||
torch.manual_seed(0)
|
||||
guided = head.get_action(
|
||||
fused,
|
||||
state=state,
|
||||
action_mask=action_mask,
|
||||
inference_delay=1,
|
||||
prev_chunk_left_over=prev_chunk,
|
||||
rtc_processor=processor,
|
||||
)
|
||||
guided = guided.view(2, CHUNK_SIZE, ACTION_DIM)
|
||||
|
||||
# The frozen prefix (first inference_delay steps) must land far closer to the previous chunk
|
||||
# than the unguided sample from the same noise does.
|
||||
guided_dist = (guided[:, 0] - prev_chunk[:, 0]).abs().mean()
|
||||
unguided_dist = (unguided[:, 0] - prev_chunk[:, 0]).abs().mean()
|
||||
assert guided_dist < 0.5 * unguided_dist
|
||||
assert torch.isfinite(guided).all()
|
||||
|
||||
|
||||
def test_flowmatching_rtc_first_chunk_without_leftover_matches_unguided():
|
||||
head = make_flowmatching_head(num_inference_timesteps=4)
|
||||
processor = RTCProcessor(RTCConfig(execution_horizon=CHUNK_SIZE))
|
||||
fused = torch.randn(2, 4, EMBED_DIM)
|
||||
state = torch.randn(2, STATE_DIM)
|
||||
action_mask = torch.ones(2, ACTION_DIM, dtype=torch.bool)
|
||||
|
||||
torch.manual_seed(0)
|
||||
unguided = head.get_action(fused, state=state, action_mask=action_mask)
|
||||
torch.manual_seed(0)
|
||||
first_chunk = head.get_action(
|
||||
fused,
|
||||
state=state,
|
||||
action_mask=action_mask,
|
||||
inference_delay=2,
|
||||
prev_chunk_left_over=None,
|
||||
rtc_processor=processor,
|
||||
)
|
||||
|
||||
assert torch.allclose(unguided, first_chunk)
|
||||
|
||||
|
||||
def test_evo1_missing_configured_camera_needs_empty_cameras_budget(monkeypatch):
|
||||
monkeypatch.setattr(modeling_evo1, "Evo1Model", DummyEvo1Model)
|
||||
batch = make_batch(include_action=False) # only provides the front camera
|
||||
|
||||
two_camera_features = {
|
||||
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(STATE_DIM,)),
|
||||
f"{OBS_IMAGES}.front": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 16, 16)),
|
||||
f"{OBS_IMAGES}.wrist": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 16, 16)),
|
||||
}
|
||||
strict_policy = modeling_evo1.Evo1Policy(make_config(input_features=dict(two_camera_features)))
|
||||
with pytest.raises(ValueError, match="empty_cameras"):
|
||||
strict_policy._collect_image_batches(batch)
|
||||
|
||||
# empty_cameras adds placeholder camera features that are never present in the batch; they
|
||||
# become masked-out views instead of crashing with a KeyError.
|
||||
padded_policy = modeling_evo1.Evo1Policy(make_config(empty_cameras=1))
|
||||
assert len(padded_policy.config.image_features) == 2
|
||||
camera_images, image_masks = padded_policy._collect_image_batches(batch)
|
||||
assert len(camera_images) == 1
|
||||
assert image_masks.tolist() == [[True, False], [True, False]]
|
||||
|
||||
|
||||
def test_stage1_frozen_vlm_embeddings_do_not_track_gradients(monkeypatch):
|
||||
monkeypatch.setattr(modeling_evo1, "Evo1Model", DummyEvo1Model)
|
||||
policy = modeling_evo1.Evo1Policy(make_config(training_stage="stage1"))
|
||||
policy.train()
|
||||
|
||||
image_batches, image_masks = policy._collect_image_batches(make_batch(include_action=False))
|
||||
fused_tokens, context_mask = policy._compute_fused_tokens(["pick", "place"], image_batches, image_masks)
|
||||
|
||||
assert policy.model.grad_enabled_calls == [False]
|
||||
assert policy.model.embedder_training_calls == [False]
|
||||
assert not fused_tokens.requires_grad
|
||||
assert context_mask is not None
|
||||
assert policy.model.embedder.training is False
|
||||
|
||||
|
||||
def test_stage2_vlm_embeddings_track_gradients(monkeypatch):
|
||||
monkeypatch.setattr(modeling_evo1, "Evo1Model", DummyEvo1Model)
|
||||
policy = modeling_evo1.Evo1Policy(make_config(training_stage="stage2"))
|
||||
policy.train()
|
||||
|
||||
image_batches, image_masks = policy._collect_image_batches(make_batch(include_action=False))
|
||||
fused_tokens, _context_mask = policy._compute_fused_tokens(["pick", "place"], image_batches, image_masks)
|
||||
|
||||
assert policy.model.grad_enabled_calls == [True]
|
||||
assert policy.model.embedder_training_calls == [True]
|
||||
assert fused_tokens.requires_grad
|
||||
|
||||
|
||||
def test_collect_image_batches_handles_unbatched_chw(monkeypatch):
|
||||
# Regression for an issue where batch_size was read from shape[0] before normalizing
|
||||
# per-camera tensor dims, so an unbatched (C, H, W) input was treated as batch_size=C.
|
||||
monkeypatch.setattr(modeling_evo1, "Evo1Model", DummyEvo1Model)
|
||||
policy = modeling_evo1.Evo1Policy(make_config())
|
||||
batch = {
|
||||
OBS_STATE: torch.randn(1, STATE_DIM),
|
||||
f"{OBS_IMAGES}.front": torch.rand(3, 16, 16),
|
||||
}
|
||||
|
||||
camera_images, image_masks = policy._collect_image_batches(batch)
|
||||
|
||||
# One present camera, returned as a batched (B, C, H, W) tensor with the unbatched CHW frame
|
||||
# promoted to batch_size=1 (not read as batch_size=C).
|
||||
assert len(camera_images) == 1
|
||||
assert camera_images[0].shape == (1, 3, 16, 16)
|
||||
assert image_masks.tolist() == [[True, False]]
|
||||
|
||||
|
||||
def test_evo1_state_mask_zeroes_masked_dims(monkeypatch):
|
||||
monkeypatch.setattr(modeling_evo1, "Evo1Model", DummyEvo1Model)
|
||||
policy = modeling_evo1.Evo1Policy(make_config())
|
||||
batch = {
|
||||
OBS_STATE: torch.ones(2, STATE_DIM),
|
||||
"state_mask": torch.tensor([[True, True, False, False]] * 2),
|
||||
}
|
||||
|
||||
states, mask = policy._prepare_state(batch)
|
||||
|
||||
assert torch.all(states[:, :2] == 1.0)
|
||||
assert torch.all(states[:, 2:] == 0.0)
|
||||
assert mask[:, :2].all()
|
||||
assert not mask[:, 2:].any()
|
||||
|
||||
|
||||
def test_evo1_action_mask_accepts_chunk_size_one(monkeypatch):
|
||||
monkeypatch.setattr(modeling_evo1, "Evo1Model", DummyEvo1Model)
|
||||
config = make_config(chunk_size=1, n_action_steps=1)
|
||||
policy = modeling_evo1.Evo1Policy(config)
|
||||
batch = make_batch(include_action=True)
|
||||
batch[ACTION] = torch.randn(2, ACTION_DIM)
|
||||
batch["action_mask"] = torch.ones(2, ACTION_DIM, dtype=torch.bool)
|
||||
|
||||
actions, action_mask = policy._prepare_actions(batch)
|
||||
|
||||
assert actions.shape == (2, 1, MAX_ACTION_DIM)
|
||||
assert action_mask.shape == (2, 1, MAX_ACTION_DIM)
|
||||
assert action_mask[:, :, :ACTION_DIM].all()
|
||||
assert not action_mask[:, :, ACTION_DIM:].any()
|
||||
|
||||
|
||||
def test_flowmatching_state_encoder_for_horizon_one():
|
||||
head = make_flowmatching_head(action_dim=ACTION_DIM, horizon=1)
|
||||
|
||||
assert head.state_encoder is not None
|
||||
pred_velocity, noise = head(
|
||||
torch.randn(2, 4, EMBED_DIM),
|
||||
state=torch.randn(2, STATE_DIM),
|
||||
actions_gt=torch.randn(2, 1, ACTION_DIM),
|
||||
action_mask=torch.ones(2, 1, ACTION_DIM, dtype=torch.bool),
|
||||
)
|
||||
|
||||
assert pred_velocity.shape == (2, ACTION_DIM)
|
||||
assert noise.shape == (2, 1, ACTION_DIM)
|
||||
|
||||
|
||||
def test_flowmatching_get_action_real_path_respects_action_mask():
|
||||
torch.manual_seed(0)
|
||||
head = make_flowmatching_head()
|
||||
|
||||
action_mask = torch.zeros(2, ACTION_DIM, dtype=torch.bool)
|
||||
action_mask[:, :2] = True
|
||||
actions = head.get_action(
|
||||
torch.randn(2, 4, EMBED_DIM),
|
||||
state=torch.randn(2, STATE_DIM),
|
||||
action_mask=action_mask,
|
||||
)
|
||||
|
||||
assert actions.shape == (2, CHUNK_SIZE * ACTION_DIM)
|
||||
assert torch.isfinite(actions).all()
|
||||
action_seq = actions.view(2, CHUNK_SIZE, ACTION_DIM)
|
||||
assert torch.all(action_seq[..., 2] == 0.0)
|
||||
|
||||
|
||||
def test_flowmatching_context_mask_blocks_masked_context_tokens():
|
||||
head = make_flowmatching_head()
|
||||
state = torch.randn(2, STATE_DIM)
|
||||
action_mask = torch.ones(2, ACTION_DIM, dtype=torch.bool)
|
||||
fused = torch.randn(2, 4, EMBED_DIM)
|
||||
context_mask = torch.ones(2, 4, dtype=torch.bool)
|
||||
context_mask[:, -1] = False
|
||||
corrupted = fused.clone()
|
||||
corrupted[:, -1] = 1e4
|
||||
|
||||
torch.manual_seed(0)
|
||||
reference = head.get_action(fused, state=state, action_mask=action_mask, context_mask=context_mask)
|
||||
torch.manual_seed(0)
|
||||
with_garbage = head.get_action(corrupted, state=state, action_mask=action_mask, context_mask=context_mask)
|
||||
|
||||
assert torch.allclose(reference, with_garbage)
|
||||
|
||||
|
||||
def test_flowmatching_head_accepts_pooled_2d_context():
|
||||
head = make_flowmatching_head()
|
||||
pred_velocity, noise = head(
|
||||
torch.randn(2, EMBED_DIM), # pooled (B, E) context from return_cls_only
|
||||
state=torch.randn(2, STATE_DIM),
|
||||
actions_gt=torch.randn(2, CHUNK_SIZE, ACTION_DIM),
|
||||
action_mask=torch.ones(2, CHUNK_SIZE, ACTION_DIM, dtype=torch.bool),
|
||||
)
|
||||
assert pred_velocity.shape == (2, CHUNK_SIZE * ACTION_DIM)
|
||||
|
||||
actions = head.get_action(
|
||||
torch.randn(2, EMBED_DIM),
|
||||
state=torch.randn(2, STATE_DIM),
|
||||
action_mask=torch.ones(2, ACTION_DIM, dtype=torch.bool),
|
||||
)
|
||||
assert actions.shape == (2, CHUNK_SIZE * ACTION_DIM)
|
||||
|
||||
|
||||
def test_flowmatching_rejects_out_of_range_embodiment_ids():
|
||||
head = make_flowmatching_head(num_categories=2)
|
||||
with pytest.raises(ValueError, match="num_categories"):
|
||||
head.get_action(
|
||||
torch.randn(2, 4, EMBED_DIM),
|
||||
state=torch.randn(2, STATE_DIM),
|
||||
action_mask=torch.ones(2, ACTION_DIM, dtype=torch.bool),
|
||||
embodiment_id=torch.tensor([0, 5]),
|
||||
)
|
||||
|
||||
|
||||
def test_evo1_batched_pixel_values_shape_and_zero_padding():
|
||||
torch.manual_seed(0)
|
||||
batch_size, image_size, max_views = 2, 448, 3
|
||||
camera_images = [torch.rand(batch_size, 3, 40, 50)] # a single present camera
|
||||
mean = torch.tensor(IMAGENET_MEAN)
|
||||
std = torch.tensor(IMAGENET_STD)
|
||||
|
||||
pixel_values = _batched_pixel_values(
|
||||
camera_images, max_views, image_size, mean, std, torch.float32, torch.device("cpu")
|
||||
)
|
||||
|
||||
assert pixel_values.shape == (batch_size * max_views, 3, image_size, image_size)
|
||||
grouped = pixel_values.reshape(batch_size, max_views, 3, image_size, image_size)
|
||||
# Absent views (indices 1, 2) are zero images, normalized to the constant -mean/std.
|
||||
expected_pad = (-mean / std).view(1, 3, 1, 1)
|
||||
for view in (1, 2):
|
||||
assert torch.allclose(
|
||||
grouped[:, view], expected_pad.expand(batch_size, 3, image_size, image_size), atol=1e-5
|
||||
)
|
||||
# The present view is genuinely different from the constant pad value.
|
||||
assert not torch.allclose(
|
||||
grouped[:, 0], expected_pad.expand(batch_size, 3, image_size, image_size), atol=1e-3
|
||||
)
|
||||
@@ -2978,6 +2978,9 @@ eo1 = [
|
||||
evaluation = [
|
||||
{ name = "av" },
|
||||
]
|
||||
evo1 = [
|
||||
{ name = "transformers" },
|
||||
]
|
||||
fastwam = [
|
||||
{ name = "diffusers" },
|
||||
{ name = "transformers" },
|
||||
@@ -3179,6 +3182,9 @@ test = [
|
||||
{ name = "pytest-cov" },
|
||||
{ name = "pytest-timeout" },
|
||||
]
|
||||
timm-dep = [
|
||||
{ name = "timm" },
|
||||
]
|
||||
topreward = [
|
||||
{ name = "transformers" },
|
||||
]
|
||||
@@ -3296,6 +3302,7 @@ requires-dist = [
|
||||
{ name = "lerobot", extras = ["diffusers-dep"], marker = "extra == 'vla-jepa'" },
|
||||
{ name = "lerobot", extras = ["diffusion"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["dynamixel"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["evo1"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["fastwam"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["feetech"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["feetech"], marker = "extra == 'hopejr'" },
|
||||
@@ -3362,10 +3369,12 @@ requires-dist = [
|
||||
{ name = "lerobot", extras = ["scipy-dep"], marker = "extra == 'wallx'" },
|
||||
{ name = "lerobot", extras = ["smolvla"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["test"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["timm-dep"], marker = "extra == 'groot'" },
|
||||
{ name = "lerobot", extras = ["topreward"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["training"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'annotations'" },
|
||||
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'eo1'" },
|
||||
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'evo1'" },
|
||||
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'fastwam'" },
|
||||
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'groot'" },
|
||||
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'hilserl'" },
|
||||
@@ -3436,7 +3445,7 @@ requires-dist = [
|
||||
{ name = "setuptools", specifier = ">=71.0.0,<81.0.0" },
|
||||
{ name = "teleop", marker = "extra == 'phone'", specifier = ">=0.1.0,<0.2.0" },
|
||||
{ name = "termcolor", specifier = ">=2.4.0,<4.0.0" },
|
||||
{ name = "timm", marker = "extra == 'groot'", specifier = ">=1.0.0,<1.1.0" },
|
||||
{ name = "timm", marker = "extra == 'timm-dep'", specifier = ">=1.0.0,<1.1.0" },
|
||||
{ name = "torch", marker = "sys_platform != 'linux'", specifier = ">=2.7,<2.12.0" },
|
||||
{ name = "torch", marker = "sys_platform == 'linux'", specifier = ">=2.7,<2.12.0", index = "https://download.pytorch.org/whl/cu128" },
|
||||
{ name = "torchcodec", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin' and extra == 'dataset') or (platform_machine == 'AMD64' and sys_platform == 'linux' and extra == 'dataset') or (platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'dataset')", specifier = ">=0.3.0,<0.12.0" },
|
||||
@@ -3449,7 +3458,7 @@ requires-dist = [
|
||||
{ name = "transformers", marker = "extra == 'transformers-dep'", specifier = ">=5.4.0,<5.6.0" },
|
||||
{ name = "wandb", marker = "extra == 'training'", specifier = ">=0.24.0,<0.28.0" },
|
||||
]
|
||||
provides-extras = ["dataset", "training", "hardware", "viz", "core-scripts", "evaluation", "dataset-viz", "av-dep", "pygame-dep", "placo-dep", "transformers-dep", "grpcio-dep", "accelerate-dep", "can-dep", "peft-dep", "scipy-dep", "diffusers-dep", "qwen-vl-utils-dep", "matplotlib-dep", "pyserial-dep", "deepdiff-dep", "pynput-dep", "pyzmq-dep", "motorbridge-dep", "motorbridge-smart-servo-dep", "feetech", "dynamixel", "damiao", "robstride", "openarms", "gamepad", "hopejr", "lekiwi", "unitree-g1", "reachy2", "rebot", "kinematics", "intelrealsense", "phone", "diffusion", "wallx", "pi", "molmoact2", "smolvla", "multi-task-dit", "groot", "sarm", "robometer", "topreward", "xvla", "eo1", "fastwam", "hilserl", "vla-jepa", "async", "peft", "annotations", "dev", "notebook", "test", "video-benchmark", "aloha", "pusht", "libero", "metaworld", "all"]
|
||||
provides-extras = ["dataset", "training", "hardware", "viz", "core-scripts", "evaluation", "dataset-viz", "av-dep", "pygame-dep", "placo-dep", "transformers-dep", "grpcio-dep", "accelerate-dep", "can-dep", "peft-dep", "scipy-dep", "diffusers-dep", "qwen-vl-utils-dep", "matplotlib-dep", "pyserial-dep", "deepdiff-dep", "pynput-dep", "pyzmq-dep", "motorbridge-dep", "motorbridge-smart-servo-dep", "timm-dep", "feetech", "dynamixel", "damiao", "robstride", "openarms", "gamepad", "hopejr", "lekiwi", "unitree-g1", "reachy2", "rebot", "kinematics", "intelrealsense", "phone", "diffusion", "wallx", "pi", "molmoact2", "smolvla", "multi-task-dit", "groot", "sarm", "robometer", "topreward", "xvla", "eo1", "fastwam", "evo1", "hilserl", "vla-jepa", "async", "peft", "annotations", "dev", "notebook", "test", "video-benchmark", "aloha", "pusht", "libero", "metaworld", "all"]
|
||||
|
||||
[[package]]
|
||||
name = "librt"
|
||||
|
||||
Reference in New Issue
Block a user