fix pre-commit errors

This commit is contained in:
Geoffrey19
2025-12-17 20:27:14 +08:00
committed by Michel Aractingi
parent 9ce6dd9e25
commit a0c9a7d85d
8 changed files with 577 additions and 955 deletions
+2 -2
View File
@@ -7,11 +7,12 @@ This repository contains the Hugging Face port of **WALL-OSS**, a Vision-Languag
## Model Overview ## Model Overview
| Feature | Description | | Feature | Description |
| -------------------- | ------------------------------------------------------------------------ | | ------------------ | ----------------------------------------------------- | --- |
| Base Model | Qwen2.5-VL (Vision-Language Model) | | Base Model | Qwen2.5-VL (Vision-Language Model) |
| Action Prediction | Flow Matching (diffusion) or FAST (discrete tokens) | | Action Prediction | Flow Matching (diffusion) or FAST (discrete tokens) |
| Architecture | Mixture of Experts (MoE) with action-specific routing | | | Architecture | Mixture of Experts (MoE) with action-specific routing | |
| Multi-Modal Inputs | Vision (images/videos), Language, Proprioception | | Multi-Modal Inputs | Vision (images/videos), Language, Proprioception |
--- ---
## Citation ## Citation
@@ -32,4 +33,3 @@ If you use this work, please cite:
## License ## License
This port follows the **Apache 2.0 License**. This port follows the **Apache 2.0 License**.
@@ -85,9 +85,7 @@ class WallXConfig(PreTrainedConfig):
) )
if self.prediction_mode not in ["diffusion", "fast"]: if self.prediction_mode not in ["diffusion", "fast"]:
raise ValueError( raise ValueError(f"prediction_mode must be 'diffusion' or 'fast', got {self.prediction_mode}")
f"prediction_mode must be 'diffusion' or 'fast', got {self.prediction_mode}"
)
# Assign use_fast_tokenizer based on prediction_mode # Assign use_fast_tokenizer based on prediction_mode
if self.prediction_mode == "fast": if self.prediction_mode == "fast":
@@ -96,9 +94,7 @@ class WallXConfig(PreTrainedConfig):
self.use_fast_tokenizer = False self.use_fast_tokenizer = False
self.action_tokenizer_path = None # disable action tokenizer for diffusion mode self.action_tokenizer_path = None # disable action tokenizer for diffusion mode
else: else:
raise ValueError( raise ValueError(f"prediction_mode must be 'diffusion' or 'fast', got {self.prediction_mode}")
f"prediction_mode must be 'diffusion' or 'fast', got {self.prediction_mode}"
)
def validate_features(self) -> None: def validate_features(self) -> None:
"""Validate and set up input/output features.""" """Validate and set up input/output features."""
+194 -318
View File
@@ -34,61 +34,57 @@ lerobot-train \
``` ```
""" """
import math import math
from os import PathLike
from collections import deque from collections import deque
from typing import Any, Dict, List, Optional, Tuple, Union from os import PathLike
from PIL import Image from typing import Any
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from peft import LoraConfig, get_peft_model from peft import LoraConfig, get_peft_model
from PIL import Image
from qwen_vl_utils.vision_process import smart_resize
from torch import Tensor from torch import Tensor
from torch.distributions import Beta from torch.distributions import Beta
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
from torchdiffeq import odeint from torchdiffeq import odeint
from transformers import AutoProcessor from transformers import AutoProcessor, BatchFeature
from transformers.cache_utils import ( from transformers.cache_utils import (
StaticCache, StaticCache,
) )
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
Qwen2_5_VLForConditionalGeneration,
)
from transformers.utils import is_torchdynamo_compiling, logging from transformers.utils import is_torchdynamo_compiling, logging
from transformers import AutoProcessor, BatchFeature
from qwen_vl_utils.vision_process import smart_resize
from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.utils import populate_queues from lerobot.policies.utils import populate_queues
from lerobot.policies.wall_x.configuration_wall_x import WallXConfig from lerobot.policies.wall_x.configuration_wall_x import WallXConfig
from lerobot.utils.constants import ACTION, OBS_STATE
from lerobot.policies.wall_x.utils import (
replace_action_token,
preprocesser_call,
get_wallx_normal_text,
process_grounding_points,
)
from lerobot.policies.wall_x.constant import ( from lerobot.policies.wall_x.constant import (
MODEL_TYPE,
TOKENIZER_MAX_LENGTH,
PRIORITY_ORDER,
GENERATE_SUBTASK_RATIO, GENERATE_SUBTASK_RATIO,
RESOLUTION, IMAGE_FACTOR,
MAX_PIXELS, MAX_PIXELS,
MIN_PIXELS, MIN_PIXELS,
IMAGE_FACTOR, MODEL_TYPE,
PRIORITY_ORDER,
RESOLUTION,
TOKENIZER_MAX_LENGTH,
) )
from lerobot.policies.wall_x.qwen_model.configuration_qwen2_5_vl import Qwen2_5_VLConfig from lerobot.policies.wall_x.qwen_model.configuration_qwen2_5_vl import Qwen2_5_VLConfig
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
Qwen2_5_VLForConditionalGeneration,
)
from lerobot.policies.wall_x.qwen_model.qwen2_5_vl_moe import ( from lerobot.policies.wall_x.qwen_model.qwen2_5_vl_moe import (
Qwen2_5_VisionTransformerPretrainedModel, Qwen2_5_VisionTransformerPretrainedModel,
Qwen2_5_VLACausalLMOutputWithPast, Qwen2_5_VLACausalLMOutputWithPast,
Qwen2_5_VLMoEModel, Qwen2_5_VLMoEModel,
) )
from lerobot.policies.wall_x.utils import (
get_wallx_normal_text,
preprocesser_call,
process_grounding_points,
replace_action_token,
)
from lerobot.utils.constants import ACTION, OBS_STATE
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@@ -151,7 +147,7 @@ class ActionHead(nn.Module):
"""Sample timesteps using Beta distribution (always in float32 for numerical stability).""" """Sample timesteps using Beta distribution (always in float32 for numerical stability)."""
beta_dist = Beta( beta_dist = Beta(
torch.tensor(self.beta_alpha, dtype=torch.float32, device=device), torch.tensor(self.beta_alpha, dtype=torch.float32, device=device),
torch.tensor(self.beta_beta, dtype=torch.float32, device=device) torch.tensor(self.beta_beta, dtype=torch.float32, device=device),
) )
sample = beta_dist.sample([batch_size]) sample = beta_dist.sample([batch_size])
time = (1 - sample) * self.s time = (1 - sample) * self.s
@@ -275,14 +271,14 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
pretrained_name_or_path, pretrained_name_or_path,
config=None, config=None,
action_tokenizer_path=None, action_tokenizer_path=None,
attn_implementation: str = 'eager', attn_implementation: str = "eager",
cache_dir: str | PathLike | None = None, cache_dir: str | PathLike | None = None,
force_download: bool = False, force_download: bool = False,
local_files_only: bool = False, local_files_only: bool = False,
token: str | bool | None = None, token: str | bool | None = None,
revision: str = "main", revision: str = "main",
strict: bool = False, strict: bool = False,
**kwargs: Any **kwargs: Any,
): ):
""" """
Load model from pretrained model path. Load model from pretrained model path.
@@ -312,9 +308,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
config._attn_implementation = attn_implementation config._attn_implementation = attn_implementation
processor = AutoProcessor.from_pretrained(pretrained_name_or_path, use_fast=True) processor = AutoProcessor.from_pretrained(pretrained_name_or_path, use_fast=True)
if action_tokenizer_path is not None: if action_tokenizer_path is not None:
action_tokenizer = AutoProcessor.from_pretrained( action_tokenizer = AutoProcessor.from_pretrained(action_tokenizer_path, trust_remote_code=True)
action_tokenizer_path, trust_remote_code=True
)
processor.action_processor = action_tokenizer processor.action_processor = action_tokenizer
else: else:
action_tokenizer = None action_tokenizer = None
@@ -387,9 +381,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
super().__init__(config) super().__init__(config)
# Initialize vision transformer and language model components # Initialize vision transformer and language model components
self.visual = Qwen2_5_VisionTransformerPretrainedModel._from_config( self.visual = Qwen2_5_VisionTransformerPretrainedModel._from_config(config.vision_config)
config.vision_config
)
self.model = Qwen2_5_VLMoEModel(config) self.model = Qwen2_5_VLMoEModel(config)
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
@@ -446,12 +438,8 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
# Create list of fast action token IDs # Create list of fast action token IDs
fast_action_token_list = [] fast_action_token_list = []
if self.use_fast_tokenizer: if self.use_fast_tokenizer:
for i in range( for i in range(self.processor.tokenizer.init_kwargs["action_token_vocab_size"]):
self.processor.tokenizer.init_kwargs["action_token_vocab_size"] action_token_id = self.processor.tokenizer.convert_tokens_to_ids(f"<|action_token_{i}|>")
):
action_token_id = self.processor.tokenizer.convert_tokens_to_ids(
f"<|action_token_{i}|>"
)
fast_action_token_list.append(action_token_id) fast_action_token_list.append(action_token_id)
# Get special action token IDs # Get special action token IDs
@@ -465,9 +453,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
"action_token_id": action_token_id, "action_token_id": action_token_id,
} }
def add_lora( def add_lora(self, r=8, lora_alpha=32, target_modules=["q_proj", "v_proj"], lora_dropout=0.1):
self, r=8, lora_alpha=32, target_modules=["q_proj", "v_proj"], lora_dropout=0.1
):
""" """
Add LoRA (Low-Rank Adaptation) adapters to the model. Add LoRA (Low-Rank Adaptation) adapters to the model.
@@ -516,12 +502,12 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
def get_rope_index( def get_rope_index(
self, self,
input_ids: Optional[torch.LongTensor] = None, input_ids: torch.LongTensor | None = None,
image_grid_thw: Optional[torch.LongTensor] = None, image_grid_thw: torch.LongTensor | None = None,
video_grid_thw: Optional[torch.LongTensor] = None, video_grid_thw: torch.LongTensor | None = None,
second_per_grid_ts: Optional[torch.Tensor] = None, second_per_grid_ts: torch.Tensor | None = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: torch.Tensor | None = None,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
""" """
Calculate 3D RoPE (Rotary Position Embedding) indices for vision and text tokens. Calculate 3D RoPE (Rotary Position Embedding) indices for vision and text tokens.
@@ -555,9 +541,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
vision_start_token_id = self.config.vision_start_token_id vision_start_token_id = self.config.vision_start_token_id
mrope_position_deltas = [] mrope_position_deltas = []
if input_ids is not None and ( if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None):
image_grid_thw is not None or video_grid_thw is not None
):
total_input_ids = input_ids total_input_ids = input_ids
if attention_mask is None: if attention_mask is None:
attention_mask = torch.ones_like(total_input_ids) attention_mask = torch.ones_like(total_input_ids)
@@ -580,9 +564,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
image_nums, video_nums = 0, 0 image_nums, video_nums = 0, 0
# Find vision tokens and count images/videos # Find vision tokens and count images/videos
vision_start_indices = torch.argwhere( vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1)
input_ids == vision_start_token_id
).squeeze(1)
vision_tokens = input_ids[vision_start_indices + 1] vision_tokens = input_ids[vision_start_indices + 1]
image_nums = (vision_tokens == image_token_id).sum() image_nums = (vision_tokens == image_token_id).sum()
video_nums = (vision_tokens == video_token_id).sum() video_nums = (vision_tokens == video_token_id).sum()
@@ -641,14 +623,8 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
text_len = ed - st text_len = ed - st
# Add position IDs for text tokens before vision token # Add position IDs for text tokens before vision token
st_idx = ( st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
llm_pos_ids_list[-1].max() + 1 llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
if len(llm_pos_ids_list) > 0
else 0
)
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
)
# Calculate 3D position embeddings for vision tokens # Calculate 3D position embeddings for vision tokens
range_tensor = torch.arange(llm_grid_t).view(-1, 1) range_tensor = torch.arange(llm_grid_t).view(-1, 1)
@@ -656,71 +632,43 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
# Calculate temporal position IDs with time scaling # Calculate temporal position IDs with time scaling
time_tensor = ( time_tensor = (
expanded_range expanded_range * second_per_grid_t * self.config.vision_config.tokens_per_second
* second_per_grid_t
* self.config.vision_config.tokens_per_second
) )
time_tensor_long = time_tensor.long() time_tensor_long = time_tensor.long()
t_index = time_tensor_long.flatten() t_index = time_tensor_long.flatten()
# Calculate spatial position IDs # Calculate spatial position IDs
h_index = ( h_index = (
torch.arange(llm_grid_h) torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()
.view(1, -1, 1)
.expand(llm_grid_t, -1, llm_grid_w)
.flatten()
) )
w_index = ( w_index = (
torch.arange(llm_grid_w) torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()
.view(1, 1, -1)
.expand(llm_grid_t, llm_grid_h, -1)
.flatten()
) )
# Add 3D position IDs for vision tokens # Add 3D position IDs for vision tokens
llm_pos_ids_list.append( llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
torch.stack([t_index, h_index, w_index]) + text_len + st_idx
)
st = ed + llm_grid_t * llm_grid_h * llm_grid_w st = ed + llm_grid_t * llm_grid_h * llm_grid_w
# Add position IDs for remaining text tokens # Add position IDs for remaining text tokens
if st < len(input_tokens): if st < len(input_tokens):
st_idx = ( st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
llm_pos_ids_list[-1].max() + 1
if len(llm_pos_ids_list) > 0
else 0
)
text_len = len(input_tokens) - st text_len = len(input_tokens) - st
llm_pos_ids_list.append( llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
)
# Concatenate all position IDs for this sequence # Concatenate all position IDs for this sequence
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
position_ids[..., i, attention_mask[i] == 1] = llm_positions.to( position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device)
position_ids.device mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i]))
)
mrope_position_deltas.append(
llm_positions.max() + 1 - len(total_input_ids[i])
)
mrope_position_deltas = torch.tensor( mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1)
mrope_position_deltas, device=input_ids.device
).unsqueeze(1)
return position_ids, mrope_position_deltas return position_ids, mrope_position_deltas
else: else:
# Handle case without vision tokens - use standard 1D position embeddings # Handle case without vision tokens - use standard 1D position embeddings
if attention_mask is not None: if attention_mask is not None:
position_ids = attention_mask.long().cumsum(-1) - 1 position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1) position_ids.masked_fill_(attention_mask == 0, 1)
position_ids = ( position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device)
position_ids.unsqueeze(0) max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0]
.expand(3, -1, -1)
.to(attention_mask.device)
)
max_position_ids = position_ids.max(0, keepdim=False)[0].max(
-1, keepdim=True
)[0]
mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
else: else:
position_ids = ( position_ids = (
@@ -739,33 +687,29 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
def train_step_forward( def train_step_forward(
self, self,
input_ids: torch.LongTensor = None, input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: torch.Tensor | None = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: torch.LongTensor | None = None,
past_key_values: Optional[List[torch.FloatTensor]] = None, past_key_values: list[torch.FloatTensor] | None = None,
inputs_embeds: Optional[torch.FloatTensor] = None, inputs_embeds: torch.FloatTensor | None = None,
moe_token_types: Optional[ moe_token_types: torch.LongTensor | None = None, # MoE token type assignments
torch.LongTensor labels: torch.LongTensor | None = None,
] = None, # MoE token type assignments use_cache: bool | None = None,
labels: Optional[torch.LongTensor] = None, output_attentions: bool | None = None,
use_cache: Optional[bool] = None, output_hidden_states: bool | None = None,
output_attentions: Optional[bool] = None, return_dict: bool | None = None,
output_hidden_states: Optional[bool] = None, pixel_values: torch.Tensor | None = None,
return_dict: Optional[bool] = None, pixel_values_videos: torch.FloatTensor | None = None,
pixel_values: Optional[torch.Tensor] = None, image_grid_thw: torch.LongTensor | None = None,
pixel_values_videos: Optional[torch.FloatTensor] = None, video_grid_thw: torch.LongTensor | None = None,
image_grid_thw: Optional[torch.LongTensor] = None, action_chunk: torch.FloatTensor | None = None, # Action trajectory chunks
video_grid_thw: Optional[torch.LongTensor] = None, proprioception: torch.FloatTensor | None = None, # Joint position/orientation data
action_chunk: Optional[torch.FloatTensor] = None, # Action trajectory chunks rope_deltas: torch.LongTensor | None = None,
proprioception: Optional[ cache_position: torch.LongTensor | None = None,
torch.FloatTensor second_per_grid_ts: torch.Tensor | None = None,
] = None, # Joint position/orientation data dof_mask: torch.FloatTensor | None = None,
rope_deltas: Optional[torch.LongTensor] = None, agent_pos_mask: torch.FloatTensor | None = None,
cache_position: Optional[torch.LongTensor] = None,
second_per_grid_ts: Optional[torch.Tensor] = None,
dof_mask: Optional[torch.FloatTensor] = None,
agent_pos_mask: Optional[torch.FloatTensor] = None,
**kwargs, **kwargs,
) -> Union[Tuple, Qwen2_5_VLACausalLMOutputWithPast]: ) -> tuple | Qwen2_5_VLACausalLMOutputWithPast:
""" """
Forward pass for training with multi-modal inputs including vision, text, and action data. Forward pass for training with multi-modal inputs including vision, text, and action data.
@@ -806,24 +750,16 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
# Set output configuration from model config if not specified # Set output configuration from model config if not specified
output_attentions = ( output_attentions = (
output_attentions output_attentions if output_attentions is not None else self.config.output_attentions
if output_attentions is not None
else self.config.output_attentions
) )
output_hidden_states = ( output_hidden_states = (
output_hidden_states output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
) )
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# Calculate RoPE position IDs if not provided # Calculate RoPE position IDs if not provided
# Note: Cannot calculate rope deltas with 4D attention mask. TODO: Fix this limitation # Note: Cannot calculate rope deltas with 4D attention mask. TODO: Fix this limitation
if position_ids is None and ( if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
attention_mask is None or attention_mask.ndim == 2
):
# Calculate RoPE index once per generation in the pre-fill stage only # Calculate RoPE index once per generation in the pre-fill stage only
if ( if (
(cache_position is not None and cache_position[0] == 0) (cache_position is not None and cache_position[0] == 0)
@@ -865,9 +801,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
image_mask = mask_expanded.to(inputs_embeds.device) image_mask = mask_expanded.to(inputs_embeds.device)
image_embeds = image_embeds.to( image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds.device, inputs_embeds.dtype
)
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
# Process video embeddings # Process video embeddings
@@ -887,19 +821,13 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
video_mask = mask_expanded.to(inputs_embeds.device) video_mask = mask_expanded.to(inputs_embeds.device)
video_embeds = video_embeds.to( video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds.device, inputs_embeds.dtype
)
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
# Process proprioceptive data (joint positions, orientations, etc.) # Process proprioceptive data (joint positions, orientations, etc.)
if proprioception is not None: if proprioception is not None:
proprioception = proprioception.to(inputs_embeds.device).to( proprioception = proprioception.to(inputs_embeds.device).to(inputs_embeds.dtype)
inputs_embeds.dtype agent_pos_mask = agent_pos_mask.to(inputs_embeds.device).to(inputs_embeds.dtype)
)
agent_pos_mask = agent_pos_mask.to(inputs_embeds.device).to(
inputs_embeds.dtype
)
proprioception = self.action_preprocessor.proprioception_proj( proprioception = self.action_preprocessor.proprioception_proj(
proprioception, proprioception,
agent_pos_mask, agent_pos_mask,
@@ -910,12 +838,8 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
proprioception_mask = mask_expanded.to(inputs_embeds.device) proprioception_mask = mask_expanded.to(inputs_embeds.device)
proprioception = proprioception.to( proprioception = proprioception.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds.device, inputs_embeds.dtype inputs_embeds = inputs_embeds.masked_scatter(proprioception_mask, proprioception)
)
inputs_embeds = inputs_embeds.masked_scatter(
proprioception_mask, proprioception
)
elif self.training: elif self.training:
# Dummy forward pass to ensure gradient registration in DDP # Dummy forward pass to ensure gradient registration in DDP
# This handles cases where one process has proprioception data while another doesn't # This handles cases where one process has proprioception data while another doesn't
@@ -925,32 +849,22 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
self.action_preprocessor.propri_dim * 2, self.action_preprocessor.propri_dim * 2,
device=inputs_embeds.device, device=inputs_embeds.device,
) )
dummy_forward = self.action_preprocessor.proprioception_proj( dummy_forward = self.action_preprocessor.proprioception_proj(dummy_input)
dummy_input
)
dummy_loss = sum(p.sum() for p in dummy_forward) dummy_loss = sum(p.sum() for p in dummy_forward)
inputs_embeds = inputs_embeds + 0 * dummy_loss inputs_embeds = inputs_embeds + 0 * dummy_loss
# Process action chunk data # Process action chunk data
if action_chunk is not None: if action_chunk is not None:
action_chunk = action_chunk.to(inputs_embeds.device).to( action_chunk = action_chunk.to(inputs_embeds.device).to(inputs_embeds.dtype)
inputs_embeds.dtype
)
dof_mask = dof_mask.to(inputs_embeds.device).to(inputs_embeds.dtype) dof_mask = dof_mask.to(inputs_embeds.device).to(inputs_embeds.dtype)
noisy_action_emb, flow = self.action_preprocessor( noisy_action_emb, flow = self.action_preprocessor(action_chunk, dof_mask)
action_chunk, dof_mask
)
mask = input_ids == self.action_token_id_set["action_token_id"] mask = input_ids == self.action_token_id_set["action_token_id"]
mask_unsqueezed = mask.unsqueeze(-1) mask_unsqueezed = mask.unsqueeze(-1)
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
action_mask = mask_expanded.to(inputs_embeds.device) action_mask = mask_expanded.to(inputs_embeds.device)
noisy_action_emb = noisy_action_emb.to( noisy_action_emb = noisy_action_emb.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds.device, inputs_embeds.dtype inputs_embeds = inputs_embeds.masked_scatter(action_mask, noisy_action_emb)
)
inputs_embeds = inputs_embeds.masked_scatter(
action_mask, noisy_action_emb
)
if attention_mask is not None: if attention_mask is not None:
attention_mask = attention_mask.to(inputs_embeds.device) attention_mask = attention_mask.to(inputs_embeds.device)
@@ -1011,18 +925,14 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
if action_mask.any(): if action_mask.any():
action_hidden_states = hidden_states[action_mask].to(torch.float32) action_hidden_states = hidden_states[action_mask].to(torch.float32)
flow = flow.reshape(-1, flow.shape[-1]).to(torch.float32) flow = flow.reshape(-1, flow.shape[-1]).to(torch.float32)
_flow_loss = self.action_preprocessor.flow_loss( _flow_loss = self.action_preprocessor.flow_loss(action_hidden_states, flow, dof_mask)
action_hidden_states, flow, dof_mask
)
if isinstance(_flow_loss, torch.Tensor): if isinstance(_flow_loss, torch.Tensor):
flow_loss = _flow_loss.mean() flow_loss = _flow_loss.mean()
if loss is not None: if loss is not None:
loss = loss + self.flow_loss_weight * flow_loss.to(torch.float32) loss = loss + self.flow_loss_weight * flow_loss.to(torch.float32)
else: else:
loss = self.flow_loss_weight * flow_loss.to(torch.float32) loss = self.flow_loss_weight * flow_loss.to(torch.float32)
_flow_loss = _flow_loss.view( _flow_loss = _flow_loss.view(dof_mask.shape[0], dof_mask.shape[1], dof_mask.shape[2])
dof_mask.shape[0], dof_mask.shape[1], dof_mask.shape[2]
)
# Return outputs based on return_dict setting # Return outputs based on return_dict setting
if not return_dict: if not return_dict:
@@ -1031,9 +941,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
return Qwen2_5_VLACausalLMOutputWithPast( return Qwen2_5_VLACausalLMOutputWithPast(
loss=loss, loss=loss,
cross_entropy_loss=( cross_entropy_loss=(cross_entropy_loss.clone() if cross_entropy_loss is not None else None),
cross_entropy_loss.clone() if cross_entropy_loss is not None else None
),
flow_loss=flow_loss, flow_loss=flow_loss,
logits=logits, logits=logits,
past_key_values=outputs.past_key_values, past_key_values=outputs.past_key_values,
@@ -1065,31 +973,31 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
def predict( def predict(
self, self,
predict_mode: str, predict_mode: str,
pred_horizon: Optional[int] = None, pred_horizon: int | None = None,
action_dim: Optional[int] = None, action_dim: int | None = None,
input_ids: torch.LongTensor = None, input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: torch.Tensor | None = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: torch.LongTensor | None = None,
past_key_values: Optional[List[torch.FloatTensor]] = None, past_key_values: list[torch.FloatTensor] | None = None,
inputs_embeds: Optional[torch.FloatTensor] = None, inputs_embeds: torch.FloatTensor | None = None,
moe_token_types: Optional[torch.LongTensor] = None, moe_token_types: torch.LongTensor | None = None,
labels: Optional[torch.LongTensor] = None, labels: torch.LongTensor | None = None,
use_cache: Optional[bool] = None, use_cache: bool | None = None,
output_attentions: Optional[bool] = None, output_attentions: bool | None = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: bool | None = None,
return_dict: Optional[bool] = None, return_dict: bool | None = None,
pixel_values: Optional[torch.Tensor] = None, pixel_values: torch.Tensor | None = None,
pixel_values_videos: Optional[torch.FloatTensor] = None, pixel_values_videos: torch.FloatTensor | None = None,
image_grid_thw: Optional[torch.LongTensor] = None, image_grid_thw: torch.LongTensor | None = None,
video_grid_thw: Optional[torch.LongTensor] = None, video_grid_thw: torch.LongTensor | None = None,
action_chunk: Optional[torch.FloatTensor] = None, action_chunk: torch.FloatTensor | None = None,
proprioception: Optional[torch.FloatTensor] = None, proprioception: torch.FloatTensor | None = None,
rope_deltas: Optional[torch.LongTensor] = None, rope_deltas: torch.LongTensor | None = None,
cache_position: Optional[torch.LongTensor] = None, cache_position: torch.LongTensor | None = None,
second_per_grid_ts: Optional[torch.Tensor] = None, second_per_grid_ts: torch.Tensor | None = None,
num_inference_timesteps: Optional[int] = 10, num_inference_timesteps: int | None = 10,
dof_mask: Optional[torch.FloatTensor] = None, dof_mask: torch.FloatTensor | None = None,
agent_pos_mask: Optional[torch.FloatTensor] = None, agent_pos_mask: torch.FloatTensor | None = None,
re_generate: bool = False, re_generate: bool = False,
**kwargs, **kwargs,
): ):
@@ -1139,30 +1047,20 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
- 'predict_output_text': Generated text (for text/fast modes) - 'predict_output_text': Generated text (for text/fast modes)
- 'gt_output_text': Ground truth text (for text/fast modes) - 'gt_output_text': Ground truth text (for text/fast modes)
""" """
batch_size = ( batch_size = input_ids.shape[0] if input_ids is not None else inputs_embeds.shape[0]
input_ids.shape[0] if input_ids is not None else inputs_embeds.shape[0]
)
# Text and fast modes require batch size 1 for autoregressive generation # Text and fast modes require batch size 1 for autoregressive generation
if predict_mode in ["text", "fast"]: if predict_mode in ["text", "fast"]:
assert ( assert batch_size == 1, "predict only support batch size 1 for ar generation"
batch_size == 1
), "predict only support batch size 1 for ar generation"
# Set output configuration from model config if not specified # Set output configuration from model config if not specified
output_attentions = ( output_attentions = (
output_attentions output_attentions if output_attentions is not None else self.config.output_attentions
if output_attentions is not None
else self.config.output_attentions
) )
output_hidden_states = ( output_hidden_states = (
output_hidden_states output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
) )
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# Process input embeddings with multi-modal data # Process input embeddings with multi-modal data
if inputs_embeds is None: if inputs_embeds is None:
@@ -1186,9 +1084,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
image_mask = mask_expanded.to(inputs_embeds.device) image_mask = mask_expanded.to(inputs_embeds.device)
image_embeds = image_embeds.to( image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds.device, inputs_embeds.dtype
)
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
# Process video embeddings # Process video embeddings
@@ -1209,40 +1105,28 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
video_mask = mask_expanded.to(inputs_embeds.device) video_mask = mask_expanded.to(inputs_embeds.device)
video_embeds = video_embeds.to( video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds.device, inputs_embeds.dtype
)
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
# Process proprioceptive data # Process proprioceptive data
if proprioception is not None: if proprioception is not None:
proprioception = proprioception.to(inputs_embeds.device).to( proprioception = proprioception.to(inputs_embeds.device).to(inputs_embeds.dtype)
inputs_embeds.dtype agent_pos_mask = agent_pos_mask.to(inputs_embeds.device).to(inputs_embeds.dtype)
)
agent_pos_mask = agent_pos_mask.to(inputs_embeds.device).to(
inputs_embeds.dtype
)
proprio_embed = self.action_preprocessor.proprioception_proj( proprio_embed = self.action_preprocessor.proprioception_proj(
proprioception, proprioception,
agent_pos_mask, agent_pos_mask,
use_history=proprioception.shape[1] > 1, use_history=proprioception.shape[1] > 1,
) )
proprioception_mask = ( proprioception_mask = input_ids == self.action_token_id_set["propri_token_id"]
input_ids == self.action_token_id_set["propri_token_id"]
)
proprio_embed = proprio_embed.to(torch.bfloat16) proprio_embed = proprio_embed.to(torch.bfloat16)
inputs_embeds[proprioception_mask] = proprio_embed.reshape( inputs_embeds[proprioception_mask] = proprio_embed.reshape(-1, inputs_embeds.shape[-1])
-1, inputs_embeds.shape[-1]
)
if attention_mask is not None: if attention_mask is not None:
attention_mask = attention_mask.to(inputs_embeds.device) attention_mask = attention_mask.to(inputs_embeds.device)
# Calculate RoPE position IDs if not provided # Calculate RoPE position IDs if not provided
# Note: Cannot calculate rope deltas with 4D attention mask. TODO: Fix this limitation # Note: Cannot calculate rope deltas with 4D attention mask. TODO: Fix this limitation
if position_ids is None and ( if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
attention_mask is None or attention_mask.ndim == 2
):
# Calculate RoPE index once per generation in the pre-fill stage only # Calculate RoPE index once per generation in the pre-fill stage only
if ( if (
(cache_position is not None and cache_position[0] == 0) (cache_position is not None and cache_position[0] == 0)
@@ -1332,12 +1216,8 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
eos_token_id=[self.processor.tokenizer.eos_token_id], eos_token_id=[self.processor.tokenizer.eos_token_id],
use_cache=True, use_cache=True,
pad_token_id=self.processor.tokenizer.pad_token_id, pad_token_id=self.processor.tokenizer.pad_token_id,
temperature=( temperature=(1.0 if not re_generate else 0.7), # Higher temperature for regeneration
1.0 if not re_generate else 0.7 do_sample=(False if not re_generate else True), # Enable sampling for regeneration
), # Higher temperature for regeneration
do_sample=(
False if not re_generate else True
), # Enable sampling for regeneration
) )
# Decode generated and ground truth text # Decode generated and ground truth text
@@ -1359,15 +1239,9 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
action_id = [] action_id = []
# Extract action tokens from generated sequence # Extract action tokens from generated sequence
for token_id_i in predict_output_ids[0]: for token_id_i in predict_output_ids[0]:
if ( if token_id_i.item() >= self.processor.tokenizer.init_kwargs["action_token_start_index"]:
token_id_i.item()
>= self.processor.tokenizer.init_kwargs["action_token_start_index"]
):
action_id.append( action_id.append(
token_id_i.item() token_id_i.item() - self.processor.tokenizer.init_kwargs["action_token_start_index"]
- self.processor.tokenizer.init_kwargs[
"action_token_start_index"
]
) )
predict_action = self.processor.action_processor.decode( predict_action = self.processor.action_processor.decode(
@@ -1382,9 +1256,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
predict_action = torch.tensor(predict_action, device=self.device) predict_action = torch.tensor(predict_action, device=self.device)
dof_mask = dof_mask.to(self.device).to(pixel_values.dtype) dof_mask = dof_mask.to(self.device).to(pixel_values.dtype)
# removed unnormalization step for now # removed unnormalization step for now
predict_action = ( predict_action = predict_action[:, :, dof_mask[0, 0, :].bool()]
predict_action[:, :, dof_mask[0, 0, :].bool()]
)
output["predict_action"] = predict_action output["predict_action"] = predict_action
# Process ground truth actions if available # Process ground truth actions if available
@@ -1459,7 +1331,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
times = torch.linspace( times = torch.linspace(
0, 0,
1, 1,
num_inference_timesteps+1, num_inference_timesteps + 1,
device=inputs_embeds.device, device=inputs_embeds.device,
dtype=torch.float32, dtype=torch.float32,
) )
@@ -1477,9 +1349,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
return output return output
def forward( def forward(self, mode: str | None = None, predict_mode: str | None = "text", **kwargs):
self, mode: Optional[str] = None, predict_mode: Optional[str] = "text", **kwargs
):
""" """
Main forward pass dispatcher for different execution modes. Main forward pass dispatcher for different execution modes.
@@ -1592,9 +1462,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
# Handle input slicing based on cache state and special cases # Handle input slicing based on cache state and special cases
if past_key_values is not None: if past_key_values is not None:
if ( if inputs_embeds is not None and input_ids.shape[1] == 0: # Exception 4: input_embeds case
inputs_embeds is not None and input_ids.shape[1] == 0
): # Exception 4: input_embeds case
inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :] inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :]
moe_token_types = moe_token_types[:, -cache_position.shape[0] :] moe_token_types = moe_token_types[:, -cache_position.shape[0] :]
elif inputs_embeds is not None or ( # Exception 1: input_embeds provided elif inputs_embeds is not None or ( # Exception 1: input_embeds provided
@@ -1602,9 +1470,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
): # Exception 3: GPU sync edge case ): # Exception 3: GPU sync edge case
input_ids = input_ids[:, -cache_position.shape[0] :] input_ids = input_ids[:, -cache_position.shape[0] :]
moe_token_types = moe_token_types[:, -cache_position.shape[0] :] moe_token_types = moe_token_types[:, -cache_position.shape[0] :]
elif ( elif input_ids.shape[1] != cache_position.shape[0]: # Default case (Exception 2 is no-op)
input_ids.shape[1] != cache_position.shape[0]
): # Default case (Exception 2 is no-op)
cache_pos = cache_position.clone() cache_pos = cache_position.clone()
input_ids = input_ids[:, cache_pos] input_ids = input_ids[:, cache_pos]
moe_token_types = moe_token_types[:, cache_pos] moe_token_types = moe_token_types[:, cache_pos]
@@ -1629,8 +1495,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
batch_size, sequence_length = input_ids.shape batch_size, sequence_length = input_ids.shape
device = input_ids.device device = input_ids.device
attention_mask = ( attention_mask = self.model._prepare_4d_causal_attention_mask_with_cache_position(
self.model._prepare_4d_causal_attention_mask_with_cache_position(
attention_mask, attention_mask,
sequence_length=sequence_length, sequence_length=sequence_length,
target_length=past_key_values.get_max_cache_shape(), target_length=past_key_values.get_max_cache_shape(),
@@ -1641,7 +1506,6 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
config=self.config, config=self.config,
past_key_values=past_key_values, past_key_values=past_key_values,
) )
)
# Assemble all model inputs for generation # Assemble all model inputs for generation
model_inputs.update( model_inputs.update(
@@ -1666,8 +1530,8 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
def _get_image_nums_and_video_nums( def _get_image_nums_and_video_nums(
self, self,
input_ids: Optional[torch.LongTensor], input_ids: torch.LongTensor | None,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
""" """
Get the number of images and videos for each sample to calculate tensor separation lengths. Get the number of images and videos for each sample to calculate tensor separation lengths.
@@ -1702,9 +1566,9 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
self, self,
expand_size: int = 1, expand_size: int = 1,
is_encoder_decoder: bool = False, is_encoder_decoder: bool = False,
input_ids: Optional[torch.LongTensor] = None, input_ids: torch.LongTensor | None = None,
**model_kwargs, **model_kwargs,
) -> Tuple[torch.LongTensor, Dict[str, Any]]: ) -> tuple[torch.LongTensor, dict[str, Any]]:
""" """
Expand inputs for generation with support for multi-modal tensors. Expand inputs for generation with support for multi-modal tensors.
@@ -1745,9 +1609,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
"""Split tensor by lengths and repeat each sample.""" """Split tensor by lengths and repeat each sample."""
samples = torch.split(x, lengths) samples = torch.split(x, lengths)
repeat_args = [repeat_times] + [1] * (x.dim() - 1) repeat_args = [repeat_times] + [1] * (x.dim() - 1)
result = torch.cat( result = torch.cat([sample.repeat(*repeat_args) for sample in samples], dim=0)
[sample.repeat(*repeat_args) for sample in samples], dim=0
)
return result return result
for key in dict_to_expand: for key in dict_to_expand:
@@ -1785,9 +1647,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
) )
tensor = torch.tensor(dict_to_expand[key]) tensor = torch.tensor(dict_to_expand[key])
lengths = list(video_nums) lengths = list(video_nums)
tensor = _repeat_interleave_samples( tensor = _repeat_interleave_samples(tensor, lengths=lengths, repeat_times=expand_size)
tensor, lengths=lengths, repeat_times=expand_size
)
dict_to_expand[key] = tensor.tolist() dict_to_expand[key] = tensor.tolist()
return dict_to_expand return dict_to_expand
@@ -1800,9 +1660,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
and isinstance(dict_to_expand[key], torch.Tensor) and isinstance(dict_to_expand[key], torch.Tensor)
and key not in visual_keys and key not in visual_keys
): ):
dict_to_expand[key] = dict_to_expand[key].repeat_interleave( dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0)
expand_size, dim=0
)
return dict_to_expand return dict_to_expand
# Expand visual inputs only if input_ids is available for counting images/videos # Expand visual inputs only if input_ids is available for counting images/videos
@@ -1823,9 +1681,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
raise ValueError( raise ValueError(
"If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined." "If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined."
) )
model_kwargs["encoder_outputs"] = _expand_dict_for_generation( model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"])
model_kwargs["encoder_outputs"]
)
return input_ids, model_kwargs return input_ids, model_kwargs
@@ -1850,7 +1706,7 @@ class WallXPolicy(PreTrainedPolicy):
self.model = Qwen2_5_VLMoEForAction.from_pretrained( self.model = Qwen2_5_VLMoEForAction.from_pretrained(
pretrained_name_or_path=config.pretrained_name_or_path, pretrained_name_or_path=config.pretrained_name_or_path,
action_tokenizer_path=config.action_tokenizer_path, action_tokenizer_path=config.action_tokenizer_path,
attn_implementation=config.attn_implementation attn_implementation=config.attn_implementation,
) )
self.model.to(config.device) self.model.to(config.device)
self.model.to_bfloat16_for_selected_params() self.model.to_bfloat16_for_selected_params()
@@ -1869,7 +1725,7 @@ class WallXPolicy(PreTrainedPolicy):
def preprocess_inputs( def preprocess_inputs(
self, self,
batch: Dict[str, Any], batch: dict[str, Any],
) -> BatchFeature: ) -> BatchFeature:
""" """
Convert a batch of LeRobot dataset items to Wall-X model input format. Convert a batch of LeRobot dataset items to Wall-X model input format.
@@ -1950,12 +1806,10 @@ class WallXPolicy(PreTrainedPolicy):
) )
text = process_grounding_points( text = process_grounding_points(
complete_text, orig_height, orig_width, resized_height, resized_width, complete_text, orig_height, orig_width, resized_height, resized_width, MODEL_TYPE
MODEL_TYPE
) )
all_texts.append(text) all_texts.append(text)
# ==================== PROCESS AGENT POS ==================== # ==================== PROCESS AGENT POS ====================
agent_pos = batch[OBS_STATE] # (batch_size, state_dim) agent_pos = batch[OBS_STATE] # (batch_size, state_dim)
if agent_pos.dim() == 2: if agent_pos.dim() == 2:
@@ -1965,17 +1819,28 @@ class WallXPolicy(PreTrainedPolicy):
if agent_pos.shape[-1] != 20: if agent_pos.shape[-1] != 20:
pad_size = 20 - agent_pos.shape[-1] pad_size = 20 - agent_pos.shape[-1]
agent_pos = torch.cat([ agent_pos = torch.cat(
[
agent_pos, agent_pos,
torch.zeros(agent_pos.shape[0], agent_pos.shape[1], pad_size, device=agent_pos.device) torch.zeros(agent_pos.shape[0], agent_pos.shape[1], pad_size, device=agent_pos.device),
], dim=-1) ],
agent_pos_mask = torch.cat([ dim=-1,
)
agent_pos_mask = torch.cat(
[
agent_pos_mask, agent_pos_mask,
torch.zeros(agent_pos_mask.shape[0], agent_pos_mask.shape[1], pad_size, device=agent_pos_mask.device) torch.zeros(
], dim=-1) agent_pos_mask.shape[0],
agent_pos_mask.shape[1],
pad_size,
device=agent_pos_mask.device,
),
],
dim=-1,
)
# ==================== PROCESS ACTIONS ==================== # ==================== PROCESS ACTIONS ====================
action = batch.get(ACTION, None) # (batch_size, chunk_size, action_dim) action = batch.get(ACTION) # (batch_size, chunk_size, action_dim)
if action is not None: if action is not None:
if action.dim() == 2: if action.dim() == 2:
action = action.unsqueeze(1) action = action.unsqueeze(1)
@@ -1984,20 +1849,30 @@ class WallXPolicy(PreTrainedPolicy):
if action.shape[-1] != 20: if action.shape[-1] != 20:
pad_size = 20 - action.shape[-1] pad_size = 20 - action.shape[-1]
action = torch.cat([ action = torch.cat(
action, [action, torch.zeros(action.shape[0], action.shape[1], pad_size, device=action.device)],
torch.zeros(action.shape[0], action.shape[1], pad_size, device=action.device) dim=-1,
], dim=-1) )
dof_mask = torch.cat([ dof_mask = torch.cat(
[
dof_mask, dof_mask,
torch.zeros(dof_mask.shape[0], dof_mask.shape[1], pad_size, device=dof_mask.device) torch.zeros(dof_mask.shape[0], dof_mask.shape[1], pad_size, device=dof_mask.device),
], dim=-1) ],
dim=-1,
)
else: else:
action_dim = self.config.output_features["action"].shape[0] action_dim = self.config.output_features["action"].shape[0]
dof_mask = torch.cat([ dof_mask = torch.cat(
torch.ones(batch_size, self.config.chunk_size, action_dim, device=batch[OBS_STATE].device), [
torch.zeros(batch_size, self.config.chunk_size, 20 - action_dim, device=batch[OBS_STATE].device) torch.ones(
], dim=-1) batch_size, self.config.chunk_size, action_dim, device=batch[OBS_STATE].device
),
torch.zeros(
batch_size, self.config.chunk_size, 20 - action_dim, device=batch[OBS_STATE].device
),
],
dim=-1,
)
# ==================== ACTION TOKEN REPLACEMENT ==================== # ==================== ACTION TOKEN REPLACEMENT ====================
all_texts = replace_action_token( all_texts = replace_action_token(
@@ -2028,7 +1903,11 @@ class WallXPolicy(PreTrainedPolicy):
inputs["action_chunk"] = action inputs["action_chunk"] = action
inputs["dof_mask"] = dof_mask inputs["dof_mask"] = dof_mask
inputs["moe_token_types"] = moe_token_types inputs["moe_token_types"] = moe_token_types
inputs["frame_index"] = batch["frame_index"] if "frame_index" in batch else torch.zeros(batch_size, device=batch[OBS_STATE].device) inputs["frame_index"] = (
batch["frame_index"]
if "frame_index" in batch
else torch.zeros(batch_size, device=batch[OBS_STATE].device)
)
# Move all tensors to the correct device # Move all tensors to the correct device
device = self.config.device device = self.config.device
@@ -2056,10 +1935,7 @@ class WallXPolicy(PreTrainedPolicy):
) )
# Call the underlying model's forward with mode="train" # Call the underlying model's forward with mode="train"
outputs = self.model( outputs = self.model(**batch, mode="train")
**batch,
mode="train"
)
# Extract losses from output # Extract losses from output
loss = outputs.loss loss = outputs.loss
@@ -2096,7 +1972,7 @@ class WallXPolicy(PreTrainedPolicy):
action_dim=self.config.max_action_dim, action_dim=self.config.max_action_dim,
pred_horizon=self.config.chunk_size, pred_horizon=self.config.chunk_size,
mode="predict", mode="predict",
predict_mode="diffusion" predict_mode="diffusion",
) )
elif self.config.prediction_mode == "fast": elif self.config.prediction_mode == "fast":
output = self.model( output = self.model(
@@ -2104,7 +1980,7 @@ class WallXPolicy(PreTrainedPolicy):
action_dim=self.config.output_features["action"].shape[0], action_dim=self.config.output_features["action"].shape[0],
pred_horizon=self.config.chunk_size, pred_horizon=self.config.chunk_size,
mode="predict", mode="predict",
predict_mode="fast" predict_mode="fast",
) )
else: else:
raise NotImplementedError(f"Prediction mode {self.config.prediction_mode} not implemented") raise NotImplementedError(f"Prediction mode {self.config.prediction_mode} not implemented")
@@ -2127,6 +2003,6 @@ class WallXPolicy(PreTrainedPolicy):
# Use action queue # Use action queue
if len(self._queues[ACTION]) == 0: if len(self._queues[ACTION]) == 0:
actions = self.predict_action_chunk(batch) actions = self.predict_action_chunk(batch)
self._queues[ACTION].extend(actions.transpose(0, 1)[:self.config.n_action_steps]) self._queues[ACTION].extend(actions.transpose(0, 1)[: self.config.n_action_steps])
return self._queues[ACTION].popleft() return self._queues[ACTION].popleft()
@@ -33,6 +33,8 @@ from lerobot.processor import (
) )
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
def make_wall_x_pre_post_processors( def make_wall_x_pre_post_processors(
config: WallXConfig, config: WallXConfig,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
@@ -75,9 +77,7 @@ def make_wall_x_pre_post_processors(
output_steps = [ output_steps = [
UnnormalizerProcessorStep( UnnormalizerProcessorStep(
features=config.output_features, features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
norm_map=config.normalization_mapping,
stats=dataset_stats
), ),
DeviceProcessorStep(device="cpu"), DeviceProcessorStep(device="cpu"),
] ]
@@ -123,9 +123,7 @@ class WallXTaskProcessor(ComplementaryDataProcessorStep):
new_complementary_data["task"] = f"{task}." new_complementary_data["task"] = f"{task}."
elif isinstance(task, list) and all(isinstance(t, str) for t in task): elif isinstance(task, list) and all(isinstance(t, str) for t in task):
# List of strings: format each # List of strings: format each
new_complementary_data["task"] = [ new_complementary_data["task"] = [t if t.endswith(".") else f"{t}." for t in task]
t if t.endswith(".") else f"{t}." for t in task
]
return new_complementary_data return new_complementary_data
File diff suppressed because it is too large Load Diff
+49 -80
View File
@@ -25,7 +25,7 @@ import random
import re import re
from collections import OrderedDict from collections import OrderedDict
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any
import torch import torch
from transformers import BatchFeature from transformers import BatchFeature
@@ -46,11 +46,11 @@ class X2RDataProcessingConfig:
""" """
# Action prediction configuration # Action prediction configuration
predict_action_keys: List[str] = field(default_factory=list) predict_action_keys: list[str] = field(default_factory=list)
obs_action_keys: List[str] = field(default_factory=list) obs_action_keys: list[str] = field(default_factory=list)
# Image resolution settings for different views # Image resolution settings for different views
resolution: Dict[str, int] = field( resolution: dict[str, int] = field(
default_factory=lambda: { default_factory=lambda: {
"face_view": -1, "face_view": -1,
"left_wrist_view": 128, "left_wrist_view": 128,
@@ -63,7 +63,7 @@ class X2RDataProcessingConfig:
split_seed: int = 42 split_seed: int = 42
# Instruction handling # Instruction handling
priority_order: Optional[Dict[str, float]] = None priority_order: dict[str, float] | None = None
# Vision model parameters # Vision model parameters
model_type: str = "qwen2_5" model_type: str = "qwen2_5"
@@ -77,11 +77,9 @@ class X2RDataProcessingConfig:
"""Post-initialization validation and setup.""" """Post-initialization validation and setup."""
# Validate train/test split # Validate train/test split
if not 0 < self.train_test_split < 1: if not 0 < self.train_test_split < 1:
raise ValueError( raise ValueError(f"train_test_split must be between 0 and 1, got {self.train_test_split}")
f"train_test_split must be between 0 and 1, got {self.train_test_split}"
)
def as_dict(self) -> Dict: def as_dict(self) -> dict:
"""Convert configuration to dictionary format. """Convert configuration to dictionary format.
Returns: Returns:
@@ -105,14 +103,15 @@ class X2RDataProcessingConfig:
raise ValueError(f"Unknown configuration parameter: {key}") raise ValueError(f"Unknown configuration parameter: {key}")
return self return self
def preprocesser_call( def preprocesser_call(
processor, processor,
images: Optional[Union[List, Any]] = None, images: list | Any | None = None,
text: Optional[Union[str, List[str]]] = None, text: str | list[str] | None = None,
videos: Optional[Union[List, Any]] = None, videos: list | Any | None = None,
padding: Union[bool, str] = False, padding: bool | str = False,
truncation: Optional[bool] = None, truncation: bool | None = None,
max_length: Optional[int] = None, max_length: int | None = None,
return_tensors: str = "pt", return_tensors: str = "pt",
) -> BatchFeature: ) -> BatchFeature:
"""Unified preprocessing function for Wall-X model handling text, image and video inputs. """Unified preprocessing function for Wall-X model handling text, image and video inputs.
@@ -145,9 +144,7 @@ def preprocesser_call(
""" """
# Process image inputs # Process image inputs
if images is not None and len(images) > 0: if images is not None and len(images) > 0:
image_inputs = processor.image_processor( image_inputs = processor.image_processor(images=images, videos=None, return_tensors=return_tensors)
images=images, videos=None, return_tensors=return_tensors
)
image_grid_thw = image_inputs["image_grid_thw"] image_grid_thw = image_inputs["image_grid_thw"]
else: else:
image_inputs = {} image_inputs = {}
@@ -155,9 +152,7 @@ def preprocesser_call(
# Process video inputs # Process video inputs
if videos is not None: if videos is not None:
videos_inputs = processor.image_processor( videos_inputs = processor.image_processor(images=None, videos=videos, return_tensors=return_tensors)
images=None, videos=videos, return_tensors=return_tensors
)
video_grid_thw = videos_inputs["video_grid_thw"] video_grid_thw = videos_inputs["video_grid_thw"]
else: else:
videos_inputs = {} videos_inputs = {}
@@ -183,9 +178,7 @@ def preprocesser_call(
break break
# Replace image placeholder with actual token count # Replace image placeholder with actual token count
token_count = image_grid_thw[index].prod() // merge_length token_count = image_grid_thw[index].prod() // merge_length
text[i] = text[i].replace( text[i] = text[i].replace("<|image_pad|>", "<|placeholder|>" * token_count, 1)
"<|image_pad|>", "<|placeholder|>" * token_count, 1
)
index += 1 index += 1
text[i] = text[i].replace("<|placeholder|>", "<|image_pad|>") text[i] = text[i].replace("<|placeholder|>", "<|image_pad|>")
@@ -197,9 +190,7 @@ def preprocesser_call(
while "<|video_pad|>" in text[i]: while "<|video_pad|>" in text[i]:
# Replace video placeholder with actual token count # Replace video placeholder with actual token count
token_count = video_grid_thw[index].prod() // merge_length token_count = video_grid_thw[index].prod() // merge_length
text[i] = text[i].replace( text[i] = text[i].replace("<|video_pad|>", "<|placeholder|>" * token_count, 1)
"<|video_pad|>", "<|placeholder|>" * token_count, 1
)
index += 1 index += 1
text[i] = text[i].replace("<|placeholder|>", "<|video_pad|>") text[i] = text[i].replace("<|placeholder|>", "<|video_pad|>")
@@ -221,9 +212,7 @@ def preprocesser_call(
labels = torch.full_like(text_inputs.input_ids, -100) labels = torch.full_like(text_inputs.input_ids, -100)
assistant_marker = "<|im_start|>assistant\n" assistant_marker = "<|im_start|>assistant\n"
im_end_token_id = processor.tokenizer.convert_tokens_to_ids("<|im_end|>") im_end_token_id = processor.tokenizer.convert_tokens_to_ids("<|im_end|>")
assistant_tokens = processor.tokenizer( assistant_tokens = processor.tokenizer("<|im_start|>assistant\n", add_special_tokens=False).input_ids
"<|im_start|>assistant\n", add_special_tokens=False
).input_ids
for i in range(len(text)): for i in range(len(text)):
assistant_regions = [] assistant_regions = []
@@ -249,9 +238,7 @@ def preprocesser_call(
# From second part onwards, each part starts with assistant response # From second part onwards, each part starts with assistant response
for k in range(current_pos + 1, len(text_inputs.input_ids[i])): for k in range(current_pos + 1, len(text_inputs.input_ids[i])):
if text_inputs.input_ids[i][k] == im_end_token_id: if text_inputs.input_ids[i][k] == im_end_token_id:
assistant_regions.append( assistant_regions.append((current_pos + len(assistant_tokens), k + 2))
(current_pos + len(assistant_tokens), k + 2)
)
break break
current_pos += len(part_tokens) + 3 current_pos += len(part_tokens) + 3
@@ -344,7 +331,7 @@ def process_grounding_points(
coords = [new_x1, new_y1, new_x2, new_y2] coords = [new_x1, new_y1, new_x2, new_y2]
# Return processed point tag # Return processed point tag
return f'<point>[{", ".join(map(str, coords))}]</point>' return f"<point>[{', '.join(map(str, coords))}]</point>"
except (ValueError, TypeError): except (ValueError, TypeError):
# Return original content if processing fails # Return original content if processing fails
@@ -356,10 +343,10 @@ def process_grounding_points(
def get_frame_instruction( def get_frame_instruction(
instruction_info: Dict[str, Any], instruction_info: dict[str, Any],
frame_idx: Optional[int] = None, frame_idx: int | None = None,
truncate_keys: Optional[List[str]] = None, truncate_keys: list[str] | None = None,
) -> Tuple[Dict[str, Any], Optional[int]]: ) -> tuple[dict[str, Any], int | None]:
"""Extract frame-specific instruction from instruction dictionary. """Extract frame-specific instruction from instruction dictionary.
Args: Args:
@@ -388,11 +375,7 @@ def get_frame_instruction(
start_frame, end_frame = map(int, frame_range.split(" ")) start_frame, end_frame = map(int, frame_range.split(" "))
if start_frame <= frame_idx < end_frame or (start_frame == frame_idx): if start_frame <= frame_idx < end_frame or (start_frame == frame_idx):
instruction_for_frame[key] = frame_instruction instruction_for_frame[key] = frame_instruction
if ( if truncate_keys is not None and split_end is None and key in truncate_keys:
truncate_keys is not None
and split_end is None
and key in truncate_keys
):
split_end = end_frame + 1 split_end = end_frame + 1
break break
else: else:
@@ -402,7 +385,7 @@ def get_frame_instruction(
def get_task_instruction( def get_task_instruction(
frame_instruction_info: Dict[str, Any], priority_order: Optional[OrderedDict] = None frame_instruction_info: dict[str, Any], priority_order: OrderedDict | None = None
) -> str: ) -> str:
"""Construct task instruction from available instruction fields using priority sampling. """Construct task instruction from available instruction fields using priority sampling.
@@ -450,13 +433,13 @@ def get_task_instruction(
def get_wallx_normal_text( def get_wallx_normal_text(
instruction_info: Dict[str, Any], instruction_info: dict[str, Any],
action_chunk_size: int, action_chunk_size: int,
frame_idx: int, frame_idx: int,
priority_order: Optional[OrderedDict] = None, priority_order: OrderedDict | None = None,
img_keys: Optional[List[str]] = None, img_keys: list[str] | None = None,
generate_subtask_ratio: float = 0.0, generate_subtask_ratio: float = 0.0,
) -> Tuple[str, bool]: ) -> tuple[str, bool]:
"""Construct complete multimodal prompt text for Wall-X model. """Construct complete multimodal prompt text for Wall-X model.
Formats input using special tokens including: Formats input using special tokens including:
@@ -488,9 +471,7 @@ def get_wallx_normal_text(
action_fast_symbol = "<|action_fast|>" action_fast_symbol = "<|action_fast|>"
# System prologue # System prologue
prologue = ( prologue = f"{role_start_symbol}system\nYou are a helpful assistant.{role_end_symbol}\n"
f"{role_start_symbol}system\nYou are a helpful assistant.{role_end_symbol}\n"
)
# User request with observation # User request with observation
user_request = f"{role_start_symbol}user\nObservation:" user_request = f"{role_start_symbol}user\nObservation:"
@@ -501,9 +482,7 @@ def get_wallx_normal_text(
user_request += "\nInstruction:" user_request += "\nInstruction:"
# Get frame-specific instruction # Get frame-specific instruction
frame_instruction_info, _ = get_frame_instruction( frame_instruction_info, _ = get_frame_instruction(instruction_info, frame_idx=frame_idx)
instruction_info, frame_idx=frame_idx
)
generate_subtask = False generate_subtask = False
priority_keys = ["subtask_generation", "distribute"] priority_keys = ["subtask_generation", "distribute"]
@@ -524,15 +503,11 @@ def get_wallx_normal_text(
output_instruction = frame_instruction_info[key] output_instruction = frame_instruction_info[key]
break break
assistant_output = ( assistant_output = f"{role_start_symbol}assistant\n{output_instruction}\n{role_end_symbol}"
f"{role_start_symbol}assistant\n{output_instruction}\n{role_end_symbol}"
)
generate_subtask = True generate_subtask = True
else: else:
# Generate actions # Generate actions
instruction = get_task_instruction( instruction = get_task_instruction(frame_instruction_info, priority_order=priority_order)
frame_instruction_info, priority_order=priority_order
)
text_prompt = f"\nPredict the next action in robot action.\nProprioception: {propri_symbol}\n" text_prompt = f"\nPredict the next action in robot action.\nProprioception: {propri_symbol}\n"
user_message = f"{user_request} {instruction}{text_prompt}{role_end_symbol}\n" user_message = f"{user_request} {instruction}{text_prompt}{role_end_symbol}\n"
assistant_output = f"{role_start_symbol}assistant\n{action_fast_symbol}{role_end_symbol}\n{action_symbol * action_chunk_size}" assistant_output = f"{role_start_symbol}assistant\n{action_fast_symbol}{role_end_symbol}\n{action_symbol * action_chunk_size}"
@@ -540,7 +515,8 @@ def get_wallx_normal_text(
complete_text = prologue + user_message + assistant_output complete_text = prologue + user_message + assistant_output
return complete_text, generate_subtask return complete_text, generate_subtask
def img_key_mapping(img_keys: List[str]) -> List[str]:
def img_key_mapping(img_keys: list[str]) -> list[str]:
"""Map image keys to camera names. """Map image keys to camera names.
Args: Args:
@@ -555,16 +531,15 @@ def img_key_mapping(img_keys: List[str]) -> List[str]:
if key in CAMERA_NAME_MAPPING: if key in CAMERA_NAME_MAPPING:
key = CAMERA_NAME_MAPPING[key] key = CAMERA_NAME_MAPPING[key]
else: else:
if 'view' in key: if "view" in key:
key = key.replace('_', ' ') key = key.replace("_", " ")
else: else:
key = key + " view" key = key + " view"
processed_img_keys.append(key) processed_img_keys.append(key)
return processed_img_keys return processed_img_keys
def get_action_tokens(
normalized_actions: Union[torch.Tensor, List], action_tokenizer def get_action_tokens(normalized_actions: torch.Tensor | list, action_tokenizer) -> list[list[str]]:
) -> List[List[str]]:
"""Convert normalized actions to action token strings. """Convert normalized actions to action token strings.
Args: Args:
@@ -590,8 +565,8 @@ def get_action_tokens(
def pad_action_token_strs( def pad_action_token_strs(
actions_token_lists: List[List[str]], pad_token: str = "<|endoftext|>" actions_token_lists: list[list[str]], pad_token: str = "<|endoftext|>"
) -> List[str]: ) -> list[str]:
"""Pad action token lists to same length and join as strings. """Pad action token lists to same length and join as strings.
Args: Args:
@@ -605,20 +580,18 @@ def pad_action_token_strs(
padded_action_strs = [] padded_action_strs = []
for tokens in actions_token_lists: for tokens in actions_token_lists:
padded_tokens = ( padded_tokens = tokens + ["<|im_end|>\n"] + [pad_token] * (max_len - len(tokens))
tokens + ["<|im_end|>\n"] + [pad_token] * (max_len - len(tokens))
)
padded_action_strs.append("".join(padded_tokens)) padded_action_strs.append("".join(padded_tokens))
return padded_action_strs return padded_action_strs
def replace_action_token( def replace_action_token(
text: List[str], text: list[str],
norm_action: Optional[torch.Tensor], norm_action: torch.Tensor | None,
action_tokenizer, action_tokenizer,
dof_masks: Optional[torch.Tensor] = None, dof_masks: torch.Tensor | None = None,
) -> List[str]: ) -> list[str]:
"""Replace action placeholders in text with actual action tokens. """Replace action placeholders in text with actual action tokens.
Args: Args:
@@ -632,10 +605,7 @@ def replace_action_token(
""" """
if action_tokenizer is not None and norm_action is not None: if action_tokenizer is not None and norm_action is not None:
# Extract actions based on chunk sizes and DOF masks # Extract actions based on chunk sizes and DOF masks
norm_action = [ norm_action = [action[:32, dof_masks[i, 0].bool()] for i, action in enumerate(norm_action)]
action[: 32, dof_masks[i, 0].bool()]
for i, action in enumerate(norm_action)
]
# Convert to action tokens and pad # Convert to action tokens and pad
actions_fast_tokens = get_action_tokens(norm_action, action_tokenizer) actions_fast_tokens = get_action_tokens(norm_action, action_tokenizer)
@@ -658,4 +628,3 @@ def replace_action_token(
text = [t.replace("<|action_fast|><|im_end|>\n", "") for t in text] text = [t.replace("<|action_fast|><|im_end|>\n", "") for t in text]
return text return text
+4 -1
View File
@@ -35,10 +35,11 @@ from lerobot.policies.wall_x import ( # noqa: E402
) )
from lerobot.utils.random_utils import set_seed # noqa: E402 from lerobot.utils.random_utils import set_seed # noqa: E402
def test_policy_instantiation(): def test_policy_instantiation():
# Create config # Create config
set_seed(42) set_seed(42)
config = WallXConfig(device='cuda') config = WallXConfig(device="cuda")
# Set up input_features and output_features in the config # Set up input_features and output_features in the config
from lerobot.configs.types import FeatureType, PolicyFeature from lerobot.configs.types import FeatureType, PolicyFeature
@@ -118,6 +119,7 @@ def test_policy_instantiation():
print(f"Action prediction failed: {e}") print(f"Action prediction failed: {e}")
raise raise
def test_config_creation(): def test_config_creation():
"""Test policy config creation through factory.""" """Test policy config creation through factory."""
try: try:
@@ -130,6 +132,7 @@ def test_config_creation():
print(f"Config creation failed: {e}") print(f"Config creation failed: {e}")
raise raise
if __name__ == "__main__": if __name__ == "__main__":
test_policy_instantiation() test_policy_instantiation()
test_config_creation() test_config_creation()