mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-23 04:30:10 +00:00
fix(pi052): FAST tokenizer fit read actions from column, not ds[i]
fit_fast_tokenizer collected action chunks via ds[i]["action"], which builds a full training item — delta-timestamp expansion, video decode, image transforms. A single video-decode failure threw, was swallowed at debug level, and silently starved the fit of every chunk → "FAST fit collected zero action chunks", falling back to the universal tokenizer. Read the ``action`` column straight from the HF dataset instead: it carries no video, so it is immune to decode errors and far faster. Also fail fast with a clear message when the dataset has no ``action`` feature or all episodes are shorter than chunk_size. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -40,7 +40,6 @@ from __future__ import annotations
|
|||||||
import hashlib
|
import hashlib
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -133,44 +132,49 @@ def fit_fast_tokenizer(
|
|||||||
# Stream a single episode's worth of action chunks at a time so
|
# Stream a single episode's worth of action chunks at a time so
|
||||||
# we don't blow memory on huge datasets. Random episode +
|
# we don't blow memory on huge datasets. Random episode +
|
||||||
# random start offset gives a reasonable spread.
|
# random start offset gives a reasonable spread.
|
||||||
|
#
|
||||||
|
# Actions are read straight from the underlying HF dataset's
|
||||||
|
# ``action`` *column* — never via ``ds[i]``. ``ds[i]`` builds a full
|
||||||
|
# training item (delta-timestamp expansion + video decode + image
|
||||||
|
# transforms); a single bad video frame would then throw and, since
|
||||||
|
# the failure was swallowed at debug level, silently starve the fit
|
||||||
|
# of every chunk. The action column carries no video, so reading it
|
||||||
|
# directly is both faster and immune to decode errors.
|
||||||
rng = np.random.default_rng(seed)
|
rng = np.random.default_rng(seed)
|
||||||
actions_buf: list[np.ndarray] = []
|
actions_buf: list[np.ndarray] = []
|
||||||
|
|
||||||
# Load just the metadata first to know episode boundaries.
|
# Load just the metadata first to know episode boundaries.
|
||||||
ds_meta_only = LeRobotDataset(dataset_repo_id, episodes=[0])
|
ds_meta_only = LeRobotDataset(dataset_repo_id, episodes=[0])
|
||||||
num_episodes = ds_meta_only.meta.total_episodes
|
num_episodes = ds_meta_only.meta.total_episodes
|
||||||
|
if "action" not in ds_meta_only.features:
|
||||||
|
available = ", ".join(sorted(ds_meta_only.features)) or "<none>"
|
||||||
|
raise RuntimeError(
|
||||||
|
f"FAST fit: dataset {dataset_repo_id!r} has no ``action`` feature. "
|
||||||
|
f"Available features: {available}."
|
||||||
|
)
|
||||||
del ds_meta_only
|
del ds_meta_only
|
||||||
|
|
||||||
samples_per_episode = max(1, n_samples // max(num_episodes, 1))
|
samples_per_episode = max(1, n_samples // max(num_episodes, 1))
|
||||||
collected = 0
|
collected = 0
|
||||||
eps_visited = 0
|
eps_visited = 0
|
||||||
|
short_episodes = 0
|
||||||
for ep_idx in rng.permutation(num_episodes):
|
for ep_idx in rng.permutation(num_episodes):
|
||||||
if collected >= n_samples:
|
if collected >= n_samples:
|
||||||
break
|
break
|
||||||
ep_idx = int(ep_idx)
|
ep_idx = int(ep_idx)
|
||||||
try:
|
try:
|
||||||
ds = LeRobotDataset(dataset_repo_id, episodes=[ep_idx])
|
ds = LeRobotDataset(dataset_repo_id, episodes=[ep_idx])
|
||||||
|
ep_actions = np.asarray(ds.hf_dataset["action"], dtype=np.float32)
|
||||||
except Exception as exc: # noqa: BLE001
|
except Exception as exc: # noqa: BLE001
|
||||||
logger.warning("FAST fit: skipping episode %d: %s", ep_idx, exc)
|
logger.warning("FAST fit: skipping episode %d: %s", ep_idx, exc)
|
||||||
continue
|
continue
|
||||||
if len(ds) < chunk_size:
|
if ep_actions.ndim != 2 or ep_actions.shape[0] < chunk_size:
|
||||||
|
short_episodes += 1
|
||||||
continue
|
continue
|
||||||
# Sample ``samples_per_episode`` start indices uniformly within
|
# Sample ``samples_per_episode`` contiguous chunks uniformly.
|
||||||
# the episode.
|
starts = rng.integers(0, ep_actions.shape[0] - chunk_size + 1, size=samples_per_episode)
|
||||||
starts = rng.integers(0, len(ds) - chunk_size + 1, size=samples_per_episode)
|
|
||||||
for s in starts:
|
for s in starts:
|
||||||
try:
|
actions_buf.append(ep_actions[int(s) : int(s) + chunk_size])
|
||||||
chunk_actions = np.stack(
|
|
||||||
[
|
|
||||||
np.asarray(ds[int(s) + j]["action"].cpu().numpy())
|
|
||||||
for j in range(chunk_size)
|
|
||||||
],
|
|
||||||
axis=0,
|
|
||||||
)
|
|
||||||
except Exception as exc: # noqa: BLE001
|
|
||||||
logger.debug("FAST fit: chunk at ep=%d s=%d failed: %s", ep_idx, s, exc)
|
|
||||||
continue
|
|
||||||
actions_buf.append(chunk_actions)
|
|
||||||
collected += 1
|
collected += 1
|
||||||
if collected >= n_samples:
|
if collected >= n_samples:
|
||||||
break
|
break
|
||||||
@@ -178,9 +182,11 @@ def fit_fast_tokenizer(
|
|||||||
|
|
||||||
if not actions_buf:
|
if not actions_buf:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"FAST fit collected zero action chunks from {dataset_repo_id!r}. "
|
f"FAST fit collected zero action chunks from {dataset_repo_id!r}: "
|
||||||
"Check that the dataset has an ``action`` column and chunks of "
|
f"all {num_episodes} episodes were shorter than chunk_size="
|
||||||
f"length ≥ {chunk_size}."
|
f"{chunk_size} ({short_episodes} too short) or had an unreadable "
|
||||||
|
"``action`` column. Lower ``chunk_size`` to match your episode "
|
||||||
|
"lengths."
|
||||||
)
|
)
|
||||||
|
|
||||||
actions = np.stack(actions_buf, axis=0) # (N, H, D)
|
actions = np.stack(actions_buf, axis=0) # (N, H, D)
|
||||||
|
|||||||
Reference in New Issue
Block a user