diff --git a/src/lerobot/policies/pi052/fit_fast_tokenizer.py b/src/lerobot/policies/pi052/fit_fast_tokenizer.py index 37155d98d..14e4217ca 100644 --- a/src/lerobot/policies/pi052/fit_fast_tokenizer.py +++ b/src/lerobot/policies/pi052/fit_fast_tokenizer.py @@ -40,7 +40,6 @@ from __future__ import annotations import hashlib import logging from pathlib import Path -from typing import Any import numpy as np @@ -133,44 +132,49 @@ def fit_fast_tokenizer( # Stream a single episode's worth of action chunks at a time so # we don't blow memory on huge datasets. Random episode + # random start offset gives a reasonable spread. + # + # Actions are read straight from the underlying HF dataset's + # ``action`` *column* — never via ``ds[i]``. ``ds[i]`` builds a full + # training item (delta-timestamp expansion + video decode + image + # transforms); a single bad video frame would then throw and, since + # the failure was swallowed at debug level, silently starve the fit + # of every chunk. The action column carries no video, so reading it + # directly is both faster and immune to decode errors. rng = np.random.default_rng(seed) actions_buf: list[np.ndarray] = [] # Load just the metadata first to know episode boundaries. ds_meta_only = LeRobotDataset(dataset_repo_id, episodes=[0]) num_episodes = ds_meta_only.meta.total_episodes + if "action" not in ds_meta_only.features: + available = ", ".join(sorted(ds_meta_only.features)) or "" + raise RuntimeError( + f"FAST fit: dataset {dataset_repo_id!r} has no ``action`` feature. " + f"Available features: {available}." + ) del ds_meta_only samples_per_episode = max(1, n_samples // max(num_episodes, 1)) collected = 0 eps_visited = 0 + short_episodes = 0 for ep_idx in rng.permutation(num_episodes): if collected >= n_samples: break ep_idx = int(ep_idx) try: ds = LeRobotDataset(dataset_repo_id, episodes=[ep_idx]) + ep_actions = np.asarray(ds.hf_dataset["action"], dtype=np.float32) except Exception as exc: # noqa: BLE001 logger.warning("FAST fit: skipping episode %d: %s", ep_idx, exc) continue - if len(ds) < chunk_size: + if ep_actions.ndim != 2 or ep_actions.shape[0] < chunk_size: + short_episodes += 1 continue - # Sample ``samples_per_episode`` start indices uniformly within - # the episode. - starts = rng.integers(0, len(ds) - chunk_size + 1, size=samples_per_episode) + # Sample ``samples_per_episode`` contiguous chunks uniformly. + starts = rng.integers(0, ep_actions.shape[0] - chunk_size + 1, size=samples_per_episode) for s in starts: - try: - chunk_actions = np.stack( - [ - np.asarray(ds[int(s) + j]["action"].cpu().numpy()) - for j in range(chunk_size) - ], - axis=0, - ) - except Exception as exc: # noqa: BLE001 - logger.debug("FAST fit: chunk at ep=%d s=%d failed: %s", ep_idx, s, exc) - continue - actions_buf.append(chunk_actions) + actions_buf.append(ep_actions[int(s) : int(s) + chunk_size]) collected += 1 if collected >= n_samples: break @@ -178,9 +182,11 @@ def fit_fast_tokenizer( if not actions_buf: raise RuntimeError( - f"FAST fit collected zero action chunks from {dataset_repo_id!r}. " - "Check that the dataset has an ``action`` column and chunks of " - f"length ≥ {chunk_size}." + f"FAST fit collected zero action chunks from {dataset_repo_id!r}: " + f"all {num_episodes} episodes were shorter than chunk_size=" + f"{chunk_size} ({short_episodes} too short) or had an unreadable " + "``action`` column. Lower ``chunk_size`` to match your episode " + "lengths." ) actions = np.stack(actions_buf, axis=0) # (N, H, D)