fix pre-commit errors

This commit is contained in:
Geoffrey19
2025-12-17 20:27:14 +08:00
committed by Michel Aractingi
parent 9ce6dd9e25
commit a0c9a7d85d
8 changed files with 577 additions and 955 deletions
+2 -2
View File
@@ -7,11 +7,12 @@ This repository contains the Hugging Face port of **WALL-OSS**, a Vision-Languag
## Model Overview
| Feature | Description |
| -------------------- | ------------------------------------------------------------------------ |
| ------------------ | ----------------------------------------------------- | --- |
| Base Model | Qwen2.5-VL (Vision-Language Model) |
| Action Prediction | Flow Matching (diffusion) or FAST (discrete tokens) |
| Architecture | Mixture of Experts (MoE) with action-specific routing | |
| Multi-Modal Inputs | Vision (images/videos), Language, Proprioception |
---
## Citation
@@ -32,4 +33,3 @@ If you use this work, please cite:
## License
This port follows the **Apache 2.0 License**.
@@ -85,9 +85,7 @@ class WallXConfig(PreTrainedConfig):
)
if self.prediction_mode not in ["diffusion", "fast"]:
raise ValueError(
f"prediction_mode must be 'diffusion' or 'fast', got {self.prediction_mode}"
)
raise ValueError(f"prediction_mode must be 'diffusion' or 'fast', got {self.prediction_mode}")
# Assign use_fast_tokenizer based on prediction_mode
if self.prediction_mode == "fast":
@@ -96,9 +94,7 @@ class WallXConfig(PreTrainedConfig):
self.use_fast_tokenizer = False
self.action_tokenizer_path = None # disable action tokenizer for diffusion mode
else:
raise ValueError(
f"prediction_mode must be 'diffusion' or 'fast', got {self.prediction_mode}"
)
raise ValueError(f"prediction_mode must be 'diffusion' or 'fast', got {self.prediction_mode}")
def validate_features(self) -> None:
"""Validate and set up input/output features."""
+194 -318
View File
@@ -34,61 +34,57 @@ lerobot-train \
```
"""
import math
from os import PathLike
from collections import deque
from typing import Any, Dict, List, Optional, Tuple, Union
from PIL import Image
from os import PathLike
from typing import Any
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from peft import LoraConfig, get_peft_model
from PIL import Image
from qwen_vl_utils.vision_process import smart_resize
from torch import Tensor
from torch.distributions import Beta
from torch.nn import CrossEntropyLoss
from torchdiffeq import odeint
from transformers import AutoProcessor
from transformers import AutoProcessor, BatchFeature
from transformers.cache_utils import (
StaticCache,
)
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
Qwen2_5_VLForConditionalGeneration,
)
from transformers.utils import is_torchdynamo_compiling, logging
from transformers import AutoProcessor, BatchFeature
from qwen_vl_utils.vision_process import smart_resize
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.utils import populate_queues
from lerobot.policies.wall_x.configuration_wall_x import WallXConfig
from lerobot.utils.constants import ACTION, OBS_STATE
from lerobot.policies.wall_x.utils import (
replace_action_token,
preprocesser_call,
get_wallx_normal_text,
process_grounding_points,
)
from lerobot.policies.wall_x.constant import (
MODEL_TYPE,
TOKENIZER_MAX_LENGTH,
PRIORITY_ORDER,
GENERATE_SUBTASK_RATIO,
RESOLUTION,
IMAGE_FACTOR,
MAX_PIXELS,
MIN_PIXELS,
IMAGE_FACTOR,
MODEL_TYPE,
PRIORITY_ORDER,
RESOLUTION,
TOKENIZER_MAX_LENGTH,
)
from lerobot.policies.wall_x.qwen_model.configuration_qwen2_5_vl import Qwen2_5_VLConfig
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
Qwen2_5_VLForConditionalGeneration,
)
from lerobot.policies.wall_x.qwen_model.qwen2_5_vl_moe import (
Qwen2_5_VisionTransformerPretrainedModel,
Qwen2_5_VLACausalLMOutputWithPast,
Qwen2_5_VLMoEModel,
)
from lerobot.policies.wall_x.utils import (
get_wallx_normal_text,
preprocesser_call,
process_grounding_points,
replace_action_token,
)
from lerobot.utils.constants import ACTION, OBS_STATE
logger = logging.get_logger(__name__)
@@ -151,7 +147,7 @@ class ActionHead(nn.Module):
"""Sample timesteps using Beta distribution (always in float32 for numerical stability)."""
beta_dist = Beta(
torch.tensor(self.beta_alpha, dtype=torch.float32, device=device),
torch.tensor(self.beta_beta, dtype=torch.float32, device=device)
torch.tensor(self.beta_beta, dtype=torch.float32, device=device),
)
sample = beta_dist.sample([batch_size])
time = (1 - sample) * self.s
@@ -275,14 +271,14 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
pretrained_name_or_path,
config=None,
action_tokenizer_path=None,
attn_implementation: str = 'eager',
attn_implementation: str = "eager",
cache_dir: str | PathLike | None = None,
force_download: bool = False,
local_files_only: bool = False,
token: str | bool | None = None,
revision: str = "main",
strict: bool = False,
**kwargs: Any
**kwargs: Any,
):
"""
Load model from pretrained model path.
@@ -312,9 +308,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
config._attn_implementation = attn_implementation
processor = AutoProcessor.from_pretrained(pretrained_name_or_path, use_fast=True)
if action_tokenizer_path is not None:
action_tokenizer = AutoProcessor.from_pretrained(
action_tokenizer_path, trust_remote_code=True
)
action_tokenizer = AutoProcessor.from_pretrained(action_tokenizer_path, trust_remote_code=True)
processor.action_processor = action_tokenizer
else:
action_tokenizer = None
@@ -387,9 +381,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
super().__init__(config)
# Initialize vision transformer and language model components
self.visual = Qwen2_5_VisionTransformerPretrainedModel._from_config(
config.vision_config
)
self.visual = Qwen2_5_VisionTransformerPretrainedModel._from_config(config.vision_config)
self.model = Qwen2_5_VLMoEModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
@@ -446,12 +438,8 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
# Create list of fast action token IDs
fast_action_token_list = []
if self.use_fast_tokenizer:
for i in range(
self.processor.tokenizer.init_kwargs["action_token_vocab_size"]
):
action_token_id = self.processor.tokenizer.convert_tokens_to_ids(
f"<|action_token_{i}|>"
)
for i in range(self.processor.tokenizer.init_kwargs["action_token_vocab_size"]):
action_token_id = self.processor.tokenizer.convert_tokens_to_ids(f"<|action_token_{i}|>")
fast_action_token_list.append(action_token_id)
# Get special action token IDs
@@ -465,9 +453,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
"action_token_id": action_token_id,
}
def add_lora(
self, r=8, lora_alpha=32, target_modules=["q_proj", "v_proj"], lora_dropout=0.1
):
def add_lora(self, r=8, lora_alpha=32, target_modules=["q_proj", "v_proj"], lora_dropout=0.1):
"""
Add LoRA (Low-Rank Adaptation) adapters to the model.
@@ -516,12 +502,12 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
def get_rope_index(
self,
input_ids: Optional[torch.LongTensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
second_per_grid_ts: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
input_ids: torch.LongTensor | None = None,
image_grid_thw: torch.LongTensor | None = None,
video_grid_thw: torch.LongTensor | None = None,
second_per_grid_ts: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Calculate 3D RoPE (Rotary Position Embedding) indices for vision and text tokens.
@@ -555,9 +541,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
vision_start_token_id = self.config.vision_start_token_id
mrope_position_deltas = []
if input_ids is not None and (
image_grid_thw is not None or video_grid_thw is not None
):
if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None):
total_input_ids = input_ids
if attention_mask is None:
attention_mask = torch.ones_like(total_input_ids)
@@ -580,9 +564,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
image_nums, video_nums = 0, 0
# Find vision tokens and count images/videos
vision_start_indices = torch.argwhere(
input_ids == vision_start_token_id
).squeeze(1)
vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1)
vision_tokens = input_ids[vision_start_indices + 1]
image_nums = (vision_tokens == image_token_id).sum()
video_nums = (vision_tokens == video_token_id).sum()
@@ -641,14 +623,8 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
text_len = ed - st
# Add position IDs for text tokens before vision token
st_idx = (
llm_pos_ids_list[-1].max() + 1
if len(llm_pos_ids_list) > 0
else 0
)
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
)
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
# Calculate 3D position embeddings for vision tokens
range_tensor = torch.arange(llm_grid_t).view(-1, 1)
@@ -656,71 +632,43 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
# Calculate temporal position IDs with time scaling
time_tensor = (
expanded_range
* second_per_grid_t
* self.config.vision_config.tokens_per_second
expanded_range * second_per_grid_t * self.config.vision_config.tokens_per_second
)
time_tensor_long = time_tensor.long()
t_index = time_tensor_long.flatten()
# Calculate spatial position IDs
h_index = (
torch.arange(llm_grid_h)
.view(1, -1, 1)
.expand(llm_grid_t, -1, llm_grid_w)
.flatten()
torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()
)
w_index = (
torch.arange(llm_grid_w)
.view(1, 1, -1)
.expand(llm_grid_t, llm_grid_h, -1)
.flatten()
torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()
)
# Add 3D position IDs for vision tokens
llm_pos_ids_list.append(
torch.stack([t_index, h_index, w_index]) + text_len + st_idx
)
llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
# Add position IDs for remaining text tokens
if st < len(input_tokens):
st_idx = (
llm_pos_ids_list[-1].max() + 1
if len(llm_pos_ids_list) > 0
else 0
)
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
text_len = len(input_tokens) - st
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
)
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
# Concatenate all position IDs for this sequence
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(
position_ids.device
)
mrope_position_deltas.append(
llm_positions.max() + 1 - len(total_input_ids[i])
)
position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device)
mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i]))
mrope_position_deltas = torch.tensor(
mrope_position_deltas, device=input_ids.device
).unsqueeze(1)
mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1)
return position_ids, mrope_position_deltas
else:
# Handle case without vision tokens - use standard 1D position embeddings
if attention_mask is not None:
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
position_ids = (
position_ids.unsqueeze(0)
.expand(3, -1, -1)
.to(attention_mask.device)
)
max_position_ids = position_ids.max(0, keepdim=False)[0].max(
-1, keepdim=True
)[0]
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device)
max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0]
mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
else:
position_ids = (
@@ -739,33 +687,29 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
def train_step_forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
moe_token_types: Optional[
torch.LongTensor
] = None, # MoE token type assignments
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
pixel_values: Optional[torch.Tensor] = None,
pixel_values_videos: Optional[torch.FloatTensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
action_chunk: Optional[torch.FloatTensor] = None, # Action trajectory chunks
proprioception: Optional[
torch.FloatTensor
] = None, # Joint position/orientation data
rope_deltas: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
second_per_grid_ts: Optional[torch.Tensor] = None,
dof_mask: Optional[torch.FloatTensor] = None,
agent_pos_mask: Optional[torch.FloatTensor] = None,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_values: list[torch.FloatTensor] | None = None,
inputs_embeds: torch.FloatTensor | None = None,
moe_token_types: torch.LongTensor | None = None, # MoE token type assignments
labels: torch.LongTensor | None = None,
use_cache: bool | None = None,
output_attentions: bool | None = None,
output_hidden_states: bool | None = None,
return_dict: bool | None = None,
pixel_values: torch.Tensor | None = None,
pixel_values_videos: torch.FloatTensor | None = None,
image_grid_thw: torch.LongTensor | None = None,
video_grid_thw: torch.LongTensor | None = None,
action_chunk: torch.FloatTensor | None = None, # Action trajectory chunks
proprioception: torch.FloatTensor | None = None, # Joint position/orientation data
rope_deltas: torch.LongTensor | None = None,
cache_position: torch.LongTensor | None = None,
second_per_grid_ts: torch.Tensor | None = None,
dof_mask: torch.FloatTensor | None = None,
agent_pos_mask: torch.FloatTensor | None = None,
**kwargs,
) -> Union[Tuple, Qwen2_5_VLACausalLMOutputWithPast]:
) -> tuple | Qwen2_5_VLACausalLMOutputWithPast:
"""
Forward pass for training with multi-modal inputs including vision, text, and action data.
@@ -806,24 +750,16 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
# Set output configuration from model config if not specified
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
output_attentions if output_attentions is not None else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# Calculate RoPE position IDs if not provided
# Note: Cannot calculate rope deltas with 4D attention mask. TODO: Fix this limitation
if position_ids is None and (
attention_mask is None or attention_mask.ndim == 2
):
if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
# Calculate RoPE index once per generation in the pre-fill stage only
if (
(cache_position is not None and cache_position[0] == 0)
@@ -865,9 +801,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
image_mask = mask_expanded.to(inputs_embeds.device)
image_embeds = image_embeds.to(
inputs_embeds.device, inputs_embeds.dtype
)
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
# Process video embeddings
@@ -887,19 +821,13 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
video_mask = mask_expanded.to(inputs_embeds.device)
video_embeds = video_embeds.to(
inputs_embeds.device, inputs_embeds.dtype
)
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
# Process proprioceptive data (joint positions, orientations, etc.)
if proprioception is not None:
proprioception = proprioception.to(inputs_embeds.device).to(
inputs_embeds.dtype
)
agent_pos_mask = agent_pos_mask.to(inputs_embeds.device).to(
inputs_embeds.dtype
)
proprioception = proprioception.to(inputs_embeds.device).to(inputs_embeds.dtype)
agent_pos_mask = agent_pos_mask.to(inputs_embeds.device).to(inputs_embeds.dtype)
proprioception = self.action_preprocessor.proprioception_proj(
proprioception,
agent_pos_mask,
@@ -910,12 +838,8 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
proprioception_mask = mask_expanded.to(inputs_embeds.device)
proprioception = proprioception.to(
inputs_embeds.device, inputs_embeds.dtype
)
inputs_embeds = inputs_embeds.masked_scatter(
proprioception_mask, proprioception
)
proprioception = proprioception.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(proprioception_mask, proprioception)
elif self.training:
# Dummy forward pass to ensure gradient registration in DDP
# This handles cases where one process has proprioception data while another doesn't
@@ -925,32 +849,22 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
self.action_preprocessor.propri_dim * 2,
device=inputs_embeds.device,
)
dummy_forward = self.action_preprocessor.proprioception_proj(
dummy_input
)
dummy_forward = self.action_preprocessor.proprioception_proj(dummy_input)
dummy_loss = sum(p.sum() for p in dummy_forward)
inputs_embeds = inputs_embeds + 0 * dummy_loss
# Process action chunk data
if action_chunk is not None:
action_chunk = action_chunk.to(inputs_embeds.device).to(
inputs_embeds.dtype
)
action_chunk = action_chunk.to(inputs_embeds.device).to(inputs_embeds.dtype)
dof_mask = dof_mask.to(inputs_embeds.device).to(inputs_embeds.dtype)
noisy_action_emb, flow = self.action_preprocessor(
action_chunk, dof_mask
)
noisy_action_emb, flow = self.action_preprocessor(action_chunk, dof_mask)
mask = input_ids == self.action_token_id_set["action_token_id"]
mask_unsqueezed = mask.unsqueeze(-1)
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
action_mask = mask_expanded.to(inputs_embeds.device)
noisy_action_emb = noisy_action_emb.to(
inputs_embeds.device, inputs_embeds.dtype
)
inputs_embeds = inputs_embeds.masked_scatter(
action_mask, noisy_action_emb
)
noisy_action_emb = noisy_action_emb.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(action_mask, noisy_action_emb)
if attention_mask is not None:
attention_mask = attention_mask.to(inputs_embeds.device)
@@ -1011,18 +925,14 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
if action_mask.any():
action_hidden_states = hidden_states[action_mask].to(torch.float32)
flow = flow.reshape(-1, flow.shape[-1]).to(torch.float32)
_flow_loss = self.action_preprocessor.flow_loss(
action_hidden_states, flow, dof_mask
)
_flow_loss = self.action_preprocessor.flow_loss(action_hidden_states, flow, dof_mask)
if isinstance(_flow_loss, torch.Tensor):
flow_loss = _flow_loss.mean()
if loss is not None:
loss = loss + self.flow_loss_weight * flow_loss.to(torch.float32)
else:
loss = self.flow_loss_weight * flow_loss.to(torch.float32)
_flow_loss = _flow_loss.view(
dof_mask.shape[0], dof_mask.shape[1], dof_mask.shape[2]
)
_flow_loss = _flow_loss.view(dof_mask.shape[0], dof_mask.shape[1], dof_mask.shape[2])
# Return outputs based on return_dict setting
if not return_dict:
@@ -1031,9 +941,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
return Qwen2_5_VLACausalLMOutputWithPast(
loss=loss,
cross_entropy_loss=(
cross_entropy_loss.clone() if cross_entropy_loss is not None else None
),
cross_entropy_loss=(cross_entropy_loss.clone() if cross_entropy_loss is not None else None),
flow_loss=flow_loss,
logits=logits,
past_key_values=outputs.past_key_values,
@@ -1065,31 +973,31 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
def predict(
self,
predict_mode: str,
pred_horizon: Optional[int] = None,
action_dim: Optional[int] = None,
pred_horizon: int | None = None,
action_dim: int | None = None,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
moe_token_types: Optional[torch.LongTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
pixel_values: Optional[torch.Tensor] = None,
pixel_values_videos: Optional[torch.FloatTensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
action_chunk: Optional[torch.FloatTensor] = None,
proprioception: Optional[torch.FloatTensor] = None,
rope_deltas: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
second_per_grid_ts: Optional[torch.Tensor] = None,
num_inference_timesteps: Optional[int] = 10,
dof_mask: Optional[torch.FloatTensor] = None,
agent_pos_mask: Optional[torch.FloatTensor] = None,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_values: list[torch.FloatTensor] | None = None,
inputs_embeds: torch.FloatTensor | None = None,
moe_token_types: torch.LongTensor | None = None,
labels: torch.LongTensor | None = None,
use_cache: bool | None = None,
output_attentions: bool | None = None,
output_hidden_states: bool | None = None,
return_dict: bool | None = None,
pixel_values: torch.Tensor | None = None,
pixel_values_videos: torch.FloatTensor | None = None,
image_grid_thw: torch.LongTensor | None = None,
video_grid_thw: torch.LongTensor | None = None,
action_chunk: torch.FloatTensor | None = None,
proprioception: torch.FloatTensor | None = None,
rope_deltas: torch.LongTensor | None = None,
cache_position: torch.LongTensor | None = None,
second_per_grid_ts: torch.Tensor | None = None,
num_inference_timesteps: int | None = 10,
dof_mask: torch.FloatTensor | None = None,
agent_pos_mask: torch.FloatTensor | None = None,
re_generate: bool = False,
**kwargs,
):
@@ -1139,30 +1047,20 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
- 'predict_output_text': Generated text (for text/fast modes)
- 'gt_output_text': Ground truth text (for text/fast modes)
"""
batch_size = (
input_ids.shape[0] if input_ids is not None else inputs_embeds.shape[0]
)
batch_size = input_ids.shape[0] if input_ids is not None else inputs_embeds.shape[0]
# Text and fast modes require batch size 1 for autoregressive generation
if predict_mode in ["text", "fast"]:
assert (
batch_size == 1
), "predict only support batch size 1 for ar generation"
assert batch_size == 1, "predict only support batch size 1 for ar generation"
# Set output configuration from model config if not specified
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
output_attentions if output_attentions is not None else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# Process input embeddings with multi-modal data
if inputs_embeds is None:
@@ -1186,9 +1084,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
image_mask = mask_expanded.to(inputs_embeds.device)
image_embeds = image_embeds.to(
inputs_embeds.device, inputs_embeds.dtype
)
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
# Process video embeddings
@@ -1209,40 +1105,28 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
video_mask = mask_expanded.to(inputs_embeds.device)
video_embeds = video_embeds.to(
inputs_embeds.device, inputs_embeds.dtype
)
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
# Process proprioceptive data
if proprioception is not None:
proprioception = proprioception.to(inputs_embeds.device).to(
inputs_embeds.dtype
)
agent_pos_mask = agent_pos_mask.to(inputs_embeds.device).to(
inputs_embeds.dtype
)
proprioception = proprioception.to(inputs_embeds.device).to(inputs_embeds.dtype)
agent_pos_mask = agent_pos_mask.to(inputs_embeds.device).to(inputs_embeds.dtype)
proprio_embed = self.action_preprocessor.proprioception_proj(
proprioception,
agent_pos_mask,
use_history=proprioception.shape[1] > 1,
)
proprioception_mask = (
input_ids == self.action_token_id_set["propri_token_id"]
)
proprioception_mask = input_ids == self.action_token_id_set["propri_token_id"]
proprio_embed = proprio_embed.to(torch.bfloat16)
inputs_embeds[proprioception_mask] = proprio_embed.reshape(
-1, inputs_embeds.shape[-1]
)
inputs_embeds[proprioception_mask] = proprio_embed.reshape(-1, inputs_embeds.shape[-1])
if attention_mask is not None:
attention_mask = attention_mask.to(inputs_embeds.device)
# Calculate RoPE position IDs if not provided
# Note: Cannot calculate rope deltas with 4D attention mask. TODO: Fix this limitation
if position_ids is None and (
attention_mask is None or attention_mask.ndim == 2
):
if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
# Calculate RoPE index once per generation in the pre-fill stage only
if (
(cache_position is not None and cache_position[0] == 0)
@@ -1332,12 +1216,8 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
eos_token_id=[self.processor.tokenizer.eos_token_id],
use_cache=True,
pad_token_id=self.processor.tokenizer.pad_token_id,
temperature=(
1.0 if not re_generate else 0.7
), # Higher temperature for regeneration
do_sample=(
False if not re_generate else True
), # Enable sampling for regeneration
temperature=(1.0 if not re_generate else 0.7), # Higher temperature for regeneration
do_sample=(False if not re_generate else True), # Enable sampling for regeneration
)
# Decode generated and ground truth text
@@ -1359,15 +1239,9 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
action_id = []
# Extract action tokens from generated sequence
for token_id_i in predict_output_ids[0]:
if (
token_id_i.item()
>= self.processor.tokenizer.init_kwargs["action_token_start_index"]
):
if token_id_i.item() >= self.processor.tokenizer.init_kwargs["action_token_start_index"]:
action_id.append(
token_id_i.item()
- self.processor.tokenizer.init_kwargs[
"action_token_start_index"
]
token_id_i.item() - self.processor.tokenizer.init_kwargs["action_token_start_index"]
)
predict_action = self.processor.action_processor.decode(
@@ -1382,9 +1256,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
predict_action = torch.tensor(predict_action, device=self.device)
dof_mask = dof_mask.to(self.device).to(pixel_values.dtype)
# removed unnormalization step for now
predict_action = (
predict_action[:, :, dof_mask[0, 0, :].bool()]
)
predict_action = predict_action[:, :, dof_mask[0, 0, :].bool()]
output["predict_action"] = predict_action
# Process ground truth actions if available
@@ -1459,7 +1331,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
times = torch.linspace(
0,
1,
num_inference_timesteps+1,
num_inference_timesteps + 1,
device=inputs_embeds.device,
dtype=torch.float32,
)
@@ -1477,9 +1349,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
return output
def forward(
self, mode: Optional[str] = None, predict_mode: Optional[str] = "text", **kwargs
):
def forward(self, mode: str | None = None, predict_mode: str | None = "text", **kwargs):
"""
Main forward pass dispatcher for different execution modes.
@@ -1592,9 +1462,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
# Handle input slicing based on cache state and special cases
if past_key_values is not None:
if (
inputs_embeds is not None and input_ids.shape[1] == 0
): # Exception 4: input_embeds case
if inputs_embeds is not None and input_ids.shape[1] == 0: # Exception 4: input_embeds case
inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :]
moe_token_types = moe_token_types[:, -cache_position.shape[0] :]
elif inputs_embeds is not None or ( # Exception 1: input_embeds provided
@@ -1602,9 +1470,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
): # Exception 3: GPU sync edge case
input_ids = input_ids[:, -cache_position.shape[0] :]
moe_token_types = moe_token_types[:, -cache_position.shape[0] :]
elif (
input_ids.shape[1] != cache_position.shape[0]
): # Default case (Exception 2 is no-op)
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (Exception 2 is no-op)
cache_pos = cache_position.clone()
input_ids = input_ids[:, cache_pos]
moe_token_types = moe_token_types[:, cache_pos]
@@ -1629,8 +1495,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
batch_size, sequence_length = input_ids.shape
device = input_ids.device
attention_mask = (
self.model._prepare_4d_causal_attention_mask_with_cache_position(
attention_mask = self.model._prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=past_key_values.get_max_cache_shape(),
@@ -1641,7 +1506,6 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
config=self.config,
past_key_values=past_key_values,
)
)
# Assemble all model inputs for generation
model_inputs.update(
@@ -1666,8 +1530,8 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
def _get_image_nums_and_video_nums(
self,
input_ids: Optional[torch.LongTensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
input_ids: torch.LongTensor | None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Get the number of images and videos for each sample to calculate tensor separation lengths.
@@ -1702,9 +1566,9 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
self,
expand_size: int = 1,
is_encoder_decoder: bool = False,
input_ids: Optional[torch.LongTensor] = None,
input_ids: torch.LongTensor | None = None,
**model_kwargs,
) -> Tuple[torch.LongTensor, Dict[str, Any]]:
) -> tuple[torch.LongTensor, dict[str, Any]]:
"""
Expand inputs for generation with support for multi-modal tensors.
@@ -1745,9 +1609,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
"""Split tensor by lengths and repeat each sample."""
samples = torch.split(x, lengths)
repeat_args = [repeat_times] + [1] * (x.dim() - 1)
result = torch.cat(
[sample.repeat(*repeat_args) for sample in samples], dim=0
)
result = torch.cat([sample.repeat(*repeat_args) for sample in samples], dim=0)
return result
for key in dict_to_expand:
@@ -1785,9 +1647,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
)
tensor = torch.tensor(dict_to_expand[key])
lengths = list(video_nums)
tensor = _repeat_interleave_samples(
tensor, lengths=lengths, repeat_times=expand_size
)
tensor = _repeat_interleave_samples(tensor, lengths=lengths, repeat_times=expand_size)
dict_to_expand[key] = tensor.tolist()
return dict_to_expand
@@ -1800,9 +1660,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
and isinstance(dict_to_expand[key], torch.Tensor)
and key not in visual_keys
):
dict_to_expand[key] = dict_to_expand[key].repeat_interleave(
expand_size, dim=0
)
dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0)
return dict_to_expand
# Expand visual inputs only if input_ids is available for counting images/videos
@@ -1823,9 +1681,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
raise ValueError(
"If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined."
)
model_kwargs["encoder_outputs"] = _expand_dict_for_generation(
model_kwargs["encoder_outputs"]
)
model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"])
return input_ids, model_kwargs
@@ -1850,7 +1706,7 @@ class WallXPolicy(PreTrainedPolicy):
self.model = Qwen2_5_VLMoEForAction.from_pretrained(
pretrained_name_or_path=config.pretrained_name_or_path,
action_tokenizer_path=config.action_tokenizer_path,
attn_implementation=config.attn_implementation
attn_implementation=config.attn_implementation,
)
self.model.to(config.device)
self.model.to_bfloat16_for_selected_params()
@@ -1869,7 +1725,7 @@ class WallXPolicy(PreTrainedPolicy):
def preprocess_inputs(
self,
batch: Dict[str, Any],
batch: dict[str, Any],
) -> BatchFeature:
"""
Convert a batch of LeRobot dataset items to Wall-X model input format.
@@ -1950,12 +1806,10 @@ class WallXPolicy(PreTrainedPolicy):
)
text = process_grounding_points(
complete_text, orig_height, orig_width, resized_height, resized_width,
MODEL_TYPE
complete_text, orig_height, orig_width, resized_height, resized_width, MODEL_TYPE
)
all_texts.append(text)
# ==================== PROCESS AGENT POS ====================
agent_pos = batch[OBS_STATE] # (batch_size, state_dim)
if agent_pos.dim() == 2:
@@ -1965,17 +1819,28 @@ class WallXPolicy(PreTrainedPolicy):
if agent_pos.shape[-1] != 20:
pad_size = 20 - agent_pos.shape[-1]
agent_pos = torch.cat([
agent_pos = torch.cat(
[
agent_pos,
torch.zeros(agent_pos.shape[0], agent_pos.shape[1], pad_size, device=agent_pos.device)
], dim=-1)
agent_pos_mask = torch.cat([
torch.zeros(agent_pos.shape[0], agent_pos.shape[1], pad_size, device=agent_pos.device),
],
dim=-1,
)
agent_pos_mask = torch.cat(
[
agent_pos_mask,
torch.zeros(agent_pos_mask.shape[0], agent_pos_mask.shape[1], pad_size, device=agent_pos_mask.device)
], dim=-1)
torch.zeros(
agent_pos_mask.shape[0],
agent_pos_mask.shape[1],
pad_size,
device=agent_pos_mask.device,
),
],
dim=-1,
)
# ==================== PROCESS ACTIONS ====================
action = batch.get(ACTION, None) # (batch_size, chunk_size, action_dim)
action = batch.get(ACTION) # (batch_size, chunk_size, action_dim)
if action is not None:
if action.dim() == 2:
action = action.unsqueeze(1)
@@ -1984,20 +1849,30 @@ class WallXPolicy(PreTrainedPolicy):
if action.shape[-1] != 20:
pad_size = 20 - action.shape[-1]
action = torch.cat([
action,
torch.zeros(action.shape[0], action.shape[1], pad_size, device=action.device)
], dim=-1)
dof_mask = torch.cat([
action = torch.cat(
[action, torch.zeros(action.shape[0], action.shape[1], pad_size, device=action.device)],
dim=-1,
)
dof_mask = torch.cat(
[
dof_mask,
torch.zeros(dof_mask.shape[0], dof_mask.shape[1], pad_size, device=dof_mask.device)
], dim=-1)
torch.zeros(dof_mask.shape[0], dof_mask.shape[1], pad_size, device=dof_mask.device),
],
dim=-1,
)
else:
action_dim = self.config.output_features["action"].shape[0]
dof_mask = torch.cat([
torch.ones(batch_size, self.config.chunk_size, action_dim, device=batch[OBS_STATE].device),
torch.zeros(batch_size, self.config.chunk_size, 20 - action_dim, device=batch[OBS_STATE].device)
], dim=-1)
dof_mask = torch.cat(
[
torch.ones(
batch_size, self.config.chunk_size, action_dim, device=batch[OBS_STATE].device
),
torch.zeros(
batch_size, self.config.chunk_size, 20 - action_dim, device=batch[OBS_STATE].device
),
],
dim=-1,
)
# ==================== ACTION TOKEN REPLACEMENT ====================
all_texts = replace_action_token(
@@ -2028,7 +1903,11 @@ class WallXPolicy(PreTrainedPolicy):
inputs["action_chunk"] = action
inputs["dof_mask"] = dof_mask
inputs["moe_token_types"] = moe_token_types
inputs["frame_index"] = batch["frame_index"] if "frame_index" in batch else torch.zeros(batch_size, device=batch[OBS_STATE].device)
inputs["frame_index"] = (
batch["frame_index"]
if "frame_index" in batch
else torch.zeros(batch_size, device=batch[OBS_STATE].device)
)
# Move all tensors to the correct device
device = self.config.device
@@ -2056,10 +1935,7 @@ class WallXPolicy(PreTrainedPolicy):
)
# Call the underlying model's forward with mode="train"
outputs = self.model(
**batch,
mode="train"
)
outputs = self.model(**batch, mode="train")
# Extract losses from output
loss = outputs.loss
@@ -2096,7 +1972,7 @@ class WallXPolicy(PreTrainedPolicy):
action_dim=self.config.max_action_dim,
pred_horizon=self.config.chunk_size,
mode="predict",
predict_mode="diffusion"
predict_mode="diffusion",
)
elif self.config.prediction_mode == "fast":
output = self.model(
@@ -2104,7 +1980,7 @@ class WallXPolicy(PreTrainedPolicy):
action_dim=self.config.output_features["action"].shape[0],
pred_horizon=self.config.chunk_size,
mode="predict",
predict_mode="fast"
predict_mode="fast",
)
else:
raise NotImplementedError(f"Prediction mode {self.config.prediction_mode} not implemented")
@@ -2127,6 +2003,6 @@ class WallXPolicy(PreTrainedPolicy):
# Use action queue
if len(self._queues[ACTION]) == 0:
actions = self.predict_action_chunk(batch)
self._queues[ACTION].extend(actions.transpose(0, 1)[:self.config.n_action_steps])
self._queues[ACTION].extend(actions.transpose(0, 1)[: self.config.n_action_steps])
return self._queues[ACTION].popleft()
@@ -33,6 +33,8 @@ from lerobot.processor import (
)
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
def make_wall_x_pre_post_processors(
config: WallXConfig,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
@@ -75,9 +77,7 @@ def make_wall_x_pre_post_processors(
output_steps = [
UnnormalizerProcessorStep(
features=config.output_features,
norm_map=config.normalization_mapping,
stats=dataset_stats
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
),
DeviceProcessorStep(device="cpu"),
]
@@ -123,9 +123,7 @@ class WallXTaskProcessor(ComplementaryDataProcessorStep):
new_complementary_data["task"] = f"{task}."
elif isinstance(task, list) and all(isinstance(t, str) for t in task):
# List of strings: format each
new_complementary_data["task"] = [
t if t.endswith(".") else f"{t}." for t in task
]
new_complementary_data["task"] = [t if t.endswith(".") else f"{t}." for t in task]
return new_complementary_data
File diff suppressed because it is too large Load Diff
+49 -80
View File
@@ -25,7 +25,7 @@ import random
import re
from collections import OrderedDict
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any
import torch
from transformers import BatchFeature
@@ -46,11 +46,11 @@ class X2RDataProcessingConfig:
"""
# Action prediction configuration
predict_action_keys: List[str] = field(default_factory=list)
obs_action_keys: List[str] = field(default_factory=list)
predict_action_keys: list[str] = field(default_factory=list)
obs_action_keys: list[str] = field(default_factory=list)
# Image resolution settings for different views
resolution: Dict[str, int] = field(
resolution: dict[str, int] = field(
default_factory=lambda: {
"face_view": -1,
"left_wrist_view": 128,
@@ -63,7 +63,7 @@ class X2RDataProcessingConfig:
split_seed: int = 42
# Instruction handling
priority_order: Optional[Dict[str, float]] = None
priority_order: dict[str, float] | None = None
# Vision model parameters
model_type: str = "qwen2_5"
@@ -77,11 +77,9 @@ class X2RDataProcessingConfig:
"""Post-initialization validation and setup."""
# Validate train/test split
if not 0 < self.train_test_split < 1:
raise ValueError(
f"train_test_split must be between 0 and 1, got {self.train_test_split}"
)
raise ValueError(f"train_test_split must be between 0 and 1, got {self.train_test_split}")
def as_dict(self) -> Dict:
def as_dict(self) -> dict:
"""Convert configuration to dictionary format.
Returns:
@@ -105,14 +103,15 @@ class X2RDataProcessingConfig:
raise ValueError(f"Unknown configuration parameter: {key}")
return self
def preprocesser_call(
processor,
images: Optional[Union[List, Any]] = None,
text: Optional[Union[str, List[str]]] = None,
videos: Optional[Union[List, Any]] = None,
padding: Union[bool, str] = False,
truncation: Optional[bool] = None,
max_length: Optional[int] = None,
images: list | Any | None = None,
text: str | list[str] | None = None,
videos: list | Any | None = None,
padding: bool | str = False,
truncation: bool | None = None,
max_length: int | None = None,
return_tensors: str = "pt",
) -> BatchFeature:
"""Unified preprocessing function for Wall-X model handling text, image and video inputs.
@@ -145,9 +144,7 @@ def preprocesser_call(
"""
# Process image inputs
if images is not None and len(images) > 0:
image_inputs = processor.image_processor(
images=images, videos=None, return_tensors=return_tensors
)
image_inputs = processor.image_processor(images=images, videos=None, return_tensors=return_tensors)
image_grid_thw = image_inputs["image_grid_thw"]
else:
image_inputs = {}
@@ -155,9 +152,7 @@ def preprocesser_call(
# Process video inputs
if videos is not None:
videos_inputs = processor.image_processor(
images=None, videos=videos, return_tensors=return_tensors
)
videos_inputs = processor.image_processor(images=None, videos=videos, return_tensors=return_tensors)
video_grid_thw = videos_inputs["video_grid_thw"]
else:
videos_inputs = {}
@@ -183,9 +178,7 @@ def preprocesser_call(
break
# Replace image placeholder with actual token count
token_count = image_grid_thw[index].prod() // merge_length
text[i] = text[i].replace(
"<|image_pad|>", "<|placeholder|>" * token_count, 1
)
text[i] = text[i].replace("<|image_pad|>", "<|placeholder|>" * token_count, 1)
index += 1
text[i] = text[i].replace("<|placeholder|>", "<|image_pad|>")
@@ -197,9 +190,7 @@ def preprocesser_call(
while "<|video_pad|>" in text[i]:
# Replace video placeholder with actual token count
token_count = video_grid_thw[index].prod() // merge_length
text[i] = text[i].replace(
"<|video_pad|>", "<|placeholder|>" * token_count, 1
)
text[i] = text[i].replace("<|video_pad|>", "<|placeholder|>" * token_count, 1)
index += 1
text[i] = text[i].replace("<|placeholder|>", "<|video_pad|>")
@@ -221,9 +212,7 @@ def preprocesser_call(
labels = torch.full_like(text_inputs.input_ids, -100)
assistant_marker = "<|im_start|>assistant\n"
im_end_token_id = processor.tokenizer.convert_tokens_to_ids("<|im_end|>")
assistant_tokens = processor.tokenizer(
"<|im_start|>assistant\n", add_special_tokens=False
).input_ids
assistant_tokens = processor.tokenizer("<|im_start|>assistant\n", add_special_tokens=False).input_ids
for i in range(len(text)):
assistant_regions = []
@@ -249,9 +238,7 @@ def preprocesser_call(
# From second part onwards, each part starts with assistant response
for k in range(current_pos + 1, len(text_inputs.input_ids[i])):
if text_inputs.input_ids[i][k] == im_end_token_id:
assistant_regions.append(
(current_pos + len(assistant_tokens), k + 2)
)
assistant_regions.append((current_pos + len(assistant_tokens), k + 2))
break
current_pos += len(part_tokens) + 3
@@ -344,7 +331,7 @@ def process_grounding_points(
coords = [new_x1, new_y1, new_x2, new_y2]
# Return processed point tag
return f'<point>[{", ".join(map(str, coords))}]</point>'
return f"<point>[{', '.join(map(str, coords))}]</point>"
except (ValueError, TypeError):
# Return original content if processing fails
@@ -356,10 +343,10 @@ def process_grounding_points(
def get_frame_instruction(
instruction_info: Dict[str, Any],
frame_idx: Optional[int] = None,
truncate_keys: Optional[List[str]] = None,
) -> Tuple[Dict[str, Any], Optional[int]]:
instruction_info: dict[str, Any],
frame_idx: int | None = None,
truncate_keys: list[str] | None = None,
) -> tuple[dict[str, Any], int | None]:
"""Extract frame-specific instruction from instruction dictionary.
Args:
@@ -388,11 +375,7 @@ def get_frame_instruction(
start_frame, end_frame = map(int, frame_range.split(" "))
if start_frame <= frame_idx < end_frame or (start_frame == frame_idx):
instruction_for_frame[key] = frame_instruction
if (
truncate_keys is not None
and split_end is None
and key in truncate_keys
):
if truncate_keys is not None and split_end is None and key in truncate_keys:
split_end = end_frame + 1
break
else:
@@ -402,7 +385,7 @@ def get_frame_instruction(
def get_task_instruction(
frame_instruction_info: Dict[str, Any], priority_order: Optional[OrderedDict] = None
frame_instruction_info: dict[str, Any], priority_order: OrderedDict | None = None
) -> str:
"""Construct task instruction from available instruction fields using priority sampling.
@@ -450,13 +433,13 @@ def get_task_instruction(
def get_wallx_normal_text(
instruction_info: Dict[str, Any],
instruction_info: dict[str, Any],
action_chunk_size: int,
frame_idx: int,
priority_order: Optional[OrderedDict] = None,
img_keys: Optional[List[str]] = None,
priority_order: OrderedDict | None = None,
img_keys: list[str] | None = None,
generate_subtask_ratio: float = 0.0,
) -> Tuple[str, bool]:
) -> tuple[str, bool]:
"""Construct complete multimodal prompt text for Wall-X model.
Formats input using special tokens including:
@@ -488,9 +471,7 @@ def get_wallx_normal_text(
action_fast_symbol = "<|action_fast|>"
# System prologue
prologue = (
f"{role_start_symbol}system\nYou are a helpful assistant.{role_end_symbol}\n"
)
prologue = f"{role_start_symbol}system\nYou are a helpful assistant.{role_end_symbol}\n"
# User request with observation
user_request = f"{role_start_symbol}user\nObservation:"
@@ -501,9 +482,7 @@ def get_wallx_normal_text(
user_request += "\nInstruction:"
# Get frame-specific instruction
frame_instruction_info, _ = get_frame_instruction(
instruction_info, frame_idx=frame_idx
)
frame_instruction_info, _ = get_frame_instruction(instruction_info, frame_idx=frame_idx)
generate_subtask = False
priority_keys = ["subtask_generation", "distribute"]
@@ -524,15 +503,11 @@ def get_wallx_normal_text(
output_instruction = frame_instruction_info[key]
break
assistant_output = (
f"{role_start_symbol}assistant\n{output_instruction}\n{role_end_symbol}"
)
assistant_output = f"{role_start_symbol}assistant\n{output_instruction}\n{role_end_symbol}"
generate_subtask = True
else:
# Generate actions
instruction = get_task_instruction(
frame_instruction_info, priority_order=priority_order
)
instruction = get_task_instruction(frame_instruction_info, priority_order=priority_order)
text_prompt = f"\nPredict the next action in robot action.\nProprioception: {propri_symbol}\n"
user_message = f"{user_request} {instruction}{text_prompt}{role_end_symbol}\n"
assistant_output = f"{role_start_symbol}assistant\n{action_fast_symbol}{role_end_symbol}\n{action_symbol * action_chunk_size}"
@@ -540,7 +515,8 @@ def get_wallx_normal_text(
complete_text = prologue + user_message + assistant_output
return complete_text, generate_subtask
def img_key_mapping(img_keys: List[str]) -> List[str]:
def img_key_mapping(img_keys: list[str]) -> list[str]:
"""Map image keys to camera names.
Args:
@@ -555,16 +531,15 @@ def img_key_mapping(img_keys: List[str]) -> List[str]:
if key in CAMERA_NAME_MAPPING:
key = CAMERA_NAME_MAPPING[key]
else:
if 'view' in key:
key = key.replace('_', ' ')
if "view" in key:
key = key.replace("_", " ")
else:
key = key + " view"
processed_img_keys.append(key)
return processed_img_keys
def get_action_tokens(
normalized_actions: Union[torch.Tensor, List], action_tokenizer
) -> List[List[str]]:
def get_action_tokens(normalized_actions: torch.Tensor | list, action_tokenizer) -> list[list[str]]:
"""Convert normalized actions to action token strings.
Args:
@@ -590,8 +565,8 @@ def get_action_tokens(
def pad_action_token_strs(
actions_token_lists: List[List[str]], pad_token: str = "<|endoftext|>"
) -> List[str]:
actions_token_lists: list[list[str]], pad_token: str = "<|endoftext|>"
) -> list[str]:
"""Pad action token lists to same length and join as strings.
Args:
@@ -605,20 +580,18 @@ def pad_action_token_strs(
padded_action_strs = []
for tokens in actions_token_lists:
padded_tokens = (
tokens + ["<|im_end|>\n"] + [pad_token] * (max_len - len(tokens))
)
padded_tokens = tokens + ["<|im_end|>\n"] + [pad_token] * (max_len - len(tokens))
padded_action_strs.append("".join(padded_tokens))
return padded_action_strs
def replace_action_token(
text: List[str],
norm_action: Optional[torch.Tensor],
text: list[str],
norm_action: torch.Tensor | None,
action_tokenizer,
dof_masks: Optional[torch.Tensor] = None,
) -> List[str]:
dof_masks: torch.Tensor | None = None,
) -> list[str]:
"""Replace action placeholders in text with actual action tokens.
Args:
@@ -632,10 +605,7 @@ def replace_action_token(
"""
if action_tokenizer is not None and norm_action is not None:
# Extract actions based on chunk sizes and DOF masks
norm_action = [
action[: 32, dof_masks[i, 0].bool()]
for i, action in enumerate(norm_action)
]
norm_action = [action[:32, dof_masks[i, 0].bool()] for i, action in enumerate(norm_action)]
# Convert to action tokens and pad
actions_fast_tokens = get_action_tokens(norm_action, action_tokenizer)
@@ -658,4 +628,3 @@ def replace_action_token(
text = [t.replace("<|action_fast|><|im_end|>\n", "") for t in text]
return text
+4 -1
View File
@@ -35,10 +35,11 @@ from lerobot.policies.wall_x import ( # noqa: E402
)
from lerobot.utils.random_utils import set_seed # noqa: E402
def test_policy_instantiation():
# Create config
set_seed(42)
config = WallXConfig(device='cuda')
config = WallXConfig(device="cuda")
# Set up input_features and output_features in the config
from lerobot.configs.types import FeatureType, PolicyFeature
@@ -118,6 +119,7 @@ def test_policy_instantiation():
print(f"Action prediction failed: {e}")
raise
def test_config_creation():
"""Test policy config creation through factory."""
try:
@@ -130,6 +132,7 @@ def test_config_creation():
print(f"Config creation failed: {e}")
raise
if __name__ == "__main__":
test_policy_instantiation()
test_config_creation()