fix: train tokenizer CLI entry point (#2784)

This commit is contained in:
Jade Choghari
2026-01-13 01:42:53 +01:00
committed by GitHub
parent d0f57f58d1
commit 2cdd9f43f7
+105 -72
View File
@@ -14,11 +14,14 @@
"""Train FAST tokenizer for action encoding. """Train FAST tokenizer for action encoding.
This script: This script:
1. Loads action chunks from LeRobotDataset (with sampling) 1. Loads action chunks from LeRobotDataset (with episode sampling)
2. Applies delta transforms and per-timestamp normalization 2. Optionally applies delta transforms (relative vs absolute actions)
3. Trains FAST tokenizer on specified action dimensions 3. Extracts specified action dimensions for encoding
4. Saves tokenizer to assets directory 4. Applies normalization (MEAN_STD, MIN_MAX, QUANTILES, or other modes)
5. Reports compression statistics 5. Trains FAST tokenizer (BPE on DCT coefficients) on the action chunks
6. Saves tokenizer to output directory
7. Optionally pushes tokenizer to Hugging Face Hub
8. Reports compression statistics
Example: Example:
@@ -42,18 +45,64 @@ lerobot-train-tokenizer \
""" """
import json import json
from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING
import numpy as np import numpy as np
import torch import torch
import tyro
from huggingface_hub import HfApi from huggingface_hub import HfApi
from transformers import AutoProcessor
from lerobot.utils.import_utils import _transformers_available
if TYPE_CHECKING or _transformers_available:
from transformers import AutoProcessor
else:
AutoProcessor = None
from lerobot.configs import parser
from lerobot.configs.types import NormalizationMode from lerobot.configs.types import NormalizationMode
from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.lerobot_dataset import LeRobotDataset
@dataclass
class TokenizerTrainingConfig:
"""Configuration for training FAST tokenizer."""
# LeRobot dataset repository ID
repo_id: str
# Root directory for dataset (default: ~/.cache/huggingface/lerobot)
root: str | None = None
# Number of future actions in each chunk
action_horizon: int = 10
# Max episodes to use (None = all episodes in dataset)
max_episodes: int | None = None
# Fraction of chunks to sample per episode
sample_fraction: float = 0.1
# Comma-separated dimension ranges to encode (e.g., "0:6,7:23")
encoded_dims: str = "0:6,7:23"
# Comma-separated dimension indices for delta transform (e.g., "0,1,2,3,4,5")
delta_dims: str | None = None
# Whether to apply delta transform (relative actions vs absolute actions)
use_delta_transform: bool = False
# Dataset key for state observations (default: "observation.state")
state_key: str = "observation.state"
# Normalization mode (MEAN_STD, MIN_MAX, QUANTILES, QUANTILE10, IDENTITY)
normalization_mode: str = "QUANTILES"
# FAST vocabulary size (BPE vocab size)
vocab_size: int = 1024
# DCT scaling factor (default: 10.0)
scale: float = 10.0
# Directory to save tokenizer (default: ./fast_tokenizer_{repo_id})
output_dir: str | None = None
# Whether to push the tokenizer to Hugging Face Hub
push_to_hub: bool = False
# Hub repository ID (e.g., "username/tokenizer-name"). If None, uses output_dir name
hub_repo_id: str | None = None
# Whether to create a private repository on the Hub
hub_private: bool = False
def apply_delta_transform(state: np.ndarray, actions: np.ndarray, delta_dims: list[int] | None) -> np.ndarray: def apply_delta_transform(state: np.ndarray, actions: np.ndarray, delta_dims: list[int] | None) -> np.ndarray:
"""Apply delta transform to specified dimensions. """Apply delta transform to specified dimensions.
@@ -327,88 +376,57 @@ def compute_compression_stats(tokenizer, action_chunks: np.ndarray):
return stats return stats
def main( @parser.wrap()
repo_id: str, def train_tokenizer(cfg: TokenizerTrainingConfig):
root: str | None = None,
action_horizon: int = 10,
max_episodes: int | None = None,
sample_fraction: float = 0.1,
encoded_dims: str = "0:6,7:23",
delta_dims: str | None = None,
use_delta_transform: bool = False,
state_key: str = "observation.state",
normalization_mode: str = "QUANTILES",
vocab_size: int = 1024,
scale: float = 10.0,
output_dir: str | None = None,
push_to_hub: bool = False,
hub_repo_id: str | None = None,
hub_private: bool = False,
):
""" """
Train FAST tokenizer for action encoding. Train FAST tokenizer for action encoding.
Args: Args:
repo_id: LeRobot dataset repository ID cfg: TokenizerTrainingConfig dataclass with all configuration parameters
root: Root directory for dataset (default: ~/.cache/huggingface/lerobot)
action_horizon: Number of future actions in each chunk
max_episodes: Max episodes to use (None = all episodes in dataset)
sample_fraction: Fraction of chunks to sample per episode
encoded_dims: Comma-separated dimension ranges to encode (e.g., "0:6,7:23")
delta_dims: Comma-separated dimension indices for delta transform (e.g., "0,1,2,3,4,5")
use_delta_transform: Whether to apply delta transform (relative actions vs absolute actions)
state_key: Dataset key for state observations (default: "observation.state")
normalization_mode: Normalization mode (MEAN_STD, MIN_MAX, QUANTILES, QUANTILE10, IDENTITY)
vocab_size: FAST vocabulary size (BPE vocab size)
scale: DCT scaling factor (default: 10.0)
output_dir: Directory to save tokenizer (default: ./fast_tokenizer_{repo_id})
push_to_hub: Whether to push the tokenizer to Hugging Face Hub
hub_repo_id: Hub repository ID (e.g., "username/tokenizer-name"). If None, uses output_dir name
hub_private: Whether to create a private repository on the Hub
""" """
# load dataset # load dataset
print(f"Loading dataset: {repo_id}") print(f"Loading dataset: {cfg.repo_id}")
dataset = LeRobotDataset(repo_id=repo_id, root=root) dataset = LeRobotDataset(repo_id=cfg.repo_id, root=cfg.root)
print(f"Dataset loaded: {dataset.num_episodes} episodes, {dataset.num_frames} frames") print(f"Dataset loaded: {dataset.num_episodes} episodes, {dataset.num_frames} frames")
# parse normalization mode # parse normalization mode
try: try:
norm_mode = NormalizationMode(normalization_mode) norm_mode = NormalizationMode(cfg.normalization_mode)
except ValueError as err: except ValueError as err:
raise ValueError( raise ValueError(
f"Invalid normalization_mode: {normalization_mode}. " f"Invalid normalization_mode: {cfg.normalization_mode}. "
f"Must be one of: {', '.join([m.value for m in NormalizationMode])}" f"Must be one of: {', '.join([m.value for m in NormalizationMode])}"
) from err ) from err
print(f"Normalization mode: {norm_mode.value}") print(f"Normalization mode: {norm_mode.value}")
# parse encoded dimensions # parse encoded dimensions
encoded_dim_ranges = [] encoded_dim_ranges = []
for range_str in encoded_dims.split(","): for range_str in cfg.encoded_dims.split(","):
start, end = map(int, range_str.strip().split(":")) start, end = map(int, range_str.strip().split(":"))
encoded_dim_ranges.append((start, end)) encoded_dim_ranges.append((start, end))
total_encoded_dims = sum(end - start for start, end in encoded_dim_ranges) total_encoded_dims = sum(end - start for start, end in encoded_dim_ranges)
print(f"Encoding {total_encoded_dims} dimensions: {encoded_dims}") print(f"Encoding {total_encoded_dims} dimensions: {cfg.encoded_dims}")
# parse delta dimensions # parse delta dimensions
delta_dim_list = None delta_dim_list = None
if delta_dims is not None and delta_dims.strip(): if cfg.delta_dims is not None and cfg.delta_dims.strip():
delta_dim_list = [int(d.strip()) for d in delta_dims.split(",")] delta_dim_list = [int(d.strip()) for d in cfg.delta_dims.split(",")]
print(f"Delta dimensions: {delta_dim_list}") print(f"Delta dimensions: {delta_dim_list}")
else: else:
print("No delta dimensions specified") print("No delta dimensions specified")
print(f"Use delta transform: {use_delta_transform}") print(f"Use delta transform: {cfg.use_delta_transform}")
if use_delta_transform and (delta_dim_list is None or len(delta_dim_list) == 0): if cfg.use_delta_transform and (delta_dim_list is None or len(delta_dim_list) == 0):
print("Warning: use_delta_transform=True but no delta_dims specified. No delta will be applied.") print("Warning: use_delta_transform=True but no delta_dims specified. No delta will be applied.")
print(f"Action horizon: {action_horizon}") print(f"Action horizon: {cfg.action_horizon}")
print(f"State key: {state_key}") print(f"State key: {cfg.state_key}")
# determine episodes to process # determine episodes to process
num_episodes = dataset.num_episodes num_episodes = dataset.num_episodes
if max_episodes is not None: if cfg.max_episodes is not None:
num_episodes = min(max_episodes, num_episodes) num_episodes = min(cfg.max_episodes, num_episodes)
print(f"Processing {num_episodes} episodes...") print(f"Processing {num_episodes} episodes...")
@@ -419,7 +437,15 @@ def main(
print(f" Processing episode {ep_idx}/{num_episodes}...") print(f" Processing episode {ep_idx}/{num_episodes}...")
chunks = process_episode( chunks = process_episode(
(dataset, ep_idx, action_horizon, delta_dim_list, sample_fraction, state_key, use_delta_transform) (
dataset,
ep_idx,
cfg.action_horizon,
delta_dim_list,
cfg.sample_fraction,
cfg.state_key,
cfg.use_delta_transform,
)
) )
if chunks is not None: if chunks is not None:
all_chunks.append(chunks) all_chunks.append(chunks)
@@ -495,16 +521,17 @@ def main(
# train FAST tokenizer # train FAST tokenizer
tokenizer = train_fast_tokenizer( tokenizer = train_fast_tokenizer(
encoded_chunks, encoded_chunks,
vocab_size=vocab_size, vocab_size=cfg.vocab_size,
scale=scale, scale=cfg.scale,
) )
# compute compression statistics # compute compression statistics
compression_stats = compute_compression_stats(tokenizer, encoded_chunks) compression_stats = compute_compression_stats(tokenizer, encoded_chunks)
# save tokenizer # save tokenizer
output_dir = cfg.output_dir
if output_dir is None: if output_dir is None:
output_dir = f"fast_tokenizer_{repo_id.replace('/', '_')}" output_dir = f"fast_tokenizer_{cfg.repo_id.replace('/', '_')}"
output_path = Path(output_dir) output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True) output_path.mkdir(parents=True, exist_ok=True)
@@ -512,18 +539,18 @@ def main(
# save metadata # save metadata
metadata = { metadata = {
"repo_id": repo_id, "repo_id": cfg.repo_id,
"vocab_size": vocab_size, "vocab_size": cfg.vocab_size,
"scale": scale, "scale": cfg.scale,
"encoded_dims": encoded_dims, "encoded_dims": cfg.encoded_dims,
"encoded_dim_ranges": encoded_dim_ranges, "encoded_dim_ranges": encoded_dim_ranges,
"total_encoded_dims": total_encoded_dims, "total_encoded_dims": total_encoded_dims,
"delta_dims": delta_dims, "delta_dims": cfg.delta_dims,
"delta_dim_list": delta_dim_list, "delta_dim_list": delta_dim_list,
"use_delta_transform": use_delta_transform, "use_delta_transform": cfg.use_delta_transform,
"state_key": state_key, "state_key": cfg.state_key,
"normalization_mode": norm_mode.value, "normalization_mode": norm_mode.value,
"action_horizon": action_horizon, "action_horizon": cfg.action_horizon,
"num_training_chunks": len(encoded_chunks), "num_training_chunks": len(encoded_chunks),
"compression_stats": compression_stats, "compression_stats": compression_stats,
} }
@@ -535,21 +562,22 @@ def main(
print(f"Metadata: {json.dumps(metadata, indent=2)}") print(f"Metadata: {json.dumps(metadata, indent=2)}")
# push to Hugging Face Hub if requested # push to Hugging Face Hub if requested
if push_to_hub: if cfg.push_to_hub:
# determine the hub repository ID # determine the hub repository ID
hub_repo_id = cfg.hub_repo_id
if hub_repo_id is None: if hub_repo_id is None:
hub_repo_id = output_path.name hub_repo_id = output_path.name
print(f"\nNo hub_repo_id provided, using: {hub_repo_id}") print(f"\nNo hub_repo_id provided, using: {hub_repo_id}")
print(f"\nPushing tokenizer to Hugging Face Hub: {hub_repo_id}") print(f"\nPushing tokenizer to Hugging Face Hub: {hub_repo_id}")
print(f" Private: {hub_private}") print(f" Private: {cfg.hub_private}")
try: try:
# use the tokenizer's push_to_hub method # use the tokenizer's push_to_hub method
tokenizer.push_to_hub( tokenizer.push_to_hub(
repo_id=hub_repo_id, repo_id=hub_repo_id,
private=hub_private, private=cfg.hub_private,
commit_message=f"Upload FAST tokenizer trained on {repo_id}", commit_message=f"Upload FAST tokenizer trained on {cfg.repo_id}",
) )
# also upload the metadata.json file separately # also upload the metadata.json file separately
@@ -568,5 +596,10 @@ def main(
print(" Make sure you're logged in with `huggingface-cli login`") print(" Make sure you're logged in with `huggingface-cli login`")
def main():
"""CLI entry point that parses arguments and runs the tokenizer training."""
train_tokenizer()
if __name__ == "__main__": if __name__ == "__main__":
tyro.cli(main) main()