mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-11 14:49:43 +00:00
fix: train tokenizer CLI entry point (#2784)
This commit is contained in:
@@ -14,11 +14,14 @@
|
||||
"""Train FAST tokenizer for action encoding.
|
||||
|
||||
This script:
|
||||
1. Loads action chunks from LeRobotDataset (with sampling)
|
||||
2. Applies delta transforms and per-timestamp normalization
|
||||
3. Trains FAST tokenizer on specified action dimensions
|
||||
4. Saves tokenizer to assets directory
|
||||
5. Reports compression statistics
|
||||
1. Loads action chunks from LeRobotDataset (with episode sampling)
|
||||
2. Optionally applies delta transforms (relative vs absolute actions)
|
||||
3. Extracts specified action dimensions for encoding
|
||||
4. Applies normalization (MEAN_STD, MIN_MAX, QUANTILES, or other modes)
|
||||
5. Trains FAST tokenizer (BPE on DCT coefficients) on the action chunks
|
||||
6. Saves tokenizer to output directory
|
||||
7. Optionally pushes tokenizer to Hugging Face Hub
|
||||
8. Reports compression statistics
|
||||
|
||||
Example:
|
||||
|
||||
@@ -42,18 +45,64 @@ lerobot-train-tokenizer \
|
||||
"""
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import tyro
|
||||
from huggingface_hub import HfApi
|
||||
from transformers import AutoProcessor
|
||||
|
||||
from lerobot.utils.import_utils import _transformers_available
|
||||
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers import AutoProcessor
|
||||
else:
|
||||
AutoProcessor = None
|
||||
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.types import NormalizationMode
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
|
||||
@dataclass
|
||||
class TokenizerTrainingConfig:
|
||||
"""Configuration for training FAST tokenizer."""
|
||||
|
||||
# LeRobot dataset repository ID
|
||||
repo_id: str
|
||||
# Root directory for dataset (default: ~/.cache/huggingface/lerobot)
|
||||
root: str | None = None
|
||||
# Number of future actions in each chunk
|
||||
action_horizon: int = 10
|
||||
# Max episodes to use (None = all episodes in dataset)
|
||||
max_episodes: int | None = None
|
||||
# Fraction of chunks to sample per episode
|
||||
sample_fraction: float = 0.1
|
||||
# Comma-separated dimension ranges to encode (e.g., "0:6,7:23")
|
||||
encoded_dims: str = "0:6,7:23"
|
||||
# Comma-separated dimension indices for delta transform (e.g., "0,1,2,3,4,5")
|
||||
delta_dims: str | None = None
|
||||
# Whether to apply delta transform (relative actions vs absolute actions)
|
||||
use_delta_transform: bool = False
|
||||
# Dataset key for state observations (default: "observation.state")
|
||||
state_key: str = "observation.state"
|
||||
# Normalization mode (MEAN_STD, MIN_MAX, QUANTILES, QUANTILE10, IDENTITY)
|
||||
normalization_mode: str = "QUANTILES"
|
||||
# FAST vocabulary size (BPE vocab size)
|
||||
vocab_size: int = 1024
|
||||
# DCT scaling factor (default: 10.0)
|
||||
scale: float = 10.0
|
||||
# Directory to save tokenizer (default: ./fast_tokenizer_{repo_id})
|
||||
output_dir: str | None = None
|
||||
# Whether to push the tokenizer to Hugging Face Hub
|
||||
push_to_hub: bool = False
|
||||
# Hub repository ID (e.g., "username/tokenizer-name"). If None, uses output_dir name
|
||||
hub_repo_id: str | None = None
|
||||
# Whether to create a private repository on the Hub
|
||||
hub_private: bool = False
|
||||
|
||||
|
||||
def apply_delta_transform(state: np.ndarray, actions: np.ndarray, delta_dims: list[int] | None) -> np.ndarray:
|
||||
"""Apply delta transform to specified dimensions.
|
||||
|
||||
@@ -327,88 +376,57 @@ def compute_compression_stats(tokenizer, action_chunks: np.ndarray):
|
||||
return stats
|
||||
|
||||
|
||||
def main(
|
||||
repo_id: str,
|
||||
root: str | None = None,
|
||||
action_horizon: int = 10,
|
||||
max_episodes: int | None = None,
|
||||
sample_fraction: float = 0.1,
|
||||
encoded_dims: str = "0:6,7:23",
|
||||
delta_dims: str | None = None,
|
||||
use_delta_transform: bool = False,
|
||||
state_key: str = "observation.state",
|
||||
normalization_mode: str = "QUANTILES",
|
||||
vocab_size: int = 1024,
|
||||
scale: float = 10.0,
|
||||
output_dir: str | None = None,
|
||||
push_to_hub: bool = False,
|
||||
hub_repo_id: str | None = None,
|
||||
hub_private: bool = False,
|
||||
):
|
||||
@parser.wrap()
|
||||
def train_tokenizer(cfg: TokenizerTrainingConfig):
|
||||
"""
|
||||
Train FAST tokenizer for action encoding.
|
||||
|
||||
Args:
|
||||
repo_id: LeRobot dataset repository ID
|
||||
root: Root directory for dataset (default: ~/.cache/huggingface/lerobot)
|
||||
action_horizon: Number of future actions in each chunk
|
||||
max_episodes: Max episodes to use (None = all episodes in dataset)
|
||||
sample_fraction: Fraction of chunks to sample per episode
|
||||
encoded_dims: Comma-separated dimension ranges to encode (e.g., "0:6,7:23")
|
||||
delta_dims: Comma-separated dimension indices for delta transform (e.g., "0,1,2,3,4,5")
|
||||
use_delta_transform: Whether to apply delta transform (relative actions vs absolute actions)
|
||||
state_key: Dataset key for state observations (default: "observation.state")
|
||||
normalization_mode: Normalization mode (MEAN_STD, MIN_MAX, QUANTILES, QUANTILE10, IDENTITY)
|
||||
vocab_size: FAST vocabulary size (BPE vocab size)
|
||||
scale: DCT scaling factor (default: 10.0)
|
||||
output_dir: Directory to save tokenizer (default: ./fast_tokenizer_{repo_id})
|
||||
push_to_hub: Whether to push the tokenizer to Hugging Face Hub
|
||||
hub_repo_id: Hub repository ID (e.g., "username/tokenizer-name"). If None, uses output_dir name
|
||||
hub_private: Whether to create a private repository on the Hub
|
||||
cfg: TokenizerTrainingConfig dataclass with all configuration parameters
|
||||
"""
|
||||
# load dataset
|
||||
print(f"Loading dataset: {repo_id}")
|
||||
dataset = LeRobotDataset(repo_id=repo_id, root=root)
|
||||
print(f"Loading dataset: {cfg.repo_id}")
|
||||
dataset = LeRobotDataset(repo_id=cfg.repo_id, root=cfg.root)
|
||||
print(f"Dataset loaded: {dataset.num_episodes} episodes, {dataset.num_frames} frames")
|
||||
|
||||
# parse normalization mode
|
||||
try:
|
||||
norm_mode = NormalizationMode(normalization_mode)
|
||||
norm_mode = NormalizationMode(cfg.normalization_mode)
|
||||
except ValueError as err:
|
||||
raise ValueError(
|
||||
f"Invalid normalization_mode: {normalization_mode}. "
|
||||
f"Invalid normalization_mode: {cfg.normalization_mode}. "
|
||||
f"Must be one of: {', '.join([m.value for m in NormalizationMode])}"
|
||||
) from err
|
||||
print(f"Normalization mode: {norm_mode.value}")
|
||||
|
||||
# parse encoded dimensions
|
||||
encoded_dim_ranges = []
|
||||
for range_str in encoded_dims.split(","):
|
||||
for range_str in cfg.encoded_dims.split(","):
|
||||
start, end = map(int, range_str.strip().split(":"))
|
||||
encoded_dim_ranges.append((start, end))
|
||||
|
||||
total_encoded_dims = sum(end - start for start, end in encoded_dim_ranges)
|
||||
print(f"Encoding {total_encoded_dims} dimensions: {encoded_dims}")
|
||||
print(f"Encoding {total_encoded_dims} dimensions: {cfg.encoded_dims}")
|
||||
|
||||
# parse delta dimensions
|
||||
delta_dim_list = None
|
||||
if delta_dims is not None and delta_dims.strip():
|
||||
delta_dim_list = [int(d.strip()) for d in delta_dims.split(",")]
|
||||
if cfg.delta_dims is not None and cfg.delta_dims.strip():
|
||||
delta_dim_list = [int(d.strip()) for d in cfg.delta_dims.split(",")]
|
||||
print(f"Delta dimensions: {delta_dim_list}")
|
||||
else:
|
||||
print("No delta dimensions specified")
|
||||
|
||||
print(f"Use delta transform: {use_delta_transform}")
|
||||
if use_delta_transform and (delta_dim_list is None or len(delta_dim_list) == 0):
|
||||
print(f"Use delta transform: {cfg.use_delta_transform}")
|
||||
if cfg.use_delta_transform and (delta_dim_list is None or len(delta_dim_list) == 0):
|
||||
print("Warning: use_delta_transform=True but no delta_dims specified. No delta will be applied.")
|
||||
|
||||
print(f"Action horizon: {action_horizon}")
|
||||
print(f"State key: {state_key}")
|
||||
print(f"Action horizon: {cfg.action_horizon}")
|
||||
print(f"State key: {cfg.state_key}")
|
||||
|
||||
# determine episodes to process
|
||||
num_episodes = dataset.num_episodes
|
||||
if max_episodes is not None:
|
||||
num_episodes = min(max_episodes, num_episodes)
|
||||
if cfg.max_episodes is not None:
|
||||
num_episodes = min(cfg.max_episodes, num_episodes)
|
||||
|
||||
print(f"Processing {num_episodes} episodes...")
|
||||
|
||||
@@ -419,7 +437,15 @@ def main(
|
||||
print(f" Processing episode {ep_idx}/{num_episodes}...")
|
||||
|
||||
chunks = process_episode(
|
||||
(dataset, ep_idx, action_horizon, delta_dim_list, sample_fraction, state_key, use_delta_transform)
|
||||
(
|
||||
dataset,
|
||||
ep_idx,
|
||||
cfg.action_horizon,
|
||||
delta_dim_list,
|
||||
cfg.sample_fraction,
|
||||
cfg.state_key,
|
||||
cfg.use_delta_transform,
|
||||
)
|
||||
)
|
||||
if chunks is not None:
|
||||
all_chunks.append(chunks)
|
||||
@@ -495,16 +521,17 @@ def main(
|
||||
# train FAST tokenizer
|
||||
tokenizer = train_fast_tokenizer(
|
||||
encoded_chunks,
|
||||
vocab_size=vocab_size,
|
||||
scale=scale,
|
||||
vocab_size=cfg.vocab_size,
|
||||
scale=cfg.scale,
|
||||
)
|
||||
|
||||
# compute compression statistics
|
||||
compression_stats = compute_compression_stats(tokenizer, encoded_chunks)
|
||||
|
||||
# save tokenizer
|
||||
output_dir = cfg.output_dir
|
||||
if output_dir is None:
|
||||
output_dir = f"fast_tokenizer_{repo_id.replace('/', '_')}"
|
||||
output_dir = f"fast_tokenizer_{cfg.repo_id.replace('/', '_')}"
|
||||
output_path = Path(output_dir)
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@@ -512,18 +539,18 @@ def main(
|
||||
|
||||
# save metadata
|
||||
metadata = {
|
||||
"repo_id": repo_id,
|
||||
"vocab_size": vocab_size,
|
||||
"scale": scale,
|
||||
"encoded_dims": encoded_dims,
|
||||
"repo_id": cfg.repo_id,
|
||||
"vocab_size": cfg.vocab_size,
|
||||
"scale": cfg.scale,
|
||||
"encoded_dims": cfg.encoded_dims,
|
||||
"encoded_dim_ranges": encoded_dim_ranges,
|
||||
"total_encoded_dims": total_encoded_dims,
|
||||
"delta_dims": delta_dims,
|
||||
"delta_dims": cfg.delta_dims,
|
||||
"delta_dim_list": delta_dim_list,
|
||||
"use_delta_transform": use_delta_transform,
|
||||
"state_key": state_key,
|
||||
"use_delta_transform": cfg.use_delta_transform,
|
||||
"state_key": cfg.state_key,
|
||||
"normalization_mode": norm_mode.value,
|
||||
"action_horizon": action_horizon,
|
||||
"action_horizon": cfg.action_horizon,
|
||||
"num_training_chunks": len(encoded_chunks),
|
||||
"compression_stats": compression_stats,
|
||||
}
|
||||
@@ -535,21 +562,22 @@ def main(
|
||||
print(f"Metadata: {json.dumps(metadata, indent=2)}")
|
||||
|
||||
# push to Hugging Face Hub if requested
|
||||
if push_to_hub:
|
||||
if cfg.push_to_hub:
|
||||
# determine the hub repository ID
|
||||
hub_repo_id = cfg.hub_repo_id
|
||||
if hub_repo_id is None:
|
||||
hub_repo_id = output_path.name
|
||||
print(f"\nNo hub_repo_id provided, using: {hub_repo_id}")
|
||||
|
||||
print(f"\nPushing tokenizer to Hugging Face Hub: {hub_repo_id}")
|
||||
print(f" Private: {hub_private}")
|
||||
print(f" Private: {cfg.hub_private}")
|
||||
|
||||
try:
|
||||
# use the tokenizer's push_to_hub method
|
||||
tokenizer.push_to_hub(
|
||||
repo_id=hub_repo_id,
|
||||
private=hub_private,
|
||||
commit_message=f"Upload FAST tokenizer trained on {repo_id}",
|
||||
private=cfg.hub_private,
|
||||
commit_message=f"Upload FAST tokenizer trained on {cfg.repo_id}",
|
||||
)
|
||||
|
||||
# also upload the metadata.json file separately
|
||||
@@ -568,5 +596,10 @@ def main(
|
||||
print(" Make sure you're logged in with `huggingface-cli login`")
|
||||
|
||||
|
||||
def main():
|
||||
"""CLI entry point that parses arguments and runs the tokenizer training."""
|
||||
train_tokenizer()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tyro.cli(main)
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user