mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-28 15:09:51 +00:00
Merge branch 'main' into feat/language-columns
This commit is contained in:
@@ -16,6 +16,7 @@ from lerobot.utils.action_interpolator import ActionInterpolator as ActionInterp
|
||||
|
||||
from .act.configuration_act import ACTConfig as ACTConfig
|
||||
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
|
||||
from .eo1.configuration_eo1 import EO1Config as EO1Config
|
||||
from .factory import get_policy_class, make_policy, make_policy_config, make_pre_post_processors
|
||||
from .groot.configuration_groot import GrootConfig as GrootConfig
|
||||
from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig as MultiTaskDiTConfig
|
||||
@@ -41,6 +42,7 @@ __all__ = [
|
||||
"DiffusionConfig",
|
||||
"GrootConfig",
|
||||
"MultiTaskDiTConfig",
|
||||
"EO1Config",
|
||||
"PI0Config",
|
||||
"PI0FastConfig",
|
||||
"PI05Config",
|
||||
|
||||
@@ -100,8 +100,8 @@ class DiffusionConfig(PreTrainedConfig):
|
||||
|
||||
# Inputs / output structure.
|
||||
n_obs_steps: int = 2
|
||||
horizon: int = 16
|
||||
n_action_steps: int = 8
|
||||
horizon: int = 64
|
||||
n_action_steps: int = 32
|
||||
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
@@ -122,10 +122,10 @@ class DiffusionConfig(PreTrainedConfig):
|
||||
crop_ratio: float = 1.0
|
||||
crop_shape: tuple[int, int] | None = None
|
||||
crop_is_random: bool = True
|
||||
pretrained_backbone_weights: str | None = None
|
||||
use_group_norm: bool = True
|
||||
pretrained_backbone_weights: str | None = "ResNet18_Weights.IMAGENET1K_V1"
|
||||
use_group_norm: bool = False
|
||||
spatial_softmax_num_keypoints: int = 32
|
||||
use_separate_rgb_encoder_per_camera: bool = False
|
||||
use_separate_rgb_encoder_per_camera: bool = True
|
||||
# Unet.
|
||||
down_dims: tuple[int, ...] = (512, 1024, 2048)
|
||||
kernel_size: int = 5
|
||||
|
||||
+1
@@ -0,0 +1 @@
|
||||
../../../../docs/source/eo1.mdx
|
||||
@@ -0,0 +1,7 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
from .configuration_eo1 import EO1Config
|
||||
from .modeling_eo1 import EO1Policy
|
||||
from .processor_eo1 import make_eo1_pre_post_processors
|
||||
|
||||
__all__ = ["EO1Config", "EO1Policy", "make_eo1_pre_post_processors"]
|
||||
@@ -0,0 +1,193 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.optim.optimizers import AdamWConfig
|
||||
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE
|
||||
from lerobot.utils.import_utils import _transformers_available, require_package
|
||||
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
|
||||
Qwen2_5_VLConfig,
|
||||
Qwen2_5_VLTextConfig,
|
||||
Qwen2_5_VLVisionConfig,
|
||||
)
|
||||
else:
|
||||
Qwen2_5_VLConfig = None
|
||||
Qwen2_5_VLTextConfig = None
|
||||
Qwen2_5_VLVisionConfig = None
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("eo1")
|
||||
@dataclass
|
||||
class EO1Config(PreTrainedConfig):
|
||||
"""Configuration for native EO1 policy integration in LeRobot."""
|
||||
|
||||
vlm_base: str = "Qwen/Qwen2.5-VL-3B-Instruct"
|
||||
vlm_config: dict | None = None
|
||||
|
||||
# Vision processor settings.
|
||||
image_min_pixels: int | None = 64 * 28 * 28
|
||||
image_max_pixels: int | None = 128 * 28 * 28
|
||||
use_fast_processor: bool = False
|
||||
|
||||
# Execution and action horizon.
|
||||
n_obs_steps: int = 1
|
||||
chunk_size: int = 8
|
||||
n_action_steps: int = 8
|
||||
|
||||
# State/action padding to match EO1 flow head dimensionality.
|
||||
max_state_dim: int = 32
|
||||
max_action_dim: int = 32
|
||||
|
||||
# Flow matching sampling.
|
||||
num_denoise_steps: int = 10
|
||||
num_action_layers: int = 2
|
||||
action_act: str = "linear"
|
||||
time_sampling_beta_alpha: float = 1.5
|
||||
time_sampling_beta_beta: float = 1.0
|
||||
time_sampling_scale: float = 0.999
|
||||
time_sampling_offset: float = 0.001
|
||||
min_period: float = 4e-3
|
||||
max_period: float = 4.0
|
||||
supervise_padding_action_dims: bool = True
|
||||
supervise_padding_actions: bool = True
|
||||
|
||||
# Policy-level dtype request for the Qwen backbone.
|
||||
# - "auto": follow the backbone config/checkpoint default dtype. For Qwen2.5-VL this resolves to bf16.
|
||||
# The EO1 flow-matching head still keeps its own parameters in fp32.
|
||||
# - "bfloat16": force the backbone to initialize/load in bf16 regardless of the saved config default.
|
||||
# - "float32": force the backbone to initialize/load in fp32 for maximum numerical conservatism.
|
||||
dtype: str = "auto" # Options: "auto", "bfloat16", "float32"
|
||||
force_fp32_autocast: bool = True
|
||||
|
||||
# Optional attention backend request passed through to the Qwen backbone.
|
||||
# Common values: None, "eager", "sdpa", "flash_attention_2".
|
||||
attn_implementation: str | None = None
|
||||
|
||||
# Training settings.
|
||||
gradient_checkpointing: bool = False # Enable gradient checkpointing for memory optimization
|
||||
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.MEAN_STD,
|
||||
"ACTION": NormalizationMode.MEAN_STD,
|
||||
}
|
||||
)
|
||||
|
||||
# Optimizer settings aligned with EO1/experiments/2_libero/train.sh and EO1 TrainPipelineConfig defaults.
|
||||
optimizer_lr: float = 1e-4
|
||||
optimizer_betas: tuple[float, float] = (0.9, 0.999)
|
||||
optimizer_eps: float = 1e-8
|
||||
optimizer_weight_decay: float = 0.1
|
||||
optimizer_grad_clip_norm: float = 1.0
|
||||
|
||||
# Scheduler settings aligned with EO1 train.sh: cosine schedule with warmup_ratio=0.03.
|
||||
# Note: These will auto-scale if --steps < scheduler_decay_steps
|
||||
# For example, --steps=3000 will scale warmup to 100 and decay to 3000
|
||||
scheduler_warmup_steps: int = 900 # 0.03 * 30_000 long-run steps
|
||||
scheduler_decay_steps: int = 30_000
|
||||
scheduler_decay_lr: float = 0.0
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
if self.n_action_steps > self.chunk_size:
|
||||
raise ValueError(
|
||||
f"n_action_steps ({self.n_action_steps}) cannot be greater than chunk_size ({self.chunk_size})"
|
||||
)
|
||||
|
||||
# Populate the serialized backbone config only when the caller did not provide one.
|
||||
if self.vlm_config is None:
|
||||
require_package("transformers", extra="eo1")
|
||||
self.vlm_config = Qwen2_5_VLConfig.from_pretrained(self.vlm_base).to_dict()
|
||||
|
||||
@property
|
||||
def vlm_backbone_config(self) -> Qwen2_5_VLConfig:
|
||||
require_package("transformers", extra="eo1")
|
||||
config_dict = deepcopy(self.vlm_config)
|
||||
if self.attn_implementation is not None:
|
||||
config_dict["attn_implementation"] = self.attn_implementation
|
||||
return Qwen2_5_VLConfig(**config_dict)
|
||||
|
||||
@property
|
||||
def text_config(self) -> Qwen2_5_VLTextConfig:
|
||||
return self.vlm_backbone_config.text_config
|
||||
|
||||
@property
|
||||
def vision_config(self) -> Qwen2_5_VLVisionConfig:
|
||||
return self.vlm_backbone_config.vision_config
|
||||
|
||||
def validate_features(self) -> None:
|
||||
"""Validate and set up EO1 input and output features."""
|
||||
image_features = [key for key, feat in self.input_features.items() if feat.type == FeatureType.VISUAL]
|
||||
if not image_features:
|
||||
raise ValueError(
|
||||
"EO1 policy requires at least one visual input feature. "
|
||||
"No features of type FeatureType.VISUAL found in input_features."
|
||||
)
|
||||
|
||||
if OBS_STATE not in self.input_features:
|
||||
state_feature = PolicyFeature(
|
||||
type=FeatureType.STATE,
|
||||
shape=(self.max_state_dim,),
|
||||
)
|
||||
self.input_features[OBS_STATE] = state_feature
|
||||
|
||||
if ACTION not in self.output_features:
|
||||
action_feature = PolicyFeature(
|
||||
type=FeatureType.ACTION,
|
||||
shape=(self.max_action_dim,),
|
||||
)
|
||||
self.output_features[ACTION] = action_feature
|
||||
|
||||
def get_optimizer_preset(self) -> AdamWConfig:
|
||||
return AdamWConfig(
|
||||
lr=self.optimizer_lr,
|
||||
betas=self.optimizer_betas,
|
||||
eps=self.optimizer_eps,
|
||||
weight_decay=self.optimizer_weight_decay,
|
||||
grad_clip_norm=self.optimizer_grad_clip_norm,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self):
|
||||
return CosineDecayWithWarmupSchedulerConfig(
|
||||
peak_lr=self.optimizer_lr,
|
||||
decay_lr=self.scheduler_decay_lr,
|
||||
num_warmup_steps=self.scheduler_warmup_steps,
|
||||
num_decay_steps=self.scheduler_decay_steps,
|
||||
)
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list[int]:
|
||||
return list(range(self.chunk_size))
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
return None
|
||||
@@ -0,0 +1,620 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
import math
|
||||
from collections import deque
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
import torch.utils.checkpoint
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.policies.eo1.configuration_eo1 import EO1Config
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE
|
||||
from lerobot.utils.import_utils import _transformers_available, require_package
|
||||
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.models.qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
|
||||
from transformers.utils import torch_compilable_check
|
||||
else:
|
||||
ACT2FN = None
|
||||
Qwen2_5_VLForConditionalGeneration = None
|
||||
torch_compilable_check = None
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def pad_vector(vector, new_dim):
|
||||
"""Pad the last dimension of a vector to new_dim with zeros.
|
||||
|
||||
Can be (batch_size x sequence_length x features_dimension)
|
||||
or (batch_size x features_dimension)
|
||||
"""
|
||||
if vector.shape[-1] >= new_dim:
|
||||
return vector
|
||||
return F.pad(vector, (0, new_dim - vector.shape[-1]))
|
||||
|
||||
|
||||
class EO1Policy(PreTrainedPolicy):
|
||||
"""EO1 policy wrapper for LeRobot robot-only training/evaluation."""
|
||||
|
||||
config_class = EO1Config
|
||||
name = "eo1"
|
||||
|
||||
def __init__(self, config: EO1Config, **kwargs):
|
||||
require_package("transformers", extra="eo1")
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
|
||||
if config.pretrained_path is None:
|
||||
# Initialize from pretrained VLM
|
||||
vlm_backbone = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||
config.vlm_base,
|
||||
dtype=config.dtype,
|
||||
attn_implementation=config.attn_implementation,
|
||||
)
|
||||
else:
|
||||
vlm_backbone = Qwen2_5_VLForConditionalGeneration._from_config(
|
||||
config.vlm_backbone_config,
|
||||
dtype=config.vlm_backbone_config.dtype if config.dtype == "auto" else config.dtype,
|
||||
)
|
||||
|
||||
self.model = EO1VisionFlowMatchingModel(config, vlm_backbone)
|
||||
if config.gradient_checkpointing:
|
||||
self.model.gradient_checkpointing_enable()
|
||||
|
||||
self.model.to(config.device)
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self._action_queue = deque(maxlen=self.config.n_action_steps)
|
||||
|
||||
@staticmethod
|
||||
def _get_model_inputs(batch: dict[str, Tensor], excluded_keys: set[str]) -> dict[str, Tensor]:
|
||||
return {key: value for key, value in batch.items() if key not in excluded_keys}
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
||||
state = self.prepare_state(batch[OBS_STATE])
|
||||
actions = self.prepare_action(batch[ACTION])
|
||||
model_inputs = self._get_model_inputs(batch, {OBS_STATE, ACTION})
|
||||
loss = self.model(states=state, action=actions, **model_inputs)
|
||||
|
||||
loss_dict = {"loss": loss.item()}
|
||||
return loss, loss_dict
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs) -> Tensor:
|
||||
self.eval()
|
||||
|
||||
states = self.prepare_state(batch[OBS_STATE])
|
||||
model_inputs = self._get_model_inputs(batch, {OBS_STATE})
|
||||
actions = self.model.sample_actions(states=states, **model_inputs).to(torch.float32)
|
||||
|
||||
original_action_dim = self.config.output_features[ACTION].shape[0]
|
||||
return actions[:, :, :original_action_dim]
|
||||
|
||||
def prepare_state(self, state: Tensor) -> Tensor:
|
||||
return pad_vector(state, self.config.max_state_dim)
|
||||
|
||||
def prepare_action(self, action: Tensor) -> Tensor:
|
||||
return pad_vector(action, self.config.max_action_dim)
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
self.eval()
|
||||
|
||||
if len(self._action_queue) == 0:
|
||||
actions = self.predict_action_chunk(batch)[:, : self.config.n_action_steps]
|
||||
self._action_queue.extend(actions.transpose(0, 1))
|
||||
|
||||
return self._action_queue.popleft()
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
return self.parameters()
|
||||
|
||||
|
||||
def get_safe_dtype(target_dtype, device_type):
|
||||
"""Get a safe dtype for the given device type."""
|
||||
if device_type == "mps" and target_dtype == torch.float64:
|
||||
return torch.float32
|
||||
if device_type == "cpu":
|
||||
# CPU doesn't support bfloat16, use float32 instead
|
||||
if target_dtype == torch.bfloat16:
|
||||
return torch.float32
|
||||
if target_dtype == torch.float64:
|
||||
return torch.float64
|
||||
return target_dtype
|
||||
|
||||
|
||||
def create_sinusoidal_pos_embedding( # see openpi `create_sinusoidal_pos_embedding` (exact copy)
|
||||
time: torch.Tensor, dimension: int, min_period: float, max_period: float, device="cpu"
|
||||
) -> Tensor:
|
||||
"""Computes sine-cosine positional embedding vectors for scalar positions."""
|
||||
if dimension % 2 != 0:
|
||||
raise ValueError(f"dimension ({dimension}) must be divisible by 2")
|
||||
|
||||
if time.ndim != 1:
|
||||
raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.")
|
||||
|
||||
dtype = get_safe_dtype(torch.float64, device.type)
|
||||
fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device)
|
||||
period = min_period * (max_period / min_period) ** fraction
|
||||
|
||||
# Compute the outer product
|
||||
scaling_factor = 1.0 / period * 2 * math.pi
|
||||
sin_input = scaling_factor[None, :] * time[:, None]
|
||||
return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
|
||||
|
||||
|
||||
def sample_beta(alpha, beta, bsize, device): # see openpi `sample_beta` (exact copy)
|
||||
# Beta sampling uses _sample_dirichlet which isn't implemented for MPS, so sample on CPU
|
||||
alpha_t = torch.tensor(alpha, dtype=torch.float32)
|
||||
beta_t = torch.tensor(beta, dtype=torch.float32)
|
||||
dist = torch.distributions.Beta(alpha_t, beta_t)
|
||||
return dist.sample((bsize,)).to(device)
|
||||
|
||||
|
||||
class EO1VisionActionProjector(torch.nn.Sequential):
|
||||
"""This block implements the multi-layer perceptron (MLP) module."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
num_layers: int = 2,
|
||||
activation_layer: str = "linear",
|
||||
bias: bool = True,
|
||||
device: Any = None,
|
||||
dtype: torch.dtype = torch.float32,
|
||||
):
|
||||
layers = []
|
||||
in_dim = in_channels
|
||||
hidden_channels = [in_dim] * (num_layers - 1) + [out_channels]
|
||||
for hidden_dim in hidden_channels[:-1]:
|
||||
layers.append(torch.nn.Linear(in_dim, hidden_dim, bias=bias, dtype=dtype, device=device))
|
||||
layers.append(ACT2FN[activation_layer])
|
||||
in_dim = hidden_dim
|
||||
layers.append(torch.nn.Linear(in_dim, hidden_channels[-1], bias=bias, dtype=dtype, device=device))
|
||||
super().__init__(*layers)
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self[0].weight.dtype
|
||||
|
||||
|
||||
class EO1VisionFlowMatchingModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: EO1Config,
|
||||
vlm_backbone: Qwen2_5_VLForConditionalGeneration | None = None,
|
||||
):
|
||||
require_package("transformers", extra="eo1")
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
# Preserve the backbone dtype selected at construction time so Qwen's fp32 rotary buffers stay intact.
|
||||
self.vlm_backbone = vlm_backbone
|
||||
self.hidden_size = self.vlm_backbone.config.text_config.hidden_size
|
||||
max_state_dim = config.max_state_dim
|
||||
max_action_dim = config.max_action_dim
|
||||
self.state_proj = nn.Linear(max_state_dim, self.hidden_size, dtype=torch.float32)
|
||||
self.action_in_proj = nn.Linear(max_action_dim, self.hidden_size, dtype=torch.float32)
|
||||
self.action_out_proj = EO1VisionActionProjector(
|
||||
self.hidden_size,
|
||||
max_action_dim,
|
||||
config.num_action_layers,
|
||||
config.action_act,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
self.action_time_mlp_in = nn.Linear(self.hidden_size * 2, self.hidden_size, dtype=torch.float32)
|
||||
self.action_time_mlp_out = nn.Linear(self.hidden_size, self.hidden_size, dtype=torch.float32)
|
||||
self.gradient_checkpointing_enabled = False
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.vlm_backbone.get_input_embeddings()
|
||||
|
||||
def flow_head_autocast_context(self):
|
||||
if self.config.force_fp32_autocast:
|
||||
return torch.autocast(
|
||||
device_type=self.state_proj.weight.device.type,
|
||||
enabled=False,
|
||||
)
|
||||
return contextlib.nullcontext()
|
||||
|
||||
def gradient_checkpointing_enable(self):
|
||||
"""Enable gradient checkpointing for the Qwen2.5-VL backbone."""
|
||||
self.gradient_checkpointing_enabled = True
|
||||
self.vlm_backbone.gradient_checkpointing_enable(
|
||||
gradient_checkpointing_kwargs={"use_reentrant": False}
|
||||
)
|
||||
logger.info("Enabled gradient checkpointing for EO1VisionFlowMatchingModel")
|
||||
|
||||
def gradient_checkpointing_disable(self):
|
||||
"""Disable gradient checkpointing for the Qwen2.5-VL backbone."""
|
||||
self.gradient_checkpointing_enabled = False
|
||||
self.vlm_backbone.gradient_checkpointing_disable()
|
||||
logger.info("Disabled gradient checkpointing for EO1VisionFlowMatchingModel")
|
||||
|
||||
def _apply_checkpoint(self, func, *args, **kwargs):
|
||||
"""Apply manual gradient checkpointing to EO1 flow-head computations when training."""
|
||||
if self.gradient_checkpointing_enabled and self.training and torch.is_grad_enabled():
|
||||
return torch.utils.checkpoint.checkpoint(
|
||||
func, *args, use_reentrant=False, preserve_rng_state=False, **kwargs
|
||||
)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
def sample_noise(self, shape, device):
|
||||
noise = torch.normal(
|
||||
mean=0.0,
|
||||
std=1.0,
|
||||
size=shape,
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
return noise
|
||||
|
||||
def sample_time(self, bsize, device):
|
||||
time_beta = sample_beta(
|
||||
self.config.time_sampling_beta_alpha, self.config.time_sampling_beta_beta, bsize, device
|
||||
)
|
||||
time = time_beta * self.config.time_sampling_scale + self.config.time_sampling_offset
|
||||
return time.to(dtype=torch.float32, device=device)
|
||||
|
||||
def get_placeholder_mask(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None,
|
||||
inputs_embeds: torch.FloatTensor | None,
|
||||
state_features: torch.FloatTensor | None = None,
|
||||
action_features: torch.FloatTensor | None = None,
|
||||
*,
|
||||
state_token_id: int,
|
||||
action_token_id: int,
|
||||
) -> tuple[torch.BoolTensor, torch.BoolTensor]:
|
||||
"""Return EO1 state/action placeholder masks, following Qwen's multimodal mask style."""
|
||||
if input_ids is None:
|
||||
special_state_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(state_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_state_mask = special_state_mask.all(-1)
|
||||
special_action_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(action_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_action_mask = special_action_mask.all(-1)
|
||||
else:
|
||||
special_state_mask = input_ids == state_token_id
|
||||
special_action_mask = input_ids == action_token_id
|
||||
|
||||
n_state_tokens = special_state_mask.sum()
|
||||
special_state_mask = (
|
||||
special_state_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
)
|
||||
if state_features is not None:
|
||||
torch_compilable_check(
|
||||
inputs_embeds[special_state_mask].numel() == state_features.numel(),
|
||||
f"State features and state tokens do not match, tokens: {n_state_tokens}, features: {state_features.shape[0]}",
|
||||
)
|
||||
|
||||
n_action_tokens = special_action_mask.sum()
|
||||
special_action_mask = (
|
||||
special_action_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
)
|
||||
if action_features is not None:
|
||||
torch_compilable_check(
|
||||
inputs_embeds[special_action_mask].numel() == action_features.numel(),
|
||||
f"Action features and action tokens do not match, tokens: {n_action_tokens}, features: {action_features.shape[0]}",
|
||||
)
|
||||
|
||||
return special_state_mask, special_action_mask
|
||||
|
||||
def embed_prefix(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
states: torch.Tensor,
|
||||
*,
|
||||
state_token_id: int,
|
||||
action_token_id: int,
|
||||
) -> torch.FloatTensor:
|
||||
"""Embed the EO1 prefix tokens before native Qwen injects multimodal features."""
|
||||
|
||||
# Get the input embeddings for the input IDs
|
||||
def input_embed_func(input_ids: torch.LongTensor) -> torch.FloatTensor:
|
||||
return self.get_input_embeddings()(input_ids)
|
||||
|
||||
inputs_embeds = self._apply_checkpoint(input_embed_func, input_ids)
|
||||
|
||||
# Project the states to the hidden size
|
||||
def state_proj_func(states: torch.Tensor) -> torch.FloatTensor:
|
||||
with self.flow_head_autocast_context():
|
||||
states = states.to(dtype=self.state_proj.weight.dtype)
|
||||
return self.state_proj(states)
|
||||
|
||||
state_embs = self._apply_checkpoint(state_proj_func, states)
|
||||
state_mask, _ = self.get_placeholder_mask(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
state_features=state_embs,
|
||||
state_token_id=state_token_id,
|
||||
action_token_id=action_token_id,
|
||||
)
|
||||
state_embs = state_embs.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(state_mask, state_embs)
|
||||
return inputs_embeds
|
||||
|
||||
def embed_suffix(
|
||||
self,
|
||||
timestep: torch.Tensor,
|
||||
noisy_actions: torch.Tensor,
|
||||
) -> torch.FloatTensor:
|
||||
"""Embed the suffix"""
|
||||
|
||||
def action_proj_func(noisy_actions: torch.Tensor) -> torch.FloatTensor:
|
||||
with self.flow_head_autocast_context():
|
||||
noisy_actions = noisy_actions.to(dtype=self.action_in_proj.weight.dtype)
|
||||
return self.action_in_proj(noisy_actions)
|
||||
|
||||
action_embs = self._apply_checkpoint(action_proj_func, noisy_actions)
|
||||
time_embs = create_sinusoidal_pos_embedding(
|
||||
timestep,
|
||||
self.hidden_size,
|
||||
min_period=self.config.min_period,
|
||||
max_period=self.config.max_period,
|
||||
device=action_embs.device,
|
||||
)
|
||||
time_embs = time_embs.to(dtype=action_embs.dtype)
|
||||
time_embs = time_embs[:, None, :].expand_as(action_embs)
|
||||
action_time_embs = torch.cat([action_embs, time_embs], dim=2)
|
||||
|
||||
def mlp_func(action_time_embs: torch.Tensor) -> torch.FloatTensor:
|
||||
with self.flow_head_autocast_context():
|
||||
action_time_embs = action_time_embs.to(dtype=self.action_time_mlp_in.weight.dtype)
|
||||
action_time_embs = self.action_time_mlp_in(action_time_embs)
|
||||
action_time_embs = F.silu(action_time_embs)
|
||||
return self.action_time_mlp_out(action_time_embs)
|
||||
|
||||
action_time_embs = self._apply_checkpoint(mlp_func, action_time_embs)
|
||||
return action_time_embs
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
attention_mask: torch.LongTensor | None = None,
|
||||
pixel_values: torch.FloatTensor | None = None,
|
||||
image_grid_thw: torch.LongTensor | None = None,
|
||||
mm_token_type_ids: torch.IntTensor | None = None,
|
||||
states: torch.FloatTensor | None = None,
|
||||
action: torch.FloatTensor | None = None,
|
||||
action_is_pad: torch.BoolTensor | None = None,
|
||||
*,
|
||||
state_token_id: int,
|
||||
action_token_id: int,
|
||||
**kwargs,
|
||||
) -> Tensor:
|
||||
"""Run the EO1 training forward pass and compute the flow-matching loss."""
|
||||
|
||||
# 1. Build the EO1 prefix with state placeholders resolved.
|
||||
inputs_embeds = self.embed_prefix(
|
||||
input_ids,
|
||||
states=states,
|
||||
state_token_id=state_token_id,
|
||||
action_token_id=action_token_id,
|
||||
)
|
||||
|
||||
# 2. Sample the diffusion target and replace the action placeholders.
|
||||
time = self.sample_time(action.shape[0], inputs_embeds.device)
|
||||
noise = self.sample_noise(action.shape, inputs_embeds.device)
|
||||
|
||||
time_expanded = time[:, None, None]
|
||||
x_t = time_expanded * noise + (1 - time_expanded) * action
|
||||
u_t = noise - action
|
||||
action_time_embs = self.embed_suffix(time, x_t)
|
||||
_, action_mask = self.get_placeholder_mask(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
action_features=action_time_embs,
|
||||
state_token_id=state_token_id,
|
||||
action_token_id=action_token_id,
|
||||
)
|
||||
action_time_embs = action_time_embs.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(action_mask, action_time_embs)
|
||||
|
||||
# 3. Optionally drop padded action tokens from backbone attention.
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.to(inputs_embeds.device)
|
||||
|
||||
if not self.config.supervise_padding_actions:
|
||||
action_is_pad = action_is_pad.to(device=inputs_embeds.device, dtype=torch.bool)
|
||||
action_token_mask = action_mask[..., 0]
|
||||
action_padding_mask = torch.zeros_like(action_token_mask)
|
||||
action_padding_mask = action_padding_mask.masked_scatter(
|
||||
action_token_mask,
|
||||
action_is_pad.reshape(-1),
|
||||
)
|
||||
attention_mask = attention_mask.masked_fill(action_padding_mask, 0)
|
||||
|
||||
# 4. Run the Qwen backbone on the fused EO1 sequence.
|
||||
def vlm_forward_func(
|
||||
input_ids: torch.LongTensor,
|
||||
attention_mask: torch.Tensor | None,
|
||||
inputs_embeds: torch.FloatTensor,
|
||||
pixel_values: torch.Tensor | None,
|
||||
image_grid_thw: torch.LongTensor | None,
|
||||
mm_token_type_ids: torch.IntTensor | None,
|
||||
) -> torch.FloatTensor:
|
||||
outputs = self.vlm_backbone.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
pixel_values=pixel_values,
|
||||
image_grid_thw=image_grid_thw,
|
||||
mm_token_type_ids=mm_token_type_ids,
|
||||
use_cache=False,
|
||||
output_hidden_states=False,
|
||||
return_dict=True,
|
||||
)
|
||||
return outputs.last_hidden_state
|
||||
|
||||
hidden_states = self._apply_checkpoint(
|
||||
vlm_forward_func,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
inputs_embeds,
|
||||
pixel_values,
|
||||
image_grid_thw,
|
||||
mm_token_type_ids,
|
||||
)
|
||||
action_hidden_states = hidden_states[action_mask[..., 0]]
|
||||
|
||||
# 5. Project the action-token hidden states back to the flow target space.
|
||||
def action_out_proj_func(action_hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
||||
with self.flow_head_autocast_context():
|
||||
action_hidden_states = action_hidden_states.to(dtype=self.action_out_proj.dtype)
|
||||
return self.action_out_proj(action_hidden_states)
|
||||
|
||||
v_t = self._apply_checkpoint(action_out_proj_func, action_hidden_states)
|
||||
v_t = v_t.reshape(u_t.shape).to(dtype=u_t.dtype)
|
||||
losses = F.mse_loss(u_t, v_t, reduction="none")
|
||||
|
||||
# 6. Apply the configured supervision mask and reduce the loss.
|
||||
if not self.config.supervise_padding_action_dims:
|
||||
original_action_dim = self.config.output_features[ACTION].shape[0]
|
||||
losses = losses[..., :original_action_dim]
|
||||
|
||||
if not self.config.supervise_padding_actions:
|
||||
losses = losses[~action_is_pad]
|
||||
|
||||
return losses.mean()
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_actions(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
pixel_values: torch.Tensor | None = None,
|
||||
image_grid_thw: torch.LongTensor | None = None,
|
||||
mm_token_type_ids: torch.IntTensor | None = None,
|
||||
states: torch.Tensor | None = None,
|
||||
*,
|
||||
state_token_id: int,
|
||||
action_token_id: int,
|
||||
**kwargs,
|
||||
) -> Tensor:
|
||||
"""Sample actions from the model."""
|
||||
if states is None:
|
||||
raise ValueError("states are required for EO1 action sampling.")
|
||||
if mm_token_type_ids is None:
|
||||
raise ValueError("mm_token_type_ids are required for EO1 action sampling.")
|
||||
|
||||
# 1. Resolve the left-padded rollout prompt and locate the action span.
|
||||
chunk_size = self.config.chunk_size
|
||||
|
||||
inputs_embeds = self.embed_prefix(
|
||||
input_ids,
|
||||
states=states,
|
||||
state_token_id=state_token_id,
|
||||
action_token_id=action_token_id,
|
||||
).clone()
|
||||
_, action_placeholder_mask = self.get_placeholder_mask(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
state_token_id=state_token_id,
|
||||
action_token_id=action_token_id,
|
||||
)
|
||||
action_mask = action_placeholder_mask[..., 0]
|
||||
token_counts = action_mask.sum(dim=1)
|
||||
if not torch.all(token_counts == chunk_size):
|
||||
raise ValueError(
|
||||
f"Each sample must contain exactly {chunk_size} action tokens, got {token_counts.tolist()}."
|
||||
)
|
||||
if action_mask.ne(action_mask[:1]).any():
|
||||
raise ValueError(
|
||||
"Batch inference expects all samples to share the same action token mask after left padding."
|
||||
)
|
||||
act_start = int(action_mask[0].to(torch.int64).argmax().item())
|
||||
act_end = act_start + self.config.chunk_size
|
||||
if not torch.all(action_mask[:, act_start:act_end]):
|
||||
raise ValueError("Action tokens must form a contiguous chunk of length chunk_size.")
|
||||
act_slice = slice(act_start, act_end)
|
||||
|
||||
# 2. Encode the fixed prefix once and cache its KV state.
|
||||
batch_size = input_ids.shape[0]
|
||||
device = inputs_embeds.device
|
||||
attention_mask = attention_mask.to(device)
|
||||
mm_token_type_ids = mm_token_type_ids.to(device)
|
||||
position_ids, _ = self.vlm_backbone.model.get_rope_index(
|
||||
input_ids,
|
||||
image_grid_thw=image_grid_thw,
|
||||
attention_mask=attention_mask,
|
||||
mm_token_type_ids=mm_token_type_ids,
|
||||
)
|
||||
position_ids = position_ids.to(device)
|
||||
|
||||
outputs = self.vlm_backbone.model(
|
||||
input_ids=input_ids[:, :act_start],
|
||||
attention_mask=attention_mask[:, :act_start],
|
||||
position_ids=position_ids[..., :act_start],
|
||||
inputs_embeds=inputs_embeds[:, :act_start],
|
||||
pixel_values=pixel_values,
|
||||
image_grid_thw=image_grid_thw,
|
||||
mm_token_type_ids=mm_token_type_ids[:, :act_start],
|
||||
use_cache=True,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
x_t = self.sample_noise(
|
||||
(batch_size, chunk_size, self.config.max_action_dim),
|
||||
device,
|
||||
).to(dtype=self.action_in_proj.weight.dtype)
|
||||
dt = -1.0 / self.config.num_denoise_steps
|
||||
past_key_values = outputs.past_key_values
|
||||
|
||||
# 3. Denoise only the action chunk while keeping the prefix cache invariant.
|
||||
for step in range(self.config.num_denoise_steps):
|
||||
time = torch.full(
|
||||
(batch_size,),
|
||||
1.0 + step * dt,
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
action_time_embs = self.embed_suffix(time, x_t)
|
||||
inputs_embeds[:, act_slice] = action_time_embs.to(inputs_embeds.dtype)
|
||||
|
||||
# Keep the prefix KV cache invariant across denoising steps.
|
||||
past_key_values.crop(act_start)
|
||||
outputs = self.vlm_backbone.model(
|
||||
attention_mask=attention_mask[:, :act_end],
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds[:, act_slice],
|
||||
position_ids=position_ids[..., act_slice],
|
||||
use_cache=True,
|
||||
return_dict=True,
|
||||
)
|
||||
with self.flow_head_autocast_context():
|
||||
hidden_states = outputs.last_hidden_state[:, :chunk_size]
|
||||
hidden_states = hidden_states.to(dtype=self.action_out_proj.dtype)
|
||||
v_t = self.action_out_proj(hidden_states)
|
||||
|
||||
x_t += dt * v_t.reshape(x_t.shape)
|
||||
|
||||
return x_t
|
||||
@@ -0,0 +1,282 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
|
||||
from lerobot.policies.eo1.configuration_eo1 import EO1Config
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
ComplementaryDataProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyAction,
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorStep,
|
||||
ProcessorStepRegistry,
|
||||
RenameObservationsProcessorStep,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
|
||||
from lerobot.types import TransitionKey
|
||||
from lerobot.utils.constants import (
|
||||
OBS_STATE,
|
||||
POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
)
|
||||
from lerobot.utils.import_utils import _transformers_available, require_package
|
||||
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers.models.qwen2_5_vl import Qwen2_5_VLProcessor
|
||||
else:
|
||||
Qwen2_5_VLProcessor = None
|
||||
|
||||
SYSTEM_MESSAGE = "You are a helpful physical assistant."
|
||||
|
||||
# EO-1 special tokens
|
||||
ACTION_START_TOKEN = "<|action_start|>" # nosec B105
|
||||
DEFAULT_ACTION_TOKEN = "<|action_pad|>" # nosec B105
|
||||
ACTION_END_TOKEN = "<|action_end|>" # nosec B105
|
||||
STATE_START_TOKEN = "<|state_start|>" # nosec B105
|
||||
DEFAULT_STATE_TOKEN = "<|state_pad|>" # nosec B105
|
||||
STATE_END_TOKEN = "<|state_end|>" # nosec B105
|
||||
TASK_VLA_TOKEN = "<|vla|>" # nosec B105
|
||||
|
||||
EO1_SPECIAL_TOKENS = [
|
||||
ACTION_START_TOKEN,
|
||||
DEFAULT_ACTION_TOKEN,
|
||||
ACTION_END_TOKEN,
|
||||
STATE_START_TOKEN,
|
||||
DEFAULT_STATE_TOKEN,
|
||||
STATE_END_TOKEN,
|
||||
TASK_VLA_TOKEN,
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="eo1_conversation_template_processor")
|
||||
class EO1ConversationTemplateStep(ComplementaryDataProcessorStep):
|
||||
input_features: dict[str, PolicyFeature] | dict[str, dict[str, Any]]
|
||||
chunk_size: int
|
||||
|
||||
_image_keys: list[str] = field(default_factory=list, init=False, repr=False)
|
||||
|
||||
def __post_init__(self):
|
||||
# Robust JSON deserialization handling (guard empty maps).
|
||||
if self.input_features:
|
||||
first_val = next(iter(self.input_features.values()))
|
||||
if isinstance(first_val, dict):
|
||||
reconstructed = {}
|
||||
for key, ft_dict in self.input_features.items():
|
||||
reconstructed[key] = PolicyFeature(
|
||||
type=FeatureType(ft_dict["type"]), shape=tuple(ft_dict["shape"])
|
||||
)
|
||||
self.input_features = reconstructed
|
||||
|
||||
self._image_keys = [
|
||||
key for key, value in self.input_features.items() if value.type == FeatureType.VISUAL
|
||||
]
|
||||
|
||||
def complementary_data(self, complementary_data):
|
||||
tasks = complementary_data.get("task")
|
||||
if tasks is None:
|
||||
raise ValueError("Task is required for EO1ConversationTemplateStep.")
|
||||
|
||||
observation = self.transition.get(TransitionKey.OBSERVATION)
|
||||
if observation is None:
|
||||
raise ValueError("Observation is required for EO1ConversationTemplateStep.")
|
||||
|
||||
if OBS_STATE in observation and observation[OBS_STATE].shape[0] != len(tasks):
|
||||
raise ValueError("Batch size mismatch between observation.state and task list.")
|
||||
|
||||
# LeRobot visual observations reach in processor as float32 tensors in [0, 1].
|
||||
# Convert to uint8 in [0, 255] to meet the input requirement of Qwen2.5-VL-3B-Instruct.
|
||||
images = {
|
||||
key: observation[key].clamp(0, 1).mul(255.0).round().to(torch.uint8) for key in self._image_keys
|
||||
}
|
||||
messages = []
|
||||
for i in range(len(tasks)):
|
||||
content = [
|
||||
*[{"type": "image", "image": images[key][i]} for key in self._image_keys],
|
||||
{
|
||||
"type": "text",
|
||||
"text": (
|
||||
f"{STATE_START_TOKEN}{DEFAULT_STATE_TOKEN}{STATE_END_TOKEN}{tasks[i]}{TASK_VLA_TOKEN}"
|
||||
),
|
||||
},
|
||||
]
|
||||
messages.append(
|
||||
[
|
||||
{"role": "system", "content": [{"type": "text", "text": SYSTEM_MESSAGE}]},
|
||||
{"role": "user", "content": content},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"{ACTION_START_TOKEN}{DEFAULT_ACTION_TOKEN * self.chunk_size}{ACTION_END_TOKEN}",
|
||||
}
|
||||
],
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
complementary_data["messages"] = messages
|
||||
|
||||
return complementary_data
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
"""
|
||||
This step only materializes EO1-specific message objects in complementary_data.
|
||||
PipelineFeatureType tracks only ACTION and OBSERVATION, so there is no static
|
||||
feature contract change to record here.
|
||||
"""
|
||||
return features
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {
|
||||
"input_features": {
|
||||
key: {"type": ft.type.value, "shape": ft.shape} for key, ft in self.input_features.items()
|
||||
},
|
||||
"chunk_size": self.chunk_size,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="eo1_qwen_processor")
|
||||
class EO1QwenProcessorStep(ComplementaryDataProcessorStep):
|
||||
processor_name: str = "Qwen/Qwen2.5-VL-3B-Instruct"
|
||||
image_min_pixels: int | None = 64 * 28 * 28
|
||||
image_max_pixels: int | None = 128 * 28 * 28
|
||||
use_fast_processor: bool = False
|
||||
|
||||
_processor: Qwen2_5_VLProcessor | None = field(default=None, init=False, repr=False)
|
||||
_state_token_id: int | None = field(default=None, init=False, repr=False)
|
||||
_action_token_id: int | None = field(default=None, init=False, repr=False)
|
||||
|
||||
def __post_init__(self):
|
||||
require_package("transformers", extra="eo1")
|
||||
self._processor = Qwen2_5_VLProcessor.from_pretrained(
|
||||
self.processor_name,
|
||||
use_fast=self.use_fast_processor,
|
||||
)
|
||||
self._processor.tokenizer.add_tokens(EO1_SPECIAL_TOKENS, special_tokens=True)
|
||||
self._state_token_id = self._processor.tokenizer.convert_tokens_to_ids(DEFAULT_STATE_TOKEN)
|
||||
self._action_token_id = self._processor.tokenizer.convert_tokens_to_ids(DEFAULT_ACTION_TOKEN)
|
||||
|
||||
def complementary_data(self, complementary_data):
|
||||
messages = complementary_data.pop("messages", None)
|
||||
if messages is None:
|
||||
raise ValueError("Messages are required for EO1QwenProcessorStep.")
|
||||
|
||||
# Rollout batches use left padding so action spans stay aligned across samples.
|
||||
# Supervised batches use right padding to match standard training collation.
|
||||
padding_side = "right" if self.transition.get(TransitionKey.ACTION) is not None else "left"
|
||||
|
||||
inputs = self._processor.apply_chat_template(
|
||||
messages,
|
||||
tokenize=True,
|
||||
padding=True,
|
||||
padding_side=padding_side,
|
||||
min_pixels=self.image_min_pixels,
|
||||
max_pixels=self.image_max_pixels,
|
||||
add_generation_prompt=False,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
complementary_data["input_ids"] = inputs["input_ids"]
|
||||
complementary_data["pixel_values"] = inputs["pixel_values"]
|
||||
complementary_data["image_grid_thw"] = inputs["image_grid_thw"]
|
||||
complementary_data["attention_mask"] = inputs["attention_mask"]
|
||||
complementary_data["mm_token_type_ids"] = inputs["mm_token_type_ids"]
|
||||
complementary_data["state_token_id"] = self._state_token_id
|
||||
complementary_data["action_token_id"] = self._action_token_id
|
||||
|
||||
return complementary_data
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {
|
||||
"processor_name": self.processor_name,
|
||||
"image_min_pixels": self.image_min_pixels,
|
||||
"image_max_pixels": self.image_max_pixels,
|
||||
"use_fast_processor": self.use_fast_processor,
|
||||
}
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
"""
|
||||
This step only converts the messages to the model input format.
|
||||
"""
|
||||
return features
|
||||
|
||||
|
||||
def make_eo1_pre_post_processors(
|
||||
config: EO1Config,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
"""Build pre/post processor pipelines for EO1."""
|
||||
|
||||
input_steps: list[ProcessorStep] = [
|
||||
RenameObservationsProcessorStep(rename_map={}),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
EO1ConversationTemplateStep(input_features=config.input_features, chunk_size=config.chunk_size),
|
||||
EO1QwenProcessorStep(
|
||||
processor_name=config.vlm_base,
|
||||
image_min_pixels=config.image_min_pixels,
|
||||
image_max_pixels=config.image_max_pixels,
|
||||
use_fast_processor=config.use_fast_processor,
|
||||
),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
]
|
||||
|
||||
output_steps: list[ProcessorStep] = [
|
||||
UnnormalizerProcessorStep(
|
||||
features=config.output_features,
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
]
|
||||
|
||||
return (
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
|
||||
steps=input_steps,
|
||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
),
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction](
|
||||
steps=output_steps,
|
||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
to_transition=policy_action_to_transition,
|
||||
to_output=transition_to_policy_action,
|
||||
),
|
||||
)
|
||||
@@ -46,6 +46,7 @@ from lerobot.utils.feature_utils import dataset_to_policy_features
|
||||
|
||||
from .act.configuration_act import ACTConfig
|
||||
from .diffusion.configuration_diffusion import DiffusionConfig
|
||||
from .eo1.configuration_eo1 import EO1Config
|
||||
from .groot.configuration_groot import GrootConfig
|
||||
from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
|
||||
from .pi0.configuration_pi0 import PI0Config
|
||||
@@ -146,6 +147,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
from .wall_x.modeling_wall_x import WallXPolicy
|
||||
|
||||
return WallXPolicy
|
||||
elif name == "eo1":
|
||||
from .eo1.modeling_eo1 import EO1Policy
|
||||
|
||||
return EO1Policy
|
||||
else:
|
||||
try:
|
||||
return _get_policy_cls_from_policy_name(name=name)
|
||||
@@ -196,6 +201,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
return XVLAConfig(**kwargs)
|
||||
elif policy_type == "wall_x":
|
||||
return WallXConfig(**kwargs)
|
||||
elif policy_type == "eo1":
|
||||
return EO1Config(**kwargs)
|
||||
else:
|
||||
try:
|
||||
config_cls = PreTrainedConfig.get_choice_class(policy_type)
|
||||
@@ -399,6 +406,13 @@ def make_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
elif isinstance(policy_cfg, EO1Config):
|
||||
from .eo1.processor_eo1 import make_eo1_pre_post_processors
|
||||
|
||||
processors = make_eo1_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
else:
|
||||
try:
|
||||
|
||||
@@ -97,8 +97,8 @@ class VQBeTConfig(PreTrainedConfig):
|
||||
vision_backbone: str = "resnet18"
|
||||
crop_shape: tuple[int, int] | None = (84, 84)
|
||||
crop_is_random: bool = True
|
||||
pretrained_backbone_weights: str | None = None
|
||||
use_group_norm: bool = True
|
||||
pretrained_backbone_weights: str | None = "ResNet18_Weights.IMAGENET1K_V1"
|
||||
use_group_norm: bool = False
|
||||
spatial_softmax_num_keypoints: int = 32
|
||||
# VQ-VAE
|
||||
n_vqvae_training_steps: int = 20000
|
||||
|
||||
@@ -54,6 +54,7 @@ class BiOpenArmFollower(Robot):
|
||||
calibration_dir=config.calibration_dir,
|
||||
port=config.left_arm_config.port,
|
||||
disable_torque_on_disconnect=config.left_arm_config.disable_torque_on_disconnect,
|
||||
use_velocity_and_torque=config.left_arm_config.use_velocity_and_torque,
|
||||
max_relative_target=config.left_arm_config.max_relative_target,
|
||||
cameras=left_cameras,
|
||||
side=config.left_arm_config.side,
|
||||
@@ -72,6 +73,7 @@ class BiOpenArmFollower(Robot):
|
||||
calibration_dir=config.calibration_dir,
|
||||
port=config.right_arm_config.port,
|
||||
disable_torque_on_disconnect=config.right_arm_config.disable_torque_on_disconnect,
|
||||
use_velocity_and_torque=config.right_arm_config.use_velocity_and_torque,
|
||||
max_relative_target=config.right_arm_config.max_relative_target,
|
||||
cameras=right_cameras,
|
||||
side=config.right_arm_config.side,
|
||||
|
||||
@@ -46,7 +46,7 @@ class LeKiwiConfig(RobotConfig):
|
||||
cameras: dict[str, CameraConfig] = field(default_factory=lekiwi_cameras_config)
|
||||
|
||||
# Set to `True` for backward compatibility with previous policies/dataset
|
||||
use_degrees: bool = False
|
||||
use_degrees: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -66,6 +66,10 @@ class OpenArmFollowerConfigBase:
|
||||
# Whether to disable torque when disconnecting
|
||||
disable_torque_on_disconnect: bool = True
|
||||
|
||||
# When True, expose `.vel` and `.torque` per motor in observation features.
|
||||
# Default False for compatibility with the position-only openarm_mini teleoperator.
|
||||
use_velocity_and_torque: bool = False
|
||||
|
||||
# Safety limit for relative target positions
|
||||
# Set to a positive scalar for all motors, or a dict mapping motor names to limits
|
||||
max_relative_target: float | dict[str, float] | None = None
|
||||
|
||||
@@ -93,8 +93,9 @@ class OpenArmFollower(Robot):
|
||||
features: dict[str, type] = {}
|
||||
for motor in self.bus.motors:
|
||||
features[f"{motor}.pos"] = float
|
||||
features[f"{motor}.vel"] = float # Add this
|
||||
features[f"{motor}.torque"] = float # Add this
|
||||
if self.config.use_velocity_and_torque:
|
||||
features[f"{motor}.vel"] = float
|
||||
features[f"{motor}.torque"] = float
|
||||
return features
|
||||
|
||||
@property
|
||||
@@ -235,8 +236,9 @@ class OpenArmFollower(Robot):
|
||||
for motor in self.bus.motors:
|
||||
state = states.get(motor, {})
|
||||
obs_dict[f"{motor}.pos"] = state.get("position", 0.0)
|
||||
obs_dict[f"{motor}.vel"] = state.get("velocity", 0.0)
|
||||
obs_dict[f"{motor}.torque"] = state.get("torque", 0.0)
|
||||
if self.config.use_velocity_and_torque:
|
||||
obs_dict[f"{motor}.vel"] = state.get("velocity", 0.0)
|
||||
obs_dict[f"{motor}.torque"] = state.get("torque", 0.0)
|
||||
|
||||
# Capture images from cameras
|
||||
for cam_key, cam in self.cameras.items():
|
||||
|
||||
@@ -49,6 +49,7 @@ class BiOpenArmLeader(Teleoperator):
|
||||
can_data_bitrate=config.left_arm_config.can_data_bitrate,
|
||||
motor_config=config.left_arm_config.motor_config,
|
||||
manual_control=config.left_arm_config.manual_control,
|
||||
use_velocity_and_torque=config.left_arm_config.use_velocity_and_torque,
|
||||
position_kd=config.left_arm_config.position_kd,
|
||||
position_kp=config.left_arm_config.position_kp,
|
||||
)
|
||||
@@ -63,6 +64,7 @@ class BiOpenArmLeader(Teleoperator):
|
||||
can_data_bitrate=config.right_arm_config.can_data_bitrate,
|
||||
motor_config=config.right_arm_config.motor_config,
|
||||
manual_control=config.right_arm_config.manual_control,
|
||||
use_velocity_and_torque=config.right_arm_config.use_velocity_and_torque,
|
||||
position_kd=config.right_arm_config.position_kd,
|
||||
position_kp=config.right_arm_config.position_kp,
|
||||
)
|
||||
|
||||
@@ -60,6 +60,10 @@ class OpenArmLeaderConfigBase:
|
||||
# When enabled, motors have torque disabled for manual movement
|
||||
manual_control: bool = True
|
||||
|
||||
# When True, expose `.vel` and `.torque` per motor in action features.
|
||||
# Default False for compatibility with the position-only openarm_mini teleoperator.
|
||||
use_velocity_and_torque: bool = False
|
||||
|
||||
# TODO(Steven, Pepijn): Not used ... ?
|
||||
# MIT control parameters (used when manual_control=False for torque control)
|
||||
# List of 8 values: [joint_1, joint_2, joint_3, joint_4, joint_5, joint_6, joint_7, gripper]
|
||||
|
||||
@@ -70,8 +70,9 @@ class OpenArmLeader(Teleoperator):
|
||||
features: dict[str, type] = {}
|
||||
for motor in self.bus.motors:
|
||||
features[f"{motor}.pos"] = float
|
||||
features[f"{motor}.vel"] = float
|
||||
features[f"{motor}.torque"] = float
|
||||
if self.config.use_velocity_and_torque:
|
||||
features[f"{motor}.vel"] = float
|
||||
features[f"{motor}.torque"] = float
|
||||
return features
|
||||
|
||||
@property
|
||||
@@ -201,8 +202,9 @@ class OpenArmLeader(Teleoperator):
|
||||
for motor in self.bus.motors:
|
||||
state = states.get(motor, {})
|
||||
action_dict[f"{motor}.pos"] = state.get("position")
|
||||
action_dict[f"{motor}.vel"] = state.get("velocity")
|
||||
action_dict[f"{motor}.torque"] = state.get("torque")
|
||||
if self.config.use_velocity_and_torque:
|
||||
action_dict[f"{motor}.vel"] = state.get("velocity")
|
||||
action_dict[f"{motor}.torque"] = state.get("torque")
|
||||
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read state: {dt_ms:.1f}ms")
|
||||
|
||||
Reference in New Issue
Block a user