mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-20 19:19:56 +00:00
fix(pi052): FAST tokenizer fit read actions from column, not ds[i]
fit_fast_tokenizer collected action chunks via ds[i]["action"], which builds a full training item — delta-timestamp expansion, video decode, image transforms. A single video-decode failure threw, was swallowed at debug level, and silently starved the fit of every chunk → "FAST fit collected zero action chunks", falling back to the universal tokenizer. Read the ``action`` column straight from the HF dataset instead: it carries no video, so it is immune to decode errors and far faster. Also fail fast with a clear message when the dataset has no ``action`` feature or all episodes are shorter than chunk_size. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -40,7 +40,6 @@ from __future__ import annotations
|
||||
import hashlib
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -133,44 +132,49 @@ def fit_fast_tokenizer(
|
||||
# Stream a single episode's worth of action chunks at a time so
|
||||
# we don't blow memory on huge datasets. Random episode +
|
||||
# random start offset gives a reasonable spread.
|
||||
#
|
||||
# Actions are read straight from the underlying HF dataset's
|
||||
# ``action`` *column* — never via ``ds[i]``. ``ds[i]`` builds a full
|
||||
# training item (delta-timestamp expansion + video decode + image
|
||||
# transforms); a single bad video frame would then throw and, since
|
||||
# the failure was swallowed at debug level, silently starve the fit
|
||||
# of every chunk. The action column carries no video, so reading it
|
||||
# directly is both faster and immune to decode errors.
|
||||
rng = np.random.default_rng(seed)
|
||||
actions_buf: list[np.ndarray] = []
|
||||
|
||||
# Load just the metadata first to know episode boundaries.
|
||||
ds_meta_only = LeRobotDataset(dataset_repo_id, episodes=[0])
|
||||
num_episodes = ds_meta_only.meta.total_episodes
|
||||
if "action" not in ds_meta_only.features:
|
||||
available = ", ".join(sorted(ds_meta_only.features)) or "<none>"
|
||||
raise RuntimeError(
|
||||
f"FAST fit: dataset {dataset_repo_id!r} has no ``action`` feature. "
|
||||
f"Available features: {available}."
|
||||
)
|
||||
del ds_meta_only
|
||||
|
||||
samples_per_episode = max(1, n_samples // max(num_episodes, 1))
|
||||
collected = 0
|
||||
eps_visited = 0
|
||||
short_episodes = 0
|
||||
for ep_idx in rng.permutation(num_episodes):
|
||||
if collected >= n_samples:
|
||||
break
|
||||
ep_idx = int(ep_idx)
|
||||
try:
|
||||
ds = LeRobotDataset(dataset_repo_id, episodes=[ep_idx])
|
||||
ep_actions = np.asarray(ds.hf_dataset["action"], dtype=np.float32)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("FAST fit: skipping episode %d: %s", ep_idx, exc)
|
||||
continue
|
||||
if len(ds) < chunk_size:
|
||||
if ep_actions.ndim != 2 or ep_actions.shape[0] < chunk_size:
|
||||
short_episodes += 1
|
||||
continue
|
||||
# Sample ``samples_per_episode`` start indices uniformly within
|
||||
# the episode.
|
||||
starts = rng.integers(0, len(ds) - chunk_size + 1, size=samples_per_episode)
|
||||
# Sample ``samples_per_episode`` contiguous chunks uniformly.
|
||||
starts = rng.integers(0, ep_actions.shape[0] - chunk_size + 1, size=samples_per_episode)
|
||||
for s in starts:
|
||||
try:
|
||||
chunk_actions = np.stack(
|
||||
[
|
||||
np.asarray(ds[int(s) + j]["action"].cpu().numpy())
|
||||
for j in range(chunk_size)
|
||||
],
|
||||
axis=0,
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.debug("FAST fit: chunk at ep=%d s=%d failed: %s", ep_idx, s, exc)
|
||||
continue
|
||||
actions_buf.append(chunk_actions)
|
||||
actions_buf.append(ep_actions[int(s) : int(s) + chunk_size])
|
||||
collected += 1
|
||||
if collected >= n_samples:
|
||||
break
|
||||
@@ -178,9 +182,11 @@ def fit_fast_tokenizer(
|
||||
|
||||
if not actions_buf:
|
||||
raise RuntimeError(
|
||||
f"FAST fit collected zero action chunks from {dataset_repo_id!r}. "
|
||||
"Check that the dataset has an ``action`` column and chunks of "
|
||||
f"length ≥ {chunk_size}."
|
||||
f"FAST fit collected zero action chunks from {dataset_repo_id!r}: "
|
||||
f"all {num_episodes} episodes were shorter than chunk_size="
|
||||
f"{chunk_size} ({short_episodes} too short) or had an unreadable "
|
||||
"``action`` column. Lower ``chunk_size`` to match your episode "
|
||||
"lengths."
|
||||
)
|
||||
|
||||
actions = np.stack(actions_buf, axis=0) # (N, H, D)
|
||||
|
||||
Reference in New Issue
Block a user