From 0f4faddc01fa5f5ee5a982d805120d7084753c83 Mon Sep 17 00:00:00 2001 From: Pepijn Date: Wed, 13 May 2026 11:52:31 +0200 Subject: [PATCH] feat(pi052): auto-fit FAST tokenizer per-dataset before training MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Per Pertsch et al. 2025 (FAST paper, [64] in π0.5) and π0.5 §III.C, the recommended practice is to *fit* the FAST action tokenizer on the specific dataset's action distribution rather than using the published universal codebook off the shelf. The universal tokenizer works on any 6-DoF action sequence but produces suboptimal compression, which slows CE convergence and wastes vocab capacity. New utility ``lerobot.policies.pi052.fit_fast_tokenizer``: * samples N action chunks from the LeRobotDataset (default 1024) * loads ``physical-intelligence/fast`` as the base * calls ``.fit(actions)`` (the AutoProcessor API the HF model card documents) — produces a per-dataset codebook * saves to ``{cache_dir}/{sha256(dataset, base, n_samples)[:16]}/`` * returns the local path, ready to feed ``ActionTokenizerProcessorStep(action_tokenizer_name=...)``. Cache is keyed on (dataset, base tokenizer, sample count) so changing any of them re-runs the fit. Re-running training on the same dataset re-uses the cache (one fit per dataset per machine). Auto-fit wiring: * PI052Config gets ``auto_fit_fast_tokenizer`` (default True), ``fast_tokenizer_cache_dir`` (default ~/.cache/lerobot/...), ``fast_tokenizer_fit_samples`` (default 1024). * make_pi052_pre_post_processors now takes ``dataset_repo_id``; when ``enable_fast_action_loss`` and ``auto_fit_fast_tokenizer`` are both True and a repo_id is provided, the factory calls ``fit_fast_tokenizer`` before constructing the processor step and points it at the fitted path. * ProcessorConfigKwargs gains ``dataset_repo_id``; the global factory dispatch threads it through for ``pi052`` policies. * lerobot_train.py populates ``processor_kwargs['dataset_repo_id']`` from ``--dataset.repo_id`` for pi052 runs. Failure mode: if ``.fit()`` fails (e.g. older transformers without the method, or no usable action chunks in the dataset), the factory logs a warning and falls back to the universal base tokenizer. Train still works; you just lose the compression improvement. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/lerobot/policies/factory.py | 11 + .../policies/pi052/configuration_pi052.py | 17 ++ .../policies/pi052/fit_fast_tokenizer.py | 197 ++++++++++++++++++ src/lerobot/policies/pi052/processor_pi052.py | 33 ++- src/lerobot/scripts/lerobot_train.py | 8 + 5 files changed, 265 insertions(+), 1 deletion(-) create mode 100644 src/lerobot/policies/pi052/fit_fast_tokenizer.py diff --git a/src/lerobot/policies/factory.py b/src/lerobot/policies/factory.py index 47777f1a8..777f67f84 100644 --- a/src/lerobot/policies/factory.py +++ b/src/lerobot/policies/factory.py @@ -252,6 +252,12 @@ class ProcessorConfigKwargs(TypedDict, total=False): preprocessor_overrides: dict[str, Any] | None postprocessor_overrides: dict[str, Any] | None dataset_stats: dict[str, dict[str, torch.Tensor]] | None + # Optional: HF Hub repo id of the dataset the policy is being + # trained on. Used by policies that auto-fit pieces of their + # preprocessing (e.g. pi052's FAST action tokenizer per + # Pertsch et al. 2025 [64], π0.5 §III.C). When omitted, those + # policies fall back to their universal pre-fitted tokenizers. + dataset_repo_id: str | None def make_pre_post_processors( @@ -387,6 +393,11 @@ def make_pre_post_processors( processors = make_pi052_pre_post_processors( config=policy_cfg, dataset_stats=kwargs.get("dataset_stats"), + # ``dataset_repo_id`` flows in via kwargs when FAST CE is + # enabled — the train loop sets it from ``--dataset.repo_id``. + # When ``None``, ``make_pi052_pre_post_processors`` skips + # the auto-fit and uses the universal tokenizer. + dataset_repo_id=kwargs.get("dataset_repo_id"), ) elif isinstance(policy_cfg, PI05Config): diff --git a/src/lerobot/policies/pi052/configuration_pi052.py b/src/lerobot/policies/pi052/configuration_pi052.py index f657c84b0..67dfd446a 100644 --- a/src/lerobot/policies/pi052/configuration_pi052.py +++ b/src/lerobot/policies/pi052/configuration_pi052.py @@ -127,6 +127,23 @@ class PI052Config(PI05Config): fast_action_loss_weight: float = 1.0 """Weight on the FAST-action-token CE loss. Paper §III.C uses 1.0.""" + auto_fit_fast_tokenizer: bool = True + """If True (default), the processor factory checks + ``fast_tokenizer_cache_dir`` for a previously-fitted tokenizer keyed + on ``(dataset_repo_id, base_tokenizer_name, fit_samples)``. On cache + miss, it loads ``action_tokenizer_name`` as a base, samples + ``fast_tokenizer_fit_samples`` action chunks from the dataset, runs + ``.fit()``, saves the result, and uses *that* fitted path as the + actual tokenizer. Pertsch et al. 2025 (FAST paper [64], π0.5 §III.C) + explicitly recommend per-dataset fitting for best compression.""" + + fast_tokenizer_cache_dir: str = "~/.cache/lerobot/fast_tokenizers" + """Where fitted FAST tokenizers are stored. ``~`` expands.""" + + fast_tokenizer_fit_samples: int = 1024 + """Number of action chunks to sample for the fit. The FAST paper uses + a few thousand; 1024 is a reasonable default for medium datasets.""" + # Knowledge insulation — paper §III.B -------------------------------- # When enabled, gradients from the action expert's flow loss are # *blocked* from flowing back into the VLM's K/V projections. This diff --git a/src/lerobot/policies/pi052/fit_fast_tokenizer.py b/src/lerobot/policies/pi052/fit_fast_tokenizer.py new file mode 100644 index 000000000..eda84e204 --- /dev/null +++ b/src/lerobot/policies/pi052/fit_fast_tokenizer.py @@ -0,0 +1,197 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Dataset-specific FAST action tokenizer fitting. + +The published ``physical-intelligence/fast`` tokenizer is a *universal* +codebook fitted on a heterogeneous mix of robot datasets. Per Pertsch +et al. 2025 (the FAST paper, [64] in the π0.5 paper) and §III.C of +π0.5 itself, the recommended practice is to **finetune the tokenizer on +your specific dataset's action distribution** before training the +policy — same way one would adapt a language tokenizer to a domain +corpus. Without this finetune step, action sequences from your robot +may require more tokens per chunk than necessary, lowering effective +compression and slowing convergence of the action-CE loss. + +This module provides a single utility, :func:`fit_fast_tokenizer`, +that does the finetune. The training entry point invokes it +automatically when the policy's ``enable_fast_action_loss`` and +``auto_fit_fast_tokenizer`` flags are both ``True`` and no cached +fitted tokenizer is found at ``fast_tokenizer_cache_dir``. + +The fitted tokenizer is saved to +``{cache_dir}/{dataset_hash}_{base_hash}/`` so successive training +runs over the same dataset re-use it. +""" + +from __future__ import annotations + +import hashlib +import logging +from pathlib import Path +from typing import Any + +import numpy as np + +logger = logging.getLogger(__name__) + + +def _dataset_signature(dataset_repo_id: str, base_tokenizer_name: str, n_samples: int) -> str: + """Deterministic short hash for naming the cache directory. + + Keys on (dataset, base tokenizer, sample count) so re-fitting on a + new dataset or a different base doesn't clobber the prior cache, + and so changing n_samples re-runs the fit. + """ + h = hashlib.sha256() + h.update(dataset_repo_id.encode("utf-8")) + h.update(b"\0") + h.update(base_tokenizer_name.encode("utf-8")) + h.update(b"\0") + h.update(str(n_samples).encode("utf-8")) + return h.hexdigest()[:16] + + +def fit_fast_tokenizer( + *, + dataset_repo_id: str, + cache_dir: str | Path, + base_tokenizer_name: str = "physical-intelligence/fast", + n_samples: int = 1024, + chunk_size: int = 50, + seed: int = 42, +) -> str: + """Fit a FAST tokenizer on a LeRobot dataset's action distribution. + + Args: + dataset_repo_id: HF Hub repo id of the LeRobotDataset to fit on. + cache_dir: Directory under which to save (and look up) fitted + tokenizers. The actual save path is + ``{cache_dir}/{signature}``. + base_tokenizer_name: HF identifier for the base FAST tokenizer + to finetune from. ``physical-intelligence/fast`` is the + universal one. + n_samples: Number of action chunks to sample for the fit. The + FAST paper uses a few thousand; ``1024`` is a good default + for medium datasets. + chunk_size: Length of each action chunk (matches + ``policy.chunk_size``). The FAST tokenizer is fit on + sequences of this length. + seed: RNG seed for sample selection. + + Returns: + The local path to the fitted tokenizer. Passed directly to + ``--policy.action_tokenizer_name`` for the training run. + + Raises: + ImportError: If the ``transformers`` library doesn't expose + ``AutoProcessor`` or the FAST tokenizer doesn't have a + ``.fit()`` method (then you're on an older FAST snapshot — + update to the current published model). + FileNotFoundError: If the dataset can't be loaded. + """ + cache_dir = Path(cache_dir) + sig = _dataset_signature(dataset_repo_id, base_tokenizer_name, n_samples) + out_dir = cache_dir / sig + + if out_dir.exists() and (out_dir / "preprocessor_config.json").exists(): + logger.info( + "FAST tokenizer cache hit: %s — re-using fitted tokenizer for " + "dataset=%s base=%s n_samples=%d", + out_dir, dataset_repo_id, base_tokenizer_name, n_samples, + ) + return str(out_dir) + + logger.info( + "FAST tokenizer cache miss — fitting on dataset=%s " + "base=%s n_samples=%d chunk_size=%d → %s", + dataset_repo_id, base_tokenizer_name, n_samples, chunk_size, out_dir, + ) + + from transformers import AutoProcessor # noqa: PLC0415 + + from lerobot.datasets.lerobot_dataset import LeRobotDataset # noqa: PLC0415 + + # Stream a single episode's worth of action chunks at a time so + # we don't blow memory on huge datasets. Random episode + + # random start offset gives a reasonable spread. + rng = np.random.default_rng(seed) + actions_buf: list[np.ndarray] = [] + + # Load just the metadata first to know episode boundaries. + ds_meta_only = LeRobotDataset(dataset_repo_id, episodes=[0]) + num_episodes = ds_meta_only.meta.total_episodes + del ds_meta_only + + samples_per_episode = max(1, n_samples // max(num_episodes, 1)) + collected = 0 + eps_visited = 0 + for ep_idx in rng.permutation(num_episodes): + if collected >= n_samples: + break + ep_idx = int(ep_idx) + try: + ds = LeRobotDataset(dataset_repo_id, episodes=[ep_idx]) + except Exception as exc: # noqa: BLE001 + logger.warning("FAST fit: skipping episode %d: %s", ep_idx, exc) + continue + if len(ds) < chunk_size: + continue + # Sample ``samples_per_episode`` start indices uniformly within + # the episode. + starts = rng.integers(0, len(ds) - chunk_size + 1, size=samples_per_episode) + for s in starts: + try: + chunk_actions = np.stack( + [ + np.asarray(ds[int(s) + j]["action"].cpu().numpy()) + for j in range(chunk_size) + ], + axis=0, + ) + except Exception as exc: # noqa: BLE001 + logger.debug("FAST fit: chunk at ep=%d s=%d failed: %s", ep_idx, s, exc) + continue + actions_buf.append(chunk_actions) + collected += 1 + if collected >= n_samples: + break + eps_visited += 1 + + if not actions_buf: + raise RuntimeError( + f"FAST fit collected zero action chunks from {dataset_repo_id!r}. " + "Check that the dataset has an ``action`` column and chunks of " + f"length ≥ {chunk_size}." + ) + + actions = np.stack(actions_buf, axis=0) # (N, H, D) + logger.info( + "FAST fit: collected %d chunks of shape %s from %d episodes", + actions.shape[0], actions.shape[1:], eps_visited, + ) + + base = AutoProcessor.from_pretrained(base_tokenizer_name, trust_remote_code=True) + if not hasattr(base, "fit"): + raise ImportError( + f"Base FAST tokenizer {base_tokenizer_name!r} has no ``.fit()`` " + "method — your transformers / model snapshot is too old. Update " + "to the current ``physical-intelligence/fast`` revision." + ) + + fitted = base.fit(actions) + out_dir.mkdir(parents=True, exist_ok=True) + fitted.save_pretrained(str(out_dir)) + logger.info("FAST fit: saved fitted tokenizer to %s", out_dir) + return str(out_dir) diff --git a/src/lerobot/policies/pi052/processor_pi052.py b/src/lerobot/policies/pi052/processor_pi052.py index 7c3f9c4eb..f7ec21d0a 100644 --- a/src/lerobot/policies/pi052/processor_pi052.py +++ b/src/lerobot/policies/pi052/processor_pi052.py @@ -66,6 +66,7 @@ from .text_processor_pi052 import PI052TextTokenizerStep def make_pi052_pre_post_processors( config: PI052Config, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None, + dataset_repo_id: str | None = None, ) -> tuple[ PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], PolicyProcessorPipeline[PolicyAction, PolicyAction], @@ -110,9 +111,39 @@ def make_pi052_pre_post_processors( # writes ACTION_TOKENS / ACTION_TOKEN_MASK into # ``COMPLEMENTARY_DATA`` and the modeling forward picks them up. if getattr(config, "enable_fast_action_loss", False): + # Per Pertsch et al. 2025 (FAST [64], π0.5 §III.C): fit the + # tokenizer on this dataset's action distribution rather than + # using the universal codebook off the shelf. We do this once + # and cache to disk, keyed on (dataset, base, n_samples). + action_tokenizer_path = config.action_tokenizer_name + if ( + getattr(config, "auto_fit_fast_tokenizer", False) + and dataset_repo_id is not None + ): + from .fit_fast_tokenizer import fit_fast_tokenizer # noqa: PLC0415 + + cache_dir = Path(config.fast_tokenizer_cache_dir).expanduser() + try: + action_tokenizer_path = fit_fast_tokenizer( + dataset_repo_id=dataset_repo_id, + cache_dir=cache_dir, + base_tokenizer_name=config.action_tokenizer_name, + n_samples=config.fast_tokenizer_fit_samples, + chunk_size=config.chunk_size, + ) + except Exception as exc: # noqa: BLE001 + import logging # noqa: PLC0415 + + logging.getLogger(__name__).warning( + "FAST tokenizer fit failed (%s) — falling back to " + "the universal base tokenizer %r. Train will still " + "work but compression will be suboptimal.", + exc, config.action_tokenizer_name, + ) + input_steps.append( ActionTokenizerProcessorStep( - action_tokenizer_name=config.action_tokenizer_name, + action_tokenizer_name=action_tokenizer_path, max_action_tokens=config.max_action_tokens, fast_skip_tokens=config.fast_skip_tokens, paligemma_tokenizer_name="google/paligemma-3b-pt-224", diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index ec923bcc1..a0cd204f2 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -286,6 +286,14 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None): if cfg.policy.type == "sarm": processor_kwargs["dataset_meta"] = dataset.meta + # For pi052 (and any future policy that auto-fits part of its + # preprocessing per-dataset), pass the dataset repo id so the + # processor factory can locate/refresh dataset-specific artifacts + # (e.g. fitted FAST tokenizers per Pertsch et al. 2025 [64], + # π0.5 §III.C). + if cfg.policy.type == "pi052": + processor_kwargs["dataset_repo_id"] = cfg.dataset.repo_id + if processor_pretrained_path is not None: processor_kwargs["preprocessor_overrides"] = { "device_processor": {"device": device.type},