mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-18 08:47:05 +00:00
Merge remote-tracking branch 'origin/feat/smolvla-on-steerable' into feat/smolvla-on-steerable
This commit is contained in:
@@ -178,35 +178,76 @@ def fit_fast_tokenizer(
|
||||
rng = np.random.default_rng(seed)
|
||||
actions_buf: list[np.ndarray] = []
|
||||
|
||||
# Load just the metadata first to know episode boundaries.
|
||||
ds_meta_only = LeRobotDataset(dataset_repo_id, episodes=[0])
|
||||
num_episodes = ds_meta_only.meta.total_episodes
|
||||
if "action" not in ds_meta_only.features:
|
||||
available = ", ".join(sorted(ds_meta_only.features)) or "<none>"
|
||||
# Resolve the dataset's data parquet shards directly, sidestepping
|
||||
# ``LeRobotDataset(repo_id, episodes=[N])`` which on v3-format
|
||||
# datasets routes through HF datasets'' split lookup and raises
|
||||
# ``ValueError: Instruction "train" corresponds to no data!`` for
|
||||
# every episode (job 22182985 looped through 13,293 skipped episodes
|
||||
# for ~2.5 h before NCCL killed it). Reading the ``action`` column
|
||||
# straight from the parquet shards is also faster: each per-episode
|
||||
# ``LeRobotDataset`` instantiation re-parses every meta file.
|
||||
from huggingface_hub import snapshot_download # noqa: PLC0415
|
||||
import pyarrow as _pa # noqa: PLC0415
|
||||
import pyarrow.parquet as _pq # noqa: PLC0415
|
||||
|
||||
snap = Path(snapshot_download(repo_id=dataset_repo_id, repo_type="dataset"))
|
||||
data_files = sorted((snap / "data").glob("chunk-*/file-*.parquet"))
|
||||
if not data_files:
|
||||
raise RuntimeError(
|
||||
f"FAST fit: dataset {dataset_repo_id!r} has no ``action`` feature. "
|
||||
f"Available features: {available}."
|
||||
f"FAST fit: no ``data/chunk-*/file-*.parquet`` shards found under {snap!s}."
|
||||
)
|
||||
del ds_meta_only
|
||||
|
||||
# Read just the (episode_index, action) columns once across all
|
||||
# shards. This is the same pattern used elsewhere in the codebase
|
||||
# for whole-dataset audits and stays under ~2 GB even on 32 k-episode
|
||||
# / 29 M-frame datasets because the action column is a fixed-length
|
||||
# float vector.
|
||||
tables = [_pq.read_table(f, columns=["episode_index", "action"]) for f in data_files]
|
||||
table = _pa.concat_tables(tables)
|
||||
eps = table["episode_index"].to_numpy()
|
||||
acts_col = table["action"]
|
||||
# ``action`` may be a fixed-shape ListArray or a 2-D NumericArray;
|
||||
# ``to_numpy(zero_copy_only=False)`` produces an object array of
|
||||
# 1-D NumPy actions either way, which we stack into (N, D).
|
||||
try:
|
||||
acts = np.stack(acts_col.to_numpy(zero_copy_only=False)).astype(np.float32)
|
||||
except Exception: # noqa: BLE001
|
||||
# Fallback path for nested-list types: flatten via to_pylist().
|
||||
acts = np.asarray(acts_col.to_pylist(), dtype=np.float32)
|
||||
if acts.ndim != 2:
|
||||
raise RuntimeError(
|
||||
f"FAST fit: expected ``action`` rows to be 1-D vectors; got shape {acts.shape}."
|
||||
)
|
||||
|
||||
# Episode index → slice (start, stop) into ``acts`` along axis 0.
|
||||
# ``eps`` is monotonically increasing within each parquet shard but
|
||||
# we make no assumption across shards — sort once and group.
|
||||
order = np.argsort(eps, kind="stable")
|
||||
eps_sorted = eps[order]
|
||||
boundaries = np.searchsorted(eps_sorted, np.arange(int(eps_sorted.max()) + 2))
|
||||
ep_to_slice: dict[int, tuple[int, int]] = {
|
||||
int(ep): (int(boundaries[ep]), int(boundaries[ep + 1]))
|
||||
for ep in range(len(boundaries) - 1)
|
||||
if boundaries[ep] < boundaries[ep + 1]
|
||||
}
|
||||
num_episodes = len(ep_to_slice)
|
||||
# ``acts`` is in original (un-sorted-by-episode) row order; reorder
|
||||
# so per-episode slices are contiguous.
|
||||
acts = acts[order]
|
||||
|
||||
samples_per_episode = max(1, n_samples // max(num_episodes, 1))
|
||||
collected = 0
|
||||
eps_visited = 0
|
||||
short_episodes = 0
|
||||
for ep_idx in rng.permutation(num_episodes):
|
||||
ep_indices = list(ep_to_slice.keys())
|
||||
for ep_idx in rng.permutation(ep_indices):
|
||||
if collected >= n_samples:
|
||||
break
|
||||
ep_idx = int(ep_idx)
|
||||
try:
|
||||
ds = LeRobotDataset(dataset_repo_id, episodes=[ep_idx])
|
||||
ep_actions = np.asarray(ds.hf_dataset["action"], dtype=np.float32)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("FAST fit: skipping episode %d: %s", ep_idx, exc)
|
||||
continue
|
||||
if ep_actions.ndim != 2 or ep_actions.shape[0] < chunk_size:
|
||||
start, stop = ep_to_slice[int(ep_idx)]
|
||||
ep_actions = acts[start:stop]
|
||||
if ep_actions.shape[0] < chunk_size:
|
||||
short_episodes += 1
|
||||
continue
|
||||
# Sample ``samples_per_episode`` contiguous chunks uniformly.
|
||||
starts = rng.integers(0, ep_actions.shape[0] - chunk_size + 1, size=samples_per_episode)
|
||||
for s in starts:
|
||||
actions_buf.append(ep_actions[int(s) : int(s) + chunk_size])
|
||||
|
||||
@@ -55,6 +55,17 @@ from .configuration_pi052 import PI052Config
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# FAST action-token vocab size (``lerobot/fast-action-tokenizer``). The
|
||||
# tokenizer maps a FAST BPE id ``t`` to the PaliGemma vocab id
|
||||
# ``vocab_size - 1 - fast_skip_tokens - t`` (see ``TokenizerProcessorStep``),
|
||||
# so action tokens occupy the top ``_FAST_ACTION_VOCAB_SIZE`` ids below the
|
||||
# ``fast_skip_tokens`` margin. The upper part collides with the reserved
|
||||
# ``<loc>`` block; the lower part sits just under it and otherwise leaks into
|
||||
# generated text as high-codepoint gibberish (the action-trained LM head puts
|
||||
# heavy mass on these ids), so ``select_message`` masks it.
|
||||
_FAST_ACTION_VOCAB_SIZE = 2048
|
||||
|
||||
|
||||
_HF_KERNELS_ENABLED = False
|
||||
|
||||
|
||||
@@ -998,60 +1009,11 @@ class PI052Policy(PI05Policy):
|
||||
text_logits = lm_head(text_hidden.to(lm_head.weight.dtype))
|
||||
preds = text_logits.argmax(dim=-1)
|
||||
|
||||
# Train/inference parity check — run select_message on the
|
||||
# *same* prompt prefix (the language up to but not including
|
||||
# the supervised span) and capture the auto-regressive
|
||||
# generation. The first generated token MUST match the
|
||||
# training-side argmax at the prompt-end position (both are
|
||||
# ``argmax lm_head(h_last_prompt)`` over identical context);
|
||||
# any divergence is a parity bug (mask, dtype, KI routing
|
||||
# difference). Later tokens can diverge because training
|
||||
# uses teacher forcing while inference free-runs.
|
||||
inference_outputs: list[dict[str, Any]] = []
|
||||
for s in range(n):
|
||||
row_labels = sub_labels[s]
|
||||
sup_pos = (row_labels != -100).nonzero(as_tuple=True)[0]
|
||||
if sup_pos.numel() == 0:
|
||||
inference_outputs.append({"first_token": None, "decoded": ""})
|
||||
continue
|
||||
first_sup = int(sup_pos[0].item())
|
||||
# Build a single-sample batch by *truncating* the token
|
||||
# sequence to the prompt-only portion (length == first_sup),
|
||||
# not by zero-masking. ``select_message`` reads the
|
||||
# prompt-end hidden state via ``vlm_out[:, -1:]`` — the
|
||||
# *last position* of the prefix — so a padded sequence
|
||||
# would make it read a padding-token hidden state
|
||||
# (PaliGemma's prior on those happens to be ``<loc>``,
|
||||
# which would falsely flag a parity diverge). The real
|
||||
# runtime feeds ``tokenizer(prompt)`` without padding,
|
||||
# so we mirror that here.
|
||||
prompt_tokens = sub[OBS_LANGUAGE_TOKENS][s : s + 1, :first_sup]
|
||||
prompt_mask_orig = sub[OBS_LANGUAGE_ATTENTION_MASK][s : s + 1, :first_sup]
|
||||
inf_batch: dict[str, Any] = {
|
||||
OBS_LANGUAGE_TOKENS: prompt_tokens,
|
||||
OBS_LANGUAGE_ATTENTION_MASK: prompt_mask_orig,
|
||||
}
|
||||
for k, v in sub.items():
|
||||
if isinstance(k, str) and k.startswith("observation.images."):
|
||||
inf_batch[k] = v[s : s + 1]
|
||||
if "observation.state" in batch and torch.is_tensor(batch["observation.state"]):
|
||||
inf_batch["observation.state"] = batch["observation.state"][s : s + 1]
|
||||
try:
|
||||
# Tight budget — we just want to see the model's
|
||||
# opening continuation, not the full sequence.
|
||||
decoded = self.select_message(
|
||||
inf_batch, max_new_tokens=24, temperature=0.0, top_p=1.0
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
decoded = f"<inference failed: {type(exc).__name__}: {exc}>"
|
||||
inference_outputs.append({"first_sup_pos": first_sup, "decoded": decoded})
|
||||
|
||||
return {
|
||||
"input_ids": lang_tokens.detach().cpu(),
|
||||
"attention_mask": lang_masks.detach().cpu(),
|
||||
"labels": sub_labels.detach().cpu(),
|
||||
"predictions": preds.detach().cpu(),
|
||||
"inference": inference_outputs,
|
||||
}
|
||||
finally:
|
||||
if was_training:
|
||||
@@ -1166,6 +1128,15 @@ class PI052Policy(PI05Policy):
|
||||
if special_ids and len(generated) < min_new_tokens:
|
||||
for sid in special_ids:
|
||||
logits_step[..., sid] = float("-inf")
|
||||
# Mask FAST action tokens that fall *below* the ``<loc>`` block.
|
||||
# They are never valid text, but the action-trained head leaks
|
||||
# them as gibberish; unlike the loc/seg block this region is never
|
||||
# legitimately emitted (even by VQA), so suppress it on every call.
|
||||
vocab_size = logits_step.shape[-1]
|
||||
fast_skip = int(getattr(self.config, "fast_skip_tokens", 128))
|
||||
fast_lo = vocab_size - 1 - fast_skip - (_FAST_ACTION_VOCAB_SIZE - 1)
|
||||
if 0 < fast_lo < 256000:
|
||||
logits_step[..., fast_lo:256000] = float("-inf")
|
||||
if suppress_loc_tokens:
|
||||
logits_step[..., 256000:257024] = float("-inf")
|
||||
next_ids = self._sample_next_token(logits_step, temperature, top_p)
|
||||
|
||||
@@ -224,7 +224,6 @@ def _print_debug_text_predictions(
|
||||
labels = debug["labels"]
|
||||
preds = debug["predictions"]
|
||||
attn = debug["attention_mask"]
|
||||
inference = debug.get("inference") or []
|
||||
|
||||
n = ids.shape[0]
|
||||
print(
|
||||
@@ -251,7 +250,6 @@ def _print_debug_text_predictions(
|
||||
|
||||
# Training-side teacher-forced argmax on the same prompt+target.
|
||||
n_sup = n_ok = 0
|
||||
first_sup_pred: int | None = None
|
||||
teacher_chars: list[int] = []
|
||||
for i in range(1, real):
|
||||
label = sl[i]
|
||||
@@ -259,8 +257,6 @@ def _print_debug_text_predictions(
|
||||
continue
|
||||
n_sup += 1
|
||||
pred = int(sp[i - 1])
|
||||
if first_sup_pred is None:
|
||||
first_sup_pred = pred
|
||||
teacher_chars.append(pred)
|
||||
if label == pred:
|
||||
n_ok += 1
|
||||
@@ -272,28 +268,6 @@ def _print_debug_text_predictions(
|
||||
f" training argmax (teacher-fed) : {teacher_text!r} acc={n_ok}/{n_sup}={acc:.1%}",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
# Inference-side autoregressive output from the same prompt prefix.
|
||||
inf_entry = inference[s] if s < len(inference) else None
|
||||
if inf_entry:
|
||||
inf_decoded = inf_entry.get("decoded", "")
|
||||
print(f" inference (autoregressive) : {inf_decoded!r}", flush=True)
|
||||
# First-token parity: training-side argmax at the prompt-end
|
||||
# position MUST equal inference's first generated token —
|
||||
# both compute argmax(lm_head(h_last_prompt)) on identical
|
||||
# context. Any divergence signals a training↔inference bug.
|
||||
if first_sup_pred is not None and inf_decoded and not inf_decoded.startswith("<inference"):
|
||||
inf_ids = tokenizer(inf_decoded, add_special_tokens=False)["input_ids"]
|
||||
if inf_ids:
|
||||
inf_first = int(inf_ids[0])
|
||||
match = inf_first == first_sup_pred
|
||||
print(
|
||||
f" first-token parity : "
|
||||
f"train={first_sup_pred} ({tokenizer.decode([first_sup_pred])!r}) "
|
||||
f"vs infer={inf_first} ({tokenizer.decode([inf_first])!r}) "
|
||||
f"{'✓ MATCH' if match else '✗ DIVERGED — training/inference mismatch'}",
|
||||
flush=True,
|
||||
)
|
||||
print("=" * 60 + "\n", flush=True)
|
||||
|
||||
|
||||
@@ -381,15 +355,26 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
# We set step_scheduler_with_optimizer=False to prevent accelerate from adjusting the lr_scheduler steps based on the num_processes
|
||||
# We set find_unused_parameters=True to handle models with conditional computation
|
||||
if accelerator is None:
|
||||
from accelerate.utils import DistributedDataParallelKwargs
|
||||
from datetime import timedelta
|
||||
|
||||
from accelerate.utils import DistributedDataParallelKwargs, InitProcessGroupKwargs
|
||||
|
||||
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
||||
# Bump the c10d store-get / barrier timeout so the rank-0-only
|
||||
# ``make_dataset`` block below doesn't trigger a barrier crash on
|
||||
# large datasets. Default is 10 min (``store->get`` 600 s); a
|
||||
# 32 k-episode v3 dataset (e.g. ``robocasa_pretrain_human300_v4``)
|
||||
# spends >13 min on rank 0 building the episode/frame index
|
||||
# while ranks 1-N idle at ``wait_for_everyone()`` and crash with
|
||||
# ``DistBackendError: ... wait timeout after 600000ms``. 2 h is
|
||||
# plenty of headroom; fast paths are unaffected.
|
||||
ipg_kwargs = InitProcessGroupKwargs(timeout=timedelta(hours=2))
|
||||
# Accelerate auto-detects the device based on the available hardware and ignores the policy.device setting.
|
||||
# Force the device to be CPU when the active config's device is set to CPU (works for both policy and reward model training).
|
||||
force_cpu = cfg.trainable_config.device == "cpu"
|
||||
accelerator = Accelerator(
|
||||
step_scheduler_with_optimizer=False,
|
||||
kwargs_handlers=[ddp_kwargs],
|
||||
kwargs_handlers=[ddp_kwargs, ipg_kwargs],
|
||||
cpu=force_cpu,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user