mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-20 19:19:56 +00:00
fix(pi052): supervise only FAST action-code tokens
Mask the FAST auxiliary loss to discrete action-code tokens so wrapper formatting tokens do not affect action co-training. Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -0,0 +1,75 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Regression tests for PI052 FAST action-code supervision."""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
pytest.importorskip("transformers")
|
||||
|
||||
from lerobot.policies.pi052.modeling_pi052 import _fast_ce # noqa: E402
|
||||
|
||||
|
||||
def test_fast_ce_supervises_only_discrete_action_codes():
|
||||
"""Wrapper tokens can be wrong without affecting the FAST action-code loss."""
|
||||
vocab_size = 8
|
||||
action_tokens = torch.tensor([[1, 2, 3, 4, 5, 0]])
|
||||
action_code_mask = torch.tensor([[False, False, True, True, False, False]])
|
||||
|
||||
logits = torch.zeros(1, action_tokens.shape[1], vocab_size)
|
||||
# Deliberately bad wrapper-token predictions. These should be ignored.
|
||||
logits[0, 0, 7] = 10.0 # target would be token 2
|
||||
logits[0, 3, 7] = 10.0 # target would be delimiter token 5
|
||||
# Correct action-code predictions: hidden t predicts target t + 1.
|
||||
logits[0, 1, 3] = 10.0
|
||||
logits[0, 2, 4] = 10.0
|
||||
|
||||
loss = _fast_ce(logits, action_tokens, action_code_mask, predict_actions_t=None)
|
||||
expected = F.cross_entropy(
|
||||
torch.stack([logits[0, 1], logits[0, 2]]),
|
||||
torch.tensor([3, 4]),
|
||||
reduction="mean",
|
||||
)
|
||||
|
||||
assert torch.allclose(loss, expected)
|
||||
|
||||
|
||||
def test_fast_ce_masks_non_action_samples():
|
||||
"""Recipe samples with predict_actions=False do not contribute FAST loss."""
|
||||
vocab_size = 8
|
||||
action_tokens = torch.tensor([[1, 2, 3, 4], [1, 2, 5, 6]])
|
||||
action_code_mask = torch.tensor(
|
||||
[[False, False, True, True], [False, False, True, True]]
|
||||
)
|
||||
predict_actions = torch.tensor([True, False])
|
||||
|
||||
logits = torch.zeros(2, action_tokens.shape[1], vocab_size)
|
||||
logits[0, 1, 3] = 10.0
|
||||
logits[0, 2, 4] = 10.0
|
||||
# Bad predictions in the masked sample should not matter.
|
||||
logits[1, 1, 7] = 10.0
|
||||
logits[1, 2, 7] = 10.0
|
||||
|
||||
loss = _fast_ce(logits, action_tokens, action_code_mask, predict_actions)
|
||||
expected = F.cross_entropy(
|
||||
torch.stack([logits[0, 1], logits[0, 2]]),
|
||||
torch.tensor([3, 4]),
|
||||
reduction="mean",
|
||||
)
|
||||
|
||||
assert torch.allclose(loss, expected)
|
||||
Reference in New Issue
Block a user