mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-21 03:30:10 +00:00
fix pre-commit errors
This commit is contained in:
committed by
Michel Aractingi
parent
9ce6dd9e25
commit
a0c9a7d85d
@@ -6,12 +6,13 @@ This repository contains the Hugging Face port of **WALL-OSS**, a Vision-Languag
|
|||||||
|
|
||||||
## Model Overview
|
## Model Overview
|
||||||
|
|
||||||
| Feature | Description |
|
| Feature | Description |
|
||||||
| -------------------- | ------------------------------------------------------------------------ |
|
| ------------------ | ----------------------------------------------------- | --- |
|
||||||
| Base Model | Qwen2.5-VL (Vision-Language Model) |
|
| Base Model | Qwen2.5-VL (Vision-Language Model) |
|
||||||
| Action Prediction | Flow Matching (diffusion) or FAST (discrete tokens) |
|
| Action Prediction | Flow Matching (diffusion) or FAST (discrete tokens) |
|
||||||
| Architecture | Mixture of Experts (MoE) with action-specific routing | |
|
| Architecture | Mixture of Experts (MoE) with action-specific routing | |
|
||||||
| Multi-Modal Inputs | Vision (images/videos), Language, Proprioception |
|
| Multi-Modal Inputs | Vision (images/videos), Language, Proprioception |
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## Citation
|
## Citation
|
||||||
@@ -32,4 +33,3 @@ If you use this work, please cite:
|
|||||||
## License
|
## License
|
||||||
|
|
||||||
This port follows the **Apache 2.0 License**.
|
This port follows the **Apache 2.0 License**.
|
||||||
|
|
||||||
|
|||||||
@@ -49,7 +49,7 @@ class WallXConfig(PreTrainedConfig):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# ==================== Action Prediction ====================
|
# ==================== Action Prediction ====================
|
||||||
# Pretrained model paths
|
# Pretrained model paths
|
||||||
pretrained_name_or_path: str = "x-square-robot/wall-oss-flow"
|
pretrained_name_or_path: str = "x-square-robot/wall-oss-flow"
|
||||||
|
|
||||||
@@ -85,20 +85,16 @@ class WallXConfig(PreTrainedConfig):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.prediction_mode not in ["diffusion", "fast"]:
|
if self.prediction_mode not in ["diffusion", "fast"]:
|
||||||
raise ValueError(
|
raise ValueError(f"prediction_mode must be 'diffusion' or 'fast', got {self.prediction_mode}")
|
||||||
f"prediction_mode must be 'diffusion' or 'fast', got {self.prediction_mode}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Assign use_fast_tokenizer based on prediction_mode
|
# Assign use_fast_tokenizer based on prediction_mode
|
||||||
if self.prediction_mode == "fast":
|
if self.prediction_mode == "fast":
|
||||||
self.use_fast_tokenizer = True
|
self.use_fast_tokenizer = True
|
||||||
elif self.prediction_mode == "diffusion":
|
elif self.prediction_mode == "diffusion":
|
||||||
self.use_fast_tokenizer = False
|
self.use_fast_tokenizer = False
|
||||||
self.action_tokenizer_path = None # disable action tokenizer for diffusion mode
|
self.action_tokenizer_path = None # disable action tokenizer for diffusion mode
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(f"prediction_mode must be 'diffusion' or 'fast', got {self.prediction_mode}")
|
||||||
f"prediction_mode must be 'diffusion' or 'fast', got {self.prediction_mode}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def validate_features(self) -> None:
|
def validate_features(self) -> None:
|
||||||
"""Validate and set up input/output features."""
|
"""Validate and set up input/output features."""
|
||||||
|
|||||||
@@ -38,4 +38,4 @@ PRIORITY_ORDER = None
|
|||||||
GENERATE_SUBTASK_RATIO = 0.0
|
GENERATE_SUBTASK_RATIO = 0.0
|
||||||
MODEL_TYPE = "qwen2_5"
|
MODEL_TYPE = "qwen2_5"
|
||||||
|
|
||||||
TOKENIZER_MAX_LENGTH = 768
|
TOKENIZER_MAX_LENGTH = 768
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -33,6 +33,8 @@ from lerobot.processor import (
|
|||||||
)
|
)
|
||||||
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
|
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
|
||||||
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||||
|
|
||||||
|
|
||||||
def make_wall_x_pre_post_processors(
|
def make_wall_x_pre_post_processors(
|
||||||
config: WallXConfig,
|
config: WallXConfig,
|
||||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||||
@@ -75,9 +77,7 @@ def make_wall_x_pre_post_processors(
|
|||||||
|
|
||||||
output_steps = [
|
output_steps = [
|
||||||
UnnormalizerProcessorStep(
|
UnnormalizerProcessorStep(
|
||||||
features=config.output_features,
|
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||||
norm_map=config.normalization_mapping,
|
|
||||||
stats=dataset_stats
|
|
||||||
),
|
),
|
||||||
DeviceProcessorStep(device="cpu"),
|
DeviceProcessorStep(device="cpu"),
|
||||||
]
|
]
|
||||||
@@ -123,9 +123,7 @@ class WallXTaskProcessor(ComplementaryDataProcessorStep):
|
|||||||
new_complementary_data["task"] = f"{task}."
|
new_complementary_data["task"] = f"{task}."
|
||||||
elif isinstance(task, list) and all(isinstance(t, str) for t in task):
|
elif isinstance(task, list) and all(isinstance(t, str) for t in task):
|
||||||
# List of strings: format each
|
# List of strings: format each
|
||||||
new_complementary_data["task"] = [
|
new_complementary_data["task"] = [t if t.endswith(".") else f"{t}." for t in task]
|
||||||
t if t.endswith(".") else f"{t}." for t in task
|
|
||||||
]
|
|
||||||
|
|
||||||
return new_complementary_data
|
return new_complementary_data
|
||||||
|
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -25,7 +25,7 @@ import random
|
|||||||
import re
|
import re
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
from typing import Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transformers import BatchFeature
|
from transformers import BatchFeature
|
||||||
@@ -46,11 +46,11 @@ class X2RDataProcessingConfig:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# Action prediction configuration
|
# Action prediction configuration
|
||||||
predict_action_keys: List[str] = field(default_factory=list)
|
predict_action_keys: list[str] = field(default_factory=list)
|
||||||
obs_action_keys: List[str] = field(default_factory=list)
|
obs_action_keys: list[str] = field(default_factory=list)
|
||||||
|
|
||||||
# Image resolution settings for different views
|
# Image resolution settings for different views
|
||||||
resolution: Dict[str, int] = field(
|
resolution: dict[str, int] = field(
|
||||||
default_factory=lambda: {
|
default_factory=lambda: {
|
||||||
"face_view": -1,
|
"face_view": -1,
|
||||||
"left_wrist_view": 128,
|
"left_wrist_view": 128,
|
||||||
@@ -63,7 +63,7 @@ class X2RDataProcessingConfig:
|
|||||||
split_seed: int = 42
|
split_seed: int = 42
|
||||||
|
|
||||||
# Instruction handling
|
# Instruction handling
|
||||||
priority_order: Optional[Dict[str, float]] = None
|
priority_order: dict[str, float] | None = None
|
||||||
|
|
||||||
# Vision model parameters
|
# Vision model parameters
|
||||||
model_type: str = "qwen2_5"
|
model_type: str = "qwen2_5"
|
||||||
@@ -77,11 +77,9 @@ class X2RDataProcessingConfig:
|
|||||||
"""Post-initialization validation and setup."""
|
"""Post-initialization validation and setup."""
|
||||||
# Validate train/test split
|
# Validate train/test split
|
||||||
if not 0 < self.train_test_split < 1:
|
if not 0 < self.train_test_split < 1:
|
||||||
raise ValueError(
|
raise ValueError(f"train_test_split must be between 0 and 1, got {self.train_test_split}")
|
||||||
f"train_test_split must be between 0 and 1, got {self.train_test_split}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def as_dict(self) -> Dict:
|
def as_dict(self) -> dict:
|
||||||
"""Convert configuration to dictionary format.
|
"""Convert configuration to dictionary format.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -105,14 +103,15 @@ class X2RDataProcessingConfig:
|
|||||||
raise ValueError(f"Unknown configuration parameter: {key}")
|
raise ValueError(f"Unknown configuration parameter: {key}")
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
def preprocesser_call(
|
def preprocesser_call(
|
||||||
processor,
|
processor,
|
||||||
images: Optional[Union[List, Any]] = None,
|
images: list | Any | None = None,
|
||||||
text: Optional[Union[str, List[str]]] = None,
|
text: str | list[str] | None = None,
|
||||||
videos: Optional[Union[List, Any]] = None,
|
videos: list | Any | None = None,
|
||||||
padding: Union[bool, str] = False,
|
padding: bool | str = False,
|
||||||
truncation: Optional[bool] = None,
|
truncation: bool | None = None,
|
||||||
max_length: Optional[int] = None,
|
max_length: int | None = None,
|
||||||
return_tensors: str = "pt",
|
return_tensors: str = "pt",
|
||||||
) -> BatchFeature:
|
) -> BatchFeature:
|
||||||
"""Unified preprocessing function for Wall-X model handling text, image and video inputs.
|
"""Unified preprocessing function for Wall-X model handling text, image and video inputs.
|
||||||
@@ -145,9 +144,7 @@ def preprocesser_call(
|
|||||||
"""
|
"""
|
||||||
# Process image inputs
|
# Process image inputs
|
||||||
if images is not None and len(images) > 0:
|
if images is not None and len(images) > 0:
|
||||||
image_inputs = processor.image_processor(
|
image_inputs = processor.image_processor(images=images, videos=None, return_tensors=return_tensors)
|
||||||
images=images, videos=None, return_tensors=return_tensors
|
|
||||||
)
|
|
||||||
image_grid_thw = image_inputs["image_grid_thw"]
|
image_grid_thw = image_inputs["image_grid_thw"]
|
||||||
else:
|
else:
|
||||||
image_inputs = {}
|
image_inputs = {}
|
||||||
@@ -155,9 +152,7 @@ def preprocesser_call(
|
|||||||
|
|
||||||
# Process video inputs
|
# Process video inputs
|
||||||
if videos is not None:
|
if videos is not None:
|
||||||
videos_inputs = processor.image_processor(
|
videos_inputs = processor.image_processor(images=None, videos=videos, return_tensors=return_tensors)
|
||||||
images=None, videos=videos, return_tensors=return_tensors
|
|
||||||
)
|
|
||||||
video_grid_thw = videos_inputs["video_grid_thw"]
|
video_grid_thw = videos_inputs["video_grid_thw"]
|
||||||
else:
|
else:
|
||||||
videos_inputs = {}
|
videos_inputs = {}
|
||||||
@@ -183,9 +178,7 @@ def preprocesser_call(
|
|||||||
break
|
break
|
||||||
# Replace image placeholder with actual token count
|
# Replace image placeholder with actual token count
|
||||||
token_count = image_grid_thw[index].prod() // merge_length
|
token_count = image_grid_thw[index].prod() // merge_length
|
||||||
text[i] = text[i].replace(
|
text[i] = text[i].replace("<|image_pad|>", "<|placeholder|>" * token_count, 1)
|
||||||
"<|image_pad|>", "<|placeholder|>" * token_count, 1
|
|
||||||
)
|
|
||||||
index += 1
|
index += 1
|
||||||
text[i] = text[i].replace("<|placeholder|>", "<|image_pad|>")
|
text[i] = text[i].replace("<|placeholder|>", "<|image_pad|>")
|
||||||
|
|
||||||
@@ -197,9 +190,7 @@ def preprocesser_call(
|
|||||||
while "<|video_pad|>" in text[i]:
|
while "<|video_pad|>" in text[i]:
|
||||||
# Replace video placeholder with actual token count
|
# Replace video placeholder with actual token count
|
||||||
token_count = video_grid_thw[index].prod() // merge_length
|
token_count = video_grid_thw[index].prod() // merge_length
|
||||||
text[i] = text[i].replace(
|
text[i] = text[i].replace("<|video_pad|>", "<|placeholder|>" * token_count, 1)
|
||||||
"<|video_pad|>", "<|placeholder|>" * token_count, 1
|
|
||||||
)
|
|
||||||
index += 1
|
index += 1
|
||||||
text[i] = text[i].replace("<|placeholder|>", "<|video_pad|>")
|
text[i] = text[i].replace("<|placeholder|>", "<|video_pad|>")
|
||||||
|
|
||||||
@@ -221,9 +212,7 @@ def preprocesser_call(
|
|||||||
labels = torch.full_like(text_inputs.input_ids, -100)
|
labels = torch.full_like(text_inputs.input_ids, -100)
|
||||||
assistant_marker = "<|im_start|>assistant\n"
|
assistant_marker = "<|im_start|>assistant\n"
|
||||||
im_end_token_id = processor.tokenizer.convert_tokens_to_ids("<|im_end|>")
|
im_end_token_id = processor.tokenizer.convert_tokens_to_ids("<|im_end|>")
|
||||||
assistant_tokens = processor.tokenizer(
|
assistant_tokens = processor.tokenizer("<|im_start|>assistant\n", add_special_tokens=False).input_ids
|
||||||
"<|im_start|>assistant\n", add_special_tokens=False
|
|
||||||
).input_ids
|
|
||||||
|
|
||||||
for i in range(len(text)):
|
for i in range(len(text)):
|
||||||
assistant_regions = []
|
assistant_regions = []
|
||||||
@@ -249,9 +238,7 @@ def preprocesser_call(
|
|||||||
# From second part onwards, each part starts with assistant response
|
# From second part onwards, each part starts with assistant response
|
||||||
for k in range(current_pos + 1, len(text_inputs.input_ids[i])):
|
for k in range(current_pos + 1, len(text_inputs.input_ids[i])):
|
||||||
if text_inputs.input_ids[i][k] == im_end_token_id:
|
if text_inputs.input_ids[i][k] == im_end_token_id:
|
||||||
assistant_regions.append(
|
assistant_regions.append((current_pos + len(assistant_tokens), k + 2))
|
||||||
(current_pos + len(assistant_tokens), k + 2)
|
|
||||||
)
|
|
||||||
break
|
break
|
||||||
current_pos += len(part_tokens) + 3
|
current_pos += len(part_tokens) + 3
|
||||||
|
|
||||||
@@ -344,7 +331,7 @@ def process_grounding_points(
|
|||||||
coords = [new_x1, new_y1, new_x2, new_y2]
|
coords = [new_x1, new_y1, new_x2, new_y2]
|
||||||
|
|
||||||
# Return processed point tag
|
# Return processed point tag
|
||||||
return f'<point>[{", ".join(map(str, coords))}]</point>'
|
return f"<point>[{', '.join(map(str, coords))}]</point>"
|
||||||
|
|
||||||
except (ValueError, TypeError):
|
except (ValueError, TypeError):
|
||||||
# Return original content if processing fails
|
# Return original content if processing fails
|
||||||
@@ -356,10 +343,10 @@ def process_grounding_points(
|
|||||||
|
|
||||||
|
|
||||||
def get_frame_instruction(
|
def get_frame_instruction(
|
||||||
instruction_info: Dict[str, Any],
|
instruction_info: dict[str, Any],
|
||||||
frame_idx: Optional[int] = None,
|
frame_idx: int | None = None,
|
||||||
truncate_keys: Optional[List[str]] = None,
|
truncate_keys: list[str] | None = None,
|
||||||
) -> Tuple[Dict[str, Any], Optional[int]]:
|
) -> tuple[dict[str, Any], int | None]:
|
||||||
"""Extract frame-specific instruction from instruction dictionary.
|
"""Extract frame-specific instruction from instruction dictionary.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -388,11 +375,7 @@ def get_frame_instruction(
|
|||||||
start_frame, end_frame = map(int, frame_range.split(" "))
|
start_frame, end_frame = map(int, frame_range.split(" "))
|
||||||
if start_frame <= frame_idx < end_frame or (start_frame == frame_idx):
|
if start_frame <= frame_idx < end_frame or (start_frame == frame_idx):
|
||||||
instruction_for_frame[key] = frame_instruction
|
instruction_for_frame[key] = frame_instruction
|
||||||
if (
|
if truncate_keys is not None and split_end is None and key in truncate_keys:
|
||||||
truncate_keys is not None
|
|
||||||
and split_end is None
|
|
||||||
and key in truncate_keys
|
|
||||||
):
|
|
||||||
split_end = end_frame + 1
|
split_end = end_frame + 1
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
@@ -402,7 +385,7 @@ def get_frame_instruction(
|
|||||||
|
|
||||||
|
|
||||||
def get_task_instruction(
|
def get_task_instruction(
|
||||||
frame_instruction_info: Dict[str, Any], priority_order: Optional[OrderedDict] = None
|
frame_instruction_info: dict[str, Any], priority_order: OrderedDict | None = None
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Construct task instruction from available instruction fields using priority sampling.
|
"""Construct task instruction from available instruction fields using priority sampling.
|
||||||
|
|
||||||
@@ -450,13 +433,13 @@ def get_task_instruction(
|
|||||||
|
|
||||||
|
|
||||||
def get_wallx_normal_text(
|
def get_wallx_normal_text(
|
||||||
instruction_info: Dict[str, Any],
|
instruction_info: dict[str, Any],
|
||||||
action_chunk_size: int,
|
action_chunk_size: int,
|
||||||
frame_idx: int,
|
frame_idx: int,
|
||||||
priority_order: Optional[OrderedDict] = None,
|
priority_order: OrderedDict | None = None,
|
||||||
img_keys: Optional[List[str]] = None,
|
img_keys: list[str] | None = None,
|
||||||
generate_subtask_ratio: float = 0.0,
|
generate_subtask_ratio: float = 0.0,
|
||||||
) -> Tuple[str, bool]:
|
) -> tuple[str, bool]:
|
||||||
"""Construct complete multimodal prompt text for Wall-X model.
|
"""Construct complete multimodal prompt text for Wall-X model.
|
||||||
|
|
||||||
Formats input using special tokens including:
|
Formats input using special tokens including:
|
||||||
@@ -488,9 +471,7 @@ def get_wallx_normal_text(
|
|||||||
action_fast_symbol = "<|action_fast|>"
|
action_fast_symbol = "<|action_fast|>"
|
||||||
|
|
||||||
# System prologue
|
# System prologue
|
||||||
prologue = (
|
prologue = f"{role_start_symbol}system\nYou are a helpful assistant.{role_end_symbol}\n"
|
||||||
f"{role_start_symbol}system\nYou are a helpful assistant.{role_end_symbol}\n"
|
|
||||||
)
|
|
||||||
|
|
||||||
# User request with observation
|
# User request with observation
|
||||||
user_request = f"{role_start_symbol}user\nObservation:"
|
user_request = f"{role_start_symbol}user\nObservation:"
|
||||||
@@ -501,9 +482,7 @@ def get_wallx_normal_text(
|
|||||||
user_request += "\nInstruction:"
|
user_request += "\nInstruction:"
|
||||||
|
|
||||||
# Get frame-specific instruction
|
# Get frame-specific instruction
|
||||||
frame_instruction_info, _ = get_frame_instruction(
|
frame_instruction_info, _ = get_frame_instruction(instruction_info, frame_idx=frame_idx)
|
||||||
instruction_info, frame_idx=frame_idx
|
|
||||||
)
|
|
||||||
|
|
||||||
generate_subtask = False
|
generate_subtask = False
|
||||||
priority_keys = ["subtask_generation", "distribute"]
|
priority_keys = ["subtask_generation", "distribute"]
|
||||||
@@ -524,15 +503,11 @@ def get_wallx_normal_text(
|
|||||||
output_instruction = frame_instruction_info[key]
|
output_instruction = frame_instruction_info[key]
|
||||||
break
|
break
|
||||||
|
|
||||||
assistant_output = (
|
assistant_output = f"{role_start_symbol}assistant\n{output_instruction}\n{role_end_symbol}"
|
||||||
f"{role_start_symbol}assistant\n{output_instruction}\n{role_end_symbol}"
|
|
||||||
)
|
|
||||||
generate_subtask = True
|
generate_subtask = True
|
||||||
else:
|
else:
|
||||||
# Generate actions
|
# Generate actions
|
||||||
instruction = get_task_instruction(
|
instruction = get_task_instruction(frame_instruction_info, priority_order=priority_order)
|
||||||
frame_instruction_info, priority_order=priority_order
|
|
||||||
)
|
|
||||||
text_prompt = f"\nPredict the next action in robot action.\nProprioception: {propri_symbol}\n"
|
text_prompt = f"\nPredict the next action in robot action.\nProprioception: {propri_symbol}\n"
|
||||||
user_message = f"{user_request} {instruction}{text_prompt}{role_end_symbol}\n"
|
user_message = f"{user_request} {instruction}{text_prompt}{role_end_symbol}\n"
|
||||||
assistant_output = f"{role_start_symbol}assistant\n{action_fast_symbol}{role_end_symbol}\n{action_symbol * action_chunk_size}"
|
assistant_output = f"{role_start_symbol}assistant\n{action_fast_symbol}{role_end_symbol}\n{action_symbol * action_chunk_size}"
|
||||||
@@ -540,7 +515,8 @@ def get_wallx_normal_text(
|
|||||||
complete_text = prologue + user_message + assistant_output
|
complete_text = prologue + user_message + assistant_output
|
||||||
return complete_text, generate_subtask
|
return complete_text, generate_subtask
|
||||||
|
|
||||||
def img_key_mapping(img_keys: List[str]) -> List[str]:
|
|
||||||
|
def img_key_mapping(img_keys: list[str]) -> list[str]:
|
||||||
"""Map image keys to camera names.
|
"""Map image keys to camera names.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -555,16 +531,15 @@ def img_key_mapping(img_keys: List[str]) -> List[str]:
|
|||||||
if key in CAMERA_NAME_MAPPING:
|
if key in CAMERA_NAME_MAPPING:
|
||||||
key = CAMERA_NAME_MAPPING[key]
|
key = CAMERA_NAME_MAPPING[key]
|
||||||
else:
|
else:
|
||||||
if 'view' in key:
|
if "view" in key:
|
||||||
key = key.replace('_', ' ')
|
key = key.replace("_", " ")
|
||||||
else:
|
else:
|
||||||
key = key + " view"
|
key = key + " view"
|
||||||
processed_img_keys.append(key)
|
processed_img_keys.append(key)
|
||||||
return processed_img_keys
|
return processed_img_keys
|
||||||
|
|
||||||
def get_action_tokens(
|
|
||||||
normalized_actions: Union[torch.Tensor, List], action_tokenizer
|
def get_action_tokens(normalized_actions: torch.Tensor | list, action_tokenizer) -> list[list[str]]:
|
||||||
) -> List[List[str]]:
|
|
||||||
"""Convert normalized actions to action token strings.
|
"""Convert normalized actions to action token strings.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -590,8 +565,8 @@ def get_action_tokens(
|
|||||||
|
|
||||||
|
|
||||||
def pad_action_token_strs(
|
def pad_action_token_strs(
|
||||||
actions_token_lists: List[List[str]], pad_token: str = "<|endoftext|>"
|
actions_token_lists: list[list[str]], pad_token: str = "<|endoftext|>"
|
||||||
) -> List[str]:
|
) -> list[str]:
|
||||||
"""Pad action token lists to same length and join as strings.
|
"""Pad action token lists to same length and join as strings.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -605,20 +580,18 @@ def pad_action_token_strs(
|
|||||||
padded_action_strs = []
|
padded_action_strs = []
|
||||||
|
|
||||||
for tokens in actions_token_lists:
|
for tokens in actions_token_lists:
|
||||||
padded_tokens = (
|
padded_tokens = tokens + ["<|im_end|>\n"] + [pad_token] * (max_len - len(tokens))
|
||||||
tokens + ["<|im_end|>\n"] + [pad_token] * (max_len - len(tokens))
|
|
||||||
)
|
|
||||||
padded_action_strs.append("".join(padded_tokens))
|
padded_action_strs.append("".join(padded_tokens))
|
||||||
|
|
||||||
return padded_action_strs
|
return padded_action_strs
|
||||||
|
|
||||||
|
|
||||||
def replace_action_token(
|
def replace_action_token(
|
||||||
text: List[str],
|
text: list[str],
|
||||||
norm_action: Optional[torch.Tensor],
|
norm_action: torch.Tensor | None,
|
||||||
action_tokenizer,
|
action_tokenizer,
|
||||||
dof_masks: Optional[torch.Tensor] = None,
|
dof_masks: torch.Tensor | None = None,
|
||||||
) -> List[str]:
|
) -> list[str]:
|
||||||
"""Replace action placeholders in text with actual action tokens.
|
"""Replace action placeholders in text with actual action tokens.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -632,10 +605,7 @@ def replace_action_token(
|
|||||||
"""
|
"""
|
||||||
if action_tokenizer is not None and norm_action is not None:
|
if action_tokenizer is not None and norm_action is not None:
|
||||||
# Extract actions based on chunk sizes and DOF masks
|
# Extract actions based on chunk sizes and DOF masks
|
||||||
norm_action = [
|
norm_action = [action[:32, dof_masks[i, 0].bool()] for i, action in enumerate(norm_action)]
|
||||||
action[: 32, dof_masks[i, 0].bool()]
|
|
||||||
for i, action in enumerate(norm_action)
|
|
||||||
]
|
|
||||||
|
|
||||||
# Convert to action tokens and pad
|
# Convert to action tokens and pad
|
||||||
actions_fast_tokens = get_action_tokens(norm_action, action_tokenizer)
|
actions_fast_tokens = get_action_tokens(norm_action, action_tokenizer)
|
||||||
@@ -658,4 +628,3 @@ def replace_action_token(
|
|||||||
text = [t.replace("<|action_fast|><|im_end|>\n", "") for t in text]
|
text = [t.replace("<|action_fast|><|im_end|>\n", "") for t in text]
|
||||||
|
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
|||||||
@@ -35,10 +35,11 @@ from lerobot.policies.wall_x import ( # noqa: E402
|
|||||||
)
|
)
|
||||||
from lerobot.utils.random_utils import set_seed # noqa: E402
|
from lerobot.utils.random_utils import set_seed # noqa: E402
|
||||||
|
|
||||||
|
|
||||||
def test_policy_instantiation():
|
def test_policy_instantiation():
|
||||||
# Create config
|
# Create config
|
||||||
set_seed(42)
|
set_seed(42)
|
||||||
config = WallXConfig(device='cuda')
|
config = WallXConfig(device="cuda")
|
||||||
|
|
||||||
# Set up input_features and output_features in the config
|
# Set up input_features and output_features in the config
|
||||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||||
@@ -118,6 +119,7 @@ def test_policy_instantiation():
|
|||||||
print(f"Action prediction failed: {e}")
|
print(f"Action prediction failed: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
def test_config_creation():
|
def test_config_creation():
|
||||||
"""Test policy config creation through factory."""
|
"""Test policy config creation through factory."""
|
||||||
try:
|
try:
|
||||||
@@ -130,6 +132,7 @@ def test_config_creation():
|
|||||||
print(f"Config creation failed: {e}")
|
print(f"Config creation failed: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_policy_instantiation()
|
test_policy_instantiation()
|
||||||
test_config_creation()
|
test_config_creation()
|
||||||
|
|||||||
Reference in New Issue
Block a user