From ff1d58a46f1df7da754a16dcd45cd6b6bcbbf2c0 Mon Sep 17 00:00:00 2001 From: pepijn223 Date: Tue, 2 Jun 2026 13:06:51 +0200 Subject: [PATCH 1/3] pi052: suppress FAST action tokens in select_message text generation The FAST action tokenizer maps action codes to the top of the PaliGemma vocab (id = vocab_size-1-fast_skip_tokens-t). The lower part of that band sits just below the reserved block, so it escaped the existing suppress_loc_tokens mask and leaked into generated subtask/VQA/memory text as high-codepoint gibberish. Mask the FAST band on every select_message call so the high-level head emits clean language. Co-authored-by: Cursor --- src/lerobot/policies/pi052/modeling_pi052.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/src/lerobot/policies/pi052/modeling_pi052.py b/src/lerobot/policies/pi052/modeling_pi052.py index ce8c3abc6..73799cbc9 100644 --- a/src/lerobot/policies/pi052/modeling_pi052.py +++ b/src/lerobot/policies/pi052/modeling_pi052.py @@ -55,6 +55,17 @@ from .configuration_pi052 import PI052Config logger = logging.getLogger(__name__) +# FAST action-token vocab size (``lerobot/fast-action-tokenizer``). The +# tokenizer maps a FAST BPE id ``t`` to the PaliGemma vocab id +# ``vocab_size - 1 - fast_skip_tokens - t`` (see ``TokenizerProcessorStep``), +# so action tokens occupy the top ``_FAST_ACTION_VOCAB_SIZE`` ids below the +# ``fast_skip_tokens`` margin. The upper part collides with the reserved +# ```` block; the lower part sits just under it and otherwise leaks into +# generated text as high-codepoint gibberish (the action-trained LM head puts +# heavy mass on these ids), so ``select_message`` masks it. +_FAST_ACTION_VOCAB_SIZE = 2048 + + _HF_KERNELS_ENABLED = False @@ -1166,6 +1177,15 @@ class PI052Policy(PI05Policy): if special_ids and len(generated) < min_new_tokens: for sid in special_ids: logits_step[..., sid] = float("-inf") + # Mask FAST action tokens that fall *below* the ```` block. + # They are never valid text, but the action-trained head leaks + # them as gibberish; unlike the loc/seg block this region is never + # legitimately emitted (even by VQA), so suppress it on every call. + vocab_size = logits_step.shape[-1] + fast_skip = int(getattr(self.config, "fast_skip_tokens", 128)) + fast_lo = vocab_size - 1 - fast_skip - (_FAST_ACTION_VOCAB_SIZE - 1) + if 0 < fast_lo < 256000: + logits_step[..., fast_lo:256000] = float("-inf") if suppress_loc_tokens: logits_step[..., 256000:257024] = float("-inf") next_ids = self._sample_next_token(logits_step, temperature, top_p) From 23419026d5d2403314b5cd023d2207650a8761b9 Mon Sep 17 00:00:00 2001 From: pepijn Date: Tue, 2 Jun 2026 15:50:40 +0000 Subject: [PATCH 2/3] pi052: parquet-direct FAST tokenizer fit (fix v3 dataset hang) ``fit_fast_tokenizer`` previously called ``LeRobotDataset(repo_id, episodes=[N])`` per sampled episode, which on v3-format datasets routes through HF datasets' split lookup and raises ``ValueError: Instruction "train" corresponds to no data!`` on every episode. On ``pepijn223/robocasa_pretrain_human300_v4`` (32 k episodes) this looped through 13,293 skipped episodes for ~2.5 h before the NCCL watchdog killed the run via the 2 h ALLREDUCE timeout (job 22182985). Switch to reading the ``action`` column directly from the dataset's ``data/chunk-*/file-*.parquet`` shards (same pattern as the audit scripts). Verified end-to-end on the 32 k-episode dataset: 1000 chunks collected from 1000 episodes in 70.7 s. Co-authored-by: Cursor --- .../policies/pi052/fit_fast_tokenizer.py | 77 ++++++++++++++----- 1 file changed, 59 insertions(+), 18 deletions(-) diff --git a/src/lerobot/policies/pi052/fit_fast_tokenizer.py b/src/lerobot/policies/pi052/fit_fast_tokenizer.py index e27c01343..2f8224c72 100644 --- a/src/lerobot/policies/pi052/fit_fast_tokenizer.py +++ b/src/lerobot/policies/pi052/fit_fast_tokenizer.py @@ -178,35 +178,76 @@ def fit_fast_tokenizer( rng = np.random.default_rng(seed) actions_buf: list[np.ndarray] = [] - # Load just the metadata first to know episode boundaries. - ds_meta_only = LeRobotDataset(dataset_repo_id, episodes=[0]) - num_episodes = ds_meta_only.meta.total_episodes - if "action" not in ds_meta_only.features: - available = ", ".join(sorted(ds_meta_only.features)) or "" + # Resolve the dataset's data parquet shards directly, sidestepping + # ``LeRobotDataset(repo_id, episodes=[N])`` which on v3-format + # datasets routes through HF datasets'' split lookup and raises + # ``ValueError: Instruction "train" corresponds to no data!`` for + # every episode (job 22182985 looped through 13,293 skipped episodes + # for ~2.5 h before NCCL killed it). Reading the ``action`` column + # straight from the parquet shards is also faster: each per-episode + # ``LeRobotDataset`` instantiation re-parses every meta file. + from huggingface_hub import snapshot_download # noqa: PLC0415 + import pyarrow as _pa # noqa: PLC0415 + import pyarrow.parquet as _pq # noqa: PLC0415 + + snap = Path(snapshot_download(repo_id=dataset_repo_id, repo_type="dataset")) + data_files = sorted((snap / "data").glob("chunk-*/file-*.parquet")) + if not data_files: raise RuntimeError( - f"FAST fit: dataset {dataset_repo_id!r} has no ``action`` feature. " - f"Available features: {available}." + f"FAST fit: no ``data/chunk-*/file-*.parquet`` shards found under {snap!s}." ) - del ds_meta_only + + # Read just the (episode_index, action) columns once across all + # shards. This is the same pattern used elsewhere in the codebase + # for whole-dataset audits and stays under ~2 GB even on 32 k-episode + # / 29 M-frame datasets because the action column is a fixed-length + # float vector. + tables = [_pq.read_table(f, columns=["episode_index", "action"]) for f in data_files] + table = _pa.concat_tables(tables) + eps = table["episode_index"].to_numpy() + acts_col = table["action"] + # ``action`` may be a fixed-shape ListArray or a 2-D NumericArray; + # ``to_numpy(zero_copy_only=False)`` produces an object array of + # 1-D NumPy actions either way, which we stack into (N, D). + try: + acts = np.stack(acts_col.to_numpy(zero_copy_only=False)).astype(np.float32) + except Exception: # noqa: BLE001 + # Fallback path for nested-list types: flatten via to_pylist(). + acts = np.asarray(acts_col.to_pylist(), dtype=np.float32) + if acts.ndim != 2: + raise RuntimeError( + f"FAST fit: expected ``action`` rows to be 1-D vectors; got shape {acts.shape}." + ) + + # Episode index → slice (start, stop) into ``acts`` along axis 0. + # ``eps`` is monotonically increasing within each parquet shard but + # we make no assumption across shards — sort once and group. + order = np.argsort(eps, kind="stable") + eps_sorted = eps[order] + boundaries = np.searchsorted(eps_sorted, np.arange(int(eps_sorted.max()) + 2)) + ep_to_slice: dict[int, tuple[int, int]] = { + int(ep): (int(boundaries[ep]), int(boundaries[ep + 1])) + for ep in range(len(boundaries) - 1) + if boundaries[ep] < boundaries[ep + 1] + } + num_episodes = len(ep_to_slice) + # ``acts`` is in original (un-sorted-by-episode) row order; reorder + # so per-episode slices are contiguous. + acts = acts[order] samples_per_episode = max(1, n_samples // max(num_episodes, 1)) collected = 0 eps_visited = 0 short_episodes = 0 - for ep_idx in rng.permutation(num_episodes): + ep_indices = list(ep_to_slice.keys()) + for ep_idx in rng.permutation(ep_indices): if collected >= n_samples: break - ep_idx = int(ep_idx) - try: - ds = LeRobotDataset(dataset_repo_id, episodes=[ep_idx]) - ep_actions = np.asarray(ds.hf_dataset["action"], dtype=np.float32) - except Exception as exc: # noqa: BLE001 - logger.warning("FAST fit: skipping episode %d: %s", ep_idx, exc) - continue - if ep_actions.ndim != 2 or ep_actions.shape[0] < chunk_size: + start, stop = ep_to_slice[int(ep_idx)] + ep_actions = acts[start:stop] + if ep_actions.shape[0] < chunk_size: short_episodes += 1 continue - # Sample ``samples_per_episode`` contiguous chunks uniformly. starts = rng.integers(0, ep_actions.shape[0] - chunk_size + 1, size=samples_per_episode) for s in starts: actions_buf.append(ep_actions[int(s) : int(s) + chunk_size]) From e660a51e787470aff3e3c2f3bdbdee46ec0bb57e Mon Sep 17 00:00:00 2001 From: pepijn Date: Thu, 4 Jun 2026 13:32:44 +0000 Subject: [PATCH 3/3] pi052(debug): drop misleading inference/parity dump from text preds MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The first-token parity check re-tokenized the decoded (stripped) inference string, so the leading-space SentencePiece variant always mismatched the training argmax — a false "DIVERGED" alarm. Remove the autoregressive inference print and parity comparison (and the now-dead per-sample select_message generation), keeping only the prompt, ground-truth target, and teacher-forced argmax accuracy. Co-authored-by: Cursor --- src/lerobot/policies/pi052/modeling_pi052.py | 49 -------------------- src/lerobot/scripts/lerobot_train.py | 41 ++++++---------- 2 files changed, 13 insertions(+), 77 deletions(-) diff --git a/src/lerobot/policies/pi052/modeling_pi052.py b/src/lerobot/policies/pi052/modeling_pi052.py index 73799cbc9..f38536994 100644 --- a/src/lerobot/policies/pi052/modeling_pi052.py +++ b/src/lerobot/policies/pi052/modeling_pi052.py @@ -1009,60 +1009,11 @@ class PI052Policy(PI05Policy): text_logits = lm_head(text_hidden.to(lm_head.weight.dtype)) preds = text_logits.argmax(dim=-1) - # Train/inference parity check — run select_message on the - # *same* prompt prefix (the language up to but not including - # the supervised span) and capture the auto-regressive - # generation. The first generated token MUST match the - # training-side argmax at the prompt-end position (both are - # ``argmax lm_head(h_last_prompt)`` over identical context); - # any divergence is a parity bug (mask, dtype, KI routing - # difference). Later tokens can diverge because training - # uses teacher forcing while inference free-runs. - inference_outputs: list[dict[str, Any]] = [] - for s in range(n): - row_labels = sub_labels[s] - sup_pos = (row_labels != -100).nonzero(as_tuple=True)[0] - if sup_pos.numel() == 0: - inference_outputs.append({"first_token": None, "decoded": ""}) - continue - first_sup = int(sup_pos[0].item()) - # Build a single-sample batch by *truncating* the token - # sequence to the prompt-only portion (length == first_sup), - # not by zero-masking. ``select_message`` reads the - # prompt-end hidden state via ``vlm_out[:, -1:]`` — the - # *last position* of the prefix — so a padded sequence - # would make it read a padding-token hidden state - # (PaliGemma's prior on those happens to be ````, - # which would falsely flag a parity diverge). The real - # runtime feeds ``tokenizer(prompt)`` without padding, - # so we mirror that here. - prompt_tokens = sub[OBS_LANGUAGE_TOKENS][s : s + 1, :first_sup] - prompt_mask_orig = sub[OBS_LANGUAGE_ATTENTION_MASK][s : s + 1, :first_sup] - inf_batch: dict[str, Any] = { - OBS_LANGUAGE_TOKENS: prompt_tokens, - OBS_LANGUAGE_ATTENTION_MASK: prompt_mask_orig, - } - for k, v in sub.items(): - if isinstance(k, str) and k.startswith("observation.images."): - inf_batch[k] = v[s : s + 1] - if "observation.state" in batch and torch.is_tensor(batch["observation.state"]): - inf_batch["observation.state"] = batch["observation.state"][s : s + 1] - try: - # Tight budget — we just want to see the model's - # opening continuation, not the full sequence. - decoded = self.select_message( - inf_batch, max_new_tokens=24, temperature=0.0, top_p=1.0 - ) - except Exception as exc: # noqa: BLE001 - decoded = f"" - inference_outputs.append({"first_sup_pos": first_sup, "decoded": decoded}) - return { "input_ids": lang_tokens.detach().cpu(), "attention_mask": lang_masks.detach().cpu(), "labels": sub_labels.detach().cpu(), "predictions": preds.detach().cpu(), - "inference": inference_outputs, } finally: if was_training: diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index 0b0059955..2e9409ccb 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -224,7 +224,6 @@ def _print_debug_text_predictions( labels = debug["labels"] preds = debug["predictions"] attn = debug["attention_mask"] - inference = debug.get("inference") or [] n = ids.shape[0] print( @@ -251,7 +250,6 @@ def _print_debug_text_predictions( # Training-side teacher-forced argmax on the same prompt+target. n_sup = n_ok = 0 - first_sup_pred: int | None = None teacher_chars: list[int] = [] for i in range(1, real): label = sl[i] @@ -259,8 +257,6 @@ def _print_debug_text_predictions( continue n_sup += 1 pred = int(sp[i - 1]) - if first_sup_pred is None: - first_sup_pred = pred teacher_chars.append(pred) if label == pred: n_ok += 1 @@ -272,28 +268,6 @@ def _print_debug_text_predictions( f" training argmax (teacher-fed) : {teacher_text!r} acc={n_ok}/{n_sup}={acc:.1%}", flush=True, ) - - # Inference-side autoregressive output from the same prompt prefix. - inf_entry = inference[s] if s < len(inference) else None - if inf_entry: - inf_decoded = inf_entry.get("decoded", "") - print(f" inference (autoregressive) : {inf_decoded!r}", flush=True) - # First-token parity: training-side argmax at the prompt-end - # position MUST equal inference's first generated token — - # both compute argmax(lm_head(h_last_prompt)) on identical - # context. Any divergence signals a training↔inference bug. - if first_sup_pred is not None and inf_decoded and not inf_decoded.startswith("get`` 600 s); a + # 32 k-episode v3 dataset (e.g. ``robocasa_pretrain_human300_v4``) + # spends >13 min on rank 0 building the episode/frame index + # while ranks 1-N idle at ``wait_for_everyone()`` and crash with + # ``DistBackendError: ... wait timeout after 600000ms``. 2 h is + # plenty of headroom; fast paths are unaffected. + ipg_kwargs = InitProcessGroupKwargs(timeout=timedelta(hours=2)) # Accelerate auto-detects the device based on the available hardware and ignores the policy.device setting. # Force the device to be CPU when the active config's device is set to CPU (works for both policy and reward model training). force_cpu = cfg.trainable_config.device == "cpu" accelerator = Accelerator( step_scheduler_with_optimizer=False, - kwargs_handlers=[ddp_kwargs], + kwargs_handlers=[ddp_kwargs, ipg_kwargs], cpu=force_cpu, )