fix(pi052): FAST tokenizer fit read actions from column, not ds[i]

fit_fast_tokenizer collected action chunks via ds[i]["action"], which
builds a full training item — delta-timestamp expansion, video decode,
image transforms. A single video-decode failure threw, was swallowed
at debug level, and silently starved the fit of every chunk → "FAST
fit collected zero action chunks", falling back to the universal
tokenizer.

Read the ``action`` column straight from the HF dataset instead: it
carries no video, so it is immune to decode errors and far faster.
Also fail fast with a clear message when the dataset has no ``action``
feature or all episodes are shorter than chunk_size.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-05-19 22:51:53 +02:00
parent ddf4bc2063
commit bc0c993b25
@@ -40,7 +40,6 @@ from __future__ import annotations
import hashlib import hashlib
import logging import logging
from pathlib import Path from pathlib import Path
from typing import Any
import numpy as np import numpy as np
@@ -133,44 +132,49 @@ def fit_fast_tokenizer(
# Stream a single episode's worth of action chunks at a time so # Stream a single episode's worth of action chunks at a time so
# we don't blow memory on huge datasets. Random episode + # we don't blow memory on huge datasets. Random episode +
# random start offset gives a reasonable spread. # random start offset gives a reasonable spread.
#
# Actions are read straight from the underlying HF dataset's
# ``action`` *column* — never via ``ds[i]``. ``ds[i]`` builds a full
# training item (delta-timestamp expansion + video decode + image
# transforms); a single bad video frame would then throw and, since
# the failure was swallowed at debug level, silently starve the fit
# of every chunk. The action column carries no video, so reading it
# directly is both faster and immune to decode errors.
rng = np.random.default_rng(seed) rng = np.random.default_rng(seed)
actions_buf: list[np.ndarray] = [] actions_buf: list[np.ndarray] = []
# Load just the metadata first to know episode boundaries. # Load just the metadata first to know episode boundaries.
ds_meta_only = LeRobotDataset(dataset_repo_id, episodes=[0]) ds_meta_only = LeRobotDataset(dataset_repo_id, episodes=[0])
num_episodes = ds_meta_only.meta.total_episodes num_episodes = ds_meta_only.meta.total_episodes
if "action" not in ds_meta_only.features:
available = ", ".join(sorted(ds_meta_only.features)) or "<none>"
raise RuntimeError(
f"FAST fit: dataset {dataset_repo_id!r} has no ``action`` feature. "
f"Available features: {available}."
)
del ds_meta_only del ds_meta_only
samples_per_episode = max(1, n_samples // max(num_episodes, 1)) samples_per_episode = max(1, n_samples // max(num_episodes, 1))
collected = 0 collected = 0
eps_visited = 0 eps_visited = 0
short_episodes = 0
for ep_idx in rng.permutation(num_episodes): for ep_idx in rng.permutation(num_episodes):
if collected >= n_samples: if collected >= n_samples:
break break
ep_idx = int(ep_idx) ep_idx = int(ep_idx)
try: try:
ds = LeRobotDataset(dataset_repo_id, episodes=[ep_idx]) ds = LeRobotDataset(dataset_repo_id, episodes=[ep_idx])
ep_actions = np.asarray(ds.hf_dataset["action"], dtype=np.float32)
except Exception as exc: # noqa: BLE001 except Exception as exc: # noqa: BLE001
logger.warning("FAST fit: skipping episode %d: %s", ep_idx, exc) logger.warning("FAST fit: skipping episode %d: %s", ep_idx, exc)
continue continue
if len(ds) < chunk_size: if ep_actions.ndim != 2 or ep_actions.shape[0] < chunk_size:
short_episodes += 1
continue continue
# Sample ``samples_per_episode`` start indices uniformly within # Sample ``samples_per_episode`` contiguous chunks uniformly.
# the episode. starts = rng.integers(0, ep_actions.shape[0] - chunk_size + 1, size=samples_per_episode)
starts = rng.integers(0, len(ds) - chunk_size + 1, size=samples_per_episode)
for s in starts: for s in starts:
try: actions_buf.append(ep_actions[int(s) : int(s) + chunk_size])
chunk_actions = np.stack(
[
np.asarray(ds[int(s) + j]["action"].cpu().numpy())
for j in range(chunk_size)
],
axis=0,
)
except Exception as exc: # noqa: BLE001
logger.debug("FAST fit: chunk at ep=%d s=%d failed: %s", ep_idx, s, exc)
continue
actions_buf.append(chunk_actions)
collected += 1 collected += 1
if collected >= n_samples: if collected >= n_samples:
break break
@@ -178,9 +182,11 @@ def fit_fast_tokenizer(
if not actions_buf: if not actions_buf:
raise RuntimeError( raise RuntimeError(
f"FAST fit collected zero action chunks from {dataset_repo_id!r}. " f"FAST fit collected zero action chunks from {dataset_repo_id!r}: "
"Check that the dataset has an ``action`` column and chunks of " f"all {num_episodes} episodes were shorter than chunk_size="
f"length ≥ {chunk_size}." f"{chunk_size} ({short_episodes} too short) or had an unreadable "
"``action`` column. Lower ``chunk_size`` to match your episode "
"lengths."
) )
actions = np.stack(actions_buf, axis=0) # (N, H, D) actions = np.stack(actions_buf, axis=0) # (N, H, D)