mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-25 21:50:03 +00:00
fix: train tokenizer CLI entry point (#2784)
This commit is contained in:
@@ -14,11 +14,14 @@
|
|||||||
"""Train FAST tokenizer for action encoding.
|
"""Train FAST tokenizer for action encoding.
|
||||||
|
|
||||||
This script:
|
This script:
|
||||||
1. Loads action chunks from LeRobotDataset (with sampling)
|
1. Loads action chunks from LeRobotDataset (with episode sampling)
|
||||||
2. Applies delta transforms and per-timestamp normalization
|
2. Optionally applies delta transforms (relative vs absolute actions)
|
||||||
3. Trains FAST tokenizer on specified action dimensions
|
3. Extracts specified action dimensions for encoding
|
||||||
4. Saves tokenizer to assets directory
|
4. Applies normalization (MEAN_STD, MIN_MAX, QUANTILES, or other modes)
|
||||||
5. Reports compression statistics
|
5. Trains FAST tokenizer (BPE on DCT coefficients) on the action chunks
|
||||||
|
6. Saves tokenizer to output directory
|
||||||
|
7. Optionally pushes tokenizer to Hugging Face Hub
|
||||||
|
8. Reports compression statistics
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
@@ -42,18 +45,64 @@ lerobot-train-tokenizer \
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import tyro
|
|
||||||
from huggingface_hub import HfApi
|
from huggingface_hub import HfApi
|
||||||
from transformers import AutoProcessor
|
|
||||||
|
|
||||||
|
from lerobot.utils.import_utils import _transformers_available
|
||||||
|
|
||||||
|
if TYPE_CHECKING or _transformers_available:
|
||||||
|
from transformers import AutoProcessor
|
||||||
|
else:
|
||||||
|
AutoProcessor = None
|
||||||
|
|
||||||
|
from lerobot.configs import parser
|
||||||
from lerobot.configs.types import NormalizationMode
|
from lerobot.configs.types import NormalizationMode
|
||||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TokenizerTrainingConfig:
|
||||||
|
"""Configuration for training FAST tokenizer."""
|
||||||
|
|
||||||
|
# LeRobot dataset repository ID
|
||||||
|
repo_id: str
|
||||||
|
# Root directory for dataset (default: ~/.cache/huggingface/lerobot)
|
||||||
|
root: str | None = None
|
||||||
|
# Number of future actions in each chunk
|
||||||
|
action_horizon: int = 10
|
||||||
|
# Max episodes to use (None = all episodes in dataset)
|
||||||
|
max_episodes: int | None = None
|
||||||
|
# Fraction of chunks to sample per episode
|
||||||
|
sample_fraction: float = 0.1
|
||||||
|
# Comma-separated dimension ranges to encode (e.g., "0:6,7:23")
|
||||||
|
encoded_dims: str = "0:6,7:23"
|
||||||
|
# Comma-separated dimension indices for delta transform (e.g., "0,1,2,3,4,5")
|
||||||
|
delta_dims: str | None = None
|
||||||
|
# Whether to apply delta transform (relative actions vs absolute actions)
|
||||||
|
use_delta_transform: bool = False
|
||||||
|
# Dataset key for state observations (default: "observation.state")
|
||||||
|
state_key: str = "observation.state"
|
||||||
|
# Normalization mode (MEAN_STD, MIN_MAX, QUANTILES, QUANTILE10, IDENTITY)
|
||||||
|
normalization_mode: str = "QUANTILES"
|
||||||
|
# FAST vocabulary size (BPE vocab size)
|
||||||
|
vocab_size: int = 1024
|
||||||
|
# DCT scaling factor (default: 10.0)
|
||||||
|
scale: float = 10.0
|
||||||
|
# Directory to save tokenizer (default: ./fast_tokenizer_{repo_id})
|
||||||
|
output_dir: str | None = None
|
||||||
|
# Whether to push the tokenizer to Hugging Face Hub
|
||||||
|
push_to_hub: bool = False
|
||||||
|
# Hub repository ID (e.g., "username/tokenizer-name"). If None, uses output_dir name
|
||||||
|
hub_repo_id: str | None = None
|
||||||
|
# Whether to create a private repository on the Hub
|
||||||
|
hub_private: bool = False
|
||||||
|
|
||||||
|
|
||||||
def apply_delta_transform(state: np.ndarray, actions: np.ndarray, delta_dims: list[int] | None) -> np.ndarray:
|
def apply_delta_transform(state: np.ndarray, actions: np.ndarray, delta_dims: list[int] | None) -> np.ndarray:
|
||||||
"""Apply delta transform to specified dimensions.
|
"""Apply delta transform to specified dimensions.
|
||||||
|
|
||||||
@@ -327,88 +376,57 @@ def compute_compression_stats(tokenizer, action_chunks: np.ndarray):
|
|||||||
return stats
|
return stats
|
||||||
|
|
||||||
|
|
||||||
def main(
|
@parser.wrap()
|
||||||
repo_id: str,
|
def train_tokenizer(cfg: TokenizerTrainingConfig):
|
||||||
root: str | None = None,
|
|
||||||
action_horizon: int = 10,
|
|
||||||
max_episodes: int | None = None,
|
|
||||||
sample_fraction: float = 0.1,
|
|
||||||
encoded_dims: str = "0:6,7:23",
|
|
||||||
delta_dims: str | None = None,
|
|
||||||
use_delta_transform: bool = False,
|
|
||||||
state_key: str = "observation.state",
|
|
||||||
normalization_mode: str = "QUANTILES",
|
|
||||||
vocab_size: int = 1024,
|
|
||||||
scale: float = 10.0,
|
|
||||||
output_dir: str | None = None,
|
|
||||||
push_to_hub: bool = False,
|
|
||||||
hub_repo_id: str | None = None,
|
|
||||||
hub_private: bool = False,
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
Train FAST tokenizer for action encoding.
|
Train FAST tokenizer for action encoding.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
repo_id: LeRobot dataset repository ID
|
cfg: TokenizerTrainingConfig dataclass with all configuration parameters
|
||||||
root: Root directory for dataset (default: ~/.cache/huggingface/lerobot)
|
|
||||||
action_horizon: Number of future actions in each chunk
|
|
||||||
max_episodes: Max episodes to use (None = all episodes in dataset)
|
|
||||||
sample_fraction: Fraction of chunks to sample per episode
|
|
||||||
encoded_dims: Comma-separated dimension ranges to encode (e.g., "0:6,7:23")
|
|
||||||
delta_dims: Comma-separated dimension indices for delta transform (e.g., "0,1,2,3,4,5")
|
|
||||||
use_delta_transform: Whether to apply delta transform (relative actions vs absolute actions)
|
|
||||||
state_key: Dataset key for state observations (default: "observation.state")
|
|
||||||
normalization_mode: Normalization mode (MEAN_STD, MIN_MAX, QUANTILES, QUANTILE10, IDENTITY)
|
|
||||||
vocab_size: FAST vocabulary size (BPE vocab size)
|
|
||||||
scale: DCT scaling factor (default: 10.0)
|
|
||||||
output_dir: Directory to save tokenizer (default: ./fast_tokenizer_{repo_id})
|
|
||||||
push_to_hub: Whether to push the tokenizer to Hugging Face Hub
|
|
||||||
hub_repo_id: Hub repository ID (e.g., "username/tokenizer-name"). If None, uses output_dir name
|
|
||||||
hub_private: Whether to create a private repository on the Hub
|
|
||||||
"""
|
"""
|
||||||
# load dataset
|
# load dataset
|
||||||
print(f"Loading dataset: {repo_id}")
|
print(f"Loading dataset: {cfg.repo_id}")
|
||||||
dataset = LeRobotDataset(repo_id=repo_id, root=root)
|
dataset = LeRobotDataset(repo_id=cfg.repo_id, root=cfg.root)
|
||||||
print(f"Dataset loaded: {dataset.num_episodes} episodes, {dataset.num_frames} frames")
|
print(f"Dataset loaded: {dataset.num_episodes} episodes, {dataset.num_frames} frames")
|
||||||
|
|
||||||
# parse normalization mode
|
# parse normalization mode
|
||||||
try:
|
try:
|
||||||
norm_mode = NormalizationMode(normalization_mode)
|
norm_mode = NormalizationMode(cfg.normalization_mode)
|
||||||
except ValueError as err:
|
except ValueError as err:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Invalid normalization_mode: {normalization_mode}. "
|
f"Invalid normalization_mode: {cfg.normalization_mode}. "
|
||||||
f"Must be one of: {', '.join([m.value for m in NormalizationMode])}"
|
f"Must be one of: {', '.join([m.value for m in NormalizationMode])}"
|
||||||
) from err
|
) from err
|
||||||
print(f"Normalization mode: {norm_mode.value}")
|
print(f"Normalization mode: {norm_mode.value}")
|
||||||
|
|
||||||
# parse encoded dimensions
|
# parse encoded dimensions
|
||||||
encoded_dim_ranges = []
|
encoded_dim_ranges = []
|
||||||
for range_str in encoded_dims.split(","):
|
for range_str in cfg.encoded_dims.split(","):
|
||||||
start, end = map(int, range_str.strip().split(":"))
|
start, end = map(int, range_str.strip().split(":"))
|
||||||
encoded_dim_ranges.append((start, end))
|
encoded_dim_ranges.append((start, end))
|
||||||
|
|
||||||
total_encoded_dims = sum(end - start for start, end in encoded_dim_ranges)
|
total_encoded_dims = sum(end - start for start, end in encoded_dim_ranges)
|
||||||
print(f"Encoding {total_encoded_dims} dimensions: {encoded_dims}")
|
print(f"Encoding {total_encoded_dims} dimensions: {cfg.encoded_dims}")
|
||||||
|
|
||||||
# parse delta dimensions
|
# parse delta dimensions
|
||||||
delta_dim_list = None
|
delta_dim_list = None
|
||||||
if delta_dims is not None and delta_dims.strip():
|
if cfg.delta_dims is not None and cfg.delta_dims.strip():
|
||||||
delta_dim_list = [int(d.strip()) for d in delta_dims.split(",")]
|
delta_dim_list = [int(d.strip()) for d in cfg.delta_dims.split(",")]
|
||||||
print(f"Delta dimensions: {delta_dim_list}")
|
print(f"Delta dimensions: {delta_dim_list}")
|
||||||
else:
|
else:
|
||||||
print("No delta dimensions specified")
|
print("No delta dimensions specified")
|
||||||
|
|
||||||
print(f"Use delta transform: {use_delta_transform}")
|
print(f"Use delta transform: {cfg.use_delta_transform}")
|
||||||
if use_delta_transform and (delta_dim_list is None or len(delta_dim_list) == 0):
|
if cfg.use_delta_transform and (delta_dim_list is None or len(delta_dim_list) == 0):
|
||||||
print("Warning: use_delta_transform=True but no delta_dims specified. No delta will be applied.")
|
print("Warning: use_delta_transform=True but no delta_dims specified. No delta will be applied.")
|
||||||
|
|
||||||
print(f"Action horizon: {action_horizon}")
|
print(f"Action horizon: {cfg.action_horizon}")
|
||||||
print(f"State key: {state_key}")
|
print(f"State key: {cfg.state_key}")
|
||||||
|
|
||||||
# determine episodes to process
|
# determine episodes to process
|
||||||
num_episodes = dataset.num_episodes
|
num_episodes = dataset.num_episodes
|
||||||
if max_episodes is not None:
|
if cfg.max_episodes is not None:
|
||||||
num_episodes = min(max_episodes, num_episodes)
|
num_episodes = min(cfg.max_episodes, num_episodes)
|
||||||
|
|
||||||
print(f"Processing {num_episodes} episodes...")
|
print(f"Processing {num_episodes} episodes...")
|
||||||
|
|
||||||
@@ -419,7 +437,15 @@ def main(
|
|||||||
print(f" Processing episode {ep_idx}/{num_episodes}...")
|
print(f" Processing episode {ep_idx}/{num_episodes}...")
|
||||||
|
|
||||||
chunks = process_episode(
|
chunks = process_episode(
|
||||||
(dataset, ep_idx, action_horizon, delta_dim_list, sample_fraction, state_key, use_delta_transform)
|
(
|
||||||
|
dataset,
|
||||||
|
ep_idx,
|
||||||
|
cfg.action_horizon,
|
||||||
|
delta_dim_list,
|
||||||
|
cfg.sample_fraction,
|
||||||
|
cfg.state_key,
|
||||||
|
cfg.use_delta_transform,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
if chunks is not None:
|
if chunks is not None:
|
||||||
all_chunks.append(chunks)
|
all_chunks.append(chunks)
|
||||||
@@ -495,16 +521,17 @@ def main(
|
|||||||
# train FAST tokenizer
|
# train FAST tokenizer
|
||||||
tokenizer = train_fast_tokenizer(
|
tokenizer = train_fast_tokenizer(
|
||||||
encoded_chunks,
|
encoded_chunks,
|
||||||
vocab_size=vocab_size,
|
vocab_size=cfg.vocab_size,
|
||||||
scale=scale,
|
scale=cfg.scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
# compute compression statistics
|
# compute compression statistics
|
||||||
compression_stats = compute_compression_stats(tokenizer, encoded_chunks)
|
compression_stats = compute_compression_stats(tokenizer, encoded_chunks)
|
||||||
|
|
||||||
# save tokenizer
|
# save tokenizer
|
||||||
|
output_dir = cfg.output_dir
|
||||||
if output_dir is None:
|
if output_dir is None:
|
||||||
output_dir = f"fast_tokenizer_{repo_id.replace('/', '_')}"
|
output_dir = f"fast_tokenizer_{cfg.repo_id.replace('/', '_')}"
|
||||||
output_path = Path(output_dir)
|
output_path = Path(output_dir)
|
||||||
output_path.mkdir(parents=True, exist_ok=True)
|
output_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
@@ -512,18 +539,18 @@ def main(
|
|||||||
|
|
||||||
# save metadata
|
# save metadata
|
||||||
metadata = {
|
metadata = {
|
||||||
"repo_id": repo_id,
|
"repo_id": cfg.repo_id,
|
||||||
"vocab_size": vocab_size,
|
"vocab_size": cfg.vocab_size,
|
||||||
"scale": scale,
|
"scale": cfg.scale,
|
||||||
"encoded_dims": encoded_dims,
|
"encoded_dims": cfg.encoded_dims,
|
||||||
"encoded_dim_ranges": encoded_dim_ranges,
|
"encoded_dim_ranges": encoded_dim_ranges,
|
||||||
"total_encoded_dims": total_encoded_dims,
|
"total_encoded_dims": total_encoded_dims,
|
||||||
"delta_dims": delta_dims,
|
"delta_dims": cfg.delta_dims,
|
||||||
"delta_dim_list": delta_dim_list,
|
"delta_dim_list": delta_dim_list,
|
||||||
"use_delta_transform": use_delta_transform,
|
"use_delta_transform": cfg.use_delta_transform,
|
||||||
"state_key": state_key,
|
"state_key": cfg.state_key,
|
||||||
"normalization_mode": norm_mode.value,
|
"normalization_mode": norm_mode.value,
|
||||||
"action_horizon": action_horizon,
|
"action_horizon": cfg.action_horizon,
|
||||||
"num_training_chunks": len(encoded_chunks),
|
"num_training_chunks": len(encoded_chunks),
|
||||||
"compression_stats": compression_stats,
|
"compression_stats": compression_stats,
|
||||||
}
|
}
|
||||||
@@ -535,21 +562,22 @@ def main(
|
|||||||
print(f"Metadata: {json.dumps(metadata, indent=2)}")
|
print(f"Metadata: {json.dumps(metadata, indent=2)}")
|
||||||
|
|
||||||
# push to Hugging Face Hub if requested
|
# push to Hugging Face Hub if requested
|
||||||
if push_to_hub:
|
if cfg.push_to_hub:
|
||||||
# determine the hub repository ID
|
# determine the hub repository ID
|
||||||
|
hub_repo_id = cfg.hub_repo_id
|
||||||
if hub_repo_id is None:
|
if hub_repo_id is None:
|
||||||
hub_repo_id = output_path.name
|
hub_repo_id = output_path.name
|
||||||
print(f"\nNo hub_repo_id provided, using: {hub_repo_id}")
|
print(f"\nNo hub_repo_id provided, using: {hub_repo_id}")
|
||||||
|
|
||||||
print(f"\nPushing tokenizer to Hugging Face Hub: {hub_repo_id}")
|
print(f"\nPushing tokenizer to Hugging Face Hub: {hub_repo_id}")
|
||||||
print(f" Private: {hub_private}")
|
print(f" Private: {cfg.hub_private}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# use the tokenizer's push_to_hub method
|
# use the tokenizer's push_to_hub method
|
||||||
tokenizer.push_to_hub(
|
tokenizer.push_to_hub(
|
||||||
repo_id=hub_repo_id,
|
repo_id=hub_repo_id,
|
||||||
private=hub_private,
|
private=cfg.hub_private,
|
||||||
commit_message=f"Upload FAST tokenizer trained on {repo_id}",
|
commit_message=f"Upload FAST tokenizer trained on {cfg.repo_id}",
|
||||||
)
|
)
|
||||||
|
|
||||||
# also upload the metadata.json file separately
|
# also upload the metadata.json file separately
|
||||||
@@ -568,5 +596,10 @@ def main(
|
|||||||
print(" Make sure you're logged in with `huggingface-cli login`")
|
print(" Make sure you're logged in with `huggingface-cli login`")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""CLI entry point that parses arguments and runs the tokenizer training."""
|
||||||
|
train_tokenizer()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
tyro.cli(main)
|
main()
|
||||||
|
|||||||
Reference in New Issue
Block a user