fix(pi052): avoid dense CE over padded tokens

Select only supervised text and FAST action-code positions before cross-entropy to avoid full-vocabulary loss tensors over padded sequences.

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
pepijn
2026-05-18 18:40:34 +00:00
parent 7960cc14ec
commit 22c9c4905e
3 changed files with 36 additions and 13 deletions
@@ -37,7 +37,7 @@ import torch
pytest.importorskip("transformers")
from lerobot.policies.pi05.modeling_pi05 import make_att_2d_masks # noqa: E402
from lerobot.policies.pi052.modeling_pi052 import _mark_target_span_causal # noqa: E402
from lerobot.policies.pi052.modeling_pi052 import _mark_target_span_causal, _shifted_ce # noqa: E402
# ---------------------------------------------------------------------------
# A synthetic PI052 prefix layout: [images, prompt-lang, target-lang]
@@ -136,3 +136,14 @@ def test_unmarked_mask_is_bidirectional_the_bug():
"raw embed_prefix mask is bidirectional over language — the first "
"target token can see the last, which is the collapse bug"
)
def test_shifted_ce_returns_zero_when_no_text_positions_are_supervised():
logits = torch.randn(2, 4, 8, requires_grad=True)
labels = torch.full((2, 4), -100, dtype=torch.long)
loss = _shifted_ce(logits, labels)
assert loss.item() == 0
loss.backward()
assert logits.grad is not None
@@ -73,3 +73,15 @@ def test_fast_ce_masks_non_action_samples():
)
assert torch.allclose(loss, expected)
def test_fast_ce_returns_zero_when_no_action_code_positions_are_valid():
logits = torch.randn(2, 4, 8, requires_grad=True)
action_tokens = torch.tensor([[1, 2, 3, 4], [1, 2, 5, 6]])
action_code_mask = torch.zeros_like(action_tokens, dtype=torch.bool)
loss = _fast_ce(logits, action_tokens, action_code_mask, predict_actions_t=None)
assert loss.item() == 0
loss.backward()
assert logits.grad is not None