mirror of
https://github.com/huggingface/lerobot.git
synced 2026-07-05 09:07:03 +00:00
feat(pi052): auto-fit FAST tokenizer per-dataset before training
Per Pertsch et al. 2025 (FAST paper, [64] in π0.5) and π0.5 §III.C,
the recommended practice is to *fit* the FAST action tokenizer on
the specific dataset's action distribution rather than using the
published universal codebook off the shelf. The universal tokenizer
works on any 6-DoF action sequence but produces suboptimal
compression, which slows CE convergence and wastes vocab capacity.
New utility ``lerobot.policies.pi052.fit_fast_tokenizer``:
* samples N action chunks from the LeRobotDataset (default 1024)
* loads ``physical-intelligence/fast`` as the base
* calls ``.fit(actions)`` (the AutoProcessor API the HF model card
documents) — produces a per-dataset codebook
* saves to ``{cache_dir}/{sha256(dataset, base, n_samples)[:16]}/``
* returns the local path, ready to feed
``ActionTokenizerProcessorStep(action_tokenizer_name=...)``.
Cache is keyed on (dataset, base tokenizer, sample count) so changing
any of them re-runs the fit. Re-running training on the same dataset
re-uses the cache (one fit per dataset per machine).
Auto-fit wiring:
* PI052Config gets ``auto_fit_fast_tokenizer`` (default True),
``fast_tokenizer_cache_dir`` (default ~/.cache/lerobot/...),
``fast_tokenizer_fit_samples`` (default 1024).
* make_pi052_pre_post_processors now takes ``dataset_repo_id``;
when ``enable_fast_action_loss`` and ``auto_fit_fast_tokenizer``
are both True and a repo_id is provided, the factory calls
``fit_fast_tokenizer`` before constructing the processor step
and points it at the fitted path.
* ProcessorConfigKwargs gains ``dataset_repo_id``; the global
factory dispatch threads it through for ``pi052`` policies.
* lerobot_train.py populates ``processor_kwargs['dataset_repo_id']``
from ``--dataset.repo_id`` for pi052 runs.
Failure mode: if ``.fit()`` fails (e.g. older transformers without
the method, or no usable action chunks in the dataset), the factory
logs a warning and falls back to the universal base tokenizer. Train
still works; you just lose the compression improvement.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -252,6 +252,12 @@ class ProcessorConfigKwargs(TypedDict, total=False):
|
||||
preprocessor_overrides: dict[str, Any] | None
|
||||
postprocessor_overrides: dict[str, Any] | None
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None
|
||||
# Optional: HF Hub repo id of the dataset the policy is being
|
||||
# trained on. Used by policies that auto-fit pieces of their
|
||||
# preprocessing (e.g. pi052's FAST action tokenizer per
|
||||
# Pertsch et al. 2025 [64], π0.5 §III.C). When omitted, those
|
||||
# policies fall back to their universal pre-fitted tokenizers.
|
||||
dataset_repo_id: str | None
|
||||
|
||||
|
||||
def make_pre_post_processors(
|
||||
@@ -387,6 +393,11 @@ def make_pre_post_processors(
|
||||
processors = make_pi052_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
# ``dataset_repo_id`` flows in via kwargs when FAST CE is
|
||||
# enabled — the train loop sets it from ``--dataset.repo_id``.
|
||||
# When ``None``, ``make_pi052_pre_post_processors`` skips
|
||||
# the auto-fit and uses the universal tokenizer.
|
||||
dataset_repo_id=kwargs.get("dataset_repo_id"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, PI05Config):
|
||||
|
||||
@@ -127,6 +127,23 @@ class PI052Config(PI05Config):
|
||||
fast_action_loss_weight: float = 1.0
|
||||
"""Weight on the FAST-action-token CE loss. Paper §III.C uses 1.0."""
|
||||
|
||||
auto_fit_fast_tokenizer: bool = True
|
||||
"""If True (default), the processor factory checks
|
||||
``fast_tokenizer_cache_dir`` for a previously-fitted tokenizer keyed
|
||||
on ``(dataset_repo_id, base_tokenizer_name, fit_samples)``. On cache
|
||||
miss, it loads ``action_tokenizer_name`` as a base, samples
|
||||
``fast_tokenizer_fit_samples`` action chunks from the dataset, runs
|
||||
``.fit()``, saves the result, and uses *that* fitted path as the
|
||||
actual tokenizer. Pertsch et al. 2025 (FAST paper [64], π0.5 §III.C)
|
||||
explicitly recommend per-dataset fitting for best compression."""
|
||||
|
||||
fast_tokenizer_cache_dir: str = "~/.cache/lerobot/fast_tokenizers"
|
||||
"""Where fitted FAST tokenizers are stored. ``~`` expands."""
|
||||
|
||||
fast_tokenizer_fit_samples: int = 1024
|
||||
"""Number of action chunks to sample for the fit. The FAST paper uses
|
||||
a few thousand; 1024 is a reasonable default for medium datasets."""
|
||||
|
||||
# Knowledge insulation — paper §III.B --------------------------------
|
||||
# When enabled, gradients from the action expert's flow loss are
|
||||
# *blocked* from flowing back into the VLM's K/V projections. This
|
||||
|
||||
@@ -0,0 +1,197 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Dataset-specific FAST action tokenizer fitting.
|
||||
|
||||
The published ``physical-intelligence/fast`` tokenizer is a *universal*
|
||||
codebook fitted on a heterogeneous mix of robot datasets. Per Pertsch
|
||||
et al. 2025 (the FAST paper, [64] in the π0.5 paper) and §III.C of
|
||||
π0.5 itself, the recommended practice is to **finetune the tokenizer on
|
||||
your specific dataset's action distribution** before training the
|
||||
policy — same way one would adapt a language tokenizer to a domain
|
||||
corpus. Without this finetune step, action sequences from your robot
|
||||
may require more tokens per chunk than necessary, lowering effective
|
||||
compression and slowing convergence of the action-CE loss.
|
||||
|
||||
This module provides a single utility, :func:`fit_fast_tokenizer`,
|
||||
that does the finetune. The training entry point invokes it
|
||||
automatically when the policy's ``enable_fast_action_loss`` and
|
||||
``auto_fit_fast_tokenizer`` flags are both ``True`` and no cached
|
||||
fitted tokenizer is found at ``fast_tokenizer_cache_dir``.
|
||||
|
||||
The fitted tokenizer is saved to
|
||||
``{cache_dir}/{dataset_hash}_{base_hash}/`` so successive training
|
||||
runs over the same dataset re-use it.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _dataset_signature(dataset_repo_id: str, base_tokenizer_name: str, n_samples: int) -> str:
|
||||
"""Deterministic short hash for naming the cache directory.
|
||||
|
||||
Keys on (dataset, base tokenizer, sample count) so re-fitting on a
|
||||
new dataset or a different base doesn't clobber the prior cache,
|
||||
and so changing n_samples re-runs the fit.
|
||||
"""
|
||||
h = hashlib.sha256()
|
||||
h.update(dataset_repo_id.encode("utf-8"))
|
||||
h.update(b"\0")
|
||||
h.update(base_tokenizer_name.encode("utf-8"))
|
||||
h.update(b"\0")
|
||||
h.update(str(n_samples).encode("utf-8"))
|
||||
return h.hexdigest()[:16]
|
||||
|
||||
|
||||
def fit_fast_tokenizer(
|
||||
*,
|
||||
dataset_repo_id: str,
|
||||
cache_dir: str | Path,
|
||||
base_tokenizer_name: str = "physical-intelligence/fast",
|
||||
n_samples: int = 1024,
|
||||
chunk_size: int = 50,
|
||||
seed: int = 42,
|
||||
) -> str:
|
||||
"""Fit a FAST tokenizer on a LeRobot dataset's action distribution.
|
||||
|
||||
Args:
|
||||
dataset_repo_id: HF Hub repo id of the LeRobotDataset to fit on.
|
||||
cache_dir: Directory under which to save (and look up) fitted
|
||||
tokenizers. The actual save path is
|
||||
``{cache_dir}/{signature}``.
|
||||
base_tokenizer_name: HF identifier for the base FAST tokenizer
|
||||
to finetune from. ``physical-intelligence/fast`` is the
|
||||
universal one.
|
||||
n_samples: Number of action chunks to sample for the fit. The
|
||||
FAST paper uses a few thousand; ``1024`` is a good default
|
||||
for medium datasets.
|
||||
chunk_size: Length of each action chunk (matches
|
||||
``policy.chunk_size``). The FAST tokenizer is fit on
|
||||
sequences of this length.
|
||||
seed: RNG seed for sample selection.
|
||||
|
||||
Returns:
|
||||
The local path to the fitted tokenizer. Passed directly to
|
||||
``--policy.action_tokenizer_name`` for the training run.
|
||||
|
||||
Raises:
|
||||
ImportError: If the ``transformers`` library doesn't expose
|
||||
``AutoProcessor`` or the FAST tokenizer doesn't have a
|
||||
``.fit()`` method (then you're on an older FAST snapshot —
|
||||
update to the current published model).
|
||||
FileNotFoundError: If the dataset can't be loaded.
|
||||
"""
|
||||
cache_dir = Path(cache_dir)
|
||||
sig = _dataset_signature(dataset_repo_id, base_tokenizer_name, n_samples)
|
||||
out_dir = cache_dir / sig
|
||||
|
||||
if out_dir.exists() and (out_dir / "preprocessor_config.json").exists():
|
||||
logger.info(
|
||||
"FAST tokenizer cache hit: %s — re-using fitted tokenizer for "
|
||||
"dataset=%s base=%s n_samples=%d",
|
||||
out_dir, dataset_repo_id, base_tokenizer_name, n_samples,
|
||||
)
|
||||
return str(out_dir)
|
||||
|
||||
logger.info(
|
||||
"FAST tokenizer cache miss — fitting on dataset=%s "
|
||||
"base=%s n_samples=%d chunk_size=%d → %s",
|
||||
dataset_repo_id, base_tokenizer_name, n_samples, chunk_size, out_dir,
|
||||
)
|
||||
|
||||
from transformers import AutoProcessor # noqa: PLC0415
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset # noqa: PLC0415
|
||||
|
||||
# Stream a single episode's worth of action chunks at a time so
|
||||
# we don't blow memory on huge datasets. Random episode +
|
||||
# random start offset gives a reasonable spread.
|
||||
rng = np.random.default_rng(seed)
|
||||
actions_buf: list[np.ndarray] = []
|
||||
|
||||
# Load just the metadata first to know episode boundaries.
|
||||
ds_meta_only = LeRobotDataset(dataset_repo_id, episodes=[0])
|
||||
num_episodes = ds_meta_only.meta.total_episodes
|
||||
del ds_meta_only
|
||||
|
||||
samples_per_episode = max(1, n_samples // max(num_episodes, 1))
|
||||
collected = 0
|
||||
eps_visited = 0
|
||||
for ep_idx in rng.permutation(num_episodes):
|
||||
if collected >= n_samples:
|
||||
break
|
||||
ep_idx = int(ep_idx)
|
||||
try:
|
||||
ds = LeRobotDataset(dataset_repo_id, episodes=[ep_idx])
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("FAST fit: skipping episode %d: %s", ep_idx, exc)
|
||||
continue
|
||||
if len(ds) < chunk_size:
|
||||
continue
|
||||
# Sample ``samples_per_episode`` start indices uniformly within
|
||||
# the episode.
|
||||
starts = rng.integers(0, len(ds) - chunk_size + 1, size=samples_per_episode)
|
||||
for s in starts:
|
||||
try:
|
||||
chunk_actions = np.stack(
|
||||
[
|
||||
np.asarray(ds[int(s) + j]["action"].cpu().numpy())
|
||||
for j in range(chunk_size)
|
||||
],
|
||||
axis=0,
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.debug("FAST fit: chunk at ep=%d s=%d failed: %s", ep_idx, s, exc)
|
||||
continue
|
||||
actions_buf.append(chunk_actions)
|
||||
collected += 1
|
||||
if collected >= n_samples:
|
||||
break
|
||||
eps_visited += 1
|
||||
|
||||
if not actions_buf:
|
||||
raise RuntimeError(
|
||||
f"FAST fit collected zero action chunks from {dataset_repo_id!r}. "
|
||||
"Check that the dataset has an ``action`` column and chunks of "
|
||||
f"length ≥ {chunk_size}."
|
||||
)
|
||||
|
||||
actions = np.stack(actions_buf, axis=0) # (N, H, D)
|
||||
logger.info(
|
||||
"FAST fit: collected %d chunks of shape %s from %d episodes",
|
||||
actions.shape[0], actions.shape[1:], eps_visited,
|
||||
)
|
||||
|
||||
base = AutoProcessor.from_pretrained(base_tokenizer_name, trust_remote_code=True)
|
||||
if not hasattr(base, "fit"):
|
||||
raise ImportError(
|
||||
f"Base FAST tokenizer {base_tokenizer_name!r} has no ``.fit()`` "
|
||||
"method — your transformers / model snapshot is too old. Update "
|
||||
"to the current ``physical-intelligence/fast`` revision."
|
||||
)
|
||||
|
||||
fitted = base.fit(actions)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
fitted.save_pretrained(str(out_dir))
|
||||
logger.info("FAST fit: saved fitted tokenizer to %s", out_dir)
|
||||
return str(out_dir)
|
||||
@@ -66,6 +66,7 @@ from .text_processor_pi052 import PI052TextTokenizerStep
|
||||
def make_pi052_pre_post_processors(
|
||||
config: PI052Config,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
dataset_repo_id: str | None = None,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
@@ -110,9 +111,39 @@ def make_pi052_pre_post_processors(
|
||||
# writes ACTION_TOKENS / ACTION_TOKEN_MASK into
|
||||
# ``COMPLEMENTARY_DATA`` and the modeling forward picks them up.
|
||||
if getattr(config, "enable_fast_action_loss", False):
|
||||
# Per Pertsch et al. 2025 (FAST [64], π0.5 §III.C): fit the
|
||||
# tokenizer on this dataset's action distribution rather than
|
||||
# using the universal codebook off the shelf. We do this once
|
||||
# and cache to disk, keyed on (dataset, base, n_samples).
|
||||
action_tokenizer_path = config.action_tokenizer_name
|
||||
if (
|
||||
getattr(config, "auto_fit_fast_tokenizer", False)
|
||||
and dataset_repo_id is not None
|
||||
):
|
||||
from .fit_fast_tokenizer import fit_fast_tokenizer # noqa: PLC0415
|
||||
|
||||
cache_dir = Path(config.fast_tokenizer_cache_dir).expanduser()
|
||||
try:
|
||||
action_tokenizer_path = fit_fast_tokenizer(
|
||||
dataset_repo_id=dataset_repo_id,
|
||||
cache_dir=cache_dir,
|
||||
base_tokenizer_name=config.action_tokenizer_name,
|
||||
n_samples=config.fast_tokenizer_fit_samples,
|
||||
chunk_size=config.chunk_size,
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
import logging # noqa: PLC0415
|
||||
|
||||
logging.getLogger(__name__).warning(
|
||||
"FAST tokenizer fit failed (%s) — falling back to "
|
||||
"the universal base tokenizer %r. Train will still "
|
||||
"work but compression will be suboptimal.",
|
||||
exc, config.action_tokenizer_name,
|
||||
)
|
||||
|
||||
input_steps.append(
|
||||
ActionTokenizerProcessorStep(
|
||||
action_tokenizer_name=config.action_tokenizer_name,
|
||||
action_tokenizer_name=action_tokenizer_path,
|
||||
max_action_tokens=config.max_action_tokens,
|
||||
fast_skip_tokens=config.fast_skip_tokens,
|
||||
paligemma_tokenizer_name="google/paligemma-3b-pt-224",
|
||||
|
||||
@@ -286,6 +286,14 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
if cfg.policy.type == "sarm":
|
||||
processor_kwargs["dataset_meta"] = dataset.meta
|
||||
|
||||
# For pi052 (and any future policy that auto-fits part of its
|
||||
# preprocessing per-dataset), pass the dataset repo id so the
|
||||
# processor factory can locate/refresh dataset-specific artifacts
|
||||
# (e.g. fitted FAST tokenizers per Pertsch et al. 2025 [64],
|
||||
# π0.5 §III.C).
|
||||
if cfg.policy.type == "pi052":
|
||||
processor_kwargs["dataset_repo_id"] = cfg.dataset.repo_id
|
||||
|
||||
if processor_pretrained_path is not None:
|
||||
processor_kwargs["preprocessor_overrides"] = {
|
||||
"device_processor": {"device": device.type},
|
||||
|
||||
Reference in New Issue
Block a user