mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 12:09:42 +00:00
fix(pi052): supervise only FAST action-code tokens
Mask the FAST auxiliary loss to discrete action-code tokens so wrapper formatting tokens do not affect action co-training. Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -124,19 +124,19 @@ def _mark_target_span_causal(
|
|||||||
def _fast_ce(
|
def _fast_ce(
|
||||||
fast_logits: Tensor,
|
fast_logits: Tensor,
|
||||||
action_tokens: Tensor,
|
action_tokens: Tensor,
|
||||||
action_mask: Tensor,
|
action_code_mask: Tensor,
|
||||||
predict_actions_t: Tensor | None,
|
predict_actions_t: Tensor | None,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
"""FAST-CE with both token-pad masking and per-sample action gating.
|
"""FAST action-code CE with token-span masking and per-sample action gating.
|
||||||
|
|
||||||
``action_mask`` is the FAST tokenizer's padding mask; samples whose
|
``action_code_mask`` is true only on the discrete action-code tokens,
|
||||||
recipe sets ``predict_actions=False`` (e.g. plan_generation,
|
excluding the BOS / "Action: " / delimiter wrapper. Samples whose
|
||||||
ask_vqa_*) get *all* their FAST positions masked out via the
|
recipe sets ``predict_actions=False`` get all code positions masked
|
||||||
per-sample gate.
|
out via the per-sample gate.
|
||||||
"""
|
"""
|
||||||
shift_logits = fast_logits[:, :-1, :].contiguous()
|
shift_logits = fast_logits[:, :-1, :].contiguous()
|
||||||
shift_targets = action_tokens[:, 1:].contiguous().long()
|
shift_targets = action_tokens[:, 1:].contiguous().long()
|
||||||
shift_valid = action_mask[:, 1:].contiguous().bool()
|
shift_valid = action_code_mask[:, 1:].contiguous().bool()
|
||||||
if predict_actions_t is not None:
|
if predict_actions_t is not None:
|
||||||
sample_mask = predict_actions_t[:, None].expand_as(shift_valid)
|
sample_mask = predict_actions_t[:, None].expand_as(shift_valid)
|
||||||
shift_valid = shift_valid & sample_mask
|
shift_valid = shift_valid & sample_mask
|
||||||
@@ -429,13 +429,18 @@ class PI052Policy(PI05Policy):
|
|||||||
and self.config.fast_action_loss_weight > 0
|
and self.config.fast_action_loss_weight > 0
|
||||||
and (predict_actions_t is None or bool(predict_actions_t.any().item()))
|
and (predict_actions_t is None or bool(predict_actions_t.any().item()))
|
||||||
)
|
)
|
||||||
action_tokens = action_mask = None
|
action_tokens = action_mask = action_code_mask = None
|
||||||
if run_fast:
|
if run_fast:
|
||||||
from lerobot.utils.constants import ACTION_TOKEN_MASK, ACTION_TOKENS # noqa: PLC0415
|
from lerobot.utils.constants import ( # noqa: PLC0415
|
||||||
|
ACTION_CODE_TOKEN_MASK,
|
||||||
|
ACTION_TOKEN_MASK,
|
||||||
|
ACTION_TOKENS,
|
||||||
|
)
|
||||||
|
|
||||||
action_tokens = batch.get(ACTION_TOKENS)
|
action_tokens = batch.get(ACTION_TOKENS)
|
||||||
action_mask = batch.get(ACTION_TOKEN_MASK)
|
action_mask = batch.get(ACTION_TOKEN_MASK)
|
||||||
if action_tokens is None or action_mask is None:
|
action_code_mask = batch.get(ACTION_CODE_TOKEN_MASK)
|
||||||
|
if action_tokens is None or action_mask is None or action_code_mask is None:
|
||||||
run_fast = False
|
run_fast = False
|
||||||
|
|
||||||
# ------------------------------------------------------------
|
# ------------------------------------------------------------
|
||||||
@@ -457,6 +462,7 @@ class PI052Policy(PI05Policy):
|
|||||||
text_labels=text_labels if run_text else None,
|
text_labels=text_labels if run_text else None,
|
||||||
action_tokens=action_tokens if run_fast else None,
|
action_tokens=action_tokens if run_fast else None,
|
||||||
action_mask=action_mask if run_fast else None,
|
action_mask=action_mask if run_fast else None,
|
||||||
|
action_code_mask=action_code_mask if run_fast else None,
|
||||||
predict_actions_t=predict_actions_t,
|
predict_actions_t=predict_actions_t,
|
||||||
)
|
)
|
||||||
loss_dict["flow_loss"] = float(flow_loss.detach().item())
|
loss_dict["flow_loss"] = float(flow_loss.detach().item())
|
||||||
@@ -473,6 +479,7 @@ class PI052Policy(PI05Policy):
|
|||||||
text_labels=text_labels if run_text else None,
|
text_labels=text_labels if run_text else None,
|
||||||
action_tokens=action_tokens if run_fast else None,
|
action_tokens=action_tokens if run_fast else None,
|
||||||
action_mask=action_mask if run_fast else None,
|
action_mask=action_mask if run_fast else None,
|
||||||
|
action_code_mask=action_code_mask if run_fast else None,
|
||||||
predict_actions_t=predict_actions_t,
|
predict_actions_t=predict_actions_t,
|
||||||
)
|
)
|
||||||
if text_loss is not None:
|
if text_loss is not None:
|
||||||
@@ -508,6 +515,7 @@ class PI052Policy(PI05Policy):
|
|||||||
text_labels: Tensor | None,
|
text_labels: Tensor | None,
|
||||||
action_tokens: Tensor | None,
|
action_tokens: Tensor | None,
|
||||||
action_mask: Tensor | None,
|
action_mask: Tensor | None,
|
||||||
|
action_code_mask: Tensor | None,
|
||||||
predict_actions_t: Tensor | None = None,
|
predict_actions_t: Tensor | None = None,
|
||||||
) -> tuple[Tensor, Tensor | None, Tensor | None]:
|
) -> tuple[Tensor, Tensor | None, Tensor | None]:
|
||||||
"""Full fusion: flow + text + FAST in ONE backbone forward.
|
"""Full fusion: flow + text + FAST in ONE backbone forward.
|
||||||
@@ -647,10 +655,10 @@ class PI052Policy(PI05Policy):
|
|||||||
text_loss = _shifted_ce(text_logits, text_labels)
|
text_loss = _shifted_ce(text_logits, text_labels)
|
||||||
|
|
||||||
fast_loss: Tensor | None = None
|
fast_loss: Tensor | None = None
|
||||||
if fast_len > 0 and prefix_out is not None:
|
if fast_len > 0 and prefix_out is not None and action_code_mask is not None:
|
||||||
fast_hidden = prefix_out[:, -fast_len:, :]
|
fast_hidden = prefix_out[:, -fast_len:, :]
|
||||||
fast_logits = lm_head(fast_hidden.to(lm_head.weight.dtype))
|
fast_logits = lm_head(fast_hidden.to(lm_head.weight.dtype))
|
||||||
fast_loss = _fast_ce(fast_logits, action_tokens, action_mask, predict_actions_t)
|
fast_loss = _fast_ce(fast_logits, action_tokens, action_code_mask, predict_actions_t)
|
||||||
|
|
||||||
return flow_loss, text_loss, fast_loss
|
return flow_loss, text_loss, fast_loss
|
||||||
|
|
||||||
@@ -660,6 +668,7 @@ class PI052Policy(PI05Policy):
|
|||||||
text_labels: Tensor | None,
|
text_labels: Tensor | None,
|
||||||
action_tokens: Tensor | None,
|
action_tokens: Tensor | None,
|
||||||
action_mask: Tensor | None,
|
action_mask: Tensor | None,
|
||||||
|
action_code_mask: Tensor | None,
|
||||||
predict_actions_t: Tensor | None = None,
|
predict_actions_t: Tensor | None = None,
|
||||||
) -> tuple[Tensor | None, Tensor | None]:
|
) -> tuple[Tensor | None, Tensor | None]:
|
||||||
"""Single prefix forward → text CE + FAST CE.
|
"""Single prefix forward → text CE + FAST CE.
|
||||||
@@ -745,10 +754,14 @@ class PI052Policy(PI05Policy):
|
|||||||
text_loss = _shifted_ce(text_logits, text_labels)
|
text_loss = _shifted_ce(text_logits, text_labels)
|
||||||
|
|
||||||
fast_loss: Tensor | None = None
|
fast_loss: Tensor | None = None
|
||||||
if action_tokens is not None and action_mask is not None and fast_len > 0:
|
if (
|
||||||
|
action_tokens is not None
|
||||||
|
and action_code_mask is not None
|
||||||
|
and fast_len > 0
|
||||||
|
):
|
||||||
fast_hidden = vlm_out[:, -fast_len:, :]
|
fast_hidden = vlm_out[:, -fast_len:, :]
|
||||||
fast_logits = lm_head(fast_hidden.to(lm_head.weight.dtype))
|
fast_logits = lm_head(fast_hidden.to(lm_head.weight.dtype))
|
||||||
fast_loss = _fast_ce(fast_logits, action_tokens, action_mask, predict_actions_t)
|
fast_loss = _fast_ce(fast_logits, action_tokens, action_code_mask, predict_actions_t)
|
||||||
|
|
||||||
return text_loss, fast_loss
|
return text_loss, fast_loss
|
||||||
|
|
||||||
|
|||||||
@@ -32,6 +32,7 @@ import torch
|
|||||||
from lerobot.configs import FeatureType, PipelineFeatureType, PolicyFeature
|
from lerobot.configs import FeatureType, PipelineFeatureType, PolicyFeature
|
||||||
from lerobot.types import EnvTransition, RobotObservation, TransitionKey
|
from lerobot.types import EnvTransition, RobotObservation, TransitionKey
|
||||||
from lerobot.utils.constants import (
|
from lerobot.utils.constants import (
|
||||||
|
ACTION_CODE_TOKEN_MASK,
|
||||||
ACTION_TOKEN_MASK,
|
ACTION_TOKEN_MASK,
|
||||||
ACTION_TOKENS,
|
ACTION_TOKENS,
|
||||||
OBS_LANGUAGE_ATTENTION_MASK,
|
OBS_LANGUAGE_ATTENTION_MASK,
|
||||||
@@ -412,14 +413,15 @@ class ActionTokenizerProcessorStep(ActionProcessorStep):
|
|||||||
# During inference, no action is available, skip tokenization
|
# During inference, no action is available, skip tokenization
|
||||||
return new_transition
|
return new_transition
|
||||||
|
|
||||||
# Tokenize and get both tokens and mask
|
# Tokenize and get masks for the full formatted sequence and the discrete action codes.
|
||||||
tokens, mask = self._tokenize_action(action)
|
tokens, mask, code_mask = self._tokenize_action(action)
|
||||||
|
|
||||||
# Store mask in complementary data
|
# Store mask in complementary data
|
||||||
complementary_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
complementary_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||||
if complementary_data is None:
|
if complementary_data is None:
|
||||||
complementary_data = {}
|
complementary_data = {}
|
||||||
complementary_data[ACTION_TOKEN_MASK] = mask
|
complementary_data[ACTION_TOKEN_MASK] = mask
|
||||||
|
complementary_data[ACTION_CODE_TOKEN_MASK] = code_mask
|
||||||
complementary_data[ACTION_TOKENS] = tokens
|
complementary_data[ACTION_TOKENS] = tokens
|
||||||
new_transition[TransitionKey.COMPLEMENTARY_DATA] = complementary_data
|
new_transition[TransitionKey.COMPLEMENTARY_DATA] = complementary_data
|
||||||
return new_transition
|
return new_transition
|
||||||
@@ -430,7 +432,7 @@ class ActionTokenizerProcessorStep(ActionProcessorStep):
|
|||||||
"""
|
"""
|
||||||
return self._paligemma_tokenizer.vocab_size - 1 - self.fast_skip_tokens - tokens
|
return self._paligemma_tokenizer.vocab_size - 1 - self.fast_skip_tokens - tokens
|
||||||
|
|
||||||
def _tokenize_action(self, action: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
def _tokenize_action(self, action: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Tokenizes the action tensor and creates a mask.
|
Tokenizes the action tensor and creates a mask.
|
||||||
|
|
||||||
@@ -459,6 +461,7 @@ class ActionTokenizerProcessorStep(ActionProcessorStep):
|
|||||||
# The fast tokenizer expects action data and returns token IDs
|
# The fast tokenizer expects action data and returns token IDs
|
||||||
tokens_list = []
|
tokens_list = []
|
||||||
masks_list = []
|
masks_list = []
|
||||||
|
code_masks_list = []
|
||||||
|
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
# Tokenize single action (move to CPU first as tokenizer uses scipy which requires numpy)
|
# Tokenize single action (move to CPU first as tokenizer uses scipy which requires numpy)
|
||||||
@@ -476,19 +479,26 @@ class ActionTokenizerProcessorStep(ActionProcessorStep):
|
|||||||
if tokens.dim() > 1:
|
if tokens.dim() > 1:
|
||||||
tokens = tokens.flatten()
|
tokens = tokens.flatten()
|
||||||
|
|
||||||
|
action_code_tokens = self._act_tokens_to_paligemma_tokens(tokens)
|
||||||
bos_id = self._paligemma_tokenizer.bos_token_id
|
bos_id = self._paligemma_tokenizer.bos_token_id
|
||||||
# add bos
|
prompt_tokens = torch.tensor(
|
||||||
|
self._paligemma_tokenizer.encode("Action: ", add_special_tokens=False),
|
||||||
|
device=action.device,
|
||||||
|
)
|
||||||
|
end_tokens = torch.tensor(self._paligemma_tokenizer.encode("|"), device=action.device)
|
||||||
|
|
||||||
|
code_start = 1 + len(prompt_tokens)
|
||||||
|
code_end = code_start + len(action_code_tokens)
|
||||||
tokens = torch.cat(
|
tokens = torch.cat(
|
||||||
[
|
[
|
||||||
torch.tensor([bos_id], device=action.device),
|
torch.tensor([bos_id], device=action.device),
|
||||||
torch.tensor(
|
prompt_tokens,
|
||||||
self._paligemma_tokenizer.encode("Action: ", add_special_tokens=False),
|
action_code_tokens,
|
||||||
device=action.device,
|
end_tokens,
|
||||||
),
|
|
||||||
self._act_tokens_to_paligemma_tokens(tokens),
|
|
||||||
torch.tensor(self._paligemma_tokenizer.encode("|"), device=action.device),
|
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
code_mask = torch.zeros(len(tokens), dtype=torch.bool, device=action.device)
|
||||||
|
code_mask[code_start:code_end] = True
|
||||||
|
|
||||||
# Truncate or pad to max_action_tokens
|
# Truncate or pad to max_action_tokens
|
||||||
if len(tokens) > self.max_action_tokens:
|
if len(tokens) > self.max_action_tokens:
|
||||||
@@ -497,44 +507,49 @@ class ActionTokenizerProcessorStep(ActionProcessorStep):
|
|||||||
"Consider increasing the `max_action_tokens` in your model config if this happens frequently."
|
"Consider increasing the `max_action_tokens` in your model config if this happens frequently."
|
||||||
)
|
)
|
||||||
tokens = tokens[: self.max_action_tokens]
|
tokens = tokens[: self.max_action_tokens]
|
||||||
|
code_mask = code_mask[: self.max_action_tokens]
|
||||||
mask = torch.ones(self.max_action_tokens, dtype=torch.bool, device=action.device)
|
mask = torch.ones(self.max_action_tokens, dtype=torch.bool, device=action.device)
|
||||||
else:
|
else:
|
||||||
|
pad_len = self.max_action_tokens - len(tokens)
|
||||||
mask = torch.cat(
|
mask = torch.cat(
|
||||||
[
|
[
|
||||||
torch.ones(len(tokens), dtype=torch.bool, device=action.device),
|
torch.ones(len(tokens), dtype=torch.bool, device=action.device),
|
||||||
torch.zeros(
|
torch.zeros(pad_len, dtype=torch.bool, device=action.device),
|
||||||
self.max_action_tokens - len(tokens), dtype=torch.bool, device=action.device
|
|
||||||
),
|
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
code_mask = torch.nn.functional.pad(code_mask, (0, pad_len), value=False)
|
||||||
# Pad tokens with zeros
|
# Pad tokens with zeros
|
||||||
tokens = torch.nn.functional.pad(tokens, (0, self.max_action_tokens - len(tokens)), value=0)
|
tokens = torch.nn.functional.pad(tokens, (0, pad_len), value=0)
|
||||||
|
|
||||||
tokens_list.append(tokens)
|
tokens_list.append(tokens)
|
||||||
masks_list.append(mask)
|
masks_list.append(mask)
|
||||||
|
code_masks_list.append(code_mask)
|
||||||
|
|
||||||
# Stack into batched tensors
|
# Stack into batched tensors
|
||||||
tokens_batch = torch.stack(tokens_list, dim=0) # (B, max_action_tokens)
|
tokens_batch = torch.stack(tokens_list, dim=0) # (B, max_action_tokens)
|
||||||
masks_batch = torch.stack(masks_list, dim=0) # (B, max_action_tokens)
|
masks_batch = torch.stack(masks_list, dim=0) # (B, max_action_tokens)
|
||||||
|
code_masks_batch = torch.stack(code_masks_list, dim=0) # (B, max_action_tokens)
|
||||||
|
|
||||||
# Remove batch dimension if input was single sample
|
# Remove batch dimension if input was single sample
|
||||||
if single_sample:
|
if single_sample:
|
||||||
tokens_batch = tokens_batch.squeeze(0)
|
tokens_batch = tokens_batch.squeeze(0)
|
||||||
masks_batch = masks_batch.squeeze(0)
|
masks_batch = masks_batch.squeeze(0)
|
||||||
|
code_masks_batch = code_masks_batch.squeeze(0)
|
||||||
|
|
||||||
# Move to the same device as the input
|
# Move to the same device as the input
|
||||||
if device is not None:
|
if device is not None:
|
||||||
tokens_batch = tokens_batch.to(device)
|
tokens_batch = tokens_batch.to(device)
|
||||||
masks_batch = masks_batch.to(device)
|
masks_batch = masks_batch.to(device)
|
||||||
|
code_masks_batch = code_masks_batch.to(device)
|
||||||
|
|
||||||
return tokens_batch, masks_batch
|
return tokens_batch, masks_batch, code_masks_batch
|
||||||
|
|
||||||
def action(self, action: torch.Tensor) -> torch.Tensor:
|
def action(self, action: torch.Tensor) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
This method is not used since we override __call__.
|
This method is not used since we override __call__.
|
||||||
Required by ActionProcessorStep ABC.
|
Required by ActionProcessorStep ABC.
|
||||||
"""
|
"""
|
||||||
tokens, _ = self._tokenize_action(action)
|
tokens, _, _ = self._tokenize_action(action)
|
||||||
return tokens
|
return tokens
|
||||||
|
|
||||||
def get_config(self) -> dict[str, Any]:
|
def get_config(self) -> dict[str, Any]:
|
||||||
|
|||||||
@@ -34,6 +34,7 @@ ACTION = "action"
|
|||||||
ACTION_PREFIX = ACTION + "."
|
ACTION_PREFIX = ACTION + "."
|
||||||
ACTION_TOKENS = ACTION + ".tokens"
|
ACTION_TOKENS = ACTION + ".tokens"
|
||||||
ACTION_TOKEN_MASK = ACTION + ".token_mask"
|
ACTION_TOKEN_MASK = ACTION + ".token_mask"
|
||||||
|
ACTION_CODE_TOKEN_MASK = ACTION + ".code_token_mask"
|
||||||
REWARD = "next.reward"
|
REWARD = "next.reward"
|
||||||
TRUNCATED = "next.truncated"
|
TRUNCATED = "next.truncated"
|
||||||
DONE = "next.done"
|
DONE = "next.done"
|
||||||
|
|||||||
@@ -0,0 +1,75 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""Regression tests for PI052 FAST action-code supervision."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
pytest.importorskip("transformers")
|
||||||
|
|
||||||
|
from lerobot.policies.pi052.modeling_pi052 import _fast_ce # noqa: E402
|
||||||
|
|
||||||
|
|
||||||
|
def test_fast_ce_supervises_only_discrete_action_codes():
|
||||||
|
"""Wrapper tokens can be wrong without affecting the FAST action-code loss."""
|
||||||
|
vocab_size = 8
|
||||||
|
action_tokens = torch.tensor([[1, 2, 3, 4, 5, 0]])
|
||||||
|
action_code_mask = torch.tensor([[False, False, True, True, False, False]])
|
||||||
|
|
||||||
|
logits = torch.zeros(1, action_tokens.shape[1], vocab_size)
|
||||||
|
# Deliberately bad wrapper-token predictions. These should be ignored.
|
||||||
|
logits[0, 0, 7] = 10.0 # target would be token 2
|
||||||
|
logits[0, 3, 7] = 10.0 # target would be delimiter token 5
|
||||||
|
# Correct action-code predictions: hidden t predicts target t + 1.
|
||||||
|
logits[0, 1, 3] = 10.0
|
||||||
|
logits[0, 2, 4] = 10.0
|
||||||
|
|
||||||
|
loss = _fast_ce(logits, action_tokens, action_code_mask, predict_actions_t=None)
|
||||||
|
expected = F.cross_entropy(
|
||||||
|
torch.stack([logits[0, 1], logits[0, 2]]),
|
||||||
|
torch.tensor([3, 4]),
|
||||||
|
reduction="mean",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert torch.allclose(loss, expected)
|
||||||
|
|
||||||
|
|
||||||
|
def test_fast_ce_masks_non_action_samples():
|
||||||
|
"""Recipe samples with predict_actions=False do not contribute FAST loss."""
|
||||||
|
vocab_size = 8
|
||||||
|
action_tokens = torch.tensor([[1, 2, 3, 4], [1, 2, 5, 6]])
|
||||||
|
action_code_mask = torch.tensor(
|
||||||
|
[[False, False, True, True], [False, False, True, True]]
|
||||||
|
)
|
||||||
|
predict_actions = torch.tensor([True, False])
|
||||||
|
|
||||||
|
logits = torch.zeros(2, action_tokens.shape[1], vocab_size)
|
||||||
|
logits[0, 1, 3] = 10.0
|
||||||
|
logits[0, 2, 4] = 10.0
|
||||||
|
# Bad predictions in the masked sample should not matter.
|
||||||
|
logits[1, 1, 7] = 10.0
|
||||||
|
logits[1, 2, 7] = 10.0
|
||||||
|
|
||||||
|
loss = _fast_ce(logits, action_tokens, action_code_mask, predict_actions)
|
||||||
|
expected = F.cross_entropy(
|
||||||
|
torch.stack([logits[0, 1], logits[0, 2]]),
|
||||||
|
torch.tensor([3, 4]),
|
||||||
|
reduction="mean",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert torch.allclose(loss, expected)
|
||||||
Reference in New Issue
Block a user