mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 12:09:42 +00:00
fix(pi052): avoid dense CE over padded tokens
Select only supervised text and FAST action-code positions before cross-entropy to avoid full-vocabulary loss tensors over padded sequences. Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -77,13 +77,13 @@ def _shifted_ce(logits: Tensor, labels: Tensor) -> Tensor:
|
|||||||
shift_logits = logits[:, :-1, :].contiguous()
|
shift_logits = logits[:, :-1, :].contiguous()
|
||||||
shift_labels = labels[:, 1:].contiguous().long()
|
shift_labels = labels[:, 1:].contiguous().long()
|
||||||
valid = shift_labels != -100
|
valid = shift_labels != -100
|
||||||
loss = F.cross_entropy(
|
if not bool(valid.any().item()):
|
||||||
shift_logits.reshape(-1, shift_logits.shape[-1]),
|
return shift_logits.sum() * 0.0
|
||||||
shift_labels.reshape(-1),
|
return F.cross_entropy(
|
||||||
ignore_index=-100,
|
shift_logits[valid],
|
||||||
reduction="sum",
|
shift_labels[valid],
|
||||||
|
reduction="mean",
|
||||||
)
|
)
|
||||||
return loss / valid.sum().clamp(min=1)
|
|
||||||
|
|
||||||
|
|
||||||
def _mark_target_span_causal(
|
def _mark_target_span_causal(
|
||||||
@@ -140,13 +140,13 @@ def _fast_ce(
|
|||||||
if predict_actions_t is not None:
|
if predict_actions_t is not None:
|
||||||
sample_mask = predict_actions_t[:, None].expand_as(shift_valid)
|
sample_mask = predict_actions_t[:, None].expand_as(shift_valid)
|
||||||
shift_valid = shift_valid & sample_mask
|
shift_valid = shift_valid & sample_mask
|
||||||
shift_targets = shift_targets.masked_fill(~shift_valid, -100)
|
if not bool(shift_valid.any().item()):
|
||||||
|
return shift_logits.sum() * 0.0
|
||||||
return F.cross_entropy(
|
return F.cross_entropy(
|
||||||
shift_logits.reshape(-1, shift_logits.shape[-1]),
|
shift_logits[shift_valid],
|
||||||
shift_targets.reshape(-1),
|
shift_targets[shift_valid],
|
||||||
ignore_index=-100,
|
reduction="mean",
|
||||||
reduction="sum",
|
)
|
||||||
) / shift_valid.sum().clamp(min=1)
|
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------
|
# ----------------------------------------------------------------------
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ import torch
|
|||||||
pytest.importorskip("transformers")
|
pytest.importorskip("transformers")
|
||||||
|
|
||||||
from lerobot.policies.pi05.modeling_pi05 import make_att_2d_masks # noqa: E402
|
from lerobot.policies.pi05.modeling_pi05 import make_att_2d_masks # noqa: E402
|
||||||
from lerobot.policies.pi052.modeling_pi052 import _mark_target_span_causal # noqa: E402
|
from lerobot.policies.pi052.modeling_pi052 import _mark_target_span_causal, _shifted_ce # noqa: E402
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# A synthetic PI052 prefix layout: [images, prompt-lang, target-lang]
|
# A synthetic PI052 prefix layout: [images, prompt-lang, target-lang]
|
||||||
@@ -136,3 +136,14 @@ def test_unmarked_mask_is_bidirectional_the_bug():
|
|||||||
"raw embed_prefix mask is bidirectional over language — the first "
|
"raw embed_prefix mask is bidirectional over language — the first "
|
||||||
"target token can see the last, which is the collapse bug"
|
"target token can see the last, which is the collapse bug"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_shifted_ce_returns_zero_when_no_text_positions_are_supervised():
|
||||||
|
logits = torch.randn(2, 4, 8, requires_grad=True)
|
||||||
|
labels = torch.full((2, 4), -100, dtype=torch.long)
|
||||||
|
|
||||||
|
loss = _shifted_ce(logits, labels)
|
||||||
|
|
||||||
|
assert loss.item() == 0
|
||||||
|
loss.backward()
|
||||||
|
assert logits.grad is not None
|
||||||
|
|||||||
@@ -73,3 +73,15 @@ def test_fast_ce_masks_non_action_samples():
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert torch.allclose(loss, expected)
|
assert torch.allclose(loss, expected)
|
||||||
|
|
||||||
|
|
||||||
|
def test_fast_ce_returns_zero_when_no_action_code_positions_are_valid():
|
||||||
|
logits = torch.randn(2, 4, 8, requires_grad=True)
|
||||||
|
action_tokens = torch.tensor([[1, 2, 3, 4], [1, 2, 5, 6]])
|
||||||
|
action_code_mask = torch.zeros_like(action_tokens, dtype=torch.bool)
|
||||||
|
|
||||||
|
loss = _fast_ce(logits, action_tokens, action_code_mask, predict_actions_t=None)
|
||||||
|
|
||||||
|
assert loss.item() == 0
|
||||||
|
loss.backward()
|
||||||
|
assert logits.grad is not None
|
||||||
|
|||||||
Reference in New Issue
Block a user