diff --git a/src/lerobot/policies/pi052/modeling_pi052.py b/src/lerobot/policies/pi052/modeling_pi052.py index 73799cbc9..f38536994 100644 --- a/src/lerobot/policies/pi052/modeling_pi052.py +++ b/src/lerobot/policies/pi052/modeling_pi052.py @@ -1009,60 +1009,11 @@ class PI052Policy(PI05Policy): text_logits = lm_head(text_hidden.to(lm_head.weight.dtype)) preds = text_logits.argmax(dim=-1) - # Train/inference parity check — run select_message on the - # *same* prompt prefix (the language up to but not including - # the supervised span) and capture the auto-regressive - # generation. The first generated token MUST match the - # training-side argmax at the prompt-end position (both are - # ``argmax lm_head(h_last_prompt)`` over identical context); - # any divergence is a parity bug (mask, dtype, KI routing - # difference). Later tokens can diverge because training - # uses teacher forcing while inference free-runs. - inference_outputs: list[dict[str, Any]] = [] - for s in range(n): - row_labels = sub_labels[s] - sup_pos = (row_labels != -100).nonzero(as_tuple=True)[0] - if sup_pos.numel() == 0: - inference_outputs.append({"first_token": None, "decoded": ""}) - continue - first_sup = int(sup_pos[0].item()) - # Build a single-sample batch by *truncating* the token - # sequence to the prompt-only portion (length == first_sup), - # not by zero-masking. ``select_message`` reads the - # prompt-end hidden state via ``vlm_out[:, -1:]`` — the - # *last position* of the prefix — so a padded sequence - # would make it read a padding-token hidden state - # (PaliGemma's prior on those happens to be ````, - # which would falsely flag a parity diverge). The real - # runtime feeds ``tokenizer(prompt)`` without padding, - # so we mirror that here. - prompt_tokens = sub[OBS_LANGUAGE_TOKENS][s : s + 1, :first_sup] - prompt_mask_orig = sub[OBS_LANGUAGE_ATTENTION_MASK][s : s + 1, :first_sup] - inf_batch: dict[str, Any] = { - OBS_LANGUAGE_TOKENS: prompt_tokens, - OBS_LANGUAGE_ATTENTION_MASK: prompt_mask_orig, - } - for k, v in sub.items(): - if isinstance(k, str) and k.startswith("observation.images."): - inf_batch[k] = v[s : s + 1] - if "observation.state" in batch and torch.is_tensor(batch["observation.state"]): - inf_batch["observation.state"] = batch["observation.state"][s : s + 1] - try: - # Tight budget — we just want to see the model's - # opening continuation, not the full sequence. - decoded = self.select_message( - inf_batch, max_new_tokens=24, temperature=0.0, top_p=1.0 - ) - except Exception as exc: # noqa: BLE001 - decoded = f"" - inference_outputs.append({"first_sup_pos": first_sup, "decoded": decoded}) - return { "input_ids": lang_tokens.detach().cpu(), "attention_mask": lang_masks.detach().cpu(), "labels": sub_labels.detach().cpu(), "predictions": preds.detach().cpu(), - "inference": inference_outputs, } finally: if was_training: diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index 0b0059955..2e9409ccb 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -224,7 +224,6 @@ def _print_debug_text_predictions( labels = debug["labels"] preds = debug["predictions"] attn = debug["attention_mask"] - inference = debug.get("inference") or [] n = ids.shape[0] print( @@ -251,7 +250,6 @@ def _print_debug_text_predictions( # Training-side teacher-forced argmax on the same prompt+target. n_sup = n_ok = 0 - first_sup_pred: int | None = None teacher_chars: list[int] = [] for i in range(1, real): label = sl[i] @@ -259,8 +257,6 @@ def _print_debug_text_predictions( continue n_sup += 1 pred = int(sp[i - 1]) - if first_sup_pred is None: - first_sup_pred = pred teacher_chars.append(pred) if label == pred: n_ok += 1 @@ -272,28 +268,6 @@ def _print_debug_text_predictions( f" training argmax (teacher-fed) : {teacher_text!r} acc={n_ok}/{n_sup}={acc:.1%}", flush=True, ) - - # Inference-side autoregressive output from the same prompt prefix. - inf_entry = inference[s] if s < len(inference) else None - if inf_entry: - inf_decoded = inf_entry.get("decoded", "") - print(f" inference (autoregressive) : {inf_decoded!r}", flush=True) - # First-token parity: training-side argmax at the prompt-end - # position MUST equal inference's first generated token — - # both compute argmax(lm_head(h_last_prompt)) on identical - # context. Any divergence signals a training↔inference bug. - if first_sup_pred is not None and inf_decoded and not inf_decoded.startswith("get`` 600 s); a + # 32 k-episode v3 dataset (e.g. ``robocasa_pretrain_human300_v4``) + # spends >13 min on rank 0 building the episode/frame index + # while ranks 1-N idle at ``wait_for_everyone()`` and crash with + # ``DistBackendError: ... wait timeout after 600000ms``. 2 h is + # plenty of headroom; fast paths are unaffected. + ipg_kwargs = InitProcessGroupKwargs(timeout=timedelta(hours=2)) # Accelerate auto-detects the device based on the available hardware and ignores the policy.device setting. # Force the device to be CPU when the active config's device is set to CPU (works for both policy and reward model training). force_cpu = cfg.trainable_config.device == "cpu" accelerator = Accelerator( step_scheduler_with_optimizer=False, - kwargs_handlers=[ddp_kwargs], + kwargs_handlers=[ddp_kwargs, ipg_kwargs], cpu=force_cpu, )