mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 03:59:42 +00:00
fix pre-commit errors
This commit is contained in:
committed by
Michel Aractingi
parent
9ce6dd9e25
commit
a0c9a7d85d
@@ -7,11 +7,12 @@ This repository contains the Hugging Face port of **WALL-OSS**, a Vision-Languag
|
|||||||
## Model Overview
|
## Model Overview
|
||||||
|
|
||||||
| Feature | Description |
|
| Feature | Description |
|
||||||
| -------------------- | ------------------------------------------------------------------------ |
|
| ------------------ | ----------------------------------------------------- | --- |
|
||||||
| Base Model | Qwen2.5-VL (Vision-Language Model) |
|
| Base Model | Qwen2.5-VL (Vision-Language Model) |
|
||||||
| Action Prediction | Flow Matching (diffusion) or FAST (discrete tokens) |
|
| Action Prediction | Flow Matching (diffusion) or FAST (discrete tokens) |
|
||||||
| Architecture | Mixture of Experts (MoE) with action-specific routing | |
|
| Architecture | Mixture of Experts (MoE) with action-specific routing | |
|
||||||
| Multi-Modal Inputs | Vision (images/videos), Language, Proprioception |
|
| Multi-Modal Inputs | Vision (images/videos), Language, Proprioception |
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## Citation
|
## Citation
|
||||||
@@ -32,4 +33,3 @@ If you use this work, please cite:
|
|||||||
## License
|
## License
|
||||||
|
|
||||||
This port follows the **Apache 2.0 License**.
|
This port follows the **Apache 2.0 License**.
|
||||||
|
|
||||||
|
|||||||
@@ -85,9 +85,7 @@ class WallXConfig(PreTrainedConfig):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.prediction_mode not in ["diffusion", "fast"]:
|
if self.prediction_mode not in ["diffusion", "fast"]:
|
||||||
raise ValueError(
|
raise ValueError(f"prediction_mode must be 'diffusion' or 'fast', got {self.prediction_mode}")
|
||||||
f"prediction_mode must be 'diffusion' or 'fast', got {self.prediction_mode}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Assign use_fast_tokenizer based on prediction_mode
|
# Assign use_fast_tokenizer based on prediction_mode
|
||||||
if self.prediction_mode == "fast":
|
if self.prediction_mode == "fast":
|
||||||
@@ -96,9 +94,7 @@ class WallXConfig(PreTrainedConfig):
|
|||||||
self.use_fast_tokenizer = False
|
self.use_fast_tokenizer = False
|
||||||
self.action_tokenizer_path = None # disable action tokenizer for diffusion mode
|
self.action_tokenizer_path = None # disable action tokenizer for diffusion mode
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(f"prediction_mode must be 'diffusion' or 'fast', got {self.prediction_mode}")
|
||||||
f"prediction_mode must be 'diffusion' or 'fast', got {self.prediction_mode}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def validate_features(self) -> None:
|
def validate_features(self) -> None:
|
||||||
"""Validate and set up input/output features."""
|
"""Validate and set up input/output features."""
|
||||||
|
|||||||
@@ -34,61 +34,57 @@ lerobot-train \
|
|||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from os import PathLike
|
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
from os import PathLike
|
||||||
from PIL import Image
|
from typing import Any
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from peft import LoraConfig, get_peft_model
|
from peft import LoraConfig, get_peft_model
|
||||||
|
from PIL import Image
|
||||||
|
from qwen_vl_utils.vision_process import smart_resize
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.distributions import Beta
|
from torch.distributions import Beta
|
||||||
from torch.nn import CrossEntropyLoss
|
from torch.nn import CrossEntropyLoss
|
||||||
from torchdiffeq import odeint
|
from torchdiffeq import odeint
|
||||||
from transformers import AutoProcessor
|
from transformers import AutoProcessor, BatchFeature
|
||||||
from transformers.cache_utils import (
|
from transformers.cache_utils import (
|
||||||
StaticCache,
|
StaticCache,
|
||||||
)
|
)
|
||||||
|
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
|
||||||
|
Qwen2_5_VLForConditionalGeneration,
|
||||||
|
)
|
||||||
from transformers.utils import is_torchdynamo_compiling, logging
|
from transformers.utils import is_torchdynamo_compiling, logging
|
||||||
from transformers import AutoProcessor, BatchFeature
|
|
||||||
from qwen_vl_utils.vision_process import smart_resize
|
|
||||||
|
|
||||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||||
from lerobot.policies.utils import populate_queues
|
from lerobot.policies.utils import populate_queues
|
||||||
from lerobot.policies.wall_x.configuration_wall_x import WallXConfig
|
from lerobot.policies.wall_x.configuration_wall_x import WallXConfig
|
||||||
from lerobot.utils.constants import ACTION, OBS_STATE
|
|
||||||
|
|
||||||
from lerobot.policies.wall_x.utils import (
|
|
||||||
replace_action_token,
|
|
||||||
preprocesser_call,
|
|
||||||
get_wallx_normal_text,
|
|
||||||
process_grounding_points,
|
|
||||||
)
|
|
||||||
from lerobot.policies.wall_x.constant import (
|
from lerobot.policies.wall_x.constant import (
|
||||||
MODEL_TYPE,
|
|
||||||
TOKENIZER_MAX_LENGTH,
|
|
||||||
PRIORITY_ORDER,
|
|
||||||
GENERATE_SUBTASK_RATIO,
|
GENERATE_SUBTASK_RATIO,
|
||||||
RESOLUTION,
|
IMAGE_FACTOR,
|
||||||
MAX_PIXELS,
|
MAX_PIXELS,
|
||||||
MIN_PIXELS,
|
MIN_PIXELS,
|
||||||
IMAGE_FACTOR,
|
MODEL_TYPE,
|
||||||
|
PRIORITY_ORDER,
|
||||||
|
RESOLUTION,
|
||||||
|
TOKENIZER_MAX_LENGTH,
|
||||||
)
|
)
|
||||||
from lerobot.policies.wall_x.qwen_model.configuration_qwen2_5_vl import Qwen2_5_VLConfig
|
from lerobot.policies.wall_x.qwen_model.configuration_qwen2_5_vl import Qwen2_5_VLConfig
|
||||||
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
|
|
||||||
Qwen2_5_VLForConditionalGeneration,
|
|
||||||
)
|
|
||||||
|
|
||||||
from lerobot.policies.wall_x.qwen_model.qwen2_5_vl_moe import (
|
from lerobot.policies.wall_x.qwen_model.qwen2_5_vl_moe import (
|
||||||
Qwen2_5_VisionTransformerPretrainedModel,
|
Qwen2_5_VisionTransformerPretrainedModel,
|
||||||
Qwen2_5_VLACausalLMOutputWithPast,
|
Qwen2_5_VLACausalLMOutputWithPast,
|
||||||
Qwen2_5_VLMoEModel,
|
Qwen2_5_VLMoEModel,
|
||||||
)
|
)
|
||||||
|
from lerobot.policies.wall_x.utils import (
|
||||||
|
get_wallx_normal_text,
|
||||||
|
preprocesser_call,
|
||||||
|
process_grounding_points,
|
||||||
|
replace_action_token,
|
||||||
|
)
|
||||||
|
from lerobot.utils.constants import ACTION, OBS_STATE
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
@@ -151,7 +147,7 @@ class ActionHead(nn.Module):
|
|||||||
"""Sample timesteps using Beta distribution (always in float32 for numerical stability)."""
|
"""Sample timesteps using Beta distribution (always in float32 for numerical stability)."""
|
||||||
beta_dist = Beta(
|
beta_dist = Beta(
|
||||||
torch.tensor(self.beta_alpha, dtype=torch.float32, device=device),
|
torch.tensor(self.beta_alpha, dtype=torch.float32, device=device),
|
||||||
torch.tensor(self.beta_beta, dtype=torch.float32, device=device)
|
torch.tensor(self.beta_beta, dtype=torch.float32, device=device),
|
||||||
)
|
)
|
||||||
sample = beta_dist.sample([batch_size])
|
sample = beta_dist.sample([batch_size])
|
||||||
time = (1 - sample) * self.s
|
time = (1 - sample) * self.s
|
||||||
@@ -275,14 +271,14 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
|
|||||||
pretrained_name_or_path,
|
pretrained_name_or_path,
|
||||||
config=None,
|
config=None,
|
||||||
action_tokenizer_path=None,
|
action_tokenizer_path=None,
|
||||||
attn_implementation: str = 'eager',
|
attn_implementation: str = "eager",
|
||||||
cache_dir: str | PathLike | None = None,
|
cache_dir: str | PathLike | None = None,
|
||||||
force_download: bool = False,
|
force_download: bool = False,
|
||||||
local_files_only: bool = False,
|
local_files_only: bool = False,
|
||||||
token: str | bool | None = None,
|
token: str | bool | None = None,
|
||||||
revision: str = "main",
|
revision: str = "main",
|
||||||
strict: bool = False,
|
strict: bool = False,
|
||||||
**kwargs: Any
|
**kwargs: Any,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Load model from pretrained model path.
|
Load model from pretrained model path.
|
||||||
@@ -312,9 +308,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
|
|||||||
config._attn_implementation = attn_implementation
|
config._attn_implementation = attn_implementation
|
||||||
processor = AutoProcessor.from_pretrained(pretrained_name_or_path, use_fast=True)
|
processor = AutoProcessor.from_pretrained(pretrained_name_or_path, use_fast=True)
|
||||||
if action_tokenizer_path is not None:
|
if action_tokenizer_path is not None:
|
||||||
action_tokenizer = AutoProcessor.from_pretrained(
|
action_tokenizer = AutoProcessor.from_pretrained(action_tokenizer_path, trust_remote_code=True)
|
||||||
action_tokenizer_path, trust_remote_code=True
|
|
||||||
)
|
|
||||||
processor.action_processor = action_tokenizer
|
processor.action_processor = action_tokenizer
|
||||||
else:
|
else:
|
||||||
action_tokenizer = None
|
action_tokenizer = None
|
||||||
@@ -387,9 +381,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
|
|||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
# Initialize vision transformer and language model components
|
# Initialize vision transformer and language model components
|
||||||
self.visual = Qwen2_5_VisionTransformerPretrainedModel._from_config(
|
self.visual = Qwen2_5_VisionTransformerPretrainedModel._from_config(config.vision_config)
|
||||||
config.vision_config
|
|
||||||
)
|
|
||||||
self.model = Qwen2_5_VLMoEModel(config)
|
self.model = Qwen2_5_VLMoEModel(config)
|
||||||
self.vocab_size = config.vocab_size
|
self.vocab_size = config.vocab_size
|
||||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||||
@@ -446,12 +438,8 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
|
|||||||
# Create list of fast action token IDs
|
# Create list of fast action token IDs
|
||||||
fast_action_token_list = []
|
fast_action_token_list = []
|
||||||
if self.use_fast_tokenizer:
|
if self.use_fast_tokenizer:
|
||||||
for i in range(
|
for i in range(self.processor.tokenizer.init_kwargs["action_token_vocab_size"]):
|
||||||
self.processor.tokenizer.init_kwargs["action_token_vocab_size"]
|
action_token_id = self.processor.tokenizer.convert_tokens_to_ids(f"<|action_token_{i}|>")
|
||||||
):
|
|
||||||
action_token_id = self.processor.tokenizer.convert_tokens_to_ids(
|
|
||||||
f"<|action_token_{i}|>"
|
|
||||||
)
|
|
||||||
fast_action_token_list.append(action_token_id)
|
fast_action_token_list.append(action_token_id)
|
||||||
|
|
||||||
# Get special action token IDs
|
# Get special action token IDs
|
||||||
@@ -465,9 +453,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
|
|||||||
"action_token_id": action_token_id,
|
"action_token_id": action_token_id,
|
||||||
}
|
}
|
||||||
|
|
||||||
def add_lora(
|
def add_lora(self, r=8, lora_alpha=32, target_modules=["q_proj", "v_proj"], lora_dropout=0.1):
|
||||||
self, r=8, lora_alpha=32, target_modules=["q_proj", "v_proj"], lora_dropout=0.1
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
Add LoRA (Low-Rank Adaptation) adapters to the model.
|
Add LoRA (Low-Rank Adaptation) adapters to the model.
|
||||||
|
|
||||||
@@ -516,12 +502,12 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
|
|||||||
|
|
||||||
def get_rope_index(
|
def get_rope_index(
|
||||||
self,
|
self,
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
input_ids: torch.LongTensor | None = None,
|
||||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
image_grid_thw: torch.LongTensor | None = None,
|
||||||
video_grid_thw: Optional[torch.LongTensor] = None,
|
video_grid_thw: torch.LongTensor | None = None,
|
||||||
second_per_grid_ts: Optional[torch.Tensor] = None,
|
second_per_grid_ts: torch.Tensor | None = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: torch.Tensor | None = None,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Calculate 3D RoPE (Rotary Position Embedding) indices for vision and text tokens.
|
Calculate 3D RoPE (Rotary Position Embedding) indices for vision and text tokens.
|
||||||
|
|
||||||
@@ -555,9 +541,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
|
|||||||
vision_start_token_id = self.config.vision_start_token_id
|
vision_start_token_id = self.config.vision_start_token_id
|
||||||
mrope_position_deltas = []
|
mrope_position_deltas = []
|
||||||
|
|
||||||
if input_ids is not None and (
|
if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None):
|
||||||
image_grid_thw is not None or video_grid_thw is not None
|
|
||||||
):
|
|
||||||
total_input_ids = input_ids
|
total_input_ids = input_ids
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = torch.ones_like(total_input_ids)
|
attention_mask = torch.ones_like(total_input_ids)
|
||||||
@@ -580,9 +564,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
|
|||||||
image_nums, video_nums = 0, 0
|
image_nums, video_nums = 0, 0
|
||||||
|
|
||||||
# Find vision tokens and count images/videos
|
# Find vision tokens and count images/videos
|
||||||
vision_start_indices = torch.argwhere(
|
vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1)
|
||||||
input_ids == vision_start_token_id
|
|
||||||
).squeeze(1)
|
|
||||||
vision_tokens = input_ids[vision_start_indices + 1]
|
vision_tokens = input_ids[vision_start_indices + 1]
|
||||||
image_nums = (vision_tokens == image_token_id).sum()
|
image_nums = (vision_tokens == image_token_id).sum()
|
||||||
video_nums = (vision_tokens == video_token_id).sum()
|
video_nums = (vision_tokens == video_token_id).sum()
|
||||||
@@ -641,14 +623,8 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
|
|||||||
text_len = ed - st
|
text_len = ed - st
|
||||||
|
|
||||||
# Add position IDs for text tokens before vision token
|
# Add position IDs for text tokens before vision token
|
||||||
st_idx = (
|
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
||||||
llm_pos_ids_list[-1].max() + 1
|
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
|
||||||
if len(llm_pos_ids_list) > 0
|
|
||||||
else 0
|
|
||||||
)
|
|
||||||
llm_pos_ids_list.append(
|
|
||||||
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
|
|
||||||
)
|
|
||||||
|
|
||||||
# Calculate 3D position embeddings for vision tokens
|
# Calculate 3D position embeddings for vision tokens
|
||||||
range_tensor = torch.arange(llm_grid_t).view(-1, 1)
|
range_tensor = torch.arange(llm_grid_t).view(-1, 1)
|
||||||
@@ -656,71 +632,43 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
|
|||||||
|
|
||||||
# Calculate temporal position IDs with time scaling
|
# Calculate temporal position IDs with time scaling
|
||||||
time_tensor = (
|
time_tensor = (
|
||||||
expanded_range
|
expanded_range * second_per_grid_t * self.config.vision_config.tokens_per_second
|
||||||
* second_per_grid_t
|
|
||||||
* self.config.vision_config.tokens_per_second
|
|
||||||
)
|
)
|
||||||
time_tensor_long = time_tensor.long()
|
time_tensor_long = time_tensor.long()
|
||||||
t_index = time_tensor_long.flatten()
|
t_index = time_tensor_long.flatten()
|
||||||
|
|
||||||
# Calculate spatial position IDs
|
# Calculate spatial position IDs
|
||||||
h_index = (
|
h_index = (
|
||||||
torch.arange(llm_grid_h)
|
torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()
|
||||||
.view(1, -1, 1)
|
|
||||||
.expand(llm_grid_t, -1, llm_grid_w)
|
|
||||||
.flatten()
|
|
||||||
)
|
)
|
||||||
w_index = (
|
w_index = (
|
||||||
torch.arange(llm_grid_w)
|
torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()
|
||||||
.view(1, 1, -1)
|
|
||||||
.expand(llm_grid_t, llm_grid_h, -1)
|
|
||||||
.flatten()
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add 3D position IDs for vision tokens
|
# Add 3D position IDs for vision tokens
|
||||||
llm_pos_ids_list.append(
|
llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
|
||||||
torch.stack([t_index, h_index, w_index]) + text_len + st_idx
|
|
||||||
)
|
|
||||||
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
|
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
|
||||||
|
|
||||||
# Add position IDs for remaining text tokens
|
# Add position IDs for remaining text tokens
|
||||||
if st < len(input_tokens):
|
if st < len(input_tokens):
|
||||||
st_idx = (
|
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
||||||
llm_pos_ids_list[-1].max() + 1
|
|
||||||
if len(llm_pos_ids_list) > 0
|
|
||||||
else 0
|
|
||||||
)
|
|
||||||
text_len = len(input_tokens) - st
|
text_len = len(input_tokens) - st
|
||||||
llm_pos_ids_list.append(
|
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
|
||||||
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
|
|
||||||
)
|
|
||||||
|
|
||||||
# Concatenate all position IDs for this sequence
|
# Concatenate all position IDs for this sequence
|
||||||
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
|
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
|
||||||
position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(
|
position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device)
|
||||||
position_ids.device
|
mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i]))
|
||||||
)
|
|
||||||
mrope_position_deltas.append(
|
|
||||||
llm_positions.max() + 1 - len(total_input_ids[i])
|
|
||||||
)
|
|
||||||
|
|
||||||
mrope_position_deltas = torch.tensor(
|
mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1)
|
||||||
mrope_position_deltas, device=input_ids.device
|
|
||||||
).unsqueeze(1)
|
|
||||||
return position_ids, mrope_position_deltas
|
return position_ids, mrope_position_deltas
|
||||||
else:
|
else:
|
||||||
# Handle case without vision tokens - use standard 1D position embeddings
|
# Handle case without vision tokens - use standard 1D position embeddings
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||||
position_ids = (
|
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device)
|
||||||
position_ids.unsqueeze(0)
|
max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0]
|
||||||
.expand(3, -1, -1)
|
|
||||||
.to(attention_mask.device)
|
|
||||||
)
|
|
||||||
max_position_ids = position_ids.max(0, keepdim=False)[0].max(
|
|
||||||
-1, keepdim=True
|
|
||||||
)[0]
|
|
||||||
mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
|
mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
|
||||||
else:
|
else:
|
||||||
position_ids = (
|
position_ids = (
|
||||||
@@ -739,33 +687,29 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
|
|||||||
def train_step_forward(
|
def train_step_forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor = None,
|
input_ids: torch.LongTensor = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: torch.Tensor | None = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: torch.LongTensor | None = None,
|
||||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
past_key_values: list[torch.FloatTensor] | None = None,
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
inputs_embeds: torch.FloatTensor | None = None,
|
||||||
moe_token_types: Optional[
|
moe_token_types: torch.LongTensor | None = None, # MoE token type assignments
|
||||||
torch.LongTensor
|
labels: torch.LongTensor | None = None,
|
||||||
] = None, # MoE token type assignments
|
use_cache: bool | None = None,
|
||||||
labels: Optional[torch.LongTensor] = None,
|
output_attentions: bool | None = None,
|
||||||
use_cache: Optional[bool] = None,
|
output_hidden_states: bool | None = None,
|
||||||
output_attentions: Optional[bool] = None,
|
return_dict: bool | None = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
pixel_values: torch.Tensor | None = None,
|
||||||
return_dict: Optional[bool] = None,
|
pixel_values_videos: torch.FloatTensor | None = None,
|
||||||
pixel_values: Optional[torch.Tensor] = None,
|
image_grid_thw: torch.LongTensor | None = None,
|
||||||
pixel_values_videos: Optional[torch.FloatTensor] = None,
|
video_grid_thw: torch.LongTensor | None = None,
|
||||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
action_chunk: torch.FloatTensor | None = None, # Action trajectory chunks
|
||||||
video_grid_thw: Optional[torch.LongTensor] = None,
|
proprioception: torch.FloatTensor | None = None, # Joint position/orientation data
|
||||||
action_chunk: Optional[torch.FloatTensor] = None, # Action trajectory chunks
|
rope_deltas: torch.LongTensor | None = None,
|
||||||
proprioception: Optional[
|
cache_position: torch.LongTensor | None = None,
|
||||||
torch.FloatTensor
|
second_per_grid_ts: torch.Tensor | None = None,
|
||||||
] = None, # Joint position/orientation data
|
dof_mask: torch.FloatTensor | None = None,
|
||||||
rope_deltas: Optional[torch.LongTensor] = None,
|
agent_pos_mask: torch.FloatTensor | None = None,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
|
||||||
second_per_grid_ts: Optional[torch.Tensor] = None,
|
|
||||||
dof_mask: Optional[torch.FloatTensor] = None,
|
|
||||||
agent_pos_mask: Optional[torch.FloatTensor] = None,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Union[Tuple, Qwen2_5_VLACausalLMOutputWithPast]:
|
) -> tuple | Qwen2_5_VLACausalLMOutputWithPast:
|
||||||
"""
|
"""
|
||||||
Forward pass for training with multi-modal inputs including vision, text, and action data.
|
Forward pass for training with multi-modal inputs including vision, text, and action data.
|
||||||
|
|
||||||
@@ -806,24 +750,16 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
|
|||||||
|
|
||||||
# Set output configuration from model config if not specified
|
# Set output configuration from model config if not specified
|
||||||
output_attentions = (
|
output_attentions = (
|
||||||
output_attentions
|
output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
if output_attentions is not None
|
|
||||||
else self.config.output_attentions
|
|
||||||
)
|
)
|
||||||
output_hidden_states = (
|
output_hidden_states = (
|
||||||
output_hidden_states
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
if output_hidden_states is not None
|
|
||||||
else self.config.output_hidden_states
|
|
||||||
)
|
|
||||||
return_dict = (
|
|
||||||
return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
)
|
)
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
# Calculate RoPE position IDs if not provided
|
# Calculate RoPE position IDs if not provided
|
||||||
# Note: Cannot calculate rope deltas with 4D attention mask. TODO: Fix this limitation
|
# Note: Cannot calculate rope deltas with 4D attention mask. TODO: Fix this limitation
|
||||||
if position_ids is None and (
|
if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
|
||||||
attention_mask is None or attention_mask.ndim == 2
|
|
||||||
):
|
|
||||||
# Calculate RoPE index once per generation in the pre-fill stage only
|
# Calculate RoPE index once per generation in the pre-fill stage only
|
||||||
if (
|
if (
|
||||||
(cache_position is not None and cache_position[0] == 0)
|
(cache_position is not None and cache_position[0] == 0)
|
||||||
@@ -865,9 +801,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
|
|||||||
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
|
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
|
||||||
image_mask = mask_expanded.to(inputs_embeds.device)
|
image_mask = mask_expanded.to(inputs_embeds.device)
|
||||||
|
|
||||||
image_embeds = image_embeds.to(
|
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||||
inputs_embeds.device, inputs_embeds.dtype
|
|
||||||
)
|
|
||||||
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
|
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
|
||||||
|
|
||||||
# Process video embeddings
|
# Process video embeddings
|
||||||
@@ -887,19 +821,13 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
|
|||||||
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
|
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
|
||||||
video_mask = mask_expanded.to(inputs_embeds.device)
|
video_mask = mask_expanded.to(inputs_embeds.device)
|
||||||
|
|
||||||
video_embeds = video_embeds.to(
|
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||||
inputs_embeds.device, inputs_embeds.dtype
|
|
||||||
)
|
|
||||||
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
|
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
|
||||||
|
|
||||||
# Process proprioceptive data (joint positions, orientations, etc.)
|
# Process proprioceptive data (joint positions, orientations, etc.)
|
||||||
if proprioception is not None:
|
if proprioception is not None:
|
||||||
proprioception = proprioception.to(inputs_embeds.device).to(
|
proprioception = proprioception.to(inputs_embeds.device).to(inputs_embeds.dtype)
|
||||||
inputs_embeds.dtype
|
agent_pos_mask = agent_pos_mask.to(inputs_embeds.device).to(inputs_embeds.dtype)
|
||||||
)
|
|
||||||
agent_pos_mask = agent_pos_mask.to(inputs_embeds.device).to(
|
|
||||||
inputs_embeds.dtype
|
|
||||||
)
|
|
||||||
proprioception = self.action_preprocessor.proprioception_proj(
|
proprioception = self.action_preprocessor.proprioception_proj(
|
||||||
proprioception,
|
proprioception,
|
||||||
agent_pos_mask,
|
agent_pos_mask,
|
||||||
@@ -910,12 +838,8 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
|
|||||||
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
|
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
|
||||||
proprioception_mask = mask_expanded.to(inputs_embeds.device)
|
proprioception_mask = mask_expanded.to(inputs_embeds.device)
|
||||||
|
|
||||||
proprioception = proprioception.to(
|
proprioception = proprioception.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||||
inputs_embeds.device, inputs_embeds.dtype
|
inputs_embeds = inputs_embeds.masked_scatter(proprioception_mask, proprioception)
|
||||||
)
|
|
||||||
inputs_embeds = inputs_embeds.masked_scatter(
|
|
||||||
proprioception_mask, proprioception
|
|
||||||
)
|
|
||||||
elif self.training:
|
elif self.training:
|
||||||
# Dummy forward pass to ensure gradient registration in DDP
|
# Dummy forward pass to ensure gradient registration in DDP
|
||||||
# This handles cases where one process has proprioception data while another doesn't
|
# This handles cases where one process has proprioception data while another doesn't
|
||||||
@@ -925,32 +849,22 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
|
|||||||
self.action_preprocessor.propri_dim * 2,
|
self.action_preprocessor.propri_dim * 2,
|
||||||
device=inputs_embeds.device,
|
device=inputs_embeds.device,
|
||||||
)
|
)
|
||||||
dummy_forward = self.action_preprocessor.proprioception_proj(
|
dummy_forward = self.action_preprocessor.proprioception_proj(dummy_input)
|
||||||
dummy_input
|
|
||||||
)
|
|
||||||
dummy_loss = sum(p.sum() for p in dummy_forward)
|
dummy_loss = sum(p.sum() for p in dummy_forward)
|
||||||
inputs_embeds = inputs_embeds + 0 * dummy_loss
|
inputs_embeds = inputs_embeds + 0 * dummy_loss
|
||||||
|
|
||||||
# Process action chunk data
|
# Process action chunk data
|
||||||
if action_chunk is not None:
|
if action_chunk is not None:
|
||||||
action_chunk = action_chunk.to(inputs_embeds.device).to(
|
action_chunk = action_chunk.to(inputs_embeds.device).to(inputs_embeds.dtype)
|
||||||
inputs_embeds.dtype
|
|
||||||
)
|
|
||||||
dof_mask = dof_mask.to(inputs_embeds.device).to(inputs_embeds.dtype)
|
dof_mask = dof_mask.to(inputs_embeds.device).to(inputs_embeds.dtype)
|
||||||
noisy_action_emb, flow = self.action_preprocessor(
|
noisy_action_emb, flow = self.action_preprocessor(action_chunk, dof_mask)
|
||||||
action_chunk, dof_mask
|
|
||||||
)
|
|
||||||
mask = input_ids == self.action_token_id_set["action_token_id"]
|
mask = input_ids == self.action_token_id_set["action_token_id"]
|
||||||
mask_unsqueezed = mask.unsqueeze(-1)
|
mask_unsqueezed = mask.unsqueeze(-1)
|
||||||
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
|
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
|
||||||
action_mask = mask_expanded.to(inputs_embeds.device)
|
action_mask = mask_expanded.to(inputs_embeds.device)
|
||||||
|
|
||||||
noisy_action_emb = noisy_action_emb.to(
|
noisy_action_emb = noisy_action_emb.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||||
inputs_embeds.device, inputs_embeds.dtype
|
inputs_embeds = inputs_embeds.masked_scatter(action_mask, noisy_action_emb)
|
||||||
)
|
|
||||||
inputs_embeds = inputs_embeds.masked_scatter(
|
|
||||||
action_mask, noisy_action_emb
|
|
||||||
)
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
attention_mask = attention_mask.to(inputs_embeds.device)
|
attention_mask = attention_mask.to(inputs_embeds.device)
|
||||||
@@ -1011,18 +925,14 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
|
|||||||
if action_mask.any():
|
if action_mask.any():
|
||||||
action_hidden_states = hidden_states[action_mask].to(torch.float32)
|
action_hidden_states = hidden_states[action_mask].to(torch.float32)
|
||||||
flow = flow.reshape(-1, flow.shape[-1]).to(torch.float32)
|
flow = flow.reshape(-1, flow.shape[-1]).to(torch.float32)
|
||||||
_flow_loss = self.action_preprocessor.flow_loss(
|
_flow_loss = self.action_preprocessor.flow_loss(action_hidden_states, flow, dof_mask)
|
||||||
action_hidden_states, flow, dof_mask
|
|
||||||
)
|
|
||||||
if isinstance(_flow_loss, torch.Tensor):
|
if isinstance(_flow_loss, torch.Tensor):
|
||||||
flow_loss = _flow_loss.mean()
|
flow_loss = _flow_loss.mean()
|
||||||
if loss is not None:
|
if loss is not None:
|
||||||
loss = loss + self.flow_loss_weight * flow_loss.to(torch.float32)
|
loss = loss + self.flow_loss_weight * flow_loss.to(torch.float32)
|
||||||
else:
|
else:
|
||||||
loss = self.flow_loss_weight * flow_loss.to(torch.float32)
|
loss = self.flow_loss_weight * flow_loss.to(torch.float32)
|
||||||
_flow_loss = _flow_loss.view(
|
_flow_loss = _flow_loss.view(dof_mask.shape[0], dof_mask.shape[1], dof_mask.shape[2])
|
||||||
dof_mask.shape[0], dof_mask.shape[1], dof_mask.shape[2]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Return outputs based on return_dict setting
|
# Return outputs based on return_dict setting
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
@@ -1031,9 +941,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
|
|||||||
|
|
||||||
return Qwen2_5_VLACausalLMOutputWithPast(
|
return Qwen2_5_VLACausalLMOutputWithPast(
|
||||||
loss=loss,
|
loss=loss,
|
||||||
cross_entropy_loss=(
|
cross_entropy_loss=(cross_entropy_loss.clone() if cross_entropy_loss is not None else None),
|
||||||
cross_entropy_loss.clone() if cross_entropy_loss is not None else None
|
|
||||||
),
|
|
||||||
flow_loss=flow_loss,
|
flow_loss=flow_loss,
|
||||||
logits=logits,
|
logits=logits,
|
||||||
past_key_values=outputs.past_key_values,
|
past_key_values=outputs.past_key_values,
|
||||||
@@ -1065,31 +973,31 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
|
|||||||
def predict(
|
def predict(
|
||||||
self,
|
self,
|
||||||
predict_mode: str,
|
predict_mode: str,
|
||||||
pred_horizon: Optional[int] = None,
|
pred_horizon: int | None = None,
|
||||||
action_dim: Optional[int] = None,
|
action_dim: int | None = None,
|
||||||
input_ids: torch.LongTensor = None,
|
input_ids: torch.LongTensor = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: torch.Tensor | None = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: torch.LongTensor | None = None,
|
||||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
past_key_values: list[torch.FloatTensor] | None = None,
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
inputs_embeds: torch.FloatTensor | None = None,
|
||||||
moe_token_types: Optional[torch.LongTensor] = None,
|
moe_token_types: torch.LongTensor | None = None,
|
||||||
labels: Optional[torch.LongTensor] = None,
|
labels: torch.LongTensor | None = None,
|
||||||
use_cache: Optional[bool] = None,
|
use_cache: bool | None = None,
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: bool | None = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: bool | None = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: bool | None = None,
|
||||||
pixel_values: Optional[torch.Tensor] = None,
|
pixel_values: torch.Tensor | None = None,
|
||||||
pixel_values_videos: Optional[torch.FloatTensor] = None,
|
pixel_values_videos: torch.FloatTensor | None = None,
|
||||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
image_grid_thw: torch.LongTensor | None = None,
|
||||||
video_grid_thw: Optional[torch.LongTensor] = None,
|
video_grid_thw: torch.LongTensor | None = None,
|
||||||
action_chunk: Optional[torch.FloatTensor] = None,
|
action_chunk: torch.FloatTensor | None = None,
|
||||||
proprioception: Optional[torch.FloatTensor] = None,
|
proprioception: torch.FloatTensor | None = None,
|
||||||
rope_deltas: Optional[torch.LongTensor] = None,
|
rope_deltas: torch.LongTensor | None = None,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: torch.LongTensor | None = None,
|
||||||
second_per_grid_ts: Optional[torch.Tensor] = None,
|
second_per_grid_ts: torch.Tensor | None = None,
|
||||||
num_inference_timesteps: Optional[int] = 10,
|
num_inference_timesteps: int | None = 10,
|
||||||
dof_mask: Optional[torch.FloatTensor] = None,
|
dof_mask: torch.FloatTensor | None = None,
|
||||||
agent_pos_mask: Optional[torch.FloatTensor] = None,
|
agent_pos_mask: torch.FloatTensor | None = None,
|
||||||
re_generate: bool = False,
|
re_generate: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
@@ -1139,30 +1047,20 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
|
|||||||
- 'predict_output_text': Generated text (for text/fast modes)
|
- 'predict_output_text': Generated text (for text/fast modes)
|
||||||
- 'gt_output_text': Ground truth text (for text/fast modes)
|
- 'gt_output_text': Ground truth text (for text/fast modes)
|
||||||
"""
|
"""
|
||||||
batch_size = (
|
batch_size = input_ids.shape[0] if input_ids is not None else inputs_embeds.shape[0]
|
||||||
input_ids.shape[0] if input_ids is not None else inputs_embeds.shape[0]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Text and fast modes require batch size 1 for autoregressive generation
|
# Text and fast modes require batch size 1 for autoregressive generation
|
||||||
if predict_mode in ["text", "fast"]:
|
if predict_mode in ["text", "fast"]:
|
||||||
assert (
|
assert batch_size == 1, "predict only support batch size 1 for ar generation"
|
||||||
batch_size == 1
|
|
||||||
), "predict only support batch size 1 for ar generation"
|
|
||||||
|
|
||||||
# Set output configuration from model config if not specified
|
# Set output configuration from model config if not specified
|
||||||
output_attentions = (
|
output_attentions = (
|
||||||
output_attentions
|
output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
if output_attentions is not None
|
|
||||||
else self.config.output_attentions
|
|
||||||
)
|
)
|
||||||
output_hidden_states = (
|
output_hidden_states = (
|
||||||
output_hidden_states
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
if output_hidden_states is not None
|
|
||||||
else self.config.output_hidden_states
|
|
||||||
)
|
|
||||||
return_dict = (
|
|
||||||
return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
)
|
)
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
# Process input embeddings with multi-modal data
|
# Process input embeddings with multi-modal data
|
||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
@@ -1186,9 +1084,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
|
|||||||
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
|
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
|
||||||
image_mask = mask_expanded.to(inputs_embeds.device)
|
image_mask = mask_expanded.to(inputs_embeds.device)
|
||||||
|
|
||||||
image_embeds = image_embeds.to(
|
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||||
inputs_embeds.device, inputs_embeds.dtype
|
|
||||||
)
|
|
||||||
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
|
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
|
||||||
|
|
||||||
# Process video embeddings
|
# Process video embeddings
|
||||||
@@ -1209,40 +1105,28 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
|
|||||||
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
|
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
|
||||||
video_mask = mask_expanded.to(inputs_embeds.device)
|
video_mask = mask_expanded.to(inputs_embeds.device)
|
||||||
|
|
||||||
video_embeds = video_embeds.to(
|
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||||
inputs_embeds.device, inputs_embeds.dtype
|
|
||||||
)
|
|
||||||
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
|
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
|
||||||
|
|
||||||
# Process proprioceptive data
|
# Process proprioceptive data
|
||||||
if proprioception is not None:
|
if proprioception is not None:
|
||||||
proprioception = proprioception.to(inputs_embeds.device).to(
|
proprioception = proprioception.to(inputs_embeds.device).to(inputs_embeds.dtype)
|
||||||
inputs_embeds.dtype
|
agent_pos_mask = agent_pos_mask.to(inputs_embeds.device).to(inputs_embeds.dtype)
|
||||||
)
|
|
||||||
agent_pos_mask = agent_pos_mask.to(inputs_embeds.device).to(
|
|
||||||
inputs_embeds.dtype
|
|
||||||
)
|
|
||||||
proprio_embed = self.action_preprocessor.proprioception_proj(
|
proprio_embed = self.action_preprocessor.proprioception_proj(
|
||||||
proprioception,
|
proprioception,
|
||||||
agent_pos_mask,
|
agent_pos_mask,
|
||||||
use_history=proprioception.shape[1] > 1,
|
use_history=proprioception.shape[1] > 1,
|
||||||
)
|
)
|
||||||
proprioception_mask = (
|
proprioception_mask = input_ids == self.action_token_id_set["propri_token_id"]
|
||||||
input_ids == self.action_token_id_set["propri_token_id"]
|
|
||||||
)
|
|
||||||
proprio_embed = proprio_embed.to(torch.bfloat16)
|
proprio_embed = proprio_embed.to(torch.bfloat16)
|
||||||
inputs_embeds[proprioception_mask] = proprio_embed.reshape(
|
inputs_embeds[proprioception_mask] = proprio_embed.reshape(-1, inputs_embeds.shape[-1])
|
||||||
-1, inputs_embeds.shape[-1]
|
|
||||||
)
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
attention_mask = attention_mask.to(inputs_embeds.device)
|
attention_mask = attention_mask.to(inputs_embeds.device)
|
||||||
|
|
||||||
# Calculate RoPE position IDs if not provided
|
# Calculate RoPE position IDs if not provided
|
||||||
# Note: Cannot calculate rope deltas with 4D attention mask. TODO: Fix this limitation
|
# Note: Cannot calculate rope deltas with 4D attention mask. TODO: Fix this limitation
|
||||||
if position_ids is None and (
|
if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
|
||||||
attention_mask is None or attention_mask.ndim == 2
|
|
||||||
):
|
|
||||||
# Calculate RoPE index once per generation in the pre-fill stage only
|
# Calculate RoPE index once per generation in the pre-fill stage only
|
||||||
if (
|
if (
|
||||||
(cache_position is not None and cache_position[0] == 0)
|
(cache_position is not None and cache_position[0] == 0)
|
||||||
@@ -1332,12 +1216,8 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
|
|||||||
eos_token_id=[self.processor.tokenizer.eos_token_id],
|
eos_token_id=[self.processor.tokenizer.eos_token_id],
|
||||||
use_cache=True,
|
use_cache=True,
|
||||||
pad_token_id=self.processor.tokenizer.pad_token_id,
|
pad_token_id=self.processor.tokenizer.pad_token_id,
|
||||||
temperature=(
|
temperature=(1.0 if not re_generate else 0.7), # Higher temperature for regeneration
|
||||||
1.0 if not re_generate else 0.7
|
do_sample=(False if not re_generate else True), # Enable sampling for regeneration
|
||||||
), # Higher temperature for regeneration
|
|
||||||
do_sample=(
|
|
||||||
False if not re_generate else True
|
|
||||||
), # Enable sampling for regeneration
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Decode generated and ground truth text
|
# Decode generated and ground truth text
|
||||||
@@ -1359,15 +1239,9 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
|
|||||||
action_id = []
|
action_id = []
|
||||||
# Extract action tokens from generated sequence
|
# Extract action tokens from generated sequence
|
||||||
for token_id_i in predict_output_ids[0]:
|
for token_id_i in predict_output_ids[0]:
|
||||||
if (
|
if token_id_i.item() >= self.processor.tokenizer.init_kwargs["action_token_start_index"]:
|
||||||
token_id_i.item()
|
|
||||||
>= self.processor.tokenizer.init_kwargs["action_token_start_index"]
|
|
||||||
):
|
|
||||||
action_id.append(
|
action_id.append(
|
||||||
token_id_i.item()
|
token_id_i.item() - self.processor.tokenizer.init_kwargs["action_token_start_index"]
|
||||||
- self.processor.tokenizer.init_kwargs[
|
|
||||||
"action_token_start_index"
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
predict_action = self.processor.action_processor.decode(
|
predict_action = self.processor.action_processor.decode(
|
||||||
@@ -1382,9 +1256,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
|
|||||||
predict_action = torch.tensor(predict_action, device=self.device)
|
predict_action = torch.tensor(predict_action, device=self.device)
|
||||||
dof_mask = dof_mask.to(self.device).to(pixel_values.dtype)
|
dof_mask = dof_mask.to(self.device).to(pixel_values.dtype)
|
||||||
# removed unnormalization step for now
|
# removed unnormalization step for now
|
||||||
predict_action = (
|
predict_action = predict_action[:, :, dof_mask[0, 0, :].bool()]
|
||||||
predict_action[:, :, dof_mask[0, 0, :].bool()]
|
|
||||||
)
|
|
||||||
output["predict_action"] = predict_action
|
output["predict_action"] = predict_action
|
||||||
|
|
||||||
# Process ground truth actions if available
|
# Process ground truth actions if available
|
||||||
@@ -1459,7 +1331,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
|
|||||||
times = torch.linspace(
|
times = torch.linspace(
|
||||||
0,
|
0,
|
||||||
1,
|
1,
|
||||||
num_inference_timesteps+1,
|
num_inference_timesteps + 1,
|
||||||
device=inputs_embeds.device,
|
device=inputs_embeds.device,
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
)
|
)
|
||||||
@@ -1477,9 +1349,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
|
|||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def forward(
|
def forward(self, mode: str | None = None, predict_mode: str | None = "text", **kwargs):
|
||||||
self, mode: Optional[str] = None, predict_mode: Optional[str] = "text", **kwargs
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
Main forward pass dispatcher for different execution modes.
|
Main forward pass dispatcher for different execution modes.
|
||||||
|
|
||||||
@@ -1592,9 +1462,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
|
|||||||
|
|
||||||
# Handle input slicing based on cache state and special cases
|
# Handle input slicing based on cache state and special cases
|
||||||
if past_key_values is not None:
|
if past_key_values is not None:
|
||||||
if (
|
if inputs_embeds is not None and input_ids.shape[1] == 0: # Exception 4: input_embeds case
|
||||||
inputs_embeds is not None and input_ids.shape[1] == 0
|
|
||||||
): # Exception 4: input_embeds case
|
|
||||||
inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :]
|
inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :]
|
||||||
moe_token_types = moe_token_types[:, -cache_position.shape[0] :]
|
moe_token_types = moe_token_types[:, -cache_position.shape[0] :]
|
||||||
elif inputs_embeds is not None or ( # Exception 1: input_embeds provided
|
elif inputs_embeds is not None or ( # Exception 1: input_embeds provided
|
||||||
@@ -1602,9 +1470,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
|
|||||||
): # Exception 3: GPU sync edge case
|
): # Exception 3: GPU sync edge case
|
||||||
input_ids = input_ids[:, -cache_position.shape[0] :]
|
input_ids = input_ids[:, -cache_position.shape[0] :]
|
||||||
moe_token_types = moe_token_types[:, -cache_position.shape[0] :]
|
moe_token_types = moe_token_types[:, -cache_position.shape[0] :]
|
||||||
elif (
|
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (Exception 2 is no-op)
|
||||||
input_ids.shape[1] != cache_position.shape[0]
|
|
||||||
): # Default case (Exception 2 is no-op)
|
|
||||||
cache_pos = cache_position.clone()
|
cache_pos = cache_position.clone()
|
||||||
input_ids = input_ids[:, cache_pos]
|
input_ids = input_ids[:, cache_pos]
|
||||||
moe_token_types = moe_token_types[:, cache_pos]
|
moe_token_types = moe_token_types[:, cache_pos]
|
||||||
@@ -1629,8 +1495,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
|
|||||||
batch_size, sequence_length = input_ids.shape
|
batch_size, sequence_length = input_ids.shape
|
||||||
device = input_ids.device
|
device = input_ids.device
|
||||||
|
|
||||||
attention_mask = (
|
attention_mask = self.model._prepare_4d_causal_attention_mask_with_cache_position(
|
||||||
self.model._prepare_4d_causal_attention_mask_with_cache_position(
|
|
||||||
attention_mask,
|
attention_mask,
|
||||||
sequence_length=sequence_length,
|
sequence_length=sequence_length,
|
||||||
target_length=past_key_values.get_max_cache_shape(),
|
target_length=past_key_values.get_max_cache_shape(),
|
||||||
@@ -1641,7 +1506,6 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
|
|||||||
config=self.config,
|
config=self.config,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
# Assemble all model inputs for generation
|
# Assemble all model inputs for generation
|
||||||
model_inputs.update(
|
model_inputs.update(
|
||||||
@@ -1666,8 +1530,8 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
|
|||||||
|
|
||||||
def _get_image_nums_and_video_nums(
|
def _get_image_nums_and_video_nums(
|
||||||
self,
|
self,
|
||||||
input_ids: Optional[torch.LongTensor],
|
input_ids: torch.LongTensor | None,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Get the number of images and videos for each sample to calculate tensor separation lengths.
|
Get the number of images and videos for each sample to calculate tensor separation lengths.
|
||||||
|
|
||||||
@@ -1702,9 +1566,9 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
|
|||||||
self,
|
self,
|
||||||
expand_size: int = 1,
|
expand_size: int = 1,
|
||||||
is_encoder_decoder: bool = False,
|
is_encoder_decoder: bool = False,
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
input_ids: torch.LongTensor | None = None,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
) -> Tuple[torch.LongTensor, Dict[str, Any]]:
|
) -> tuple[torch.LongTensor, dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Expand inputs for generation with support for multi-modal tensors.
|
Expand inputs for generation with support for multi-modal tensors.
|
||||||
|
|
||||||
@@ -1745,9 +1609,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
|
|||||||
"""Split tensor by lengths and repeat each sample."""
|
"""Split tensor by lengths and repeat each sample."""
|
||||||
samples = torch.split(x, lengths)
|
samples = torch.split(x, lengths)
|
||||||
repeat_args = [repeat_times] + [1] * (x.dim() - 1)
|
repeat_args = [repeat_times] + [1] * (x.dim() - 1)
|
||||||
result = torch.cat(
|
result = torch.cat([sample.repeat(*repeat_args) for sample in samples], dim=0)
|
||||||
[sample.repeat(*repeat_args) for sample in samples], dim=0
|
|
||||||
)
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
for key in dict_to_expand:
|
for key in dict_to_expand:
|
||||||
@@ -1785,9 +1647,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
|
|||||||
)
|
)
|
||||||
tensor = torch.tensor(dict_to_expand[key])
|
tensor = torch.tensor(dict_to_expand[key])
|
||||||
lengths = list(video_nums)
|
lengths = list(video_nums)
|
||||||
tensor = _repeat_interleave_samples(
|
tensor = _repeat_interleave_samples(tensor, lengths=lengths, repeat_times=expand_size)
|
||||||
tensor, lengths=lengths, repeat_times=expand_size
|
|
||||||
)
|
|
||||||
dict_to_expand[key] = tensor.tolist()
|
dict_to_expand[key] = tensor.tolist()
|
||||||
return dict_to_expand
|
return dict_to_expand
|
||||||
|
|
||||||
@@ -1800,9 +1660,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
|
|||||||
and isinstance(dict_to_expand[key], torch.Tensor)
|
and isinstance(dict_to_expand[key], torch.Tensor)
|
||||||
and key not in visual_keys
|
and key not in visual_keys
|
||||||
):
|
):
|
||||||
dict_to_expand[key] = dict_to_expand[key].repeat_interleave(
|
dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0)
|
||||||
expand_size, dim=0
|
|
||||||
)
|
|
||||||
return dict_to_expand
|
return dict_to_expand
|
||||||
|
|
||||||
# Expand visual inputs only if input_ids is available for counting images/videos
|
# Expand visual inputs only if input_ids is available for counting images/videos
|
||||||
@@ -1823,9 +1681,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined."
|
"If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined."
|
||||||
)
|
)
|
||||||
model_kwargs["encoder_outputs"] = _expand_dict_for_generation(
|
model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"])
|
||||||
model_kwargs["encoder_outputs"]
|
|
||||||
)
|
|
||||||
|
|
||||||
return input_ids, model_kwargs
|
return input_ids, model_kwargs
|
||||||
|
|
||||||
@@ -1850,7 +1706,7 @@ class WallXPolicy(PreTrainedPolicy):
|
|||||||
self.model = Qwen2_5_VLMoEForAction.from_pretrained(
|
self.model = Qwen2_5_VLMoEForAction.from_pretrained(
|
||||||
pretrained_name_or_path=config.pretrained_name_or_path,
|
pretrained_name_or_path=config.pretrained_name_or_path,
|
||||||
action_tokenizer_path=config.action_tokenizer_path,
|
action_tokenizer_path=config.action_tokenizer_path,
|
||||||
attn_implementation=config.attn_implementation
|
attn_implementation=config.attn_implementation,
|
||||||
)
|
)
|
||||||
self.model.to(config.device)
|
self.model.to(config.device)
|
||||||
self.model.to_bfloat16_for_selected_params()
|
self.model.to_bfloat16_for_selected_params()
|
||||||
@@ -1869,7 +1725,7 @@ class WallXPolicy(PreTrainedPolicy):
|
|||||||
|
|
||||||
def preprocess_inputs(
|
def preprocess_inputs(
|
||||||
self,
|
self,
|
||||||
batch: Dict[str, Any],
|
batch: dict[str, Any],
|
||||||
) -> BatchFeature:
|
) -> BatchFeature:
|
||||||
"""
|
"""
|
||||||
Convert a batch of LeRobot dataset items to Wall-X model input format.
|
Convert a batch of LeRobot dataset items to Wall-X model input format.
|
||||||
@@ -1950,12 +1806,10 @@ class WallXPolicy(PreTrainedPolicy):
|
|||||||
)
|
)
|
||||||
|
|
||||||
text = process_grounding_points(
|
text = process_grounding_points(
|
||||||
complete_text, orig_height, orig_width, resized_height, resized_width,
|
complete_text, orig_height, orig_width, resized_height, resized_width, MODEL_TYPE
|
||||||
MODEL_TYPE
|
|
||||||
)
|
)
|
||||||
all_texts.append(text)
|
all_texts.append(text)
|
||||||
|
|
||||||
|
|
||||||
# ==================== PROCESS AGENT POS ====================
|
# ==================== PROCESS AGENT POS ====================
|
||||||
agent_pos = batch[OBS_STATE] # (batch_size, state_dim)
|
agent_pos = batch[OBS_STATE] # (batch_size, state_dim)
|
||||||
if agent_pos.dim() == 2:
|
if agent_pos.dim() == 2:
|
||||||
@@ -1965,17 +1819,28 @@ class WallXPolicy(PreTrainedPolicy):
|
|||||||
|
|
||||||
if agent_pos.shape[-1] != 20:
|
if agent_pos.shape[-1] != 20:
|
||||||
pad_size = 20 - agent_pos.shape[-1]
|
pad_size = 20 - agent_pos.shape[-1]
|
||||||
agent_pos = torch.cat([
|
agent_pos = torch.cat(
|
||||||
|
[
|
||||||
agent_pos,
|
agent_pos,
|
||||||
torch.zeros(agent_pos.shape[0], agent_pos.shape[1], pad_size, device=agent_pos.device)
|
torch.zeros(agent_pos.shape[0], agent_pos.shape[1], pad_size, device=agent_pos.device),
|
||||||
], dim=-1)
|
],
|
||||||
agent_pos_mask = torch.cat([
|
dim=-1,
|
||||||
|
)
|
||||||
|
agent_pos_mask = torch.cat(
|
||||||
|
[
|
||||||
agent_pos_mask,
|
agent_pos_mask,
|
||||||
torch.zeros(agent_pos_mask.shape[0], agent_pos_mask.shape[1], pad_size, device=agent_pos_mask.device)
|
torch.zeros(
|
||||||
], dim=-1)
|
agent_pos_mask.shape[0],
|
||||||
|
agent_pos_mask.shape[1],
|
||||||
|
pad_size,
|
||||||
|
device=agent_pos_mask.device,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
|
||||||
# ==================== PROCESS ACTIONS ====================
|
# ==================== PROCESS ACTIONS ====================
|
||||||
action = batch.get(ACTION, None) # (batch_size, chunk_size, action_dim)
|
action = batch.get(ACTION) # (batch_size, chunk_size, action_dim)
|
||||||
if action is not None:
|
if action is not None:
|
||||||
if action.dim() == 2:
|
if action.dim() == 2:
|
||||||
action = action.unsqueeze(1)
|
action = action.unsqueeze(1)
|
||||||
@@ -1984,20 +1849,30 @@ class WallXPolicy(PreTrainedPolicy):
|
|||||||
|
|
||||||
if action.shape[-1] != 20:
|
if action.shape[-1] != 20:
|
||||||
pad_size = 20 - action.shape[-1]
|
pad_size = 20 - action.shape[-1]
|
||||||
action = torch.cat([
|
action = torch.cat(
|
||||||
action,
|
[action, torch.zeros(action.shape[0], action.shape[1], pad_size, device=action.device)],
|
||||||
torch.zeros(action.shape[0], action.shape[1], pad_size, device=action.device)
|
dim=-1,
|
||||||
], dim=-1)
|
)
|
||||||
dof_mask = torch.cat([
|
dof_mask = torch.cat(
|
||||||
|
[
|
||||||
dof_mask,
|
dof_mask,
|
||||||
torch.zeros(dof_mask.shape[0], dof_mask.shape[1], pad_size, device=dof_mask.device)
|
torch.zeros(dof_mask.shape[0], dof_mask.shape[1], pad_size, device=dof_mask.device),
|
||||||
], dim=-1)
|
],
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
action_dim = self.config.output_features["action"].shape[0]
|
action_dim = self.config.output_features["action"].shape[0]
|
||||||
dof_mask = torch.cat([
|
dof_mask = torch.cat(
|
||||||
torch.ones(batch_size, self.config.chunk_size, action_dim, device=batch[OBS_STATE].device),
|
[
|
||||||
torch.zeros(batch_size, self.config.chunk_size, 20 - action_dim, device=batch[OBS_STATE].device)
|
torch.ones(
|
||||||
], dim=-1)
|
batch_size, self.config.chunk_size, action_dim, device=batch[OBS_STATE].device
|
||||||
|
),
|
||||||
|
torch.zeros(
|
||||||
|
batch_size, self.config.chunk_size, 20 - action_dim, device=batch[OBS_STATE].device
|
||||||
|
),
|
||||||
|
],
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
|
||||||
# ==================== ACTION TOKEN REPLACEMENT ====================
|
# ==================== ACTION TOKEN REPLACEMENT ====================
|
||||||
all_texts = replace_action_token(
|
all_texts = replace_action_token(
|
||||||
@@ -2028,7 +1903,11 @@ class WallXPolicy(PreTrainedPolicy):
|
|||||||
inputs["action_chunk"] = action
|
inputs["action_chunk"] = action
|
||||||
inputs["dof_mask"] = dof_mask
|
inputs["dof_mask"] = dof_mask
|
||||||
inputs["moe_token_types"] = moe_token_types
|
inputs["moe_token_types"] = moe_token_types
|
||||||
inputs["frame_index"] = batch["frame_index"] if "frame_index" in batch else torch.zeros(batch_size, device=batch[OBS_STATE].device)
|
inputs["frame_index"] = (
|
||||||
|
batch["frame_index"]
|
||||||
|
if "frame_index" in batch
|
||||||
|
else torch.zeros(batch_size, device=batch[OBS_STATE].device)
|
||||||
|
)
|
||||||
|
|
||||||
# Move all tensors to the correct device
|
# Move all tensors to the correct device
|
||||||
device = self.config.device
|
device = self.config.device
|
||||||
@@ -2056,10 +1935,7 @@ class WallXPolicy(PreTrainedPolicy):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Call the underlying model's forward with mode="train"
|
# Call the underlying model's forward with mode="train"
|
||||||
outputs = self.model(
|
outputs = self.model(**batch, mode="train")
|
||||||
**batch,
|
|
||||||
mode="train"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Extract losses from output
|
# Extract losses from output
|
||||||
loss = outputs.loss
|
loss = outputs.loss
|
||||||
@@ -2096,7 +1972,7 @@ class WallXPolicy(PreTrainedPolicy):
|
|||||||
action_dim=self.config.max_action_dim,
|
action_dim=self.config.max_action_dim,
|
||||||
pred_horizon=self.config.chunk_size,
|
pred_horizon=self.config.chunk_size,
|
||||||
mode="predict",
|
mode="predict",
|
||||||
predict_mode="diffusion"
|
predict_mode="diffusion",
|
||||||
)
|
)
|
||||||
elif self.config.prediction_mode == "fast":
|
elif self.config.prediction_mode == "fast":
|
||||||
output = self.model(
|
output = self.model(
|
||||||
@@ -2104,7 +1980,7 @@ class WallXPolicy(PreTrainedPolicy):
|
|||||||
action_dim=self.config.output_features["action"].shape[0],
|
action_dim=self.config.output_features["action"].shape[0],
|
||||||
pred_horizon=self.config.chunk_size,
|
pred_horizon=self.config.chunk_size,
|
||||||
mode="predict",
|
mode="predict",
|
||||||
predict_mode="fast"
|
predict_mode="fast",
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"Prediction mode {self.config.prediction_mode} not implemented")
|
raise NotImplementedError(f"Prediction mode {self.config.prediction_mode} not implemented")
|
||||||
@@ -2127,6 +2003,6 @@ class WallXPolicy(PreTrainedPolicy):
|
|||||||
# Use action queue
|
# Use action queue
|
||||||
if len(self._queues[ACTION]) == 0:
|
if len(self._queues[ACTION]) == 0:
|
||||||
actions = self.predict_action_chunk(batch)
|
actions = self.predict_action_chunk(batch)
|
||||||
self._queues[ACTION].extend(actions.transpose(0, 1)[:self.config.n_action_steps])
|
self._queues[ACTION].extend(actions.transpose(0, 1)[: self.config.n_action_steps])
|
||||||
|
|
||||||
return self._queues[ACTION].popleft()
|
return self._queues[ACTION].popleft()
|
||||||
|
|||||||
@@ -33,6 +33,8 @@ from lerobot.processor import (
|
|||||||
)
|
)
|
||||||
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
|
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
|
||||||
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||||
|
|
||||||
|
|
||||||
def make_wall_x_pre_post_processors(
|
def make_wall_x_pre_post_processors(
|
||||||
config: WallXConfig,
|
config: WallXConfig,
|
||||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||||
@@ -75,9 +77,7 @@ def make_wall_x_pre_post_processors(
|
|||||||
|
|
||||||
output_steps = [
|
output_steps = [
|
||||||
UnnormalizerProcessorStep(
|
UnnormalizerProcessorStep(
|
||||||
features=config.output_features,
|
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||||
norm_map=config.normalization_mapping,
|
|
||||||
stats=dataset_stats
|
|
||||||
),
|
),
|
||||||
DeviceProcessorStep(device="cpu"),
|
DeviceProcessorStep(device="cpu"),
|
||||||
]
|
]
|
||||||
@@ -123,9 +123,7 @@ class WallXTaskProcessor(ComplementaryDataProcessorStep):
|
|||||||
new_complementary_data["task"] = f"{task}."
|
new_complementary_data["task"] = f"{task}."
|
||||||
elif isinstance(task, list) and all(isinstance(t, str) for t in task):
|
elif isinstance(task, list) and all(isinstance(t, str) for t in task):
|
||||||
# List of strings: format each
|
# List of strings: format each
|
||||||
new_complementary_data["task"] = [
|
new_complementary_data["task"] = [t if t.endswith(".") else f"{t}." for t in task]
|
||||||
t if t.endswith(".") else f"{t}." for t in task
|
|
||||||
]
|
|
||||||
|
|
||||||
return new_complementary_data
|
return new_complementary_data
|
||||||
|
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -25,7 +25,7 @@ import random
|
|||||||
import re
|
import re
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
from typing import Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transformers import BatchFeature
|
from transformers import BatchFeature
|
||||||
@@ -46,11 +46,11 @@ class X2RDataProcessingConfig:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# Action prediction configuration
|
# Action prediction configuration
|
||||||
predict_action_keys: List[str] = field(default_factory=list)
|
predict_action_keys: list[str] = field(default_factory=list)
|
||||||
obs_action_keys: List[str] = field(default_factory=list)
|
obs_action_keys: list[str] = field(default_factory=list)
|
||||||
|
|
||||||
# Image resolution settings for different views
|
# Image resolution settings for different views
|
||||||
resolution: Dict[str, int] = field(
|
resolution: dict[str, int] = field(
|
||||||
default_factory=lambda: {
|
default_factory=lambda: {
|
||||||
"face_view": -1,
|
"face_view": -1,
|
||||||
"left_wrist_view": 128,
|
"left_wrist_view": 128,
|
||||||
@@ -63,7 +63,7 @@ class X2RDataProcessingConfig:
|
|||||||
split_seed: int = 42
|
split_seed: int = 42
|
||||||
|
|
||||||
# Instruction handling
|
# Instruction handling
|
||||||
priority_order: Optional[Dict[str, float]] = None
|
priority_order: dict[str, float] | None = None
|
||||||
|
|
||||||
# Vision model parameters
|
# Vision model parameters
|
||||||
model_type: str = "qwen2_5"
|
model_type: str = "qwen2_5"
|
||||||
@@ -77,11 +77,9 @@ class X2RDataProcessingConfig:
|
|||||||
"""Post-initialization validation and setup."""
|
"""Post-initialization validation and setup."""
|
||||||
# Validate train/test split
|
# Validate train/test split
|
||||||
if not 0 < self.train_test_split < 1:
|
if not 0 < self.train_test_split < 1:
|
||||||
raise ValueError(
|
raise ValueError(f"train_test_split must be between 0 and 1, got {self.train_test_split}")
|
||||||
f"train_test_split must be between 0 and 1, got {self.train_test_split}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def as_dict(self) -> Dict:
|
def as_dict(self) -> dict:
|
||||||
"""Convert configuration to dictionary format.
|
"""Convert configuration to dictionary format.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -105,14 +103,15 @@ class X2RDataProcessingConfig:
|
|||||||
raise ValueError(f"Unknown configuration parameter: {key}")
|
raise ValueError(f"Unknown configuration parameter: {key}")
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
def preprocesser_call(
|
def preprocesser_call(
|
||||||
processor,
|
processor,
|
||||||
images: Optional[Union[List, Any]] = None,
|
images: list | Any | None = None,
|
||||||
text: Optional[Union[str, List[str]]] = None,
|
text: str | list[str] | None = None,
|
||||||
videos: Optional[Union[List, Any]] = None,
|
videos: list | Any | None = None,
|
||||||
padding: Union[bool, str] = False,
|
padding: bool | str = False,
|
||||||
truncation: Optional[bool] = None,
|
truncation: bool | None = None,
|
||||||
max_length: Optional[int] = None,
|
max_length: int | None = None,
|
||||||
return_tensors: str = "pt",
|
return_tensors: str = "pt",
|
||||||
) -> BatchFeature:
|
) -> BatchFeature:
|
||||||
"""Unified preprocessing function for Wall-X model handling text, image and video inputs.
|
"""Unified preprocessing function for Wall-X model handling text, image and video inputs.
|
||||||
@@ -145,9 +144,7 @@ def preprocesser_call(
|
|||||||
"""
|
"""
|
||||||
# Process image inputs
|
# Process image inputs
|
||||||
if images is not None and len(images) > 0:
|
if images is not None and len(images) > 0:
|
||||||
image_inputs = processor.image_processor(
|
image_inputs = processor.image_processor(images=images, videos=None, return_tensors=return_tensors)
|
||||||
images=images, videos=None, return_tensors=return_tensors
|
|
||||||
)
|
|
||||||
image_grid_thw = image_inputs["image_grid_thw"]
|
image_grid_thw = image_inputs["image_grid_thw"]
|
||||||
else:
|
else:
|
||||||
image_inputs = {}
|
image_inputs = {}
|
||||||
@@ -155,9 +152,7 @@ def preprocesser_call(
|
|||||||
|
|
||||||
# Process video inputs
|
# Process video inputs
|
||||||
if videos is not None:
|
if videos is not None:
|
||||||
videos_inputs = processor.image_processor(
|
videos_inputs = processor.image_processor(images=None, videos=videos, return_tensors=return_tensors)
|
||||||
images=None, videos=videos, return_tensors=return_tensors
|
|
||||||
)
|
|
||||||
video_grid_thw = videos_inputs["video_grid_thw"]
|
video_grid_thw = videos_inputs["video_grid_thw"]
|
||||||
else:
|
else:
|
||||||
videos_inputs = {}
|
videos_inputs = {}
|
||||||
@@ -183,9 +178,7 @@ def preprocesser_call(
|
|||||||
break
|
break
|
||||||
# Replace image placeholder with actual token count
|
# Replace image placeholder with actual token count
|
||||||
token_count = image_grid_thw[index].prod() // merge_length
|
token_count = image_grid_thw[index].prod() // merge_length
|
||||||
text[i] = text[i].replace(
|
text[i] = text[i].replace("<|image_pad|>", "<|placeholder|>" * token_count, 1)
|
||||||
"<|image_pad|>", "<|placeholder|>" * token_count, 1
|
|
||||||
)
|
|
||||||
index += 1
|
index += 1
|
||||||
text[i] = text[i].replace("<|placeholder|>", "<|image_pad|>")
|
text[i] = text[i].replace("<|placeholder|>", "<|image_pad|>")
|
||||||
|
|
||||||
@@ -197,9 +190,7 @@ def preprocesser_call(
|
|||||||
while "<|video_pad|>" in text[i]:
|
while "<|video_pad|>" in text[i]:
|
||||||
# Replace video placeholder with actual token count
|
# Replace video placeholder with actual token count
|
||||||
token_count = video_grid_thw[index].prod() // merge_length
|
token_count = video_grid_thw[index].prod() // merge_length
|
||||||
text[i] = text[i].replace(
|
text[i] = text[i].replace("<|video_pad|>", "<|placeholder|>" * token_count, 1)
|
||||||
"<|video_pad|>", "<|placeholder|>" * token_count, 1
|
|
||||||
)
|
|
||||||
index += 1
|
index += 1
|
||||||
text[i] = text[i].replace("<|placeholder|>", "<|video_pad|>")
|
text[i] = text[i].replace("<|placeholder|>", "<|video_pad|>")
|
||||||
|
|
||||||
@@ -221,9 +212,7 @@ def preprocesser_call(
|
|||||||
labels = torch.full_like(text_inputs.input_ids, -100)
|
labels = torch.full_like(text_inputs.input_ids, -100)
|
||||||
assistant_marker = "<|im_start|>assistant\n"
|
assistant_marker = "<|im_start|>assistant\n"
|
||||||
im_end_token_id = processor.tokenizer.convert_tokens_to_ids("<|im_end|>")
|
im_end_token_id = processor.tokenizer.convert_tokens_to_ids("<|im_end|>")
|
||||||
assistant_tokens = processor.tokenizer(
|
assistant_tokens = processor.tokenizer("<|im_start|>assistant\n", add_special_tokens=False).input_ids
|
||||||
"<|im_start|>assistant\n", add_special_tokens=False
|
|
||||||
).input_ids
|
|
||||||
|
|
||||||
for i in range(len(text)):
|
for i in range(len(text)):
|
||||||
assistant_regions = []
|
assistant_regions = []
|
||||||
@@ -249,9 +238,7 @@ def preprocesser_call(
|
|||||||
# From second part onwards, each part starts with assistant response
|
# From second part onwards, each part starts with assistant response
|
||||||
for k in range(current_pos + 1, len(text_inputs.input_ids[i])):
|
for k in range(current_pos + 1, len(text_inputs.input_ids[i])):
|
||||||
if text_inputs.input_ids[i][k] == im_end_token_id:
|
if text_inputs.input_ids[i][k] == im_end_token_id:
|
||||||
assistant_regions.append(
|
assistant_regions.append((current_pos + len(assistant_tokens), k + 2))
|
||||||
(current_pos + len(assistant_tokens), k + 2)
|
|
||||||
)
|
|
||||||
break
|
break
|
||||||
current_pos += len(part_tokens) + 3
|
current_pos += len(part_tokens) + 3
|
||||||
|
|
||||||
@@ -344,7 +331,7 @@ def process_grounding_points(
|
|||||||
coords = [new_x1, new_y1, new_x2, new_y2]
|
coords = [new_x1, new_y1, new_x2, new_y2]
|
||||||
|
|
||||||
# Return processed point tag
|
# Return processed point tag
|
||||||
return f'<point>[{", ".join(map(str, coords))}]</point>'
|
return f"<point>[{', '.join(map(str, coords))}]</point>"
|
||||||
|
|
||||||
except (ValueError, TypeError):
|
except (ValueError, TypeError):
|
||||||
# Return original content if processing fails
|
# Return original content if processing fails
|
||||||
@@ -356,10 +343,10 @@ def process_grounding_points(
|
|||||||
|
|
||||||
|
|
||||||
def get_frame_instruction(
|
def get_frame_instruction(
|
||||||
instruction_info: Dict[str, Any],
|
instruction_info: dict[str, Any],
|
||||||
frame_idx: Optional[int] = None,
|
frame_idx: int | None = None,
|
||||||
truncate_keys: Optional[List[str]] = None,
|
truncate_keys: list[str] | None = None,
|
||||||
) -> Tuple[Dict[str, Any], Optional[int]]:
|
) -> tuple[dict[str, Any], int | None]:
|
||||||
"""Extract frame-specific instruction from instruction dictionary.
|
"""Extract frame-specific instruction from instruction dictionary.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -388,11 +375,7 @@ def get_frame_instruction(
|
|||||||
start_frame, end_frame = map(int, frame_range.split(" "))
|
start_frame, end_frame = map(int, frame_range.split(" "))
|
||||||
if start_frame <= frame_idx < end_frame or (start_frame == frame_idx):
|
if start_frame <= frame_idx < end_frame or (start_frame == frame_idx):
|
||||||
instruction_for_frame[key] = frame_instruction
|
instruction_for_frame[key] = frame_instruction
|
||||||
if (
|
if truncate_keys is not None and split_end is None and key in truncate_keys:
|
||||||
truncate_keys is not None
|
|
||||||
and split_end is None
|
|
||||||
and key in truncate_keys
|
|
||||||
):
|
|
||||||
split_end = end_frame + 1
|
split_end = end_frame + 1
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
@@ -402,7 +385,7 @@ def get_frame_instruction(
|
|||||||
|
|
||||||
|
|
||||||
def get_task_instruction(
|
def get_task_instruction(
|
||||||
frame_instruction_info: Dict[str, Any], priority_order: Optional[OrderedDict] = None
|
frame_instruction_info: dict[str, Any], priority_order: OrderedDict | None = None
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Construct task instruction from available instruction fields using priority sampling.
|
"""Construct task instruction from available instruction fields using priority sampling.
|
||||||
|
|
||||||
@@ -450,13 +433,13 @@ def get_task_instruction(
|
|||||||
|
|
||||||
|
|
||||||
def get_wallx_normal_text(
|
def get_wallx_normal_text(
|
||||||
instruction_info: Dict[str, Any],
|
instruction_info: dict[str, Any],
|
||||||
action_chunk_size: int,
|
action_chunk_size: int,
|
||||||
frame_idx: int,
|
frame_idx: int,
|
||||||
priority_order: Optional[OrderedDict] = None,
|
priority_order: OrderedDict | None = None,
|
||||||
img_keys: Optional[List[str]] = None,
|
img_keys: list[str] | None = None,
|
||||||
generate_subtask_ratio: float = 0.0,
|
generate_subtask_ratio: float = 0.0,
|
||||||
) -> Tuple[str, bool]:
|
) -> tuple[str, bool]:
|
||||||
"""Construct complete multimodal prompt text for Wall-X model.
|
"""Construct complete multimodal prompt text for Wall-X model.
|
||||||
|
|
||||||
Formats input using special tokens including:
|
Formats input using special tokens including:
|
||||||
@@ -488,9 +471,7 @@ def get_wallx_normal_text(
|
|||||||
action_fast_symbol = "<|action_fast|>"
|
action_fast_symbol = "<|action_fast|>"
|
||||||
|
|
||||||
# System prologue
|
# System prologue
|
||||||
prologue = (
|
prologue = f"{role_start_symbol}system\nYou are a helpful assistant.{role_end_symbol}\n"
|
||||||
f"{role_start_symbol}system\nYou are a helpful assistant.{role_end_symbol}\n"
|
|
||||||
)
|
|
||||||
|
|
||||||
# User request with observation
|
# User request with observation
|
||||||
user_request = f"{role_start_symbol}user\nObservation:"
|
user_request = f"{role_start_symbol}user\nObservation:"
|
||||||
@@ -501,9 +482,7 @@ def get_wallx_normal_text(
|
|||||||
user_request += "\nInstruction:"
|
user_request += "\nInstruction:"
|
||||||
|
|
||||||
# Get frame-specific instruction
|
# Get frame-specific instruction
|
||||||
frame_instruction_info, _ = get_frame_instruction(
|
frame_instruction_info, _ = get_frame_instruction(instruction_info, frame_idx=frame_idx)
|
||||||
instruction_info, frame_idx=frame_idx
|
|
||||||
)
|
|
||||||
|
|
||||||
generate_subtask = False
|
generate_subtask = False
|
||||||
priority_keys = ["subtask_generation", "distribute"]
|
priority_keys = ["subtask_generation", "distribute"]
|
||||||
@@ -524,15 +503,11 @@ def get_wallx_normal_text(
|
|||||||
output_instruction = frame_instruction_info[key]
|
output_instruction = frame_instruction_info[key]
|
||||||
break
|
break
|
||||||
|
|
||||||
assistant_output = (
|
assistant_output = f"{role_start_symbol}assistant\n{output_instruction}\n{role_end_symbol}"
|
||||||
f"{role_start_symbol}assistant\n{output_instruction}\n{role_end_symbol}"
|
|
||||||
)
|
|
||||||
generate_subtask = True
|
generate_subtask = True
|
||||||
else:
|
else:
|
||||||
# Generate actions
|
# Generate actions
|
||||||
instruction = get_task_instruction(
|
instruction = get_task_instruction(frame_instruction_info, priority_order=priority_order)
|
||||||
frame_instruction_info, priority_order=priority_order
|
|
||||||
)
|
|
||||||
text_prompt = f"\nPredict the next action in robot action.\nProprioception: {propri_symbol}\n"
|
text_prompt = f"\nPredict the next action in robot action.\nProprioception: {propri_symbol}\n"
|
||||||
user_message = f"{user_request} {instruction}{text_prompt}{role_end_symbol}\n"
|
user_message = f"{user_request} {instruction}{text_prompt}{role_end_symbol}\n"
|
||||||
assistant_output = f"{role_start_symbol}assistant\n{action_fast_symbol}{role_end_symbol}\n{action_symbol * action_chunk_size}"
|
assistant_output = f"{role_start_symbol}assistant\n{action_fast_symbol}{role_end_symbol}\n{action_symbol * action_chunk_size}"
|
||||||
@@ -540,7 +515,8 @@ def get_wallx_normal_text(
|
|||||||
complete_text = prologue + user_message + assistant_output
|
complete_text = prologue + user_message + assistant_output
|
||||||
return complete_text, generate_subtask
|
return complete_text, generate_subtask
|
||||||
|
|
||||||
def img_key_mapping(img_keys: List[str]) -> List[str]:
|
|
||||||
|
def img_key_mapping(img_keys: list[str]) -> list[str]:
|
||||||
"""Map image keys to camera names.
|
"""Map image keys to camera names.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -555,16 +531,15 @@ def img_key_mapping(img_keys: List[str]) -> List[str]:
|
|||||||
if key in CAMERA_NAME_MAPPING:
|
if key in CAMERA_NAME_MAPPING:
|
||||||
key = CAMERA_NAME_MAPPING[key]
|
key = CAMERA_NAME_MAPPING[key]
|
||||||
else:
|
else:
|
||||||
if 'view' in key:
|
if "view" in key:
|
||||||
key = key.replace('_', ' ')
|
key = key.replace("_", " ")
|
||||||
else:
|
else:
|
||||||
key = key + " view"
|
key = key + " view"
|
||||||
processed_img_keys.append(key)
|
processed_img_keys.append(key)
|
||||||
return processed_img_keys
|
return processed_img_keys
|
||||||
|
|
||||||
def get_action_tokens(
|
|
||||||
normalized_actions: Union[torch.Tensor, List], action_tokenizer
|
def get_action_tokens(normalized_actions: torch.Tensor | list, action_tokenizer) -> list[list[str]]:
|
||||||
) -> List[List[str]]:
|
|
||||||
"""Convert normalized actions to action token strings.
|
"""Convert normalized actions to action token strings.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -590,8 +565,8 @@ def get_action_tokens(
|
|||||||
|
|
||||||
|
|
||||||
def pad_action_token_strs(
|
def pad_action_token_strs(
|
||||||
actions_token_lists: List[List[str]], pad_token: str = "<|endoftext|>"
|
actions_token_lists: list[list[str]], pad_token: str = "<|endoftext|>"
|
||||||
) -> List[str]:
|
) -> list[str]:
|
||||||
"""Pad action token lists to same length and join as strings.
|
"""Pad action token lists to same length and join as strings.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -605,20 +580,18 @@ def pad_action_token_strs(
|
|||||||
padded_action_strs = []
|
padded_action_strs = []
|
||||||
|
|
||||||
for tokens in actions_token_lists:
|
for tokens in actions_token_lists:
|
||||||
padded_tokens = (
|
padded_tokens = tokens + ["<|im_end|>\n"] + [pad_token] * (max_len - len(tokens))
|
||||||
tokens + ["<|im_end|>\n"] + [pad_token] * (max_len - len(tokens))
|
|
||||||
)
|
|
||||||
padded_action_strs.append("".join(padded_tokens))
|
padded_action_strs.append("".join(padded_tokens))
|
||||||
|
|
||||||
return padded_action_strs
|
return padded_action_strs
|
||||||
|
|
||||||
|
|
||||||
def replace_action_token(
|
def replace_action_token(
|
||||||
text: List[str],
|
text: list[str],
|
||||||
norm_action: Optional[torch.Tensor],
|
norm_action: torch.Tensor | None,
|
||||||
action_tokenizer,
|
action_tokenizer,
|
||||||
dof_masks: Optional[torch.Tensor] = None,
|
dof_masks: torch.Tensor | None = None,
|
||||||
) -> List[str]:
|
) -> list[str]:
|
||||||
"""Replace action placeholders in text with actual action tokens.
|
"""Replace action placeholders in text with actual action tokens.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -632,10 +605,7 @@ def replace_action_token(
|
|||||||
"""
|
"""
|
||||||
if action_tokenizer is not None and norm_action is not None:
|
if action_tokenizer is not None and norm_action is not None:
|
||||||
# Extract actions based on chunk sizes and DOF masks
|
# Extract actions based on chunk sizes and DOF masks
|
||||||
norm_action = [
|
norm_action = [action[:32, dof_masks[i, 0].bool()] for i, action in enumerate(norm_action)]
|
||||||
action[: 32, dof_masks[i, 0].bool()]
|
|
||||||
for i, action in enumerate(norm_action)
|
|
||||||
]
|
|
||||||
|
|
||||||
# Convert to action tokens and pad
|
# Convert to action tokens and pad
|
||||||
actions_fast_tokens = get_action_tokens(norm_action, action_tokenizer)
|
actions_fast_tokens = get_action_tokens(norm_action, action_tokenizer)
|
||||||
@@ -658,4 +628,3 @@ def replace_action_token(
|
|||||||
text = [t.replace("<|action_fast|><|im_end|>\n", "") for t in text]
|
text = [t.replace("<|action_fast|><|im_end|>\n", "") for t in text]
|
||||||
|
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
|||||||
@@ -35,10 +35,11 @@ from lerobot.policies.wall_x import ( # noqa: E402
|
|||||||
)
|
)
|
||||||
from lerobot.utils.random_utils import set_seed # noqa: E402
|
from lerobot.utils.random_utils import set_seed # noqa: E402
|
||||||
|
|
||||||
|
|
||||||
def test_policy_instantiation():
|
def test_policy_instantiation():
|
||||||
# Create config
|
# Create config
|
||||||
set_seed(42)
|
set_seed(42)
|
||||||
config = WallXConfig(device='cuda')
|
config = WallXConfig(device="cuda")
|
||||||
|
|
||||||
# Set up input_features and output_features in the config
|
# Set up input_features and output_features in the config
|
||||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||||
@@ -118,6 +119,7 @@ def test_policy_instantiation():
|
|||||||
print(f"Action prediction failed: {e}")
|
print(f"Action prediction failed: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
def test_config_creation():
|
def test_config_creation():
|
||||||
"""Test policy config creation through factory."""
|
"""Test policy config creation through factory."""
|
||||||
try:
|
try:
|
||||||
@@ -130,6 +132,7 @@ def test_config_creation():
|
|||||||
print(f"Config creation failed: {e}")
|
print(f"Config creation failed: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_policy_instantiation()
|
test_policy_instantiation()
|
||||||
test_config_creation()
|
test_config_creation()
|
||||||
Reference in New Issue
Block a user