diff --git a/src/lerobot/policies/pi052/modeling_pi052.py b/src/lerobot/policies/pi052/modeling_pi052.py index 7fbf7921c..b528eaca0 100644 --- a/src/lerobot/policies/pi052/modeling_pi052.py +++ b/src/lerobot/policies/pi052/modeling_pi052.py @@ -77,13 +77,13 @@ def _shifted_ce(logits: Tensor, labels: Tensor) -> Tensor: shift_logits = logits[:, :-1, :].contiguous() shift_labels = labels[:, 1:].contiguous().long() valid = shift_labels != -100 - loss = F.cross_entropy( - shift_logits.reshape(-1, shift_logits.shape[-1]), - shift_labels.reshape(-1), - ignore_index=-100, - reduction="sum", + if not bool(valid.any().item()): + return shift_logits.sum() * 0.0 + return F.cross_entropy( + shift_logits[valid], + shift_labels[valid], + reduction="mean", ) - return loss / valid.sum().clamp(min=1) def _mark_target_span_causal( @@ -140,13 +140,13 @@ def _fast_ce( if predict_actions_t is not None: sample_mask = predict_actions_t[:, None].expand_as(shift_valid) shift_valid = shift_valid & sample_mask - shift_targets = shift_targets.masked_fill(~shift_valid, -100) + if not bool(shift_valid.any().item()): + return shift_logits.sum() * 0.0 return F.cross_entropy( - shift_logits.reshape(-1, shift_logits.shape[-1]), - shift_targets.reshape(-1), - ignore_index=-100, - reduction="sum", - ) / shift_valid.sum().clamp(min=1) + shift_logits[shift_valid], + shift_targets[shift_valid], + reduction="mean", + ) # ---------------------------------------------------------------------- diff --git a/tests/policies/pi052/test_pi052_attention_masking.py b/tests/policies/pi052/test_pi052_attention_masking.py index 96ff4b479..5c74c5488 100644 --- a/tests/policies/pi052/test_pi052_attention_masking.py +++ b/tests/policies/pi052/test_pi052_attention_masking.py @@ -37,7 +37,7 @@ import torch pytest.importorskip("transformers") from lerobot.policies.pi05.modeling_pi05 import make_att_2d_masks # noqa: E402 -from lerobot.policies.pi052.modeling_pi052 import _mark_target_span_causal # noqa: E402 +from lerobot.policies.pi052.modeling_pi052 import _mark_target_span_causal, _shifted_ce # noqa: E402 # --------------------------------------------------------------------------- # A synthetic PI052 prefix layout: [images, prompt-lang, target-lang] @@ -136,3 +136,14 @@ def test_unmarked_mask_is_bidirectional_the_bug(): "raw embed_prefix mask is bidirectional over language — the first " "target token can see the last, which is the collapse bug" ) + + +def test_shifted_ce_returns_zero_when_no_text_positions_are_supervised(): + logits = torch.randn(2, 4, 8, requires_grad=True) + labels = torch.full((2, 4), -100, dtype=torch.long) + + loss = _shifted_ce(logits, labels) + + assert loss.item() == 0 + loss.backward() + assert logits.grad is not None diff --git a/tests/policies/pi052/test_pi052_fast_action_loss.py b/tests/policies/pi052/test_pi052_fast_action_loss.py index 9839db28c..c5575d6fd 100644 --- a/tests/policies/pi052/test_pi052_fast_action_loss.py +++ b/tests/policies/pi052/test_pi052_fast_action_loss.py @@ -73,3 +73,15 @@ def test_fast_ce_masks_non_action_samples(): ) assert torch.allclose(loss, expected) + + +def test_fast_ce_returns_zero_when_no_action_code_positions_are_valid(): + logits = torch.randn(2, 4, 8, requires_grad=True) + action_tokens = torch.tensor([[1, 2, 3, 4], [1, 2, 5, 6]]) + action_code_mask = torch.zeros_like(action_tokens, dtype=torch.bool) + + loss = _fast_ce(logits, action_tokens, action_code_mask, predict_actions_t=None) + + assert loss.item() == 0 + loss.backward() + assert logits.grad is not None