mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-24 04:59:47 +00:00
feat(pi052): FAST action CE loss + knowledge insulation + processor wiring
Three additions ported from ``pi05_full`` on branch ``feat/add-pi05``,
giving pi052 full paper-§III.B-C training capabilities alongside the
recipe-driven text supervision it already had:
* **Config flags** in PI052Config:
- ``enable_fast_action_loss`` default False
- ``action_tokenizer_name`` default "physical-intelligence/fast"
- ``max_action_tokens`` default 256
- ``fast_skip_tokens`` default 128
- ``fast_action_loss_weight`` default 1.0
- ``knowledge_insulation`` default False
* **Processor wiring** (processor_pi052.py): when
``enable_fast_action_loss=True``, append an
``ActionTokenizerProcessorStep`` after the text tokenizer. It
tokenises the action tensor with the FAST tokenizer and writes
ACTION_TOKENS / ACTION_TOKEN_MASK into ``COMPLEMENTARY_DATA`` —
the existing batch-collation pipeline forwards them as
``batch['action.tokens']`` / ``batch['action.token_mask']``.
* **FAST CE loss** (modeling_pi052.py::_compute_fast_action_loss):
Re-embeds the prefix [images, language], appends the FAST token
embeddings (using PaliGemma's shared embed_language_tokens),
forwards through the backbone, slices the trailing
``fast_len`` positions, applies the LM head, computes shifted
next-token CE with the action-mask gating the loss. The loss is
summed into ``forward()``'s total with ``fast_action_loss_weight``.
* **Knowledge insulation** (modeling_pi052.py::_compute_layer_ki +
_paligemma_forward_ki): port of pi05_full's per-layer attention
that detaches VLM K/V on the action-query path so action loss
gradients cannot flow back into the VLM's K/V projections. Bound
per-instance via ``types.MethodType`` so it doesn't leak into
stock ``pi05`` policies that share PaliGemmaWithExpertModel.
Activated automatically when ``config.knowledge_insulation=True``.
Combined with the existing recipe-driven text head, pi052 now
supports the full three-loss objective:
L = text_w·H(text) + fast_w·H(FAST actions) + flow_w·MSE(flow)
matching Eq. (1) of arxiv:2504.16054 §IV.D (α=10 by default for the
flow term, 1.0 each for text and FAST CE).
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -99,6 +99,50 @@ class PI052Config(PI05Config):
|
|||||||
memory_dropout_prob: float = 0.0
|
memory_dropout_prob: float = 0.0
|
||||||
subtask_dropout_prob: float = 0.0
|
subtask_dropout_prob: float = 0.0
|
||||||
|
|
||||||
|
# FAST discrete-action supervision — paper §III.B-C ------------------
|
||||||
|
# When enabled, actions are *also* tokenised via the FAST tokenizer
|
||||||
|
# ("physical-intelligence/fast") and supervised with cross-entropy
|
||||||
|
# on the PaliGemma LM head — exactly as in the paper's pre-training
|
||||||
|
# objective (Eq. 1 mixes FAST CE + flow MSE + subtask CE). The
|
||||||
|
# ActionTokenizerProcessorStep is wired into the preprocessor
|
||||||
|
# pipeline when this flag is set; the loss is computed in
|
||||||
|
# PI052Policy.forward.
|
||||||
|
enable_fast_action_loss: bool = False
|
||||||
|
"""If True, tokenise actions with the FAST tokenizer and add a
|
||||||
|
cross-entropy loss on the LM head. Off by default because most
|
||||||
|
fine-tuning runs only need the flow head + text supervision; the
|
||||||
|
FAST CE term is most useful when training from a base PaliGemma
|
||||||
|
rather than an existing π0.5 checkpoint."""
|
||||||
|
|
||||||
|
action_tokenizer_name: str = "physical-intelligence/fast"
|
||||||
|
"""HF identifier for the FAST action tokenizer."""
|
||||||
|
|
||||||
|
max_action_tokens: int = 256
|
||||||
|
"""Maximum number of FAST tokens per action chunk."""
|
||||||
|
|
||||||
|
fast_skip_tokens: int = 128
|
||||||
|
"""Number of low-vocab tokens the FAST tokenizer skips to avoid
|
||||||
|
collisions with PaliGemma's text vocabulary."""
|
||||||
|
|
||||||
|
fast_action_loss_weight: float = 1.0
|
||||||
|
"""Weight on the FAST-action-token CE loss. Paper §III.C uses 1.0."""
|
||||||
|
|
||||||
|
# Knowledge insulation — paper §III.B --------------------------------
|
||||||
|
# When enabled, gradients from the action expert's flow loss are
|
||||||
|
# *blocked* from flowing back into the VLM's K/V projections. This
|
||||||
|
# prevents the action loss from over-fitting the language backbone
|
||||||
|
# to robot-specific features. Implementation requires a custom
|
||||||
|
# per-layer attention forward that uses ``.detach()`` on VLM K/V
|
||||||
|
# when computing attention for action queries — see
|
||||||
|
# ``pi05_full/modeling_pi05.py::compute_layer_complete_knowledge_insulation``
|
||||||
|
# on branch ``feat/add-pi05`` for the reference implementation.
|
||||||
|
knowledge_insulation: bool = False
|
||||||
|
"""If True, route every transformer layer through the KI
|
||||||
|
attention path (blocks action→VLM gradient flow on K/V).
|
||||||
|
Currently a no-op stub in this branch — PR follow-up required to
|
||||||
|
port the full custom layer from pi05_full. The flag is exposed
|
||||||
|
here so SLURM commands can be written against the final shape."""
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
super().__post_init__()
|
super().__post_init__()
|
||||||
# Backbone needs gradients flowing through the text head when
|
# Backbone needs gradients flowing through the text head when
|
||||||
|
|||||||
@@ -39,10 +39,12 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
|
import types
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
from lerobot.utils.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
|
from lerobot.utils.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
|
||||||
|
|
||||||
@@ -53,6 +55,190 @@ from .configuration_pi052 import PI052Config
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------
|
||||||
|
# Knowledge insulation — ported from pi05_full (branch ``feat/add-pi05``)
|
||||||
|
# ----------------------------------------------------------------------
|
||||||
|
#
|
||||||
|
# Per-layer attention that splits the queries into VLM and action
|
||||||
|
# parts, computing attention for action queries with .detach()'d VLM
|
||||||
|
# K/V so the action loss's gradient cannot flow back into the VLM's K
|
||||||
|
# and V projections. Forward output is bit-equivalent to the standard
|
||||||
|
# layer; backward differs only on the path action_loss → VLM K/V.
|
||||||
|
|
||||||
|
def _compute_layer_ki(
|
||||||
|
layer_idx,
|
||||||
|
inputs_embeds,
|
||||||
|
attention_mask,
|
||||||
|
position_ids,
|
||||||
|
adarms_cond,
|
||||||
|
paligemma,
|
||||||
|
gemma_expert,
|
||||||
|
):
|
||||||
|
from transformers.models.gemma import modeling_gemma # noqa: PLC0415
|
||||||
|
|
||||||
|
models = [paligemma.language_model, gemma_expert.model]
|
||||||
|
query_states, key_states, value_states, gates = [], [], [], []
|
||||||
|
|
||||||
|
vlm_len = inputs_embeds[0].shape[1]
|
||||||
|
|
||||||
|
for i, hidden_states in enumerate(inputs_embeds):
|
||||||
|
layer = models[i].layers[layer_idx]
|
||||||
|
hidden_states, gate = layer.input_layernorm(hidden_states, cond=adarms_cond[i])
|
||||||
|
gates.append(gate)
|
||||||
|
input_shape = hidden_states.shape[:-1]
|
||||||
|
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
|
||||||
|
q = layer.self_attn.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||||
|
k = layer.self_attn.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||||
|
v = layer.self_attn.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||||
|
query_states.append(q)
|
||||||
|
key_states.append(k)
|
||||||
|
value_states.append(v)
|
||||||
|
|
||||||
|
query_states = torch.cat(query_states, dim=2)
|
||||||
|
key_states = torch.cat(key_states, dim=2)
|
||||||
|
value_states = torch.cat(value_states, dim=2)
|
||||||
|
|
||||||
|
dummy = torch.zeros(
|
||||||
|
query_states.shape[0],
|
||||||
|
query_states.shape[2],
|
||||||
|
query_states.shape[-1],
|
||||||
|
device=query_states.device,
|
||||||
|
dtype=query_states.dtype,
|
||||||
|
)
|
||||||
|
cos, sin = paligemma.model.language_model.rotary_emb(dummy, position_ids)
|
||||||
|
query_states, key_states = modeling_gemma.apply_rotary_pos_emb(
|
||||||
|
query_states, key_states, cos, sin, unsqueeze_dim=1
|
||||||
|
)
|
||||||
|
|
||||||
|
batch_size = query_states.shape[0]
|
||||||
|
scaling = paligemma.language_model.layers[layer_idx].self_attn.scaling
|
||||||
|
|
||||||
|
# Split queries / K / V at the VLM-vs-action boundary.
|
||||||
|
Q_vlm = query_states[:, :, :vlm_len, :]
|
||||||
|
Q_action = query_states[:, :, vlm_len:, :]
|
||||||
|
K_vlm = key_states[:, :, :vlm_len, :]
|
||||||
|
K_action = key_states[:, :, vlm_len:, :]
|
||||||
|
V_vlm = value_states[:, :, :vlm_len, :]
|
||||||
|
V_action = value_states[:, :, vlm_len:, :]
|
||||||
|
|
||||||
|
# Detach VLM K/V *only* on the path the action queries use.
|
||||||
|
K_vlm_det = K_vlm.detach()
|
||||||
|
V_vlm_det = V_vlm.detach()
|
||||||
|
K_for_vlm = key_states # full (gradients flow)
|
||||||
|
V_for_vlm = value_states
|
||||||
|
K_for_action = torch.cat([K_vlm_det, K_action], dim=2)
|
||||||
|
V_for_action = torch.cat([V_vlm_det, V_action], dim=2)
|
||||||
|
|
||||||
|
mask_for_vlm = attention_mask[:, :, :vlm_len, :]
|
||||||
|
mask_for_action = attention_mask[:, :, vlm_len:, :]
|
||||||
|
|
||||||
|
att_vlm, _ = modeling_gemma.eager_attention_forward(
|
||||||
|
paligemma.language_model.layers[layer_idx].self_attn,
|
||||||
|
Q_vlm, K_for_vlm, V_for_vlm, mask_for_vlm, scaling,
|
||||||
|
)
|
||||||
|
att_action, _ = modeling_gemma.eager_attention_forward(
|
||||||
|
paligemma.language_model.layers[layer_idx].self_attn,
|
||||||
|
Q_action, K_for_action, V_for_action, mask_for_action, scaling,
|
||||||
|
)
|
||||||
|
att = torch.cat([att_vlm, att_action], dim=1)
|
||||||
|
|
||||||
|
head_dim = paligemma.language_model.layers[layer_idx].self_attn.head_dim
|
||||||
|
att = att.reshape(batch_size, -1, 1 * 8 * head_dim)
|
||||||
|
|
||||||
|
outputs_embeds = []
|
||||||
|
start = 0
|
||||||
|
for i, hidden_states in enumerate(inputs_embeds):
|
||||||
|
layer = models[i].layers[layer_idx]
|
||||||
|
end = start + hidden_states.shape[1]
|
||||||
|
if att.dtype != layer.self_attn.o_proj.weight.dtype:
|
||||||
|
att = att.to(layer.self_attn.o_proj.weight.dtype)
|
||||||
|
out_emb = layer.self_attn.o_proj(att[:, start:end])
|
||||||
|
out_emb = modeling_gemma._gated_residual(hidden_states, out_emb, gates[i]) # noqa: SLF001
|
||||||
|
after_first = out_emb.clone()
|
||||||
|
out_emb, gate = layer.post_attention_layernorm(out_emb.clone(), cond=adarms_cond[i])
|
||||||
|
if layer.mlp.up_proj.weight.dtype == torch.bfloat16:
|
||||||
|
out_emb = out_emb.to(dtype=torch.bfloat16)
|
||||||
|
out_emb = layer.mlp(out_emb)
|
||||||
|
out_emb = modeling_gemma._gated_residual(after_first, out_emb, gate) # noqa: SLF001
|
||||||
|
outputs_embeds.append(out_emb)
|
||||||
|
start = end
|
||||||
|
return outputs_embeds
|
||||||
|
|
||||||
|
|
||||||
|
def _paligemma_forward_ki(
|
||||||
|
self,
|
||||||
|
attention_mask=None,
|
||||||
|
position_ids=None,
|
||||||
|
past_key_values=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
use_cache=None,
|
||||||
|
adarms_cond=None,
|
||||||
|
fill_kv_cache=None,
|
||||||
|
):
|
||||||
|
"""Replacement ``PaliGemmaWithExpertModel.forward`` that routes the
|
||||||
|
dual-expert layer pass through :func:`_compute_layer_ki`.
|
||||||
|
|
||||||
|
Bound onto the model instance when ``config.knowledge_insulation``
|
||||||
|
is True (see ``PI052Policy.__init__``). Single-expert branches
|
||||||
|
(VLM-only or action-only) reuse the parent's implementation
|
||||||
|
because there's no KI signal to add — KI only matters when
|
||||||
|
actions and VLM tokens are forwarded together.
|
||||||
|
"""
|
||||||
|
from ..pi05.modeling_pi05 import layernorm_forward # noqa: PLC0415
|
||||||
|
|
||||||
|
if adarms_cond is None:
|
||||||
|
adarms_cond = [None, None]
|
||||||
|
|
||||||
|
# Single-expert paths: defer to the bound class method via super().
|
||||||
|
if inputs_embeds[0] is None or inputs_embeds[1] is None:
|
||||||
|
return type(self).__bases__[0].forward(
|
||||||
|
self,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
adarms_cond=adarms_cond,
|
||||||
|
) if hasattr(self, "_pi052_orig_forward") else self._pi052_orig_forward(
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
adarms_cond=adarms_cond,
|
||||||
|
)
|
||||||
|
|
||||||
|
models = [self.paligemma.model.language_model, self.gemma_expert.model]
|
||||||
|
num_layers = self.paligemma.config.text_config.num_hidden_layers
|
||||||
|
use_gc = (
|
||||||
|
hasattr(self.gemma_expert.model, "gradient_checkpointing")
|
||||||
|
and self.gemma_expert.model.gradient_checkpointing
|
||||||
|
and self.training
|
||||||
|
) or (
|
||||||
|
hasattr(self, "gradient_checkpointing") and self.gradient_checkpointing and self.training
|
||||||
|
)
|
||||||
|
|
||||||
|
for layer_idx in range(num_layers):
|
||||||
|
if use_gc:
|
||||||
|
inputs_embeds = torch.utils.checkpoint.checkpoint(
|
||||||
|
_compute_layer_ki,
|
||||||
|
layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond,
|
||||||
|
use_reentrant=False, preserve_rng_state=False,
|
||||||
|
paligemma=self.paligemma, gemma_expert=self.gemma_expert,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
inputs_embeds = _compute_layer_ki(
|
||||||
|
layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond,
|
||||||
|
paligemma=self.paligemma, gemma_expert=self.gemma_expert,
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs_embeds = []
|
||||||
|
for i, hidden_states in enumerate(inputs_embeds):
|
||||||
|
out_emb, _ = layernorm_forward(models[i].norm, hidden_states, adarms_cond[i])
|
||||||
|
outputs_embeds.append(out_emb)
|
||||||
|
return [outputs_embeds[0], outputs_embeds[1]], None
|
||||||
|
|
||||||
|
|
||||||
class PI052Policy(PI05Policy):
|
class PI052Policy(PI05Policy):
|
||||||
"""π0.5 with the PaliGemma LM head re-enabled."""
|
"""π0.5 with the PaliGemma LM head re-enabled."""
|
||||||
|
|
||||||
@@ -68,6 +254,20 @@ class PI052Policy(PI05Policy):
|
|||||||
if config.text_loss_weight > 0 and config.unfreeze_lm_head:
|
if config.text_loss_weight > 0 and config.unfreeze_lm_head:
|
||||||
self._unfreeze_lm_head()
|
self._unfreeze_lm_head()
|
||||||
|
|
||||||
|
# Knowledge insulation: bind a custom ``forward`` on the
|
||||||
|
# PaliGemmaWithExpertModel instance that uses
|
||||||
|
# :func:`_compute_layer_ki` for the dual-expert layer pass.
|
||||||
|
# The bind is per-instance, so this doesn't leak into stock
|
||||||
|
# ``pi05`` policies that share the same class.
|
||||||
|
if getattr(config, "knowledge_insulation", False):
|
||||||
|
backbone = self.model.paligemma_with_expert
|
||||||
|
backbone._pi052_orig_forward = backbone.forward
|
||||||
|
backbone.forward = types.MethodType(_paligemma_forward_ki, backbone)
|
||||||
|
logger.info(
|
||||||
|
"PI052: knowledge insulation enabled — action→VLM K/V "
|
||||||
|
"gradients are blocked in attention."
|
||||||
|
)
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Head unfreeze helper
|
# Head unfreeze helper
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
@@ -143,6 +343,22 @@ class PI052Policy(PI05Policy):
|
|||||||
else total + self.config.text_loss_weight * text_loss
|
else total + self.config.text_loss_weight * text_loss
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# FAST action-token CE loss (paper §III.C). When
|
||||||
|
# ``enable_fast_action_loss=True`` the preprocessor wrote
|
||||||
|
# ACTION_TOKENS / ACTION_TOKEN_MASK into the batch — we
|
||||||
|
# forward them through the PaliGemma backbone alongside the
|
||||||
|
# language prefix and compute CE on the action positions.
|
||||||
|
if getattr(self.config, "enable_fast_action_loss", False):
|
||||||
|
from lerobot.utils.constants import ACTION_TOKEN_MASK, ACTION_TOKENS # noqa: PLC0415
|
||||||
|
|
||||||
|
action_tokens = batch.get(ACTION_TOKENS)
|
||||||
|
action_mask = batch.get(ACTION_TOKEN_MASK)
|
||||||
|
if action_tokens is not None and action_mask is not None:
|
||||||
|
fast_loss = self._compute_fast_action_loss(batch, action_tokens, action_mask)
|
||||||
|
loss_dict["fast_action_loss"] = float(fast_loss.detach().item())
|
||||||
|
weighted = self.config.fast_action_loss_weight * fast_loss
|
||||||
|
total = weighted if total is None else total + weighted
|
||||||
|
|
||||||
if total is None:
|
if total is None:
|
||||||
# Both flow and text disabled — make this an obvious bug
|
# Both flow and text disabled — make this an obvious bug
|
||||||
# rather than a silent zero loss.
|
# rather than a silent zero loss.
|
||||||
@@ -161,6 +377,82 @@ class PI052Policy(PI05Policy):
|
|||||||
# Text loss
|
# Text loss
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _compute_fast_action_loss(
|
||||||
|
self,
|
||||||
|
batch: dict[str, Tensor],
|
||||||
|
action_tokens: Tensor,
|
||||||
|
action_mask: Tensor,
|
||||||
|
) -> Tensor:
|
||||||
|
"""Cross-entropy on FAST-tokenised actions via PaliGemma's LM head.
|
||||||
|
|
||||||
|
Mirrors the paper's §III.C action-token loss: append the FAST
|
||||||
|
tokens to the language prefix, forward through the backbone,
|
||||||
|
slice the per-token logits at the action positions, and
|
||||||
|
compute CE against the FAST targets (shifted for next-token
|
||||||
|
prediction). Token-mask gates the loss to valid positions.
|
||||||
|
"""
|
||||||
|
from ..pi05.modeling_pi05 import make_att_2d_masks # noqa: PLC0415
|
||||||
|
|
||||||
|
images, img_masks = self.model._preprocess_images(batch)
|
||||||
|
lang_tokens = batch[OBS_LANGUAGE_TOKENS]
|
||||||
|
lang_masks = batch[OBS_LANGUAGE_ATTENTION_MASK]
|
||||||
|
|
||||||
|
# Re-embed prefix to get [images, language] embeddings, then
|
||||||
|
# append the FAST action token embeddings as additional
|
||||||
|
# prefix tokens. PaliGemma's embed_language_tokens is shared
|
||||||
|
# between text and FAST tokens so we can re-use it directly.
|
||||||
|
prefix_embs, prefix_pad, prefix_att = self.model.embed_prefix(
|
||||||
|
images, img_masks, lang_tokens, lang_masks
|
||||||
|
)
|
||||||
|
emb_dim = prefix_embs.shape[-1]
|
||||||
|
fast_emb = self.model.paligemma_with_expert.embed_language_tokens(action_tokens)
|
||||||
|
fast_emb = fast_emb * math.sqrt(emb_dim)
|
||||||
|
|
||||||
|
# Concat onto the prefix. Pad masks: language uses
|
||||||
|
# ``lang_masks``; FAST uses ``action_mask`` (True at valid
|
||||||
|
# token positions). Attention masks add ``True`` (causal)
|
||||||
|
# for FAST so they can attend to the bidirectional prefix
|
||||||
|
# but only causally among themselves.
|
||||||
|
bsize, fast_len = action_tokens.shape
|
||||||
|
device = prefix_embs.device
|
||||||
|
ones_att = torch.ones((bsize, fast_len), dtype=torch.bool, device=device)
|
||||||
|
full_embs = torch.cat([prefix_embs, fast_emb], dim=1)
|
||||||
|
full_pad = torch.cat([prefix_pad, action_mask.to(prefix_pad.dtype)], dim=1)
|
||||||
|
full_att = torch.cat([prefix_att, ones_att], dim=1)
|
||||||
|
|
||||||
|
att_2d = make_att_2d_masks(full_pad, full_att)
|
||||||
|
position_ids = torch.cumsum(full_pad, dim=1) - 1
|
||||||
|
|
||||||
|
(vlm_out, _), _ = self.model.paligemma_with_expert.forward(
|
||||||
|
attention_mask=att_2d,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=None,
|
||||||
|
inputs_embeds=[full_embs, None],
|
||||||
|
use_cache=False,
|
||||||
|
fill_kv_cache=True,
|
||||||
|
)
|
||||||
|
if vlm_out is None:
|
||||||
|
raise RuntimeError("PI052 FAST loss: VLM forward returned no hidden states.")
|
||||||
|
|
||||||
|
# Slice the last ``fast_len`` positions — those correspond to
|
||||||
|
# the FAST tokens we just appended.
|
||||||
|
fast_hidden = vlm_out[:, -fast_len:, :]
|
||||||
|
lm_head = self.model.paligemma_with_expert.paligemma.lm_head
|
||||||
|
fast_logits = lm_head(fast_hidden.to(lm_head.weight.dtype))
|
||||||
|
|
||||||
|
# Shift for next-token prediction.
|
||||||
|
shift_logits = fast_logits[:, :-1, :].contiguous()
|
||||||
|
shift_targets = action_tokens[:, 1:].contiguous()
|
||||||
|
shift_mask = action_mask[:, 1:].contiguous().float()
|
||||||
|
|
||||||
|
per_tok = F.cross_entropy(
|
||||||
|
shift_logits.view(-1, shift_logits.size(-1)),
|
||||||
|
shift_targets.view(-1),
|
||||||
|
reduction="none",
|
||||||
|
).view(shift_targets.shape)
|
||||||
|
loss = (per_tok * shift_mask).sum() / shift_mask.sum().clamp_min(1.0)
|
||||||
|
return loss
|
||||||
|
|
||||||
def _compute_text_loss(self, batch: dict[str, Tensor], text_labels: Tensor) -> Tensor:
|
def _compute_text_loss(self, batch: dict[str, Tensor], text_labels: Tensor) -> Tensor:
|
||||||
"""Cross-entropy on PaliGemma's LM head over the supervised span.
|
"""Cross-entropy on PaliGemma's LM head over the supervised span.
|
||||||
|
|
||||||
|
|||||||
@@ -43,6 +43,7 @@ import torch
|
|||||||
from lerobot.configs.recipe import TrainingRecipe
|
from lerobot.configs.recipe import TrainingRecipe
|
||||||
from lerobot.processor import (
|
from lerobot.processor import (
|
||||||
AbsoluteActionsProcessorStep,
|
AbsoluteActionsProcessorStep,
|
||||||
|
ActionTokenizerProcessorStep,
|
||||||
AddBatchDimensionProcessorStep,
|
AddBatchDimensionProcessorStep,
|
||||||
DeviceProcessorStep,
|
DeviceProcessorStep,
|
||||||
NormalizerProcessorStep,
|
NormalizerProcessorStep,
|
||||||
@@ -101,9 +102,25 @@ def make_pi052_pre_post_processors(
|
|||||||
memory_dropout_prob=getattr(config, "memory_dropout_prob", 0.0),
|
memory_dropout_prob=getattr(config, "memory_dropout_prob", 0.0),
|
||||||
subtask_dropout_prob=getattr(config, "subtask_dropout_prob", 0.0),
|
subtask_dropout_prob=getattr(config, "subtask_dropout_prob", 0.0),
|
||||||
),
|
),
|
||||||
DeviceProcessorStep(device=config.device),
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# FAST tokenizer for discrete-action CE supervision (paper §III.C).
|
||||||
|
# Only inserted when explicitly enabled — keeps the post-training-
|
||||||
|
# style recipe (flow + text) as the default. When on, the step
|
||||||
|
# writes ACTION_TOKENS / ACTION_TOKEN_MASK into
|
||||||
|
# ``COMPLEMENTARY_DATA`` and the modeling forward picks them up.
|
||||||
|
if getattr(config, "enable_fast_action_loss", False):
|
||||||
|
input_steps.append(
|
||||||
|
ActionTokenizerProcessorStep(
|
||||||
|
action_tokenizer_name=config.action_tokenizer_name,
|
||||||
|
max_action_tokens=config.max_action_tokens,
|
||||||
|
fast_skip_tokens=config.fast_skip_tokens,
|
||||||
|
paligemma_tokenizer_name="google/paligemma-3b-pt-224",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
input_steps.append(DeviceProcessorStep(device=config.device))
|
||||||
|
|
||||||
output_steps = [
|
output_steps = [
|
||||||
UnnormalizerProcessorStep(
|
UnnormalizerProcessorStep(
|
||||||
features=config.output_features,
|
features=config.output_features,
|
||||||
|
|||||||
Reference in New Issue
Block a user