fix(pi052): decouple flow suffix RoPE positions from the FAST block

At training the prefix is [images, language, FAST], so the flow action
suffix got position_ids offset by n_fast (per-sample 33-111). At inference
there is no FAST block, so the suffix lands ~n_fast positions earlier. Since
the action expert uses RoPE, this shifts the flow->prefix relative positions
between train and deploy, corrupting the conditioning and collapsing the
predicted action distribution (pi052 ~0% while pi05, which has no FAST in its
prefix, works). Offset the flow suffix by the valid image+language count only
(excluding FAST) in both _combined_prefix_and_flow and _amortized_prefix_and_flow
so train positions == inference positions.

Also: recipe blend weights 0.30/0.55 -> 0.25/0.60 (match the trained mix), and
an env-gated EVAL_TASK_OVERRIDE diagnostic in lerobot_eval.

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
pepijn223
2026-06-29 16:40:55 +02:00
parent e1cf646e84
commit d099ac91b3
2 changed files with 102 additions and 3 deletions
+90 -3
View File
@@ -562,6 +562,20 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
if self.rtc_processor is not None and self.rtc_processor.is_debug_enabled():
self.rtc_processor.track(time=time, x_t=x_t, v_t=v_t)
import os as _os # noqa: PLC0415
if _os.environ.get("PI052_DEBUG_TENSORS") == "1" and not getattr(self, "_dbg_act_done", False):
import logging as _lg # noqa: PLC0415
_a = x_t.float()
ad = self.config.max_action_dim
_lg.getLogger(__name__).info(
"PI052_DEBUG predicted norm action chunk shape=%s min=%.3f max=%.3f mean=%.3f std=%.3f (real dims only) (expect ~[-1,1])",
tuple(x_t.shape), _a[..., :12].min().item(), _a[..., :12].max().item(),
_a[..., :12].mean().item(), _a[..., :12].std().item(),
)
self._dbg_act_done = True
return x_t
def denoise_step(
@@ -1382,6 +1396,17 @@ class PI052Policy(PreTrainedPolicy):
att_2d_masks[:, fast_end:, fast_start:fast_end] = False
position_ids = torch.cumsum(pad_masks, dim=1) - 1
if fast_len > 0:
# The flow suffix is a PARALLEL action representation to the FAST
# block, not a continuation of it (the two never attend to each
# other). At inference there is no FAST block, so the suffix RoPE
# positions start at the valid image+language count. Match that here
# so flow->prefix relative positions are train==inference; otherwise
# the suffix is offset by n_fast (per-sample, 33-111) and the trained
# head reads the wrong RoPE conditioning at deploy time.
non_fast_valid = prefix_pad[:, :non_fast_prefix_len].sum(dim=1, keepdim=True)
suffix_pos = non_fast_valid + torch.cumsum(suffix_pad, dim=1) - 1
position_ids = torch.cat([position_ids[:, : prefix_pad.shape[1]], suffix_pos], dim=1)
att_2d_masks_4d = self.model._prepare_attention_masks_4d(
att_2d_masks, dtype=prefix_embs.dtype
)
@@ -1492,8 +1517,14 @@ class PI052Policy(PreTrainedPolicy):
att_2d_4d = model._prepare_attention_masks_4d(att_2d, dtype=prefix_embs.dtype)
# Positions: prefix as usual; every block restarts at the prefix offset
# (each block is an independent denoising of the same chunk).
prefix_offsets = torch.sum(prefix_pad, dim=-1)[:, None]
# (each block is an independent denoising of the same chunk). The flow
# blocks are PARALLEL to the FAST block, not a continuation, so offset by
# the valid image+language count (excluding FAST) — matching inference
# (no FAST block) so flow->prefix RoPE positions are train==inference.
if fast_len > 0:
prefix_offsets = prefix_pad[:, :non_fast_prefix_len].sum(dim=-1)[:, None]
else:
prefix_offsets = torch.sum(prefix_pad, dim=-1)[:, None]
block_positions = prefix_offsets + torch.cumsum(suffix_pad, dim=1) - 1 # (B, chunk)
position_ids = torch.cat([torch.cumsum(prefix_pad, dim=1) - 1, block_positions.repeat(1, k)], dim=1)
@@ -1984,10 +2015,20 @@ class PI052Policy(PreTrainedPolicy):
# own task + observation, then stack the per-env prompts into a single
# (n, L) batch for the action expert. This keeps batch_size > 1 correct
# (env i is conditioned on env i's subtask, not a broadcast of env 0).
# Diagnostic toggle (PI052_SUBTASK_USE_TASK=1): skip the learned subtask
# generator and condition the action expert on the raw task text. Isolates
# whether the generator is the eval bottleneck — eval-only, off by default.
import os # noqa: PLC0415
use_task_directly = os.environ.get("PI052_SUBTASK_USE_TASK") == "1"
rows: list[tuple[Tensor, Tensor | None]] = []
tokenizer = None
for i in range(n):
if regenerate or not self.last_subtasks[i]:
if use_task_directly:
subtask = tasks[i]
self.last_subtasks[i] = subtask
elif regenerate or not self.last_subtasks[i]:
obs_i = self._slice_observation(batch, i)
subtask = self._generate_low_level_subtask(obs_i, tasks[i], i)
else:
@@ -2002,6 +2043,27 @@ class PI052Policy(PreTrainedPolicy):
[{"role": "user", "content": content}],
add_generation_prompt=False,
)
if (
os.environ.get("PI052_DEBUG_TENSORS") == "1"
and i == 0
and not getattr(self, "_dbg_prompt_done", False)
):
import logging as _lg # noqa: PLC0415
_tok = text_batch["tokenizer"]
_ids = text_batch["lang_tokens"][0]
_decoded = _tok.decode(_ids.tolist())
_log = _lg.getLogger(__name__)
_log.info("PI052_DEBUG eval low-level content[0]: %r", content)
_log.info("PI052_DEBUG eval decoded prompt[0]: %r", _decoded)
if torch.is_tensor(state_all):
_s = state_all[i].float()
_log.info(
"PI052_DEBUG eval norm state[0]: min=%.3f max=%.3f mean=%.3f | digits=%s",
_s.min().item(), _s.max().item(), _s.mean().item(),
discretize_state_str(state_all[i]),
)
self._dbg_prompt_done = True
rows.append((text_batch["lang_tokens"], text_batch["lang_masks"]))
tokenizer = text_batch["tokenizer"]
@@ -2496,10 +2558,25 @@ class PI052Policy(PreTrainedPolicy):
f"(batch: {batch.keys()}) (image_features: {self.config.image_features})"
)
# Diagnostic (PI052_DEBUG_TENSORS=1): dump raw + processed image stats
# once, to compare the eval env's image pipeline against training.
import os as _os # noqa: PLC0415
_dbg = _os.environ.get("PI052_DEBUG_TENSORS") == "1" and not getattr(self, "_dbg_img_done", False)
# Preprocess image features present in the batch
for key in present_img_keys:
img = batch[key]
if _dbg and key == present_img_keys[0]:
import logging as _lg # noqa: PLC0415
_r = img.float()
_lg.getLogger(__name__).info(
"PI052_DEBUG raw img[%s] shape=%s dtype=%s min=%.3f max=%.3f mean=%.3f",
key, tuple(img.shape), str(img.dtype), _r.min().item(), _r.max().item(), _r.mean().item(),
)
# Ensure tensor is on the same device as the model
if img.device != device:
img = img.to(device)
@@ -2522,6 +2599,16 @@ class PI052Policy(PreTrainedPolicy):
# Normalize from [0,1] to [-1,1] as expected by siglip
img = img * 2.0 - 1.0
if _dbg and key == present_img_keys[0]:
import logging as _lg # noqa: PLC0415
_p = img.float()
_lg.getLogger(__name__).info(
"PI052_DEBUG processed img[%s] shape=%s min=%.3f max=%.3f mean=%.3f (expect ~[-1,1])",
key, tuple(img.shape), _p.min().item(), _p.max().item(), _p.mean().item(),
)
self._dbg_img_done = True
# from openpi preprocess_observation_pytorch: Convert back to [B, C, H, W] format if it was originally channels-first
if is_channels_first:
img = img.permute(0, 3, 1, 2) # [B, H, W, C] -> [B, C, H, W]
+12
View File
@@ -52,6 +52,7 @@ You can learn about the CLI options for this script in the `EvalPipelineConfig`
import concurrent.futures as cf
import json
import logging
import os
import threading
import time
from collections import defaultdict
@@ -239,6 +240,17 @@ def rollout(
except (AttributeError, NotImplementedError):
observation["task"] = [""] * env.num_envs
# Diagnostic (EVAL_TASK_OVERRIDE): replace the env task string with a
# fixed hand-written instruction for every env. Isolates whether the
# action head can execute a given phrasing, independent of the env's
# own description. Logs the original once for comparison.
_task_override = os.environ.get("EVAL_TASK_OVERRIDE")
if _task_override:
if step == 0:
logging.info("EVAL_TASK_OVERRIDE active: env task[0]=%r -> %r",
observation["task"][0], _task_override)
observation["task"] = [_task_override] * env.num_envs
# Apply environment-specific preprocessing (e.g., LiberoProcessorStep for LIBERO)
observation = env_preprocessor(observation)